14
14
import operator
15
15
import sys
16
16
import time
17
- from collections import defaultdict
17
+ from collections import Counter , defaultdict
18
18
from contextlib import contextmanager
19
- from typing import TYPE_CHECKING , Any , Union
19
+ from typing import TYPE_CHECKING , Any
20
20
21
21
import numpy as np
22
22
@@ -204,8 +204,8 @@ def reset(self):
204
204
self .fct_call_time = 0.0
205
205
self .fct_callcount = 0
206
206
self .vm_call_time = 0.0
207
- self .apply_time = {}
208
- self .apply_callcount = {}
207
+ self .apply_time = defaultdict ( float )
208
+ self .apply_callcount = Counter ()
209
209
# self.apply_cimpl = None
210
210
# self.message = None
211
211
@@ -234,9 +234,9 @@ def reset(self):
234
234
# Total time spent in Function.vm.__call__
235
235
#
236
236
237
- apply_time : dict [Union ["FunctionGraph" , Variable ], float ] | None = None
237
+ apply_time : dict [tuple ["FunctionGraph" , Apply ], float ]
238
238
239
- apply_callcount : dict [Union ["FunctionGraph" , Variable ], int ] | None = None
239
+ apply_callcount : dict [tuple ["FunctionGraph" , Apply ], int ]
240
240
241
241
apply_cimpl : dict [Apply , bool ] | None = None
242
242
# dict from node -> bool (1 if c, 0 if py)
@@ -292,10 +292,9 @@ def reset(self):
292
292
# param is called flag_time_thunks because most other attributes with time
293
293
# in the name are times *of* something, rather than configuration flags.
294
294
def __init__ (self , atexit_print = True , flag_time_thunks = None , ** kwargs ):
295
- self .apply_callcount = {}
295
+ self .apply_callcount = Counter ()
296
296
self .output_size = {}
297
- # Keys are `(FunctionGraph, Variable)`
298
- self .apply_time = {}
297
+ self .apply_time = defaultdict (float )
299
298
self .apply_cimpl = {}
300
299
self .variable_shape = {}
301
300
self .variable_strides = {}
@@ -320,37 +319,29 @@ def class_time(self):
320
319
321
320
"""
322
321
# timing is stored by node, we compute timing by class on demand
323
- rval = {}
324
- for (fgraph , node ), t in self .apply_time .items ():
325
- typ = type (node .op )
326
- rval .setdefault (typ , 0 )
327
- rval [typ ] += t
328
- return rval
322
+ rval = defaultdict (float )
323
+ for (_fgraph , node ), t in self .apply_time .items ():
324
+ rval [type (node .op )] += t
325
+ return dict (rval )
329
326
330
327
def class_callcount (self ):
331
328
"""
332
329
dict op -> total number of thunk calls
333
330
334
331
"""
335
332
# timing is stored by node, we compute timing by class on demand
336
- rval = {}
337
- for (fgraph , node ), count in self .apply_callcount .items ():
338
- typ = type (node .op )
339
- rval .setdefault (typ , 0 )
340
- rval [typ ] += count
333
+ rval = Counter ()
334
+ for (_fgraph , node ), count in self .apply_callcount .items ():
335
+ rval [type (node .op )] += count
341
336
return rval
342
337
343
- def class_nodes (self ):
338
+ def class_nodes (self ) -> Counter :
344
339
"""
345
340
dict op -> total number of nodes
346
341
347
342
"""
348
343
# timing is stored by node, we compute timing by class on demand
349
- rval = {}
350
- for (fgraph , node ), count in self .apply_callcount .items ():
351
- typ = type (node .op )
352
- rval .setdefault (typ , 0 )
353
- rval [typ ] += 1
344
+ rval = Counter (type (node .op ) for _fgraph , node in self .apply_callcount )
354
345
return rval
355
346
356
347
def class_impl (self ):
@@ -360,12 +351,9 @@ def class_impl(self):
360
351
"""
361
352
# timing is stored by node, we compute timing by class on demand
362
353
rval = {}
363
- for fgraph , node in self .apply_callcount :
354
+ for _fgraph , node in self .apply_callcount :
364
355
typ = type (node .op )
365
- if self .apply_cimpl [node ]:
366
- impl = "C "
367
- else :
368
- impl = "Py"
356
+ impl = "C " if self .apply_cimpl [node ] else "Py"
369
357
rval .setdefault (typ , impl )
370
358
if rval [typ ] != impl and len (rval [typ ]) == 2 :
371
359
rval [typ ] += impl
@@ -377,11 +365,10 @@ def op_time(self):
377
365
378
366
"""
379
367
# timing is stored by node, we compute timing by Op on demand
380
- rval = {}
368
+ rval = defaultdict ( float )
381
369
for (fgraph , node ), t in self .apply_time .items ():
382
- rval .setdefault (node .op , 0 )
383
370
rval [node .op ] += t
384
- return rval
371
+ return dict ( rval )
385
372
386
373
def fill_node_total_time (self , fgraph , node , total_times ):
387
374
"""
@@ -414,9 +401,8 @@ def op_callcount(self):
414
401
415
402
"""
416
403
# timing is stored by node, we compute timing by Op on demand
417
- rval = {}
404
+ rval = Counter ()
418
405
for (fgraph , node ), count in self .apply_callcount .items ():
419
- rval .setdefault (node .op , 0 )
420
406
rval [node .op ] += count
421
407
return rval
422
408
@@ -426,10 +412,7 @@ def op_nodes(self):
426
412
427
413
"""
428
414
# timing is stored by node, we compute timing by Op on demand
429
- rval = {}
430
- for (fgraph , node ), count in self .apply_callcount .items ():
431
- rval .setdefault (node .op , 0 )
432
- rval [node .op ] += 1
415
+ rval = Counter (node .op for _fgraph , node in self .apply_callcount )
433
416
return rval
434
417
435
418
def op_impl (self ):
0 commit comments