Skip to content

Commit 5e74536

Browse files
committed
Rewrite for/append as list comprehensions
1 parent 4c714e0 commit 5e74536

File tree

21 files changed

+113
-164
lines changed

21 files changed

+113
-164
lines changed

pytensor/compile/debugmode.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -906,11 +906,10 @@ def _get_preallocated_maps(
906906
name = f"strided{tuple(steps)}"
907907
for r in considered_outputs:
908908
if r in init_strided:
909-
strides = []
910-
shapes = []
911-
for i, size in enumerate(r_vals[r].shape):
912-
shapes.append(slice(None, size, None))
913-
strides.append(slice(None, None, steps[i]))
909+
shapes = [slice(None, size, None) for size in r_vals[r].shape]
910+
strides = [
911+
slice(None, None, steps[i]) for i in range(r_vals[r].ndim)
912+
]
914913

915914
r_buf = init_strided[r]
916915

pytensor/compile/function/__init__.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -247,18 +247,10 @@ def opt_log1p(node):
247247
248248
"""
249249
if isinstance(outputs, dict):
250-
output_items = list(outputs.items())
250+
assert all(isinstance(k, str) for k in outputs)
251251

252-
for item_pair in output_items:
253-
assert isinstance(item_pair[0], str)
254-
255-
output_items_sorted = sorted(output_items)
256-
257-
output_keys = []
258-
outputs = []
259-
for pair in output_items_sorted:
260-
output_keys.append(pair[0])
261-
outputs.append(pair[1])
252+
output_keys = sorted(outputs)
253+
outputs = [outputs[key] for key in output_keys]
262254

263255
else:
264256
output_keys = None

pytensor/compile/function/types.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,14 @@ def std_fgraph(
212212

213213
found_updates.extend(map(SymbolicOutput, updates))
214214
elif fgraph is None:
215-
input_vars = []
216-
217215
# If one of the inputs is non-atomic (i.e. has a non-`None` `Variable.owner`),
218216
# then we need to create/clone the graph starting at these inputs.
219217
# The result will be atomic versions of the given inputs connected to
220218
# the same outputs.
221219
# Otherwise, when all the inputs are already atomic, there's no need to
222220
# clone the graph.
223-
clone = force_clone
224-
for spec in input_specs:
225-
input_vars.append(spec.variable)
226-
clone |= spec.variable.owner is not None
221+
input_vars = [spec.variable for spec in input_specs]
222+
clone = force_clone or any(var.owner is not None for var in input_vars)
227223

228224
fgraph = FunctionGraph(
229225
input_vars,

pytensor/compile/profiling.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,8 +1204,7 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
12041204
compute_map[var][0] = 0
12051205

12061206
for k_remove, v_remove in viewedby_remove.items():
1207-
for i in v_remove:
1208-
viewed_by[k_remove].append(i)
1207+
viewed_by[k_remove].extend(v_remove)
12091208

12101209
for k_add, v_add in viewedby_add.items():
12111210
for i in v_add:
@@ -1215,15 +1214,16 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
12151214
del view_of[k]
12161215

12171216
# two data structure used to mimic Python gc
1218-
viewed_by = {} # {var1: [vars that view var1]}
1217+
# * {var1: [vars that view var1]}
12191218
# The len of the list is the value of python ref
12201219
# count. But we use a list, not just the ref count value.
1221-
# This is more safe to help detect potential bug in the algo
1222-
for var in fgraph.variables:
1223-
viewed_by[var] = []
1224-
view_of = {} # {var1: original var viewed by var1}
1220+
# This is more safe to help detect potential bug in the algo
1221+
viewed_by = {var: [] for var in fgraph.variables}
1222+
1223+
# * {var1: original var viewed by var1}
12251224
# The original mean that we don't keep track of all the intermediate
12261225
# relationship in the view.
1226+
view_of = {}
12271227

12281228
min_memory_generator(executable_nodes, viewed_by, view_of)
12291229

pytensor/graph/basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,9 +1474,8 @@ def _compute_deps_cache_(io):
14741474

14751475
_clients: dict[T, list[T]] = {}
14761476
sources: deque[T] = deque()
1477-
search_res_len: int = 0
1477+
search_res_len = len(search_res)
14781478
for snode, children in search_res:
1479-
search_res_len += 1
14801479
if children:
14811480
for child in children:
14821481
_clients.setdefault(child, []).append(snode)

pytensor/graph/fg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,10 @@ def remove_client(
270270

271271
self.execute_callbacks("on_prune", apply_node, reason)
272272

273-
for i, in_var in enumerate(apply_node.inputs):
274-
removal_stack.append((in_var, (apply_node, i)))
273+
removal_stack.extend(
274+
(in_var, (apply_node, i))
275+
for i, in_var in enumerate(apply_node.inputs)
276+
)
275277

276278
if remove_if_empty:
277279
del clients[var]

pytensor/graph/rewriting/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,9 @@ def merge_profile(prof1, prof2):
479479
new_sub_profile.append(p[6][idx])
480480

481481
new_rewrite = SequentialGraphRewriter(*new_l)
482-
new_nb_nodes = []
483-
for p1, p2 in zip(prof1[8], prof2[8]):
484-
new_nb_nodes.append((p1[0] + p2[0], p1[1] + p2[1]))
482+
new_nb_nodes = [
483+
(p1[0] + p2[0], p1[1] + p2[1]) for p1, p2 in zip(prof1[8], prof2[8])
484+
]
485485
new_nb_nodes.extend(prof1[8][len(new_nb_nodes) :])
486486
new_nb_nodes.extend(prof2[8][len(new_nb_nodes) :])
487487

@@ -960,9 +960,9 @@ def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]):
960960

961961
tracks = rewriter.tracks()
962962
if tracks:
963+
self._tracks.extend(tracks)
963964
for c in tracks:
964965
self.track_dict[c].append(rewriter)
965-
self._tracks.append(c)
966966

967967
for tag in tag_list:
968968
self.tag_dict[tag].append(rewriter)

pytensor/link/basic.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -524,12 +524,13 @@ def make_thunk(self, **kwargs):
524524
thunk_groups = list(zip(*thunk_lists))
525525
order = [x[0] for x in zip(*order_lists)]
526526

527-
to_reset = []
528-
for thunks, node in zip(thunk_groups, order):
529-
for j, output in enumerate(node.outputs):
530-
if output in no_recycling:
531-
for thunk in thunks:
532-
to_reset.append(thunk.outputs[j])
527+
to_reset = [
528+
thunk.outputs[j]
529+
for thunks, node in zip(thunk_groups, order)
530+
for j, output in enumerate(node.outputs)
531+
if output in no_recycling
532+
for thunk in thunks
533+
]
533534

534535
wrapper = self.wrapper
535536
pre = self.pre
@@ -692,18 +693,16 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
692693
computed, last_user = gc_helper(nodes)
693694

694695
if self.allow_gc:
695-
post_thunk_old_storage = []
696-
697-
for node in nodes:
698-
post_thunk_old_storage.append(
699-
[
700-
storage_map[input]
701-
for input in node.inputs
702-
if (input in computed)
703-
and (input not in fgraph.outputs)
704-
and (node == last_user[input])
705-
]
706-
)
696+
post_thunk_old_storage = [
697+
[
698+
storage_map[input]
699+
for input in node.inputs
700+
if (input in computed)
701+
and (input not in fgraph.outputs)
702+
and (node == last_user[input])
703+
]
704+
for node in nodes
705+
]
707706
else:
708707
post_thunk_old_storage = None
709708

pytensor/link/c/basic.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,19 +1129,18 @@ def __compile__(
11291129
)
11301130

11311131
def get_init_tasks(self):
1132-
init_tasks = []
1133-
tasks = []
1132+
vars = [v for v in self.variables if v not in self.consts]
11341133
id = 1
1135-
for v in self.variables:
1136-
if v in self.consts:
1137-
continue
1138-
init_tasks.append((v, "init", id))
1139-
tasks.append((v, "get", id + 1))
1140-
id += 2
1141-
for node in self.node_order:
1142-
tasks.append((node, "code", id))
1143-
init_tasks.append((node, "init", id + 1))
1144-
id += 2
1134+
init_tasks = [(v, "init", id + 2 * i) for i, v in enumerate(vars)]
1135+
tasks = [(v, "get", id + 2 * i + 1) for i, v in enumerate(vars)]
1136+
1137+
id += 2 * len(vars)
1138+
tasks.extend(
1139+
(node, "code", id + 2 * i) for i, node in enumerate(self.node_order)
1140+
)
1141+
init_tasks.extend(
1142+
(node, "init", id + 2 * i + 1) for i, node in enumerate(self.node_order)
1143+
)
11451144
return init_tasks, tasks
11461145

11471146
def make_thunk(
@@ -1492,12 +1491,11 @@ def in_sig(i, topological_pos, i_idx):
14921491
# graph's information used to compute the key. If we mistakenly
14931492
# pretend that inputs with clients don't have any, were are only using
14941493
# those inputs more than once to compute the key.
1495-
for ipos, var in [
1496-
(i, var)
1497-
for i, var in enumerate(fgraph.inputs)
1494+
sig.extend(
1495+
(var.type, in_sig(var, -1, ipos))
1496+
for ipos, var in enumerate(fgraph.inputs)
14981497
if not len(fgraph.clients[var])
1499-
]:
1500-
sig.append((var.type, in_sig(var, -1, ipos)))
1498+
)
15011499

15021500
# crystalize the signature and version
15031501
sig = tuple(sig)

pytensor/link/c/op.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
220220

221221
def lquote_macro(txt: str) -> str:
222222
"""Turn the last line of text into a ``\\``-commented line."""
223-
res = []
224-
spl = txt.split("\n")
225-
for l in spl[:-1]:
226-
res.append(l + " \\")
227-
res.append(spl[-1])
228-
return "\n".join(res)
223+
return " \\\n".join(txt.split("\n"))
229224

230225

231226
def get_sub_macros(sub: dict[str, str]) -> tuple[str, str]:
@@ -240,21 +235,17 @@ def get_sub_macros(sub: dict[str, str]) -> tuple[str, str]:
240235
return "\n".join(define_macros), "\n".join(undef_macros)
241236

242237

243-
def get_io_macros(
244-
inputs: list[str], outputs: list[str]
245-
) -> tuple[list[str]] | tuple[str, str]:
246-
define_macros = []
247-
undef_macros = []
238+
def get_io_macros(inputs: list[str], outputs: list[str]) -> tuple[str, str]:
239+
define_inputs = [f"#define INPUT_{int(i)} {inp}" for i, inp in enumerate(inputs)]
240+
define_outputs = [f"#define OUTPUT_{int(i)} {out}" for i, out in enumerate(outputs)]
248241

249-
for i, inp in enumerate(inputs):
250-
define_macros.append(f"#define INPUT_{int(i)} {inp}")
251-
undef_macros.append(f"#undef INPUT_{int(i)}")
242+
undef_inputs = [f"#undef INPUT_{int(i)}" for i in range(len(inputs))]
243+
undef_outputs = [f"#undef OUTPUT_{int(i)}" for i in range(len(outputs))]
252244

253-
for i, out in enumerate(outputs):
254-
define_macros.append(f"#define OUTPUT_{int(i)} {out}")
255-
undef_macros.append(f"#undef OUTPUT_{int(i)}")
245+
define_all = "\n".join(define_inputs + define_outputs)
246+
undef_all = "\n".join(undef_inputs + undef_outputs)
256247

257-
return "\n".join(define_macros), "\n".join(undef_macros)
248+
return define_all, undef_all
258249

259250

260251
class ExternalCOp(COp):
@@ -560,9 +551,10 @@ def get_c_macros(
560551
define_macros.append(define_template % ("APPLY_SPECIFIC(str)", f"str##_{name}"))
561552
undef_macros.append(undef_template % "APPLY_SPECIFIC")
562553

563-
for n, v in self.__get_op_params():
564-
define_macros.append(define_template % (n, v))
565-
undef_macros.append(undef_template % (n,))
554+
define_macros.extend(
555+
define_template % (n, v) for n, v in self.__get_op_params()
556+
)
557+
undef_macros.extend(undef_template % (n,) for n, _ in self.__get_op_params())
566558

567559
return "\n".join(define_macros), "\n".join(undef_macros)
568560

0 commit comments

Comments
 (0)