diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index f6888a4f..e45e3939 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1638,6 +1638,11 @@ def check(self): *_, validation_error = validate_model(self.__class__, self.__dict__) if validation_error: raise validation_error + else: + from pydantic import TypeAdapter + + adapter = TypeAdapter(self.__class__) + adapter.validate_python(self.__dict__) class HashModel(RedisModel, abc.ABC): diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 185005f6..1870ec3d 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -875,3 +875,22 @@ async def test_xfix_queries(members, m): result = await m.Member.find(m.Member.bio % "*eat*").first() assert result.first_name == "Andrew" + + +@py_test_mark_asyncio +async def test_update_validation(): + class TestUpdate(HashModel): + name: str + age: int + + await Migrator().run() + t = TestUpdate(name="steve", age=34) + await t.save() + update_dict = dict() + update_dict["age"] = "cat" + + with pytest.raises(ValidationError): + await t.update(**update_dict) + + rematerialized = await TestUpdate.find(TestUpdate.pk == t.pk).first() + assert rematerialized.age == 34 diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 6e5cead4..b1950b6e 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -974,6 +974,37 @@ async def test_xfix_queries(m): @py_test_mark_asyncio +async def test_update_validation(): + + class Embedded(EmbeddedJsonModel): + price: float + name: str = Field(index=True) + + class TestUpdatesClass(JsonModel): + name: str + age: int + embedded: Embedded + + await Migrator().run() + embedded = Embedded(price=3.14, name="foo") + t = TestUpdatesClass(name="str", age=42, embedded=embedded) + await t.save() + + update_dict = dict() + update_dict["age"] = "foo" + with pytest.raises(ValidationError): + await t.update(**update_dict) + + t.age = 42 + update_dict.clear() + update_dict["embedded"] = "hello" + with pytest.raises(ValidationError): + await t.update(**update_dict) + + rematerialized = await TestUpdatesClass.find(TestUpdatesClass.pk == t.pk).first() + assert rematerialized.age == 42 + + async def test_model_with_dict(): class EmbeddedJsonModelWithDict(EmbeddedJsonModel): dict: Dict