From 7238e99ae2d31a73c369a9dbcbd7a47d8664510c Mon Sep 17 00:00:00 2001 From: slorello89 Date: Tue, 7 May 2024 07:32:42 -0400 Subject: [PATCH] kick off type validation before update --- aredis_om/model/model.py | 5 +++++ tests/test_hash_model.py | 19 +++++++++++++++++++ tests/test_json_model.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 56c5c90d..258c990a 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1628,6 +1628,11 @@ def check(self): *_, validation_error = validate_model(self.__class__, self.__dict__) if validation_error: raise validation_error + else: + from pydantic import TypeAdapter + + adapter = TypeAdapter(self.__class__) + adapter.validate_python(self.__dict__) class HashModel(RedisModel, abc.ABC): diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 185005f6..1870ec3d 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -875,3 +875,22 @@ async def test_xfix_queries(members, m): result = await m.Member.find(m.Member.bio % "*eat*").first() assert result.first_name == "Andrew" + + +@py_test_mark_asyncio +async def test_update_validation(): + class TestUpdate(HashModel): + name: str + age: int + + await Migrator().run() + t = TestUpdate(name="steve", age=34) + await t.save() + update_dict = dict() + update_dict["age"] = "cat" + + with pytest.raises(ValidationError): + await t.update(**update_dict) + + rematerialized = await TestUpdate.find(TestUpdate.pk == t.pk).first() + assert rematerialized.age == 34 diff --git a/tests/test_json_model.py b/tests/test_json_model.py index d5744858..602de97d 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -971,3 +971,35 @@ async def test_xfix_queries(m): result = await m.Member.find(m.Member.bio % "*ack*").first() assert result.first_name == "Steve" + + +@py_test_mark_asyncio +async def test_update_validation(): + + class Embedded(EmbeddedJsonModel): + price: float + name: str = Field(index=True) + + class TestUpdatesClass(JsonModel): + name: str + age: int + embedded: Embedded + + await Migrator().run() + embedded = Embedded(price=3.14, name="foo") + t = TestUpdatesClass(name="str", age=42, embedded=embedded) + await t.save() + + update_dict = dict() + update_dict["age"] = "foo" + with pytest.raises(ValidationError): + await t.update(**update_dict) + + t.age = 42 + update_dict.clear() + update_dict["embedded"] = "hello" + with pytest.raises(ValidationError): + await t.update(**update_dict) + + rematerialized = await TestUpdatesClass.find(TestUpdatesClass.pk == t.pk).first() + assert rematerialized.age == 42