diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 60b4b36ef..43594657f 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -1,16 +1,24 @@ """Initialize Temporal OpenAI Agents overrides.""" import dataclasses +import json import typing from collections.abc import AsyncIterator, Callable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from datetime import timedelta +import pydantic from agents import ModelProvider, Trace, set_trace_provider from agents.run import get_default_agent_runner, set_default_agent_runner from agents.tracing import get_trace_provider from agents.tracing.provider import DefaultTraceProvider +# construct_type is OpenAI's lenient (non-validating) model builder, the same +# one the SDK uses to parse live API responses. It is in a private module but +# has no public alias. +from openai._models import construct_type + +import temporalio.api.common.v1 from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._openai_runner import ( @@ -25,12 +33,14 @@ from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError from temporalio.contrib.opentelemetry._tracer_provider import ReplaySafeTracerProvider from temporalio.contrib.pydantic import ( - PydanticPayloadConverter, + PydanticJSONPlainPayloadConverter, ToJsonOptions, ) from temporalio.converter import ( + CompositePayloadConverter, DataConverter, DefaultPayloadConverter, + JSONPlainPayloadConverter, ) from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner @@ -64,12 +74,72 @@ def _set_open_ai_agent_temporal_overrides( set_trace_provider(previous_trace_provider or DefaultTraceProvider()) -class OpenAIPayloadConverter(PydanticPayloadConverter): +def _lenient_construct(type_: typing.Any, value: typing.Any) -> typing.Any: + """Build ``value`` into ``type_`` without enforcing required fields. + + OpenAI's ``construct_type`` handles its own response models (and the + unions/lists thereof), but not the ``agents`` dataclasses that wrap them + (e.g. ``ModelResponse``), so the dataclass layer is reconstructed here and + each field delegated to ``construct_type``. ``include_extras`` preserves the + ``Annotated`` discriminators the unions rely on. + """ + if ( + isinstance(type_, type) + and dataclasses.is_dataclass(type_) + and isinstance(value, dict) + ): + hints = typing.get_type_hints(type_, include_extras=True) + return type_( + **{ + field.name: _lenient_construct( + hints.get(field.name, object), value[field.name] + ) + for field in dataclasses.fields(type_) + if field.name in value + } + ) + return construct_type(type_=type_, value=value) + + +class _OpenAIJSONPlainPayloadConverter(PydanticJSONPlainPayloadConverter): + """Strict pydantic deserialization with a lenient fallback. + + OpenAI's response models can drift from live API payloads (e.g. a + deprecated-but-required field the API has stopped sending). The SDK tolerates + this when parsing responses, but strict ``validate_json`` on the workflow + side does not, so fall back to lenient construction when validation fails. + """ + + def from_payload( + self, + payload: temporalio.api.common.v1.Payload, + type_hint: type | None = None, + ) -> typing.Any: + """See base class.""" + try: + return super().from_payload(payload, type_hint) + except pydantic.ValidationError: + if type_hint is None: + raise + return _lenient_construct(type_hint, json.loads(payload.data)) + + +class OpenAIPayloadConverter(CompositePayloadConverter): """PayloadConverter for OpenAI agents.""" def __init__(self) -> None: """Initialize a payload converter.""" - super().__init__(ToJsonOptions(exclude_unset=True)) + json_payload_converter = _OpenAIJSONPlainPayloadConverter( + ToJsonOptions(exclude_unset=True) + ) + super().__init__( + *( + c + if not isinstance(c, JSONPlainPayloadConverter) + else json_payload_converter + for c in DefaultPayloadConverter.default_encoding_payload_converters + ) + ) def _data_converter(converter: DataConverter | None) -> DataConverter: diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 294acc1d0..de0af3923 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -547,7 +547,9 @@ def research_mock_model(): id="", status="completed", type="web_search_call", - action=ActionSearch(query="", type="search"), + action=ActionSearch.model_construct( + type="search", queries=[""] + ), ), ResponseBuilders.response_output_message("Granada"), ],