Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/openai/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@
)

import pydantic

try:
from pydantic.errors import PydanticUserError
except ImportError: # pydantic v1 does not export PydanticUserError
class PydanticUserError(Exception): # type: ignore[no-redef]
pass

from pydantic.fields import FieldInfo

from ._types import (
Expand Down Expand Up @@ -442,7 +449,11 @@ def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None:
# TODO
return None

schema = cls.__pydantic_core_schema__
try:
schema = cls.__pydantic_core_schema__
except PydanticUserError:
cls.model_rebuild(force=True, raise_errors=False)
schema = cls.__pydantic_core_schema__
if schema["type"] == "model":
fields = schema["schema"]
if fields["type"] == "model-fields":
Expand Down
39 changes: 38 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from openai._utils import PropertyInfo
from openai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
from openai._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
from openai._models import DISCRIMINATOR_CACHE, BaseModel, _get_extra_fields_type, construct_type


class BasicModel(BaseModel):
Expand Down Expand Up @@ -157,6 +157,43 @@ def test_unknown_fields() -> None:
assert model_dump(m2) == {"foo": "foo", "unknown": {"foo_bar": True}}


@pytest.mark.skipif(PYDANTIC_V1, reason="pydantic v2 only")
def test_get_extra_fields_type_rebuilds_after_pydantic_user_error() -> None:
schema = {
"type": "model",
"schema": {
"type": "model-fields",
"extras_schema": {"cls": str},
},
}

class Meta(type):
_schema_calls = 0
_rebuild_calls = 0

@property
def __pydantic_core_schema__(cls) -> dict[str, object]:
type(cls)._schema_calls += 1
if type(cls)._schema_calls == 1:
raise pydantic.errors.PydanticUserError(
"schema unavailable",
code="schema-unavailable",
)
return schema

def model_rebuild(cls, *, force: bool, raise_errors: bool) -> bool:
type(cls)._rebuild_calls += 1
assert force is True
assert raise_errors is False
return True

class FakeModel(metaclass=Meta):
pass

assert _get_extra_fields_type(cast(Any, FakeModel)) is str
assert Meta._rebuild_calls == 1


def test_strict_validation_unknown_fields() -> None:
class Model(BaseModel):
foo: str
Expand Down