Skip to content

Commit 5ffab8f

Browse files
committed
Remove frozendict
1 parent 294a4bf commit 5ffab8f

File tree

5 files changed

+21
-74
lines changed

5 files changed

+21
-74
lines changed

doc/LICENSE.txt

-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ theano/tensor/sharedvar.py: James Bergstra, (c) 2010, Universite de Montreal, 3-
1818
theano/gradient.py: James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow, PyMC Developers, PyTensor Developers, (c) 2011, Universite de Montreal, 3-clause BSD License
1919
theano/compile/monitormode.py: this code was initially copied from the 'pyutools' package by its original author, and re-licensed under Theano's license.
2020

21-
Contains frozendict code from slezica’s python-frozendict(https://github.com/slezica/python-frozendict/blob/master/frozendict/__init__.py), Copyright (c) 2012 Santiago Lezica. All rights reserved.
22-
2321
Redistribution and use in source and binary forms, with or without
2422
modification, are permitted provided that the following conditions are met:
2523

pytensor/link/numba/dispatch/elemwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
488488
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
489489
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
490490
output_dtypes = tuple(out.type.dtype for out in node.outputs)
491-
inplace_pattern = tuple(op.inplace_pattern.items())
491+
inplace_pattern = op.inplace_pattern
492492
core_output_shapes = tuple(() for _ in range(nout))
493493

494494
# numba doesn't support nested literals right now...

pytensor/misc/frozendict.py

-44
This file was deleted.

pytensor/tensor/elemwise.py

+19-26
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from pytensor.link.c.basic import failure_code
1414
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
1515
from pytensor.link.c.params_type import ParamsType
16-
from pytensor.misc.frozendict import frozendict
1716
from pytensor.misc.safe_asarray import _asarray
1817
from pytensor.printing import Printer, pprint
1918
from pytensor.scalar import get_scalar_type
@@ -374,11 +373,11 @@ def __init__(
374373
"""
375374
assert not isinstance(scalar_op, type(self))
376375
if inplace_pattern is None:
377-
inplace_pattern = frozendict({})
376+
inplace_pattern = {}
378377
self.name = name
379378
self.scalar_op = scalar_op
380-
self.inplace_pattern = inplace_pattern
381-
self.destroy_map = {o: [i] for o, i in self.inplace_pattern.items()}
379+
self.inplace_pattern = tuple(inplace_pattern.items())
380+
self.destroy_map = {o: [i] for o, i in self.inplace_pattern}
382381

383382
if nfunc_spec is None:
384383
nfunc_spec = getattr(scalar_op, "nfunc_spec", None)
@@ -397,7 +396,6 @@ def __setstate__(self, d):
397396
super().__setstate__(d)
398397
self.ufunc = None
399398
self.nfunc = None
400-
self.inplace_pattern = frozendict(self.inplace_pattern)
401399

402400
def get_output_info(self, dim_shuffle, *inputs):
403401
"""Return the outputs dtype and broadcastable pattern and the
@@ -446,27 +444,23 @@ def get_output_info(self, dim_shuffle, *inputs):
446444
)
447445

448446
# inplace_pattern maps output idx -> input idx
449-
inplace_pattern = self.inplace_pattern
450-
if inplace_pattern:
451-
for overwriter, overwritten in inplace_pattern.items():
452-
for out_s, in_s in zip(
453-
out_shapes[overwriter],
454-
inputs[overwritten].type.shape,
455-
):
456-
if in_s == 1 and out_s != 1:
457-
raise ValueError(
458-
"Operation cannot be done inplace on an input "
459-
"with broadcasted dimensions."
460-
)
447+
for overwriter, overwritten in self.inplace_pattern:
448+
for out_s, in_s in zip(
449+
out_shapes[overwriter],
450+
inputs[overwritten].type.shape,
451+
):
452+
if in_s == 1 and out_s != 1:
453+
raise ValueError(
454+
"Operation cannot be done inplace on an input "
455+
"with broadcasted dimensions."
456+
)
461457

462458
out_dtypes = [o.type.dtype for o in shadow.outputs]
463-
if any(
464-
inputs[i].type.dtype != out_dtypes[o] for o, i in inplace_pattern.items()
465-
):
459+
if any(inputs[i].type.dtype != out_dtypes[o] for o, i in self.inplace_pattern):
466460
raise TypeError(
467461
(
468462
"Cannot do an inplace operation on incompatible data types.",
469-
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern),
463+
([i.type.dtype for i in inputs], out_dtypes, self.inplace_pattern),
470464
)
471465
)
472466
assert len(out_dtypes) == len(out_shapes)
@@ -755,6 +749,7 @@ def perform(self, node, inputs, output_storage):
755749
if nout == 1:
756750
variables = [variables]
757751

752+
inplace_pattern = dict(self.inplace_pattern)
758753
for i, (variable, storage, nout) in enumerate(
759754
zip(variables, output_storage, node.outputs)
760755
):
@@ -763,8 +758,8 @@ def perform(self, node, inputs, output_storage):
763758
# always return an ndarray with dtype object
764759
variable = np.asarray(variable, dtype=nout.dtype)
765760

766-
if i in self.inplace_pattern:
767-
odat = inputs[self.inplace_pattern[i]]
761+
if i in inplace_pattern:
762+
odat = inputs[inplace_pattern[i]]
768763
odat[...] = variable
769764
storage[0] = odat
770765

@@ -832,9 +827,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
832827
# The destroy map is a map of output indices to input indices
833828
# that overwrite them. We just convert them to the actual
834829
# Variables.
835-
dmap = {
836-
node.outputs[o]: [node.inputs[i]] for o, i in self.inplace_pattern.items()
837-
}
830+
dmap = {node.outputs[o]: [node.inputs[i]] for o, i in self.inplace_pattern}
838831

839832
# dtypes of the inputs
840833
idtypes = [input.type.dtype_specs()[1] for input in inputs]

pytensor/tensor/rewriting/elemwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def apply(self, fgraph):
173173
# original node add already some inplace patter and we
174174
# still try to add more pattern.
175175

176-
baseline = op.inplace_pattern
176+
baseline = dict(op.inplace_pattern)
177177
candidate_outputs = [
178178
i for i in self.candidate_input_idxs(node) if i not in baseline
179179
]

0 commit comments

Comments
 (0)