Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 73 additions & 3 deletions temporalio/contrib/openai_agents/_temporal_openai_agents.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
"""Initialize Temporal OpenAI Agents overrides."""

import dataclasses
import json
import typing
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from datetime import timedelta

import pydantic
from agents import ModelProvider, Trace, set_trace_provider
from agents.run import get_default_agent_runner, set_default_agent_runner
from agents.tracing import get_trace_provider
from agents.tracing.provider import DefaultTraceProvider

# construct_type is OpenAI's lenient (non-validating) model builder, the same
# one the SDK uses to parse live API responses. It is in a private module but
# has no public alias.
from openai._models import construct_type

import temporalio.api.common.v1
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._openai_runner import (
Expand All @@ -25,12 +33,14 @@
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
from temporalio.contrib.opentelemetry._tracer_provider import ReplaySafeTracerProvider
from temporalio.contrib.pydantic import (
PydanticPayloadConverter,
PydanticJSONPlainPayloadConverter,
ToJsonOptions,
)
from temporalio.converter import (
CompositePayloadConverter,
DataConverter,
DefaultPayloadConverter,
JSONPlainPayloadConverter,
)
from temporalio.plugin import SimplePlugin
from temporalio.worker import WorkflowRunner
Expand Down Expand Up @@ -64,12 +74,72 @@ def _set_open_ai_agent_temporal_overrides(
set_trace_provider(previous_trace_provider or DefaultTraceProvider())


class OpenAIPayloadConverter(PydanticPayloadConverter):
def _lenient_construct(type_: typing.Any, value: typing.Any) -> typing.Any:
"""Build ``value`` into ``type_`` without enforcing required fields.

OpenAI's ``construct_type`` handles its own response models (and the
unions/lists thereof), but not the ``agents`` dataclasses that wrap them
(e.g. ``ModelResponse``), so the dataclass layer is reconstructed here and
each field delegated to ``construct_type``. ``include_extras`` preserves the
``Annotated`` discriminators the unions rely on.
"""
if (
isinstance(type_, type)
and dataclasses.is_dataclass(type_)
and isinstance(value, dict)
):
hints = typing.get_type_hints(type_, include_extras=True)
return type_(
**{
field.name: _lenient_construct(
hints.get(field.name, object), value[field.name]
)
for field in dataclasses.fields(type_)
if field.name in value
}
)
return construct_type(type_=type_, value=value)


class _OpenAIJSONPlainPayloadConverter(PydanticJSONPlainPayloadConverter):
"""Strict pydantic deserialization with a lenient fallback.

OpenAI's response models can drift from live API payloads (e.g. a
deprecated-but-required field the API has stopped sending). The SDK tolerates
this when parsing responses, but strict ``validate_json`` on the workflow
side does not, so fall back to lenient construction when validation fails.
"""

def from_payload(
self,
payload: temporalio.api.common.v1.Payload,
type_hint: type | None = None,
) -> typing.Any:
"""See base class."""
try:
return super().from_payload(payload, type_hint)
except pydantic.ValidationError:
if type_hint is None:
raise
return _lenient_construct(type_hint, json.loads(payload.data))


class OpenAIPayloadConverter(CompositePayloadConverter):
"""PayloadConverter for OpenAI agents."""

def __init__(self) -> None:
"""Initialize a payload converter."""
super().__init__(ToJsonOptions(exclude_unset=True))
json_payload_converter = _OpenAIJSONPlainPayloadConverter(
ToJsonOptions(exclude_unset=True)
)
super().__init__(
*(
c
if not isinstance(c, JSONPlainPayloadConverter)
else json_payload_converter
for c in DefaultPayloadConverter.default_encoding_payload_converters
)
)


def _data_converter(converter: DataConverter | None) -> DataConverter:
Expand Down
4 changes: 3 additions & 1 deletion tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,9 @@ def research_mock_model():
id="",
status="completed",
type="web_search_call",
action=ActionSearch(query="", type="search"),
action=ActionSearch.model_construct(
type="search", queries=[""]
),
),
ResponseBuilders.response_output_message("Granada"),
],
Expand Down
Loading