Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to interpolation to pressure levels #158

Merged
merged 13 commits into from
Feb 19, 2025
Merged
95 changes: 16 additions & 79 deletions applications/rollout_to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
from pathlib import Path
from argparse import ArgumentParser
import multiprocessing as mp
from collections import defaultdict

# ---------- #
# Numerics
from datetime import datetime, timedelta
import xarray as xr
import numpy as np
import pandas as pd

# ---------- #
import torch
Expand All @@ -35,7 +33,6 @@
from credit.transforms import load_transforms, Normalize_ERA5_and_Forcing
from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.metrics import LatWeightedMetrics
from credit.forecast import load_forecasts
from credit.distributed import distributed_model_wrapper, setup
from credit.models.checkpoint import load_model_state, load_state_dict_error_handler
Expand Down Expand Up @@ -214,16 +211,16 @@ def predict(rank, world_size, conf, p):
model = distributed_model_wrapper(conf, model, device)
ckpt = os.path.join(save_loc, "checkpoint.pt")
checkpoint = torch.load(ckpt, map_location=device)
load_msg = model.module.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_msg = model.module.load_state_dict(checkpoint["model_state_dict"], strict=False)
load_state_dict_error_handler(load_msg)

elif conf["predict"]["mode"] == "fsdp":
model = load_model(conf, load_weights=True).to(device)
model = distributed_model_wrapper(conf, model, device)
# Load model weights (if any), an optimizer, scheduler, and gradient scaler
model = load_model_state(conf, model, device)
else:
model = None
# ================================================================================ #

model.eval()
Expand All @@ -232,16 +229,8 @@ def predict(rank, world_size, conf, p):
latlons = xr.open_dataset(conf["loss"]["latitude_weights"])

meta_data = load_metadata(conf)

# Set up metrics and containers
metrics = LatWeightedMetrics(conf, predict_mode=True)
metrics_results = defaultdict(list)

# Set up the diffusion and pole filters
if (
"use_laplace_filter" in conf["predict"]
and conf["predict"]["use_laplace_filter"]
):
if "use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]:
dpf = Diffusion_and_Pole_Filter(
nlat=conf["model"]["image_height"],
nlon=conf["model"]["image_width"],
Expand All @@ -268,11 +257,7 @@ def predict(rank, world_size, conf, p):
# combine x and x_surf
# input: (batch_num, time, var, level, lat, lon), (batch_num, time, var, lat, lon)
# output: (batch_num, var, time, lat, lon), 'x' first and then 'x_surf'
x = (
concat_and_reshape(batch["x"], batch["x_surf"])
.to(device)
.float()
)
x = concat_and_reshape(batch["x"], batch["x_surf"]).to(device).float()
else:
# no x_surf
x = reshape_only(batch["x"]).to(device).float()
Expand All @@ -284,9 +269,7 @@ def predict(rank, world_size, conf, p):
# add forcing and static variables (regardless of fcst hours)
if "x_forcing_static" in batch:
# (batch_num, time, var, lat, lon) --> (batch_num, var, time, lat, lon)
x_forcing_batch = (
batch["x_forcing_static"].to(device).permute(0, 2, 1, 3, 4).float()
)
x_forcing_batch = batch["x_forcing_static"].to(device).permute(0, 2, 1, 3, 4).float()

# concat on var dimension
x = torch.cat((x, x_forcing_batch), dim=1)
Expand Down Expand Up @@ -345,32 +328,12 @@ def predict(rank, world_size, conf, p):

# y_pred with unit
y_pred = state_transformer.inverse_transform(y_pred.cpu())
# y_target with unit
y = state_transformer.inverse_transform(y.cpu())

if (
"use_laplace_filter" in conf["predict"]
and conf["predict"]["use_laplace_filter"]
):
y_pred = (
dpf.diff_lap2d_filt(y_pred.to(device).squeeze())
.unsqueeze(0)
.unsqueeze(2)
.cpu()
)

# Compute metrics
metrics_dict = metrics(
y_pred.float(), y.float(), forecast_datetime=forecast_hour
)
for k, m in metrics_dict.items():
metrics_results[k].append(m.item())
metrics_results["forecast_hour"].append(forecast_hour)

if "use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]:
y_pred = dpf.diff_lap2d_filt(y_pred.to(device).squeeze()).unsqueeze(0).unsqueeze(2).cpu()

# Save the current forecast hour data in parallel
utc_datetime = init_datetime + timedelta(
hours=lead_time_periods * forecast_hour
)
utc_datetime = init_datetime + timedelta(hours=lead_time_periods * forecast_hour)

# convert the current step result as x-array
darray_upper_air, darray_single_level = make_xarray(
Expand All @@ -395,13 +358,6 @@ def predict(rank, world_size, conf, p):
)
results.append(result)

metrics_results["datetime"].append(utc_datetime)

print_str = f"Forecast: {forecast_count} "
print_str += f"Date: {utc_datetime.strftime('%Y-%m-%d %H:%M:%S')} "
print_str += f"Hour: {batch['forecast_hour'].item()} "
print_str += f"ACC: {metrics_dict['acc']} "

# Update the input
# setup for next iteration, transform to z-space and send to device
y_pred = state_transformer.transform_array(y_pred).to(device)
Expand All @@ -424,9 +380,7 @@ def predict(rank, world_size, conf, p):

# cut diagnostic vars from y_pred, they are not inputs
if "y_diag" in batch:
x = torch.cat(
[x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2
)
x = torch.cat([x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2)
else:
x = torch.cat([x_detach, y_pred.detach()], dim=2)
# ============================================================ #
Expand All @@ -439,19 +393,6 @@ def predict(rank, world_size, conf, p):
# Wait for all processes to finish in order
for result in results:
result.get()

# save metrics file
save_location = os.path.join(
os.path.expandvars(conf["save_loc"]), "forecasts", "metrics"
)
os.makedirs(
save_location, exist_ok=True
) # should already be made above
df = pd.DataFrame(metrics_results)
df.to_csv(
os.path.join(save_location, f"metrics{init_datetime_str}.csv")
)

# forecast count = a constant for each run
forecast_count += 1

Expand Down Expand Up @@ -563,17 +504,13 @@ def predict(rank, world_size, conf, p):

# ======================================================== #
# handling config args
conf = credit_main_parser(
conf, parse_training=False, parse_predict=True, print_summary=False
)
conf = credit_main_parser(conf, parse_training=False, parse_predict=True, print_summary=False)
predict_data_check(conf, print_summary=False)
# ======================================================== #

# create a save location for rollout
# ---------------------------------------------------- #
assert (
"save_forecast" in conf["predict"]
), "Please specify the output dir through conf['predict']['save_forecast']"
assert "save_forecast" in conf["predict"], "Please specify the output dir through conf['predict']['save_forecast']"

forecast_save_loc = conf["predict"]["save_forecast"]
os.makedirs(forecast_save_loc, exist_ok=True)
Expand Down Expand Up @@ -627,6 +564,6 @@ def predict(rank, world_size, conf, p):
else: # single device inference
_ = predict(0, 1, conf, p=p)

# Ensure all processes are finished
p.close()
p.join()
# Ensure all processes are finished
p.close()
p.join()
2 changes: 1 addition & 1 deletion credit/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2024.1.0
2025.1.0
Loading