diff --git a/src/openai/_models.py b/src/openai/_models.py index 810e49dfc5..50773f193c 100644 --- a/src/openai/_models.py +++ b/src/openai/_models.py @@ -36,6 +36,13 @@ ) import pydantic + +try: + from pydantic.errors import PydanticUserError +except ImportError: # pydantic v1 does not export PydanticUserError + class PydanticUserError(Exception): # type: ignore[no-redef] + pass + from pydantic.fields import FieldInfo from ._types import ( @@ -442,7 +449,11 @@ def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None: # TODO return None - schema = cls.__pydantic_core_schema__ + try: + schema = cls.__pydantic_core_schema__ + except PydanticUserError: + cls.model_rebuild(force=True, raise_errors=False) + schema = cls.__pydantic_core_schema__ if schema["type"] == "model": fields = schema["schema"] if fields["type"] == "model-fields": diff --git a/tests/test_models.py b/tests/test_models.py index 588869ee35..63e2ef0b5c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,7 @@ from openai._utils import PropertyInfo from openai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json -from openai._models import DISCRIMINATOR_CACHE, BaseModel, construct_type +from openai._models import DISCRIMINATOR_CACHE, BaseModel, _get_extra_fields_type, construct_type class BasicModel(BaseModel): @@ -157,6 +157,43 @@ def test_unknown_fields() -> None: assert model_dump(m2) == {"foo": "foo", "unknown": {"foo_bar": True}} +@pytest.mark.skipif(PYDANTIC_V1, reason="pydantic v2 only") +def test_get_extra_fields_type_rebuilds_after_pydantic_user_error() -> None: + schema = { + "type": "model", + "schema": { + "type": "model-fields", + "extras_schema": {"cls": str}, + }, + } + + class Meta(type): + _schema_calls = 0 + _rebuild_calls = 0 + + @property + def __pydantic_core_schema__(cls) -> dict[str, object]: + type(cls)._schema_calls += 1 + if type(cls)._schema_calls == 1: + raise pydantic.errors.PydanticUserError( + "schema unavailable", + code="schema-unavailable", + ) + return schema + + def model_rebuild(cls, *, force: bool, raise_errors: bool) -> bool: + type(cls)._rebuild_calls += 1 + assert force is True + assert raise_errors is False + return True + + class FakeModel(metaclass=Meta): + pass + + assert _get_extra_fields_type(cast(Any, FakeModel)) is str + assert Meta._rebuild_calls == 1 + + def test_strict_validation_unknown_fields() -> None: class Model(BaseModel): foo: str