Skip to content

Commit 4661459

Browse files
moznuychayim
andauthored
Migrate from aioredis to redis-py with asyncio support (#233)
* Migrate from aioredis to redis with asyncio support Add test for redis type Fix imports from wrong module (for tests_sync) * fixing merge conflicts and up to dating the lock file Co-authored-by: Chayim I. Kirshen <[email protected]>
1 parent e5e8872 commit 4661459

15 files changed

+476
-486
lines changed

aredis_om/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .async_redis import redis # isort:skip
12
from .checks import has_redis_json, has_redisearch
23
from .connections import get_redis_connection
34
from .model.migrations.migrator import MigrationError, Migrator

aredis_om/async_redis.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from redis import asyncio as redis

aredis_om/connections.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import os
22

3-
import aioredis
3+
from . import redis
44

55

66
URL = os.environ.get("REDIS_OM_URL", None)
77

88

9-
def get_redis_connection(**kwargs) -> aioredis.Redis:
9+
def get_redis_connection(**kwargs) -> redis.Redis:
1010
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
1111
# environment variable, we'll create the Redis client from the URL.
1212
url = kwargs.pop("url", URL)
1313
if url:
14-
return aioredis.Redis.from_url(url, **kwargs)
14+
return redis.Redis.from_url(url, **kwargs)
1515

1616
# Decode from UTF-8 by default
1717
if "decode_responses" not in kwargs:
1818
kwargs["decode_responses"] = True
19-
return aioredis.Redis(**kwargs)
19+
return redis.Redis(**kwargs)

aredis_om/model/migrations/migrator.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from enum import Enum
55
from typing import List, Optional
66

7-
from aioredis import Redis, ResponseError
7+
from ... import redis
88

99

1010
log = logging.getLogger(__name__)
@@ -39,18 +39,19 @@ def schema_hash_key(index_name):
3939
return f"{index_name}:hash"
4040

4141

42-
async def create_index(redis: Redis, index_name, schema, current_hash):
43-
db_number = redis.connection_pool.connection_kwargs.get("db")
42+
async def create_index(conn: redis.Redis, index_name, schema, current_hash):
43+
db_number = conn.connection_pool.connection_kwargs.get("db")
4444
if db_number and db_number > 0:
4545
raise MigrationError(
4646
"Creating search indexes is only supported in database 0. "
4747
f"You attempted to create an index in database {db_number}"
4848
)
4949
try:
50-
await redis.execute_command(f"ft.info {index_name}")
51-
except ResponseError:
52-
await redis.execute_command(f"ft.create {index_name} {schema}")
53-
await redis.set(schema_hash_key(index_name), current_hash)
50+
await conn.execute_command(f"ft.info {index_name}")
51+
except redis.ResponseError:
52+
await conn.execute_command(f"ft.create {index_name} {schema}")
53+
# TODO: remove "type: ignore" when type stubs will be fixed
54+
await conn.set(schema_hash_key(index_name), current_hash) # type: ignore
5455
else:
5556
log.info("Index already exists, skipping. Index hash: %s", index_name)
5657

@@ -67,7 +68,7 @@ class IndexMigration:
6768
schema: str
6869
hash: str
6970
action: MigrationAction
70-
redis: Redis
71+
conn: redis.Redis
7172
previous_hash: Optional[str] = None
7273

7374
async def run(self):
@@ -78,14 +79,14 @@ async def run(self):
7879

7980
async def create(self):
8081
try:
81-
await create_index(self.redis, self.index_name, self.schema, self.hash)
82-
except ResponseError:
82+
await create_index(self.conn, self.index_name, self.schema, self.hash)
83+
except redis.ResponseError:
8384
log.info("Index already exists: %s", self.index_name)
8485

8586
async def drop(self):
8687
try:
87-
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
88-
except ResponseError:
88+
await self.conn.execute_command(f"FT.DROPINDEX {self.index_name}")
89+
except redis.ResponseError:
8990
log.info("Index does not exist: %s", self.index_name)
9091

9192

@@ -105,7 +106,7 @@ async def detect_migrations(self):
105106

106107
for name, cls in model_registry.items():
107108
hash_key = schema_hash_key(cls.Meta.index_name)
108-
redis = cls.db()
109+
conn = cls.db()
109110
try:
110111
schema = cls.redisearch_schema()
111112
except NotImplementedError:
@@ -114,21 +115,21 @@ async def detect_migrations(self):
114115
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
115116

116117
try:
117-
await redis.execute_command("ft.info", cls.Meta.index_name)
118-
except ResponseError:
118+
await conn.execute_command("ft.info", cls.Meta.index_name)
119+
except redis.ResponseError:
119120
self.migrations.append(
120121
IndexMigration(
121122
name,
122123
cls.Meta.index_name,
123124
schema,
124125
current_hash,
125126
MigrationAction.CREATE,
126-
redis,
127+
conn,
127128
)
128129
)
129130
continue
130131

131-
stored_hash = await redis.get(hash_key)
132+
stored_hash = await conn.get(hash_key)
132133
schema_out_of_date = current_hash != stored_hash
133134

134135
if schema_out_of_date:
@@ -140,7 +141,7 @@ async def detect_migrations(self):
140141
schema,
141142
current_hash,
142143
MigrationAction.DROP,
143-
redis,
144+
conn,
144145
stored_hash,
145146
)
146147
)
@@ -151,7 +152,7 @@ async def detect_migrations(self):
151152
schema,
152153
current_hash,
153154
MigrationAction.CREATE,
154-
redis,
155+
conn,
155156
stored_hash,
156157
)
157158
)

aredis_om/model/model.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
no_type_check,
2525
)
2626

27-
import aioredis
28-
from aioredis.client import Pipeline
2927
from pydantic import BaseModel, validator
3028
from pydantic.fields import FieldInfo as PydanticFieldInfo
3129
from pydantic.fields import ModelField, Undefined, UndefinedType
@@ -35,9 +33,10 @@
3533
from typing_extensions import Protocol, get_args, get_origin
3634
from ulid import ULID
3735

36+
from .. import redis
3837
from ..checks import has_redis_json, has_redisearch
3938
from ..connections import get_redis_connection
40-
from ..unasync_util import ASYNC_MODE
39+
from ..util import ASYNC_MODE
4140
from .encoders import jsonable_encoder
4241
from .render_tree import render_tree
4342
from .token_escaper import TokenEscaper
@@ -978,7 +977,7 @@ class BaseMeta(Protocol):
978977
global_key_prefix: str
979978
model_key_prefix: str
980979
primary_key_pattern: str
981-
database: aioredis.Redis
980+
database: redis.Redis
982981
primary_key: PrimaryKey
983982
primary_key_creator_cls: Type[PrimaryKeyCreator]
984983
index_name: str
@@ -997,7 +996,7 @@ class DefaultMeta:
997996
global_key_prefix: Optional[str] = None
998997
model_key_prefix: Optional[str] = None
999998
primary_key_pattern: Optional[str] = None
1000-
database: Optional[aioredis.Redis] = None
999+
database: Optional[redis.Redis] = None
10011000
primary_key: Optional[PrimaryKey] = None
10021001
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
10031002
index_name: Optional[str] = None
@@ -1130,10 +1129,14 @@ async def update(self, **field_values):
11301129
"""Update this model instance with the specified key-value pairs."""
11311130
raise NotImplementedError
11321131

1133-
async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
1132+
async def save(
1133+
self, pipeline: Optional[redis.client.Pipeline] = None
1134+
) -> "RedisModel":
11341135
raise NotImplementedError
11351136

1136-
async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
1137+
async def expire(
1138+
self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None
1139+
):
11371140
if pipeline is None:
11381141
db = self.db()
11391142
else:
@@ -1226,7 +1229,7 @@ def get_annotations(cls):
12261229
async def add(
12271230
cls,
12281231
models: Sequence["RedisModel"],
1229-
pipeline: Optional[Pipeline] = None,
1232+
pipeline: Optional[redis.client.Pipeline] = None,
12301233
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
12311234
) -> Sequence["RedisModel"]:
12321235
if pipeline is None:
@@ -1286,7 +1289,9 @@ def __init_subclass__(cls, **kwargs):
12861289
f"HashModels cannot index dataclass fields. Field: {name}"
12871290
)
12881291

1289-
async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
1292+
async def save(
1293+
self, pipeline: Optional[redis.client.Pipeline] = None
1294+
) -> "HashModel":
12901295
self.check()
12911296
if pipeline is None:
12921297
db = self.db()
@@ -1458,7 +1463,9 @@ def __init__(self, *args, **kwargs):
14581463
)
14591464
super().__init__(*args, **kwargs)
14601465

1461-
async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
1466+
async def save(
1467+
self, pipeline: Optional[redis.client.Pipeline] = None
1468+
) -> "JsonModel":
14621469
self.check()
14631470
if pipeline is None:
14641471
db = self.db()

aredis_om/sync_redis.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import redis

aredis_om/unasync_util.py

-41
This file was deleted.

aredis_om/util.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import inspect
2+
3+
4+
def is_async_mode():
5+
async def f():
6+
"""Unasync transforms async functions in sync functions"""
7+
return None
8+
9+
return inspect.iscoroutinefunction(f)
10+
11+
12+
ASYNC_MODE = is_async_mode()

make_sync.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
ADDITIONAL_REPLACEMENTS = {
77
"aredis_om": "redis_om",
8-
"aioredis": "redis",
8+
"async_redis": "sync_redis",
99
":tests.": ":tests_sync.",
1010
"pytest_asyncio": "pytest",
1111
"py_test_mark_asyncio": "py_test_mark_sync",
@@ -26,11 +26,12 @@ def main():
2626
),
2727
]
2828
filepaths = []
29-
for root, _, filenames in os.walk(
30-
Path(__file__).absolute().parent
31-
):
29+
for root, _, filenames in os.walk(Path(__file__).absolute().parent):
3230
for filename in filenames:
33-
if filename.rpartition(".")[-1] in ("py", "pyi",):
31+
if filename.rpartition(".")[-1] in (
32+
"py",
33+
"pyi",
34+
):
3435
filepaths.append(os.path.join(root, filename))
3536

3637
unasync.unasync_files(filepaths, rules)

0 commit comments

Comments
 (0)