Skip to content

Commit

Permalink
Merge pull request #158 from NCAR/vert_interp
Browse files Browse the repository at this point in the history
Fixes to interpolation to pressure levels
  • Loading branch information
djgagne authored Feb 19, 2025
2 parents 8cd8f41 + 156e435 commit 5d1de36
Show file tree
Hide file tree
Showing 16 changed files with 561 additions and 1,831 deletions.
95 changes: 16 additions & 79 deletions applications/rollout_to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
from pathlib import Path
from argparse import ArgumentParser
import multiprocessing as mp
from collections import defaultdict

# ---------- #
# Numerics
from datetime import datetime, timedelta
import xarray as xr
import numpy as np
import pandas as pd

# ---------- #
import torch
Expand All @@ -35,7 +33,6 @@
from credit.transforms import load_transforms, Normalize_ERA5_and_Forcing
from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.metrics import LatWeightedMetrics
from credit.forecast import load_forecasts
from credit.distributed import distributed_model_wrapper, setup
from credit.models.checkpoint import load_model_state, load_state_dict_error_handler
Expand Down Expand Up @@ -214,16 +211,16 @@ def predict(rank, world_size, conf, p):
model = distributed_model_wrapper(conf, model, device)
ckpt = os.path.join(save_loc, "checkpoint.pt")
checkpoint = torch.load(ckpt, map_location=device)
load_msg = model.module.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_msg = model.module.load_state_dict(checkpoint["model_state_dict"], strict=False)
load_state_dict_error_handler(load_msg)

elif conf["predict"]["mode"] == "fsdp":
model = load_model(conf, load_weights=True).to(device)
model = distributed_model_wrapper(conf, model, device)
# Load model weights (if any), an optimizer, scheduler, and gradient scaler
model = load_model_state(conf, model, device)
else:
model = None
# ================================================================================ #

model.eval()
Expand All @@ -232,16 +229,8 @@ def predict(rank, world_size, conf, p):
latlons = xr.open_dataset(conf["loss"]["latitude_weights"])

meta_data = load_metadata(conf)

# Set up metrics and containers
metrics = LatWeightedMetrics(conf, predict_mode=True)
metrics_results = defaultdict(list)

# Set up the diffusion and pole filters
if (
"use_laplace_filter" in conf["predict"]
and conf["predict"]["use_laplace_filter"]
):
if "use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]:
dpf = Diffusion_and_Pole_Filter(
nlat=conf["model"]["image_height"],
nlon=conf["model"]["image_width"],
Expand All @@ -268,11 +257,7 @@ def predict(rank, world_size, conf, p):
# combine x and x_surf
# input: (batch_num, time, var, level, lat, lon), (batch_num, time, var, lat, lon)
# output: (batch_num, var, time, lat, lon), 'x' first and then 'x_surf'
x = (
concat_and_reshape(batch["x"], batch["x_surf"])
.to(device)
.float()
)
x = concat_and_reshape(batch["x"], batch["x_surf"]).to(device).float()
else:
# no x_surf
x = reshape_only(batch["x"]).to(device).float()
Expand All @@ -284,9 +269,7 @@ def predict(rank, world_size, conf, p):
# add forcing and static variables (regardless of fcst hours)
if "x_forcing_static" in batch:
# (batch_num, time, var, lat, lon) --> (batch_num, var, time, lat, lon)
x_forcing_batch = (
batch["x_forcing_static"].to(device).permute(0, 2, 1, 3, 4).float()
)
x_forcing_batch = batch["x_forcing_static"].to(device).permute(0, 2, 1, 3, 4).float()

# concat on var dimension
x = torch.cat((x, x_forcing_batch), dim=1)
Expand Down Expand Up @@ -345,32 +328,12 @@ def predict(rank, world_size, conf, p):

# y_pred with unit
y_pred = state_transformer.inverse_transform(y_pred.cpu())
# y_target with unit
y = state_transformer.inverse_transform(y.cpu())

if (
"use_laplace_filter" in conf["predict"]
and conf["predict"]["use_laplace_filter"]
):
y_pred = (
dpf.diff_lap2d_filt(y_pred.to(device).squeeze())
.unsqueeze(0)
.unsqueeze(2)
.cpu()
)

# Compute metrics
metrics_dict = metrics(
y_pred.float(), y.float(), forecast_datetime=forecast_hour
)
for k, m in metrics_dict.items():
metrics_results[k].append(m.item())
metrics_results["forecast_hour"].append(forecast_hour)

if "use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]:
y_pred = dpf.diff_lap2d_filt(y_pred.to(device).squeeze()).unsqueeze(0).unsqueeze(2).cpu()

# Save the current forecast hour data in parallel
utc_datetime = init_datetime + timedelta(
hours=lead_time_periods * forecast_hour
)
utc_datetime = init_datetime + timedelta(hours=lead_time_periods * forecast_hour)

# convert the current step result as x-array
darray_upper_air, darray_single_level = make_xarray(
Expand All @@ -395,13 +358,6 @@ def predict(rank, world_size, conf, p):
)
results.append(result)

metrics_results["datetime"].append(utc_datetime)

print_str = f"Forecast: {forecast_count} "
print_str += f"Date: {utc_datetime.strftime('%Y-%m-%d %H:%M:%S')} "
print_str += f"Hour: {batch['forecast_hour'].item()} "
print_str += f"ACC: {metrics_dict['acc']} "

# Update the input
# setup for next iteration, transform to z-space and send to device
y_pred = state_transformer.transform_array(y_pred).to(device)
Expand All @@ -424,9 +380,7 @@ def predict(rank, world_size, conf, p):

# cut diagnostic vars from y_pred, they are not inputs
if "y_diag" in batch:
x = torch.cat(
[x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2
)
x = torch.cat([x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2)
else:
x = torch.cat([x_detach, y_pred.detach()], dim=2)
# ============================================================ #
Expand All @@ -439,19 +393,6 @@ def predict(rank, world_size, conf, p):
# Wait for all processes to finish in order
for result in results:
result.get()

# save metrics file
save_location = os.path.join(
os.path.expandvars(conf["save_loc"]), "forecasts", "metrics"
)
os.makedirs(
save_location, exist_ok=True
) # should already be made above
df = pd.DataFrame(metrics_results)
df.to_csv(
os.path.join(save_location, f"metrics{init_datetime_str}.csv")
)

# forecast count = a constant for each run
forecast_count += 1

Expand Down Expand Up @@ -563,17 +504,13 @@ def predict(rank, world_size, conf, p):

# ======================================================== #
# handling config args
conf = credit_main_parser(
conf, parse_training=False, parse_predict=True, print_summary=False
)
conf = credit_main_parser(conf, parse_training=False, parse_predict=True, print_summary=False)
predict_data_check(conf, print_summary=False)
# ======================================================== #

# create a save location for rollout
# ---------------------------------------------------- #
assert (
"save_forecast" in conf["predict"]
), "Please specify the output dir through conf['predict']['save_forecast']"
assert "save_forecast" in conf["predict"], "Please specify the output dir through conf['predict']['save_forecast']"

forecast_save_loc = conf["predict"]["save_forecast"]
os.makedirs(forecast_save_loc, exist_ok=True)
Expand Down Expand Up @@ -627,6 +564,6 @@ def predict(rank, world_size, conf, p):
else: # single device inference
_ = predict(0, 1, conf, p=p)

# Ensure all processes are finished
p.close()
p.join()
# Ensure all processes are finished
p.close()
p.join()
2 changes: 1 addition & 1 deletion credit/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2024.1.0
2025.1.0
Loading

0 comments on commit 5d1de36

Please sign in to comment.