diff --git a/ndsl/stencils/testing/serialbox_to_netcdf.py b/ndsl/stencils/testing/serialbox_to_netcdf.py index bee0a789..453a861b 100644 --- a/ndsl/stencils/testing/serialbox_to_netcdf.py +++ b/ndsl/stencils/testing/serialbox_to_netcdf.py @@ -1,15 +1,19 @@ import argparse import os import shutil -import xarray as xr +from typing import Any, Dict, Optional + import f90nml import numpy as np -from typing import Optional +import xarray as xr + try: import serialbox except ModuleNotFoundError: - raise ModuleNotFoundError("Serialbox couldn't be imported, make sure it's in your PYTHONPATH or you env") + raise ModuleNotFoundError( + "Serialbox couldn't be imported, make sure it's in your PYTHONPATH or you env" + ) def get_parser(): @@ -23,7 +27,10 @@ def get_parser(): "output_path", type=str, help="output directory where netcdf data will be saved" ) parser.add_argument( - "-dn", "--data_name", type=str, help="[Optional] Give the name of the data, will default to Generator_rankX" + "-dn", + "--data_name", + type=str, + help="[Optional] Give the name of the data, will default to Generator_rankX", ) return parser @@ -43,7 +50,7 @@ def get_all_savepoint_names(serializer): return savepoint_names -def get_serializer(data_path: str, rank:int , data_name:Optional[str] = None): +def get_serializer(data_path: str, rank: int, data_name: Optional[str] = None): if data_name: name = data_name else: @@ -73,14 +80,16 @@ def main(data_path: str, output_path: str, data_name: Optional[str] = None): for savepoint_name in sorted(list(savepoint_names)): rank_list = [] names_list = list( - serializer_0.fields_at_savepoint(serializer_0.get_savepoint(savepoint_name)[0]) + serializer_0.fields_at_savepoint( + serializer_0.get_savepoint(savepoint_name)[0] + ) ) serializer_list = [] for rank in range(total_ranks): serializer = get_serializer(data_path, rank, data_name) serializer_list.append(serializer) savepoints = serializer.get_savepoint(savepoint_name) - rank_data = {} + rank_data: Dict[str, Any] = {} for name in set(names_list): rank_data[name] = [] for savepoint in savepoints: @@ -94,7 +103,12 @@ def main(data_path: str, output_path: str, data_name: Optional[str] = None): encoding = {} for varname in set(names_list).difference(["rank"]): data_shape = list(rank_list[0][varname][0].shape) - if savepoint_name in ["FVDynamics-In", "FVDynamics-Out", "Driver-In", "Driver-Out"]: + if savepoint_name in [ + "FVDynamics-In", + "FVDynamics-Out", + "Driver-In", + "Driver-Out", + ]: if varname in [ "qvapor", "qliquid", @@ -138,10 +152,14 @@ def get_data(data_shape, total_ranks, n_savepoints, output_list, varname): data[i_savepoint, rank] = output_list[rank][varname][i_savepoint] return data + def entry_point(): parser = get_parser() args = parser.parse_args() - main(data_path=args.data_path, output_path=args.output_path, data_name=args.data_name) + main( + data_path=args.data_path, output_path=args.output_path, data_name=args.data_name + ) + if __name__ == "__main__": - entry_point() \ No newline at end of file + entry_point()