diff --git a/CHANGELOG.md b/CHANGELOG.md index e358b871f..ac792bc2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- fix(example): support multi-step Responses tool streaming by @abetlen in #2288 - fix(ci): Repair Linux accelerator wheels for manylinux publishing ## [0.3.28] diff --git a/examples/server/server.py b/examples/server/server.py index 28fc8f4eb..fb00501cf 100644 --- a/examples/server/server.py +++ b/examples/server/server.py @@ -2812,6 +2812,7 @@ def to_chat_template_tool(self) -> ChatTemplateTool: class ResponsesCustomToolFormat(BaseModel): model_config = ConfigDict(extra="ignore") + type: Optional[str] = None syntax: Optional[str] = None definition: Optional[str] = None @@ -2880,10 +2881,24 @@ class ResponsesWebSearchTool(BaseModel): type: Literal["web_search"] +class ResponsesNamespaceTool(BaseModel): + model_config = ConfigDict(extra="ignore") + + type: Literal["namespace"] + + +class ResponsesImageGenerationTool(BaseModel): + model_config = ConfigDict(extra="ignore") + + type: Literal["image_generation"] + + ResponsesToolDefinition = Union[ ResponsesFunctionTool, ResponsesCustomTool, ResponsesWebSearchTool, + ResponsesNamespaceTool, + ResponsesImageGenerationTool, ] @@ -5069,6 +5084,68 @@ def _tool_content_type(self, tool_name: str) -> Optional[str]: return content_type return None + def _raw_string_tool_arguments(self, tool_name: str, value: str) -> Optional[Dict[str, str]]: + if self._tools is None: + return None + for tool in self._tools: + if tool.get("type") != "function": + continue + function = tool.get("function", {}) + if function.get("name") != tool_name: + continue + parameters = function.get("parameters") + if not isinstance(parameters, dict): + return None + required = parameters.get("required") + if not isinstance(required, list) or len(required) != 1: + return None + argument_name = required[0] + if not isinstance(argument_name, str): + return None + properties = parameters.get("properties") + if not isinstance(properties, dict): + return None + argument_schema = properties.get(argument_name) + if not isinstance(argument_schema, dict): + return None + argument_type = argument_schema.get("type") + if argument_type == "string" or ( + isinstance(argument_type, list) and "string" in argument_type + ): + return {argument_name: value} + return None + return None + + @classmethod + def _raw_object_tool_arguments(cls, value: str) -> Optional[Dict[str, Any]]: + candidates = [value] + stripped = value.strip() + if stripped.startswith("{{") and stripped.endswith("}}"): + candidates.append(stripped[1:-1]) + for candidate in candidates: + normalized = cls._gemma4_tool_call_to_json(candidate) + for allow_partial in (False, True): + try: + parsed = from_json(normalized, allow_partial=allow_partial) + except ValueError: + continue + if isinstance(parsed, dict): + return { + key: cls._trim_partial_gemma_quote_marker(value) + if isinstance(value, str) + else value + for key, value in parsed.items() + } + return None + + @staticmethod + def _trim_partial_gemma_quote_marker(value: str) -> str: + quote_marker = '<|"|>' + for prefix_length in range(len(quote_marker) - 1, 0, -1): + if value.endswith(quote_marker[:prefix_length]): + return value[:-prefix_length] + return value + def _has_text_tools(self) -> bool: return any( isinstance(tool_schema, dict) and tool_schema.get("content_type") == "text" @@ -5637,6 +5714,18 @@ def _advance_direct_stream_state(self, text: str) -> Tuple[bool, List[Dict[str, self._direct.saw_tool_calls = saw_tool_calls self._direct.done = done return True, deltas + if leading_capture_field is not None: + if buffer.startswith(leading_capture_start): + buffer = buffer[len(leading_capture_start) :] + mode = self.DIRECT_MODE_LEADING_CAPTURE + continue + if leading_capture_start.startswith(buffer): + self._direct.pending = buffer + self._direct.mode = mode + self._direct.tool_call_count = tool_call_count + self._direct.saw_tool_calls = saw_tool_calls + self._direct.done = done + return True, deltas if buffer.startswith(iterator_start): saw_tool_calls = True self._start_direct_tool_call(tool_call_count) @@ -6302,6 +6391,16 @@ def _advance_stream_state(self, text: str) -> Tuple[bool, List[Dict[str, Any]]]: if not buffer: state.pending = "" return True, deltas + leading_capture = plan.get("leading_capture") + if leading_capture is not None: + capture_start = leading_capture["start"] + if buffer.startswith(capture_start): + buffer = buffer[len(capture_start) :] + state.mode = "leading-capture" + continue + if capture_start.startswith(buffer): + state.pending = buffer + return True, deltas if buffer.startswith(iterator_start): item_state = self._new_tool_call_state(plan["iterator"]["item"]) state.saw_tool_calls = True @@ -6866,6 +6965,10 @@ def _normalize_tool_call_item( }, } arguments = function.get("arguments", {}) + if isinstance(arguments, str): + arguments = self._raw_object_tool_arguments(arguments) or self._raw_string_tool_arguments( + tool_name, arguments + ) if not isinstance(arguments, (dict, ResponseParser.PartialJsonObject)): if partial: return None @@ -8009,7 +8112,14 @@ def _responses_tools_to_chat_tools( return None chat_tools: List[ChatTemplateTool] = [] for tool in tools: - if isinstance(tool, ResponsesWebSearchTool): + if isinstance( + tool, + ( + ResponsesWebSearchTool, + ResponsesNamespaceTool, + ResponsesImageGenerationTool, + ), + ): continue if isinstance(tool, ResponsesFunctionTool): chat_tools.append(tool.to_chat_template_tool())