Skip to content

Commit 112af3e

Browse files
authored
Remove deprecated generator data (#7664)
1 parent 358b825 commit 112af3e

File tree

6 files changed

+3
-327
lines changed

6 files changed

+3
-327
lines changed

docs/source/api/data.rst

-1
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,4 @@ Data
1010
MutableData
1111
get_data
1212
Data
13-
GeneratorAdapter
1413
Minibatch

pymc/data.py

+1-50
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,16 @@
3333
from pytensor.scalar import Cast
3434
from pytensor.tensor.elemwise import Elemwise
3535
from pytensor.tensor.random.basic import IntegersRV
36-
from pytensor.tensor.type import TensorType
3736
from pytensor.tensor.variable import TensorConstant, TensorVariable
3837

3938
import pymc as pm
4039

41-
from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX
40+
from pymc.pytensorf import convert_data
4241
from pymc.vartypes import isgenerator
4342

4443
__all__ = [
4544
"ConstantData",
4645
"Data",
47-
"GeneratorAdapter",
4846
"Minibatch",
4947
"MutableData",
5048
"get_data",
@@ -86,51 +84,6 @@ def clone(self):
8684
return cp
8785

8886

89-
class GeneratorAdapter:
90-
"""Class that helps infer data type of generator.
91-
92-
It looks at the first item, preserving the order of the resulting generator.
93-
"""
94-
95-
def make_variable(self, gop, name=None):
96-
var = GenTensorVariable(gop, self.tensortype, name)
97-
var.tag.test_value = self.test_value
98-
return var
99-
100-
def __init__(self, generator):
101-
if not pm.vartypes.isgenerator(generator):
102-
raise TypeError("Object should be generator like")
103-
self.test_value = smarttypeX(copy(next(generator)))
104-
# make pickling potentially possible
105-
self._yielded_test_value = False
106-
self.gen = generator
107-
self.tensortype = TensorType(self.test_value.dtype, ((False,) * self.test_value.ndim))
108-
109-
# python3 generator
110-
def __next__(self):
111-
"""Next value in the generator."""
112-
if not self._yielded_test_value:
113-
self._yielded_test_value = True
114-
return self.test_value
115-
else:
116-
return smarttypeX(copy(next(self.gen)))
117-
118-
# python2 generator
119-
next = __next__
120-
121-
def __iter__(self):
122-
"""Return an iterator."""
123-
return self
124-
125-
def __eq__(self, other):
126-
"""Return true if both objects are actually the same."""
127-
return id(self) == id(other)
128-
129-
def __hash__(self):
130-
"""Return a hash of the object."""
131-
return hash(id(self))
132-
133-
13487
class MinibatchIndexRV(IntegersRV):
13588
_print_name = ("minibatch_index", r"\operatorname{minibatch\_index}")
13689

@@ -170,8 +123,6 @@ def is_valid_observed(v) -> bool:
170123
isinstance(v.owner.op, MinibatchOp)
171124
and all(is_valid_observed(inp) for inp in v.owner.inputs)
172125
)
173-
# Or Generator
174-
or isinstance(v.owner.op, GeneratorOp)
175126
)
176127

177128

pymc/pytensorf.py

+1-106
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
walk,
3737
)
3838
from pytensor.graph.fg import FunctionGraph, Output
39-
from pytensor.graph.op import Op
4039
from pytensor.scalar.basic import Cast
4140
from pytensor.scan.op import Scan
4241
from pytensor.tensor.basic import _as_tensor_variable
@@ -63,10 +62,8 @@
6362
"compile_pymc",
6463
"cont_inputs",
6564
"convert_data",
66-
"convert_generator_data",
6765
"convert_observed_data",
6866
"floatX",
69-
"generator",
7067
"gradient",
7168
"hessian",
7269
"hessian_diag",
@@ -81,20 +78,10 @@
8178
def convert_observed_data(data) -> np.ndarray | Variable:
8279
"""Convert user provided dataset to accepted formats."""
8380
if isgenerator(data):
84-
return convert_generator_data(data)
81+
raise TypeError("Data passed to `observed` cannot be a generator.")
8582
return convert_data(data)
8683

8784

88-
def convert_generator_data(data) -> TensorVariable:
89-
warnings.warn(
90-
"Generator data is deprecated and we intend to remove it."
91-
" If you disagree and need this, please get in touch via https://github.com/pymc-devs/pymc/issues.",
92-
DeprecationWarning,
93-
stacklevel=2,
94-
)
95-
return generator(data)
96-
97-
9885
def convert_data(data) -> np.ndarray | Variable:
9986
ret: np.ndarray | Variable
10087
if hasattr(data, "to_numpy") and hasattr(data, "isnull"):
@@ -625,98 +612,6 @@ def __call__(self, input):
625612
return pytensor.clone_replace(self.tensor, {oldinput: input}, rebuild_strict=False)
626613

627614

628-
class GeneratorOp(Op):
629-
"""
630-
Generator Op is designed for storing python generators inside pytensor graph.
631-
632-
__call__ creates TensorVariable
633-
It has 2 new methods
634-
- var.set_gen(gen): sets new generator
635-
- var.set_default(value): sets new default value (None erases default value)
636-
637-
If generator is exhausted, variable will produce default value if it is not None,
638-
else raises `StopIteration` exception that can be caught on runtime.
639-
640-
Parameters
641-
----------
642-
gen: generator that implements __next__ (py3) or next (py2) method
643-
and yields np.arrays with same types
644-
default: np.array with the same type as generator produces
645-
"""
646-
647-
__props__ = ("generator",)
648-
649-
def __init__(self, gen, default=None):
650-
warnings.warn(
651-
"generator data is deprecated and will be removed in a future release", FutureWarning
652-
)
653-
from pymc.data import GeneratorAdapter
654-
655-
super().__init__()
656-
if not isinstance(gen, GeneratorAdapter):
657-
gen = GeneratorAdapter(gen)
658-
self.generator = gen
659-
self.set_default(default)
660-
661-
def make_node(self, *inputs):
662-
gen_var = self.generator.make_variable(self)
663-
return Apply(self, [], [gen_var])
664-
665-
def perform(self, node, inputs, output_storage, params=None):
666-
if self.default is not None:
667-
output_storage[0][0] = next(self.generator, self.default)
668-
else:
669-
output_storage[0][0] = next(self.generator)
670-
671-
def do_constant_folding(self, fgraph, node):
672-
return False
673-
674-
__call__ = pytensor.config.change_flags(compute_test_value="off")(Op.__call__)
675-
676-
def set_gen(self, gen):
677-
from pymc.data import GeneratorAdapter
678-
679-
if not isinstance(gen, GeneratorAdapter):
680-
gen = GeneratorAdapter(gen)
681-
if not gen.tensortype == self.generator.tensortype:
682-
raise ValueError("New generator should yield the same type")
683-
self.generator = gen
684-
685-
def set_default(self, value):
686-
if value is None:
687-
self.default = None
688-
else:
689-
value = np.asarray(value, self.generator.tensortype.dtype)
690-
t1 = (False,) * value.ndim
691-
t2 = self.generator.tensortype.broadcastable
692-
if not t1 == t2:
693-
raise ValueError("Default value should have the same type as generator")
694-
self.default = value
695-
696-
697-
def generator(gen, default=None):
698-
"""
699-
Create a generator variable with possibility to set default value and new generator.
700-
701-
If generator is exhausted variable will produce default value if it is not None,
702-
else raises `StopIteration` exception that can be caught on runtime.
703-
704-
Parameters
705-
----------
706-
gen: generator that implements __next__ (py3) or next (py2) method
707-
and yields np.arrays with same types
708-
default: np.array with the same type as generator produces
709-
710-
Returns
711-
-------
712-
TensorVariable
713-
It has 2 new methods
714-
- var.set_gen(gen): sets new generator
715-
- var.set_default(value): sets new default value (None erases default value)
716-
"""
717-
return GeneratorOp(gen, default)()
718-
719-
720615
def ix_(*args):
721616
"""
722617
PyTensor np.ix_ analog.

tests/test_data.py

+1-94
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@
1313
# limitations under the License.
1414

1515
import io
16-
import itertools as it
1716

1817
from os import path
1918

20-
import cloudpickle
2119
import numpy as np
2220
import pytensor
2321
import pytensor.tensor as pt
@@ -29,7 +27,7 @@
2927
import pymc as pm
3028

3129
from pymc.data import MinibatchOp
32-
from pymc.pytensorf import GeneratorOp, floatX
30+
from pymc.pytensorf import floatX
3331

3432

3533
class TestData:
@@ -495,97 +493,6 @@ def integers_ndim(ndim):
495493
i += 1
496494

497495

498-
@pytest.mark.usefixtures("strict_float32")
499-
class TestGenerator:
500-
def test_basic(self):
501-
generator = pm.GeneratorAdapter(integers())
502-
gop = GeneratorOp(generator)()
503-
assert gop.tag.test_value == np.float32(0)
504-
f = pytensor.function([], gop)
505-
assert f() == np.float32(0)
506-
assert f() == np.float32(1)
507-
for _ in range(2, 100):
508-
f()
509-
assert f() == np.float32(100)
510-
511-
def test_ndim(self):
512-
for ndim in range(10):
513-
res = list(it.islice(integers_ndim(ndim), 0, 2))
514-
generator = pm.GeneratorAdapter(integers_ndim(ndim))
515-
gop = GeneratorOp(generator)()
516-
f = pytensor.function([], gop)
517-
assert ndim == res[0].ndim
518-
np.testing.assert_equal(f(), res[0])
519-
np.testing.assert_equal(f(), res[1])
520-
521-
def test_cloning_available(self):
522-
gop = pm.generator(integers())
523-
res = gop**2
524-
shared = pytensor.shared(pm.floatX(10))
525-
res1 = pytensor.clone_replace(res, {gop: shared})
526-
f = pytensor.function([], res1)
527-
assert f() == np.float32(100)
528-
529-
def test_default_value(self):
530-
def gen():
531-
for i in range(2):
532-
yield pm.floatX(np.ones((10, 10)) * i)
533-
534-
gop = pm.generator(gen(), np.ones((10, 10)) * 10)
535-
f = pytensor.function([], gop)
536-
np.testing.assert_equal(np.ones((10, 10)) * 0, f())
537-
np.testing.assert_equal(np.ones((10, 10)) * 1, f())
538-
np.testing.assert_equal(np.ones((10, 10)) * 10, f())
539-
with pytest.raises(ValueError):
540-
gop.set_default(1)
541-
542-
def test_set_gen_and_exc(self):
543-
def gen():
544-
for i in range(2):
545-
yield pm.floatX(np.ones((10, 10)) * i)
546-
547-
gop = pm.generator(gen())
548-
f = pytensor.function([], gop)
549-
np.testing.assert_equal(np.ones((10, 10)) * 0, f())
550-
np.testing.assert_equal(np.ones((10, 10)) * 1, f())
551-
with pytest.raises(StopIteration):
552-
f()
553-
gop.set_gen(gen())
554-
np.testing.assert_equal(np.ones((10, 10)) * 0, f())
555-
np.testing.assert_equal(np.ones((10, 10)) * 1, f())
556-
557-
def test_pickling(self, datagen):
558-
gen = pm.generator(datagen)
559-
cloudpickle.loads(cloudpickle.dumps(gen))
560-
bad_gen = pm.generator(integers())
561-
with pytest.raises(TypeError):
562-
cloudpickle.dumps(bad_gen)
563-
564-
def test_gen_cloning_with_shape_change(self, datagen):
565-
gen = pm.generator(datagen)
566-
gen_r = pt.random.normal(size=gen.shape).T
567-
X = gen.dot(gen_r)
568-
res, _ = pytensor.scan(lambda x: x.sum(), X, n_steps=X.shape[0])
569-
assert res.eval().shape == (50,)
570-
shared = pytensor.shared(datagen.data.astype(gen.dtype))
571-
res2 = pytensor.clone_replace(res, {gen: shared**2})
572-
assert res2.eval().shape == (1000,)
573-
574-
575-
def gen1():
576-
i = 0
577-
while True:
578-
yield np.ones((10, 100)) * i
579-
i += 1
580-
581-
582-
def gen2():
583-
i = 0
584-
while True:
585-
yield np.ones((20, 100)) * i
586-
i += 1
587-
588-
589496
@pytest.mark.usefixtures("strict_float32")
590497
class TestMinibatch:
591498
data = np.random.rand(30, 10)

0 commit comments

Comments
 (0)