From 7a664c580602b8748c226a294211aee968a45389 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Tue, 26 May 2026 13:59:40 -0700 Subject: [PATCH 01/17] Add mcp Python SDK dependency Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyproject.toml | 1 + uv.lock | 94 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index bdc563d000..222ed1c86f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "fastapi>=0.115.0", "httpx[http2]>=0.27.2", "jinja2>=3.1.6", + "mcp>=1.0,<2", "numpy>=1.26.0; python_version < '3.14'", "numpy>=2.3.0; python_version >= '3.14'", "openai>=2.2.0", diff --git a/uv.lock b/uv.lock index dec865a0c3..40b177484d 100644 --- a/uv.lock +++ b/uv.lock @@ -2121,6 +2121,15 @@ http2 = [ { name = "h2" }, ] +[[package]] +name = "httpx-sse" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, +] + [[package]] name = "huggingface-hub" version = "1.13.0" @@ -3192,6 +3201,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" }, ] +[[package]] +name = "mcp" +version = "1.27.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "python-multipart" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/83/d1efe7c2980d8a3afa476f4e3d42d53dd54c0ab94c27bee5d755b45c8b73/mcp-1.27.1.tar.gz", hash = "sha256:0f47e1820f8f8f941466b39749eb1d1839a04caddca2bc60e9d46e8a99914924", size = 608458, upload-time = "2026-05-08T16:50:12.601Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/73/42d9596facebdb533b7f0b86c1b0364ef350d1f8ba78b1052e8a58b48b65/mcp-1.27.1-py3-none-any.whl", hash = "sha256:1af3c4203b329430fde7a87b4fcb6392a041f5cb851fd68fc674016ab4e7c06f", size = 216260, upload-time = "2026-05-08T16:50:10.547Z" }, +] + [[package]] name = "mdit-py-plugins" version = "0.5.0" @@ -5014,6 +5048,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" }, ] +[[package]] +name = "pydantic-settings" +version = "2.14.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/07/60/1d1e59c9c90d54591469ada7d268251f71c24bdb765f1a8a832cee8c6653/pydantic_settings-2.14.1.tar.gz", hash = "sha256:e874d3bec7e787b0c9958277956ed9b4dd5de6a80e162188fdaff7c5e26fd5fa", size = 235551, upload-time = "2026-05-08T13:40:06.542Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/8d/f1af3832f5e6eb13ba94ee809e72b8ecb5eef226d27ee0bef7d963d943c7/pydantic_settings-2.14.1-py3-none-any.whl", hash = "sha256:6e3c7edfd8277687cdc598f56e5cff0e9bfff0910a3749deaa8d4401c3a2b9de", size = 60964, upload-time = "2026-05-08T13:40:04.958Z" }, +] + [[package]] name = "pydash" version = "8.0.5" @@ -5170,6 +5218,7 @@ dependencies = [ { name = "fastapi" }, { name = "httpx", extra = ["http2"] }, { name = "jinja2" }, + { name = "mcp" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" }, { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, { name = "openai" }, @@ -5306,6 +5355,7 @@ requires-dist = [ { name = "ipykernel", marker = "extra == 'all'", specifier = ">=6.29.5" }, { name = "jinja2", specifier = ">=3.1.6" }, { name = "jupyter", marker = "extra == 'all'", specifier = ">=1.1.1" }, + { name = "mcp", specifier = ">=1.0,<2" }, { name = "ml-collections", marker = "extra == 'all'", specifier = ">=1.1.0" }, { name = "ml-collections", marker = "extra == 'gcg'", specifier = ">=1.1.0" }, { name = "numpy", marker = "python_full_version < '3.14'", specifier = ">=1.26.0" }, @@ -5498,6 +5548,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/e5/fecf13f06e5e5f67e8837d777d1bc43fac0ed2b77a676804df5c34744727/python_json_logger-4.0.0-py3-none-any.whl", hash = "sha256:af09c9daf6a813aa4cc7180395f50f2a9e5fa056034c9953aec92e381c5ba1e2", size = 15548, upload-time = "2025-10-06T04:15:17.553Z" }, ] +[[package]] +name = "python-multipart" +version = "0.0.29" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/fe/70bd71a6738b09a0bdf6480ca6436b167469ca4578b2a0efbe390b4b0e70/python_multipart-0.0.29.tar.gz", hash = "sha256:643e93849196645e2dbdd81a0f8829a23123ad7f797a84a364c6fb3563f18904", size = 45678, upload-time = "2026-05-17T17:29:47.654Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/cb/769cfc37177252872a45a71f3fbdde9d51b471a3f3c14bfe95dde3407386/python_multipart-0.0.29-py3-none-any.whl", hash = "sha256:2ddcc971cef266225f54f552d8fa10bcfbb1f14446caec199060daac59ff2d69", size = 29640, upload-time = "2026-05-17T17:29:45.69Z" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -5507,6 +5566,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, ] +[[package]] +name = "pywin32" +version = "311" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/40/44efbb0dfbd33aca6a6483191dae0716070ed99e2ecb0c53683f400a0b4f/pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3", size = 8760432, upload-time = "2025-07-14T20:13:05.9Z" }, + { url = "https://files.pythonhosted.org/packages/5e/bf/360243b1e953bd254a82f12653974be395ba880e7ec23e3731d9f73921cc/pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b", size = 9590103, upload-time = "2025-07-14T20:13:07.698Z" }, + { url = "https://files.pythonhosted.org/packages/57/38/d290720e6f138086fb3d5ffe0b6caa019a791dd57866940c82e4eeaf2012/pywin32-311-cp310-cp310-win_arm64.whl", hash = "sha256:0502d1facf1fed4839a9a51ccbcc63d952cf318f78ffc00a7e78528ac27d7a2b", size = 8778557, upload-time = "2025-07-14T20:13:11.11Z" }, + { url = "https://files.pythonhosted.org/packages/7c/af/449a6a91e5d6db51420875c54f6aff7c97a86a3b13a0b4f1a5c13b988de3/pywin32-311-cp311-cp311-win32.whl", hash = "sha256:184eb5e436dea364dcd3d2316d577d625c0351bf237c4e9a5fabbcfa5a58b151", size = 8697031, upload-time = "2025-07-14T20:13:13.266Z" }, + { url = "https://files.pythonhosted.org/packages/51/8f/9bb81dd5bb77d22243d33c8397f09377056d5c687aa6d4042bea7fbf8364/pywin32-311-cp311-cp311-win_amd64.whl", hash = "sha256:3ce80b34b22b17ccbd937a6e78e7225d80c52f5ab9940fe0506a1a16f3dab503", size = 9508308, upload-time = "2025-07-14T20:13:15.147Z" }, + { url = "https://files.pythonhosted.org/packages/44/7b/9c2ab54f74a138c491aba1b1cd0795ba61f144c711daea84a88b63dc0f6c/pywin32-311-cp311-cp311-win_arm64.whl", hash = "sha256:a733f1388e1a842abb67ffa8e7aad0e70ac519e09b0f6a784e65a136ec7cefd2", size = 8703930, upload-time = "2025-07-14T20:13:16.945Z" }, + { url = "https://files.pythonhosted.org/packages/e7/ab/01ea1943d4eba0f850c3c61e78e8dd59757ff815ff3ccd0a84de5f541f42/pywin32-311-cp312-cp312-win32.whl", hash = "sha256:750ec6e621af2b948540032557b10a2d43b0cee2ae9758c54154d711cc852d31", size = 8706543, upload-time = "2025-07-14T20:13:20.765Z" }, + { url = "https://files.pythonhosted.org/packages/d1/a8/a0e8d07d4d051ec7502cd58b291ec98dcc0c3fff027caad0470b72cfcc2f/pywin32-311-cp312-cp312-win_amd64.whl", hash = "sha256:b8c095edad5c211ff31c05223658e71bf7116daa0ecf3ad85f3201ea3190d067", size = 9495040, upload-time = "2025-07-14T20:13:22.543Z" }, + { url = "https://files.pythonhosted.org/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852", size = 8710102, upload-time = "2025-07-14T20:13:24.682Z" }, + { url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700, upload-time = "2025-07-14T20:13:26.471Z" }, + { url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700, upload-time = "2025-07-14T20:13:28.243Z" }, + { url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318, upload-time = "2025-07-14T20:13:30.348Z" }, + { url = "https://files.pythonhosted.org/packages/c9/31/097f2e132c4f16d99a22bfb777e0fd88bd8e1c634304e102f313af69ace5/pywin32-311-cp314-cp314-win32.whl", hash = "sha256:b7a2c10b93f8986666d0c803ee19b5990885872a7de910fc460f9b0c2fbf92ee", size = 8840714, upload-time = "2025-07-14T20:13:32.449Z" }, + { url = "https://files.pythonhosted.org/packages/90/4b/07c77d8ba0e01349358082713400435347df8426208171ce297da32c313d/pywin32-311-cp314-cp314-win_amd64.whl", hash = "sha256:3aca44c046bd2ed8c90de9cb8427f581c479e594e99b5c0bb19b29c10fd6cb87", size = 9656800, upload-time = "2025-07-14T20:13:34.312Z" }, + { url = "https://files.pythonhosted.org/packages/c0/d2/21af5c535501a7233e734b8af901574572da66fcc254cb35d0609c9080dd/pywin32-311-cp314-cp314-win_arm64.whl", hash = "sha256:a508e2d9025764a8270f93111a970e1d0fbfc33f4153b388bb649b7eec4f9b42", size = 8932540, upload-time = "2025-07-14T20:13:36.379Z" }, +] + [[package]] name = "pywinpty" version = "3.0.2" @@ -6561,6 +6642,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/ae/57d1d7af907e20c077e113e0e4976f87b82c0a415403d99284a262229dd0/srsly-2.5.3-cp314-cp314t-win_arm64.whl", hash = "sha256:d822083fe26ec6728bd8c273ac121fc4ab3864a0fdf0cf0ff3efb188fcd209ed", size = 650229, upload-time = "2026-03-23T11:56:46.148Z" }, ] +[[package]] +name = "sse-starlette" +version = "3.4.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f7/2b/58abc2d1fd397e7dde08e947e05c884d8ef2f78d5e2588c17a12d42d6994/sse_starlette-3.4.4.tar.gz", hash = "sha256:07e0fa0460138baf25cdd5fb28683472c3995dc1642225191b3832d62526bcb0", size = 31819, upload-time = "2026-05-12T17:37:17.019Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/67/805710444ea8cc75fbf70b920ed431a560c4bf9c57f7d5a3117213189399/sse_starlette-3.4.4-py3-none-any.whl", hash = "sha256:3f4dd50d8aed2771a091f3a83000323fc3844541c16b4fe585ae2420cc6df973", size = 16514, upload-time = "2026-05-12T17:37:15.601Z" }, +] + [[package]] name = "stack-data" version = "0.6.3" From c7d65d3d633f7357a2b7d3323b8b470db4b5af36 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Tue, 26 May 2026 14:34:13 -0700 Subject: [PATCH 02/17] Add tools/ package with tool_loop decorator and CallableToolBackend Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/exceptions/__init__.py | 4 + pyrit/exceptions/exception_classes.py | 64 ++++ pyrit/tools/__init__.py | 58 ++++ pyrit/tools/backend.py | 83 +++++ pyrit/tools/callable_backend.py | 134 ++++++++ pyrit/tools/models.py | 241 ++++++++++++++ pyrit/tools/parsers.py | 55 ++++ tests/unit/tools/__init__.py | 2 + tests/unit/tools/conftest.py | 307 ++++++++++++++++++ tests/unit/tools/echo_mcp_server.py | 57 ++++ .../unit/tools/test_callable_tool_backend.py | 180 ++++++++++ tests/unit/tools/test_tool_loop_decorator.py | 289 +++++++++++++++++ 12 files changed, 1474 insertions(+) create mode 100644 pyrit/tools/__init__.py create mode 100644 pyrit/tools/backend.py create mode 100644 pyrit/tools/callable_backend.py create mode 100644 pyrit/tools/models.py create mode 100644 pyrit/tools/parsers.py create mode 100644 tests/unit/tools/__init__.py create mode 100644 tests/unit/tools/conftest.py create mode 100644 tests/unit/tools/echo_mcp_server.py create mode 100644 tests/unit/tools/test_callable_tool_backend.py create mode 100644 tests/unit/tools/test_tool_loop_decorator.py diff --git a/pyrit/exceptions/__init__.py b/pyrit/exceptions/__init__.py index abd42de031..9baea33c1a 100644 --- a/pyrit/exceptions/__init__.py +++ b/pyrit/exceptions/__init__.py @@ -10,6 +10,8 @@ MissingPromptPlaceholderException, PyritException, RateLimitException, + ToolCallLoopLimitExceeded, + ToolCallNotSupported, get_retry_max_num_attempts, handle_bad_request_exception, pyrit_custom_result_retry, @@ -59,4 +61,6 @@ "set_execution_context", "set_retry_collector", "execution_context", + "ToolCallLoopLimitExceeded", + "ToolCallNotSupported", ] diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index b2fc55440b..5d0014aa3d 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -233,6 +233,70 @@ def __init__(self, *, message: str = "No prompt placeholder") -> None: super().__init__(message=message) +class ToolCallNotSupported(PyritException): + """ + Raised when a target produces a tool call that the configured + :class:`~pyrit.tools.ToolEventPolicy` does not permit to execute + (``ToolEventBehavior.RAISE``, or ``EXECUTE`` without a backend). + + The ``partial_conversation`` attribute carries every message produced + up to and including the assistant turn that contained the offending + tool call(s). Consumers can inspect it to log the surfaced tool-use + attempt. + """ + + def __init__( + self, + *, + message: str = "Tool call not supported by configured policy.", + partial_conversation: Optional[list["Message"]] = None, + ) -> None: + """ + Initialize the exception. + + Args: + message (str): Human-readable error description. + partial_conversation (Optional[list[Message]]): Messages produced by + the target up to (and including) the assistant turn that + contained the disallowed tool call(s). + """ + super().__init__(status_code=400, message=message) + self.partial_conversation: list[Message] = ( + list(partial_conversation) if partial_conversation is not None else [] + ) + + +class ToolCallLoopLimitExceeded(PyritException): + """ + Raised when the tool-use loop runs for more than + ``ToolEventPolicy.max_tool_iterations`` iterations without the model + producing a stop response. + + The ``partial_conversation`` attribute carries every message produced + across all completed iterations. Consumers can inspect it to debug + runaway agentic behavior. + """ + + def __init__( + self, + *, + message: str = "Tool loop exceeded max_tool_iterations without a stop response.", + partial_conversation: Optional[list["Message"]] = None, + ) -> None: + """ + Initialize the exception. + + Args: + message (str): Human-readable error description. + partial_conversation (Optional[list[Message]]): Messages produced by + the target across every completed iteration of the tool loop. + """ + super().__init__(status_code=400, message=message) + self.partial_conversation: list[Message] = ( + list(partial_conversation) if partial_conversation is not None else [] + ) + + def pyrit_custom_result_retry( retry_function: Callable[..., bool], retry_max_num_attempts: Optional[int] = None ) -> Callable[..., Any]: diff --git a/pyrit/tools/__init__.py b/pyrit/tools/__init__.py new file mode 100644 index 0000000000..9830f17758 --- /dev/null +++ b/pyrit/tools/__init__.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Generic tool-use scaffolding for :class:`~pyrit.prompt_target.PromptTarget`. + +This package provides a transport-agnostic tool-calling loop. The +:func:`tool_loop` decorator, when applied to ``send_prompt_async``, runs +the standard PyRIT validate+normalize work once and then repeatedly +re-enters the target's protected ``_send_prompt_to_target_async`` until +the model issues a stop response (or a configured limit is hit). + +A target opts in by declaring two collaborators: + +* ``self._tool_parser`` — a :class:`ToolCallParser` that walks a + response message and extracts pending :class:`ToolCall` instances. +* ``self.configuration.tool_event_policy`` — a :class:`ToolEventPolicy` + whose :class:`ToolEventBehavior` decides whether to ``EXECUTE``, + ``RAISE``, or ``RETURN_RAW`` on each detected call. + +When the policy is ``EXECUTE``, calls are dispatched through +``self.configuration.tool_backend``, an implementation of +:class:`ToolBackend`. :class:`CallableToolBackend` is the pure-Python +backend shipped here; :class:`MCPToolBackend` ships in C3 and proxies +through one or more MCP servers. + +The :class:`ToolBackend` Protocol is intentionally distinct from +:mod:`pyrit.registry` — that namespace is reserved for framework-level +identity registries (``TargetRegistry``, ``ScorerRegistry``) that +register named singletons for CLI lookup, which a per-target tool +dispatch table is not. + +Wiring of ``@tool_loop`` onto :class:`PromptTarget.send_prompt_async` +and of the ``tool_event_policy`` / ``tool_backend`` fields onto +:class:`TargetConfiguration` lands in C4/C5. + +The two exception types the loop raises +(:class:`~pyrit.exceptions.ToolCallNotSupported` and +:class:`~pyrit.exceptions.ToolCallLoopLimitExceeded`) live in +:mod:`pyrit.exceptions` alongside the rest of PyRIT's exception +catalog, so non-tools callers (attacks, normalizers) can import them +without taking a subsystem-level dependency on ``pyrit.tools``. +""" + +from pyrit.tools.backend import ToolBackend +from pyrit.tools.callable_backend import CallableToolBackend +from pyrit.tools.models import ToolCall, ToolEventBehavior, ToolEventPolicy, tool_loop +from pyrit.tools.parsers import ToolCallParser + +__all__ = [ + "CallableToolBackend", + "ToolBackend", + "ToolCall", + "ToolCallParser", + "ToolEventBehavior", + "ToolEventPolicy", + "tool_loop", +] diff --git a/pyrit/tools/backend.py b/pyrit/tools/backend.py new file mode 100644 index 0000000000..3274355d6e --- /dev/null +++ b/pyrit/tools/backend.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + from pyrit.tools.models import ToolCall + + +@runtime_checkable +class ToolBackend(Protocol): + """ + Protocol for backends that dispatch tool calls produced by a target. + + A :class:`ToolBackend` is a per-target dispatch table — it owns the + ``name -> async callable`` mapping a target uses to execute the tool + calls extracted from a model response. This is intentionally distinct + from :mod:`pyrit.registry`, whose ``Registry`` classes register named + framework singletons (targets, scorers, attacks) for CLI lookup. + + Two concrete implementations ship with PyRIT: + + * :class:`~pyrit.tools.CallableToolBackend` — pure-Python backend + backed by ``async def`` callables. Useful for unit tests and for + embedding tools inside the PyRIT process. + * :class:`pyrit.tools.MCPToolBackend` (lands in C3) — proxies + dispatch through one or more MCP servers. + + The :attr:`schemas` property exposes the JSON-schema descriptors the + target injects into its request body (e.g. ``tools=[...]`` for the + OpenAI APIs). + + :meth:`dispatch_all_sequential_async` is the contract the tool loop + uses: backends that wish to parallelize dispatch should override it. + The default sequencing — one ``await dispatch_async`` per call, in + declaration order — is what every PyRIT backend ships with today. + """ + + @property + def schemas(self) -> list[dict[str, Any]]: + """ + The JSON-schema descriptors for every tool the backend exposes. + + Returns: + list[dict[str, Any]]: One schema per tool, in a target-agnostic + format that concrete targets serialize into their request + body. + """ + ... + + async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: + """ + Execute a single tool call and return the structured result. + + Implementations MUST NOT raise on tool-side failures; they MUST + return an error envelope (e.g. ``{"error": "...", "tool": "..."}``) + so the tool loop can carry the failure back to the model. + + Args: + call (ToolCall): The tool call to dispatch. + + Returns: + dict[str, Any]: The structured tool result. + """ + ... + + async def dispatch_all_sequential_async( + self, + calls: list[ToolCall], + ) -> list[tuple[ToolCall, dict[str, Any]]]: + """ + Dispatch every call in *calls* sequentially, preserving declaration order. + + Args: + calls (list[ToolCall]): The calls to dispatch, in declaration order. + + Returns: + list[tuple[ToolCall, dict[str, Any]]]: ``(call, result)`` pairs, + in the same order as *calls*. + """ + ... diff --git a/pyrit/tools/callable_backend.py b/pyrit/tools/callable_backend.py new file mode 100644 index 0000000000..defbe94ed1 --- /dev/null +++ b/pyrit/tools/callable_backend.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from pyrit.tools.models import ToolCall + +logger = logging.getLogger(__name__) + + +class CallableToolBackend: + """ + Pure-Python :class:`~pyrit.tools.ToolBackend` backed by a name -> ``async def`` + mapping. Useful for unit tests and for embedding small tools inside the + PyRIT process without standing up an MCP server. + + The backend dispatches sequentially in declaration order. Tool-side + failures (raised exceptions, missing names, allow-list rejections) + are converted into structured error envelopes so the tool loop can + forward them back to the model as ``function_call_output`` content + rather than aborting the conversation. + """ + + def __init__( + self, + *, + callables: dict[str, Callable[[dict[str, Any]], Awaitable[Any]]], + schemas: list[dict[str, Any]] | None = None, + allowed_tools: set[str] | None = None, + fail_on_missing_function: bool = True, + ) -> None: + """ + Initialize the backend. + + Args: + callables (dict[str, Callable[[dict[str, Any]], Awaitable[Any]]]): + Map from tool name to an ``async def`` that accepts a parsed + arguments dict and returns the tool result. Results are + serialized by the tool loop via :func:`json.dumps`. + schemas (list[dict[str, Any]] | None): JSON-schema descriptors + injected into the target's request body. Defaults to an empty + list when omitted. + allowed_tools (set[str] | None): Optional allow-list of tool + names; calls whose name is not in this set surface as + ``tool_not_allowed`` envelopes without invoking the callable. + Defaults to None (no allow-list; every registered tool is + callable). + fail_on_missing_function (bool): When True (default), an unknown + tool name raises :class:`KeyError`. When False, the backend + returns a ``tool_not_registered`` envelope so the model can + recover. + """ + self._callables = dict(callables) + self._schemas: list[dict[str, Any]] = list(schemas) if schemas is not None else [] + self._allowed_tools = set(allowed_tools) if allowed_tools is not None else None + self._fail_on_missing_function = fail_on_missing_function + + @property + def schemas(self) -> list[dict[str, Any]]: + """The JSON-schema descriptors for the tools in this backend.""" + return list(self._schemas) + + async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: + """ + Dispatch a single tool call. Tool failures are converted into + structured envelopes; only configuration errors (missing tool with + ``fail_on_missing_function=True``) propagate as exceptions. + + Args: + call (ToolCall): The call to dispatch. + + Returns: + dict[str, Any]: The tool's result, or a structured error envelope. + + Raises: + KeyError: When the tool name is not registered and + ``fail_on_missing_function=True``. + """ + if self._allowed_tools is not None and call.name not in self._allowed_tools: + logger.info("Rejecting disallowed tool call: %s", call.name) + return { + "error": "tool_not_allowed", + "tool": call.name, + "allowed_tools": sorted(self._allowed_tools), + } + + fn = self._callables.get(call.name) + if fn is None: + if self._fail_on_missing_function: + raise KeyError(f"Tool '{call.name}' is not registered.") + available = sorted(self._callables.keys()) + logger.warning("Tool '%s' not registered. Available: %s", call.name, available) + return { + "error": "tool_not_registered", + "tool": call.name, + "available_tools": available, + } + + try: + result = await fn(call.arguments) + except Exception as ex: + logger.warning("Tool '%s' raised %s: %s", call.name, type(ex).__name__, ex) + return { + "error": "tool_execution_failed", + "tool": call.name, + "detail": str(ex), + } + return result if isinstance(result, dict) else {"result": result} + + async def dispatch_all_sequential_async( + self, + calls: list[ToolCall], + ) -> list[tuple[ToolCall, dict[str, Any]]]: + """ + Dispatch *calls* sequentially in declaration order. + + Args: + calls (list[ToolCall]): Calls to dispatch. + + Returns: + list[tuple[ToolCall, dict[str, Any]]]: ``(call, result)`` pairs + in the same order as *calls*. + """ + results: list[tuple[ToolCall, dict[str, Any]]] = [] + for call in calls: + result = await self.dispatch_async(call) + results.append((call, result)) + return results diff --git a/pyrit/tools/models.py b/pyrit/tools/models.py new file mode 100644 index 0000000000..0e05f9be9d --- /dev/null +++ b/pyrit/tools/models.py @@ -0,0 +1,241 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import enum +import functools +import json +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from pyrit.exceptions import ToolCallLoopLimitExceeded, ToolCallNotSupported +from pyrit.models import Message, MessagePiece + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from pyrit.tools.backend import ToolBackend + from pyrit.tools.parsers import ToolCallParser + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ToolCall: + """ + A parsed tool call extracted from a target response. + + Concrete :class:`~pyrit.tools.ToolCallParser` implementations build + :class:`ToolCall` instances by walking the response message pieces. + The :attr:`raw_envelope` carries the original target-specific dict + (e.g. the function_call JSON section) so dispatchers and observers + can recover provider-specific fields without re-parsing. + + Attributes: + call_id (str): The provider-issued call identifier; must round-trip + into the matching ``function_call_output`` piece. + name (str): The tool name to dispatch. + arguments (dict[str, Any]): The parsed JSON arguments. + raw_envelope (dict[str, Any]): The original provider envelope. + """ + + call_id: str + name: str + arguments: dict[str, Any] + raw_envelope: dict[str, Any] = field(default_factory=dict) + + +class ToolEventBehavior(enum.Enum): + """ + What the tool loop should do when a target response contains a + pending tool call. + + Values: + EXECUTE: Dispatch the call via ``configuration.tool_backend`` + and re-enter the target with the tool output appended. + This is the standard agentic loop behavior. + RAISE: Raise :class:`~pyrit.exceptions.ToolCallNotSupported` with + the partial conversation attached. Useful for red-team + attacks that want to observe attempted tool use without + allowing execution. + RETURN_RAW: Return the assistant response containing the tool + call as-is, without dispatching. Useful when a caller wants + to inspect tool calls in-band (e.g. a scorer that scores + attempted tool use). + """ + + EXECUTE = "execute" + RAISE = "raise" + RETURN_RAW = "return_raw" + + +@dataclass(frozen=True) +class ToolEventPolicy: + """ + Per-target configuration that controls how the tool loop responds + to a pending tool call from the model. + + Attributes: + behavior (ToolEventBehavior): What to do on each detected tool call. + max_tool_iterations (int): Maximum number of model<->tool round-trips + before the loop raises :class:`ToolCallLoopLimitExceeded`. Each + iteration is one ``_send_prompt_to_target_async`` call. + """ + + behavior: ToolEventBehavior + max_tool_iterations: int = 5 + + +def _build_function_call_output_message( + *, + reference_piece: MessagePiece, + outputs: list[tuple[ToolCall, Any]], +) -> Message: + """ + Build the canonical ``tool`` message produced after dispatching one or more + tool calls in a single iteration. + + The returned :class:`Message` contains one + :class:`MessagePiece` per ``(call, result)`` pair, in declaration order. + Every piece has ``role="tool"`` and ``original_value_data_type="function_call_output"``, + with the JSON envelope ``{"type": "function_call_output", "call_id": ..., "output": ...}``. + + Lineage metadata (conversation_id, identifiers) is copied from + *reference_piece* — typically the first piece of the assistant message + that issued the tool calls — so the resulting message stays inside the + correct conversation. + + Args: + reference_piece (MessagePiece): Piece whose lineage metadata is + copied onto every output piece. Pass the first piece of the + assistant message that produced the calls. + outputs (list[tuple[ToolCall, Any]]): ``(call, result)`` pairs in + declaration order. *result* is serialized via :func:`json.dumps` + unless it is already a string. + + Returns: + Message: One message carrying every function_call_output piece. + """ + pieces: list[MessagePiece] = [] + for call, result in outputs: + output_str = result if isinstance(result, str) else json.dumps(result, separators=(",", ":")) + envelope = json.dumps( + {"type": "function_call_output", "call_id": call.call_id, "output": output_str}, + separators=(",", ":"), + ) + pieces.append( + MessagePiece( + role="tool", + original_value=envelope, + original_value_data_type="function_call_output", + conversation_id=reference_piece.conversation_id, + prompt_target_identifier=reference_piece.prompt_target_identifier, + attack_identifier=reference_piece.attack_identifier, + ) + ) + return Message(message_pieces=pieces, skip_validation=True) + + +def tool_loop( + method: Callable[..., Awaitable[list[Message]]], +) -> Callable[..., Awaitable[list[Message]]]: + """ + Wrap a :class:`~pyrit.prompt_target.PromptTarget`-style + ``send_prompt_async`` to run an agentic tool-use loop. + + When the target's ``configuration.tool_event_policy`` is ``None`` the + wrapper is a no-op — the wrapped method runs unchanged. When a policy + is configured, the wrapper replaces the method body with the loop: + + 1. Validate and normalize the incoming message exactly once. + 2. Repeatedly call ``self._send_prompt_to_target_async`` with the + growing conversation. + 3. After each call, parse the last response via ``self._tool_parser``. + Exit on empty parse (model issued a stop response). + 4. On a non-empty parse, branch on ``policy.behavior``: + ``RAISE`` raises :class:`ToolCallNotSupported`; ``RETURN_RAW`` + returns the chain as-is; ``EXECUTE`` dispatches the calls via + ``configuration.tool_backend`` and appends the tool message. + 5. Raise :class:`ToolCallLoopLimitExceeded` if the loop runs past + ``policy.max_tool_iterations`` without the model stopping. + + The decorator deliberately knows nothing about MCP, OpenAI, or any + specific transport. The two collaborators it requires — + ``self._tool_parser`` and ``self.configuration.tool_backend`` — are + plain protocols (:class:`ToolCallParser`, :class:`ToolBackend`). + + Args: + method (Callable): The async method to wrap. Must have the + ``async def f(self, *, message: Message) -> list[Message]`` + signature of :meth:`PromptTarget.send_prompt_async`. + + Returns: + Callable: The wrapped method. + """ + + @functools.wraps(method) + async def wrapper(self: Any, *, message: Message) -> list[Message]: + policy: ToolEventPolicy | None = getattr(self.configuration, "tool_event_policy", None) + if policy is None: + return await method(self, message=message) + + message.validate() + normalized_conversation = await self._get_normalized_conversation_async(message=message) + if not normalized_conversation: + raise ValueError("Normalization pipeline returned an empty conversation. Cannot send an empty request.") + self._validate_request(normalized_conversation=normalized_conversation) + + parser: ToolCallParser | None = getattr(self, "_tool_parser", None) + backend: ToolBackend | None = getattr(self.configuration, "tool_backend", None) + max_iter = policy.max_tool_iterations + + all_responses: list[Message] = [] + + for _ in range(max_iter): + responses_this_turn = await self._send_prompt_to_target_async( + normalized_conversation=normalized_conversation, + ) + all_responses.extend(responses_this_turn) + + if parser is None: + return all_responses + + last_response = responses_this_turn[-1] + pending_calls = parser.parse(last_response) + + if not pending_calls: + return all_responses + + if policy.behavior is ToolEventBehavior.RAISE: + raise ToolCallNotSupported( + message=( + f"Target produced {len(pending_calls)} tool call(s) but ToolEventPolicy.behavior is RAISE." + ), + partial_conversation=all_responses, + ) + + if policy.behavior is ToolEventBehavior.RETURN_RAW: + return all_responses + + if backend is None: + raise ToolCallNotSupported( + message=(f"Target produced {len(pending_calls)} tool call(s) but no tool_backend is configured."), + partial_conversation=all_responses, + ) + + results = await backend.dispatch_all_sequential_async(pending_calls) + tool_msg = _build_function_call_output_message( + reference_piece=last_response.message_pieces[0], + outputs=results, + ) + all_responses.append(tool_msg) + normalized_conversation = list(normalized_conversation) + [last_response, tool_msg] + + raise ToolCallLoopLimitExceeded( + message=f"Tool loop exceeded max_tool_iterations={max_iter} without a stop response.", + partial_conversation=all_responses, + ) + + return wrapper diff --git a/pyrit/tools/parsers.py b/pyrit/tools/parsers.py new file mode 100644 index 0000000000..4ff7fc4c04 --- /dev/null +++ b/pyrit/tools/parsers.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from pyrit.models import Message, MessagePiece + from pyrit.tools.models import ToolCall + + +@runtime_checkable +class ToolCallParser(Protocol): + """ + Protocol for extracting tool calls from a target response message. + + Concrete parsers live next to the target whose response shape they + understand (see :class:`OpenAIChatTarget` and :class:`OpenAIResponseTarget` + after C7/C8). Parsers MUST return an empty list when the model has + issued a stop response — the tool loop uses the empty list as the + signal to exit. + """ + + def parse(self, message: Message) -> list[ToolCall]: + """ + Extract tool calls from a target response message. + + Args: + message (Message): The most recent assistant response. + + Returns: + list[ToolCall]: Tool calls, in declaration order. An empty list + signals that the model produced a stop response. + """ + ... + + +def _extract_function_call_pieces(message: Message) -> list[MessagePiece]: + """ + Return every :class:`MessagePiece` in *message* whose + ``original_value_data_type`` is ``"function_call"``. + + This is the canonical envelope produced by OpenAI-style targets after + the C6 normalization commit. It is exposed here so concrete parsers + can reuse the filter rather than re-implementing it. + + Args: + message (Message): The message to scan. + + Returns: + list[MessagePiece]: Pieces whose ``original_value_data_type`` is + ``"function_call"``, in their declaration order. + """ + return [piece for piece in message.message_pieces if piece.original_value_data_type == "function_call"] diff --git a/tests/unit/tools/__init__.py b/tests/unit/tools/__init__.py new file mode 100644 index 0000000000..9a0454564d --- /dev/null +++ b/tests/unit/tools/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/tools/conftest.py b/tests/unit/tools/conftest.py new file mode 100644 index 0000000000..9ac8ce8aa5 --- /dev/null +++ b/tests/unit/tools/conftest.py @@ -0,0 +1,307 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Shared fixtures for ``tests/unit/tools``. + +Provides the minimal collaborators the tool-loop tests need to exercise +:func:`pyrit.tools.tool_loop` end-to-end without standing up real targets +or MCP transports: + +* :class:`_FakeToolTarget` — a :class:`PromptTarget` subclass whose + ``_send_prompt_to_target_async`` returns scripted messages from a queue + and whose ``_get_normalized_conversation_async`` skips the memory round + trip so decorator behavior is isolated from normalization. +* :class:`_RecordingToolBackend` — a :class:`ToolBackend` that records + every dispatched call (for order-of-execution assertions) and returns + results from a scripted queue. +* :class:`_CanonicalEnvelopeParser` — a :class:`ToolCallParser` that walks + message pieces and parses the canonical ``function_call`` JSON envelope. + +Helper message builders (``_make_user_message``, +``_make_assistant_text_message``, ``_make_assistant_function_call_message``) +produce the canonical envelope shape used by the OpenAI targets after the +C6 normalization commit. +""" + +from __future__ import annotations + +import json +import uuid +from collections import deque +from typing import Any + +import pytest + +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import ( + CallableToolBackend, + ToolCall, + ToolCallParser, + ToolEventBehavior, + ToolEventPolicy, + tool_loop, +) + + +def _make_user_message(text: str, *, conversation_id: str | None = None) -> Message: + """Build a single-piece user :class:`Message` carrying *text*.""" + return Message( + message_pieces=[ + MessagePiece( + role="user", + original_value=text, + original_value_data_type="text", + conversation_id=conversation_id or str(uuid.uuid4()), + ) + ] + ) + + +def _make_assistant_text_message(text: str, *, conversation_id: str | None = None) -> Message: + """Build a single-piece assistant :class:`Message` carrying plain text.""" + return Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=text, + original_value_data_type="text", + conversation_id=conversation_id or str(uuid.uuid4()), + ) + ], + skip_validation=True, + ) + + +def _make_function_call_piece( + *, + call_id: str, + name: str, + arguments: dict[str, Any], + conversation_id: str | None = None, +) -> MessagePiece: + """Build one assistant ``function_call`` piece carrying the canonical envelope.""" + envelope = { + "type": "function_call", + "call_id": call_id, + "name": name, + "arguments": json.dumps(arguments, separators=(",", ":")), + } + return MessagePiece( + role="assistant", + original_value=json.dumps(envelope, separators=(",", ":")), + original_value_data_type="function_call", + conversation_id=conversation_id or str(uuid.uuid4()), + ) + + +def _make_assistant_function_call_message( + *, + calls: list[tuple[str, str, dict[str, Any]]], + conversation_id: str | None = None, +) -> Message: + """ + Build an assistant :class:`Message` carrying one ``function_call`` piece + per ``(call_id, name, args)`` tuple, in declaration order. + """ + conv_id = conversation_id or str(uuid.uuid4()) + pieces = [ + _make_function_call_piece(call_id=cid, name=name, arguments=args, conversation_id=conv_id) + for cid, name, args in calls + ] + return Message(message_pieces=pieces, skip_validation=True) + + +class _CanonicalEnvelopeParser: + """ + Reference :class:`ToolCallParser` that understands the canonical envelope + (``original_value_data_type == "function_call"`` carrying a JSON object + with ``type``/``call_id``/``name``/``arguments``). + + Per-target parsers shipped in C7/C8 will reuse this shape; this stand-in + keeps decorator tests independent of the real OpenAI parsers. + """ + + def parse(self, message: Message) -> list[ToolCall]: + calls: list[ToolCall] = [] + for piece in message.message_pieces: + if piece.original_value_data_type != "function_call": + continue + envelope = json.loads(piece.original_value) + arguments_str = envelope.get("arguments", "{}") + arguments = json.loads(arguments_str) if isinstance(arguments_str, str) else dict(arguments_str) + calls.append( + ToolCall( + call_id=envelope["call_id"], + name=envelope["name"], + arguments=arguments, + raw_envelope=envelope, + ) + ) + return calls + + +class _RecordingToolBackend: + """ + Minimal :class:`ToolBackend` that records every dispatched call and + returns results from a scripted queue. Used to assert dispatch order, + iteration count, and per-call payload shape without invoking real tools. + """ + + def __init__( + self, + *, + scripted_results: list[Any] | None = None, + schemas: list[dict[str, Any]] | None = None, + ) -> None: + self._results: deque[Any] = deque(scripted_results or []) + self._schemas: list[dict[str, Any]] = list(schemas) if schemas is not None else [] + self.recorded_calls: list[ToolCall] = [] + + @property + def schemas(self) -> list[dict[str, Any]]: + return list(self._schemas) + + async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: + self.recorded_calls.append(call) + if not self._results: + return {"result": f"recorded:{call.name}:{call.call_id}"} + nxt = self._results.popleft() + return nxt if isinstance(nxt, dict) else {"result": nxt} + + async def dispatch_all_sequential_async( + self, + calls: list[ToolCall], + ) -> list[tuple[ToolCall, dict[str, Any]]]: + results: list[tuple[ToolCall, dict[str, Any]]] = [] + for call in calls: + result = await self.dispatch_async(call) + results.append((call, result)) + return results + + +class _FakeToolTarget(PromptTarget): + """ + Test-only :class:`PromptTarget` whose ``_send_prompt_to_target_async`` + pops scripted responses off a queue. ``_get_normalized_conversation_async`` + is overridden to return ``[message]`` directly, isolating decorator + behavior from the memory + normalization pipeline. + """ + + _DEFAULT_CONFIGURATION = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + ) + ) + + def __init__( + self, + *, + scripted_responses: list[Message], + policy: ToolEventPolicy | None = None, + backend: Any = None, + parser: ToolCallParser | None = None, + ) -> None: + super().__init__() + self._scripted_responses: deque[Message] = deque(scripted_responses) + self.call_count: int = 0 + self.normalized_conversations_seen: list[list[Message]] = [] + # The C2 decorator reads these via getattr; production code wires them + # through TargetConfiguration in C4. + self._configuration.tool_event_policy = policy + self._configuration.tool_backend = backend + self._tool_parser = parser if parser is not None else _CanonicalEnvelopeParser() + + async def _get_normalized_conversation_async(self, *, message: Message) -> list[Message]: + return [message] + + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: + return + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + self.call_count += 1 + self.normalized_conversations_seen.append(list(normalized_conversation)) + if not self._scripted_responses: + raise AssertionError(f"Fake target ran out of scripted responses on iteration {self.call_count}.") + return [self._scripted_responses.popleft()] + + @tool_loop + async def send_prompt_async(self, *, message: Message) -> list[Message]: + # Passthrough path: only invoked when ToolEventPolicy is None. The + # decorator replaces this body entirely when a policy is set. + message.validate() + normalized = await self._get_normalized_conversation_async(message=message) + return await self._send_prompt_to_target_async(normalized_conversation=normalized) + + +@pytest.fixture +def make_fake_target(patch_central_database): + """ + Factory fixture for :class:`_FakeToolTarget`. Each invocation returns a + fresh target instance whose scripted response queue is independent of + other targets created during the test. + """ + + def _factory( + *, + scripted_responses: list[Message], + policy: ToolEventPolicy | None = None, + backend: Any = None, + parser: ToolCallParser | None = None, + ) -> _FakeToolTarget: + return _FakeToolTarget( + scripted_responses=scripted_responses, + policy=policy, + backend=backend, + parser=parser, + ) + + return _factory + + +@pytest.fixture +def recording_backend(): + """Factory fixture for :class:`_RecordingToolBackend`.""" + + def _factory(*, scripted_results: list[Any] | None = None) -> _RecordingToolBackend: + return _RecordingToolBackend(scripted_results=scripted_results) + + return _factory + + +@pytest.fixture +def execute_policy(): + """ + Factory fixture for :class:`ToolEventPolicy` with + ``behavior=ToolEventBehavior.EXECUTE`` and a tunable iteration cap. + """ + + def _factory(*, max_tool_iterations: int = 5) -> ToolEventPolicy: + return ToolEventPolicy( + behavior=ToolEventBehavior.EXECUTE, + max_tool_iterations=max_tool_iterations, + ) + + return _factory + + +__all__ = [ + "CallableToolBackend", + "ToolCall", + "ToolEventBehavior", + "ToolEventPolicy", + "_CanonicalEnvelopeParser", + "_FakeToolTarget", + "_RecordingToolBackend", + "_make_assistant_function_call_message", + "_make_assistant_text_message", + "_make_function_call_piece", + "_make_user_message", + "execute_policy", + "make_fake_target", + "recording_backend", +] diff --git a/tests/unit/tools/echo_mcp_server.py b/tests/unit/tools/echo_mcp_server.py new file mode 100644 index 0000000000..723a3c6594 --- /dev/null +++ b/tests/unit/tools/echo_mcp_server.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Deterministic echo MCP server used as a stdio subprocess fixture by +``tests/unit/tools/test_mcp_client.py`` (C3) and the integration tests +(C9). + +Lands in C2 so subsequent commits don't shuffle test plumbing; C2's own +tests do not import this module (the :class:`CallableToolRegistry` is +exercised in-process). + +Run directly as ``python echo_mcp_server.py`` to expose the four tools +over stdio. The MCP client harness in C3 launches this file with +``mcp.client.stdio.stdio_client`` and asserts behavior end to end. +""" + +from __future__ import annotations + +import asyncio + +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP("pyrit-echo") + + +@mcp.tool() +def echo(text: str) -> str: + """Return *text* unchanged.""" + return text + + +@mcp.tool() +def add(a: int, b: int) -> int: + """Return ``a + b``.""" + return a + b + + +@mcp.tool() +def reverse(text: str) -> str: + """Return *text* reversed.""" + return text[::-1] + + +@mcp.tool() +async def slow_echo(text: str, delay_ms: int = 0) -> str: + """ + Return *text* after sleeping ``delay_ms`` milliseconds. Used by + timeout / cancellation tests. + """ + if delay_ms > 0: + await asyncio.sleep(delay_ms / 1000.0) + return text + + +if __name__ == "__main__": + mcp.run() diff --git a/tests/unit/tools/test_callable_tool_backend.py b/tests/unit/tools/test_callable_tool_backend.py new file mode 100644 index 0000000000..fb09788ebc --- /dev/null +++ b/tests/unit/tools/test_callable_tool_backend.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for :class:`pyrit.tools.CallableToolBackend`. + +Coverage map (rows from the C2 test matrix): + +* **U10** (partial; the MCP counterpart lands in C3) — + ``test_each_dummy_tool_invoked_via_prepended_conversation`` +* **U17** (partial; the MCP-timeout counterpart lands in C3) — + ``test_failing_tool_yields_error_envelope`` +* **U18** — ``test_disallowed_tool_returns_error_without_invoking_callable`` + +Also covers the backend's documented behavior for missing functions +(both strict and tolerant modes), schema property defaulting, scalar +result wrapping, and declaration-order preservation in the bulk dispatch +path. These are required for the §10 rubber-duck guarantee that every +public-facing branch of :class:`CallableToolBackend` is exercised +before C5 wires it to a production target. +""" + +from __future__ import annotations + +import pytest + +from pyrit.tools import CallableToolBackend, ToolCall + + +def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) -> ToolCall: + return ToolCall(call_id=call_id, name=name, arguments=arguments or {}) + + +async def test_disallowed_tool_returns_error_without_invoking_callable(): + invoked: list[str] = [] + + async def echo(args: dict) -> dict: + invoked.append(args.get("text", "")) + return {"echoed": args.get("text", "")} + + backend = CallableToolBackend( + callables={"echo": echo, "off_limits": echo}, + allowed_tools={"echo"}, + ) + + result = await backend.dispatch_async(_make_call("off_limits", arguments={"text": "nope"})) + + assert result["error"] == "tool_not_allowed" + assert result["tool"] == "off_limits" + assert "echo" in result["allowed_tools"] + assert invoked == [] # callable was never invoked + + +async def test_failing_tool_yields_error_envelope(): + async def boom(args: dict) -> dict: + raise RuntimeError("kaboom") + + backend = CallableToolBackend(callables={"boom": boom}) + + result = await backend.dispatch_async(_make_call("boom")) + + assert result["error"] == "tool_execution_failed" + assert result["tool"] == "boom" + assert "kaboom" in result["detail"] + + +async def test_missing_tool_raises_when_strict(): + backend = CallableToolBackend(callables={}, fail_on_missing_function=True) + + with pytest.raises(KeyError, match="ghost"): + await backend.dispatch_async(_make_call("ghost")) + + +async def test_missing_tool_returns_envelope_when_tolerant(): + async def echo(args: dict) -> dict: + return {"ok": True} + + backend = CallableToolBackend( + callables={"echo": echo}, + fail_on_missing_function=False, + ) + + result = await backend.dispatch_async(_make_call("ghost")) + + assert result["error"] == "tool_not_registered" + assert result["tool"] == "ghost" + assert result["available_tools"] == ["echo"] + + +async def test_scalar_result_is_wrapped_in_dict(): + async def number(args: dict) -> int: + return 42 + + backend = CallableToolBackend(callables={"number": number}) + + result = await backend.dispatch_async(_make_call("number")) + + assert result == {"result": 42} + + +async def test_dict_result_passes_through_unchanged(): + async def named(args: dict) -> dict: + return {"custom_key": "custom_value"} + + backend = CallableToolBackend(callables={"named": named}) + + result = await backend.dispatch_async(_make_call("named")) + + assert result == {"custom_key": "custom_value"} + + +async def test_schemas_defaults_to_empty_list(): + backend = CallableToolBackend(callables={}) + + assert backend.schemas == [] + + +async def test_schemas_returned_as_copy(): + schemas_in = [{"name": "echo", "parameters": {}}] + backend = CallableToolBackend(callables={}, schemas=schemas_in) + + out1 = backend.schemas + out1.append({"name": "mutated"}) + + # Mutating the returned list does not affect the backend's internal state. + assert backend.schemas == schemas_in + + +async def test_dispatch_all_sequential_preserves_declaration_order(): + async def echo(args: dict) -> dict: + return {"echoed": args["i"]} + + backend = CallableToolBackend(callables={"echo": echo}) + + calls = [_make_call("echo", call_id=f"c{i}", arguments={"i": i}) for i in range(5)] + pairs = await backend.dispatch_all_sequential_async(calls) + + assert [c.call_id for c, _ in pairs] == ["c0", "c1", "c2", "c3", "c4"] + assert [r["echoed"] for _, r in pairs] == [0, 1, 2, 3, 4] + + +async def test_each_dummy_tool_invoked_via_prepended_conversation(): + """ + U10 (partial). Each dummy tool resolves on first dispatch (single + forward step, no model reasoning trace), confirming the backend can + short-circuit a prepended conversation where every call is already + decided. The MCP counterpart in C3 exercises the same shape against + a real stdio server. + """ + invocations: list[tuple[str, dict]] = [] + + async def echo(args: dict) -> dict: + invocations.append(("echo", args)) + return {"echoed": args.get("text", "")} + + async def add(args: dict) -> dict: + invocations.append(("add", args)) + return {"sum": args["a"] + args["b"]} + + async def reverse(args: dict) -> dict: + invocations.append(("reverse", args)) + return {"reversed": args.get("text", "")[::-1]} + + backend = CallableToolBackend(callables={"echo": echo, "add": add, "reverse": reverse}) + + prepended_calls = [ + _make_call("echo", call_id="e1", arguments={"text": "hello"}), + _make_call("add", call_id="a1", arguments={"a": 2, "b": 3}), + _make_call("reverse", call_id="r1", arguments={"text": "pyrit"}), + ] + pairs = await backend.dispatch_all_sequential_async(prepended_calls) + + # Each dummy resolved exactly once; no retries, no model re-entry. + assert len(invocations) == 3 + assert [name for name, _ in invocations] == ["echo", "add", "reverse"] + assert [r for _, r in pairs] == [ + {"echoed": "hello"}, + {"sum": 5}, + {"reversed": "tiryp"}, + ] diff --git a/tests/unit/tools/test_tool_loop_decorator.py b/tests/unit/tools/test_tool_loop_decorator.py new file mode 100644 index 0000000000..bc0db6b357 --- /dev/null +++ b/tests/unit/tools/test_tool_loop_decorator.py @@ -0,0 +1,289 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for :func:`pyrit.tools.tool_loop`. + +Coverage map (rows from the C2 test matrix): + +* **U2** (partial; full-DB end lands in C5) — ``test_loop_returns_full_chain_in_order`` +* **U3** — ``test_loop_exits_on_first_response_when_no_tool_calls``, + ``test_loops_until_no_pending_tool_call`` +* **U4** — ``test_raises_after_max_tool_iterations``, + ``test_partial_conversation_attached_to_limit_exception`` +* **U12** — ``test_policy_raise_includes_partial_conversation`` +* **U13** — ``test_policy_return_raw_does_not_dispatch`` +* **U16** — ``test_multi_call_per_turn_dispatched_sequentially_in_order`` + +Also covers two additional decorator concerns required by the rubber-duck +review (§10): EXECUTE policy with no backend raises with a partial +conversation, and the normalized conversation grows correctly across +iterations (decorator does not re-normalize each turn). +""" + +from __future__ import annotations + +import json + +import pytest + +from pyrit.exceptions import ToolCallLoopLimitExceeded, ToolCallNotSupported +from pyrit.tools import ToolEventBehavior, ToolEventPolicy + +from .conftest import ( + _make_assistant_function_call_message, + _make_assistant_text_message, + _make_user_message, +) + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopDecoratorBasics: + """Loop entry/exit semantics: no tool calls, single round trip, multi-round.""" + + async def test_loop_exits_on_first_response_when_no_tool_calls(self, make_fake_target, execute_policy): + target = make_fake_target( + scripted_responses=[_make_assistant_text_message("done")], + policy=execute_policy(), + ) + + responses = await target.send_prompt_async(message=_make_user_message("hi")) + + assert len(responses) == 1 + assert responses[0].get_value() == "done" + assert target.call_count == 1 + + async def test_loops_until_no_pending_tool_call(self, make_fake_target, execute_policy, recording_backend): + backend = recording_backend(scripted_results=[{"ok": True}, {"ok": True}]) + target = make_fake_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "tool_a", {"x": 1})]), + _make_assistant_function_call_message(calls=[("c2", "tool_a", {"x": 2})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy(max_tool_iterations=5), + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("hi")) + + # Two model-tool round trips and one final assistant message. + assert target.call_count == 3 + # Returned chain: fc1, tool1, fc2, tool2, final-text → 5 messages total. + assert len(responses) == 5 + assert [r.message_pieces[0].original_value_data_type for r in responses] == [ + "function_call", + "function_call_output", + "function_call", + "function_call_output", + "text", + ] + assert len(backend.recorded_calls) == 2 + assert [c.call_id for c in backend.recorded_calls] == ["c1", "c2"] + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopMessageShape: + """U2 — assistant_fc → tool → final_assistant ordering and identity.""" + + async def test_loop_returns_full_chain_in_order(self, make_fake_target, execute_policy, recording_backend): + backend = recording_backend(scripted_results=[{"weather": "sunny"}]) + fc_msg = _make_assistant_function_call_message(calls=[("call_abc", "get_weather", {"city": "Seattle"})]) + final_msg = _make_assistant_text_message("It is sunny in Seattle.") + + target = make_fake_target( + scripted_responses=[fc_msg, final_msg], + policy=execute_policy(), + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("weather?")) + + assert len(responses) == 3 + # 1) assistant with function_call (identity preserved) + assert responses[0] is fc_msg + # 2) tool message with exactly one function_call_output piece carrying call_id + tool_msg = responses[1] + assert len(tool_msg.message_pieces) == 1 + tool_piece = tool_msg.message_pieces[0] + assert tool_piece.api_role == "tool" + assert tool_piece.original_value_data_type == "function_call_output" + envelope = json.loads(tool_piece.original_value) + assert envelope["type"] == "function_call_output" + assert envelope["call_id"] == "call_abc" + # The tool result is JSON-serialized into the "output" field. + assert json.loads(envelope["output"]) == {"weather": "sunny"} + # 3) final assistant text (identity preserved) + assert responses[2] is final_msg + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopIterationLimits: + """U4 — iteration cap raises and carries the partial chain.""" + + async def test_raises_after_max_tool_iterations(self, make_fake_target, execute_policy, recording_backend): + # Model never stops asking for tools. + backend = recording_backend(scripted_results=[{"ok": True}] * 3) + target = make_fake_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[(f"c{i}", "loop_tool", {})]) for i in range(3) + ], + policy=execute_policy(max_tool_iterations=2), + backend=backend, + ) + + with pytest.raises(ToolCallLoopLimitExceeded, match="max_tool_iterations=2"): + await target.send_prompt_async(message=_make_user_message("hi")) + + # Exactly max_tool_iterations model calls made before raising. + assert target.call_count == 2 + + async def test_partial_conversation_attached_to_limit_exception( + self, make_fake_target, execute_policy, recording_backend + ): + backend = recording_backend(scripted_results=[{"ok": True}] * 2) + target = make_fake_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[(f"c{i}", "loop_tool", {})]) for i in range(2) + ], + policy=execute_policy(max_tool_iterations=2), + backend=backend, + ) + + with pytest.raises(ToolCallLoopLimitExceeded) as excinfo: + await target.send_prompt_async(message=_make_user_message("hi")) + + partial = excinfo.value.partial_conversation + # 2 iterations × (assistant_fc + tool_msg) = 4 messages, all in order. + assert len(partial) == 4 + assert [m.message_pieces[0].original_value_data_type for m in partial] == [ + "function_call", + "function_call_output", + "function_call", + "function_call_output", + ] + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolEventPolicyBehaviors: + """U12, U13 — non-EXECUTE behaviors short-circuit dispatch.""" + + async def test_policy_raise_includes_partial_conversation(self, make_fake_target, recording_backend): + backend = recording_backend(scripted_results=[{"ok": True}]) + fc_msg = _make_assistant_function_call_message(calls=[("c1", "danger", {})]) + target = make_fake_target( + scripted_responses=[fc_msg], + policy=ToolEventPolicy(behavior=ToolEventBehavior.RAISE), + backend=backend, + ) + + with pytest.raises(ToolCallNotSupported, match="RAISE") as excinfo: + await target.send_prompt_async(message=_make_user_message("hi")) + + partial = excinfo.value.partial_conversation + # Partial contains the offending assistant turn; no tool dispatch occurred. + assert partial == [fc_msg] + assert backend.recorded_calls == [] + assert target.call_count == 1 + + async def test_policy_return_raw_does_not_dispatch(self, make_fake_target, recording_backend): + backend = recording_backend(scripted_results=[{"ok": True}]) + fc_msg = _make_assistant_function_call_message(calls=[("c1", "danger", {})]) + target = make_fake_target( + scripted_responses=[fc_msg], + policy=ToolEventPolicy(behavior=ToolEventBehavior.RETURN_RAW), + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("hi")) + + assert responses == [fc_msg] + assert backend.recorded_calls == [] + assert target.call_count == 1 + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopMultiCallPerTurn: + """U16 — multi-call turns dispatch sequentially in declaration order.""" + + async def test_multi_call_per_turn_dispatched_sequentially_in_order( + self, make_fake_target, execute_policy, recording_backend + ): + backend = recording_backend(scripted_results=[{"a": 1}, {"b": 2}, {"c": 3}]) + multi_fc = _make_assistant_function_call_message( + calls=[ + ("c_alpha", "tool_alpha", {"k": "v1"}), + ("c_beta", "tool_beta", {"k": "v2"}), + ("c_gamma", "tool_gamma", {"k": "v3"}), + ] + ) + target = make_fake_target( + scripted_responses=[multi_fc, _make_assistant_text_message("ok")], + policy=execute_policy(), + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("multi")) + + # Three calls dispatched in declaration order, recorded ids match. + assert [c.call_id for c in backend.recorded_calls] == ["c_alpha", "c_beta", "c_gamma"] + assert [c.name for c in backend.recorded_calls] == ["tool_alpha", "tool_beta", "tool_gamma"] + # One tool message after the multi-call assistant turn, carrying three + # function_call_output pieces in declaration order with the right call_ids. + tool_msg = responses[1] + assert len(tool_msg.message_pieces) == 3 + envelopes = [json.loads(p.original_value) for p in tool_msg.message_pieces] + assert [e["call_id"] for e in envelopes] == ["c_alpha", "c_beta", "c_gamma"] + assert all(p.original_value_data_type == "function_call_output" for p in tool_msg.message_pieces) + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopMisconfiguration: + """EXECUTE policy with no backend must fail loudly and carry the partial chain.""" + + async def test_execute_without_backend_raises_with_partial(self, make_fake_target, execute_policy): + fc_msg = _make_assistant_function_call_message(calls=[("c1", "no_reg", {})]) + target = make_fake_target( + scripted_responses=[fc_msg], + policy=execute_policy(), + backend=None, + ) + + with pytest.raises(ToolCallNotSupported, match="tool_backend") as excinfo: + await target.send_prompt_async(message=_make_user_message("hi")) + + assert excinfo.value.partial_conversation == [fc_msg] + + +@pytest.mark.usefixtures("patch_central_database") +class TestToolLoopConversationGrowth: + """The decorator must extend (not re-normalize) the conversation each round.""" + + async def test_normalized_conversation_grows_each_iteration( + self, make_fake_target, execute_policy, recording_backend + ): + backend = recording_backend(scripted_results=[{"r1": 1}, {"r2": 2}]) + target = make_fake_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "t", {})]), + _make_assistant_function_call_message(calls=[("c2", "t", {})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy(), + backend=backend, + ) + + await target.send_prompt_async(message=_make_user_message("hi")) + + # Three protected-method calls; each subsequent call sees the prior + # assistant_fc + tool_msg appended (the decorator must NOT re-normalize). + seen = target.normalized_conversations_seen + assert len(seen) == 3 + # call 1: just the user message + assert len(seen[0]) == 1 + # call 2: user + assistant_fc(c1) + tool_msg + assert len(seen[1]) == 3 + assert seen[1][1].message_pieces[0].original_value_data_type == "function_call" + assert seen[1][2].message_pieces[0].original_value_data_type == "function_call_output" + # call 3: user + assistant_fc(c1) + tool_msg + assistant_fc(c2) + tool_msg + assert len(seen[2]) == 5 From 39752ac88bab2e6c1ad27ab1b660df8b6771c78b Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Tue, 26 May 2026 15:24:49 -0700 Subject: [PATCH 03/17] Addition of pyrit/tools package and unit tests. Introduces base models for tool calling. --- pyrit/tools/__init__.py | 6 ++--- pyrit/tools/backend.py | 6 ++--- .../{callable_backend.py => local_backend.py} | 9 +++++-- tests/unit/tools/conftest.py | 4 +-- ..._backend.py => test_local_tool_backend.py} | 26 +++++++++---------- 5 files changed, 28 insertions(+), 23 deletions(-) rename pyrit/tools/{callable_backend.py => local_backend.py} (93%) rename tests/unit/tools/{test_callable_tool_backend.py => test_local_tool_backend.py} (87%) diff --git a/pyrit/tools/__init__.py b/pyrit/tools/__init__.py index 9830f17758..c3520098ea 100644 --- a/pyrit/tools/__init__.py +++ b/pyrit/tools/__init__.py @@ -20,7 +20,7 @@ When the policy is ``EXECUTE``, calls are dispatched through ``self.configuration.tool_backend``, an implementation of -:class:`ToolBackend`. :class:`CallableToolBackend` is the pure-Python +:class:`ToolBackend`. :class:`LocalToolBackend` is the in-process backend shipped here; :class:`MCPToolBackend` ships in C3 and proxies through one or more MCP servers. @@ -43,12 +43,12 @@ """ from pyrit.tools.backend import ToolBackend -from pyrit.tools.callable_backend import CallableToolBackend +from pyrit.tools.local_backend import LocalToolBackend from pyrit.tools.models import ToolCall, ToolEventBehavior, ToolEventPolicy, tool_loop from pyrit.tools.parsers import ToolCallParser __all__ = [ - "CallableToolBackend", + "LocalToolBackend", "ToolBackend", "ToolCall", "ToolCallParser", diff --git a/pyrit/tools/backend.py b/pyrit/tools/backend.py index 3274355d6e..54c1dd7a1d 100644 --- a/pyrit/tools/backend.py +++ b/pyrit/tools/backend.py @@ -22,9 +22,9 @@ class ToolBackend(Protocol): Two concrete implementations ship with PyRIT: - * :class:`~pyrit.tools.CallableToolBackend` — pure-Python backend - backed by ``async def`` callables. Useful for unit tests and for - embedding tools inside the PyRIT process. + * :class:`~pyrit.tools.LocalToolBackend` — in-process backend backed + by ``async def`` callables. Useful for unit tests and for embedding + tools inside the PyRIT process. * :class:`pyrit.tools.MCPToolBackend` (lands in C3) — proxies dispatch through one or more MCP servers. diff --git a/pyrit/tools/callable_backend.py b/pyrit/tools/local_backend.py similarity index 93% rename from pyrit/tools/callable_backend.py rename to pyrit/tools/local_backend.py index defbe94ed1..0c67054590 100644 --- a/pyrit/tools/callable_backend.py +++ b/pyrit/tools/local_backend.py @@ -14,12 +14,17 @@ logger = logging.getLogger(__name__) -class CallableToolBackend: +class LocalToolBackend: """ - Pure-Python :class:`~pyrit.tools.ToolBackend` backed by a name -> ``async def`` + In-process :class:`~pyrit.tools.ToolBackend` backed by a name -> ``async def`` mapping. Useful for unit tests and for embedding small tools inside the PyRIT process without standing up an MCP server. + "Local" here means tools run in PyRIT's own Python process — no + subprocess, no IPC, no wire protocol. Contrast with + :class:`~pyrit.tools.MCPToolBackend` (lands in C3), which proxies + dispatch through one or more MCP servers reached via JSON-RPC. + The backend dispatches sequentially in declaration order. Tool-side failures (raised exceptions, missing names, allow-list rejections) are converted into structured error envelopes so the tool loop can diff --git a/tests/unit/tools/conftest.py b/tests/unit/tools/conftest.py index 9ac8ce8aa5..8419c5b06c 100644 --- a/tests/unit/tools/conftest.py +++ b/tests/unit/tools/conftest.py @@ -38,7 +38,7 @@ from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.tools import ( - CallableToolBackend, + LocalToolBackend, ToolCall, ToolCallParser, ToolEventBehavior, @@ -290,7 +290,7 @@ def _factory(*, max_tool_iterations: int = 5) -> ToolEventPolicy: __all__ = [ - "CallableToolBackend", + "LocalToolBackend", "ToolCall", "ToolEventBehavior", "ToolEventPolicy", diff --git a/tests/unit/tools/test_callable_tool_backend.py b/tests/unit/tools/test_local_tool_backend.py similarity index 87% rename from tests/unit/tools/test_callable_tool_backend.py rename to tests/unit/tools/test_local_tool_backend.py index fb09788ebc..8e24693140 100644 --- a/tests/unit/tools/test_callable_tool_backend.py +++ b/tests/unit/tools/test_local_tool_backend.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. """ -Unit tests for :class:`pyrit.tools.CallableToolBackend`. +Unit tests for :class:`pyrit.tools.LocalToolBackend`. Coverage map (rows from the C2 test matrix): @@ -16,7 +16,7 @@ (both strict and tolerant modes), schema property defaulting, scalar result wrapping, and declaration-order preservation in the bulk dispatch path. These are required for the §10 rubber-duck guarantee that every -public-facing branch of :class:`CallableToolBackend` is exercised +public-facing branch of :class:`LocalToolBackend` is exercised before C5 wires it to a production target. """ @@ -24,7 +24,7 @@ import pytest -from pyrit.tools import CallableToolBackend, ToolCall +from pyrit.tools import LocalToolBackend, ToolCall def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) -> ToolCall: @@ -38,7 +38,7 @@ async def echo(args: dict) -> dict: invoked.append(args.get("text", "")) return {"echoed": args.get("text", "")} - backend = CallableToolBackend( + backend = LocalToolBackend( callables={"echo": echo, "off_limits": echo}, allowed_tools={"echo"}, ) @@ -55,7 +55,7 @@ async def test_failing_tool_yields_error_envelope(): async def boom(args: dict) -> dict: raise RuntimeError("kaboom") - backend = CallableToolBackend(callables={"boom": boom}) + backend = LocalToolBackend(callables={"boom": boom}) result = await backend.dispatch_async(_make_call("boom")) @@ -65,7 +65,7 @@ async def boom(args: dict) -> dict: async def test_missing_tool_raises_when_strict(): - backend = CallableToolBackend(callables={}, fail_on_missing_function=True) + backend = LocalToolBackend(callables={}, fail_on_missing_function=True) with pytest.raises(KeyError, match="ghost"): await backend.dispatch_async(_make_call("ghost")) @@ -75,7 +75,7 @@ async def test_missing_tool_returns_envelope_when_tolerant(): async def echo(args: dict) -> dict: return {"ok": True} - backend = CallableToolBackend( + backend = LocalToolBackend( callables={"echo": echo}, fail_on_missing_function=False, ) @@ -91,7 +91,7 @@ async def test_scalar_result_is_wrapped_in_dict(): async def number(args: dict) -> int: return 42 - backend = CallableToolBackend(callables={"number": number}) + backend = LocalToolBackend(callables={"number": number}) result = await backend.dispatch_async(_make_call("number")) @@ -102,7 +102,7 @@ async def test_dict_result_passes_through_unchanged(): async def named(args: dict) -> dict: return {"custom_key": "custom_value"} - backend = CallableToolBackend(callables={"named": named}) + backend = LocalToolBackend(callables={"named": named}) result = await backend.dispatch_async(_make_call("named")) @@ -110,14 +110,14 @@ async def named(args: dict) -> dict: async def test_schemas_defaults_to_empty_list(): - backend = CallableToolBackend(callables={}) + backend = LocalToolBackend(callables={}) assert backend.schemas == [] async def test_schemas_returned_as_copy(): schemas_in = [{"name": "echo", "parameters": {}}] - backend = CallableToolBackend(callables={}, schemas=schemas_in) + backend = LocalToolBackend(callables={}, schemas=schemas_in) out1 = backend.schemas out1.append({"name": "mutated"}) @@ -130,7 +130,7 @@ async def test_dispatch_all_sequential_preserves_declaration_order(): async def echo(args: dict) -> dict: return {"echoed": args["i"]} - backend = CallableToolBackend(callables={"echo": echo}) + backend = LocalToolBackend(callables={"echo": echo}) calls = [_make_call("echo", call_id=f"c{i}", arguments={"i": i}) for i in range(5)] pairs = await backend.dispatch_all_sequential_async(calls) @@ -161,7 +161,7 @@ async def reverse(args: dict) -> dict: invocations.append(("reverse", args)) return {"reversed": args.get("text", "")[::-1]} - backend = CallableToolBackend(callables={"echo": echo, "add": add, "reverse": reverse}) + backend = LocalToolBackend(callables={"echo": echo, "add": add, "reverse": reverse}) prepended_calls = [ _make_call("echo", call_id="e1", arguments={"text": "hello"}), From e8ab8ffd055e742454836ba6416e64bd9ac7044d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Tue, 26 May 2026 15:56:40 -0700 Subject: [PATCH 04/17] Addition of MCP components including the MCP client and tool backend. --- pyrit/tools/__init__.py | 14 + pyrit/tools/backend.py | 40 +-- pyrit/tools/local_backend.py | 24 +- pyrit/tools/mcp_backend.py | 199 +++++++++++++++ pyrit/tools/mcp_client.py | 369 +++++++++++++++++++++++++++ tests/unit/tools/conftest.py | 13 +- tests/unit/tools/test_mcp_backend.py | 156 +++++++++++ tests/unit/tools/test_mcp_client.py | 171 +++++++++++++ 8 files changed, 936 insertions(+), 50 deletions(-) create mode 100644 pyrit/tools/mcp_backend.py create mode 100644 pyrit/tools/mcp_client.py create mode 100644 tests/unit/tools/test_mcp_backend.py create mode 100644 tests/unit/tools/test_mcp_client.py diff --git a/pyrit/tools/__init__.py b/pyrit/tools/__init__.py index c3520098ea..46b11aa358 100644 --- a/pyrit/tools/__init__.py +++ b/pyrit/tools/__init__.py @@ -44,11 +44,25 @@ from pyrit.tools.backend import ToolBackend from pyrit.tools.local_backend import LocalToolBackend +from pyrit.tools.mcp_backend import MCPToolBackend +from pyrit.tools.mcp_client import ( + DockerMCPServerSpec, + LocalMCPServerSpec, + MCPClient, + MCPServerSpec, + RemoteMCPServerSpec, +) from pyrit.tools.models import ToolCall, ToolEventBehavior, ToolEventPolicy, tool_loop from pyrit.tools.parsers import ToolCallParser __all__ = [ + "DockerMCPServerSpec", + "LocalMCPServerSpec", "LocalToolBackend", + "MCPClient", + "MCPServerSpec", + "MCPToolBackend", + "RemoteMCPServerSpec", "ToolBackend", "ToolCall", "ToolCallParser", diff --git a/pyrit/tools/backend.py b/pyrit/tools/backend.py index 54c1dd7a1d..e7a02a7685 100644 --- a/pyrit/tools/backend.py +++ b/pyrit/tools/backend.py @@ -3,16 +3,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from pyrit.tools.models import ToolCall -@runtime_checkable -class ToolBackend(Protocol): +class ToolBackend(ABC): """ - Protocol for backends that dispatch tool calls produced by a target. + Abstract base for backends that dispatch tool calls produced by a target. A :class:`ToolBackend` is a per-target dispatch table — it owns the ``name -> async callable`` mapping a target uses to execute the tool @@ -25,20 +25,18 @@ class ToolBackend(Protocol): * :class:`~pyrit.tools.LocalToolBackend` — in-process backend backed by ``async def`` callables. Useful for unit tests and for embedding tools inside the PyRIT process. - * :class:`pyrit.tools.MCPToolBackend` (lands in C3) — proxies - dispatch through one or more MCP servers. - - The :attr:`schemas` property exposes the JSON-schema descriptors the - target injects into its request body (e.g. ``tools=[...]`` for the - OpenAI APIs). - - :meth:`dispatch_all_sequential_async` is the contract the tool loop - uses: backends that wish to parallelize dispatch should override it. - The default sequencing — one ``await dispatch_async`` per call, in - declaration order — is what every PyRIT backend ships with today. + * :class:`~pyrit.tools.MCPToolBackend` — proxies dispatch through one + or more MCP servers. + + Subclasses MUST implement :attr:`schemas` and :meth:`dispatch_async`. + :meth:`dispatch_all_sequential_async` ships with a default + implementation that awaits :meth:`dispatch_async` once per call in + declaration order; backends that wish to parallelize dispatch + (e.g. fan out across multiple sandbox containers) should override it. """ @property + @abstractmethod def schemas(self) -> list[dict[str, Any]]: """ The JSON-schema descriptors for every tool the backend exposes. @@ -48,8 +46,8 @@ def schemas(self) -> list[dict[str, Any]]: format that concrete targets serialize into their request body. """ - ... + @abstractmethod async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: """ Execute a single tool call and return the structured result. @@ -64,7 +62,6 @@ async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: Returns: dict[str, Any]: The structured tool result. """ - ... async def dispatch_all_sequential_async( self, @@ -73,6 +70,9 @@ async def dispatch_all_sequential_async( """ Dispatch every call in *calls* sequentially, preserving declaration order. + Default implementation: ``await dispatch_async`` once per call. + Backends that parallelize dispatch should override this method. + Args: calls (list[ToolCall]): The calls to dispatch, in declaration order. @@ -80,4 +80,8 @@ async def dispatch_all_sequential_async( list[tuple[ToolCall, dict[str, Any]]]: ``(call, result)`` pairs, in the same order as *calls*. """ - ... + results: list[tuple[ToolCall, dict[str, Any]]] = [] + for call in calls: + envelope = await self.dispatch_async(call) + results.append((call, envelope)) + return results diff --git a/pyrit/tools/local_backend.py b/pyrit/tools/local_backend.py index 0c67054590..25fe42e83c 100644 --- a/pyrit/tools/local_backend.py +++ b/pyrit/tools/local_backend.py @@ -6,6 +6,8 @@ import logging from typing import TYPE_CHECKING, Any +from pyrit.tools.backend import ToolBackend + if TYPE_CHECKING: from collections.abc import Awaitable, Callable @@ -14,7 +16,7 @@ logger = logging.getLogger(__name__) -class LocalToolBackend: +class LocalToolBackend(ToolBackend): """ In-process :class:`~pyrit.tools.ToolBackend` backed by a name -> ``async def`` mapping. Useful for unit tests and for embedding small tools inside the @@ -117,23 +119,3 @@ async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: "detail": str(ex), } return result if isinstance(result, dict) else {"result": result} - - async def dispatch_all_sequential_async( - self, - calls: list[ToolCall], - ) -> list[tuple[ToolCall, dict[str, Any]]]: - """ - Dispatch *calls* sequentially in declaration order. - - Args: - calls (list[ToolCall]): Calls to dispatch. - - Returns: - list[tuple[ToolCall, dict[str, Any]]]: ``(call, result)`` pairs - in the same order as *calls*. - """ - results: list[tuple[ToolCall, dict[str, Any]]] = [] - for call in calls: - result = await self.dispatch_async(call) - results.append((call, result)) - return results diff --git a/pyrit/tools/mcp_backend.py b/pyrit/tools/mcp_backend.py new file mode 100644 index 0000000000..66da88a30f --- /dev/null +++ b/pyrit/tools/mcp_backend.py @@ -0,0 +1,199 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Multi-server tool backend that proxies dispatch through one or more +MCP servers. + +This is the :class:`~pyrit.tools.ToolBackend` implementation that real +red-team configurations use. It composes one +:class:`~pyrit.tools.MCPClient` per :class:`~pyrit.tools.MCPServerSpec`, +aggregates their advertised schemas, routes incoming +:class:`~pyrit.tools.ToolCall` instances to the correct underlying +client, and enforces an optional ``allowed_tools`` allow-list. + +Contrast with :class:`~pyrit.tools.LocalToolBackend`, which dispatches +to Python ``async def`` callables inside PyRIT's own process. +""" + +from __future__ import annotations + +import asyncio +import logging +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Any + +from pyrit.tools.backend import ToolBackend +from pyrit.tools.mcp_client import MCPClient + +if TYPE_CHECKING: + from collections.abc import Iterable + + from pyrit.tools.mcp_client import MCPServerSpec + from pyrit.tools.models import ToolCall + +logger = logging.getLogger(__name__) + + +class MCPToolBackend(ToolBackend): + """ + :class:`~pyrit.tools.ToolBackend` backed by one or more MCP servers. + + On :meth:`__aenter__`, the backend spawns / connects each server in + its :attr:`_servers` list (sequentially) through a single + :class:`contextlib.AsyncExitStack`, runs the MCP handshake, caches + schemas, and builds an advertised-name → ``(client, server_name)`` + routing table. Collisions raise :class:`ValueError` unless the + colliding specs set :attr:`~pyrit.tools.LocalMCPServerSpec.name_prefix`. + + A single shared :class:`AsyncExitStack` (rather than one per client) + is required so anyio's nested cancel scopes — opened by the ``mcp`` + SDK's ``stdio_client`` and ``ClientSession`` context managers — are + closed in strict LIFO order from the entering task. Closing + out-of-order would trip + ``"Attempted to exit a cancel scope that isn't the current task's + current cancel scope"``. + + Dispatch is serialized through an :class:`asyncio.Lock` per backend + instance — multiple concurrent coroutines sharing the same backend + (e.g. parallel attack runs) will not interleave JSON-RPC frames on + the same stdio pipe. + """ + + def __init__( + self, + *, + servers: Iterable[MCPServerSpec], + allowed_tools: list[str] | None = None, + ) -> None: + """ + Initialize the backend. + + Args: + servers: One or more :class:`MCPServerSpec` instances describing + where each server runs. + allowed_tools: Optional allow-list of tool names. Names not in + the list are filtered from :attr:`schemas` AND + short-circuit dispatch with a ``tool_not_allowed`` envelope. + Names are matched after :attr:`~LocalMCPServerSpec.name_prefix` + has been applied. Defaults to None (every advertised tool is + callable). + + Raises: + ValueError: When *servers* is empty. + """ + self._servers: list[MCPServerSpec] = list(servers) + if not self._servers: + raise ValueError("MCPToolBackend requires at least one server spec.") + self._allowed_tools: set[str] | None = set(allowed_tools) if allowed_tools is not None else None + self._clients: list[MCPClient] = [] + self._routing: dict[str, tuple[MCPClient, str]] = {} + self._dispatch_lock = asyncio.Lock() + self._stack: AsyncExitStack | None = None + self._entered = False + + @property + def schemas(self) -> list[dict[str, Any]]: + """The union of every connected server's schemas, filtered by ``allowed_tools``.""" + out: list[dict[str, Any]] = [] + for client in self._clients: + for schema in client.schemas: + if self._allowed_tools is not None and schema["name"] not in self._allowed_tools: + continue + out.append(schema) + return out + + async def __aenter__(self) -> MCPToolBackend: + """ + Connect each underlying client through a shared :class:`AsyncExitStack` and build the routing table. + + Returns: + MCPToolBackend: *self*, ready to dispatch. + + Raises: + ValueError: When two connected clients advertise the same tool + name without a disambiguating ``name_prefix``. + """ + stack = AsyncExitStack() + clients: list[MCPClient] = [] + routing: dict[str, tuple[MCPClient, str]] = {} + try: + for spec in self._servers: + client = MCPClient(spec=spec) + await stack.enter_async_context(client) + clients.append(client) + for advertised_name in client.tool_names: + if advertised_name in routing: + raise ValueError( + f"duplicate tool name '{advertised_name}'. " + "Set LocalMCPServerSpec.name_prefix on at least one " + "colliding server to disambiguate.", + ) + routing[advertised_name] = (client, advertised_name) + except Exception: + await stack.aclose() + raise + + self._stack = stack + self._clients = clients + self._routing = routing + self._entered = True + return self + + async def __aexit__(self, *exc: Any) -> None: + """Tear down every underlying client in strict LIFO order.""" + stack = self._stack + self._stack = None + self._clients = [] + self._routing = {} + self._entered = False + if stack is not None: + await stack.aclose() + + async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: + """ + Route *call* to the correct client and dispatch. + + See :class:`MCPClient.dispatch_async` for the envelope shape. + Allow-list rejections and unknown-tool calls return error + envelopes; only "backend not entered" raises. + + Args: + call (ToolCall): The call to dispatch. + + Returns: + dict[str, Any]: A structured envelope (success, ``tool_not_allowed``, + ``tool_not_registered``, or the underlying + :meth:`MCPClient.dispatch_async` envelope). + + Raises: + RuntimeError: When the backend has not been entered via ``async with``. + """ + if not self._entered: + raise RuntimeError( + "MCPToolBackend is not active. Use `async with backend:` to manage its lifecycle before dispatching.", + ) + + if self._allowed_tools is not None and call.name not in self._allowed_tools: + logger.info("Rejecting disallowed tool call: %s", call.name) + return { + "is_error": True, + "error": "tool_not_allowed", + "tool": call.name, + "allowed_tools": sorted(self._allowed_tools), + } + + route = self._routing.get(call.name) + if route is None: + available = sorted(self._routing.keys()) + logger.warning("Tool '%s' not registered. Available: %s", call.name, available) + return { + "is_error": True, + "error": "tool_not_registered", + "tool": call.name, + "available_tools": available, + } + + client, _server_side_name = route + async with self._dispatch_lock: + return await client.dispatch_async(call) diff --git a/pyrit/tools/mcp_client.py b/pyrit/tools/mcp_client.py new file mode 100644 index 0000000000..904004f675 --- /dev/null +++ b/pyrit/tools/mcp_client.py @@ -0,0 +1,369 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Stdio-transport client for the Model Context Protocol (MCP). + +This module is the wire-protocol half of PyRIT's MCP integration. It +sits below :class:`~pyrit.tools.MCPToolBackend` (which composes one +:class:`MCPClient` per configured server and handles cross-server +routing) and above the upstream ``mcp`` Python SDK (which owns the +JSON-RPC framing, capability negotiation, and asyncio task plumbing). + +The three :class:`MCPServerSpec` variants describe *where* the server +runs. Only :class:`LocalMCPServerSpec` is implemented in this commit: + +* :class:`LocalMCPServerSpec` — spawn the server as a child process and + speak JSON-RPC over its stdin/stdout. +* :class:`RemoteMCPServerSpec` — HTTP/SSE transport against a hosted + server. Stub: ``connect_async`` raises ``NotImplementedError``. +* :class:`DockerMCPServerSpec` — stdio over ``docker run -i`` against a + hardened sandbox container. Stub: ``connect_async`` raises + ``NotImplementedError``. Implementation lands in the follow-up + sandbox PR. + +The stub variants are intentionally part of the type union today so +downstream code can be written against the eventual API without +forcing a Union expansion later. +""" + +from __future__ import annotations + +import asyncio +import logging +from contextlib import AsyncExitStack +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from mcp import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client + +if TYPE_CHECKING: + from pyrit.tools.models import ToolCall + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class LocalMCPServerSpec: + """ + Spec for an MCP server spawned as a child process and reached via + stdio JSON-RPC. + + Attributes: + command (str): The interpreter or binary to exec (e.g. ``"python"``). + args (tuple[str, ...]): Arguments passed to *command*, in order. + env (dict[str, str] | None): Environment overlay for the child + process. ``None`` (default) inherits PyRIT's environment. + name_prefix (str | None): When set, every tool advertised by the + server is registered as ``f"{name_prefix}{tool_name}"`` in + the parent :class:`~pyrit.tools.MCPToolBackend`. Used to + disambiguate two servers that expose the same tool name. + timeout_seconds (float): Per-call timeout, enforced by + :meth:`MCPClient.dispatch_async`. Defaults to 30 seconds. + """ + + command: str + args: tuple[str, ...] = () + env: dict[str, str] | None = None + name_prefix: str | None = None + timeout_seconds: float = 30.0 + + +@dataclass(frozen=True) +class RemoteMCPServerSpec: + """ + Spec for an MCP server reached over HTTP / SSE. **Not implemented** + in this PR — :meth:`MCPClient.connect_async` raises + :class:`NotImplementedError`. Tracked by ``# TODO(mcp-http-transport)``. + + Attributes: + url (str): The base URL of the MCP server. + name_prefix (str | None): Same semantics as + :attr:`LocalMCPServerSpec.name_prefix`. + timeout_seconds (float): Per-call timeout. + """ + + url: str + name_prefix: str | None = None + timeout_seconds: float = 30.0 + + +# TODO(sandbox-provider) — DockerMCPServerSpec stub here; implementation lands in follow-up PR. +@dataclass(frozen=True) +class DockerMCPServerSpec: + """ + Spec for an MCP server hosted inside a hardened Docker container. + + **NOT IMPLEMENTED IN THIS PR.** Reached via stdio over ``docker run -i``. + + Expected behavior in the follow-up sandbox PR: + + * One container per spec instance, managed by a process-wide + ``SandboxPool``. + * Image is built lazily, keyed by ``sha256(Dockerfile + build_context)``, + and cached across attacks; no rebuild unless missing or explicitly + overridden. + * Container is recreated from the cached image at attack and scenario + boundaries (filesystem returns to baseline every time). + * Network access governed by ``NetworkProfile`` (default ``"none"`` = + ``--network=none``). + * Container runs as a non-root UID with ``--cap-drop=ALL``, a read-only + root filesystem, and an in-container MCP server exposing + ``run_shell(cmd, timeout_seconds)``. + + Attributes: + image (str): Docker image tag (e.g. ``"pyrit-sandbox:base"``). + network_profile (str): ``NetworkProfile`` name; ``"none"`` (default) + launches the container with ``--network=none``. + name_prefix (str | None): Same semantics as + :attr:`LocalMCPServerSpec.name_prefix`. + timeout_seconds (float): Per-call timeout. + + Future fields (deferred to the follow-up sandbox PR): ``memory_limit``, + ``cpu_limit``, ``pids_limit``, ``env``, ``mounts``, ``command_override``. + """ + + image: str + network_profile: str = "none" + name_prefix: str | None = None + timeout_seconds: float = 30.0 + + +MCPServerSpec = LocalMCPServerSpec | RemoteMCPServerSpec | DockerMCPServerSpec + + +def _to_input_schema_dict(input_schema: Any) -> dict[str, Any]: + """ + Coerce the SDK's tool ``inputSchema`` (pydantic model or dict) into a plain dict. + + Returns: + dict[str, Any]: A plain-dict copy of *input_schema*, or an empty + object schema when *input_schema* is None or of an unrecognized type. + """ + if input_schema is None: + return {"type": "object", "properties": {}} + if hasattr(input_schema, "model_dump"): + return input_schema.model_dump() + if isinstance(input_schema, dict): + return dict(input_schema) + return {"type": "object", "properties": {}} + + +def _flatten_content(content: list[Any]) -> str: + """ + Concatenate the text portions of an MCP ``CallToolResult.content`` list. + + Returns: + str: Concatenated ``.text`` values from each content item, in order. + """ + pieces: list[str] = [] + for item in content: + text = getattr(item, "text", None) + if text is not None: + pieces.append(text) + elif isinstance(item, dict) and "text" in item: + pieces.append(item["text"]) + return "".join(pieces) + + +class MCPClient: + """ + A single MCP-server session. + + The client owns the lifetime of one server's transport stack and + exposes a uniform :meth:`dispatch_async` regardless of which + :class:`MCPServerSpec` variant it was constructed from. Composition + across multiple servers (routing, schema aggregation, allow-lists) + is the responsibility of :class:`~pyrit.tools.MCPToolBackend`. + + Lifecycle: + + * :meth:`connect_async` spawns the subprocess (for + :class:`LocalMCPServerSpec`), runs the MCP handshake, and caches + ``tools/list`` results. + * :meth:`dispatch_async` issues one ``tools/call`` and returns a + structured envelope (success or error). + * :meth:`close_async` tears down the transport stack. + + The class is usable as an async context manager. + """ + + def __init__(self, *, spec: MCPServerSpec) -> None: + """ + Initialize the client around *spec*. Does not connect; call + :meth:`connect_async` (or use the async context-manager form) to start + the transport stack. + """ + self._spec = spec + self._stack = AsyncExitStack() + self._session: ClientSession | None = None + self._tools: list[Any] = [] + + @property + def spec(self) -> MCPServerSpec: + """The :class:`MCPServerSpec` this client was constructed with.""" + return self._spec + + @property + def schemas(self) -> list[dict[str, Any]]: + """ + JSON schemas for every tool the server advertises. + + Each schema is shaped ``{"name", "description", "parameters"}``. + The optional :attr:`LocalMCPServerSpec.name_prefix` is applied + here so a backend that owns this client sees the prefixed name. + """ + prefix = getattr(self._spec, "name_prefix", None) or "" + return [ + { + "name": f"{prefix}{tool.name}", + "description": tool.description or "", + "parameters": _to_input_schema_dict(tool.inputSchema), + } + for tool in self._tools + ] + + @property + def tool_names(self) -> list[str]: + """Tool names with the spec's :attr:`name_prefix` applied.""" + return [s["name"] for s in self.schemas] + + def _strip_prefix(self, name: str) -> str: + prefix = getattr(self._spec, "name_prefix", None) or "" + if prefix and name.startswith(prefix): + return name[len(prefix) :] + return name + + async def connect_async(self) -> None: + """Establish the transport, run the handshake, and cache schemas.""" + if isinstance(self._spec, RemoteMCPServerSpec): + raise NotImplementedError( + "HTTP/SSE transport ships in a follow-up PR. " + "RemoteMCPServerSpec is declared today so user code can target the eventual API." + ) + if isinstance(self._spec, DockerMCPServerSpec): + raise NotImplementedError( + "Docker sandbox transport ships in a follow-up PR. " + "DockerMCPServerSpec runs the MCP server inside a hardened " + "Debian container reached via stdio over `docker run -i`, " + "managed by a process-wide SandboxPool with image caching and " + "per-attack container recreation." + ) + + assert isinstance(self._spec, LocalMCPServerSpec) + params = StdioServerParameters( + command=self._spec.command, + args=list(self._spec.args), + env=self._spec.env, + ) + read, write = await self._stack.enter_async_context(stdio_client(params)) + session = await self._stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + result = await session.list_tools() + self._session = session + self._tools = list(result.tools) + + async def close_async(self) -> None: + """Tear down the transport stack. Idempotent; safe to call before connect.""" + try: + await self._stack.aclose() + except Exception as ex: # noqa: BLE001 — close should never raise into the caller. + logger.warning("Error tearing down MCP client stack: %s", ex) + finally: + self._stack = AsyncExitStack() + self._session = None + self._tools = [] + + async def __aenter__(self) -> MCPClient: + """ + Connect the transport stack and return *self*. + + Returns: + MCPClient: *self*, ready to dispatch tool calls. + """ + await self.connect_async() + return self + + async def __aexit__(self, *exc: Any) -> None: + """Tear down the transport stack.""" + await self.close_async() + + async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: + """ + Issue one ``tools/call`` and return a structured envelope. + + Envelope shape: + + * Success: ``{"is_error": False, "content": str, "tool": name}``. + * Timeout: ``{"is_error": True, "error": "tool_timeout", "tool": name, ...}``. + * Server-reported error: ``{"is_error": True, "error": "tool_execution_failed", "tool": name, ...}``. + + Tool-side failures are converted to envelopes; only programmer + errors (calling before :meth:`connect_async`) raise. + + Args: + call (ToolCall): The call to dispatch. The advertised + ``name_prefix`` (if any) is stripped before contacting the server. + + Returns: + dict[str, Any]: One of the envelope shapes documented above. + + Raises: + RuntimeError: When the client has not been connected. + """ + if self._session is None: + raise RuntimeError("MCPClient is not connected; call connect_async first.") + + server_side_name = self._strip_prefix(call.name) + timeout = getattr(self._spec, "timeout_seconds", 30.0) + try: + result = await asyncio.wait_for( + self._session.call_tool(server_side_name, arguments=dict(call.arguments)), + timeout=timeout, + ) + except asyncio.TimeoutError: + logger.warning( + "MCP tool '%s' timed out after %.2fs", + call.name, + timeout, + ) + return { + "is_error": True, + "error": "tool_timeout", + "tool": call.name, + "timeout_seconds": timeout, + } + except Exception as ex: # noqa: BLE001 — wrap and surface as envelope. + logger.warning( + "MCP tool '%s' raised %s: %s", + call.name, + type(ex).__name__, + ex, + ) + return { + "is_error": True, + "error": "tool_execution_failed", + "tool": call.name, + "detail": str(ex), + } + + content_text = _flatten_content(list(result.content)) + is_error = bool(getattr(result, "isError", False)) + envelope: dict[str, Any] = { + "is_error": is_error, + "content": content_text, + "tool": call.name, + } + if is_error: + envelope["error"] = "tool_execution_failed" + return envelope + + +__all__ = [ + "DockerMCPServerSpec", + "LocalMCPServerSpec", + "MCPClient", + "MCPServerSpec", + "RemoteMCPServerSpec", +] diff --git a/tests/unit/tools/conftest.py b/tests/unit/tools/conftest.py index 8419c5b06c..dd1ae7756e 100644 --- a/tests/unit/tools/conftest.py +++ b/tests/unit/tools/conftest.py @@ -39,6 +39,7 @@ from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.tools import ( LocalToolBackend, + ToolBackend, ToolCall, ToolCallParser, ToolEventBehavior, @@ -144,7 +145,7 @@ def parse(self, message: Message) -> list[ToolCall]: return calls -class _RecordingToolBackend: +class _RecordingToolBackend(ToolBackend): """ Minimal :class:`ToolBackend` that records every dispatched call and returns results from a scripted queue. Used to assert dispatch order, @@ -172,16 +173,6 @@ async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: nxt = self._results.popleft() return nxt if isinstance(nxt, dict) else {"result": nxt} - async def dispatch_all_sequential_async( - self, - calls: list[ToolCall], - ) -> list[tuple[ToolCall, dict[str, Any]]]: - results: list[tuple[ToolCall, dict[str, Any]]] = [] - for call in calls: - result = await self.dispatch_async(call) - results.append((call, result)) - return results - class _FakeToolTarget(PromptTarget): """ diff --git a/tests/unit/tools/test_mcp_backend.py b/tests/unit/tools/test_mcp_backend.py new file mode 100644 index 0000000000..0abc8ff88c --- /dev/null +++ b/tests/unit/tools/test_mcp_backend.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for :class:`pyrit.tools.MCPToolBackend`. + +These tests verify the multi-server fan-out and routing layer on top of +:class:`MCPClient`: schema aggregation, name-collision detection, +``name_prefix`` disambiguation, ``allowed_tools`` allow-list semantics, +and concurrent-dispatch serialization. They reuse the real +``echo_mcp_server.py`` stdio subprocess. + +Coverage map: + +* **U18** — ``test_disallowed_tool_returns_error_envelope_without_invoking_server``. +* **U20a** — ``test_name_collision_raises_value_error``. +* **U20b** — ``test_name_prefix_disambiguates_colliding_servers``. +* **U21** — ``test_concurrent_dispatch_is_serialized_by_lock``. +""" + +from __future__ import annotations + +import asyncio +import sys +from pathlib import Path + +import pytest + +from pyrit.tools import ( + LocalMCPServerSpec, + MCPToolBackend, + ToolCall, +) + +ECHO_SERVER_SCRIPT = str(Path(__file__).parent / "echo_mcp_server.py") + + +def _spec(*, name_prefix: str | None = None, timeout_seconds: float = 5.0) -> LocalMCPServerSpec: + return LocalMCPServerSpec( + command=sys.executable, + args=(ECHO_SERVER_SCRIPT,), + name_prefix=name_prefix, + timeout_seconds=timeout_seconds, + ) + + +def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) -> ToolCall: + return ToolCall(call_id=call_id, name=name, arguments=arguments or {}) + + +@pytest.mark.asyncio +async def test_backend_aggregates_schemas_across_servers() -> None: + """Schemas from every connected server show up in :attr:`schemas`.""" + backend = MCPToolBackend(servers=[_spec()]) + async with backend: + names = {s["name"] for s in backend.schemas} + assert names == {"echo", "add", "reverse", "slow_echo"} + + +@pytest.mark.asyncio +async def test_dispatch_routes_to_correct_server() -> None: + """A :class:`ToolCall` is routed to the server that registered the name.""" + backend = MCPToolBackend(servers=[_spec()]) + async with backend: + envelope = await backend.dispatch_async(_make_call("echo", arguments={"text": "routed"})) + assert envelope["is_error"] is False + assert envelope["content"] == "routed" + + +@pytest.mark.asyncio +async def test_name_collision_raises_value_error() -> None: + """Two servers exposing the same tool name without prefixes raise.""" + backend = MCPToolBackend(servers=[_spec(), _spec()]) + with pytest.raises(ValueError, match="duplicate tool name"): + await backend.__aenter__() + # __aexit__ is the cleanup path; __aenter__ failing leaves nothing to clean. + + +@pytest.mark.asyncio +async def test_name_prefix_disambiguates_colliding_servers() -> None: + """Setting :attr:`LocalMCPServerSpec.name_prefix` disambiguates duplicates.""" + backend = MCPToolBackend( + servers=[ + _spec(name_prefix="a_"), + _spec(name_prefix="b_"), + ], + ) + async with backend: + names = {s["name"] for s in backend.schemas} + assert "a_echo" in names + assert "b_echo" in names + envelope = await backend.dispatch_async(_make_call("a_echo", arguments={"text": "alpha"})) + assert envelope["content"] == "alpha" + envelope_b = await backend.dispatch_async(_make_call("b_echo", arguments={"text": "beta"})) + assert envelope_b["content"] == "beta" + + +@pytest.mark.asyncio +async def test_disallowed_tool_returns_error_envelope_without_invoking_server() -> None: + """U18: allowed_tools blocks both schema advertisement AND dispatch.""" + backend = MCPToolBackend(servers=[_spec()], allowed_tools=["echo"]) + async with backend: + advertised = {s["name"] for s in backend.schemas} + assert advertised == {"echo"} # add/reverse/slow_echo are filtered out. + + envelope = await backend.dispatch_async(_make_call("add", arguments={"a": 1, "b": 2})) + assert envelope["is_error"] is True + assert envelope["error"] == "tool_not_allowed" + assert envelope["tool"] == "add" + assert envelope["allowed_tools"] == ["echo"] + + +@pytest.mark.asyncio +async def test_unknown_tool_returns_error_envelope() -> None: + """A call to a name no connected server exposes returns an error envelope.""" + backend = MCPToolBackend(servers=[_spec()]) + async with backend: + envelope = await backend.dispatch_async(_make_call("never_registered")) + assert envelope["is_error"] is True + assert envelope["error"] == "tool_not_registered" + assert envelope["tool"] == "never_registered" + + +@pytest.mark.asyncio +async def test_concurrent_dispatch_is_serialized_by_lock() -> None: + """U21: two coroutines dispatching against the same backend do not interleave. + + The slow_echo tool sleeps server-side; without the lock the two + dispatches would issue overlapping JSON-RPC frames over the same + stdio pipe. With the lock they run back-to-back. We assert both + return successfully — interleaved frames would surface as protocol + errors or wrong content. + """ + backend = MCPToolBackend(servers=[_spec(timeout_seconds=10.0)]) + async with backend: + results = await asyncio.gather( + backend.dispatch_async(_make_call("slow_echo", arguments={"text": "A", "delay_ms": 50})), + backend.dispatch_async(_make_call("slow_echo", arguments={"text": "B", "delay_ms": 50})), + ) + assert all(not r["is_error"] for r in results) + assert {r["content"] for r in results} == {"A", "B"} + + +@pytest.mark.asyncio +async def test_dispatch_all_sequential_async_preserves_order() -> None: + """Bulk dispatch returns (call, envelope) pairs in declaration order.""" + backend = MCPToolBackend(servers=[_spec()]) + calls = [ + _make_call("echo", call_id="c1", arguments={"text": "first"}), + _make_call("echo", call_id="c2", arguments={"text": "second"}), + _make_call("echo", call_id="c3", arguments={"text": "third"}), + ] + async with backend: + results = await backend.dispatch_all_sequential_async(calls) + assert [c.call_id for c, _ in results] == ["c1", "c2", "c3"] + assert [r["content"] for _, r in results] == ["first", "second", "third"] diff --git a/tests/unit/tools/test_mcp_client.py b/tests/unit/tools/test_mcp_client.py new file mode 100644 index 0000000000..67f93d046e --- /dev/null +++ b/tests/unit/tools/test_mcp_client.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for :class:`pyrit.tools.MCPClient` and the +:class:`pyrit.tools.MCPServerSpec` union. + +Coverage map (rows from the C2/C3 test matrix): + +* **U10** — ``test_real_subprocess_dispatch_returns_text_content``, + ``test_sequential_dispatch_against_real_server``. +* **U14** — ``test_connect_async_populates_schemas_via_tools_list``. +* **U17** — ``test_dispatch_timeout_returns_error_envelope``. +* **U20** — ``test_remote_mcp_server_spec_raises_not_implemented``, + ``test_docker_mcp_server_spec_raises_not_implemented``. + +These tests spawn the real ``tests/unit/tools/echo_mcp_server.py`` +subprocess via ``mcp.client.stdio.stdio_client``; they exercise the +full handshake → ``tools/list`` → ``tools/call`` round trip. The +purpose is to verify that ``MCPClient`` is a thin, correct facade +over the SDK rather than to re-test the SDK itself. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest + +from pyrit.tools import ( + DockerMCPServerSpec, + LocalMCPServerSpec, + MCPClient, + RemoteMCPServerSpec, + ToolCall, +) + +ECHO_SERVER_SCRIPT = str(Path(__file__).parent / "echo_mcp_server.py") + + +def _local_spec(*, timeout_seconds: float = 5.0) -> LocalMCPServerSpec: + """Build a :class:`LocalMCPServerSpec` that spawns ``echo_mcp_server.py``.""" + return LocalMCPServerSpec( + command=sys.executable, + args=(ECHO_SERVER_SCRIPT,), + timeout_seconds=timeout_seconds, + ) + + +def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) -> ToolCall: + return ToolCall(call_id=call_id, name=name, arguments=arguments or {}) + + +@pytest.mark.asyncio +async def test_real_subprocess_dispatch_returns_text_content() -> None: + """U10: dispatching a single tool call returns the echo server's text response.""" + client = MCPClient(spec=_local_spec()) + async with client: + envelope = await client.dispatch_async(_make_call("echo", arguments={"text": "hi"})) + assert envelope["is_error"] is False + assert envelope["content"] == "hi" + + +@pytest.mark.asyncio +async def test_sequential_dispatch_against_real_server() -> None: + """U10: multiple sequential calls round-trip through the same session.""" + client = MCPClient(spec=_local_spec()) + async with client: + envelopes = [ + await client.dispatch_async(_make_call("echo", arguments={"text": "first"})), + await client.dispatch_async(_make_call("add", arguments={"a": 2, "b": 3})), + await client.dispatch_async(_make_call("reverse", arguments={"text": "abc"})), + ] + contents = [e["content"] for e in envelopes] + assert contents == ["first", "5", "cba"] + + +@pytest.mark.asyncio +async def test_connect_async_populates_schemas_via_tools_list() -> None: + """U14: schemas are discovered via tools/list during connect_async.""" + client = MCPClient(spec=_local_spec()) + async with client: + schemas = client.schemas + names = {s["name"] for s in schemas} + assert names == {"echo", "add", "reverse", "slow_echo"} + echo_schema = next(s for s in schemas if s["name"] == "echo") + assert "parameters" in echo_schema + assert echo_schema["parameters"]["properties"]["text"]["type"] == "string" + + +@pytest.mark.asyncio +async def test_dispatch_timeout_returns_error_envelope() -> None: + """U17: a tool call that exceeds the spec's timeout produces an error envelope.""" + client = MCPClient(spec=_local_spec(timeout_seconds=0.05)) + async with client: + envelope = await client.dispatch_async( + _make_call("slow_echo", arguments={"text": "late", "delay_ms": 500}), + ) + assert envelope["is_error"] is True + assert envelope["error"] == "tool_timeout" + assert envelope["tool"] == "slow_echo" + + +@pytest.mark.asyncio +async def test_dispatch_async_returns_error_envelope_on_unknown_tool() -> None: + """Server-side errors (unknown tool name) surface as is_error envelopes.""" + client = MCPClient(spec=_local_spec()) + async with client: + envelope = await client.dispatch_async(_make_call("nonexistent_tool")) + assert envelope["is_error"] is True + assert envelope["tool"] == "nonexistent_tool" + + +def test_remote_mcp_server_spec_is_frozen_dataclass() -> None: + """U20: RemoteMCPServerSpec exists in the type system as a frozen dataclass.""" + spec = RemoteMCPServerSpec(url="https://example.com/mcp") + assert spec.url == "https://example.com/mcp" + with pytest.raises((AttributeError, Exception)): # frozen dataclass guard + spec.url = "other" # type: ignore[misc] + + +@pytest.mark.asyncio +async def test_remote_mcp_server_spec_raises_not_implemented() -> None: + """U20: connecting to a RemoteMCPServerSpec raises NotImplementedError.""" + client = MCPClient(spec=RemoteMCPServerSpec(url="https://example.com/mcp")) + with pytest.raises(NotImplementedError, match="follow-up PR"): + await client.connect_async() + + +def test_docker_mcp_server_spec_dataclass_fields() -> None: + """U20: DockerMCPServerSpec carries the fields the sandbox PR will consume.""" + spec = DockerMCPServerSpec(image="pyrit-sandbox:base") + assert spec.image == "pyrit-sandbox:base" + assert spec.network_profile == "none" + assert spec.name_prefix is None + assert spec.timeout_seconds == 30.0 + + +@pytest.mark.asyncio +async def test_docker_mcp_server_spec_raises_not_implemented() -> None: + """U20: connecting to a DockerMCPServerSpec raises NotImplementedError.""" + client = MCPClient(spec=DockerMCPServerSpec(image="pyrit-sandbox:base")) + with pytest.raises(NotImplementedError, match="follow-up PR"): + await client.connect_async() + + +@pytest.mark.asyncio +async def test_dispatch_before_connect_raises_runtime_error() -> None: + """Calling dispatch_async before connect_async is a programmer error.""" + client = MCPClient(spec=_local_spec()) + with pytest.raises(RuntimeError, match="not connected"): + await client.dispatch_async(_make_call("echo", arguments={"text": "hi"})) + + +@pytest.mark.asyncio +async def test_close_async_is_idempotent() -> None: + """Calling close_async twice (or before connect) does not raise.""" + client = MCPClient(spec=_local_spec()) + await client.close_async() # before connect — no-op. + await client.connect_async() + await client.close_async() + await client.close_async() # double-close — no-op. + + +@pytest.mark.asyncio +async def test_local_mcp_server_spec_is_frozen() -> None: + """LocalMCPServerSpec is a frozen dataclass.""" + spec = LocalMCPServerSpec(command="python", args=("a.py",)) + with pytest.raises((AttributeError, Exception)): + spec.command = "other" # type: ignore[misc] From c61bcfe20579148884f15ab565979b241eaa09aa Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Tue, 26 May 2026 17:03:40 -0700 Subject: [PATCH 05/17] Add supports_tool_use capability + ToolEventPolicy and wire tool_loop into PromptTarget.send_prompt_async C4 lands the in-tree wiring for the generic tool-use loop introduced by C2/C3: - TargetCapabilities gains supports_tool_use: bool (default False) and CapabilityName.TOOL_USE for the corresponding enum value, matching the existing supports_X / "supports_X" naming convention used by every other capability. - TargetConfiguration grows tool_event_policy + tool_backend kwargs, both gettable/settable properties. The setter (and constructor) validate that a non-None tool_backend requires supports_tool_use=True; otherwise they raise ValueError immediately. ToolBackend / ToolEventPolicy imports are quoted + behind TYPE_CHECKING to keep pyrit.prompt_target.common from importing pyrit.tools eagerly. - PromptTarget.send_prompt_async picks up @tool_loop (below the existing @final). The wrapper is a no-op when tool_event_policy is None, so every existing target keeps its current behavior. _tool_parser (property, default None) and _tool_schemas() (default []) are added on the base class as the two collaborators @tool_loop reads. - _permissive_configuration is updated to flip supports_tool_use=True alongside the other supports_X flags so the all-flags-on probe loop in test_discover_target_capabilities still sees every CapabilityName value as supported. tests/unit/tools/conftest.py drops the hand-decorated @tool_loop on _FakeToolTarget.send_prompt_async (which would now violate the base class's @final) and instead wires policy + backend through TargetConfiguration. _tool_parser becomes a subclass property since the base class now defines one. Tests: - test_tool_event_policy.py adds U7 (capability flag wiring through the wrapper) plus dataclass field defaults and the TargetConfiguration validator. - test_prompt_target_tool_loop.py adds U1 / U2 (DB-end) / U8 / U9 / U11 exercised against a _ProductionShapedTarget that uses the real base-class _get_normalized_conversation_async (memory round-trip via patch_central_database). Plus default-_tool_parser / -_tool_schemas assertions. Validation: 8104 unit tests pass; pre-commit clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../common/discover_target_capabilities.py | 1 + pyrit/prompt_target/common/prompt_target.py | 42 +++ .../common/target_capabilities.py | 9 + .../common/target_configuration.py | 72 ++++- tests/unit/tools/conftest.py | 45 +-- .../tools/test_prompt_target_tool_loop.py | 281 ++++++++++++++++++ tests/unit/tools/test_tool_event_policy.py | 121 ++++++++ 7 files changed, 548 insertions(+), 23 deletions(-) create mode 100644 tests/unit/tools/test_prompt_target_tool_loop.py create mode 100644 tests/unit/tools/test_tool_event_policy.py diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index 859d07d428..a61a0cb9a1 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -149,6 +149,7 @@ def _permissive_configuration( supports_json_output=True, supports_editable_history=True, supports_system_prompt=True, + supports_tool_use=True, input_modalities=merged_modalities, ) # Rebuild a fresh configuration from the instance's native capabilities so diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 461af0e03b..800335e1ef 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -12,6 +12,7 @@ from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import ToolCallParser, tool_loop logger = logging.getLogger(__name__) @@ -85,6 +86,7 @@ def __init__( logging.basicConfig(level=logging.INFO) @final + @tool_loop async def send_prompt_async(self, *, message: Message) -> list[Message]: """ Validate, normalize, and send a prompt to the target. @@ -97,6 +99,13 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: 3. Delegates to :meth:`_send_prompt_to_target_async` with the normalized conversation. + When the target's :attr:`configuration.tool_event_policy` is set, the + :func:`pyrit.tools.tool_loop` decorator replaces this body with the + agentic loop and re-enters :meth:`_send_prompt_to_target_async` + repeatedly until the model issues a stop response (or the configured + ``max_tool_iterations`` is hit). When no policy is set, the decorator + is a no-op and the body below runs unchanged. + Subclasses MUST NOT override this method. Override :meth:`_send_prompt_to_target_async` instead. @@ -132,6 +141,39 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me list[Message]: Response messages from the target. """ + @property + def _tool_parser(self) -> ToolCallParser | None: + """ + Per-target :class:`ToolCallParser` consulted by :func:`pyrit.tools.tool_loop`. + + Targets that participate in the tool-use loop override this property + to return a parser that walks their response messages and extracts + :class:`~pyrit.tools.ToolCall` instances. The base default of + ``None`` signals "this target does not participate" -- the wrapper + short-circuits after the first response. + + Returns: + ToolCallParser | None: The parser, or ``None`` for the default + no-tool-use behavior. + """ + return None + + def _tool_schemas(self) -> list[dict[str, Any]]: + """ + Outbound tool-schema list sent on the next request to the model. + + Targets that participate in the tool-use loop override this method + to translate the active :class:`~pyrit.tools.ToolBackend.schemas` + into the wire format their model expects (Responses API vs. Chat + Completions API vs. anything else). The base default returns an + empty list, which means no schemas are advertised. + + Returns: + list[dict[str, Any]]: One schema per advertised tool, in the + target-specific wire format. Empty by default. + """ + return [] + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ Validate the normalized conversation before sending to the target. diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 6ae9ed69e2..234ef4d359 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -24,6 +24,7 @@ class CapabilityName(str, Enum): JSON_OUTPUT = "supports_json_output" EDITABLE_HISTORY = "supports_editable_history" SYSTEM_PROMPT = "supports_system_prompt" + TOOL_USE = "supports_tool_use" class UnsupportedCapabilityBehavior(str, Enum): @@ -138,6 +139,14 @@ class attribute. Users can override individual capabilities per instance # Whether the target natively supports system prompts. supports_system_prompt: bool = False + # Whether the target natively supports model-issued tool calls (the + # canonical OpenAI ``function_call`` / ``function_call_output`` envelopes + # plus an outbound tool-schema list). Targets without this capability + # cannot host a tool-use loop -- attempting to configure a + # :class:`TargetConfiguration` with a ``tool_backend`` on a target whose + # capabilities have ``supports_tool_use=False`` raises at construction. + supports_tool_use: bool = False + # The input modalities supported by the target (e.g., "text", "image"). input_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])}) diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py index 72ca42fcc1..b00d611ab7 100644 --- a/pyrit/prompt_target/common/target_configuration.py +++ b/pyrit/prompt_target/common/target_configuration.py @@ -4,7 +4,7 @@ import logging from collections.abc import Mapping from dataclasses import fields -from typing import Any +from typing import TYPE_CHECKING, Any from pyrit.message_normalizer import MessageListNormalizer from pyrit.models import Message @@ -16,6 +16,10 @@ UnsupportedCapabilityBehavior, ) +if TYPE_CHECKING: + from pyrit.tools.backend import ToolBackend + from pyrit.tools.models import ToolEventPolicy + logger = logging.getLogger(__name__) @@ -39,6 +43,15 @@ class TargetConfiguration: Each target defines defaults; callers can override policy or individual normalizers at creation time. + + Tool use is configured by setting :attr:`tool_event_policy` (mandatory + when a target's response contains tool calls; controls EXECUTE / RAISE / + RETURN\\_RAW behavior) and optionally :attr:`tool_backend` (required only + when ``tool_event_policy.behavior`` is ``EXECUTE``). Both default to + ``None`` and are read by :func:`pyrit.tools.tool_loop` at runtime; + constructing a configuration with a ``tool_backend`` on a target that + does not declare ``capabilities.supports_tool_use=True`` raises + immediately. """ def __init__( @@ -47,6 +60,8 @@ def __init__( capabilities: TargetCapabilities, policy: CapabilityHandlingPolicy | None = None, normalizer_overrides: Mapping[CapabilityName, MessageListNormalizer[Any]] | None = None, + tool_event_policy: "ToolEventPolicy | None" = None, + tool_backend: "ToolBackend | None" = None, ) -> None: """ Build a target configuration and resolve the normalization pipeline. @@ -57,7 +72,25 @@ def __init__( capability. Defaults to RAISE for all adaptable capabilities. normalizer_overrides (Mapping[CapabilityName, MessageListNormalizer[Any]] | None): Optional overrides for specific capability normalizers. + tool_event_policy (ToolEventPolicy | None): How + :func:`pyrit.tools.tool_loop` should react to a pending tool + call from the target. ``None`` means the loop is disabled and + the wrapper short-circuits. + tool_backend (ToolBackend | None): Dispatch table used when + ``tool_event_policy.behavior`` is ``EXECUTE``. ``None`` is + valid only for the RAISE / RETURN\\_RAW policies and the + no-policy passthrough. + + Raises: + ValueError: If ``tool_backend`` is set on a target whose + capabilities do not include ``supports_tool_use``. """ + if tool_backend is not None and not capabilities.includes(capability=CapabilityName.TOOL_USE): + raise ValueError( + "tool_backend is set but capabilities.supports_tool_use is False. " + "Either declare supports_tool_use=True on the target's capabilities, " + "or remove the tool_backend." + ) self._capabilities = capabilities self._policy = policy or _DEFAULT_POLICY self._pipeline = ConversationNormalizationPipeline.from_capabilities( @@ -65,6 +98,8 @@ def __init__( policy=self._policy, normalizer_overrides=normalizer_overrides, ) + self._tool_event_policy = tool_event_policy + self._tool_backend = tool_backend @property def capabilities(self) -> TargetCapabilities: @@ -81,6 +116,41 @@ def pipeline(self) -> ConversationNormalizationPipeline: """The resolved normalization pipeline.""" return self._pipeline + @property + def tool_event_policy(self) -> "ToolEventPolicy | None": + """The tool-use policy consulted by :func:`pyrit.tools.tool_loop`.""" + return self._tool_event_policy + + @tool_event_policy.setter + def tool_event_policy(self, value: "ToolEventPolicy | None") -> None: + """Allow runtime updates so callers can opt a configured target into tool use.""" + self._tool_event_policy = value + + @property + def tool_backend(self) -> "ToolBackend | None": + """The tool dispatch backend used when the loop's behavior is ``EXECUTE``.""" + return self._tool_backend + + @tool_backend.setter + def tool_backend(self, value: "ToolBackend | None") -> None: + """ + Allow runtime updates to the backend. + + Re-runs the ``supports_tool_use`` validator so a backend can never be + installed onto a configuration that does not declare the capability. + + Raises: + ValueError: If ``value`` is not ``None`` and the configuration's + capabilities do not include ``supports_tool_use``. + """ + if value is not None and not self._capabilities.includes(capability=CapabilityName.TOOL_USE): + raise ValueError( + "tool_backend is set but capabilities.supports_tool_use is False. " + "Either declare supports_tool_use=True on the target's capabilities, " + "or remove the tool_backend." + ) + self._tool_backend = value + def includes(self, *, capability: CapabilityName) -> bool: """ Check whether the target includes support for the given capability. diff --git a/tests/unit/tools/conftest.py b/tests/unit/tools/conftest.py index dd1ae7756e..ad7d2c7fd1 100644 --- a/tests/unit/tools/conftest.py +++ b/tests/unit/tools/conftest.py @@ -44,7 +44,6 @@ ToolCallParser, ToolEventBehavior, ToolEventPolicy, - tool_loop, ) @@ -180,14 +179,12 @@ class _FakeToolTarget(PromptTarget): pops scripted responses off a queue. ``_get_normalized_conversation_async`` is overridden to return ``[message]`` directly, isolating decorator behavior from the memory + normalization pipeline. - """ - _DEFAULT_CONFIGURATION = TargetConfiguration( - capabilities=TargetCapabilities( - supports_multi_turn=True, - supports_multi_message_pieces=True, - ) - ) + Inherits the base class's ``@final @tool_loop send_prompt_async``; the + policy + backend are wired through :class:`TargetConfiguration` so the + wrapper finds them via ``self.configuration.tool_event_policy`` and + ``self.configuration.tool_backend``. + """ def __init__( self, @@ -197,15 +194,27 @@ def __init__( backend: Any = None, parser: ToolCallParser | None = None, ) -> None: - super().__init__() + # ``supports_tool_use`` is forced on whenever a policy is configured so + # the TargetConfiguration validator accepts the backend. + caps = TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_tool_use=policy is not None, + ) + config = TargetConfiguration( + capabilities=caps, + tool_event_policy=policy, + tool_backend=backend, + ) + super().__init__(custom_configuration=config) self._scripted_responses: deque[Message] = deque(scripted_responses) self.call_count: int = 0 self.normalized_conversations_seen: list[list[Message]] = [] - # The C2 decorator reads these via getattr; production code wires them - # through TargetConfiguration in C4. - self._configuration.tool_event_policy = policy - self._configuration.tool_backend = backend - self._tool_parser = parser if parser is not None else _CanonicalEnvelopeParser() + self._parser_instance: ToolCallParser | None = parser if parser is not None else _CanonicalEnvelopeParser() + + @property + def _tool_parser(self) -> ToolCallParser | None: + return self._parser_instance async def _get_normalized_conversation_async(self, *, message: Message) -> list[Message]: return [message] @@ -220,14 +229,6 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me raise AssertionError(f"Fake target ran out of scripted responses on iteration {self.call_count}.") return [self._scripted_responses.popleft()] - @tool_loop - async def send_prompt_async(self, *, message: Message) -> list[Message]: - # Passthrough path: only invoked when ToolEventPolicy is None. The - # decorator replaces this body entirely when a policy is set. - message.validate() - normalized = await self._get_normalized_conversation_async(message=message) - return await self._send_prompt_to_target_async(normalized_conversation=normalized) - @pytest.fixture def make_fake_target(patch_central_database): diff --git a/tests/unit/tools/test_prompt_target_tool_loop.py b/tests/unit/tools/test_prompt_target_tool_loop.py new file mode 100644 index 0000000000..8848ddb010 --- /dev/null +++ b/tests/unit/tools/test_prompt_target_tool_loop.py @@ -0,0 +1,281 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for ``@tool_loop`` wired into :meth:`PromptTarget.send_prompt_async`. + +C4 lands the wiring: ``send_prompt_async`` becomes ``@final @tool_loop`` +on the base class, ``_tool_parser`` and ``_tool_schemas()`` get default +no-op implementations, and ``TargetConfiguration`` grows ``tool_event_policy`` ++ ``tool_backend`` kwargs. + +These tests use the production ``_get_normalized_conversation_async`` path +(memory round-trip through :class:`SQLiteMemory` via ``patch_central_database``) +to exercise the wrapper end-to-end. They cover: + +- U1: decorator order (validate + normalize happen exactly once, then the loop) +- U2 (DB-end half): produced ``tool`` message has one ``function_call_output`` + piece per dispatched call, in declaration order +- U8: DB inserts user, asst_with_fc, tool, asst_final in that order +- U9: DB roles + data_types match the canonical envelope +- U11: targets without a policy short-circuit (no wrapper behavior change) + +Tests for capability flag wiring + ``TargetConfiguration`` construction +validation live in :mod:`tests.unit.tools.test_tool_event_policy`. +""" + +from __future__ import annotations + +import json +from collections import deque +from typing import TYPE_CHECKING, Any + +import pytest + +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import ToolCallParser, ToolEventBehavior, ToolEventPolicy + +from .conftest import ( + _CanonicalEnvelopeParser, + _make_assistant_function_call_message, + _make_assistant_text_message, + _make_user_message, + _RecordingToolBackend, +) + +if TYPE_CHECKING: + from pyrit.models import Message + + +class _ProductionShapedTarget(PromptTarget): + """ + Minimal :class:`PromptTarget` that uses the *real* base-class + ``_get_normalized_conversation_async`` (memory round-trip + normalization + pipeline) instead of the conftest stub. Drives the production wrapper + end-to-end so DB-insert-order assertions can run against the real + :class:`CentralMemory` instance set up by ``patch_central_database``. + """ + + def __init__( + self, + *, + scripted_responses: list[Message], + policy: ToolEventPolicy | None, + backend: Any, + parser: ToolCallParser | None, + ) -> None: + caps = TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_tool_use=policy is not None, + ) + config = TargetConfiguration( + capabilities=caps, + tool_event_policy=policy, + tool_backend=backend, + ) + super().__init__(custom_configuration=config) + self._scripted: deque[Message] = deque(scripted_responses) + self.call_count: int = 0 + self._parser_instance = parser + + @property + def _tool_parser(self) -> ToolCallParser | None: + return self._parser_instance + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + self.call_count += 1 + if not self._scripted: + raise AssertionError(f"Target ran out of scripted responses on iteration {self.call_count}.") + response = self._scripted.popleft() + conversation_id = normalized_conversation[-1].message_pieces[0].conversation_id + for piece in response.message_pieces: + piece.conversation_id = conversation_id + return [response] + + +@pytest.fixture +def make_production_target(patch_central_database): + def _factory( + *, + scripted_responses: list[Message], + policy: ToolEventPolicy | None = None, + backend: Any = None, + parser: ToolCallParser | None = None, + ) -> _ProductionShapedTarget: + effective_parser = parser + if effective_parser is None and policy is not None: + effective_parser = _CanonicalEnvelopeParser() + return _ProductionShapedTarget( + scripted_responses=scripted_responses, + policy=policy, + backend=backend, + parser=effective_parser, + ) + + return _factory + + +@pytest.fixture +def execute_policy_fixture(): + return ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE, max_tool_iterations=5) + + +class TestToolLoopWiredIntoBaseClass: + """Verifies ``@tool_loop`` runs on every ``send_prompt_async`` call.""" + + @pytest.mark.asyncio + async def test_decorator_passthrough_when_no_policy(self, make_production_target): + """U11 -- target without a policy behaves exactly like pre-C4 ``send_prompt_async``.""" + target = make_production_target( + scripted_responses=[_make_assistant_text_message("plain")], + policy=None, + ) + + responses = await target.send_prompt_async(message=_make_user_message("hi")) + + assert target.call_count == 1 + assert len(responses) == 1 + assert responses[0].message_pieces[0].original_value == "plain" + + @pytest.mark.asyncio + async def test_tool_loop_order_after_normalize_before_memory(self, make_production_target, execute_policy_fixture): + """U1 -- validate + normalize happen exactly once before the loop iterates.""" + backend = _RecordingToolBackend(scripted_results=[{"result": "echoed"}]) + target = make_production_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "x"})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy_fixture, + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("please echo")) + + assert target.call_count == 2 + assert [c.name for c in backend.recorded_calls] == ["echo"] + assert len(responses) == 3 + assert responses[0].message_pieces[0].original_value_data_type == "function_call" + assert responses[1].message_pieces[0].original_value_data_type == "function_call_output" + assert responses[2].message_pieces[0].original_value_data_type == "text" + + @pytest.mark.asyncio + async def test_tool_message_has_one_function_call_output_piece_per_call( + self, make_production_target, execute_policy_fixture + ): + """U2 DB-end half -- one tool Message, N pieces, one per dispatched call.""" + backend = _RecordingToolBackend(scripted_results=[{"r": 1}, {"r": 2}]) + target = make_production_target( + scripted_responses=[ + _make_assistant_function_call_message( + calls=[("c1", "echo", {"text": "a"}), ("c2", "echo", {"text": "b"})] + ), + _make_assistant_text_message("done"), + ], + policy=execute_policy_fixture, + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("two calls please")) + + tool_msg = responses[1] + assert len(tool_msg.message_pieces) == 2 + call_ids_in_order = [json.loads(p.original_value)["call_id"] for p in tool_msg.message_pieces] + assert call_ids_in_order == ["c1", "c2"] + assert all(p.original_value_data_type == "function_call_output" for p in tool_msg.message_pieces) + assert all(p.api_role == "tool" for p in tool_msg.message_pieces) + + +class TestDbTranscriptAfterToolLoop: + """ + DB-level assertions that exercise the production memory pipeline. + + These tests rely on the wrapper writing the user message + every assistant + + tool message produced during the loop back to ``CentralMemory``, in + declaration order. Whether that write happens *inside* the wrapper or via + the caller (the prompt normalizer) is an implementation detail; the + invariant is the wrapper returns the full chain so the caller can persist + in order. + """ + + @pytest.mark.asyncio + async def test_db_insert_order_user_then_asst_fc_then_tool_then_final_asst( + self, make_production_target, execute_policy_fixture + ): + """U8 -- after a complete tool round, the wrapper's return order is canonical.""" + backend = _RecordingToolBackend(scripted_results=[{"result": "echoed"}]) + target = make_production_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "x"})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy_fixture, + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("please echo")) + + data_types_in_order = [r.message_pieces[0].original_value_data_type for r in responses] + assert data_types_in_order == ["function_call", "function_call_output", "text"] + + @pytest.mark.asyncio + async def test_db_roles_and_data_types_match_canonical_envelope( + self, make_production_target, execute_policy_fixture + ): + """U9 -- roles and data_types match the canonical envelope contract.""" + backend = _RecordingToolBackend(scripted_results=[{"result": "echoed"}]) + target = make_production_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "x"})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy_fixture, + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("please echo")) + + asst_fc, tool_msg, asst_final = responses + # function_call from the assistant + assert asst_fc.message_pieces[0].api_role == "assistant" + assert asst_fc.message_pieces[0].original_value_data_type == "function_call" + envelope = json.loads(asst_fc.message_pieces[0].original_value) + assert envelope["type"] == "function_call" + assert envelope["call_id"] == "c1" + assert envelope["name"] == "echo" + # function_call_output from the tool + assert tool_msg.message_pieces[0].api_role == "tool" + assert tool_msg.message_pieces[0].original_value_data_type == "function_call_output" + tool_envelope = json.loads(tool_msg.message_pieces[0].original_value) + assert tool_envelope["type"] == "function_call_output" + assert tool_envelope["call_id"] == "c1" + # Final assistant text + assert asst_final.message_pieces[0].api_role == "assistant" + assert asst_final.message_pieces[0].original_value_data_type == "text" + + +class TestFinalAndAbstractMethodContract: + """ + Asserts the base-class shape changes that C4 introduces but doesn't + exercise via end-to-end runs: ``_tool_parser`` defaults to ``None``, + ``_tool_schemas`` defaults to ``[]``. + """ + + def test_default_tool_parser_is_none(self, make_production_target): + target = make_production_target( + scripted_responses=[_make_assistant_text_message("plain")], + policy=None, + ) + # Subclass overrides only when the test caller passes a parser. With + # no policy + no parser, the override returns None. + assert target._tool_parser is None + + def test_default_tool_schemas_is_empty_list(self, make_production_target): + target = make_production_target( + scripted_responses=[_make_assistant_text_message("plain")], + policy=None, + ) + assert target._tool_schemas() == [] diff --git a/tests/unit/tools/test_tool_event_policy.py b/tests/unit/tools/test_tool_event_policy.py new file mode 100644 index 0000000000..aa451e672d --- /dev/null +++ b/tests/unit/tools/test_tool_event_policy.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for the wiring between :class:`TargetCapabilities.supports_tool_use`, +:class:`TargetConfiguration.tool_event_policy` / +:class:`TargetConfiguration.tool_backend`, and the +:func:`pyrit.tools.tool_loop` decorator that lives on +:class:`PromptTarget.send_prompt_async`. + +These tests are the §7 U7 row plus the construction-time validator added in C4. +They assert the *capability flag* axis only -- that targets which declare +``supports_tool_use=True`` and configure a policy + backend route through +the loop, that targets without a policy short-circuit, and that the +``tool_backend``-without-capability misconfiguration raises at construction. + +End-to-end ordering against the production memory pipeline (U1, U8, U9) is +exercised separately in ``tests/unit/prompt_target/common/test_prompt_target_tool_loop.py``. +""" + +from __future__ import annotations + +import pytest + +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import LocalToolBackend, ToolEventBehavior, ToolEventPolicy + +from .conftest import ( + _make_assistant_function_call_message, + _make_assistant_text_message, + _make_user_message, +) + + +class TestSupportsToolUseCapabilityFlag: + """Asserts the new ``supports_tool_use`` field on :class:`TargetCapabilities`.""" + + def test_default_is_false(self): + caps = TargetCapabilities() + assert caps.supports_tool_use is False + + def test_explicit_true(self): + caps = TargetCapabilities(supports_tool_use=True) + assert caps.supports_tool_use is True + + +class TestTargetConfigurationToolFields: + """Asserts the new ``tool_event_policy`` / ``tool_backend`` kwargs.""" + + def test_defaults_are_none(self): + caps = TargetCapabilities(supports_tool_use=True) + config = TargetConfiguration(capabilities=caps) + assert config.tool_event_policy is None + assert config.tool_backend is None + + def test_explicit_policy_and_backend(self): + caps = TargetCapabilities(supports_tool_use=True) + backend = LocalToolBackend(callables={}, schemas=[]) + policy = ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE) + config = TargetConfiguration( + capabilities=caps, + tool_event_policy=policy, + tool_backend=backend, + ) + assert config.tool_event_policy is policy + assert config.tool_backend is backend + + def test_tool_backend_without_capability_raises(self): + caps = TargetCapabilities(supports_tool_use=False) + backend = LocalToolBackend(callables={}, schemas=[]) + with pytest.raises(ValueError, match="supports_tool_use"): + TargetConfiguration(capabilities=caps, tool_backend=backend) + + def test_tool_event_policy_without_backend_is_allowed(self): + """``RAISE`` / ``RETURN_RAW`` policies do not require a backend.""" + caps = TargetCapabilities(supports_tool_use=True) + policy = ToolEventPolicy(behavior=ToolEventBehavior.RAISE) + config = TargetConfiguration(capabilities=caps, tool_event_policy=policy) + assert config.tool_event_policy is policy + assert config.tool_backend is None + + +class TestCapabilityFlagWiringIntoToolLoop: + """ + U7 -- verify the wrapper dispatches only when the target declares + ``supports_tool_use`` AND a policy is configured. + """ + + @pytest.mark.asyncio + async def test_target_with_tool_use_capability_uses_tool_loop( + self, make_fake_target, recording_backend, execute_policy + ): + backend = recording_backend() + target = make_fake_target( + scripted_responses=[ + _make_assistant_function_call_message(calls=[("c1", "echo", {"text": "hi"})]), + _make_assistant_text_message("done"), + ], + policy=execute_policy(), + backend=backend, + ) + + responses = await target.send_prompt_async(message=_make_user_message("please call echo")) + + assert target.call_count == 2, "Decorator should have iterated twice (call + final)." + assert [c.name for c in backend.recorded_calls] == ["echo"] + assert len(responses) == 3, "user expects asst_fc, tool_msg, asst_final." + + @pytest.mark.asyncio + async def test_target_without_tool_use_capability_skips_dispatch(self, make_fake_target): + target = make_fake_target( + scripted_responses=[_make_assistant_text_message("plain response, no tool call")], + policy=None, + backend=None, + ) + + responses = await target.send_prompt_async(message=_make_user_message("hello")) + + assert target.call_count == 1 + assert len(responses) == 1 From 7a54cb25b72aaa3bb1788b3d3e6d4c30f46e3b73 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 27 May 2026 17:01:37 -0700 Subject: [PATCH 06/17] Drop C5 (Chat target tool calling): defer to follow-up, redesign around Response target This commit is intentionally empty. It records a scope decision made in response to PR review feedback. No code changes - the C5 working set was uncommitted and has been reverted. # Why we're dropping C5 Review feedback raised two concerns the original C5 did not address: 1. **Duplication against OpenAIResponseTarget.** The Response target already implements an agentic tool loop (openai_response_target.py lines 590-626), the canonical function_call envelope (lines 666-674), a Python-callable dispatch registry (custom_functions), and an allow-list-ish hook (fail_on_missing_function). C5 layered a parallel implementation on top for the Chat target instead of converging both targets onto one stack. 2. **Chat Completions is on its way out.** OpenAI has publicly framed the Responses API as the long-term replacement for Chat Completions. Investing in tool-call plumbing for a deprecated endpoint ages out fast and obscures the actual value of this PR. The right framing is: this PR is not "tool calling for all targets." It is "pluggable tool-execution backends + a client-side agentic loop for non-Responses-API targets." The Responses API is one transport; this PR is the in-process abstraction that works for every transport. # What survives unchanged C1 (mcp SDK dep), C2 (tools/ scaffold + LocalToolBackend), C3 (MCPClient + MCPToolBackend + Docker stub), and C4 (capability flag + @tool_loop wired on the base class) all remain shipped. The genuinely-novel work - local stdio MCP, pluggable backend ABC, ToolEventPolicy (RAISE / EXECUTE / RETURN_RAW), allowed_tools - is unaffected. # The new design **One agentic loop driver.** The @tool_loop decorator on PromptTarget.send_prompt_async (shipped in C4) is the only loop driver. Every target's _send_prompt_to_target_async returns exactly ONE Message per call. The decorator stitches iterations into the response list. **One tool execution layer.** Every dispatched call flows through ToolBackend.dispatch_async(call) -> envelope. Backends (LocalToolBackend for Python callables, MCPToolBackend for stdio MCP subprocesses, future DockerMCPToolBackend, future CompositeToolBackend) are interchangeable behind a single ABC. **Migrate OpenAIResponseTarget onto the decorator (new C5).** Delete the in-class while loop (lines 590-626). _send_prompt_to_target_async becomes "build body, call API, parse response into one Message, return." Add _tool_parser returning CanonicalEnvelopeParser (extracts only function_call pieces; reasoning, mcp_call, web_search_call, etc. continue to pass through to Memory without dispatch). Translate the configured backend's schemas into the Responses-API tools shape inside _construct_request_body (without clobbering an existing extra_body_parameters["tools"]). Wrap custom_functions as a LocalToolBackend internally with DeprecationWarning(removed_in="0.16.0"), preserving the existing fail_on_missing_function semantics. **Integration tests (new C6).** Rewrite to use the Response target as the sole OpenAI tool-calling path, plus end-to-end scenario tests against the real echo_mcp_server. **OpenAIChatTarget receives no tool-calling support in this PR.** A future PR can pull Chat onto the same abstractions if anyone still wants it, but the recommended OpenAI tool-calling path becomes the Responses API. # Risks * Behavior-parity on the Response target: callers that rely on `len(send_prompt_async(...)) == iterations` rather than scanning piece types will need updating. Existing function-chaining tests act as sentinels. * `custom_functions` deprecation must preserve `fail_on_missing_function` semantics through the LocalToolBackend wrapper. * Response parser must continue to round-trip non-`function_call` piece types (reasoning, mcp_call, etc.) to Memory without dispatching. * `extra_body_parameters["tools"]` takes precedence over backend-derived tools so existing manual configs keep working. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> From b010591e22c18bfc479f6e365291bfd44f2766f4 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 28 May 2026 10:33:30 -0700 Subject: [PATCH 07/17] Migrate OpenAIResponseTarget onto @tool_loop and LocalToolBackend C6 collapses the Response target in-class agentic loop into the @tool_loop decorator shipped in C4, and routes tool dispatch through LocalToolBackend (wrapping the existing custom_functions registry as a deprecation shim). # What changed - _send_prompt_to_target_async no longer runs a while loop. It now returns exactly one Message per call. The agentic loop is driven by @tool_loop on the base class. - Added _tool_parser returning CanonicalEnvelopeParser from pyrit/tools/parsers.py. The parser extracts only function_call pieces; reasoning, mcp_call, web_search_call, computer_call, local_shell_call, etc. pass through to Memory unchanged because the parser ignores them and the decorator exits cleanly on the empty parse. - Added _tool_schemas() translating the configured backend schemas into the Responses-API tools shape. - _construct_request_body injects tools=... when the backend has schemas. User-supplied extra_body_parameters["tools"] takes precedence. - supports_tool_use=True on _DEFAULT_CONFIGURATION. - custom_functions= now emits DeprecationWarning(removed_in="0.16.0"). Internally wraps into a LocalToolBackend. A LocalToolBackend is always installed (populated or empty) so legacy target._custom_functions[name]=fn mutations keep affecting dispatch via a back-compat property. - Constructor deep-copies the class-level _DEFAULT_CONFIGURATION before mutating it (PromptTarget.get_default_configuration returns the singleton, so otherwise one instances tool_backend would leak across every other instance). # What did NOT change The legacy _find_last_pending_tool_call, _execute_call_section, and _make_tool_piece helpers remain in place. They are no longer called from production code, but existing tests still cover them; cleanup is deferred to the same follow-up PR that removes the custom_functions kwarg after the 0.16.0 deprecation window. # Tests - New tests/unit/prompt_target/target/test_openai_response_target_c6_migration.py with 7 tests covering deprecation warning, dispatch through user-supplied LocalToolBackend, schema injection, extra_body precedence, no-backend behavior, and reasoning-only passthrough. - All 5 existing function-chaining sentinel tests in test_openai_response_target_function_chaining.py pass unchanged: the back-compat _custom_functions property keeps in-place mutations working. 8131 unit tests green; pre-commit clean (ruff format, ruff check, ty). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../openai/openai_response_target.py | 213 +++++++++--- pyrit/tools/__init__.py | 3 +- pyrit/tools/parsers.py | 66 +++- ...est_openai_response_target_c6_migration.py | 304 ++++++++++++++++++ 4 files changed, 527 insertions(+), 59 deletions(-) create mode 100644 tests/unit/prompt_target/target/test_openai_response_target_c6_migration.py diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index f2e4b19a76..332b64847e 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -3,6 +3,7 @@ import json import logging +import warnings from collections.abc import Awaitable, Callable, MutableSequence from enum import Enum from typing import ( @@ -34,6 +35,14 @@ from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p from pyrit.prompt_target.openai.openai_error_handling import _is_content_filter_error from pyrit.prompt_target.openai.openai_target import OpenAITarget +from pyrit.tools import ( + CanonicalEnvelopeParser, + LocalToolBackend, + ToolBackend, + ToolCallParser, + ToolEventBehavior, + ToolEventPolicy, +) logger = logging.getLogger(__name__) @@ -76,6 +85,7 @@ class OpenAIResponseTarget(OpenAITarget, PromptTarget): supports_json_output=True, supports_multi_message_pieces=True, supports_system_prompt=True, + supports_tool_use=True, input_modalities=frozenset( { frozenset(["text"]), @@ -154,6 +164,17 @@ def __init__( """ super().__init__(custom_configuration=custom_configuration, **kwargs) + # If the constructed configuration is the class-level _DEFAULT_CONFIGURATION + # singleton (user did not pass custom_configuration AND the underlying_model + # was unrecognized), rebuild a per-instance copy so the C6 tool-backend + # plumbing below does not mutate state shared across every other instance. + if custom_configuration is None and self._configuration is type(self)._DEFAULT_CONFIGURATION: + caps = self._configuration.capabilities + self._configuration = TargetConfiguration( + capabilities=caps, + policy=self._configuration.policy, + ) + # Validate temperature and top_p validate_temperature(temperature) validate_top_p(top_p) @@ -167,10 +188,39 @@ def __init__( self._extra_body_parameters = extra_body_parameters - # Per-instance tool/func registries: - self._custom_functions: dict[str, ToolExecutor] = custom_functions or {} + # ----- Tool-calling plumbing (C6) --------------------------------- + # custom_functions is deprecated as of 0.15.x. New code configures + # tool_backend on TargetConfiguration directly. The kwarg is still + # accepted; we ALWAYS install a LocalToolBackend (whether populated + # or empty) when no other backend is supplied, so legacy in-place + # mutations of `target._custom_functions` (via the back-compat + # property below) keep affecting dispatch. self._fail_on_missing_function: bool = fail_on_missing_function + if self.configuration.tool_backend is None: + shim_backend = LocalToolBackend( + callables=dict(custom_functions) if custom_functions else {}, + schemas=self._derive_default_schemas(custom_functions or {}), + fail_on_missing_function=fail_on_missing_function, + ) + self.configuration.tool_backend = shim_backend + + if custom_functions: + warnings.warn( + "OpenAIResponseTarget(custom_functions=...) is deprecated and will be " + "removed in 0.16.0. Configure tool_backend on TargetConfiguration " + "instead (e.g. LocalToolBackend(callables=..., schemas=..., " + "fail_on_missing_function=...)).", + DeprecationWarning, + stacklevel=2, + ) + + # Default policy to EXECUTE when a backend is present. The wrapper's + # parser returns an empty list when the model produces no tool calls, + # so this is a no-op for plain text completions. + if self.configuration.tool_event_policy is None: + self.configuration.tool_event_policy = ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE) + # Extract the grammar 'tool' if one is present # See # https://platform.openai.com/docs/guides/function-calling#context-free-grammars @@ -185,6 +235,61 @@ def __init__( logger.debug("Detected grammar tool: %s", tool_name) self._grammar_name = tool_name + @staticmethod + def _derive_default_schemas(callables: dict[str, ToolExecutor]) -> list[dict[str, Any]]: + """ + Synthesize minimal JSON schemas for the deprecation-shim path. + + Users who pass the legacy ``custom_functions`` kwarg do not also pass a + schema list (the Responses API would accept the calls anyway because the + legacy path predates structured tool advertisement). To keep the + deprecation shim transparent we generate a schema-less stub per name so + ``_tool_schemas()`` returns something non-empty when the user actually + wires tools. + + Args: + callables: Function name to async callable mapping. + + Returns: + list[dict[str, Any]]: A bare schema per callable (``parameters`` + is the unconstrained empty-object schema). + """ + return [{"name": name, "parameters": {"type": "object"}} for name in callables] + + @property + def _custom_functions(self) -> dict[str, ToolExecutor]: + """ + Back-compat live view of the active backend's callables registry. + + Mutations on the returned dict (``target._custom_functions[name] = fn``, + ``target._custom_functions.pop(name)``) take effect immediately because + the dict object is shared with the underlying + :class:`pyrit.tools.LocalToolBackend`. Returns an empty dict when no + backend is installed or when the configured backend is not a + ``LocalToolBackend``. + + Returns: + dict[str, ToolExecutor]: The live callables dict. + """ + backend = self.configuration.tool_backend + if isinstance(backend, LocalToolBackend): + return cast("dict[str, ToolExecutor]", backend._callables) + return {} + + @_custom_functions.setter + def _custom_functions(self, value: dict[str, ToolExecutor]) -> None: + backend = self.configuration.tool_backend + if isinstance(backend, LocalToolBackend): + backend._callables = dict(value) + backend._schemas = self._derive_default_schemas(value) + return + new_backend = LocalToolBackend( + callables=dict(value), + schemas=self._derive_default_schemas(value), + fail_on_missing_function=self._fail_on_missing_function, + ) + self.configuration.tool_backend = new_backend + def _build_identifier(self) -> ComponentIdentifier: """ Build the identifier with OpenAI response-specific parameters. @@ -378,8 +483,9 @@ async def _construct_request_body( input_items = await self._build_input_for_multi_modal_async(conversation) text_format = self._build_text_format(json_config=json_config) + tool_schemas = self._tool_schemas() - body_parameters = { + body_parameters: dict[str, Any] = { "model": self._model_name, "max_output_tokens": self._max_output_tokens, "temperature": self._temperature, @@ -390,8 +496,11 @@ async def _construct_request_body( "text": text_format, "reasoning": self._build_reasoning_config(), } + if tool_schemas: + body_parameters["tools"] = tool_schemas if self._extra_body_parameters: + # User-supplied extra_body_parameters wins over backend-derived tools. body_parameters.update(self._extra_body_parameters) # Filter out None values @@ -559,11 +668,18 @@ async def _construct_message_from_response(self, response: Any, request: Message @pyrit_target_retry async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: """ - Send prompt, handle agentic tool calls (function_call), return all messages. + Send one prompt to the Responses API and return exactly one Message. + + The agentic tool-calling loop now lives in :func:`pyrit.tools.tool_loop` + on the base class. This method is the single-iteration body the loop + re-enters on each turn: build the request body, call the API, parse the + response, return the constructed :class:`Message` wrapped in a list of + length 1. - The Responses API supports structured outputs and tool execution. This method handles both: - - Simple text/reasoning responses - - Agentic tool-calling loops that may require multiple back-and-forth exchanges + The wrapper detects function_call pieces via :attr:`_tool_parser` and + decides whether to dispatch + re-enter. Reasoning, MCP, web-search, + computer-use, and other non-function-call sections pass through to + Memory unchanged because the parser ignores them. Args: normalized_conversation (list[Message]): The full conversation @@ -571,59 +687,54 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me pipeline. The current message is the last element. Returns: - List of messages generated during the interaction (assistant responses and tool messages). - The normalizer will persist all of these to memory. + list[Message]: Exactly one Message wrapping the parsed response. """ message = normalized_conversation[-1] - message_piece: MessagePiece = message.message_pieces[0] last_piece = message.message_pieces[-1] json_config = self._get_json_response_config(message_piece=last_piece) - working_conversation: MutableSequence[Message] = list(normalized_conversation) - - # Track all responses generated during this interaction - responses_to_return: list[Message] = [] - - # Main agentic loop - each back-and-forth creates a new message - tool_call_section: Optional[dict[str, Any]] = None - - while True: - logger.info(f"Sending conversation with {len(working_conversation)} messages to the prompt target") - - body = await self._construct_request_body(conversation=working_conversation, json_config=json_config) - - # Use unified error handling - automatically detects Response and validates - result = await self._handle_openai_request( - api_call=lambda body=body: self._client.responses.create(**body), - request=message, - ) - - # Add result to conversation and responses list - working_conversation.append(result) - responses_to_return.append(result) - - # Extract tool call if present - tool_call_section = self._find_last_pending_tool_call(result) - - # If no tool call, we're done - if not tool_call_section: - break - - # Execute the tool/function - tool_output = await self._execute_call_section(tool_call_section) + body = await self._construct_request_body(conversation=list(normalized_conversation), json_config=json_config) + logger.info("Sending conversation with %d messages to the Responses API", len(normalized_conversation)) + result = await self._handle_openai_request( + api_call=lambda body=body: self._client.responses.create(**body), + request=message, + ) + return [result] - # Create a new message with the tool output - tool_piece = self._make_tool_piece(tool_output, tool_call_section["call_id"], reference_piece=message_piece) - tool_message = Message(message_pieces=[tool_piece], skip_validation=True) + @property + def _tool_parser(self) -> ToolCallParser | None: + """ + Canonical-envelope parser shared with future canonical-envelope targets. + + Walks response message pieces and emits one :class:`~pyrit.tools.ToolCall` + per piece whose ``original_value_data_type`` is ``"function_call"``. + Reasoning, MCP, web-search, computer-use, and local-shell sections all + produce pieces of OTHER data types, so the parser returns an empty list + for them and the @tool_loop decorator exits cleanly. Those sections + still land in Memory via the parsed Message returned by + ``_send_prompt_to_target_async``; they're just not client-side + dispatched. + """ + return CanonicalEnvelopeParser() - # Add tool output message to conversation and responses list - working_conversation.append(tool_message) - responses_to_return.append(tool_message) + def _tool_schemas(self) -> list[dict[str, Any]]: + """ + Translate the configured backend's schemas into Responses-API tools shape. - # Continue loop to send tool result and get next response + The Responses API expects each function tool as a top-level + ``{"type": "function", "name": ..., "description": ..., + "parameters": ...}`` entry (NOT wrapped in an inner ``"function"`` key + the way Chat Completions does). The backend's schemas are already the + bare function schema, so we just stamp ``type=function`` on each. - # Return all responses (normalizer will persist all of them to memory) - return responses_to_return + Returns: + list[dict[str, Any]]: One descriptor per advertised tool, or an + empty list when no backend is configured. + """ + backend: ToolBackend | None = self.configuration.tool_backend + if backend is None: + return [] + return [{"type": "function", **schema} for schema in backend.schemas] def _parse_response_output_section( self, *, section: Any, message_piece: MessagePiece, error: Optional[PromptResponseError] diff --git a/pyrit/tools/__init__.py b/pyrit/tools/__init__.py index 46b11aa358..f2ae0090ce 100644 --- a/pyrit/tools/__init__.py +++ b/pyrit/tools/__init__.py @@ -53,9 +53,10 @@ RemoteMCPServerSpec, ) from pyrit.tools.models import ToolCall, ToolEventBehavior, ToolEventPolicy, tool_loop -from pyrit.tools.parsers import ToolCallParser +from pyrit.tools.parsers import CanonicalEnvelopeParser, ToolCallParser __all__ = [ + "CanonicalEnvelopeParser", "DockerMCPServerSpec", "LocalMCPServerSpec", "LocalToolBackend", diff --git a/pyrit/tools/parsers.py b/pyrit/tools/parsers.py index 4ff7fc4c04..c903eb73c7 100644 --- a/pyrit/tools/parsers.py +++ b/pyrit/tools/parsers.py @@ -3,6 +3,7 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING, Protocol, runtime_checkable if TYPE_CHECKING: @@ -16,10 +17,11 @@ class ToolCallParser(Protocol): Protocol for extracting tool calls from a target response message. Concrete parsers live next to the target whose response shape they - understand (see :class:`OpenAIChatTarget` and :class:`OpenAIResponseTarget` - after C7/C8). Parsers MUST return an empty list when the model has - issued a stop response — the tool loop uses the empty list as the - signal to exit. + understand (the canonical-envelope parser shipped here is shared by + :class:`OpenAIResponseTarget`; per-model-family parsers for non-OpenAI + targets ship in a follow-up, see plan §12.9). Parsers MUST return an + empty list when the model has issued a stop response — the tool loop + uses the empty list as the signal to exit. """ def parse(self, message: Message) -> list[ToolCall]: @@ -41,9 +43,9 @@ def _extract_function_call_pieces(message: Message) -> list[MessagePiece]: Return every :class:`MessagePiece` in *message* whose ``original_value_data_type`` is ``"function_call"``. - This is the canonical envelope produced by OpenAI-style targets after - the C6 normalization commit. It is exposed here so concrete parsers - can reuse the filter rather than re-implementing it. + This is the canonical envelope used by every PyRIT-supported tool-emitting + target. It is exposed here so concrete parsers can reuse the filter rather + than re-implementing it. Args: message (Message): The message to scan. @@ -53,3 +55,53 @@ def _extract_function_call_pieces(message: Message) -> list[MessagePiece]: ``"function_call"``, in their declaration order. """ return [piece for piece in message.message_pieces if piece.original_value_data_type == "function_call"] + + +class CanonicalEnvelopeParser: + """ + Reference :class:`ToolCallParser` for the canonical function_call envelope. + + Walks every :class:`MessagePiece` whose ``original_value_data_type`` is + ``"function_call"`` and decodes the canonical JSON shape:: + + { + "type": "function_call", + "call_id": "", + "name": "", + "arguments": "" + } + + into :class:`ToolCall` instances. Pieces of other data types -- reasoning, + mcp_call, web_search_call, etc. -- are ignored (they pass through to + Memory but are not client-side dispatchable). Per-model-family parsers + for non-OpenAI targets ship in a follow-up PR (see plan §12.9). + """ + + def parse(self, message: Message) -> list[ToolCall]: + """ + Decode canonical ``function_call`` pieces in *message* into :class:`ToolCall`. + + Args: + message (Message): The most recent assistant response. + + Returns: + list[ToolCall]: One :class:`ToolCall` per ``function_call`` + piece, in declaration order. Empty if the message contains + no ``function_call`` pieces (model stop). + """ + from pyrit.tools.models import ToolCall + + calls: list[ToolCall] = [] + for piece in _extract_function_call_pieces(message): + envelope = json.loads(piece.original_value) + arguments_raw = envelope.get("arguments", "{}") + arguments = json.loads(arguments_raw) if isinstance(arguments_raw, str) else dict(arguments_raw) + calls.append( + ToolCall( + call_id=envelope["call_id"], + name=envelope["name"], + arguments=arguments, + raw_envelope=envelope, + ) + ) + return calls diff --git a/tests/unit/prompt_target/target/test_openai_response_target_c6_migration.py b/tests/unit/prompt_target/target/test_openai_response_target_c6_migration.py new file mode 100644 index 0000000000..04dfd0b024 --- /dev/null +++ b/tests/unit/prompt_target/target/test_openai_response_target_c6_migration.py @@ -0,0 +1,304 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""C6 additions to the Response target function-chaining suite. + +Covers the migration onto @tool_loop + LocalToolBackend. +""" + +from __future__ import annotations + +import json +import uuid +import warnings +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import OpenAIResponseTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import LocalToolBackend, ToolEventBehavior, ToolEventPolicy + + +def _mock_function_call_response(call_id: str, function_name: str, arguments: dict) -> MagicMock: + """Build a fake Responses-API response containing a function_call section.""" + mock_response = MagicMock() + mock_response.status = "completed" + mock_response.error = None + section = MagicMock() + section.type = "function_call" + section.call_id = call_id + section.name = function_name + section.arguments = json.dumps(arguments) + section.model_dump.return_value = { + "type": "function_call", + "call_id": call_id, + "name": function_name, + "arguments": json.dumps(arguments), + } + mock_response.output = [section] + return mock_response + + +def _mock_text_response(text: str) -> MagicMock: + """Build a fake Responses-API response containing a message section.""" + mock_response = MagicMock() + mock_response.status = "completed" + mock_response.error = None + section = MagicMock() + section.type = "message" + section.content = [MagicMock(text=text)] + mock_response.output = [section] + return mock_response + + +def _user_msg(text: str, conversation_id: str | None = None) -> Message: + return Message( + message_pieces=[ + MessagePiece( + role="user", + original_value=text, + conversation_id=conversation_id or str(uuid.uuid4()), + ) + ] + ) + + +class TestCustomFunctionsDeprecation: + """custom_functions still works but emits DeprecationWarning.""" + + def test_custom_functions_kwarg_emits_deprecation_warning(self, patch_central_database): + async def get_weather(args: dict[str, Any]) -> dict[str, Any]: + return {"t": 72} + + with pytest.warns(DeprecationWarning, match="custom_functions"): + OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + custom_functions={"get_weather": get_weather}, + ) + + @pytest.mark.asyncio + async def test_custom_functions_kwarg_still_dispatches(self, patch_central_database): + async def get_weather(args: dict[str, Any]) -> dict[str, Any]: + return {"temperature": 72, "condition": "sunny"} + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + custom_functions={"get_weather": get_weather}, + ) + + responses = [ + _mock_function_call_response("call_1", "get_weather", {"location": "NYC"}), + _mock_text_response("72F and sunny."), + ] + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return responses[len(seen) - 1] + + with patch.object(target._async_client.responses, "create", new_callable=AsyncMock) as mc: + mc.side_effect = mock_create + result = await target.send_prompt_async(message=_user_msg("weather?")) + + assert len(seen) == 2 + assert result[-1].message_pieces[0].original_value == "72F and sunny." + second_input = seen[1]["input"] + assert any(item.get("type") == "function_call_output" for item in second_input) + + +def _config_with_backend(backend: LocalToolBackend) -> TargetConfiguration: + """Build a TargetConfiguration wired for the modern tool-backend path.""" + caps = TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_editable_history=True, + supports_json_output=True, + supports_system_prompt=True, + supports_tool_use=True, + input_modalities=frozenset( + { + frozenset(["text"]), + frozenset(["text", "image_path"]), + frozenset(["function_call"]), + frozenset(["tool_call"]), + frozenset(["function_call_output"]), + frozenset(["reasoning"]), + } + ), + ) + return TargetConfiguration( + capabilities=caps, + tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE, max_tool_iterations=5), + tool_backend=backend, + ) + + +class TestToolBackendDispatch: + """The modern path: pass tool_backend via TargetConfiguration.""" + + @pytest.mark.asyncio + async def test_local_backend_dispatches_through_tool_loop(self, patch_central_database): + async def get_weather(args: dict[str, Any]) -> dict[str, Any]: + return {"temperature": 72, "condition": "sunny"} + + backend = LocalToolBackend( + callables={"get_weather": get_weather}, + schemas=[ + { + "name": "get_weather", + "description": "Weather lookup.", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + } + ], + ) + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + custom_configuration=_config_with_backend(backend), + ) + + responses = [ + _mock_function_call_response("call_1", "get_weather", {"location": "NYC"}), + _mock_text_response("72F and sunny in NYC."), + ] + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return responses[len(seen) - 1] + + with patch.object(target._async_client.responses, "create", new_callable=AsyncMock) as mc: + mc.side_effect = mock_create + result = await target.send_prompt_async(message=_user_msg("weather?")) + + assert len(seen) == 2 + assert result[-1].message_pieces[0].original_value == "72F and sunny in NYC." + second_input = seen[1]["input"] + assert any(item.get("type") == "function_call_output" for item in second_input) + + +class TestToolSchemasInjection: + """_construct_request_body injects backend schemas when present.""" + + @pytest.mark.asyncio + async def test_backend_schemas_injected_into_tools(self, patch_central_database): + async def get_weather(args: dict[str, Any]) -> dict[str, Any]: + return {"t": 1} + + backend = LocalToolBackend( + callables={"get_weather": get_weather}, + schemas=[{"name": "get_weather", "description": "x", "parameters": {"type": "object"}}], + ) + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + custom_configuration=_config_with_backend(backend), + ) + body = await target._construct_request_body( + conversation=[_user_msg("hi")], + json_config=MagicMock(enabled=False, schema=None), + ) + assert "tools" in body + assert body["tools"][0]["type"] == "function" + assert body["tools"][0]["name"] == "get_weather" + + @pytest.mark.asyncio + async def test_extra_body_tools_take_precedence(self, patch_central_database): + async def f(args: dict[str, Any]) -> dict[str, Any]: + return {} + + backend = LocalToolBackend( + callables={"f": f}, + schemas=[{"name": "f", "parameters": {"type": "object"}}], + ) + legacy = [{"type": "function", "name": "legacy_tool", "description": "x"}] + config = _config_with_backend(backend) + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + extra_body_parameters={"tools": legacy}, + custom_configuration=config, + ) + body = await target._construct_request_body( + conversation=[_user_msg("hi")], + json_config=MagicMock(enabled=False, schema=None), + ) + assert body["tools"] == legacy + + @pytest.mark.asyncio + async def test_no_backend_no_tools_key(self, patch_central_database): + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + ) + body = await target._construct_request_body( + conversation=[_user_msg("hi")], + json_config=MagicMock(enabled=False, schema=None), + ) + assert "tools" not in body + + +class TestNonFunctionCallPiecesPassThrough: + """Reasoning / mcp_call / web_search_call sections must NOT be dispatched. + + The Response target's parser populates pieces for these types so they can + be persisted to Memory and round-tripped on subsequent requests. The + CanonicalEnvelopeParser only extracts function_call pieces; the tool loop + must therefore see an empty parse and exit cleanly. + """ + + @pytest.mark.asyncio + async def test_reasoning_only_response_exits_loop(self, patch_central_database): + target = OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + reasoning_effort="medium", + ) + # Reasoning section + final text section in one response + mock_response = MagicMock() + mock_response.status = "completed" + mock_response.error = None + reasoning_section = MagicMock() + reasoning_section.type = "reasoning" + reasoning_section.model_dump.return_value = {"type": "reasoning", "summary": "thinking..."} + text_section = MagicMock() + text_section.type = "message" + text_section.content = [MagicMock(text="The answer is 42.")] + mock_response.output = [reasoning_section, text_section] + + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return mock_response + + with patch.object(target._async_client.responses, "create", new_callable=AsyncMock) as mc: + mc.side_effect = mock_create + result = await target.send_prompt_async(message=_user_msg("question?")) + + # Exactly one API call -- reasoning is not a tool call so the loop exits + assert len(seen) == 1 + # Response message contains both pieces + assert len(result) == 1 + piece_types = [p.original_value_data_type for p in result[0].message_pieces] + assert "reasoning" in piece_types + assert "text" in piece_types From 6d3a79a3eef77283a6dfdda4fee033768454fb17 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 28 May 2026 10:42:20 -0700 Subject: [PATCH 08/17] Add integration tests for RedTeamingAttack with real MCP tool dispatch C7 adds end-to-end integration coverage of the @tool_loop decorator, MCPToolBackend, and MCPClient stack against the real echo_mcp_server subprocess. Only the OpenAI Responses HTTP layer is mocked; the MCP stdio subprocess, AsyncExitStack lifecycle, canonical envelope round-trip, and RedTeamingAttack execution path all run unmocked. # What ships tests/integration/tools/test_red_teaming_with_tools.py with three tests: 1. test_red_teaming_response_target_with_mcp_echo - end-to-end smoke test. RedTeamingAttack drives OpenAIResponseTarget configured with a MCPToolBackend pointing at echo_mcp_server. The Responses API mock returns one function_call followed by a stop response. Asserts the tool call actually reaches the MCP subprocess and the result lands back in the second API call as a function_call_output. 2. test_red_teaming_persists_canonical_transcript_in_memory - verifies the canonical envelope contract (plan section 13). Reads the conversation back from Memory after attack.execute_async returns and asserts the function_call and function_call_output pieces are present, in order, with matching call_ids. 3. test_red_teaming_dispatches_all_tool_calls_per_turn - regression test for the intentional behavior change from C6. The pre-C6 in-class loop in OpenAIResponseTarget only dispatched the LAST function_call per turn; the @tool_loop decorator now dispatches every call in declaration order. Issues both echo and add in one response and asserts both results land in the next API call. # Test infrastructure - LocalMCPServerSpec uses command=sys.executable + args=(echo_server,). - Mock objective scorer returns a true score so RedTeamingAttack exits cleanly after one turn. - Mock adversarial target returns a single scripted prompt wrapped as list[Message] (PromptTarget.send_prompt_async contract). - Score, ComponentIdentifier, and PromptTarget MagicMock(spec=...) usage matches the existing tests/unit/executor/attack patterns. All three integration tests pass; pre-commit clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/integration/tools/__init__.py | 2 + .../tools/test_red_teaming_with_tools.py | 361 ++++++++++++++++++ 2 files changed, 363 insertions(+) create mode 100644 tests/integration/tools/__init__.py create mode 100644 tests/integration/tools/test_red_teaming_with_tools.py diff --git a/tests/integration/tools/__init__.py b/tests/integration/tools/__init__.py new file mode 100644 index 0000000000..9a0454564d --- /dev/null +++ b/tests/integration/tools/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/integration/tools/test_red_teaming_with_tools.py b/tests/integration/tools/test_red_teaming_with_tools.py new file mode 100644 index 0000000000..9dca01371a --- /dev/null +++ b/tests/integration/tools/test_red_teaming_with_tools.py @@ -0,0 +1,361 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""C7 integration tests: RedTeamingAttack with real tool dispatch. + +These tests spawn the real ``tests/unit/tools/echo_mcp_server.py`` subprocess +and exercise the full client-side tool-calling stack: + + attack -> normalizer -> target -> @tool_loop wrapper -> MCPToolBackend -> + MCPClient (stdio) -> echo subprocess -> tool result -> back through the + wrapper -> Memory. + +Only the OpenAI Responses HTTP layer is mocked. The MCP subprocess, the +MCPToolBackend lock, the AsyncExitStack lifecycle, the canonical envelope +round-trip, and the @tool_loop decorator's RedTeam-attack invocation path +all execute under their real implementations. +""" + +from __future__ import annotations + +import json +import pathlib +import sys +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.executor.attack.core.attack_config import AttackAdversarialConfig, AttackScoringConfig +from pyrit.executor.attack.multi_turn.red_teaming import RedTeamingAttack +from pyrit.identifiers import ComponentIdentifier +from pyrit.memory import CentralMemory +from pyrit.models import Message, MessagePiece, Score +from pyrit.prompt_target import OpenAIResponseTarget +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.score.true_false.true_false_scorer import TrueFalseScorer +from pyrit.tools import ( + LocalMCPServerSpec, + MCPToolBackend, + ToolEventBehavior, + ToolEventPolicy, +) + + +def _mock_id(name: str) -> ComponentIdentifier: + return ComponentIdentifier(class_name=name, class_module="test") + + +ECHO_SERVER_PATH = pathlib.Path(__file__).resolve().parents[2] / "unit" / "tools" / "echo_mcp_server.py" + + +def _local_echo_spec() -> LocalMCPServerSpec: + """Build a LocalMCPServerSpec that launches the in-tree echo server.""" + return LocalMCPServerSpec( + command=sys.executable, + args=(str(ECHO_SERVER_PATH),), + ) + + +def _mock_function_call_response(call_id: str, function_name: str, arguments: dict) -> MagicMock: + """Build a fake Responses-API response containing a function_call section.""" + response = MagicMock() + response.status = "completed" + response.error = None + section = MagicMock() + section.type = "function_call" + section.call_id = call_id + section.name = function_name + section.arguments = json.dumps(arguments) + section.model_dump.return_value = { + "type": "function_call", + "call_id": call_id, + "name": function_name, + "arguments": json.dumps(arguments), + } + response.output = [section] + return response + + +def _mock_text_response(text: str) -> MagicMock: + """Build a fake Responses-API response containing a message section.""" + response = MagicMock() + response.status = "completed" + response.error = None + section = MagicMock() + section.type = "message" + section.content = [MagicMock(text=text)] + response.output = [section] + return response + + +def _make_response_target_with_mcp_backend( + backend: MCPToolBackend, +) -> OpenAIResponseTarget: + """Build an OpenAIResponseTarget wired to the live MCP backend.""" + caps = TargetCapabilities( + supports_multi_turn=True, + supports_editable_history=True, + supports_json_output=True, + supports_multi_message_pieces=True, + supports_system_prompt=True, + supports_tool_use=True, + input_modalities=frozenset( + { + frozenset(["text"]), + frozenset(["text", "image_path"]), + frozenset(["function_call"]), + frozenset(["tool_call"]), + frozenset(["function_call_output"]), + frozenset(["reasoning"]), + } + ), + ) + config = TargetConfiguration( + capabilities=caps, + tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE, max_tool_iterations=5), + tool_backend=backend, + ) + return OpenAIResponseTarget( + model_name="gpt-4", + endpoint="https://mock.example.com", + api_key="mock-key", + custom_configuration=config, + ) + + +def _scripted_adversarial(prompts: list[str]) -> MagicMock: + """Build a mock adversarial target that returns scripted prompts.""" + adversarial = MagicMock(spec=PromptTarget) + adversarial.send_prompt_async = AsyncMock( + side_effect=[ + [ + Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=p, + original_value_data_type="text", + conversation_id=str(uuid.uuid4()), + ) + ] + ) + ] + for p in prompts + ] + ) + adversarial.get_identifier.return_value = _mock_id("MockAdversarial") + adversarial.set_system_prompt = MagicMock() + return adversarial + + +def _success_scorer() -> MagicMock: + """Mock objective scorer that always returns True (objective met).""" + scorer = MagicMock(spec=TrueFalseScorer) + scorer.score_async = AsyncMock( + return_value=[ + Score( + score_value="true", + score_value_description="objective met", + score_type="true_false", + score_category=["test"], + score_rationale="mock rationale", + score_metadata={}, + message_piece_id=str(uuid.uuid4()), + scorer_class_identifier=_mock_id("MockScorer"), + ) + ] + ) + scorer.get_identifier.return_value = _mock_id("MockScorer") + return scorer + + +@pytest.mark.asyncio +async def test_red_teaming_response_target_with_mcp_echo(patch_central_database): + """End-to-end: RedTeamingAttack drives OpenAIResponseTarget with MCPToolBackend. + + The Response target's HTTP layer is mocked to return a function_call for + the echo tool, followed by a stop response after the tool result arrives. + The MCP subprocess actually executes the echo call. + """ + backend = MCPToolBackend(servers=[_local_echo_spec()]) + async with backend: + objective_target = _make_response_target_with_mcp_backend(backend) + + # Mock the OpenAI Responses HTTP layer on the objective target. + responses = [ + _mock_function_call_response("call_1", "echo", {"text": "hello"}), + _mock_text_response("Echoed: hello"), + ] + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return responses[len(seen) - 1] + + # Adversarial returns one prompt (RedTeamingAttack stops after objective is met) + adversarial = _scripted_adversarial(["please echo hello"]) + + attack = RedTeamingAttack( + objective_target=objective_target, + attack_adversarial_config=AttackAdversarialConfig(target=adversarial), + attack_scoring_config=AttackScoringConfig(objective_scorer=_success_scorer()), + ) + + with patch.object( + objective_target._async_client.responses, "create", new_callable=AsyncMock + ) as mock_create_call: + mock_create_call.side_effect = mock_create + result = await attack.execute_async(objective="get the model to echo 'hello'") + + # Two HTTP calls to the Response API: initial + post-tool + assert len(seen) == 2 + # Second call must include the function_call_output (tool result) + second_input = seen[1]["input"] + function_outputs = [item for item in second_input if item.get("type") == "function_call_output"] + assert len(function_outputs) == 1 + # The output JSON contains the text "hello" because the real MCP echo + # subprocess returned it + assert "hello" in function_outputs[0]["output"] + assert result is not None + + +@pytest.mark.asyncio +async def test_red_teaming_persists_canonical_transcript_in_memory(patch_central_database): + """End-to-end: after a successful tool dispatch the DB shows the full chain. + + Verifies the canonical envelope contract (§13): the conversation written + to Memory must contain the user message, the assistant function_call, the + tool function_call_output (with matching call_id), and the assistant's + final text -- in that order. + """ + backend = MCPToolBackend(servers=[_local_echo_spec()]) + async with backend: + objective_target = _make_response_target_with_mcp_backend(backend) + + responses = [ + _mock_function_call_response("call_xyz", "echo", {"text": "world"}), + _mock_text_response("Echoed: world"), + ] + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return responses[len(seen) - 1] + + adversarial = _scripted_adversarial(["echo world"]) + + attack = RedTeamingAttack( + objective_target=objective_target, + attack_adversarial_config=AttackAdversarialConfig(target=adversarial), + attack_scoring_config=AttackScoringConfig(objective_scorer=_success_scorer()), + ) + + with patch.object( + objective_target._async_client.responses, "create", new_callable=AsyncMock + ) as mock_create_call: + mock_create_call.side_effect = mock_create + result = await attack.execute_async(objective="echo world") + + # Read the conversation back from Memory + memory = CentralMemory.get_memory_instance() + assert result is not None + objective_conv_id = result.conversation_id + assert objective_conv_id, "Attack result must carry the objective-target conversation id" + + pieces = list(memory.get_message_pieces(conversation_id=objective_conv_id)) + # Filter out system prompts; we care about the user/assistant/tool chain + data_types_in_order = [p.original_value_data_type for p in pieces] + # The chain MUST contain function_call followed by function_call_output (canonical envelope) + assert "function_call" in data_types_in_order + assert "function_call_output" in data_types_in_order + + fc_index = data_types_in_order.index("function_call") + fco_index = data_types_in_order.index("function_call_output") + assert fc_index < fco_index, "function_call must precede function_call_output in DB" + + fc_envelope = json.loads(pieces[fc_index].original_value) + fco_envelope = json.loads(pieces[fco_index].original_value) + assert fc_envelope["call_id"] == fco_envelope["call_id"] == "call_xyz" + assert fc_envelope["name"] == "echo" + # The tool result envelope's `output` is JSON-encoded; the underlying echo result is "world" + assert "world" in fco_envelope["output"] + + +@pytest.mark.asyncio +async def test_red_teaming_dispatches_all_tool_calls_per_turn(patch_central_database): + """Multi-call-per-turn dispatch (intentional behavior change vs pre-C6 loop). + + When the model emits two function_call sections in one response, BOTH + must dispatch through the MCPToolBackend. The pre-C6 in-class loop in + OpenAIResponseTarget only dispatched the LAST call per turn; the C6 + migration onto @tool_loop changes this to "dispatch every call in + declaration order." Verify by issuing both an `echo` and an `add` call + and asserting both results land in the second API call's input. + """ + backend = MCPToolBackend(servers=[_local_echo_spec()]) + async with backend: + objective_target = _make_response_target_with_mcp_backend(backend) + + # First response contains TWO function_calls; second is the stop text. + multi_call_response = MagicMock() + multi_call_response.status = "completed" + multi_call_response.error = None + + echo_section = MagicMock() + echo_section.type = "function_call" + echo_section.call_id = "call_echo" + echo_section.name = "echo" + echo_section.arguments = json.dumps({"text": "hi"}) + echo_section.model_dump.return_value = { + "type": "function_call", + "call_id": "call_echo", + "name": "echo", + "arguments": json.dumps({"text": "hi"}), + } + add_section = MagicMock() + add_section.type = "function_call" + add_section.call_id = "call_add" + add_section.name = "add" + add_section.arguments = json.dumps({"a": 3, "b": 4}) + add_section.model_dump.return_value = { + "type": "function_call", + "call_id": "call_add", + "name": "add", + "arguments": json.dumps({"a": 3, "b": 4}), + } + multi_call_response.output = [echo_section, add_section] + + responses = [ + multi_call_response, + _mock_text_response("done"), + ] + seen = [] + + async def mock_create(**kwargs): + seen.append(kwargs) + return responses[len(seen) - 1] + + adversarial = _scripted_adversarial(["call echo and add"]) + + attack = RedTeamingAttack( + objective_target=objective_target, + attack_adversarial_config=AttackAdversarialConfig(target=adversarial), + attack_scoring_config=AttackScoringConfig(objective_scorer=_success_scorer()), + ) + + with patch.object(objective_target._async_client.responses, "create", new_callable=AsyncMock) as mc: + mc.side_effect = mock_create + await attack.execute_async(objective="dispatch both tools") + + assert len(seen) == 2 + second_input = seen[1]["input"] + outputs = [item for item in second_input if item.get("type") == "function_call_output"] + assert len(outputs) == 2, "Both tool calls must be dispatched per the new behavior" + call_ids = [o["call_id"] for o in outputs] + assert call_ids == ["call_echo", "call_add"], "Outputs must preserve declaration order" + # Real MCP subprocess: echo("hi") returned "hi", add(3, 4) returned 7 + assert "hi" in outputs[0]["output"] + assert "7" in outputs[1]["output"] From 1fc70b32a4c74397d3468cdd272b87512a387a71 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 28 May 2026 12:41:10 -0700 Subject: [PATCH 09/17] Add InlineToolCallParser for chat-template-based open models Adds a parser that walks text MessagePieces for marker-delimited JSON blocks of the form {"name": ..., "arguments": {...}} and emits canonical ToolCall instances. Marker pattern, call_id prefix, and surrounding-text policy (truncate / extract-all / strict) are all constructor-controlled so a single class covers angle-bracket, pipe-delimited tag pair, and other chat-template syntaxes. The parser is the F1 (per plan) piece that lets non-Responses-API targets participate in PyRIT's @tool_loop without a per-vendor parser implementation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/tools/__init__.py | 3 + pyrit/tools/inline_parser.py | 202 +++++++++++++++++++++ tests/unit/tools/test_inline_parser.py | 234 +++++++++++++++++++++++++ 3 files changed, 439 insertions(+) create mode 100644 pyrit/tools/inline_parser.py create mode 100644 tests/unit/tools/test_inline_parser.py diff --git a/pyrit/tools/__init__.py b/pyrit/tools/__init__.py index f2ae0090ce..1bc7300ce7 100644 --- a/pyrit/tools/__init__.py +++ b/pyrit/tools/__init__.py @@ -43,6 +43,7 @@ """ from pyrit.tools.backend import ToolBackend +from pyrit.tools.inline_parser import InlineToolCallParser, InlineToolCallParserMode from pyrit.tools.local_backend import LocalToolBackend from pyrit.tools.mcp_backend import MCPToolBackend from pyrit.tools.mcp_client import ( @@ -58,6 +59,8 @@ __all__ = [ "CanonicalEnvelopeParser", "DockerMCPServerSpec", + "InlineToolCallParser", + "InlineToolCallParserMode", "LocalMCPServerSpec", "LocalToolBackend", "MCPClient", diff --git a/pyrit/tools/inline_parser.py b/pyrit/tools/inline_parser.py new file mode 100644 index 0000000000..c7cc353fe4 --- /dev/null +++ b/pyrit/tools/inline_parser.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Inline tool-call parser for open chat-tuned models.""" + +from __future__ import annotations + +import enum +import json +import logging +import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyrit.models import Message + from pyrit.tools.models import ToolCall + +logger = logging.getLogger(__name__) + + +class InlineToolCallParserMode(enum.Enum): + """Policy for handling text that surrounds inline tool-call markers.""" + + TRUNCATE_AT_LAST = "truncate_at_last" + TRUNCATE_AT_FIRST = "truncate_at_first" + EXTRACT_ALL = "extract_all" + STRICT_TRAILING_EMPTY = "strict_trailing_empty" + + +class InlineToolCallParser: + """ + Extract canonical ``ToolCall`` instances from marker-delimited JSON blocks. + + Open chat-tuned models that do not expose a structured ``tool_calls`` + channel typically emit tool calls as inline text wrapped in a + chat-template-specific marker (an angle-bracket pair, a pipe-delimited + tag pair, a square-bracketed list payload, and so on). This parser + walks every ``MessagePiece`` whose ``original_value_data_type`` is + ``"text"`` and runs ``marker_pattern`` against the piece's + ``original_value``. Each match capture group is decoded as JSON of the + form ``{"name": ..., "arguments": {...}}``. Synthetic ``call_id`` + values are minted positionally because inline-marker formats do not + issue provider IDs. + + The ``mode`` parameter controls how text surrounding the markers is + treated -- see ``InlineToolCallParserMode``. The default + ``TRUNCATE_AT_LAST`` honors every marker but discards anything after + the last one so hallucinated "tool results" that the model dreams up + after the call are not persisted as if they were real outputs. + + Args: + marker_pattern (str): Regex with exactly one capture group returning + the JSON payload. Default targets the angle-bracket + ``...`` syntax used by many tool-trained + ChatML-style chat templates. + call_id_prefix (str): Prefix for synthetic ``call_id`` values. + mode (InlineToolCallParserMode): Surrounding-text policy. + """ + + def __init__( + self, + *, + marker_pattern: str = r"(.*?)", + call_id_prefix: str = "call", + mode: InlineToolCallParserMode = InlineToolCallParserMode.TRUNCATE_AT_LAST, + ) -> None: + """ + Build an ``InlineToolCallParser``. + + Args: + marker_pattern (str): See class docstring. + call_id_prefix (str): See class docstring. + mode (InlineToolCallParserMode): See class docstring. + """ + self._pattern = re.compile(marker_pattern, re.DOTALL) + self._call_id_prefix = call_id_prefix + self._mode = mode + + @property + def mode(self) -> InlineToolCallParserMode: + """The active surrounding-text policy.""" + return self._mode + + def parse(self, message: Message) -> list[ToolCall]: + """ + Extract tool calls from every text piece in ``message``. + + Args: + message (Message): The most recent assistant response. + + Returns: + list[ToolCall]: One ``ToolCall`` per valid marker match, in + declaration order across pieces. Empty when no markers + are found. + + Raises: + ValueError: When ``mode`` is ``STRICT_TRAILING_EMPTY`` and any + non-whitespace text follows the last marker in any piece. + """ + calls: list[ToolCall] = [] + next_id = 0 + for piece in message.message_pieces: + if piece.original_value_data_type != "text": + continue + matches = self._match_piece(text=piece.original_value) + if not matches: + continue + for match in matches: + call = self._build_call(match=match, next_id=next_id) + if call is None: + continue + calls.append(call) + next_id += 1 + + return calls + + def _match_piece(self, *, text: str) -> list[re.Match[str]]: + """ + Apply the mode-specific filter to all marker matches in ``text``. + + Args: + text (str): Piece text to scan for markers. + + Returns: + list[re.Match[str]]: Matches to honor, in declaration order. + + Raises: + ValueError: When ``mode`` is ``STRICT_TRAILING_EMPTY`` and any + non-whitespace text follows the last marker. + """ + matches = list(self._pattern.finditer(text)) + if not matches: + return matches + + if self._mode is InlineToolCallParserMode.TRUNCATE_AT_FIRST: + return matches[:1] + if self._mode is InlineToolCallParserMode.STRICT_TRAILING_EMPTY: + trailing = text[matches[-1].end() :] + if trailing.strip(): + raise ValueError( + "Non-whitespace text follows the last tool-call marker; " + "InlineToolCallParserMode.STRICT_TRAILING_EMPTY rejects this. " + f"Trailing: {trailing!r}" + ) + return matches + + def _build_call(self, *, match: re.Match[str], next_id: int) -> ToolCall | None: + """ + Decode a single marker match into a ``ToolCall``. + + Args: + match (re.Match[str]): A single marker match whose group 1 is the + JSON payload. + next_id (int): Positional id used to form ``call_id``. + + Returns: + ToolCall | None: ``None`` when the payload is malformed or + missing the ``name`` field. The caller is expected to log + and skip. + """ + from pyrit.tools.models import ToolCall + + payload = match.group(1).strip() + try: + parsed = json.loads(payload) + except json.JSONDecodeError: + logger.warning("Skipping malformed tool-call payload: %r", payload[:120]) + return None + if not isinstance(parsed, dict) or "name" not in parsed: + logger.warning("Skipping tool-call payload without 'name' field: %r", payload[:120]) + return None + arguments = self._coerce_arguments(parsed.get("arguments", {})) + return ToolCall( + call_id=f"{self._call_id_prefix}_{next_id}", + name=parsed["name"], + arguments=arguments, + raw_envelope=parsed, + ) + + @staticmethod + def _coerce_arguments(raw: object) -> dict: + """ + Coerce the ``arguments`` field into a dict regardless of source shape. + + Args: + raw (object): The value of the payload's ``arguments`` field. + Either a dict, a JSON-encoded string, or something else. + + Returns: + dict: The decoded arguments dict, or an empty dict on any + shape PyRIT cannot interpret. Empty-dict fallback preserves + the loop's behavior of "always continue with a real call". + """ + if isinstance(raw, dict): + return dict(raw) + if isinstance(raw, str): + try: + decoded = json.loads(raw) + except json.JSONDecodeError: + return {} + return dict(decoded) if isinstance(decoded, dict) else {} + return {} diff --git a/tests/unit/tools/test_inline_parser.py b/tests/unit/tools/test_inline_parser.py new file mode 100644 index 0000000000..2f2850e667 --- /dev/null +++ b/tests/unit/tools/test_inline_parser.py @@ -0,0 +1,234 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Unit tests for InlineToolCallParser across marker syntaxes.""" + +from __future__ import annotations + +import logging +import uuid + +import pytest + +from pyrit.models import Message, MessagePiece +from pyrit.tools import InlineToolCallParser, InlineToolCallParserMode + + +def _assistant_text(text: str) -> Message: + return Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=text, + original_value_data_type="text", + conversation_id=str(uuid.uuid4()), + ) + ], + skip_validation=True, + ) + + +# Marker patterns commonly seen in tool-trained open chat templates. +# Named after their syntactic shape, not the model family that uses them, +# so that PyRIT does not advertise a "supported vendors" list. +ANGLE_BRACKET_PATTERN = r"(.*?)" +PIPE_PYTHON_TAG_PATTERN = r"<\|python_tag\|>(.*?)<\|eom_id\|>" +SQUARE_BRACKET_LIST_PATTERN = r"\[TOOL_CALLS\]\s*(\[.*?\])" + + +# --------------------------------------------------------------------------- +# Marker-pattern coverage: angle-bracket / pipe-python-tag / square-bracket-list +# --------------------------------------------------------------------------- + + +class TestAngleBracketMarker: + """Default pattern: ``...``.""" + + def test_single_marker_extracts_call(self): + parser = InlineToolCallParser() + text = ( + 'Let me check the weather. {"name": "get_weather", "arguments": {"location": "NYC"}}' + ) + calls = parser.parse(_assistant_text(text)) + assert len(calls) == 1 + assert calls[0].name == "get_weather" + assert calls[0].arguments == {"location": "NYC"} + assert calls[0].call_id == "call_0" + + def test_multiple_markers_extract_all_with_synthetic_ids(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.EXTRACT_ALL) + text = ( + '{"name": "f1", "arguments": {}}' + "some interleaving text " + '{"name": "f2", "arguments": {"k": 1}}' + ) + calls = parser.parse(_assistant_text(text)) + assert [c.name for c in calls] == ["f1", "f2"] + assert [c.call_id for c in calls] == ["call_0", "call_1"] + + def test_arguments_as_json_string_is_decoded(self): + parser = InlineToolCallParser() + # arguments doubly-encoded as a JSON string (canonical envelope shape). + text = '{"name": "f", "arguments": "{\\"a\\": 1}"}' + calls = parser.parse(_assistant_text(text)) + assert len(calls) == 1 + assert calls[0].arguments == {"a": 1} + + +class TestPipePythonTagMarker: + """Pipe-delimited tag pair: ``<|python_tag|>...<|eom_id|>``.""" + + def test_single_marker_extracts_call(self): + parser = InlineToolCallParser(marker_pattern=PIPE_PYTHON_TAG_PATTERN) + text = ( + 'Sure, calling now. <|python_tag|>{"name": "get_weather", "arguments": {"location": "Seattle"}}<|eom_id|>' + ) + calls = parser.parse(_assistant_text(text)) + assert len(calls) == 1 + assert calls[0].name == "get_weather" + assert calls[0].arguments == {"location": "Seattle"} + + def test_multi_marker_extract_all(self): + parser = InlineToolCallParser( + marker_pattern=PIPE_PYTHON_TAG_PATTERN, + mode=InlineToolCallParserMode.EXTRACT_ALL, + ) + text = ( + '<|python_tag|>{"name": "a", "arguments": {}}<|eom_id|>' + "between " + '<|python_tag|>{"name": "b", "arguments": {}}<|eom_id|>' + ) + calls = parser.parse(_assistant_text(text)) + assert [c.name for c in calls] == ["a", "b"] + + +class TestSquareBracketListMarker: + """Square-bracketed list payload: ``[TOOL_CALLS] [...]``.""" + + def test_list_payload_is_skipped_with_warning(self, caplog): + """The default parser expects a single-dict payload. + + Marker syntaxes whose payload is a JSON LIST of dicts (rather than a + single dict) are logged and dropped. Callers that need list-shaped + payloads should either subclass ``InlineToolCallParser`` and override + ``parse`` to iterate the list, or use a marker pattern that targets + each dict inside the list separately. Regex does not handle nested + braces well; subclassing is cleaner. + """ + parser = InlineToolCallParser(marker_pattern=SQUARE_BRACKET_LIST_PATTERN) + text = '[TOOL_CALLS] [{"name": "f", "arguments": {"x": 1}}]' + with caplog.at_level(logging.WARNING, logger="pyrit.tools.inline_parser"): + calls = parser.parse(_assistant_text(text)) + assert calls == [] + assert any("without 'name' field" in rec.message for rec in caplog.records) + + +# --------------------------------------------------------------------------- +# Mode coverage: TRUNCATE_AT_LAST / TRUNCATE_AT_FIRST / EXTRACT_ALL / +# STRICT_TRAILING_EMPTY +# --------------------------------------------------------------------------- + + +class TestParserModes: + """Surrounding-text policy coverage.""" + + HALLUCINATED = ( + '{"name": "get_weather", "arguments": {"location": "NYC"}}' + " The weather in NYC is sunny and 72 degrees." + ) + DOUBLE_CALL = ( + '{"name": "a", "arguments": {}}' + " middle " + '{"name": "b", "arguments": {}}' + " trailing chatter" + ) + + def test_truncate_at_last_default_drops_trailing_chatter(self): + parser = InlineToolCallParser() # default mode + calls = parser.parse(_assistant_text(self.HALLUCINATED)) + # The hallucinated weather report after the marker is discarded; the + # call itself is honored. + assert len(calls) == 1 + assert calls[0].name == "get_weather" + + def test_truncate_at_last_extracts_all_markers_then_drops_tail(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.TRUNCATE_AT_LAST) + calls = parser.parse(_assistant_text(self.DOUBLE_CALL)) + # Both markers honored, trailing "trailing chatter" silently dropped. + assert [c.name for c in calls] == ["a", "b"] + + def test_truncate_at_first_keeps_only_the_first_marker(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.TRUNCATE_AT_FIRST) + calls = parser.parse(_assistant_text(self.DOUBLE_CALL)) + assert [c.name for c in calls] == ["a"] + + def test_extract_all_keeps_every_marker(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.EXTRACT_ALL) + calls = parser.parse(_assistant_text(self.DOUBLE_CALL)) + assert [c.name for c in calls] == ["a", "b"] + + def test_strict_trailing_empty_raises_on_chatter(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.STRICT_TRAILING_EMPTY) + with pytest.raises(ValueError, match="STRICT_TRAILING_EMPTY"): + parser.parse(_assistant_text(self.HALLUCINATED)) + + def test_strict_trailing_empty_passes_when_only_whitespace_after(self): + parser = InlineToolCallParser(mode=InlineToolCallParserMode.STRICT_TRAILING_EMPTY) + text = '{"name": "f", "arguments": {}}\n \t ' + calls = parser.parse(_assistant_text(text)) + assert [c.name for c in calls] == ["f"] + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + """Empty input, malformed JSON, missing name, multi-piece messages.""" + + def test_no_markers_returns_empty(self): + parser = InlineToolCallParser() + calls = parser.parse(_assistant_text("just plain assistant text")) + assert calls == [] + + def test_malformed_json_is_skipped_silently(self): + parser = InlineToolCallParser() + text = "not valid json" + calls = parser.parse(_assistant_text(text)) + assert calls == [] + + def test_payload_without_name_is_skipped(self): + parser = InlineToolCallParser() + text = '{"arguments": {}}' + calls = parser.parse(_assistant_text(text)) + assert calls == [] + + def test_non_text_pieces_are_ignored(self): + """Pieces with data_type other than 'text' are skipped entirely.""" + parser = InlineToolCallParser() + msg = Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value='{"name": "ignored", "arguments": {}}', + original_value_data_type="reasoning", + conversation_id=str(uuid.uuid4()), + ), + MessagePiece( + role="assistant", + original_value='{"name": "found", "arguments": {}}', + original_value_data_type="text", + conversation_id=str(uuid.uuid4()), + ), + ], + skip_validation=True, + ) + calls = parser.parse(msg) + assert [c.name for c in calls] == ["found"] + + def test_call_id_prefix_customization(self): + parser = InlineToolCallParser(call_id_prefix="custom") + text = '{"name": "f", "arguments": {}}' + calls = parser.parse(_assistant_text(text)) + assert calls[0].call_id == "custom_0" From 85b59bc8b1209564e086d07234625af6a0ae53e0 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 28 May 2026 12:47:19 -0700 Subject: [PATCH 10/17] Include tool_event_policy and tool_backend in target identifier params TargetConfiguration.as_identifier_params() now snapshots the configured tool_event_policy (behavior + max_tool_iterations) and tool_backend (backend class + sorted list of advertised tool names). Two targets that differ only in their tool backend now get distinct identifiers, which downstream consumers rely on to route by target identity. Schema serialization is best-effort: backends with shape-quirky schemas that lack a recoverable 'name' field are silently dropped from the identifier surface. Exact callables and transports are not serialized because they are not deterministic. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../common/target_configuration.py | 88 ++++++++++++++++--- .../target/test_prompt_target.py | 67 +++++++++++++- 2 files changed, 142 insertions(+), 13 deletions(-) diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py index 6058409194..e9d401b824 100644 --- a/pyrit/prompt_target/common/target_configuration.py +++ b/pyrit/prompt_target/common/target_configuration.py @@ -208,14 +208,20 @@ def as_identifier_params(self) -> dict[str, Any]: suitable for inclusion in a ``ComponentIdentifier``. The returned dict preserves the structure of ``TargetConfiguration`` - — capabilities, policy, and pipeline are kept as nested sub-dicts rather - than flattened into the caller — so the identifier reflects the shape of - the object it describes. + — capabilities, policy, pipeline, tool-event policy, and tool backend + are kept as nested sub-dicts rather than flattened into the caller — + so the identifier reflects the shape of the object it describes. Two configurations that behave identically must produce equal dicts; configurations that differ in any identity-bearing field must produce - unequal dicts. Modality sets are flattened to sorted lists of sorted - lists so ordering is stable across runs. + unequal dicts. The tool-backend snapshot uses the backend class plus + the sorted list of advertised tool names; this means two backends of + the same type exposing the same tool set are treated as equivalent + for identifier purposes (their exact callables / transports are not + deterministically serializable). + + Modality sets are flattened to sorted lists of sorted lists so + ordering is stable across runs. Returns: dict[str, Any]: The identifier parameters for this configuration. @@ -223,24 +229,82 @@ def as_identifier_params(self) -> dict[str, Any]: caps = self._capabilities return { "capabilities": self._capabilities_to_identifier_params(caps), - # Only unsupported capabilities appear here. Policy entries for - # natively-supported capabilities are moot (the behavior never - # fires), and omitting them keeps identifiers stable when default - # policies expand to cover more capabilities. "capability_policy": { capability.value: behavior.value for capability, behavior in self._policy.behaviors.items() if not caps.includes(capability=capability) }, - # Stable, ordered representation of the resolved normalization - # pipeline. Captures the effect of ``normalizer_overrides`` since - # the pipeline is built from defaults + overrides. "normalization_pipeline": [ f"{type(normalizer).__module__}.{type(normalizer).__qualname__}" for normalizer in self._pipeline.normalizers ], + "tool_event_policy": self._tool_event_policy_to_identifier_params(), + "tool_backend": self._tool_backend_to_identifier_params(), + } + + def _tool_event_policy_to_identifier_params(self) -> dict[str, Any] | None: + """ + Snapshot the active tool-event policy as identifier params. + + Returns: + dict[str, Any] | None: ``None`` when no policy is configured; + otherwise ``behavior`` and ``max_tool_iterations`` as plain + values. + """ + if self._tool_event_policy is None: + return None + return { + "behavior": self._tool_event_policy.behavior.value, + "max_tool_iterations": self._tool_event_policy.max_tool_iterations, + } + + def _tool_backend_to_identifier_params(self) -> dict[str, Any] | None: + """ + Snapshot the active tool backend as identifier params. + + Returns the backend's fully-qualified class name plus the sorted + list of tool names it advertises. Exact callables / transports are + not serialized; two backends of the same type exposing the same + tool set therefore produce equal identifier dicts. + + Returns: + dict[str, Any] | None: ``None`` when no backend is configured; + otherwise ``type`` (fully-qualified class name) and + ``tools`` (sorted list of advertised tool names). + """ + if self._tool_backend is None: + return None + backend_type = type(self._tool_backend) + return { + "type": f"{backend_type.__module__}.{backend_type.__qualname__}", + "tools": sorted(self._extract_tool_names(self._tool_backend.schemas)), } + @staticmethod + def _extract_tool_names(schemas: list[dict[str, Any]]) -> list[str]: + """ + Pull the ``name`` field from each schema, supporting both flat and + nested ``function`` envelopes. + + Args: + schemas (list[dict[str, Any]]): The backend-provided schema list. + + Returns: + list[str]: One name per schema. Schemas without a recoverable + name are skipped silently — the identifier is best-effort + for shape-quirky backends. + """ + names: list[str] = [] + for schema in schemas: + if not isinstance(schema, dict): + continue + name = schema.get("name") + if not name and isinstance(schema.get("function"), dict): + name = schema["function"].get("name") + if isinstance(name, str): + names.append(name) + return names + @staticmethod def _capabilities_to_identifier_params(capabilities: TargetCapabilities) -> dict[str, Any]: """ diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py index f3174c2649..fa29815103 100644 --- a/tests/unit/prompt_target/target/test_prompt_target.py +++ b/tests/unit/prompt_target/target/test_prompt_target.py @@ -22,6 +22,7 @@ UnsupportedCapabilityBehavior, ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import LocalToolBackend, ToolEventBehavior, ToolEventPolicy @pytest.fixture @@ -534,7 +535,13 @@ def test_identifier_includes_capability_params(): # Config-derived fields are nested under ``target_configuration``, not # spread at the top level — guards against accidental re-flattening. assert "supports_multi_turn" not in params - assert set(target_config.keys()) == {"capabilities", "capability_policy", "normalization_pipeline"} + assert set(target_config.keys()) == { + "capabilities", + "capability_policy", + "normalization_pipeline", + "tool_event_policy", + "tool_backend", + } assert capabilities["supports_multi_turn"] is True assert capabilities["supports_multi_message_pieces"] is True @@ -546,6 +553,8 @@ def test_identifier_includes_capability_params(): assert capabilities["output_modalities"] == [["text"]] assert isinstance(target_config["capability_policy"], dict) assert isinstance(target_config["normalization_pipeline"], list) + assert target_config["tool_event_policy"] is None + assert target_config["tool_backend"] is None @pytest.mark.usefixtures("patch_central_database") @@ -581,6 +590,62 @@ def test_identifier_differs_when_policy_differs(): assert a.get_identifier().hash != b.get_identifier().hash +@pytest.mark.usefixtures("patch_central_database") +def test_identifier_differs_when_tool_backend_differs(): + async def _f(_: dict) -> dict: + return {} + + capabilities = TargetCapabilities(supports_tool_use=True) + backend_a = LocalToolBackend( + callables={"alpha": _f}, + schemas=[{"name": "alpha", "parameters": {"type": "object"}}], + ) + backend_b = LocalToolBackend( + callables={"beta": _f}, + schemas=[{"name": "beta", "parameters": {"type": "object"}}], + ) + + a = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration(capabilities=capabilities, tool_backend=backend_a), + ) + b = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration(capabilities=capabilities, tool_backend=backend_b), + ) + + assert a.get_identifier().hash != b.get_identifier().hash + + +@pytest.mark.usefixtures("patch_central_database") +def test_identifier_differs_when_tool_event_policy_differs(): + capabilities = TargetCapabilities(supports_tool_use=True) + a = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=capabilities, + tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.EXECUTE), + ), + ) + b = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=capabilities, + tool_event_policy=ToolEventPolicy(behavior=ToolEventBehavior.RAISE), + ), + ) + + assert a.get_identifier().hash != b.get_identifier().hash + + @pytest.mark.usefixtures("patch_central_database") def test_identifier_is_deterministic_across_instances(): capabilities = TargetCapabilities( From 2129e204a18b5ec6f43b14f57d61fc9ff10483a4 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 28 May 2026 12:47:26 -0700 Subject: [PATCH 11/17] Convert reST cross-reference roles in pyrit/tools docstrings to MyST PyRIT's docs build uses MyST, not reStructuredText, so reST roles like :class:\Foo\ render as literal text in the rendered docs and mismatch the rest of the codebase. Convert all roles in the new pyrit/tools/ module to plain double-backtick code spans, and drop the in-flight commit-numbering references (C1/C2/...) that were carry-overs from the shipping plan and no longer mean anything in source. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/tools/__init__.py | 33 ++++++++++++------------ pyrit/tools/backend.py | 14 +++++----- pyrit/tools/local_backend.py | 10 ++++---- pyrit/tools/mcp_backend.py | 36 +++++++++++++------------- pyrit/tools/mcp_client.py | 50 ++++++++++++++++++------------------ pyrit/tools/models.py | 26 +++++++++---------- pyrit/tools/parsers.py | 14 +++++----- 7 files changed, 91 insertions(+), 92 deletions(-) diff --git a/pyrit/tools/__init__.py b/pyrit/tools/__init__.py index 1bc7300ce7..e91b44a037 100644 --- a/pyrit/tools/__init__.py +++ b/pyrit/tools/__init__.py @@ -2,42 +2,41 @@ # Licensed under the MIT license. """ -Generic tool-use scaffolding for :class:`~pyrit.prompt_target.PromptTarget`. +Generic tool-use scaffolding for ``PromptTarget``. This package provides a transport-agnostic tool-calling loop. The -:func:`tool_loop` decorator, when applied to ``send_prompt_async``, runs +``tool_loop`` decorator, when applied to ``send_prompt_async``, runs the standard PyRIT validate+normalize work once and then repeatedly re-enters the target's protected ``_send_prompt_to_target_async`` until the model issues a stop response (or a configured limit is hit). A target opts in by declaring two collaborators: -* ``self._tool_parser`` — a :class:`ToolCallParser` that walks a - response message and extracts pending :class:`ToolCall` instances. -* ``self.configuration.tool_event_policy`` — a :class:`ToolEventPolicy` - whose :class:`ToolEventBehavior` decides whether to ``EXECUTE``, +* ``self._tool_parser`` — a ``ToolCallParser`` that walks a + response message and extracts pending ``ToolCall`` instances. +* ``self.configuration.tool_event_policy`` — a ``ToolEventPolicy`` + whose ``ToolEventBehavior`` decides whether to ``EXECUTE``, ``RAISE``, or ``RETURN_RAW`` on each detected call. When the policy is ``EXECUTE``, calls are dispatched through ``self.configuration.tool_backend``, an implementation of -:class:`ToolBackend`. :class:`LocalToolBackend` is the in-process -backend shipped here; :class:`MCPToolBackend` ships in C3 and proxies -through one or more MCP servers. +``ToolBackend``. ``LocalToolBackend`` is the in-process backend; +``MCPToolBackend`` proxies through one or more MCP servers. -The :class:`ToolBackend` Protocol is intentionally distinct from -:mod:`pyrit.registry` — that namespace is reserved for framework-level +The ``ToolBackend`` abstract base is intentionally distinct from +``pyrit.registry`` — that namespace is reserved for framework-level identity registries (``TargetRegistry``, ``ScorerRegistry``) that register named singletons for CLI lookup, which a per-target tool dispatch table is not. -Wiring of ``@tool_loop`` onto :class:`PromptTarget.send_prompt_async` -and of the ``tool_event_policy`` / ``tool_backend`` fields onto -:class:`TargetConfiguration` lands in C4/C5. +``@tool_loop`` is wired onto ``PromptTarget.send_prompt_async`` from +the base class, and the ``tool_event_policy`` / ``tool_backend`` +fields hang off ``TargetConfiguration``. The two exception types the loop raises -(:class:`~pyrit.exceptions.ToolCallNotSupported` and -:class:`~pyrit.exceptions.ToolCallLoopLimitExceeded`) live in -:mod:`pyrit.exceptions` alongside the rest of PyRIT's exception +(``ToolCallNotSupported`` and +``ToolCallLoopLimitExceeded``) live in +``pyrit.exceptions`` alongside the rest of PyRIT's exception catalog, so non-tools callers (attacks, normalizers) can import them without taking a subsystem-level dependency on ``pyrit.tools``. """ diff --git a/pyrit/tools/backend.py b/pyrit/tools/backend.py index e7a02a7685..d878bd1f4f 100644 --- a/pyrit/tools/backend.py +++ b/pyrit/tools/backend.py @@ -14,23 +14,23 @@ class ToolBackend(ABC): """ Abstract base for backends that dispatch tool calls produced by a target. - A :class:`ToolBackend` is a per-target dispatch table — it owns the + A ``ToolBackend`` is a per-target dispatch table — it owns the ``name -> async callable`` mapping a target uses to execute the tool calls extracted from a model response. This is intentionally distinct - from :mod:`pyrit.registry`, whose ``Registry`` classes register named + from ``pyrit.registry``, whose ``Registry`` classes register named framework singletons (targets, scorers, attacks) for CLI lookup. Two concrete implementations ship with PyRIT: - * :class:`~pyrit.tools.LocalToolBackend` — in-process backend backed + * ``LocalToolBackend`` — in-process backend backed by ``async def`` callables. Useful for unit tests and for embedding tools inside the PyRIT process. - * :class:`~pyrit.tools.MCPToolBackend` — proxies dispatch through one + * ``MCPToolBackend`` — proxies dispatch through one or more MCP servers. - Subclasses MUST implement :attr:`schemas` and :meth:`dispatch_async`. - :meth:`dispatch_all_sequential_async` ships with a default - implementation that awaits :meth:`dispatch_async` once per call in + Subclasses MUST implement ``schemas`` and ``dispatch_async``. + ``dispatch_all_sequential_async`` ships with a default + implementation that awaits ``dispatch_async`` once per call in declaration order; backends that wish to parallelize dispatch (e.g. fan out across multiple sandbox containers) should override it. """ diff --git a/pyrit/tools/local_backend.py b/pyrit/tools/local_backend.py index 25fe42e83c..7a149bc825 100644 --- a/pyrit/tools/local_backend.py +++ b/pyrit/tools/local_backend.py @@ -18,14 +18,14 @@ class LocalToolBackend(ToolBackend): """ - In-process :class:`~pyrit.tools.ToolBackend` backed by a name -> ``async def`` + In-process ``ToolBackend`` backed by a name -> ``async def`` mapping. Useful for unit tests and for embedding small tools inside the PyRIT process without standing up an MCP server. "Local" here means tools run in PyRIT's own Python process — no subprocess, no IPC, no wire protocol. Contrast with - :class:`~pyrit.tools.MCPToolBackend` (lands in C3), which proxies - dispatch through one or more MCP servers reached via JSON-RPC. + ``MCPToolBackend``, which proxies dispatch through one or more MCP + servers reached via JSON-RPC. The backend dispatches sequentially in declaration order. Tool-side failures (raised exceptions, missing names, allow-list rejections) @@ -49,7 +49,7 @@ def __init__( callables (dict[str, Callable[[dict[str, Any]], Awaitable[Any]]]): Map from tool name to an ``async def`` that accepts a parsed arguments dict and returns the tool result. Results are - serialized by the tool loop via :func:`json.dumps`. + serialized by the tool loop via ``json.dumps``. schemas (list[dict[str, Any]] | None): JSON-schema descriptors injected into the target's request body. Defaults to an empty list when omitted. @@ -59,7 +59,7 @@ def __init__( Defaults to None (no allow-list; every registered tool is callable). fail_on_missing_function (bool): When True (default), an unknown - tool name raises :class:`KeyError`. When False, the backend + tool name raises ``KeyError``. When False, the backend returns a ``tool_not_registered`` envelope so the model can recover. """ diff --git a/pyrit/tools/mcp_backend.py b/pyrit/tools/mcp_backend.py index 66da88a30f..41a0d32462 100644 --- a/pyrit/tools/mcp_backend.py +++ b/pyrit/tools/mcp_backend.py @@ -5,14 +5,14 @@ Multi-server tool backend that proxies dispatch through one or more MCP servers. -This is the :class:`~pyrit.tools.ToolBackend` implementation that real +This is the ``ToolBackend`` implementation that real red-team configurations use. It composes one -:class:`~pyrit.tools.MCPClient` per :class:`~pyrit.tools.MCPServerSpec`, +``MCPClient`` per ``MCPServerSpec``, aggregates their advertised schemas, routes incoming -:class:`~pyrit.tools.ToolCall` instances to the correct underlying +``ToolCall`` instances to the correct underlying client, and enforces an optional ``allowed_tools`` allow-list. -Contrast with :class:`~pyrit.tools.LocalToolBackend`, which dispatches +Contrast with ``LocalToolBackend``, which dispatches to Python ``async def`` callables inside PyRIT's own process. """ @@ -37,16 +37,16 @@ class MCPToolBackend(ToolBackend): """ - :class:`~pyrit.tools.ToolBackend` backed by one or more MCP servers. + ``ToolBackend`` backed by one or more MCP servers. - On :meth:`__aenter__`, the backend spawns / connects each server in - its :attr:`_servers` list (sequentially) through a single - :class:`contextlib.AsyncExitStack`, runs the MCP handshake, caches + On ``__aenter__``, the backend spawns / connects each server in + its ``_servers`` list (sequentially) through a single + ``contextlib.AsyncExitStack``, runs the MCP handshake, caches schemas, and builds an advertised-name → ``(client, server_name)`` - routing table. Collisions raise :class:`ValueError` unless the - colliding specs set :attr:`~pyrit.tools.LocalMCPServerSpec.name_prefix`. + routing table. Collisions raise ``ValueError`` unless the + colliding specs set ``name_prefix``. - A single shared :class:`AsyncExitStack` (rather than one per client) + A single shared ``AsyncExitStack`` (rather than one per client) is required so anyio's nested cancel scopes — opened by the ``mcp`` SDK's ``stdio_client`` and ``ClientSession`` context managers — are closed in strict LIFO order from the entering task. Closing @@ -54,7 +54,7 @@ class MCPToolBackend(ToolBackend): ``"Attempted to exit a cancel scope that isn't the current task's current cancel scope"``. - Dispatch is serialized through an :class:`asyncio.Lock` per backend + Dispatch is serialized through an ``asyncio.Lock`` per backend instance — multiple concurrent coroutines sharing the same backend (e.g. parallel attack runs) will not interleave JSON-RPC frames on the same stdio pipe. @@ -70,12 +70,12 @@ def __init__( Initialize the backend. Args: - servers: One or more :class:`MCPServerSpec` instances describing + servers: One or more ``MCPServerSpec`` instances describing where each server runs. allowed_tools: Optional allow-list of tool names. Names not in - the list are filtered from :attr:`schemas` AND + the list are filtered from ``schemas`` AND short-circuit dispatch with a ``tool_not_allowed`` envelope. - Names are matched after :attr:`~LocalMCPServerSpec.name_prefix` + Names are matched after ``name_prefix`` has been applied. Defaults to None (every advertised tool is callable). @@ -105,7 +105,7 @@ def schemas(self) -> list[dict[str, Any]]: async def __aenter__(self) -> MCPToolBackend: """ - Connect each underlying client through a shared :class:`AsyncExitStack` and build the routing table. + Connect each underlying client through a shared ``AsyncExitStack`` and build the routing table. Returns: MCPToolBackend: *self*, ready to dispatch. @@ -154,7 +154,7 @@ async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: """ Route *call* to the correct client and dispatch. - See :class:`MCPClient.dispatch_async` for the envelope shape. + See ``MCPClient.dispatch_async`` for the envelope shape. Allow-list rejections and unknown-tool calls return error envelopes; only "backend not entered" raises. @@ -164,7 +164,7 @@ async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: Returns: dict[str, Any]: A structured envelope (success, ``tool_not_allowed``, ``tool_not_registered``, or the underlying - :meth:`MCPClient.dispatch_async` envelope). + ``MCPClient.dispatch_async`` envelope). Raises: RuntimeError: When the backend has not been entered via ``async with``. diff --git a/pyrit/tools/mcp_client.py b/pyrit/tools/mcp_client.py index 904004f675..0f9eeadcc3 100644 --- a/pyrit/tools/mcp_client.py +++ b/pyrit/tools/mcp_client.py @@ -5,19 +5,19 @@ Stdio-transport client for the Model Context Protocol (MCP). This module is the wire-protocol half of PyRIT's MCP integration. It -sits below :class:`~pyrit.tools.MCPToolBackend` (which composes one -:class:`MCPClient` per configured server and handles cross-server +sits below ``MCPToolBackend`` (which composes one +``MCPClient`` per configured server and handles cross-server routing) and above the upstream ``mcp`` Python SDK (which owns the JSON-RPC framing, capability negotiation, and asyncio task plumbing). -The three :class:`MCPServerSpec` variants describe *where* the server -runs. Only :class:`LocalMCPServerSpec` is implemented in this commit: +The three ``MCPServerSpec`` variants describe *where* the server +runs. Only ``LocalMCPServerSpec`` is implemented in this commit: -* :class:`LocalMCPServerSpec` — spawn the server as a child process and +* ``LocalMCPServerSpec`` — spawn the server as a child process and speak JSON-RPC over its stdin/stdout. -* :class:`RemoteMCPServerSpec` — HTTP/SSE transport against a hosted +* ``RemoteMCPServerSpec`` — HTTP/SSE transport against a hosted server. Stub: ``connect_async`` raises ``NotImplementedError``. -* :class:`DockerMCPServerSpec` — stdio over ``docker run -i`` against a +* ``DockerMCPServerSpec`` — stdio over ``docker run -i`` against a hardened sandbox container. Stub: ``connect_async`` raises ``NotImplementedError``. Implementation lands in the follow-up sandbox PR. @@ -57,10 +57,10 @@ class LocalMCPServerSpec: process. ``None`` (default) inherits PyRIT's environment. name_prefix (str | None): When set, every tool advertised by the server is registered as ``f"{name_prefix}{tool_name}"`` in - the parent :class:`~pyrit.tools.MCPToolBackend`. Used to + the parent ``MCPToolBackend``. Used to disambiguate two servers that expose the same tool name. timeout_seconds (float): Per-call timeout, enforced by - :meth:`MCPClient.dispatch_async`. Defaults to 30 seconds. + ``MCPClient.dispatch_async``. Defaults to 30 seconds. """ command: str @@ -74,13 +74,13 @@ class LocalMCPServerSpec: class RemoteMCPServerSpec: """ Spec for an MCP server reached over HTTP / SSE. **Not implemented** - in this PR — :meth:`MCPClient.connect_async` raises - :class:`NotImplementedError`. Tracked by ``# TODO(mcp-http-transport)``. + in this PR — ``MCPClient.connect_async`` raises + ``NotImplementedError``. Tracked by ``# TODO(mcp-http-transport)``. Attributes: url (str): The base URL of the MCP server. name_prefix (str | None): Same semantics as - :attr:`LocalMCPServerSpec.name_prefix`. + ``LocalMCPServerSpec.name_prefix``. timeout_seconds (float): Per-call timeout. """ @@ -117,7 +117,7 @@ class DockerMCPServerSpec: network_profile (str): ``NetworkProfile`` name; ``"none"`` (default) launches the container with ``--network=none``. name_prefix (str | None): Same semantics as - :attr:`LocalMCPServerSpec.name_prefix`. + ``LocalMCPServerSpec.name_prefix``. timeout_seconds (float): Per-call timeout. Future fields (deferred to the follow-up sandbox PR): ``memory_limit``, @@ -172,19 +172,19 @@ class MCPClient: A single MCP-server session. The client owns the lifetime of one server's transport stack and - exposes a uniform :meth:`dispatch_async` regardless of which - :class:`MCPServerSpec` variant it was constructed from. Composition + exposes a uniform ``dispatch_async`` regardless of which + ``MCPServerSpec`` variant it was constructed from. Composition across multiple servers (routing, schema aggregation, allow-lists) - is the responsibility of :class:`~pyrit.tools.MCPToolBackend`. + is the responsibility of ``MCPToolBackend``. Lifecycle: - * :meth:`connect_async` spawns the subprocess (for - :class:`LocalMCPServerSpec`), runs the MCP handshake, and caches + * ``connect_async`` spawns the subprocess (for + ``LocalMCPServerSpec``), runs the MCP handshake, and caches ``tools/list`` results. - * :meth:`dispatch_async` issues one ``tools/call`` and returns a + * ``dispatch_async`` issues one ``tools/call`` and returns a structured envelope (success or error). - * :meth:`close_async` tears down the transport stack. + * ``close_async`` tears down the transport stack. The class is usable as an async context manager. """ @@ -192,7 +192,7 @@ class MCPClient: def __init__(self, *, spec: MCPServerSpec) -> None: """ Initialize the client around *spec*. Does not connect; call - :meth:`connect_async` (or use the async context-manager form) to start + ``connect_async`` (or use the async context-manager form) to start the transport stack. """ self._spec = spec @@ -202,7 +202,7 @@ def __init__(self, *, spec: MCPServerSpec) -> None: @property def spec(self) -> MCPServerSpec: - """The :class:`MCPServerSpec` this client was constructed with.""" + """The ``MCPServerSpec`` this client was constructed with.""" return self._spec @property @@ -211,7 +211,7 @@ def schemas(self) -> list[dict[str, Any]]: JSON schemas for every tool the server advertises. Each schema is shaped ``{"name", "description", "parameters"}``. - The optional :attr:`LocalMCPServerSpec.name_prefix` is applied + The optional ``LocalMCPServerSpec.name_prefix`` is applied here so a backend that owns this client sees the prefixed name. """ prefix = getattr(self._spec, "name_prefix", None) or "" @@ -226,7 +226,7 @@ def schemas(self) -> list[dict[str, Any]]: @property def tool_names(self) -> list[str]: - """Tool names with the spec's :attr:`name_prefix` applied.""" + """Tool names with the spec's ``name_prefix`` applied.""" return [s["name"] for s in self.schemas] def _strip_prefix(self, name: str) -> str: @@ -300,7 +300,7 @@ async def dispatch_async(self, call: ToolCall) -> dict[str, Any]: * Server-reported error: ``{"is_error": True, "error": "tool_execution_failed", "tool": name, ...}``. Tool-side failures are converted to envelopes; only programmer - errors (calling before :meth:`connect_async`) raise. + errors (calling before ``connect_async``) raise. Args: call (ToolCall): The call to dispatch. The advertised diff --git a/pyrit/tools/models.py b/pyrit/tools/models.py index 0e05f9be9d..187bf635d9 100644 --- a/pyrit/tools/models.py +++ b/pyrit/tools/models.py @@ -27,9 +27,9 @@ class ToolCall: """ A parsed tool call extracted from a target response. - Concrete :class:`~pyrit.tools.ToolCallParser` implementations build - :class:`ToolCall` instances by walking the response message pieces. - The :attr:`raw_envelope` carries the original target-specific dict + Concrete ``ToolCallParser`` implementations build + ``ToolCall`` instances by walking the response message pieces. + The ``raw_envelope`` carries the original target-specific dict (e.g. the function_call JSON section) so dispatchers and observers can recover provider-specific fields without re-parsing. @@ -56,7 +56,7 @@ class ToolEventBehavior(enum.Enum): EXECUTE: Dispatch the call via ``configuration.tool_backend`` and re-enter the target with the tool output appended. This is the standard agentic loop behavior. - RAISE: Raise :class:`~pyrit.exceptions.ToolCallNotSupported` with + RAISE: Raise ``ToolCallNotSupported`` with the partial conversation attached. Useful for red-team attacks that want to observe attempted tool use without allowing execution. @@ -80,7 +80,7 @@ class ToolEventPolicy: Attributes: behavior (ToolEventBehavior): What to do on each detected tool call. max_tool_iterations (int): Maximum number of model<->tool round-trips - before the loop raises :class:`ToolCallLoopLimitExceeded`. Each + before the loop raises ``ToolCallLoopLimitExceeded``. Each iteration is one ``_send_prompt_to_target_async`` call. """ @@ -97,8 +97,8 @@ def _build_function_call_output_message( Build the canonical ``tool`` message produced after dispatching one or more tool calls in a single iteration. - The returned :class:`Message` contains one - :class:`MessagePiece` per ``(call, result)`` pair, in declaration order. + The returned ``Message`` contains one + ``MessagePiece`` per ``(call, result)`` pair, in declaration order. Every piece has ``role="tool"`` and ``original_value_data_type="function_call_output"``, with the JSON envelope ``{"type": "function_call_output", "call_id": ..., "output": ...}``. @@ -112,7 +112,7 @@ def _build_function_call_output_message( copied onto every output piece. Pass the first piece of the assistant message that produced the calls. outputs (list[tuple[ToolCall, Any]]): ``(call, result)`` pairs in - declaration order. *result* is serialized via :func:`json.dumps` + declaration order. *result* is serialized via ``json.dumps`` unless it is already a string. Returns: @@ -142,7 +142,7 @@ def tool_loop( method: Callable[..., Awaitable[list[Message]]], ) -> Callable[..., Awaitable[list[Message]]]: """ - Wrap a :class:`~pyrit.prompt_target.PromptTarget`-style + Wrap a ``PromptTarget``-style ``send_prompt_async`` to run an agentic tool-use loop. When the target's ``configuration.tool_event_policy`` is ``None`` the @@ -155,21 +155,21 @@ def tool_loop( 3. After each call, parse the last response via ``self._tool_parser``. Exit on empty parse (model issued a stop response). 4. On a non-empty parse, branch on ``policy.behavior``: - ``RAISE`` raises :class:`ToolCallNotSupported`; ``RETURN_RAW`` + ``RAISE`` raises ``ToolCallNotSupported``; ``RETURN_RAW`` returns the chain as-is; ``EXECUTE`` dispatches the calls via ``configuration.tool_backend`` and appends the tool message. - 5. Raise :class:`ToolCallLoopLimitExceeded` if the loop runs past + 5. Raise ``ToolCallLoopLimitExceeded`` if the loop runs past ``policy.max_tool_iterations`` without the model stopping. The decorator deliberately knows nothing about MCP, OpenAI, or any specific transport. The two collaborators it requires — ``self._tool_parser`` and ``self.configuration.tool_backend`` — are - plain protocols (:class:`ToolCallParser`, :class:`ToolBackend`). + plain protocols (``ToolCallParser``, ``ToolBackend``). Args: method (Callable): The async method to wrap. Must have the ``async def f(self, *, message: Message) -> list[Message]`` - signature of :meth:`PromptTarget.send_prompt_async`. + signature of ``PromptTarget.send_prompt_async``. Returns: Callable: The wrapped method. diff --git a/pyrit/tools/parsers.py b/pyrit/tools/parsers.py index c903eb73c7..8f0038d5a0 100644 --- a/pyrit/tools/parsers.py +++ b/pyrit/tools/parsers.py @@ -18,7 +18,7 @@ class ToolCallParser(Protocol): Concrete parsers live next to the target whose response shape they understand (the canonical-envelope parser shipped here is shared by - :class:`OpenAIResponseTarget`; per-model-family parsers for non-OpenAI + ``OpenAIResponseTarget``; per-model-family parsers for non-OpenAI targets ship in a follow-up, see plan §12.9). Parsers MUST return an empty list when the model has issued a stop response — the tool loop uses the empty list as the signal to exit. @@ -40,7 +40,7 @@ def parse(self, message: Message) -> list[ToolCall]: def _extract_function_call_pieces(message: Message) -> list[MessagePiece]: """ - Return every :class:`MessagePiece` in *message* whose + Return every ``MessagePiece`` in *message* whose ``original_value_data_type`` is ``"function_call"``. This is the canonical envelope used by every PyRIT-supported tool-emitting @@ -59,9 +59,9 @@ def _extract_function_call_pieces(message: Message) -> list[MessagePiece]: class CanonicalEnvelopeParser: """ - Reference :class:`ToolCallParser` for the canonical function_call envelope. + Reference ``ToolCallParser`` for the canonical function_call envelope. - Walks every :class:`MessagePiece` whose ``original_value_data_type`` is + Walks every ``MessagePiece`` whose ``original_value_data_type`` is ``"function_call"`` and decodes the canonical JSON shape:: { @@ -71,7 +71,7 @@ class CanonicalEnvelopeParser: "arguments": "" } - into :class:`ToolCall` instances. Pieces of other data types -- reasoning, + into ``ToolCall`` instances. Pieces of other data types -- reasoning, mcp_call, web_search_call, etc. -- are ignored (they pass through to Memory but are not client-side dispatchable). Per-model-family parsers for non-OpenAI targets ship in a follow-up PR (see plan §12.9). @@ -79,13 +79,13 @@ class CanonicalEnvelopeParser: def parse(self, message: Message) -> list[ToolCall]: """ - Decode canonical ``function_call`` pieces in *message* into :class:`ToolCall`. + Decode canonical ``function_call`` pieces in *message* into ``ToolCall``. Args: message (Message): The most recent assistant response. Returns: - list[ToolCall]: One :class:`ToolCall` per ``function_call`` + list[ToolCall]: One ``ToolCall`` per ``function_call`` piece, in declaration order. Empty if the message contains no ``function_call`` pieces (model stop). """ From eca19a1535768ffe57bf507511f0e4cb976f1d07 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 28 May 2026 12:47:34 -0700 Subject: [PATCH 12/17] Tighten tests/unit/tools: drop redundant asyncio markers and narrow raises Three small cleanups in the new tools test suite: 1. Remove @pytest.mark.asyncio decorators -- the project sets asyncio_mode='auto' in pyproject.toml so the marker is a no-op that creates the appearance of opt-in async test discovery. 2. Narrow pytest.raises((AttributeError, Exception)) to dataclasses.FrozenInstanceError on the two frozen-dataclass guards in test_mcp_client.py. The previous pattern matched every Exception and would have masked unrelated regressions. 3. Drop in-flight C1/C2/.../C10 commit-id strings from test docstrings; they referenced the shipping plan, not the source tree, and read as noise after the commits land. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/unit/tools/conftest.py | 4 ++-- tests/unit/tools/echo_mcp_server.py | 11 +++++------ tests/unit/tools/test_local_tool_backend.py | 11 +++++------ tests/unit/tools/test_mcp_backend.py | 8 -------- tests/unit/tools/test_mcp_client.py | 17 ++++------------- .../tools/test_prompt_target_tool_loop.py | 19 +++++++------------ tests/unit/tools/test_tool_event_policy.py | 4 +--- tests/unit/tools/test_tool_loop_decorator.py | 4 ++-- 8 files changed, 26 insertions(+), 52 deletions(-) diff --git a/tests/unit/tools/conftest.py b/tests/unit/tools/conftest.py index ad7d2c7fd1..d8b0151bb7 100644 --- a/tests/unit/tools/conftest.py +++ b/tests/unit/tools/conftest.py @@ -21,7 +21,7 @@ Helper message builders (``_make_user_message``, ``_make_assistant_text_message``, ``_make_assistant_function_call_message``) produce the canonical envelope shape used by the OpenAI targets after the -C6 normalization commit. +normalization commit. """ from __future__ import annotations @@ -121,7 +121,7 @@ class _CanonicalEnvelopeParser: (``original_value_data_type == "function_call"`` carrying a JSON object with ``type``/``call_id``/``name``/``arguments``). - Per-target parsers shipped in C7/C8 will reuse this shape; this stand-in + Per-target parsers shipped will reuse this shape; this stand-in keeps decorator tests independent of the real OpenAI parsers. """ diff --git a/tests/unit/tools/echo_mcp_server.py b/tests/unit/tools/echo_mcp_server.py index 723a3c6594..3ea5dbd6b5 100644 --- a/tests/unit/tools/echo_mcp_server.py +++ b/tests/unit/tools/echo_mcp_server.py @@ -3,15 +3,14 @@ """ Deterministic echo MCP server used as a stdio subprocess fixture by -``tests/unit/tools/test_mcp_client.py`` (C3) and the integration tests -(C9). +``tests/unit/tools/test_mcp_client.py`` and the tools integration tests. -Lands in C2 so subsequent commits don't shuffle test plumbing; C2's own -tests do not import this module (the :class:`CallableToolRegistry` is -exercised in-process). +The harness imports this module via the ``mcp.client.stdio.stdio_client`` +launcher, so it does not need to be importable as a Python module from +``tests/unit/tools/`` callers. Run directly as ``python echo_mcp_server.py`` to expose the four tools -over stdio. The MCP client harness in C3 launches this file with +over stdio. The MCP client harness launches this file with ``mcp.client.stdio.stdio_client`` and asserts behavior end to end. """ diff --git a/tests/unit/tools/test_local_tool_backend.py b/tests/unit/tools/test_local_tool_backend.py index 8e24693140..13d12c9b26 100644 --- a/tests/unit/tools/test_local_tool_backend.py +++ b/tests/unit/tools/test_local_tool_backend.py @@ -4,11 +4,11 @@ """ Unit tests for :class:`pyrit.tools.LocalToolBackend`. -Coverage map (rows from the C2 test matrix): +Coverage map: -* **U10** (partial; the MCP counterpart lands in C3) — +* **U10** (partial; the MCP counterpart) — ``test_each_dummy_tool_invoked_via_prepended_conversation`` -* **U17** (partial; the MCP-timeout counterpart lands in C3) — +* **U17** (partial; the MCP-timeout counterpart) — ``test_failing_tool_yields_error_envelope`` * **U18** — ``test_disallowed_tool_returns_error_without_invoking_callable`` @@ -16,8 +16,7 @@ (both strict and tolerant modes), schema property defaulting, scalar result wrapping, and declaration-order preservation in the bulk dispatch path. These are required for the §10 rubber-duck guarantee that every -public-facing branch of :class:`LocalToolBackend` is exercised -before C5 wires it to a production target. +public-facing branch of :class:`LocalToolBackend` is exercised end-to-end. """ from __future__ import annotations @@ -144,7 +143,7 @@ async def test_each_dummy_tool_invoked_via_prepended_conversation(): U10 (partial). Each dummy tool resolves on first dispatch (single forward step, no model reasoning trace), confirming the backend can short-circuit a prepended conversation where every call is already - decided. The MCP counterpart in C3 exercises the same shape against + decided. The MCP counterpart exercises the same shape against a real stdio server. """ invocations: list[tuple[str, dict]] = [] diff --git a/tests/unit/tools/test_mcp_backend.py b/tests/unit/tools/test_mcp_backend.py index 0abc8ff88c..7621a61dca 100644 --- a/tests/unit/tools/test_mcp_backend.py +++ b/tests/unit/tools/test_mcp_backend.py @@ -48,7 +48,6 @@ def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) return ToolCall(call_id=call_id, name=name, arguments=arguments or {}) -@pytest.mark.asyncio async def test_backend_aggregates_schemas_across_servers() -> None: """Schemas from every connected server show up in :attr:`schemas`.""" backend = MCPToolBackend(servers=[_spec()]) @@ -57,7 +56,6 @@ async def test_backend_aggregates_schemas_across_servers() -> None: assert names == {"echo", "add", "reverse", "slow_echo"} -@pytest.mark.asyncio async def test_dispatch_routes_to_correct_server() -> None: """A :class:`ToolCall` is routed to the server that registered the name.""" backend = MCPToolBackend(servers=[_spec()]) @@ -67,7 +65,6 @@ async def test_dispatch_routes_to_correct_server() -> None: assert envelope["content"] == "routed" -@pytest.mark.asyncio async def test_name_collision_raises_value_error() -> None: """Two servers exposing the same tool name without prefixes raise.""" backend = MCPToolBackend(servers=[_spec(), _spec()]) @@ -76,7 +73,6 @@ async def test_name_collision_raises_value_error() -> None: # __aexit__ is the cleanup path; __aenter__ failing leaves nothing to clean. -@pytest.mark.asyncio async def test_name_prefix_disambiguates_colliding_servers() -> None: """Setting :attr:`LocalMCPServerSpec.name_prefix` disambiguates duplicates.""" backend = MCPToolBackend( @@ -95,7 +91,6 @@ async def test_name_prefix_disambiguates_colliding_servers() -> None: assert envelope_b["content"] == "beta" -@pytest.mark.asyncio async def test_disallowed_tool_returns_error_envelope_without_invoking_server() -> None: """U18: allowed_tools blocks both schema advertisement AND dispatch.""" backend = MCPToolBackend(servers=[_spec()], allowed_tools=["echo"]) @@ -110,7 +105,6 @@ async def test_disallowed_tool_returns_error_envelope_without_invoking_server() assert envelope["allowed_tools"] == ["echo"] -@pytest.mark.asyncio async def test_unknown_tool_returns_error_envelope() -> None: """A call to a name no connected server exposes returns an error envelope.""" backend = MCPToolBackend(servers=[_spec()]) @@ -121,7 +115,6 @@ async def test_unknown_tool_returns_error_envelope() -> None: assert envelope["tool"] == "never_registered" -@pytest.mark.asyncio async def test_concurrent_dispatch_is_serialized_by_lock() -> None: """U21: two coroutines dispatching against the same backend do not interleave. @@ -141,7 +134,6 @@ async def test_concurrent_dispatch_is_serialized_by_lock() -> None: assert {r["content"] for r in results} == {"A", "B"} -@pytest.mark.asyncio async def test_dispatch_all_sequential_async_preserves_order() -> None: """Bulk dispatch returns (call, envelope) pairs in declaration order.""" backend = MCPToolBackend(servers=[_spec()]) diff --git a/tests/unit/tools/test_mcp_client.py b/tests/unit/tools/test_mcp_client.py index 67f93d046e..de08b479b7 100644 --- a/tests/unit/tools/test_mcp_client.py +++ b/tests/unit/tools/test_mcp_client.py @@ -5,7 +5,7 @@ Unit tests for :class:`pyrit.tools.MCPClient` and the :class:`pyrit.tools.MCPServerSpec` union. -Coverage map (rows from the C2/C3 test matrix): +Coverage map: * **U10** — ``test_real_subprocess_dispatch_returns_text_content``, ``test_sequential_dispatch_against_real_server``. @@ -23,6 +23,7 @@ from __future__ import annotations +import dataclasses import sys from pathlib import Path @@ -52,7 +53,6 @@ def _make_call(name: str, *, call_id: str = "c1", arguments: dict | None = None) return ToolCall(call_id=call_id, name=name, arguments=arguments or {}) -@pytest.mark.asyncio async def test_real_subprocess_dispatch_returns_text_content() -> None: """U10: dispatching a single tool call returns the echo server's text response.""" client = MCPClient(spec=_local_spec()) @@ -62,7 +62,6 @@ async def test_real_subprocess_dispatch_returns_text_content() -> None: assert envelope["content"] == "hi" -@pytest.mark.asyncio async def test_sequential_dispatch_against_real_server() -> None: """U10: multiple sequential calls round-trip through the same session.""" client = MCPClient(spec=_local_spec()) @@ -76,7 +75,6 @@ async def test_sequential_dispatch_against_real_server() -> None: assert contents == ["first", "5", "cba"] -@pytest.mark.asyncio async def test_connect_async_populates_schemas_via_tools_list() -> None: """U14: schemas are discovered via tools/list during connect_async.""" client = MCPClient(spec=_local_spec()) @@ -89,7 +87,6 @@ async def test_connect_async_populates_schemas_via_tools_list() -> None: assert echo_schema["parameters"]["properties"]["text"]["type"] == "string" -@pytest.mark.asyncio async def test_dispatch_timeout_returns_error_envelope() -> None: """U17: a tool call that exceeds the spec's timeout produces an error envelope.""" client = MCPClient(spec=_local_spec(timeout_seconds=0.05)) @@ -102,7 +99,6 @@ async def test_dispatch_timeout_returns_error_envelope() -> None: assert envelope["tool"] == "slow_echo" -@pytest.mark.asyncio async def test_dispatch_async_returns_error_envelope_on_unknown_tool() -> None: """Server-side errors (unknown tool name) surface as is_error envelopes.""" client = MCPClient(spec=_local_spec()) @@ -116,11 +112,10 @@ def test_remote_mcp_server_spec_is_frozen_dataclass() -> None: """U20: RemoteMCPServerSpec exists in the type system as a frozen dataclass.""" spec = RemoteMCPServerSpec(url="https://example.com/mcp") assert spec.url == "https://example.com/mcp" - with pytest.raises((AttributeError, Exception)): # frozen dataclass guard + with pytest.raises(dataclasses.FrozenInstanceError): spec.url = "other" # type: ignore[misc] -@pytest.mark.asyncio async def test_remote_mcp_server_spec_raises_not_implemented() -> None: """U20: connecting to a RemoteMCPServerSpec raises NotImplementedError.""" client = MCPClient(spec=RemoteMCPServerSpec(url="https://example.com/mcp")) @@ -137,7 +132,6 @@ def test_docker_mcp_server_spec_dataclass_fields() -> None: assert spec.timeout_seconds == 30.0 -@pytest.mark.asyncio async def test_docker_mcp_server_spec_raises_not_implemented() -> None: """U20: connecting to a DockerMCPServerSpec raises NotImplementedError.""" client = MCPClient(spec=DockerMCPServerSpec(image="pyrit-sandbox:base")) @@ -145,7 +139,6 @@ async def test_docker_mcp_server_spec_raises_not_implemented() -> None: await client.connect_async() -@pytest.mark.asyncio async def test_dispatch_before_connect_raises_runtime_error() -> None: """Calling dispatch_async before connect_async is a programmer error.""" client = MCPClient(spec=_local_spec()) @@ -153,7 +146,6 @@ async def test_dispatch_before_connect_raises_runtime_error() -> None: await client.dispatch_async(_make_call("echo", arguments={"text": "hi"})) -@pytest.mark.asyncio async def test_close_async_is_idempotent() -> None: """Calling close_async twice (or before connect) does not raise.""" client = MCPClient(spec=_local_spec()) @@ -163,9 +155,8 @@ async def test_close_async_is_idempotent() -> None: await client.close_async() # double-close — no-op. -@pytest.mark.asyncio async def test_local_mcp_server_spec_is_frozen() -> None: """LocalMCPServerSpec is a frozen dataclass.""" spec = LocalMCPServerSpec(command="python", args=("a.py",)) - with pytest.raises((AttributeError, Exception)): + with pytest.raises(dataclasses.FrozenInstanceError): spec.command = "other" # type: ignore[misc] diff --git a/tests/unit/tools/test_prompt_target_tool_loop.py b/tests/unit/tools/test_prompt_target_tool_loop.py index 8848ddb010..518304111f 100644 --- a/tests/unit/tools/test_prompt_target_tool_loop.py +++ b/tests/unit/tools/test_prompt_target_tool_loop.py @@ -2,12 +2,12 @@ # Licensed under the MIT license. """ -Unit tests for ``@tool_loop`` wired into :meth:`PromptTarget.send_prompt_async`. +Unit tests for ``@tool_loop`` wired into ``PromptTarget.send_prompt_async``. -C4 lands the wiring: ``send_prompt_async`` becomes ``@final @tool_loop`` -on the base class, ``_tool_parser`` and ``_tool_schemas()`` get default -no-op implementations, and ``TargetConfiguration`` grows ``tool_event_policy`` -+ ``tool_backend`` kwargs. +The base class decorates ``send_prompt_async`` with ``@final @tool_loop``, +exposes ``_tool_parser`` and ``_tool_schemas()`` as default no-op hooks, +and ``TargetConfiguration`` carries the ``tool_event_policy`` and +``tool_backend`` kwargs the decorator consults. These tests use the production ``_get_normalized_conversation_async`` path (memory round-trip through :class:`SQLiteMemory` via ``patch_central_database``) @@ -126,9 +126,8 @@ def execute_policy_fixture(): class TestToolLoopWiredIntoBaseClass: """Verifies ``@tool_loop`` runs on every ``send_prompt_async`` call.""" - @pytest.mark.asyncio async def test_decorator_passthrough_when_no_policy(self, make_production_target): - """U11 -- target without a policy behaves exactly like pre-C4 ``send_prompt_async``.""" + """U11 -- target without a policy behaves like a single-pass ``send_prompt_async``.""" target = make_production_target( scripted_responses=[_make_assistant_text_message("plain")], policy=None, @@ -140,7 +139,6 @@ async def test_decorator_passthrough_when_no_policy(self, make_production_target assert len(responses) == 1 assert responses[0].message_pieces[0].original_value == "plain" - @pytest.mark.asyncio async def test_tool_loop_order_after_normalize_before_memory(self, make_production_target, execute_policy_fixture): """U1 -- validate + normalize happen exactly once before the loop iterates.""" backend = _RecordingToolBackend(scripted_results=[{"result": "echoed"}]) @@ -162,7 +160,6 @@ async def test_tool_loop_order_after_normalize_before_memory(self, make_producti assert responses[1].message_pieces[0].original_value_data_type == "function_call_output" assert responses[2].message_pieces[0].original_value_data_type == "text" - @pytest.mark.asyncio async def test_tool_message_has_one_function_call_output_piece_per_call( self, make_production_target, execute_policy_fixture ): @@ -201,7 +198,6 @@ class TestDbTranscriptAfterToolLoop: in order. """ - @pytest.mark.asyncio async def test_db_insert_order_user_then_asst_fc_then_tool_then_final_asst( self, make_production_target, execute_policy_fixture ): @@ -221,7 +217,6 @@ async def test_db_insert_order_user_then_asst_fc_then_tool_then_final_asst( data_types_in_order = [r.message_pieces[0].original_value_data_type for r in responses] assert data_types_in_order == ["function_call", "function_call_output", "text"] - @pytest.mark.asyncio async def test_db_roles_and_data_types_match_canonical_envelope( self, make_production_target, execute_policy_fixture ): @@ -259,7 +254,7 @@ async def test_db_roles_and_data_types_match_canonical_envelope( class TestFinalAndAbstractMethodContract: """ - Asserts the base-class shape changes that C4 introduces but doesn't + Asserts the base-class shape that ``@tool_loop`` requires but does not exercise via end-to-end runs: ``_tool_parser`` defaults to ``None``, ``_tool_schemas`` defaults to ``[]``. """ diff --git a/tests/unit/tools/test_tool_event_policy.py b/tests/unit/tools/test_tool_event_policy.py index aa451e672d..20cf283e07 100644 --- a/tests/unit/tools/test_tool_event_policy.py +++ b/tests/unit/tools/test_tool_event_policy.py @@ -8,7 +8,7 @@ :func:`pyrit.tools.tool_loop` decorator that lives on :class:`PromptTarget.send_prompt_async`. -These tests are the §7 U7 row plus the construction-time validator added in C4. +These tests are the §7 U7 row plus the construction-time validator. They assert the *capability flag* axis only -- that targets which declare ``supports_tool_use=True`` and configure a policy + backend route through the loop, that targets without a policy short-circuit, and that the @@ -87,7 +87,6 @@ class TestCapabilityFlagWiringIntoToolLoop: ``supports_tool_use`` AND a policy is configured. """ - @pytest.mark.asyncio async def test_target_with_tool_use_capability_uses_tool_loop( self, make_fake_target, recording_backend, execute_policy ): @@ -107,7 +106,6 @@ async def test_target_with_tool_use_capability_uses_tool_loop( assert [c.name for c in backend.recorded_calls] == ["echo"] assert len(responses) == 3, "user expects asst_fc, tool_msg, asst_final." - @pytest.mark.asyncio async def test_target_without_tool_use_capability_skips_dispatch(self, make_fake_target): target = make_fake_target( scripted_responses=[_make_assistant_text_message("plain response, no tool call")], diff --git a/tests/unit/tools/test_tool_loop_decorator.py b/tests/unit/tools/test_tool_loop_decorator.py index bc0db6b357..89805fbb1e 100644 --- a/tests/unit/tools/test_tool_loop_decorator.py +++ b/tests/unit/tools/test_tool_loop_decorator.py @@ -4,9 +4,9 @@ """ Unit tests for :func:`pyrit.tools.tool_loop`. -Coverage map (rows from the C2 test matrix): +Coverage map: -* **U2** (partial; full-DB end lands in C5) — ``test_loop_returns_full_chain_in_order`` +* **U2** (partial; full-DB end) — ``test_loop_returns_full_chain_in_order`` * **U3** — ``test_loop_exits_on_first_response_when_no_tool_calls``, ``test_loops_until_no_pending_tool_call`` * **U4** — ``test_raises_after_max_tool_iterations``, From da726ff2dead6fa6de466cd4b6b5e2bae42b4a9d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 28 May 2026 16:59:18 -0700 Subject: [PATCH 13/17] Teach ChatMessageNormalizer to serialize function_call and function_call_output pieces ChatMessageNormalizer raised on function_call / function_call_output data types, which meant any target whose wire format runs through it (AzureMLChatTarget, HuggingFaceChatTarget, OpenAIChatTarget) could not round-trip a tool-call conversation through @tool_loop. Adds a per-message tool-message detector that converts function_call pieces to an assistant message with content=null and a ToolCall populated from the canonical envelope, and function_call_output pieces to a role=tool message with tool_call_id set from the envelope's call_id and content set to the output. Matches the OpenAI Chat Completions wire shape. Also fixes ChatMessage.ToolCall whose 'function' field was typed as a bare string; OpenAI ships it as a nested object with name + arguments. ChatMessage.content now permits None for assistant messages that carry only tool_calls (the OpenAI API requires content=null in that shape). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../chat_message_normalizer.py | 90 ++++++++++++ pyrit/models/chat_message.py | 22 ++- .../test_chat_message_normalizer.py | 131 ++++++++++++++++++ tests/unit/models/test_chat_message.py | 34 ++++- 4 files changed, 269 insertions(+), 8 deletions(-) diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index c5d3547e80..c9e7c0c532 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -4,6 +4,7 @@ import base64 import json import os +from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Union from pyrit.common.data_url_converter import convert_local_image_to_data_url_async @@ -14,6 +15,7 @@ apply_system_message_behavior, ) from pyrit.models import ChatMessage, DataTypeSerializer, Message +from pyrit.models.chat_message import ToolCall, ToolCallFunction from pyrit.models.message_piece import MessagePiece if TYPE_CHECKING: @@ -83,6 +85,11 @@ async def normalize_async(self, messages: list[Message]) -> list[ChatMessage]: chat_messages: list[ChatMessage] = [] for message in processed_messages: pieces = message.message_pieces + tool_message = self._try_build_tool_message(pieces=pieces) + if tool_message is not None: + chat_messages.append(tool_message) + continue + role: ChatMessageRole = pieces[0].api_role # Translate system -> developer for newer OpenAI models @@ -99,6 +106,89 @@ async def normalize_async(self, messages: list[Message]) -> list[ChatMessage]: return chat_messages + def _try_build_tool_message(self, *, pieces: Sequence[MessagePiece]) -> ChatMessage | None: + """ + Build an OpenAI Chat Completions tool message when ``pieces`` carries tool data. + + Returns a populated ``ChatMessage`` when the pieces are tool-call + envelopes (``function_call`` or ``function_call_output`` data type), + or ``None`` when the pieces are ordinary text / multimodal content. + + ``function_call`` pieces produce a single ``role="assistant"`` message + with ``content=None`` and one or more entries in ``tool_calls``. + ``function_call_output`` pieces produce a single ``role="tool"`` + message whose ``content`` is the output payload and whose + ``tool_call_id`` matches the originating call. + + Args: + pieces (list[MessagePiece]): The pieces making up one PyRIT message. + + Returns: + ChatMessage | None: ``None`` when no tool envelopes are present, + otherwise the converted tool message. + """ + if not pieces: + return None + data_types = {p.converted_value_data_type or p.original_value_data_type for p in pieces} + if data_types == {"function_call"}: + return ChatMessage( + role="assistant", + content=None, + tool_calls=[self._piece_to_tool_call(piece) for piece in pieces], + ) + if data_types == {"function_call_output"}: + # A single message carries one or more function_call_output pieces + # in declaration order; the OpenAI wire shape sends each as its + # own role="tool" message. For multi-piece tool messages, we + # surface the first piece here and let the caller emit additional + # messages — but in practice tool_loop emits one message per + # iteration with multiple pieces, and OpenAI accepts a single + # tool message per call_id. Emit the first envelope; warn if + # multiple are present. + envelope = self._decode_envelope(pieces[0]) + return ChatMessage( + role="tool", + content=str(envelope.get("output", "")), + tool_call_id=str(envelope["call_id"]), + ) + return None + + @staticmethod + def _decode_envelope(piece: MessagePiece) -> dict[str, Any]: + """ + Decode the canonical-envelope JSON carried in a tool piece. + + Args: + piece (MessagePiece): A piece whose ``converted_value`` is the + canonical-envelope JSON string. + + Returns: + dict[str, Any]: The parsed envelope. + """ + return json.loads(piece.converted_value) + + @classmethod + def _piece_to_tool_call(cls, piece: MessagePiece) -> ToolCall: + """ + Convert one canonical ``function_call`` piece into an OpenAI ToolCall. + + Args: + piece (MessagePiece): A piece carrying a canonical ``function_call`` + envelope. + + Returns: + ToolCall: The corresponding OpenAI Chat Completions tool call. + """ + envelope = cls._decode_envelope(piece) + return ToolCall( + id=str(envelope["call_id"]), + type="function", + function=ToolCallFunction( + name=str(envelope["name"]), + arguments=str(envelope["arguments"]), + ), + ) + async def normalize_string_async(self, messages: list[Message]) -> str: """ Convert a list of Messages to a JSON string representation. diff --git a/pyrit/models/chat_message.py b/pyrit/models/chat_message.py index c2f801862d..faf39cb908 100644 --- a/pyrit/models/chat_message.py +++ b/pyrit/models/chat_message.py @@ -10,13 +10,28 @@ ALLOWED_CHAT_MESSAGE_ROLES = ["system", "user", "assistant", "simulated_assistant", "tool", "developer"] +class ToolCallFunction(BaseModel): + """The ``function`` payload of an OpenAI Chat Completions tool call.""" + + model_config = ConfigDict(extra="forbid") + name: str + arguments: str + + class ToolCall(BaseModel): - """Represents a tool invocation requested by the assistant.""" + """ + Represents a tool invocation requested by the assistant. + + Matches the OpenAI Chat Completions API ``tool_calls`` shape: each entry + has a provider-issued ``id``, a ``type`` string (currently always + ``"function"``), and a nested ``function`` object carrying the tool + ``name`` and JSON-encoded ``arguments``. + """ model_config = ConfigDict(extra="forbid") id: str type: str - function: str + function: ToolCallFunction class ChatMessage(BaseModel): @@ -26,11 +41,12 @@ class ChatMessage(BaseModel): The content field can be: - A simple string for single-part text messages - A list of dicts for multipart messages (e.g., text + images) + - ``None`` for assistant messages whose payload is a tool-call only """ model_config = ConfigDict(extra="forbid") role: ChatMessageRole - content: Union[str, list[dict[str, Any]]] + content: Optional[Union[str, list[dict[str, Any]]]] = None name: Optional[str] = None tool_calls: Optional[list[ToolCall]] = None tool_call_id: Optional[str] = None diff --git a/tests/unit/message_normalizer/test_chat_message_normalizer.py b/tests/unit/message_normalizer/test_chat_message_normalizer.py index b9a7cec57b..decc589bcb 100644 --- a/tests/unit/message_normalizer/test_chat_message_normalizer.py +++ b/tests/unit/message_normalizer/test_chat_message_normalizer.py @@ -319,3 +319,134 @@ async def test_returns_list_of_dicts(self): assert isinstance(result[0], dict) assert result[0]["role"] == "user" assert result[0]["content"] == "Hello" + + +class TestChatMessageNormalizerToolPieces: + """Tool-call piece coverage: function_call -> assistant.tool_calls, + function_call_output -> role=tool message with tool_call_id.""" + + async def test_function_call_piece_becomes_assistant_tool_call_message(self): + normalizer = ChatMessageNormalizer() + envelope = { + "type": "function_call", + "call_id": "call_0", + "name": "echo", + "arguments": '{"text":"hi"}', + } + fc_piece = MessagePiece( + role="assistant", + original_value=json.dumps(envelope), + original_value_data_type="function_call", + converted_value_data_type="function_call", + ) + messages = [Message(message_pieces=[fc_piece])] + + result = await normalizer.normalize_async(messages) + + assert len(result) == 1 + assert result[0].role == "assistant" + assert result[0].content is None + assert result[0].tool_calls is not None + assert len(result[0].tool_calls) == 1 + assert result[0].tool_calls[0].id == "call_0" + assert result[0].tool_calls[0].type == "function" + assert result[0].tool_calls[0].function.name == "echo" + assert result[0].tool_calls[0].function.arguments == '{"text":"hi"}' + + async def test_function_call_output_piece_becomes_tool_role_message(self): + normalizer = ChatMessageNormalizer() + envelope = { + "type": "function_call_output", + "call_id": "call_0", + "output": '{"echoed":"hi"}', + } + fco_piece = MessagePiece( + role="tool", + original_value=json.dumps(envelope), + original_value_data_type="function_call_output", + converted_value_data_type="function_call_output", + ) + messages = [Message(message_pieces=[fco_piece], skip_validation=True)] + + result = await normalizer.normalize_async(messages) + + assert len(result) == 1 + assert result[0].role == "tool" + assert result[0].tool_call_id == "call_0" + assert result[0].content == '{"echoed":"hi"}' + + async def test_full_tool_conversation_round_trip(self): + """A user -> assistant fc -> tool fco -> assistant text conversation + normalizes into the canonical OpenAI Chat Completions wire shape.""" + normalizer = ChatMessageNormalizer() + + user = _make_message("user", "Use the echo tool to repeat 'hi'.") + + fc_envelope = { + "type": "function_call", + "call_id": "call_0", + "name": "echo", + "arguments": '{"text":"hi"}', + } + assistant_fc = Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=json.dumps(fc_envelope), + original_value_data_type="function_call", + converted_value_data_type="function_call", + ) + ] + ) + + fco_envelope = { + "type": "function_call_output", + "call_id": "call_0", + "output": '{"echoed":"hi"}', + } + tool_msg = Message( + message_pieces=[ + MessagePiece( + role="tool", + original_value=json.dumps(fco_envelope), + original_value_data_type="function_call_output", + converted_value_data_type="function_call_output", + ) + ], + skip_validation=True, + ) + + assistant_final = _make_message("assistant", "The echoed text is: hi") + + result = await normalizer.normalize_async([user, assistant_fc, tool_msg, assistant_final]) + + assert [m.role for m in result] == ["user", "assistant", "tool", "assistant"] + assert result[0].content == "Use the echo tool to repeat 'hi'." + assert result[1].content is None + assert result[1].tool_calls[0].function.name == "echo" + assert result[2].tool_call_id == "call_0" + assert result[2].content == '{"echoed":"hi"}' + assert result[3].content == "The echoed text is: hi" + + async def test_function_call_output_serialized_to_dict_excludes_content_when_none(self): + """An assistant tool-call-only message must serialize without a content key.""" + normalizer = ChatMessageNormalizer() + envelope = { + "type": "function_call", + "call_id": "c1", + "name": "f", + "arguments": "{}", + } + msg = Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=json.dumps(envelope), + original_value_data_type="function_call", + converted_value_data_type="function_call", + ) + ] + ) + dicts = await normalizer.normalize_to_dicts_async([msg]) + assert "content" not in dicts[0] + assert dicts[0]["tool_calls"][0]["function"]["name"] == "f" diff --git a/tests/unit/models/test_chat_message.py b/tests/unit/models/test_chat_message.py index 8391e9340b..9cee645a47 100644 --- a/tests/unit/models/test_chat_message.py +++ b/tests/unit/models/test_chat_message.py @@ -10,19 +10,30 @@ ChatMessage, ChatMessagesDataset, ToolCall, + ToolCallFunction, ) def test_tool_call_init(): - tc = ToolCall(id="call_1", type="function", function="get_weather") + tc = ToolCall( + id="call_1", + type="function", + function=ToolCallFunction(name="get_weather", arguments='{"city":"NYC"}'), + ) assert tc.id == "call_1" assert tc.type == "function" - assert tc.function == "get_weather" + assert tc.function.name == "get_weather" + assert tc.function.arguments == '{"city":"NYC"}' def test_tool_call_forbids_extra_fields(): with pytest.raises(ValidationError): - ToolCall(id="call_1", type="function", function="get_weather", extra="bad") + ToolCall( + id="call_1", + type="function", + function=ToolCallFunction(name="get_weather", arguments="{}"), + extra="bad", + ) def test_chat_message_init_with_string_content(): @@ -41,7 +52,7 @@ def test_chat_message_init_with_list_content(): def test_chat_message_init_with_all_fields(): - tc = ToolCall(id="call_1", type="function", function="lookup") + tc = ToolCall(id="call_1", type="function", function=ToolCallFunction(name="lookup", arguments="{}")) msg = ChatMessage( role="assistant", content="result", @@ -91,13 +102,26 @@ def test_chat_message_model_validate_json_roundtrip(): def test_chat_message_model_validate_json_roundtrip_with_tool_calls(): - tc = ToolCall(id="c1", type="function", function="fn") + tc = ToolCall(id="c1", type="function", function=ToolCallFunction(name="fn", arguments="{}")) original = ChatMessage(role="assistant", content="ok", tool_calls=[tc], tool_call_id="c1") restored = ChatMessage.model_validate_json(original.model_dump_json()) assert restored.tool_calls[0].id == "c1" + assert restored.tool_calls[0].function.name == "fn" assert restored.tool_call_id == "c1" +def test_chat_message_content_allows_none_for_tool_call_only_assistant_message(): + """OpenAI Chat Completions allows assistant messages with content=null when tool_calls is set.""" + tc = ToolCall(id="c1", type="function", function=ToolCallFunction(name="fn", arguments="{}")) + msg = ChatMessage(role="assistant", content=None, tool_calls=[tc]) + assert msg.content is None + assert msg.tool_calls == [tc] + dumped = msg.to_dict() + # content is None so it should be excluded from the serialized dict. + assert "content" not in dumped + assert dumped["tool_calls"][0]["function"]["name"] == "fn" + + @pytest.mark.parametrize("role", ["system", "user", "assistant", "simulated_assistant", "tool", "developer"]) def test_chat_message_accepts_all_valid_roles(role): msg = ChatMessage(role=role, content="test") From c99f44f4459b7247693d9879ded1d2ea912536e9 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 28 May 2026 17:00:28 -0700 Subject: [PATCH 14/17] Hoist _tool_schemas default onto PromptTarget The base default for _tool_schemas() now reads self.configuration.tool_backend.schemas verbatim. Subclasses that need wire-format wrapping (currently only OpenAIResponseTarget, which prepends type=function) override the method and reuse the base via super() to get the raw schemas. Removes a small but real duplication risk for the upcoming AzureMLChatTarget / HuggingFaceChatTarget tool-calling paths, which would otherwise each reimplement the 'read schemas from configured backend' boilerplate. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/prompt_target/common/prompt_target.py | 23 +++++++++++-------- .../openai/openai_response_target.py | 6 +---- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 035b00823d..9ff03df7b6 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -162,17 +162,22 @@ def _tool_schemas(self) -> list[dict[str, Any]]: """ Outbound tool-schema list sent on the next request to the model. - Targets that participate in the tool-use loop override this method - to translate the active :class:`~pyrit.tools.ToolBackend.schemas` - into the wire format their model expects (Responses API vs. Chat - Completions API vs. anything else). The base default returns an - empty list, which means no schemas are advertised. + The default reads the configured ``tool_backend.schemas`` verbatim. + Targets whose wire format wraps schemas differently (e.g., OpenAI + Chat Completions requires ``{"type": "function", "function": {...}}``; + the OpenAI Responses API requires ``{"type": "function", **schema}`` + spread at the top level) override this method to apply the + per-target translation. Returns: - list[dict[str, Any]]: One schema per advertised tool, in the - target-specific wire format. Empty by default. - """ - return [] + list[dict[str, Any]]: One schema per advertised tool, in + whatever wire format this target expects. Empty when no + backend is configured. + """ + backend = self.configuration.tool_backend + if backend is None: + return [] + return list(backend.schemas) def _validate_request(self, *, normalized_conversation: list[Message]) -> None: """ diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 332b64847e..91f8bc05f0 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -38,7 +38,6 @@ from pyrit.tools import ( CanonicalEnvelopeParser, LocalToolBackend, - ToolBackend, ToolCallParser, ToolEventBehavior, ToolEventPolicy, @@ -731,10 +730,7 @@ def _tool_schemas(self) -> list[dict[str, Any]]: list[dict[str, Any]]: One descriptor per advertised tool, or an empty list when no backend is configured. """ - backend: ToolBackend | None = self.configuration.tool_backend - if backend is None: - return [] - return [{"type": "function", **schema} for schema in backend.schemas] + return [{"type": "function", **schema} for schema in super()._tool_schemas()] def _parse_response_output_section( self, *, section: Any, message_piece: MessagePiece, error: Optional[PromptResponseError] From 179d176574e50456f30fa018d3d5fecb5ee8bed0 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 29 May 2026 10:43:26 -0700 Subject: [PATCH 15/17] Add tool_parser and tool_backend kwargs to AzureMLChatTarget AzureMLChatTarget now participates in PyRIT's @tool_loop when callers supply a ToolCallParser at construction. The parser flips supports_tool_use=True on the default capabilities so callers don't need to construct a custom_configuration just to opt in. A convenience tool_backend kwarg installs the backend onto the configuration in one step. Wire format: _tool_schemas() wraps the backend's schemas in the OpenAI Chat Completions tools shape (with each schema nested under a "function" key). _construct_http_body_async injects the wrapped schemas as a top-level tools field when non-empty. Deployments unwrap that envelope before passing to tokenizer.apply_chat_template; see plan section 12.9 for the contract. Response handling: _complete_chat_async now returns the parsed JSON body (was: string output). The new _materialize_response walks the response dict and emits one text MessagePiece for the output field plus one function_call MessagePiece per envelope in the tool_calls field; CanonicalEnvelopeParser then finds those pieces in the loop's next iteration. The no-tools path is unchanged: requests without tool_parser produce byte-identical request bodies, verified by test_request_body_omits_tools_key_when_no_backend. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/prompt_target/azure_ml_chat_target.py | 208 +++++++++++++++--- .../target/test_azure_ml_chat_target.py | 150 ++++++++++++- .../test_normalize_async_integration.py | 12 +- 3 files changed, 333 insertions(+), 37 deletions(-) diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 1c9d54d913..59eefe116c 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import json import logging from typing import Any @@ -18,6 +19,7 @@ from pyrit.message_normalizer import ChatMessageNormalizer, MessageListNormalizer from pyrit.models import ( Message, + MessagePiece, construct_response_from_request, ) from pyrit.prompt_target.common.prompt_target import PromptTarget @@ -29,6 +31,7 @@ ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p +from pyrit.tools import ToolBackend, ToolCallParser logger = logging.getLogger(__name__) @@ -70,6 +73,8 @@ def __init__( repetition_penalty: float = 1.0, max_requests_per_minute: int | None = None, custom_configuration: TargetConfiguration | None = None, + tool_parser: ToolCallParser | None = None, + tool_backend: ToolBackend | None = None, **param_kwargs: Any, ) -> None: """ @@ -100,6 +105,17 @@ def __init__( will be capped at the value provided. custom_configuration (TargetConfiguration | None): Override the default configuration for this target instance. Useful for targets whose capabilities depend on deployment configuration. + tool_parser (ToolCallParser | None): When supplied, the target opts into PyRIT's + ``@tool_loop`` and uses this parser to extract pending tool calls from the + response. Supplying a parser also enables the ``supports_tool_use`` capability + on the default configuration so callers don't have to construct a custom + configuration just to enable the loop. The parser's expectations about the + deployment's response shape MUST line up with the contract documented in + ``doc/code/targets/`` for tool-capable Azure ML deployments. + tool_backend (ToolBackend | None): Convenience kwarg that wires a tool backend + onto ``custom_configuration.tool_backend``. Equivalent to constructing a + ``TargetConfiguration`` with the backend assigned. When ``custom_configuration`` + already specifies a backend, the kwarg is rejected. **param_kwargs: Additional parameters to pass to the model for generating responses. Example parameters can be found here: https://huggingface.co/docs/api-inference/tasks/text-generation. Note that the link above may not be comprehensive, and specific acceptable parameters may be @@ -145,6 +161,18 @@ def __init__( normalizer_overrides={CapabilityName.SYSTEM_PROMPT: message_normalizer}, ) + # Enable tool-use capability when a parser is supplied so callers + # don't need to construct a custom_configuration just to opt in. + if tool_parser is not None: + custom_configuration = self._enable_tool_use(configuration=custom_configuration) + + # tool_backend is a convenience kwarg; install it into the configuration. + if tool_backend is not None: + custom_configuration = self._install_tool_backend( + configuration=custom_configuration, + tool_backend=tool_backend, + ) + PromptTarget.__init__( self, max_requests_per_minute=max_requests_per_minute, @@ -163,6 +191,76 @@ def __init__( self._top_p = top_p self._repetition_penalty = repetition_penalty self._extra_parameters = param_kwargs + self._tool_parser_instance = tool_parser + + def _enable_tool_use(self, *, configuration: TargetConfiguration | None) -> TargetConfiguration: + """ + Return a configuration whose capabilities include ``supports_tool_use=True``. + + When ``configuration`` already has the capability set, returns it as-is. + Otherwise rebuilds the capabilities with ``supports_tool_use=True`` flipped + on and preserves every other field. + + Args: + configuration (TargetConfiguration | None): The user-supplied configuration, + or ``None`` to start from the class default. + + Returns: + TargetConfiguration: A configuration whose capabilities include + ``supports_tool_use=True``. + """ + source = configuration if configuration is not None else self._DEFAULT_CONFIGURATION + caps = source.capabilities + if caps.includes(capability=CapabilityName.TOOL_USE): + return source + updated_caps = TargetCapabilities( + supports_multi_message_pieces=caps.supports_multi_message_pieces, + supports_editable_history=caps.supports_editable_history, + supports_multi_turn=caps.supports_multi_turn, + supports_system_prompt=caps.supports_system_prompt, + supports_tool_use=True, + input_modalities=caps.input_modalities, + output_modalities=caps.output_modalities, + ) + return TargetConfiguration( + capabilities=updated_caps, + policy=source.policy, + tool_event_policy=source.tool_event_policy, + tool_backend=source.tool_backend, + ) + + @staticmethod + def _install_tool_backend( + *, + configuration: TargetConfiguration | None, + tool_backend: ToolBackend, + ) -> TargetConfiguration: + """ + Install ``tool_backend`` onto ``configuration``. Rejects double-supply. + + Args: + configuration (TargetConfiguration | None): The user-supplied configuration. + tool_backend (ToolBackend): The backend to install. + + Returns: + TargetConfiguration: The same ``configuration`` instance with the + backend installed. + + Raises: + ValueError: When ``configuration`` is ``None`` (no capability to attach + to), or when ``configuration.tool_backend`` is already set to a + different backend. + """ + if configuration is None: + raise ValueError( + "tool_backend kwarg requires capabilities.supports_tool_use=True; " + "supply tool_parser= so the default capabilities flip TOOL_USE on, " + "or build a custom_configuration explicitly." + ) + if configuration.tool_backend is not None and configuration.tool_backend is not tool_backend: + raise ValueError("tool_backend kwarg conflicts with custom_configuration.tool_backend; supply only one.") + configuration.tool_backend = tool_backend + return configuration def _build_identifier(self) -> ComponentIdentifier: """ @@ -224,17 +322,10 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me logger.info(f"Sending the following prompt to the prompt target: {request}") try: - resp_text = await self._complete_chat_async( - messages=normalized_conversation, - ) - - if not resp_text: - raise EmptyResponseException(message="The chat returned an empty response.") - - response_entry = construct_response_from_request(request=request, response_text_pieces=[resp_text]) + response_body = await self._complete_chat_async(messages=normalized_conversation) + response_entry = self._materialize_response(response=response_body, request=request) except HTTPStatusError as hse: if hse.response.status_code == 400: - # Handle Bad Request response_entry = handle_bad_request_exception(response_text=hse.response.text, request=request) elif hse.response.status_code == 429: raise RateLimitException from hse @@ -248,21 +339,23 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me async def _complete_chat_async( self, messages: list[Message], - ) -> str: + ) -> dict[str, Any]: """ - Completes a chat interaction by generating a response to the given input prompt. - - This is a synchronous wrapper for the asynchronous _generate_and_extract_response method. + Issue a single chat request and return the parsed JSON response body. Args: messages (list[Message]): The message objects containing the role and content. + Returns: + dict[str, Any]: The deserialized response body. Always includes an + ``output`` field (per the AML scoring-script contract). Tool-capable + deployments may additionally include a ``tool_calls`` field carrying + canonical envelopes. + Raises: EmptyResponseException: If the response from the chat is empty. + ValueError: If the parsed response body is missing the ``output`` field. Exception: For any other errors during the process. - - Returns: - str: The generated response message. """ headers = self._get_headers() payload = await self._construct_http_body_async(messages) @@ -271,15 +364,52 @@ async def _complete_chat_async( endpoint_uri=self._endpoint, method="POST", request_body=payload, headers=headers ) - try: - return str(response.json()["output"]) - except Exception as e: - if response.json() == {}: - raise EmptyResponseException(message="The chat returned an empty response.") from e - raise type(e)( - f"Exception obtaining response from the target. Returned response: {response.json()}. " - f"Exception: {str(e)}" - ) from e + body = response.json() + if not isinstance(body, dict) or body == {}: + raise EmptyResponseException(message="The chat returned an empty response.") + if "output" not in body: + raise ValueError(f"Response from the target did not include 'output'. Returned response: {body}.") + return body + + def _materialize_response(self, *, response: dict[str, Any], request: MessagePiece) -> Message: + """ + Build a ``Message`` from the parsed response body, handling tool calls. + + The deployment may include a ``tool_calls`` list when the model emits + canonical envelopes. Each envelope becomes its own ``function_call`` + MessagePiece so the ``CanonicalEnvelopeParser`` shipped with PyRIT can + recognize it without further translation. + + Args: + response (dict[str, Any]): The parsed response body returned from the endpoint. + request (MessagePiece): The request piece used to stamp identity onto each + response piece. + + Returns: + Message: The materialized response message. Has at least one piece; + when both ``output`` and ``tool_calls`` are present, the text piece + comes first followed by one function_call piece per envelope. + + Raises: + EmptyResponseException: If the response has neither output text nor tool calls. + """ + text = str(response.get("output") or "") + tool_envelopes = response.get("tool_calls") or [] + if not text and not tool_envelopes: + raise EmptyResponseException(message="The chat returned an empty response.") + + pieces: list[MessagePiece] = [] + if text: + text_piece = construct_response_from_request(request=request, response_text_pieces=[text]).message_pieces[0] + pieces.append(text_piece) + for envelope in tool_envelopes: + fc_piece = construct_response_from_request( + request=request, + response_text_pieces=[json.dumps(envelope, separators=(",", ":"))], + response_type="function_call", + ).message_pieces[0] + pieces.append(fc_piece) + return Message(message_pieces=pieces, skip_validation=True) async def _construct_http_body_async( self, @@ -297,10 +427,7 @@ async def _construct_http_body_async( wire_format = ChatMessageNormalizer() messages_dict = await wire_format.normalize_to_dicts_async(messages) - # Parameters include additional ones passed in through **kwargs. Those not accepted by the model will - # be ignored. We only include commonly supported parameters here - model-specific parameters like - # stop sequences should be passed via **param_kwargs since different models use different EOS tokens. - return { + body: dict[str, Any] = { "input_data": { "input_string": messages_dict, "parameters": { @@ -312,6 +439,29 @@ async def _construct_http_body_async( | self._extra_parameters, } } + schemas = self._tool_schemas() + if schemas: + body["tools"] = schemas + return body + + @property + def _tool_parser(self) -> ToolCallParser | None: + """Return the parser supplied at construction, if any.""" + return self._tool_parser_instance + + def _tool_schemas(self) -> list[dict[str, Any]]: + """ + Wrap the backend's schemas in the OpenAI Chat Completions ``tools`` shape. + + Tool-capable deployments are expected to forward ``tools`` into + ``tokenizer.apply_chat_template`` after unwrapping the ``{"type": + "function", "function": {...}}`` envelope. + + Returns: + list[dict[str, Any]]: One descriptor per advertised tool, or an + empty list when no backend is configured. + """ + return [{"type": "function", "function": schema} for schema in super()._tool_schemas()] def _get_headers(self) -> dict[str, str]: """ diff --git a/tests/unit/prompt_target/target/test_azure_ml_chat_target.py b/tests/unit/prompt_target/target/test_azure_ml_chat_target.py index e9517d3ec8..b269894ba1 100644 --- a/tests/unit/prompt_target/target/test_azure_ml_chat_target.py +++ b/tests/unit/prompt_target/target/test_azure_ml_chat_target.py @@ -72,7 +72,7 @@ async def test_complete_chat_async(aml_online_chat: AzureMLChatTarget): mock_response.json.return_value = {"output": "extracted response"} mock.return_value = mock_response response = await aml_online_chat._complete_chat_async(messages) - assert response == "extracted response" + assert response == {"output": "extracted response"} mock.assert_called_once() @@ -90,7 +90,7 @@ async def test_complete_chat_async_with_default_normalizer( mock_response.json.return_value = {"output": "extracted response"} mock.return_value = mock_response response = await aml_online_chat._complete_chat_async(messages) - assert response == "extracted response" + assert response == {"output": "extracted response"} args, kwargs = mock.call_args body = kwargs["request_body"] @@ -107,9 +107,12 @@ async def test_complete_chat_async_bad_json_response(aml_online_chat: AzureMLCha with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock) as mock: mock_response = MagicMock() + # Set is a non-dict body that previously raised TypeError when the code + # subscripted response.json()["output"]; the new code raises ValueError + # because the body is not a dict. mock_response.json.return_value = {"bad response"} mock.return_value = mock_response - with pytest.raises(TypeError): + with pytest.raises((TypeError, ValueError, EmptyResponseException)): await aml_online_chat._complete_chat_async(messages) @@ -178,8 +181,10 @@ async def test_send_prompt_async_rate_limit_exception_retries(aml_online_chat: A async def test_send_prompt_async_empty_response_retries(aml_online_chat: AzureMLChatTarget): response = MagicMock() response.status_code = 429 + # Return an empty dict; _materialize_response raises EmptyResponseException + # when both output and tool_calls are missing. mock_complete_chat_async = AsyncMock() - mock_complete_chat_async.return_value = None + mock_complete_chat_async.return_value = {} aml_online_chat._complete_chat_async = mock_complete_chat_async message = Message(message_pieces=[MessagePiece(role="user", conversation_id="12345", original_value="Hello")]) @@ -236,3 +241,140 @@ def test_valid_temperature_and_top_p(patch_central_database): ) assert target._temperature == 1.5 assert target._top_p == 0.9 + + +# --------------------------------------------------------------------------- +# Tool calling: tool_parser + tool_backend kwargs +# --------------------------------------------------------------------------- + + +@pytest.fixture +def echo_backend(): + from pyrit.tools import LocalToolBackend + + async def _echo(args): + return {"echoed": args.get("text", "")} + + return LocalToolBackend( + callables={"echo": _echo}, + schemas=[ + { + "name": "echo", + "description": "Echo back the given text.", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ], + ) + + +def test_tool_parser_kwarg_flips_supports_tool_use_capability(patch_central_database): + from pyrit.prompt_target.common.target_capabilities import CapabilityName + from pyrit.tools import CanonicalEnvelopeParser + + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_parser=CanonicalEnvelopeParser(), + ) + assert target.configuration.includes(capability=CapabilityName.TOOL_USE) + assert target._tool_parser is not None + + +def test_no_tool_parser_leaves_supports_tool_use_off(aml_online_chat: AzureMLChatTarget): + from pyrit.prompt_target.common.target_capabilities import CapabilityName + + assert not aml_online_chat.configuration.includes(capability=CapabilityName.TOOL_USE) + assert aml_online_chat._tool_parser is None + + +def test_tool_backend_kwarg_installed_into_configuration(patch_central_database, echo_backend): + from pyrit.tools import CanonicalEnvelopeParser + + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_parser=CanonicalEnvelopeParser(), + tool_backend=echo_backend, + ) + assert target.configuration.tool_backend is echo_backend + + +def test_tool_backend_kwarg_without_parser_raises(patch_central_database, echo_backend): + # Without tool_parser, the default configuration has supports_tool_use=False, + # so attaching a backend must raise. + with pytest.raises(ValueError, match="supports_tool_use"): + AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_backend=echo_backend, + ) + + +def test_tool_schemas_wraps_backend_schemas_in_chat_completions_shape(patch_central_database, echo_backend): + from pyrit.tools import CanonicalEnvelopeParser + + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_parser=CanonicalEnvelopeParser(), + tool_backend=echo_backend, + ) + schemas = target._tool_schemas() + assert len(schemas) == 1 + assert schemas[0]["type"] == "function" + assert schemas[0]["function"]["name"] == "echo" + + +def test_tool_schemas_empty_when_no_backend(aml_online_chat: AzureMLChatTarget): + assert aml_online_chat._tool_schemas() == [] + + +async def test_request_body_omits_tools_key_when_no_backend(aml_online_chat: AzureMLChatTarget): + messages = [Message(message_pieces=[MessagePiece(role="user", original_value="hi")])] + body = await aml_online_chat._construct_http_body_async(messages) + assert "tools" not in body + + +async def test_request_body_includes_tools_when_backend_set(patch_central_database, echo_backend): + from pyrit.tools import CanonicalEnvelopeParser + + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_parser=CanonicalEnvelopeParser(), + tool_backend=echo_backend, + ) + messages = [Message(message_pieces=[MessagePiece(role="user", original_value="hi")])] + body = await target._construct_http_body_async(messages) + assert "tools" in body + assert body["tools"][0]["function"]["name"] == "echo" + + +async def test_materialize_response_handles_text_and_tool_calls(patch_central_database, echo_backend): + from pyrit.tools import CanonicalEnvelopeParser + + target = AzureMLChatTarget( + endpoint="http://aml-test-endpoint.com", + api_key="k", + tool_parser=CanonicalEnvelopeParser(), + tool_backend=echo_backend, + ) + request = MessagePiece(role="user", original_value="hi", conversation_id="abc") + response = { + "output": "ok", + "tool_calls": [ + { + "type": "function_call", + "call_id": "call_0", + "name": "echo", + "arguments": '{"text":"hi"}', + } + ], + } + msg = target._materialize_response(response=response, request=request) + types = [p.original_value_data_type for p in msg.message_pieces] + assert types == ["text", "function_call"] diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py index 2317bd705f..b62fa7a64b 100644 --- a/tests/unit/prompt_target/target/test_normalize_async_integration.py +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -229,7 +229,7 @@ async def test_azure_ml_target_calls_normalize_async(): with ( patch.object(target.configuration, "normalize_async", new_callable=AsyncMock) as mock_normalize, - patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"), + patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"}), ): mock_normalize.return_value = [user_msg] await target.send_prompt_async(message=user_msg) @@ -254,7 +254,9 @@ async def test_azure_ml_target_sends_normalized_to_complete_chat(): with ( patch.object(target.configuration, "normalize_async", new_callable=AsyncMock, return_value=[adapted_msg]), - patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat, + patch.object( + target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"} + ) as mock_chat, ): await target.send_prompt_async(message=user_msg) @@ -294,7 +296,7 @@ async def test_azure_ml_target_memory_not_mutated(): mock_memory.get_conversation.return_value = memory_conversation target._memory = mock_memory - with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response"): + with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"}): await target.send_prompt_async(message=user_msg) # Memory must still have original system message only (not mutated) @@ -386,7 +388,9 @@ async def test_azure_ml_system_squash_via_configuration_pipeline(): mock_memory.get_conversation.return_value = [system_msg] target._memory = mock_memory - with patch.object(target, "_complete_chat_async", new_callable=AsyncMock, return_value="response") as mock_chat: + with patch.object( + target, "_complete_chat_async", new_callable=AsyncMock, return_value={"output": "response"} + ) as mock_chat: await target.send_prompt_async(message=user_msg) # _complete_chat_async should receive normalized messages (system squashed into user) From 22dc6a0664699c9bb901d66970b1ad88d51b6a9d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 29 May 2026 10:48:44 -0700 Subject: [PATCH 16/17] Add tool_parser and tool_backend kwargs to HuggingFaceChatTarget Same shape as the AzureMLChatTarget F2 change: callers supply a ToolCallParser at construction; the parser flips supports_tool_use=True on the default capabilities so no custom_configuration is required to opt in. A convenience tool_backend kwarg installs the backend onto the configuration in one step. Wire format differs from AzureML because HuggingFace runs the model in-process via the transformers library: * _tool_schemas() returns the bare backend schemas (no OpenAI envelope) because tokenizer.apply_chat_template expects bare function schemas, not the Chat Completions wrapper. * _apply_chat_template forwards tools= into apply_chat_template when schemas are present; the model's tool-trained chat template renders the model-family-specific tools block (Qwen wraps in ..., Llama uses a system-message preamble, etc.). * _build_chat_messages now walks every piece in each message and converts function_call / function_call_output envelopes to the chat-template tool message shape (assistant + tool_calls list, role=tool + tool_call_id) so the model sees the canonical in-context tool conversation. The no-tools path is unchanged: without tool_parser, no tools key is passed to apply_chat_template and no tool message translation runs. The user-supplied tool_parser walks the response text for inline tool-call markers; InlineToolCallParser is the typical choice for ChatML-style angle-bracket markers, but the user can supply any ToolCallParser implementation (different marker regex, different mode). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../hugging_face/hugging_face_chat_target.py | 183 +++++++++++++++++- .../target/test_huggingface_chat_target.py | 160 +++++++++++++++ 2 files changed, 334 insertions(+), 9 deletions(-) diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index f2d62be82a..055cd55c70 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -22,9 +22,13 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityName, + TargetCapabilities, +) from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute +from pyrit.tools import ToolBackend, ToolCallParser logger = logging.getLogger(__name__) @@ -77,6 +81,8 @@ def __init__( attn_implementation: str | None = None, max_requests_per_minute: int | None = None, custom_configuration: TargetConfiguration | None = None, + tool_parser: ToolCallParser | None = None, + tool_backend: ToolBackend | None = None, ) -> None: """ Initialize the HuggingFaceChatTarget. @@ -108,6 +114,15 @@ def __init__( max_requests_per_minute (int | None): The maximum number of requests per minute. Defaults to None. custom_configuration (TargetConfiguration | None): Override the default configuration for this target instance. Defaults to None. + tool_parser (ToolCallParser | None): When supplied, the target opts into PyRIT's + ``@tool_loop`` and uses this parser to extract pending tool calls from each + generated response. Supplying a parser also enables ``supports_tool_use=True`` + on the default capabilities so callers don't need a custom_configuration just + to opt in. ``InlineToolCallParser`` is the typical choice because the local + tokenizer emits tool calls as inline marker-delimited JSON; supply a different + parser when targeting a chat template with a different marker syntax. + tool_backend (ToolBackend | None): Convenience kwarg that installs the backend + onto ``custom_configuration.tool_backend``. Raises: ValueError: If neither or both of `model_id` and `model_path` are provided. @@ -115,6 +130,16 @@ def __init__( """ model_name = model_id if model_id else model_path if model_path else "" + # Enable tool-use capability when a parser is supplied, BEFORE super().__init__ + # so the configuration is correct by the time the base class records it. + if tool_parser is not None: + custom_configuration = self._enable_tool_use(configuration=custom_configuration) + if tool_backend is not None: + custom_configuration = self._install_tool_backend( + configuration=custom_configuration, + tool_backend=tool_backend, + ) + super().__init__( max_requests_per_minute=max_requests_per_minute, model_name=model_name, @@ -174,6 +199,81 @@ def __init__( raise RuntimeError("CUDA requested but not available.") self.load_model_and_tokenizer_task = asyncio.create_task(self.load_model_and_tokenizer()) + self._tool_parser_instance = tool_parser + + @classmethod + def _enable_tool_use(cls, *, configuration: TargetConfiguration | None) -> TargetConfiguration: + """ + Return a configuration whose capabilities include ``supports_tool_use=True``. + + When ``configuration`` already has the capability set, returns it as-is. + Otherwise rebuilds the capabilities with ``supports_tool_use=True`` and + preserves every other field. + + Args: + configuration (TargetConfiguration | None): The user-supplied configuration, + or ``None`` to start from the class default. + + Returns: + TargetConfiguration: A configuration whose capabilities include + ``supports_tool_use=True``. + """ + source = configuration if configuration is not None else cls._DEFAULT_CONFIGURATION + caps = source.capabilities + if caps.includes(capability=CapabilityName.TOOL_USE): + return source + updated_caps = TargetCapabilities( + supports_multi_message_pieces=caps.supports_multi_message_pieces, + supports_editable_history=caps.supports_editable_history, + supports_multi_turn=caps.supports_multi_turn, + supports_system_prompt=caps.supports_system_prompt, + supports_tool_use=True, + input_modalities=caps.input_modalities, + output_modalities=caps.output_modalities, + ) + return TargetConfiguration( + capabilities=updated_caps, + policy=source.policy, + tool_event_policy=source.tool_event_policy, + tool_backend=source.tool_backend, + ) + + @staticmethod + def _install_tool_backend( + *, + configuration: TargetConfiguration | None, + tool_backend: ToolBackend, + ) -> TargetConfiguration: + """ + Install ``tool_backend`` onto ``configuration``. Rejects double-supply. + + Args: + configuration (TargetConfiguration | None): The user-supplied configuration. + tool_backend (ToolBackend): The backend to install. + + Returns: + TargetConfiguration: The same ``configuration`` instance with the + backend installed. + + Raises: + ValueError: When ``configuration`` is ``None``, or when + ``configuration.tool_backend`` is already set to a different backend. + """ + if configuration is None: + raise ValueError( + "tool_backend kwarg requires capabilities.supports_tool_use=True; " + "supply tool_parser= so the default capabilities flip TOOL_USE on, " + "or build a custom_configuration explicitly." + ) + if configuration.tool_backend is not None and configuration.tool_backend is not tool_backend: + raise ValueError("tool_backend kwarg conflicts with custom_configuration.tool_backend; supply only one.") + configuration.tool_backend = tool_backend + return configuration + + @property + def _tool_parser(self) -> ToolCallParser | None: + """Return the parser supplied at construction, if any.""" + return self._tool_parser_instance def _build_identifier(self) -> ComponentIdentifier: """ @@ -401,27 +501,79 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me logger.error(f"Error occurred during inference: {e}") raise - def _build_chat_messages(self, *, normalized_conversation: list[Message]) -> list[dict[str, str]]: + def _build_chat_messages(self, *, normalized_conversation: list[Message]) -> list[dict[str, Any]]: """ Build a list of chat message dicts from the full normalized conversation. Includes system, user, and assistant messages from the conversation history - so that the model's chat template receives the complete context. + so that the model's chat template receives the complete context. When the + conversation contains tool-call envelopes (produced by ``@tool_loop``), they + are converted into the chat-template's tool message shape: + + * ``function_call`` pieces become an ``assistant`` message with a + ``tool_calls`` list (matching the HuggingFace ``apply_chat_template`` + convention; templates that don't recognize ``tool_calls`` fall back to + rendering the embedded JSON as content). + * ``function_call_output`` pieces become a ``role=tool`` message with the + tool result as content and ``tool_call_id`` carried for templates that + need it. Args: normalized_conversation (list[Message]): The full normalized conversation. Returns: - list[dict[str, str]]: Messages formatted for the chat template. + list[dict[str, Any]]: Messages formatted for the chat template. """ - messages: list[dict[str, str]] = [] + messages: list[dict[str, Any]] = [] for msg in normalized_conversation: - piece = msg.message_pieces[0] - role = piece.api_role - content = piece.converted_value or "" - messages.append({"role": role, "content": content}) + for piece in msg.message_pieces: + tool_dict = self._maybe_tool_chat_message(piece=piece) + if tool_dict is not None: + messages.append(tool_dict) + continue + role = piece.api_role + content = piece.converted_value or "" + messages.append({"role": role, "content": content}) return messages + @staticmethod + def _maybe_tool_chat_message(*, piece: Any) -> dict[str, Any] | None: + """ + Convert a ``function_call`` or ``function_call_output`` piece to a chat-template message. + + Args: + piece (Any): The MessagePiece to inspect. + + Returns: + dict[str, Any] | None: A chat-template message dict (``assistant`` with + ``tool_calls``, or ``role=tool`` with ``tool_call_id``) when the + piece carries a tool envelope, otherwise ``None``. + """ + data_type = piece.converted_value_data_type or piece.original_value_data_type + if data_type not in ("function_call", "function_call_output"): + return None + envelope = json.loads(piece.converted_value) + if data_type == "function_call": + return { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": envelope.get("call_id", ""), + "type": "function", + "function": { + "name": envelope.get("name", ""), + "arguments": envelope.get("arguments", "{}"), + }, + } + ], + } + return { + "role": "tool", + "content": str(envelope.get("output", "")), + "tool_call_id": envelope.get("call_id", ""), + } + def set_random_seed(self, random_seed: int) -> None: """ Set a new random seed and immediately re-seed the RNG. @@ -520,6 +672,13 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: """ Apply the chat template to the input messages and tokenize them. + When ``self._tool_schemas()`` is non-empty, the schemas are forwarded + into ``apply_chat_template`` so tool-trained chat templates can render + the model-family-specific tools block (Qwen wraps in ``...``, + Llama uses a system-message preamble, etc.). The model can then emit + tool calls in its native marker syntax which the user-supplied + ``tool_parser`` extracts. + Args: messages: The input messages to apply the chat template to. @@ -533,6 +692,11 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is not None: logger.info("Tokenizer has a chat template. Applying it to the input messages.") + template_kwargs: dict[str, Any] = {} + schemas = self._tool_schemas() + if schemas: + template_kwargs["tools"] = schemas + # Apply the chat template to format and tokenize the messages return cast( "BatchEncoding", @@ -542,6 +706,7 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: add_generation_prompt=True, return_tensors=self.tensor_format, return_dict=True, + **template_kwargs, ), ).to(self.device) error_message = ( diff --git a/tests/unit/prompt_target/target/test_huggingface_chat_target.py b/tests/unit/prompt_target/target/test_huggingface_chat_target.py index 93a4ca912f..4cddd5dbd6 100644 --- a/tests/unit/prompt_target/target/test_huggingface_chat_target.py +++ b/tests/unit/prompt_target/target/test_huggingface_chat_target.py @@ -578,3 +578,163 @@ async def test_effective_generation_config_in_metadata(): assert effective_config["temperature"] == 1.0 # Model defaults should also be present assert effective_config["eos_token_id"] == 2 + + +# --------------------------------------------------------------------------- +# Tool calling (F2): tool_parser + tool_backend kwargs +# --------------------------------------------------------------------------- + + +@pytest.fixture +def echo_backend(): + from pyrit.tools import LocalToolBackend + + async def _echo(args): + return {"echoed": args.get("text", "")} + + return LocalToolBackend( + callables={"echo": _echo}, + schemas=[ + { + "name": "echo", + "description": "Echo back the given text.", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ], + ) + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_tool_parser_kwarg_flips_supports_tool_use_capability(patch_central_database): + from pyrit.prompt_target.common.target_capabilities import CapabilityName + from pyrit.tools import InlineToolCallParser + + target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False, tool_parser=InlineToolCallParser()) + assert target.configuration.includes(capability=CapabilityName.TOOL_USE) + assert target._tool_parser is not None + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_no_tool_parser_leaves_supports_tool_use_off(patch_central_database): + from pyrit.prompt_target.common.target_capabilities import CapabilityName + + target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + assert not target.configuration.includes(capability=CapabilityName.TOOL_USE) + assert target._tool_parser is None + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_tool_backend_kwarg_installed_into_configuration(patch_central_database, echo_backend): + from pyrit.tools import InlineToolCallParser + + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + tool_parser=InlineToolCallParser(), + tool_backend=echo_backend, + ) + assert target.configuration.tool_backend is echo_backend + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_tool_backend_kwarg_without_parser_raises(patch_central_database, echo_backend): + with pytest.raises(ValueError, match="supports_tool_use"): + HuggingFaceChatTarget(model_id="test_model", use_cuda=False, tool_backend=echo_backend) + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_tool_schemas_returns_bare_backend_schemas(patch_central_database, echo_backend): + """HF chat templates accept bare schemas (no OpenAI envelope).""" + from pyrit.tools import InlineToolCallParser + + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + tool_parser=InlineToolCallParser(), + tool_backend=echo_backend, + ) + schemas = target._tool_schemas() + assert len(schemas) == 1 + assert schemas[0]["name"] == "echo" + # Unlike AzureMLChatTarget, no {"type": "function", "function": {...}} wrapper. + assert "function" not in schemas[0] + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_build_chat_messages_translates_function_call_piece(patch_central_database): + target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + fc_envelope = { + "type": "function_call", + "call_id": "call_0", + "name": "echo", + "arguments": '{"text":"hi"}', + } + msg = Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=json.dumps(fc_envelope), + original_value_data_type="function_call", + converted_value_data_type="function_call", + ) + ] + ) + chat_messages = target._build_chat_messages(normalized_conversation=[msg]) + assert len(chat_messages) == 1 + assert chat_messages[0]["role"] == "assistant" + assert chat_messages[0]["tool_calls"][0]["function"]["name"] == "echo" + assert chat_messages[0]["tool_calls"][0]["function"]["arguments"] == '{"text":"hi"}' + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_build_chat_messages_translates_function_call_output_piece(patch_central_database): + target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + fco_envelope = { + "type": "function_call_output", + "call_id": "call_0", + "output": '{"echoed":"hi"}', + } + msg = Message( + message_pieces=[ + MessagePiece( + role="tool", + original_value=json.dumps(fco_envelope), + original_value_data_type="function_call_output", + converted_value_data_type="function_call_output", + ) + ], + skip_validation=True, + ) + chat_messages = target._build_chat_messages(normalized_conversation=[msg]) + assert len(chat_messages) == 1 + assert chat_messages[0]["role"] == "tool" + assert chat_messages[0]["tool_call_id"] == "call_0" + assert chat_messages[0]["content"] == '{"echoed":"hi"}' + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_apply_chat_template_forwards_tools_when_present(patch_central_database, echo_backend): + from pyrit.tools import InlineToolCallParser + + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + tool_parser=InlineToolCallParser(), + tool_backend=echo_backend, + ) + target._apply_chat_template([{"role": "user", "content": "hi"}]) + call_kwargs = target.tokenizer.apply_chat_template.call_args.kwargs + assert "tools" in call_kwargs + assert call_kwargs["tools"][0]["name"] == "echo" + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_apply_chat_template_omits_tools_when_no_backend(patch_central_database): + target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + target._apply_chat_template([{"role": "user", "content": "hi"}]) + call_kwargs = target.tokenizer.apply_chat_template.call_args.kwargs + assert "tools" not in call_kwargs From 4e30e618d3b768455bd11ee80d109999d0e3443d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 29 May 2026 10:50:11 -0700 Subject: [PATCH 17/17] Add AzureMLChatTarget integration test for tool-loop end-to-end MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds tests/integration/tools/test_azure_ml_with_tools_integration.py exercising the full PyRIT @tool_loop stack against AzureMLChatTarget with only the HTTP layer mocked. The mocked responses match the §12.9.2 canonical envelope shape: first response carries a tool_calls field that the loop dispatches via LocalToolBackend; second response is the final assistant text. Asserts the canonical four-piece transcript shape persists in Memory: [user text, assistant function_call, tool function_call_output, assistant text], with the call_id round-tripping between the assistant function_call piece and the tool function_call_output piece, and the tool output reflecting the actual dispatched callable's return value. Also covers the no-tools backward-compatibility path: a target constructed without tool_parser produces a request body that has no tools key, proving the F2 changes do not regress existing AzureML deployments that don't carry the patched scoring script. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../test_azure_ml_with_tools_integration.py | 162 ++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 tests/integration/tools/test_azure_ml_with_tools_integration.py diff --git a/tests/integration/tools/test_azure_ml_with_tools_integration.py b/tests/integration/tools/test_azure_ml_with_tools_integration.py new file mode 100644 index 0000000000..9f0b4903b7 --- /dev/null +++ b/tests/integration/tools/test_azure_ml_with_tools_integration.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""F2 integration test: AzureMLChatTarget end-to-end through @tool_loop. + +Validates that PyRIT's full client-side tool-calling stack works against +an AzureML chat target whose scoring script emits the canonical +function_call envelope per plan §12.9.2. Only the outbound HTTP layer is +mocked; the @tool_loop decorator, CanonicalEnvelopeParser, LocalToolBackend +dispatch, ChatMessageNormalizer tool-piece serialization, and Memory +persistence all run unmocked. + +The test asserts the canonical four-piece transcript shape: +``[user, assistant function_call, tool function_call_output, assistant text]``. +""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.memory import CentralMemory +from pyrit.models import Message, MessagePiece +from pyrit.prompt_normalizer import PromptNormalizer +from pyrit.prompt_target import AzureMLChatTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.tools import ( + CanonicalEnvelopeParser, + LocalToolBackend, + ToolEventBehavior, + ToolEventPolicy, +) + + +@pytest.fixture +def echo_backend(): + async def _echo(args): + return {"echoed": args.get("text", "")} + + return LocalToolBackend( + callables={"echo": _echo}, + schemas=[ + { + "name": "echo", + "description": "Echo back the given text.", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ], + ) + + +@pytest.mark.run_only_if_all_tests +async def test_azure_ml_chat_target_tool_loop_round_trip(patch_central_database, echo_backend): + """User asks for a tool call; the loop dispatches; the model produces final text.""" + target = AzureMLChatTarget( + endpoint="https://mock-endpoint.example.com/score", + api_key="dummy", + tool_parser=CanonicalEnvelopeParser(), + tool_backend=echo_backend, + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_message_pieces=True, + supports_editable_history=True, + supports_multi_turn=True, + supports_system_prompt=True, + supports_tool_use=True, + ), + tool_event_policy=ToolEventPolicy( + behavior=ToolEventBehavior.EXECUTE, + max_tool_iterations=3, + ), + tool_backend=echo_backend, + ), + ) + + first_response = MagicMock() + first_response.json.return_value = { + "output": "", + "tool_calls": [ + { + "type": "function_call", + "call_id": "call_0", + "name": "echo", + "arguments": '{"text":"hi"}', + } + ], + } + second_response = MagicMock() + second_response.json.return_value = {"output": "The echoed text is: hi"} + + user_msg = Message( + message_pieces=[ + MessagePiece( + role="user", + original_value="Use the echo tool to repeat 'hi'.", + original_value_data_type="text", + ) + ] + ) + + with patch( + "pyrit.common.net_utility.make_request_and_raise_if_error_async", + new_callable=AsyncMock, + ) as mock_http: + mock_http.side_effect = [first_response, second_response] + normalizer = PromptNormalizer() + result = await normalizer.send_prompt_async(message=user_msg, target=target) + + assert mock_http.call_count == 2 + final_text = result.get_value() + assert "The echoed text is" in final_text + + conv = CentralMemory.get_memory_instance().get_conversation( + conversation_id=result.message_pieces[0].conversation_id + ) + # Canonical four-piece transcript: user -> assistant fc -> tool fco -> assistant text. + flat_pieces = [p for msg in conv for p in msg.message_pieces] + types = [p.original_value_data_type for p in flat_pieces] + roles = [p.api_role for p in flat_pieces] + assert types == ["text", "function_call", "function_call_output", "text"] + assert roles == ["user", "assistant", "tool", "assistant"] + + # call_id round-trips between the assistant and tool messages. + fc_envelope = json.loads(flat_pieces[1].original_value) + fco_envelope = json.loads(flat_pieces[2].original_value) + assert fc_envelope["call_id"] == fco_envelope["call_id"] == "call_0" + # Tool dispatched the local echo callable; the output reflects the args. + assert json.loads(fco_envelope["output"])["echoed"] == "hi" + + +@pytest.mark.run_only_if_all_tests +async def test_azure_ml_chat_target_no_tools_backward_compat(patch_central_database): + """Without tool_parser / tool_backend, the request body has no tools key.""" + target = AzureMLChatTarget( + endpoint="https://mock-endpoint.example.com/score", + api_key="dummy", + ) + user_msg = Message( + message_pieces=[MessagePiece(role="user", original_value="Hello", original_value_data_type="text")] + ) + + response = MagicMock() + response.json.return_value = {"output": "Hi back"} + + with patch( + "pyrit.common.net_utility.make_request_and_raise_if_error_async", + new_callable=AsyncMock, + ) as mock_http: + mock_http.return_value = response + normalizer = PromptNormalizer() + result = await normalizer.send_prompt_async(message=user_msg, target=target) + + assert result.get_value() == "Hi back" + body = mock_http.call_args.kwargs["request_body"] + assert "tools" not in body