Skip to content

Commit 4161c39

Browse files
committed
Refactor check_duplicate_key and remove mod.cu
1 parent f464583 commit 4161c39

File tree

2 files changed

+25
-38
lines changed

2 files changed

+25
-38
lines changed

pytensor/link/c/cmodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ class ModuleCache:
641641
The cache contains one directory for each module, containing:
642642
- the dynamic library file itself (e.g. ``.so/.pyd``),
643643
- an empty ``__init__.py`` file, so Python can import it,
644-
- a file containing the source code for the module (e.g. ``mod.cpp/mod.cu``),
644+
- a file containing the source code for the module (e.g. ``mod.cpp``),
645645
- a ``key.pkl`` file, containing a KeyData object with all the keys
646646
associated with that module,
647647
- possibly a ``delete.me`` file, meaning this directory has been marked

pytensor/misc/check_duplicate_key.py

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import pickle
33
import sys
4+
from collections import Counter
45

56
from pytensor.configdefaults import config
67

@@ -15,77 +16,63 @@
1516
else:
1617
dirs = os.listdir(config.compiledir)
1718
dirs = [os.path.join(config.compiledir, d) for d in dirs]
18-
keys: dict = {} # key -> nb seen
19+
keys: Counter[bytes] = Counter() # key -> nb seen
1920
mods: dict = {}
2021
for dir in dirs:
2122
key = None
2223
try:
23-
with open(os.path.join(dir, "key.pkl")) as f:
24+
with open(os.path.join(dir, "key.pkl"), "rb") as f:
2425
key = f.read()
25-
keys.setdefault(key, 0)
2626
keys[key] += 1
2727
del f
2828
except OSError:
2929
# print dir, "don't have a key.pkl file"
3030
pass
3131
try:
3232
path = os.path.join(dir, "mod.cpp")
33-
if not os.path.exists(path):
34-
path = os.path.join(dir, "mod.cu")
35-
with open(path) as f:
36-
mod = f.read()
33+
with open(path) as fmod:
34+
mod = fmod.read()
3735
mods.setdefault(mod, ())
3836
mods[mod] += (key,)
3937
del mod
40-
del f
38+
del fmod
4139
del path
4240
except OSError:
43-
print(dir, "don't have a mod.{cpp,cu} file")
41+
print(dir, "don't have a mod.cpp file")
4442

4543
if DISPLAY_DUPLICATE_KEYS:
4644
for k, v in keys.items():
4745
if v > 1:
4846
print("Duplicate key (%i copies): %s" % (v, pickle.loads(k)))
4947

50-
nbs_keys: dict = {} # nb seen -> now many key
51-
for val in keys.values():
52-
nbs_keys.setdefault(val, 0)
53-
nbs_keys[val] += 1
48+
# nb seen -> how many keys
49+
nbs_keys = Counter(val for val in keys.values())
5450

55-
nbs_mod: dict = {} # nb seen -> how many key
56-
nbs_mod_to_key = {} # nb seen -> keys
57-
more_than_one = 0
58-
for mod, kk in mods.items():
59-
val = len(kk)
60-
nbs_mod.setdefault(val, 0)
61-
nbs_mod[val] += 1
62-
if val > 1:
63-
more_than_one += 1
64-
nbs_mod_to_key[val] = kk
51+
# nb seen -> how many keys
52+
nbs_mod = Counter(len(kk) for kk in mods.values())
53+
# nb seen -> keys
54+
nbs_mod_to_key = {len(kk): kk for kk in mods.values()}
55+
more_than_one = sum(len(kk) > 1 for kk in mods.values())
6556

6657
if DISPLAY_MOST_FREQUENT_DUPLICATE_CCODE:
67-
m = max(nbs_mod.keys())
68-
print("The keys associated to the mod.{cpp,cu} with the most number of copy:")
58+
m = max(nbs_mod)
59+
print("The keys associated to the mod.cpp with the most number of copy:")
6960
for kk in nbs_mod_to_key[m]:
7061
kk = pickle.loads(kk)
7162
print(kk)
7263

7364
print("key.pkl histograph")
74-
l = list(nbs_keys.items())
75-
l.sort()
76-
print(l)
65+
print(sorted(nbs_keys.items()))
7766

78-
print("mod.{cpp,cu} histogram")
79-
l = list(nbs_mod.items())
80-
l.sort()
81-
print(l)
67+
print("mod.cpp histogram")
68+
print(sorted(nbs_mod.items()))
8269

83-
total = sum(len(k) for k in list(mods.values()))
70+
total = sum(len(k) for k in mods.values())
8471
uniq = len(mods)
8572
useless = total - uniq
86-
print("mod.{cpp,cu} total:", total)
87-
print("mod.{cpp,cu} uniq:", uniq)
88-
print("mod.{cpp,cu} with more than 1 copy:", more_than_one)
89-
print("mod.{cpp,cu} useless:", useless, float(useless) / total * 100, "%")
73+
print("mod.cpp total:", total)
74+
print("mod.cpp uniq:", uniq)
75+
print("mod.cpp with more than 1 copy:", more_than_one)
76+
print("mod.cpp useless:", useless, float(useless) / total * 100, "%")
9077

9178
print("nb directory", len(dirs))

0 commit comments

Comments
 (0)