Skip to content

Commit b698a11

Browse files
committed
Migrate from aioredis to redis with asyncio support
Add test for redis type Fix imports from wrong module (for tests_sync)
1 parent 490d8c1 commit b698a11

15 files changed

+82
-104
lines changed

aredis_om/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .async_redis import redis # isort:skip
12
from .checks import has_redis_json, has_redisearch
23
from .connections import get_redis_connection
34
from .model.migrations.migrator import MigrationError, Migrator

aredis_om/async_redis.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from redis import asyncio as redis

aredis_om/connections.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import os
22

3-
import aioredis
3+
from . import redis
44

55

66
URL = os.environ.get("REDIS_OM_URL", None)
77

88

9-
def get_redis_connection(**kwargs) -> aioredis.Redis:
9+
def get_redis_connection(**kwargs) -> redis.Redis:
1010
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
1111
# environment variable, we'll create the Redis client from the URL.
1212
url = kwargs.pop("url", URL)
1313
if url:
14-
return aioredis.Redis.from_url(url, **kwargs)
14+
return redis.Redis.from_url(url, **kwargs)
1515

1616
# Decode from UTF-8 by default
1717
if "decode_responses" not in kwargs:
1818
kwargs["decode_responses"] = True
19-
return aioredis.Redis(**kwargs)
19+
return redis.Redis(**kwargs)

aredis_om/model/migrations/migrator.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from enum import Enum
55
from typing import List, Optional
66

7-
from aioredis import Redis, ResponseError
7+
from ... import redis
88

99

1010
log = logging.getLogger(__name__)
@@ -39,18 +39,19 @@ def schema_hash_key(index_name):
3939
return f"{index_name}:hash"
4040

4141

42-
async def create_index(redis: Redis, index_name, schema, current_hash):
43-
db_number = redis.connection_pool.connection_kwargs.get("db")
42+
async def create_index(conn: redis.Redis, index_name, schema, current_hash):
43+
db_number = conn.connection_pool.connection_kwargs.get("db")
4444
if db_number and db_number > 0:
4545
raise MigrationError(
4646
"Creating search indexes is only supported in database 0. "
4747
f"You attempted to create an index in database {db_number}"
4848
)
4949
try:
50-
await redis.execute_command(f"ft.info {index_name}")
51-
except ResponseError:
52-
await redis.execute_command(f"ft.create {index_name} {schema}")
53-
await redis.set(schema_hash_key(index_name), current_hash)
50+
await conn.execute_command(f"ft.info {index_name}")
51+
except redis.ResponseError:
52+
await conn.execute_command(f"ft.create {index_name} {schema}")
53+
# TODO: remove "type: ignore" when type stubs will be fixed
54+
await conn.set(schema_hash_key(index_name), current_hash) # type: ignore
5455
else:
5556
log.info("Index already exists, skipping. Index hash: %s", index_name)
5657

@@ -67,7 +68,7 @@ class IndexMigration:
6768
schema: str
6869
hash: str
6970
action: MigrationAction
70-
redis: Redis
71+
conn: redis.Redis
7172
previous_hash: Optional[str] = None
7273

7374
async def run(self):
@@ -78,14 +79,14 @@ async def run(self):
7879

7980
async def create(self):
8081
try:
81-
await create_index(self.redis, self.index_name, self.schema, self.hash)
82-
except ResponseError:
82+
await create_index(self.conn, self.index_name, self.schema, self.hash)
83+
except redis.ResponseError:
8384
log.info("Index already exists: %s", self.index_name)
8485

8586
async def drop(self):
8687
try:
87-
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
88-
except ResponseError:
88+
await self.conn.execute_command(f"FT.DROPINDEX {self.index_name}")
89+
except redis.ResponseError:
8990
log.info("Index does not exist: %s", self.index_name)
9091

9192

@@ -105,7 +106,7 @@ async def detect_migrations(self):
105106

106107
for name, cls in model_registry.items():
107108
hash_key = schema_hash_key(cls.Meta.index_name)
108-
redis = cls.db()
109+
conn = cls.db()
109110
try:
110111
schema = cls.redisearch_schema()
111112
except NotImplementedError:
@@ -114,21 +115,21 @@ async def detect_migrations(self):
114115
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
115116

116117
try:
117-
await redis.execute_command("ft.info", cls.Meta.index_name)
118-
except ResponseError:
118+
await conn.execute_command("ft.info", cls.Meta.index_name)
119+
except redis.ResponseError:
119120
self.migrations.append(
120121
IndexMigration(
121122
name,
122123
cls.Meta.index_name,
123124
schema,
124125
current_hash,
125126
MigrationAction.CREATE,
126-
redis,
127+
conn,
127128
)
128129
)
129130
continue
130131

131-
stored_hash = await redis.get(hash_key)
132+
stored_hash = await conn.get(hash_key)
132133
schema_out_of_date = current_hash != stored_hash
133134

134135
if schema_out_of_date:
@@ -140,7 +141,7 @@ async def detect_migrations(self):
140141
schema,
141142
current_hash,
142143
MigrationAction.DROP,
143-
redis,
144+
conn,
144145
stored_hash,
145146
)
146147
)
@@ -151,7 +152,7 @@ async def detect_migrations(self):
151152
schema,
152153
current_hash,
153154
MigrationAction.CREATE,
154-
redis,
155+
conn,
155156
stored_hash,
156157
)
157158
)

aredis_om/model/model.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
no_type_check,
2525
)
2626

27-
import aioredis
28-
from aioredis.client import Pipeline
2927
from pydantic import BaseModel, validator
3028
from pydantic.fields import FieldInfo as PydanticFieldInfo
3129
from pydantic.fields import ModelField, Undefined, UndefinedType
@@ -35,9 +33,10 @@
3533
from typing_extensions import Protocol, get_args, get_origin
3634
from ulid import ULID
3735

36+
from .. import redis
3837
from ..checks import has_redis_json, has_redisearch
3938
from ..connections import get_redis_connection
40-
from ..unasync_util import ASYNC_MODE
39+
from ..util import ASYNC_MODE
4140
from .encoders import jsonable_encoder
4241
from .render_tree import render_tree
4342
from .token_escaper import TokenEscaper
@@ -975,7 +974,7 @@ class BaseMeta(Protocol):
975974
global_key_prefix: str
976975
model_key_prefix: str
977976
primary_key_pattern: str
978-
database: aioredis.Redis
977+
database: redis.Redis
979978
primary_key: PrimaryKey
980979
primary_key_creator_cls: Type[PrimaryKeyCreator]
981980
index_name: str
@@ -994,7 +993,7 @@ class DefaultMeta:
994993
global_key_prefix: Optional[str] = None
995994
model_key_prefix: Optional[str] = None
996995
primary_key_pattern: Optional[str] = None
997-
database: Optional[aioredis.Redis] = None
996+
database: Optional[redis.Redis] = None
998997
primary_key: Optional[PrimaryKey] = None
999998
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
1000999
index_name: Optional[str] = None
@@ -1127,10 +1126,14 @@ async def update(self, **field_values):
11271126
"""Update this model instance with the specified key-value pairs."""
11281127
raise NotImplementedError
11291128

1130-
async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
1129+
async def save(
1130+
self, pipeline: Optional[redis.client.Pipeline] = None
1131+
) -> "RedisModel":
11311132
raise NotImplementedError
11321133

1133-
async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
1134+
async def expire(
1135+
self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None
1136+
):
11341137
if pipeline is None:
11351138
db = self.db()
11361139
else:
@@ -1241,7 +1244,7 @@ def get_annotations(cls):
12411244
async def add(
12421245
cls,
12431246
models: Sequence["RedisModel"],
1244-
pipeline: Optional[Pipeline] = None,
1247+
pipeline: Optional[redis.client.Pipeline] = None,
12451248
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
12461249
) -> Sequence["RedisModel"]:
12471250
if pipeline is None:
@@ -1301,7 +1304,9 @@ def __init_subclass__(cls, **kwargs):
13011304
f"HashModels cannot index dataclass fields. Field: {name}"
13021305
)
13031306

1304-
async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
1307+
async def save(
1308+
self, pipeline: Optional[redis.client.Pipeline] = None
1309+
) -> "HashModel":
13051310
self.check()
13061311
if pipeline is None:
13071312
db = self.db()
@@ -1473,7 +1478,9 @@ def __init__(self, *args, **kwargs):
14731478
)
14741479
super().__init__(*args, **kwargs)
14751480

1476-
async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
1481+
async def save(
1482+
self, pipeline: Optional[redis.client.Pipeline] = None
1483+
) -> "JsonModel":
14771484
self.check()
14781485
if pipeline is None:
14791486
db = self.db()

aredis_om/sync_redis.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import redis

aredis_om/unasync_util.py

-41
This file was deleted.

aredis_om/util.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import inspect
2+
3+
4+
def is_async_mode():
5+
async def f():
6+
"""Unasync transforms async functions in sync functions"""
7+
return None
8+
9+
return inspect.iscoroutinefunction(f)
10+
11+
12+
ASYNC_MODE = is_async_mode()

make_sync.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
ADDITIONAL_REPLACEMENTS = {
77
"aredis_om": "redis_om",
8-
"aioredis": "redis",
8+
"async_redis": "sync_redis",
99
":tests.": ":tests_sync.",
1010
"pytest_asyncio": "pytest",
1111
"py_test_mark_asyncio": "py_test_mark_sync",
@@ -26,11 +26,12 @@ def main():
2626
),
2727
]
2828
filepaths = []
29-
for root, _, filenames in os.walk(
30-
Path(__file__).absolute().parent
31-
):
29+
for root, _, filenames in os.walk(Path(__file__).absolute().parent):
3230
for filename in filenames:
33-
if filename.rpartition(".")[-1] in ("py", "pyi",):
31+
if filename.rpartition(".")[-1] in (
32+
"py",
33+
"pyi",
34+
):
3435
filepaths.append(os.path.join(root, filename))
3536

3637
unasync.unasync_files(filepaths, rules)

poetry.lock

+1-20
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ include=[
2323

2424
[tool.poetry.dependencies]
2525
python = "^3.7"
26-
redis = ">=3.5.3,<5.0.0"
27-
aioredis = "^2.0.0"
26+
redis = ">=4.2.0,<5.0.0"
2827
pydantic = "^1.8.2"
2928
click = "^8.0.1"
3029
six = "^1.16.0"

tests/test_hash_model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
# We need to run this check as sync code (during tests) even in async mode
2424
# because we call it in the top-level module scope.
2525
from redis_om import has_redisearch
26-
from tests.conftest import py_test_mark_asyncio
26+
27+
from .conftest import py_test_mark_asyncio
28+
2729

2830
if not has_redisearch():
2931
pytestmark = pytest.mark.skip

tests/test_json_model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
# We need to run this check as sync code (during tests) even in async mode
2626
# because we call it in the top-level module scope.
2727
from redis_om import has_redis_json
28-
from tests.conftest import py_test_mark_asyncio
28+
29+
from .conftest import py_test_mark_asyncio
30+
2931

3032
if not has_redis_json():
3133
pytestmark = pytest.mark.skip

tests/test_oss_redis_features.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from pydantic import ValidationError
1010

1111
from aredis_om import HashModel, Migrator, NotFoundError, RedisModelError
12-
from tests.conftest import py_test_mark_asyncio
12+
13+
from .conftest import py_test_mark_asyncio
1314

1415

1516
today = datetime.date.today()

0 commit comments

Comments
 (0)