14
14
import operator
15
15
import sys
16
16
import time
17
- from collections import defaultdict
17
+ from collections import Counter , defaultdict
18
18
from contextlib import contextmanager
19
- from typing import TYPE_CHECKING , Any , Union
19
+ from typing import TYPE_CHECKING , Any
20
20
21
21
import numpy as np
22
22
@@ -204,8 +204,8 @@ def reset(self):
204
204
self .fct_call_time = 0.0
205
205
self .fct_callcount = 0
206
206
self .vm_call_time = 0.0
207
- self .apply_time = {}
208
- self .apply_callcount = {}
207
+ self .apply_time = defaultdict ( float )
208
+ self .apply_callcount = Counter ()
209
209
# self.apply_cimpl = None
210
210
# self.message = None
211
211
@@ -234,9 +234,9 @@ def reset(self):
234
234
# Total time spent in Function.vm.__call__
235
235
#
236
236
237
- apply_time : dict [Union ["FunctionGraph" , Variable ], float ] | None = None
237
+ apply_time : dict [tuple ["FunctionGraph" , Apply ], float ] = defaultdict ( float )
238
238
239
- apply_callcount : dict [Union ["FunctionGraph" , Variable ], int ] | None = None
239
+ apply_callcount : dict [tuple ["FunctionGraph" , Apply ], int ] = Counter ()
240
240
241
241
apply_cimpl : dict [Apply , bool ] | None = None
242
242
# dict from node -> bool (1 if c, 0 if py)
@@ -292,10 +292,7 @@ def reset(self):
292
292
# param is called flag_time_thunks because most other attributes with time
293
293
# in the name are times *of* something, rather than configuration flags.
294
294
def __init__ (self , atexit_print = True , flag_time_thunks = None , ** kwargs ):
295
- self .apply_callcount = {}
296
295
self .output_size = {}
297
- # Keys are `(FunctionGraph, Variable)`
298
- self .apply_time = {}
299
296
self .apply_cimpl = {}
300
297
self .variable_shape = {}
301
298
self .variable_strides = {}
@@ -320,37 +317,29 @@ def class_time(self):
320
317
321
318
"""
322
319
# timing is stored by node, we compute timing by class on demand
323
- rval = {}
324
- for (fgraph , node ), t in self .apply_time .items ():
325
- typ = type (node .op )
326
- rval .setdefault (typ , 0 )
327
- rval [typ ] += t
328
- return rval
320
+ rval = defaultdict (float )
321
+ for (_fgraph , node ), t in self .apply_time .items ():
322
+ rval [type (node .op )] += t
323
+ return dict (rval )
329
324
330
325
def class_callcount (self ):
331
326
"""
332
327
dict op -> total number of thunk calls
333
328
334
329
"""
335
330
# timing is stored by node, we compute timing by class on demand
336
- rval = {}
337
- for (fgraph , node ), count in self .apply_callcount .items ():
338
- typ = type (node .op )
339
- rval .setdefault (typ , 0 )
340
- rval [typ ] += count
331
+ rval = Counter ()
332
+ for (_fgraph , node ), count in self .apply_callcount .items ():
333
+ rval [type (node .op )] += count
341
334
return rval
342
335
343
- def class_nodes (self ):
336
+ def class_nodes (self ) -> Counter :
344
337
"""
345
338
dict op -> total number of nodes
346
339
347
340
"""
348
341
# timing is stored by node, we compute timing by class on demand
349
- rval = {}
350
- for (fgraph , node ), count in self .apply_callcount .items ():
351
- typ = type (node .op )
352
- rval .setdefault (typ , 0 )
353
- rval [typ ] += 1
342
+ rval = Counter (type (node .op ) for _fgraph , node in self .apply_callcount )
354
343
return rval
355
344
356
345
def class_impl (self ):
@@ -360,12 +349,9 @@ def class_impl(self):
360
349
"""
361
350
# timing is stored by node, we compute timing by class on demand
362
351
rval = {}
363
- for fgraph , node in self .apply_callcount :
352
+ for _fgraph , node in self .apply_callcount :
364
353
typ = type (node .op )
365
- if self .apply_cimpl [node ]:
366
- impl = "C "
367
- else :
368
- impl = "Py"
354
+ impl = "C " if self .apply_cimpl [node ] else "Py"
369
355
rval .setdefault (typ , impl )
370
356
if rval [typ ] != impl and len (rval [typ ]) == 2 :
371
357
rval [typ ] += impl
@@ -377,11 +363,10 @@ def op_time(self):
377
363
378
364
"""
379
365
# timing is stored by node, we compute timing by Op on demand
380
- rval = {}
366
+ rval = defaultdict ( float )
381
367
for (fgraph , node ), t in self .apply_time .items ():
382
- rval .setdefault (node .op , 0 )
383
368
rval [node .op ] += t
384
- return rval
369
+ return dict ( rval )
385
370
386
371
def fill_node_total_time (self , fgraph , node , total_times ):
387
372
"""
@@ -414,9 +399,8 @@ def op_callcount(self):
414
399
415
400
"""
416
401
# timing is stored by node, we compute timing by Op on demand
417
- rval = {}
402
+ rval = Counter ()
418
403
for (fgraph , node ), count in self .apply_callcount .items ():
419
- rval .setdefault (node .op , 0 )
420
404
rval [node .op ] += count
421
405
return rval
422
406
@@ -426,10 +410,7 @@ def op_nodes(self):
426
410
427
411
"""
428
412
# timing is stored by node, we compute timing by Op on demand
429
- rval = {}
430
- for (fgraph , node ), count in self .apply_callcount .items ():
431
- rval .setdefault (node .op , 0 )
432
- rval [node .op ] += 1
413
+ rval = Counter (node .op for _fgraph , node in self .apply_callcount )
433
414
return rval
434
415
435
416
def op_impl (self ):
0 commit comments