Skip to content

Commit 80137a4

Browse files
committed
Remove frozendict
1 parent 294a4bf commit 80137a4

File tree

3 files changed

+19
-72
lines changed

3 files changed

+19
-72
lines changed

doc/LICENSE.txt

-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ theano/tensor/sharedvar.py: James Bergstra, (c) 2010, Universite de Montreal, 3-
1818
theano/gradient.py: James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow, PyMC Developers, PyTensor Developers, (c) 2011, Universite de Montreal, 3-clause BSD License
1919
theano/compile/monitormode.py: this code was initially copied from the 'pyutools' package by its original author, and re-licensed under Theano's license.
2020

21-
Contains frozendict code from slezica’s python-frozendict(https://github.com/slezica/python-frozendict/blob/master/frozendict/__init__.py), Copyright (c) 2012 Santiago Lezica. All rights reserved.
22-
2321
Redistribution and use in source and binary forms, with or without
2422
modification, are permitted provided that the following conditions are met:
2523

pytensor/misc/frozendict.py

-44
This file was deleted.

pytensor/tensor/elemwise.py

+19-26
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from pytensor.link.c.basic import failure_code
1414
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
1515
from pytensor.link.c.params_type import ParamsType
16-
from pytensor.misc.frozendict import frozendict
1716
from pytensor.misc.safe_asarray import _asarray
1817
from pytensor.printing import Printer, pprint
1918
from pytensor.scalar import get_scalar_type
@@ -374,11 +373,11 @@ def __init__(
374373
"""
375374
assert not isinstance(scalar_op, type(self))
376375
if inplace_pattern is None:
377-
inplace_pattern = frozendict({})
376+
inplace_pattern = {}
378377
self.name = name
379378
self.scalar_op = scalar_op
380-
self.inplace_pattern = inplace_pattern
381-
self.destroy_map = {o: [i] for o, i in self.inplace_pattern.items()}
379+
self.inplace_pattern = tuple(inplace_pattern.items())
380+
self.destroy_map = {o: [i] for o, i in self.inplace_pattern}
382381

383382
if nfunc_spec is None:
384383
nfunc_spec = getattr(scalar_op, "nfunc_spec", None)
@@ -397,7 +396,6 @@ def __setstate__(self, d):
397396
super().__setstate__(d)
398397
self.ufunc = None
399398
self.nfunc = None
400-
self.inplace_pattern = frozendict(self.inplace_pattern)
401399

402400
def get_output_info(self, dim_shuffle, *inputs):
403401
"""Return the outputs dtype and broadcastable pattern and the
@@ -446,27 +444,23 @@ def get_output_info(self, dim_shuffle, *inputs):
446444
)
447445

448446
# inplace_pattern maps output idx -> input idx
449-
inplace_pattern = self.inplace_pattern
450-
if inplace_pattern:
451-
for overwriter, overwritten in inplace_pattern.items():
452-
for out_s, in_s in zip(
453-
out_shapes[overwriter],
454-
inputs[overwritten].type.shape,
455-
):
456-
if in_s == 1 and out_s != 1:
457-
raise ValueError(
458-
"Operation cannot be done inplace on an input "
459-
"with broadcasted dimensions."
460-
)
447+
for overwriter, overwritten in self.inplace_pattern:
448+
for out_s, in_s in zip(
449+
out_shapes[overwriter],
450+
inputs[overwritten].type.shape,
451+
):
452+
if in_s == 1 and out_s != 1:
453+
raise ValueError(
454+
"Operation cannot be done inplace on an input "
455+
"with broadcasted dimensions."
456+
)
461457

462458
out_dtypes = [o.type.dtype for o in shadow.outputs]
463-
if any(
464-
inputs[i].type.dtype != out_dtypes[o] for o, i in inplace_pattern.items()
465-
):
459+
if any(inputs[i].type.dtype != out_dtypes[o] for o, i in self.inplace_pattern):
466460
raise TypeError(
467461
(
468462
"Cannot do an inplace operation on incompatible data types.",
469-
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern),
463+
([i.type.dtype for i in inputs], out_dtypes, self.inplace_pattern),
470464
)
471465
)
472466
assert len(out_dtypes) == len(out_shapes)
@@ -755,6 +749,7 @@ def perform(self, node, inputs, output_storage):
755749
if nout == 1:
756750
variables = [variables]
757751

752+
inplace_pattern = dict(self.inplace_pattern)
758753
for i, (variable, storage, nout) in enumerate(
759754
zip(variables, output_storage, node.outputs)
760755
):
@@ -763,8 +758,8 @@ def perform(self, node, inputs, output_storage):
763758
# always return an ndarray with dtype object
764759
variable = np.asarray(variable, dtype=nout.dtype)
765760

766-
if i in self.inplace_pattern:
767-
odat = inputs[self.inplace_pattern[i]]
761+
if i in inplace_pattern:
762+
odat = inputs[inplace_pattern[i]]
768763
odat[...] = variable
769764
storage[0] = odat
770765

@@ -832,9 +827,7 @@ def _c_all(self, node, nodename, inames, onames, sub):
832827
# The destroy map is a map of output indices to input indices
833828
# that overwrite them. We just convert them to the actual
834829
# Variables.
835-
dmap = {
836-
node.outputs[o]: [node.inputs[i]] for o, i in self.inplace_pattern.items()
837-
}
830+
dmap = {node.outputs[o]: [node.inputs[i]] for o, i in self.inplace_pattern}
838831

839832
# dtypes of the inputs
840833
idtypes = [input.type.dtype_specs()[1] for input in inputs]

0 commit comments

Comments
 (0)