Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- fix(example): support multi-step Responses tool streaming by @abetlen in #2288
- fix(ci): Repair Linux accelerator wheels for manylinux publishing

## [0.3.28]
Expand Down
112 changes: 111 additions & 1 deletion examples/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2812,6 +2812,7 @@ def to_chat_template_tool(self) -> ChatTemplateTool:
class ResponsesCustomToolFormat(BaseModel):
model_config = ConfigDict(extra="ignore")

type: Optional[str] = None
syntax: Optional[str] = None
definition: Optional[str] = None

Expand Down Expand Up @@ -2880,10 +2881,24 @@ class ResponsesWebSearchTool(BaseModel):
type: Literal["web_search"]


class ResponsesNamespaceTool(BaseModel):
model_config = ConfigDict(extra="ignore")

type: Literal["namespace"]


class ResponsesImageGenerationTool(BaseModel):
model_config = ConfigDict(extra="ignore")

type: Literal["image_generation"]


ResponsesToolDefinition = Union[
ResponsesFunctionTool,
ResponsesCustomTool,
ResponsesWebSearchTool,
ResponsesNamespaceTool,
ResponsesImageGenerationTool,
]


Expand Down Expand Up @@ -5069,6 +5084,68 @@ def _tool_content_type(self, tool_name: str) -> Optional[str]:
return content_type
return None

def _raw_string_tool_arguments(self, tool_name: str, value: str) -> Optional[Dict[str, str]]:
if self._tools is None:
return None
for tool in self._tools:
if tool.get("type") != "function":
continue
function = tool.get("function", {})
if function.get("name") != tool_name:
continue
parameters = function.get("parameters")
if not isinstance(parameters, dict):
return None
required = parameters.get("required")
if not isinstance(required, list) or len(required) != 1:
return None
argument_name = required[0]
if not isinstance(argument_name, str):
return None
properties = parameters.get("properties")
if not isinstance(properties, dict):
return None
argument_schema = properties.get(argument_name)
if not isinstance(argument_schema, dict):
return None
argument_type = argument_schema.get("type")
if argument_type == "string" or (
isinstance(argument_type, list) and "string" in argument_type
):
return {argument_name: value}
return None
return None

@classmethod
def _raw_object_tool_arguments(cls, value: str) -> Optional[Dict[str, Any]]:
candidates = [value]
stripped = value.strip()
if stripped.startswith("{{") and stripped.endswith("}}"):
candidates.append(stripped[1:-1])
for candidate in candidates:
normalized = cls._gemma4_tool_call_to_json(candidate)
for allow_partial in (False, True):
try:
parsed = from_json(normalized, allow_partial=allow_partial)
except ValueError:
continue
if isinstance(parsed, dict):
return {
key: cls._trim_partial_gemma_quote_marker(value)
if isinstance(value, str)
else value
for key, value in parsed.items()
}
return None

@staticmethod
def _trim_partial_gemma_quote_marker(value: str) -> str:
quote_marker = '<|"|>'
for prefix_length in range(len(quote_marker) - 1, 0, -1):
if value.endswith(quote_marker[:prefix_length]):
return value[:-prefix_length]
return value

def _has_text_tools(self) -> bool:
return any(
isinstance(tool_schema, dict) and tool_schema.get("content_type") == "text"
Expand Down Expand Up @@ -5637,6 +5714,18 @@ def _advance_direct_stream_state(self, text: str) -> Tuple[bool, List[Dict[str,
self._direct.saw_tool_calls = saw_tool_calls
self._direct.done = done
return True, deltas
if leading_capture_field is not None:
if buffer.startswith(leading_capture_start):
buffer = buffer[len(leading_capture_start) :]
mode = self.DIRECT_MODE_LEADING_CAPTURE
continue
if leading_capture_start.startswith(buffer):
self._direct.pending = buffer
self._direct.mode = mode
self._direct.tool_call_count = tool_call_count
self._direct.saw_tool_calls = saw_tool_calls
self._direct.done = done
return True, deltas
if buffer.startswith(iterator_start):
saw_tool_calls = True
self._start_direct_tool_call(tool_call_count)
Expand Down Expand Up @@ -6302,6 +6391,16 @@ def _advance_stream_state(self, text: str) -> Tuple[bool, List[Dict[str, Any]]]:
if not buffer:
state.pending = ""
return True, deltas
leading_capture = plan.get("leading_capture")
if leading_capture is not None:
capture_start = leading_capture["start"]
if buffer.startswith(capture_start):
buffer = buffer[len(capture_start) :]
state.mode = "leading-capture"
continue
if capture_start.startswith(buffer):
state.pending = buffer
return True, deltas
if buffer.startswith(iterator_start):
item_state = self._new_tool_call_state(plan["iterator"]["item"])
state.saw_tool_calls = True
Expand Down Expand Up @@ -6866,6 +6965,10 @@ def _normalize_tool_call_item(
},
}
arguments = function.get("arguments", {})
if isinstance(arguments, str):
arguments = self._raw_object_tool_arguments(arguments) or self._raw_string_tool_arguments(
tool_name, arguments
)
if not isinstance(arguments, (dict, ResponseParser.PartialJsonObject)):
if partial:
return None
Expand Down Expand Up @@ -8009,7 +8112,14 @@ def _responses_tools_to_chat_tools(
return None
chat_tools: List[ChatTemplateTool] = []
for tool in tools:
if isinstance(tool, ResponsesWebSearchTool):
if isinstance(
tool,
(
ResponsesWebSearchTool,
ResponsesNamespaceTool,
ResponsesImageGenerationTool,
),
):
continue
if isinstance(tool, ResponsesFunctionTool):
chat_tools.append(tool.to_chat_template_tool())
Expand Down
Loading