diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 1ec9659..a524f8c 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -258,13 +258,22 @@ FANOUT = flags.DEFINE_integer( 'fanout', None, - help='Beam CombineFn fanout. Might be required for large dataset.', + help='Beam CombineFn fanout. Recommended when evaluating large datasets.', ) NUM_THREADS = flags.DEFINE_integer( 'num_threads', None, help='Number of chunks to read/write Zarr in parallel per worker.', ) +SHUFFLE_BEFORE_TEMPORAL_MEAN = flags.DEFINE_bool( + 'shuffle_before_temporal_mean', + False, + help=( + 'Shuffle before computing the temporal mean. This is a good idea when' + ' evaluation metric outputs are small compared to the size of the' + ' input data, such as when aggregating over space or a large ensemble.' + ), +) def _wind_vector_error(err_type: str): @@ -661,6 +670,7 @@ def main(argv: list[str]) -> None: skipna=SKIPNA.value, fanout=FANOUT.value, num_threads=NUM_THREADS.value, + shuffle_before_temporal_mean=SHUFFLE_BEFORE_TEMPORAL_MEAN.value, argv=argv, ) else: diff --git a/weatherbench2/config.py b/weatherbench2/config.py index 3be06e9..f407d8d 100644 --- a/weatherbench2/config.py +++ b/weatherbench2/config.py @@ -114,7 +114,7 @@ class Eval: by-valid convention. For by-init, specify analysis dataset as obs. derived_variables: dict of DerivedVariable instances to compute on the fly. temporal_mean: Compute temporal mean (over time/init_time) for metrics. - output_format: Wether to save to 'netcdf' or 'zarr'. + output_format: whether to save to 'netcdf' or 'zarr'. """ metrics: t.Dict[str, Metric] diff --git a/weatherbench2/evaluation.py b/weatherbench2/evaluation.py index 942e414..1eb3679 100644 --- a/weatherbench2/evaluation.py +++ b/weatherbench2/evaluation.py @@ -395,6 +395,10 @@ def _metric_and_region_loop( """Compute metric results looping over metrics and regions in eval config.""" # Compute derived variables logging.info('Starting _metric_and_region_loop') + logging.info( + f'{len(forecast)} variables, {forecast.sizes=}, {truth.sizes=}, ' + f'({forecast.nbytes + truth.nbytes} bytes)' + ) for name, dv in eval_config.derived_variables.items(): logging.info(f'Logging: derived_variable {name!r}: {dv}') forecast[name] = dv.compute(forecast) @@ -559,7 +563,11 @@ class _EvaluateAllMetrics(beam.PTransform): input_chunks: Chunks to use for input files. skipna: Whether to skip NaN values in both forecasts and observations during evaluation. - fanout: Fanout parameter for Beam combiners. + fanout: Fanout parameter for Beam combiners in the temporal mean. + shuffle_before_temporal_mean: If True, shuffle before computing the temporal + mean. This is a good idea when evaluation metric outputs are small + compared to the size of the input data, such as when aggregating over + space or a large ensemble. num_threads: Number of threads for reading/writing files. """ @@ -569,6 +577,7 @@ class _EvaluateAllMetrics(beam.PTransform): input_chunks: abc.Mapping[str, int] skipna: bool fanout: Optional[int] = None + shuffle_before_temporal_mean: bool = False num_threads: Optional[int] = None def _evaluate_chunk( @@ -724,6 +733,10 @@ def _evaluate( forecast_pipeline |= 'EvaluateChunk' >> beam.MapTuple(self._evaluate_chunk) if self.eval_config.temporal_mean: + if self.shuffle_before_temporal_mean: + # Reshuffle to avoid fusing evaluation of chunks with the temporal mean. + forecast_pipeline |= beam.Reshuffle() + forecast_pipeline |= 'TemporalMean' >> xbeam.Mean( dim='init_time' if self.data_config.by_init else 'time', fanout=self.fanout, @@ -749,6 +762,7 @@ def evaluate_with_beam( input_chunks: abc.Mapping[str, int], runner: str, fanout: Optional[int] = None, + shuffle_before_temporal_mean: bool = False, num_threads: Optional[int] = None, argv: Optional[list[str]] = None, skipna: bool = False, @@ -777,7 +791,11 @@ def evaluate_with_beam( eval_configs: Dictionary of config.Eval instances. input_chunks: Chunking of input datasets. runner: Beam runner. - fanout: Beam CombineFn fanout. + fanout: Fanout parameter for Beam combiners in the temporal mean. + shuffle_before_temporal_mean: If True, shuffle before computing the temporal + mean. This is a good idea when evaluation metric outputs are small + compared to the size of the input data, such as when aggregating over + space or a large ensemble. num_threads: Number of threads to use for reading/writing data. argv: Other arguments to pass into the Beam pipeline. skipna: Whether to skip NaN values in both forecasts and observations during @@ -795,9 +813,10 @@ def evaluate_with_beam( eval_config, data_config, input_chunks, + skipna=skipna, fanout=fanout, + shuffle_before_temporal_mean=shuffle_before_temporal_mean, num_threads=num_threads, - skipna=skipna, ) | f'save_{eval_name}' >> _SaveOutputs(