Skip to content

Commit 87b51d6

Browse files
committed
Use Counter and defaultdict in profiling
1 parent 7f27d45 commit 87b51d6

File tree

2 files changed

+21
-42
lines changed

2 files changed

+21
-42
lines changed

pytensor/compile/profiling.py

+21-40
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
import operator
1515
import sys
1616
import time
17-
from collections import defaultdict
17+
from collections import Counter, defaultdict
1818
from contextlib import contextmanager
19-
from typing import TYPE_CHECKING, Any, Union
19+
from typing import TYPE_CHECKING, Any
2020

2121
import numpy as np
2222

@@ -204,8 +204,8 @@ def reset(self):
204204
self.fct_call_time = 0.0
205205
self.fct_callcount = 0
206206
self.vm_call_time = 0.0
207-
self.apply_time = {}
208-
self.apply_callcount = {}
207+
self.apply_time = defaultdict(float)
208+
self.apply_callcount = Counter()
209209
# self.apply_cimpl = None
210210
# self.message = None
211211

@@ -234,9 +234,9 @@ def reset(self):
234234
# Total time spent in Function.vm.__call__
235235
#
236236

237-
apply_time: dict[Union["FunctionGraph", Variable], float] | None = None
237+
apply_time: dict[tuple["FunctionGraph", Apply], float] = defaultdict(float)
238238

239-
apply_callcount: dict[Union["FunctionGraph", Variable], int] | None = None
239+
apply_callcount: dict[tuple["FunctionGraph", Apply], int] = Counter()
240240

241241
apply_cimpl: dict[Apply, bool] | None = None
242242
# dict from node -> bool (1 if c, 0 if py)
@@ -292,10 +292,7 @@ def reset(self):
292292
# param is called flag_time_thunks because most other attributes with time
293293
# in the name are times *of* something, rather than configuration flags.
294294
def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs):
295-
self.apply_callcount = {}
296295
self.output_size = {}
297-
# Keys are `(FunctionGraph, Variable)`
298-
self.apply_time = {}
299296
self.apply_cimpl = {}
300297
self.variable_shape = {}
301298
self.variable_strides = {}
@@ -320,37 +317,29 @@ def class_time(self):
320317
321318
"""
322319
# timing is stored by node, we compute timing by class on demand
323-
rval = {}
324-
for (fgraph, node), t in self.apply_time.items():
325-
typ = type(node.op)
326-
rval.setdefault(typ, 0)
327-
rval[typ] += t
328-
return rval
320+
rval = defaultdict(float)
321+
for (_fgraph, node), t in self.apply_time.items():
322+
rval[type(node.op)] += t
323+
return dict(rval)
329324

330325
def class_callcount(self):
331326
"""
332327
dict op -> total number of thunk calls
333328
334329
"""
335330
# timing is stored by node, we compute timing by class on demand
336-
rval = {}
337-
for (fgraph, node), count in self.apply_callcount.items():
338-
typ = type(node.op)
339-
rval.setdefault(typ, 0)
340-
rval[typ] += count
331+
rval = Counter()
332+
for (_fgraph, node), count in self.apply_callcount.items():
333+
rval[type(node.op)] += count
341334
return rval
342335

343-
def class_nodes(self):
336+
def class_nodes(self) -> Counter:
344337
"""
345338
dict op -> total number of nodes
346339
347340
"""
348341
# timing is stored by node, we compute timing by class on demand
349-
rval = {}
350-
for (fgraph, node), count in self.apply_callcount.items():
351-
typ = type(node.op)
352-
rval.setdefault(typ, 0)
353-
rval[typ] += 1
342+
rval = Counter(type(node.op) for _fgraph, node in self.apply_callcount)
354343
return rval
355344

356345
def class_impl(self):
@@ -360,12 +349,9 @@ def class_impl(self):
360349
"""
361350
# timing is stored by node, we compute timing by class on demand
362351
rval = {}
363-
for fgraph, node in self.apply_callcount:
352+
for _fgraph, node in self.apply_callcount:
364353
typ = type(node.op)
365-
if self.apply_cimpl[node]:
366-
impl = "C "
367-
else:
368-
impl = "Py"
354+
impl = "C " if self.apply_cimpl[node] else "Py"
369355
rval.setdefault(typ, impl)
370356
if rval[typ] != impl and len(rval[typ]) == 2:
371357
rval[typ] += impl
@@ -377,11 +363,10 @@ def op_time(self):
377363
378364
"""
379365
# timing is stored by node, we compute timing by Op on demand
380-
rval = {}
366+
rval = defaultdict(float)
381367
for (fgraph, node), t in self.apply_time.items():
382-
rval.setdefault(node.op, 0)
383368
rval[node.op] += t
384-
return rval
369+
return dict(rval)
385370

386371
def fill_node_total_time(self, fgraph, node, total_times):
387372
"""
@@ -414,9 +399,8 @@ def op_callcount(self):
414399
415400
"""
416401
# timing is stored by node, we compute timing by Op on demand
417-
rval = {}
402+
rval = Counter()
418403
for (fgraph, node), count in self.apply_callcount.items():
419-
rval.setdefault(node.op, 0)
420404
rval[node.op] += count
421405
return rval
422406

@@ -426,10 +410,7 @@ def op_nodes(self):
426410
427411
"""
428412
# timing is stored by node, we compute timing by Op on demand
429-
rval = {}
430-
for (fgraph, node), count in self.apply_callcount.items():
431-
rval.setdefault(node.op, 0)
432-
rval[node.op] += 1
413+
rval = Counter(node.op for _fgraph, node in self.apply_callcount)
433414
return rval
434415

435416
def op_impl(self):

pytensor/link/vm.py

-2
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,8 @@ def update_profile(self, profile):
246246
for node, thunk, t, c in zip(
247247
self.nodes, self.thunks, self.call_times, self.call_counts
248248
):
249-
profile.apply_time.setdefault((self.fgraph, node), 0.0)
250249
profile.apply_time[(self.fgraph, node)] += t
251250

252-
profile.apply_callcount.setdefault((self.fgraph, node), 0)
253251
profile.apply_callcount[(self.fgraph, node)] += c
254252

255253
profile.apply_cimpl[node] = hasattr(thunk, "cthunk")

0 commit comments

Comments
 (0)