Skip to content

Commit bfb72dc

Browse files
committed
Migrate from aioredis to redis with asyncio support
1 parent e2ff503 commit bfb72dc

File tree

10 files changed

+49
-65
lines changed

10 files changed

+49
-65
lines changed

aredis_om/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ._util import redis
12
from .checks import has_redis_json, has_redisearch
23
from .connections import get_redis_connection
34
from .model.migrations.migrator import MigrationError, Migrator

aredis_om/_util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import inspect
2+
3+
4+
def is_async_mode():
5+
async def f():
6+
"""Unasync transforms async functions in sync functions"""
7+
return None
8+
9+
return inspect.iscoroutinefunction(f)
10+
11+
12+
ASYNC_MODE = is_async_mode()
13+
14+
if ASYNC_MODE:
15+
from redis import asyncio as redis
16+
else:
17+
import redis # type: ignore

aredis_om/connections.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import os
22

3-
import aioredis
3+
from . import redis
44

55

66
URL = os.environ.get("REDIS_OM_URL", None)
77

88

9-
def get_redis_connection(**kwargs) -> aioredis.Redis:
9+
def get_redis_connection(**kwargs) -> redis.Redis:
1010
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
1111
# environment variable, we'll create the Redis client from the URL.
1212
url = kwargs.pop("url", URL)
1313
if url:
14-
return aioredis.Redis.from_url(url, **kwargs)
14+
return redis.Redis.from_url(url, **kwargs)
1515

1616
# Decode from UTF-8 by default
1717
if "decode_responses" not in kwargs:
1818
kwargs["decode_responses"] = True
19-
return aioredis.Redis(**kwargs)
19+
return redis.Redis(**kwargs)

aredis_om/model/migrations/migrator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from enum import Enum
55
from typing import List, Optional
66

7-
from aioredis import Redis, ResponseError
7+
from ... import redis as redis_module
88

99

1010
log = logging.getLogger(__name__)
@@ -39,7 +39,7 @@ def schema_hash_key(index_name):
3939
return f"{index_name}:hash"
4040

4141

42-
async def create_index(redis: Redis, index_name, schema, current_hash):
42+
async def create_index(redis: redis_module.Redis, index_name, schema, current_hash):
4343
db_number = redis.connection_pool.connection_kwargs.get("db")
4444
if db_number and db_number > 0:
4545
raise MigrationError(
@@ -48,7 +48,7 @@ async def create_index(redis: Redis, index_name, schema, current_hash):
4848
)
4949
try:
5050
await redis.execute_command(f"ft.info {index_name}")
51-
except ResponseError:
51+
except redis_module.ResponseError:
5252
await redis.execute_command(f"ft.create {index_name} {schema}")
5353
await redis.set(schema_hash_key(index_name), current_hash)
5454
else:
@@ -67,7 +67,7 @@ class IndexMigration:
6767
schema: str
6868
hash: str
6969
action: MigrationAction
70-
redis: Redis
70+
redis: redis_module.Redis
7171
previous_hash: Optional[str] = None
7272

7373
async def run(self):
@@ -79,13 +79,13 @@ async def run(self):
7979
async def create(self):
8080
try:
8181
await create_index(self.redis, self.index_name, self.schema, self.hash)
82-
except ResponseError:
82+
except redis_module.ResponseError:
8383
log.info("Index already exists: %s", self.index_name)
8484

8585
async def drop(self):
8686
try:
8787
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
88-
except ResponseError:
88+
except redis_module.ResponseError:
8989
log.info("Index does not exist: %s", self.index_name)
9090

9191

@@ -115,7 +115,7 @@ async def detect_migrations(self):
115115

116116
try:
117117
await redis.execute_command("ft.info", cls.Meta.index_name)
118-
except ResponseError:
118+
except redis_module.ResponseError:
119119
self.migrations.append(
120120
IndexMigration(
121121
name,

aredis_om/model/model.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
no_type_check,
2525
)
2626

27-
import aioredis
28-
from aioredis.client import Pipeline
2927
from pydantic import BaseModel, validator
3028
from pydantic.fields import FieldInfo as PydanticFieldInfo
3129
from pydantic.fields import ModelField, Undefined, UndefinedType
@@ -35,9 +33,10 @@
3533
from typing_extensions import Protocol, get_args, get_origin
3634
from ulid import ULID
3735

36+
from .. import redis
37+
from .._util import ASYNC_MODE
3838
from ..checks import has_redis_json, has_redisearch
3939
from ..connections import get_redis_connection
40-
from ..unasync_util import ASYNC_MODE
4140
from .encoders import jsonable_encoder
4241
from .render_tree import render_tree
4342
from .token_escaper import TokenEscaper
@@ -975,7 +974,7 @@ class BaseMeta(Protocol):
975974
global_key_prefix: str
976975
model_key_prefix: str
977976
primary_key_pattern: str
978-
database: aioredis.Redis
977+
database: redis.Redis
979978
primary_key: PrimaryKey
980979
primary_key_creator_cls: Type[PrimaryKeyCreator]
981980
index_name: str
@@ -994,7 +993,7 @@ class DefaultMeta:
994993
global_key_prefix: Optional[str] = None
995994
model_key_prefix: Optional[str] = None
996995
primary_key_pattern: Optional[str] = None
997-
database: Optional[aioredis.Redis] = None
996+
database: Optional[redis.Redis] = None
998997
primary_key: Optional[PrimaryKey] = None
999998
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
1000999
index_name: Optional[str] = None
@@ -1127,10 +1126,14 @@ async def update(self, **field_values):
11271126
"""Update this model instance with the specified key-value pairs."""
11281127
raise NotImplementedError
11291128

1130-
async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
1129+
async def save(
1130+
self, pipeline: Optional[redis.client.Pipeline] = None
1131+
) -> "RedisModel":
11311132
raise NotImplementedError
11321133

1133-
async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
1134+
async def expire(
1135+
self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None
1136+
):
11341137
if pipeline is None:
11351138
db = self.db()
11361139
else:
@@ -1241,7 +1244,7 @@ def get_annotations(cls):
12411244
async def add(
12421245
cls,
12431246
models: Sequence["RedisModel"],
1244-
pipeline: Optional[Pipeline] = None,
1247+
pipeline: Optional[redis.client.Pipeline] = None,
12451248
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
12461249
) -> Sequence["RedisModel"]:
12471250
if pipeline is None:
@@ -1301,7 +1304,9 @@ def __init_subclass__(cls, **kwargs):
13011304
f"HashModels cannot index dataclass fields. Field: {name}"
13021305
)
13031306

1304-
async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
1307+
async def save(
1308+
self, pipeline: Optional[redis.client.Pipeline] = None
1309+
) -> "HashModel":
13051310
self.check()
13061311
if pipeline is None:
13071312
db = self.db()
@@ -1473,7 +1478,9 @@ def __init__(self, *args, **kwargs):
14731478
)
14741479
super().__init__(*args, **kwargs)
14751480

1476-
async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
1481+
async def save(
1482+
self, pipeline: Optional[redis.client.Pipeline] = None
1483+
) -> "JsonModel":
14771484
self.check()
14781485
if pipeline is None:
14791486
db = self.db()

aredis_om/unasync_util.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

make_sync.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
ADDITIONAL_REPLACEMENTS = {
77
"aredis_om": "redis_om",
8-
"aioredis": "redis",
98
":tests.": ":tests_sync.",
109
"pytest_asyncio": "pytest",
1110
"py_test_mark_asyncio": "py_test_mark_sync",

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ include=[
2323

2424
[tool.poetry.dependencies]
2525
python = "^3.7"
26-
redis = ">=3.5.3,<5.0.0"
27-
aioredis = "^2.0.0"
26+
redis = ">=4.2.0,<5.0.0"
2827
pydantic = "^1.8.2"
2928
click = "^8.0.1"
3029
six = "^1.16.0"

tests/test_hash_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from redis_om import has_redisearch
2626
from tests.conftest import py_test_mark_asyncio
2727

28+
2829
if not has_redisearch():
2930
pytestmark = pytest.mark.skip
3031

tests/test_json_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from redis_om import has_redis_json
2828
from tests.conftest import py_test_mark_asyncio
2929

30+
3031
if not has_redis_json():
3132
pytestmark = pytest.mark.skip
3233

0 commit comments

Comments
 (0)