Skip to content

Commit

Permalink
cesm and ensemble compatible rollout
Browse files Browse the repository at this point in the history
  • Loading branch information
dkimpara committed Jan 24, 2025
1 parent b5b4ac6 commit ccf8287
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 25 deletions.
48 changes: 29 additions & 19 deletions applications/rollout_to_netcdf_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def predict(rank, world_size, conf, p):

# batch size
batch_size = conf["predict"].get("batch_size", 1)
ensemble_size = conf["predict"].get("ensemble_size", 1)

# transform and ToTensor class
if conf["data"]["scaler_type"] == "std_new":
Expand Down Expand Up @@ -223,18 +224,25 @@ def predict(rank, world_size, conf, p):
).to(device).float()
else:
x = reshape_only(batch["x"]).to(device).float()

# create ensemble:
if ensemble_size > 1:
x = torch.repeat_interleave(x, ensemble_size, 0)


# Add forcing and static variables for the entire batch
if "x_forcing_static" in batch:
x_forcing_batch = batch["x_forcing_static"].to(device).permute(0, 2, 1, 3, 4).float()
if ensemble_size > 1:
x_forcing_batch = torch.repeat_interleave(x_forcing_batch, ensemble_size, 0)
x = torch.cat((x, x_forcing_batch), dim=1)

# Clamp if needed
if flag_clamp:
x = torch.clamp(x, min=clamp_min, max=clamp_max)

# Model inference on the entire batch
y_pred = model(x)
y_pred = model(x.float())

# Post-processing blocks
if flag_mass_conserve:
Expand All @@ -256,7 +264,7 @@ def predict(rank, world_size, conf, p):

# Transform predictions
y_pred = state_transformer.inverse_transform(y_pred.cpu())

if "use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]:
y_pred = dpf.diff_lap2d_filt(y_pred.to(device).squeeze()).unsqueeze(0).unsqueeze(2).cpu()

Expand All @@ -268,20 +276,27 @@ def predict(rank, world_size, conf, p):

# Convert to xarray and handle results
for j in range(batch_size):
darray_upper_air, darray_single_level = make_xarray(
y_pred[j:j+1], # Process each forecast step
utc_datetimes[j],
latlons.latitude.values,
latlons.longitude.values,
conf,
)
upper_air_list, single_level_list = [], []
for i in range(ensemble_size): # ensemble_size default is 1
darray_upper_air, darray_single_level = make_xarray(
y_pred[j+i:j+i+1], # Process each batch
utc_datetimes[j],
latlons.latitude.values,
latlons.longitude.values,
conf,
)
upper_air_list.append(darray_upper_air)
single_level_list.append(darray_single_level)

if ensemble_size > 1:
pass

# Save the current forecast hour data in parallel
result = p.apply_async(
save_netcdf_increment,
(
darray_upper_air,
darray_single_level,
upper_air_list,
single_level_list,
save_datetimes[forecast_count + j], # Use correct index for current batch item
lead_time_periods * forecast_step,
meta_data,
Expand All @@ -298,21 +313,16 @@ def predict(rank, world_size, conf, p):
# Prepare for next iteration
y_pred = state_transformer.transform_array(y_pred).to(device)

# rollout doesnt draw any more samples so we always need to remove diagnostics
if history_len == 1:
if "y_diag" in batch:
x = y_pred[:, :-varnum_diag, ...].detach()
else:
x = y_pred.detach()
x = y_pred[:, :-varnum_diag, ...].detach()
else:
if static_dim_size == 0:
x_detach = x[:, :, 1:, ...].detach()
else:
x_detach = x[:, :-static_dim_size, 1:, ...].detach()

if "y_diag" in batch:
x = torch.cat([x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2)
else:
x = torch.cat([x_detach, y_pred.detach()], dim=2)
x = torch.cat([x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2)

# Memory cleanup
torch.cuda.empty_cache()
Expand Down
7 changes: 4 additions & 3 deletions credit/datasets/era5_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def worker(
# ind_end_time = len(ERA5_subset['time'])

# datetiem information as int number (used in some normalization methods)
datetime_as_number = ERA5_subset.time.values.astype("datetime64[s]").astype(int)
datetime_as_number = ERA5_subset.time.astype("datetime64[s]").values.astype(int)

# ==================================================== #
# xarray dataset as input
Expand Down Expand Up @@ -236,11 +236,12 @@ def worker(
# sample['stop_forecast'] = stop_forecast
sample["datetime"] = [
int(
historical_ERA5_images.time.values[0]
historical_ERA5_images.time[0]
.astype("datetime64[s]")
.values
.astype(int)
),
int(target_ERA5_images.time.values[0].astype("datetime64[s]").astype(int)),
int(target_ERA5_images.time[0].astype("datetime64[s]").values.astype(int)),
]

except Exception as e:
Expand Down
11 changes: 8 additions & 3 deletions credit/datasets/era5_predict_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,13 @@ def find_start_stop_indices(self, index):
info = []
for init_time in self.init_time_list_np:
for i_file, ds in enumerate(self.all_files):
# get the year of the current file
ds_year = int(np.datetime_as_string(ds["time"][0].values, unit="Y"))
# get the year of the current file
# looks messy because extra code needed to handle cftime
ds_year = int((
np.datetime_as_string(ds["time"][0]
.astype('datetime64[ns]')
.values,
unit="Y")))

# get the first and last years of init times
init_year0 = nanoseconds_to_year(init_time)
Expand All @@ -497,7 +502,7 @@ def find_start_stop_indices(self, index):
N_times = len(ds["time"])
# convert ds['time'] to a list of nanosecondes
ds_time_list = [
np.datetime64(ds_time.values).astype(datetime)
ds_time.astype('datetime64[ns]').values.astype(datetime)
for ds_time in ds["time"]
]
ds_start_time = ds_time_list[0]
Expand Down

0 comments on commit ccf8287

Please sign in to comment.