diff --git a/pyproject.toml b/pyproject.toml index 0a29123c87..911d611804 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "fastapi>=0.133.0", "httpx[http2]>=0.27.2", "jinja2>=3.1.6", + "mcp>=1.0,<2", "numpy>=1.26.0; python_version < '3.14'", "numpy>=2.3.0; python_version >= '3.14'", "openai>=2.2.0", diff --git a/pyrit/exceptions/__init__.py b/pyrit/exceptions/__init__.py index abd42de031..9baea33c1a 100644 --- a/pyrit/exceptions/__init__.py +++ b/pyrit/exceptions/__init__.py @@ -10,6 +10,8 @@ MissingPromptPlaceholderException, PyritException, RateLimitException, + ToolCallLoopLimitExceeded, + ToolCallNotSupported, get_retry_max_num_attempts, handle_bad_request_exception, pyrit_custom_result_retry, @@ -59,4 +61,6 @@ "set_execution_context", "set_retry_collector", "execution_context", + "ToolCallLoopLimitExceeded", + "ToolCallNotSupported", ] diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index b2fc55440b..5d0014aa3d 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -233,6 +233,70 @@ def __init__(self, *, message: str = "No prompt placeholder") -> None: super().__init__(message=message) +class ToolCallNotSupported(PyritException): + """ + Raised when a target produces a tool call that the configured + :class:`~pyrit.tools.ToolEventPolicy` does not permit to execute + (``ToolEventBehavior.RAISE``, or ``EXECUTE`` without a backend). + + The ``partial_conversation`` attribute carries every message produced + up to and including the assistant turn that contained the offending + tool call(s). Consumers can inspect it to log the surfaced tool-use + attempt. + """ + + def __init__( + self, + *, + message: str = "Tool call not supported by configured policy.", + partial_conversation: Optional[list["Message"]] = None, + ) -> None: + """ + Initialize the exception. + + Args: + message (str): Human-readable error description. + partial_conversation (Optional[list[Message]]): Messages produced by + the target up to (and including) the assistant turn that + contained the disallowed tool call(s). + """ + super().__init__(status_code=400, message=message) + self.partial_conversation: list[Message] = ( + list(partial_conversation) if partial_conversation is not None else [] + ) + + +class ToolCallLoopLimitExceeded(PyritException): + """ + Raised when the tool-use loop runs for more than + ``ToolEventPolicy.max_tool_iterations`` iterations without the model + producing a stop response. + + The ``partial_conversation`` attribute carries every message produced + across all completed iterations. Consumers can inspect it to debug + runaway agentic behavior. + """ + + def __init__( + self, + *, + message: str = "Tool loop exceeded max_tool_iterations without a stop response.", + partial_conversation: Optional[list["Message"]] = None, + ) -> None: + """ + Initialize the exception. + + Args: + message (str): Human-readable error description. + partial_conversation (Optional[list[Message]]): Messages produced by + the target across every completed iteration of the tool loop. + """ + super().__init__(status_code=400, message=message) + self.partial_conversation: list[Message] = ( + list(partial_conversation) if partial_conversation is not None else [] + ) + + def pyrit_custom_result_retry( retry_function: Callable[..., bool], retry_max_num_attempts: Optional[int] = None ) -> Callable[..., Any]: diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index c5d3547e80..c9e7c0c532 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -4,6 +4,7 @@ import base64 import json import os +from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Union from pyrit.common.data_url_converter import convert_local_image_to_data_url_async @@ -14,6 +15,7 @@ apply_system_message_behavior, ) from pyrit.models import ChatMessage, DataTypeSerializer, Message +from pyrit.models.chat_message import ToolCall, ToolCallFunction from pyrit.models.message_piece import MessagePiece if TYPE_CHECKING: @@ -83,6 +85,11 @@ async def normalize_async(self, messages: list[Message]) -> list[ChatMessage]: chat_messages: list[ChatMessage] = [] for message in processed_messages: pieces = message.message_pieces + tool_message = self._try_build_tool_message(pieces=pieces) + if tool_message is not None: + chat_messages.append(tool_message) + continue + role: ChatMessageRole = pieces[0].api_role # Translate system -> developer for newer OpenAI models @@ -99,6 +106,89 @@ async def normalize_async(self, messages: list[Message]) -> list[ChatMessage]: return chat_messages + def _try_build_tool_message(self, *, pieces: Sequence[MessagePiece]) -> ChatMessage | None: + """ + Build an OpenAI Chat Completions tool message when ``pieces`` carries tool data. + + Returns a populated ``ChatMessage`` when the pieces are tool-call + envelopes (``function_call`` or ``function_call_output`` data type), + or ``None`` when the pieces are ordinary text / multimodal content. + + ``function_call`` pieces produce a single ``role="assistant"`` message + with ``content=None`` and one or more entries in ``tool_calls``. + ``function_call_output`` pieces produce a single ``role="tool"`` + message whose ``content`` is the output payload and whose + ``tool_call_id`` matches the originating call. + + Args: + pieces (list[MessagePiece]): The pieces making up one PyRIT message. + + Returns: + ChatMessage | None: ``None`` when no tool envelopes are present, + otherwise the converted tool message. + """ + if not pieces: + return None + data_types = {p.converted_value_data_type or p.original_value_data_type for p in pieces} + if data_types == {"function_call"}: + return ChatMessage( + role="assistant", + content=None, + tool_calls=[self._piece_to_tool_call(piece) for piece in pieces], + ) + if data_types == {"function_call_output"}: + # A single message carries one or more function_call_output pieces + # in declaration order; the OpenAI wire shape sends each as its + # own role="tool" message. For multi-piece tool messages, we + # surface the first piece here and let the caller emit additional + # messages — but in practice tool_loop emits one message per + # iteration with multiple pieces, and OpenAI accepts a single + # tool message per call_id. Emit the first envelope; warn if + # multiple are present. + envelope = self._decode_envelope(pieces[0]) + return ChatMessage( + role="tool", + content=str(envelope.get("output", "")), + tool_call_id=str(envelope["call_id"]), + ) + return None + + @staticmethod + def _decode_envelope(piece: MessagePiece) -> dict[str, Any]: + """ + Decode the canonical-envelope JSON carried in a tool piece. + + Args: + piece (MessagePiece): A piece whose ``converted_value`` is the + canonical-envelope JSON string. + + Returns: + dict[str, Any]: The parsed envelope. + """ + return json.loads(piece.converted_value) + + @classmethod + def _piece_to_tool_call(cls, piece: MessagePiece) -> ToolCall: + """ + Convert one canonical ``function_call`` piece into an OpenAI ToolCall. + + Args: + piece (MessagePiece): A piece carrying a canonical ``function_call`` + envelope. + + Returns: + ToolCall: The corresponding OpenAI Chat Completions tool call. + """ + envelope = cls._decode_envelope(piece) + return ToolCall( + id=str(envelope["call_id"]), + type="function", + function=ToolCallFunction( + name=str(envelope["name"]), + arguments=str(envelope["arguments"]), + ), + ) + async def normalize_string_async(self, messages: list[Message]) -> str: """ Convert a list of Messages to a JSON string representation. diff --git a/pyrit/models/chat_message.py b/pyrit/models/chat_message.py index c2f801862d..faf39cb908 100644 --- a/pyrit/models/chat_message.py +++ b/pyrit/models/chat_message.py @@ -10,13 +10,28 @@ ALLOWED_CHAT_MESSAGE_ROLES = ["system", "user", "assistant", "simulated_assistant", "tool", "developer"] +class ToolCallFunction(BaseModel): + """The ``function`` payload of an OpenAI Chat Completions tool call.""" + + model_config = ConfigDict(extra="forbid") + name: str + arguments: str + + class ToolCall(BaseModel): - """Represents a tool invocation requested by the assistant.""" + """ + Represents a tool invocation requested by the assistant. + + Matches the OpenAI Chat Completions API ``tool_calls`` shape: each entry + has a provider-issued ``id``, a ``type`` string (currently always + ``"function"``), and a nested ``function`` object carrying the tool + ``name`` and JSON-encoded ``arguments``. + """ model_config = ConfigDict(extra="forbid") id: str type: str - function: str + function: ToolCallFunction class ChatMessage(BaseModel): @@ -26,11 +41,12 @@ class ChatMessage(BaseModel): The content field can be: - A simple string for single-part text messages - A list of dicts for multipart messages (e.g., text + images) + - ``None`` for assistant messages whose payload is a tool-call only """ model_config = ConfigDict(extra="forbid") role: ChatMessageRole - content: Union[str, list[dict[str, Any]]] + content: Optional[Union[str, list[dict[str, Any]]]] = None name: Optional[str] = None tool_calls: Optional[list[ToolCall]] = None tool_call_id: Optional[str] = None diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 1c9d54d913..59eefe116c 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import json import logging from typing import Any @@ -18,6 +19,7 @@ from pyrit.message_normalizer import ChatMessageNormalizer, MessageListNormalizer from pyrit.models import ( Message, + MessagePiece, construct_response_from_request, ) from pyrit.prompt_target.common.prompt_target import PromptTarget @@ -29,6 +31,7 @@ ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p +from pyrit.tools import ToolBackend, ToolCallParser logger = logging.getLogger(__name__) @@ -70,6 +73,8 @@ def __init__( repetition_penalty: float = 1.0, max_requests_per_minute: int | None = None, custom_configuration: TargetConfiguration | None = None, + tool_parser: ToolCallParser | None = None, + tool_backend: ToolBackend | None = None, **param_kwargs: Any, ) -> None: """ @@ -100,6 +105,17 @@ def __init__( will be capped at the value provided. custom_configuration (TargetConfiguration | None): Override the default configuration for this target instance. Useful for targets whose capabilities depend on deployment configuration. + tool_parser (ToolCallParser | None): When supplied, the target opts into PyRIT's + ``@tool_loop`` and uses this parser to extract pending tool calls from the + response. Supplying a parser also enables the ``supports_tool_use`` capability + on the default configuration so callers don't have to construct a custom + configuration just to enable the loop. The parser's expectations about the + deployment's response shape MUST line up with the contract documented in + ``doc/code/targets/`` for tool-capable Azure ML deployments. + tool_backend (ToolBackend | None): Convenience kwarg that wires a tool backend + onto ``custom_configuration.tool_backend``. Equivalent to constructing a + ``TargetConfiguration`` with the backend assigned. When ``custom_configuration`` + already specifies a backend, the kwarg is rejected. **param_kwargs: Additional parameters to pass to the model for generating responses. Example parameters can be found here: https://huggingface.co/docs/api-inference/tasks/text-generation. Note that the link above may not be comprehensive, and specific acceptable parameters may be @@ -145,6 +161,18 @@ def __init__( normalizer_overrides={CapabilityName.SYSTEM_PROMPT: message_normalizer}, ) + # Enable tool-use capability when a parser is supplied so callers + # don't need to construct a custom_configuration just to opt in. + if tool_parser is not None: + custom_configuration = self._enable_tool_use(configuration=custom_configuration) + + # tool_backend is a convenience kwarg; install it into the configuration. + if tool_backend is not None: + custom_configuration = self._install_tool_backend( + configuration=custom_configuration, + tool_backend=tool_backend, + ) + PromptTarget.__init__( self, max_requests_per_minute=max_requests_per_minute, @@ -163,6 +191,76 @@ def __init__( self._top_p = top_p self._repetition_penalty = repetition_penalty self._extra_parameters = param_kwargs + self._tool_parser_instance = tool_parser + + def _enable_tool_use(self, *, configuration: TargetConfiguration | None) -> TargetConfiguration: + """ + Return a configuration whose capabilities include ``supports_tool_use=True``. + + When ``configuration`` already has the capability set, returns it as-is. + Otherwise rebuilds the capabilities with ``supports_tool_use=True`` flipped + on and preserves every other field. + + Args: + configuration (TargetConfiguration | None): The user-supplied configuration, + or ``None`` to start from the class default. + + Returns: + TargetConfiguration: A configuration whose capabilities include + ``supports_tool_use=True``. + """ + source = configuration if configuration is not None else self._DEFAULT_CONFIGURATION + caps = source.capabilities + if caps.includes(capability=CapabilityName.TOOL_USE): + return source + updated_caps = TargetCapabilities( + supports_multi_message_pieces=caps.supports_multi_message_pieces, + supports_editable_history=caps.supports_editable_history, + supports_multi_turn=caps.supports_multi_turn, + supports_system_prompt=caps.supports_system_prompt, + supports_tool_use=True, + input_modalities=caps.input_modalities, + output_modalities=caps.output_modalities, + ) + return TargetConfiguration( + capabilities=updated_caps, + policy=source.policy, + tool_event_policy=source.tool_event_policy, + tool_backend=source.tool_backend, + ) + + @staticmethod + def _install_tool_backend( + *, + configuration: TargetConfiguration | None, + tool_backend: ToolBackend, + ) -> TargetConfiguration: + """ + Install ``tool_backend`` onto ``configuration``. Rejects double-supply. + + Args: + configuration (TargetConfiguration | None): The user-supplied configuration. + tool_backend (ToolBackend): The backend to install. + + Returns: + TargetConfiguration: The same ``configuration`` instance with the + backend installed. + + Raises: + ValueError: When ``configuration`` is ``None`` (no capability to attach + to), or when ``configuration.tool_backend`` is already set to a + different backend. + """ + if configuration is None: + raise ValueError( + "tool_backend kwarg requires capabilities.supports_tool_use=True; " + "supply tool_parser= so the default capabilities flip TOOL_USE on, " + "or build a custom_configuration explicitly." + ) + if configuration.tool_backend is not None and configuration.tool_backend is not tool_backend: + raise ValueError("tool_backend kwarg conflicts with custom_configuration.tool_backend; supply only one.") + configuration.tool_backend = tool_backend + return configuration def _build_identifier(self) -> ComponentIdentifier: """ @@ -224,17 +322,10 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me logger.info(f"Sending the following prompt to the prompt target: {request}") try: - resp_text = await self._complete_chat_async( - messages=normalized_conversation, - ) - - if not resp_text: - raise EmptyResponseException(message="The chat returned an empty response.") - - response_entry = construct_response_from_request(request=request, response_text_pieces=[resp_text]) + response_body = await self._complete_chat_async(messages=normalized_conversation) + response_entry = self._materialize_response(response=response_body, request=request) except HTTPStatusError as hse: if hse.response.status_code == 400: - # Handle Bad Request response_entry = handle_bad_request_exception(response_text=hse.response.text, request=request) elif hse.response.status_code == 429: raise RateLimitException from hse @@ -248,21 +339,23 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me async def _complete_chat_async( self, messages: list[Message], - ) -> str: + ) -> dict[str, Any]: """ - Completes a chat interaction by generating a response to the given input prompt. - - This is a synchronous wrapper for the asynchronous _generate_and_extract_response method. + Issue a single chat request and return the parsed JSON response body. Args: messages (list[Message]): The message objects containing the role and content. + Returns: + dict[str, Any]: The deserialized response body. Always includes an + ``output`` field (per the AML scoring-script contract). Tool-capable + deployments may additionally include a ``tool_calls`` field carrying + canonical envelopes. + Raises: EmptyResponseException: If the response from the chat is empty. + ValueError: If the parsed response body is missing the ``output`` field. Exception: For any other errors during the process. - - Returns: - str: The generated response message. """ headers = self._get_headers() payload = await self._construct_http_body_async(messages) @@ -271,15 +364,52 @@ async def _complete_chat_async( endpoint_uri=self._endpoint, method="POST", request_body=payload, headers=headers ) - try: - return str(response.json()["output"]) - except Exception as e: - if response.json() == {}: - raise EmptyResponseException(message="The chat returned an empty response.") from e - raise type(e)( - f"Exception obtaining response from the target. Returned response: {response.json()}. " - f"Exception: {str(e)}" - ) from e + body = response.json() + if not isinstance(body, dict) or body == {}: + raise EmptyResponseException(message="The chat returned an empty response.") + if "output" not in body: + raise ValueError(f"Response from the target did not include 'output'. Returned response: {body}.") + return body + + def _materialize_response(self, *, response: dict[str, Any], request: MessagePiece) -> Message: + """ + Build a ``Message`` from the parsed response body, handling tool calls. + + The deployment may include a ``tool_calls`` list when the model emits + canonical envelopes. Each envelope becomes its own ``function_call`` + MessagePiece so the ``CanonicalEnvelopeParser`` shipped with PyRIT can + recognize it without further translation. + + Args: + response (dict[str, Any]): The parsed response body returned from the endpoint. + request (MessagePiece): The request piece used to stamp identity onto each + response piece. + + Returns: + Message: The materialized response message. Has at least one piece; + when both ``output`` and ``tool_calls`` are present, the text piece + comes first followed by one function_call piece per envelope. + + Raises: + EmptyResponseException: If the response has neither output text nor tool calls. + """ + text = str(response.get("output") or "") + tool_envelopes = response.get("tool_calls") or [] + if not text and not tool_envelopes: + raise EmptyResponseException(message="The chat returned an empty response.") + + pieces: list[MessagePiece] = [] + if text: + text_piece = construct_response_from_request(request=request, response_text_pieces=[text]).message_pieces[0] + pieces.append(text_piece) + for envelope in tool_envelopes: + fc_piece = construct_response_from_request( + request=request, + response_text_pieces=[json.dumps(envelope, separators=(",", ":"))], + response_type="function_call", + ).message_pieces[0] + pieces.append(fc_piece) + return Message(message_pieces=pieces, skip_validation=True) async def _construct_http_body_async( self, @@ -297,10 +427,7 @@ async def _construct_http_body_async( wire_format = ChatMessageNormalizer() messages_dict = await wire_format.normalize_to_dicts_async(messages) - # Parameters include additional ones passed in through **kwargs. Those not accepted by the model will - # be ignored. We only include commonly supported parameters here - model-specific parameters like - # stop sequences should be passed via **param_kwargs since different models use different EOS tokens. - return { + body: dict[str, Any] = { "input_data": { "input_string": messages_dict, "parameters": { @@ -312,6 +439,29 @@ async def _construct_http_body_async( | self._extra_parameters, } } + schemas = self._tool_schemas() + if schemas: + body["tools"] = schemas + return body + + @property + def _tool_parser(self) -> ToolCallParser | None: + """Return the parser supplied at construction, if any.""" + return self._tool_parser_instance + + def _tool_schemas(self) -> list[dict[str, Any]]: + """ + Wrap the backend's schemas in the OpenAI Chat Completions ``tools`` shape. + + Tool-capable deployments are expected to forward ``tools`` into + ``tokenizer.apply_chat_template`` after unwrapping the ``{"type": + "function", "function": {...}}`` envelope. + + Returns: + list[dict[str, Any]]: One descriptor per advertised tool, or an + empty list when no backend is configured. + """ + return [{"type": "function", "function": schema} for schema in super()._tool_schemas()] def _get_headers(self) -> dict[str, str]: """ diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index 45600e6009..b6a79bbcb9 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -149,6 +149,7 @@ def _permissive_configuration( supports_json_output=True, supports_editable_history=True, supports_system_prompt=True, + supports_tool_use=True, input_modalities=merged_modalities, ) # Rebuild a fresh configuration from the instance's native capabilities so diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index b1ee5caaa2..9ff03df7b6 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -12,6 +12,7 @@ from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import ToolCallParser, tool_loop logger = logging.getLogger(__name__) @@ -85,6 +86,7 @@ def __init__( logging.basicConfig(level=logging.INFO) @final + @tool_loop async def send_prompt_async(self, *, message: Message) -> list[Message]: """ Validate, normalize, and send a prompt to the target. @@ -97,6 +99,13 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: 3. Delegates to ``_send_prompt_to_target_async`` with the normalized conversation. + When the target's :attr:`configuration.tool_event_policy` is set, the + :func:`pyrit.tools.tool_loop` decorator replaces this body with the + agentic loop and re-enters :meth:`_send_prompt_to_target_async` + repeatedly until the model issues a stop response (or the configured + ``max_tool_iterations`` is hit). When no policy is set, the decorator + is a no-op and the body below runs unchanged. + Subclasses MUST NOT override this method. Override ``_send_prompt_to_target_async`` instead. @@ -132,6 +141,44 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me list[Message]: Response messages from the target. """ + @property + def _tool_parser(self) -> ToolCallParser | None: + """ + Per-target :class:`ToolCallParser` consulted by :func:`pyrit.tools.tool_loop`. + + Targets that participate in the tool-use loop override this property + to return a parser that walks their response messages and extracts + :class:`~pyrit.tools.ToolCall` instances. The base default of + ``None`` signals "this target does not participate" -- the wrapper + short-circuits after the first response. + + Returns: + ToolCallParser | None: The parser, or ``None`` for the default + no-tool-use behavior. + """ + return None + + def _tool_schemas(self) -> list[dict[str, Any]]: + """ + Outbound tool-schema list sent on the next request to the model. + + The default reads the configured ``tool_backend.schemas`` verbatim. + Targets whose wire format wraps schemas differently (e.g., OpenAI + Chat Completions requires ``{"type": "function", "function": {...}}``; + the OpenAI Responses API requires ``{"type": "function", **schema}`` + spread at the top level) override this method to apply the + per-target translation. + + Returns: + list[dict[str, Any]]: One schema per advertised tool, in + whatever wire format this target expects. Empty when no + backend is configured. + """ + backend = self.configuration.tool_backend + if backend is None: + return [] + return list(backend.schemas) + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validate the normalized conversation before sending to the target. diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 6ae9ed69e2..234ef4d359 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -24,6 +24,7 @@ class CapabilityName(str, Enum): JSON_OUTPUT = "supports_json_output" EDITABLE_HISTORY = "supports_editable_history" SYSTEM_PROMPT = "supports_system_prompt" + TOOL_USE = "supports_tool_use" class UnsupportedCapabilityBehavior(str, Enum): @@ -138,6 +139,14 @@ class attribute. Users can override individual capabilities per instance # Whether the target natively supports system prompts. supports_system_prompt: bool = False + # Whether the target natively supports model-issued tool calls (the + # canonical OpenAI ``function_call`` / ``function_call_output`` envelopes + # plus an outbound tool-schema list). Targets without this capability + # cannot host a tool-use loop -- attempting to configure a + # :class:`TargetConfiguration` with a ``tool_backend`` on a target whose + # capabilities have ``supports_tool_use=False`` raises at construction. + supports_tool_use: bool = False + # The input modalities supported by the target (e.g., "text", "image"). input_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])}) diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py index 7e11a04673..e9d401b824 100644 --- a/pyrit/prompt_target/common/target_configuration.py +++ b/pyrit/prompt_target/common/target_configuration.py @@ -4,7 +4,7 @@ import logging from collections.abc import Mapping from dataclasses import fields -from typing import Any +from typing import TYPE_CHECKING, Any from pyrit.message_normalizer import MessageListNormalizer from pyrit.models import Message @@ -16,6 +16,10 @@ UnsupportedCapabilityBehavior, ) +if TYPE_CHECKING: + from pyrit.tools.backend import ToolBackend + from pyrit.tools.models import ToolEventPolicy + logger = logging.getLogger(__name__) @@ -39,6 +43,15 @@ class TargetConfiguration: Each target defines defaults; callers can override policy or individual normalizers at creation time. + + Tool use is configured by setting :attr:`tool_event_policy` (mandatory + when a target's response contains tool calls; controls EXECUTE / RAISE / + RETURN\\_RAW behavior) and optionally :attr:`tool_backend` (required only + when ``tool_event_policy.behavior`` is ``EXECUTE``). Both default to + ``None`` and are read by :func:`pyrit.tools.tool_loop` at runtime; + constructing a configuration with a ``tool_backend`` on a target that + does not declare ``capabilities.supports_tool_use=True`` raises + immediately. """ def __init__( @@ -47,6 +60,8 @@ def __init__( capabilities: TargetCapabilities, policy: CapabilityHandlingPolicy | None = None, normalizer_overrides: Mapping[CapabilityName, MessageListNormalizer[Any]] | None = None, + tool_event_policy: "ToolEventPolicy | None" = None, + tool_backend: "ToolBackend | None" = None, ) -> None: """ Build a target configuration and resolve the normalization pipeline. @@ -57,7 +72,25 @@ def __init__( capability. Defaults to RAISE for all adaptable capabilities. normalizer_overrides (Mapping[CapabilityName, MessageListNormalizer[Any]] | None): Optional overrides for specific capability normalizers. + tool_event_policy (ToolEventPolicy | None): How + :func:`pyrit.tools.tool_loop` should react to a pending tool + call from the target. ``None`` means the loop is disabled and + the wrapper short-circuits. + tool_backend (ToolBackend | None): Dispatch table used when + ``tool_event_policy.behavior`` is ``EXECUTE``. ``None`` is + valid only for the RAISE / RETURN\\_RAW policies and the + no-policy passthrough. + + Raises: + ValueError: If ``tool_backend`` is set on a target whose + capabilities do not include ``supports_tool_use``. """ + if tool_backend is not None and not capabilities.includes(capability=CapabilityName.TOOL_USE): + raise ValueError( + "tool_backend is set but capabilities.supports_tool_use is False. " + "Either declare supports_tool_use=True on the target's capabilities, " + "or remove the tool_backend." + ) self._capabilities = capabilities self._policy = policy or _DEFAULT_POLICY self._pipeline = ConversationNormalizationPipeline.from_capabilities( @@ -65,6 +98,8 @@ def __init__( policy=self._policy, normalizer_overrides=normalizer_overrides, ) + self._tool_event_policy = tool_event_policy + self._tool_backend = tool_backend @property def capabilities(self) -> TargetCapabilities: @@ -81,6 +116,41 @@ def pipeline(self) -> ConversationNormalizationPipeline: """The resolved normalization pipeline.""" return self._pipeline + @property + def tool_event_policy(self) -> "ToolEventPolicy | None": + """The tool-use policy consulted by :func:`pyrit.tools.tool_loop`.""" + return self._tool_event_policy + + @tool_event_policy.setter + def tool_event_policy(self, value: "ToolEventPolicy | None") -> None: + """Allow runtime updates so callers can opt a configured target into tool use.""" + self._tool_event_policy = value + + @property + def tool_backend(self) -> "ToolBackend | None": + """The tool dispatch backend used when the loop's behavior is ``EXECUTE``.""" + return self._tool_backend + + @tool_backend.setter + def tool_backend(self, value: "ToolBackend | None") -> None: + """ + Allow runtime updates to the backend. + + Re-runs the ``supports_tool_use`` validator so a backend can never be + installed onto a configuration that does not declare the capability. + + Raises: + ValueError: If ``value`` is not ``None`` and the configuration's + capabilities do not include ``supports_tool_use``. + """ + if value is not None and not self._capabilities.includes(capability=CapabilityName.TOOL_USE): + raise ValueError( + "tool_backend is set but capabilities.supports_tool_use is False. " + "Either declare supports_tool_use=True on the target's capabilities, " + "or remove the tool_backend." + ) + self._tool_backend = value + def includes(self, *, capability: CapabilityName) -> bool: """ Check whether the target includes support for the given capability. @@ -138,14 +208,20 @@ def as_identifier_params(self) -> dict[str, Any]: suitable for inclusion in a ``ComponentIdentifier``. The returned dict preserves the structure of ``TargetConfiguration`` - — capabilities, policy, and pipeline are kept as nested sub-dicts rather - than flattened into the caller — so the identifier reflects the shape of - the object it describes. + — capabilities, policy, pipeline, tool-event policy, and tool backend + are kept as nested sub-dicts rather than flattened into the caller — + so the identifier reflects the shape of the object it describes. Two configurations that behave identically must produce equal dicts; configurations that differ in any identity-bearing field must produce - unequal dicts. Modality sets are flattened to sorted lists of sorted - lists so ordering is stable across runs. + unequal dicts. The tool-backend snapshot uses the backend class plus + the sorted list of advertised tool names; this means two backends of + the same type exposing the same tool set are treated as equivalent + for identifier purposes (their exact callables / transports are not + deterministically serializable). + + Modality sets are flattened to sorted lists of sorted lists so + ordering is stable across runs. Returns: dict[str, Any]: The identifier parameters for this configuration. @@ -153,24 +229,82 @@ def as_identifier_params(self) -> dict[str, Any]: caps = self._capabilities return { "capabilities": self._capabilities_to_identifier_params(caps), - # Only unsupported capabilities appear here. Policy entries for - # natively-supported capabilities are moot (the behavior never - # fires), and omitting them keeps identifiers stable when default - # policies expand to cover more capabilities. "capability_policy": { capability.value: behavior.value for capability, behavior in self._policy.behaviors.items() if not caps.includes(capability=capability) }, - # Stable, ordered representation of the resolved normalization - # pipeline. Captures the effect of ``normalizer_overrides`` since - # the pipeline is built from defaults + overrides. "normalization_pipeline": [ f"{type(normalizer).__module__}.{type(normalizer).__qualname__}" for normalizer in self._pipeline.normalizers ], + "tool_event_policy": self._tool_event_policy_to_identifier_params(), + "tool_backend": self._tool_backend_to_identifier_params(), } + def _tool_event_policy_to_identifier_params(self) -> dict[str, Any] | None: + """ + Snapshot the active tool-event policy as identifier params. + + Returns: + dict[str, Any] | None: ``None`` when no policy is configured; + otherwise ``behavior`` and ``max_tool_iterations`` as plain + values. + """ + if self._tool_event_policy is None: + return None + return { + "behavior": self._tool_event_policy.behavior.value, + "max_tool_iterations": self._tool_event_policy.max_tool_iterations, + } + + def _tool_backend_to_identifier_params(self) -> dict[str, Any] | None: + """ + Snapshot the active tool backend as identifier params. + + Returns the backend's fully-qualified class name plus the sorted + list of tool names it advertises. Exact callables / transports are + not serialized; two backends of the same type exposing the same + tool set therefore produce equal identifier dicts. + + Returns: + dict[str, Any] | None: ``None`` when no backend is configured; + otherwise ``type`` (fully-qualified class name) and + ``tools`` (sorted list of advertised tool names). + """ + if self._tool_backend is None: + return None + backend_type = type(self._tool_backend) + return { + "type": f"{backend_type.__module__}.{backend_type.__qualname__}", + "tools": sorted(self._extract_tool_names(self._tool_backend.schemas)), + } + + @staticmethod + def _extract_tool_names(schemas: list[dict[str, Any]]) -> list[str]: + """ + Pull the ``name`` field from each schema, supporting both flat and + nested ``function`` envelopes. + + Args: + schemas (list[dict[str, Any]]): The backend-provided schema list. + + Returns: + list[str]: One name per schema. Schemas without a recoverable + name are skipped silently — the identifier is best-effort + for shape-quirky backends. + """ + names: list[str] = [] + for schema in schemas: + if not isinstance(schema, dict): + continue + name = schema.get("name") + if not name and isinstance(schema.get("function"), dict): + name = schema["function"].get("name") + if isinstance(name, str): + names.append(name) + return names + @staticmethod def _capabilities_to_identifier_params(capabilities: TargetCapabilities) -> dict[str, Any]: """ diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index f2d62be82a..055cd55c70 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -22,9 +22,13 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityName, + TargetCapabilities, +) from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute +from pyrit.tools import ToolBackend, ToolCallParser logger = logging.getLogger(__name__) @@ -77,6 +81,8 @@ def __init__( attn_implementation: str | None = None, max_requests_per_minute: int | None = None, custom_configuration: TargetConfiguration | None = None, + tool_parser: ToolCallParser | None = None, + tool_backend: ToolBackend | None = None, ) -> None: """ Initialize the HuggingFaceChatTarget. @@ -108,6 +114,15 @@ def __init__( max_requests_per_minute (int | None): The maximum number of requests per minute. Defaults to None. custom_configuration (TargetConfiguration | None): Override the default configuration for this target instance. Defaults to None. + tool_parser (ToolCallParser | None): When supplied, the target opts into PyRIT's + ``@tool_loop`` and uses this parser to extract pending tool calls from each + generated response. Supplying a parser also enables ``supports_tool_use=True`` + on the default capabilities so callers don't need a custom_configuration just + to opt in. ``InlineToolCallParser`` is the typical choice because the local + tokenizer emits tool calls as inline marker-delimited JSON; supply a different + parser when targeting a chat template with a different marker syntax. + tool_backend (ToolBackend | None): Convenience kwarg that installs the backend + onto ``custom_configuration.tool_backend``. Raises: ValueError: If neither or both of `model_id` and `model_path` are provided. @@ -115,6 +130,16 @@ def __init__( """ model_name = model_id if model_id else model_path if model_path else "" + # Enable tool-use capability when a parser is supplied, BEFORE super().__init__ + # so the configuration is correct by the time the base class records it. + if tool_parser is not None: + custom_configuration = self._enable_tool_use(configuration=custom_configuration) + if tool_backend is not None: + custom_configuration = self._install_tool_backend( + configuration=custom_configuration, + tool_backend=tool_backend, + ) + super().__init__( max_requests_per_minute=max_requests_per_minute, model_name=model_name, @@ -174,6 +199,81 @@ def __init__( raise RuntimeError("CUDA requested but not available.") self.load_model_and_tokenizer_task = asyncio.create_task(self.load_model_and_tokenizer()) + self._tool_parser_instance = tool_parser + + @classmethod + def _enable_tool_use(cls, *, configuration: TargetConfiguration | None) -> TargetConfiguration: + """ + Return a configuration whose capabilities include ``supports_tool_use=True``. + + When ``configuration`` already has the capability set, returns it as-is. + Otherwise rebuilds the capabilities with ``supports_tool_use=True`` and + preserves every other field. + + Args: + configuration (TargetConfiguration | None): The user-supplied configuration, + or ``None`` to start from the class default. + + Returns: + TargetConfiguration: A configuration whose capabilities include + ``supports_tool_use=True``. + """ + source = configuration if configuration is not None else cls._DEFAULT_CONFIGURATION + caps = source.capabilities + if caps.includes(capability=CapabilityName.TOOL_USE): + return source + updated_caps = TargetCapabilities( + supports_multi_message_pieces=caps.supports_multi_message_pieces, + supports_editable_history=caps.supports_editable_history, + supports_multi_turn=caps.supports_multi_turn, + supports_system_prompt=caps.supports_system_prompt, + supports_tool_use=True, + input_modalities=caps.input_modalities, + output_modalities=caps.output_modalities, + ) + return TargetConfiguration( + capabilities=updated_caps, + policy=source.policy, + tool_event_policy=source.tool_event_policy, + tool_backend=source.tool_backend, + ) + + @staticmethod + def _install_tool_backend( + *, + configuration: TargetConfiguration | None, + tool_backend: ToolBackend, + ) -> TargetConfiguration: + """ + Install ``tool_backend`` onto ``configuration``. Rejects double-supply. + + Args: + configuration (TargetConfiguration | None): The user-supplied configuration. + tool_backend (ToolBackend): The backend to install. + + Returns: + TargetConfiguration: The same ``configuration`` instance with the + backend installed. + + Raises: + ValueError: When ``configuration`` is ``None``, or when + ``configuration.tool_backend`` is already set to a different backend. + """ + if configuration is None: + raise ValueError( + "tool_backend kwarg requires capabilities.supports_tool_use=True; " + "supply tool_parser= so the default capabilities flip TOOL_USE on, " + "or build a custom_configuration explicitly." + ) + if configuration.tool_backend is not None and configuration.tool_backend is not tool_backend: + raise ValueError("tool_backend kwarg conflicts with custom_configuration.tool_backend; supply only one.") + configuration.tool_backend = tool_backend + return configuration + + @property + def _tool_parser(self) -> ToolCallParser | None: + """Return the parser supplied at construction, if any.""" + return self._tool_parser_instance def _build_identifier(self) -> ComponentIdentifier: """ @@ -401,27 +501,79 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me logger.error(f"Error occurred during inference: {e}") raise - def _build_chat_messages(self, *, normalized_conversation: list[Message]) -> list[dict[str, str]]: + def _build_chat_messages(self, *, normalized_conversation: list[Message]) -> list[dict[str, Any]]: """ Build a list of chat message dicts from the full normalized conversation. Includes system, user, and assistant messages from the conversation history - so that the model's chat template receives the complete context. + so that the model's chat template receives the complete context. When the + conversation contains tool-call envelopes (produced by ``@tool_loop``), they + are converted into the chat-template's tool message shape: + + * ``function_call`` pieces become an ``assistant`` message with a + ``tool_calls`` list (matching the HuggingFace ``apply_chat_template`` + convention; templates that don't recognize ``tool_calls`` fall back to + rendering the embedded JSON as content). + * ``function_call_output`` pieces become a ``role=tool`` message with the + tool result as content and ``tool_call_id`` carried for templates that + need it. Args: normalized_conversation (list[Message]): The full normalized conversation. Returns: - list[dict[str, str]]: Messages formatted for the chat template. + list[dict[str, Any]]: Messages formatted for the chat template. """ - messages: list[dict[str, str]] = [] + messages: list[dict[str, Any]] = [] for msg in normalized_conversation: - piece = msg.message_pieces[0] - role = piece.api_role - content = piece.converted_value or "" - messages.append({"role": role, "content": content}) + for piece in msg.message_pieces: + tool_dict = self._maybe_tool_chat_message(piece=piece) + if tool_dict is not None: + messages.append(tool_dict) + continue + role = piece.api_role + content = piece.converted_value or "" + messages.append({"role": role, "content": content}) return messages + @staticmethod + def _maybe_tool_chat_message(*, piece: Any) -> dict[str, Any] | None: + """ + Convert a ``function_call`` or ``function_call_output`` piece to a chat-template message. + + Args: + piece (Any): The MessagePiece to inspect. + + Returns: + dict[str, Any] | None: A chat-template message dict (``assistant`` with + ``tool_calls``, or ``role=tool`` with ``tool_call_id``) when the + piece carries a tool envelope, otherwise ``None``. + """ + data_type = piece.converted_value_data_type or piece.original_value_data_type + if data_type not in ("function_call", "function_call_output"): + return None + envelope = json.loads(piece.converted_value) + if data_type == "function_call": + return { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": envelope.get("call_id", ""), + "type": "function", + "function": { + "name": envelope.get("name", ""), + "arguments": envelope.get("arguments", "{}"), + }, + } + ], + } + return { + "role": "tool", + "content": str(envelope.get("output", "")), + "tool_call_id": envelope.get("call_id", ""), + } + def set_random_seed(self, random_seed: int) -> None: """ Set a new random seed and immediately re-seed the RNG. @@ -520,6 +672,13 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: """ Apply the chat template to the input messages and tokenize them. + When ``self._tool_schemas()`` is non-empty, the schemas are forwarded + into ``apply_chat_template`` so tool-trained chat templates can render + the model-family-specific tools block (Qwen wraps in ``...``, + Llama uses a system-message preamble, etc.). The model can then emit + tool calls in its native marker syntax which the user-supplied + ``tool_parser`` extracts. + Args: messages: The input messages to apply the chat template to. @@ -533,6 +692,11 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is not None: logger.info("Tokenizer has a chat template. Applying it to the input messages.") + template_kwargs: dict[str, Any] = {} + schemas = self._tool_schemas() + if schemas: + template_kwargs["tools"] = schemas + # Apply the chat template to format and tokenize the messages return cast( "BatchEncoding", @@ -542,6 +706,7 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: add_generation_prompt=True, return_tensors=self.tensor_format, return_dict=True, + **template_kwargs, ), ).to(self.device) error_message = ( diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index f2e4b19a76..91f8bc05f0 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -3,6 +3,7 @@ import json import logging +import warnings from collections.abc import Awaitable, Callable, MutableSequence from enum import Enum from typing import ( @@ -34,6 +35,13 @@ from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p from pyrit.prompt_target.openai.openai_error_handling import _is_content_filter_error from pyrit.prompt_target.openai.openai_target import OpenAITarget +from pyrit.tools import ( + CanonicalEnvelopeParser, + LocalToolBackend, + ToolCallParser, + ToolEventBehavior, + ToolEventPolicy, +) logger = logging.getLogger(__name__) @@ -76,6 +84,7 @@ class OpenAIResponseTarget(OpenAITarget, PromptTarget): supports_json_output=True, supports_multi_message_pieces=True, supports_system_prompt=True, + supports_tool_use=True, input_modalities=frozenset( { frozenset(["text"]), @@ -154,6 +163,17 @@ def __init__( """ super().__init__(custom_configuration=custom_configuration, **kwargs) + # If the constructed configuration is the class-level _DEFAULT_CONFIGURATION + # singleton (user did not pass custom_configuration AND the underlying_model + # was unrecognized), rebuild a per-instance copy so the C6 tool-backend + # plumbing below does not mutate state shared across every other instance. + if custom_configuration is None and self._configuration is type(self)._DEFAULT_CONFIGURATION: + caps = self._configuration.capabilities + self._configuration = TargetConfiguration( + capabilities=caps, + policy=self._configuration.policy, + ) + # Validate temperature and top_p validate_temperature(temperature) validate_top_p(top_p) @@ -167,10 +187,39 @@ def __init__( self._extra_body_parameters = extra_body_parameters - # Per-instance tool/func registries: - self._custom_functions: dict[str, ToolExecutor] = custom_functions or {} + # ----- Tool-calling plumbing (C6) --------------------------------- + # custom_functions is deprecated as of 0.15.x. New code configures + # tool_backend on TargetConfiguration directly. The kwarg is still + # accepted; we ALWAYS install a LocalToolBackend (whether populated + # or empty) when no other backend is supplied, so legacy in-place + # mutations of `target._custom_functions` (via the back-compat + # property below) keep affecting dispatch. self._fail_on_missing_function: bool = fail_on_missing_function + if self.configuration.tool_backend is None: + shim_backend = LocalToolBackend( + callables=dict(custom_functions) if custom_functions else {}, + schemas=self._derive_default_schemas(custom_functions or {}), + fail_on_missing_function=fail_on_missing_function, + ) + self.configuration.tool_backend = shim_backend + + if custom_functions: + warnings.warn( + "OpenAIResponseTarget(custom_functions=...) is deprecated and will be " + "removed in 0.16.0. Configure tool_backend on TargetConfiguration " + "instead (e.g. LocalToolBackend(callables=..., schemas=..., " + "fail_on_missing_function=...)).", + DeprecationWarning, + stacklevel=2, + ) + + # Default policy to EXECUTE when a backend is present. The wrapper's + # parser returns an empty list when the model produces no tool calls, + # so this is a no-op for plain text completions. + if self.configuration.tool_event_policy is None: + self.configuration.tool_event_policy = ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE) + # Extract the grammar 'tool' if one is present # See # https://platform.openai.com/docs/guides/function-calling#context-free-grammars @@ -185,6 +234,61 @@ def __init__( logger.debug("Detected grammar tool: %s", tool_name) self._grammar_name = tool_name + @staticmethod + def _derive_default_schemas(callables: dict[str, ToolExecutor]) -> list[dict[str, Any]]: + """ + Synthesize minimal JSON schemas for the deprecation-shim path. + + Users who pass the legacy ``custom_functions`` kwarg do not also pass a + schema list (the Responses API would accept the calls anyway because the + legacy path predates structured tool advertisement). To keep the + deprecation shim transparent we generate a schema-less stub per name so + ``_tool_schemas()`` returns something non-empty when the user actually + wires tools. + + Args: + callables: Function name to async callable mapping. + + Returns: + list[dict[str, Any]]: A bare schema per callable (``parameters`` + is the unconstrained empty-object schema). + """ + return [{"name": name, "parameters": {"type": "object"}} for name in callables] + + @property + def _custom_functions(self) -> dict[str, ToolExecutor]: + """ + Back-compat live view of the active backend's callables registry. + + Mutations on the returned dict (``target._custom_functions[name] = fn``, + ``target._custom_functions.pop(name)``) take effect immediately because + the dict object is shared with the underlying + :class:`pyrit.tools.LocalToolBackend`. Returns an empty dict when no + backend is installed or when the configured backend is not a + ``LocalToolBackend``. + + Returns: + dict[str, ToolExecutor]: The live callables dict. + """ + backend = self.configuration.tool_backend + if isinstance(backend, LocalToolBackend): + return cast("dict[str, ToolExecutor]", backend._callables) + return {} + + @_custom_functions.setter + def _custom_functions(self, value: dict[str, ToolExecutor]) -> None: + backend = self.configuration.tool_backend + if isinstance(backend, LocalToolBackend): + backend._callables = dict(value) + backend._schemas = self._derive_default_schemas(value) + return + new_backend = LocalToolBackend( + callables=dict(value), + schemas=self._derive_default_schemas(value), + fail_on_missing_function=self._fail_on_missing_function, + ) + self.configuration.tool_backend = new_backend + def _build_identifier(self) -> ComponentIdentifier: """ Build the identifier with OpenAI response-specific parameters. @@ -378,8 +482,9 @@ async def _construct_request_body( input_items = await self._build_input_for_multi_modal_async(conversation) text_format = self._build_text_format(json_config=json_config) + tool_schemas = self._tool_schemas() - body_parameters = { + body_parameters: dict[str, Any] = { "model": self._model_name, "max_output_tokens": self._max_output_tokens, "temperature": self._temperature, @@ -390,8 +495,11 @@ async def _construct_request_body( "text": text_format, "reasoning": self._build_reasoning_config(), } + if tool_schemas: + body_parameters["tools"] = tool_schemas if self._extra_body_parameters: + # User-supplied extra_body_parameters wins over backend-derived tools. body_parameters.update(self._extra_body_parameters) # Filter out None values @@ -559,11 +667,18 @@ async def _construct_message_from_response(self, response: Any, request: Message @pyrit_target_retry async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ - Send prompt, handle agentic tool calls (function_call), return all messages. + Send one prompt to the Responses API and return exactly one Message. + + The agentic tool-calling loop now lives in :func:`pyrit.tools.tool_loop` + on the base class. This method is the single-iteration body the loop + re-enters on each turn: build the request body, call the API, parse the + response, return the constructed :class:`Message` wrapped in a list of + length 1. - The Responses API supports structured outputs and tool execution. This method handles both: - - Simple text/reasoning responses - - Agentic tool-calling loops that may require multiple back-and-forth exchanges + The wrapper detects function_call pieces via :attr:`_tool_parser` and + decides whether to dispatch + re-enter. Reasoning, MCP, web-search, + computer-use, and other non-function-call sections pass through to + Memory unchanged because the parser ignores them. Args: normalized_conversation (list[Message]): The full conversation @@ -571,59 +686,51 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me pipeline. The current message is the last element. Returns: - List of messages generated during the interaction (assistant responses and tool messages). - The normalizer will persist all of these to memory. + list[Message]: Exactly one Message wrapping the parsed response. """ message = normalized_conversation[-1] - message_piece: MessagePiece = message.message_pieces[0] last_piece = message.message_pieces[-1] json_config = self._get_json_response_config(message_piece=last_piece) - working_conversation: MutableSequence[Message] = list(normalized_conversation) - - # Track all responses generated during this interaction - responses_to_return: list[Message] = [] - - # Main agentic loop - each back-and-forth creates a new message - tool_call_section: Optional[dict[str, Any]] = None - - while True: - logger.info(f"Sending conversation with {len(working_conversation)} messages to the prompt target") - - body = await self._construct_request_body(conversation=working_conversation, json_config=json_config) - - # Use unified error handling - automatically detects Response and validates - result = await self._handle_openai_request( - api_call=lambda body=body: self._client.responses.create(**body), - request=message, - ) - - # Add result to conversation and responses list - working_conversation.append(result) - responses_to_return.append(result) - - # Extract tool call if present - tool_call_section = self._find_last_pending_tool_call(result) - - # If no tool call, we're done - if not tool_call_section: - break - - # Execute the tool/function - tool_output = await self._execute_call_section(tool_call_section) + body = await self._construct_request_body(conversation=list(normalized_conversation), json_config=json_config) + logger.info("Sending conversation with %d messages to the Responses API", len(normalized_conversation)) + result = await self._handle_openai_request( + api_call=lambda body=body: self._client.responses.create(**body), + request=message, + ) + return [result] - # Create a new message with the tool output - tool_piece = self._make_tool_piece(tool_output, tool_call_section["call_id"], reference_piece=message_piece) - tool_message = Message(message_pieces=[tool_piece], skip_validation=True) + @property + def _tool_parser(self) -> ToolCallParser | None: + """ + Canonical-envelope parser shared with future canonical-envelope targets. + + Walks response message pieces and emits one :class:`~pyrit.tools.ToolCall` + per piece whose ``original_value_data_type`` is ``"function_call"``. + Reasoning, MCP, web-search, computer-use, and local-shell sections all + produce pieces of OTHER data types, so the parser returns an empty list + for them and the @tool_loop decorator exits cleanly. Those sections + still land in Memory via the parsed Message returned by + ``_send_prompt_to_target_async``; they're just not client-side + dispatched. + """ + return CanonicalEnvelopeParser() - # Add tool output message to conversation and responses list - working_conversation.append(tool_message) - responses_to_return.append(tool_message) + def _tool_schemas(self) -> list[dict[str, Any]]: + """ + Translate the configured backend's schemas into Responses-API tools shape. - # Continue loop to send tool result and get next response + The Responses API expects each function tool as a top-level + ``{"type": "function", "name": ..., "description": ..., + "parameters": ...}`` entry (NOT wrapped in an inner ``"function"`` key + the way Chat Completions does). The backend's schemas are already the + bare function schema, so we just stamp ``type=function`` on each. - # Return all responses (normalizer will persist all of them to memory) - return responses_to_return + Returns: + list[dict[str, Any]]: One descriptor per advertised tool, or an + empty list when no backend is configured. + """ + return [{"type": "function", **schema} for schema in super()._tool_schemas()] def _parse_response_output_section( self, *, section: Any, message_piece: MessagePiece, error: Optional[PromptResponseError] diff --git a/pyrit/tools/__init__.py b/pyrit/tools/__init__.py new file mode 100644 index 0000000000..e91b44a037 --- /dev/null +++ b/pyrit/tools/__init__.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Generic tool-use scaffolding for ``PromptTarget``. + +This package provides a transport-agnostic tool-calling loop. The +``tool_loop`` decorator, when applied to ``send_prompt_async``, runs +the standard PyRIT validate+normalize work once and then repeatedly +re-enters the target's protected ``_send_prompt_to_target_async`` until +the model issues a stop response (or a configured limit is hit). + +A target opts in by declaring two collaborators: + +* ``self._tool_parser`` — a ``ToolCallParser`` that walks a + response message and extracts pending ``ToolCall`` instances. +* ``self.configuration.tool_event_policy`` — a ``ToolEventPolicy`` + whose ``ToolEventBehavior`` decides whether to ``EXECUTE``, + ``RAISE``, or ``RETURN_RAW`` on each detected call. + +When the policy is ``EXECUTE``, calls are dispatched through +``self.configuration.tool_backend``, an implementation of +``ToolBackend``. ``LocalToolBackend`` is the in-process backend; +``MCPToolBackend`` proxies through one or more MCP servers. + +The ``ToolBackend`` abstract base is intentionally distinct from +``pyrit.registry`` — that namespace is reserved for framework-level +identity registries (``TargetRegistry``, ``ScorerRegistry``) that +register named singletons for CLI lookup, which a per-target tool +dispatch table is not. + +``@tool_loop`` is wired onto ``PromptTarget.send_prompt_async`` from +the base class, and the ``tool_event_policy`` / ``tool_backend`` +fields hang off ``TargetConfiguration``. + +The two exception types the loop raises +(``ToolCallNotSupported`` and +``ToolCallLoopLimitExceeded``) live in +``pyrit.exceptions`` alongside the rest of PyRIT's exception +catalog, so non-tools callers (attacks, normalizers) can import them +without taking a subsystem-level dependency on ``pyrit.tools``. +""" + +from pyrit.tools.backend import ToolBackend +from pyrit.tools.inline_parser import InlineToolCallParser, InlineToolCallParserMode +from pyrit.tools.local_backend import LocalToolBackend +from pyrit.tools.mcp_backend import MCPToolBackend +from pyrit.tools.mcp_client import ( + DockerMCPServerSpec, + LocalMCPServerSpec, + MCPClient, + MCPServerSpec, + RemoteMCPServerSpec, +) +from pyrit.tools.models import ToolCall, ToolEventBehavior, ToolEventPolicy, tool_loop +from pyrit.tools.parsers import CanonicalEnvelopeParser, ToolCallParser + +__all__ = [ + "CanonicalEnvelopeParser", + "DockerMCPServerSpec", + "InlineToolCallParser", + "InlineToolCallParserMode", + "LocalMCPServerSpec", + "LocalToolBackend", + "MCPClient", + "MCPServerSpec", + "MCPToolBackend", + "RemoteMCPServerSpec", + "ToolBackend", + "ToolCall", + "ToolCallParser", + "ToolEventBehavior", + "ToolEventPolicy", + "tool_loop", +] diff --git a/pyrit/tools/backend.py b/pyrit/tools/backend.py new file mode 100644 index 0000000000..d878bd1f4f --- /dev/null +++ b/pyrit/tools/backend.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pyrit.tools.models import ToolCall + + +class ToolBackend(ABC): + """ + Abstract base for backends that dispatch tool calls produced by a target. + + A ``ToolBackend`` is a per-target dispatch table — it owns the + ``name -> async callable`` mapping a target uses to execute the tool + calls extracted from a model response. This is intentionally distinct + from ``pyrit.registry``, whose ``Registry`` classes register named + framework singletons (targets, scorers, attacks) for CLI lookup. + + Two concrete implementations ship with PyRIT: + + * ``LocalToolBackend`` — in-process backend backed + by ``async def`` callables. Useful for unit tests and for embedding + tools inside the PyRIT process. + * ``MCPToolBackend`` — proxies dispatch through one + or more MCP servers. + + Subclasses MUST implement ``schemas`` and ``dispatch_async``. + ``dispatch_all_sequential_async`` ships with a default + implementation that awaits ``dispatch_async`` once per call in + declaration order; backends that wish to parallelize dispatch + (e.g. fan out across multiple sandbox containers) should override it. + """ + + @property + @abstractmethod + def schemas(self) -> list[dict[str, Any]]: + """ + The JSON-schema descriptors for every tool the backend exposes. + + Returns: + list[dict[str, Any]]: One schema per tool, in a target-agnostic + format that concrete targets serialize into their request + body. + """ + + @abstractmethod + async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: + """ + Execute a single tool call and return the structured result. + + Implementations MUST NOT raise on tool-side failures; they MUST + return an error envelope (e.g. ``{"error": "...", "tool": "..."}``) + so the tool loop can carry the failure back to the model. + + Args: + call (ToolCall): The tool call to dispatch. + + Returns: + dict[str, Any]: The structured tool result. + """ + + async def dispatch_all_sequential_async( + self, + calls: list[ToolCall], + ) -> list[tuple[ToolCall, dict[str, Any]]]: + """ + Dispatch every call in *calls* sequentially, preserving declaration order. + + Default implementation: ``await dispatch_async`` once per call. + Backends that parallelize dispatch should override this method. + + Args: + calls (list[ToolCall]): The calls to dispatch, in declaration order. + + Returns: + list[tuple[ToolCall, dict[str, Any]]]: ``(call, result)`` pairs, + in the same order as *calls*. + """ + results: list[tuple[ToolCall, dict[str, Any]]] = [] + for call in calls: + envelope = await self.dispatch_async(call) + results.append((call, envelope)) + return results diff --git a/pyrit/tools/inline_parser.py b/pyrit/tools/inline_parser.py new file mode 100644 index 0000000000..c7cc353fe4 --- /dev/null +++ b/pyrit/tools/inline_parser.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Inline tool-call parser for open chat-tuned models.""" + +from __future__ import annotations + +import enum +import json +import logging +import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyrit.models import Message + from pyrit.tools.models import ToolCall + +logger = logging.getLogger(__name__) + + +class InlineToolCallParserMode(enum.Enum): + """Policy for handling text that surrounds inline tool-call markers.""" + + TRUNCATE_AT_LAST = "truncate_at_last" + TRUNCATE_AT_FIRST = "truncate_at_first" + EXTRACT_ALL = "extract_all" + STRICT_TRAILING_EMPTY = "strict_trailing_empty" + + +class InlineToolCallParser: + """ + Extract canonical ``ToolCall`` instances from marker-delimited JSON blocks. + + Open chat-tuned models that do not expose a structured ``tool_calls`` + channel typically emit tool calls as inline text wrapped in a + chat-template-specific marker (an angle-bracket pair, a pipe-delimited + tag pair, a square-bracketed list payload, and so on). This parser + walks every ``MessagePiece`` whose ``original_value_data_type`` is + ``"text"`` and runs ``marker_pattern`` against the piece's + ``original_value``. Each match capture group is decoded as JSON of the + form ``{"name": ..., "arguments": {...}}``. Synthetic ``call_id`` + values are minted positionally because inline-marker formats do not + issue provider IDs. + + The ``mode`` parameter controls how text surrounding the markers is + treated -- see ``InlineToolCallParserMode``. The default + ``TRUNCATE_AT_LAST`` honors every marker but discards anything after + the last one so hallucinated "tool results" that the model dreams up + after the call are not persisted as if they were real outputs. + + Args: + marker_pattern (str): Regex with exactly one capture group returning + the JSON payload. Default targets the angle-bracket + ``...`` syntax used by many tool-trained + ChatML-style chat templates. + call_id_prefix (str): Prefix for synthetic ``call_id`` values. + mode (InlineToolCallParserMode): Surrounding-text policy. + """ + + def __init__( + self, + *, + marker_pattern: str = r"(.*?)", + call_id_prefix: str = "call", + mode: InlineToolCallParserMode = InlineToolCallParserMode.TRUNCATE_AT_LAST, + ) -> None: + """ + Build an ``InlineToolCallParser``. + + Args: + marker_pattern (str): See class docstring. + call_id_prefix (str): See class docstring. + mode (InlineToolCallParserMode): See class docstring. + """ + self._pattern = re.compile(marker_pattern, re.DOTALL) + self._call_id_prefix = call_id_prefix + self._mode = mode + + @property + def mode(self) -> InlineToolCallParserMode: + """The active surrounding-text policy.""" + return self._mode + + def parse(self, message: Message) -> list[ToolCall]: + """ + Extract tool calls from every text piece in ``message``. + + Args: + message (Message): The most recent assistant response. + + Returns: + list[ToolCall]: One ``ToolCall`` per valid marker match, in + declaration order across pieces. Empty when no markers + are found. + + Raises: + ValueError: When ``mode`` is ``STRICT_TRAILING_EMPTY`` and any + non-whitespace text follows the last marker in any piece. + """ + calls: list[ToolCall] = [] + next_id = 0 + for piece in message.message_pieces: + if piece.original_value_data_type != "text": + continue + matches = self._match_piece(text=piece.original_value) + if not matches: + continue + for match in matches: + call = self._build_call(match=match, next_id=next_id) + if call is None: + continue + calls.append(call) + next_id += 1 + + return calls + + def _match_piece(self, *, text: str) -> list[re.Match[str]]: + """ + Apply the mode-specific filter to all marker matches in ``text``. + + Args: + text (str): Piece text to scan for markers. + + Returns: + list[re.Match[str]]: Matches to honor, in declaration order. + + Raises: + ValueError: When ``mode`` is ``STRICT_TRAILING_EMPTY`` and any + non-whitespace text follows the last marker. + """ + matches = list(self._pattern.finditer(text)) + if not matches: + return matches + + if self._mode is InlineToolCallParserMode.TRUNCATE_AT_FIRST: + return matches[:1] + if self._mode is InlineToolCallParserMode.STRICT_TRAILING_EMPTY: + trailing = text[matches[-1].end() :] + if trailing.strip(): + raise ValueError( + "Non-whitespace text follows the last tool-call marker; " + "InlineToolCallParserMode.STRICT_TRAILING_EMPTY rejects this. " + f"Trailing: {trailing!r}" + ) + return matches + + def _build_call(self, *, match: re.Match[str], next_id: int) -> ToolCall | None: + """ + Decode a single marker match into a ``ToolCall``. + + Args: + match (re.Match[str]): A single marker match whose group 1 is the + JSON payload. + next_id (int): Positional id used to form ``call_id``. + + Returns: + ToolCall | None: ``None`` when the payload is malformed or + missing the ``name`` field. The caller is expected to log + and skip. + """ + from pyrit.tools.models import ToolCall + + payload = match.group(1).strip() + try: + parsed = json.loads(payload) + except json.JSONDecodeError: + logger.warning("Skipping malformed tool-call payload: %r", payload[:120]) + return None + if not isinstance(parsed, dict) or "name" not in parsed: + logger.warning("Skipping tool-call payload without 'name' field: %r", payload[:120]) + return None + arguments = self._coerce_arguments(parsed.get("arguments", {})) + return ToolCall( + call_id=f"{self._call_id_prefix}_{next_id}", + name=parsed["name"], + arguments=arguments, + raw_envelope=parsed, + ) + + @staticmethod + def _coerce_arguments(raw: object) -> dict: + """ + Coerce the ``arguments`` field into a dict regardless of source shape. + + Args: + raw (object): The value of the payload's ``arguments`` field. + Either a dict, a JSON-encoded string, or something else. + + Returns: + dict: The decoded arguments dict, or an empty dict on any + shape PyRIT cannot interpret. Empty-dict fallback preserves + the loop's behavior of "always continue with a real call". + """ + if isinstance(raw, dict): + return dict(raw) + if isinstance(raw, str): + try: + decoded = json.loads(raw) + except json.JSONDecodeError: + return {} + return dict(decoded) if isinstance(decoded, dict) else {} + return {} diff --git a/pyrit/tools/local_backend.py b/pyrit/tools/local_backend.py new file mode 100644 index 0000000000..7a149bc825 --- /dev/null +++ b/pyrit/tools/local_backend.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from pyrit.tools.backend import ToolBackend + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from pyrit.tools.models import ToolCall + +logger = logging.getLogger(__name__) + + +class LocalToolBackend(ToolBackend): + """ + In-process ``ToolBackend`` backed by a name -> ``async def`` + mapping. Useful for unit tests and for embedding small tools inside the + PyRIT process without standing up an MCP server. + + "Local" here means tools run in PyRIT's own Python process — no + subprocess, no IPC, no wire protocol. Contrast with + ``MCPToolBackend``, which proxies dispatch through one or more MCP + servers reached via JSON-RPC. + + The backend dispatches sequentially in declaration order. Tool-side + failures (raised exceptions, missing names, allow-list rejections) + are converted into structured error envelopes so the tool loop can + forward them back to the model as ``function_call_output`` content + rather than aborting the conversation. + """ + + def __init__( + self, + *, + callables: dict[str, Callable[[dict[str, Any]], Awaitable[Any]]], + schemas: list[dict[str, Any]] | None = None, + allowed_tools: set[str] | None = None, + fail_on_missing_function: bool = True, + ) -> None: + """ + Initialize the backend. + + Args: + callables (dict[str, Callable[[dict[str, Any]], Awaitable[Any]]]): + Map from tool name to an ``async def`` that accepts a parsed + arguments dict and returns the tool result. Results are + serialized by the tool loop via ``json.dumps``. + schemas (list[dict[str, Any]] | None): JSON-schema descriptors + injected into the target's request body. Defaults to an empty + list when omitted. + allowed_tools (set[str] | None): Optional allow-list of tool + names; calls whose name is not in this set surface as + ``tool_not_allowed`` envelopes without invoking the callable. + Defaults to None (no allow-list; every registered tool is + callable). + fail_on_missing_function (bool): When True (default), an unknown + tool name raises ``KeyError``. When False, the backend + returns a ``tool_not_registered`` envelope so the model can + recover. + """ + self._callables = dict(callables) + self._schemas: list[dict[str, Any]] = list(schemas) if schemas is not None else [] + self._allowed_tools = set(allowed_tools) if allowed_tools is not None else None + self._fail_on_missing_function = fail_on_missing_function + + @property + def schemas(self) -> list[dict[str, Any]]: + """The JSON-schema descriptors for the tools in this backend.""" + return list(self._schemas) + + async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: + """ + Dispatch a single tool call. Tool failures are converted into + structured envelopes; only configuration errors (missing tool with + ``fail_on_missing_function=True``) propagate as exceptions. + + Args: + call (ToolCall): The call to dispatch. + + Returns: + dict[str, Any]: The tool's result, or a structured error envelope. + + Raises: + KeyError: When the tool name is not registered and + ``fail_on_missing_function=True``. + """ + if self._allowed_tools is not None and call.name not in self._allowed_tools: + logger.info("Rejecting disallowed tool call: %s", call.name) + return { + "error": "tool_not_allowed", + "tool": call.name, + "allowed_tools": sorted(self._allowed_tools), + } + + fn = self._callables.get(call.name) + if fn is None: + if self._fail_on_missing_function: + raise KeyError(f"Tool '{call.name}' is not registered.") + available = sorted(self._callables.keys()) + logger.warning("Tool '%s' not registered. Available: %s", call.name, available) + return { + "error": "tool_not_registered", + "tool": call.name, + "available_tools": available, + } + + try: + result = await fn(call.arguments) + except Exception as ex: + logger.warning("Tool '%s' raised %s: %s", call.name, type(ex).__name__, ex) + return { + "error": "tool_execution_failed", + "tool": call.name, + "detail": str(ex), + } + return result if isinstance(result, dict) else {"result": result} diff --git a/pyrit/tools/mcp_backend.py b/pyrit/tools/mcp_backend.py new file mode 100644 index 0000000000..41a0d32462 --- /dev/null +++ b/pyrit/tools/mcp_backend.py @@ -0,0 +1,199 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Multi-server tool backend that proxies dispatch through one or more +MCP servers. + +This is the ``ToolBackend`` implementation that real +red-team configurations use. It composes one +``MCPClient`` per ``MCPServerSpec``, +aggregates their advertised schemas, routes incoming +``ToolCall`` instances to the correct underlying +client, and enforces an optional ``allowed_tools`` allow-list. + +Contrast with ``LocalToolBackend``, which dispatches +to Python ``async def`` callables inside PyRIT's own process. +""" + +from __future__ import annotations + +import asyncio +import logging +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Any + +from pyrit.tools.backend import ToolBackend +from pyrit.tools.mcp_client import MCPClient + +if TYPE_CHECKING: + from collections.abc import Iterable + + from pyrit.tools.mcp_client import MCPServerSpec + from pyrit.tools.models import ToolCall + +logger = logging.getLogger(__name__) + + +class MCPToolBackend(ToolBackend): + """ + ``ToolBackend`` backed by one or more MCP servers. + + On ``__aenter__``, the backend spawns / connects each server in + its ``_servers`` list (sequentially) through a single + ``contextlib.AsyncExitStack``, runs the MCP handshake, caches + schemas, and builds an advertised-name → ``(client, server_name)`` + routing table. Collisions raise ``ValueError`` unless the + colliding specs set ``name_prefix``. + + A single shared ``AsyncExitStack`` (rather than one per client) + is required so anyio's nested cancel scopes — opened by the ``mcp`` + SDK's ``stdio_client`` and ``ClientSession`` context managers — are + closed in strict LIFO order from the entering task. Closing + out-of-order would trip + ``"Attempted to exit a cancel scope that isn't the current task's + current cancel scope"``. + + Dispatch is serialized through an ``asyncio.Lock`` per backend + instance — multiple concurrent coroutines sharing the same backend + (e.g. parallel attack runs) will not interleave JSON-RPC frames on + the same stdio pipe. + """ + + def __init__( + self, + *, + servers: Iterable[MCPServerSpec], + allowed_tools: list[str] | None = None, + ) -> None: + """ + Initialize the backend. + + Args: + servers: One or more ``MCPServerSpec`` instances describing + where each server runs. + allowed_tools: Optional allow-list of tool names. Names not in + the list are filtered from ``schemas`` AND + short-circuit dispatch with a ``tool_not_allowed`` envelope. + Names are matched after ``name_prefix`` + has been applied. Defaults to None (every advertised tool is + callable). + + Raises: + ValueError: When *servers* is empty. + """ + self._servers: list[MCPServerSpec] = list(servers) + if not self._servers: + raise ValueError("MCPToolBackend requires at least one server spec.") + self._allowed_tools: set[str] | None = set(allowed_tools) if allowed_tools is not None else None + self._clients: list[MCPClient] = [] + self._routing: dict[str, tuple[MCPClient, str]] = {} + self._dispatch_lock = asyncio.Lock() + self._stack: AsyncExitStack | None = None + self._entered = False + + @property + def schemas(self) -> list[dict[str, Any]]: + """The union of every connected server's schemas, filtered by ``allowed_tools``.""" + out: list[dict[str, Any]] = [] + for client in self._clients: + for schema in client.schemas: + if self._allowed_tools is not None and schema["name"] not in self._allowed_tools: + continue + out.append(schema) + return out + + async def __aenter__(self) -> MCPToolBackend: + """ + Connect each underlying client through a shared ``AsyncExitStack`` and build the routing table. + + Returns: + MCPToolBackend: *self*, ready to dispatch. + + Raises: + ValueError: When two connected clients advertise the same tool + name without a disambiguating ``name_prefix``. + """ + stack = AsyncExitStack() + clients: list[MCPClient] = [] + routing: dict[str, tuple[MCPClient, str]] = {} + try: + for spec in self._servers: + client = MCPClient(spec=spec) + await stack.enter_async_context(client) + clients.append(client) + for advertised_name in client.tool_names: + if advertised_name in routing: + raise ValueError( + f"duplicate tool name '{advertised_name}'. " + "Set LocalMCPServerSpec.name_prefix on at least one " + "colliding server to disambiguate.", + ) + routing[advertised_name] = (client, advertised_name) + except Exception: + await stack.aclose() + raise + + self._stack = stack + self._clients = clients + self._routing = routing + self._entered = True + return self + + async def __aexit__(self, *exc: Any) -> None: + """Tear down every underlying client in strict LIFO order.""" + stack = self._stack + self._stack = None + self._clients = [] + self._routing = {} + self._entered = False + if stack is not None: + await stack.aclose() + + async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: + """ + Route *call* to the correct client and dispatch. + + See ``MCPClient.dispatch_async`` for the envelope shape. + Allow-list rejections and unknown-tool calls return error + envelopes; only "backend not entered" raises. + + Args: + call (ToolCall): The call to dispatch. + + Returns: + dict[str, Any]: A structured envelope (success, ``tool_not_allowed``, + ``tool_not_registered``, or the underlying + ``MCPClient.dispatch_async`` envelope). + + Raises: + RuntimeError: When the backend has not been entered via ``async with``. + """ + if not self._entered: + raise RuntimeError( + "MCPToolBackend is not active. Use `async with backend:` to manage its lifecycle before dispatching.", + ) + + if self._allowed_tools is not None and call.name not in self._allowed_tools: + logger.info("Rejecting disallowed tool call: %s", call.name) + return { + "is_error": True, + "error": "tool_not_allowed", + "tool": call.name, + "allowed_tools": sorted(self._allowed_tools), + } + + route = self._routing.get(call.name) + if route is None: + available = sorted(self._routing.keys()) + logger.warning("Tool '%s' not registered. Available: %s", call.name, available) + return { + "is_error": True, + "error": "tool_not_registered", + "tool": call.name, + "available_tools": available, + } + + client, _server_side_name = route + async with self._dispatch_lock: + return await client.dispatch_async(call) diff --git a/pyrit/tools/mcp_client.py b/pyrit/tools/mcp_client.py new file mode 100644 index 0000000000..0f9eeadcc3 --- /dev/null +++ b/pyrit/tools/mcp_client.py @@ -0,0 +1,369 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Stdio-transport client for the Model Context Protocol (MCP). + +This module is the wire-protocol half of PyRIT's MCP integration. It +sits below ``MCPToolBackend`` (which composes one +``MCPClient`` per configured server and handles cross-server +routing) and above the upstream ``mcp`` Python SDK (which owns the +JSON-RPC framing, capability negotiation, and asyncio task plumbing). + +The three ``MCPServerSpec`` variants describe *where* the server +runs. Only ``LocalMCPServerSpec`` is implemented in this commit: + +* ``LocalMCPServerSpec`` — spawn the server as a child process and + speak JSON-RPC over its stdin/stdout. +* ``RemoteMCPServerSpec`` — HTTP/SSE transport against a hosted + server. Stub: ``connect_async`` raises ``NotImplementedError``. +* ``DockerMCPServerSpec`` — stdio over ``docker run -i`` against a + hardened sandbox container. Stub: ``connect_async`` raises + ``NotImplementedError``. Implementation lands in the follow-up + sandbox PR. + +The stub variants are intentionally part of the type union today so +downstream code can be written against the eventual API without +forcing a Union expansion later. +""" + +from __future__ import annotations + +import asyncio +import logging +from contextlib import AsyncExitStack +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from mcp import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client + +if TYPE_CHECKING: + from pyrit.tools.models import ToolCall + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class LocalMCPServerSpec: + """ + Spec for an MCP server spawned as a child process and reached via + stdio JSON-RPC. + + Attributes: + command (str): The interpreter or binary to exec (e.g. ``"python"``). + args (tuple[str, ...]): Arguments passed to *command*, in order. + env (dict[str, str] | None): Environment overlay for the child + process. ``None`` (default) inherits PyRIT's environment. + name_prefix (str | None): When set, every tool advertised by the + server is registered as ``f"{name_prefix}{tool_name}"`` in + the parent ``MCPToolBackend``. Used to + disambiguate two servers that expose the same tool name. + timeout_seconds (float): Per-call timeout, enforced by + ``MCPClient.dispatch_async``. Defaults to 30 seconds. + """ + + command: str + args: tuple[str, ...] = () + env: dict[str, str] | None = None + name_prefix: str | None = None + timeout_seconds: float = 30.0 + + +@dataclass(frozen=True) +class RemoteMCPServerSpec: + """ + Spec for an MCP server reached over HTTP / SSE. **Not implemented** + in this PR — ``MCPClient.connect_async`` raises + ``NotImplementedError``. Tracked by ``# TODO(mcp-http-transport)``. + + Attributes: + url (str): The base URL of the MCP server. + name_prefix (str | None): Same semantics as + ``LocalMCPServerSpec.name_prefix``. + timeout_seconds (float): Per-call timeout. + """ + + url: str + name_prefix: str | None = None + timeout_seconds: float = 30.0 + + +# TODO(sandbox-provider) — DockerMCPServerSpec stub here; implementation lands in follow-up PR. +@dataclass(frozen=True) +class DockerMCPServerSpec: + """ + Spec for an MCP server hosted inside a hardened Docker container. + + **NOT IMPLEMENTED IN THIS PR.** Reached via stdio over ``docker run -i``. + + Expected behavior in the follow-up sandbox PR: + + * One container per spec instance, managed by a process-wide + ``SandboxPool``. + * Image is built lazily, keyed by ``sha256(Dockerfile + build_context)``, + and cached across attacks; no rebuild unless missing or explicitly + overridden. + * Container is recreated from the cached image at attack and scenario + boundaries (filesystem returns to baseline every time). + * Network access governed by ``NetworkProfile`` (default ``"none"`` = + ``--network=none``). + * Container runs as a non-root UID with ``--cap-drop=ALL``, a read-only + root filesystem, and an in-container MCP server exposing + ``run_shell(cmd, timeout_seconds)``. + + Attributes: + image (str): Docker image tag (e.g. ``"pyrit-sandbox:base"``). + network_profile (str): ``NetworkProfile`` name; ``"none"`` (default) + launches the container with ``--network=none``. + name_prefix (str | None): Same semantics as + ``LocalMCPServerSpec.name_prefix``. + timeout_seconds (float): Per-call timeout. + + Future fields (deferred to the follow-up sandbox PR): ``memory_limit``, + ``cpu_limit``, ``pids_limit``, ``env``, ``mounts``, ``command_override``. + """ + + image: str + network_profile: str = "none" + name_prefix: str | None = None + timeout_seconds: float = 30.0 + + +MCPServerSpec = LocalMCPServerSpec | RemoteMCPServerSpec | DockerMCPServerSpec + + +def _to_input_schema_dict(input_schema: Any) -> dict[str, Any]: + """ + Coerce the SDK's tool ``inputSchema`` (pydantic model or dict) into a plain dict. + + Returns: + dict[str, Any]: A plain-dict copy of *input_schema*, or an empty + object schema when *input_schema* is None or of an unrecognized type. + """ + if input_schema is None: + return {"type": "object", "properties": {}} + if hasattr(input_schema, "model_dump"): + return input_schema.model_dump() + if isinstance(input_schema, dict): + return dict(input_schema) + return {"type": "object", "properties": {}} + + +def _flatten_content(content: list[Any]) -> str: + """ + Concatenate the text portions of an MCP ``CallToolResult.content`` list. + + Returns: + str: Concatenated ``.text`` values from each content item, in order. + """ + pieces: list[str] = [] + for item in content: + text = getattr(item, "text", None) + if text is not None: + pieces.append(text) + elif isinstance(item, dict) and "text" in item: + pieces.append(item["text"]) + return "".join(pieces) + + +class MCPClient: + """ + A single MCP-server session. + + The client owns the lifetime of one server's transport stack and + exposes a uniform ``dispatch_async`` regardless of which + ``MCPServerSpec`` variant it was constructed from. Composition + across multiple servers (routing, schema aggregation, allow-lists) + is the responsibility of ``MCPToolBackend``. + + Lifecycle: + + * ``connect_async`` spawns the subprocess (for + ``LocalMCPServerSpec``), runs the MCP handshake, and caches + ``tools/list`` results. + * ``dispatch_async`` issues one ``tools/call`` and returns a + structured envelope (success or error). + * ``close_async`` tears down the transport stack. + + The class is usable as an async context manager. + """ + + def __init__(self, *, spec: MCPServerSpec) -> None: + """ + Initialize the client around *spec*. Does not connect; call + ``connect_async`` (or use the async context-manager form) to start + the transport stack. + """ + self._spec = spec + self._stack = AsyncExitStack() + self._session: ClientSession | None = None + self._tools: list[Any] = [] + + @property + def spec(self) -> MCPServerSpec: + """The ``MCPServerSpec`` this client was constructed with.""" + return self._spec + + @property + def schemas(self) -> list[dict[str, Any]]: + """ + JSON schemas for every tool the server advertises. + + Each schema is shaped ``{"name", "description", "parameters"}``. + The optional ``LocalMCPServerSpec.name_prefix`` is applied + here so a backend that owns this client sees the prefixed name. + """ + prefix = getattr(self._spec, "name_prefix", None) or "" + return [ + { + "name": f"{prefix}{tool.name}", + "description": tool.description or "", + "parameters": _to_input_schema_dict(tool.inputSchema), + } + for tool in self._tools + ] + + @property + def tool_names(self) -> list[str]: + """Tool names with the spec's ``name_prefix`` applied.""" + return [s["name"] for s in self.schemas] + + def _strip_prefix(self, name: str) -> str: + prefix = getattr(self._spec, "name_prefix", None) or "" + if prefix and name.startswith(prefix): + return name[len(prefix) :] + return name + + async def connect_async(self) -> None: + """Establish the transport, run the handshake, and cache schemas.""" + if isinstance(self._spec, RemoteMCPServerSpec): + raise NotImplementedError( + "HTTP/SSE transport ships in a follow-up PR. " + "RemoteMCPServerSpec is declared today so user code can target the eventual API." + ) + if isinstance(self._spec, DockerMCPServerSpec): + raise NotImplementedError( + "Docker sandbox transport ships in a follow-up PR. " + "DockerMCPServerSpec runs the MCP server inside a hardened " + "Debian container reached via stdio over `docker run -i`, " + "managed by a process-wide SandboxPool with image caching and " + "per-attack container recreation." + ) + + assert isinstance(self._spec, LocalMCPServerSpec) + params = StdioServerParameters( + command=self._spec.command, + args=list(self._spec.args), + env=self._spec.env, + ) + read, write = await self._stack.enter_async_context(stdio_client(params)) + session = await self._stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + result = await session.list_tools() + self._session = session + self._tools = list(result.tools) + + async def close_async(self) -> None: + """Tear down the transport stack. Idempotent; safe to call before connect.""" + try: + await self._stack.aclose() + except Exception as ex: # noqa: BLE001 — close should never raise into the caller. + logger.warning("Error tearing down MCP client stack: %s", ex) + finally: + self._stack = AsyncExitStack() + self._session = None + self._tools = [] + + async def __aenter__(self) -> MCPClient: + """ + Connect the transport stack and return *self*. + + Returns: + MCPClient: *self*, ready to dispatch tool calls. + """ + await self.connect_async() + return self + + async def __aexit__(self, *exc: Any) -> None: + """Tear down the transport stack.""" + await self.close_async() + + async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: + """ + Issue one ``tools/call`` and return a structured envelope. + + Envelope shape: + + * Success: ``{"is_error": False, "content": str, "tool": name}``. + * Timeout: ``{"is_error": True, "error": "tool_timeout", "tool": name, ...}``. + * Server-reported error: ``{"is_error": True, "error": "tool_execution_failed", "tool": name, ...}``. + + Tool-side failures are converted to envelopes; only programmer + errors (calling before ``connect_async``) raise. + + Args: + call (ToolCall): The call to dispatch. The advertised + ``name_prefix`` (if any) is stripped before contacting the server. + + Returns: + dict[str, Any]: One of the envelope shapes documented above. + + Raises: + RuntimeError: When the client has not been connected. + """ + if self._session is None: + raise RuntimeError("MCPClient is not connected; call connect_async first.") + + server_side_name = self._strip_prefix(call.name) + timeout = getattr(self._spec, "timeout_seconds", 30.0) + try: + result = await asyncio.wait_for( + self._session.call_tool(server_side_name, arguments=dict(call.arguments)), + timeout=timeout, + ) + except asyncio.TimeoutError: + logger.warning( + "MCP tool '%s' timed out after %.2fs", + call.name, + timeout, + ) + return { + "is_error": True, + "error": "tool_timeout", + "tool": call.name, + "timeout_seconds": timeout, + } + except Exception as ex: # noqa: BLE001 — wrap and surface as envelope. + logger.warning( + "MCP tool '%s' raised %s: %s", + call.name, + type(ex).__name__, + ex, + ) + return { + "is_error": True, + "error": "tool_execution_failed", + "tool": call.name, + "detail": str(ex), + } + + content_text = _flatten_content(list(result.content)) + is_error = bool(getattr(result, "isError", False)) + envelope: dict[str, Any] = { + "is_error": is_error, + "content": content_text, + "tool": call.name, + } + if is_error: + envelope["error"] = "tool_execution_failed" + return envelope + + +__all__ = [ + "DockerMCPServerSpec", + "LocalMCPServerSpec", + "MCPClient", + "MCPServerSpec", + "RemoteMCPServerSpec", +] diff --git a/pyrit/tools/models.py b/pyrit/tools/models.py new file mode 100644 index 0000000000..187bf635d9 --- /dev/null +++ b/pyrit/tools/models.py @@ -0,0 +1,241 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import enum +import functools +import json +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from pyrit.exceptions import ToolCallLoopLimitExceeded, ToolCallNotSupported +from pyrit.models import Message, MessagePiece + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from pyrit.tools.backend import ToolBackend + from pyrit.tools.parsers import ToolCallParser + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ToolCall: + """ + A parsed tool call extracted from a target response. + + Concrete ``ToolCallParser`` implementations build + ``ToolCall`` instances by walking the response message pieces. + The ``raw_envelope`` carries the original target-specific dict + (e.g. the function_call JSON section) so dispatchers and observers + can recover provider-specific fields without re-parsing. + + Attributes: + call_id (str): The provider-issued call identifier; must round-trip + into the matching ``function_call_output`` piece. + name (str): The tool name to dispatch. + arguments (dict[str, Any]): The parsed JSON arguments. + raw_envelope (dict[str, Any]): The original provider envelope. + """ + + call_id: str + name: str + arguments: dict[str, Any] + raw_envelope: dict[str, Any] = field(default_factory=dict) + + +class ToolEventBehavior(enum.Enum): + """ + What the tool loop should do when a target response contains a + pending tool call. + + Values: + EXECUTE: Dispatch the call via ``configuration.tool_backend`` + and re-enter the target with the tool output appended. + This is the standard agentic loop behavior. + RAISE: Raise ``ToolCallNotSupported`` with + the partial conversation attached. Useful for red-team + attacks that want to observe attempted tool use without + allowing execution. + RETURN_RAW: Return the assistant response containing the tool + call as-is, without dispatching. Useful when a caller wants + to inspect tool calls in-band (e.g. a scorer that scores + attempted tool use). + """ + + EXECUTE = "execute" + RAISE = "raise" + RETURN_RAW = "return_raw" + + +@dataclass(frozen=True) +class ToolEventPolicy: + """ + Per-target configuration that controls how the tool loop responds + to a pending tool call from the model. + + Attributes: + behavior (ToolEventBehavior): What to do on each detected tool call. + max_tool_iterations (int): Maximum number of model<->tool round-trips + before the loop raises ``ToolCallLoopLimitExceeded``. Each + iteration is one ``_send_prompt_to_target_async`` call. + """ + + behavior: ToolEventBehavior + max_tool_iterations: int = 5 + + +def _build_function_call_output_message( + *, + reference_piece: MessagePiece, + outputs: list[tuple[ToolCall, Any]], +) -> Message: + """ + Build the canonical ``tool`` message produced after dispatching one or more + tool calls in a single iteration. + + The returned ``Message`` contains one + ``MessagePiece`` per ``(call, result)`` pair, in declaration order. + Every piece has ``role="tool"`` and ``original_value_data_type="function_call_output"``, + with the JSON envelope ``{"type": "function_call_output", "call_id": ..., "output": ...}``. + + Lineage metadata (conversation_id, identifiers) is copied from + *reference_piece* — typically the first piece of the assistant message + that issued the tool calls — so the resulting message stays inside the + correct conversation. + + Args: + reference_piece (MessagePiece): Piece whose lineage metadata is + copied onto every output piece. Pass the first piece of the + assistant message that produced the calls. + outputs (list[tuple[ToolCall, Any]]): ``(call, result)`` pairs in + declaration order. *result* is serialized via ``json.dumps`` + unless it is already a string. + + Returns: + Message: One message carrying every function_call_output piece. + """ + pieces: list[MessagePiece] = [] + for call, result in outputs: + output_str = result if isinstance(result, str) else json.dumps(result, separators=(",", ":")) + envelope = json.dumps( + {"type": "function_call_output", "call_id": call.call_id, "output": output_str}, + separators=(",", ":"), + ) + pieces.append( + MessagePiece( + role="tool", + original_value=envelope, + original_value_data_type="function_call_output", + conversation_id=reference_piece.conversation_id, + prompt_target_identifier=reference_piece.prompt_target_identifier, + attack_identifier=reference_piece.attack_identifier, + ) + ) + return Message(message_pieces=pieces, skip_validation=True) + + +def tool_loop( + method: Callable[..., Awaitable[list[Message]]], +) -> Callable[..., Awaitable[list[Message]]]: + """ + Wrap a ``PromptTarget``-style + ``send_prompt_async`` to run an agentic tool-use loop. + + When the target's ``configuration.tool_event_policy`` is ``None`` the + wrapper is a no-op — the wrapped method runs unchanged. When a policy + is configured, the wrapper replaces the method body with the loop: + + 1. Validate and normalize the incoming message exactly once. + 2. Repeatedly call ``self._send_prompt_to_target_async`` with the + growing conversation. + 3. After each call, parse the last response via ``self._tool_parser``. + Exit on empty parse (model issued a stop response). + 4. On a non-empty parse, branch on ``policy.behavior``: + ``RAISE`` raises ``ToolCallNotSupported``; ``RETURN_RAW`` + returns the chain as-is; ``EXECUTE`` dispatches the calls via + ``configuration.tool_backend`` and appends the tool message. + 5. Raise ``ToolCallLoopLimitExceeded`` if the loop runs past + ``policy.max_tool_iterations`` without the model stopping. + + The decorator deliberately knows nothing about MCP, OpenAI, or any + specific transport. The two collaborators it requires — + ``self._tool_parser`` and ``self.configuration.tool_backend`` — are + plain protocols (``ToolCallParser``, ``ToolBackend``). + + Args: + method (Callable): The async method to wrap. Must have the + ``async def f(self, *, message: Message) -> list[Message]`` + signature of ``PromptTarget.send_prompt_async``. + + Returns: + Callable: The wrapped method. + """ + + @functools.wraps(method) + async def wrapper(self: Any, *, message: Message) -> list[Message]: + policy: ToolEventPolicy | None = getattr(self.configuration, "tool_event_policy", None) + if policy is None: + return await method(self, message=message) + + message.validate() + normalized_conversation = await self._get_normalized_conversation_async(message=message) + if not normalized_conversation: + raise ValueError("Normalization pipeline returned an empty conversation. Cannot send an empty request.") + self._validate_request(normalized_conversation=normalized_conversation) + + parser: ToolCallParser | None = getattr(self, "_tool_parser", None) + backend: ToolBackend | None = getattr(self.configuration, "tool_backend", None) + max_iter = policy.max_tool_iterations + + all_responses: list[Message] = [] + + for _ in range(max_iter): + responses_this_turn = await self._send_prompt_to_target_async( + normalized_conversation=normalized_conversation, + ) + all_responses.extend(responses_this_turn) + + if parser is None: + return all_responses + + last_response = responses_this_turn[-1] + pending_calls = parser.parse(last_response) + + if not pending_calls: + return all_responses + + if policy.behavior is ToolEventBehavior.RAISE: + raise ToolCallNotSupported( + message=( + f"Target produced {len(pending_calls)} tool call(s) but ToolEventPolicy.behavior is RAISE." + ), + partial_conversation=all_responses, + ) + + if policy.behavior is ToolEventBehavior.RETURN_RAW: + return all_responses + + if backend is None: + raise ToolCallNotSupported( + message=(f"Target produced {len(pending_calls)} tool call(s) but no tool_backend is configured."), + partial_conversation=all_responses, + ) + + results = await backend.dispatch_all_sequential_async(pending_calls) + tool_msg = _build_function_call_output_message( + reference_piece=last_response.message_pieces[0], + outputs=results, + ) + all_responses.append(tool_msg) + normalized_conversation = list(normalized_conversation) + [last_response, tool_msg] + + raise ToolCallLoopLimitExceeded( + message=f"Tool loop exceeded max_tool_iterations={max_iter} without a stop response.", + partial_conversation=all_responses, + ) + + return wrapper diff --git a/pyrit/tools/parsers.py b/pyrit/tools/parsers.py new file mode 100644 index 0000000000..8f0038d5a0 --- /dev/null +++ b/pyrit/tools/parsers.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from pyrit.models import Message, MessagePiece + from pyrit.tools.models import ToolCall + + +@runtime_checkable +class ToolCallParser(Protocol): + """ + Protocol for extracting tool calls from a target response message. + + Concrete parsers live next to the target whose response shape they + understand (the canonical-envelope parser shipped here is shared by + ``OpenAIResponseTarget``; per-model-family parsers for non-OpenAI + targets ship in a follow-up, see plan §12.9). Parsers MUST return an + empty list when the model has issued a stop response — the tool loop + uses the empty list as the signal to exit. + """ + + def parse(self, message: Message) -> list[ToolCall]: + """ + Extract tool calls from a target response message. + + Args: + message (Message): The most recent assistant response. + + Returns: + list[ToolCall]: Tool calls, in declaration order. An empty list + signals that the model produced a stop response. + """ + ... + + +def _extract_function_call_pieces(message: Message) -> list[MessagePiece]: + """ + Return every ``MessagePiece`` in *message* whose + ``original_value_data_type`` is ``"function_call"``. + + This is the canonical envelope used by every PyRIT-supported tool-emitting + target. It is exposed here so concrete parsers can reuse the filter rather + than re-implementing it. + + Args: + message (Message): The message to scan. + + Returns: + list[MessagePiece]: Pieces whose ``original_value_data_type`` is + ``"function_call"``, in their declaration order. + """ + return [piece for piece in message.message_pieces if piece.original_value_data_type == "function_call"] + + +class CanonicalEnvelopeParser: + """ + Reference ``ToolCallParser`` for the canonical function_call envelope. + + Walks every ``MessagePiece`` whose ``original_value_data_type`` is + ``"function_call"`` and decodes the canonical JSON shape:: + + { + "type": "function_call", + "call_id": "", + "name": "", + "arguments": "" + } + + into ``ToolCall`` instances. Pieces of other data types -- reasoning, + mcp_call, web_search_call, etc. -- are ignored (they pass through to + Memory but are not client-side dispatchable). Per-model-family parsers + for non-OpenAI targets ship in a follow-up PR (see plan §12.9). + """ + + def parse(self, message: Message) -> list[ToolCall]: + """ + Decode canonical ``function_call`` pieces in *message* into ``ToolCall``. + + Args: + message (Message): The most recent assistant response. + + Returns: + list[ToolCall]: One ``ToolCall`` per ``function_call`` + piece, in declaration order. Empty if the message contains + no ``function_call`` pieces (model stop). + """ + from pyrit.tools.models import ToolCall + + calls: list[ToolCall] = [] + for piece in _extract_function_call_pieces(message): + envelope = json.loads(piece.original_value) + arguments_raw = envelope.get("arguments", "{}") + arguments = json.loads(arguments_raw) if isinstance(arguments_raw, str) else dict(arguments_raw) + calls.append( + ToolCall( + call_id=envelope["call_id"], + name=envelope["name"], + arguments=arguments, + raw_envelope=envelope, + ) + ) + return calls diff --git a/tests/integration/tools/__init__.py b/tests/integration/tools/__init__.py new file mode 100644 index 0000000000..9a0454564d --- /dev/null +++ b/tests/integration/tools/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/integration/tools/test_azure_ml_with_tools_integration.py b/tests/integration/tools/test_azure_ml_with_tools_integration.py new file mode 100644 index 0000000000..9f0b4903b7 --- /dev/null +++ b/tests/integration/tools/test_azure_ml_with_tools_integration.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""F2 integration test: AzureMLChatTarget end-to-end through @tool_loop. + +Validates that PyRIT's full client-side tool-calling stack works against +an AzureML chat target whose scoring script emits the canonical +function_call envelope per plan §12.9.2. Only the outbound HTTP layer is +mocked; the @tool_loop decorator, CanonicalEnvelopeParser, LocalToolBackend +dispatch, ChatMessageNormalizer tool-piece serialization, and Memory +persistence all run unmocked. + +The test asserts the canonical four-piece transcript shape: +``[user, assistant function_call, tool function_call_output, assistant text]``. +""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.memory import CentralMemory +from pyrit.models import Message, MessagePiece +from pyrit.prompt_normalizer import PromptNormalizer +from pyrit.prompt_target import AzureMLChatTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import ( + CanonicalEnvelopeParser, + LocalToolBackend, + ToolEventBehavior, + ToolEventPolicy, +) + + +@pytest.fixture +def echo_backend(): + async def _echo(args): + return {"echoed": args.get("text", "")} + + return LocalToolBackend( + callables={"echo": _echo}, + schemas=[ + { + "name": "echo", + "description": "Echo back the given text.", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ], + ) + + +@pytest.mark.run_only_if_all_tests +async def test_azure_ml_chat_target_tool_loop_round_trip(patch_central_database, echo_backend): + """User asks for a tool call; the loop dispatches; the model produces final text.""" + target = AzureMLChatTarget( + endpoint="https://mock-endpoint.example.com/score", + api_key="dummy", + tool_parser=CanonicalEnvelopeParser(), + tool_backend=echo_backend, + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_message_pieces=True, + supports_editable_history=True, + supports_multi_turn=True, + supports_system_prompt=True, + supports_tool_use=True, + ), + tool_event_policy=ToolEventPolicy( + behavior=ToolEventBehavior.EXECUTE, + max_tool_iterations=3, + ), + tool_backend=echo_backend, + ), + ) + + first_response = MagicMock() + first_response.json.return_value = { + "output": "", + "tool_calls": [ + { + "type": "function_call", + "call_id": "call_0", + "name": "echo", + "arguments": '{"text":"hi"}', + } + ], + } + second_response = MagicMock() + second_response.json.return_value = {"output": "The echoed text is: hi"} + + user_msg = Message( + message_pieces=[ + MessagePiece( + role="user", + original_value="Use the echo tool to repeat 'hi'.", + original_value_data_type="text", + ) + ] + ) + + with patch( + "pyrit.common.net_utility.make_request_and_raise_if_error_async", + new_callable=AsyncMock, + ) as mock_http: + mock_http.side_effect = [first_response, second_response] + normalizer = PromptNormalizer() + result = await normalizer.send_prompt_async(message=user_msg, target=target) + + assert mock_http.call_count == 2 + final_text = result.get_value() + assert "The echoed text is" in final_text + + conv = CentralMemory.get_memory_instance().get_conversation( + conversation_id=result.message_pieces[0].conversation_id + ) + # Canonical four-piece transcript: user -> assistant fc -> tool fco -> assistant text. + flat_pieces = [p for msg in conv for p in msg.message_pieces] + types = [p.original_value_data_type for p in flat_pieces] + roles = [p.api_role for p in flat_pieces] + assert types == ["text", "function_call", "function_call_output", "text"] + assert roles == ["user", "assistant", "tool", "assistant"] + + # call_id round-trips between the assistant and tool messages. + fc_envelope = json.loads(flat_pieces[1].original_value) + fco_envelope = json.loads(flat_pieces[2].original_value) + assert fc_envelope["call_id"] == fco_envelope["call_id"] == "call_0" + # Tool dispatched the local echo callable; the output reflects the args. + assert json.loads(fco_envelope["output"])["echoed"] == "hi" + + +@pytest.mark.run_only_if_all_tests +async def test_azure_ml_chat_target_no_tools_backward_compat(patch_central_database): + """Without tool_parser / tool_backend, the request body has no tools key.""" + target = AzureMLChatTarget( + endpoint="https://mock-endpoint.example.com/score", + api_key="dummy", + ) + user_msg = Message( + message_pieces=[MessagePiece(role="user", original_value="Hello", original_value_data_type="text")] + ) + + response = MagicMock() + response.json.return_value = {"output": "Hi back"} + + with patch( + "pyrit.common.net_utility.make_request_and_raise_if_error_async", + new_callable=AsyncMock, + ) as mock_http: + mock_http.return_value = response + normalizer = PromptNormalizer() + result = await normalizer.send_prompt_async(message=user_msg, target=target) + + assert result.get_value() == "Hi back" + body = mock_http.call_args.kwargs["request_body"] + assert "tools" not in body diff --git a/tests/integration/tools/test_red_teaming_with_tools.py b/tests/integration/tools/test_red_teaming_with_tools.py new file mode 100644 index 0000000000..9dca01371a --- /dev/null +++ b/tests/integration/tools/test_red_teaming_with_tools.py @@ -0,0 +1,361 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""C7 integration tests: RedTeamingAttack with real tool dispatch. + +These tests spawn the real ``tests/unit/tools/echo_mcp_server.py`` subprocess +and exercise the full client-side tool-calling stack: + + attack -> normalizer -> target -> @tool_loop wrapper -> MCPToolBackend -> + MCPClient (stdio) -> echo subprocess -> tool result -> back through the + wrapper -> Memory. + +Only the OpenAI Responses HTTP layer is mocked. The MCP subprocess, the +MCPToolBackend lock, the AsyncExitStack lifecycle, the canonical envelope +round-trip, and the @tool_loop decorator's RedTeam-attack invocation path +all execute under their real implementations. +""" + +from __future__ import annotations + +import json +import pathlib +import sys +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.executor.attack.core.attack_config import AttackAdversarialConfig, AttackScoringConfig +from pyrit.executor.attack.multi_turn.red_teaming import RedTeamingAttack +from pyrit.identifiers import ComponentIdentifier +from pyrit.memory import CentralMemory +from pyrit.models import Message, MessagePiece, Score +from pyrit.prompt_target import OpenAIResponseTarget +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.score.true_false.true_false_scorer import TrueFalseScorer +from pyrit.tools import ( + LocalMCPServerSpec, + MCPToolBackend, + ToolEventBehavior, + ToolEventPolicy, +) + + +def _mock_id(name: str) -> ComponentIdentifier: + return ComponentIdentifier(class_name=name, class_module="test") + + +ECHO_SERVER_PATH = pathlib.Path(__file__).resolve().parents[2] / "unit" / "tools" / "echo_mcp_server.py" + + +def _local_echo_spec() -> LocalMCPServerSpec: + """Build a LocalMCPServerSpec that launches the in-tree echo server.""" + return LocalMCPServerSpec( + command=sys.executable, + args=(str(ECHO_SERVER_PATH),), + ) + + +def _mock_function_call_response(call_id: str, function_name: str, arguments: dict) -> MagicMock: + """Build a fake Responses-API response containing a function_call section.""" + response = MagicMock() + response.status = "completed" + response.error = None + section = MagicMock() + section.type = "function_call" + section.call_id = call_id + section.name = function_name + section.arguments = json.dumps(arguments) + section.model_dump.return_value = { + "type": "function_call", + "call_id": call_id, + "name": function_name, + "arguments": json.dumps(arguments), + } + response.output = [section] + return response + + +def _mock_text_response(text: str) -> MagicMock: + """Build a fake Responses-API response containing a message section.""" + response = MagicMock() + response.status = "completed" + response.error = None + section = MagicMock() + section.type = "message" + section.content = [MagicMock(text=text)] + response.output = [section] + return response + + +def _make_response_target_with_mcp_backend( + backend: MCPToolBackend, +) -> OpenAIResponseTarget: + """Build an OpenAIResponseTarget wired to the live MCP backend.""" + caps = TargetCapabilities( + supports_multi_turn=True, + supports_editable_history=True, + supports_json_output=True, + supports_multi_message_pieces=True, + supports_system_prompt=True, + supports_tool_use=True, + input_modalities=frozenset( + { + frozenset(["text"]), + frozenset(["text", "image_path"]), + frozenset(["function_call"]), + frozenset(["tool_call"]), + frozenset(["function_call_output"]), + frozenset(["reasoning"]), + } + ), + ) + config = TargetConfiguration( + capabilities=caps, + tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE, max_tool_iterations=5), + tool_backend=backend, + ) + return OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + custom_configuration=config, + ) + + +def _scripted_adversarial(prompts: list[str]) -> MagicMock: + """Build a mock adversarial target that returns scripted prompts.""" + adversarial = MagicMock(spec=PromptTarget) + adversarial.send_prompt_async = AsyncMock( + side_effect=[ + [ + Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=p, + original_value_data_type="text", + conversation_id=str(uuid.uuid4()), + ) + ] + ) + ] + for p in prompts + ] + ) + adversarial.get_identifier.return_value = _mock_id("MockAdversarial") + adversarial.set_system_prompt = MagicMock() + return adversarial + + +def _success_scorer() -> MagicMock: + """Mock objective scorer that always returns True (objective met).""" + scorer = MagicMock(spec=TrueFalseScorer) + scorer.score_async = AsyncMock( + return_value=[ + Score( + score_value="true", + score_value_description="objective met", + score_type="true_false", + score_category=["test"], + score_rationale="mock rationale", + score_metadata={}, + message_piece_id=str(uuid.uuid4()), + scorer_class_identifier=_mock_id("MockScorer"), + ) + ] + ) + scorer.get_identifier.return_value = _mock_id("MockScorer") + return scorer + + +@pytest.mark.asyncio +async def test_red_teaming_response_target_with_mcp_echo(patch_central_database): + """End-to-end: RedTeamingAttack drives OpenAIResponseTarget with MCPToolBackend. + + The Response target's HTTP layer is mocked to return a function_call for + the echo tool, followed by a stop response after the tool result arrives. + The MCP subprocess actually executes the echo call. + """ + backend = MCPToolBackend(servers=[_local_echo_spec()]) + async with backend: + objective_target = _make_response_target_with_mcp_backend(backend) + + # Mock the OpenAI Responses HTTP layer on the objective target. + responses = [ + _mock_function_call_response("call_1", "echo", {"text": "hello"}), + _mock_text_response("Echoed: hello"), + ] + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return responses[len(seen) - 1] + + # Adversarial returns one prompt (RedTeamingAttack stops after objective is met) + adversarial = _scripted_adversarial(["please echo hello"]) + + attack = RedTeamingAttack( + objective_target=objective_target, + attack_adversarial_config=AttackAdversarialConfig(target=adversarial), + attack_scoring_config=AttackScoringConfig(objective_scorer=_success_scorer()), + ) + + with patch.object( + objective_target._async_client.responses, "create", new_callable=AsyncMock + ) as mock_create_call: + mock_create_call.side_effect = mock_create + result = await attack.execute_async(objective="get the model to echo 'hello'") + + # Two HTTP calls to the Response API: initial + post-tool + assert len(seen) == 2 + # Second call must include the function_call_output (tool result) + second_input = seen[1]["input"] + function_outputs = [item for item in second_input if item.get("type") == "function_call_output"] + assert len(function_outputs) == 1 + # The output JSON contains the text "hello" because the real MCP echo + # subprocess returned it + assert "hello" in function_outputs[0]["output"] + assert result is not None + + +@pytest.mark.asyncio +async def test_red_teaming_persists_canonical_transcript_in_memory(patch_central_database): + """End-to-end: after a successful tool dispatch the DB shows the full chain. + + Verifies the canonical envelope contract (§13): the conversation written + to Memory must contain the user message, the assistant function_call, the + tool function_call_output (with matching call_id), and the assistant's + final text -- in that order. + """ + backend = MCPToolBackend(servers=[_local_echo_spec()]) + async with backend: + objective_target = _make_response_target_with_mcp_backend(backend) + + responses = [ + _mock_function_call_response("call_xyz", "echo", {"text": "world"}), + _mock_text_response("Echoed: world"), + ] + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return responses[len(seen) - 1] + + adversarial = _scripted_adversarial(["echo world"]) + + attack = RedTeamingAttack( + objective_target=objective_target, + attack_adversarial_config=AttackAdversarialConfig(target=adversarial), + attack_scoring_config=AttackScoringConfig(objective_scorer=_success_scorer()), + ) + + with patch.object( + objective_target._async_client.responses, "create", new_callable=AsyncMock + ) as mock_create_call: + mock_create_call.side_effect = mock_create + result = await attack.execute_async(objective="echo world") + + # Read the conversation back from Memory + memory = CentralMemory.get_memory_instance() + assert result is not None + objective_conv_id = result.conversation_id + assert objective_conv_id, "Attack result must carry the objective-target conversation id" + + pieces = list(memory.get_message_pieces(conversation_id=objective_conv_id)) + # Filter out system prompts; we care about the user/assistant/tool chain + data_types_in_order = [p.original_value_data_type for p in pieces] + # The chain MUST contain function_call followed by function_call_output (canonical envelope) + assert "function_call" in data_types_in_order + assert "function_call_output" in data_types_in_order + + fc_index = data_types_in_order.index("function_call") + fco_index = data_types_in_order.index("function_call_output") + assert fc_index < fco_index, "function_call must precede function_call_output in DB" + + fc_envelope = json.loads(pieces[fc_index].original_value) + fco_envelope = json.loads(pieces[fco_index].original_value) + assert fc_envelope["call_id"] == fco_envelope["call_id"] == "call_xyz" + assert fc_envelope["name"] == "echo" + # The tool result envelope's `output` is JSON-encoded; the underlying echo result is "world" + assert "world" in fco_envelope["output"] + + +@pytest.mark.asyncio +async def test_red_teaming_dispatches_all_tool_calls_per_turn(patch_central_database): + """Multi-call-per-turn dispatch (intentional behavior change vs pre-C6 loop). + + When the model emits two function_call sections in one response, BOTH + must dispatch through the MCPToolBackend. The pre-C6 in-class loop in + OpenAIResponseTarget only dispatched the LAST call per turn; the C6 + migration onto @tool_loop changes this to "dispatch every call in + declaration order." Verify by issuing both an `echo` and an `add` call + and asserting both results land in the second API call's input. + """ + backend = MCPToolBackend(servers=[_local_echo_spec()]) + async with backend: + objective_target = _make_response_target_with_mcp_backend(backend) + + # First response contains TWO function_calls; second is the stop text. + multi_call_response = MagicMock() + multi_call_response.status = "completed" + multi_call_response.error = None + + echo_section = MagicMock() + echo_section.type = "function_call" + echo_section.call_id = "call_echo" + echo_section.name = "echo" + echo_section.arguments = json.dumps({"text": "hi"}) + echo_section.model_dump.return_value = { + "type": "function_call", + "call_id": "call_echo", + "name": "echo", + "arguments": json.dumps({"text": "hi"}), + } + add_section = MagicMock() + add_section.type = "function_call" + add_section.call_id = "call_add" + add_section.name = "add" + add_section.arguments = json.dumps({"a": 3, "b": 4}) + add_section.model_dump.return_value = { + "type": "function_call", + "call_id": "call_add", + "name": "add", + "arguments": json.dumps({"a": 3, "b": 4}), + } + multi_call_response.output = [echo_section, add_section] + + responses = [ + multi_call_response, + _mock_text_response("done"), + ] + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return responses[len(seen) - 1] + + adversarial = _scripted_adversarial(["call echo and add"]) + + attack = RedTeamingAttack( + objective_target=objective_target, + attack_adversarial_config=AttackAdversarialConfig(target=adversarial), + attack_scoring_config=AttackScoringConfig(objective_scorer=_success_scorer()), + ) + + with patch.object(objective_target._async_client.responses, "create", new_callable=AsyncMock) as mc: + mc.side_effect = mock_create + await attack.execute_async(objective="dispatch both tools") + + assert len(seen) == 2 + second_input = seen[1]["input"] + outputs = [item for item in second_input if item.get("type") == "function_call_output"] + assert len(outputs) == 2, "Both tool calls must be dispatched per the new behavior" + call_ids = [o["call_id"] for o in outputs] + assert call_ids == ["call_echo", "call_add"], "Outputs must preserve declaration order" + # Real MCP subprocess: echo("hi") returned "hi", add(3, 4) returned 7 + assert "hi" in outputs[0]["output"] + assert "7" in outputs[1]["output"] diff --git a/tests/unit/message_normalizer/test_chat_message_normalizer.py b/tests/unit/message_normalizer/test_chat_message_normalizer.py index b9a7cec57b..decc589bcb 100644 --- a/tests/unit/message_normalizer/test_chat_message_normalizer.py +++ b/tests/unit/message_normalizer/test_chat_message_normalizer.py @@ -319,3 +319,134 @@ async def test_returns_list_of_dicts(self): assert isinstance(result[0], dict) assert result[0]["role"] == "user" assert result[0]["content"] == "Hello" + + +class TestChatMessageNormalizerToolPieces: + """Tool-call piece coverage: function_call -> assistant.tool_calls, + function_call_output -> role=tool message with tool_call_id.""" + + async def test_function_call_piece_becomes_assistant_tool_call_message(self): + normalizer = ChatMessageNormalizer() + envelope = { + "type": "function_call", + "call_id": "call_0", + "name": "echo", + "arguments": '{"text":"hi"}', + } + fc_piece = MessagePiece( + role="assistant", + original_value=json.dumps(envelope), + original_value_data_type="function_call", + converted_value_data_type="function_call", + ) + messages = [Message(message_pieces=[fc_piece])] + + result = await normalizer.normalize_async(messages) + + assert len(result) == 1 + assert result[0].role == "assistant" + assert result[0].content is None + assert result[0].tool_calls is not None + assert len(result[0].tool_calls) == 1 + assert result[0].tool_calls[0].id == "call_0" + assert result[0].tool_calls[0].type == "function" + assert result[0].tool_calls[0].function.name == "echo" + assert result[0].tool_calls[0].function.arguments == '{"text":"hi"}' + + async def test_function_call_output_piece_becomes_tool_role_message(self): + normalizer = ChatMessageNormalizer() + envelope = { + "type": "function_call_output", + "call_id": "call_0", + "output": '{"echoed":"hi"}', + } + fco_piece = MessagePiece( + role="tool", + original_value=json.dumps(envelope), + original_value_data_type="function_call_output", + converted_value_data_type="function_call_output", + ) + messages = [Message(message_pieces=[fco_piece], skip_validation=True)] + + result = await normalizer.normalize_async(messages) + + assert len(result) == 1 + assert result[0].role == "tool" + assert result[0].tool_call_id == "call_0" + assert result[0].content == '{"echoed":"hi"}' + + async def test_full_tool_conversation_round_trip(self): + """A user -> assistant fc -> tool fco -> assistant text conversation + normalizes into the canonical OpenAI Chat Completions wire shape.""" + normalizer = ChatMessageNormalizer() + + user = _make_message("user", "Use the echo tool to repeat 'hi'.") + + fc_envelope = { + "type": "function_call", + "call_id": "call_0", + "name": "echo", + "arguments": '{"text":"hi"}', + } + assistant_fc = Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=json.dumps(fc_envelope), + original_value_data_type="function_call", + converted_value_data_type="function_call", + ) + ] + ) + + fco_envelope = { + "type": "function_call_output", + "call_id": "call_0", + "output": '{"echoed":"hi"}', + } + tool_msg = Message( + message_pieces=[ + MessagePiece( + role="tool", + original_value=json.dumps(fco_envelope), + original_value_data_type="function_call_output", + converted_value_data_type="function_call_output", + ) + ], + skip_validation=True, + ) + + assistant_final = _make_message("assistant", "The echoed text is: hi") + + result = await normalizer.normalize_async([user, assistant_fc, tool_msg, assistant_final]) + + assert [m.role for m in result] == ["user", "assistant", "tool", "assistant"] + assert result[0].content == "Use the echo tool to repeat 'hi'." + assert result[1].content is None + assert result[1].tool_calls[0].function.name == "echo" + assert result[2].tool_call_id == "call_0" + assert result[2].content == '{"echoed":"hi"}' + assert result[3].content == "The echoed text is: hi" + + async def test_function_call_output_serialized_to_dict_excludes_content_when_none(self): + """An assistant tool-call-only message must serialize without a content key.""" + normalizer = ChatMessageNormalizer() + envelope = { + "type": "function_call", + "call_id": "c1", + "name": "f", + "arguments": "{}", + } + msg = Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=json.dumps(envelope), + original_value_data_type="function_call", + converted_value_data_type="function_call", + ) + ] + ) + dicts = await normalizer.normalize_to_dicts_async([msg]) + assert "content" not in dicts[0] + assert dicts[0]["tool_calls"][0]["function"]["name"] == "f" diff --git a/tests/unit/models/test_chat_message.py b/tests/unit/models/test_chat_message.py index 8391e9340b..9cee645a47 100644 --- a/tests/unit/models/test_chat_message.py +++ b/tests/unit/models/test_chat_message.py @@ -10,19 +10,30 @@ ChatMessage, ChatMessagesDataset, ToolCall, + ToolCallFunction, ) def test_tool_call_init(): - tc = ToolCall(id="call_1", type="function", function="get_weather") + tc = ToolCall( + id="call_1", + type="function", + function=ToolCallFunction(name="get_weather", arguments='{"city":"NYC"}'), + ) assert tc.id == "call_1" assert tc.type == "function" - assert tc.function == "get_weather" + assert tc.function.name == "get_weather" + assert tc.function.arguments == '{"city":"NYC"}' def test_tool_call_forbids_extra_fields(): with pytest.raises(ValidationError): - ToolCall(id="call_1", type="function", function="get_weather", extra="bad") + ToolCall( + id="call_1", + type="function", + function=ToolCallFunction(name="get_weather", arguments="{}"), + extra="bad", + ) def test_chat_message_init_with_string_content(): @@ -41,7 +52,7 @@ def test_chat_message_init_with_list_content(): def test_chat_message_init_with_all_fields(): - tc = ToolCall(id="call_1", type="function", function="lookup") + tc = ToolCall(id="call_1", type="function", function=ToolCallFunction(name="lookup", arguments="{}")) msg = ChatMessage( role="assistant", content="result", @@ -91,13 +102,26 @@ def test_chat_message_model_validate_json_roundtrip(): def test_chat_message_model_validate_json_roundtrip_with_tool_calls(): - tc = ToolCall(id="c1", type="function", function="fn") + tc = ToolCall(id="c1", type="function", function=ToolCallFunction(name="fn", arguments="{}")) original = ChatMessage(role="assistant", content="ok", tool_calls=[tc], tool_call_id="c1") restored = ChatMessage.model_validate_json(original.model_dump_json()) assert restored.tool_calls[0].id == "c1" + assert restored.tool_calls[0].function.name == "fn" assert restored.tool_call_id == "c1" +def test_chat_message_content_allows_none_for_tool_call_only_assistant_message(): + """OpenAI Chat Completions allows assistant messages with content=null when tool_calls is set.""" + tc = ToolCall(id="c1", type="function", function=ToolCallFunction(name="fn", arguments="{}")) + msg = ChatMessage(role="assistant", content=None, tool_calls=[tc]) + assert msg.content is None + assert msg.tool_calls == [tc] + dumped = msg.to_dict() + # content is None so it should be excluded from the serialized dict. + assert "content" not in dumped + assert dumped["tool_calls"][0]["function"]["name"] == "fn" + + @pytest.mark.parametrize("role", ["system", "user", "assistant", "simulated_assistant", "tool", "developer"]) def test_chat_message_accepts_all_valid_roles(role): msg = ChatMessage(role=role, content="test") diff --git a/tests/unit/prompt_target/target/test_azure_ml_chat_target.py b/tests/unit/prompt_target/target/test_azure_ml_chat_target.py index e9517d3ec8..b269894ba1 100644 --- a/tests/unit/prompt_target/target/test_azure_ml_chat_target.py +++ b/tests/unit/prompt_target/target/test_azure_ml_chat_target.py @@ -72,7 +72,7 @@ async def test_complete_chat_async(aml_online_chat: AzureMLChatTarget): mock_response.json.return_value = {"output": "extracted response"} mock.return_value = mock_response response = await aml_online_chat._complete_chat_async(messages) - assert response == "extracted response" + assert response == {"output": "extracted response"} mock.assert_called_once() @@ -90,7 +90,7 @@ async def test_complete_chat_async_with_default_normalizer( mock_response.json.return_value = {"output": "extracted response"} mock.return_value = mock_response response = await aml_online_chat._complete_chat_async(messages) - assert response == "extracted response" + assert response == {"output": "extracted response"} args, kwargs = mock.call_args body = kwargs["request_body"] @@ -107,9 +107,12 @@ async def test_complete_chat_async_bad_json_response(aml_online_chat: AzureMLCha with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock) as mock: mock_response = MagicMock() + # Set is a non-dict body that previously raised TypeError when the code + # subscripted response.json()["output"]; the new code raises ValueError + # because the body is not a dict. mock_response.json.return_value = {"bad response"} mock.return_value = mock_response - with pytest.raises(TypeError): + with pytest.raises((TypeError, ValueError, EmptyResponseException)): await aml_online_chat._complete_chat_async(messages) @@ -178,8 +181,10 @@ async def test_send_prompt_async_rate_limit_exception_retries(aml_online_chat: A async def test_send_prompt_async_empty_response_retries(aml_online_chat: AzureMLChatTarget): response = MagicMock() response.status_code = 429 + # Return an empty dict; _materialize_response raises EmptyResponseException + # when both output and tool_calls are missing. mock_complete_chat_async = AsyncMock() - mock_complete_chat_async.return_value = None + mock_complete_chat_async.return_value = {} aml_online_chat._complete_chat_async = mock_complete_chat_async message = Message(message_pieces=[MessagePiece(role="user", conversation_id="12345", original_value="Hello")]) @@ -236,3 +241,140 @@ def test_valid_temperature_and_top_p(patch_central_database): ) assert target._temperature == 1.5 assert target._top_p == 0.9 + + +# --------------------------------------------------------------------------- +# Tool calling: tool_parser + tool_backend kwargs +# --------------------------------------------------------------------------- + + +@pytest.fixture +def echo_backend(): + from pyrit.tools import LocalToolBackend + + async def _echo(args): + return {"echoed": args.get("text", "")} + + return LocalToolBackend( + callables={"echo": _echo}, + schemas=[ + { + "name": "echo", + "description": "Echo back the given text.", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ], + ) + + +def test_tool_parser_kwarg_flips_supports_tool_use_capability(patch_central_database): + from pyrit.prompt_target.common.target_capabilities import CapabilityName + from pyrit.tools import CanonicalEnvelopeParser + + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_parser=CanonicalEnvelopeParser(), + ) + assert target.configuration.includes(capability=CapabilityName.TOOL_USE) + assert target._tool_parser is not None + + +def test_no_tool_parser_leaves_supports_tool_use_off(aml_online_chat: AzureMLChatTarget): + from pyrit.prompt_target.common.target_capabilities import CapabilityName + + assert not aml_online_chat.configuration.includes(capability=CapabilityName.TOOL_USE) + assert aml_online_chat._tool_parser is None + + +def test_tool_backend_kwarg_installed_into_configuration(patch_central_database, echo_backend): + from pyrit.tools import CanonicalEnvelopeParser + + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_parser=CanonicalEnvelopeParser(), + tool_backend=echo_backend, + ) + assert target.configuration.tool_backend is echo_backend + + +def test_tool_backend_kwarg_without_parser_raises(patch_central_database, echo_backend): + # Without tool_parser, the default configuration has supports_tool_use=False, + # so attaching a backend must raise. + with pytest.raises(ValueError, match="supports_tool_use"): + AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_backend=echo_backend, + ) + + +def test_tool_schemas_wraps_backend_schemas_in_chat_completions_shape(patch_central_database, echo_backend): + from pyrit.tools import CanonicalEnvelopeParser + + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_parser=CanonicalEnvelopeParser(), + tool_backend=echo_backend, + ) + schemas = target._tool_schemas() + assert len(schemas) == 1 + assert schemas[0]["type"] == "function" + assert schemas[0]["function"]["name"] == "echo" + + +def test_tool_schemas_empty_when_no_backend(aml_online_chat: AzureMLChatTarget): + assert aml_online_chat._tool_schemas() == [] + + +async def test_request_body_omits_tools_key_when_no_backend(aml_online_chat: AzureMLChatTarget): + messages = [Message(message_pieces=[MessagePiece(role="user", original_value="hi")])] + body = await aml_online_chat._construct_http_body_async(messages) + assert "tools" not in body + + +async def test_request_body_includes_tools_when_backend_set(patch_central_database, echo_backend): + from pyrit.tools import CanonicalEnvelopeParser + + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_parser=CanonicalEnvelopeParser(), + tool_backend=echo_backend, + ) + messages = [Message(message_pieces=[MessagePiece(role="user", original_value="hi")])] + body = await target._construct_http_body_async(messages) + assert "tools" in body + assert body["tools"][0]["function"]["name"] == "echo" + + +async def test_materialize_response_handles_text_and_tool_calls(patch_central_database, echo_backend): + from pyrit.tools import CanonicalEnvelopeParser + + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_parser=CanonicalEnvelopeParser(), + tool_backend=echo_backend, + ) + request = MessagePiece(role="user", original_value="hi", conversation_id="abc") + response = { + "output": "ok", + "tool_calls": [ + { + "type": "function_call", + "call_id": "call_0", + "name": "echo", + "arguments": '{"text":"hi"}', + } + ], + } + msg = target._materialize_response(response=response, request=request) + types = [p.original_value_data_type for p in msg.message_pieces] + assert types == ["text", "function_call"] diff --git a/tests/unit/prompt_target/target/test_huggingface_chat_target.py b/tests/unit/prompt_target/target/test_huggingface_chat_target.py index 93a4ca912f..4cddd5dbd6 100644 --- a/tests/unit/prompt_target/target/test_huggingface_chat_target.py +++ b/tests/unit/prompt_target/target/test_huggingface_chat_target.py @@ -578,3 +578,163 @@ async def test_effective_generation_config_in_metadata(): assert effective_config["temperature"] == 1.0 # Model defaults should also be present assert effective_config["eos_token_id"] == 2 + + +# --------------------------------------------------------------------------- +# Tool calling (F2): tool_parser + tool_backend kwargs +# --------------------------------------------------------------------------- + + +@pytest.fixture +def echo_backend(): + from pyrit.tools import LocalToolBackend + + async def _echo(args): + return {"echoed": args.get("text", "")} + + return LocalToolBackend( + callables={"echo": _echo}, + schemas=[ + { + "name": "echo", + "description": "Echo back the given text.", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ], + ) + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_tool_parser_kwarg_flips_supports_tool_use_capability(patch_central_database): + from pyrit.prompt_target.common.target_capabilities import CapabilityName + from pyrit.tools import InlineToolCallParser + + target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False, tool_parser=InlineToolCallParser()) + assert target.configuration.includes(capability=CapabilityName.TOOL_USE) + assert target._tool_parser is not None + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_no_tool_parser_leaves_supports_tool_use_off(patch_central_database): + from pyrit.prompt_target.common.target_capabilities import CapabilityName + + target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + assert not target.configuration.includes(capability=CapabilityName.TOOL_USE) + assert target._tool_parser is None + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_tool_backend_kwarg_installed_into_configuration(patch_central_database, echo_backend): + from pyrit.tools import InlineToolCallParser + + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + tool_parser=InlineToolCallParser(), + tool_backend=echo_backend, + ) + assert target.configuration.tool_backend is echo_backend + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_tool_backend_kwarg_without_parser_raises(patch_central_database, echo_backend): + with pytest.raises(ValueError, match="supports_tool_use"): + HuggingFaceChatTarget(model_id="test_model", use_cuda=False, tool_backend=echo_backend) + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_tool_schemas_returns_bare_backend_schemas(patch_central_database, echo_backend): + """HF chat templates accept bare schemas (no OpenAI envelope).""" + from pyrit.tools import InlineToolCallParser + + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + tool_parser=InlineToolCallParser(), + tool_backend=echo_backend, + ) + schemas = target._tool_schemas() + assert len(schemas) == 1 + assert schemas[0]["name"] == "echo" + # Unlike AzureMLChatTarget, no {"type": "function", "function": {...}} wrapper. + assert "function" not in schemas[0] + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_build_chat_messages_translates_function_call_piece(patch_central_database): + target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + fc_envelope = { + "type": "function_call", + "call_id": "call_0", + "name": "echo", + "arguments": '{"text":"hi"}', + } + msg = Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=json.dumps(fc_envelope), + original_value_data_type="function_call", + converted_value_data_type="function_call", + ) + ] + ) + chat_messages = target._build_chat_messages(normalized_conversation=[msg]) + assert len(chat_messages) == 1 + assert chat_messages[0]["role"] == "assistant" + assert chat_messages[0]["tool_calls"][0]["function"]["name"] == "echo" + assert chat_messages[0]["tool_calls"][0]["function"]["arguments"] == '{"text":"hi"}' + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_build_chat_messages_translates_function_call_output_piece(patch_central_database): + target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + fco_envelope = { + "type": "function_call_output", + "call_id": "call_0", + "output": '{"echoed":"hi"}', + } + msg = Message( + message_pieces=[ + MessagePiece( + role="tool", + original_value=json.dumps(fco_envelope), + original_value_data_type="function_call_output", + converted_value_data_type="function_call_output", + ) + ], + skip_validation=True, + ) + chat_messages = target._build_chat_messages(normalized_conversation=[msg]) + assert len(chat_messages) == 1 + assert chat_messages[0]["role"] == "tool" + assert chat_messages[0]["tool_call_id"] == "call_0" + assert chat_messages[0]["content"] == '{"echoed":"hi"}' + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_apply_chat_template_forwards_tools_when_present(patch_central_database, echo_backend): + from pyrit.tools import InlineToolCallParser + + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + tool_parser=InlineToolCallParser(), + tool_backend=echo_backend, + ) + target._apply_chat_template([{"role": "user", "content": "hi"}]) + call_kwargs = target.tokenizer.apply_chat_template.call_args.kwargs + assert "tools" in call_kwargs + assert call_kwargs["tools"][0]["name"] == "echo" + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_apply_chat_template_omits_tools_when_no_backend(patch_central_database): + target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + target._apply_chat_template([{"role": "user", "content": "hi"}]) + call_kwargs = target.tokenizer.apply_chat_template.call_args.kwargs + assert "tools" not in call_kwargs diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py index 2317bd705f..b62fa7a64b 100644 --- a/tests/unit/prompt_target/target/test_normalize_async_integration.py +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -229,7 +229,7 @@ async def test_azure_ml_target_calls_normalize_async(): with ( patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize, - patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"), + patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"}), ): mock_normalize.return_value = [user_msg] await target.send_prompt_async(message=user_msg) @@ -254,7 +254,9 @@ async def test_azure_ml_target_sends_normalized_to_complete_chat(): with ( patch.object(target.configuration, "normalize_async", new_callable=AsyncMock, return_value=[adapted_msg]), - patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat, + patch.object( + target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"} + ) as mock_chat, ): await target.send_prompt_async(message=user_msg) @@ -294,7 +296,7 @@ async def test_azure_ml_target_memory_not_mutated(): mock_memory.get_conversation.return_value = memory_conversation target._memory = mock_memory - with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"): + with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"}): await target.send_prompt_async(message=user_msg) # Memory must still have original system message only (not mutated) @@ -386,7 +388,9 @@ async def test_azure_ml_system_squash_via_configuration_pipeline(): mock_memory.get_conversation.return_value = [system_msg] target._memory = mock_memory - with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat: + with patch.object( + target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"} + ) as mock_chat: await target.send_prompt_async(message=user_msg) # _complete_chat_async should receive normalized messages (system squashed into user) diff --git a/tests/unit/prompt_target/target/test_openai_response_target_c6_migration.py b/tests/unit/prompt_target/target/test_openai_response_target_c6_migration.py new file mode 100644 index 0000000000..04dfd0b024 --- /dev/null +++ b/tests/unit/prompt_target/target/test_openai_response_target_c6_migration.py @@ -0,0 +1,304 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""C6 additions to the Response target function-chaining suite. + +Covers the migration onto @tool_loop + LocalToolBackend. +""" + +from __future__ import annotations + +import json +import uuid +import warnings +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import OpenAIResponseTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import LocalToolBackend, ToolEventBehavior, ToolEventPolicy + + +def _mock_function_call_response(call_id: str, function_name: str, arguments: dict) -> MagicMock: + """Build a fake Responses-API response containing a function_call section.""" + mock_response = MagicMock() + mock_response.status = "completed" + mock_response.error = None + section = MagicMock() + section.type = "function_call" + section.call_id = call_id + section.name = function_name + section.arguments = json.dumps(arguments) + section.model_dump.return_value = { + "type": "function_call", + "call_id": call_id, + "name": function_name, + "arguments": json.dumps(arguments), + } + mock_response.output = [section] + return mock_response + + +def _mock_text_response(text: str) -> MagicMock: + """Build a fake Responses-API response containing a message section.""" + mock_response = MagicMock() + mock_response.status = "completed" + mock_response.error = None + section = MagicMock() + section.type = "message" + section.content = [MagicMock(text=text)] + mock_response.output = [section] + return mock_response + + +def _user_msg(text: str, conversation_id: str | None = None) -> Message: + return Message( + message_pieces=[ + MessagePiece( + role="user", + original_value=text, + conversation_id=conversation_id or str(uuid.uuid4()), + ) + ] + ) + + +class TestCustomFunctionsDeprecation: + """custom_functions still works but emits DeprecationWarning.""" + + def test_custom_functions_kwarg_emits_deprecation_warning(self, patch_central_database): + async def get_weather(args: dict[str, Any]) -> dict[str, Any]: + return {"t": 72} + + with pytest.warns(DeprecationWarning, match="custom_functions"): + OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + custom_functions={"get_weather": get_weather}, + ) + + @pytest.mark.asyncio + async def test_custom_functions_kwarg_still_dispatches(self, patch_central_database): + async def get_weather(args: dict[str, Any]) -> dict[str, Any]: + return {"temperature": 72, "condition": "sunny"} + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + custom_functions={"get_weather": get_weather}, + ) + + responses = [ + _mock_function_call_response("call_1", "get_weather", {"location": "NYC"}), + _mock_text_response("72F and sunny."), + ] + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return responses[len(seen) - 1] + + with patch.object(target._async_client.responses, "create", new_callable=AsyncMock) as mc: + mc.side_effect = mock_create + result = await target.send_prompt_async(message=_user_msg("weather?")) + + assert len(seen) == 2 + assert result[-1].message_pieces[0].original_value == "72F and sunny." + second_input = seen[1]["input"] + assert any(item.get("type") == "function_call_output" for item in second_input) + + +def _config_with_backend(backend: LocalToolBackend) -> TargetConfiguration: + """Build a TargetConfiguration wired for the modern tool-backend path.""" + caps = TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_editable_history=True, + supports_json_output=True, + supports_system_prompt=True, + supports_tool_use=True, + input_modalities=frozenset( + { + frozenset(["text"]), + frozenset(["text", "image_path"]), + frozenset(["function_call"]), + frozenset(["tool_call"]), + frozenset(["function_call_output"]), + frozenset(["reasoning"]), + } + ), + ) + return TargetConfiguration( + capabilities=caps, + tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE, max_tool_iterations=5), + tool_backend=backend, + ) + + +class TestToolBackendDispatch: + """The modern path: pass tool_backend via TargetConfiguration.""" + + @pytest.mark.asyncio + async def test_local_backend_dispatches_through_tool_loop(self, patch_central_database): + async def get_weather(args: dict[str, Any]) -> dict[str, Any]: + return {"temperature": 72, "condition": "sunny"} + + backend = LocalToolBackend( + callables={"get_weather": get_weather}, + schemas=[ + { + "name": "get_weather", + "description": "Weather lookup.", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + } + ], + ) + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + custom_configuration=_config_with_backend(backend), + ) + + responses = [ + _mock_function_call_response("call_1", "get_weather", {"location": "NYC"}), + _mock_text_response("72F and sunny in NYC."), + ] + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return responses[len(seen) - 1] + + with patch.object(target._async_client.responses, "create", new_callable=AsyncMock) as mc: + mc.side_effect = mock_create + result = await target.send_prompt_async(message=_user_msg("weather?")) + + assert len(seen) == 2 + assert result[-1].message_pieces[0].original_value == "72F and sunny in NYC." + second_input = seen[1]["input"] + assert any(item.get("type") == "function_call_output" for item in second_input) + + +class TestToolSchemasInjection: + """_construct_request_body injects backend schemas when present.""" + + @pytest.mark.asyncio + async def test_backend_schemas_injected_into_tools(self, patch_central_database): + async def get_weather(args: dict[str, Any]) -> dict[str, Any]: + return {"t": 1} + + backend = LocalToolBackend( + callables={"get_weather": get_weather}, + schemas=[{"name": "get_weather", "description": "x", "parameters": {"type": "object"}}], + ) + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + custom_configuration=_config_with_backend(backend), + ) + body = await target._construct_request_body( + conversation=[_user_msg("hi")], + json_config=MagicMock(enabled=False, schema=None), + ) + assert "tools" in body + assert body["tools"][0]["type"] == "function" + assert body["tools"][0]["name"] == "get_weather" + + @pytest.mark.asyncio + async def test_extra_body_tools_take_precedence(self, patch_central_database): + async def f(args: dict[str, Any]) -> dict[str, Any]: + return {} + + backend = LocalToolBackend( + callables={"f": f}, + schemas=[{"name": "f", "parameters": {"type": "object"}}], + ) + legacy = [{"type": "function", "name": "legacy_tool", "description": "x"}] + config = _config_with_backend(backend) + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + extra_body_parameters={"tools": legacy}, + custom_configuration=config, + ) + body = await target._construct_request_body( + conversation=[_user_msg("hi")], + json_config=MagicMock(enabled=False, schema=None), + ) + assert body["tools"] == legacy + + @pytest.mark.asyncio + async def test_no_backend_no_tools_key(self, patch_central_database): + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + ) + body = await target._construct_request_body( + conversation=[_user_msg("hi")], + json_config=MagicMock(enabled=False, schema=None), + ) + assert "tools" not in body + + +class TestNonFunctionCallPiecesPassThrough: + """Reasoning / mcp_call / web_search_call sections must NOT be dispatched. + + The Response target's parser populates pieces for these types so they can + be persisted to Memory and round-tripped on subsequent requests. The + CanonicalEnvelopeParser only extracts function_call pieces; the tool loop + must therefore see an empty parse and exit cleanly. + """ + + @pytest.mark.asyncio + async def test_reasoning_only_response_exits_loop(self, patch_central_database): + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + reasoning_effort="medium", + ) + # Reasoning section + final text section in one response + mock_response = MagicMock() + mock_response.status = "completed" + mock_response.error = None + reasoning_section = MagicMock() + reasoning_section.type = "reasoning" + reasoning_section.model_dump.return_value = {"type": "reasoning", "summary": "thinking..."} + text_section = MagicMock() + text_section.type = "message" + text_section.content = [MagicMock(text="The answer is 42.")] + mock_response.output = [reasoning_section, text_section] + + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return mock_response + + with patch.object(target._async_client.responses, "create", new_callable=AsyncMock) as mc: + mc.side_effect = mock_create + result = await target.send_prompt_async(message=_user_msg("question?")) + + # Exactly one API call -- reasoning is not a tool call so the loop exits + assert len(seen) == 1 + # Response message contains both pieces + assert len(result) == 1 + piece_types = [p.original_value_data_type for p in result[0].message_pieces] + assert "reasoning" in piece_types + assert "text" in piece_types diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py index f3174c2649..fa29815103 100644 --- a/tests/unit/prompt_target/target/test_prompt_target.py +++ b/tests/unit/prompt_target/target/test_prompt_target.py @@ -22,6 +22,7 @@ UnsupportedCapabilityBehavior, ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import LocalToolBackend, ToolEventBehavior, ToolEventPolicy @pytest.fixture @@ -534,7 +535,13 @@ def test_identifier_includes_capability_params(): # Config-derived fields are nested under ``target_configuration``, not # spread at the top level — guards against accidental re-flattening. assert "supports_multi_turn" not in params - assert set(target_config.keys()) == {"capabilities", "capability_policy", "normalization_pipeline"} + assert set(target_config.keys()) == { + "capabilities", + "capability_policy", + "normalization_pipeline", + "tool_event_policy", + "tool_backend", + } assert capabilities["supports_multi_turn"] is True assert capabilities["supports_multi_message_pieces"] is True @@ -546,6 +553,8 @@ def test_identifier_includes_capability_params(): assert capabilities["output_modalities"] == [["text"]] assert isinstance(target_config["capability_policy"], dict) assert isinstance(target_config["normalization_pipeline"], list) + assert target_config["tool_event_policy"] is None + assert target_config["tool_backend"] is None @pytest.mark.usefixtures("patch_central_database") @@ -581,6 +590,62 @@ def test_identifier_differs_when_policy_differs(): assert a.get_identifier().hash != b.get_identifier().hash +@pytest.mark.usefixtures("patch_central_database") +def test_identifier_differs_when_tool_backend_differs(): + async def _f(_: dict) -> dict: + return {} + + capabilities = TargetCapabilities(supports_tool_use=True) + backend_a = LocalToolBackend( + callables={"alpha": _f}, + schemas=[{"name": "alpha", "parameters": {"type": "object"}}], + ) + backend_b = LocalToolBackend( + callables={"beta": _f}, + schemas=[{"name": "beta", "parameters": {"type": "object"}}], + ) + + a = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration(capabilities=capabilities, tool_backend=backend_a), + ) + b = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration(capabilities=capabilities, tool_backend=backend_b), + ) + + assert a.get_identifier().hash != b.get_identifier().hash + + +@pytest.mark.usefixtures("patch_central_database") +def test_identifier_differs_when_tool_event_policy_differs(): + capabilities = TargetCapabilities(supports_tool_use=True) + a = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=capabilities, + tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE), + ), + ) + b = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=capabilities, + tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.RAISE), + ), + ) + + assert a.get_identifier().hash != b.get_identifier().hash + + @pytest.mark.usefixtures("patch_central_database") def test_identifier_is_deterministic_across_instances(): capabilities = TargetCapabilities( diff --git a/tests/unit/tools/__init__.py b/tests/unit/tools/__init__.py new file mode 100644 index 0000000000..9a0454564d --- /dev/null +++ b/tests/unit/tools/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/tools/conftest.py b/tests/unit/tools/conftest.py new file mode 100644 index 0000000000..d8b0151bb7 --- /dev/null +++ b/tests/unit/tools/conftest.py @@ -0,0 +1,299 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Shared fixtures for ``tests/unit/tools``. + +Provides the minimal collaborators the tool-loop tests need to exercise +:func:`pyrit.tools.tool_loop` end-to-end without standing up real targets +or MCP transports: + +* :class:`_FakeToolTarget` — a :class:`PromptTarget` subclass whose + ``_send_prompt_to_target_async`` returns scripted messages from a queue + and whose ``_get_normalized_conversation_async`` skips the memory round + trip so decorator behavior is isolated from normalization. +* :class:`_RecordingToolBackend` — a :class:`ToolBackend` that records + every dispatched call (for order-of-execution assertions) and returns + results from a scripted queue. +* :class:`_CanonicalEnvelopeParser` — a :class:`ToolCallParser` that walks + message pieces and parses the canonical ``function_call`` JSON envelope. + +Helper message builders (``_make_user_message``, +``_make_assistant_text_message``, ``_make_assistant_function_call_message``) +produce the canonical envelope shape used by the OpenAI targets after the +normalization commit. +""" + +from __future__ import annotations + +import json +import uuid +from collections import deque +from typing import Any + +import pytest + +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import ( + LocalToolBackend, + ToolBackend, + ToolCall, + ToolCallParser, + ToolEventBehavior, + ToolEventPolicy, +) + + +def _make_user_message(text: str, *, conversation_id: str | None = None) -> Message: + """Build a single-piece user :class:`Message` carrying *text*.""" + return Message( + message_pieces=[ + MessagePiece( + role="user", + original_value=text, + original_value_data_type="text", + conversation_id=conversation_id or str(uuid.uuid4()), + ) + ] + ) + + +def _make_assistant_text_message(text: str, *, conversation_id: str | None = None) -> Message: + """Build a single-piece assistant :class:`Message` carrying plain text.""" + return Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=text, + original_value_data_type="text", + conversation_id=conversation_id or str(uuid.uuid4()), + ) + ], + skip_validation=True, + ) + + +def _make_function_call_piece( + *, + call_id: str, + name: str, + arguments: dict[str, Any], + conversation_id: str | None = None, +) -> MessagePiece: + """Build one assistant ``function_call`` piece carrying the canonical envelope.""" + envelope = { + "type": "function_call", + "call_id": call_id, + "name": name, + "arguments": json.dumps(arguments, separators=(",", ":")), + } + return MessagePiece( + role="assistant", + original_value=json.dumps(envelope, separators=(",", ":")), + original_value_data_type="function_call", + conversation_id=conversation_id or str(uuid.uuid4()), + ) + + +def _make_assistant_function_call_message( + *, + calls: list[tuple[str, str, dict[str, Any]]], + conversation_id: str | None = None, +) -> Message: + """ + Build an assistant :class:`Message` carrying one ``function_call`` piece + per ``(call_id, name, args)`` tuple, in declaration order. + """ + conv_id = conversation_id or str(uuid.uuid4()) + pieces = [ + _make_function_call_piece(call_id=cid, name=name, arguments=args, conversation_id=conv_id) + for cid, name, args in calls + ] + return Message(message_pieces=pieces, skip_validation=True) + + +class _CanonicalEnvelopeParser: + """ + Reference :class:`ToolCallParser` that understands the canonical envelope + (``original_value_data_type == "function_call"`` carrying a JSON object + with ``type``/``call_id``/``name``/``arguments``). + + Per-target parsers shipped will reuse this shape; this stand-in + keeps decorator tests independent of the real OpenAI parsers. + """ + + def parse(self, message: Message) -> list[ToolCall]: + calls: list[ToolCall] = [] + for piece in message.message_pieces: + if piece.original_value_data_type != "function_call": + continue + envelope = json.loads(piece.original_value) + arguments_str = envelope.get("arguments", "{}") + arguments = json.loads(arguments_str) if isinstance(arguments_str, str) else dict(arguments_str) + calls.append( + ToolCall( + call_id=envelope["call_id"], + name=envelope["name"], + arguments=arguments, + raw_envelope=envelope, + ) + ) + return calls + + +class _RecordingToolBackend(ToolBackend): + """ + Minimal :class:`ToolBackend` that records every dispatched call and + returns results from a scripted queue. Used to assert dispatch order, + iteration count, and per-call payload shape without invoking real tools. + """ + + def __init__( + self, + *, + scripted_results: list[Any] | None = None, + schemas: list[dict[str, Any]] | None = None, + ) -> None: + self._results: deque[Any] = deque(scripted_results or []) + self._schemas: list[dict[str, Any]] = list(schemas) if schemas is not None else [] + self.recorded_calls: list[ToolCall] = [] + + @property + def schemas(self) -> list[dict[str, Any]]: + return list(self._schemas) + + async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: + self.recorded_calls.append(call) + if not self._results: + return {"result": f"recorded:{call.name}:{call.call_id}"} + nxt = self._results.popleft() + return nxt if isinstance(nxt, dict) else {"result": nxt} + + +class _FakeToolTarget(PromptTarget): + """ + Test-only :class:`PromptTarget` whose ``_send_prompt_to_target_async`` + pops scripted responses off a queue. ``_get_normalized_conversation_async`` + is overridden to return ``[message]`` directly, isolating decorator + behavior from the memory + normalization pipeline. + + Inherits the base class's ``@final @tool_loop send_prompt_async``; the + policy + backend are wired through :class:`TargetConfiguration` so the + wrapper finds them via ``self.configuration.tool_event_policy`` and + ``self.configuration.tool_backend``. + """ + + def __init__( + self, + *, + scripted_responses: list[Message], + policy: ToolEventPolicy | None = None, + backend: Any = None, + parser: ToolCallParser | None = None, + ) -> None: + # ``supports_tool_use`` is forced on whenever a policy is configured so + # the TargetConfiguration validator accepts the backend. + caps = TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_tool_use=policy is not None, + ) + config = TargetConfiguration( + capabilities=caps, + tool_event_policy=policy, + tool_backend=backend, + ) + super().__init__(custom_configuration=config) + self._scripted_responses: deque[Message] = deque(scripted_responses) + self.call_count: int = 0 + self.normalized_conversations_seen: list[list[Message]] = [] + self._parser_instance: ToolCallParser | None = parser if parser is not None else _CanonicalEnvelopeParser() + + @property + def _tool_parser(self) -> ToolCallParser | None: + return self._parser_instance + + async def _get_normalized_conversation_async(self, *, message: Message) -> list[Message]: + return [message] + + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: + return + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + self.call_count += 1 + self.normalized_conversations_seen.append(list(normalized_conversation)) + if not self._scripted_responses: + raise AssertionError(f"Fake target ran out of scripted responses on iteration {self.call_count}.") + return [self._scripted_responses.popleft()] + + +@pytest.fixture +def make_fake_target(patch_central_database): + """ + Factory fixture for :class:`_FakeToolTarget`. Each invocation returns a + fresh target instance whose scripted response queue is independent of + other targets created during the test. + """ + + def _factory( + *, + scripted_responses: list[Message], + policy: ToolEventPolicy | None = None, + backend: Any = None, + parser: ToolCallParser | None = None, + ) -> _FakeToolTarget: + return _FakeToolTarget( + scripted_responses=scripted_responses, + policy=policy, + backend=backend, + parser=parser, + ) + + return _factory + + +@pytest.fixture +def recording_backend(): + """Factory fixture for :class:`_RecordingToolBackend`.""" + + def _factory(*, scripted_results: list[Any] | None = None) -> _RecordingToolBackend: + return _RecordingToolBackend(scripted_results=scripted_results) + + return _factory + + +@pytest.fixture +def execute_policy(): + """ + Factory fixture for :class:`ToolEventPolicy` with + ``behavior=ToolEventBehavior.EXECUTE`` and a tunable iteration cap. + """ + + def _factory(*, max_tool_iterations: int = 5) -> ToolEventPolicy: + return ToolEventPolicy( + behavior=ToolEventBehavior.EXECUTE, + max_tool_iterations=max_tool_iterations, + ) + + return _factory + + +__all__ = [ + "LocalToolBackend", + "ToolCall", + "ToolEventBehavior", + "ToolEventPolicy", + "_CanonicalEnvelopeParser", + "_FakeToolTarget", + "_RecordingToolBackend", + "_make_assistant_function_call_message", + "_make_assistant_text_message", + "_make_function_call_piece", + "_make_user_message", + "execute_policy", + "make_fake_target", + "recording_backend", +] diff --git a/tests/unit/tools/echo_mcp_server.py b/tests/unit/tools/echo_mcp_server.py new file mode 100644 index 0000000000..3ea5dbd6b5 --- /dev/null +++ b/tests/unit/tools/echo_mcp_server.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Deterministic echo MCP server used as a stdio subprocess fixture by +``tests/unit/tools/test_mcp_client.py`` and the tools integration tests. + +The harness imports this module via the ``mcp.client.stdio.stdio_client`` +launcher, so it does not need to be importable as a Python module from +``tests/unit/tools/`` callers. + +Run directly as ``python echo_mcp_server.py`` to expose the four tools +over stdio. The MCP client harness launches this file with +``mcp.client.stdio.stdio_client`` and asserts behavior end to end. +""" + +from __future__ import annotations + +import asyncio + +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP("pyrit-echo") + + +@mcp.tool() +def echo(text: str) -> str: + """Return *text* unchanged.""" + return text + + +@mcp.tool() +def add(a: int, b: int) -> int: + """Return ``a + b``.""" + return a + b + + +@mcp.tool() +def reverse(text: str) -> str: + """Return *text* reversed.""" + return text[::-1] + + +@mcp.tool() +async def slow_echo(text: str, delay_ms: int = 0) -> str: + """ + Return *text* after sleeping ``delay_ms`` milliseconds. Used by + timeout / cancellation tests. + """ + if delay_ms > 0: + await asyncio.sleep(delay_ms / 1000.0) + return text + + +if __name__ == "__main__": + mcp.run() diff --git a/tests/unit/tools/test_inline_parser.py b/tests/unit/tools/test_inline_parser.py new file mode 100644 index 0000000000..2f2850e667 --- /dev/null +++ b/tests/unit/tools/test_inline_parser.py @@ -0,0 +1,234 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Unit tests for InlineToolCallParser across marker syntaxes.""" + +from __future__ import annotations + +import logging +import uuid + +import pytest + +from pyrit.models import Message, MessagePiece +from pyrit.tools import InlineToolCallParser, InlineToolCallParserMode + + +def _assistant_text(text: str) -> Message: + return Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=text, + original_value_data_type="text", + conversation_id=str(uuid.uuid4()), + ) + ], + skip_validation=True, + ) + + +# Marker patterns commonly seen in tool-trained open chat templates. +# Named after their syntactic shape, not the model family that uses them, +# so that PyRIT does not advertise a "supported vendors" list. +ANGLE_BRACKET_PATTERN = r"(.*?)" +PIPE_PYTHON_TAG_PATTERN = r"<\|python_tag\|>(.*?)<\|eom_id\|>" +SQUARE_BRACKET_LIST_PATTERN = r"\[TOOL_CALLS\]\s*(\[.*?\])" + + +# --------------------------------------------------------------------------- +# Marker-pattern coverage: angle-bracket / pipe-python-tag / square-bracket-list +# --------------------------------------------------------------------------- + + +class TestAngleBracketMarker: + """Default pattern: ``...``.""" + + def test_single_marker_extracts_call(self): + parser = InlineToolCallParser() + text = ( + 'Let me check the weather. {"name": "get_weather", "arguments": {"location": "NYC"}}' + ) + calls = parser.parse(_assistant_text(text)) + assert len(calls) == 1 + assert calls[0].name == "get_weather" + assert calls[0].arguments == {"location": "NYC"} + assert calls[0].call_id == "call_0" + + def test_multiple_markers_extract_all_with_synthetic_ids(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.EXTRACT_ALL) + text = ( + '{"name": "f1", "arguments": {}}' + "some interleaving text " + '{"name": "f2", "arguments": {"k": 1}}' + ) + calls = parser.parse(_assistant_text(text)) + assert [c.name for c in calls] == ["f1", "f2"] + assert [c.call_id for c in calls] == ["call_0", "call_1"] + + def test_arguments_as_json_string_is_decoded(self): + parser = InlineToolCallParser() + # arguments doubly-encoded as a JSON string (canonical envelope shape). + text = '{"name": "f", "arguments": "{\\"a\\": 1}"}' + calls = parser.parse(_assistant_text(text)) + assert len(calls) == 1 + assert calls[0].arguments == {"a": 1} + + +class TestPipePythonTagMarker: + """Pipe-delimited tag pair: ``<|python_tag|>...<|eom_id|>``.""" + + def test_single_marker_extracts_call(self): + parser = InlineToolCallParser(marker_pattern=PIPE_PYTHON_TAG_PATTERN) + text = ( + 'Sure, calling now. <|python_tag|>{"name": "get_weather", "arguments": {"location": "Seattle"}}<|eom_id|>' + ) + calls = parser.parse(_assistant_text(text)) + assert len(calls) == 1 + assert calls[0].name == "get_weather" + assert calls[0].arguments == {"location": "Seattle"} + + def test_multi_marker_extract_all(self): + parser = InlineToolCallParser( + marker_pattern=PIPE_PYTHON_TAG_PATTERN, + mode=InlineToolCallParserMode.EXTRACT_ALL, + ) + text = ( + '<|python_tag|>{"name": "a", "arguments": {}}<|eom_id|>' + "between " + '<|python_tag|>{"name": "b", "arguments": {}}<|eom_id|>' + ) + calls = parser.parse(_assistant_text(text)) + assert [c.name for c in calls] == ["a", "b"] + + +class TestSquareBracketListMarker: + """Square-bracketed list payload: ``[TOOL_CALLS] [...]``.""" + + def test_list_payload_is_skipped_with_warning(self, caplog): + """The default parser expects a single-dict payload. + + Marker syntaxes whose payload is a JSON LIST of dicts (rather than a + single dict) are logged and dropped. Callers that need list-shaped + payloads should either subclass ``InlineToolCallParser`` and override + ``parse`` to iterate the list, or use a marker pattern that targets + each dict inside the list separately. Regex does not handle nested + braces well; subclassing is cleaner. + """ + parser = InlineToolCallParser(marker_pattern=SQUARE_BRACKET_LIST_PATTERN) + text = '[TOOL_CALLS] [{"name": "f", "arguments": {"x": 1}}]' + with caplog.at_level(logging.WARNING, logger="pyrit.tools.inline_parser"): + calls = parser.parse(_assistant_text(text)) + assert calls == [] + assert any("without 'name' field" in rec.message for rec in caplog.records) + + +# --------------------------------------------------------------------------- +# Mode coverage: TRUNCATE_AT_LAST / TRUNCATE_AT_FIRST / EXTRACT_ALL / +# STRICT_TRAILING_EMPTY +# --------------------------------------------------------------------------- + + +class TestParserModes: + """Surrounding-text policy coverage.""" + + HALLUCINATED = ( + '{"name": "get_weather", "arguments": {"location": "NYC"}}' + " The weather in NYC is sunny and 72 degrees." + ) + DOUBLE_CALL = ( + '{"name": "a", "arguments": {}}' + " middle " + '{"name": "b", "arguments": {}}' + " trailing chatter" + ) + + def test_truncate_at_last_default_drops_trailing_chatter(self): + parser = InlineToolCallParser() # default mode + calls = parser.parse(_assistant_text(self.HALLUCINATED)) + # The hallucinated weather report after the marker is discarded; the + # call itself is honored. + assert len(calls) == 1 + assert calls[0].name == "get_weather" + + def test_truncate_at_last_extracts_all_markers_then_drops_tail(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.TRUNCATE_AT_LAST) + calls = parser.parse(_assistant_text(self.DOUBLE_CALL)) + # Both markers honored, trailing "trailing chatter" silently dropped. + assert [c.name for c in calls] == ["a", "b"] + + def test_truncate_at_first_keeps_only_the_first_marker(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.TRUNCATE_AT_FIRST) + calls = parser.parse(_assistant_text(self.DOUBLE_CALL)) + assert [c.name for c in calls] == ["a"] + + def test_extract_all_keeps_every_marker(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.EXTRACT_ALL) + calls = parser.parse(_assistant_text(self.DOUBLE_CALL)) + assert [c.name for c in calls] == ["a", "b"] + + def test_strict_trailing_empty_raises_on_chatter(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.STRICT_TRAILING_EMPTY) + with pytest.raises(ValueError, match="STRICT_TRAILING_EMPTY"): + parser.parse(_assistant_text(self.HALLUCINATED)) + + def test_strict_trailing_empty_passes_when_only_whitespace_after(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.STRICT_TRAILING_EMPTY) + text = '{"name": "f", "arguments": {}}\n \t ' + calls = parser.parse(_assistant_text(text)) + assert [c.name for c in calls] == ["f"] + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + """Empty input, malformed JSON, missing name, multi-piece messages.""" + + def test_no_markers_returns_empty(self): + parser = InlineToolCallParser() + calls = parser.parse(_assistant_text("just plain assistant text")) + assert calls == [] + + def test_malformed_json_is_skipped_silently(self): + parser = InlineToolCallParser() + text = "not valid json" + calls = parser.parse(_assistant_text(text)) + assert calls == [] + + def test_payload_without_name_is_skipped(self): + parser = InlineToolCallParser() + text = '{"arguments": {}}' + calls = parser.parse(_assistant_text(text)) + assert calls == [] + + def test_non_text_pieces_are_ignored(self): + """Pieces with data_type other than 'text' are skipped entirely.""" + parser = InlineToolCallParser() + msg = Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value='{"name": "ignored", "arguments": {}}', + original_value_data_type="reasoning", + conversation_id=str(uuid.uuid4()), + ), + MessagePiece( + role="assistant", + original_value='{"name": "found", "arguments": {}}', + original_value_data_type="text", + conversation_id=str(uuid.uuid4()), + ), + ], + skip_validation=True, + ) + calls = parser.parse(msg) + assert [c.name for c in calls] == ["found"] + + def test_call_id_prefix_customization(self): + parser = InlineToolCallParser(call_id_prefix="custom") + text = '{"name": "f", "arguments": {}}' + calls = parser.parse(_assistant_text(text)) + assert calls[0].call_id == "custom_0" diff --git a/tests/unit/tools/test_local_tool_backend.py b/tests/unit/tools/test_local_tool_backend.py new file mode 100644 index 0000000000..13d12c9b26 --- /dev/null +++ b/tests/unit/tools/test_local_tool_backend.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for :class:`pyrit.tools.LocalToolBackend`. + +Coverage map: + +* **U10** (partial; the MCP counterpart) — + ``test_each_dummy_tool_invoked_via_prepended_conversation`` +* **U17** (partial; the MCP-timeout counterpart) — + ``test_failing_tool_yields_error_envelope`` +* **U18** — ``test_disallowed_tool_returns_error_without_invoking_callable`` + +Also covers the backend's documented behavior for missing functions +(both strict and tolerant modes), schema property defaulting, scalar +result wrapping, and declaration-order preservation in the bulk dispatch +path. These are required for the §10 rubber-duck guarantee that every +public-facing branch of :class:`LocalToolBackend` is exercised end-to-end. +""" + +from __future__ import annotations + +import pytest + +from pyrit.tools import LocalToolBackend, ToolCall + + +def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) -> ToolCall: + return ToolCall(call_id=call_id, name=name, arguments=arguments or {}) + + +async def test_disallowed_tool_returns_error_without_invoking_callable(): + invoked: list[str] = [] + + async def echo(args: dict) -> dict: + invoked.append(args.get("text", "")) + return {"echoed": args.get("text", "")} + + backend = LocalToolBackend( + callables={"echo": echo, "off_limits": echo}, + allowed_tools={"echo"}, + ) + + result = await backend.dispatch_async(_make_call("off_limits", arguments={"text": "nope"})) + + assert result["error"] == "tool_not_allowed" + assert result["tool"] == "off_limits" + assert "echo" in result["allowed_tools"] + assert invoked == [] # callable was never invoked + + +async def test_failing_tool_yields_error_envelope(): + async def boom(args: dict) -> dict: + raise RuntimeError("kaboom") + + backend = LocalToolBackend(callables={"boom": boom}) + + result = await backend.dispatch_async(_make_call("boom")) + + assert result["error"] == "tool_execution_failed" + assert result["tool"] == "boom" + assert "kaboom" in result["detail"] + + +async def test_missing_tool_raises_when_strict(): + backend = LocalToolBackend(callables={}, fail_on_missing_function=True) + + with pytest.raises(KeyError, match="ghost"): + await backend.dispatch_async(_make_call("ghost")) + + +async def test_missing_tool_returns_envelope_when_tolerant(): + async def echo(args: dict) -> dict: + return {"ok": True} + + backend = LocalToolBackend( + callables={"echo": echo}, + fail_on_missing_function=False, + ) + + result = await backend.dispatch_async(_make_call("ghost")) + + assert result["error"] == "tool_not_registered" + assert result["tool"] == "ghost" + assert result["available_tools"] == ["echo"] + + +async def test_scalar_result_is_wrapped_in_dict(): + async def number(args: dict) -> int: + return 42 + + backend = LocalToolBackend(callables={"number": number}) + + result = await backend.dispatch_async(_make_call("number")) + + assert result == {"result": 42} + + +async def test_dict_result_passes_through_unchanged(): + async def named(args: dict) -> dict: + return {"custom_key": "custom_value"} + + backend = LocalToolBackend(callables={"named": named}) + + result = await backend.dispatch_async(_make_call("named")) + + assert result == {"custom_key": "custom_value"} + + +async def test_schemas_defaults_to_empty_list(): + backend = LocalToolBackend(callables={}) + + assert backend.schemas == [] + + +async def test_schemas_returned_as_copy(): + schemas_in = [{"name": "echo", "parameters": {}}] + backend = LocalToolBackend(callables={}, schemas=schemas_in) + + out1 = backend.schemas + out1.append({"name": "mutated"}) + + # Mutating the returned list does not affect the backend's internal state. + assert backend.schemas == schemas_in + + +async def test_dispatch_all_sequential_preserves_declaration_order(): + async def echo(args: dict) -> dict: + return {"echoed": args["i"]} + + backend = LocalToolBackend(callables={"echo": echo}) + + calls = [_make_call("echo", call_id=f"c{i}", arguments={"i": i}) for i in range(5)] + pairs = await backend.dispatch_all_sequential_async(calls) + + assert [c.call_id for c, _ in pairs] == ["c0", "c1", "c2", "c3", "c4"] + assert [r["echoed"] for _, r in pairs] == [0, 1, 2, 3, 4] + + +async def test_each_dummy_tool_invoked_via_prepended_conversation(): + """ + U10 (partial). Each dummy tool resolves on first dispatch (single + forward step, no model reasoning trace), confirming the backend can + short-circuit a prepended conversation where every call is already + decided. The MCP counterpart exercises the same shape against + a real stdio server. + """ + invocations: list[tuple[str, dict]] = [] + + async def echo(args: dict) -> dict: + invocations.append(("echo", args)) + return {"echoed": args.get("text", "")} + + async def add(args: dict) -> dict: + invocations.append(("add", args)) + return {"sum": args["a"] + args["b"]} + + async def reverse(args: dict) -> dict: + invocations.append(("reverse", args)) + return {"reversed": args.get("text", "")[::-1]} + + backend = LocalToolBackend(callables={"echo": echo, "add": add, "reverse": reverse}) + + prepended_calls = [ + _make_call("echo", call_id="e1", arguments={"text": "hello"}), + _make_call("add", call_id="a1", arguments={"a": 2, "b": 3}), + _make_call("reverse", call_id="r1", arguments={"text": "pyrit"}), + ] + pairs = await backend.dispatch_all_sequential_async(prepended_calls) + + # Each dummy resolved exactly once; no retries, no model re-entry. + assert len(invocations) == 3 + assert [name for name, _ in invocations] == ["echo", "add", "reverse"] + assert [r for _, r in pairs] == [ + {"echoed": "hello"}, + {"sum": 5}, + {"reversed": "tiryp"}, + ] diff --git a/tests/unit/tools/test_mcp_backend.py b/tests/unit/tools/test_mcp_backend.py new file mode 100644 index 0000000000..7621a61dca --- /dev/null +++ b/tests/unit/tools/test_mcp_backend.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for :class:`pyrit.tools.MCPToolBackend`. + +These tests verify the multi-server fan-out and routing layer on top of +:class:`MCPClient`: schema aggregation, name-collision detection, +``name_prefix`` disambiguation, ``allowed_tools`` allow-list semantics, +and concurrent-dispatch serialization. They reuse the real +``echo_mcp_server.py`` stdio subprocess. + +Coverage map: + +* **U18** — ``test_disallowed_tool_returns_error_envelope_without_invoking_server``. +* **U20a** — ``test_name_collision_raises_value_error``. +* **U20b** — ``test_name_prefix_disambiguates_colliding_servers``. +* **U21** — ``test_concurrent_dispatch_is_serialized_by_lock``. +""" + +from __future__ import annotations + +import asyncio +import sys +from pathlib import Path + +import pytest + +from pyrit.tools import ( + LocalMCPServerSpec, + MCPToolBackend, + ToolCall, +) + +ECHO_SERVER_SCRIPT = str(Path(__file__).parent / "echo_mcp_server.py") + + +def _spec(*, name_prefix: str | None = None, timeout_seconds: float = 5.0) -> LocalMCPServerSpec: + return LocalMCPServerSpec( + command=sys.executable, + args=(ECHO_SERVER_SCRIPT,), + name_prefix=name_prefix, + timeout_seconds=timeout_seconds, + ) + + +def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) -> ToolCall: + return ToolCall(call_id=call_id, name=name, arguments=arguments or {}) + + +async def test_backend_aggregates_schemas_across_servers() -> None: + """Schemas from every connected server show up in :attr:`schemas`.""" + backend = MCPToolBackend(servers=[_spec()]) + async with backend: + names = {s["name"] for s in backend.schemas} + assert names == {"echo", "add", "reverse", "slow_echo"} + + +async def test_dispatch_routes_to_correct_server() -> None: + """A :class:`ToolCall` is routed to the server that registered the name.""" + backend = MCPToolBackend(servers=[_spec()]) + async with backend: + envelope = await backend.dispatch_async(_make_call("echo", arguments={"text": "routed"})) + assert envelope["is_error"] is False + assert envelope["content"] == "routed" + + +async def test_name_collision_raises_value_error() -> None: + """Two servers exposing the same tool name without prefixes raise.""" + backend = MCPToolBackend(servers=[_spec(), _spec()]) + with pytest.raises(ValueError, match="duplicate tool name"): + await backend.__aenter__() + # __aexit__ is the cleanup path; __aenter__ failing leaves nothing to clean. + + +async def test_name_prefix_disambiguates_colliding_servers() -> None: + """Setting :attr:`LocalMCPServerSpec.name_prefix` disambiguates duplicates.""" + backend = MCPToolBackend( + servers=[ + _spec(name_prefix="a_"), + _spec(name_prefix="b_"), + ], + ) + async with backend: + names = {s["name"] for s in backend.schemas} + assert "a_echo" in names + assert "b_echo" in names + envelope = await backend.dispatch_async(_make_call("a_echo", arguments={"text": "alpha"})) + assert envelope["content"] == "alpha" + envelope_b = await backend.dispatch_async(_make_call("b_echo", arguments={"text": "beta"})) + assert envelope_b["content"] == "beta" + + +async def test_disallowed_tool_returns_error_envelope_without_invoking_server() -> None: + """U18: allowed_tools blocks both schema advertisement AND dispatch.""" + backend = MCPToolBackend(servers=[_spec()], allowed_tools=["echo"]) + async with backend: + advertised = {s["name"] for s in backend.schemas} + assert advertised == {"echo"} # add/reverse/slow_echo are filtered out. + + envelope = await backend.dispatch_async(_make_call("add", arguments={"a": 1, "b": 2})) + assert envelope["is_error"] is True + assert envelope["error"] == "tool_not_allowed" + assert envelope["tool"] == "add" + assert envelope["allowed_tools"] == ["echo"] + + +async def test_unknown_tool_returns_error_envelope() -> None: + """A call to a name no connected server exposes returns an error envelope.""" + backend = MCPToolBackend(servers=[_spec()]) + async with backend: + envelope = await backend.dispatch_async(_make_call("never_registered")) + assert envelope["is_error"] is True + assert envelope["error"] == "tool_not_registered" + assert envelope["tool"] == "never_registered" + + +async def test_concurrent_dispatch_is_serialized_by_lock() -> None: + """U21: two coroutines dispatching against the same backend do not interleave. + + The slow_echo tool sleeps server-side; without the lock the two + dispatches would issue overlapping JSON-RPC frames over the same + stdio pipe. With the lock they run back-to-back. We assert both + return successfully — interleaved frames would surface as protocol + errors or wrong content. + """ + backend = MCPToolBackend(servers=[_spec(timeout_seconds=10.0)]) + async with backend: + results = await asyncio.gather( + backend.dispatch_async(_make_call("slow_echo", arguments={"text": "A", "delay_ms": 50})), + backend.dispatch_async(_make_call("slow_echo", arguments={"text": "B", "delay_ms": 50})), + ) + assert all(not r["is_error"] for r in results) + assert {r["content"] for r in results} == {"A", "B"} + + +async def test_dispatch_all_sequential_async_preserves_order() -> None: + """Bulk dispatch returns (call, envelope) pairs in declaration order.""" + backend = MCPToolBackend(servers=[_spec()]) + calls = [ + _make_call("echo", call_id="c1", arguments={"text": "first"}), + _make_call("echo", call_id="c2", arguments={"text": "second"}), + _make_call("echo", call_id="c3", arguments={"text": "third"}), + ] + async with backend: + results = await backend.dispatch_all_sequential_async(calls) + assert [c.call_id for c, _ in results] == ["c1", "c2", "c3"] + assert [r["content"] for _, r in results] == ["first", "second", "third"] diff --git a/tests/unit/tools/test_mcp_client.py b/tests/unit/tools/test_mcp_client.py new file mode 100644 index 0000000000..de08b479b7 --- /dev/null +++ b/tests/unit/tools/test_mcp_client.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for :class:`pyrit.tools.MCPClient` and the +:class:`pyrit.tools.MCPServerSpec` union. + +Coverage map: + +* **U10** — ``test_real_subprocess_dispatch_returns_text_content``, + ``test_sequential_dispatch_against_real_server``. +* **U14** — ``test_connect_async_populates_schemas_via_tools_list``. +* **U17** — ``test_dispatch_timeout_returns_error_envelope``. +* **U20** — ``test_remote_mcp_server_spec_raises_not_implemented``, + ``test_docker_mcp_server_spec_raises_not_implemented``. + +These tests spawn the real ``tests/unit/tools/echo_mcp_server.py`` +subprocess via ``mcp.client.stdio.stdio_client``; they exercise the +full handshake → ``tools/list`` → ``tools/call`` round trip. The +purpose is to verify that ``MCPClient`` is a thin, correct facade +over the SDK rather than to re-test the SDK itself. +""" + +from __future__ import annotations + +import dataclasses +import sys +from pathlib import Path + +import pytest + +from pyrit.tools import ( + DockerMCPServerSpec, + LocalMCPServerSpec, + MCPClient, + RemoteMCPServerSpec, + ToolCall, +) + +ECHO_SERVER_SCRIPT = str(Path(__file__).parent / "echo_mcp_server.py") + + +def _local_spec(*, timeout_seconds: float = 5.0) -> LocalMCPServerSpec: + """Build a :class:`LocalMCPServerSpec` that spawns ``echo_mcp_server.py``.""" + return LocalMCPServerSpec( + command=sys.executable, + args=(ECHO_SERVER_SCRIPT,), + timeout_seconds=timeout_seconds, + ) + + +def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) -> ToolCall: + return ToolCall(call_id=call_id, name=name, arguments=arguments or {}) + + +async def test_real_subprocess_dispatch_returns_text_content() -> None: + """U10: dispatching a single tool call returns the echo server's text response.""" + client = MCPClient(spec=_local_spec()) + async with client: + envelope = await client.dispatch_async(_make_call("echo", arguments={"text": "hi"})) + assert envelope["is_error"] is False + assert envelope["content"] == "hi" + + +async def test_sequential_dispatch_against_real_server() -> None: + """U10: multiple sequential calls round-trip through the same session.""" + client = MCPClient(spec=_local_spec()) + async with client: + envelopes = [ + await client.dispatch_async(_make_call("echo", arguments={"text": "first"})), + await client.dispatch_async(_make_call("add", arguments={"a": 2, "b": 3})), + await client.dispatch_async(_make_call("reverse", arguments={"text": "abc"})), + ] + contents = [e["content"] for e in envelopes] + assert contents == ["first", "5", "cba"] + + +async def test_connect_async_populates_schemas_via_tools_list() -> None: + """U14: schemas are discovered via tools/list during connect_async.""" + client = MCPClient(spec=_local_spec()) + async with client: + schemas = client.schemas + names = {s["name"] for s in schemas} + assert names == {"echo", "add", "reverse", "slow_echo"} + echo_schema = next(s for s in schemas if s["name"] == "echo") + assert "parameters" in echo_schema + assert echo_schema["parameters"]["properties"]["text"]["type"] == "string" + + +async def test_dispatch_timeout_returns_error_envelope() -> None: + """U17: a tool call that exceeds the spec's timeout produces an error envelope.""" + client = MCPClient(spec=_local_spec(timeout_seconds=0.05)) + async with client: + envelope = await client.dispatch_async( + _make_call("slow_echo", arguments={"text": "late", "delay_ms": 500}), + ) + assert envelope["is_error"] is True + assert envelope["error"] == "tool_timeout" + assert envelope["tool"] == "slow_echo" + + +async def test_dispatch_async_returns_error_envelope_on_unknown_tool() -> None: + """Server-side errors (unknown tool name) surface as is_error envelopes.""" + client = MCPClient(spec=_local_spec()) + async with client: + envelope = await client.dispatch_async(_make_call("nonexistent_tool")) + assert envelope["is_error"] is True + assert envelope["tool"] == "nonexistent_tool" + + +def test_remote_mcp_server_spec_is_frozen_dataclass() -> None: + """U20: RemoteMCPServerSpec exists in the type system as a frozen dataclass.""" + spec = RemoteMCPServerSpec(url="https://example.com/mcp") + assert spec.url == "https://example.com/mcp" + with pytest.raises(dataclasses.FrozenInstanceError): + spec.url = "other" # type: ignore[misc] + + +async def test_remote_mcp_server_spec_raises_not_implemented() -> None: + """U20: connecting to a RemoteMCPServerSpec raises NotImplementedError.""" + client = MCPClient(spec=RemoteMCPServerSpec(url="https://example.com/mcp")) + with pytest.raises(NotImplementedError, match="follow-up PR"): + await client.connect_async() + + +def test_docker_mcp_server_spec_dataclass_fields() -> None: + """U20: DockerMCPServerSpec carries the fields the sandbox PR will consume.""" + spec = DockerMCPServerSpec(image="pyrit-sandbox:base") + assert spec.image == "pyrit-sandbox:base" + assert spec.network_profile == "none" + assert spec.name_prefix is None + assert spec.timeout_seconds == 30.0 + + +async def test_docker_mcp_server_spec_raises_not_implemented() -> None: + """U20: connecting to a DockerMCPServerSpec raises NotImplementedError.""" + client = MCPClient(spec=DockerMCPServerSpec(image="pyrit-sandbox:base")) + with pytest.raises(NotImplementedError, match="follow-up PR"): + await client.connect_async() + + +async def test_dispatch_before_connect_raises_runtime_error() -> None: + """Calling dispatch_async before connect_async is a programmer error.""" + client = MCPClient(spec=_local_spec()) + with pytest.raises(RuntimeError, match="not connected"): + await client.dispatch_async(_make_call("echo", arguments={"text": "hi"})) + + +async def test_close_async_is_idempotent() -> None: + """Calling close_async twice (or before connect) does not raise.""" + client = MCPClient(spec=_local_spec()) + await client.close_async() # before connect — no-op. + await client.connect_async() + await client.close_async() + await client.close_async() # double-close — no-op. + + +async def test_local_mcp_server_spec_is_frozen() -> None: + """LocalMCPServerSpec is a frozen dataclass.""" + spec = LocalMCPServerSpec(command="python", args=("a.py",)) + with pytest.raises(dataclasses.FrozenInstanceError): + spec.command = "other" # type: ignore[misc] diff --git a/tests/unit/tools/test_prompt_target_tool_loop.py b/tests/unit/tools/test_prompt_target_tool_loop.py new file mode 100644 index 0000000000..518304111f --- /dev/null +++ b/tests/unit/tools/test_prompt_target_tool_loop.py @@ -0,0 +1,276 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for ``@tool_loop`` wired into ``PromptTarget.send_prompt_async``. + +The base class decorates ``send_prompt_async`` with ``@final @tool_loop``, +exposes ``_tool_parser`` and ``_tool_schemas()`` as default no-op hooks, +and ``TargetConfiguration`` carries the ``tool_event_policy`` and +``tool_backend`` kwargs the decorator consults. + +These tests use the production ``_get_normalized_conversation_async`` path +(memory round-trip through :class:`SQLiteMemory` via ``patch_central_database``) +to exercise the wrapper end-to-end. They cover: + +- U1: decorator order (validate + normalize happen exactly once, then the loop) +- U2 (DB-end half): produced ``tool`` message has one ``function_call_output`` + piece per dispatched call, in declaration order +- U8: DB inserts user, asst_with_fc, tool, asst_final in that order +- U9: DB roles + data_types match the canonical envelope +- U11: targets without a policy short-circuit (no wrapper behavior change) + +Tests for capability flag wiring + ``TargetConfiguration`` construction +validation live in :mod:`tests.unit.tools.test_tool_event_policy`. +""" + +from __future__ import annotations + +import json +from collections import deque +from typing import TYPE_CHECKING, Any + +import pytest + +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import ToolCallParser, ToolEventBehavior, ToolEventPolicy + +from .conftest import ( + _CanonicalEnvelopeParser, + _make_assistant_function_call_message, + _make_assistant_text_message, + _make_user_message, + _RecordingToolBackend, +) + +if TYPE_CHECKING: + from pyrit.models import Message + + +class _ProductionShapedTarget(PromptTarget): + """ + Minimal :class:`PromptTarget` that uses the *real* base-class + ``_get_normalized_conversation_async`` (memory round-trip + normalization + pipeline) instead of the conftest stub. Drives the production wrapper + end-to-end so DB-insert-order assertions can run against the real + :class:`CentralMemory` instance set up by ``patch_central_database``. + """ + + def __init__( + self, + *, + scripted_responses: list[Message], + policy: ToolEventPolicy | None, + backend: Any, + parser: ToolCallParser | None, + ) -> None: + caps = TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_tool_use=policy is not None, + ) + config = TargetConfiguration( + capabilities=caps, + tool_event_policy=policy, + tool_backend=backend, + ) + super().__init__(custom_configuration=config) + self._scripted: deque[Message] = deque(scripted_responses) + self.call_count: int = 0 + self._parser_instance = parser + + @property + def _tool_parser(self) -> ToolCallParser | None: + return self._parser_instance + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + self.call_count += 1 + if not self._scripted: + raise AssertionError(f"Target ran out of scripted responses on iteration {self.call_count}.") + response = self._scripted.popleft() + conversation_id = normalized_conversation[-1].message_pieces[0].conversation_id + for piece in response.message_pieces: + piece.conversation_id = conversation_id + return [response] + + +@pytest.fixture +def make_production_target(patch_central_database): + def _factory( + *, + scripted_responses: list[Message], + policy: ToolEventPolicy | None = None, + backend: Any = None, + parser: ToolCallParser | None = None, + ) -> _ProductionShapedTarget: + effective_parser = parser + if effective_parser is None and policy is not None: + effective_parser = _CanonicalEnvelopeParser() + return _ProductionShapedTarget( + scripted_responses=scripted_responses, + policy=policy, + backend=backend, + parser=effective_parser, + ) + + return _factory + + +@pytest.fixture +def execute_policy_fixture(): + return ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE, max_tool_iterations=5) + + +class TestToolLoopWiredIntoBaseClass: + """Verifies ``@tool_loop`` runs on every ``send_prompt_async`` call.""" + + async def test_decorator_passthrough_when_no_policy(self, make_production_target): + """U11 -- target without a policy behaves like a single-pass ``send_prompt_async``.""" + target = make_production_target( + scripted_responses=[_make_assistant_text_message("plain")], + policy=None, + ) + + responses = await target.send_prompt_async(message=_make_user_message("hi")) + + assert target.call_count == 1 + assert len(responses) == 1 + assert responses[0].message_pieces[0].original_value == "plain" + + async def test_tool_loop_order_after_normalize_before_memory(self, make_production_target, execute_policy_fixture): + """U1 -- validate + normalize happen exactly once before the loop iterates.""" + backend = _RecordingToolBackend(scripted_results=[{"result": "echoed"}]) + target = make_production_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "x"})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy_fixture, + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("please echo")) + + assert target.call_count == 2 + assert [c.name for c in backend.recorded_calls] == ["echo"] + assert len(responses) == 3 + assert responses[0].message_pieces[0].original_value_data_type == "function_call" + assert responses[1].message_pieces[0].original_value_data_type == "function_call_output" + assert responses[2].message_pieces[0].original_value_data_type == "text" + + async def test_tool_message_has_one_function_call_output_piece_per_call( + self, make_production_target, execute_policy_fixture + ): + """U2 DB-end half -- one tool Message, N pieces, one per dispatched call.""" + backend = _RecordingToolBackend(scripted_results=[{"r": 1}, {"r": 2}]) + target = make_production_target( + scripted_responses=[ + _make_assistant_function_call_message( + calls=[("c1", "echo", {"text": "a"}), ("c2", "echo", {"text": "b"})] + ), + _make_assistant_text_message("done"), + ], + policy=execute_policy_fixture, + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("two calls please")) + + tool_msg = responses[1] + assert len(tool_msg.message_pieces) == 2 + call_ids_in_order = [json.loads(p.original_value)["call_id"] for p in tool_msg.message_pieces] + assert call_ids_in_order == ["c1", "c2"] + assert all(p.original_value_data_type == "function_call_output" for p in tool_msg.message_pieces) + assert all(p.api_role == "tool" for p in tool_msg.message_pieces) + + +class TestDbTranscriptAfterToolLoop: + """ + DB-level assertions that exercise the production memory pipeline. + + These tests rely on the wrapper writing the user message + every assistant + + tool message produced during the loop back to ``CentralMemory``, in + declaration order. Whether that write happens *inside* the wrapper or via + the caller (the prompt normalizer) is an implementation detail; the + invariant is the wrapper returns the full chain so the caller can persist + in order. + """ + + async def test_db_insert_order_user_then_asst_fc_then_tool_then_final_asst( + self, make_production_target, execute_policy_fixture + ): + """U8 -- after a complete tool round, the wrapper's return order is canonical.""" + backend = _RecordingToolBackend(scripted_results=[{"result": "echoed"}]) + target = make_production_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "x"})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy_fixture, + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("please echo")) + + data_types_in_order = [r.message_pieces[0].original_value_data_type for r in responses] + assert data_types_in_order == ["function_call", "function_call_output", "text"] + + async def test_db_roles_and_data_types_match_canonical_envelope( + self, make_production_target, execute_policy_fixture + ): + """U9 -- roles and data_types match the canonical envelope contract.""" + backend = _RecordingToolBackend(scripted_results=[{"result": "echoed"}]) + target = make_production_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "x"})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy_fixture, + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("please echo")) + + asst_fc, tool_msg, asst_final = responses + # function_call from the assistant + assert asst_fc.message_pieces[0].api_role == "assistant" + assert asst_fc.message_pieces[0].original_value_data_type == "function_call" + envelope = json.loads(asst_fc.message_pieces[0].original_value) + assert envelope["type"] == "function_call" + assert envelope["call_id"] == "c1" + assert envelope["name"] == "echo" + # function_call_output from the tool + assert tool_msg.message_pieces[0].api_role == "tool" + assert tool_msg.message_pieces[0].original_value_data_type == "function_call_output" + tool_envelope = json.loads(tool_msg.message_pieces[0].original_value) + assert tool_envelope["type"] == "function_call_output" + assert tool_envelope["call_id"] == "c1" + # Final assistant text + assert asst_final.message_pieces[0].api_role == "assistant" + assert asst_final.message_pieces[0].original_value_data_type == "text" + + +class TestFinalAndAbstractMethodContract: + """ + Asserts the base-class shape that ``@tool_loop`` requires but does not + exercise via end-to-end runs: ``_tool_parser`` defaults to ``None``, + ``_tool_schemas`` defaults to ``[]``. + """ + + def test_default_tool_parser_is_none(self, make_production_target): + target = make_production_target( + scripted_responses=[_make_assistant_text_message("plain")], + policy=None, + ) + # Subclass overrides only when the test caller passes a parser. With + # no policy + no parser, the override returns None. + assert target._tool_parser is None + + def test_default_tool_schemas_is_empty_list(self, make_production_target): + target = make_production_target( + scripted_responses=[_make_assistant_text_message("plain")], + policy=None, + ) + assert target._tool_schemas() == [] diff --git a/tests/unit/tools/test_tool_event_policy.py b/tests/unit/tools/test_tool_event_policy.py new file mode 100644 index 0000000000..20cf283e07 --- /dev/null +++ b/tests/unit/tools/test_tool_event_policy.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for the wiring between :class:`TargetCapabilities.supports_tool_use`, +:class:`TargetConfiguration.tool_event_policy` / +:class:`TargetConfiguration.tool_backend`, and the +:func:`pyrit.tools.tool_loop` decorator that lives on +:class:`PromptTarget.send_prompt_async`. + +These tests are the §7 U7 row plus the construction-time validator. +They assert the *capability flag* axis only -- that targets which declare +``supports_tool_use=True`` and configure a policy + backend route through +the loop, that targets without a policy short-circuit, and that the +``tool_backend``-without-capability misconfiguration raises at construction. + +End-to-end ordering against the production memory pipeline (U1, U8, U9) is +exercised separately in ``tests/unit/prompt_target/common/test_prompt_target_tool_loop.py``. +""" + +from __future__ import annotations + +import pytest + +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import LocalToolBackend, ToolEventBehavior, ToolEventPolicy + +from .conftest import ( + _make_assistant_function_call_message, + _make_assistant_text_message, + _make_user_message, +) + + +class TestSupportsToolUseCapabilityFlag: + """Asserts the new ``supports_tool_use`` field on :class:`TargetCapabilities`.""" + + def test_default_is_false(self): + caps = TargetCapabilities() + assert caps.supports_tool_use is False + + def test_explicit_true(self): + caps = TargetCapabilities(supports_tool_use=True) + assert caps.supports_tool_use is True + + +class TestTargetConfigurationToolFields: + """Asserts the new ``tool_event_policy`` / ``tool_backend`` kwargs.""" + + def test_defaults_are_none(self): + caps = TargetCapabilities(supports_tool_use=True) + config = TargetConfiguration(capabilities=caps) + assert config.tool_event_policy is None + assert config.tool_backend is None + + def test_explicit_policy_and_backend(self): + caps = TargetCapabilities(supports_tool_use=True) + backend = LocalToolBackend(callables={}, schemas=[]) + policy = ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE) + config = TargetConfiguration( + capabilities=caps, + tool_event_policy=policy, + tool_backend=backend, + ) + assert config.tool_event_policy is policy + assert config.tool_backend is backend + + def test_tool_backend_without_capability_raises(self): + caps = TargetCapabilities(supports_tool_use=False) + backend = LocalToolBackend(callables={}, schemas=[]) + with pytest.raises(ValueError, match="supports_tool_use"): + TargetConfiguration(capabilities=caps, tool_backend=backend) + + def test_tool_event_policy_without_backend_is_allowed(self): + """``RAISE`` / ``RETURN_RAW`` policies do not require a backend.""" + caps = TargetCapabilities(supports_tool_use=True) + policy = ToolEventPolicy(behavior=ToolEventBehavior.RAISE) + config = TargetConfiguration(capabilities=caps, tool_event_policy=policy) + assert config.tool_event_policy is policy + assert config.tool_backend is None + + +class TestCapabilityFlagWiringIntoToolLoop: + """ + U7 -- verify the wrapper dispatches only when the target declares + ``supports_tool_use`` AND a policy is configured. + """ + + async def test_target_with_tool_use_capability_uses_tool_loop( + self, make_fake_target, recording_backend, execute_policy + ): + backend = recording_backend() + target = make_fake_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "hi"})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy(), + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("please call echo")) + + assert target.call_count == 2, "Decorator should have iterated twice (call + final)." + assert [c.name for c in backend.recorded_calls] == ["echo"] + assert len(responses) == 3, "user expects asst_fc, tool_msg, asst_final." + + async def test_target_without_tool_use_capability_skips_dispatch(self, make_fake_target): + target = make_fake_target( + scripted_responses=[_make_assistant_text_message("plain response, no tool call")], + policy=None, + backend=None, + ) + + responses = await target.send_prompt_async(message=_make_user_message("hello")) + + assert target.call_count == 1 + assert len(responses) == 1 diff --git a/tests/unit/tools/test_tool_loop_decorator.py b/tests/unit/tools/test_tool_loop_decorator.py new file mode 100644 index 0000000000..89805fbb1e --- /dev/null +++ b/tests/unit/tools/test_tool_loop_decorator.py @@ -0,0 +1,289 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for :func:`pyrit.tools.tool_loop`. + +Coverage map: + +* **U2** (partial; full-DB end) — ``test_loop_returns_full_chain_in_order`` +* **U3** — ``test_loop_exits_on_first_response_when_no_tool_calls``, + ``test_loops_until_no_pending_tool_call`` +* **U4** — ``test_raises_after_max_tool_iterations``, + ``test_partial_conversation_attached_to_limit_exception`` +* **U12** — ``test_policy_raise_includes_partial_conversation`` +* **U13** — ``test_policy_return_raw_does_not_dispatch`` +* **U16** — ``test_multi_call_per_turn_dispatched_sequentially_in_order`` + +Also covers two additional decorator concerns required by the rubber-duck +review (§10): EXECUTE policy with no backend raises with a partial +conversation, and the normalized conversation grows correctly across +iterations (decorator does not re-normalize each turn). +""" + +from __future__ import annotations + +import json + +import pytest + +from pyrit.exceptions import ToolCallLoopLimitExceeded, ToolCallNotSupported +from pyrit.tools import ToolEventBehavior, ToolEventPolicy + +from .conftest import ( + _make_assistant_function_call_message, + _make_assistant_text_message, + _make_user_message, +) + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopDecoratorBasics: + """Loop entry/exit semantics: no tool calls, single round trip, multi-round.""" + + async def test_loop_exits_on_first_response_when_no_tool_calls(self, make_fake_target, execute_policy): + target = make_fake_target( + scripted_responses=[_make_assistant_text_message("done")], + policy=execute_policy(), + ) + + responses = await target.send_prompt_async(message=_make_user_message("hi")) + + assert len(responses) == 1 + assert responses[0].get_value() == "done" + assert target.call_count == 1 + + async def test_loops_until_no_pending_tool_call(self, make_fake_target, execute_policy, recording_backend): + backend = recording_backend(scripted_results=[{"ok": True}, {"ok": True}]) + target = make_fake_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "tool_a", {"x": 1})]), + _make_assistant_function_call_message(calls=[("c2", "tool_a", {"x": 2})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy(max_tool_iterations=5), + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("hi")) + + # Two model-tool round trips and one final assistant message. + assert target.call_count == 3 + # Returned chain: fc1, tool1, fc2, tool2, final-text → 5 messages total. + assert len(responses) == 5 + assert [r.message_pieces[0].original_value_data_type for r in responses] == [ + "function_call", + "function_call_output", + "function_call", + "function_call_output", + "text", + ] + assert len(backend.recorded_calls) == 2 + assert [c.call_id for c in backend.recorded_calls] == ["c1", "c2"] + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopMessageShape: + """U2 — assistant_fc → tool → final_assistant ordering and identity.""" + + async def test_loop_returns_full_chain_in_order(self, make_fake_target, execute_policy, recording_backend): + backend = recording_backend(scripted_results=[{"weather": "sunny"}]) + fc_msg = _make_assistant_function_call_message(calls=[("call_abc", "get_weather", {"city": "Seattle"})]) + final_msg = _make_assistant_text_message("It is sunny in Seattle.") + + target = make_fake_target( + scripted_responses=[fc_msg, final_msg], + policy=execute_policy(), + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("weather?")) + + assert len(responses) == 3 + # 1) assistant with function_call (identity preserved) + assert responses[0] is fc_msg + # 2) tool message with exactly one function_call_output piece carrying call_id + tool_msg = responses[1] + assert len(tool_msg.message_pieces) == 1 + tool_piece = tool_msg.message_pieces[0] + assert tool_piece.api_role == "tool" + assert tool_piece.original_value_data_type == "function_call_output" + envelope = json.loads(tool_piece.original_value) + assert envelope["type"] == "function_call_output" + assert envelope["call_id"] == "call_abc" + # The tool result is JSON-serialized into the "output" field. + assert json.loads(envelope["output"]) == {"weather": "sunny"} + # 3) final assistant text (identity preserved) + assert responses[2] is final_msg + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopIterationLimits: + """U4 — iteration cap raises and carries the partial chain.""" + + async def test_raises_after_max_tool_iterations(self, make_fake_target, execute_policy, recording_backend): + # Model never stops asking for tools. + backend = recording_backend(scripted_results=[{"ok": True}] * 3) + target = make_fake_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[(f"c{i}", "loop_tool", {})]) for i in range(3) + ], + policy=execute_policy(max_tool_iterations=2), + backend=backend, + ) + + with pytest.raises(ToolCallLoopLimitExceeded, match="max_tool_iterations=2"): + await target.send_prompt_async(message=_make_user_message("hi")) + + # Exactly max_tool_iterations model calls made before raising. + assert target.call_count == 2 + + async def test_partial_conversation_attached_to_limit_exception( + self, make_fake_target, execute_policy, recording_backend + ): + backend = recording_backend(scripted_results=[{"ok": True}] * 2) + target = make_fake_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[(f"c{i}", "loop_tool", {})]) for i in range(2) + ], + policy=execute_policy(max_tool_iterations=2), + backend=backend, + ) + + with pytest.raises(ToolCallLoopLimitExceeded) as excinfo: + await target.send_prompt_async(message=_make_user_message("hi")) + + partial = excinfo.value.partial_conversation + # 2 iterations × (assistant_fc + tool_msg) = 4 messages, all in order. + assert len(partial) == 4 + assert [m.message_pieces[0].original_value_data_type for m in partial] == [ + "function_call", + "function_call_output", + "function_call", + "function_call_output", + ] + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolEventPolicyBehaviors: + """U12, U13 — non-EXECUTE behaviors short-circuit dispatch.""" + + async def test_policy_raise_includes_partial_conversation(self, make_fake_target, recording_backend): + backend = recording_backend(scripted_results=[{"ok": True}]) + fc_msg = _make_assistant_function_call_message(calls=[("c1", "danger", {})]) + target = make_fake_target( + scripted_responses=[fc_msg], + policy=ToolEventPolicy(behavior=ToolEventBehavior.RAISE), + backend=backend, + ) + + with pytest.raises(ToolCallNotSupported, match="RAISE") as excinfo: + await target.send_prompt_async(message=_make_user_message("hi")) + + partial = excinfo.value.partial_conversation + # Partial contains the offending assistant turn; no tool dispatch occurred. + assert partial == [fc_msg] + assert backend.recorded_calls == [] + assert target.call_count == 1 + + async def test_policy_return_raw_does_not_dispatch(self, make_fake_target, recording_backend): + backend = recording_backend(scripted_results=[{"ok": True}]) + fc_msg = _make_assistant_function_call_message(calls=[("c1", "danger", {})]) + target = make_fake_target( + scripted_responses=[fc_msg], + policy=ToolEventPolicy(behavior=ToolEventBehavior.RETURN_RAW), + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("hi")) + + assert responses == [fc_msg] + assert backend.recorded_calls == [] + assert target.call_count == 1 + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopMultiCallPerTurn: + """U16 — multi-call turns dispatch sequentially in declaration order.""" + + async def test_multi_call_per_turn_dispatched_sequentially_in_order( + self, make_fake_target, execute_policy, recording_backend + ): + backend = recording_backend(scripted_results=[{"a": 1}, {"b": 2}, {"c": 3}]) + multi_fc = _make_assistant_function_call_message( + calls=[ + ("c_alpha", "tool_alpha", {"k": "v1"}), + ("c_beta", "tool_beta", {"k": "v2"}), + ("c_gamma", "tool_gamma", {"k": "v3"}), + ] + ) + target = make_fake_target( + scripted_responses=[multi_fc, _make_assistant_text_message("ok")], + policy=execute_policy(), + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("multi")) + + # Three calls dispatched in declaration order, recorded ids match. + assert [c.call_id for c in backend.recorded_calls] == ["c_alpha", "c_beta", "c_gamma"] + assert [c.name for c in backend.recorded_calls] == ["tool_alpha", "tool_beta", "tool_gamma"] + # One tool message after the multi-call assistant turn, carrying three + # function_call_output pieces in declaration order with the right call_ids. + tool_msg = responses[1] + assert len(tool_msg.message_pieces) == 3 + envelopes = [json.loads(p.original_value) for p in tool_msg.message_pieces] + assert [e["call_id"] for e in envelopes] == ["c_alpha", "c_beta", "c_gamma"] + assert all(p.original_value_data_type == "function_call_output" for p in tool_msg.message_pieces) + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopMisconfiguration: + """EXECUTE policy with no backend must fail loudly and carry the partial chain.""" + + async def test_execute_without_backend_raises_with_partial(self, make_fake_target, execute_policy): + fc_msg = _make_assistant_function_call_message(calls=[("c1", "no_reg", {})]) + target = make_fake_target( + scripted_responses=[fc_msg], + policy=execute_policy(), + backend=None, + ) + + with pytest.raises(ToolCallNotSupported, match="tool_backend") as excinfo: + await target.send_prompt_async(message=_make_user_message("hi")) + + assert excinfo.value.partial_conversation == [fc_msg] + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopConversationGrowth: + """The decorator must extend (not re-normalize) the conversation each round.""" + + async def test_normalized_conversation_grows_each_iteration( + self, make_fake_target, execute_policy, recording_backend + ): + backend = recording_backend(scripted_results=[{"r1": 1}, {"r2": 2}]) + target = make_fake_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "t", {})]), + _make_assistant_function_call_message(calls=[("c2", "t", {})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy(), + backend=backend, + ) + + await target.send_prompt_async(message=_make_user_message("hi")) + + # Three protected-method calls; each subsequent call sees the prior + # assistant_fc + tool_msg appended (the decorator must NOT re-normalize). + seen = target.normalized_conversations_seen + assert len(seen) == 3 + # call 1: just the user message + assert len(seen[0]) == 1 + # call 2: user + assistant_fc(c1) + tool_msg + assert len(seen[1]) == 3 + assert seen[1][1].message_pieces[0].original_value_data_type == "function_call" + assert seen[1][2].message_pieces[0].original_value_data_type == "function_call_output" + # call 3: user + assistant_fc(c1) + tool_msg + assistant_fc(c2) + tool_msg + assert len(seen[2]) == 5 diff --git a/uv.lock b/uv.lock index 18a65c10f0..2adcb96aab 100644 --- a/uv.lock +++ b/uv.lock @@ -2122,6 +2122,15 @@ http2 = [ { name = "h2" }, ] +[[package]] +name = "httpx-sse" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, +] + [[package]] name = "huggingface-hub" version = "1.13.0" @@ -3193,6 +3202,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" }, ] +[[package]] +name = "mcp" +version = "1.27.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "python-multipart" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/83/d1efe7c2980d8a3afa476f4e3d42d53dd54c0ab94c27bee5d755b45c8b73/mcp-1.27.1.tar.gz", hash = "sha256:0f47e1820f8f8f941466b39749eb1d1839a04caddca2bc60e9d46e8a99914924", size = 608458, upload-time = "2026-05-08T16:50:12.601Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/73/42d9596facebdb533b7f0b86c1b0364ef350d1f8ba78b1052e8a58b48b65/mcp-1.27.1-py3-none-any.whl", hash = "sha256:1af3c4203b329430fde7a87b4fcb6392a041f5cb851fd68fc674016ab4e7c06f", size = 216260, upload-time = "2026-05-08T16:50:10.547Z" }, +] + [[package]] name = "mdit-py-plugins" version = "0.5.0" @@ -5015,6 +5049,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" }, ] +[[package]] +name = "pydantic-settings" +version = "2.14.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/07/60/1d1e59c9c90d54591469ada7d268251f71c24bdb765f1a8a832cee8c6653/pydantic_settings-2.14.1.tar.gz", hash = "sha256:e874d3bec7e787b0c9958277956ed9b4dd5de6a80e162188fdaff7c5e26fd5fa", size = 235551, upload-time = "2026-05-08T13:40:06.542Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/8d/f1af3832f5e6eb13ba94ee809e72b8ecb5eef226d27ee0bef7d963d943c7/pydantic_settings-2.14.1-py3-none-any.whl", hash = "sha256:6e3c7edfd8277687cdc598f56e5cff0e9bfff0910a3749deaa8d4401c3a2b9de", size = 60964, upload-time = "2026-05-08T13:40:04.958Z" }, +] + [[package]] name = "pydash" version = "8.0.5" @@ -5171,6 +5219,7 @@ dependencies = [ { name = "fastapi" }, { name = "httpx", extra = ["http2"] }, { name = "jinja2" }, + { name = "mcp" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" }, { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, { name = "openai" }, @@ -5308,6 +5357,7 @@ requires-dist = [ { name = "ipykernel", marker = "extra == 'all'", specifier = ">=6.29.5" }, { name = "jinja2", specifier = ">=3.1.6" }, { name = "jupyter", marker = "extra == 'all'", specifier = ">=1.1.1" }, + { name = "mcp", specifier = ">=1.0,<2" }, { name = "ml-collections", marker = "extra == 'all'", specifier = ">=1.1.0" }, { name = "ml-collections", marker = "extra == 'gcg'", specifier = ">=1.1.0" }, { name = "numpy", marker = "python_full_version < '3.14'", specifier = ">=1.26.0" }, @@ -5501,6 +5551,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/e5/fecf13f06e5e5f67e8837d777d1bc43fac0ed2b77a676804df5c34744727/python_json_logger-4.0.0-py3-none-any.whl", hash = "sha256:af09c9daf6a813aa4cc7180395f50f2a9e5fa056034c9953aec92e381c5ba1e2", size = 15548, upload-time = "2025-10-06T04:15:17.553Z" }, ] +[[package]] +name = "python-multipart" +version = "0.0.29" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/fe/70bd71a6738b09a0bdf6480ca6436b167469ca4578b2a0efbe390b4b0e70/python_multipart-0.0.29.tar.gz", hash = "sha256:643e93849196645e2dbdd81a0f8829a23123ad7f797a84a364c6fb3563f18904", size = 45678, upload-time = "2026-05-17T17:29:47.654Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/cb/769cfc37177252872a45a71f3fbdde9d51b471a3f3c14bfe95dde3407386/python_multipart-0.0.29-py3-none-any.whl", hash = "sha256:2ddcc971cef266225f54f552d8fa10bcfbb1f14446caec199060daac59ff2d69", size = 29640, upload-time = "2026-05-17T17:29:45.69Z" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -5510,6 +5569,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, ] +[[package]] +name = "pywin32" +version = "311" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/40/44efbb0dfbd33aca6a6483191dae0716070ed99e2ecb0c53683f400a0b4f/pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3", size = 8760432, upload-time = "2025-07-14T20:13:05.9Z" }, + { url = "https://files.pythonhosted.org/packages/5e/bf/360243b1e953bd254a82f12653974be395ba880e7ec23e3731d9f73921cc/pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b", size = 9590103, upload-time = "2025-07-14T20:13:07.698Z" }, + { url = "https://files.pythonhosted.org/packages/57/38/d290720e6f138086fb3d5ffe0b6caa019a791dd57866940c82e4eeaf2012/pywin32-311-cp310-cp310-win_arm64.whl", hash = "sha256:0502d1facf1fed4839a9a51ccbcc63d952cf318f78ffc00a7e78528ac27d7a2b", size = 8778557, upload-time = "2025-07-14T20:13:11.11Z" }, + { url = "https://files.pythonhosted.org/packages/7c/af/449a6a91e5d6db51420875c54f6aff7c97a86a3b13a0b4f1a5c13b988de3/pywin32-311-cp311-cp311-win32.whl", hash = "sha256:184eb5e436dea364dcd3d2316d577d625c0351bf237c4e9a5fabbcfa5a58b151", size = 8697031, upload-time = "2025-07-14T20:13:13.266Z" }, + { url = "https://files.pythonhosted.org/packages/51/8f/9bb81dd5bb77d22243d33c8397f09377056d5c687aa6d4042bea7fbf8364/pywin32-311-cp311-cp311-win_amd64.whl", hash = "sha256:3ce80b34b22b17ccbd937a6e78e7225d80c52f5ab9940fe0506a1a16f3dab503", size = 9508308, upload-time = "2025-07-14T20:13:15.147Z" }, + { url = "https://files.pythonhosted.org/packages/44/7b/9c2ab54f74a138c491aba1b1cd0795ba61f144c711daea84a88b63dc0f6c/pywin32-311-cp311-cp311-win_arm64.whl", hash = "sha256:a733f1388e1a842abb67ffa8e7aad0e70ac519e09b0f6a784e65a136ec7cefd2", size = 8703930, upload-time = "2025-07-14T20:13:16.945Z" }, + { url = "https://files.pythonhosted.org/packages/e7/ab/01ea1943d4eba0f850c3c61e78e8dd59757ff815ff3ccd0a84de5f541f42/pywin32-311-cp312-cp312-win32.whl", hash = "sha256:750ec6e621af2b948540032557b10a2d43b0cee2ae9758c54154d711cc852d31", size = 8706543, upload-time = "2025-07-14T20:13:20.765Z" }, + { url = "https://files.pythonhosted.org/packages/d1/a8/a0e8d07d4d051ec7502cd58b291ec98dcc0c3fff027caad0470b72cfcc2f/pywin32-311-cp312-cp312-win_amd64.whl", hash = "sha256:b8c095edad5c211ff31c05223658e71bf7116daa0ecf3ad85f3201ea3190d067", size = 9495040, upload-time = "2025-07-14T20:13:22.543Z" }, + { url = "https://files.pythonhosted.org/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852", size = 8710102, upload-time = "2025-07-14T20:13:24.682Z" }, + { url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700, upload-time = "2025-07-14T20:13:26.471Z" }, + { url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700, upload-time = "2025-07-14T20:13:28.243Z" }, + { url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318, upload-time = "2025-07-14T20:13:30.348Z" }, + { url = "https://files.pythonhosted.org/packages/c9/31/097f2e132c4f16d99a22bfb777e0fd88bd8e1c634304e102f313af69ace5/pywin32-311-cp314-cp314-win32.whl", hash = "sha256:b7a2c10b93f8986666d0c803ee19b5990885872a7de910fc460f9b0c2fbf92ee", size = 8840714, upload-time = "2025-07-14T20:13:32.449Z" }, + { url = "https://files.pythonhosted.org/packages/90/4b/07c77d8ba0e01349358082713400435347df8426208171ce297da32c313d/pywin32-311-cp314-cp314-win_amd64.whl", hash = "sha256:3aca44c046bd2ed8c90de9cb8427f581c479e594e99b5c0bb19b29c10fd6cb87", size = 9656800, upload-time = "2025-07-14T20:13:34.312Z" }, + { url = "https://files.pythonhosted.org/packages/c0/d2/21af5c535501a7233e734b8af901574572da66fcc254cb35d0609c9080dd/pywin32-311-cp314-cp314-win_arm64.whl", hash = "sha256:a508e2d9025764a8270f93111a970e1d0fbfc33f4153b388bb649b7eec4f9b42", size = 8932540, upload-time = "2025-07-14T20:13:36.379Z" }, +] + [[package]] name = "pywinpty" version = "3.0.2" @@ -6564,6 +6645,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/ae/57d1d7af907e20c077e113e0e4976f87b82c0a415403d99284a262229dd0/srsly-2.5.3-cp314-cp314t-win_arm64.whl", hash = "sha256:d822083fe26ec6728bd8c273ac121fc4ab3864a0fdf0cf0ff3efb188fcd209ed", size = 650229, upload-time = "2026-03-23T11:56:46.148Z" }, ] +[[package]] +name = "sse-starlette" +version = "3.4.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f7/2b/58abc2d1fd397e7dde08e947e05c884d8ef2f78d5e2588c17a12d42d6994/sse_starlette-3.4.4.tar.gz", hash = "sha256:07e0fa0460138baf25cdd5fb28683472c3995dc1642225191b3832d62526bcb0", size = 31819, upload-time = "2026-05-12T17:37:17.019Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/67/805710444ea8cc75fbf70b920ed431a560c4bf9c57f7d5a3117213189399/sse_starlette-3.4.4-py3-none-any.whl", hash = "sha256:3f4dd50d8aed2771a091f3a83000323fc3844541c16b4fe585ae2420cc6df973", size = 16514, upload-time = "2026-05-12T17:37:15.601Z" }, +] + [[package]] name = "stack-data" version = "0.6.3"