@@ -120,26 +120,25 @@ def _getitems(
120
120
_getitem (self , record_key , generator , serialized = serialized )
121
121
for record_key in record_keys
122
122
]
123
- if serialized :
124
- return np .array (items )
125
- return items
123
+ return np .asarray (items )
126
124
127
125
128
- def _deserialize_example_np (serialized_example , * , decoders = None ):
126
+ def _deserialize_example_np (self , serialized_example , * , decoders = None ):
129
127
"""Function to overwrite dataset_info.features.deserialize_example_np.
130
128
131
129
Warning: this has to be defined in the outer scope in order for the function
132
130
to be pickable.
133
131
134
132
Args:
133
+ self: the dataset builder.
135
134
serialized_example: the example to deserialize.
136
135
decoders: optional decoders.
137
136
138
137
Returns:
139
138
The serialized example, because deserialization is taken care by
140
139
RandomFakeGenerator.
141
140
"""
142
- del decoders
141
+ del self , decoders
143
142
return serialized_example
144
143
145
144
@@ -173,6 +172,7 @@ def mock_data(
173
172
as_data_source_fn : Optional [Callable [..., Sequence [Any ]]] = None ,
174
173
data_dir : Optional [str ] = None ,
175
174
mock_array_record_data_source : Optional [PickableDataSourceMock ] = None ,
175
+ use_in_multiprocessing : bool = False ,
176
176
) -> Iterator [None ]:
177
177
"""Mock tfds to generate random data.
178
178
@@ -262,6 +262,10 @@ def as_dataset(self, *args, **kwargs):
262
262
mock_array_record_data_source: Overwrite a mock for the underlying
263
263
ArrayRecord data source if it is used. Note: If used the same mock will be
264
264
used for all data sources loaded within this context.
265
+ use_in_multiprocessing: If True, the mock will use a multiprocessing-safe
266
+ approach to generate the data. It's notably useful for PyGrain. The goal
267
+ is to migrate the codebase to this mode by default. Find a more detailed
268
+ explanation of this parameter in a comment in the code below.
265
269
266
270
Yields:
267
271
None
@@ -361,9 +365,31 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
361
365
if split is None :
362
366
split = {s : s for s in self .info .splits }
363
367
364
- generator_cls , features , _ , _ = _get_fake_data_components (
365
- decoders , self .info .features
366
- )
368
+ features = self .info .features
369
+ if use_in_multiprocessing :
370
+ # In multiprocessing, we generate serialized data. The data is then
371
+ # re-deserialized by the feature as it would normally happen in TFDS. In
372
+ # this approach, we don't need to monkey-patch workers to propagate the
373
+ # information that deserialize_example_np should be a no-op. Indeed, doing
374
+ # so is difficult as PyGrain uses the `spawn` multiprocessing mode. Users
375
+ # of tfds.testing.mock_data in the codebase started relying on the
376
+ # function not serializing (for example, they don't have TensorFlow in
377
+ # their dependency), so we cannot have use_in_multiprocessing by default.
378
+ # ┌─────────────┐
379
+ # │ Main process│
380
+ # └─┬──────┬────┘
381
+ # ┌───────▼─┐ ┌─▼───────┐
382
+ # │ worker1 │ │ worker2 │ ...
383
+ # └───────┬─┘ └─┬───────┘
384
+ # serialized data by the generator
385
+ # ┌───────▼─┐ ┌─▼───────┐
386
+ # │ tfds 1 │ │ tfds 2 │ ...
387
+ # └───────┬─┘ └─┬───────┘
388
+ # deserialized data
389
+ generator_cls = SerializedRandomFakeGenerator
390
+ else :
391
+ # We generate already deserialized data with the generator.
392
+ generator_cls , _ , _ , _ = _get_fake_data_components (decoders , features )
367
393
generator = generator_cls (features , num_examples )
368
394
369
395
if actual_policy == MockPolicy .USE_CODE :
@@ -385,7 +411,6 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
385
411
# Force ARRAY_RECORD as the default file_format.
386
412
return_value = file_adapters .FileFormat .ARRAY_RECORD ,
387
413
):
388
- self .info .features .deserialize_example_np = _deserialize_example_np
389
414
mock_data_source .return_value .__len__ .return_value = num_examples
390
415
mock_data_source .return_value ._generator = ( # pylint:disable=protected-access
391
416
generator
@@ -399,7 +424,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
399
424
400
425
def build_single_data_source (split ):
401
426
single_data_source = array_record .ArrayRecordDataSource (
402
- dataset_info = self . info , split = split , decoders = decoders
427
+ dataset_builder = self , split = split , decoders = decoders
403
428
)
404
429
return single_data_source
405
430
@@ -463,6 +488,10 @@ def new_builder_from_files(*args, **kwargs):
463
488
f'{ core } .dataset_builder.FileReaderBuilder._as_dataset' ,
464
489
as_dataset_fn ,
465
490
),
491
+ (
492
+ f'{ core } .features.top_level_feature.TopLevelFeature.deserialize_example_np' ,
493
+ _deserialize_example_np ,
494
+ ),
466
495
]:
467
496
stack .enter_context (mock .patch (path , mocked_fn ))
468
497
yield
0 commit comments