Skip to content

Commit 85540c0

Browse files
committed
Implement temporary workaround from unasync package
1 parent 490fcec commit 85540c0

File tree

5 files changed

+73
-96
lines changed

5 files changed

+73
-96
lines changed

make_sync.py

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
from pathlib import Path
3-
from typing import Iterable, Optional, Union
43

54
import unasync
65

@@ -9,44 +8,9 @@
98
"aioredis": "redis",
109
":tests.": ":tests_sync.",
1110
"pytest_asyncio": "pytest",
11+
"py_test_mark_asyncio": "py_test_mark_sync",
1212
}
1313

14-
STRINGS_TO_REMOVE_FROM_SYNC_TESTS = {
15-
"@pytest.mark.asyncio",
16-
}
17-
18-
19-
def remove_strings_from_files(
20-
filepaths: Iterable[Union[bytes, str, os.PathLike]],
21-
strings_to_remove: Iterable[str],
22-
):
23-
for filepath in filepaths:
24-
tmp_filepath = f"{filepath}.tmp"
25-
with open(filepath, "r") as read_file, open(tmp_filepath, "w") as write_file:
26-
for line in read_file:
27-
if line.strip() in strings_to_remove:
28-
continue
29-
print(line, end="", file=write_file)
30-
os.replace(tmp_filepath, filepath)
31-
32-
33-
def get_source_filepaths(directory: Optional[Union[bytes, str, os.PathLike]] = None):
34-
walk_path = (
35-
Path(__file__).absolute().parent
36-
if directory is None
37-
else os.path.join(Path(__file__).absolute().parent, directory)
38-
)
39-
40-
filepaths = []
41-
for root, _, filenames in os.walk(walk_path):
42-
for filename in filenames:
43-
if filename.rpartition(".")[-1] in (
44-
"py",
45-
"pyi",
46-
):
47-
filepaths.append(os.path.join(root, filename))
48-
return filepaths
49-
5014

5115
def main():
5216
rules = [
@@ -61,11 +25,15 @@ def main():
6125
additional_replacements=ADDITIONAL_REPLACEMENTS,
6226
),
6327
]
28+
filepaths = []
29+
for root, _, filenames in os.walk(
30+
Path(__file__).absolute().parent
31+
):
32+
for filename in filenames:
33+
if filename.rpartition(".")[-1] in ("py", "pyi",):
34+
filepaths.append(os.path.join(root, filename))
6435

65-
unasync.unasync_files(get_source_filepaths(), rules)
66-
remove_strings_from_files(
67-
get_source_filepaths("tests_sync"), STRINGS_TO_REMOVE_FROM_SYNC_TESTS
68-
)
36+
unasync.unasync_files(filepaths, rules)
6937

7038

7139
if __name__ == "__main__":

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
TEST_PREFIX = "redis-om:testing"
1010

1111

12+
py_test_mark_asyncio = pytest.mark.asyncio
13+
14+
15+
# "pytest_mark_sync" causes problem in pytest
16+
def py_test_mark_sync(f):
17+
return f # no-op decorator
18+
19+
1220
@pytest.fixture(scope="session")
1321
def event_loop(request):
1422
loop = asyncio.get_event_loop_policy().new_event_loop()

tests/test_hash_model.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
# We need to run this check as sync code (during tests) even in async mode
2424
# because we call it in the top-level module scope.
2525
from redis_om import has_redisearch
26-
26+
from tests.conftest import py_test_mark_asyncio
2727

2828
if not has_redisearch():
2929
pytestmark = pytest.mark.skip
@@ -96,7 +96,7 @@ async def members(m):
9696
yield member1, member2, member3
9797

9898

99-
@pytest.mark.asyncio
99+
@py_test_mark_asyncio
100100
async def test_exact_match_queries(members, m):
101101
member1, member2, member3 = members
102102

@@ -130,7 +130,7 @@ async def test_exact_match_queries(members, m):
130130
assert actual == [member2]
131131

132132

133-
@pytest.mark.asyncio
133+
@py_test_mark_asyncio
134134
async def test_full_text_search_queries(members, m):
135135
member1, member2, member3 = members
136136

@@ -143,7 +143,7 @@ async def test_full_text_search_queries(members, m):
143143
assert actual == [member1, member3]
144144

145145

146-
@pytest.mark.asyncio
146+
@py_test_mark_asyncio
147147
async def test_recursive_query_resolution(members, m):
148148
member1, member2, member3 = members
149149

@@ -158,7 +158,7 @@ async def test_recursive_query_resolution(members, m):
158158
assert actual == [member2, member1, member3]
159159

160160

161-
@pytest.mark.asyncio
161+
@py_test_mark_asyncio
162162
async def test_tag_queries_boolean_logic(members, m):
163163
member1, member2, member3 = members
164164

@@ -173,7 +173,7 @@ async def test_tag_queries_boolean_logic(members, m):
173173
assert actual == [member1, member3]
174174

175175

176-
@pytest.mark.asyncio
176+
@py_test_mark_asyncio
177177
async def test_tag_queries_punctuation(m):
178178
member1 = m.Member(
179179
first_name="Andrew, the Michael",
@@ -211,7 +211,7 @@ async def test_tag_queries_punctuation(m):
211211
assert results == [member2]
212212

213213

214-
@pytest.mark.asyncio
214+
@py_test_mark_asyncio
215215
async def test_tag_queries_negation(members, m):
216216
member1, member2, member3 = members
217217

@@ -283,7 +283,7 @@ async def test_tag_queries_negation(members, m):
283283
assert actual == [member3]
284284

285285

286-
@pytest.mark.asyncio
286+
@py_test_mark_asyncio
287287
async def test_numeric_queries(members, m):
288288
member1, member2, member3 = members
289289

@@ -314,7 +314,7 @@ async def test_numeric_queries(members, m):
314314
assert actual == [member2, member1]
315315

316316

317-
@pytest.mark.asyncio
317+
@py_test_mark_asyncio
318318
async def test_sorting(members, m):
319319
member1, member2, member3 = members
320320

@@ -359,7 +359,7 @@ def test_validation_passes(m):
359359
assert member.first_name == "Andrew"
360360

361361

362-
@pytest.mark.asyncio
362+
@py_test_mark_asyncio
363363
async def test_retrieve_first(m):
364364
member = m.Member(
365365
first_name="Simon",
@@ -398,7 +398,7 @@ async def test_retrieve_first(m):
398398
assert first_one == member3
399399

400400

401-
@pytest.mark.asyncio
401+
@py_test_mark_asyncio
402402
async def test_saves_model_and_creates_pk(m):
403403
member = m.Member(
404404
first_name="Andrew",
@@ -415,7 +415,7 @@ async def test_saves_model_and_creates_pk(m):
415415
assert member2 == member
416416

417417

418-
@pytest.mark.asyncio
418+
@py_test_mark_asyncio
419419
async def test_all_pks(m):
420420
member = m.Member(
421421
first_name="Simon",
@@ -446,7 +446,7 @@ async def test_all_pks(m):
446446
assert len(pk_list) == 2
447447

448448

449-
@pytest.mark.asyncio
449+
@py_test_mark_asyncio
450450
async def test_delete(m):
451451
member = m.Member(
452452
first_name="Simon",
@@ -462,7 +462,7 @@ async def test_delete(m):
462462
assert response == 1
463463

464464

465-
@pytest.mark.asyncio
465+
@py_test_mark_asyncio
466466
async def test_expire(m):
467467
member = m.Member(
468468
first_name="Expire",
@@ -526,7 +526,7 @@ class InvalidMember(m.BaseHashModel):
526526
friend_ids: List[str]
527527

528528

529-
@pytest.mark.asyncio
529+
@py_test_mark_asyncio
530530
async def test_saves_many(m):
531531
member1 = m.Member(
532532
first_name="Andrew",
@@ -552,22 +552,22 @@ async def test_saves_many(m):
552552
assert await m.Member.get(pk=member2.pk) == member2
553553

554554

555-
@pytest.mark.asyncio
555+
@py_test_mark_asyncio
556556
async def test_updates_a_model(members, m):
557557
member1, member2, member3 = members
558558
await member1.update(last_name="Smith")
559559
member = await m.Member.get(member1.pk)
560560
assert member.last_name == "Smith"
561561

562562

563-
@pytest.mark.asyncio
563+
@py_test_mark_asyncio
564564
async def test_paginate_query(members, m):
565565
member1, member2, member3 = members
566566
actual = await m.Member.find().sort_by("age").all(batch_size=1)
567567
assert actual == [member2, member1, member3]
568568

569569

570-
@pytest.mark.asyncio
570+
@py_test_mark_asyncio
571571
async def test_access_result_by_index_cached(members, m):
572572
member1, member2, member3 = members
573573
query = m.Member.find().sort_by("age")
@@ -582,7 +582,7 @@ async def test_access_result_by_index_cached(members, m):
582582
assert not mock_db.called
583583

584584

585-
@pytest.mark.asyncio
585+
@py_test_mark_asyncio
586586
async def test_access_result_by_index_not_cached(members, m):
587587
member1, member2, member3 = members
588588
query = m.Member.find().sort_by("age")

0 commit comments

Comments
 (0)