Skip to content

Commit f245488

Browse files
authored
supporting literals as tag type (#635)
* supporting literals as tag type * fixing key-prefix issue
1 parent b20e887 commit f245488

File tree

3 files changed

+77
-16
lines changed

3 files changed

+77
-16
lines changed

aredis_om/model/model.py

+31-16
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ClassVar,
1515
Dict,
1616
List,
17+
Literal,
1718
Mapping,
1819
Optional,
1920
Sequence,
@@ -141,10 +142,10 @@ def embedded(cls):
141142

142143
def is_supported_container_type(typ: Optional[type]) -> bool:
143144
# TODO: Wait, why don't we support indexing sets?
144-
if typ == list or typ == tuple:
145+
if typ == list or typ == tuple or typ == Literal:
145146
return True
146147
unwrapped = get_origin(typ)
147-
return unwrapped == list or unwrapped == tuple
148+
return unwrapped == list or unwrapped == tuple or unwrapped == Literal
148149

149150

150151
def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]):
@@ -1414,6 +1415,8 @@ def outer_type_or_annotation(field):
14141415
if not isinstance(field.annotation, type):
14151416
raise AttributeError(f"could not extract outer type from field {field}")
14161417
return field.annotation
1418+
elif get_origin(field.annotation) == Literal:
1419+
return str
14171420
else:
14181421
return field.annotation.__args__[0]
14191422

@@ -2057,21 +2060,33 @@ def schema_for_type(
20572060
# find any values marked as indexed.
20582061
if is_container_type and not is_vector:
20592062
field_type = get_origin(typ)
2060-
embedded_cls = get_args(typ)
2061-
if not embedded_cls:
2062-
log.warning(
2063-
"Model %s defined an empty list or tuple field: %s", cls, name
2063+
if field_type == Literal:
2064+
path = f"{json_path}.{name}"
2065+
return cls.schema_for_type(
2066+
path,
2067+
name,
2068+
name_prefix,
2069+
str,
2070+
field_info,
2071+
parent_type=field_type,
2072+
)
2073+
else:
2074+
embedded_cls = get_args(typ)
2075+
if not embedded_cls:
2076+
log.warning(
2077+
"Model %s defined an empty list or tuple field: %s", cls, name
2078+
)
2079+
return ""
2080+
path = f"{json_path}.{name}[*]"
2081+
embedded_cls = embedded_cls[0]
2082+
return cls.schema_for_type(
2083+
path,
2084+
name,
2085+
name_prefix,
2086+
embedded_cls,
2087+
field_info,
2088+
parent_type=field_type,
20642089
)
2065-
return ""
2066-
embedded_cls = embedded_cls[0]
2067-
return cls.schema_for_type(
2068-
f"{json_path}.{name}[*]",
2069-
name,
2070-
name_prefix,
2071-
embedded_cls,
2072-
field_info,
2073-
parent_type=field_type,
2074-
)
20752090
elif field_is_model:
20762091
name_prefix = f"{name_prefix}_{name}" if name_prefix else name
20772092
sub_fields = []

tests/test_hash_model.py

+22
Original file line numberDiff line numberDiff line change
@@ -917,3 +917,25 @@ class TestUpdate(HashModel):
917917

918918
rematerialized = await TestUpdate.find(TestUpdate.pk == t.pk).first()
919919
assert rematerialized.age == 34
920+
921+
922+
@py_test_mark_asyncio
923+
async def test_literals():
924+
from typing import Literal
925+
926+
class TestLiterals(HashModel):
927+
flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple")
928+
929+
schema = TestLiterals.redisearch_schema()
930+
931+
key_prefix = TestLiterals.make_key(
932+
TestLiterals._meta.primary_key_pattern.format(pk="")
933+
)
934+
assert schema == (
935+
f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | flavor TAG SEPARATOR |"
936+
)
937+
await Migrator().run()
938+
item = TestLiterals(flavor="pumpkin")
939+
await item.save()
940+
rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first()
941+
assert rematerialized.pk == item.pk

tests/test_json_model.py

+24
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,7 @@ class ModelWithIntPk(JsonModel):
10981098
m = await ModelWithIntPk.find(ModelWithIntPk.my_id == 42).first()
10991099
assert m.my_id == 42
11001100

1101+
11011102
@py_test_mark_asyncio
11021103
async def test_pagination():
11031104
class Test(JsonModel):
@@ -1121,3 +1122,26 @@ async def get_page(cls, offset, limit):
11211122
res = await Test.get_page(10, 30)
11221123
assert len(res) == 30
11231124
assert res[0].num == 10
1125+
1126+
1127+
@py_test_mark_asyncio
1128+
async def test_literals():
1129+
from typing import Literal
1130+
1131+
class TestLiterals(JsonModel):
1132+
flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple")
1133+
1134+
schema = TestLiterals.redisearch_schema()
1135+
1136+
key_prefix = TestLiterals.make_key(
1137+
TestLiterals._meta.primary_key_pattern.format(pk="")
1138+
)
1139+
assert schema == (
1140+
f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | "
1141+
"$.flavor AS flavor TAG SEPARATOR |"
1142+
)
1143+
await Migrator().run()
1144+
item = TestLiterals(flavor="pumpkin")
1145+
await item.save()
1146+
rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first()
1147+
assert rematerialized.pk == item.pk

0 commit comments

Comments
 (0)