1
1
import os
2
2
from pathlib import Path
3
+ from typing import Iterable , Optional , Union
3
4
4
5
import unasync
5
6
6
7
ADDITIONAL_REPLACEMENTS = {
7
8
"aredis_om" : "redis_om" ,
8
9
"aioredis" : "redis" ,
9
10
":tests." : ":tests_sync." ,
11
+ "pytest_asyncio" : "pytest" ,
10
12
}
11
13
14
+ STRINGS_TO_REMOVE_FROM_SYNC_TESTS = {
15
+ "@pytest.mark.asyncio" ,
16
+ }
17
+
18
+
19
+ def remove_strings_from_files (
20
+ filepaths : Iterable [Union [bytes , str , os .PathLike ]],
21
+ strings_to_remove : Iterable [str ],
22
+ ):
23
+ for filepath in filepaths :
24
+ tmp_filepath = f"{ filepath } .tmp"
25
+ with open (filepath , "r" ) as read_file , open (tmp_filepath , "w" ) as write_file :
26
+ for line in read_file :
27
+ if line .strip () in strings_to_remove :
28
+ continue
29
+ print (line , end = "" , file = write_file )
30
+ os .replace (tmp_filepath , filepath )
31
+
32
+
33
+ def get_source_filepaths (directory : Optional [Union [bytes , str , os .PathLike ]] = None ):
34
+ walk_path = (
35
+ Path (__file__ ).absolute ().parent
36
+ if directory is None
37
+ else os .path .join (Path (__file__ ).absolute ().parent , directory )
38
+ )
39
+
40
+ filepaths = []
41
+ for root , _ , filenames in os .walk (walk_path ):
42
+ for filename in filenames :
43
+ if filename .rpartition ("." )[- 1 ] in (
44
+ "py" ,
45
+ "pyi" ,
46
+ ):
47
+ filepaths .append (os .path .join (root , filename ))
48
+ return filepaths
49
+
12
50
13
51
def main ():
14
52
rules = [
@@ -23,15 +61,11 @@ def main():
23
61
additional_replacements = ADDITIONAL_REPLACEMENTS ,
24
62
),
25
63
]
26
- filepaths = []
27
- for root , _ , filenames in os .walk (
28
- Path (__file__ ).absolute ().parent
29
- ):
30
- for filename in filenames :
31
- if filename .rpartition ("." )[- 1 ] in ("py" , "pyi" ,):
32
- filepaths .append (os .path .join (root , filename ))
33
64
34
- unasync .unasync_files (filepaths , rules )
65
+ unasync .unasync_files (get_source_filepaths (), rules )
66
+ remove_strings_from_files (
67
+ get_source_filepaths ("tests_sync" ), STRINGS_TO_REMOVE_FROM_SYNC_TESTS
68
+ )
35
69
36
70
37
71
if __name__ == "__main__" :
0 commit comments