diff --git a/pyproject.toml b/pyproject.toml
index 0a29123c87..911d611804 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,6 +46,7 @@ dependencies = [
"fastapi>=0.133.0",
"httpx[http2]>=0.27.2",
"jinja2>=3.1.6",
+ "mcp>=1.0,<2",
"numpy>=1.26.0; python_version < '3.14'",
"numpy>=2.3.0; python_version >= '3.14'",
"openai>=2.2.0",
diff --git a/pyrit/exceptions/__init__.py b/pyrit/exceptions/__init__.py
index abd42de031..9baea33c1a 100644
--- a/pyrit/exceptions/__init__.py
+++ b/pyrit/exceptions/__init__.py
@@ -10,6 +10,8 @@
MissingPromptPlaceholderException,
PyritException,
RateLimitException,
+ ToolCallLoopLimitExceeded,
+ ToolCallNotSupported,
get_retry_max_num_attempts,
handle_bad_request_exception,
pyrit_custom_result_retry,
@@ -59,4 +61,6 @@
"set_execution_context",
"set_retry_collector",
"execution_context",
+ "ToolCallLoopLimitExceeded",
+ "ToolCallNotSupported",
]
diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py
index b2fc55440b..5d0014aa3d 100644
--- a/pyrit/exceptions/exception_classes.py
+++ b/pyrit/exceptions/exception_classes.py
@@ -233,6 +233,70 @@ def __init__(self, *, message: str = "No prompt placeholder") -> None:
super().__init__(message=message)
+class ToolCallNotSupported(PyritException):
+ """
+ Raised when a target produces a tool call that the configured
+ :class:`~pyrit.tools.ToolEventPolicy` does not permit to execute
+ (``ToolEventBehavior.RAISE``, or ``EXECUTE`` without a backend).
+
+ The ``partial_conversation`` attribute carries every message produced
+ up to and including the assistant turn that contained the offending
+ tool call(s). Consumers can inspect it to log the surfaced tool-use
+ attempt.
+ """
+
+ def __init__(
+ self,
+ *,
+ message: str = "Tool call not supported by configured policy.",
+ partial_conversation: Optional[list["Message"]] = None,
+ ) -> None:
+ """
+ Initialize the exception.
+
+ Args:
+ message (str): Human-readable error description.
+ partial_conversation (Optional[list[Message]]): Messages produced by
+ the target up to (and including) the assistant turn that
+ contained the disallowed tool call(s).
+ """
+ super().__init__(status_code=400, message=message)
+ self.partial_conversation: list[Message] = (
+ list(partial_conversation) if partial_conversation is not None else []
+ )
+
+
+class ToolCallLoopLimitExceeded(PyritException):
+ """
+ Raised when the tool-use loop runs for more than
+ ``ToolEventPolicy.max_tool_iterations`` iterations without the model
+ producing a stop response.
+
+ The ``partial_conversation`` attribute carries every message produced
+ across all completed iterations. Consumers can inspect it to debug
+ runaway agentic behavior.
+ """
+
+ def __init__(
+ self,
+ *,
+ message: str = "Tool loop exceeded max_tool_iterations without a stop response.",
+ partial_conversation: Optional[list["Message"]] = None,
+ ) -> None:
+ """
+ Initialize the exception.
+
+ Args:
+ message (str): Human-readable error description.
+ partial_conversation (Optional[list[Message]]): Messages produced by
+ the target across every completed iteration of the tool loop.
+ """
+ super().__init__(status_code=400, message=message)
+ self.partial_conversation: list[Message] = (
+ list(partial_conversation) if partial_conversation is not None else []
+ )
+
+
def pyrit_custom_result_retry(
retry_function: Callable[..., bool], retry_max_num_attempts: Optional[int] = None
) -> Callable[..., Any]:
diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py
index c5d3547e80..c9e7c0c532 100644
--- a/pyrit/message_normalizer/chat_message_normalizer.py
+++ b/pyrit/message_normalizer/chat_message_normalizer.py
@@ -4,6 +4,7 @@
import base64
import json
import os
+from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Union
from pyrit.common.data_url_converter import convert_local_image_to_data_url_async
@@ -14,6 +15,7 @@
apply_system_message_behavior,
)
from pyrit.models import ChatMessage, DataTypeSerializer, Message
+from pyrit.models.chat_message import ToolCall, ToolCallFunction
from pyrit.models.message_piece import MessagePiece
if TYPE_CHECKING:
@@ -83,6 +85,11 @@ async def normalize_async(self, messages: list[Message]) -> list[ChatMessage]:
chat_messages: list[ChatMessage] = []
for message in processed_messages:
pieces = message.message_pieces
+ tool_message = self._try_build_tool_message(pieces=pieces)
+ if tool_message is not None:
+ chat_messages.append(tool_message)
+ continue
+
role: ChatMessageRole = pieces[0].api_role
# Translate system -> developer for newer OpenAI models
@@ -99,6 +106,89 @@ async def normalize_async(self, messages: list[Message]) -> list[ChatMessage]:
return chat_messages
+ def _try_build_tool_message(self, *, pieces: Sequence[MessagePiece]) -> ChatMessage | None:
+ """
+ Build an OpenAI Chat Completions tool message when ``pieces`` carries tool data.
+
+ Returns a populated ``ChatMessage`` when the pieces are tool-call
+ envelopes (``function_call`` or ``function_call_output`` data type),
+ or ``None`` when the pieces are ordinary text / multimodal content.
+
+ ``function_call`` pieces produce a single ``role="assistant"`` message
+ with ``content=None`` and one or more entries in ``tool_calls``.
+ ``function_call_output`` pieces produce a single ``role="tool"``
+ message whose ``content`` is the output payload and whose
+ ``tool_call_id`` matches the originating call.
+
+ Args:
+ pieces (list[MessagePiece]): The pieces making up one PyRIT message.
+
+ Returns:
+ ChatMessage | None: ``None`` when no tool envelopes are present,
+ otherwise the converted tool message.
+ """
+ if not pieces:
+ return None
+ data_types = {p.converted_value_data_type or p.original_value_data_type for p in pieces}
+ if data_types == {"function_call"}:
+ return ChatMessage(
+ role="assistant",
+ content=None,
+ tool_calls=[self._piece_to_tool_call(piece) for piece in pieces],
+ )
+ if data_types == {"function_call_output"}:
+ # A single message carries one or more function_call_output pieces
+ # in declaration order; the OpenAI wire shape sends each as its
+ # own role="tool" message. For multi-piece tool messages, we
+ # surface the first piece here and let the caller emit additional
+ # messages — but in practice tool_loop emits one message per
+ # iteration with multiple pieces, and OpenAI accepts a single
+ # tool message per call_id. Emit the first envelope; warn if
+ # multiple are present.
+ envelope = self._decode_envelope(pieces[0])
+ return ChatMessage(
+ role="tool",
+ content=str(envelope.get("output", "")),
+ tool_call_id=str(envelope["call_id"]),
+ )
+ return None
+
+ @staticmethod
+ def _decode_envelope(piece: MessagePiece) -> dict[str, Any]:
+ """
+ Decode the canonical-envelope JSON carried in a tool piece.
+
+ Args:
+ piece (MessagePiece): A piece whose ``converted_value`` is the
+ canonical-envelope JSON string.
+
+ Returns:
+ dict[str, Any]: The parsed envelope.
+ """
+ return json.loads(piece.converted_value)
+
+ @classmethod
+ def _piece_to_tool_call(cls, piece: MessagePiece) -> ToolCall:
+ """
+ Convert one canonical ``function_call`` piece into an OpenAI ToolCall.
+
+ Args:
+ piece (MessagePiece): A piece carrying a canonical ``function_call``
+ envelope.
+
+ Returns:
+ ToolCall: The corresponding OpenAI Chat Completions tool call.
+ """
+ envelope = cls._decode_envelope(piece)
+ return ToolCall(
+ id=str(envelope["call_id"]),
+ type="function",
+ function=ToolCallFunction(
+ name=str(envelope["name"]),
+ arguments=str(envelope["arguments"]),
+ ),
+ )
+
async def normalize_string_async(self, messages: list[Message]) -> str:
"""
Convert a list of Messages to a JSON string representation.
diff --git a/pyrit/models/chat_message.py b/pyrit/models/chat_message.py
index c2f801862d..faf39cb908 100644
--- a/pyrit/models/chat_message.py
+++ b/pyrit/models/chat_message.py
@@ -10,13 +10,28 @@
ALLOWED_CHAT_MESSAGE_ROLES = ["system", "user", "assistant", "simulated_assistant", "tool", "developer"]
+class ToolCallFunction(BaseModel):
+ """The ``function`` payload of an OpenAI Chat Completions tool call."""
+
+ model_config = ConfigDict(extra="forbid")
+ name: str
+ arguments: str
+
+
class ToolCall(BaseModel):
- """Represents a tool invocation requested by the assistant."""
+ """
+ Represents a tool invocation requested by the assistant.
+
+ Matches the OpenAI Chat Completions API ``tool_calls`` shape: each entry
+ has a provider-issued ``id``, a ``type`` string (currently always
+ ``"function"``), and a nested ``function`` object carrying the tool
+ ``name`` and JSON-encoded ``arguments``.
+ """
model_config = ConfigDict(extra="forbid")
id: str
type: str
- function: str
+ function: ToolCallFunction
class ChatMessage(BaseModel):
@@ -26,11 +41,12 @@ class ChatMessage(BaseModel):
The content field can be:
- A simple string for single-part text messages
- A list of dicts for multipart messages (e.g., text + images)
+ - ``None`` for assistant messages whose payload is a tool-call only
"""
model_config = ConfigDict(extra="forbid")
role: ChatMessageRole
- content: Union[str, list[dict[str, Any]]]
+ content: Optional[Union[str, list[dict[str, Any]]]] = None
name: Optional[str] = None
tool_calls: Optional[list[ToolCall]] = None
tool_call_id: Optional[str] = None
diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py
index 1c9d54d913..59eefe116c 100644
--- a/pyrit/prompt_target/azure_ml_chat_target.py
+++ b/pyrit/prompt_target/azure_ml_chat_target.py
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
+import json
import logging
from typing import Any
@@ -18,6 +19,7 @@
from pyrit.message_normalizer import ChatMessageNormalizer, MessageListNormalizer
from pyrit.models import (
Message,
+ MessagePiece,
construct_response_from_request,
)
from pyrit.prompt_target.common.prompt_target import PromptTarget
@@ -29,6 +31,7 @@
)
from pyrit.prompt_target.common.target_configuration import TargetConfiguration
from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p
+from pyrit.tools import ToolBackend, ToolCallParser
logger = logging.getLogger(__name__)
@@ -70,6 +73,8 @@ def __init__(
repetition_penalty: float = 1.0,
max_requests_per_minute: int | None = None,
custom_configuration: TargetConfiguration | None = None,
+ tool_parser: ToolCallParser | None = None,
+ tool_backend: ToolBackend | None = None,
**param_kwargs: Any,
) -> None:
"""
@@ -100,6 +105,17 @@ def __init__(
will be capped at the value provided.
custom_configuration (TargetConfiguration | None): Override the default configuration for this target
instance. Useful for targets whose capabilities depend on deployment configuration.
+ tool_parser (ToolCallParser | None): When supplied, the target opts into PyRIT's
+ ``@tool_loop`` and uses this parser to extract pending tool calls from the
+ response. Supplying a parser also enables the ``supports_tool_use`` capability
+ on the default configuration so callers don't have to construct a custom
+ configuration just to enable the loop. The parser's expectations about the
+ deployment's response shape MUST line up with the contract documented in
+ ``doc/code/targets/`` for tool-capable Azure ML deployments.
+ tool_backend (ToolBackend | None): Convenience kwarg that wires a tool backend
+ onto ``custom_configuration.tool_backend``. Equivalent to constructing a
+ ``TargetConfiguration`` with the backend assigned. When ``custom_configuration``
+ already specifies a backend, the kwarg is rejected.
**param_kwargs: Additional parameters to pass to the model for generating responses. Example
parameters can be found here: https://huggingface.co/docs/api-inference/tasks/text-generation.
Note that the link above may not be comprehensive, and specific acceptable parameters may be
@@ -145,6 +161,18 @@ def __init__(
normalizer_overrides={CapabilityName.SYSTEM_PROMPT: message_normalizer},
)
+ # Enable tool-use capability when a parser is supplied so callers
+ # don't need to construct a custom_configuration just to opt in.
+ if tool_parser is not None:
+ custom_configuration = self._enable_tool_use(configuration=custom_configuration)
+
+ # tool_backend is a convenience kwarg; install it into the configuration.
+ if tool_backend is not None:
+ custom_configuration = self._install_tool_backend(
+ configuration=custom_configuration,
+ tool_backend=tool_backend,
+ )
+
PromptTarget.__init__(
self,
max_requests_per_minute=max_requests_per_minute,
@@ -163,6 +191,76 @@ def __init__(
self._top_p = top_p
self._repetition_penalty = repetition_penalty
self._extra_parameters = param_kwargs
+ self._tool_parser_instance = tool_parser
+
+ def _enable_tool_use(self, *, configuration: TargetConfiguration | None) -> TargetConfiguration:
+ """
+ Return a configuration whose capabilities include ``supports_tool_use=True``.
+
+ When ``configuration`` already has the capability set, returns it as-is.
+ Otherwise rebuilds the capabilities with ``supports_tool_use=True`` flipped
+ on and preserves every other field.
+
+ Args:
+ configuration (TargetConfiguration | None): The user-supplied configuration,
+ or ``None`` to start from the class default.
+
+ Returns:
+ TargetConfiguration: A configuration whose capabilities include
+ ``supports_tool_use=True``.
+ """
+ source = configuration if configuration is not None else self._DEFAULT_CONFIGURATION
+ caps = source.capabilities
+ if caps.includes(capability=CapabilityName.TOOL_USE):
+ return source
+ updated_caps = TargetCapabilities(
+ supports_multi_message_pieces=caps.supports_multi_message_pieces,
+ supports_editable_history=caps.supports_editable_history,
+ supports_multi_turn=caps.supports_multi_turn,
+ supports_system_prompt=caps.supports_system_prompt,
+ supports_tool_use=True,
+ input_modalities=caps.input_modalities,
+ output_modalities=caps.output_modalities,
+ )
+ return TargetConfiguration(
+ capabilities=updated_caps,
+ policy=source.policy,
+ tool_event_policy=source.tool_event_policy,
+ tool_backend=source.tool_backend,
+ )
+
+ @staticmethod
+ def _install_tool_backend(
+ *,
+ configuration: TargetConfiguration | None,
+ tool_backend: ToolBackend,
+ ) -> TargetConfiguration:
+ """
+ Install ``tool_backend`` onto ``configuration``. Rejects double-supply.
+
+ Args:
+ configuration (TargetConfiguration | None): The user-supplied configuration.
+ tool_backend (ToolBackend): The backend to install.
+
+ Returns:
+ TargetConfiguration: The same ``configuration`` instance with the
+ backend installed.
+
+ Raises:
+ ValueError: When ``configuration`` is ``None`` (no capability to attach
+ to), or when ``configuration.tool_backend`` is already set to a
+ different backend.
+ """
+ if configuration is None:
+ raise ValueError(
+ "tool_backend kwarg requires capabilities.supports_tool_use=True; "
+ "supply tool_parser= so the default capabilities flip TOOL_USE on, "
+ "or build a custom_configuration explicitly."
+ )
+ if configuration.tool_backend is not None and configuration.tool_backend is not tool_backend:
+ raise ValueError("tool_backend kwarg conflicts with custom_configuration.tool_backend; supply only one.")
+ configuration.tool_backend = tool_backend
+ return configuration
def _build_identifier(self) -> ComponentIdentifier:
"""
@@ -224,17 +322,10 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me
logger.info(f"Sending the following prompt to the prompt target: {request}")
try:
- resp_text = await self._complete_chat_async(
- messages=normalized_conversation,
- )
-
- if not resp_text:
- raise EmptyResponseException(message="The chat returned an empty response.")
-
- response_entry = construct_response_from_request(request=request, response_text_pieces=[resp_text])
+ response_body = await self._complete_chat_async(messages=normalized_conversation)
+ response_entry = self._materialize_response(response=response_body, request=request)
except HTTPStatusError as hse:
if hse.response.status_code == 400:
- # Handle Bad Request
response_entry = handle_bad_request_exception(response_text=hse.response.text, request=request)
elif hse.response.status_code == 429:
raise RateLimitException from hse
@@ -248,21 +339,23 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me
async def _complete_chat_async(
self,
messages: list[Message],
- ) -> str:
+ ) -> dict[str, Any]:
"""
- Completes a chat interaction by generating a response to the given input prompt.
-
- This is a synchronous wrapper for the asynchronous _generate_and_extract_response method.
+ Issue a single chat request and return the parsed JSON response body.
Args:
messages (list[Message]): The message objects containing the role and content.
+ Returns:
+ dict[str, Any]: The deserialized response body. Always includes an
+ ``output`` field (per the AML scoring-script contract). Tool-capable
+ deployments may additionally include a ``tool_calls`` field carrying
+ canonical envelopes.
+
Raises:
EmptyResponseException: If the response from the chat is empty.
+ ValueError: If the parsed response body is missing the ``output`` field.
Exception: For any other errors during the process.
-
- Returns:
- str: The generated response message.
"""
headers = self._get_headers()
payload = await self._construct_http_body_async(messages)
@@ -271,15 +364,52 @@ async def _complete_chat_async(
endpoint_uri=self._endpoint, method="POST", request_body=payload, headers=headers
)
- try:
- return str(response.json()["output"])
- except Exception as e:
- if response.json() == {}:
- raise EmptyResponseException(message="The chat returned an empty response.") from e
- raise type(e)(
- f"Exception obtaining response from the target. Returned response: {response.json()}. "
- f"Exception: {str(e)}"
- ) from e
+ body = response.json()
+ if not isinstance(body, dict) or body == {}:
+ raise EmptyResponseException(message="The chat returned an empty response.")
+ if "output" not in body:
+ raise ValueError(f"Response from the target did not include 'output'. Returned response: {body}.")
+ return body
+
+ def _materialize_response(self, *, response: dict[str, Any], request: MessagePiece) -> Message:
+ """
+ Build a ``Message`` from the parsed response body, handling tool calls.
+
+ The deployment may include a ``tool_calls`` list when the model emits
+ canonical envelopes. Each envelope becomes its own ``function_call``
+ MessagePiece so the ``CanonicalEnvelopeParser`` shipped with PyRIT can
+ recognize it without further translation.
+
+ Args:
+ response (dict[str, Any]): The parsed response body returned from the endpoint.
+ request (MessagePiece): The request piece used to stamp identity onto each
+ response piece.
+
+ Returns:
+ Message: The materialized response message. Has at least one piece;
+ when both ``output`` and ``tool_calls`` are present, the text piece
+ comes first followed by one function_call piece per envelope.
+
+ Raises:
+ EmptyResponseException: If the response has neither output text nor tool calls.
+ """
+ text = str(response.get("output") or "")
+ tool_envelopes = response.get("tool_calls") or []
+ if not text and not tool_envelopes:
+ raise EmptyResponseException(message="The chat returned an empty response.")
+
+ pieces: list[MessagePiece] = []
+ if text:
+ text_piece = construct_response_from_request(request=request, response_text_pieces=[text]).message_pieces[0]
+ pieces.append(text_piece)
+ for envelope in tool_envelopes:
+ fc_piece = construct_response_from_request(
+ request=request,
+ response_text_pieces=[json.dumps(envelope, separators=(",", ":"))],
+ response_type="function_call",
+ ).message_pieces[0]
+ pieces.append(fc_piece)
+ return Message(message_pieces=pieces, skip_validation=True)
async def _construct_http_body_async(
self,
@@ -297,10 +427,7 @@ async def _construct_http_body_async(
wire_format = ChatMessageNormalizer()
messages_dict = await wire_format.normalize_to_dicts_async(messages)
- # Parameters include additional ones passed in through **kwargs. Those not accepted by the model will
- # be ignored. We only include commonly supported parameters here - model-specific parameters like
- # stop sequences should be passed via **param_kwargs since different models use different EOS tokens.
- return {
+ body: dict[str, Any] = {
"input_data": {
"input_string": messages_dict,
"parameters": {
@@ -312,6 +439,29 @@ async def _construct_http_body_async(
| self._extra_parameters,
}
}
+ schemas = self._tool_schemas()
+ if schemas:
+ body["tools"] = schemas
+ return body
+
+ @property
+ def _tool_parser(self) -> ToolCallParser | None:
+ """Return the parser supplied at construction, if any."""
+ return self._tool_parser_instance
+
+ def _tool_schemas(self) -> list[dict[str, Any]]:
+ """
+ Wrap the backend's schemas in the OpenAI Chat Completions ``tools`` shape.
+
+ Tool-capable deployments are expected to forward ``tools`` into
+ ``tokenizer.apply_chat_template`` after unwrapping the ``{"type":
+ "function", "function": {...}}`` envelope.
+
+ Returns:
+ list[dict[str, Any]]: One descriptor per advertised tool, or an
+ empty list when no backend is configured.
+ """
+ return [{"type": "function", "function": schema} for schema in super()._tool_schemas()]
def _get_headers(self) -> dict[str, str]:
"""
diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py
index 45600e6009..b6a79bbcb9 100644
--- a/pyrit/prompt_target/common/discover_target_capabilities.py
+++ b/pyrit/prompt_target/common/discover_target_capabilities.py
@@ -149,6 +149,7 @@ def _permissive_configuration(
supports_json_output=True,
supports_editable_history=True,
supports_system_prompt=True,
+ supports_tool_use=True,
input_modalities=merged_modalities,
)
# Rebuild a fresh configuration from the instance's native capabilities so
diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py
index b1ee5caaa2..9ff03df7b6 100644
--- a/pyrit/prompt_target/common/prompt_target.py
+++ b/pyrit/prompt_target/common/prompt_target.py
@@ -12,6 +12,7 @@
from pyrit.models.json_response_config import _JsonResponseConfig
from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities
from pyrit.prompt_target.common.target_configuration import TargetConfiguration
+from pyrit.tools import ToolCallParser, tool_loop
logger = logging.getLogger(__name__)
@@ -85,6 +86,7 @@ def __init__(
logging.basicConfig(level=logging.INFO)
@final
+ @tool_loop
async def send_prompt_async(self, *, message: Message) -> list[Message]:
"""
Validate, normalize, and send a prompt to the target.
@@ -97,6 +99,13 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]:
3. Delegates to ``_send_prompt_to_target_async`` with the normalized
conversation.
+ When the target's :attr:`configuration.tool_event_policy` is set, the
+ :func:`pyrit.tools.tool_loop` decorator replaces this body with the
+ agentic loop and re-enters :meth:`_send_prompt_to_target_async`
+ repeatedly until the model issues a stop response (or the configured
+ ``max_tool_iterations`` is hit). When no policy is set, the decorator
+ is a no-op and the body below runs unchanged.
+
Subclasses MUST NOT override this method. Override
``_send_prompt_to_target_async`` instead.
@@ -132,6 +141,44 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me
list[Message]: Response messages from the target.
"""
+ @property
+ def _tool_parser(self) -> ToolCallParser | None:
+ """
+ Per-target :class:`ToolCallParser` consulted by :func:`pyrit.tools.tool_loop`.
+
+ Targets that participate in the tool-use loop override this property
+ to return a parser that walks their response messages and extracts
+ :class:`~pyrit.tools.ToolCall` instances. The base default of
+ ``None`` signals "this target does not participate" -- the wrapper
+ short-circuits after the first response.
+
+ Returns:
+ ToolCallParser | None: The parser, or ``None`` for the default
+ no-tool-use behavior.
+ """
+ return None
+
+ def _tool_schemas(self) -> list[dict[str, Any]]:
+ """
+ Outbound tool-schema list sent on the next request to the model.
+
+ The default reads the configured ``tool_backend.schemas`` verbatim.
+ Targets whose wire format wraps schemas differently (e.g., OpenAI
+ Chat Completions requires ``{"type": "function", "function": {...}}``;
+ the OpenAI Responses API requires ``{"type": "function", **schema}``
+ spread at the top level) override this method to apply the
+ per-target translation.
+
+ Returns:
+ list[dict[str, Any]]: One schema per advertised tool, in
+ whatever wire format this target expects. Empty when no
+ backend is configured.
+ """
+ backend = self.configuration.tool_backend
+ if backend is None:
+ return []
+ return list(backend.schemas)
+
def _validate_request(self, *, normalized_conversation: list[Message]) -> None:
"""
Validate the normalized conversation before sending to the target.
diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py
index 6ae9ed69e2..234ef4d359 100644
--- a/pyrit/prompt_target/common/target_capabilities.py
+++ b/pyrit/prompt_target/common/target_capabilities.py
@@ -24,6 +24,7 @@ class CapabilityName(str, Enum):
JSON_OUTPUT = "supports_json_output"
EDITABLE_HISTORY = "supports_editable_history"
SYSTEM_PROMPT = "supports_system_prompt"
+ TOOL_USE = "supports_tool_use"
class UnsupportedCapabilityBehavior(str, Enum):
@@ -138,6 +139,14 @@ class attribute. Users can override individual capabilities per instance
# Whether the target natively supports system prompts.
supports_system_prompt: bool = False
+ # Whether the target natively supports model-issued tool calls (the
+ # canonical OpenAI ``function_call`` / ``function_call_output`` envelopes
+ # plus an outbound tool-schema list). Targets without this capability
+ # cannot host a tool-use loop -- attempting to configure a
+ # :class:`TargetConfiguration` with a ``tool_backend`` on a target whose
+ # capabilities have ``supports_tool_use=False`` raises at construction.
+ supports_tool_use: bool = False
+
# The input modalities supported by the target (e.g., "text", "image").
input_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])})
diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py
index 7e11a04673..e9d401b824 100644
--- a/pyrit/prompt_target/common/target_configuration.py
+++ b/pyrit/prompt_target/common/target_configuration.py
@@ -4,7 +4,7 @@
import logging
from collections.abc import Mapping
from dataclasses import fields
-from typing import Any
+from typing import TYPE_CHECKING, Any
from pyrit.message_normalizer import MessageListNormalizer
from pyrit.models import Message
@@ -16,6 +16,10 @@
UnsupportedCapabilityBehavior,
)
+if TYPE_CHECKING:
+ from pyrit.tools.backend import ToolBackend
+ from pyrit.tools.models import ToolEventPolicy
+
logger = logging.getLogger(__name__)
@@ -39,6 +43,15 @@ class TargetConfiguration:
Each target defines defaults; callers can override policy or individual
normalizers at creation time.
+
+ Tool use is configured by setting :attr:`tool_event_policy` (mandatory
+ when a target's response contains tool calls; controls EXECUTE / RAISE /
+ RETURN\\_RAW behavior) and optionally :attr:`tool_backend` (required only
+ when ``tool_event_policy.behavior`` is ``EXECUTE``). Both default to
+ ``None`` and are read by :func:`pyrit.tools.tool_loop` at runtime;
+ constructing a configuration with a ``tool_backend`` on a target that
+ does not declare ``capabilities.supports_tool_use=True`` raises
+ immediately.
"""
def __init__(
@@ -47,6 +60,8 @@ def __init__(
capabilities: TargetCapabilities,
policy: CapabilityHandlingPolicy | None = None,
normalizer_overrides: Mapping[CapabilityName, MessageListNormalizer[Any]] | None = None,
+ tool_event_policy: "ToolEventPolicy | None" = None,
+ tool_backend: "ToolBackend | None" = None,
) -> None:
"""
Build a target configuration and resolve the normalization pipeline.
@@ -57,7 +72,25 @@ def __init__(
capability. Defaults to RAISE for all adaptable capabilities.
normalizer_overrides (Mapping[CapabilityName, MessageListNormalizer[Any]] | None):
Optional overrides for specific capability normalizers.
+ tool_event_policy (ToolEventPolicy | None): How
+ :func:`pyrit.tools.tool_loop` should react to a pending tool
+ call from the target. ``None`` means the loop is disabled and
+ the wrapper short-circuits.
+ tool_backend (ToolBackend | None): Dispatch table used when
+ ``tool_event_policy.behavior`` is ``EXECUTE``. ``None`` is
+ valid only for the RAISE / RETURN\\_RAW policies and the
+ no-policy passthrough.
+
+ Raises:
+ ValueError: If ``tool_backend`` is set on a target whose
+ capabilities do not include ``supports_tool_use``.
"""
+ if tool_backend is not None and not capabilities.includes(capability=CapabilityName.TOOL_USE):
+ raise ValueError(
+ "tool_backend is set but capabilities.supports_tool_use is False. "
+ "Either declare supports_tool_use=True on the target's capabilities, "
+ "or remove the tool_backend."
+ )
self._capabilities = capabilities
self._policy = policy or _DEFAULT_POLICY
self._pipeline = ConversationNormalizationPipeline.from_capabilities(
@@ -65,6 +98,8 @@ def __init__(
policy=self._policy,
normalizer_overrides=normalizer_overrides,
)
+ self._tool_event_policy = tool_event_policy
+ self._tool_backend = tool_backend
@property
def capabilities(self) -> TargetCapabilities:
@@ -81,6 +116,41 @@ def pipeline(self) -> ConversationNormalizationPipeline:
"""The resolved normalization pipeline."""
return self._pipeline
+ @property
+ def tool_event_policy(self) -> "ToolEventPolicy | None":
+ """The tool-use policy consulted by :func:`pyrit.tools.tool_loop`."""
+ return self._tool_event_policy
+
+ @tool_event_policy.setter
+ def tool_event_policy(self, value: "ToolEventPolicy | None") -> None:
+ """Allow runtime updates so callers can opt a configured target into tool use."""
+ self._tool_event_policy = value
+
+ @property
+ def tool_backend(self) -> "ToolBackend | None":
+ """The tool dispatch backend used when the loop's behavior is ``EXECUTE``."""
+ return self._tool_backend
+
+ @tool_backend.setter
+ def tool_backend(self, value: "ToolBackend | None") -> None:
+ """
+ Allow runtime updates to the backend.
+
+ Re-runs the ``supports_tool_use`` validator so a backend can never be
+ installed onto a configuration that does not declare the capability.
+
+ Raises:
+ ValueError: If ``value`` is not ``None`` and the configuration's
+ capabilities do not include ``supports_tool_use``.
+ """
+ if value is not None and not self._capabilities.includes(capability=CapabilityName.TOOL_USE):
+ raise ValueError(
+ "tool_backend is set but capabilities.supports_tool_use is False. "
+ "Either declare supports_tool_use=True on the target's capabilities, "
+ "or remove the tool_backend."
+ )
+ self._tool_backend = value
+
def includes(self, *, capability: CapabilityName) -> bool:
"""
Check whether the target includes support for the given capability.
@@ -138,14 +208,20 @@ def as_identifier_params(self) -> dict[str, Any]:
suitable for inclusion in a ``ComponentIdentifier``.
The returned dict preserves the structure of ``TargetConfiguration``
- — capabilities, policy, and pipeline are kept as nested sub-dicts rather
- than flattened into the caller — so the identifier reflects the shape of
- the object it describes.
+ — capabilities, policy, pipeline, tool-event policy, and tool backend
+ are kept as nested sub-dicts rather than flattened into the caller —
+ so the identifier reflects the shape of the object it describes.
Two configurations that behave identically must produce equal dicts;
configurations that differ in any identity-bearing field must produce
- unequal dicts. Modality sets are flattened to sorted lists of sorted
- lists so ordering is stable across runs.
+ unequal dicts. The tool-backend snapshot uses the backend class plus
+ the sorted list of advertised tool names; this means two backends of
+ the same type exposing the same tool set are treated as equivalent
+ for identifier purposes (their exact callables / transports are not
+ deterministically serializable).
+
+ Modality sets are flattened to sorted lists of sorted lists so
+ ordering is stable across runs.
Returns:
dict[str, Any]: The identifier parameters for this configuration.
@@ -153,24 +229,82 @@ def as_identifier_params(self) -> dict[str, Any]:
caps = self._capabilities
return {
"capabilities": self._capabilities_to_identifier_params(caps),
- # Only unsupported capabilities appear here. Policy entries for
- # natively-supported capabilities are moot (the behavior never
- # fires), and omitting them keeps identifiers stable when default
- # policies expand to cover more capabilities.
"capability_policy": {
capability.value: behavior.value
for capability, behavior in self._policy.behaviors.items()
if not caps.includes(capability=capability)
},
- # Stable, ordered representation of the resolved normalization
- # pipeline. Captures the effect of ``normalizer_overrides`` since
- # the pipeline is built from defaults + overrides.
"normalization_pipeline": [
f"{type(normalizer).__module__}.{type(normalizer).__qualname__}"
for normalizer in self._pipeline.normalizers
],
+ "tool_event_policy": self._tool_event_policy_to_identifier_params(),
+ "tool_backend": self._tool_backend_to_identifier_params(),
}
+ def _tool_event_policy_to_identifier_params(self) -> dict[str, Any] | None:
+ """
+ Snapshot the active tool-event policy as identifier params.
+
+ Returns:
+ dict[str, Any] | None: ``None`` when no policy is configured;
+ otherwise ``behavior`` and ``max_tool_iterations`` as plain
+ values.
+ """
+ if self._tool_event_policy is None:
+ return None
+ return {
+ "behavior": self._tool_event_policy.behavior.value,
+ "max_tool_iterations": self._tool_event_policy.max_tool_iterations,
+ }
+
+ def _tool_backend_to_identifier_params(self) -> dict[str, Any] | None:
+ """
+ Snapshot the active tool backend as identifier params.
+
+ Returns the backend's fully-qualified class name plus the sorted
+ list of tool names it advertises. Exact callables / transports are
+ not serialized; two backends of the same type exposing the same
+ tool set therefore produce equal identifier dicts.
+
+ Returns:
+ dict[str, Any] | None: ``None`` when no backend is configured;
+ otherwise ``type`` (fully-qualified class name) and
+ ``tools`` (sorted list of advertised tool names).
+ """
+ if self._tool_backend is None:
+ return None
+ backend_type = type(self._tool_backend)
+ return {
+ "type": f"{backend_type.__module__}.{backend_type.__qualname__}",
+ "tools": sorted(self._extract_tool_names(self._tool_backend.schemas)),
+ }
+
+ @staticmethod
+ def _extract_tool_names(schemas: list[dict[str, Any]]) -> list[str]:
+ """
+ Pull the ``name`` field from each schema, supporting both flat and
+ nested ``function`` envelopes.
+
+ Args:
+ schemas (list[dict[str, Any]]): The backend-provided schema list.
+
+ Returns:
+ list[str]: One name per schema. Schemas without a recoverable
+ name are skipped silently — the identifier is best-effort
+ for shape-quirky backends.
+ """
+ names: list[str] = []
+ for schema in schemas:
+ if not isinstance(schema, dict):
+ continue
+ name = schema.get("name")
+ if not name and isinstance(schema.get("function"), dict):
+ name = schema["function"].get("name")
+ if isinstance(name, str):
+ names.append(name)
+ return names
+
@staticmethod
def _capabilities_to_identifier_params(capabilities: TargetCapabilities) -> dict[str, Any]:
"""
diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
index f2d62be82a..055cd55c70 100644
--- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
+++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
@@ -22,9 +22,13 @@
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import Message, construct_response_from_request
from pyrit.prompt_target.common.prompt_target import PromptTarget
-from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
+from pyrit.prompt_target.common.target_capabilities import (
+ CapabilityName,
+ TargetCapabilities,
+)
from pyrit.prompt_target.common.target_configuration import TargetConfiguration
from pyrit.prompt_target.common.utils import limit_requests_per_minute
+from pyrit.tools import ToolBackend, ToolCallParser
logger = logging.getLogger(__name__)
@@ -77,6 +81,8 @@ def __init__(
attn_implementation: str | None = None,
max_requests_per_minute: int | None = None,
custom_configuration: TargetConfiguration | None = None,
+ tool_parser: ToolCallParser | None = None,
+ tool_backend: ToolBackend | None = None,
) -> None:
"""
Initialize the HuggingFaceChatTarget.
@@ -108,6 +114,15 @@ def __init__(
max_requests_per_minute (int | None): The maximum number of requests per minute. Defaults to None.
custom_configuration (TargetConfiguration | None): Override the default configuration for this target
instance. Defaults to None.
+ tool_parser (ToolCallParser | None): When supplied, the target opts into PyRIT's
+ ``@tool_loop`` and uses this parser to extract pending tool calls from each
+ generated response. Supplying a parser also enables ``supports_tool_use=True``
+ on the default capabilities so callers don't need a custom_configuration just
+ to opt in. ``InlineToolCallParser`` is the typical choice because the local
+ tokenizer emits tool calls as inline marker-delimited JSON; supply a different
+ parser when targeting a chat template with a different marker syntax.
+ tool_backend (ToolBackend | None): Convenience kwarg that installs the backend
+ onto ``custom_configuration.tool_backend``.
Raises:
ValueError: If neither or both of `model_id` and `model_path` are provided.
@@ -115,6 +130,16 @@ def __init__(
"""
model_name = model_id if model_id else model_path if model_path else ""
+ # Enable tool-use capability when a parser is supplied, BEFORE super().__init__
+ # so the configuration is correct by the time the base class records it.
+ if tool_parser is not None:
+ custom_configuration = self._enable_tool_use(configuration=custom_configuration)
+ if tool_backend is not None:
+ custom_configuration = self._install_tool_backend(
+ configuration=custom_configuration,
+ tool_backend=tool_backend,
+ )
+
super().__init__(
max_requests_per_minute=max_requests_per_minute,
model_name=model_name,
@@ -174,6 +199,81 @@ def __init__(
raise RuntimeError("CUDA requested but not available.")
self.load_model_and_tokenizer_task = asyncio.create_task(self.load_model_and_tokenizer())
+ self._tool_parser_instance = tool_parser
+
+ @classmethod
+ def _enable_tool_use(cls, *, configuration: TargetConfiguration | None) -> TargetConfiguration:
+ """
+ Return a configuration whose capabilities include ``supports_tool_use=True``.
+
+ When ``configuration`` already has the capability set, returns it as-is.
+ Otherwise rebuilds the capabilities with ``supports_tool_use=True`` and
+ preserves every other field.
+
+ Args:
+ configuration (TargetConfiguration | None): The user-supplied configuration,
+ or ``None`` to start from the class default.
+
+ Returns:
+ TargetConfiguration: A configuration whose capabilities include
+ ``supports_tool_use=True``.
+ """
+ source = configuration if configuration is not None else cls._DEFAULT_CONFIGURATION
+ caps = source.capabilities
+ if caps.includes(capability=CapabilityName.TOOL_USE):
+ return source
+ updated_caps = TargetCapabilities(
+ supports_multi_message_pieces=caps.supports_multi_message_pieces,
+ supports_editable_history=caps.supports_editable_history,
+ supports_multi_turn=caps.supports_multi_turn,
+ supports_system_prompt=caps.supports_system_prompt,
+ supports_tool_use=True,
+ input_modalities=caps.input_modalities,
+ output_modalities=caps.output_modalities,
+ )
+ return TargetConfiguration(
+ capabilities=updated_caps,
+ policy=source.policy,
+ tool_event_policy=source.tool_event_policy,
+ tool_backend=source.tool_backend,
+ )
+
+ @staticmethod
+ def _install_tool_backend(
+ *,
+ configuration: TargetConfiguration | None,
+ tool_backend: ToolBackend,
+ ) -> TargetConfiguration:
+ """
+ Install ``tool_backend`` onto ``configuration``. Rejects double-supply.
+
+ Args:
+ configuration (TargetConfiguration | None): The user-supplied configuration.
+ tool_backend (ToolBackend): The backend to install.
+
+ Returns:
+ TargetConfiguration: The same ``configuration`` instance with the
+ backend installed.
+
+ Raises:
+ ValueError: When ``configuration`` is ``None``, or when
+ ``configuration.tool_backend`` is already set to a different backend.
+ """
+ if configuration is None:
+ raise ValueError(
+ "tool_backend kwarg requires capabilities.supports_tool_use=True; "
+ "supply tool_parser= so the default capabilities flip TOOL_USE on, "
+ "or build a custom_configuration explicitly."
+ )
+ if configuration.tool_backend is not None and configuration.tool_backend is not tool_backend:
+ raise ValueError("tool_backend kwarg conflicts with custom_configuration.tool_backend; supply only one.")
+ configuration.tool_backend = tool_backend
+ return configuration
+
+ @property
+ def _tool_parser(self) -> ToolCallParser | None:
+ """Return the parser supplied at construction, if any."""
+ return self._tool_parser_instance
def _build_identifier(self) -> ComponentIdentifier:
"""
@@ -401,27 +501,79 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me
logger.error(f"Error occurred during inference: {e}")
raise
- def _build_chat_messages(self, *, normalized_conversation: list[Message]) -> list[dict[str, str]]:
+ def _build_chat_messages(self, *, normalized_conversation: list[Message]) -> list[dict[str, Any]]:
"""
Build a list of chat message dicts from the full normalized conversation.
Includes system, user, and assistant messages from the conversation history
- so that the model's chat template receives the complete context.
+ so that the model's chat template receives the complete context. When the
+ conversation contains tool-call envelopes (produced by ``@tool_loop``), they
+ are converted into the chat-template's tool message shape:
+
+ * ``function_call`` pieces become an ``assistant`` message with a
+ ``tool_calls`` list (matching the HuggingFace ``apply_chat_template``
+ convention; templates that don't recognize ``tool_calls`` fall back to
+ rendering the embedded JSON as content).
+ * ``function_call_output`` pieces become a ``role=tool`` message with the
+ tool result as content and ``tool_call_id`` carried for templates that
+ need it.
Args:
normalized_conversation (list[Message]): The full normalized conversation.
Returns:
- list[dict[str, str]]: Messages formatted for the chat template.
+ list[dict[str, Any]]: Messages formatted for the chat template.
"""
- messages: list[dict[str, str]] = []
+ messages: list[dict[str, Any]] = []
for msg in normalized_conversation:
- piece = msg.message_pieces[0]
- role = piece.api_role
- content = piece.converted_value or ""
- messages.append({"role": role, "content": content})
+ for piece in msg.message_pieces:
+ tool_dict = self._maybe_tool_chat_message(piece=piece)
+ if tool_dict is not None:
+ messages.append(tool_dict)
+ continue
+ role = piece.api_role
+ content = piece.converted_value or ""
+ messages.append({"role": role, "content": content})
return messages
+ @staticmethod
+ def _maybe_tool_chat_message(*, piece: Any) -> dict[str, Any] | None:
+ """
+ Convert a ``function_call`` or ``function_call_output`` piece to a chat-template message.
+
+ Args:
+ piece (Any): The MessagePiece to inspect.
+
+ Returns:
+ dict[str, Any] | None: A chat-template message dict (``assistant`` with
+ ``tool_calls``, or ``role=tool`` with ``tool_call_id``) when the
+ piece carries a tool envelope, otherwise ``None``.
+ """
+ data_type = piece.converted_value_data_type or piece.original_value_data_type
+ if data_type not in ("function_call", "function_call_output"):
+ return None
+ envelope = json.loads(piece.converted_value)
+ if data_type == "function_call":
+ return {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {
+ "id": envelope.get("call_id", ""),
+ "type": "function",
+ "function": {
+ "name": envelope.get("name", ""),
+ "arguments": envelope.get("arguments", "{}"),
+ },
+ }
+ ],
+ }
+ return {
+ "role": "tool",
+ "content": str(envelope.get("output", "")),
+ "tool_call_id": envelope.get("call_id", ""),
+ }
+
def set_random_seed(self, random_seed: int) -> None:
"""
Set a new random seed and immediately re-seed the RNG.
@@ -520,6 +672,13 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any:
"""
Apply the chat template to the input messages and tokenize them.
+ When ``self._tool_schemas()`` is non-empty, the schemas are forwarded
+ into ``apply_chat_template`` so tool-trained chat templates can render
+ the model-family-specific tools block (Qwen wraps in ``...``,
+ Llama uses a system-message preamble, etc.). The model can then emit
+ tool calls in its native marker syntax which the user-supplied
+ ``tool_parser`` extracts.
+
Args:
messages: The input messages to apply the chat template to.
@@ -533,6 +692,11 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any:
if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is not None:
logger.info("Tokenizer has a chat template. Applying it to the input messages.")
+ template_kwargs: dict[str, Any] = {}
+ schemas = self._tool_schemas()
+ if schemas:
+ template_kwargs["tools"] = schemas
+
# Apply the chat template to format and tokenize the messages
return cast(
"BatchEncoding",
@@ -542,6 +706,7 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any:
add_generation_prompt=True,
return_tensors=self.tensor_format,
return_dict=True,
+ **template_kwargs,
),
).to(self.device)
error_message = (
diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py
index f2e4b19a76..91f8bc05f0 100644
--- a/pyrit/prompt_target/openai/openai_response_target.py
+++ b/pyrit/prompt_target/openai/openai_response_target.py
@@ -3,6 +3,7 @@
import json
import logging
+import warnings
from collections.abc import Awaitable, Callable, MutableSequence
from enum import Enum
from typing import (
@@ -34,6 +35,13 @@
from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p
from pyrit.prompt_target.openai.openai_error_handling import _is_content_filter_error
from pyrit.prompt_target.openai.openai_target import OpenAITarget
+from pyrit.tools import (
+ CanonicalEnvelopeParser,
+ LocalToolBackend,
+ ToolCallParser,
+ ToolEventBehavior,
+ ToolEventPolicy,
+)
logger = logging.getLogger(__name__)
@@ -76,6 +84,7 @@ class OpenAIResponseTarget(OpenAITarget, PromptTarget):
supports_json_output=True,
supports_multi_message_pieces=True,
supports_system_prompt=True,
+ supports_tool_use=True,
input_modalities=frozenset(
{
frozenset(["text"]),
@@ -154,6 +163,17 @@ def __init__(
"""
super().__init__(custom_configuration=custom_configuration, **kwargs)
+ # If the constructed configuration is the class-level _DEFAULT_CONFIGURATION
+ # singleton (user did not pass custom_configuration AND the underlying_model
+ # was unrecognized), rebuild a per-instance copy so the C6 tool-backend
+ # plumbing below does not mutate state shared across every other instance.
+ if custom_configuration is None and self._configuration is type(self)._DEFAULT_CONFIGURATION:
+ caps = self._configuration.capabilities
+ self._configuration = TargetConfiguration(
+ capabilities=caps,
+ policy=self._configuration.policy,
+ )
+
# Validate temperature and top_p
validate_temperature(temperature)
validate_top_p(top_p)
@@ -167,10 +187,39 @@ def __init__(
self._extra_body_parameters = extra_body_parameters
- # Per-instance tool/func registries:
- self._custom_functions: dict[str, ToolExecutor] = custom_functions or {}
+ # ----- Tool-calling plumbing (C6) ---------------------------------
+ # custom_functions is deprecated as of 0.15.x. New code configures
+ # tool_backend on TargetConfiguration directly. The kwarg is still
+ # accepted; we ALWAYS install a LocalToolBackend (whether populated
+ # or empty) when no other backend is supplied, so legacy in-place
+ # mutations of `target._custom_functions` (via the back-compat
+ # property below) keep affecting dispatch.
self._fail_on_missing_function: bool = fail_on_missing_function
+ if self.configuration.tool_backend is None:
+ shim_backend = LocalToolBackend(
+ callables=dict(custom_functions) if custom_functions else {},
+ schemas=self._derive_default_schemas(custom_functions or {}),
+ fail_on_missing_function=fail_on_missing_function,
+ )
+ self.configuration.tool_backend = shim_backend
+
+ if custom_functions:
+ warnings.warn(
+ "OpenAIResponseTarget(custom_functions=...) is deprecated and will be "
+ "removed in 0.16.0. Configure tool_backend on TargetConfiguration "
+ "instead (e.g. LocalToolBackend(callables=..., schemas=..., "
+ "fail_on_missing_function=...)).",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ # Default policy to EXECUTE when a backend is present. The wrapper's
+ # parser returns an empty list when the model produces no tool calls,
+ # so this is a no-op for plain text completions.
+ if self.configuration.tool_event_policy is None:
+ self.configuration.tool_event_policy = ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE)
+
# Extract the grammar 'tool' if one is present
# See
# https://platform.openai.com/docs/guides/function-calling#context-free-grammars
@@ -185,6 +234,61 @@ def __init__(
logger.debug("Detected grammar tool: %s", tool_name)
self._grammar_name = tool_name
+ @staticmethod
+ def _derive_default_schemas(callables: dict[str, ToolExecutor]) -> list[dict[str, Any]]:
+ """
+ Synthesize minimal JSON schemas for the deprecation-shim path.
+
+ Users who pass the legacy ``custom_functions`` kwarg do not also pass a
+ schema list (the Responses API would accept the calls anyway because the
+ legacy path predates structured tool advertisement). To keep the
+ deprecation shim transparent we generate a schema-less stub per name so
+ ``_tool_schemas()`` returns something non-empty when the user actually
+ wires tools.
+
+ Args:
+ callables: Function name to async callable mapping.
+
+ Returns:
+ list[dict[str, Any]]: A bare schema per callable (``parameters``
+ is the unconstrained empty-object schema).
+ """
+ return [{"name": name, "parameters": {"type": "object"}} for name in callables]
+
+ @property
+ def _custom_functions(self) -> dict[str, ToolExecutor]:
+ """
+ Back-compat live view of the active backend's callables registry.
+
+ Mutations on the returned dict (``target._custom_functions[name] = fn``,
+ ``target._custom_functions.pop(name)``) take effect immediately because
+ the dict object is shared with the underlying
+ :class:`pyrit.tools.LocalToolBackend`. Returns an empty dict when no
+ backend is installed or when the configured backend is not a
+ ``LocalToolBackend``.
+
+ Returns:
+ dict[str, ToolExecutor]: The live callables dict.
+ """
+ backend = self.configuration.tool_backend
+ if isinstance(backend, LocalToolBackend):
+ return cast("dict[str, ToolExecutor]", backend._callables)
+ return {}
+
+ @_custom_functions.setter
+ def _custom_functions(self, value: dict[str, ToolExecutor]) -> None:
+ backend = self.configuration.tool_backend
+ if isinstance(backend, LocalToolBackend):
+ backend._callables = dict(value)
+ backend._schemas = self._derive_default_schemas(value)
+ return
+ new_backend = LocalToolBackend(
+ callables=dict(value),
+ schemas=self._derive_default_schemas(value),
+ fail_on_missing_function=self._fail_on_missing_function,
+ )
+ self.configuration.tool_backend = new_backend
+
def _build_identifier(self) -> ComponentIdentifier:
"""
Build the identifier with OpenAI response-specific parameters.
@@ -378,8 +482,9 @@ async def _construct_request_body(
input_items = await self._build_input_for_multi_modal_async(conversation)
text_format = self._build_text_format(json_config=json_config)
+ tool_schemas = self._tool_schemas()
- body_parameters = {
+ body_parameters: dict[str, Any] = {
"model": self._model_name,
"max_output_tokens": self._max_output_tokens,
"temperature": self._temperature,
@@ -390,8 +495,11 @@ async def _construct_request_body(
"text": text_format,
"reasoning": self._build_reasoning_config(),
}
+ if tool_schemas:
+ body_parameters["tools"] = tool_schemas
if self._extra_body_parameters:
+ # User-supplied extra_body_parameters wins over backend-derived tools.
body_parameters.update(self._extra_body_parameters)
# Filter out None values
@@ -559,11 +667,18 @@ async def _construct_message_from_response(self, response: Any, request: Message
@pyrit_target_retry
async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]:
"""
- Send prompt, handle agentic tool calls (function_call), return all messages.
+ Send one prompt to the Responses API and return exactly one Message.
+
+ The agentic tool-calling loop now lives in :func:`pyrit.tools.tool_loop`
+ on the base class. This method is the single-iteration body the loop
+ re-enters on each turn: build the request body, call the API, parse the
+ response, return the constructed :class:`Message` wrapped in a list of
+ length 1.
- The Responses API supports structured outputs and tool execution. This method handles both:
- - Simple text/reasoning responses
- - Agentic tool-calling loops that may require multiple back-and-forth exchanges
+ The wrapper detects function_call pieces via :attr:`_tool_parser` and
+ decides whether to dispatch + re-enter. Reasoning, MCP, web-search,
+ computer-use, and other non-function-call sections pass through to
+ Memory unchanged because the parser ignores them.
Args:
normalized_conversation (list[Message]): The full conversation
@@ -571,59 +686,51 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me
pipeline. The current message is the last element.
Returns:
- List of messages generated during the interaction (assistant responses and tool messages).
- The normalizer will persist all of these to memory.
+ list[Message]: Exactly one Message wrapping the parsed response.
"""
message = normalized_conversation[-1]
- message_piece: MessagePiece = message.message_pieces[0]
last_piece = message.message_pieces[-1]
json_config = self._get_json_response_config(message_piece=last_piece)
- working_conversation: MutableSequence[Message] = list(normalized_conversation)
-
- # Track all responses generated during this interaction
- responses_to_return: list[Message] = []
-
- # Main agentic loop - each back-and-forth creates a new message
- tool_call_section: Optional[dict[str, Any]] = None
-
- while True:
- logger.info(f"Sending conversation with {len(working_conversation)} messages to the prompt target")
-
- body = await self._construct_request_body(conversation=working_conversation, json_config=json_config)
-
- # Use unified error handling - automatically detects Response and validates
- result = await self._handle_openai_request(
- api_call=lambda body=body: self._client.responses.create(**body),
- request=message,
- )
-
- # Add result to conversation and responses list
- working_conversation.append(result)
- responses_to_return.append(result)
-
- # Extract tool call if present
- tool_call_section = self._find_last_pending_tool_call(result)
-
- # If no tool call, we're done
- if not tool_call_section:
- break
-
- # Execute the tool/function
- tool_output = await self._execute_call_section(tool_call_section)
+ body = await self._construct_request_body(conversation=list(normalized_conversation), json_config=json_config)
+ logger.info("Sending conversation with %d messages to the Responses API", len(normalized_conversation))
+ result = await self._handle_openai_request(
+ api_call=lambda body=body: self._client.responses.create(**body),
+ request=message,
+ )
+ return [result]
- # Create a new message with the tool output
- tool_piece = self._make_tool_piece(tool_output, tool_call_section["call_id"], reference_piece=message_piece)
- tool_message = Message(message_pieces=[tool_piece], skip_validation=True)
+ @property
+ def _tool_parser(self) -> ToolCallParser | None:
+ """
+ Canonical-envelope parser shared with future canonical-envelope targets.
+
+ Walks response message pieces and emits one :class:`~pyrit.tools.ToolCall`
+ per piece whose ``original_value_data_type`` is ``"function_call"``.
+ Reasoning, MCP, web-search, computer-use, and local-shell sections all
+ produce pieces of OTHER data types, so the parser returns an empty list
+ for them and the @tool_loop decorator exits cleanly. Those sections
+ still land in Memory via the parsed Message returned by
+ ``_send_prompt_to_target_async``; they're just not client-side
+ dispatched.
+ """
+ return CanonicalEnvelopeParser()
- # Add tool output message to conversation and responses list
- working_conversation.append(tool_message)
- responses_to_return.append(tool_message)
+ def _tool_schemas(self) -> list[dict[str, Any]]:
+ """
+ Translate the configured backend's schemas into Responses-API tools shape.
- # Continue loop to send tool result and get next response
+ The Responses API expects each function tool as a top-level
+ ``{"type": "function", "name": ..., "description": ...,
+ "parameters": ...}`` entry (NOT wrapped in an inner ``"function"`` key
+ the way Chat Completions does). The backend's schemas are already the
+ bare function schema, so we just stamp ``type=function`` on each.
- # Return all responses (normalizer will persist all of them to memory)
- return responses_to_return
+ Returns:
+ list[dict[str, Any]]: One descriptor per advertised tool, or an
+ empty list when no backend is configured.
+ """
+ return [{"type": "function", **schema} for schema in super()._tool_schemas()]
def _parse_response_output_section(
self, *, section: Any, message_piece: MessagePiece, error: Optional[PromptResponseError]
diff --git a/pyrit/tools/__init__.py b/pyrit/tools/__init__.py
new file mode 100644
index 0000000000..e91b44a037
--- /dev/null
+++ b/pyrit/tools/__init__.py
@@ -0,0 +1,75 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+Generic tool-use scaffolding for ``PromptTarget``.
+
+This package provides a transport-agnostic tool-calling loop. The
+``tool_loop`` decorator, when applied to ``send_prompt_async``, runs
+the standard PyRIT validate+normalize work once and then repeatedly
+re-enters the target's protected ``_send_prompt_to_target_async`` until
+the model issues a stop response (or a configured limit is hit).
+
+A target opts in by declaring two collaborators:
+
+* ``self._tool_parser`` — a ``ToolCallParser`` that walks a
+ response message and extracts pending ``ToolCall`` instances.
+* ``self.configuration.tool_event_policy`` — a ``ToolEventPolicy``
+ whose ``ToolEventBehavior`` decides whether to ``EXECUTE``,
+ ``RAISE``, or ``RETURN_RAW`` on each detected call.
+
+When the policy is ``EXECUTE``, calls are dispatched through
+``self.configuration.tool_backend``, an implementation of
+``ToolBackend``. ``LocalToolBackend`` is the in-process backend;
+``MCPToolBackend`` proxies through one or more MCP servers.
+
+The ``ToolBackend`` abstract base is intentionally distinct from
+``pyrit.registry`` — that namespace is reserved for framework-level
+identity registries (``TargetRegistry``, ``ScorerRegistry``) that
+register named singletons for CLI lookup, which a per-target tool
+dispatch table is not.
+
+``@tool_loop`` is wired onto ``PromptTarget.send_prompt_async`` from
+the base class, and the ``tool_event_policy`` / ``tool_backend``
+fields hang off ``TargetConfiguration``.
+
+The two exception types the loop raises
+(``ToolCallNotSupported`` and
+``ToolCallLoopLimitExceeded``) live in
+``pyrit.exceptions`` alongside the rest of PyRIT's exception
+catalog, so non-tools callers (attacks, normalizers) can import them
+without taking a subsystem-level dependency on ``pyrit.tools``.
+"""
+
+from pyrit.tools.backend import ToolBackend
+from pyrit.tools.inline_parser import InlineToolCallParser, InlineToolCallParserMode
+from pyrit.tools.local_backend import LocalToolBackend
+from pyrit.tools.mcp_backend import MCPToolBackend
+from pyrit.tools.mcp_client import (
+ DockerMCPServerSpec,
+ LocalMCPServerSpec,
+ MCPClient,
+ MCPServerSpec,
+ RemoteMCPServerSpec,
+)
+from pyrit.tools.models import ToolCall, ToolEventBehavior, ToolEventPolicy, tool_loop
+from pyrit.tools.parsers import CanonicalEnvelopeParser, ToolCallParser
+
+__all__ = [
+ "CanonicalEnvelopeParser",
+ "DockerMCPServerSpec",
+ "InlineToolCallParser",
+ "InlineToolCallParserMode",
+ "LocalMCPServerSpec",
+ "LocalToolBackend",
+ "MCPClient",
+ "MCPServerSpec",
+ "MCPToolBackend",
+ "RemoteMCPServerSpec",
+ "ToolBackend",
+ "ToolCall",
+ "ToolCallParser",
+ "ToolEventBehavior",
+ "ToolEventPolicy",
+ "tool_loop",
+]
diff --git a/pyrit/tools/backend.py b/pyrit/tools/backend.py
new file mode 100644
index 0000000000..d878bd1f4f
--- /dev/null
+++ b/pyrit/tools/backend.py
@@ -0,0 +1,87 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from pyrit.tools.models import ToolCall
+
+
+class ToolBackend(ABC):
+ """
+ Abstract base for backends that dispatch tool calls produced by a target.
+
+ A ``ToolBackend`` is a per-target dispatch table — it owns the
+ ``name -> async callable`` mapping a target uses to execute the tool
+ calls extracted from a model response. This is intentionally distinct
+ from ``pyrit.registry``, whose ``Registry`` classes register named
+ framework singletons (targets, scorers, attacks) for CLI lookup.
+
+ Two concrete implementations ship with PyRIT:
+
+ * ``LocalToolBackend`` — in-process backend backed
+ by ``async def`` callables. Useful for unit tests and for embedding
+ tools inside the PyRIT process.
+ * ``MCPToolBackend`` — proxies dispatch through one
+ or more MCP servers.
+
+ Subclasses MUST implement ``schemas`` and ``dispatch_async``.
+ ``dispatch_all_sequential_async`` ships with a default
+ implementation that awaits ``dispatch_async`` once per call in
+ declaration order; backends that wish to parallelize dispatch
+ (e.g. fan out across multiple sandbox containers) should override it.
+ """
+
+ @property
+ @abstractmethod
+ def schemas(self) -> list[dict[str, Any]]:
+ """
+ The JSON-schema descriptors for every tool the backend exposes.
+
+ Returns:
+ list[dict[str, Any]]: One schema per tool, in a target-agnostic
+ format that concrete targets serialize into their request
+ body.
+ """
+
+ @abstractmethod
+ async def dispatch_async(self, call: ToolCall) -> dict[str, Any]:
+ """
+ Execute a single tool call and return the structured result.
+
+ Implementations MUST NOT raise on tool-side failures; they MUST
+ return an error envelope (e.g. ``{"error": "...", "tool": "..."}``)
+ so the tool loop can carry the failure back to the model.
+
+ Args:
+ call (ToolCall): The tool call to dispatch.
+
+ Returns:
+ dict[str, Any]: The structured tool result.
+ """
+
+ async def dispatch_all_sequential_async(
+ self,
+ calls: list[ToolCall],
+ ) -> list[tuple[ToolCall, dict[str, Any]]]:
+ """
+ Dispatch every call in *calls* sequentially, preserving declaration order.
+
+ Default implementation: ``await dispatch_async`` once per call.
+ Backends that parallelize dispatch should override this method.
+
+ Args:
+ calls (list[ToolCall]): The calls to dispatch, in declaration order.
+
+ Returns:
+ list[tuple[ToolCall, dict[str, Any]]]: ``(call, result)`` pairs,
+ in the same order as *calls*.
+ """
+ results: list[tuple[ToolCall, dict[str, Any]]] = []
+ for call in calls:
+ envelope = await self.dispatch_async(call)
+ results.append((call, envelope))
+ return results
diff --git a/pyrit/tools/inline_parser.py b/pyrit/tools/inline_parser.py
new file mode 100644
index 0000000000..c7cc353fe4
--- /dev/null
+++ b/pyrit/tools/inline_parser.py
@@ -0,0 +1,202 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""Inline tool-call parser for open chat-tuned models."""
+
+from __future__ import annotations
+
+import enum
+import json
+import logging
+import re
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from pyrit.models import Message
+ from pyrit.tools.models import ToolCall
+
+logger = logging.getLogger(__name__)
+
+
+class InlineToolCallParserMode(enum.Enum):
+ """Policy for handling text that surrounds inline tool-call markers."""
+
+ TRUNCATE_AT_LAST = "truncate_at_last"
+ TRUNCATE_AT_FIRST = "truncate_at_first"
+ EXTRACT_ALL = "extract_all"
+ STRICT_TRAILING_EMPTY = "strict_trailing_empty"
+
+
+class InlineToolCallParser:
+ """
+ Extract canonical ``ToolCall`` instances from marker-delimited JSON blocks.
+
+ Open chat-tuned models that do not expose a structured ``tool_calls``
+ channel typically emit tool calls as inline text wrapped in a
+ chat-template-specific marker (an angle-bracket pair, a pipe-delimited
+ tag pair, a square-bracketed list payload, and so on). This parser
+ walks every ``MessagePiece`` whose ``original_value_data_type`` is
+ ``"text"`` and runs ``marker_pattern`` against the piece's
+ ``original_value``. Each match capture group is decoded as JSON of the
+ form ``{"name": ..., "arguments": {...}}``. Synthetic ``call_id``
+ values are minted positionally because inline-marker formats do not
+ issue provider IDs.
+
+ The ``mode`` parameter controls how text surrounding the markers is
+ treated -- see ``InlineToolCallParserMode``. The default
+ ``TRUNCATE_AT_LAST`` honors every marker but discards anything after
+ the last one so hallucinated "tool results" that the model dreams up
+ after the call are not persisted as if they were real outputs.
+
+ Args:
+ marker_pattern (str): Regex with exactly one capture group returning
+ the JSON payload. Default targets the angle-bracket
+ ``...`` syntax used by many tool-trained
+ ChatML-style chat templates.
+ call_id_prefix (str): Prefix for synthetic ``call_id`` values.
+ mode (InlineToolCallParserMode): Surrounding-text policy.
+ """
+
+ def __init__(
+ self,
+ *,
+ marker_pattern: str = r"(.*?)",
+ call_id_prefix: str = "call",
+ mode: InlineToolCallParserMode = InlineToolCallParserMode.TRUNCATE_AT_LAST,
+ ) -> None:
+ """
+ Build an ``InlineToolCallParser``.
+
+ Args:
+ marker_pattern (str): See class docstring.
+ call_id_prefix (str): See class docstring.
+ mode (InlineToolCallParserMode): See class docstring.
+ """
+ self._pattern = re.compile(marker_pattern, re.DOTALL)
+ self._call_id_prefix = call_id_prefix
+ self._mode = mode
+
+ @property
+ def mode(self) -> InlineToolCallParserMode:
+ """The active surrounding-text policy."""
+ return self._mode
+
+ def parse(self, message: Message) -> list[ToolCall]:
+ """
+ Extract tool calls from every text piece in ``message``.
+
+ Args:
+ message (Message): The most recent assistant response.
+
+ Returns:
+ list[ToolCall]: One ``ToolCall`` per valid marker match, in
+ declaration order across pieces. Empty when no markers
+ are found.
+
+ Raises:
+ ValueError: When ``mode`` is ``STRICT_TRAILING_EMPTY`` and any
+ non-whitespace text follows the last marker in any piece.
+ """
+ calls: list[ToolCall] = []
+ next_id = 0
+ for piece in message.message_pieces:
+ if piece.original_value_data_type != "text":
+ continue
+ matches = self._match_piece(text=piece.original_value)
+ if not matches:
+ continue
+ for match in matches:
+ call = self._build_call(match=match, next_id=next_id)
+ if call is None:
+ continue
+ calls.append(call)
+ next_id += 1
+
+ return calls
+
+ def _match_piece(self, *, text: str) -> list[re.Match[str]]:
+ """
+ Apply the mode-specific filter to all marker matches in ``text``.
+
+ Args:
+ text (str): Piece text to scan for markers.
+
+ Returns:
+ list[re.Match[str]]: Matches to honor, in declaration order.
+
+ Raises:
+ ValueError: When ``mode`` is ``STRICT_TRAILING_EMPTY`` and any
+ non-whitespace text follows the last marker.
+ """
+ matches = list(self._pattern.finditer(text))
+ if not matches:
+ return matches
+
+ if self._mode is InlineToolCallParserMode.TRUNCATE_AT_FIRST:
+ return matches[:1]
+ if self._mode is InlineToolCallParserMode.STRICT_TRAILING_EMPTY:
+ trailing = text[matches[-1].end() :]
+ if trailing.strip():
+ raise ValueError(
+ "Non-whitespace text follows the last tool-call marker; "
+ "InlineToolCallParserMode.STRICT_TRAILING_EMPTY rejects this. "
+ f"Trailing: {trailing!r}"
+ )
+ return matches
+
+ def _build_call(self, *, match: re.Match[str], next_id: int) -> ToolCall | None:
+ """
+ Decode a single marker match into a ``ToolCall``.
+
+ Args:
+ match (re.Match[str]): A single marker match whose group 1 is the
+ JSON payload.
+ next_id (int): Positional id used to form ``call_id``.
+
+ Returns:
+ ToolCall | None: ``None`` when the payload is malformed or
+ missing the ``name`` field. The caller is expected to log
+ and skip.
+ """
+ from pyrit.tools.models import ToolCall
+
+ payload = match.group(1).strip()
+ try:
+ parsed = json.loads(payload)
+ except json.JSONDecodeError:
+ logger.warning("Skipping malformed tool-call payload: %r", payload[:120])
+ return None
+ if not isinstance(parsed, dict) or "name" not in parsed:
+ logger.warning("Skipping tool-call payload without 'name' field: %r", payload[:120])
+ return None
+ arguments = self._coerce_arguments(parsed.get("arguments", {}))
+ return ToolCall(
+ call_id=f"{self._call_id_prefix}_{next_id}",
+ name=parsed["name"],
+ arguments=arguments,
+ raw_envelope=parsed,
+ )
+
+ @staticmethod
+ def _coerce_arguments(raw: object) -> dict:
+ """
+ Coerce the ``arguments`` field into a dict regardless of source shape.
+
+ Args:
+ raw (object): The value of the payload's ``arguments`` field.
+ Either a dict, a JSON-encoded string, or something else.
+
+ Returns:
+ dict: The decoded arguments dict, or an empty dict on any
+ shape PyRIT cannot interpret. Empty-dict fallback preserves
+ the loop's behavior of "always continue with a real call".
+ """
+ if isinstance(raw, dict):
+ return dict(raw)
+ if isinstance(raw, str):
+ try:
+ decoded = json.loads(raw)
+ except json.JSONDecodeError:
+ return {}
+ return dict(decoded) if isinstance(decoded, dict) else {}
+ return {}
diff --git a/pyrit/tools/local_backend.py b/pyrit/tools/local_backend.py
new file mode 100644
index 0000000000..7a149bc825
--- /dev/null
+++ b/pyrit/tools/local_backend.py
@@ -0,0 +1,121 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any
+
+from pyrit.tools.backend import ToolBackend
+
+if TYPE_CHECKING:
+ from collections.abc import Awaitable, Callable
+
+ from pyrit.tools.models import ToolCall
+
+logger = logging.getLogger(__name__)
+
+
+class LocalToolBackend(ToolBackend):
+ """
+ In-process ``ToolBackend`` backed by a name -> ``async def``
+ mapping. Useful for unit tests and for embedding small tools inside the
+ PyRIT process without standing up an MCP server.
+
+ "Local" here means tools run in PyRIT's own Python process — no
+ subprocess, no IPC, no wire protocol. Contrast with
+ ``MCPToolBackend``, which proxies dispatch through one or more MCP
+ servers reached via JSON-RPC.
+
+ The backend dispatches sequentially in declaration order. Tool-side
+ failures (raised exceptions, missing names, allow-list rejections)
+ are converted into structured error envelopes so the tool loop can
+ forward them back to the model as ``function_call_output`` content
+ rather than aborting the conversation.
+ """
+
+ def __init__(
+ self,
+ *,
+ callables: dict[str, Callable[[dict[str, Any]], Awaitable[Any]]],
+ schemas: list[dict[str, Any]] | None = None,
+ allowed_tools: set[str] | None = None,
+ fail_on_missing_function: bool = True,
+ ) -> None:
+ """
+ Initialize the backend.
+
+ Args:
+ callables (dict[str, Callable[[dict[str, Any]], Awaitable[Any]]]):
+ Map from tool name to an ``async def`` that accepts a parsed
+ arguments dict and returns the tool result. Results are
+ serialized by the tool loop via ``json.dumps``.
+ schemas (list[dict[str, Any]] | None): JSON-schema descriptors
+ injected into the target's request body. Defaults to an empty
+ list when omitted.
+ allowed_tools (set[str] | None): Optional allow-list of tool
+ names; calls whose name is not in this set surface as
+ ``tool_not_allowed`` envelopes without invoking the callable.
+ Defaults to None (no allow-list; every registered tool is
+ callable).
+ fail_on_missing_function (bool): When True (default), an unknown
+ tool name raises ``KeyError``. When False, the backend
+ returns a ``tool_not_registered`` envelope so the model can
+ recover.
+ """
+ self._callables = dict(callables)
+ self._schemas: list[dict[str, Any]] = list(schemas) if schemas is not None else []
+ self._allowed_tools = set(allowed_tools) if allowed_tools is not None else None
+ self._fail_on_missing_function = fail_on_missing_function
+
+ @property
+ def schemas(self) -> list[dict[str, Any]]:
+ """The JSON-schema descriptors for the tools in this backend."""
+ return list(self._schemas)
+
+ async def dispatch_async(self, call: ToolCall) -> dict[str, Any]:
+ """
+ Dispatch a single tool call. Tool failures are converted into
+ structured envelopes; only configuration errors (missing tool with
+ ``fail_on_missing_function=True``) propagate as exceptions.
+
+ Args:
+ call (ToolCall): The call to dispatch.
+
+ Returns:
+ dict[str, Any]: The tool's result, or a structured error envelope.
+
+ Raises:
+ KeyError: When the tool name is not registered and
+ ``fail_on_missing_function=True``.
+ """
+ if self._allowed_tools is not None and call.name not in self._allowed_tools:
+ logger.info("Rejecting disallowed tool call: %s", call.name)
+ return {
+ "error": "tool_not_allowed",
+ "tool": call.name,
+ "allowed_tools": sorted(self._allowed_tools),
+ }
+
+ fn = self._callables.get(call.name)
+ if fn is None:
+ if self._fail_on_missing_function:
+ raise KeyError(f"Tool '{call.name}' is not registered.")
+ available = sorted(self._callables.keys())
+ logger.warning("Tool '%s' not registered. Available: %s", call.name, available)
+ return {
+ "error": "tool_not_registered",
+ "tool": call.name,
+ "available_tools": available,
+ }
+
+ try:
+ result = await fn(call.arguments)
+ except Exception as ex:
+ logger.warning("Tool '%s' raised %s: %s", call.name, type(ex).__name__, ex)
+ return {
+ "error": "tool_execution_failed",
+ "tool": call.name,
+ "detail": str(ex),
+ }
+ return result if isinstance(result, dict) else {"result": result}
diff --git a/pyrit/tools/mcp_backend.py b/pyrit/tools/mcp_backend.py
new file mode 100644
index 0000000000..41a0d32462
--- /dev/null
+++ b/pyrit/tools/mcp_backend.py
@@ -0,0 +1,199 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+Multi-server tool backend that proxies dispatch through one or more
+MCP servers.
+
+This is the ``ToolBackend`` implementation that real
+red-team configurations use. It composes one
+``MCPClient`` per ``MCPServerSpec``,
+aggregates their advertised schemas, routes incoming
+``ToolCall`` instances to the correct underlying
+client, and enforces an optional ``allowed_tools`` allow-list.
+
+Contrast with ``LocalToolBackend``, which dispatches
+to Python ``async def`` callables inside PyRIT's own process.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+from contextlib import AsyncExitStack
+from typing import TYPE_CHECKING, Any
+
+from pyrit.tools.backend import ToolBackend
+from pyrit.tools.mcp_client import MCPClient
+
+if TYPE_CHECKING:
+ from collections.abc import Iterable
+
+ from pyrit.tools.mcp_client import MCPServerSpec
+ from pyrit.tools.models import ToolCall
+
+logger = logging.getLogger(__name__)
+
+
+class MCPToolBackend(ToolBackend):
+ """
+ ``ToolBackend`` backed by one or more MCP servers.
+
+ On ``__aenter__``, the backend spawns / connects each server in
+ its ``_servers`` list (sequentially) through a single
+ ``contextlib.AsyncExitStack``, runs the MCP handshake, caches
+ schemas, and builds an advertised-name → ``(client, server_name)``
+ routing table. Collisions raise ``ValueError`` unless the
+ colliding specs set ``name_prefix``.
+
+ A single shared ``AsyncExitStack`` (rather than one per client)
+ is required so anyio's nested cancel scopes — opened by the ``mcp``
+ SDK's ``stdio_client`` and ``ClientSession`` context managers — are
+ closed in strict LIFO order from the entering task. Closing
+ out-of-order would trip
+ ``"Attempted to exit a cancel scope that isn't the current task's
+ current cancel scope"``.
+
+ Dispatch is serialized through an ``asyncio.Lock`` per backend
+ instance — multiple concurrent coroutines sharing the same backend
+ (e.g. parallel attack runs) will not interleave JSON-RPC frames on
+ the same stdio pipe.
+ """
+
+ def __init__(
+ self,
+ *,
+ servers: Iterable[MCPServerSpec],
+ allowed_tools: list[str] | None = None,
+ ) -> None:
+ """
+ Initialize the backend.
+
+ Args:
+ servers: One or more ``MCPServerSpec`` instances describing
+ where each server runs.
+ allowed_tools: Optional allow-list of tool names. Names not in
+ the list are filtered from ``schemas`` AND
+ short-circuit dispatch with a ``tool_not_allowed`` envelope.
+ Names are matched after ``name_prefix``
+ has been applied. Defaults to None (every advertised tool is
+ callable).
+
+ Raises:
+ ValueError: When *servers* is empty.
+ """
+ self._servers: list[MCPServerSpec] = list(servers)
+ if not self._servers:
+ raise ValueError("MCPToolBackend requires at least one server spec.")
+ self._allowed_tools: set[str] | None = set(allowed_tools) if allowed_tools is not None else None
+ self._clients: list[MCPClient] = []
+ self._routing: dict[str, tuple[MCPClient, str]] = {}
+ self._dispatch_lock = asyncio.Lock()
+ self._stack: AsyncExitStack | None = None
+ self._entered = False
+
+ @property
+ def schemas(self) -> list[dict[str, Any]]:
+ """The union of every connected server's schemas, filtered by ``allowed_tools``."""
+ out: list[dict[str, Any]] = []
+ for client in self._clients:
+ for schema in client.schemas:
+ if self._allowed_tools is not None and schema["name"] not in self._allowed_tools:
+ continue
+ out.append(schema)
+ return out
+
+ async def __aenter__(self) -> MCPToolBackend:
+ """
+ Connect each underlying client through a shared ``AsyncExitStack`` and build the routing table.
+
+ Returns:
+ MCPToolBackend: *self*, ready to dispatch.
+
+ Raises:
+ ValueError: When two connected clients advertise the same tool
+ name without a disambiguating ``name_prefix``.
+ """
+ stack = AsyncExitStack()
+ clients: list[MCPClient] = []
+ routing: dict[str, tuple[MCPClient, str]] = {}
+ try:
+ for spec in self._servers:
+ client = MCPClient(spec=spec)
+ await stack.enter_async_context(client)
+ clients.append(client)
+ for advertised_name in client.tool_names:
+ if advertised_name in routing:
+ raise ValueError(
+ f"duplicate tool name '{advertised_name}'. "
+ "Set LocalMCPServerSpec.name_prefix on at least one "
+ "colliding server to disambiguate.",
+ )
+ routing[advertised_name] = (client, advertised_name)
+ except Exception:
+ await stack.aclose()
+ raise
+
+ self._stack = stack
+ self._clients = clients
+ self._routing = routing
+ self._entered = True
+ return self
+
+ async def __aexit__(self, *exc: Any) -> None:
+ """Tear down every underlying client in strict LIFO order."""
+ stack = self._stack
+ self._stack = None
+ self._clients = []
+ self._routing = {}
+ self._entered = False
+ if stack is not None:
+ await stack.aclose()
+
+ async def dispatch_async(self, call: ToolCall) -> dict[str, Any]:
+ """
+ Route *call* to the correct client and dispatch.
+
+ See ``MCPClient.dispatch_async`` for the envelope shape.
+ Allow-list rejections and unknown-tool calls return error
+ envelopes; only "backend not entered" raises.
+
+ Args:
+ call (ToolCall): The call to dispatch.
+
+ Returns:
+ dict[str, Any]: A structured envelope (success, ``tool_not_allowed``,
+ ``tool_not_registered``, or the underlying
+ ``MCPClient.dispatch_async`` envelope).
+
+ Raises:
+ RuntimeError: When the backend has not been entered via ``async with``.
+ """
+ if not self._entered:
+ raise RuntimeError(
+ "MCPToolBackend is not active. Use `async with backend:` to manage its lifecycle before dispatching.",
+ )
+
+ if self._allowed_tools is not None and call.name not in self._allowed_tools:
+ logger.info("Rejecting disallowed tool call: %s", call.name)
+ return {
+ "is_error": True,
+ "error": "tool_not_allowed",
+ "tool": call.name,
+ "allowed_tools": sorted(self._allowed_tools),
+ }
+
+ route = self._routing.get(call.name)
+ if route is None:
+ available = sorted(self._routing.keys())
+ logger.warning("Tool '%s' not registered. Available: %s", call.name, available)
+ return {
+ "is_error": True,
+ "error": "tool_not_registered",
+ "tool": call.name,
+ "available_tools": available,
+ }
+
+ client, _server_side_name = route
+ async with self._dispatch_lock:
+ return await client.dispatch_async(call)
diff --git a/pyrit/tools/mcp_client.py b/pyrit/tools/mcp_client.py
new file mode 100644
index 0000000000..0f9eeadcc3
--- /dev/null
+++ b/pyrit/tools/mcp_client.py
@@ -0,0 +1,369 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+Stdio-transport client for the Model Context Protocol (MCP).
+
+This module is the wire-protocol half of PyRIT's MCP integration. It
+sits below ``MCPToolBackend`` (which composes one
+``MCPClient`` per configured server and handles cross-server
+routing) and above the upstream ``mcp`` Python SDK (which owns the
+JSON-RPC framing, capability negotiation, and asyncio task plumbing).
+
+The three ``MCPServerSpec`` variants describe *where* the server
+runs. Only ``LocalMCPServerSpec`` is implemented in this commit:
+
+* ``LocalMCPServerSpec`` — spawn the server as a child process and
+ speak JSON-RPC over its stdin/stdout.
+* ``RemoteMCPServerSpec`` — HTTP/SSE transport against a hosted
+ server. Stub: ``connect_async`` raises ``NotImplementedError``.
+* ``DockerMCPServerSpec`` — stdio over ``docker run -i`` against a
+ hardened sandbox container. Stub: ``connect_async`` raises
+ ``NotImplementedError``. Implementation lands in the follow-up
+ sandbox PR.
+
+The stub variants are intentionally part of the type union today so
+downstream code can be written against the eventual API without
+forcing a Union expansion later.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+from contextlib import AsyncExitStack
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any
+
+from mcp import ClientSession
+from mcp.client.stdio import StdioServerParameters, stdio_client
+
+if TYPE_CHECKING:
+ from pyrit.tools.models import ToolCall
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class LocalMCPServerSpec:
+ """
+ Spec for an MCP server spawned as a child process and reached via
+ stdio JSON-RPC.
+
+ Attributes:
+ command (str): The interpreter or binary to exec (e.g. ``"python"``).
+ args (tuple[str, ...]): Arguments passed to *command*, in order.
+ env (dict[str, str] | None): Environment overlay for the child
+ process. ``None`` (default) inherits PyRIT's environment.
+ name_prefix (str | None): When set, every tool advertised by the
+ server is registered as ``f"{name_prefix}{tool_name}"`` in
+ the parent ``MCPToolBackend``. Used to
+ disambiguate two servers that expose the same tool name.
+ timeout_seconds (float): Per-call timeout, enforced by
+ ``MCPClient.dispatch_async``. Defaults to 30 seconds.
+ """
+
+ command: str
+ args: tuple[str, ...] = ()
+ env: dict[str, str] | None = None
+ name_prefix: str | None = None
+ timeout_seconds: float = 30.0
+
+
+@dataclass(frozen=True)
+class RemoteMCPServerSpec:
+ """
+ Spec for an MCP server reached over HTTP / SSE. **Not implemented**
+ in this PR — ``MCPClient.connect_async`` raises
+ ``NotImplementedError``. Tracked by ``# TODO(mcp-http-transport)``.
+
+ Attributes:
+ url (str): The base URL of the MCP server.
+ name_prefix (str | None): Same semantics as
+ ``LocalMCPServerSpec.name_prefix``.
+ timeout_seconds (float): Per-call timeout.
+ """
+
+ url: str
+ name_prefix: str | None = None
+ timeout_seconds: float = 30.0
+
+
+# TODO(sandbox-provider) — DockerMCPServerSpec stub here; implementation lands in follow-up PR.
+@dataclass(frozen=True)
+class DockerMCPServerSpec:
+ """
+ Spec for an MCP server hosted inside a hardened Docker container.
+
+ **NOT IMPLEMENTED IN THIS PR.** Reached via stdio over ``docker run -i``.
+
+ Expected behavior in the follow-up sandbox PR:
+
+ * One container per spec instance, managed by a process-wide
+ ``SandboxPool``.
+ * Image is built lazily, keyed by ``sha256(Dockerfile + build_context)``,
+ and cached across attacks; no rebuild unless missing or explicitly
+ overridden.
+ * Container is recreated from the cached image at attack and scenario
+ boundaries (filesystem returns to baseline every time).
+ * Network access governed by ``NetworkProfile`` (default ``"none"`` =
+ ``--network=none``).
+ * Container runs as a non-root UID with ``--cap-drop=ALL``, a read-only
+ root filesystem, and an in-container MCP server exposing
+ ``run_shell(cmd, timeout_seconds)``.
+
+ Attributes:
+ image (str): Docker image tag (e.g. ``"pyrit-sandbox:base"``).
+ network_profile (str): ``NetworkProfile`` name; ``"none"`` (default)
+ launches the container with ``--network=none``.
+ name_prefix (str | None): Same semantics as
+ ``LocalMCPServerSpec.name_prefix``.
+ timeout_seconds (float): Per-call timeout.
+
+ Future fields (deferred to the follow-up sandbox PR): ``memory_limit``,
+ ``cpu_limit``, ``pids_limit``, ``env``, ``mounts``, ``command_override``.
+ """
+
+ image: str
+ network_profile: str = "none"
+ name_prefix: str | None = None
+ timeout_seconds: float = 30.0
+
+
+MCPServerSpec = LocalMCPServerSpec | RemoteMCPServerSpec | DockerMCPServerSpec
+
+
+def _to_input_schema_dict(input_schema: Any) -> dict[str, Any]:
+ """
+ Coerce the SDK's tool ``inputSchema`` (pydantic model or dict) into a plain dict.
+
+ Returns:
+ dict[str, Any]: A plain-dict copy of *input_schema*, or an empty
+ object schema when *input_schema* is None or of an unrecognized type.
+ """
+ if input_schema is None:
+ return {"type": "object", "properties": {}}
+ if hasattr(input_schema, "model_dump"):
+ return input_schema.model_dump()
+ if isinstance(input_schema, dict):
+ return dict(input_schema)
+ return {"type": "object", "properties": {}}
+
+
+def _flatten_content(content: list[Any]) -> str:
+ """
+ Concatenate the text portions of an MCP ``CallToolResult.content`` list.
+
+ Returns:
+ str: Concatenated ``.text`` values from each content item, in order.
+ """
+ pieces: list[str] = []
+ for item in content:
+ text = getattr(item, "text", None)
+ if text is not None:
+ pieces.append(text)
+ elif isinstance(item, dict) and "text" in item:
+ pieces.append(item["text"])
+ return "".join(pieces)
+
+
+class MCPClient:
+ """
+ A single MCP-server session.
+
+ The client owns the lifetime of one server's transport stack and
+ exposes a uniform ``dispatch_async`` regardless of which
+ ``MCPServerSpec`` variant it was constructed from. Composition
+ across multiple servers (routing, schema aggregation, allow-lists)
+ is the responsibility of ``MCPToolBackend``.
+
+ Lifecycle:
+
+ * ``connect_async`` spawns the subprocess (for
+ ``LocalMCPServerSpec``), runs the MCP handshake, and caches
+ ``tools/list`` results.
+ * ``dispatch_async`` issues one ``tools/call`` and returns a
+ structured envelope (success or error).
+ * ``close_async`` tears down the transport stack.
+
+ The class is usable as an async context manager.
+ """
+
+ def __init__(self, *, spec: MCPServerSpec) -> None:
+ """
+ Initialize the client around *spec*. Does not connect; call
+ ``connect_async`` (or use the async context-manager form) to start
+ the transport stack.
+ """
+ self._spec = spec
+ self._stack = AsyncExitStack()
+ self._session: ClientSession | None = None
+ self._tools: list[Any] = []
+
+ @property
+ def spec(self) -> MCPServerSpec:
+ """The ``MCPServerSpec`` this client was constructed with."""
+ return self._spec
+
+ @property
+ def schemas(self) -> list[dict[str, Any]]:
+ """
+ JSON schemas for every tool the server advertises.
+
+ Each schema is shaped ``{"name", "description", "parameters"}``.
+ The optional ``LocalMCPServerSpec.name_prefix`` is applied
+ here so a backend that owns this client sees the prefixed name.
+ """
+ prefix = getattr(self._spec, "name_prefix", None) or ""
+ return [
+ {
+ "name": f"{prefix}{tool.name}",
+ "description": tool.description or "",
+ "parameters": _to_input_schema_dict(tool.inputSchema),
+ }
+ for tool in self._tools
+ ]
+
+ @property
+ def tool_names(self) -> list[str]:
+ """Tool names with the spec's ``name_prefix`` applied."""
+ return [s["name"] for s in self.schemas]
+
+ def _strip_prefix(self, name: str) -> str:
+ prefix = getattr(self._spec, "name_prefix", None) or ""
+ if prefix and name.startswith(prefix):
+ return name[len(prefix) :]
+ return name
+
+ async def connect_async(self) -> None:
+ """Establish the transport, run the handshake, and cache schemas."""
+ if isinstance(self._spec, RemoteMCPServerSpec):
+ raise NotImplementedError(
+ "HTTP/SSE transport ships in a follow-up PR. "
+ "RemoteMCPServerSpec is declared today so user code can target the eventual API."
+ )
+ if isinstance(self._spec, DockerMCPServerSpec):
+ raise NotImplementedError(
+ "Docker sandbox transport ships in a follow-up PR. "
+ "DockerMCPServerSpec runs the MCP server inside a hardened "
+ "Debian container reached via stdio over `docker run -i`, "
+ "managed by a process-wide SandboxPool with image caching and "
+ "per-attack container recreation."
+ )
+
+ assert isinstance(self._spec, LocalMCPServerSpec)
+ params = StdioServerParameters(
+ command=self._spec.command,
+ args=list(self._spec.args),
+ env=self._spec.env,
+ )
+ read, write = await self._stack.enter_async_context(stdio_client(params))
+ session = await self._stack.enter_async_context(ClientSession(read, write))
+ await session.initialize()
+ result = await session.list_tools()
+ self._session = session
+ self._tools = list(result.tools)
+
+ async def close_async(self) -> None:
+ """Tear down the transport stack. Idempotent; safe to call before connect."""
+ try:
+ await self._stack.aclose()
+ except Exception as ex: # noqa: BLE001 — close should never raise into the caller.
+ logger.warning("Error tearing down MCP client stack: %s", ex)
+ finally:
+ self._stack = AsyncExitStack()
+ self._session = None
+ self._tools = []
+
+ async def __aenter__(self) -> MCPClient:
+ """
+ Connect the transport stack and return *self*.
+
+ Returns:
+ MCPClient: *self*, ready to dispatch tool calls.
+ """
+ await self.connect_async()
+ return self
+
+ async def __aexit__(self, *exc: Any) -> None:
+ """Tear down the transport stack."""
+ await self.close_async()
+
+ async def dispatch_async(self, call: ToolCall) -> dict[str, Any]:
+ """
+ Issue one ``tools/call`` and return a structured envelope.
+
+ Envelope shape:
+
+ * Success: ``{"is_error": False, "content": str, "tool": name}``.
+ * Timeout: ``{"is_error": True, "error": "tool_timeout", "tool": name, ...}``.
+ * Server-reported error: ``{"is_error": True, "error": "tool_execution_failed", "tool": name, ...}``.
+
+ Tool-side failures are converted to envelopes; only programmer
+ errors (calling before ``connect_async``) raise.
+
+ Args:
+ call (ToolCall): The call to dispatch. The advertised
+ ``name_prefix`` (if any) is stripped before contacting the server.
+
+ Returns:
+ dict[str, Any]: One of the envelope shapes documented above.
+
+ Raises:
+ RuntimeError: When the client has not been connected.
+ """
+ if self._session is None:
+ raise RuntimeError("MCPClient is not connected; call connect_async first.")
+
+ server_side_name = self._strip_prefix(call.name)
+ timeout = getattr(self._spec, "timeout_seconds", 30.0)
+ try:
+ result = await asyncio.wait_for(
+ self._session.call_tool(server_side_name, arguments=dict(call.arguments)),
+ timeout=timeout,
+ )
+ except asyncio.TimeoutError:
+ logger.warning(
+ "MCP tool '%s' timed out after %.2fs",
+ call.name,
+ timeout,
+ )
+ return {
+ "is_error": True,
+ "error": "tool_timeout",
+ "tool": call.name,
+ "timeout_seconds": timeout,
+ }
+ except Exception as ex: # noqa: BLE001 — wrap and surface as envelope.
+ logger.warning(
+ "MCP tool '%s' raised %s: %s",
+ call.name,
+ type(ex).__name__,
+ ex,
+ )
+ return {
+ "is_error": True,
+ "error": "tool_execution_failed",
+ "tool": call.name,
+ "detail": str(ex),
+ }
+
+ content_text = _flatten_content(list(result.content))
+ is_error = bool(getattr(result, "isError", False))
+ envelope: dict[str, Any] = {
+ "is_error": is_error,
+ "content": content_text,
+ "tool": call.name,
+ }
+ if is_error:
+ envelope["error"] = "tool_execution_failed"
+ return envelope
+
+
+__all__ = [
+ "DockerMCPServerSpec",
+ "LocalMCPServerSpec",
+ "MCPClient",
+ "MCPServerSpec",
+ "RemoteMCPServerSpec",
+]
diff --git a/pyrit/tools/models.py b/pyrit/tools/models.py
new file mode 100644
index 0000000000..187bf635d9
--- /dev/null
+++ b/pyrit/tools/models.py
@@ -0,0 +1,241 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from __future__ import annotations
+
+import enum
+import functools
+import json
+import logging
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Any
+
+from pyrit.exceptions import ToolCallLoopLimitExceeded, ToolCallNotSupported
+from pyrit.models import Message, MessagePiece
+
+if TYPE_CHECKING:
+ from collections.abc import Awaitable, Callable
+
+ from pyrit.tools.backend import ToolBackend
+ from pyrit.tools.parsers import ToolCallParser
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class ToolCall:
+ """
+ A parsed tool call extracted from a target response.
+
+ Concrete ``ToolCallParser`` implementations build
+ ``ToolCall`` instances by walking the response message pieces.
+ The ``raw_envelope`` carries the original target-specific dict
+ (e.g. the function_call JSON section) so dispatchers and observers
+ can recover provider-specific fields without re-parsing.
+
+ Attributes:
+ call_id (str): The provider-issued call identifier; must round-trip
+ into the matching ``function_call_output`` piece.
+ name (str): The tool name to dispatch.
+ arguments (dict[str, Any]): The parsed JSON arguments.
+ raw_envelope (dict[str, Any]): The original provider envelope.
+ """
+
+ call_id: str
+ name: str
+ arguments: dict[str, Any]
+ raw_envelope: dict[str, Any] = field(default_factory=dict)
+
+
+class ToolEventBehavior(enum.Enum):
+ """
+ What the tool loop should do when a target response contains a
+ pending tool call.
+
+ Values:
+ EXECUTE: Dispatch the call via ``configuration.tool_backend``
+ and re-enter the target with the tool output appended.
+ This is the standard agentic loop behavior.
+ RAISE: Raise ``ToolCallNotSupported`` with
+ the partial conversation attached. Useful for red-team
+ attacks that want to observe attempted tool use without
+ allowing execution.
+ RETURN_RAW: Return the assistant response containing the tool
+ call as-is, without dispatching. Useful when a caller wants
+ to inspect tool calls in-band (e.g. a scorer that scores
+ attempted tool use).
+ """
+
+ EXECUTE = "execute"
+ RAISE = "raise"
+ RETURN_RAW = "return_raw"
+
+
+@dataclass(frozen=True)
+class ToolEventPolicy:
+ """
+ Per-target configuration that controls how the tool loop responds
+ to a pending tool call from the model.
+
+ Attributes:
+ behavior (ToolEventBehavior): What to do on each detected tool call.
+ max_tool_iterations (int): Maximum number of model<->tool round-trips
+ before the loop raises ``ToolCallLoopLimitExceeded``. Each
+ iteration is one ``_send_prompt_to_target_async`` call.
+ """
+
+ behavior: ToolEventBehavior
+ max_tool_iterations: int = 5
+
+
+def _build_function_call_output_message(
+ *,
+ reference_piece: MessagePiece,
+ outputs: list[tuple[ToolCall, Any]],
+) -> Message:
+ """
+ Build the canonical ``tool`` message produced after dispatching one or more
+ tool calls in a single iteration.
+
+ The returned ``Message`` contains one
+ ``MessagePiece`` per ``(call, result)`` pair, in declaration order.
+ Every piece has ``role="tool"`` and ``original_value_data_type="function_call_output"``,
+ with the JSON envelope ``{"type": "function_call_output", "call_id": ..., "output": ...}``.
+
+ Lineage metadata (conversation_id, identifiers) is copied from
+ *reference_piece* — typically the first piece of the assistant message
+ that issued the tool calls — so the resulting message stays inside the
+ correct conversation.
+
+ Args:
+ reference_piece (MessagePiece): Piece whose lineage metadata is
+ copied onto every output piece. Pass the first piece of the
+ assistant message that produced the calls.
+ outputs (list[tuple[ToolCall, Any]]): ``(call, result)`` pairs in
+ declaration order. *result* is serialized via ``json.dumps``
+ unless it is already a string.
+
+ Returns:
+ Message: One message carrying every function_call_output piece.
+ """
+ pieces: list[MessagePiece] = []
+ for call, result in outputs:
+ output_str = result if isinstance(result, str) else json.dumps(result, separators=(",", ":"))
+ envelope = json.dumps(
+ {"type": "function_call_output", "call_id": call.call_id, "output": output_str},
+ separators=(",", ":"),
+ )
+ pieces.append(
+ MessagePiece(
+ role="tool",
+ original_value=envelope,
+ original_value_data_type="function_call_output",
+ conversation_id=reference_piece.conversation_id,
+ prompt_target_identifier=reference_piece.prompt_target_identifier,
+ attack_identifier=reference_piece.attack_identifier,
+ )
+ )
+ return Message(message_pieces=pieces, skip_validation=True)
+
+
+def tool_loop(
+ method: Callable[..., Awaitable[list[Message]]],
+) -> Callable[..., Awaitable[list[Message]]]:
+ """
+ Wrap a ``PromptTarget``-style
+ ``send_prompt_async`` to run an agentic tool-use loop.
+
+ When the target's ``configuration.tool_event_policy`` is ``None`` the
+ wrapper is a no-op — the wrapped method runs unchanged. When a policy
+ is configured, the wrapper replaces the method body with the loop:
+
+ 1. Validate and normalize the incoming message exactly once.
+ 2. Repeatedly call ``self._send_prompt_to_target_async`` with the
+ growing conversation.
+ 3. After each call, parse the last response via ``self._tool_parser``.
+ Exit on empty parse (model issued a stop response).
+ 4. On a non-empty parse, branch on ``policy.behavior``:
+ ``RAISE`` raises ``ToolCallNotSupported``; ``RETURN_RAW``
+ returns the chain as-is; ``EXECUTE`` dispatches the calls via
+ ``configuration.tool_backend`` and appends the tool message.
+ 5. Raise ``ToolCallLoopLimitExceeded`` if the loop runs past
+ ``policy.max_tool_iterations`` without the model stopping.
+
+ The decorator deliberately knows nothing about MCP, OpenAI, or any
+ specific transport. The two collaborators it requires —
+ ``self._tool_parser`` and ``self.configuration.tool_backend`` — are
+ plain protocols (``ToolCallParser``, ``ToolBackend``).
+
+ Args:
+ method (Callable): The async method to wrap. Must have the
+ ``async def f(self, *, message: Message) -> list[Message]``
+ signature of ``PromptTarget.send_prompt_async``.
+
+ Returns:
+ Callable: The wrapped method.
+ """
+
+ @functools.wraps(method)
+ async def wrapper(self: Any, *, message: Message) -> list[Message]:
+ policy: ToolEventPolicy | None = getattr(self.configuration, "tool_event_policy", None)
+ if policy is None:
+ return await method(self, message=message)
+
+ message.validate()
+ normalized_conversation = await self._get_normalized_conversation_async(message=message)
+ if not normalized_conversation:
+ raise ValueError("Normalization pipeline returned an empty conversation. Cannot send an empty request.")
+ self._validate_request(normalized_conversation=normalized_conversation)
+
+ parser: ToolCallParser | None = getattr(self, "_tool_parser", None)
+ backend: ToolBackend | None = getattr(self.configuration, "tool_backend", None)
+ max_iter = policy.max_tool_iterations
+
+ all_responses: list[Message] = []
+
+ for _ in range(max_iter):
+ responses_this_turn = await self._send_prompt_to_target_async(
+ normalized_conversation=normalized_conversation,
+ )
+ all_responses.extend(responses_this_turn)
+
+ if parser is None:
+ return all_responses
+
+ last_response = responses_this_turn[-1]
+ pending_calls = parser.parse(last_response)
+
+ if not pending_calls:
+ return all_responses
+
+ if policy.behavior is ToolEventBehavior.RAISE:
+ raise ToolCallNotSupported(
+ message=(
+ f"Target produced {len(pending_calls)} tool call(s) but ToolEventPolicy.behavior is RAISE."
+ ),
+ partial_conversation=all_responses,
+ )
+
+ if policy.behavior is ToolEventBehavior.RETURN_RAW:
+ return all_responses
+
+ if backend is None:
+ raise ToolCallNotSupported(
+ message=(f"Target produced {len(pending_calls)} tool call(s) but no tool_backend is configured."),
+ partial_conversation=all_responses,
+ )
+
+ results = await backend.dispatch_all_sequential_async(pending_calls)
+ tool_msg = _build_function_call_output_message(
+ reference_piece=last_response.message_pieces[0],
+ outputs=results,
+ )
+ all_responses.append(tool_msg)
+ normalized_conversation = list(normalized_conversation) + [last_response, tool_msg]
+
+ raise ToolCallLoopLimitExceeded(
+ message=f"Tool loop exceeded max_tool_iterations={max_iter} without a stop response.",
+ partial_conversation=all_responses,
+ )
+
+ return wrapper
diff --git a/pyrit/tools/parsers.py b/pyrit/tools/parsers.py
new file mode 100644
index 0000000000..8f0038d5a0
--- /dev/null
+++ b/pyrit/tools/parsers.py
@@ -0,0 +1,107 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Protocol, runtime_checkable
+
+if TYPE_CHECKING:
+ from pyrit.models import Message, MessagePiece
+ from pyrit.tools.models import ToolCall
+
+
+@runtime_checkable
+class ToolCallParser(Protocol):
+ """
+ Protocol for extracting tool calls from a target response message.
+
+ Concrete parsers live next to the target whose response shape they
+ understand (the canonical-envelope parser shipped here is shared by
+ ``OpenAIResponseTarget``; per-model-family parsers for non-OpenAI
+ targets ship in a follow-up, see plan §12.9). Parsers MUST return an
+ empty list when the model has issued a stop response — the tool loop
+ uses the empty list as the signal to exit.
+ """
+
+ def parse(self, message: Message) -> list[ToolCall]:
+ """
+ Extract tool calls from a target response message.
+
+ Args:
+ message (Message): The most recent assistant response.
+
+ Returns:
+ list[ToolCall]: Tool calls, in declaration order. An empty list
+ signals that the model produced a stop response.
+ """
+ ...
+
+
+def _extract_function_call_pieces(message: Message) -> list[MessagePiece]:
+ """
+ Return every ``MessagePiece`` in *message* whose
+ ``original_value_data_type`` is ``"function_call"``.
+
+ This is the canonical envelope used by every PyRIT-supported tool-emitting
+ target. It is exposed here so concrete parsers can reuse the filter rather
+ than re-implementing it.
+
+ Args:
+ message (Message): The message to scan.
+
+ Returns:
+ list[MessagePiece]: Pieces whose ``original_value_data_type`` is
+ ``"function_call"``, in their declaration order.
+ """
+ return [piece for piece in message.message_pieces if piece.original_value_data_type == "function_call"]
+
+
+class CanonicalEnvelopeParser:
+ """
+ Reference ``ToolCallParser`` for the canonical function_call envelope.
+
+ Walks every ``MessagePiece`` whose ``original_value_data_type`` is
+ ``"function_call"`` and decodes the canonical JSON shape::
+
+ {
+ "type": "function_call",
+ "call_id": "",
+ "name": "",
+ "arguments": ""
+ }
+
+ into ``ToolCall`` instances. Pieces of other data types -- reasoning,
+ mcp_call, web_search_call, etc. -- are ignored (they pass through to
+ Memory but are not client-side dispatchable). Per-model-family parsers
+ for non-OpenAI targets ship in a follow-up PR (see plan §12.9).
+ """
+
+ def parse(self, message: Message) -> list[ToolCall]:
+ """
+ Decode canonical ``function_call`` pieces in *message* into ``ToolCall``.
+
+ Args:
+ message (Message): The most recent assistant response.
+
+ Returns:
+ list[ToolCall]: One ``ToolCall`` per ``function_call``
+ piece, in declaration order. Empty if the message contains
+ no ``function_call`` pieces (model stop).
+ """
+ from pyrit.tools.models import ToolCall
+
+ calls: list[ToolCall] = []
+ for piece in _extract_function_call_pieces(message):
+ envelope = json.loads(piece.original_value)
+ arguments_raw = envelope.get("arguments", "{}")
+ arguments = json.loads(arguments_raw) if isinstance(arguments_raw, str) else dict(arguments_raw)
+ calls.append(
+ ToolCall(
+ call_id=envelope["call_id"],
+ name=envelope["name"],
+ arguments=arguments,
+ raw_envelope=envelope,
+ )
+ )
+ return calls
diff --git a/tests/integration/tools/__init__.py b/tests/integration/tools/__init__.py
new file mode 100644
index 0000000000..9a0454564d
--- /dev/null
+++ b/tests/integration/tools/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
diff --git a/tests/integration/tools/test_azure_ml_with_tools_integration.py b/tests/integration/tools/test_azure_ml_with_tools_integration.py
new file mode 100644
index 0000000000..9f0b4903b7
--- /dev/null
+++ b/tests/integration/tools/test_azure_ml_with_tools_integration.py
@@ -0,0 +1,162 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""F2 integration test: AzureMLChatTarget end-to-end through @tool_loop.
+
+Validates that PyRIT's full client-side tool-calling stack works against
+an AzureML chat target whose scoring script emits the canonical
+function_call envelope per plan §12.9.2. Only the outbound HTTP layer is
+mocked; the @tool_loop decorator, CanonicalEnvelopeParser, LocalToolBackend
+dispatch, ChatMessageNormalizer tool-piece serialization, and Memory
+persistence all run unmocked.
+
+The test asserts the canonical four-piece transcript shape:
+``[user, assistant function_call, tool function_call_output, assistant text]``.
+"""
+
+from __future__ import annotations
+
+import json
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from pyrit.memory import CentralMemory
+from pyrit.models import Message, MessagePiece
+from pyrit.prompt_normalizer import PromptNormalizer
+from pyrit.prompt_target import AzureMLChatTarget
+from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
+from pyrit.prompt_target.common.target_configuration import TargetConfiguration
+from pyrit.tools import (
+ CanonicalEnvelopeParser,
+ LocalToolBackend,
+ ToolEventBehavior,
+ ToolEventPolicy,
+)
+
+
+@pytest.fixture
+def echo_backend():
+ async def _echo(args):
+ return {"echoed": args.get("text", "")}
+
+ return LocalToolBackend(
+ callables={"echo": _echo},
+ schemas=[
+ {
+ "name": "echo",
+ "description": "Echo back the given text.",
+ "parameters": {
+ "type": "object",
+ "properties": {"text": {"type": "string"}},
+ "required": ["text"],
+ },
+ }
+ ],
+ )
+
+
+@pytest.mark.run_only_if_all_tests
+async def test_azure_ml_chat_target_tool_loop_round_trip(patch_central_database, echo_backend):
+ """User asks for a tool call; the loop dispatches; the model produces final text."""
+ target = AzureMLChatTarget(
+ endpoint="https://mock-endpoint.example.com/score",
+ api_key="dummy",
+ tool_parser=CanonicalEnvelopeParser(),
+ tool_backend=echo_backend,
+ custom_configuration=TargetConfiguration(
+ capabilities=TargetCapabilities(
+ supports_multi_message_pieces=True,
+ supports_editable_history=True,
+ supports_multi_turn=True,
+ supports_system_prompt=True,
+ supports_tool_use=True,
+ ),
+ tool_event_policy=ToolEventPolicy(
+ behavior=ToolEventBehavior.EXECUTE,
+ max_tool_iterations=3,
+ ),
+ tool_backend=echo_backend,
+ ),
+ )
+
+ first_response = MagicMock()
+ first_response.json.return_value = {
+ "output": "",
+ "tool_calls": [
+ {
+ "type": "function_call",
+ "call_id": "call_0",
+ "name": "echo",
+ "arguments": '{"text":"hi"}',
+ }
+ ],
+ }
+ second_response = MagicMock()
+ second_response.json.return_value = {"output": "The echoed text is: hi"}
+
+ user_msg = Message(
+ message_pieces=[
+ MessagePiece(
+ role="user",
+ original_value="Use the echo tool to repeat 'hi'.",
+ original_value_data_type="text",
+ )
+ ]
+ )
+
+ with patch(
+ "pyrit.common.net_utility.make_request_and_raise_if_error_async",
+ new_callable=AsyncMock,
+ ) as mock_http:
+ mock_http.side_effect = [first_response, second_response]
+ normalizer = PromptNormalizer()
+ result = await normalizer.send_prompt_async(message=user_msg, target=target)
+
+ assert mock_http.call_count == 2
+ final_text = result.get_value()
+ assert "The echoed text is" in final_text
+
+ conv = CentralMemory.get_memory_instance().get_conversation(
+ conversation_id=result.message_pieces[0].conversation_id
+ )
+ # Canonical four-piece transcript: user -> assistant fc -> tool fco -> assistant text.
+ flat_pieces = [p for msg in conv for p in msg.message_pieces]
+ types = [p.original_value_data_type for p in flat_pieces]
+ roles = [p.api_role for p in flat_pieces]
+ assert types == ["text", "function_call", "function_call_output", "text"]
+ assert roles == ["user", "assistant", "tool", "assistant"]
+
+ # call_id round-trips between the assistant and tool messages.
+ fc_envelope = json.loads(flat_pieces[1].original_value)
+ fco_envelope = json.loads(flat_pieces[2].original_value)
+ assert fc_envelope["call_id"] == fco_envelope["call_id"] == "call_0"
+ # Tool dispatched the local echo callable; the output reflects the args.
+ assert json.loads(fco_envelope["output"])["echoed"] == "hi"
+
+
+@pytest.mark.run_only_if_all_tests
+async def test_azure_ml_chat_target_no_tools_backward_compat(patch_central_database):
+ """Without tool_parser / tool_backend, the request body has no tools key."""
+ target = AzureMLChatTarget(
+ endpoint="https://mock-endpoint.example.com/score",
+ api_key="dummy",
+ )
+ user_msg = Message(
+ message_pieces=[MessagePiece(role="user", original_value="Hello", original_value_data_type="text")]
+ )
+
+ response = MagicMock()
+ response.json.return_value = {"output": "Hi back"}
+
+ with patch(
+ "pyrit.common.net_utility.make_request_and_raise_if_error_async",
+ new_callable=AsyncMock,
+ ) as mock_http:
+ mock_http.return_value = response
+ normalizer = PromptNormalizer()
+ result = await normalizer.send_prompt_async(message=user_msg, target=target)
+
+ assert result.get_value() == "Hi back"
+ body = mock_http.call_args.kwargs["request_body"]
+ assert "tools" not in body
diff --git a/tests/integration/tools/test_red_teaming_with_tools.py b/tests/integration/tools/test_red_teaming_with_tools.py
new file mode 100644
index 0000000000..9dca01371a
--- /dev/null
+++ b/tests/integration/tools/test_red_teaming_with_tools.py
@@ -0,0 +1,361 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""C7 integration tests: RedTeamingAttack with real tool dispatch.
+
+These tests spawn the real ``tests/unit/tools/echo_mcp_server.py`` subprocess
+and exercise the full client-side tool-calling stack:
+
+ attack -> normalizer -> target -> @tool_loop wrapper -> MCPToolBackend ->
+ MCPClient (stdio) -> echo subprocess -> tool result -> back through the
+ wrapper -> Memory.
+
+Only the OpenAI Responses HTTP layer is mocked. The MCP subprocess, the
+MCPToolBackend lock, the AsyncExitStack lifecycle, the canonical envelope
+round-trip, and the @tool_loop decorator's RedTeam-attack invocation path
+all execute under their real implementations.
+"""
+
+from __future__ import annotations
+
+import json
+import pathlib
+import sys
+import uuid
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from pyrit.executor.attack.core.attack_config import AttackAdversarialConfig, AttackScoringConfig
+from pyrit.executor.attack.multi_turn.red_teaming import RedTeamingAttack
+from pyrit.identifiers import ComponentIdentifier
+from pyrit.memory import CentralMemory
+from pyrit.models import Message, MessagePiece, Score
+from pyrit.prompt_target import OpenAIResponseTarget
+from pyrit.prompt_target.common.prompt_target import PromptTarget
+from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
+from pyrit.prompt_target.common.target_configuration import TargetConfiguration
+from pyrit.score.true_false.true_false_scorer import TrueFalseScorer
+from pyrit.tools import (
+ LocalMCPServerSpec,
+ MCPToolBackend,
+ ToolEventBehavior,
+ ToolEventPolicy,
+)
+
+
+def _mock_id(name: str) -> ComponentIdentifier:
+ return ComponentIdentifier(class_name=name, class_module="test")
+
+
+ECHO_SERVER_PATH = pathlib.Path(__file__).resolve().parents[2] / "unit" / "tools" / "echo_mcp_server.py"
+
+
+def _local_echo_spec() -> LocalMCPServerSpec:
+ """Build a LocalMCPServerSpec that launches the in-tree echo server."""
+ return LocalMCPServerSpec(
+ command=sys.executable,
+ args=(str(ECHO_SERVER_PATH),),
+ )
+
+
+def _mock_function_call_response(call_id: str, function_name: str, arguments: dict) -> MagicMock:
+ """Build a fake Responses-API response containing a function_call section."""
+ response = MagicMock()
+ response.status = "completed"
+ response.error = None
+ section = MagicMock()
+ section.type = "function_call"
+ section.call_id = call_id
+ section.name = function_name
+ section.arguments = json.dumps(arguments)
+ section.model_dump.return_value = {
+ "type": "function_call",
+ "call_id": call_id,
+ "name": function_name,
+ "arguments": json.dumps(arguments),
+ }
+ response.output = [section]
+ return response
+
+
+def _mock_text_response(text: str) -> MagicMock:
+ """Build a fake Responses-API response containing a message section."""
+ response = MagicMock()
+ response.status = "completed"
+ response.error = None
+ section = MagicMock()
+ section.type = "message"
+ section.content = [MagicMock(text=text)]
+ response.output = [section]
+ return response
+
+
+def _make_response_target_with_mcp_backend(
+ backend: MCPToolBackend,
+) -> OpenAIResponseTarget:
+ """Build an OpenAIResponseTarget wired to the live MCP backend."""
+ caps = TargetCapabilities(
+ supports_multi_turn=True,
+ supports_editable_history=True,
+ supports_json_output=True,
+ supports_multi_message_pieces=True,
+ supports_system_prompt=True,
+ supports_tool_use=True,
+ input_modalities=frozenset(
+ {
+ frozenset(["text"]),
+ frozenset(["text", "image_path"]),
+ frozenset(["function_call"]),
+ frozenset(["tool_call"]),
+ frozenset(["function_call_output"]),
+ frozenset(["reasoning"]),
+ }
+ ),
+ )
+ config = TargetConfiguration(
+ capabilities=caps,
+ tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE, max_tool_iterations=5),
+ tool_backend=backend,
+ )
+ return OpenAIResponseTarget(
+ model_name="gpt-4",
+ endpoint="https://mock.example.com",
+ api_key="mock-key",
+ custom_configuration=config,
+ )
+
+
+def _scripted_adversarial(prompts: list[str]) -> MagicMock:
+ """Build a mock adversarial target that returns scripted prompts."""
+ adversarial = MagicMock(spec=PromptTarget)
+ adversarial.send_prompt_async = AsyncMock(
+ side_effect=[
+ [
+ Message(
+ message_pieces=[
+ MessagePiece(
+ role="assistant",
+ original_value=p,
+ original_value_data_type="text",
+ conversation_id=str(uuid.uuid4()),
+ )
+ ]
+ )
+ ]
+ for p in prompts
+ ]
+ )
+ adversarial.get_identifier.return_value = _mock_id("MockAdversarial")
+ adversarial.set_system_prompt = MagicMock()
+ return adversarial
+
+
+def _success_scorer() -> MagicMock:
+ """Mock objective scorer that always returns True (objective met)."""
+ scorer = MagicMock(spec=TrueFalseScorer)
+ scorer.score_async = AsyncMock(
+ return_value=[
+ Score(
+ score_value="true",
+ score_value_description="objective met",
+ score_type="true_false",
+ score_category=["test"],
+ score_rationale="mock rationale",
+ score_metadata={},
+ message_piece_id=str(uuid.uuid4()),
+ scorer_class_identifier=_mock_id("MockScorer"),
+ )
+ ]
+ )
+ scorer.get_identifier.return_value = _mock_id("MockScorer")
+ return scorer
+
+
+@pytest.mark.asyncio
+async def test_red_teaming_response_target_with_mcp_echo(patch_central_database):
+ """End-to-end: RedTeamingAttack drives OpenAIResponseTarget with MCPToolBackend.
+
+ The Response target's HTTP layer is mocked to return a function_call for
+ the echo tool, followed by a stop response after the tool result arrives.
+ The MCP subprocess actually executes the echo call.
+ """
+ backend = MCPToolBackend(servers=[_local_echo_spec()])
+ async with backend:
+ objective_target = _make_response_target_with_mcp_backend(backend)
+
+ # Mock the OpenAI Responses HTTP layer on the objective target.
+ responses = [
+ _mock_function_call_response("call_1", "echo", {"text": "hello"}),
+ _mock_text_response("Echoed: hello"),
+ ]
+ seen = []
+
+ async def mock_create(**kwargs):
+ seen.append(kwargs)
+ return responses[len(seen) - 1]
+
+ # Adversarial returns one prompt (RedTeamingAttack stops after objective is met)
+ adversarial = _scripted_adversarial(["please echo hello"])
+
+ attack = RedTeamingAttack(
+ objective_target=objective_target,
+ attack_adversarial_config=AttackAdversarialConfig(target=adversarial),
+ attack_scoring_config=AttackScoringConfig(objective_scorer=_success_scorer()),
+ )
+
+ with patch.object(
+ objective_target._async_client.responses, "create", new_callable=AsyncMock
+ ) as mock_create_call:
+ mock_create_call.side_effect = mock_create
+ result = await attack.execute_async(objective="get the model to echo 'hello'")
+
+ # Two HTTP calls to the Response API: initial + post-tool
+ assert len(seen) == 2
+ # Second call must include the function_call_output (tool result)
+ second_input = seen[1]["input"]
+ function_outputs = [item for item in second_input if item.get("type") == "function_call_output"]
+ assert len(function_outputs) == 1
+ # The output JSON contains the text "hello" because the real MCP echo
+ # subprocess returned it
+ assert "hello" in function_outputs[0]["output"]
+ assert result is not None
+
+
+@pytest.mark.asyncio
+async def test_red_teaming_persists_canonical_transcript_in_memory(patch_central_database):
+ """End-to-end: after a successful tool dispatch the DB shows the full chain.
+
+ Verifies the canonical envelope contract (§13): the conversation written
+ to Memory must contain the user message, the assistant function_call, the
+ tool function_call_output (with matching call_id), and the assistant's
+ final text -- in that order.
+ """
+ backend = MCPToolBackend(servers=[_local_echo_spec()])
+ async with backend:
+ objective_target = _make_response_target_with_mcp_backend(backend)
+
+ responses = [
+ _mock_function_call_response("call_xyz", "echo", {"text": "world"}),
+ _mock_text_response("Echoed: world"),
+ ]
+ seen = []
+
+ async def mock_create(**kwargs):
+ seen.append(kwargs)
+ return responses[len(seen) - 1]
+
+ adversarial = _scripted_adversarial(["echo world"])
+
+ attack = RedTeamingAttack(
+ objective_target=objective_target,
+ attack_adversarial_config=AttackAdversarialConfig(target=adversarial),
+ attack_scoring_config=AttackScoringConfig(objective_scorer=_success_scorer()),
+ )
+
+ with patch.object(
+ objective_target._async_client.responses, "create", new_callable=AsyncMock
+ ) as mock_create_call:
+ mock_create_call.side_effect = mock_create
+ result = await attack.execute_async(objective="echo world")
+
+ # Read the conversation back from Memory
+ memory = CentralMemory.get_memory_instance()
+ assert result is not None
+ objective_conv_id = result.conversation_id
+ assert objective_conv_id, "Attack result must carry the objective-target conversation id"
+
+ pieces = list(memory.get_message_pieces(conversation_id=objective_conv_id))
+ # Filter out system prompts; we care about the user/assistant/tool chain
+ data_types_in_order = [p.original_value_data_type for p in pieces]
+ # The chain MUST contain function_call followed by function_call_output (canonical envelope)
+ assert "function_call" in data_types_in_order
+ assert "function_call_output" in data_types_in_order
+
+ fc_index = data_types_in_order.index("function_call")
+ fco_index = data_types_in_order.index("function_call_output")
+ assert fc_index < fco_index, "function_call must precede function_call_output in DB"
+
+ fc_envelope = json.loads(pieces[fc_index].original_value)
+ fco_envelope = json.loads(pieces[fco_index].original_value)
+ assert fc_envelope["call_id"] == fco_envelope["call_id"] == "call_xyz"
+ assert fc_envelope["name"] == "echo"
+ # The tool result envelope's `output` is JSON-encoded; the underlying echo result is "world"
+ assert "world" in fco_envelope["output"]
+
+
+@pytest.mark.asyncio
+async def test_red_teaming_dispatches_all_tool_calls_per_turn(patch_central_database):
+ """Multi-call-per-turn dispatch (intentional behavior change vs pre-C6 loop).
+
+ When the model emits two function_call sections in one response, BOTH
+ must dispatch through the MCPToolBackend. The pre-C6 in-class loop in
+ OpenAIResponseTarget only dispatched the LAST call per turn; the C6
+ migration onto @tool_loop changes this to "dispatch every call in
+ declaration order." Verify by issuing both an `echo` and an `add` call
+ and asserting both results land in the second API call's input.
+ """
+ backend = MCPToolBackend(servers=[_local_echo_spec()])
+ async with backend:
+ objective_target = _make_response_target_with_mcp_backend(backend)
+
+ # First response contains TWO function_calls; second is the stop text.
+ multi_call_response = MagicMock()
+ multi_call_response.status = "completed"
+ multi_call_response.error = None
+
+ echo_section = MagicMock()
+ echo_section.type = "function_call"
+ echo_section.call_id = "call_echo"
+ echo_section.name = "echo"
+ echo_section.arguments = json.dumps({"text": "hi"})
+ echo_section.model_dump.return_value = {
+ "type": "function_call",
+ "call_id": "call_echo",
+ "name": "echo",
+ "arguments": json.dumps({"text": "hi"}),
+ }
+ add_section = MagicMock()
+ add_section.type = "function_call"
+ add_section.call_id = "call_add"
+ add_section.name = "add"
+ add_section.arguments = json.dumps({"a": 3, "b": 4})
+ add_section.model_dump.return_value = {
+ "type": "function_call",
+ "call_id": "call_add",
+ "name": "add",
+ "arguments": json.dumps({"a": 3, "b": 4}),
+ }
+ multi_call_response.output = [echo_section, add_section]
+
+ responses = [
+ multi_call_response,
+ _mock_text_response("done"),
+ ]
+ seen = []
+
+ async def mock_create(**kwargs):
+ seen.append(kwargs)
+ return responses[len(seen) - 1]
+
+ adversarial = _scripted_adversarial(["call echo and add"])
+
+ attack = RedTeamingAttack(
+ objective_target=objective_target,
+ attack_adversarial_config=AttackAdversarialConfig(target=adversarial),
+ attack_scoring_config=AttackScoringConfig(objective_scorer=_success_scorer()),
+ )
+
+ with patch.object(objective_target._async_client.responses, "create", new_callable=AsyncMock) as mc:
+ mc.side_effect = mock_create
+ await attack.execute_async(objective="dispatch both tools")
+
+ assert len(seen) == 2
+ second_input = seen[1]["input"]
+ outputs = [item for item in second_input if item.get("type") == "function_call_output"]
+ assert len(outputs) == 2, "Both tool calls must be dispatched per the new behavior"
+ call_ids = [o["call_id"] for o in outputs]
+ assert call_ids == ["call_echo", "call_add"], "Outputs must preserve declaration order"
+ # Real MCP subprocess: echo("hi") returned "hi", add(3, 4) returned 7
+ assert "hi" in outputs[0]["output"]
+ assert "7" in outputs[1]["output"]
diff --git a/tests/unit/message_normalizer/test_chat_message_normalizer.py b/tests/unit/message_normalizer/test_chat_message_normalizer.py
index b9a7cec57b..decc589bcb 100644
--- a/tests/unit/message_normalizer/test_chat_message_normalizer.py
+++ b/tests/unit/message_normalizer/test_chat_message_normalizer.py
@@ -319,3 +319,134 @@ async def test_returns_list_of_dicts(self):
assert isinstance(result[0], dict)
assert result[0]["role"] == "user"
assert result[0]["content"] == "Hello"
+
+
+class TestChatMessageNormalizerToolPieces:
+ """Tool-call piece coverage: function_call -> assistant.tool_calls,
+ function_call_output -> role=tool message with tool_call_id."""
+
+ async def test_function_call_piece_becomes_assistant_tool_call_message(self):
+ normalizer = ChatMessageNormalizer()
+ envelope = {
+ "type": "function_call",
+ "call_id": "call_0",
+ "name": "echo",
+ "arguments": '{"text":"hi"}',
+ }
+ fc_piece = MessagePiece(
+ role="assistant",
+ original_value=json.dumps(envelope),
+ original_value_data_type="function_call",
+ converted_value_data_type="function_call",
+ )
+ messages = [Message(message_pieces=[fc_piece])]
+
+ result = await normalizer.normalize_async(messages)
+
+ assert len(result) == 1
+ assert result[0].role == "assistant"
+ assert result[0].content is None
+ assert result[0].tool_calls is not None
+ assert len(result[0].tool_calls) == 1
+ assert result[0].tool_calls[0].id == "call_0"
+ assert result[0].tool_calls[0].type == "function"
+ assert result[0].tool_calls[0].function.name == "echo"
+ assert result[0].tool_calls[0].function.arguments == '{"text":"hi"}'
+
+ async def test_function_call_output_piece_becomes_tool_role_message(self):
+ normalizer = ChatMessageNormalizer()
+ envelope = {
+ "type": "function_call_output",
+ "call_id": "call_0",
+ "output": '{"echoed":"hi"}',
+ }
+ fco_piece = MessagePiece(
+ role="tool",
+ original_value=json.dumps(envelope),
+ original_value_data_type="function_call_output",
+ converted_value_data_type="function_call_output",
+ )
+ messages = [Message(message_pieces=[fco_piece], skip_validation=True)]
+
+ result = await normalizer.normalize_async(messages)
+
+ assert len(result) == 1
+ assert result[0].role == "tool"
+ assert result[0].tool_call_id == "call_0"
+ assert result[0].content == '{"echoed":"hi"}'
+
+ async def test_full_tool_conversation_round_trip(self):
+ """A user -> assistant fc -> tool fco -> assistant text conversation
+ normalizes into the canonical OpenAI Chat Completions wire shape."""
+ normalizer = ChatMessageNormalizer()
+
+ user = _make_message("user", "Use the echo tool to repeat 'hi'.")
+
+ fc_envelope = {
+ "type": "function_call",
+ "call_id": "call_0",
+ "name": "echo",
+ "arguments": '{"text":"hi"}',
+ }
+ assistant_fc = Message(
+ message_pieces=[
+ MessagePiece(
+ role="assistant",
+ original_value=json.dumps(fc_envelope),
+ original_value_data_type="function_call",
+ converted_value_data_type="function_call",
+ )
+ ]
+ )
+
+ fco_envelope = {
+ "type": "function_call_output",
+ "call_id": "call_0",
+ "output": '{"echoed":"hi"}',
+ }
+ tool_msg = Message(
+ message_pieces=[
+ MessagePiece(
+ role="tool",
+ original_value=json.dumps(fco_envelope),
+ original_value_data_type="function_call_output",
+ converted_value_data_type="function_call_output",
+ )
+ ],
+ skip_validation=True,
+ )
+
+ assistant_final = _make_message("assistant", "The echoed text is: hi")
+
+ result = await normalizer.normalize_async([user, assistant_fc, tool_msg, assistant_final])
+
+ assert [m.role for m in result] == ["user", "assistant", "tool", "assistant"]
+ assert result[0].content == "Use the echo tool to repeat 'hi'."
+ assert result[1].content is None
+ assert result[1].tool_calls[0].function.name == "echo"
+ assert result[2].tool_call_id == "call_0"
+ assert result[2].content == '{"echoed":"hi"}'
+ assert result[3].content == "The echoed text is: hi"
+
+ async def test_function_call_output_serialized_to_dict_excludes_content_when_none(self):
+ """An assistant tool-call-only message must serialize without a content key."""
+ normalizer = ChatMessageNormalizer()
+ envelope = {
+ "type": "function_call",
+ "call_id": "c1",
+ "name": "f",
+ "arguments": "{}",
+ }
+ msg = Message(
+ message_pieces=[
+ MessagePiece(
+ role="assistant",
+ original_value=json.dumps(envelope),
+ original_value_data_type="function_call",
+ converted_value_data_type="function_call",
+ )
+ ]
+ )
+ dicts = await normalizer.normalize_to_dicts_async([msg])
+ assert "content" not in dicts[0]
+ assert dicts[0]["tool_calls"][0]["function"]["name"] == "f"
diff --git a/tests/unit/models/test_chat_message.py b/tests/unit/models/test_chat_message.py
index 8391e9340b..9cee645a47 100644
--- a/tests/unit/models/test_chat_message.py
+++ b/tests/unit/models/test_chat_message.py
@@ -10,19 +10,30 @@
ChatMessage,
ChatMessagesDataset,
ToolCall,
+ ToolCallFunction,
)
def test_tool_call_init():
- tc = ToolCall(id="call_1", type="function", function="get_weather")
+ tc = ToolCall(
+ id="call_1",
+ type="function",
+ function=ToolCallFunction(name="get_weather", arguments='{"city":"NYC"}'),
+ )
assert tc.id == "call_1"
assert tc.type == "function"
- assert tc.function == "get_weather"
+ assert tc.function.name == "get_weather"
+ assert tc.function.arguments == '{"city":"NYC"}'
def test_tool_call_forbids_extra_fields():
with pytest.raises(ValidationError):
- ToolCall(id="call_1", type="function", function="get_weather", extra="bad")
+ ToolCall(
+ id="call_1",
+ type="function",
+ function=ToolCallFunction(name="get_weather", arguments="{}"),
+ extra="bad",
+ )
def test_chat_message_init_with_string_content():
@@ -41,7 +52,7 @@ def test_chat_message_init_with_list_content():
def test_chat_message_init_with_all_fields():
- tc = ToolCall(id="call_1", type="function", function="lookup")
+ tc = ToolCall(id="call_1", type="function", function=ToolCallFunction(name="lookup", arguments="{}"))
msg = ChatMessage(
role="assistant",
content="result",
@@ -91,13 +102,26 @@ def test_chat_message_model_validate_json_roundtrip():
def test_chat_message_model_validate_json_roundtrip_with_tool_calls():
- tc = ToolCall(id="c1", type="function", function="fn")
+ tc = ToolCall(id="c1", type="function", function=ToolCallFunction(name="fn", arguments="{}"))
original = ChatMessage(role="assistant", content="ok", tool_calls=[tc], tool_call_id="c1")
restored = ChatMessage.model_validate_json(original.model_dump_json())
assert restored.tool_calls[0].id == "c1"
+ assert restored.tool_calls[0].function.name == "fn"
assert restored.tool_call_id == "c1"
+def test_chat_message_content_allows_none_for_tool_call_only_assistant_message():
+ """OpenAI Chat Completions allows assistant messages with content=null when tool_calls is set."""
+ tc = ToolCall(id="c1", type="function", function=ToolCallFunction(name="fn", arguments="{}"))
+ msg = ChatMessage(role="assistant", content=None, tool_calls=[tc])
+ assert msg.content is None
+ assert msg.tool_calls == [tc]
+ dumped = msg.to_dict()
+ # content is None so it should be excluded from the serialized dict.
+ assert "content" not in dumped
+ assert dumped["tool_calls"][0]["function"]["name"] == "fn"
+
+
@pytest.mark.parametrize("role", ["system", "user", "assistant", "simulated_assistant", "tool", "developer"])
def test_chat_message_accepts_all_valid_roles(role):
msg = ChatMessage(role=role, content="test")
diff --git a/tests/unit/prompt_target/target/test_azure_ml_chat_target.py b/tests/unit/prompt_target/target/test_azure_ml_chat_target.py
index e9517d3ec8..b269894ba1 100644
--- a/tests/unit/prompt_target/target/test_azure_ml_chat_target.py
+++ b/tests/unit/prompt_target/target/test_azure_ml_chat_target.py
@@ -72,7 +72,7 @@ async def test_complete_chat_async(aml_online_chat: AzureMLChatTarget):
mock_response.json.return_value = {"output": "extracted response"}
mock.return_value = mock_response
response = await aml_online_chat._complete_chat_async(messages)
- assert response == "extracted response"
+ assert response == {"output": "extracted response"}
mock.assert_called_once()
@@ -90,7 +90,7 @@ async def test_complete_chat_async_with_default_normalizer(
mock_response.json.return_value = {"output": "extracted response"}
mock.return_value = mock_response
response = await aml_online_chat._complete_chat_async(messages)
- assert response == "extracted response"
+ assert response == {"output": "extracted response"}
args, kwargs = mock.call_args
body = kwargs["request_body"]
@@ -107,9 +107,12 @@ async def test_complete_chat_async_bad_json_response(aml_online_chat: AzureMLCha
with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock) as mock:
mock_response = MagicMock()
+ # Set is a non-dict body that previously raised TypeError when the code
+ # subscripted response.json()["output"]; the new code raises ValueError
+ # because the body is not a dict.
mock_response.json.return_value = {"bad response"}
mock.return_value = mock_response
- with pytest.raises(TypeError):
+ with pytest.raises((TypeError, ValueError, EmptyResponseException)):
await aml_online_chat._complete_chat_async(messages)
@@ -178,8 +181,10 @@ async def test_send_prompt_async_rate_limit_exception_retries(aml_online_chat: A
async def test_send_prompt_async_empty_response_retries(aml_online_chat: AzureMLChatTarget):
response = MagicMock()
response.status_code = 429
+ # Return an empty dict; _materialize_response raises EmptyResponseException
+ # when both output and tool_calls are missing.
mock_complete_chat_async = AsyncMock()
- mock_complete_chat_async.return_value = None
+ mock_complete_chat_async.return_value = {}
aml_online_chat._complete_chat_async = mock_complete_chat_async
message = Message(message_pieces=[MessagePiece(role="user", conversation_id="12345", original_value="Hello")])
@@ -236,3 +241,140 @@ def test_valid_temperature_and_top_p(patch_central_database):
)
assert target._temperature == 1.5
assert target._top_p == 0.9
+
+
+# ---------------------------------------------------------------------------
+# Tool calling: tool_parser + tool_backend kwargs
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture
+def echo_backend():
+ from pyrit.tools import LocalToolBackend
+
+ async def _echo(args):
+ return {"echoed": args.get("text", "")}
+
+ return LocalToolBackend(
+ callables={"echo": _echo},
+ schemas=[
+ {
+ "name": "echo",
+ "description": "Echo back the given text.",
+ "parameters": {
+ "type": "object",
+ "properties": {"text": {"type": "string"}},
+ "required": ["text"],
+ },
+ }
+ ],
+ )
+
+
+def test_tool_parser_kwarg_flips_supports_tool_use_capability(patch_central_database):
+ from pyrit.prompt_target.common.target_capabilities import CapabilityName
+ from pyrit.tools import CanonicalEnvelopeParser
+
+ target = AzureMLChatTarget(
+ endpoint="http://aml-test-endpoint.com",
+ api_key="k",
+ tool_parser=CanonicalEnvelopeParser(),
+ )
+ assert target.configuration.includes(capability=CapabilityName.TOOL_USE)
+ assert target._tool_parser is not None
+
+
+def test_no_tool_parser_leaves_supports_tool_use_off(aml_online_chat: AzureMLChatTarget):
+ from pyrit.prompt_target.common.target_capabilities import CapabilityName
+
+ assert not aml_online_chat.configuration.includes(capability=CapabilityName.TOOL_USE)
+ assert aml_online_chat._tool_parser is None
+
+
+def test_tool_backend_kwarg_installed_into_configuration(patch_central_database, echo_backend):
+ from pyrit.tools import CanonicalEnvelopeParser
+
+ target = AzureMLChatTarget(
+ endpoint="http://aml-test-endpoint.com",
+ api_key="k",
+ tool_parser=CanonicalEnvelopeParser(),
+ tool_backend=echo_backend,
+ )
+ assert target.configuration.tool_backend is echo_backend
+
+
+def test_tool_backend_kwarg_without_parser_raises(patch_central_database, echo_backend):
+ # Without tool_parser, the default configuration has supports_tool_use=False,
+ # so attaching a backend must raise.
+ with pytest.raises(ValueError, match="supports_tool_use"):
+ AzureMLChatTarget(
+ endpoint="http://aml-test-endpoint.com",
+ api_key="k",
+ tool_backend=echo_backend,
+ )
+
+
+def test_tool_schemas_wraps_backend_schemas_in_chat_completions_shape(patch_central_database, echo_backend):
+ from pyrit.tools import CanonicalEnvelopeParser
+
+ target = AzureMLChatTarget(
+ endpoint="http://aml-test-endpoint.com",
+ api_key="k",
+ tool_parser=CanonicalEnvelopeParser(),
+ tool_backend=echo_backend,
+ )
+ schemas = target._tool_schemas()
+ assert len(schemas) == 1
+ assert schemas[0]["type"] == "function"
+ assert schemas[0]["function"]["name"] == "echo"
+
+
+def test_tool_schemas_empty_when_no_backend(aml_online_chat: AzureMLChatTarget):
+ assert aml_online_chat._tool_schemas() == []
+
+
+async def test_request_body_omits_tools_key_when_no_backend(aml_online_chat: AzureMLChatTarget):
+ messages = [Message(message_pieces=[MessagePiece(role="user", original_value="hi")])]
+ body = await aml_online_chat._construct_http_body_async(messages)
+ assert "tools" not in body
+
+
+async def test_request_body_includes_tools_when_backend_set(patch_central_database, echo_backend):
+ from pyrit.tools import CanonicalEnvelopeParser
+
+ target = AzureMLChatTarget(
+ endpoint="http://aml-test-endpoint.com",
+ api_key="k",
+ tool_parser=CanonicalEnvelopeParser(),
+ tool_backend=echo_backend,
+ )
+ messages = [Message(message_pieces=[MessagePiece(role="user", original_value="hi")])]
+ body = await target._construct_http_body_async(messages)
+ assert "tools" in body
+ assert body["tools"][0]["function"]["name"] == "echo"
+
+
+async def test_materialize_response_handles_text_and_tool_calls(patch_central_database, echo_backend):
+ from pyrit.tools import CanonicalEnvelopeParser
+
+ target = AzureMLChatTarget(
+ endpoint="http://aml-test-endpoint.com",
+ api_key="k",
+ tool_parser=CanonicalEnvelopeParser(),
+ tool_backend=echo_backend,
+ )
+ request = MessagePiece(role="user", original_value="hi", conversation_id="abc")
+ response = {
+ "output": "ok",
+ "tool_calls": [
+ {
+ "type": "function_call",
+ "call_id": "call_0",
+ "name": "echo",
+ "arguments": '{"text":"hi"}',
+ }
+ ],
+ }
+ msg = target._materialize_response(response=response, request=request)
+ types = [p.original_value_data_type for p in msg.message_pieces]
+ assert types == ["text", "function_call"]
diff --git a/tests/unit/prompt_target/target/test_huggingface_chat_target.py b/tests/unit/prompt_target/target/test_huggingface_chat_target.py
index 93a4ca912f..4cddd5dbd6 100644
--- a/tests/unit/prompt_target/target/test_huggingface_chat_target.py
+++ b/tests/unit/prompt_target/target/test_huggingface_chat_target.py
@@ -578,3 +578,163 @@ async def test_effective_generation_config_in_metadata():
assert effective_config["temperature"] == 1.0
# Model defaults should also be present
assert effective_config["eos_token_id"] == 2
+
+
+# ---------------------------------------------------------------------------
+# Tool calling (F2): tool_parser + tool_backend kwargs
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture
+def echo_backend():
+ from pyrit.tools import LocalToolBackend
+
+ async def _echo(args):
+ return {"echoed": args.get("text", "")}
+
+ return LocalToolBackend(
+ callables={"echo": _echo},
+ schemas=[
+ {
+ "name": "echo",
+ "description": "Echo back the given text.",
+ "parameters": {
+ "type": "object",
+ "properties": {"text": {"type": "string"}},
+ "required": ["text"],
+ },
+ }
+ ],
+ )
+
+
+@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed")
+def test_tool_parser_kwarg_flips_supports_tool_use_capability(patch_central_database):
+ from pyrit.prompt_target.common.target_capabilities import CapabilityName
+ from pyrit.tools import InlineToolCallParser
+
+ target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False, tool_parser=InlineToolCallParser())
+ assert target.configuration.includes(capability=CapabilityName.TOOL_USE)
+ assert target._tool_parser is not None
+
+
+@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed")
+def test_no_tool_parser_leaves_supports_tool_use_off(patch_central_database):
+ from pyrit.prompt_target.common.target_capabilities import CapabilityName
+
+ target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False)
+ assert not target.configuration.includes(capability=CapabilityName.TOOL_USE)
+ assert target._tool_parser is None
+
+
+@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed")
+def test_tool_backend_kwarg_installed_into_configuration(patch_central_database, echo_backend):
+ from pyrit.tools import InlineToolCallParser
+
+ target = HuggingFaceChatTarget(
+ model_id="test_model",
+ use_cuda=False,
+ tool_parser=InlineToolCallParser(),
+ tool_backend=echo_backend,
+ )
+ assert target.configuration.tool_backend is echo_backend
+
+
+@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed")
+def test_tool_backend_kwarg_without_parser_raises(patch_central_database, echo_backend):
+ with pytest.raises(ValueError, match="supports_tool_use"):
+ HuggingFaceChatTarget(model_id="test_model", use_cuda=False, tool_backend=echo_backend)
+
+
+@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed")
+def test_tool_schemas_returns_bare_backend_schemas(patch_central_database, echo_backend):
+ """HF chat templates accept bare schemas (no OpenAI envelope)."""
+ from pyrit.tools import InlineToolCallParser
+
+ target = HuggingFaceChatTarget(
+ model_id="test_model",
+ use_cuda=False,
+ tool_parser=InlineToolCallParser(),
+ tool_backend=echo_backend,
+ )
+ schemas = target._tool_schemas()
+ assert len(schemas) == 1
+ assert schemas[0]["name"] == "echo"
+ # Unlike AzureMLChatTarget, no {"type": "function", "function": {...}} wrapper.
+ assert "function" not in schemas[0]
+
+
+@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed")
+def test_build_chat_messages_translates_function_call_piece(patch_central_database):
+ target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False)
+ fc_envelope = {
+ "type": "function_call",
+ "call_id": "call_0",
+ "name": "echo",
+ "arguments": '{"text":"hi"}',
+ }
+ msg = Message(
+ message_pieces=[
+ MessagePiece(
+ role="assistant",
+ original_value=json.dumps(fc_envelope),
+ original_value_data_type="function_call",
+ converted_value_data_type="function_call",
+ )
+ ]
+ )
+ chat_messages = target._build_chat_messages(normalized_conversation=[msg])
+ assert len(chat_messages) == 1
+ assert chat_messages[0]["role"] == "assistant"
+ assert chat_messages[0]["tool_calls"][0]["function"]["name"] == "echo"
+ assert chat_messages[0]["tool_calls"][0]["function"]["arguments"] == '{"text":"hi"}'
+
+
+@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed")
+def test_build_chat_messages_translates_function_call_output_piece(patch_central_database):
+ target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False)
+ fco_envelope = {
+ "type": "function_call_output",
+ "call_id": "call_0",
+ "output": '{"echoed":"hi"}',
+ }
+ msg = Message(
+ message_pieces=[
+ MessagePiece(
+ role="tool",
+ original_value=json.dumps(fco_envelope),
+ original_value_data_type="function_call_output",
+ converted_value_data_type="function_call_output",
+ )
+ ],
+ skip_validation=True,
+ )
+ chat_messages = target._build_chat_messages(normalized_conversation=[msg])
+ assert len(chat_messages) == 1
+ assert chat_messages[0]["role"] == "tool"
+ assert chat_messages[0]["tool_call_id"] == "call_0"
+ assert chat_messages[0]["content"] == '{"echoed":"hi"}'
+
+
+@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed")
+def test_apply_chat_template_forwards_tools_when_present(patch_central_database, echo_backend):
+ from pyrit.tools import InlineToolCallParser
+
+ target = HuggingFaceChatTarget(
+ model_id="test_model",
+ use_cuda=False,
+ tool_parser=InlineToolCallParser(),
+ tool_backend=echo_backend,
+ )
+ target._apply_chat_template([{"role": "user", "content": "hi"}])
+ call_kwargs = target.tokenizer.apply_chat_template.call_args.kwargs
+ assert "tools" in call_kwargs
+ assert call_kwargs["tools"][0]["name"] == "echo"
+
+
+@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed")
+def test_apply_chat_template_omits_tools_when_no_backend(patch_central_database):
+ target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False)
+ target._apply_chat_template([{"role": "user", "content": "hi"}])
+ call_kwargs = target.tokenizer.apply_chat_template.call_args.kwargs
+ assert "tools" not in call_kwargs
diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py
index 2317bd705f..b62fa7a64b 100644
--- a/tests/unit/prompt_target/target/test_normalize_async_integration.py
+++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py
@@ -229,7 +229,7 @@ async def test_azure_ml_target_calls_normalize_async():
with (
patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize,
- patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"),
+ patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"}),
):
mock_normalize.return_value = [user_msg]
await target.send_prompt_async(message=user_msg)
@@ -254,7 +254,9 @@ async def test_azure_ml_target_sends_normalized_to_complete_chat():
with (
patch.object(target.configuration, "normalize_async", new_callable=AsyncMock, return_value=[adapted_msg]),
- patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat,
+ patch.object(
+ target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"}
+ ) as mock_chat,
):
await target.send_prompt_async(message=user_msg)
@@ -294,7 +296,7 @@ async def test_azure_ml_target_memory_not_mutated():
mock_memory.get_conversation.return_value = memory_conversation
target._memory = mock_memory
- with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"):
+ with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"}):
await target.send_prompt_async(message=user_msg)
# Memory must still have original system message only (not mutated)
@@ -386,7 +388,9 @@ async def test_azure_ml_system_squash_via_configuration_pipeline():
mock_memory.get_conversation.return_value = [system_msg]
target._memory = mock_memory
- with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat:
+ with patch.object(
+ target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"}
+ ) as mock_chat:
await target.send_prompt_async(message=user_msg)
# _complete_chat_async should receive normalized messages (system squashed into user)
diff --git a/tests/unit/prompt_target/target/test_openai_response_target_c6_migration.py b/tests/unit/prompt_target/target/test_openai_response_target_c6_migration.py
new file mode 100644
index 0000000000..04dfd0b024
--- /dev/null
+++ b/tests/unit/prompt_target/target/test_openai_response_target_c6_migration.py
@@ -0,0 +1,304 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""C6 additions to the Response target function-chaining suite.
+
+Covers the migration onto @tool_loop + LocalToolBackend.
+"""
+
+from __future__ import annotations
+
+import json
+import uuid
+import warnings
+from typing import Any
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from pyrit.models import Message, MessagePiece
+from pyrit.prompt_target import OpenAIResponseTarget
+from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
+from pyrit.prompt_target.common.target_configuration import TargetConfiguration
+from pyrit.tools import LocalToolBackend, ToolEventBehavior, ToolEventPolicy
+
+
+def _mock_function_call_response(call_id: str, function_name: str, arguments: dict) -> MagicMock:
+ """Build a fake Responses-API response containing a function_call section."""
+ mock_response = MagicMock()
+ mock_response.status = "completed"
+ mock_response.error = None
+ section = MagicMock()
+ section.type = "function_call"
+ section.call_id = call_id
+ section.name = function_name
+ section.arguments = json.dumps(arguments)
+ section.model_dump.return_value = {
+ "type": "function_call",
+ "call_id": call_id,
+ "name": function_name,
+ "arguments": json.dumps(arguments),
+ }
+ mock_response.output = [section]
+ return mock_response
+
+
+def _mock_text_response(text: str) -> MagicMock:
+ """Build a fake Responses-API response containing a message section."""
+ mock_response = MagicMock()
+ mock_response.status = "completed"
+ mock_response.error = None
+ section = MagicMock()
+ section.type = "message"
+ section.content = [MagicMock(text=text)]
+ mock_response.output = [section]
+ return mock_response
+
+
+def _user_msg(text: str, conversation_id: str | None = None) -> Message:
+ return Message(
+ message_pieces=[
+ MessagePiece(
+ role="user",
+ original_value=text,
+ conversation_id=conversation_id or str(uuid.uuid4()),
+ )
+ ]
+ )
+
+
+class TestCustomFunctionsDeprecation:
+ """custom_functions still works but emits DeprecationWarning."""
+
+ def test_custom_functions_kwarg_emits_deprecation_warning(self, patch_central_database):
+ async def get_weather(args: dict[str, Any]) -> dict[str, Any]:
+ return {"t": 72}
+
+ with pytest.warns(DeprecationWarning, match="custom_functions"):
+ OpenAIResponseTarget(
+ model_name="gpt-4",
+ endpoint="https://mock.example.com",
+ api_key="mock-key",
+ custom_functions={"get_weather": get_weather},
+ )
+
+ @pytest.mark.asyncio
+ async def test_custom_functions_kwarg_still_dispatches(self, patch_central_database):
+ async def get_weather(args: dict[str, Any]) -> dict[str, Any]:
+ return {"temperature": 72, "condition": "sunny"}
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ target = OpenAIResponseTarget(
+ model_name="gpt-4",
+ endpoint="https://mock.example.com",
+ api_key="mock-key",
+ custom_functions={"get_weather": get_weather},
+ )
+
+ responses = [
+ _mock_function_call_response("call_1", "get_weather", {"location": "NYC"}),
+ _mock_text_response("72F and sunny."),
+ ]
+ seen = []
+
+ async def mock_create(**kwargs):
+ seen.append(kwargs)
+ return responses[len(seen) - 1]
+
+ with patch.object(target._async_client.responses, "create", new_callable=AsyncMock) as mc:
+ mc.side_effect = mock_create
+ result = await target.send_prompt_async(message=_user_msg("weather?"))
+
+ assert len(seen) == 2
+ assert result[-1].message_pieces[0].original_value == "72F and sunny."
+ second_input = seen[1]["input"]
+ assert any(item.get("type") == "function_call_output" for item in second_input)
+
+
+def _config_with_backend(backend: LocalToolBackend) -> TargetConfiguration:
+ """Build a TargetConfiguration wired for the modern tool-backend path."""
+ caps = TargetCapabilities(
+ supports_multi_turn=True,
+ supports_multi_message_pieces=True,
+ supports_editable_history=True,
+ supports_json_output=True,
+ supports_system_prompt=True,
+ supports_tool_use=True,
+ input_modalities=frozenset(
+ {
+ frozenset(["text"]),
+ frozenset(["text", "image_path"]),
+ frozenset(["function_call"]),
+ frozenset(["tool_call"]),
+ frozenset(["function_call_output"]),
+ frozenset(["reasoning"]),
+ }
+ ),
+ )
+ return TargetConfiguration(
+ capabilities=caps,
+ tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE, max_tool_iterations=5),
+ tool_backend=backend,
+ )
+
+
+class TestToolBackendDispatch:
+ """The modern path: pass tool_backend via TargetConfiguration."""
+
+ @pytest.mark.asyncio
+ async def test_local_backend_dispatches_through_tool_loop(self, patch_central_database):
+ async def get_weather(args: dict[str, Any]) -> dict[str, Any]:
+ return {"temperature": 72, "condition": "sunny"}
+
+ backend = LocalToolBackend(
+ callables={"get_weather": get_weather},
+ schemas=[
+ {
+ "name": "get_weather",
+ "description": "Weather lookup.",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ "required": ["location"],
+ },
+ }
+ ],
+ )
+ target = OpenAIResponseTarget(
+ model_name="gpt-4",
+ endpoint="https://mock.example.com",
+ api_key="mock-key",
+ custom_configuration=_config_with_backend(backend),
+ )
+
+ responses = [
+ _mock_function_call_response("call_1", "get_weather", {"location": "NYC"}),
+ _mock_text_response("72F and sunny in NYC."),
+ ]
+ seen = []
+
+ async def mock_create(**kwargs):
+ seen.append(kwargs)
+ return responses[len(seen) - 1]
+
+ with patch.object(target._async_client.responses, "create", new_callable=AsyncMock) as mc:
+ mc.side_effect = mock_create
+ result = await target.send_prompt_async(message=_user_msg("weather?"))
+
+ assert len(seen) == 2
+ assert result[-1].message_pieces[0].original_value == "72F and sunny in NYC."
+ second_input = seen[1]["input"]
+ assert any(item.get("type") == "function_call_output" for item in second_input)
+
+
+class TestToolSchemasInjection:
+ """_construct_request_body injects backend schemas when present."""
+
+ @pytest.mark.asyncio
+ async def test_backend_schemas_injected_into_tools(self, patch_central_database):
+ async def get_weather(args: dict[str, Any]) -> dict[str, Any]:
+ return {"t": 1}
+
+ backend = LocalToolBackend(
+ callables={"get_weather": get_weather},
+ schemas=[{"name": "get_weather", "description": "x", "parameters": {"type": "object"}}],
+ )
+ target = OpenAIResponseTarget(
+ model_name="gpt-4",
+ endpoint="https://mock.example.com",
+ api_key="mock-key",
+ custom_configuration=_config_with_backend(backend),
+ )
+ body = await target._construct_request_body(
+ conversation=[_user_msg("hi")],
+ json_config=MagicMock(enabled=False, schema=None),
+ )
+ assert "tools" in body
+ assert body["tools"][0]["type"] == "function"
+ assert body["tools"][0]["name"] == "get_weather"
+
+ @pytest.mark.asyncio
+ async def test_extra_body_tools_take_precedence(self, patch_central_database):
+ async def f(args: dict[str, Any]) -> dict[str, Any]:
+ return {}
+
+ backend = LocalToolBackend(
+ callables={"f": f},
+ schemas=[{"name": "f", "parameters": {"type": "object"}}],
+ )
+ legacy = [{"type": "function", "name": "legacy_tool", "description": "x"}]
+ config = _config_with_backend(backend)
+ target = OpenAIResponseTarget(
+ model_name="gpt-4",
+ endpoint="https://mock.example.com",
+ api_key="mock-key",
+ extra_body_parameters={"tools": legacy},
+ custom_configuration=config,
+ )
+ body = await target._construct_request_body(
+ conversation=[_user_msg("hi")],
+ json_config=MagicMock(enabled=False, schema=None),
+ )
+ assert body["tools"] == legacy
+
+ @pytest.mark.asyncio
+ async def test_no_backend_no_tools_key(self, patch_central_database):
+ target = OpenAIResponseTarget(
+ model_name="gpt-4",
+ endpoint="https://mock.example.com",
+ api_key="mock-key",
+ )
+ body = await target._construct_request_body(
+ conversation=[_user_msg("hi")],
+ json_config=MagicMock(enabled=False, schema=None),
+ )
+ assert "tools" not in body
+
+
+class TestNonFunctionCallPiecesPassThrough:
+ """Reasoning / mcp_call / web_search_call sections must NOT be dispatched.
+
+ The Response target's parser populates pieces for these types so they can
+ be persisted to Memory and round-tripped on subsequent requests. The
+ CanonicalEnvelopeParser only extracts function_call pieces; the tool loop
+ must therefore see an empty parse and exit cleanly.
+ """
+
+ @pytest.mark.asyncio
+ async def test_reasoning_only_response_exits_loop(self, patch_central_database):
+ target = OpenAIResponseTarget(
+ model_name="gpt-4",
+ endpoint="https://mock.example.com",
+ api_key="mock-key",
+ reasoning_effort="medium",
+ )
+ # Reasoning section + final text section in one response
+ mock_response = MagicMock()
+ mock_response.status = "completed"
+ mock_response.error = None
+ reasoning_section = MagicMock()
+ reasoning_section.type = "reasoning"
+ reasoning_section.model_dump.return_value = {"type": "reasoning", "summary": "thinking..."}
+ text_section = MagicMock()
+ text_section.type = "message"
+ text_section.content = [MagicMock(text="The answer is 42.")]
+ mock_response.output = [reasoning_section, text_section]
+
+ seen = []
+
+ async def mock_create(**kwargs):
+ seen.append(kwargs)
+ return mock_response
+
+ with patch.object(target._async_client.responses, "create", new_callable=AsyncMock) as mc:
+ mc.side_effect = mock_create
+ result = await target.send_prompt_async(message=_user_msg("question?"))
+
+ # Exactly one API call -- reasoning is not a tool call so the loop exits
+ assert len(seen) == 1
+ # Response message contains both pieces
+ assert len(result) == 1
+ piece_types = [p.original_value_data_type for p in result[0].message_pieces]
+ assert "reasoning" in piece_types
+ assert "text" in piece_types
diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py
index f3174c2649..fa29815103 100644
--- a/tests/unit/prompt_target/target/test_prompt_target.py
+++ b/tests/unit/prompt_target/target/test_prompt_target.py
@@ -22,6 +22,7 @@
UnsupportedCapabilityBehavior,
)
from pyrit.prompt_target.common.target_configuration import TargetConfiguration
+from pyrit.tools import LocalToolBackend, ToolEventBehavior, ToolEventPolicy
@pytest.fixture
@@ -534,7 +535,13 @@ def test_identifier_includes_capability_params():
# Config-derived fields are nested under ``target_configuration``, not
# spread at the top level — guards against accidental re-flattening.
assert "supports_multi_turn" not in params
- assert set(target_config.keys()) == {"capabilities", "capability_policy", "normalization_pipeline"}
+ assert set(target_config.keys()) == {
+ "capabilities",
+ "capability_policy",
+ "normalization_pipeline",
+ "tool_event_policy",
+ "tool_backend",
+ }
assert capabilities["supports_multi_turn"] is True
assert capabilities["supports_multi_message_pieces"] is True
@@ -546,6 +553,8 @@ def test_identifier_includes_capability_params():
assert capabilities["output_modalities"] == [["text"]]
assert isinstance(target_config["capability_policy"], dict)
assert isinstance(target_config["normalization_pipeline"], list)
+ assert target_config["tool_event_policy"] is None
+ assert target_config["tool_backend"] is None
@pytest.mark.usefixtures("patch_central_database")
@@ -581,6 +590,62 @@ def test_identifier_differs_when_policy_differs():
assert a.get_identifier().hash != b.get_identifier().hash
+@pytest.mark.usefixtures("patch_central_database")
+def test_identifier_differs_when_tool_backend_differs():
+ async def _f(_: dict) -> dict:
+ return {}
+
+ capabilities = TargetCapabilities(supports_tool_use=True)
+ backend_a = LocalToolBackend(
+ callables={"alpha": _f},
+ schemas=[{"name": "alpha", "parameters": {"type": "object"}}],
+ )
+ backend_b = LocalToolBackend(
+ callables={"beta": _f},
+ schemas=[{"name": "beta", "parameters": {"type": "object"}}],
+ )
+
+ a = OpenAIChatTarget(
+ model_name="gpt-4o",
+ endpoint="https://mock.azure.com/",
+ api_key="mock-api-key",
+ custom_configuration=TargetConfiguration(capabilities=capabilities, tool_backend=backend_a),
+ )
+ b = OpenAIChatTarget(
+ model_name="gpt-4o",
+ endpoint="https://mock.azure.com/",
+ api_key="mock-api-key",
+ custom_configuration=TargetConfiguration(capabilities=capabilities, tool_backend=backend_b),
+ )
+
+ assert a.get_identifier().hash != b.get_identifier().hash
+
+
+@pytest.mark.usefixtures("patch_central_database")
+def test_identifier_differs_when_tool_event_policy_differs():
+ capabilities = TargetCapabilities(supports_tool_use=True)
+ a = OpenAIChatTarget(
+ model_name="gpt-4o",
+ endpoint="https://mock.azure.com/",
+ api_key="mock-api-key",
+ custom_configuration=TargetConfiguration(
+ capabilities=capabilities,
+ tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE),
+ ),
+ )
+ b = OpenAIChatTarget(
+ model_name="gpt-4o",
+ endpoint="https://mock.azure.com/",
+ api_key="mock-api-key",
+ custom_configuration=TargetConfiguration(
+ capabilities=capabilities,
+ tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.RAISE),
+ ),
+ )
+
+ assert a.get_identifier().hash != b.get_identifier().hash
+
+
@pytest.mark.usefixtures("patch_central_database")
def test_identifier_is_deterministic_across_instances():
capabilities = TargetCapabilities(
diff --git a/tests/unit/tools/__init__.py b/tests/unit/tools/__init__.py
new file mode 100644
index 0000000000..9a0454564d
--- /dev/null
+++ b/tests/unit/tools/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
diff --git a/tests/unit/tools/conftest.py b/tests/unit/tools/conftest.py
new file mode 100644
index 0000000000..d8b0151bb7
--- /dev/null
+++ b/tests/unit/tools/conftest.py
@@ -0,0 +1,299 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+Shared fixtures for ``tests/unit/tools``.
+
+Provides the minimal collaborators the tool-loop tests need to exercise
+:func:`pyrit.tools.tool_loop` end-to-end without standing up real targets
+or MCP transports:
+
+* :class:`_FakeToolTarget` — a :class:`PromptTarget` subclass whose
+ ``_send_prompt_to_target_async`` returns scripted messages from a queue
+ and whose ``_get_normalized_conversation_async`` skips the memory round
+ trip so decorator behavior is isolated from normalization.
+* :class:`_RecordingToolBackend` — a :class:`ToolBackend` that records
+ every dispatched call (for order-of-execution assertions) and returns
+ results from a scripted queue.
+* :class:`_CanonicalEnvelopeParser` — a :class:`ToolCallParser` that walks
+ message pieces and parses the canonical ``function_call`` JSON envelope.
+
+Helper message builders (``_make_user_message``,
+``_make_assistant_text_message``, ``_make_assistant_function_call_message``)
+produce the canonical envelope shape used by the OpenAI targets after the
+normalization commit.
+"""
+
+from __future__ import annotations
+
+import json
+import uuid
+from collections import deque
+from typing import Any
+
+import pytest
+
+from pyrit.models import Message, MessagePiece
+from pyrit.prompt_target.common.prompt_target import PromptTarget
+from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
+from pyrit.prompt_target.common.target_configuration import TargetConfiguration
+from pyrit.tools import (
+ LocalToolBackend,
+ ToolBackend,
+ ToolCall,
+ ToolCallParser,
+ ToolEventBehavior,
+ ToolEventPolicy,
+)
+
+
+def _make_user_message(text: str, *, conversation_id: str | None = None) -> Message:
+ """Build a single-piece user :class:`Message` carrying *text*."""
+ return Message(
+ message_pieces=[
+ MessagePiece(
+ role="user",
+ original_value=text,
+ original_value_data_type="text",
+ conversation_id=conversation_id or str(uuid.uuid4()),
+ )
+ ]
+ )
+
+
+def _make_assistant_text_message(text: str, *, conversation_id: str | None = None) -> Message:
+ """Build a single-piece assistant :class:`Message` carrying plain text."""
+ return Message(
+ message_pieces=[
+ MessagePiece(
+ role="assistant",
+ original_value=text,
+ original_value_data_type="text",
+ conversation_id=conversation_id or str(uuid.uuid4()),
+ )
+ ],
+ skip_validation=True,
+ )
+
+
+def _make_function_call_piece(
+ *,
+ call_id: str,
+ name: str,
+ arguments: dict[str, Any],
+ conversation_id: str | None = None,
+) -> MessagePiece:
+ """Build one assistant ``function_call`` piece carrying the canonical envelope."""
+ envelope = {
+ "type": "function_call",
+ "call_id": call_id,
+ "name": name,
+ "arguments": json.dumps(arguments, separators=(",", ":")),
+ }
+ return MessagePiece(
+ role="assistant",
+ original_value=json.dumps(envelope, separators=(",", ":")),
+ original_value_data_type="function_call",
+ conversation_id=conversation_id or str(uuid.uuid4()),
+ )
+
+
+def _make_assistant_function_call_message(
+ *,
+ calls: list[tuple[str, str, dict[str, Any]]],
+ conversation_id: str | None = None,
+) -> Message:
+ """
+ Build an assistant :class:`Message` carrying one ``function_call`` piece
+ per ``(call_id, name, args)`` tuple, in declaration order.
+ """
+ conv_id = conversation_id or str(uuid.uuid4())
+ pieces = [
+ _make_function_call_piece(call_id=cid, name=name, arguments=args, conversation_id=conv_id)
+ for cid, name, args in calls
+ ]
+ return Message(message_pieces=pieces, skip_validation=True)
+
+
+class _CanonicalEnvelopeParser:
+ """
+ Reference :class:`ToolCallParser` that understands the canonical envelope
+ (``original_value_data_type == "function_call"`` carrying a JSON object
+ with ``type``/``call_id``/``name``/``arguments``).
+
+ Per-target parsers shipped will reuse this shape; this stand-in
+ keeps decorator tests independent of the real OpenAI parsers.
+ """
+
+ def parse(self, message: Message) -> list[ToolCall]:
+ calls: list[ToolCall] = []
+ for piece in message.message_pieces:
+ if piece.original_value_data_type != "function_call":
+ continue
+ envelope = json.loads(piece.original_value)
+ arguments_str = envelope.get("arguments", "{}")
+ arguments = json.loads(arguments_str) if isinstance(arguments_str, str) else dict(arguments_str)
+ calls.append(
+ ToolCall(
+ call_id=envelope["call_id"],
+ name=envelope["name"],
+ arguments=arguments,
+ raw_envelope=envelope,
+ )
+ )
+ return calls
+
+
+class _RecordingToolBackend(ToolBackend):
+ """
+ Minimal :class:`ToolBackend` that records every dispatched call and
+ returns results from a scripted queue. Used to assert dispatch order,
+ iteration count, and per-call payload shape without invoking real tools.
+ """
+
+ def __init__(
+ self,
+ *,
+ scripted_results: list[Any] | None = None,
+ schemas: list[dict[str, Any]] | None = None,
+ ) -> None:
+ self._results: deque[Any] = deque(scripted_results or [])
+ self._schemas: list[dict[str, Any]] = list(schemas) if schemas is not None else []
+ self.recorded_calls: list[ToolCall] = []
+
+ @property
+ def schemas(self) -> list[dict[str, Any]]:
+ return list(self._schemas)
+
+ async def dispatch_async(self, call: ToolCall) -> dict[str, Any]:
+ self.recorded_calls.append(call)
+ if not self._results:
+ return {"result": f"recorded:{call.name}:{call.call_id}"}
+ nxt = self._results.popleft()
+ return nxt if isinstance(nxt, dict) else {"result": nxt}
+
+
+class _FakeToolTarget(PromptTarget):
+ """
+ Test-only :class:`PromptTarget` whose ``_send_prompt_to_target_async``
+ pops scripted responses off a queue. ``_get_normalized_conversation_async``
+ is overridden to return ``[message]`` directly, isolating decorator
+ behavior from the memory + normalization pipeline.
+
+ Inherits the base class's ``@final @tool_loop send_prompt_async``; the
+ policy + backend are wired through :class:`TargetConfiguration` so the
+ wrapper finds them via ``self.configuration.tool_event_policy`` and
+ ``self.configuration.tool_backend``.
+ """
+
+ def __init__(
+ self,
+ *,
+ scripted_responses: list[Message],
+ policy: ToolEventPolicy | None = None,
+ backend: Any = None,
+ parser: ToolCallParser | None = None,
+ ) -> None:
+ # ``supports_tool_use`` is forced on whenever a policy is configured so
+ # the TargetConfiguration validator accepts the backend.
+ caps = TargetCapabilities(
+ supports_multi_turn=True,
+ supports_multi_message_pieces=True,
+ supports_tool_use=policy is not None,
+ )
+ config = TargetConfiguration(
+ capabilities=caps,
+ tool_event_policy=policy,
+ tool_backend=backend,
+ )
+ super().__init__(custom_configuration=config)
+ self._scripted_responses: deque[Message] = deque(scripted_responses)
+ self.call_count: int = 0
+ self.normalized_conversations_seen: list[list[Message]] = []
+ self._parser_instance: ToolCallParser | None = parser if parser is not None else _CanonicalEnvelopeParser()
+
+ @property
+ def _tool_parser(self) -> ToolCallParser | None:
+ return self._parser_instance
+
+ async def _get_normalized_conversation_async(self, *, message: Message) -> list[Message]:
+ return [message]
+
+ def _validate_request(self, *, normalized_conversation: list[Message]) -> None:
+ return
+
+ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]:
+ self.call_count += 1
+ self.normalized_conversations_seen.append(list(normalized_conversation))
+ if not self._scripted_responses:
+ raise AssertionError(f"Fake target ran out of scripted responses on iteration {self.call_count}.")
+ return [self._scripted_responses.popleft()]
+
+
+@pytest.fixture
+def make_fake_target(patch_central_database):
+ """
+ Factory fixture for :class:`_FakeToolTarget`. Each invocation returns a
+ fresh target instance whose scripted response queue is independent of
+ other targets created during the test.
+ """
+
+ def _factory(
+ *,
+ scripted_responses: list[Message],
+ policy: ToolEventPolicy | None = None,
+ backend: Any = None,
+ parser: ToolCallParser | None = None,
+ ) -> _FakeToolTarget:
+ return _FakeToolTarget(
+ scripted_responses=scripted_responses,
+ policy=policy,
+ backend=backend,
+ parser=parser,
+ )
+
+ return _factory
+
+
+@pytest.fixture
+def recording_backend():
+ """Factory fixture for :class:`_RecordingToolBackend`."""
+
+ def _factory(*, scripted_results: list[Any] | None = None) -> _RecordingToolBackend:
+ return _RecordingToolBackend(scripted_results=scripted_results)
+
+ return _factory
+
+
+@pytest.fixture
+def execute_policy():
+ """
+ Factory fixture for :class:`ToolEventPolicy` with
+ ``behavior=ToolEventBehavior.EXECUTE`` and a tunable iteration cap.
+ """
+
+ def _factory(*, max_tool_iterations: int = 5) -> ToolEventPolicy:
+ return ToolEventPolicy(
+ behavior=ToolEventBehavior.EXECUTE,
+ max_tool_iterations=max_tool_iterations,
+ )
+
+ return _factory
+
+
+__all__ = [
+ "LocalToolBackend",
+ "ToolCall",
+ "ToolEventBehavior",
+ "ToolEventPolicy",
+ "_CanonicalEnvelopeParser",
+ "_FakeToolTarget",
+ "_RecordingToolBackend",
+ "_make_assistant_function_call_message",
+ "_make_assistant_text_message",
+ "_make_function_call_piece",
+ "_make_user_message",
+ "execute_policy",
+ "make_fake_target",
+ "recording_backend",
+]
diff --git a/tests/unit/tools/echo_mcp_server.py b/tests/unit/tools/echo_mcp_server.py
new file mode 100644
index 0000000000..3ea5dbd6b5
--- /dev/null
+++ b/tests/unit/tools/echo_mcp_server.py
@@ -0,0 +1,56 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+Deterministic echo MCP server used as a stdio subprocess fixture by
+``tests/unit/tools/test_mcp_client.py`` and the tools integration tests.
+
+The harness imports this module via the ``mcp.client.stdio.stdio_client``
+launcher, so it does not need to be importable as a Python module from
+``tests/unit/tools/`` callers.
+
+Run directly as ``python echo_mcp_server.py`` to expose the four tools
+over stdio. The MCP client harness launches this file with
+``mcp.client.stdio.stdio_client`` and asserts behavior end to end.
+"""
+
+from __future__ import annotations
+
+import asyncio
+
+from mcp.server.fastmcp import FastMCP
+
+mcp = FastMCP("pyrit-echo")
+
+
+@mcp.tool()
+def echo(text: str) -> str:
+ """Return *text* unchanged."""
+ return text
+
+
+@mcp.tool()
+def add(a: int, b: int) -> int:
+ """Return ``a + b``."""
+ return a + b
+
+
+@mcp.tool()
+def reverse(text: str) -> str:
+ """Return *text* reversed."""
+ return text[::-1]
+
+
+@mcp.tool()
+async def slow_echo(text: str, delay_ms: int = 0) -> str:
+ """
+ Return *text* after sleeping ``delay_ms`` milliseconds. Used by
+ timeout / cancellation tests.
+ """
+ if delay_ms > 0:
+ await asyncio.sleep(delay_ms / 1000.0)
+ return text
+
+
+if __name__ == "__main__":
+ mcp.run()
diff --git a/tests/unit/tools/test_inline_parser.py b/tests/unit/tools/test_inline_parser.py
new file mode 100644
index 0000000000..2f2850e667
--- /dev/null
+++ b/tests/unit/tools/test_inline_parser.py
@@ -0,0 +1,234 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""Unit tests for InlineToolCallParser across marker syntaxes."""
+
+from __future__ import annotations
+
+import logging
+import uuid
+
+import pytest
+
+from pyrit.models import Message, MessagePiece
+from pyrit.tools import InlineToolCallParser, InlineToolCallParserMode
+
+
+def _assistant_text(text: str) -> Message:
+ return Message(
+ message_pieces=[
+ MessagePiece(
+ role="assistant",
+ original_value=text,
+ original_value_data_type="text",
+ conversation_id=str(uuid.uuid4()),
+ )
+ ],
+ skip_validation=True,
+ )
+
+
+# Marker patterns commonly seen in tool-trained open chat templates.
+# Named after their syntactic shape, not the model family that uses them,
+# so that PyRIT does not advertise a "supported vendors" list.
+ANGLE_BRACKET_PATTERN = r"(.*?)"
+PIPE_PYTHON_TAG_PATTERN = r"<\|python_tag\|>(.*?)<\|eom_id\|>"
+SQUARE_BRACKET_LIST_PATTERN = r"\[TOOL_CALLS\]\s*(\[.*?\])"
+
+
+# ---------------------------------------------------------------------------
+# Marker-pattern coverage: angle-bracket / pipe-python-tag / square-bracket-list
+# ---------------------------------------------------------------------------
+
+
+class TestAngleBracketMarker:
+ """Default pattern: ``...``."""
+
+ def test_single_marker_extracts_call(self):
+ parser = InlineToolCallParser()
+ text = (
+ 'Let me check the weather. {"name": "get_weather", "arguments": {"location": "NYC"}}'
+ )
+ calls = parser.parse(_assistant_text(text))
+ assert len(calls) == 1
+ assert calls[0].name == "get_weather"
+ assert calls[0].arguments == {"location": "NYC"}
+ assert calls[0].call_id == "call_0"
+
+ def test_multiple_markers_extract_all_with_synthetic_ids(self):
+ parser = InlineToolCallParser(mode=InlineToolCallParserMode.EXTRACT_ALL)
+ text = (
+ '{"name": "f1", "arguments": {}}'
+ "some interleaving text "
+ '{"name": "f2", "arguments": {"k": 1}}'
+ )
+ calls = parser.parse(_assistant_text(text))
+ assert [c.name for c in calls] == ["f1", "f2"]
+ assert [c.call_id for c in calls] == ["call_0", "call_1"]
+
+ def test_arguments_as_json_string_is_decoded(self):
+ parser = InlineToolCallParser()
+ # arguments doubly-encoded as a JSON string (canonical envelope shape).
+ text = '{"name": "f", "arguments": "{\\"a\\": 1}"}'
+ calls = parser.parse(_assistant_text(text))
+ assert len(calls) == 1
+ assert calls[0].arguments == {"a": 1}
+
+
+class TestPipePythonTagMarker:
+ """Pipe-delimited tag pair: ``<|python_tag|>...<|eom_id|>``."""
+
+ def test_single_marker_extracts_call(self):
+ parser = InlineToolCallParser(marker_pattern=PIPE_PYTHON_TAG_PATTERN)
+ text = (
+ 'Sure, calling now. <|python_tag|>{"name": "get_weather", "arguments": {"location": "Seattle"}}<|eom_id|>'
+ )
+ calls = parser.parse(_assistant_text(text))
+ assert len(calls) == 1
+ assert calls[0].name == "get_weather"
+ assert calls[0].arguments == {"location": "Seattle"}
+
+ def test_multi_marker_extract_all(self):
+ parser = InlineToolCallParser(
+ marker_pattern=PIPE_PYTHON_TAG_PATTERN,
+ mode=InlineToolCallParserMode.EXTRACT_ALL,
+ )
+ text = (
+ '<|python_tag|>{"name": "a", "arguments": {}}<|eom_id|>'
+ "between "
+ '<|python_tag|>{"name": "b", "arguments": {}}<|eom_id|>'
+ )
+ calls = parser.parse(_assistant_text(text))
+ assert [c.name for c in calls] == ["a", "b"]
+
+
+class TestSquareBracketListMarker:
+ """Square-bracketed list payload: ``[TOOL_CALLS] [...]``."""
+
+ def test_list_payload_is_skipped_with_warning(self, caplog):
+ """The default parser expects a single-dict payload.
+
+ Marker syntaxes whose payload is a JSON LIST of dicts (rather than a
+ single dict) are logged and dropped. Callers that need list-shaped
+ payloads should either subclass ``InlineToolCallParser`` and override
+ ``parse`` to iterate the list, or use a marker pattern that targets
+ each dict inside the list separately. Regex does not handle nested
+ braces well; subclassing is cleaner.
+ """
+ parser = InlineToolCallParser(marker_pattern=SQUARE_BRACKET_LIST_PATTERN)
+ text = '[TOOL_CALLS] [{"name": "f", "arguments": {"x": 1}}]'
+ with caplog.at_level(logging.WARNING, logger="pyrit.tools.inline_parser"):
+ calls = parser.parse(_assistant_text(text))
+ assert calls == []
+ assert any("without 'name' field" in rec.message for rec in caplog.records)
+
+
+# ---------------------------------------------------------------------------
+# Mode coverage: TRUNCATE_AT_LAST / TRUNCATE_AT_FIRST / EXTRACT_ALL /
+# STRICT_TRAILING_EMPTY
+# ---------------------------------------------------------------------------
+
+
+class TestParserModes:
+ """Surrounding-text policy coverage."""
+
+ HALLUCINATED = (
+ '{"name": "get_weather", "arguments": {"location": "NYC"}}'
+ " The weather in NYC is sunny and 72 degrees."
+ )
+ DOUBLE_CALL = (
+ '{"name": "a", "arguments": {}}'
+ " middle "
+ '{"name": "b", "arguments": {}}'
+ " trailing chatter"
+ )
+
+ def test_truncate_at_last_default_drops_trailing_chatter(self):
+ parser = InlineToolCallParser() # default mode
+ calls = parser.parse(_assistant_text(self.HALLUCINATED))
+ # The hallucinated weather report after the marker is discarded; the
+ # call itself is honored.
+ assert len(calls) == 1
+ assert calls[0].name == "get_weather"
+
+ def test_truncate_at_last_extracts_all_markers_then_drops_tail(self):
+ parser = InlineToolCallParser(mode=InlineToolCallParserMode.TRUNCATE_AT_LAST)
+ calls = parser.parse(_assistant_text(self.DOUBLE_CALL))
+ # Both markers honored, trailing "trailing chatter" silently dropped.
+ assert [c.name for c in calls] == ["a", "b"]
+
+ def test_truncate_at_first_keeps_only_the_first_marker(self):
+ parser = InlineToolCallParser(mode=InlineToolCallParserMode.TRUNCATE_AT_FIRST)
+ calls = parser.parse(_assistant_text(self.DOUBLE_CALL))
+ assert [c.name for c in calls] == ["a"]
+
+ def test_extract_all_keeps_every_marker(self):
+ parser = InlineToolCallParser(mode=InlineToolCallParserMode.EXTRACT_ALL)
+ calls = parser.parse(_assistant_text(self.DOUBLE_CALL))
+ assert [c.name for c in calls] == ["a", "b"]
+
+ def test_strict_trailing_empty_raises_on_chatter(self):
+ parser = InlineToolCallParser(mode=InlineToolCallParserMode.STRICT_TRAILING_EMPTY)
+ with pytest.raises(ValueError, match="STRICT_TRAILING_EMPTY"):
+ parser.parse(_assistant_text(self.HALLUCINATED))
+
+ def test_strict_trailing_empty_passes_when_only_whitespace_after(self):
+ parser = InlineToolCallParser(mode=InlineToolCallParserMode.STRICT_TRAILING_EMPTY)
+ text = '{"name": "f", "arguments": {}}\n \t '
+ calls = parser.parse(_assistant_text(text))
+ assert [c.name for c in calls] == ["f"]
+
+
+# ---------------------------------------------------------------------------
+# Edge cases
+# ---------------------------------------------------------------------------
+
+
+class TestEdgeCases:
+ """Empty input, malformed JSON, missing name, multi-piece messages."""
+
+ def test_no_markers_returns_empty(self):
+ parser = InlineToolCallParser()
+ calls = parser.parse(_assistant_text("just plain assistant text"))
+ assert calls == []
+
+ def test_malformed_json_is_skipped_silently(self):
+ parser = InlineToolCallParser()
+ text = "not valid json"
+ calls = parser.parse(_assistant_text(text))
+ assert calls == []
+
+ def test_payload_without_name_is_skipped(self):
+ parser = InlineToolCallParser()
+ text = '{"arguments": {}}'
+ calls = parser.parse(_assistant_text(text))
+ assert calls == []
+
+ def test_non_text_pieces_are_ignored(self):
+ """Pieces with data_type other than 'text' are skipped entirely."""
+ parser = InlineToolCallParser()
+ msg = Message(
+ message_pieces=[
+ MessagePiece(
+ role="assistant",
+ original_value='{"name": "ignored", "arguments": {}}',
+ original_value_data_type="reasoning",
+ conversation_id=str(uuid.uuid4()),
+ ),
+ MessagePiece(
+ role="assistant",
+ original_value='{"name": "found", "arguments": {}}',
+ original_value_data_type="text",
+ conversation_id=str(uuid.uuid4()),
+ ),
+ ],
+ skip_validation=True,
+ )
+ calls = parser.parse(msg)
+ assert [c.name for c in calls] == ["found"]
+
+ def test_call_id_prefix_customization(self):
+ parser = InlineToolCallParser(call_id_prefix="custom")
+ text = '{"name": "f", "arguments": {}}'
+ calls = parser.parse(_assistant_text(text))
+ assert calls[0].call_id == "custom_0"
diff --git a/tests/unit/tools/test_local_tool_backend.py b/tests/unit/tools/test_local_tool_backend.py
new file mode 100644
index 0000000000..13d12c9b26
--- /dev/null
+++ b/tests/unit/tools/test_local_tool_backend.py
@@ -0,0 +1,179 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+Unit tests for :class:`pyrit.tools.LocalToolBackend`.
+
+Coverage map:
+
+* **U10** (partial; the MCP counterpart) —
+ ``test_each_dummy_tool_invoked_via_prepended_conversation``
+* **U17** (partial; the MCP-timeout counterpart) —
+ ``test_failing_tool_yields_error_envelope``
+* **U18** — ``test_disallowed_tool_returns_error_without_invoking_callable``
+
+Also covers the backend's documented behavior for missing functions
+(both strict and tolerant modes), schema property defaulting, scalar
+result wrapping, and declaration-order preservation in the bulk dispatch
+path. These are required for the §10 rubber-duck guarantee that every
+public-facing branch of :class:`LocalToolBackend` is exercised end-to-end.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from pyrit.tools import LocalToolBackend, ToolCall
+
+
+def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) -> ToolCall:
+ return ToolCall(call_id=call_id, name=name, arguments=arguments or {})
+
+
+async def test_disallowed_tool_returns_error_without_invoking_callable():
+ invoked: list[str] = []
+
+ async def echo(args: dict) -> dict:
+ invoked.append(args.get("text", ""))
+ return {"echoed": args.get("text", "")}
+
+ backend = LocalToolBackend(
+ callables={"echo": echo, "off_limits": echo},
+ allowed_tools={"echo"},
+ )
+
+ result = await backend.dispatch_async(_make_call("off_limits", arguments={"text": "nope"}))
+
+ assert result["error"] == "tool_not_allowed"
+ assert result["tool"] == "off_limits"
+ assert "echo" in result["allowed_tools"]
+ assert invoked == [] # callable was never invoked
+
+
+async def test_failing_tool_yields_error_envelope():
+ async def boom(args: dict) -> dict:
+ raise RuntimeError("kaboom")
+
+ backend = LocalToolBackend(callables={"boom": boom})
+
+ result = await backend.dispatch_async(_make_call("boom"))
+
+ assert result["error"] == "tool_execution_failed"
+ assert result["tool"] == "boom"
+ assert "kaboom" in result["detail"]
+
+
+async def test_missing_tool_raises_when_strict():
+ backend = LocalToolBackend(callables={}, fail_on_missing_function=True)
+
+ with pytest.raises(KeyError, match="ghost"):
+ await backend.dispatch_async(_make_call("ghost"))
+
+
+async def test_missing_tool_returns_envelope_when_tolerant():
+ async def echo(args: dict) -> dict:
+ return {"ok": True}
+
+ backend = LocalToolBackend(
+ callables={"echo": echo},
+ fail_on_missing_function=False,
+ )
+
+ result = await backend.dispatch_async(_make_call("ghost"))
+
+ assert result["error"] == "tool_not_registered"
+ assert result["tool"] == "ghost"
+ assert result["available_tools"] == ["echo"]
+
+
+async def test_scalar_result_is_wrapped_in_dict():
+ async def number(args: dict) -> int:
+ return 42
+
+ backend = LocalToolBackend(callables={"number": number})
+
+ result = await backend.dispatch_async(_make_call("number"))
+
+ assert result == {"result": 42}
+
+
+async def test_dict_result_passes_through_unchanged():
+ async def named(args: dict) -> dict:
+ return {"custom_key": "custom_value"}
+
+ backend = LocalToolBackend(callables={"named": named})
+
+ result = await backend.dispatch_async(_make_call("named"))
+
+ assert result == {"custom_key": "custom_value"}
+
+
+async def test_schemas_defaults_to_empty_list():
+ backend = LocalToolBackend(callables={})
+
+ assert backend.schemas == []
+
+
+async def test_schemas_returned_as_copy():
+ schemas_in = [{"name": "echo", "parameters": {}}]
+ backend = LocalToolBackend(callables={}, schemas=schemas_in)
+
+ out1 = backend.schemas
+ out1.append({"name": "mutated"})
+
+ # Mutating the returned list does not affect the backend's internal state.
+ assert backend.schemas == schemas_in
+
+
+async def test_dispatch_all_sequential_preserves_declaration_order():
+ async def echo(args: dict) -> dict:
+ return {"echoed": args["i"]}
+
+ backend = LocalToolBackend(callables={"echo": echo})
+
+ calls = [_make_call("echo", call_id=f"c{i}", arguments={"i": i}) for i in range(5)]
+ pairs = await backend.dispatch_all_sequential_async(calls)
+
+ assert [c.call_id for c, _ in pairs] == ["c0", "c1", "c2", "c3", "c4"]
+ assert [r["echoed"] for _, r in pairs] == [0, 1, 2, 3, 4]
+
+
+async def test_each_dummy_tool_invoked_via_prepended_conversation():
+ """
+ U10 (partial). Each dummy tool resolves on first dispatch (single
+ forward step, no model reasoning trace), confirming the backend can
+ short-circuit a prepended conversation where every call is already
+ decided. The MCP counterpart exercises the same shape against
+ a real stdio server.
+ """
+ invocations: list[tuple[str, dict]] = []
+
+ async def echo(args: dict) -> dict:
+ invocations.append(("echo", args))
+ return {"echoed": args.get("text", "")}
+
+ async def add(args: dict) -> dict:
+ invocations.append(("add", args))
+ return {"sum": args["a"] + args["b"]}
+
+ async def reverse(args: dict) -> dict:
+ invocations.append(("reverse", args))
+ return {"reversed": args.get("text", "")[::-1]}
+
+ backend = LocalToolBackend(callables={"echo": echo, "add": add, "reverse": reverse})
+
+ prepended_calls = [
+ _make_call("echo", call_id="e1", arguments={"text": "hello"}),
+ _make_call("add", call_id="a1", arguments={"a": 2, "b": 3}),
+ _make_call("reverse", call_id="r1", arguments={"text": "pyrit"}),
+ ]
+ pairs = await backend.dispatch_all_sequential_async(prepended_calls)
+
+ # Each dummy resolved exactly once; no retries, no model re-entry.
+ assert len(invocations) == 3
+ assert [name for name, _ in invocations] == ["echo", "add", "reverse"]
+ assert [r for _, r in pairs] == [
+ {"echoed": "hello"},
+ {"sum": 5},
+ {"reversed": "tiryp"},
+ ]
diff --git a/tests/unit/tools/test_mcp_backend.py b/tests/unit/tools/test_mcp_backend.py
new file mode 100644
index 0000000000..7621a61dca
--- /dev/null
+++ b/tests/unit/tools/test_mcp_backend.py
@@ -0,0 +1,148 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+Unit tests for :class:`pyrit.tools.MCPToolBackend`.
+
+These tests verify the multi-server fan-out and routing layer on top of
+:class:`MCPClient`: schema aggregation, name-collision detection,
+``name_prefix`` disambiguation, ``allowed_tools`` allow-list semantics,
+and concurrent-dispatch serialization. They reuse the real
+``echo_mcp_server.py`` stdio subprocess.
+
+Coverage map:
+
+* **U18** — ``test_disallowed_tool_returns_error_envelope_without_invoking_server``.
+* **U20a** — ``test_name_collision_raises_value_error``.
+* **U20b** — ``test_name_prefix_disambiguates_colliding_servers``.
+* **U21** — ``test_concurrent_dispatch_is_serialized_by_lock``.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import sys
+from pathlib import Path
+
+import pytest
+
+from pyrit.tools import (
+ LocalMCPServerSpec,
+ MCPToolBackend,
+ ToolCall,
+)
+
+ECHO_SERVER_SCRIPT = str(Path(__file__).parent / "echo_mcp_server.py")
+
+
+def _spec(*, name_prefix: str | None = None, timeout_seconds: float = 5.0) -> LocalMCPServerSpec:
+ return LocalMCPServerSpec(
+ command=sys.executable,
+ args=(ECHO_SERVER_SCRIPT,),
+ name_prefix=name_prefix,
+ timeout_seconds=timeout_seconds,
+ )
+
+
+def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) -> ToolCall:
+ return ToolCall(call_id=call_id, name=name, arguments=arguments or {})
+
+
+async def test_backend_aggregates_schemas_across_servers() -> None:
+ """Schemas from every connected server show up in :attr:`schemas`."""
+ backend = MCPToolBackend(servers=[_spec()])
+ async with backend:
+ names = {s["name"] for s in backend.schemas}
+ assert names == {"echo", "add", "reverse", "slow_echo"}
+
+
+async def test_dispatch_routes_to_correct_server() -> None:
+ """A :class:`ToolCall` is routed to the server that registered the name."""
+ backend = MCPToolBackend(servers=[_spec()])
+ async with backend:
+ envelope = await backend.dispatch_async(_make_call("echo", arguments={"text": "routed"}))
+ assert envelope["is_error"] is False
+ assert envelope["content"] == "routed"
+
+
+async def test_name_collision_raises_value_error() -> None:
+ """Two servers exposing the same tool name without prefixes raise."""
+ backend = MCPToolBackend(servers=[_spec(), _spec()])
+ with pytest.raises(ValueError, match="duplicate tool name"):
+ await backend.__aenter__()
+ # __aexit__ is the cleanup path; __aenter__ failing leaves nothing to clean.
+
+
+async def test_name_prefix_disambiguates_colliding_servers() -> None:
+ """Setting :attr:`LocalMCPServerSpec.name_prefix` disambiguates duplicates."""
+ backend = MCPToolBackend(
+ servers=[
+ _spec(name_prefix="a_"),
+ _spec(name_prefix="b_"),
+ ],
+ )
+ async with backend:
+ names = {s["name"] for s in backend.schemas}
+ assert "a_echo" in names
+ assert "b_echo" in names
+ envelope = await backend.dispatch_async(_make_call("a_echo", arguments={"text": "alpha"}))
+ assert envelope["content"] == "alpha"
+ envelope_b = await backend.dispatch_async(_make_call("b_echo", arguments={"text": "beta"}))
+ assert envelope_b["content"] == "beta"
+
+
+async def test_disallowed_tool_returns_error_envelope_without_invoking_server() -> None:
+ """U18: allowed_tools blocks both schema advertisement AND dispatch."""
+ backend = MCPToolBackend(servers=[_spec()], allowed_tools=["echo"])
+ async with backend:
+ advertised = {s["name"] for s in backend.schemas}
+ assert advertised == {"echo"} # add/reverse/slow_echo are filtered out.
+
+ envelope = await backend.dispatch_async(_make_call("add", arguments={"a": 1, "b": 2}))
+ assert envelope["is_error"] is True
+ assert envelope["error"] == "tool_not_allowed"
+ assert envelope["tool"] == "add"
+ assert envelope["allowed_tools"] == ["echo"]
+
+
+async def test_unknown_tool_returns_error_envelope() -> None:
+ """A call to a name no connected server exposes returns an error envelope."""
+ backend = MCPToolBackend(servers=[_spec()])
+ async with backend:
+ envelope = await backend.dispatch_async(_make_call("never_registered"))
+ assert envelope["is_error"] is True
+ assert envelope["error"] == "tool_not_registered"
+ assert envelope["tool"] == "never_registered"
+
+
+async def test_concurrent_dispatch_is_serialized_by_lock() -> None:
+ """U21: two coroutines dispatching against the same backend do not interleave.
+
+ The slow_echo tool sleeps server-side; without the lock the two
+ dispatches would issue overlapping JSON-RPC frames over the same
+ stdio pipe. With the lock they run back-to-back. We assert both
+ return successfully — interleaved frames would surface as protocol
+ errors or wrong content.
+ """
+ backend = MCPToolBackend(servers=[_spec(timeout_seconds=10.0)])
+ async with backend:
+ results = await asyncio.gather(
+ backend.dispatch_async(_make_call("slow_echo", arguments={"text": "A", "delay_ms": 50})),
+ backend.dispatch_async(_make_call("slow_echo", arguments={"text": "B", "delay_ms": 50})),
+ )
+ assert all(not r["is_error"] for r in results)
+ assert {r["content"] for r in results} == {"A", "B"}
+
+
+async def test_dispatch_all_sequential_async_preserves_order() -> None:
+ """Bulk dispatch returns (call, envelope) pairs in declaration order."""
+ backend = MCPToolBackend(servers=[_spec()])
+ calls = [
+ _make_call("echo", call_id="c1", arguments={"text": "first"}),
+ _make_call("echo", call_id="c2", arguments={"text": "second"}),
+ _make_call("echo", call_id="c3", arguments={"text": "third"}),
+ ]
+ async with backend:
+ results = await backend.dispatch_all_sequential_async(calls)
+ assert [c.call_id for c, _ in results] == ["c1", "c2", "c3"]
+ assert [r["content"] for _, r in results] == ["first", "second", "third"]
diff --git a/tests/unit/tools/test_mcp_client.py b/tests/unit/tools/test_mcp_client.py
new file mode 100644
index 0000000000..de08b479b7
--- /dev/null
+++ b/tests/unit/tools/test_mcp_client.py
@@ -0,0 +1,162 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+Unit tests for :class:`pyrit.tools.MCPClient` and the
+:class:`pyrit.tools.MCPServerSpec` union.
+
+Coverage map:
+
+* **U10** — ``test_real_subprocess_dispatch_returns_text_content``,
+ ``test_sequential_dispatch_against_real_server``.
+* **U14** — ``test_connect_async_populates_schemas_via_tools_list``.
+* **U17** — ``test_dispatch_timeout_returns_error_envelope``.
+* **U20** — ``test_remote_mcp_server_spec_raises_not_implemented``,
+ ``test_docker_mcp_server_spec_raises_not_implemented``.
+
+These tests spawn the real ``tests/unit/tools/echo_mcp_server.py``
+subprocess via ``mcp.client.stdio.stdio_client``; they exercise the
+full handshake → ``tools/list`` → ``tools/call`` round trip. The
+purpose is to verify that ``MCPClient`` is a thin, correct facade
+over the SDK rather than to re-test the SDK itself.
+"""
+
+from __future__ import annotations
+
+import dataclasses
+import sys
+from pathlib import Path
+
+import pytest
+
+from pyrit.tools import (
+ DockerMCPServerSpec,
+ LocalMCPServerSpec,
+ MCPClient,
+ RemoteMCPServerSpec,
+ ToolCall,
+)
+
+ECHO_SERVER_SCRIPT = str(Path(__file__).parent / "echo_mcp_server.py")
+
+
+def _local_spec(*, timeout_seconds: float = 5.0) -> LocalMCPServerSpec:
+ """Build a :class:`LocalMCPServerSpec` that spawns ``echo_mcp_server.py``."""
+ return LocalMCPServerSpec(
+ command=sys.executable,
+ args=(ECHO_SERVER_SCRIPT,),
+ timeout_seconds=timeout_seconds,
+ )
+
+
+def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) -> ToolCall:
+ return ToolCall(call_id=call_id, name=name, arguments=arguments or {})
+
+
+async def test_real_subprocess_dispatch_returns_text_content() -> None:
+ """U10: dispatching a single tool call returns the echo server's text response."""
+ client = MCPClient(spec=_local_spec())
+ async with client:
+ envelope = await client.dispatch_async(_make_call("echo", arguments={"text": "hi"}))
+ assert envelope["is_error"] is False
+ assert envelope["content"] == "hi"
+
+
+async def test_sequential_dispatch_against_real_server() -> None:
+ """U10: multiple sequential calls round-trip through the same session."""
+ client = MCPClient(spec=_local_spec())
+ async with client:
+ envelopes = [
+ await client.dispatch_async(_make_call("echo", arguments={"text": "first"})),
+ await client.dispatch_async(_make_call("add", arguments={"a": 2, "b": 3})),
+ await client.dispatch_async(_make_call("reverse", arguments={"text": "abc"})),
+ ]
+ contents = [e["content"] for e in envelopes]
+ assert contents == ["first", "5", "cba"]
+
+
+async def test_connect_async_populates_schemas_via_tools_list() -> None:
+ """U14: schemas are discovered via tools/list during connect_async."""
+ client = MCPClient(spec=_local_spec())
+ async with client:
+ schemas = client.schemas
+ names = {s["name"] for s in schemas}
+ assert names == {"echo", "add", "reverse", "slow_echo"}
+ echo_schema = next(s for s in schemas if s["name"] == "echo")
+ assert "parameters" in echo_schema
+ assert echo_schema["parameters"]["properties"]["text"]["type"] == "string"
+
+
+async def test_dispatch_timeout_returns_error_envelope() -> None:
+ """U17: a tool call that exceeds the spec's timeout produces an error envelope."""
+ client = MCPClient(spec=_local_spec(timeout_seconds=0.05))
+ async with client:
+ envelope = await client.dispatch_async(
+ _make_call("slow_echo", arguments={"text": "late", "delay_ms": 500}),
+ )
+ assert envelope["is_error"] is True
+ assert envelope["error"] == "tool_timeout"
+ assert envelope["tool"] == "slow_echo"
+
+
+async def test_dispatch_async_returns_error_envelope_on_unknown_tool() -> None:
+ """Server-side errors (unknown tool name) surface as is_error envelopes."""
+ client = MCPClient(spec=_local_spec())
+ async with client:
+ envelope = await client.dispatch_async(_make_call("nonexistent_tool"))
+ assert envelope["is_error"] is True
+ assert envelope["tool"] == "nonexistent_tool"
+
+
+def test_remote_mcp_server_spec_is_frozen_dataclass() -> None:
+ """U20: RemoteMCPServerSpec exists in the type system as a frozen dataclass."""
+ spec = RemoteMCPServerSpec(url="https://example.com/mcp")
+ assert spec.url == "https://example.com/mcp"
+ with pytest.raises(dataclasses.FrozenInstanceError):
+ spec.url = "other" # type: ignore[misc]
+
+
+async def test_remote_mcp_server_spec_raises_not_implemented() -> None:
+ """U20: connecting to a RemoteMCPServerSpec raises NotImplementedError."""
+ client = MCPClient(spec=RemoteMCPServerSpec(url="https://example.com/mcp"))
+ with pytest.raises(NotImplementedError, match="follow-up PR"):
+ await client.connect_async()
+
+
+def test_docker_mcp_server_spec_dataclass_fields() -> None:
+ """U20: DockerMCPServerSpec carries the fields the sandbox PR will consume."""
+ spec = DockerMCPServerSpec(image="pyrit-sandbox:base")
+ assert spec.image == "pyrit-sandbox:base"
+ assert spec.network_profile == "none"
+ assert spec.name_prefix is None
+ assert spec.timeout_seconds == 30.0
+
+
+async def test_docker_mcp_server_spec_raises_not_implemented() -> None:
+ """U20: connecting to a DockerMCPServerSpec raises NotImplementedError."""
+ client = MCPClient(spec=DockerMCPServerSpec(image="pyrit-sandbox:base"))
+ with pytest.raises(NotImplementedError, match="follow-up PR"):
+ await client.connect_async()
+
+
+async def test_dispatch_before_connect_raises_runtime_error() -> None:
+ """Calling dispatch_async before connect_async is a programmer error."""
+ client = MCPClient(spec=_local_spec())
+ with pytest.raises(RuntimeError, match="not connected"):
+ await client.dispatch_async(_make_call("echo", arguments={"text": "hi"}))
+
+
+async def test_close_async_is_idempotent() -> None:
+ """Calling close_async twice (or before connect) does not raise."""
+ client = MCPClient(spec=_local_spec())
+ await client.close_async() # before connect — no-op.
+ await client.connect_async()
+ await client.close_async()
+ await client.close_async() # double-close — no-op.
+
+
+async def test_local_mcp_server_spec_is_frozen() -> None:
+ """LocalMCPServerSpec is a frozen dataclass."""
+ spec = LocalMCPServerSpec(command="python", args=("a.py",))
+ with pytest.raises(dataclasses.FrozenInstanceError):
+ spec.command = "other" # type: ignore[misc]
diff --git a/tests/unit/tools/test_prompt_target_tool_loop.py b/tests/unit/tools/test_prompt_target_tool_loop.py
new file mode 100644
index 0000000000..518304111f
--- /dev/null
+++ b/tests/unit/tools/test_prompt_target_tool_loop.py
@@ -0,0 +1,276 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+Unit tests for ``@tool_loop`` wired into ``PromptTarget.send_prompt_async``.
+
+The base class decorates ``send_prompt_async`` with ``@final @tool_loop``,
+exposes ``_tool_parser`` and ``_tool_schemas()`` as default no-op hooks,
+and ``TargetConfiguration`` carries the ``tool_event_policy`` and
+``tool_backend`` kwargs the decorator consults.
+
+These tests use the production ``_get_normalized_conversation_async`` path
+(memory round-trip through :class:`SQLiteMemory` via ``patch_central_database``)
+to exercise the wrapper end-to-end. They cover:
+
+- U1: decorator order (validate + normalize happen exactly once, then the loop)
+- U2 (DB-end half): produced ``tool`` message has one ``function_call_output``
+ piece per dispatched call, in declaration order
+- U8: DB inserts user, asst_with_fc, tool, asst_final in that order
+- U9: DB roles + data_types match the canonical envelope
+- U11: targets without a policy short-circuit (no wrapper behavior change)
+
+Tests for capability flag wiring + ``TargetConfiguration`` construction
+validation live in :mod:`tests.unit.tools.test_tool_event_policy`.
+"""
+
+from __future__ import annotations
+
+import json
+from collections import deque
+from typing import TYPE_CHECKING, Any
+
+import pytest
+
+from pyrit.prompt_target.common.prompt_target import PromptTarget
+from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
+from pyrit.prompt_target.common.target_configuration import TargetConfiguration
+from pyrit.tools import ToolCallParser, ToolEventBehavior, ToolEventPolicy
+
+from .conftest import (
+ _CanonicalEnvelopeParser,
+ _make_assistant_function_call_message,
+ _make_assistant_text_message,
+ _make_user_message,
+ _RecordingToolBackend,
+)
+
+if TYPE_CHECKING:
+ from pyrit.models import Message
+
+
+class _ProductionShapedTarget(PromptTarget):
+ """
+ Minimal :class:`PromptTarget` that uses the *real* base-class
+ ``_get_normalized_conversation_async`` (memory round-trip + normalization
+ pipeline) instead of the conftest stub. Drives the production wrapper
+ end-to-end so DB-insert-order assertions can run against the real
+ :class:`CentralMemory` instance set up by ``patch_central_database``.
+ """
+
+ def __init__(
+ self,
+ *,
+ scripted_responses: list[Message],
+ policy: ToolEventPolicy | None,
+ backend: Any,
+ parser: ToolCallParser | None,
+ ) -> None:
+ caps = TargetCapabilities(
+ supports_multi_turn=True,
+ supports_multi_message_pieces=True,
+ supports_tool_use=policy is not None,
+ )
+ config = TargetConfiguration(
+ capabilities=caps,
+ tool_event_policy=policy,
+ tool_backend=backend,
+ )
+ super().__init__(custom_configuration=config)
+ self._scripted: deque[Message] = deque(scripted_responses)
+ self.call_count: int = 0
+ self._parser_instance = parser
+
+ @property
+ def _tool_parser(self) -> ToolCallParser | None:
+ return self._parser_instance
+
+ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]:
+ self.call_count += 1
+ if not self._scripted:
+ raise AssertionError(f"Target ran out of scripted responses on iteration {self.call_count}.")
+ response = self._scripted.popleft()
+ conversation_id = normalized_conversation[-1].message_pieces[0].conversation_id
+ for piece in response.message_pieces:
+ piece.conversation_id = conversation_id
+ return [response]
+
+
+@pytest.fixture
+def make_production_target(patch_central_database):
+ def _factory(
+ *,
+ scripted_responses: list[Message],
+ policy: ToolEventPolicy | None = None,
+ backend: Any = None,
+ parser: ToolCallParser | None = None,
+ ) -> _ProductionShapedTarget:
+ effective_parser = parser
+ if effective_parser is None and policy is not None:
+ effective_parser = _CanonicalEnvelopeParser()
+ return _ProductionShapedTarget(
+ scripted_responses=scripted_responses,
+ policy=policy,
+ backend=backend,
+ parser=effective_parser,
+ )
+
+ return _factory
+
+
+@pytest.fixture
+def execute_policy_fixture():
+ return ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE, max_tool_iterations=5)
+
+
+class TestToolLoopWiredIntoBaseClass:
+ """Verifies ``@tool_loop`` runs on every ``send_prompt_async`` call."""
+
+ async def test_decorator_passthrough_when_no_policy(self, make_production_target):
+ """U11 -- target without a policy behaves like a single-pass ``send_prompt_async``."""
+ target = make_production_target(
+ scripted_responses=[_make_assistant_text_message("plain")],
+ policy=None,
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("hi"))
+
+ assert target.call_count == 1
+ assert len(responses) == 1
+ assert responses[0].message_pieces[0].original_value == "plain"
+
+ async def test_tool_loop_order_after_normalize_before_memory(self, make_production_target, execute_policy_fixture):
+ """U1 -- validate + normalize happen exactly once before the loop iterates."""
+ backend = _RecordingToolBackend(scripted_results=[{"result": "echoed"}])
+ target = make_production_target(
+ scripted_responses=[
+ _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "x"})]),
+ _make_assistant_text_message("done"),
+ ],
+ policy=execute_policy_fixture,
+ backend=backend,
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("please echo"))
+
+ assert target.call_count == 2
+ assert [c.name for c in backend.recorded_calls] == ["echo"]
+ assert len(responses) == 3
+ assert responses[0].message_pieces[0].original_value_data_type == "function_call"
+ assert responses[1].message_pieces[0].original_value_data_type == "function_call_output"
+ assert responses[2].message_pieces[0].original_value_data_type == "text"
+
+ async def test_tool_message_has_one_function_call_output_piece_per_call(
+ self, make_production_target, execute_policy_fixture
+ ):
+ """U2 DB-end half -- one tool Message, N pieces, one per dispatched call."""
+ backend = _RecordingToolBackend(scripted_results=[{"r": 1}, {"r": 2}])
+ target = make_production_target(
+ scripted_responses=[
+ _make_assistant_function_call_message(
+ calls=[("c1", "echo", {"text": "a"}), ("c2", "echo", {"text": "b"})]
+ ),
+ _make_assistant_text_message("done"),
+ ],
+ policy=execute_policy_fixture,
+ backend=backend,
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("two calls please"))
+
+ tool_msg = responses[1]
+ assert len(tool_msg.message_pieces) == 2
+ call_ids_in_order = [json.loads(p.original_value)["call_id"] for p in tool_msg.message_pieces]
+ assert call_ids_in_order == ["c1", "c2"]
+ assert all(p.original_value_data_type == "function_call_output" for p in tool_msg.message_pieces)
+ assert all(p.api_role == "tool" for p in tool_msg.message_pieces)
+
+
+class TestDbTranscriptAfterToolLoop:
+ """
+ DB-level assertions that exercise the production memory pipeline.
+
+ These tests rely on the wrapper writing the user message + every assistant
+ + tool message produced during the loop back to ``CentralMemory``, in
+ declaration order. Whether that write happens *inside* the wrapper or via
+ the caller (the prompt normalizer) is an implementation detail; the
+ invariant is the wrapper returns the full chain so the caller can persist
+ in order.
+ """
+
+ async def test_db_insert_order_user_then_asst_fc_then_tool_then_final_asst(
+ self, make_production_target, execute_policy_fixture
+ ):
+ """U8 -- after a complete tool round, the wrapper's return order is canonical."""
+ backend = _RecordingToolBackend(scripted_results=[{"result": "echoed"}])
+ target = make_production_target(
+ scripted_responses=[
+ _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "x"})]),
+ _make_assistant_text_message("done"),
+ ],
+ policy=execute_policy_fixture,
+ backend=backend,
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("please echo"))
+
+ data_types_in_order = [r.message_pieces[0].original_value_data_type for r in responses]
+ assert data_types_in_order == ["function_call", "function_call_output", "text"]
+
+ async def test_db_roles_and_data_types_match_canonical_envelope(
+ self, make_production_target, execute_policy_fixture
+ ):
+ """U9 -- roles and data_types match the canonical envelope contract."""
+ backend = _RecordingToolBackend(scripted_results=[{"result": "echoed"}])
+ target = make_production_target(
+ scripted_responses=[
+ _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "x"})]),
+ _make_assistant_text_message("done"),
+ ],
+ policy=execute_policy_fixture,
+ backend=backend,
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("please echo"))
+
+ asst_fc, tool_msg, asst_final = responses
+ # function_call from the assistant
+ assert asst_fc.message_pieces[0].api_role == "assistant"
+ assert asst_fc.message_pieces[0].original_value_data_type == "function_call"
+ envelope = json.loads(asst_fc.message_pieces[0].original_value)
+ assert envelope["type"] == "function_call"
+ assert envelope["call_id"] == "c1"
+ assert envelope["name"] == "echo"
+ # function_call_output from the tool
+ assert tool_msg.message_pieces[0].api_role == "tool"
+ assert tool_msg.message_pieces[0].original_value_data_type == "function_call_output"
+ tool_envelope = json.loads(tool_msg.message_pieces[0].original_value)
+ assert tool_envelope["type"] == "function_call_output"
+ assert tool_envelope["call_id"] == "c1"
+ # Final assistant text
+ assert asst_final.message_pieces[0].api_role == "assistant"
+ assert asst_final.message_pieces[0].original_value_data_type == "text"
+
+
+class TestFinalAndAbstractMethodContract:
+ """
+ Asserts the base-class shape that ``@tool_loop`` requires but does not
+ exercise via end-to-end runs: ``_tool_parser`` defaults to ``None``,
+ ``_tool_schemas`` defaults to ``[]``.
+ """
+
+ def test_default_tool_parser_is_none(self, make_production_target):
+ target = make_production_target(
+ scripted_responses=[_make_assistant_text_message("plain")],
+ policy=None,
+ )
+ # Subclass overrides only when the test caller passes a parser. With
+ # no policy + no parser, the override returns None.
+ assert target._tool_parser is None
+
+ def test_default_tool_schemas_is_empty_list(self, make_production_target):
+ target = make_production_target(
+ scripted_responses=[_make_assistant_text_message("plain")],
+ policy=None,
+ )
+ assert target._tool_schemas() == []
diff --git a/tests/unit/tools/test_tool_event_policy.py b/tests/unit/tools/test_tool_event_policy.py
new file mode 100644
index 0000000000..20cf283e07
--- /dev/null
+++ b/tests/unit/tools/test_tool_event_policy.py
@@ -0,0 +1,119 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+Unit tests for the wiring between :class:`TargetCapabilities.supports_tool_use`,
+:class:`TargetConfiguration.tool_event_policy` /
+:class:`TargetConfiguration.tool_backend`, and the
+:func:`pyrit.tools.tool_loop` decorator that lives on
+:class:`PromptTarget.send_prompt_async`.
+
+These tests are the §7 U7 row plus the construction-time validator.
+They assert the *capability flag* axis only -- that targets which declare
+``supports_tool_use=True`` and configure a policy + backend route through
+the loop, that targets without a policy short-circuit, and that the
+``tool_backend``-without-capability misconfiguration raises at construction.
+
+End-to-end ordering against the production memory pipeline (U1, U8, U9) is
+exercised separately in ``tests/unit/prompt_target/common/test_prompt_target_tool_loop.py``.
+"""
+
+from __future__ import annotations
+
+import pytest
+
+from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
+from pyrit.prompt_target.common.target_configuration import TargetConfiguration
+from pyrit.tools import LocalToolBackend, ToolEventBehavior, ToolEventPolicy
+
+from .conftest import (
+ _make_assistant_function_call_message,
+ _make_assistant_text_message,
+ _make_user_message,
+)
+
+
+class TestSupportsToolUseCapabilityFlag:
+ """Asserts the new ``supports_tool_use`` field on :class:`TargetCapabilities`."""
+
+ def test_default_is_false(self):
+ caps = TargetCapabilities()
+ assert caps.supports_tool_use is False
+
+ def test_explicit_true(self):
+ caps = TargetCapabilities(supports_tool_use=True)
+ assert caps.supports_tool_use is True
+
+
+class TestTargetConfigurationToolFields:
+ """Asserts the new ``tool_event_policy`` / ``tool_backend`` kwargs."""
+
+ def test_defaults_are_none(self):
+ caps = TargetCapabilities(supports_tool_use=True)
+ config = TargetConfiguration(capabilities=caps)
+ assert config.tool_event_policy is None
+ assert config.tool_backend is None
+
+ def test_explicit_policy_and_backend(self):
+ caps = TargetCapabilities(supports_tool_use=True)
+ backend = LocalToolBackend(callables={}, schemas=[])
+ policy = ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE)
+ config = TargetConfiguration(
+ capabilities=caps,
+ tool_event_policy=policy,
+ tool_backend=backend,
+ )
+ assert config.tool_event_policy is policy
+ assert config.tool_backend is backend
+
+ def test_tool_backend_without_capability_raises(self):
+ caps = TargetCapabilities(supports_tool_use=False)
+ backend = LocalToolBackend(callables={}, schemas=[])
+ with pytest.raises(ValueError, match="supports_tool_use"):
+ TargetConfiguration(capabilities=caps, tool_backend=backend)
+
+ def test_tool_event_policy_without_backend_is_allowed(self):
+ """``RAISE`` / ``RETURN_RAW`` policies do not require a backend."""
+ caps = TargetCapabilities(supports_tool_use=True)
+ policy = ToolEventPolicy(behavior=ToolEventBehavior.RAISE)
+ config = TargetConfiguration(capabilities=caps, tool_event_policy=policy)
+ assert config.tool_event_policy is policy
+ assert config.tool_backend is None
+
+
+class TestCapabilityFlagWiringIntoToolLoop:
+ """
+ U7 -- verify the wrapper dispatches only when the target declares
+ ``supports_tool_use`` AND a policy is configured.
+ """
+
+ async def test_target_with_tool_use_capability_uses_tool_loop(
+ self, make_fake_target, recording_backend, execute_policy
+ ):
+ backend = recording_backend()
+ target = make_fake_target(
+ scripted_responses=[
+ _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "hi"})]),
+ _make_assistant_text_message("done"),
+ ],
+ policy=execute_policy(),
+ backend=backend,
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("please call echo"))
+
+ assert target.call_count == 2, "Decorator should have iterated twice (call + final)."
+ assert [c.name for c in backend.recorded_calls] == ["echo"]
+ assert len(responses) == 3, "user expects asst_fc, tool_msg, asst_final."
+
+ async def test_target_without_tool_use_capability_skips_dispatch(self, make_fake_target):
+ target = make_fake_target(
+ scripted_responses=[_make_assistant_text_message("plain response, no tool call")],
+ policy=None,
+ backend=None,
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("hello"))
+
+ assert target.call_count == 1
+ assert len(responses) == 1
diff --git a/tests/unit/tools/test_tool_loop_decorator.py b/tests/unit/tools/test_tool_loop_decorator.py
new file mode 100644
index 0000000000..89805fbb1e
--- /dev/null
+++ b/tests/unit/tools/test_tool_loop_decorator.py
@@ -0,0 +1,289 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+Unit tests for :func:`pyrit.tools.tool_loop`.
+
+Coverage map:
+
+* **U2** (partial; full-DB end) — ``test_loop_returns_full_chain_in_order``
+* **U3** — ``test_loop_exits_on_first_response_when_no_tool_calls``,
+ ``test_loops_until_no_pending_tool_call``
+* **U4** — ``test_raises_after_max_tool_iterations``,
+ ``test_partial_conversation_attached_to_limit_exception``
+* **U12** — ``test_policy_raise_includes_partial_conversation``
+* **U13** — ``test_policy_return_raw_does_not_dispatch``
+* **U16** — ``test_multi_call_per_turn_dispatched_sequentially_in_order``
+
+Also covers two additional decorator concerns required by the rubber-duck
+review (§10): EXECUTE policy with no backend raises with a partial
+conversation, and the normalized conversation grows correctly across
+iterations (decorator does not re-normalize each turn).
+"""
+
+from __future__ import annotations
+
+import json
+
+import pytest
+
+from pyrit.exceptions import ToolCallLoopLimitExceeded, ToolCallNotSupported
+from pyrit.tools import ToolEventBehavior, ToolEventPolicy
+
+from .conftest import (
+ _make_assistant_function_call_message,
+ _make_assistant_text_message,
+ _make_user_message,
+)
+
+
+@pytest.mark.usefixtures("patch_central_database")
+class TestToolLoopDecoratorBasics:
+ """Loop entry/exit semantics: no tool calls, single round trip, multi-round."""
+
+ async def test_loop_exits_on_first_response_when_no_tool_calls(self, make_fake_target, execute_policy):
+ target = make_fake_target(
+ scripted_responses=[_make_assistant_text_message("done")],
+ policy=execute_policy(),
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("hi"))
+
+ assert len(responses) == 1
+ assert responses[0].get_value() == "done"
+ assert target.call_count == 1
+
+ async def test_loops_until_no_pending_tool_call(self, make_fake_target, execute_policy, recording_backend):
+ backend = recording_backend(scripted_results=[{"ok": True}, {"ok": True}])
+ target = make_fake_target(
+ scripted_responses=[
+ _make_assistant_function_call_message(calls=[("c1", "tool_a", {"x": 1})]),
+ _make_assistant_function_call_message(calls=[("c2", "tool_a", {"x": 2})]),
+ _make_assistant_text_message("done"),
+ ],
+ policy=execute_policy(max_tool_iterations=5),
+ backend=backend,
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("hi"))
+
+ # Two model-tool round trips and one final assistant message.
+ assert target.call_count == 3
+ # Returned chain: fc1, tool1, fc2, tool2, final-text → 5 messages total.
+ assert len(responses) == 5
+ assert [r.message_pieces[0].original_value_data_type for r in responses] == [
+ "function_call",
+ "function_call_output",
+ "function_call",
+ "function_call_output",
+ "text",
+ ]
+ assert len(backend.recorded_calls) == 2
+ assert [c.call_id for c in backend.recorded_calls] == ["c1", "c2"]
+
+
+@pytest.mark.usefixtures("patch_central_database")
+class TestToolLoopMessageShape:
+ """U2 — assistant_fc → tool → final_assistant ordering and identity."""
+
+ async def test_loop_returns_full_chain_in_order(self, make_fake_target, execute_policy, recording_backend):
+ backend = recording_backend(scripted_results=[{"weather": "sunny"}])
+ fc_msg = _make_assistant_function_call_message(calls=[("call_abc", "get_weather", {"city": "Seattle"})])
+ final_msg = _make_assistant_text_message("It is sunny in Seattle.")
+
+ target = make_fake_target(
+ scripted_responses=[fc_msg, final_msg],
+ policy=execute_policy(),
+ backend=backend,
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("weather?"))
+
+ assert len(responses) == 3
+ # 1) assistant with function_call (identity preserved)
+ assert responses[0] is fc_msg
+ # 2) tool message with exactly one function_call_output piece carrying call_id
+ tool_msg = responses[1]
+ assert len(tool_msg.message_pieces) == 1
+ tool_piece = tool_msg.message_pieces[0]
+ assert tool_piece.api_role == "tool"
+ assert tool_piece.original_value_data_type == "function_call_output"
+ envelope = json.loads(tool_piece.original_value)
+ assert envelope["type"] == "function_call_output"
+ assert envelope["call_id"] == "call_abc"
+ # The tool result is JSON-serialized into the "output" field.
+ assert json.loads(envelope["output"]) == {"weather": "sunny"}
+ # 3) final assistant text (identity preserved)
+ assert responses[2] is final_msg
+
+
+@pytest.mark.usefixtures("patch_central_database")
+class TestToolLoopIterationLimits:
+ """U4 — iteration cap raises and carries the partial chain."""
+
+ async def test_raises_after_max_tool_iterations(self, make_fake_target, execute_policy, recording_backend):
+ # Model never stops asking for tools.
+ backend = recording_backend(scripted_results=[{"ok": True}] * 3)
+ target = make_fake_target(
+ scripted_responses=[
+ _make_assistant_function_call_message(calls=[(f"c{i}", "loop_tool", {})]) for i in range(3)
+ ],
+ policy=execute_policy(max_tool_iterations=2),
+ backend=backend,
+ )
+
+ with pytest.raises(ToolCallLoopLimitExceeded, match="max_tool_iterations=2"):
+ await target.send_prompt_async(message=_make_user_message("hi"))
+
+ # Exactly max_tool_iterations model calls made before raising.
+ assert target.call_count == 2
+
+ async def test_partial_conversation_attached_to_limit_exception(
+ self, make_fake_target, execute_policy, recording_backend
+ ):
+ backend = recording_backend(scripted_results=[{"ok": True}] * 2)
+ target = make_fake_target(
+ scripted_responses=[
+ _make_assistant_function_call_message(calls=[(f"c{i}", "loop_tool", {})]) for i in range(2)
+ ],
+ policy=execute_policy(max_tool_iterations=2),
+ backend=backend,
+ )
+
+ with pytest.raises(ToolCallLoopLimitExceeded) as excinfo:
+ await target.send_prompt_async(message=_make_user_message("hi"))
+
+ partial = excinfo.value.partial_conversation
+ # 2 iterations × (assistant_fc + tool_msg) = 4 messages, all in order.
+ assert len(partial) == 4
+ assert [m.message_pieces[0].original_value_data_type for m in partial] == [
+ "function_call",
+ "function_call_output",
+ "function_call",
+ "function_call_output",
+ ]
+
+
+@pytest.mark.usefixtures("patch_central_database")
+class TestToolEventPolicyBehaviors:
+ """U12, U13 — non-EXECUTE behaviors short-circuit dispatch."""
+
+ async def test_policy_raise_includes_partial_conversation(self, make_fake_target, recording_backend):
+ backend = recording_backend(scripted_results=[{"ok": True}])
+ fc_msg = _make_assistant_function_call_message(calls=[("c1", "danger", {})])
+ target = make_fake_target(
+ scripted_responses=[fc_msg],
+ policy=ToolEventPolicy(behavior=ToolEventBehavior.RAISE),
+ backend=backend,
+ )
+
+ with pytest.raises(ToolCallNotSupported, match="RAISE") as excinfo:
+ await target.send_prompt_async(message=_make_user_message("hi"))
+
+ partial = excinfo.value.partial_conversation
+ # Partial contains the offending assistant turn; no tool dispatch occurred.
+ assert partial == [fc_msg]
+ assert backend.recorded_calls == []
+ assert target.call_count == 1
+
+ async def test_policy_return_raw_does_not_dispatch(self, make_fake_target, recording_backend):
+ backend = recording_backend(scripted_results=[{"ok": True}])
+ fc_msg = _make_assistant_function_call_message(calls=[("c1", "danger", {})])
+ target = make_fake_target(
+ scripted_responses=[fc_msg],
+ policy=ToolEventPolicy(behavior=ToolEventBehavior.RETURN_RAW),
+ backend=backend,
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("hi"))
+
+ assert responses == [fc_msg]
+ assert backend.recorded_calls == []
+ assert target.call_count == 1
+
+
+@pytest.mark.usefixtures("patch_central_database")
+class TestToolLoopMultiCallPerTurn:
+ """U16 — multi-call turns dispatch sequentially in declaration order."""
+
+ async def test_multi_call_per_turn_dispatched_sequentially_in_order(
+ self, make_fake_target, execute_policy, recording_backend
+ ):
+ backend = recording_backend(scripted_results=[{"a": 1}, {"b": 2}, {"c": 3}])
+ multi_fc = _make_assistant_function_call_message(
+ calls=[
+ ("c_alpha", "tool_alpha", {"k": "v1"}),
+ ("c_beta", "tool_beta", {"k": "v2"}),
+ ("c_gamma", "tool_gamma", {"k": "v3"}),
+ ]
+ )
+ target = make_fake_target(
+ scripted_responses=[multi_fc, _make_assistant_text_message("ok")],
+ policy=execute_policy(),
+ backend=backend,
+ )
+
+ responses = await target.send_prompt_async(message=_make_user_message("multi"))
+
+ # Three calls dispatched in declaration order, recorded ids match.
+ assert [c.call_id for c in backend.recorded_calls] == ["c_alpha", "c_beta", "c_gamma"]
+ assert [c.name for c in backend.recorded_calls] == ["tool_alpha", "tool_beta", "tool_gamma"]
+ # One tool message after the multi-call assistant turn, carrying three
+ # function_call_output pieces in declaration order with the right call_ids.
+ tool_msg = responses[1]
+ assert len(tool_msg.message_pieces) == 3
+ envelopes = [json.loads(p.original_value) for p in tool_msg.message_pieces]
+ assert [e["call_id"] for e in envelopes] == ["c_alpha", "c_beta", "c_gamma"]
+ assert all(p.original_value_data_type == "function_call_output" for p in tool_msg.message_pieces)
+
+
+@pytest.mark.usefixtures("patch_central_database")
+class TestToolLoopMisconfiguration:
+ """EXECUTE policy with no backend must fail loudly and carry the partial chain."""
+
+ async def test_execute_without_backend_raises_with_partial(self, make_fake_target, execute_policy):
+ fc_msg = _make_assistant_function_call_message(calls=[("c1", "no_reg", {})])
+ target = make_fake_target(
+ scripted_responses=[fc_msg],
+ policy=execute_policy(),
+ backend=None,
+ )
+
+ with pytest.raises(ToolCallNotSupported, match="tool_backend") as excinfo:
+ await target.send_prompt_async(message=_make_user_message("hi"))
+
+ assert excinfo.value.partial_conversation == [fc_msg]
+
+
+@pytest.mark.usefixtures("patch_central_database")
+class TestToolLoopConversationGrowth:
+ """The decorator must extend (not re-normalize) the conversation each round."""
+
+ async def test_normalized_conversation_grows_each_iteration(
+ self, make_fake_target, execute_policy, recording_backend
+ ):
+ backend = recording_backend(scripted_results=[{"r1": 1}, {"r2": 2}])
+ target = make_fake_target(
+ scripted_responses=[
+ _make_assistant_function_call_message(calls=[("c1", "t", {})]),
+ _make_assistant_function_call_message(calls=[("c2", "t", {})]),
+ _make_assistant_text_message("done"),
+ ],
+ policy=execute_policy(),
+ backend=backend,
+ )
+
+ await target.send_prompt_async(message=_make_user_message("hi"))
+
+ # Three protected-method calls; each subsequent call sees the prior
+ # assistant_fc + tool_msg appended (the decorator must NOT re-normalize).
+ seen = target.normalized_conversations_seen
+ assert len(seen) == 3
+ # call 1: just the user message
+ assert len(seen[0]) == 1
+ # call 2: user + assistant_fc(c1) + tool_msg
+ assert len(seen[1]) == 3
+ assert seen[1][1].message_pieces[0].original_value_data_type == "function_call"
+ assert seen[1][2].message_pieces[0].original_value_data_type == "function_call_output"
+ # call 3: user + assistant_fc(c1) + tool_msg + assistant_fc(c2) + tool_msg
+ assert len(seen[2]) == 5
diff --git a/uv.lock b/uv.lock
index 18a65c10f0..2adcb96aab 100644
--- a/uv.lock
+++ b/uv.lock
@@ -2122,6 +2122,15 @@ http2 = [
{ name = "h2" },
]
+[[package]]
+name = "httpx-sse"
+version = "0.4.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" },
+]
+
[[package]]
name = "huggingface-hub"
version = "1.13.0"
@@ -3193,6 +3202,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" },
]
+[[package]]
+name = "mcp"
+version = "1.27.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "anyio" },
+ { name = "httpx" },
+ { name = "httpx-sse" },
+ { name = "jsonschema" },
+ { name = "pydantic" },
+ { name = "pydantic-settings" },
+ { name = "pyjwt", extra = ["crypto"] },
+ { name = "python-multipart" },
+ { name = "pywin32", marker = "sys_platform == 'win32'" },
+ { name = "sse-starlette" },
+ { name = "starlette" },
+ { name = "typing-extensions" },
+ { name = "typing-inspection" },
+ { name = "uvicorn", marker = "sys_platform != 'emscripten'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/38/83/d1efe7c2980d8a3afa476f4e3d42d53dd54c0ab94c27bee5d755b45c8b73/mcp-1.27.1.tar.gz", hash = "sha256:0f47e1820f8f8f941466b39749eb1d1839a04caddca2bc60e9d46e8a99914924", size = 608458, upload-time = "2026-05-08T16:50:12.601Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/fd/73/42d9596facebdb533b7f0b86c1b0364ef350d1f8ba78b1052e8a58b48b65/mcp-1.27.1-py3-none-any.whl", hash = "sha256:1af3c4203b329430fde7a87b4fcb6392a041f5cb851fd68fc674016ab4e7c06f", size = 216260, upload-time = "2026-05-08T16:50:10.547Z" },
+]
+
[[package]]
name = "mdit-py-plugins"
version = "0.5.0"
@@ -5015,6 +5049,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" },
]
+[[package]]
+name = "pydantic-settings"
+version = "2.14.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pydantic" },
+ { name = "python-dotenv" },
+ { name = "typing-inspection" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/07/60/1d1e59c9c90d54591469ada7d268251f71c24bdb765f1a8a832cee8c6653/pydantic_settings-2.14.1.tar.gz", hash = "sha256:e874d3bec7e787b0c9958277956ed9b4dd5de6a80e162188fdaff7c5e26fd5fa", size = 235551, upload-time = "2026-05-08T13:40:06.542Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ae/8d/f1af3832f5e6eb13ba94ee809e72b8ecb5eef226d27ee0bef7d963d943c7/pydantic_settings-2.14.1-py3-none-any.whl", hash = "sha256:6e3c7edfd8277687cdc598f56e5cff0e9bfff0910a3749deaa8d4401c3a2b9de", size = 60964, upload-time = "2026-05-08T13:40:04.958Z" },
+]
+
[[package]]
name = "pydash"
version = "8.0.5"
@@ -5171,6 +5219,7 @@ dependencies = [
{ name = "fastapi" },
{ name = "httpx", extra = ["http2"] },
{ name = "jinja2" },
+ { name = "mcp" },
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" },
{ name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" },
{ name = "openai" },
@@ -5308,6 +5357,7 @@ requires-dist = [
{ name = "ipykernel", marker = "extra == 'all'", specifier = ">=6.29.5" },
{ name = "jinja2", specifier = ">=3.1.6" },
{ name = "jupyter", marker = "extra == 'all'", specifier = ">=1.1.1" },
+ { name = "mcp", specifier = ">=1.0,<2" },
{ name = "ml-collections", marker = "extra == 'all'", specifier = ">=1.1.0" },
{ name = "ml-collections", marker = "extra == 'gcg'", specifier = ">=1.1.0" },
{ name = "numpy", marker = "python_full_version < '3.14'", specifier = ">=1.26.0" },
@@ -5501,6 +5551,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/51/e5/fecf13f06e5e5f67e8837d777d1bc43fac0ed2b77a676804df5c34744727/python_json_logger-4.0.0-py3-none-any.whl", hash = "sha256:af09c9daf6a813aa4cc7180395f50f2a9e5fa056034c9953aec92e381c5ba1e2", size = 15548, upload-time = "2025-10-06T04:15:17.553Z" },
]
+[[package]]
+name = "python-multipart"
+version = "0.0.29"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/4e/fe/70bd71a6738b09a0bdf6480ca6436b167469ca4578b2a0efbe390b4b0e70/python_multipart-0.0.29.tar.gz", hash = "sha256:643e93849196645e2dbdd81a0f8829a23123ad7f797a84a364c6fb3563f18904", size = 45678, upload-time = "2026-05-17T17:29:47.654Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8f/cb/769cfc37177252872a45a71f3fbdde9d51b471a3f3c14bfe95dde3407386/python_multipart-0.0.29-py3-none-any.whl", hash = "sha256:2ddcc971cef266225f54f552d8fa10bcfbb1f14446caec199060daac59ff2d69", size = 29640, upload-time = "2026-05-17T17:29:45.69Z" },
+]
+
[[package]]
name = "pytz"
version = "2025.2"
@@ -5510,6 +5569,28 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" },
]
+[[package]]
+name = "pywin32"
+version = "311"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7b/40/44efbb0dfbd33aca6a6483191dae0716070ed99e2ecb0c53683f400a0b4f/pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3", size = 8760432, upload-time = "2025-07-14T20:13:05.9Z" },
+ { url = "https://files.pythonhosted.org/packages/5e/bf/360243b1e953bd254a82f12653974be395ba880e7ec23e3731d9f73921cc/pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b", size = 9590103, upload-time = "2025-07-14T20:13:07.698Z" },
+ { url = "https://files.pythonhosted.org/packages/57/38/d290720e6f138086fb3d5ffe0b6caa019a791dd57866940c82e4eeaf2012/pywin32-311-cp310-cp310-win_arm64.whl", hash = "sha256:0502d1facf1fed4839a9a51ccbcc63d952cf318f78ffc00a7e78528ac27d7a2b", size = 8778557, upload-time = "2025-07-14T20:13:11.11Z" },
+ { url = "https://files.pythonhosted.org/packages/7c/af/449a6a91e5d6db51420875c54f6aff7c97a86a3b13a0b4f1a5c13b988de3/pywin32-311-cp311-cp311-win32.whl", hash = "sha256:184eb5e436dea364dcd3d2316d577d625c0351bf237c4e9a5fabbcfa5a58b151", size = 8697031, upload-time = "2025-07-14T20:13:13.266Z" },
+ { url = "https://files.pythonhosted.org/packages/51/8f/9bb81dd5bb77d22243d33c8397f09377056d5c687aa6d4042bea7fbf8364/pywin32-311-cp311-cp311-win_amd64.whl", hash = "sha256:3ce80b34b22b17ccbd937a6e78e7225d80c52f5ab9940fe0506a1a16f3dab503", size = 9508308, upload-time = "2025-07-14T20:13:15.147Z" },
+ { url = "https://files.pythonhosted.org/packages/44/7b/9c2ab54f74a138c491aba1b1cd0795ba61f144c711daea84a88b63dc0f6c/pywin32-311-cp311-cp311-win_arm64.whl", hash = "sha256:a733f1388e1a842abb67ffa8e7aad0e70ac519e09b0f6a784e65a136ec7cefd2", size = 8703930, upload-time = "2025-07-14T20:13:16.945Z" },
+ { url = "https://files.pythonhosted.org/packages/e7/ab/01ea1943d4eba0f850c3c61e78e8dd59757ff815ff3ccd0a84de5f541f42/pywin32-311-cp312-cp312-win32.whl", hash = "sha256:750ec6e621af2b948540032557b10a2d43b0cee2ae9758c54154d711cc852d31", size = 8706543, upload-time = "2025-07-14T20:13:20.765Z" },
+ { url = "https://files.pythonhosted.org/packages/d1/a8/a0e8d07d4d051ec7502cd58b291ec98dcc0c3fff027caad0470b72cfcc2f/pywin32-311-cp312-cp312-win_amd64.whl", hash = "sha256:b8c095edad5c211ff31c05223658e71bf7116daa0ecf3ad85f3201ea3190d067", size = 9495040, upload-time = "2025-07-14T20:13:22.543Z" },
+ { url = "https://files.pythonhosted.org/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852", size = 8710102, upload-time = "2025-07-14T20:13:24.682Z" },
+ { url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700, upload-time = "2025-07-14T20:13:26.471Z" },
+ { url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700, upload-time = "2025-07-14T20:13:28.243Z" },
+ { url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318, upload-time = "2025-07-14T20:13:30.348Z" },
+ { url = "https://files.pythonhosted.org/packages/c9/31/097f2e132c4f16d99a22bfb777e0fd88bd8e1c634304e102f313af69ace5/pywin32-311-cp314-cp314-win32.whl", hash = "sha256:b7a2c10b93f8986666d0c803ee19b5990885872a7de910fc460f9b0c2fbf92ee", size = 8840714, upload-time = "2025-07-14T20:13:32.449Z" },
+ { url = "https://files.pythonhosted.org/packages/90/4b/07c77d8ba0e01349358082713400435347df8426208171ce297da32c313d/pywin32-311-cp314-cp314-win_amd64.whl", hash = "sha256:3aca44c046bd2ed8c90de9cb8427f581c479e594e99b5c0bb19b29c10fd6cb87", size = 9656800, upload-time = "2025-07-14T20:13:34.312Z" },
+ { url = "https://files.pythonhosted.org/packages/c0/d2/21af5c535501a7233e734b8af901574572da66fcc254cb35d0609c9080dd/pywin32-311-cp314-cp314-win_arm64.whl", hash = "sha256:a508e2d9025764a8270f93111a970e1d0fbfc33f4153b388bb649b7eec4f9b42", size = 8932540, upload-time = "2025-07-14T20:13:36.379Z" },
+]
+
[[package]]
name = "pywinpty"
version = "3.0.2"
@@ -6564,6 +6645,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/ae/57d1d7af907e20c077e113e0e4976f87b82c0a415403d99284a262229dd0/srsly-2.5.3-cp314-cp314t-win_arm64.whl", hash = "sha256:d822083fe26ec6728bd8c273ac121fc4ab3864a0fdf0cf0ff3efb188fcd209ed", size = 650229, upload-time = "2026-03-23T11:56:46.148Z" },
]
+[[package]]
+name = "sse-starlette"
+version = "3.4.4"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "anyio" },
+ { name = "starlette" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/f7/2b/58abc2d1fd397e7dde08e947e05c884d8ef2f78d5e2588c17a12d42d6994/sse_starlette-3.4.4.tar.gz", hash = "sha256:07e0fa0460138baf25cdd5fb28683472c3995dc1642225191b3832d62526bcb0", size = 31819, upload-time = "2026-05-12T17:37:17.019Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/dc/67/805710444ea8cc75fbf70b920ed431a560c4bf9c57f7d5a3117213189399/sse_starlette-3.4.4-py3-none-any.whl", hash = "sha256:3f4dd50d8aed2771a091f3a83000323fc3844541c16b4fe585ae2420cc6df973", size = 16514, upload-time = "2026-05-12T17:37:15.601Z" },
+]
+
[[package]]
name = "stack-data"
version = "0.6.3"