From 4b51f3c161866027ccd8476c00958c877a02b704 Mon Sep 17 00:00:00 2001 From: mon Date: Thu, 1 Aug 2024 10:58:20 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=8F=EF=B8=8F=20Fix=20pydantic=20invalid?= =?UTF-8?q?=20when=20table=3DTrue(#1036)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 10 ++++++++++ tests/test_instance_no_args.py | 22 +--------------------- tests/test_validation.py | 33 ++++++++++++++++++++++++++++++++- 3 files changed, 43 insertions(+), 22 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 4018d1bb39..551978295d 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -346,6 +346,16 @@ def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None: self_instance=self, ) else: + raw_self = self.model_copy() + pydantic_validated_model = self.__pydantic_validator__.validate_python( + data, + self_instance=raw_self, + ) + pydantic_dict = pydantic_validated_model.model_dump() + for k in pydantic_dict.keys(): + if k not in data.keys(): + continue + data[k] = pydantic_dict[k] sqlmodel_table_construct( self_instance=self, values=data, diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index 5c8ad77531..323cecce8f 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -2,27 +2,7 @@ import pytest from pydantic import ValidationError -from sqlmodel import Field, Session, SQLModel, create_engine, select - - -def test_allow_instantiation_without_arguments(clear_sqlmodel): - class Item(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str - description: Optional[str] = None - - engine = create_engine("sqlite:///:memory:") - SQLModel.metadata.create_all(engine) - with Session(engine) as db: - item = Item() - item.name = "Rick" - db.add(item) - db.commit() - statement = select(Item) - result = db.exec(statement).all() - assert len(result) == 1 - assert isinstance(item.id, int) - SQLModel.metadata.clear() +from sqlmodel import Field, SQLModel def test_not_allow_instantiation_without_arguments_if_not_table(): diff --git a/tests/test_validation.py b/tests/test_validation.py index 3265922070..1bebe0af1c 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -2,7 +2,7 @@ import pytest from pydantic.error_wrappers import ValidationError -from sqlmodel import SQLModel +from sqlmodel import Field, SQLModel from .conftest import needs_pydanticv1, needs_pydanticv2 @@ -63,3 +63,34 @@ def reject_none(cls, v): with pytest.raises(ValidationError): Hero.model_validate({"name": None, "age": 25}) + + +@needs_pydanticv2 +def test_validation_with_table_true(): + """Test validation with table=True.""" + from pydantic import field_validator + + class Hero(SQLModel, table=True): + name: Optional[str] = Field(default=None, primary_key=True) + secret_name: Optional[str] = None + age: Optional[int] = None + + @field_validator("age", mode="after") + @classmethod + def double_age(cls, v): + if v is not None: + return v * 2 + return v + + Hero(name="Deadpond", age=25) + Hero.model_validate({"name": "Deadpond", "age": 25}) + with pytest.raises(ValidationError): + Hero(name="Deadpond", secret_name="Dive Wilson", age="test") + with pytest.raises(ValidationError): + Hero.model_validate({"name": "Deadpond", "age": "test"}) + + double_age_hero = Hero(name="Deadpond", age=25) + assert double_age_hero.age == 50 + + double_age_hero = Hero.model_validate({"name": "Deadpond", "age": 25}) + assert double_age_hero.age == 50