Skip to content

Commit 8a2800c

Browse files
tf-model-analysis-teamtfx-copybara
tf-model-analysis-team
authored andcommitted
Refactor batch ptransform in PredictionsExtractor.
PiperOrigin-RevId: 489042610
1 parent a21dc44 commit 8a2800c

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tensorflow_model_analysis/extractors/predictions_extractor.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,12 @@ def _RunInference(extracts: beam.pvalue.PCollection,
331331
# Beam batch will group single Extracts into a batch. Then
332332
# merge_extracts will flatten the batch into a single "batched"
333333
# extract.
334-
batch_extracts_stage_name = 'BatchSingleExampleExtracts'
335334
if batch_size is not None:
336-
extracts |= batch_extracts_stage_name >> beam.BatchElements(
337-
min_batch_size=batch_size, max_batch_size=batch_size)
335+
batch_kwargs = {'min_batch_size': batch_size, 'max_batch_size': batch_size}
338336
else:
339-
extracts |= batch_extracts_stage_name >> beam.BatchElements()
340-
return extracts | 'MergeExtracts' >> beam.Map(
341-
util.merge_extracts, squeeze_two_dim_vector=False)
337+
# Default batch parameters.
338+
batch_kwargs = {}
339+
return (extracts
340+
| 'BatchSingleExampleExtracts' >> beam.BatchElements(**batch_kwargs)
341+
| 'MergeExtracts' >> beam.Map(
342+
util.merge_extracts, squeeze_two_dim_vector=False))

0 commit comments

Comments
 (0)