diff --git a/pytensor/compile/compiledir.py b/pytensor/compile/compiledir.py index 513053bd92..906f203251 100644 --- a/pytensor/compile/compiledir.py +++ b/pytensor/compile/compiledir.py @@ -7,6 +7,7 @@ import os import pickle import shutil +from collections import Counter import numpy as np @@ -111,11 +112,11 @@ def print_compiledir_content(): compiledir = config.compiledir table = [] table_multiple_ops = [] - table_op_class = {} + table_op_class = Counter() zeros_op = 0 big_key_files = [] total_key_sizes = 0 - nb_keys = {} + nb_keys = Counter() for dir in os.listdir(compiledir): filename = os.path.join(compiledir, dir, "key.pkl") if not os.path.exists(filename): @@ -125,9 +126,7 @@ def print_compiledir_content(): keydata = pickle.load(file) ops = list({x for x in flatten(keydata.keys) if isinstance(x, Op)}) # Whatever the case, we count compilations for OP classes. - for op_class in {op.__class__ for op in ops}: - table_op_class.setdefault(op_class, 0) - table_op_class[op_class] += 1 + table_op_class.update({op.__class__ for op in ops}) if len(ops) == 0: zeros_op += 1 else: @@ -159,7 +158,6 @@ def print_compiledir_content(): if size > max_key_file_size: big_key_files.append((dir, size, ops)) - nb_keys.setdefault(len(keydata.keys), 0) nb_keys[len(keydata.keys)] += 1 except OSError: pass @@ -198,8 +196,7 @@ def print_compiledir_content(): ), underline="+", ) - table_op_class = sorted(table_op_class.items(), key=lambda t: t[1]) - for op_class, nb in table_op_class: + for op_class, nb in reversed(table_op_class.most_common()): print(op_class, nb) if big_key_files: diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index 92f4865e69..bd664d2b2b 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -906,11 +906,10 @@ def _get_preallocated_maps( name = f"strided{tuple(steps)}" for r in considered_outputs: if r in init_strided: - strides = [] - shapes = [] - for i, size in enumerate(r_vals[r].shape): - shapes.append(slice(None, size, None)) - strides.append(slice(None, None, steps[i])) + shapes = [slice(None, size, None) for size in r_vals[r].shape] + strides = [ + slice(None, None, steps[i]) for i in range(r_vals[r].ndim) + ] r_buf = init_strided[r] diff --git a/pytensor/compile/function/__init__.py b/pytensor/compile/function/__init__.py index 0f4b6e00d4..0447aab5f4 100644 --- a/pytensor/compile/function/__init__.py +++ b/pytensor/compile/function/__init__.py @@ -247,18 +247,10 @@ def opt_log1p(node): """ if isinstance(outputs, dict): - output_items = list(outputs.items()) + assert all(isinstance(k, str) for k in outputs) - for item_pair in output_items: - assert isinstance(item_pair[0], str) - - output_items_sorted = sorted(output_items) - - output_keys = [] - outputs = [] - for pair in output_items_sorted: - output_keys.append(pair[0]) - outputs.append(pair[1]) + output_keys = sorted(outputs) + outputs = [outputs[key] for key in output_keys] else: output_keys = None diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index c221d7cf41..d9070831ff 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -14,6 +14,7 @@ import pytensor.compile.profiling from pytensor.compile.io import In, SymbolicInput, SymbolicOutput from pytensor.compile.ops import deep_copy_op, view_op +from pytensor.compile.profiling import ProfileStats from pytensor.configdefaults import config from pytensor.graph.basic import ( Constant, @@ -212,18 +213,14 @@ def std_fgraph( found_updates.extend(map(SymbolicOutput, updates)) elif fgraph is None: - input_vars = [] - # If one of the inputs is non-atomic (i.e. has a non-`None` `Variable.owner`), # then we need to create/clone the graph starting at these inputs. # The result will be atomic versions of the given inputs connected to # the same outputs. # Otherwise, when all the inputs are already atomic, there's no need to # clone the graph. - clone = force_clone - for spec in input_specs: - input_vars.append(spec.variable) - clone |= spec.variable.owner is not None + input_vars = [spec.variable for spec in input_specs] + clone = force_clone or any(var.owner is not None for var in input_vars) fgraph = FunctionGraph( input_vars, @@ -557,11 +554,11 @@ def __copy__(self): def copy( self, - share_memory=False, - swap=None, - delete_updates=False, - name=None, - profile=None, + share_memory: bool = False, + swap: dict | None = None, + delete_updates: bool = False, + name: str | None = None, + profile: bool | str | ProfileStats | None = None, ): """ Copy this function. Copied function will have separated maker and @@ -588,7 +585,7 @@ def copy( If provided, will be the name of the new Function. Otherwise, it will be old + " copy" - profile : + profile : bool | str | ProfileStats | None as pytensor.function profile parameter Returns @@ -727,14 +724,8 @@ def checkSV(sv_ori, sv_rpl): # reinitialize new maker and create new function if profile is None: profile = config.profile or config.print_global_stats - # profile -> True or False if profile is True: - if name: - message = name - else: - message = str(profile.message) + " copy" - profile = pytensor.compile.profiling.ProfileStats(message=message) - # profile -> object + profile = pytensor.compile.profiling.ProfileStats(message=name) elif isinstance(profile, str): profile = pytensor.compile.profiling.ProfileStats(message=profile) diff --git a/pytensor/compile/profiling.py b/pytensor/compile/profiling.py index 56a88ecfe3..8aa9795b40 100644 --- a/pytensor/compile/profiling.py +++ b/pytensor/compile/profiling.py @@ -14,9 +14,9 @@ import operator import sys import time -from collections import defaultdict +from collections import Counter, defaultdict from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any import numpy as np @@ -204,8 +204,8 @@ def reset(self): self.fct_call_time = 0.0 self.fct_callcount = 0 self.vm_call_time = 0.0 - self.apply_time = {} - self.apply_callcount = {} + self.apply_time = defaultdict(float) + self.apply_callcount = Counter() # self.apply_cimpl = None # self.message = None @@ -234,9 +234,9 @@ def reset(self): # Total time spent in Function.vm.__call__ # - apply_time: dict[Union["FunctionGraph", Variable], float] | None = None + apply_time: dict[tuple["FunctionGraph", Apply], float] - apply_callcount: dict[Union["FunctionGraph", Variable], int] | None = None + apply_callcount: dict[tuple["FunctionGraph", Apply], int] apply_cimpl: dict[Apply, bool] | None = None # dict from node -> bool (1 if c, 0 if py) @@ -292,10 +292,9 @@ def reset(self): # param is called flag_time_thunks because most other attributes with time # in the name are times *of* something, rather than configuration flags. def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs): - self.apply_callcount = {} + self.apply_callcount = Counter() self.output_size = {} - # Keys are `(FunctionGraph, Variable)` - self.apply_time = {} + self.apply_time = defaultdict(float) self.apply_cimpl = {} self.variable_shape = {} self.variable_strides = {} @@ -320,12 +319,10 @@ def class_time(self): """ # timing is stored by node, we compute timing by class on demand - rval = {} - for (fgraph, node), t in self.apply_time.items(): - typ = type(node.op) - rval.setdefault(typ, 0) - rval[typ] += t - return rval + rval = defaultdict(float) + for (_fgraph, node), t in self.apply_time.items(): + rval[type(node.op)] += t + return dict(rval) def class_callcount(self): """ @@ -333,24 +330,18 @@ def class_callcount(self): """ # timing is stored by node, we compute timing by class on demand - rval = {} - for (fgraph, node), count in self.apply_callcount.items(): - typ = type(node.op) - rval.setdefault(typ, 0) - rval[typ] += count + rval = Counter() + for (_fgraph, node), count in self.apply_callcount.items(): + rval[type(node.op)] += count return rval - def class_nodes(self): + def class_nodes(self) -> Counter: """ dict op -> total number of nodes """ # timing is stored by node, we compute timing by class on demand - rval = {} - for (fgraph, node), count in self.apply_callcount.items(): - typ = type(node.op) - rval.setdefault(typ, 0) - rval[typ] += 1 + rval = Counter(type(node.op) for _fgraph, node in self.apply_callcount) return rval def class_impl(self): @@ -360,12 +351,9 @@ def class_impl(self): """ # timing is stored by node, we compute timing by class on demand rval = {} - for fgraph, node in self.apply_callcount: + for _fgraph, node in self.apply_callcount: typ = type(node.op) - if self.apply_cimpl[node]: - impl = "C " - else: - impl = "Py" + impl = "C " if self.apply_cimpl[node] else "Py" rval.setdefault(typ, impl) if rval[typ] != impl and len(rval[typ]) == 2: rval[typ] += impl @@ -377,11 +365,10 @@ def op_time(self): """ # timing is stored by node, we compute timing by Op on demand - rval = {} + rval = defaultdict(float) for (fgraph, node), t in self.apply_time.items(): - rval.setdefault(node.op, 0) rval[node.op] += t - return rval + return dict(rval) def fill_node_total_time(self, fgraph, node, total_times): """ @@ -414,9 +401,8 @@ def op_callcount(self): """ # timing is stored by node, we compute timing by Op on demand - rval = {} + rval = Counter() for (fgraph, node), count in self.apply_callcount.items(): - rval.setdefault(node.op, 0) rval[node.op] += count return rval @@ -426,10 +412,7 @@ def op_nodes(self): """ # timing is stored by node, we compute timing by Op on demand - rval = {} - for (fgraph, node), count in self.apply_callcount.items(): - rval.setdefault(node.op, 0) - rval[node.op] += 1 + rval = Counter(node.op for _fgraph, node in self.apply_callcount) return rval def op_impl(self): @@ -1204,8 +1187,7 @@ def min_memory_generator(executable_nodes, viewed_by, view_of): compute_map[var][0] = 0 for k_remove, v_remove in viewedby_remove.items(): - for i in v_remove: - viewed_by[k_remove].append(i) + viewed_by[k_remove].extend(v_remove) for k_add, v_add in viewedby_add.items(): for i in v_add: @@ -1215,15 +1197,16 @@ def min_memory_generator(executable_nodes, viewed_by, view_of): del view_of[k] # two data structure used to mimic Python gc - viewed_by = {} # {var1: [vars that view var1]} + # * {var1: [vars that view var1]} # The len of the list is the value of python ref # count. But we use a list, not just the ref count value. - # This is more safe to help detect potential bug in the algo - for var in fgraph.variables: - viewed_by[var] = [] - view_of = {} # {var1: original var viewed by var1} + # This is more safe to help detect potential bug in the algo + viewed_by = {var: [] for var in fgraph.variables} + + # * {var1: original var viewed by var1} # The original mean that we don't keep track of all the intermediate # relationship in the view. + view_of = {} min_memory_generator(executable_nodes, viewed_by, view_of) diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 0ca6856eb1..298aea0c02 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -1041,13 +1041,12 @@ def access_term_cache(node): # list of bools indicating if each input is connected to the cost inputs_connected = [ ( - True - in [ + any( input_to_output and output_to_cost for input_to_output, output_to_cost in zip( input_to_outputs, outputs_connected ) - ] + ) ) for input_to_outputs in connection_pattern ] @@ -1067,25 +1066,24 @@ def access_term_cache(node): # List of bools indicating if each input only has NullType outputs only_connected_to_nan = [ ( - True - not in [ + not any( in_to_out and out_to_cost and not out_nan for in_to_out, out_to_cost, out_nan in zip( in_to_outs, outputs_connected, ograd_is_nan ) - ] + ) ) for in_to_outs in connection_pattern ] - if True not in inputs_connected: + if not any(inputs_connected): # All outputs of this op are disconnected so we can skip # Calling the op's grad method and report that the inputs # are disconnected # (The op's grad method could do this too, but this saves the # implementer the trouble of worrying about this case) input_grads = [disconnected_type() for ipt in inputs] - elif False not in only_connected_to_nan: + elif all(only_connected_to_nan): # All inputs are only connected to nan gradients, so we don't # need to bother calling the grad method. We know the gradient # with respect to all connected inputs is nan. diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index f71c591473..81c9a0f1a9 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1474,9 +1474,8 @@ def _compute_deps_cache_(io): _clients: dict[T, list[T]] = {} sources: deque[T] = deque() - search_res_len: int = 0 + search_res_len = len(search_res) for snode, children in search_res: - search_res_len += 1 if children: for child in children: _clients.setdefault(child, []).append(snode) diff --git a/pytensor/graph/destroyhandler.py b/pytensor/graph/destroyhandler.py index e90dc01a26..3ad3fce7ae 100644 --- a/pytensor/graph/destroyhandler.py +++ b/pytensor/graph/destroyhandler.py @@ -5,7 +5,7 @@ """ import itertools -from collections import OrderedDict, deque +from collections import deque import pytensor from pytensor.configdefaults import config @@ -306,7 +306,7 @@ def __init__(self, do_imports_on_attach=True, algo=None): TODO: change name to var_to_vroot. """ - self.droot = OrderedDict() + self.droot = {} """ Maps a variable to all variables that are indirect or direct views of it @@ -317,7 +317,7 @@ def __init__(self, do_imports_on_attach=True, algo=None): TODO: rename to x_to_views after reverse engineering what x is """ - self.impact = OrderedDict() + self.impact = {} """ If a var is destroyed, then this dict will map @@ -325,11 +325,11 @@ def __init__(self, do_imports_on_attach=True, algo=None): TODO: rename to vroot_to_destroyer """ - self.root_destroyer = OrderedDict() + self.root_destroyer = {} if algo is None: algo = config.cycle_detection self.algo = algo - self.fail_validate = OrderedDict() + self.fail_validate = {} def clone(self): return type(self)(self.do_imports_on_attach, self.algo) @@ -370,7 +370,7 @@ def on_attach(self, fgraph): self.view_i = {} # variable -> variable used in calculation self.view_o = {} # variable -> set of variables that use this one as a direct input # clients: how many times does an apply use a given variable - self.clients = OrderedDict() # variable -> apply -> ninputs + self.clients = {} # variable -> apply -> ninputs self.stale_droot = True self.debug_all_apps = set() @@ -527,11 +527,11 @@ def on_import(self, fgraph, app, reason): # update self.clients for i, input in enumerate(app.inputs): - self.clients.setdefault(input, OrderedDict()).setdefault(app, 0) + self.clients.setdefault(input, {}).setdefault(app, 0) self.clients[input][app] += 1 for i, output in enumerate(app.outputs): - self.clients.setdefault(output, OrderedDict()) + self.clients.setdefault(output, {}) self.stale_droot = True @@ -591,7 +591,7 @@ def on_change_input(self, fgraph, app, i, old_r, new_r, reason): if self.clients[old_r][app] == 0: del self.clients[old_r][app] - self.clients.setdefault(new_r, OrderedDict()).setdefault(app, 0) + self.clients.setdefault(new_r, {}).setdefault(app, 0) self.clients[new_r][app] += 1 # UPDATE self.view_i, self.view_o @@ -632,7 +632,7 @@ def validate(self, fgraph): if self.algo == "fast": if self.fail_validate: app_err_pairs = self.fail_validate - self.fail_validate = OrderedDict() + self.fail_validate = {} # self.fail_validate can only be a hint that maybe/probably # there is a cycle.This is because inside replace() we could # record many reasons to not accept a change, but we don't @@ -674,12 +674,8 @@ def orderings(self, fgraph, ordered=True): c) an Apply destroys (illegally) one of its own inputs by aliasing """ - if ordered: - set_type = OrderedSet - rval = OrderedDict() - else: - set_type = set - rval = dict() + set_type = OrderedSet if ordered else set + rval = {} if self.destroyers: # BUILD DATA STRUCTURES diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index 2bce7f1748..93321fa61f 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -2,7 +2,6 @@ import sys import time import warnings -from collections import OrderedDict from functools import partial from io import StringIO @@ -324,7 +323,7 @@ def orderings(self, fgraph): might be broken for all intents and purposes. """ - return OrderedDict() + return {} def clone(self): """Create a clone that can be attached to a new `FunctionGraph`. diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index e06845324b..d55d55ff7c 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -1,7 +1,7 @@ """A container for specifying and manipulating a graph with distinct inputs and outputs.""" import time -from collections import OrderedDict +from collections import defaultdict from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Any, Literal, Union, cast @@ -109,7 +109,7 @@ def __init__( inputs = [cast(Variable, _memo[i]) for i in inputs] self.execute_callbacks_time: float = 0.0 - self.execute_callbacks_times: dict[Feature, float] = {} + self.execute_callbacks_times: dict[Feature, float] = defaultdict(float) if features is None: features = [] @@ -270,8 +270,10 @@ def remove_client( self.execute_callbacks("on_prune", apply_node, reason) - for i, in_var in enumerate(apply_node.inputs): - removal_stack.append((in_var, (apply_node, i))) + removal_stack.extend( + (in_var, (apply_node, i)) + for i, in_var in enumerate(apply_node.inputs) + ) if remove_if_empty: del clients[var] @@ -671,7 +673,6 @@ def attach_feature(self, feature: Feature) -> None: attach(self) except AlreadyThere: return - self.execute_callbacks_times.setdefault(feature, 0.0) # It would be nice if we could require a specific class instead of # a "workalike" so we could do actual error checking # if not isinstance(feature, Feature): @@ -767,12 +768,12 @@ def orderings(self) -> dict[Apply, list[Apply]]: """ assert isinstance(self._features, list) - all_orderings: list[OrderedDict] = [] + all_orderings: list[dict] = [] for feature in self._features: if hasattr(feature, "orderings"): orderings = feature.orderings(self) - if not isinstance(orderings, OrderedDict): + if not isinstance(orderings, dict): raise TypeError( "Non-deterministic return value from " + str(feature.orderings) @@ -793,7 +794,7 @@ def orderings(self) -> dict[Apply, list[Apply]]: return all_orderings[0].copy() else: # If there is more than 1 ordering, combine them. - ords: dict[Apply, list[Apply]] = OrderedDict() + ords: dict[Apply, list[Apply]] = {} for orderings in all_orderings: for node, prereqs in orderings.items(): ords.setdefault(node, []).extend(prereqs) diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index aba9a3fa39..1de17e8680 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -10,7 +10,7 @@ import time import traceback import warnings -from collections import UserList, defaultdict, deque +from collections import Counter, UserList, defaultdict, deque from collections.abc import Callable, Iterable, Sequence from collections.abc import Iterable as IterableType from functools import _compose_mro, partial, reduce # type: ignore @@ -479,9 +479,9 @@ def merge_profile(prof1, prof2): new_sub_profile.append(p[6][idx]) new_rewrite = SequentialGraphRewriter(*new_l) - new_nb_nodes = [] - for p1, p2 in zip(prof1[8], prof2[8]): - new_nb_nodes.append((p1[0] + p2[0], p1[1] + p2[1])) + new_nb_nodes = [ + (p1[0] + p2[0], p1[1] + p2[1]) for p1, p2 in zip(prof1[8], prof2[8]) + ] new_nb_nodes.extend(prof1[8][len(new_nb_nodes) :]) new_nb_nodes.extend(prof2[8][len(new_nb_nodes) :]) @@ -960,9 +960,9 @@ def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]): tracks = rewriter.tracks() if tracks: + self._tracks.extend(tracks) for c in tracks: self.track_dict[c].append(rewriter) - self._tracks.append(c) for tag in tag_list: self.tag_dict[tag].append(rewriter) @@ -1153,8 +1153,8 @@ class OpToRewriterTracker: r"""A container that maps `NodeRewriter`\s to `Op` instances and `Op`-type inheritance.""" def __init__(self): - self.tracked_instances: dict[Op, list[NodeRewriter]] = {} - self.tracked_types: dict[type, list[NodeRewriter]] = {} + self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list) + self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list) self.untracked_rewrites: list[NodeRewriter] = [] def add_tracker(self, rw: NodeRewriter): @@ -1166,9 +1166,9 @@ def add_tracker(self, rw: NodeRewriter): else: for c in tracks: if isinstance(c, type): - self.tracked_types.setdefault(c, []).append(rw) + self.tracked_types[c].append(rw) else: - self.tracked_instances.setdefault(c, []).append(rw) + self.tracked_instances[c].append(rw) def _find_impl(self, cls) -> list[NodeRewriter]: r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance. @@ -1250,22 +1250,16 @@ def __init__( self.profile = profile if self.profile: - self.time_rewrites: dict[Rewriter, float] = {} - self.process_count: dict[Rewriter, int] = {} - self.applied_true: dict[Rewriter, int] = {} - self.node_created: dict[Rewriter, int] = {} + self.time_rewrites: dict[Rewriter, float] = defaultdict(float) + self.process_count: dict[Rewriter, int] = Counter() + self.applied_true: dict[Rewriter, int] = Counter() + self.node_created: dict[Rewriter, int] = Counter() self.tracker = OpToRewriterTracker() for o in self.rewrites: self.tracker.add_tracker(o) - if self.profile: - self.time_rewrites.setdefault(o, 0.0) - self.process_count.setdefault(o, 0) - self.applied_true.setdefault(o, 0) - self.node_created.setdefault(o, 0) - def __str__(self): return getattr( self, @@ -2316,7 +2310,7 @@ def apply(self, fgraph, start_from=None): changed = True max_use_abort = False rewriter_name = None - global_process_count = {} + global_process_count = Counter() start_nb_nodes = len(fgraph.apply_nodes) max_nb_nodes = len(fgraph.apply_nodes) max_use = max_nb_nodes * self.max_use_ratio @@ -2324,22 +2318,21 @@ def apply(self, fgraph, start_from=None): loop_timing = [] loop_process_count = [] global_rewriter_timing = [] - time_rewriters = {} + time_rewriters = defaultdict(float) io_toposort_timing = [] nb_nodes = [] - node_created = {} + node_created = Counter() global_sub_profs = [] final_sub_profs = [] cleanup_sub_profs = [] - for rewriter in ( - self.global_rewriters - + list(self.get_node_rewriters()) - + self.final_rewriters - + self.cleanup_rewriters - ): - global_process_count.setdefault(rewriter, 0) - time_rewriters.setdefault(rewriter, 0) - node_created.setdefault(rewriter, 0) + + for rewriter in [ + *self.global_rewriters, + *self.get_node_rewriters(), + *self.final_rewriters, + *self.cleanup_rewriters, + ]: + time_rewriters[rewriter] += 0 def apply_cleanup(profs_dict): changed = False @@ -2351,7 +2344,6 @@ def apply_cleanup(profs_dict): time_rewriters[crewriter] += time.perf_counter() - t_rewrite profs_dict[crewriter].append(sub_prof) if change_tracker.changed: - process_count.setdefault(crewriter, 0) process_count[crewriter] += 1 global_process_count[crewriter] += 1 changed = True @@ -2359,7 +2351,7 @@ def apply_cleanup(profs_dict): return changed while changed and not max_use_abort: - process_count = {} + process_count = Counter() t0 = time.perf_counter() changed = False iter_cleanup_sub_profs = {} @@ -2376,7 +2368,6 @@ def apply_cleanup(profs_dict): time_rewriters[grewrite] += time.perf_counter() - t_rewrite sub_profs.append(sub_prof) if change_tracker.changed: - process_count.setdefault(grewrite, 0) process_count[grewrite] += 1 global_process_count[grewrite] += 1 changed = True @@ -2431,7 +2422,6 @@ def chin_(node, i, r, new_r, reason): time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite if not node_rewriter_change: continue - process_count.setdefault(node_rewriter, 0) process_count[node_rewriter] += 1 global_process_count[node_rewriter] += 1 changed = True @@ -2459,7 +2449,6 @@ def chin_(node, i, r, new_r, reason): time_rewriters[grewrite] += time.perf_counter() - t_rewrite sub_profs.append(sub_prof) if change_tracker.changed: - process_count.setdefault(grewrite, 0) process_count[grewrite] += 1 global_process_count[grewrite] += 1 changed = True @@ -2514,7 +2503,7 @@ def chin_(node, i, r, new_r, reason): (start_nb_nodes, end_nb_nodes, max_nb_nodes), global_rewriter_timing, nb_nodes, - time_rewriters, + dict(time_rewriters), io_toposort_timing, node_created, global_sub_profs, @@ -2597,14 +2586,7 @@ def print_profile(cls, stream, prof, level=0): count_rewrite = [] not_used = [] not_used_time = 0 - process_count = {} - for o in ( - rewrite.global_rewriters - + list(rewrite.get_node_rewriters()) - + list(rewrite.final_rewriters) - + list(rewrite.cleanup_rewriters) - ): - process_count.setdefault(o, 0) + process_count = Counter() for count in loop_process_count: for o, v in count.items(): process_count[o] += v diff --git a/pytensor/graph/rewriting/db.py b/pytensor/graph/rewriting/db.py index 65a90abb8b..a35ae6fe5d 100644 --- a/pytensor/graph/rewriting/db.py +++ b/pytensor/graph/rewriting/db.py @@ -1,6 +1,7 @@ import copy import math import sys +from collections import defaultdict from collections.abc import Iterable, Sequence from functools import cmp_to_key from io import StringIO @@ -9,7 +10,6 @@ from pytensor.configdefaults import config from pytensor.graph.rewriting import basic as pytensor_rewriting from pytensor.misc.ordered_set import OrderedSet -from pytensor.utils import DefaultOrderedDict RewritesType = pytensor_rewriting.GraphRewriter | pytensor_rewriting.NodeRewriter @@ -23,7 +23,7 @@ class RewriteDatabase: """ def __init__(self): - self.__db__ = DefaultOrderedDict(OrderedSet) + self.__db__ = defaultdict(OrderedSet) self._names = set() # This will be reset by `self.register` (via `obj.name` by the thing # doing the registering) diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index 767656e081..b78ed7e670 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -524,12 +524,13 @@ def make_thunk(self, **kwargs): thunk_groups = list(zip(*thunk_lists)) order = [x[0] for x in zip(*order_lists)] - to_reset = [] - for thunks, node in zip(thunk_groups, order): - for j, output in enumerate(node.outputs): - if output in no_recycling: - for thunk in thunks: - to_reset.append(thunk.outputs[j]) + to_reset = [ + thunk.outputs[j] + for thunks, node in zip(thunk_groups, order) + for j, output in enumerate(node.outputs) + if output in no_recycling + for thunk in thunks + ] wrapper = self.wrapper pre = self.pre @@ -696,18 +697,16 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None): computed, last_user = gc_helper(nodes) if self.allow_gc: - post_thunk_old_storage = [] - - for node in nodes: - post_thunk_old_storage.append( - [ - storage_map[input] - for input in node.inputs - if (input in computed) - and (input not in fgraph.outputs) - and (node == last_user[input]) - ] - ) + post_thunk_old_storage = [ + [ + storage_map[input] + for input in node.inputs + if (input in computed) + and (input not in fgraph.outputs) + and (node == last_user[input]) + ] + for node in nodes + ] else: post_thunk_old_storage = None diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index e11247c9b3..cfbacb9414 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -1129,19 +1129,18 @@ def __compile__( ) def get_init_tasks(self): - init_tasks = [] - tasks = [] + vars = [v for v in self.variables if v not in self.consts] id = 1 - for v in self.variables: - if v in self.consts: - continue - init_tasks.append((v, "init", id)) - tasks.append((v, "get", id + 1)) - id += 2 - for node in self.node_order: - tasks.append((node, "code", id)) - init_tasks.append((node, "init", id + 1)) - id += 2 + init_tasks = [(v, "init", id + 2 * i) for i, v in enumerate(vars)] + tasks = [(v, "get", id + 2 * i + 1) for i, v in enumerate(vars)] + + id += 2 * len(vars) + tasks.extend( + (node, "code", id + 2 * i) for i, node in enumerate(self.node_order) + ) + init_tasks.extend( + (node, "init", id + 2 * i + 1) for i, node in enumerate(self.node_order) + ) return init_tasks, tasks def make_thunk( @@ -1492,12 +1491,11 @@ def in_sig(i, topological_pos, i_idx): # graph's information used to compute the key. If we mistakenly # pretend that inputs with clients don't have any, were are only using # those inputs more than once to compute the key. - for ipos, var in [ - (i, var) - for i, var in enumerate(fgraph.inputs) + sig.extend( + (var.type, in_sig(var, -1, ipos)) + for ipos, var in enumerate(fgraph.inputs) if not len(fgraph.clients[var]) - ]: - sig.append((var.type, in_sig(var, -1, ipos))) + ) # crystalize the signature and version sig = tuple(sig) diff --git a/pytensor/link/c/cmodule.py b/pytensor/link/c/cmodule.py index 8ad07a04f5..1532835a3c 100644 --- a/pytensor/link/c/cmodule.py +++ b/pytensor/link/c/cmodule.py @@ -641,7 +641,7 @@ class ModuleCache: The cache contains one directory for each module, containing: - the dynamic library file itself (e.g. ``.so/.pyd``), - an empty ``__init__.py`` file, so Python can import it, - - a file containing the source code for the module (e.g. ``mod.cpp/mod.cu``), + - a file containing the source code for the module (e.g. ``mod.cpp``), - a ``key.pkl`` file, containing a KeyData object with all the keys associated with that module, - possibly a ``delete.me`` file, meaning this directory has been marked diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index c07f903de8..582947e45d 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -220,12 +220,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): def lquote_macro(txt: str) -> str: """Turn the last line of text into a ``\\``-commented line.""" - res = [] - spl = txt.split("\n") - for l in spl[:-1]: - res.append(l + " \\") - res.append(spl[-1]) - return "\n".join(res) + return " \\\n".join(txt.split("\n")) def get_sub_macros(sub: dict[str, str]) -> tuple[str, str]: @@ -240,21 +235,17 @@ def get_sub_macros(sub: dict[str, str]) -> tuple[str, str]: return "\n".join(define_macros), "\n".join(undef_macros) -def get_io_macros( - inputs: list[str], outputs: list[str] -) -> tuple[list[str]] | tuple[str, str]: - define_macros = [] - undef_macros = [] +def get_io_macros(inputs: list[str], outputs: list[str]) -> tuple[str, str]: + define_inputs = [f"#define INPUT_{int(i)} {inp}" for i, inp in enumerate(inputs)] + define_outputs = [f"#define OUTPUT_{int(i)} {out}" for i, out in enumerate(outputs)] - for i, inp in enumerate(inputs): - define_macros.append(f"#define INPUT_{int(i)} {inp}") - undef_macros.append(f"#undef INPUT_{int(i)}") + undef_inputs = [f"#undef INPUT_{int(i)}" for i in range(len(inputs))] + undef_outputs = [f"#undef OUTPUT_{int(i)}" for i in range(len(outputs))] - for i, out in enumerate(outputs): - define_macros.append(f"#define OUTPUT_{int(i)} {out}") - undef_macros.append(f"#undef OUTPUT_{int(i)}") + define_all = "\n".join(define_inputs + define_outputs) + undef_all = "\n".join(undef_inputs + undef_outputs) - return "\n".join(define_macros), "\n".join(undef_macros) + return define_all, undef_all class ExternalCOp(COp): @@ -560,9 +551,10 @@ def get_c_macros( define_macros.append(define_template % ("APPLY_SPECIFIC(str)", f"str##_{name}")) undef_macros.append(undef_template % "APPLY_SPECIFIC") - for n, v in self.__get_op_params(): - define_macros.append(define_template % (n, v)) - undef_macros.append(undef_template % (n,)) + define_macros.extend( + define_template % (n, v) for n, v in self.__get_op_params() + ) + undef_macros.extend(undef_template % (n,) for n, _ in self.__get_op_params()) return "\n".join(define_macros), "\n".join(undef_macros) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index a341231674..2b934d049c 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -131,21 +131,19 @@ def create_numba_signature( reduce_to_scalar: bool = False, ) -> numba.types.Type: """Create a Numba type for the signature of an `Apply` node or `FunctionGraph`.""" - input_types = [] - for inp in node_or_fgraph.inputs: - input_types.append( - get_numba_type( - inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar - ) + input_types = [ + get_numba_type( + inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar ) + for inp in node_or_fgraph.inputs + ] - output_types = [] - for out in node_or_fgraph.outputs: - output_types.append( - get_numba_type( - out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar - ) + output_types = [ + get_numba_type( + out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar ) + for out in node_or_fgraph.outputs + ] if len(output_types) > 1: return numba.types.Tuple(output_types)(*input_types) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index e93df12ec6..f0820c3899 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -520,9 +520,7 @@ def elemwise(*inputs): if length == 1 and shape and iter_length != 1 and not allow_bc: raise ValueError("Broadcast not allowed.") - outputs = [] - for dtype in output_dtypes: - outputs.append(np.empty(shape, dtype=dtype)) + outputs = [np.empty(shape, dtype=dtype) for dtype in output_dtypes] for idx in np.ndindex(shape): vals = [input[idx] for input in inputs_bc] diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 34f088fd54..94ecb107b6 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -268,15 +268,15 @@ def add_output_storage_post_proc_stmt( output_taps = inner_in_names_to_output_taps.get( outer_in_name, [tap_storage_size] ) - for out_tap in output_taps: - inner_out_to_outer_in_stmts.append( - idx_to_str( - storage_name, - out_tap, - size=storage_size_name, - allow_scalar=True, - ) + inner_out_to_outer_in_stmts.extend( + idx_to_str( + storage_name, + out_tap, + size=storage_size_name, + allow_scalar=True, ) + for out_tap in output_taps + ) add_output_storage_post_proc_stmt( storage_name, output_taps, storage_size_name diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index 1cb3c01265..2d0ef805ba 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -246,10 +246,8 @@ def update_profile(self, profile): for node, thunk, t, c in zip( self.nodes, self.thunks, self.call_times, self.call_counts ): - profile.apply_time.setdefault((self.fgraph, node), 0.0) profile.apply_time[(self.fgraph, node)] += t - profile.apply_callcount.setdefault((self.fgraph, node), 0) profile.apply_callcount[(self.fgraph, node)] += c profile.apply_cimpl[node] = hasattr(thunk, "cthunk") @@ -1111,9 +1109,8 @@ def make_vm( # builds the list of prereqs induced by e.g. destroy_handler ords = self.fgraph.orderings() node_prereqs = [] - node_output_size = [] + node_output_size = [0] * len(nodes) for i, node in enumerate(nodes): - node_output_size.append(0) prereq_var_idxs = [] for prereq_node in ords.get(node, []): prereq_var_idxs.extend([vars_idx[v] for v in prereq_node.outputs]) diff --git a/pytensor/misc/check_duplicate_key.py b/pytensor/misc/check_duplicate_key.py index fa42cb6558..c8accf100d 100644 --- a/pytensor/misc/check_duplicate_key.py +++ b/pytensor/misc/check_duplicate_key.py @@ -1,6 +1,7 @@ import os import pickle import sys +from collections import Counter from pytensor.configdefaults import config @@ -15,14 +16,13 @@ else: dirs = os.listdir(config.compiledir) dirs = [os.path.join(config.compiledir, d) for d in dirs] -keys: dict = {} # key -> nb seen +keys: Counter[bytes] = Counter() # key -> nb seen mods: dict = {} for dir in dirs: key = None try: - with open(os.path.join(dir, "key.pkl")) as f: + with open(os.path.join(dir, "key.pkl"), "rb") as f: key = f.read() - keys.setdefault(key, 0) keys[key] += 1 del f except OSError: @@ -30,62 +30,49 @@ pass try: path = os.path.join(dir, "mod.cpp") - if not os.path.exists(path): - path = os.path.join(dir, "mod.cu") - with open(path) as f: - mod = f.read() + with open(path) as fmod: + mod = fmod.read() mods.setdefault(mod, ()) mods[mod] += (key,) del mod - del f + del fmod del path except OSError: - print(dir, "don't have a mod.{cpp,cu} file") + print(dir, "don't have a mod.cpp file") if DISPLAY_DUPLICATE_KEYS: for k, v in keys.items(): if v > 1: print("Duplicate key (%i copies): %s" % (v, pickle.loads(k))) -nbs_keys: dict = {} # nb seen -> now many key -for val in keys.values(): - nbs_keys.setdefault(val, 0) - nbs_keys[val] += 1 +# nb seen -> how many keys +nbs_keys = Counter(val for val in keys.values()) -nbs_mod: dict = {} # nb seen -> how many key -nbs_mod_to_key = {} # nb seen -> keys -more_than_one = 0 -for mod, kk in mods.items(): - val = len(kk) - nbs_mod.setdefault(val, 0) - nbs_mod[val] += 1 - if val > 1: - more_than_one += 1 - nbs_mod_to_key[val] = kk +# nb seen -> how many keys +nbs_mod = Counter(len(kk) for kk in mods.values()) +# nb seen -> keys +nbs_mod_to_key = {len(kk): kk for kk in mods.values()} +more_than_one = sum(len(kk) > 1 for kk in mods.values()) if DISPLAY_MOST_FREQUENT_DUPLICATE_CCODE: - m = max(nbs_mod.keys()) - print("The keys associated to the mod.{cpp,cu} with the most number of copy:") + m = max(nbs_mod) + print("The keys associated to the mod.cpp with the most number of copy:") for kk in nbs_mod_to_key[m]: kk = pickle.loads(kk) print(kk) print("key.pkl histograph") -l = list(nbs_keys.items()) -l.sort() -print(l) +print(sorted(nbs_keys.items())) -print("mod.{cpp,cu} histogram") -l = list(nbs_mod.items()) -l.sort() -print(l) +print("mod.cpp histogram") +print(sorted(nbs_mod.items())) -total = sum(len(k) for k in list(mods.values())) +total = sum(len(k) for k in mods.values()) uniq = len(mods) useless = total - uniq -print("mod.{cpp,cu} total:", total) -print("mod.{cpp,cu} uniq:", uniq) -print("mod.{cpp,cu} with more than 1 copy:", more_than_one) -print("mod.{cpp,cu} useless:", useless, float(useless) / total * 100, "%") +print("mod.cpp total:", total) +print("mod.cpp uniq:", uniq) +print("mod.cpp with more than 1 copy:", more_than_one) +print("mod.cpp useless:", useless, float(useless) / total * 100, "%") print("nb directory", len(dirs)) diff --git a/pytensor/misc/frozendict.py b/pytensor/misc/frozendict.py index 9e87bdde24..909053ffbe 100644 --- a/pytensor/misc/frozendict.py +++ b/pytensor/misc/frozendict.py @@ -1,7 +1,6 @@ # License : https://github.com/slezica/python-frozendict/blob/master/LICENSE.txt -import collections import functools import operator from collections.abc import Mapping @@ -43,11 +42,3 @@ def __hash__(self): self._hash = functools.reduce(operator.xor, hashes, 0) return self._hash - - -class FrozenOrderedDict(frozendict): - """ - A FrozenDict subclass that maintains key order - """ - - dict_cls = collections.OrderedDict diff --git a/pytensor/misc/ordered_set.py b/pytensor/misc/ordered_set.py index 13b090cf65..c6b6566644 100644 --- a/pytensor/misc/ordered_set.py +++ b/pytensor/misc/ordered_set.py @@ -1,197 +1,44 @@ -import types -import weakref -from collections.abc import MutableSet - - -def check_deterministic(iterable): - # Most places where OrderedSet is used, pytensor interprets any exception - # whatsoever as a problem that an optimization introduced into the graph. - # If I raise a TypeError when the DestroyHandler tries to do something - # non-deterministic, it will just result in optimizations getting ignored. - # So I must use an assert here. In the long term we should fix the rest of - # pytensor to use exceptions correctly, so that this can be a TypeError. - if iterable is not None: - if not isinstance( - iterable, list | tuple | OrderedSet | types.GeneratorType | str | dict - ): - if len(iterable) > 1: - # We need to accept length 1 size to allow unpickle in tests. - raise AssertionError( - "Get an not ordered iterable when one was expected" - ) - - -# Copyright (C) 2009 Raymond Hettinger -# Permission is hereby granted, free of charge, to any person obtaining a -# copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to permit -# persons to whom the Software is furnished to do so, subject to the -# following conditions: - -# The above copyright notice and this permission notice shall be included -# in all copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# {{{ http://code.activestate.com/recipes/576696/ (r5) - - -class Link: - # This make that we need to use a different pickle protocol - # then the default. Otherwise, there is pickling errors - __slots__ = "prev", "next", "key", "__weakref__" - - def __getstate__(self): - # weakref.proxy don't pickle well, so we use weakref.ref - # manually and don't pickle the weakref. - # We restore the weakref when we unpickle. - ret = [self.prev(), self.next()] - try: - ret.append(self.key) - except AttributeError: - pass - return ret - - def __setstate__(self, state): - self.prev = weakref.ref(state[0]) - self.next = weakref.ref(state[1]) - if len(state) == 3: - self.key = state[2] +from collections.abc import Iterable, Iterator, MutableSet +from typing import Any class OrderedSet(MutableSet): - "Set the remembers the order elements were added" - - # Big-O running times for all methods are the same as for regular sets. - # The internal self.__map dictionary maps keys to links in a doubly linked list. - # The circular doubly linked list starts and ends with a sentinel element. - # The sentinel element never gets deleted (this simplifies the algorithm). - # The prev/next links are weakref proxies (to prevent circular references). - # Individual links are kept alive by the hard reference in self.__map. - # Those hard references disappear when a key is deleted from an OrderedSet. + values: dict[Any, None] - # Added by IG-- pre-existing pytensor code expected sets - # to have this method - def update(self, iterable): - check_deterministic(iterable) - self |= iterable - - def __init__(self, iterable=None): - # Checks added by IG - check_deterministic(iterable) - self.__root = root = Link() # sentinel node for doubly linked list - root.prev = root.next = weakref.ref(root) - self.__map = {} # key --> link - if iterable is not None: - self |= iterable - - def __len__(self): - return len(self.__map) - - def __contains__(self, key): - return key in self.__map - - def add(self, key): - # Store new key in a new link at the end of the linked list - if key not in self.__map: - self.__map[key] = link = Link() - root = self.__root - last = root.prev - link.prev, link.next, link.key = last, weakref.ref(root), key - last().next = root.prev = weakref.ref(link) - - def union(self, s): - check_deterministic(s) - n = self.copy() - for elem in s: - if elem not in n: - n.add(elem) - return n - - def intersection_update(self, s): - l = [] - for elem in self: - if elem not in s: - l.append(elem) - for elem in l: - self.remove(elem) - return self + def __init__(self, iterable: Iterable | None = None) -> None: + if iterable is None: + self.values = {} + else: + self.values = {value: None for value in iterable} - def difference_update(self, s): - check_deterministic(s) - for elem in s: - if elem in self: - self.remove(elem) - return self + def __contains__(self, value) -> bool: + return value in self.values - def copy(self): - n = OrderedSet() - n.update(self) - return n + def __iter__(self) -> Iterator: + yield from self.values - def discard(self, key): - # Remove an existing item using self.__map to find the link which is - # then removed by updating the links in the predecessor and successors. - if key in self.__map: - link = self.__map.pop(key) - link.prev().next = link.next - link.next().prev = link.prev + def __len__(self) -> int: + return len(self.values) - def __iter__(self): - # Traverse the linked list in order. - root = self.__root - curr = root.next() - while curr is not root: - yield curr.key - curr = curr.next() + def add(self, value) -> None: + self.values[value] = None - def __reversed__(self): - # Traverse the linked list in reverse order. - root = self.__root - curr = root.prev() - while curr is not root: - yield curr.key - curr = curr.prev() + def discard(self, value) -> None: + if value in self.values: + del self.values[value] - def pop(self, last=True): - if not self: - raise KeyError("set is empty") - if last: - key = next(reversed(self)) - else: - key = next(iter(self)) - self.discard(key) - return key + def copy(self) -> "OrderedSet": + return OrderedSet(self) - def __repr__(self): - if not self: - return f"{self.__class__.__name__}()" - return f"{self.__class__.__name__}({list(self)!r})" - - def __eq__(self, other): - # Note that we implement only the comparison to another - # `OrderedSet`, and not to a regular `set`, because otherwise we - # could have a non-symmetric equality relation like: - # my_ordered_set == my_set and my_set != my_ordered_set - if isinstance(other, OrderedSet): - return len(self) == len(other) and list(self) == list(other) - elif isinstance(other, set): - # Raise exception to avoid confusion. - raise TypeError( - "Cannot compare an `OrderedSet` to a `set` because " - "this comparison cannot be made symmetric: please " - "manually cast your `OrderedSet` into `set` before " - "performing this comparison." - ) - else: - return NotImplemented + def update(self, other: Iterable) -> None: + for value in other: + self.add(value) + def union(self, other: Iterable) -> "OrderedSet": + new_set = self.copy() + new_set.update(other) + return new_set -# end of http://code.activestate.com/recipes/576696/ }}} + def difference_update(self, other: Iterable) -> None: + for value in other: + self.discard(value) diff --git a/pytensor/misc/pkl_utils.py b/pytensor/misc/pkl_utils.py index ae3549f1cb..bd652d573d 100644 --- a/pytensor/misc/pkl_utils.py +++ b/pytensor/misc/pkl_utils.py @@ -10,7 +10,7 @@ import sys import tempfile import zipfile -from collections import defaultdict +from collections import Counter from contextlib import closing from io import BytesIO from pickle import HIGHEST_PROTOCOL @@ -144,7 +144,7 @@ class PersistentSharedVariableID(PersistentNdarrayID): def __init__(self, zip_file, allow_unnamed=True, allow_duplicates=True): super().__init__(zip_file) - self.name_counter = defaultdict(int) + self.name_counter = Counter() self.ndarray_names = {} self.allow_unnamed = allow_unnamed self.allow_duplicates = allow_duplicates diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 2a3db168ba..865a68b584 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1575,9 +1575,7 @@ def get_grad(self, elem): def L_op(self, inputs, outputs, gout): (x, low, hi) = inputs (gz,) = gout - grads = [] - for elem in [x, low, hi]: - grads.append(self.get_grad(elem)) + grads = [self.get_grad(elem) for elem in [x, low, hi]] return grads diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index d830e1a0ce..0f7c9dcc69 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -646,9 +646,7 @@ def wrap_into_list(x): # Since we've added all sequences now we need to level them up based on # n_steps or their different shapes - lengths_vec = [] - for seq in scan_seqs: - lengths_vec.append(seq.shape[0]) + lengths_vec = [seq.shape[0] for seq in scan_seqs] if not isNaN_or_Inf_or_None(n_steps): # ^ N_steps should also be considered diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 4235220e81..592af1c44d 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -46,7 +46,6 @@ import dataclasses import logging import time -from collections import OrderedDict from collections.abc import Callable, Iterable from copy import copy from itertools import chain, product @@ -2188,7 +2187,7 @@ def infer_shape(self, fgraph, node, input_shapes): # corresponding outer inputs that the Scan would use as input for # any given iteration. For simplicity, we use iteration 0. inner_ins_shapes = [] - out_equivalent = OrderedDict() + out_equivalent = {} # The two following blocks are commented as it cause in some # cases extra scans in the graph. See gh-XXX for the @@ -2469,7 +2468,7 @@ def compute_all_gradients(known_grads): if (x in diff_inputs) and get_inp_idx(self_inputs.index(x)) in connected_inputs ] - gmp = OrderedDict() + gmp = {} # Required in case there is a pair of variables X and Y, with X # used to compute Y, for both of which there is an external @@ -2478,7 +2477,7 @@ def compute_all_gradients(known_grads): # it will be the sum of the external gradient signal and the # gradient obtained by propagating Y's external gradient signal # to X. - known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()]) + known_grads = {k.copy(): v for (k, v) in known_grads.items()} grads = grad( cost=None, @@ -2548,7 +2547,7 @@ def compute_all_gradients(known_grads): dC_dXt = safe_new(dC_douts[idx][0]) dC_dXts.append(dC_dXt) - known_grads = OrderedDict() + known_grads = {} dc_dxts_idx = 0 for i in range(len(diff_outputs)): if i < idx_nitsot_start or i >= idx_nitsot_end: diff --git a/pytensor/scan/utils.py b/pytensor/scan/utils.py index b31ee6a9f6..c48dd028a4 100644 --- a/pytensor/scan/utils.py +++ b/pytensor/scan/utils.py @@ -3,7 +3,7 @@ import copy import dataclasses import logging -from collections import OrderedDict, namedtuple +from collections import namedtuple from collections.abc import Callable, Sequence from itertools import chain from typing import TYPE_CHECKING @@ -258,7 +258,7 @@ def __init__(self, valid=None, invalid=None, valid_equivalent=None): if invalid is None: invalid = [] if valid_equivalent is None: - valid_equivalent = OrderedDict() + valid_equivalent = {} # Nodes that are valid to have in the graph computing outputs self.valid = set(valid) @@ -416,7 +416,7 @@ def compress_outs(op, not_required, inputs): op_inputs = op.inner_inputs[: op_info.n_seqs] op_outputs = [] node_inputs = inputs[: op_info.n_seqs + 1] - map_old_new = OrderedDict() + map_old_new = {} offset = 0 ni_offset = op_info.n_seqs + 1 diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 135433a0ab..414e3b6ed2 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1629,10 +1629,7 @@ def infer_shape(self, fgraph, node, input_shapes): return [node.inputs[1:]] def connection_pattern(self, node): - rval = [[True]] - - for ipt in node.inputs[1:]: - rval.append([False]) + rval = [[True], *([False] for _ in node.inputs[1:])] return rval @@ -1859,9 +1856,7 @@ def grad(self, inputs, output_gradients): if self.dtype in discrete_dtypes: return [ipt.zeros_like().astype(config.floatX) for ipt in inputs] - grads = [] - for i, inp in enumerate(inputs): - grads.append(output_gradients[0][i]) + grads = [output_gradients[0][i] for i in range(len(inputs))] return grads def R_op(self, inputs, eval_points): @@ -2514,13 +2509,11 @@ def c_code(self, node, name, inputs, outputs, sub): (out,) = outputs fail = sub["fail"] adtype = node.inputs[0].type.dtype_specs()[1] - copy_to_list = [] - for i, inp in enumerate(tens): - copy_to_list.append( - f"""Py_INCREF({inp}); - PyList_SetItem(list, {i}, (PyObject*){inp});""" - ) + copy_to_list = ( + f"""Py_INCREF({inp}); PyList_SetItem(list, {i}, (PyObject*){inp});""" + for i, inp in enumerate(tens) + ) copy_inputs_to_list = "\n".join(copy_to_list) n = len(tens) @@ -3442,9 +3435,7 @@ def infer_shape(self, fgraph, node, in_shapes): shp_x = in_shapes[0] shp_y = in_shapes[1] assert len(shp_x) == len(shp_y) - out_shape = [] - for i in range(len(shp_x)): - out_shape.append(maximum(shp_x[i], shp_y[i])) + out_shape = [maximum(sx, sy) for sx, sy in zip(shp_x, shp_y, strict=True)] return [out_shape] def grad(self, inp, grads): diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index d9c634b6c9..92d425cb00 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -167,9 +167,8 @@ def infer_shape( batch_ndims = self.batch_ndim(node) core_dims: dict[str, Any] = {} - batch_shapes = [] + batch_shapes = [input_shape[:batch_ndims] for input_shape in input_shapes] for input_shape, sig in zip(input_shapes, self.inputs_sig): - batch_shapes.append(input_shape[:batch_ndims]) core_shape = input_shape[batch_ndims:] for core_dim, dim_name in zip(core_shape, sig): diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index ff6f969878..2fdc8e7fd5 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -201,12 +201,12 @@ def make_node(self, _input): f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}." ) for expected, b in zip(self.input_broadcastable, ib): - if expected is True and b is False: + if expected and not b: raise TypeError( "The broadcastable pattern of the " f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}." ) - # else, expected == b or expected is False and b is True + # else, expected == b or not expected and b # Both case are good. out_static_shape = [] @@ -1161,8 +1161,10 @@ def c_code_cache_version_apply(self, node): ], ) version.append(self.scalar_op.c_code_cache_version_apply(scalar_node)) - for i in node.inputs + node.outputs: - version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version()) + version.extend( + get_scalar_type(dtype=i.type.dtype).c_code_cache_version() + for i in node.inputs + node.outputs + ) version.append(("openmp", self.openmp)) version.append(("openmp_elemwise_minsize", config.openmp_elemwise_minsize)) if all(version): @@ -1664,8 +1666,10 @@ def c_code_cache_version_apply(self, node): ], ) version.append(self.scalar_op.c_code_cache_version_apply(scalar_node)) - for i in node.inputs + node.outputs: - version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version()) + version.extend( + get_scalar_type(dtype=i.type.dtype).c_code_cache_version() + for i in node.inputs + node.outputs + ) if all(version): return tuple(version) else: diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index ec049abc67..67701cc15a 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -1,6 +1,6 @@ import itertools import sys -from collections import defaultdict, deque +from collections import Counter, defaultdict, deque from collections.abc import Generator from functools import cache from typing import TypeVar @@ -127,7 +127,7 @@ def apply(self, fgraph): "nb_call_replace": 0, "nb_call_validate": 0, "nb_inconsistent": 0, - "ndim": defaultdict(int), + "ndim": Counter(), } check_each_change = config.tensor__insert_inplace_optimizer_validate_nb diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 2ec1afa930..edac16bdee 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -777,7 +777,7 @@ def f(fgraph, node): # constant value... but in the meantime, better not apply this # rewrite. if rval.type.ndim == node.outputs[0].type.ndim and all( - s1 == s1 + s1 == s2 for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape) if s1 == 1 or s2 == 1 ): diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index f0f5555499..2cec476c4a 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -952,10 +952,7 @@ def grad(self, inputs, grads): return [first] + [DisconnectedType()()] * len(rest) def connection_pattern(self, node): - rval = [[True]] - - for ipt in node.inputs[1:]: - rval.append([False]) + rval = [[True], *([False] for _ in node.inputs[1:])] return rval @@ -1963,10 +1960,7 @@ def R_op(self, inputs, eval_points): return self(eval_points[0], eval_points[1], *inputs[2:], return_list=True) def connection_pattern(self, node): - rval = [[True], [True]] - - for ipt in node.inputs[2:]: - rval.append([False]) + rval = [[True], [True], *([False] for _ in node.inputs[2:])] return rval @@ -2765,10 +2759,7 @@ def perform(self, node, inputs, out_): out[0] = rval def connection_pattern(self, node): - rval = [[True]] - - for ipt in node.inputs[1:]: - rval.append([False]) + rval = [[True], *([False] for _ in node.inputs[1:])] return rval @@ -2912,10 +2903,7 @@ def infer_shape(self, fgraph, node, ishapes): return [ishapes[0]] def connection_pattern(self, node): - rval = [[True], [True]] - - for ipt in node.inputs[2:]: - rval.append([False]) + rval = [[True], [True], *([False] for _ in node.inputs[2:])] return rval diff --git a/pytensor/typed_list/basic.py b/pytensor/typed_list/basic.py index 476841f619..d62eb86739 100644 --- a/pytensor/typed_list/basic.py +++ b/pytensor/typed_list/basic.py @@ -238,8 +238,7 @@ def perform(self, node, inputs, outputs): # need to copy toAppend due to destroy_handler limitation if toAppend: o = out[0] - for i in toAppend: - o.append(_lessbroken_deepcopy(i)) + o.extend(_lessbroken_deepcopy(i) for i in toAppend) def __str__(self): return self.__class__.__name__ diff --git a/pytensor/updates.py b/pytensor/updates.py index a60c8cd91f..fa1320200c 100644 --- a/pytensor/updates.py +++ b/pytensor/updates.py @@ -1,8 +1,6 @@ """Defines Updates object for storing a (SharedVariable, new_value) mapping.""" import logging -import warnings -from collections import OrderedDict from pytensor.compile.sharedvalue import SharedVariable @@ -12,9 +10,9 @@ logger = logging.getLogger("pytensor.updates") -# Must be an OrderedDict or updates will be applied in a non-deterministic -# order. -class OrderedUpdates(OrderedDict): +# Relies on the fact that dict is ordered, otherwise updates will be applied +# in a non-deterministic order. +class OrderedUpdates(dict): """ Dict-like mapping from SharedVariable keys to their new values. @@ -22,20 +20,6 @@ class OrderedUpdates(OrderedDict): """ def __init__(self, *key, **kwargs): - if ( - len(key) >= 1 - and isinstance(key[0], dict) - and len(key[0]) > 1 - and not isinstance(key[0], OrderedDict) - ): - # Warn when using as input a non-ordered dictionary. - warnings.warn( - "Initializing an `OrderedUpdates` from a " - "non-ordered dictionary with 2+ elements could " - "make your code non-deterministic. You can use " - "an OrderedDict that is available at " - "collections.OrderedDict for python 2.6+." - ) super().__init__(*key, **kwargs) for key in self: if not isinstance(key, SharedVariable): @@ -56,19 +40,7 @@ def __setitem__(self, key, value): def update(self, other=None): if other is None: return - if ( - isinstance(other, dict) - and len(other) > 1 - and not isinstance(other, OrderedDict) - ): - # Warn about non-determinism. - warnings.warn( - "Updating an `OrderedUpdates` with a " - "non-ordered dictionary with 2+ elements could " - "make your code non-deterministic", - stacklevel=2, - ) - for key, val in OrderedDict(other).items(): + for key, val in dict(other).items(): if key in self: if self[key] == val: continue diff --git a/pytensor/utils.py b/pytensor/utils.py index ee937f1932..9fa9a4aff8 100644 --- a/pytensor/utils.py +++ b/pytensor/utils.py @@ -6,15 +6,12 @@ import struct import subprocess import sys -from collections import OrderedDict -from collections.abc import Callable from functools import partial __all__ = [ "get_unbound_function", "maybe_add_to_os_environ_pathlist", - "DefaultOrderedDict", "subprocess_Popen", "call_subprocess_Popen", "output_subprocess_Popen", @@ -376,36 +373,3 @@ def __eq__(self, other): def __hash__(self): return hash(type(self)) - - -class DefaultOrderedDict(OrderedDict): - def __init__(self, default_factory=None, *a, **kw): - if default_factory is not None and not isinstance(default_factory, Callable): - raise TypeError("first argument must be callable") - OrderedDict.__init__(self, *a, **kw) - self.default_factory = default_factory - - def __getitem__(self, key): - try: - return OrderedDict.__getitem__(self, key) - except KeyError: - return self.__missing__(key) - - def __missing__(self, key): - if self.default_factory is None: - raise KeyError(key) - self[key] = value = self.default_factory() - return value - - def __reduce__(self): - if self.default_factory is None: - args = tuple() - else: - args = (self.default_factory,) - return type(self), args, None, None, iter(self.items()) - - def copy(self): - return self.__copy__() - - def __copy__(self): - return type(self)(self.default_factory, self) diff --git a/tests/link/test_vm.py b/tests/link/test_vm.py index 8331146322..c8ac274494 100644 --- a/tests/link/test_vm.py +++ b/tests/link/test_vm.py @@ -1,4 +1,5 @@ import time +from collections import Counter import numpy as np import pytest @@ -34,11 +35,10 @@ class TestCallbacks: # Test the `VMLinker`'s callback argument, which can be useful for debugging. def setup_method(self): - self.n_callbacks = {} + self.n_callbacks = Counter() def callback(self, node, thunk, storage_map, compute_map): key = node.op.__class__.__name__ - self.n_callbacks.setdefault(key, 0) self.n_callbacks[key] += 1 def test_callback(self): @@ -50,9 +50,9 @@ def test_callback(self): ) f(1, 2, 3) - assert sum(self.n_callbacks.values()) == len(f.maker.fgraph.toposort()) + assert self.n_callbacks.total() == len(f.maker.fgraph.toposort()) f(1, 2, 3) - assert sum(self.n_callbacks.values()) == len(f.maker.fgraph.toposort()) * 2 + assert self.n_callbacks.total() == len(f.maker.fgraph.toposort()) * 2 def test_callback_with_ifelse(self): a, b, c = scalars("abc") diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 343f539274..bda8dec782 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -13,7 +13,6 @@ import pickle import shutil import sys -from collections import OrderedDict from tempfile import mkdtemp import numpy as np @@ -764,11 +763,9 @@ def test_output_padding(self): b = shared(np.random.default_rng(utt.fetch_seed()).random((5, 4))) def inner_func(a): - return a + 1, OrderedDict([(b, 2 * b)]) + return a + 1, {b: 2 * b} - out, updates = scan( - inner_func, outputs_info=[OrderedDict([("initial", init_a)])], n_steps=1 - ) + out, updates = scan(inner_func, outputs_info=[{"initial": init_a}], n_steps=1) out = out[-1] assert out.type.ndim == a.type.ndim assert updates[b].type.ndim == b.type.ndim @@ -934,7 +931,7 @@ def test_only_shared_no_input_no_output(self): state = shared(v_state, "vstate") def f_2(): - return OrderedDict([(state, 2 * state)]) + return {state: 2 * state} n_steps = iscalar("nstep") output, updates = scan( @@ -968,7 +965,7 @@ def test_shared_updates(self): X = shared(np.array(1)) out, updates = scan( - lambda: OrderedDict([(X, (X + 1))]), + lambda: {X: (X + 1)}, outputs_info=[], non_sequences=[], sequences=[], @@ -984,7 +981,7 @@ def test_shared_memory_aliasing_updates(self): y = shared(np.array(1)) out, updates = scan( - lambda: OrderedDict([(x, x + 1), (y, x)]), + lambda: {x: x + 1, y: x}, outputs_info=[], non_sequences=[], sequences=[], @@ -1914,7 +1911,7 @@ def test_grad_numeric_shared(self): shared_var = shared(np.float32(1.0)) def inner_fn(): - return [], OrderedDict([(shared_var, shared_var + np.float32(1.0))]) + return [], {shared_var: shared_var + np.float32(1.0)} _, updates = scan( inner_fn, n_steps=10, truncate_gradient=-1, go_backwards=False @@ -2746,7 +2743,7 @@ def one_step(x_t, h_tm1, W): v1 = shared(np.ones(5, dtype=config.floatX)) v2 = shared(np.ones((5, 5), dtype=config.floatX)) - shapef = function([W], expr, givens=OrderedDict([(initial, v1), (inpt, v2)])) + shapef = function([W], expr, givens={initial: v1, inpt: v2}) # First execution to cache n_steps shapef(np.ones((5, 5), dtype=config.floatX)) @@ -2755,7 +2752,7 @@ def one_step(x_t, h_tm1, W): f = function( [W, inpt], d_cost_wrt_W, - givens=OrderedDict([(initial, shared(np.zeros(5)))]), + givens={initial: shared(np.zeros(5))}, ) rval = np.asarray([[5187989] * 5] * 5, dtype=config.floatX) @@ -2956,7 +2953,7 @@ def onestep(x, x_tm4): seq = matrix() initial_value = shared(np.zeros((4, 1), dtype=config.floatX)) - outputs_info = [OrderedDict([("initial", initial_value), ("taps", [-4])]), None] + outputs_info = [{"initial": initial_value, "taps": [-4]}, None] results, updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info) f = function([seq], results[1]) @@ -2979,10 +2976,10 @@ def onestep(x, x_tm4): seq = matrix() initial_value = shared(np.zeros((4, 1), dtype=config.floatX)) - outputs_info = [OrderedDict([("initial", initial_value), ("taps", [-4])]), None] + outputs_info = [{"initial": initial_value, "taps": [-4]}, None] results, _ = scan(fn=onestep, sequences=seq, outputs_info=outputs_info) sharedvar = shared(np.zeros((1, 1), dtype=config.floatX)) - updates = OrderedDict([(sharedvar, results[0][-1:])]) + updates = {sharedvar: results[0][-1:]} f = function([seq], results[1], updates=updates) diff --git a/tests/test_gradient.py b/tests/test_gradient.py index f9f1e8fe4b..a92939369f 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -1,5 +1,3 @@ -from collections import OrderedDict - import numpy as np import pytest @@ -637,7 +635,7 @@ def test_known_grads(): for layer in layers: first = grad(cost, layer, disconnected_inputs="ignore") - known = OrderedDict(zip(layer, first)) + known = dict(zip(layer, first)) full = grad( cost=None, known_grads=known, wrt=inputs, disconnected_inputs="ignore" ) @@ -755,7 +753,7 @@ def test_subgraph_grad(): param_grad, next_grad = subgraph_grad( wrt=params[i], end=grad_ends[i], start=next_grad, cost=costs[i] ) - next_grad = OrderedDict(zip(grad_ends[i], next_grad)) + next_grad = dict(zip(grad_ends[i], next_grad)) param_grads.extend(param_grad) pgrads = pytensor.function(inputs, param_grads)