Skip to content

Commit 2e5eba3

Browse files
[tf.data] Make the result of tf.data.Dataset.options() immutable. This change is in preparation of making tf.data.Options() persistent across tf.function boundaries.
PiperOrigin-RevId: 361053121 Change-Id: I9b4ab3592f914e2311381d41b1a7bd11c45830aa
1 parent dd1ce23 commit 2e5eba3

File tree

5 files changed

+41
-19
lines changed

5 files changed

+41
-19
lines changed

RELEASE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
multiple input batches should be computed in parallel. With
3939
`num_parallel_calls` set, `deterministic` is used to indicate that
4040
outputs can be obtained in the non-deterministic order.
41+
* Options returned by `tf.data.Dataset.options()` are no longer mutable.
4142

4243
## Bug Fixes and Other Changes
4344

tensorflow/python/data/experimental/ops/optimization_options.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,9 @@ def _from_proto(self, pb):
418418
self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops
419419
if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
420420
self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
421+
422+
def _set_mutable(self, mutable):
423+
"""Change the mutability value to `mutable` on this options and children."""
424+
# pylint: disable=protected-access
425+
object.__setattr__(self, "_mutable", mutable)
426+
self.map_vectorization._set_mutable(mutable)

tensorflow/python/data/kernel_tests/options_test.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,18 @@ def testOptionsHaveDefaults(self):
111111
threading_options.ThreadingOptions())
112112

113113
@combinations.generate(test_base.default_test_combinations())
114-
def testMutableOptions(self):
114+
def testMutatingOptionsRaiseValueError(self):
115115
ds = dataset_ops.Dataset.range(0)
116-
ds.options().experimental_optimization.autotune = True
117-
self.assertTrue(ds.options().experimental_optimization.autotune)
118-
options = dataset_ops.Options()
119-
ds = ds.with_options(options)
120-
ds.options().experimental_deterministic = True
121-
self.assertTrue(ds.options().experimental_deterministic)
116+
options1 = dataset_ops.Options()
117+
options1.experimental_slack = True
118+
options2 = dataset_ops.Options()
119+
options2.experimental_optimization.autotune = True
120+
ds = ds.with_options(options1)
121+
ds = ds.map(lambda x: 2 * x)
122+
ds = ds.with_options(options2)
123+
with self.assertRaises(ValueError):
124+
dataset_options = ds.options()
125+
dataset_options.experimental_deterministic = True
122126

123127
@combinations.generate(test_base.eager_only_combinations())
124128
def testNestedDataset(self):

tensorflow/python/data/ops/dataset_ops.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def __init__(self, variant_tensor):
218218
input_options = input_dataset.options()
219219
if input_options is not None:
220220
self._options_attr = self._options_attr.merge(input_options)
221+
self._options_attr._set_mutable(False) # pylint: disable=protected-access
221222

222223
@property
223224
def _variant_tensor(self):
@@ -2990,16 +2991,10 @@ class Options(options_lib.OptionsBase):
29902991
The options are set for the entire dataset and are carried over to datasets
29912992
created through tf.data transformations.
29922993
2993-
The options can be set either by mutating the object returned by
2994-
`tf.data.Dataset.options()` or by constructing an `Options` object and using
2995-
the `tf.data.Dataset.with_options(options)` transformation, which returns a
2994+
The options can be set by constructing an `Options` object and using the
2995+
`tf.data.Dataset.with_options(options)` transformation, which returns a
29962996
dataset with the options set.
29972997
2998-
>>> dataset = tf.data.Dataset.range(42)
2999-
>>> dataset.options().experimental_deterministic = False
3000-
>>> print(dataset.options().experimental_deterministic)
3001-
False
3002-
30032998
>>> dataset = tf.data.Dataset.range(42)
30042999
>>> options = tf.data.Options()
30053000
>>> options.experimental_deterministic = False
@@ -3099,6 +3094,14 @@ def _from_proto(self, pb):
30993094
self.experimental_slack = pb.slack
31003095
self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access
31013096

3097+
def _set_mutable(self, mutable):
3098+
"""Change the mutability value to `mutable` on this options and children."""
3099+
# pylint: disable=protected-access
3100+
object.__setattr__(self, "_mutable", mutable)
3101+
self.experimental_distribute._set_mutable(mutable)
3102+
self.experimental_optimization._set_mutable(mutable)
3103+
self.experimental_threading._set_mutable(mutable)
3104+
31023105
def _graph_rewrites(self):
31033106
"""Produces lists of enabled, disabled, default static graph rewrites.
31043107
@@ -4665,17 +4668,17 @@ class _OptionsDataset(UnaryUnchangedStructureDataset):
46654668
"""An identity `Dataset` that stores options."""
46664669

46674670
def __init__(self, input_dataset, options):
4671+
# pylint: disable=protected-access
46684672
self._input_dataset = input_dataset
4669-
variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access
4673+
variant_tensor = input_dataset._variant_tensor
46704674
super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
46714675

46724676
if self._options_attr:
4677+
self._options_attr._set_mutable(True)
46734678
self._options_attr = self._options_attr.merge(options)
46744679
else:
46754680
self._options_attr = options
4676-
4677-
def options(self):
4678-
return self._options_attr
4681+
self._options_attr._set_mutable(False)
46794682

46804683

46814684
class _ModelDataset(UnaryUnchangedStructureDataset):

tensorflow/python/data/util/options.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class OptionsBase(object):
3737
def __init__(self):
3838
# NOTE: Cannot use `self._options` here as we override `__setattr__`
3939
object.__setattr__(self, "_options", {})
40+
object.__setattr__(self, "_mutable", True)
4041

4142
def __eq__(self, other):
4243
if not isinstance(other, self.__class__):
@@ -53,12 +54,19 @@ def __ne__(self, other):
5354
return NotImplemented
5455

5556
def __setattr__(self, name, value):
57+
if not self._mutable:
58+
raise ValueError("Mutating `tf.data.Options()` returned by "
59+
"`tf.data.Dataset.options()` has no effect.")
5660
if hasattr(self, name):
5761
object.__setattr__(self, name, value)
5862
else:
5963
raise AttributeError(
6064
"Cannot set the property %s on %s." % (name, type(self).__name__))
6165

66+
def _set_mutable(self, mutable):
67+
"""Change the mutability property to `mutable`."""
68+
object.__setattr__(self, "_mutable", mutable)
69+
6270
def _to_proto(self):
6371
"""Convert options to protocol buffer."""
6472
raise NotImplementedError("%s._to_proto()" % type(self).__name__)

0 commit comments

Comments
 (0)