Skip to content

Code cleanups #842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4968e44
Fix bug in local_reshape_chain
Armavica Jun 21, 2024
bf29020
Rewrite for/append as list comprehensions
Armavica Jun 21, 2024
bfdc134
Simplify boolean operations with any and all
Armavica Jun 21, 2024
ec48ac6
Type function.copy method parameters
Armavica Jun 21, 2024
63f9337
Refactor check_duplicate_key and remove mod.cu
Armavica Jun 21, 2024
aee49e7
Use a Counter in compiledir
Armavica Jun 21, 2024
a37460a
Replace defaultdict(int) with Counter in pkl_utils
Armavica Jun 21, 2024
849b6a6
Replace defaultdict(int) with Counter in rewriting/elemwise
Armavica Jun 21, 2024
e5c7913
Use defaultdict in graph/rewriting/basic.py
Armavica Jun 21, 2024
7785ea6
Use defaultdict and Counter in profiling.py
Armavica Jun 21, 2024
184cf18
Use a Counter in tests/link/test_vm
Armavica Jun 21, 2024
84d2403
Use a defaultdict in graph/fg
Armavica Jun 21, 2024
d5b39d9
Remove DefaultOrderedDict (dicts are ordered now)
Armavica Jun 21, 2024
cd01e6d
Remove OrderedDict in graph/destroyhandler
Armavica Jun 21, 2024
0b1a19c
Remove OrderedDict from graph/features
Armavica Jun 21, 2024
efdb973
Remove OrderedDict from graph/fg
Armavica Jun 21, 2024
909680f
Remove OrderedDict from scan/op
Armavica Jun 21, 2024
8cb8734
Remove OrderedDict from scan/utils
Armavica Jun 21, 2024
8c0d4d0
Remove OrderedDict from updates.py
Armavica Jun 21, 2024
74e2eed
Remove OrderedDict from tests/scan/test_basic
Armavica Jun 21, 2024
484cca2
Remove OrderedDict from tests/test_gradient
Armavica Jun 21, 2024
452a15b
Remove unused FrozenOrderedDict
Armavica Jun 21, 2024
7a7daac
Rewrite OrderedSet
Armavica Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions pytensor/compile/compiledir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import pickle
import shutil
from collections import Counter

Check warning on line 10 in pytensor/compile/compiledir.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/compiledir.py#L10

Added line #L10 was not covered by tests

import numpy as np

Expand Down Expand Up @@ -111,11 +112,11 @@
compiledir = config.compiledir
table = []
table_multiple_ops = []
table_op_class = {}
table_op_class = Counter()

Check warning on line 115 in pytensor/compile/compiledir.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/compiledir.py#L115

Added line #L115 was not covered by tests
zeros_op = 0
big_key_files = []
total_key_sizes = 0
nb_keys = {}
nb_keys = Counter()

Check warning on line 119 in pytensor/compile/compiledir.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/compiledir.py#L119

Added line #L119 was not covered by tests
for dir in os.listdir(compiledir):
filename = os.path.join(compiledir, dir, "key.pkl")
if not os.path.exists(filename):
Expand All @@ -125,9 +126,7 @@
keydata = pickle.load(file)
ops = list({x for x in flatten(keydata.keys) if isinstance(x, Op)})
# Whatever the case, we count compilations for OP classes.
for op_class in {op.__class__ for op in ops}:
table_op_class.setdefault(op_class, 0)
table_op_class[op_class] += 1
table_op_class.update({op.__class__ for op in ops})
if len(ops) == 0:
zeros_op += 1
else:
Expand Down Expand Up @@ -159,7 +158,6 @@
if size > max_key_file_size:
big_key_files.append((dir, size, ops))

nb_keys.setdefault(len(keydata.keys), 0)
nb_keys[len(keydata.keys)] += 1
except OSError:
pass
Expand Down Expand Up @@ -198,8 +196,7 @@
),
underline="+",
)
table_op_class = sorted(table_op_class.items(), key=lambda t: t[1])
for op_class, nb in table_op_class:
for op_class, nb in reversed(table_op_class.most_common()):
print(op_class, nb)

if big_key_files:
Expand Down
9 changes: 4 additions & 5 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,11 +906,10 @@ def _get_preallocated_maps(
name = f"strided{tuple(steps)}"
for r in considered_outputs:
if r in init_strided:
strides = []
shapes = []
for i, size in enumerate(r_vals[r].shape):
shapes.append(slice(None, size, None))
strides.append(slice(None, None, steps[i]))
shapes = [slice(None, size, None) for size in r_vals[r].shape]
strides = [
slice(None, None, steps[i]) for i in range(r_vals[r].ndim)
]

r_buf = init_strided[r]

Expand Down
14 changes: 3 additions & 11 deletions pytensor/compile/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,10 @@ def opt_log1p(node):

"""
if isinstance(outputs, dict):
output_items = list(outputs.items())
assert all(isinstance(k, str) for k in outputs)

for item_pair in output_items:
assert isinstance(item_pair[0], str)

output_items_sorted = sorted(output_items)

output_keys = []
outputs = []
for pair in output_items_sorted:
output_keys.append(pair[0])
outputs.append(pair[1])
output_keys = sorted(outputs)
outputs = [outputs[key] for key in output_keys]

else:
output_keys = None
Expand Down
29 changes: 10 additions & 19 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pytensor.compile.profiling
from pytensor.compile.io import In, SymbolicInput, SymbolicOutput
from pytensor.compile.ops import deep_copy_op, view_op
from pytensor.compile.profiling import ProfileStats
from pytensor.configdefaults import config
from pytensor.graph.basic import (
Constant,
Expand Down Expand Up @@ -212,18 +213,14 @@

found_updates.extend(map(SymbolicOutput, updates))
elif fgraph is None:
input_vars = []

# If one of the inputs is non-atomic (i.e. has a non-`None` `Variable.owner`),
# then we need to create/clone the graph starting at these inputs.
# The result will be atomic versions of the given inputs connected to
# the same outputs.
# Otherwise, when all the inputs are already atomic, there's no need to
# clone the graph.
clone = force_clone
for spec in input_specs:
input_vars.append(spec.variable)
clone |= spec.variable.owner is not None
input_vars = [spec.variable for spec in input_specs]
clone = force_clone or any(var.owner is not None for var in input_vars)

fgraph = FunctionGraph(
input_vars,
Expand Down Expand Up @@ -557,11 +554,11 @@

def copy(
self,
share_memory=False,
swap=None,
delete_updates=False,
name=None,
profile=None,
share_memory: bool = False,
swap: dict | None = None,
delete_updates: bool = False,
name: str | None = None,
profile: bool | str | ProfileStats | None = None,
):
"""
Copy this function. Copied function will have separated maker and
Expand All @@ -588,7 +585,7 @@
If provided, will be the name of the new
Function. Otherwise, it will be old + " copy"

profile :
profile : bool | str | ProfileStats | None
as pytensor.function profile parameter

Returns
Expand Down Expand Up @@ -727,14 +724,8 @@
# reinitialize new maker and create new function
if profile is None:
profile = config.profile or config.print_global_stats
# profile -> True or False
if profile is True:
if name:
message = name
else:
message = str(profile.message) + " copy"
profile = pytensor.compile.profiling.ProfileStats(message=message)
# profile -> object
profile = pytensor.compile.profiling.ProfileStats(message=name)

Check warning on line 728 in pytensor/compile/function/types.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/function/types.py#L728

Added line #L728 was not covered by tests
elif isinstance(profile, str):
profile = pytensor.compile.profiling.ProfileStats(message=profile)

Expand Down
77 changes: 30 additions & 47 deletions pytensor/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import operator
import sys
import time
from collections import defaultdict
from collections import Counter, defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any

import numpy as np

Expand Down Expand Up @@ -204,8 +204,8 @@
self.fct_call_time = 0.0
self.fct_callcount = 0
self.vm_call_time = 0.0
self.apply_time = {}
self.apply_callcount = {}
self.apply_time = defaultdict(float)
self.apply_callcount = Counter()

Check warning on line 208 in pytensor/compile/profiling.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/profiling.py#L207-L208

Added lines #L207 - L208 were not covered by tests
# self.apply_cimpl = None
# self.message = None

Expand Down Expand Up @@ -234,9 +234,9 @@
# Total time spent in Function.vm.__call__
#

apply_time: dict[Union["FunctionGraph", Variable], float] | None = None
apply_time: dict[tuple["FunctionGraph", Apply], float]

apply_callcount: dict[Union["FunctionGraph", Variable], int] | None = None
apply_callcount: dict[tuple["FunctionGraph", Apply], int]

apply_cimpl: dict[Apply, bool] | None = None
# dict from node -> bool (1 if c, 0 if py)
Expand Down Expand Up @@ -292,10 +292,9 @@
# param is called flag_time_thunks because most other attributes with time
# in the name are times *of* something, rather than configuration flags.
def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs):
self.apply_callcount = {}
self.apply_callcount = Counter()
self.output_size = {}
# Keys are `(FunctionGraph, Variable)`
self.apply_time = {}
self.apply_time = defaultdict(float)
self.apply_cimpl = {}
self.variable_shape = {}
self.variable_strides = {}
Expand All @@ -320,37 +319,29 @@

"""
# timing is stored by node, we compute timing by class on demand
rval = {}
for (fgraph, node), t in self.apply_time.items():
typ = type(node.op)
rval.setdefault(typ, 0)
rval[typ] += t
return rval
rval = defaultdict(float)
for (_fgraph, node), t in self.apply_time.items():
rval[type(node.op)] += t
return dict(rval)

def class_callcount(self):
"""
dict op -> total number of thunk calls

"""
# timing is stored by node, we compute timing by class on demand
rval = {}
for (fgraph, node), count in self.apply_callcount.items():
typ = type(node.op)
rval.setdefault(typ, 0)
rval[typ] += count
rval = Counter()
for (_fgraph, node), count in self.apply_callcount.items():
rval[type(node.op)] += count
return rval

def class_nodes(self):
def class_nodes(self) -> Counter:
"""
dict op -> total number of nodes

"""
# timing is stored by node, we compute timing by class on demand
rval = {}
for (fgraph, node), count in self.apply_callcount.items():
typ = type(node.op)
rval.setdefault(typ, 0)
rval[typ] += 1
rval = Counter(type(node.op) for _fgraph, node in self.apply_callcount)
return rval

def class_impl(self):
Expand All @@ -360,12 +351,9 @@
"""
# timing is stored by node, we compute timing by class on demand
rval = {}
for fgraph, node in self.apply_callcount:
for _fgraph, node in self.apply_callcount:
typ = type(node.op)
if self.apply_cimpl[node]:
impl = "C "
else:
impl = "Py"
impl = "C " if self.apply_cimpl[node] else "Py"
rval.setdefault(typ, impl)
if rval[typ] != impl and len(rval[typ]) == 2:
rval[typ] += impl
Expand All @@ -377,11 +365,10 @@

"""
# timing is stored by node, we compute timing by Op on demand
rval = {}
rval = defaultdict(float)
for (fgraph, node), t in self.apply_time.items():
rval.setdefault(node.op, 0)
rval[node.op] += t
return rval
return dict(rval)

def fill_node_total_time(self, fgraph, node, total_times):
"""
Expand Down Expand Up @@ -414,9 +401,8 @@

"""
# timing is stored by node, we compute timing by Op on demand
rval = {}
rval = Counter()
for (fgraph, node), count in self.apply_callcount.items():
rval.setdefault(node.op, 0)
rval[node.op] += count
return rval

Expand All @@ -426,10 +412,7 @@

"""
# timing is stored by node, we compute timing by Op on demand
rval = {}
for (fgraph, node), count in self.apply_callcount.items():
rval.setdefault(node.op, 0)
rval[node.op] += 1
rval = Counter(node.op for _fgraph, node in self.apply_callcount)
return rval

def op_impl(self):
Expand Down Expand Up @@ -1204,8 +1187,7 @@
compute_map[var][0] = 0

for k_remove, v_remove in viewedby_remove.items():
for i in v_remove:
viewed_by[k_remove].append(i)
viewed_by[k_remove].extend(v_remove)

for k_add, v_add in viewedby_add.items():
for i in v_add:
Expand All @@ -1215,15 +1197,16 @@
del view_of[k]

# two data structure used to mimic Python gc
viewed_by = {} # {var1: [vars that view var1]}
# * {var1: [vars that view var1]}
# The len of the list is the value of python ref
# count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo
for var in fgraph.variables:
viewed_by[var] = []
view_of = {} # {var1: original var viewed by var1}
# This is more safe to help detect potential bug in the algo
viewed_by = {var: [] for var in fgraph.variables}

# * {var1: original var viewed by var1}
# The original mean that we don't keep track of all the intermediate
# relationship in the view.
view_of = {}

min_memory_generator(executable_nodes, viewed_by, view_of)

Expand Down
14 changes: 6 additions & 8 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,13 +1041,12 @@ def access_term_cache(node):
# list of bools indicating if each input is connected to the cost
inputs_connected = [
(
True
in [
any(
input_to_output and output_to_cost
for input_to_output, output_to_cost in zip(
input_to_outputs, outputs_connected
)
]
)
)
for input_to_outputs in connection_pattern
]
Expand All @@ -1067,25 +1066,24 @@ def access_term_cache(node):
# List of bools indicating if each input only has NullType outputs
only_connected_to_nan = [
(
True
not in [
not any(
in_to_out and out_to_cost and not out_nan
for in_to_out, out_to_cost, out_nan in zip(
in_to_outs, outputs_connected, ograd_is_nan
)
]
)
)
for in_to_outs in connection_pattern
]

if True not in inputs_connected:
if not any(inputs_connected):
# All outputs of this op are disconnected so we can skip
# Calling the op's grad method and report that the inputs
# are disconnected
# (The op's grad method could do this too, but this saves the
# implementer the trouble of worrying about this case)
input_grads = [disconnected_type() for ipt in inputs]
elif False not in only_connected_to_nan:
elif all(only_connected_to_nan):
# All inputs are only connected to nan gradients, so we don't
# need to bother calling the grad method. We know the gradient
# with respect to all connected inputs is nan.
Expand Down
3 changes: 1 addition & 2 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,9 +1474,8 @@ def _compute_deps_cache_(io):

_clients: dict[T, list[T]] = {}
sources: deque[T] = deque()
search_res_len: int = 0
search_res_len = len(search_res)
for snode, children in search_res:
search_res_len += 1
if children:
for child in children:
_clients.setdefault(child, []).append(snode)
Expand Down
Loading
Loading