Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

supporting literals as tag type #635

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ClassVar,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -141,10 +142,10 @@ def embedded(cls):

def is_supported_container_type(typ: Optional[type]) -> bool:
# TODO: Wait, why don't we support indexing sets?
if typ == list or typ == tuple:
if typ == list or typ == tuple or typ == Literal:
return True
unwrapped = get_origin(typ)
return unwrapped == list or unwrapped == tuple
return unwrapped == list or unwrapped == tuple or unwrapped == Literal


def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]):
Expand Down Expand Up @@ -1414,6 +1415,8 @@ def outer_type_or_annotation(field):
if not isinstance(field.annotation, type):
raise AttributeError(f"could not extract outer type from field {field}")
return field.annotation
elif get_origin(field.annotation) == Literal:
return str
else:
return field.annotation.__args__[0]

Expand Down Expand Up @@ -2057,21 +2060,33 @@ def schema_for_type(
# find any values marked as indexed.
if is_container_type and not is_vector:
field_type = get_origin(typ)
embedded_cls = get_args(typ)
if not embedded_cls:
log.warning(
"Model %s defined an empty list or tuple field: %s", cls, name
if field_type == Literal:
path = f"{json_path}.{name}"
return cls.schema_for_type(
path,
name,
name_prefix,
str,
field_info,
parent_type=field_type,
)
else:
embedded_cls = get_args(typ)
if not embedded_cls:
log.warning(
"Model %s defined an empty list or tuple field: %s", cls, name
)
return ""
path = f"{json_path}.{name}[*]"
embedded_cls = embedded_cls[0]
return cls.schema_for_type(
path,
name,
name_prefix,
embedded_cls,
field_info,
parent_type=field_type,
)
return ""
embedded_cls = embedded_cls[0]
return cls.schema_for_type(
f"{json_path}.{name}[*]",
name,
name_prefix,
embedded_cls,
field_info,
parent_type=field_type,
)
elif field_is_model:
name_prefix = f"{name_prefix}_{name}" if name_prefix else name
sub_fields = []
Expand Down
22 changes: 22 additions & 0 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,3 +917,25 @@ class TestUpdate(HashModel):

rematerialized = await TestUpdate.find(TestUpdate.pk == t.pk).first()
assert rematerialized.age == 34


@py_test_mark_asyncio
async def test_literals():
from typing import Literal

class TestLiterals(HashModel):
flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple")

schema = TestLiterals.redisearch_schema()

key_prefix = TestLiterals.make_key(
TestLiterals._meta.primary_key_pattern.format(pk="")
)
assert schema == (
f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | flavor TAG SEPARATOR |"
)
await Migrator().run()
item = TestLiterals(flavor="pumpkin")
await item.save()
rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first()
assert rematerialized.pk == item.pk
24 changes: 24 additions & 0 deletions tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ class ModelWithIntPk(JsonModel):
m = await ModelWithIntPk.find(ModelWithIntPk.my_id == 42).first()
assert m.my_id == 42


@py_test_mark_asyncio
async def test_pagination():
class Test(JsonModel):
Expand All @@ -1121,3 +1122,26 @@ async def get_page(cls, offset, limit):
res = await Test.get_page(10, 30)
assert len(res) == 30
assert res[0].num == 10


@py_test_mark_asyncio
async def test_literals():
from typing import Literal

class TestLiterals(JsonModel):
flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple")

schema = TestLiterals.redisearch_schema()

key_prefix = TestLiterals.make_key(
TestLiterals._meta.primary_key_pattern.format(pk="")
)
assert schema == (
f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | "
"$.flavor AS flavor TAG SEPARATOR |"
)
await Migrator().run()
item = TestLiterals(flavor="pumpkin")
await item.save()
rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first()
assert rematerialized.pk == item.pk
Loading