Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed May 17, 2024
1 parent 480a574 commit 2bfdf47
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions ndsl/stencils/testing/serialbox_to_netcdf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import argparse
import os
import shutil
import xarray as xr
from typing import Any, Dict, Optional

import f90nml
import numpy as np
from typing import Optional
import xarray as xr


try:
import serialbox
except ModuleNotFoundError:
raise ModuleNotFoundError("Serialbox couldn't be imported, make sure it's in your PYTHONPATH or you env")
raise ModuleNotFoundError(
"Serialbox couldn't be imported, make sure it's in your PYTHONPATH or you env"
)


def get_parser():
Expand All @@ -23,7 +27,10 @@ def get_parser():
"output_path", type=str, help="output directory where netcdf data will be saved"
)
parser.add_argument(
"-dn", "--data_name", type=str, help="[Optional] Give the name of the data, will default to Generator_rankX"
"-dn",
"--data_name",
type=str,
help="[Optional] Give the name of the data, will default to Generator_rankX",
)
return parser

Expand All @@ -43,7 +50,7 @@ def get_all_savepoint_names(serializer):
return savepoint_names


def get_serializer(data_path: str, rank:int , data_name:Optional[str] = None):
def get_serializer(data_path: str, rank: int, data_name: Optional[str] = None):
if data_name:
name = data_name
else:
Expand Down Expand Up @@ -73,14 +80,16 @@ def main(data_path: str, output_path: str, data_name: Optional[str] = None):
for savepoint_name in sorted(list(savepoint_names)):
rank_list = []
names_list = list(
serializer_0.fields_at_savepoint(serializer_0.get_savepoint(savepoint_name)[0])
serializer_0.fields_at_savepoint(
serializer_0.get_savepoint(savepoint_name)[0]
)
)
serializer_list = []
for rank in range(total_ranks):
serializer = get_serializer(data_path, rank, data_name)
serializer_list.append(serializer)
savepoints = serializer.get_savepoint(savepoint_name)
rank_data = {}
rank_data: Dict[str, Any] = {}
for name in set(names_list):
rank_data[name] = []
for savepoint in savepoints:
Expand All @@ -94,7 +103,12 @@ def main(data_path: str, output_path: str, data_name: Optional[str] = None):
encoding = {}
for varname in set(names_list).difference(["rank"]):
data_shape = list(rank_list[0][varname][0].shape)
if savepoint_name in ["FVDynamics-In", "FVDynamics-Out", "Driver-In", "Driver-Out"]:
if savepoint_name in [
"FVDynamics-In",
"FVDynamics-Out",
"Driver-In",
"Driver-Out",
]:
if varname in [
"qvapor",
"qliquid",
Expand Down Expand Up @@ -138,10 +152,14 @@ def get_data(data_shape, total_ranks, n_savepoints, output_list, varname):
data[i_savepoint, rank] = output_list[rank][varname][i_savepoint]
return data


def entry_point():
parser = get_parser()
args = parser.parse_args()
main(data_path=args.data_path, output_path=args.output_path, data_name=args.data_name)
main(
data_path=args.data_path, output_path=args.output_path, data_name=args.data_name
)


if __name__ == "__main__":
entry_point()
entry_point()

0 comments on commit 2bfdf47

Please sign in to comment.