Skip to content

Commit bf535a8

Browse files
committed
Use defaultdict and Counter in profiling.py
1 parent f1c3c15 commit bf535a8

File tree

2 files changed

+23
-42
lines changed

2 files changed

+23
-42
lines changed

pytensor/compile/profiling.py

+23-40
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
import operator
1515
import sys
1616
import time
17-
from collections import defaultdict
17+
from collections import Counter, defaultdict
1818
from contextlib import contextmanager
19-
from typing import TYPE_CHECKING, Any, Union
19+
from typing import TYPE_CHECKING, Any
2020

2121
import numpy as np
2222

@@ -204,8 +204,8 @@ def reset(self):
204204
self.fct_call_time = 0.0
205205
self.fct_callcount = 0
206206
self.vm_call_time = 0.0
207-
self.apply_time = {}
208-
self.apply_callcount = {}
207+
self.apply_time = defaultdict(float)
208+
self.apply_callcount = Counter()
209209
# self.apply_cimpl = None
210210
# self.message = None
211211

@@ -234,9 +234,9 @@ def reset(self):
234234
# Total time spent in Function.vm.__call__
235235
#
236236

237-
apply_time: dict[Union["FunctionGraph", Variable], float] | None = None
237+
apply_time: dict[tuple["FunctionGraph", Apply], float]
238238

239-
apply_callcount: dict[Union["FunctionGraph", Variable], int] | None = None
239+
apply_callcount: dict[tuple["FunctionGraph", Apply], int]
240240

241241
apply_cimpl: dict[Apply, bool] | None = None
242242
# dict from node -> bool (1 if c, 0 if py)
@@ -292,10 +292,9 @@ def reset(self):
292292
# param is called flag_time_thunks because most other attributes with time
293293
# in the name are times *of* something, rather than configuration flags.
294294
def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs):
295-
self.apply_callcount = {}
295+
self.apply_callcount = Counter()
296296
self.output_size = {}
297-
# Keys are `(FunctionGraph, Variable)`
298-
self.apply_time = {}
297+
self.apply_time = defaultdict(float)
299298
self.apply_cimpl = {}
300299
self.variable_shape = {}
301300
self.variable_strides = {}
@@ -320,37 +319,29 @@ def class_time(self):
320319
321320
"""
322321
# timing is stored by node, we compute timing by class on demand
323-
rval = {}
324-
for (fgraph, node), t in self.apply_time.items():
325-
typ = type(node.op)
326-
rval.setdefault(typ, 0)
327-
rval[typ] += t
328-
return rval
322+
rval = defaultdict(float)
323+
for (_fgraph, node), t in self.apply_time.items():
324+
rval[type(node.op)] += t
325+
return dict(rval)
329326

330327
def class_callcount(self):
331328
"""
332329
dict op -> total number of thunk calls
333330
334331
"""
335332
# timing is stored by node, we compute timing by class on demand
336-
rval = {}
337-
for (fgraph, node), count in self.apply_callcount.items():
338-
typ = type(node.op)
339-
rval.setdefault(typ, 0)
340-
rval[typ] += count
333+
rval = Counter()
334+
for (_fgraph, node), count in self.apply_callcount.items():
335+
rval[type(node.op)] += count
341336
return rval
342337

343-
def class_nodes(self):
338+
def class_nodes(self) -> Counter:
344339
"""
345340
dict op -> total number of nodes
346341
347342
"""
348343
# timing is stored by node, we compute timing by class on demand
349-
rval = {}
350-
for (fgraph, node), count in self.apply_callcount.items():
351-
typ = type(node.op)
352-
rval.setdefault(typ, 0)
353-
rval[typ] += 1
344+
rval = Counter(type(node.op) for _fgraph, node in self.apply_callcount)
354345
return rval
355346

356347
def class_impl(self):
@@ -360,12 +351,9 @@ def class_impl(self):
360351
"""
361352
# timing is stored by node, we compute timing by class on demand
362353
rval = {}
363-
for fgraph, node in self.apply_callcount:
354+
for _fgraph, node in self.apply_callcount:
364355
typ = type(node.op)
365-
if self.apply_cimpl[node]:
366-
impl = "C "
367-
else:
368-
impl = "Py"
356+
impl = "C " if self.apply_cimpl[node] else "Py"
369357
rval.setdefault(typ, impl)
370358
if rval[typ] != impl and len(rval[typ]) == 2:
371359
rval[typ] += impl
@@ -377,11 +365,10 @@ def op_time(self):
377365
378366
"""
379367
# timing is stored by node, we compute timing by Op on demand
380-
rval = {}
368+
rval = defaultdict(float)
381369
for (fgraph, node), t in self.apply_time.items():
382-
rval.setdefault(node.op, 0)
383370
rval[node.op] += t
384-
return rval
371+
return dict(rval)
385372

386373
def fill_node_total_time(self, fgraph, node, total_times):
387374
"""
@@ -414,9 +401,8 @@ def op_callcount(self):
414401
415402
"""
416403
# timing is stored by node, we compute timing by Op on demand
417-
rval = {}
404+
rval = Counter()
418405
for (fgraph, node), count in self.apply_callcount.items():
419-
rval.setdefault(node.op, 0)
420406
rval[node.op] += count
421407
return rval
422408

@@ -426,10 +412,7 @@ def op_nodes(self):
426412
427413
"""
428414
# timing is stored by node, we compute timing by Op on demand
429-
rval = {}
430-
for (fgraph, node), count in self.apply_callcount.items():
431-
rval.setdefault(node.op, 0)
432-
rval[node.op] += 1
415+
rval = Counter(node.op for _fgraph, node in self.apply_callcount)
433416
return rval
434417

435418
def op_impl(self):

pytensor/link/vm.py

-2
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,8 @@ def update_profile(self, profile):
246246
for node, thunk, t, c in zip(
247247
self.nodes, self.thunks, self.call_times, self.call_counts
248248
):
249-
profile.apply_time.setdefault((self.fgraph, node), 0.0)
250249
profile.apply_time[(self.fgraph, node)] += t
251250

252-
profile.apply_callcount.setdefault((self.fgraph, node), 0)
253251
profile.apply_callcount[(self.fgraph, node)] += c
254252

255253
profile.apply_cimpl[node] = hasattr(thunk, "cthunk")

0 commit comments

Comments
 (0)