Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from aioredis to redis-py with asyncio support #233

Merged
merged 3 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aredis_om/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .async_redis import redis # isort:skip
from .checks import has_redis_json, has_redisearch
from .connections import get_redis_connection
from .model.migrations.migrator import MigrationError, Migrator
Expand Down
1 change: 1 addition & 0 deletions aredis_om/async_redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from redis import asyncio as redis
8 changes: 4 additions & 4 deletions aredis_om/connections.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import os

import aioredis
from . import redis


URL = os.environ.get("REDIS_OM_URL", None)


def get_redis_connection(**kwargs) -> aioredis.Redis:
def get_redis_connection(**kwargs) -> redis.Redis:
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
# environment variable, we'll create the Redis client from the URL.
url = kwargs.pop("url", URL)
if url:
return aioredis.Redis.from_url(url, **kwargs)
return redis.Redis.from_url(url, **kwargs)

# Decode from UTF-8 by default
if "decode_responses" not in kwargs:
kwargs["decode_responses"] = True
return aioredis.Redis(**kwargs)
return redis.Redis(**kwargs)
39 changes: 20 additions & 19 deletions aredis_om/model/migrations/migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from typing import List, Optional

from aioredis import Redis, ResponseError
from ... import redis


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -39,18 +39,19 @@ def schema_hash_key(index_name):
return f"{index_name}:hash"


async def create_index(redis: Redis, index_name, schema, current_hash):
db_number = redis.connection_pool.connection_kwargs.get("db")
async def create_index(conn: redis.Redis, index_name, schema, current_hash):
db_number = conn.connection_pool.connection_kwargs.get("db")
if db_number and db_number > 0:
raise MigrationError(
"Creating search indexes is only supported in database 0. "
f"You attempted to create an index in database {db_number}"
)
try:
await redis.execute_command(f"ft.info {index_name}")
except ResponseError:
await redis.execute_command(f"ft.create {index_name} {schema}")
await redis.set(schema_hash_key(index_name), current_hash)
await conn.execute_command(f"ft.info {index_name}")
except redis.ResponseError:
await conn.execute_command(f"ft.create {index_name} {schema}")
# TODO: remove "type: ignore" when type stubs will be fixed
await conn.set(schema_hash_key(index_name), current_hash) # type: ignore
else:
log.info("Index already exists, skipping. Index hash: %s", index_name)

Expand All @@ -67,7 +68,7 @@ class IndexMigration:
schema: str
hash: str
action: MigrationAction
redis: Redis
conn: redis.Redis
previous_hash: Optional[str] = None

async def run(self):
Expand All @@ -78,14 +79,14 @@ async def run(self):

async def create(self):
try:
await create_index(self.redis, self.index_name, self.schema, self.hash)
except ResponseError:
await create_index(self.conn, self.index_name, self.schema, self.hash)
except redis.ResponseError:
log.info("Index already exists: %s", self.index_name)

async def drop(self):
try:
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
except ResponseError:
await self.conn.execute_command(f"FT.DROPINDEX {self.index_name}")
except redis.ResponseError:
log.info("Index does not exist: %s", self.index_name)


Expand All @@ -105,7 +106,7 @@ async def detect_migrations(self):

for name, cls in model_registry.items():
hash_key = schema_hash_key(cls.Meta.index_name)
redis = cls.db()
conn = cls.db()
try:
schema = cls.redisearch_schema()
except NotImplementedError:
Expand All @@ -114,21 +115,21 @@ async def detect_migrations(self):
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec

try:
await redis.execute_command("ft.info", cls.Meta.index_name)
except ResponseError:
await conn.execute_command("ft.info", cls.Meta.index_name)
except redis.ResponseError:
self.migrations.append(
IndexMigration(
name,
cls.Meta.index_name,
schema,
current_hash,
MigrationAction.CREATE,
redis,
conn,
)
)
continue

stored_hash = await redis.get(hash_key)
stored_hash = await conn.get(hash_key)
schema_out_of_date = current_hash != stored_hash

if schema_out_of_date:
Expand All @@ -140,7 +141,7 @@ async def detect_migrations(self):
schema,
current_hash,
MigrationAction.DROP,
redis,
conn,
stored_hash,
)
)
Expand All @@ -151,7 +152,7 @@ async def detect_migrations(self):
schema,
current_hash,
MigrationAction.CREATE,
redis,
conn,
stored_hash,
)
)
Expand Down
27 changes: 17 additions & 10 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
no_type_check,
)

import aioredis
from aioredis.client import Pipeline
from pydantic import BaseModel, validator
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType
Expand All @@ -35,9 +33,10 @@
from typing_extensions import Protocol, get_args, get_origin
from ulid import ULID

from .. import redis
from ..checks import has_redis_json, has_redisearch
from ..connections import get_redis_connection
from ..unasync_util import ASYNC_MODE
from ..util import ASYNC_MODE
from .encoders import jsonable_encoder
from .render_tree import render_tree
from .token_escaper import TokenEscaper
Expand Down Expand Up @@ -978,7 +977,7 @@ class BaseMeta(Protocol):
global_key_prefix: str
model_key_prefix: str
primary_key_pattern: str
database: aioredis.Redis
database: redis.Redis
primary_key: PrimaryKey
primary_key_creator_cls: Type[PrimaryKeyCreator]
index_name: str
Expand All @@ -997,7 +996,7 @@ class DefaultMeta:
global_key_prefix: Optional[str] = None
model_key_prefix: Optional[str] = None
primary_key_pattern: Optional[str] = None
database: Optional[aioredis.Redis] = None
database: Optional[redis.Redis] = None
primary_key: Optional[PrimaryKey] = None
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
index_name: Optional[str] = None
Expand Down Expand Up @@ -1130,10 +1129,14 @@ async def update(self, **field_values):
"""Update this model instance with the specified key-value pairs."""
raise NotImplementedError

async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "RedisModel":
raise NotImplementedError

async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
async def expire(
self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None
):
if pipeline is None:
db = self.db()
else:
Expand Down Expand Up @@ -1226,7 +1229,7 @@ def get_annotations(cls):
async def add(
cls,
models: Sequence["RedisModel"],
pipeline: Optional[Pipeline] = None,
pipeline: Optional[redis.client.Pipeline] = None,
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
) -> Sequence["RedisModel"]:
if pipeline is None:
Expand Down Expand Up @@ -1286,7 +1289,9 @@ def __init_subclass__(cls, **kwargs):
f"HashModels cannot index dataclass fields. Field: {name}"
)

async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "HashModel":
self.check()
if pipeline is None:
db = self.db()
Expand Down Expand Up @@ -1458,7 +1463,9 @@ def __init__(self, *args, **kwargs):
)
super().__init__(*args, **kwargs)

async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "JsonModel":
self.check()
if pipeline is None:
db = self.db()
Expand Down
1 change: 1 addition & 0 deletions aredis_om/sync_redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import redis
41 changes: 0 additions & 41 deletions aredis_om/unasync_util.py

This file was deleted.

12 changes: 12 additions & 0 deletions aredis_om/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import inspect


def is_async_mode():
async def f():
"""Unasync transforms async functions in sync functions"""
return None

return inspect.iscoroutinefunction(f)


ASYNC_MODE = is_async_mode()
11 changes: 6 additions & 5 deletions make_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

ADDITIONAL_REPLACEMENTS = {
"aredis_om": "redis_om",
"aioredis": "redis",
"async_redis": "sync_redis",
":tests.": ":tests_sync.",
"pytest_asyncio": "pytest",
"py_test_mark_asyncio": "py_test_mark_sync",
Expand All @@ -26,11 +26,12 @@ def main():
),
]
filepaths = []
for root, _, filenames in os.walk(
Path(__file__).absolute().parent
):
for root, _, filenames in os.walk(Path(__file__).absolute().parent):
for filename in filenames:
if filename.rpartition(".")[-1] in ("py", "pyi",):
if filename.rpartition(".")[-1] in (
"py",
"pyi",
):
filepaths.append(os.path.join(root, filename))

unasync.unasync_files(filepaths, rules)
Expand Down
Loading