Skip to content

Commit 19dafe4

Browse files
ArmavicaricardoV94
authored andcommitted
Add a strict argument to all zips
1 parent 6de3151 commit 19dafe4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+769
-481
lines changed

pytensor/compile/builders.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ def infer_shape(outs, inputs, input_shapes):
4343
# TODO: ShapeFeature should live elsewhere
4444
from pytensor.tensor.rewriting.shape import ShapeFeature
4545

46-
for inp, inp_shp in zip(inputs, input_shapes):
46+
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
4747
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
4848
assert len(inp_shp) == inp.type.ndim
4949

5050
shape_feature = ShapeFeature()
5151
shape_feature.on_attach(FunctionGraph([], []))
5252

5353
# Initialize shape_of with the input shapes
54-
for inp, inp_shp in zip(inputs, input_shapes):
54+
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
5555
shape_feature.set_shape(inp, inp_shp)
5656

5757
def local_traverse(out):
@@ -108,7 +108,9 @@ def construct_nominal_fgraph(
108108

109109
replacements = dict(
110110
zip(
111-
inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
111+
inputs + implicit_shared_inputs,
112+
dummy_inputs + dummy_implicit_shared_inputs,
113+
strict=True,
112114
)
113115
)
114116

@@ -138,7 +140,7 @@ def construct_nominal_fgraph(
138140
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
139141
)
140142

141-
fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
143+
fgraph.replace_all(zip(local_inputs, nominal_local_inputs, strict=True))
142144

143145
for i, inp in enumerate(fgraph.inputs):
144146
nom_inp = nominal_local_inputs[i]
@@ -562,7 +564,9 @@ def lop_overrides(inps, grads):
562564
# compute non-overriding downsteam grads from upstreams grads
563565
# it's normal some input may be disconnected, thus the 'ignore'
564566
wrt = [
565-
lin for lin, gov in zip(inner_inputs, custom_input_grads) if gov is None
567+
lin
568+
for lin, gov in zip(inner_inputs, custom_input_grads, strict=True)
569+
if gov is None
566570
]
567571
default_input_grads = fn_grad(wrt=wrt) if wrt else []
568572
input_grads = self._combine_list_overrides(
@@ -653,7 +657,7 @@ def _build_and_cache_rop_op(self):
653657
f = [
654658
output
655659
for output, custom_output_grad in zip(
656-
inner_outputs, custom_output_grads
660+
inner_outputs, custom_output_grads, strict=True
657661
)
658662
if custom_output_grad is None
659663
]
@@ -733,18 +737,24 @@ def make_node(self, *inputs):
733737

734738
non_shared_inputs = [
735739
inp_t.filter_variable(inp)
736-
for inp, inp_t in zip(non_shared_inputs, self.input_types)
740+
for inp, inp_t in zip(non_shared_inputs, self.input_types, strict=True)
737741
]
738742

739743
new_shared_inputs = inputs[num_expected_inps:]
740-
inner_and_input_shareds = list(zip(self.shared_inputs, new_shared_inputs))
744+
inner_and_input_shareds = list(
745+
zip(self.shared_inputs, new_shared_inputs, strict=True)
746+
)
741747

742748
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
743749
# The shared variables are not equal to the original shared
744750
# variables, so we construct a new `Op` that uses the new shared
745751
# variables instead.
746752
replace = dict(
747-
zip(self.inner_inputs[num_expected_inps:], new_shared_inputs)
753+
zip(
754+
self.inner_inputs[num_expected_inps:],
755+
new_shared_inputs,
756+
strict=True,
757+
)
748758
)
749759

750760
# If the new shared variables are inconsistent with the inner-graph,
@@ -811,7 +821,7 @@ def infer_shape(self, fgraph, node, shapes):
811821
# each shape call. PyTensor optimizer will clean this up later, but this
812822
# will make extra work for the optimizer.
813823

814-
repl = dict(zip(self.inner_inputs, node.inputs))
824+
repl = dict(zip(self.inner_inputs, node.inputs, strict=True))
815825
clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
816826
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
817827
ret = []
@@ -853,5 +863,5 @@ def clone(self):
853863
def perform(self, node, inputs, outputs):
854864
variables = self.fn(*inputs)
855865
assert len(variables) == len(outputs)
856-
for output, variable in zip(outputs, variables):
866+
for output, variable in zip(outputs, variables, strict=True):
857867
output[0] = variable

pytensor/compile/debugmode.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ def _get_preallocated_maps(
865865
# except if broadcastable, or for dimensions above
866866
# config.DebugMode__check_preallocated_output_ndim
867867
buf_shape = []
868-
for s, b in zip(r_vals[r].shape, r.broadcastable):
868+
for s, b in zip(r_vals[r].shape, r.broadcastable, strict=True):
869869
if b or ((r.ndim - len(buf_shape)) > check_ndim):
870870
buf_shape.append(s)
871871
else:
@@ -943,7 +943,7 @@ def _get_preallocated_maps(
943943
r_shape_diff = shape_diff[: r.ndim]
944944
new_buf_shape = [
945945
max((s + sd), 0)
946-
for s, sd in zip(r_vals[r].shape, r_shape_diff)
946+
for s, sd in zip(r_vals[r].shape, r_shape_diff, strict=True)
947947
]
948948
new_buf = np.empty(new_buf_shape, dtype=r.type.dtype)
949949
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
@@ -1575,7 +1575,7 @@ def f():
15751575
# try:
15761576
# compute the value of all variables
15771577
for i, (thunk_py, thunk_c, node) in enumerate(
1578-
zip(thunks_py, thunks_c, order)
1578+
zip(thunks_py, thunks_c, order, strict=True)
15791579
):
15801580
_logger.debug(f"{i} - starting node {i} {node}")
15811581

@@ -1855,7 +1855,7 @@ def thunk():
18551855
assert s[0] is None
18561856

18571857
# store our output variables to their respective storage lists
1858-
for output, storage in zip(fgraph.outputs, output_storage):
1858+
for output, storage in zip(fgraph.outputs, output_storage, strict=True):
18591859
storage[0] = r_vals[output]
18601860

18611861
# transfer all inputs back to their respective storage lists
@@ -1931,11 +1931,11 @@ def deco():
19311931
f,
19321932
[
19331933
Container(input, storage, readonly=False)
1934-
for input, storage in zip(fgraph.inputs, input_storage)
1934+
for input, storage in zip(fgraph.inputs, input_storage, strict=True)
19351935
],
19361936
[
19371937
Container(output, storage, readonly=True)
1938-
for output, storage in zip(fgraph.outputs, output_storage)
1938+
for output, storage in zip(fgraph.outputs, output_storage, strict=True)
19391939
],
19401940
thunks_py,
19411941
order,
@@ -2122,7 +2122,9 @@ def __init__(
21222122

21232123
no_borrow = [
21242124
output
2125-
for output, spec in zip(fgraph.outputs, outputs + additional_outputs)
2125+
for output, spec in zip(
2126+
fgraph.outputs, outputs + additional_outputs, strict=True
2127+
)
21262128
if not spec.borrow
21272129
]
21282130
if no_borrow:

pytensor/compile/function/pfunc.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def construct_pfunc_ins_and_outs(
603603

604604
new_inputs = []
605605

606-
for i, iv in zip(inputs, input_variables):
606+
for i, iv in zip(inputs, input_variables, strict=True):
607607
new_i = copy(i)
608608
new_i.variable = iv
609609

@@ -637,13 +637,13 @@ def construct_pfunc_ins_and_outs(
637637
assert len(fgraph.inputs) == len(inputs)
638638
assert len(fgraph.outputs) == len(outputs)
639639

640-
for fg_inp, inp in zip(fgraph.inputs, inputs):
640+
for fg_inp, inp in zip(fgraph.inputs, inputs, strict=True):
641641
if fg_inp != getattr(inp, "variable", inp):
642642
raise ValueError(
643643
f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}"
644644
)
645645

646-
for fg_out, out in zip(fgraph.outputs, outputs):
646+
for fg_out, out in zip(fgraph.outputs, outputs, strict=True):
647647
if fg_out != getattr(out, "variable", out):
648648
raise ValueError(
649649
f"`fgraph`'s output does not match the provided output: {fg_out}, {out}"

pytensor/compile/function/types.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def std_fgraph(
241241
fgraph.attach_feature(
242242
Supervisor(
243243
input
244-
for spec, input in zip(input_specs, fgraph.inputs)
244+
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
245245
if not (
246246
spec.mutable
247247
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
@@ -442,7 +442,7 @@ def __init__(
442442
# this loop works by modifying the elements (as variable c) of
443443
# self.input_storage inplace.
444444
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(
445-
zip(self.indices, defaults)
445+
zip(self.indices, defaults, strict=True)
446446
):
447447
if indices is None:
448448
# containers is being used as a stack. Here we pop off
@@ -671,7 +671,7 @@ def checkSV(sv_ori, sv_rpl):
671671
else:
672672
outs = list(map(SymbolicOutput, fg_cpy.outputs))
673673

674-
for out_ori, out_cpy in zip(maker.outputs, outs):
674+
for out_ori, out_cpy in zip(maker.outputs, outs, strict=False):
675675
out_cpy.borrow = out_ori.borrow
676676

677677
# swap SharedVariable
@@ -684,7 +684,7 @@ def checkSV(sv_ori, sv_rpl):
684684
raise ValueError(f"SharedVariable: {sv.name} not found")
685685

686686
# Swap SharedVariable in fgraph and In instances
687-
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
687+
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
688688
# Variables in maker.inputs are defined by user, therefore we
689689
# use them to make comparison and do the mapping.
690690
# Otherwise we don't touch them.
@@ -708,7 +708,7 @@ def checkSV(sv_ori, sv_rpl):
708708

709709
# Delete update if needed
710710
rev_update_mapping = {v: k for k, v in fg_cpy.update_mapping.items()}
711-
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs)):
711+
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs, strict=True)):
712712
inp.variable = in_var
713713
if not delete_updates and inp.update is not None:
714714
out_idx = rev_update_mapping[n]
@@ -768,7 +768,11 @@ def checkSV(sv_ori, sv_rpl):
768768
).create(input_storage, storage_map=new_storage_map)
769769

770770
for in_ori, in_cpy, ori, cpy in zip(
771-
maker.inputs, f_cpy.maker.inputs, self.input_storage, f_cpy.input_storage
771+
maker.inputs,
772+
f_cpy.maker.inputs,
773+
self.input_storage,
774+
f_cpy.input_storage,
775+
strict=True,
772776
):
773777
# Share immutable ShareVariable and constant input's storage
774778
swapped = swap is not None and in_ori.variable in swap
@@ -999,7 +1003,7 @@ def __call__(self, *args, **kwargs):
9991003
# output reference from the internal storage cells
10001004
if getattr(self.vm, "allow_gc", False):
10011005
for o_container, o_variable in zip(
1002-
self.output_storage, self.maker.fgraph.outputs
1006+
self.output_storage, self.maker.fgraph.outputs, strict=True
10031007
):
10041008
if o_variable.owner is not None:
10051009
# this node is the variable of computation
@@ -1009,7 +1013,7 @@ def __call__(self, *args, **kwargs):
10091013
if getattr(self.vm, "need_update_inputs", True):
10101014
# Update the inputs that have an update function
10111015
for input, storage in reversed(
1012-
list(zip(self.maker.expanded_inputs, input_storage))
1016+
list(zip(self.maker.expanded_inputs, input_storage, strict=True))
10131017
):
10141018
if input.update is not None:
10151019
storage.data = outputs.pop()
@@ -1040,7 +1044,7 @@ def __call__(self, *args, **kwargs):
10401044
assert len(self.output_keys) == len(outputs)
10411045

10421046
if output_subset is None:
1043-
return dict(zip(self.output_keys, outputs))
1047+
return dict(zip(self.output_keys, outputs, strict=True))
10441048
else:
10451049
return {
10461050
self.output_keys[index]: outputs[index]
@@ -1108,7 +1112,7 @@ def _pickle_Function(f):
11081112
input_storage = []
11091113

11101114
for (input, indices, inputs), (required, refeed, default) in zip(
1111-
f.indices, f.defaults
1115+
f.indices, f.defaults, strict=True
11121116
):
11131117
input_storage.append(ins[0])
11141118
del ins[0]
@@ -1150,7 +1154,7 @@ def _constructor_Function(maker, input_storage, inputs_data, trust_input=False):
11501154

11511155
f = maker.create(input_storage)
11521156
assert len(f.input_storage) == len(inputs_data)
1153-
for container, x in zip(f.input_storage, inputs_data):
1157+
for container, x in zip(f.input_storage, inputs_data, strict=True):
11541158
assert (
11551159
(container.data is x)
11561160
or (isinstance(x, np.ndarray) and (container.data == x).all())
@@ -1184,7 +1188,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
11841188
reason = "insert_deepcopy"
11851189
updated_fgraph_inputs = {
11861190
fgraph_i
1187-
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs)
1191+
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True)
11881192
if getattr(i, "update", False)
11891193
}
11901194

@@ -1521,7 +1525,9 @@ def __init__(
15211525
# return the internal storage pointer.
15221526
no_borrow = [
15231527
output
1524-
for output, spec in zip(fgraph.outputs, outputs + found_updates)
1528+
for output, spec in zip(
1529+
fgraph.outputs, outputs + found_updates, strict=True
1530+
)
15251531
if not spec.borrow
15261532
]
15271533

@@ -1590,7 +1596,7 @@ def create(self, input_storage=None, storage_map=None):
15901596
# defaults lists.
15911597
assert len(self.indices) == len(input_storage)
15921598
for i, ((input, indices, subinputs), input_storage_i) in enumerate(
1593-
zip(self.indices, input_storage)
1599+
zip(self.indices, input_storage, strict=True)
15941600
):
15951601
# Replace any default value given as a variable by its
15961602
# container. Note that this makes sense only in the

pytensor/d3viz/formatting.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,14 @@ def format_map(m):
244244
ext_inputs = [self.__node_id(x) for x in node.inputs]
245245
int_inputs = [gf.__node_id(x) for x in node.op.inner_inputs]
246246
assert len(ext_inputs) == len(int_inputs)
247-
h = format_map(zip(ext_inputs, int_inputs))
247+
h = format_map(zip(ext_inputs, int_inputs, strict=True))
248248
pd_node.get_attributes()["subg_map_inputs"] = h
249249

250250
# Outputs mapping
251251
ext_outputs = [self.__node_id(x) for x in node.outputs]
252252
int_outputs = [gf.__node_id(x) for x in node.op.inner_outputs]
253253
assert len(ext_outputs) == len(int_outputs)
254-
h = format_map(zip(int_outputs, ext_outputs))
254+
h = format_map(zip(int_outputs, ext_outputs, strict=True))
255255
pd_node.get_attributes()["subg_map_outputs"] = h
256256

257257
return graph

0 commit comments

Comments
 (0)