From 58c854cdab6ef75a22448b4f2d19c3c2c04bfe8b Mon Sep 17 00:00:00 2001 From: David John Gagne Date: Mon, 10 Feb 2025 17:51:30 -0700 Subject: [PATCH 01/13] Added more detailed pressure interpolation --- credit/interp.py | 264 ++++++++++++++++++++++++++++++++++++++++++++--- credit/output.py | 6 +- 2 files changed, 252 insertions(+), 18 deletions(-) diff --git a/credit/interp.py b/credit/interp.py index 433d05eb..ef7fd076 100644 --- a/credit/interp.py +++ b/credit/interp.py @@ -2,19 +2,19 @@ from numba import njit import xarray as xr from tqdm import tqdm -from .physics_constants import RDGAS, RVGAS +from .physics_constants import RDGAS, RVGAS, GRAVITY import os def full_state_pressure_interpolation( state_dataset: xr.Dataset, + surface_geopotential: np.ndarray, pressure_levels: np.ndarray = np.array([500.0, 850.0]), interp_fields: tuple[str] = ("U", "V", "T", "Q"), pres_ending: str = "_PRES", temperature_var: str = "T", q_var: str = "Q", surface_pressure_var: str = "SP", - surface_geopotential_var: str = "Z_GDS4_SFC", geopotential_var: str = "Z", time_var: str = "time", lat_var: str = "latitude", @@ -23,19 +23,21 @@ def full_state_pressure_interpolation( level_var: str = "level", model_level_file: str = "../credit/metadata/ERA5_Lev_Info.nc", verbose: int = 1, + a_coord: str = "a_model", + b_coord: str = "b_model", ) -> xr.Dataset: """ Interpolate full model state variables from model levels to pressure levels. Args: state_dataset (xr.Dataset): state variables being interpolated + surface_geopotential (np.ndarray): surface geopotential levels in units m^2/s^2. pressure_levels (np.ndarray): pressure levels for interpolation in hPa. interp_fields (tuple[str]): fields to be interpolated. pres_ending (str): ending string to attach to pressure interpolated variables. temperature_var (str): temperature variable to be interpolated (units K). q_var (str): mixing ratio/specific humidity variable to be interpolated (units kg/kg). surface_pressure_var (str): surface pressure variable (units Pa). - surface_geopotential_var (str): surface geoptential variable (units m^2/s^2). geopotential_var (str): geopotential variable being derived (units m^2/s^2). time_var (str): time coordinate lat_var (str): latitude coordinate @@ -44,14 +46,16 @@ def full_state_pressure_interpolation( level_var (str): name of level coordinate model_level_file (str): relative path to file containing model levels. verbose (int): verbosity level. If verbose > 0, print progress. + a_coord (str): Name of A weight in sigma coordinate formula. 'a_model' by default. + b_coord (str): Name of B weight in sigma coordinate formula. 'b_model' by default. Returns: pressure_ds (xr.Dataset): Dataset containing pressure interpolated variables. """ path_to_file = os.path.abspath(os.path.dirname(__file__)) model_level_file = os.path.join(path_to_file, model_level_file) with xr.open_dataset(model_level_file) as mod_lev_ds: - model_a = mod_lev_ds["a_model"].loc[state_dataset[level_var]].values - model_b = mod_lev_ds["b_model"].loc[state_dataset[level_var]].values + model_a = mod_lev_ds[a_coord].loc[state_dataset[level_var]].values + model_b = mod_lev_ds[b_coord].loc[state_dataset[level_var]].values pres_dims = (time_var, pres_var, lat_var, lon_var) coords = { time_var: state_dataset[time_var], @@ -82,7 +86,7 @@ def full_state_pressure_interpolation( state_dataset[surface_pressure_var][t].values, model_a, model_b ) geopotential_grid = geopotential_from_model_vars( - state_dataset[surface_geopotential_var][t].values, + surface_geopotential, state_dataset[surface_pressure_var][t].values, state_dataset[temperature_var][t].values, state_dataset[q_var][t].values, @@ -90,16 +94,33 @@ def full_state_pressure_interpolation( model_b, ) for interp_field in interp_fields: - pressure_ds[interp_field + pres_ending][t] = ( - interp_hybrid_to_pressure_levels( - state_dataset[interp_field][t].values, - pressure_grid / 100.0, - pressure_levels, + if interp_field == temperature_var: + pressure_ds[interp_field + pres_ending][t] = ( + interp_temperature_to_pressure_levels( + state_dataset[interp_field][t].values, + pressure_grid / 100.0, + pressure_levels, + state_dataset[surface_pressure_var][t].values / 100.0, + surface_geopotential, + state_dataset[temperature_var][t, -1].values, + ) + ) + else: + pressure_ds[interp_field + pres_ending][t] = ( + interp_hybrid_to_pressure_levels( + state_dataset[interp_field][t].values, + pressure_grid / 100.0, + pressure_levels, + ) ) - ) pressure_ds[geopotential_var + pres_ending][t] = ( - interp_hybrid_to_pressure_levels( - geopotential_grid, pressure_grid / 100.0, pressure_levels + interp_geopotential_to_pressure_levels( + geopotential_grid, + pressure_grid / 100.0, + pressure_levels, + state_dataset[surface_pressure_var][t].values / 100.0, + surface_geopotential, + state_dataset[temperature_var][t, -1].values, ) ) return pressure_ds @@ -171,6 +192,156 @@ def interp_hybrid_to_pressure_levels(model_var, model_pressure, interp_pressures return pressure_var +@njit +def interp_pressure_to_hybrid_levels( + pressure_var, pressure_levels, model_pressure, surface_pressure +): + """ + Interpolate data field from hybrid sigma-pressure vertical coordinates to pressure levels. + `model_pressure` and `pressure_levels` and 'surface_pressure' should have consistent units with each other. + + Args: + pressure_var (np.ndarray): 3D field on pressure levels with shape (levels, y, x). + pressure_levels (np.double): pressure levels for interpolation in units Pa or hPa. + model_pressure (np.ndarray): 3D pressure field with shape (levels, y, x) in units Pa or hPa + surface_pressure (np.ndarray): pressure at the surface in units Pa or hPa. + + Returns: + model_var (np.ndarray): 3D field on hybrid sigma-pressure levels with shape (model_pressure.shape[0], y, x). + """ + model_var = np.zeros(model_pressure.shape, dtype=model_pressure.dtype) + log_interp_pressures = np.log(pressure_levels) + for (i, j), v in np.ndenumerate(model_var[0]): + air_levels = np.where(pressure_levels < surface_pressure[i, j])[0] + model_var[:, i, j] = np.interp( + np.log(model_pressure[:, i, j]), + log_interp_pressures[air_levels], + pressure_var[air_levels, i, j], + ) + return pressure_var + + +@njit +def interp_geopotential_to_pressure_levels( + model_var, + model_pressure, + interp_pressures, + surface_pressure, + surface_geopotential, + temperature_lowest_level_k, +): + """ + Interpolate geopotential field from hybrid sigma-pressure vertical coordinates to pressure levels. + `model_pressure` and `interp_pressure` should have consistent units of hPa or Pa. Geopotential height is extrapolated + below the surface based on Eq. 15 in Trenberth et al. (1993). + + Args: + model_var (np.ndarray): 3D field on hybrid sigma-pressure levels with shape (levels, y, x). + model_pressure (np.ndarray): 3D pressure field with shape (levels, y, x) in units Pa or hPa + interp_pressures (np.ndarray): pressure levels for interpolation in units Pa or hPa. + surface_pressure (np.ndarray): pressure at the surface in units Pa or hPa. + surface_geopotential (np.ndarray): geopotential at the surface in units m^2/s^2. + temperature_lowest_level_k (np.ndarray): lowest model level temperature in Kelvin. + + Returns: + pressure_var (np.ndarray): 3D field on pressure levels with shape (len(interp_pressures), y, x). + """ + LAPSE_RATE = 0.0065 # K / m + ALPHA = LAPSE_RATE * RDGAS / GRAVITY + pressure_var = np.zeros( + (interp_pressures.shape[0], model_var.shape[1], model_var.shape[2]), + dtype=model_var.dtype, + ) + log_interp_pressures = np.log(interp_pressures) + for (i, j), v in np.ndenumerate(model_var[0]): + pressure_var[:, i, j] = np.interp( + log_interp_pressures, np.log(model_pressure[:, i, j]), model_var[:, i, j] + ) + for pl, interp_pressure in enumerate(interp_pressures): + if interp_pressure > surface_pressure[i, j]: + temp_surface_k = temperature_lowest_level_k[ + i, j + ] + ALPHA * temperature_lowest_level_k[i, j] * ( + surface_pressure[i, j] / model_pressure[-1, i, j] - 1 + ) + ln_p_ps = np.log(interp_pressure / surface_pressure[i, j]) + pressure_var[pl, i, j] = surface_geopotential[ + i, j + ] - RDGAS * temp_surface_k * ln_p_ps * ( + 1 + 0.5 * ALPHA * ln_p_ps + 1 / 6.0 * (ALPHA * ln_p_ps) ** 2 + ) + return pressure_var + + +@njit +def interp_temperature_to_pressure_levels( + model_var, + model_pressure, + interp_pressures, + surface_pressure, + surface_geopotential, + temperature_lowest_level_k, +): + """ + Interpolate temperature field from hybrid sigma-pressure vertical coordinates to pressure levels. + `model_pressure` and `interp_pressure` should have consistent units of hPa or Pa. Temperature is extrapolated + below the surface based on Eq. 16 in Trenberth et al. (1993). + + Args: + model_var (np.ndarray): 3D field on hybrid sigma-pressure levels with shape (levels, y, x). + model_pressure (np.ndarray): 3D pressure field with shape (levels, y, x) in units Pa + interp_pressures: (np.ndarray): pressure levels for interpolation in units Pa or. + surface_pressure (np.ndarray): pressure at the surface in units Pa or hPa. + surface_geopotential (np.ndarray): geopotential at the surface in units m^2/s^2. + temperature_lowest_level_k (np.ndarray): lowest model level temperature in Kelvin. + + Returns: + pressure_var (np.ndarray): 3D field on pressure levels with shape (len(interp_pressures), y, x). + """ + LAPSE_RATE = 0.0065 # K / m + ALPHA = LAPSE_RATE * RDGAS / GRAVITY + pressure_var = np.zeros( + (interp_pressures.shape[0], model_var.shape[1], model_var.shape[2]), + dtype=model_var.dtype, + ) + log_interp_pressures = np.log(interp_pressures) + for (i, j), v in np.ndenumerate(model_var[0]): + pressure_var[:, i, j] = np.interp( + log_interp_pressures, np.log(model_pressure[:, i, j]), model_var[:, i, j] + ) + for pl, interp_pressure in enumerate(interp_pressures): + if interp_pressure > surface_pressure[i, j]: + temp_surface_k = temperature_lowest_level_k[ + i, j + ] + ALPHA * temperature_lowest_level_k[i, j] * ( + surface_pressure[i, j] / model_pressure[-1, i, j] - 1 + ) + surface_height = surface_geopotential[i, j] / GRAVITY + temp_sea_level_k = temp_surface_k + LAPSE_RATE * surface_height + temp_pl = np.minimum(temp_surface_k, 298.0) + if surface_height > 2500.0: + a_adjusted = ( + RDGAS * (temp_pl - temp_surface_k) / surface_geopotential[i, j] + ) + elif 2000.0 <= surface_height <= 2500.0: + t_adjusted = 0.002 * ( + (2500 - surface_height) * temp_sea_level_k + + (surface_height - 2000.0) * temp_pl + ) + a_adjusted = ( + RDGAS + * (t_adjusted - temp_surface_k) + / surface_geopotential[i, j] + ) + else: + a_adjusted = ALPHA + a_ln_p = a_adjusted * np.log(interp_pressure / surface_pressure[i, j]) + pressure_var[pl, i, j] = temp_surface_k * ( + 1 + a_ln_p + 0.5 * a_ln_p**2 + 1 / 6.0 * a_ln_p**3 + ) + return pressure_var + + @njit def geopotential_from_model_vars( surface_geopotential, surface_pressure, temperature, mixing_ratio, model_a, model_b @@ -219,3 +390,68 @@ def geopotential_from_model_vars( ] * np.log(half_pressure[h] / model_pressure[m]) h -= 1 return model_geopotential + + +@njit +def mean_sea_level_pressure( + surface_pressure_pa, + temperature_lowest_level_k, + pressure_lowest_level_pa, + surface_geopotential, +): + """ + Calculate mean sea level pressure from surface pressure, lowest model level temperature, + the pressure of the lowest model level (derived from create_pressure_grid), and surface_geopotential. + This calculation is based on the procedure from Trenberth et al. (1993) implemented in CESM CAM. + + Trenberth, K., J. Berry , and L. Buja, 1993: Vertical Interpolation and Truncation of Model-Coordinate, + University Corporation for Atmospheric Research, https://doi.org/10.5065/D6HX19NH. + + CAM implementation: https://github.com/ESCOMP/CAM/blob/cam_cesm2_2_rel/src/physics/cam/cpslec.F90 + + Args: + surface_pressure_pa: surface pressure in Pascals + temperature_lowest_level_k: Temperature at the lowest model level in Kelvin. + pressure_lowest_level_pa: Pressure at the lowest model level in Pascals. + surface_geopotential: Geopotential of the surface in m^2 s^-2. + + Returns: + mslp: Mean sea level pressure in Pascals. + """ + LAPSE_RATE = 0.0065 # K / m + ALPHA = LAPSE_RATE * RDGAS / GRAVITY + mslp = np.zeros(surface_pressure_pa.shape, dtype=surface_pressure_pa.dtype) + for (i, j), p in np.ndenumerate(mslp): + if np.abs(surface_geopotential[i, j] / GRAVITY) < 1e-4: + mslp[i, j] = surface_pressure_pa[i, j] + else: + temp_surface_k = temperature_lowest_level_k[ + i, j + ] + ALPHA * temperature_lowest_level_k[i, j] * ( + surface_pressure_pa[i, j] / pressure_lowest_level_pa[i, j] - 1 + ) + temp_sealevel_k = ( + temp_surface_k + LAPSE_RATE * surface_geopotential[i, j] / GRAVITY + ) + + if (temp_surface_k <= 290.5) and (temp_sealevel_k > 290.5): + alpha_adjusted = ( + RDGAS / surface_geopotential[i, j] * (290.5 - temp_surface_k) + ) + elif (temp_surface_k > 290.5) and (temp_sealevel_k > 290.5): + alpha_adjusted = 0.0 + temp_surface_k = 0.5 * (290.5 + temp_surface_k) + else: + alpha_adjusted = ALPHA + if temp_surface_k < 255: + temp_surface_k = 0.5 * (255.0 + temp_surface_k) + beta = surface_geopotential[i, j] / (RDGAS * temp_surface_k) + mslp[i, j] = surface_pressure_pa[i, j] * np.exp( + beta + * ( + 1 + - alpha_adjusted * beta / 2.0 + + ((alpha_adjusted * beta) ** 2) / 3.0 + ) + ) + return mslp diff --git a/credit/output.py b/credit/output.py index dfcf3411..fca063be 100644 --- a/credit/output.py +++ b/credit/output.py @@ -184,11 +184,9 @@ def save_netcdf_increment( else: surface_geopotential_var = "Z_GDS4_SFC" with xr.open_dataset(conf["data"]["save_loc_static"]) as static_ds: - ds_merged[surface_geopotential_var] = static_ds[ - surface_geopotential_var - ] + surface_geopotential = static_ds[surface_geopotential_var].values pressure_interp = full_state_pressure_interpolation( - ds_merged, **conf["predict"]["interp_pressure"] + ds_merged, surface_geopotential, **conf["predict"]["interp_pressure"] ) ds_merged = xr.merge([ds_merged, pressure_interp]) From 8682795048ffa5c2473dcc092728d1075c2b4003 Mon Sep 17 00:00:00 2001 From: David John Gagne Date: Mon, 10 Feb 2025 18:20:42 -0700 Subject: [PATCH 02/13] Updated test --- tests/test_interp.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_interp.py b/tests/test_interp.py index 8ad9dd35..7a35eb49 100644 --- a/tests/test_interp.py +++ b/tests/test_interp.py @@ -10,7 +10,11 @@ def test_full_state_pressure_interpolation(): ds = xr.open_dataset(input_file) pressure_levels = np.array([200.0, 500.0, 700.0, 850.0, 1000.0]) interp_ds = full_state_pressure_interpolation( - ds, pressure_levels=pressure_levels, lat_var="lat", lon_var="lon" + ds, + ds["Z_GDS4_SFC"].values, + pressure_levels=pressure_levels, + lat_var="lat", + lon_var="lon", ) for var in ["U", "V", "T", "Q"]: assert ( From 995a017e78f7a252af1669d1df6371c118ecfb53 Mon Sep 17 00:00:00 2001 From: David John Gagne Date: Tue, 11 Feb 2025 23:04:06 -0700 Subject: [PATCH 03/13] Fixed issues with mslp and temperature interp, still bigger differences with geopotential. --- credit/interp.py | 68 ++++++++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/credit/interp.py b/credit/interp.py index ef7fd076..b41f557b 100644 --- a/credit/interp.py +++ b/credit/interp.py @@ -229,6 +229,7 @@ def interp_geopotential_to_pressure_levels( surface_pressure, surface_geopotential, temperature_lowest_level_k, + temp_level_index=-6, ): """ Interpolate geopotential field from hybrid sigma-pressure vertical coordinates to pressure levels. @@ -242,7 +243,8 @@ def interp_geopotential_to_pressure_levels( surface_pressure (np.ndarray): pressure at the surface in units Pa or hPa. surface_geopotential (np.ndarray): geopotential at the surface in units m^2/s^2. temperature_lowest_level_k (np.ndarray): lowest model level temperature in Kelvin. - + temp_level_index (int): index of vertical level where temperature is extracted for extrapolation to + surface. Returns: pressure_var (np.ndarray): 3D field on pressure levels with shape (len(interp_pressures), y, x). """ @@ -262,13 +264,13 @@ def interp_geopotential_to_pressure_levels( temp_surface_k = temperature_lowest_level_k[ i, j ] + ALPHA * temperature_lowest_level_k[i, j] * ( - surface_pressure[i, j] / model_pressure[-1, i, j] - 1 + surface_pressure[i, j] / model_pressure[temp_level_index, i, j] - 1 ) ln_p_ps = np.log(interp_pressure / surface_pressure[i, j]) pressure_var[pl, i, j] = surface_geopotential[ i, j ] - RDGAS * temp_surface_k * ln_p_ps * ( - 1 + 0.5 * ALPHA * ln_p_ps + 1 / 6.0 * (ALPHA * ln_p_ps) ** 2 + 1 + ALPHA * ln_p_ps / 2.0 + (ALPHA * ln_p_ps) ** 2 / 6.0 ) return pressure_var @@ -280,7 +282,7 @@ def interp_temperature_to_pressure_levels( interp_pressures, surface_pressure, surface_geopotential, - temperature_lowest_level_k, + temp_level_index=-6, ): """ Interpolate temperature field from hybrid sigma-pressure vertical coordinates to pressure levels. @@ -293,7 +295,8 @@ def interp_temperature_to_pressure_levels( interp_pressures: (np.ndarray): pressure levels for interpolation in units Pa or. surface_pressure (np.ndarray): pressure at the surface in units Pa or hPa. surface_geopotential (np.ndarray): geopotential at the surface in units m^2/s^2. - temperature_lowest_level_k (np.ndarray): lowest model level temperature in Kelvin. + temp_level_index (int): index of vertical level where temperature is extracted for extrapolation to + surface. Returns: pressure_var (np.ndarray): 3D field on pressure levels with shape (len(interp_pressures), y, x). @@ -311,31 +314,38 @@ def interp_temperature_to_pressure_levels( ) for pl, interp_pressure in enumerate(interp_pressures): if interp_pressure > surface_pressure[i, j]: - temp_surface_k = temperature_lowest_level_k[ - i, j - ] + ALPHA * temperature_lowest_level_k[i, j] * ( - surface_pressure[i, j] / model_pressure[-1, i, j] - 1 + model_level_temp = model_var[temp_level_index, i, j] + temp_surface_k = model_level_temp + ALPHA * model_level_temp * ( + surface_pressure[i, j] / model_pressure[temp_level_index, i, j] - 1 ) surface_height = surface_geopotential[i, j] / GRAVITY temp_sea_level_k = temp_surface_k + LAPSE_RATE * surface_height - temp_pl = np.minimum(temp_surface_k, 298.0) + temp_pl = np.minimum(temp_sea_level_k, 298.0) if surface_height > 2500.0: - a_adjusted = ( - RDGAS * (temp_pl - temp_surface_k) / surface_geopotential[i, j] + gamma = ( + GRAVITY + / surface_geopotential[i, j] + * np.maximum(temp_pl - temp_surface_k, 0) ) + elif 2000.0 <= surface_height <= 2500.0: t_adjusted = 0.002 * ( (2500 - surface_height) * temp_sea_level_k + (surface_height - 2000.0) * temp_pl ) - a_adjusted = ( - RDGAS - * (t_adjusted - temp_surface_k) + gamma = ( + GRAVITY / surface_geopotential[i, j] + * (t_adjusted - temp_surface_k) ) else: - a_adjusted = ALPHA - a_ln_p = a_adjusted * np.log(interp_pressure / surface_pressure[i, j]) + gamma = LAPSE_RATE + a_ln_p = ( + gamma + * RDGAS + / GRAVITY + * np.log(interp_pressure / surface_pressure[i, j]) + ) pressure_var[pl, i, j] = temp_surface_k * ( 1 + a_ln_p + 0.5 * a_ln_p**2 + 1 / 6.0 * a_ln_p**3 ) @@ -435,23 +445,25 @@ def mean_sea_level_pressure( ) if (temp_surface_k <= 290.5) and (temp_sealevel_k > 290.5): - alpha_adjusted = ( - RDGAS / surface_geopotential[i, j] * (290.5 - temp_surface_k) - ) + gamma = GRAVITY / surface_geopotential[i, j] * (290.5 - temp_surface_k) elif (temp_surface_k > 290.5) and (temp_sealevel_k > 290.5): - alpha_adjusted = 0.0 + gamma = 0.0 temp_surface_k = 0.5 * (290.5 + temp_surface_k) else: - alpha_adjusted = ALPHA + gamma = LAPSE_RATE if temp_surface_k < 255: temp_surface_k = 0.5 * (255.0 + temp_surface_k) beta = surface_geopotential[i, j] / (RDGAS * temp_surface_k) + x = gamma * surface_geopotential[i, j] / (GRAVITY * temp_surface_k) + # mslp[i, j] = surface_pressure_pa[i, j] * np.exp( + # beta + # * ( + # 1 + # - alpha_adjusted * beta / 2.0 + # + ((alpha_adjusted * beta) ** 2) / 3.0 + # ) + # ) mslp[i, j] = surface_pressure_pa[i, j] * np.exp( - beta - * ( - 1 - - alpha_adjusted * beta / 2.0 - + ((alpha_adjusted * beta) ** 2) / 3.0 - ) + beta * (1.0 - x / 2.0 + x**2 / 3.0) ) return mslp From 623dd61e5289e214fd4ae5263948b9f1b60342e4 Mon Sep 17 00:00:00 2001 From: David John Gagne Date: Wed, 12 Feb 2025 17:56:17 -0700 Subject: [PATCH 04/13] Updated code to fix test issue. --- credit/interp.py | 170 +++++++++++++++++++---------------------------- pyproject.toml | 1 + 2 files changed, 68 insertions(+), 103 deletions(-) diff --git a/credit/interp.py b/credit/interp.py index b41f557b..82247190 100644 --- a/credit/interp.py +++ b/credit/interp.py @@ -25,6 +25,7 @@ def full_state_pressure_interpolation( verbose: int = 1, a_coord: str = "a_model", b_coord: str = "b_model", + temp_level_index: int = -2, ) -> xr.Dataset: """ Interpolate full model state variables from model levels to pressure levels. @@ -48,6 +49,7 @@ def full_state_pressure_interpolation( verbose (int): verbosity level. If verbose > 0, print progress. a_coord (str): Name of A weight in sigma coordinate formula. 'a_model' by default. b_coord (str): Name of B weight in sigma coordinate formula. 'b_model' by default. + temp_level_index (int): vertical index of the temperature level used for interpolation below ground. Returns: pressure_ds (xr.Dataset): Dataset containing pressure interpolated variables. """ @@ -57,12 +59,18 @@ def full_state_pressure_interpolation( model_a = mod_lev_ds[a_coord].loc[state_dataset[level_var]].values model_b = mod_lev_ds[b_coord].loc[state_dataset[level_var]].values pres_dims = (time_var, pres_var, lat_var, lon_var) + surface_dims = (time_var, lat_var, lon_var) coords = { time_var: state_dataset[time_var], pres_var: pressure_levels, lat_var: state_dataset[lat_var], lon_var: state_dataset[lon_var], } + coords_surface = { + time_var: state_dataset[time_var], + lat_var: state_dataset[lat_var], + lon_var: state_dataset[lon_var], + } pressure_ds = xr.Dataset( data_vars={ f + pres_ending: xr.DataArray( @@ -76,15 +84,16 @@ def full_state_pressure_interpolation( coords=coords, ) pressure_ds[geopotential_var + pres_ending] = xr.DataArray( - coords=coords, dims=pres_dims, name=geopotential_var + coords=coords, dims=pres_dims, name=geopotential_var + pres_ending + ) + pressure_ds["mean_sea_level_" + pres_var] = xr.DataArray( + coords=coords_surface, dims=surface_dims, name="mean_sea_level_" + pres_var ) disable = False if verbose == 0: disable = True for t, time in tqdm(enumerate(state_dataset[time_var]), disable=disable): - pressure_grid = create_pressure_grid( - state_dataset[surface_pressure_var][t].values, model_a, model_b - ) + pressure_grid = create_pressure_grid(state_dataset[surface_pressure_var][t].values, model_a, model_b) geopotential_grid = geopotential_from_model_vars( surface_geopotential, state_dataset[surface_pressure_var][t].values, @@ -95,33 +104,33 @@ def full_state_pressure_interpolation( ) for interp_field in interp_fields: if interp_field == temperature_var: - pressure_ds[interp_field + pres_ending][t] = ( - interp_temperature_to_pressure_levels( - state_dataset[interp_field][t].values, - pressure_grid / 100.0, - pressure_levels, - state_dataset[surface_pressure_var][t].values / 100.0, - surface_geopotential, - state_dataset[temperature_var][t, -1].values, - ) + pressure_ds[interp_field + pres_ending][t] = interp_temperature_to_pressure_levels( + state_dataset[interp_field][t].values, + pressure_grid / 100.0, + pressure_levels, + state_dataset[surface_pressure_var][t].values / 100.0, + surface_geopotential, ) else: - pressure_ds[interp_field + pres_ending][t] = ( - interp_hybrid_to_pressure_levels( - state_dataset[interp_field][t].values, - pressure_grid / 100.0, - pressure_levels, - ) + pressure_ds[interp_field + pres_ending][t] = interp_hybrid_to_pressure_levels( + state_dataset[interp_field][t].values, + pressure_grid / 100.0, + pressure_levels, ) - pressure_ds[geopotential_var + pres_ending][t] = ( - interp_geopotential_to_pressure_levels( - geopotential_grid, - pressure_grid / 100.0, - pressure_levels, - state_dataset[surface_pressure_var][t].values / 100.0, - surface_geopotential, - state_dataset[temperature_var][t, -1].values, - ) + pressure_ds[geopotential_var + pres_ending][t] = interp_geopotential_to_pressure_levels( + geopotential_grid, + pressure_grid / 100.0, + pressure_levels, + state_dataset[surface_pressure_var][t].values / 100.0, + surface_geopotential, + state_dataset[temperature_var][t].values, + temp_level_index, + ) + pressure_ds["mean_sea_level_" + pres_var][t] = mean_sea_level_pressure( + state_dataset[surface_pressure_var][t].values, + state_dataset[temperature_var][t, temp_level_index].values, + pressure_grid[temp_level_index], + surface_geopotential, ) return pressure_ds @@ -140,9 +149,7 @@ def create_pressure_grid(surface_pressure, model_a, model_b): Returns: pressure_3d: 3D pressure field with dimensions of surface_pressure and number of levels from model_a and model_b. """ - assert ( - model_a.size == model_b.size - ), "Model pressure coefficient arrays do not match." + assert model_a.size == model_b.size, "Model pressure coefficient arrays do not match." if surface_pressure.ndim == 3: # Generate the 3D pressure field for a time series of surface pressure grids pressure_3d = np.zeros( @@ -186,16 +193,12 @@ def interp_hybrid_to_pressure_levels(model_var, model_pressure, interp_pressures ) log_interp_pressures = np.log(interp_pressures) for (i, j), v in np.ndenumerate(model_var[0]): - pressure_var[:, i, j] = np.interp( - log_interp_pressures, np.log(model_pressure[:, i, j]), model_var[:, i, j] - ) + pressure_var[:, i, j] = np.interp(log_interp_pressures, np.log(model_pressure[:, i, j]), model_var[:, i, j]) return pressure_var @njit -def interp_pressure_to_hybrid_levels( - pressure_var, pressure_levels, model_pressure, surface_pressure -): +def interp_pressure_to_hybrid_levels(pressure_var, pressure_levels, model_pressure, surface_pressure): """ Interpolate data field from hybrid sigma-pressure vertical coordinates to pressure levels. `model_pressure` and `pressure_levels` and 'surface_pressure' should have consistent units with each other. @@ -228,8 +231,8 @@ def interp_geopotential_to_pressure_levels( interp_pressures, surface_pressure, surface_geopotential, - temperature_lowest_level_k, - temp_level_index=-6, + temperature_k, + temp_level_index=-2, ): """ Interpolate geopotential field from hybrid sigma-pressure vertical coordinates to pressure levels. @@ -242,7 +245,7 @@ def interp_geopotential_to_pressure_levels( interp_pressures (np.ndarray): pressure levels for interpolation in units Pa or hPa. surface_pressure (np.ndarray): pressure at the surface in units Pa or hPa. surface_geopotential (np.ndarray): geopotential at the surface in units m^2/s^2. - temperature_lowest_level_k (np.ndarray): lowest model level temperature in Kelvin. + temperature_k (np.ndarray): temperature state on model levels in Kelvin. temp_level_index (int): index of vertical level where temperature is extracted for extrapolation to surface. Returns: @@ -255,21 +258,16 @@ def interp_geopotential_to_pressure_levels( dtype=model_var.dtype, ) log_interp_pressures = np.log(interp_pressures) + temperature_lowest_level_k = temperature_k[temp_level_index] for (i, j), v in np.ndenumerate(model_var[0]): - pressure_var[:, i, j] = np.interp( - log_interp_pressures, np.log(model_pressure[:, i, j]), model_var[:, i, j] - ) + pressure_var[:, i, j] = np.interp(log_interp_pressures, np.log(model_pressure[:, i, j]), model_var[:, i, j]) for pl, interp_pressure in enumerate(interp_pressures): if interp_pressure > surface_pressure[i, j]: - temp_surface_k = temperature_lowest_level_k[ - i, j - ] + ALPHA * temperature_lowest_level_k[i, j] * ( + temp_surface_k = temperature_lowest_level_k[i, j] + ALPHA * temperature_lowest_level_k[i, j] * ( surface_pressure[i, j] / model_pressure[temp_level_index, i, j] - 1 ) ln_p_ps = np.log(interp_pressure / surface_pressure[i, j]) - pressure_var[pl, i, j] = surface_geopotential[ - i, j - ] - RDGAS * temp_surface_k * ln_p_ps * ( + pressure_var[pl, i, j] = surface_geopotential[i, j] - RDGAS * temp_surface_k * ln_p_ps * ( 1 + ALPHA * ln_p_ps / 2.0 + (ALPHA * ln_p_ps) ** 2 / 6.0 ) return pressure_var @@ -282,7 +280,7 @@ def interp_temperature_to_pressure_levels( interp_pressures, surface_pressure, surface_geopotential, - temp_level_index=-6, + temp_level_index=-2, ): """ Interpolate temperature field from hybrid sigma-pressure vertical coordinates to pressure levels. @@ -309,9 +307,7 @@ def interp_temperature_to_pressure_levels( ) log_interp_pressures = np.log(interp_pressures) for (i, j), v in np.ndenumerate(model_var[0]): - pressure_var[:, i, j] = np.interp( - log_interp_pressures, np.log(model_pressure[:, i, j]), model_var[:, i, j] - ) + pressure_var[:, i, j] = np.interp(log_interp_pressures, np.log(model_pressure[:, i, j]), model_var[:, i, j]) for pl, interp_pressure in enumerate(interp_pressures): if interp_pressure > surface_pressure[i, j]: model_level_temp = model_var[temp_level_index, i, j] @@ -322,40 +318,22 @@ def interp_temperature_to_pressure_levels( temp_sea_level_k = temp_surface_k + LAPSE_RATE * surface_height temp_pl = np.minimum(temp_sea_level_k, 298.0) if surface_height > 2500.0: - gamma = ( - GRAVITY - / surface_geopotential[i, j] - * np.maximum(temp_pl - temp_surface_k, 0) - ) + gamma = GRAVITY / surface_geopotential[i, j] * np.maximum(temp_pl - temp_surface_k, 0) elif 2000.0 <= surface_height <= 2500.0: t_adjusted = 0.002 * ( - (2500 - surface_height) * temp_sea_level_k - + (surface_height - 2000.0) * temp_pl - ) - gamma = ( - GRAVITY - / surface_geopotential[i, j] - * (t_adjusted - temp_surface_k) + (2500 - surface_height) * temp_sea_level_k + (surface_height - 2000.0) * temp_pl ) + gamma = GRAVITY / surface_geopotential[i, j] * (t_adjusted - temp_surface_k) else: gamma = LAPSE_RATE - a_ln_p = ( - gamma - * RDGAS - / GRAVITY - * np.log(interp_pressure / surface_pressure[i, j]) - ) - pressure_var[pl, i, j] = temp_surface_k * ( - 1 + a_ln_p + 0.5 * a_ln_p**2 + 1 / 6.0 * a_ln_p**3 - ) + a_ln_p = gamma * RDGAS / GRAVITY * np.log(interp_pressure / surface_pressure[i, j]) + pressure_var[pl, i, j] = temp_surface_k * (1 + a_ln_p + 0.5 * a_ln_p**2 + 1 / 6.0 * a_ln_p**3) return pressure_var @njit -def geopotential_from_model_vars( - surface_geopotential, surface_pressure, temperature, mixing_ratio, model_a, model_b -): +def geopotential_from_model_vars(surface_geopotential, surface_pressure, temperature, mixing_ratio, model_a, model_b): """ Calculate geopotential from the base state variables. Geopotential height is calculated by adding thicknesses calculated within each half-model-level to account for variations in temperature and moisture between grid cells. @@ -387,17 +365,17 @@ def geopotential_from_model_vars( virtual_temperature = temperature * (1.0 + gamma * mixing_ratio) m = model_geopotential.shape[-3] - 1 h = half_geopotential.shape[-3] - 1 - model_geopotential[m] = surface_geopotential + RDGAS * virtual_temperature[ - m - ] * np.log(surface_pressure / model_pressure[m]) + model_geopotential[m] = surface_geopotential + RDGAS * virtual_temperature[m] * np.log( + surface_pressure / model_pressure[m] + ) for i in range(1, model_geopotential.shape[-3]): - half_geopotential[h] = model_geopotential[m] + RDGAS * virtual_temperature[ - m - ] * np.log(model_pressure[m] / half_pressure[h]) + half_geopotential[h] = model_geopotential[m] + RDGAS * virtual_temperature[m] * np.log( + model_pressure[m] / half_pressure[h] + ) m -= 1 - model_geopotential[m] = half_geopotential[h] + RDGAS * virtual_temperature[ - m - ] * np.log(half_pressure[h] / model_pressure[m]) + model_geopotential[m] = half_geopotential[h] + RDGAS * virtual_temperature[m] * np.log( + half_pressure[h] / model_pressure[m] + ) h -= 1 return model_geopotential @@ -435,14 +413,10 @@ def mean_sea_level_pressure( if np.abs(surface_geopotential[i, j] / GRAVITY) < 1e-4: mslp[i, j] = surface_pressure_pa[i, j] else: - temp_surface_k = temperature_lowest_level_k[ - i, j - ] + ALPHA * temperature_lowest_level_k[i, j] * ( + temp_surface_k = temperature_lowest_level_k[i, j] + ALPHA * temperature_lowest_level_k[i, j] * ( surface_pressure_pa[i, j] / pressure_lowest_level_pa[i, j] - 1 ) - temp_sealevel_k = ( - temp_surface_k + LAPSE_RATE * surface_geopotential[i, j] / GRAVITY - ) + temp_sealevel_k = temp_surface_k + LAPSE_RATE * surface_geopotential[i, j] / GRAVITY if (temp_surface_k <= 290.5) and (temp_sealevel_k > 290.5): gamma = GRAVITY / surface_geopotential[i, j] * (290.5 - temp_surface_k) @@ -455,15 +429,5 @@ def mean_sea_level_pressure( temp_surface_k = 0.5 * (255.0 + temp_surface_k) beta = surface_geopotential[i, j] / (RDGAS * temp_surface_k) x = gamma * surface_geopotential[i, j] / (GRAVITY * temp_surface_k) - # mslp[i, j] = surface_pressure_pa[i, j] * np.exp( - # beta - # * ( - # 1 - # - alpha_adjusted * beta / 2.0 - # + ((alpha_adjusted * beta) ** 2) / 3.0 - # ) - # ) - mslp[i, j] = surface_pressure_pa[i, j] * np.exp( - beta * (1.0 - x / 2.0 + x**2 / 3.0) - ) + mslp[i, j] = surface_pressure_pa[i, j] * np.exp(beta * (1.0 - x / 2.0 + x**2 / 3.0)) return mslp diff --git a/pyproject.toml b/pyproject.toml index 47b6cca3..97041bf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ version = {file = "credit/VERSION"} [tool.ruff] src = ["credit", "applications", "tests"] +line-length = 120 [tool.ruff.lint.pydocstyle] convention = "google" \ No newline at end of file From e811451439035213151b8ba0ffaf81ea62438613 Mon Sep 17 00:00:00 2001 From: David John Gagne Date: Wed, 12 Feb 2025 18:06:51 -0700 Subject: [PATCH 05/13] Update dependencies. --- credit/data_conversions.py | 184 -------------- credit/diagnostics.py | 496 ------------------------------------- environment_cpu.yml | 14 +- environment_gpu.yml | 6 +- pyproject.toml | 4 +- requirements.txt | 1 + 6 files changed, 10 insertions(+), 695 deletions(-) delete mode 100644 credit/data_conversions.py delete mode 100644 credit/diagnostics.py diff --git a/credit/data_conversions.py b/credit/data_conversions.py deleted file mode 100644 index 1d212f8b..00000000 --- a/credit/data_conversions.py +++ /dev/null @@ -1,184 +0,0 @@ -import logging - -import torch -import xarray as xr -from geocat.comp.interpolation import interp_hybrid_to_pressure - -logger = logging.getLogger(__name__) - - -class dataConverter: - """ - utility class for converting from various formats to xarray dataset. - e.g. in train.py, Tensor to Dataset - e.g. in predict.py DataArray to Dataset - """ - - def __init__(self, conf, new_levels=None) -> None: - self.conf = conf - static_ds = xr.open_dataset(self.conf["loss"]["latitude_weights"]) - self.lat = static_ds.latitude.values - self.lon = static_ds.longitude.values - self.SP = static_ds.SP - self.level_info = xr.open_dataset( - "/glade/derecho/scratch/dkimpara/nwp_files/hy_to_pressure.nc" - ) - self.new_levels = new_levels # levels to interpolate to - - def tensor_to_dataset(self, tensor, forecast_datetimes): - return self.dataArrays_to_dataset( - *self.tensor_to_dataArray(tensor, forecast_datetimes) - ) - - def tensor_to_pressure_lev_dataset(self, tensor, forecast_datetimes): - return self.dataset_to_pressure_levels( - self.tensor_to_dataset(tensor, forecast_datetimes) - ) - - def concat_and_reshape(self, x1, x2): # will be useful for getting back to tensor - x1 = x1.view( - x1.shape[0], - x1.shape[1], - x1.shape[2] * x1.shape[3], - x1.shape[4], - x1.shape[5], - ) - x_concat = torch.cat((x1, x2), dim=2) - return x_concat.permute(0, 2, 1, 3, 4) - - def split_and_reshape(self, tensor): - # get the number of levels - levels = self.conf["model"]["levels"] - # get number of channels - channels = len(self.conf["data"]["variables"]) - single_level_channels = len(self.conf["data"]["surface_variables"]) - - tensor1 = tensor[:, : int(channels * levels), :, :, :] - tensor2 = tensor[:, -int(single_level_channels) :, :, :, :] - tensor1 = tensor1.view( - tensor1.shape[0], - channels, - levels, - tensor1.shape[2], - tensor1.shape[3], - tensor1.shape[4], - ) - return tensor1, tensor2 - - def tensor_to_dataArray(self, pred, forecast_datetimes): - """ - Convert tensor to DataArray - - Args: - pred: Tensor with shape (B, C, T, lat, lon) - forecast_datetimes: array-like - """ - - # subset upper air and surface variables - tensor_upper_air, tensor_single_level = self.split_and_reshape(pred) - - tensor_upper_air = tensor_upper_air.squeeze(3) - tensor_single_level = tensor_single_level.squeeze( - 2 - ) # take out time dim=1, keep batch dim - - # upper air variables - darray_upper_air = xr.DataArray( - tensor_upper_air.detach().numpy(), # could put this in top level, this might be faster - dims=["datetime", "vars", "level", "latitude", "longitude"], - coords=dict( - datetime=forecast_datetimes, - vars=self.conf["data"]["variables"], - level=range(self.conf["model"]["levels"]), - latitude=self.lat, - longitude=self.lon, - ), - ) - - # diagnostics and surface variables - darray_single_level = xr.DataArray( - tensor_single_level.detach().numpy(), - dims=["datetime", "vars", "latitude", "longitude"], - coords=dict( - datetime=forecast_datetimes, - vars=self.conf["data"]["surface_variables"], - latitude=self.lat, - longitude=self.lon, - ), - ) - - # return dataarrays as outputs - return darray_upper_air, darray_single_level - - def dataArrays_to_dataset(self, darray_upper_air, darray_single_level): - # dArrays need to have a time dim - ds_x = darray_upper_air.to_dataset(dim="vars") - ds_surf = darray_single_level.to_dataset(dim="vars") - ds = xr.merge([ds_x, ds_surf]) - - # dataset as output - return ds - - def dataset_to_pressure_levels(self, dataset): - """ - unless specified in class init, - interpolation defaults to the pressure levels (in Pa): - - [100000., 92500., 85000., 70000., 50000., 40000., - 30000., 25000., 20000., 15000., 10000., 7000., 5000., - 3000., 2000., 1000., 700., 500., 300., 200., 100.], - - """ - dataset = dataset.assign_coords( - {"level": self.level_info.ref_Pa.values} - ) # these levels are from high to low - atmos_dataset = dataset[self.conf["data"]["variables"]] - SP = self.SP.expand_dims(dim={"datetime": dataset.datetime.values}) - - interp_kwargs = { - "lev_dim": "level", - "extrapolate": True, - "variable": "other", - } # need to build this kwarg dict because interpolation errors when given None - if self.new_levels: - interp_kwargs["new_levels"] = self.new_levels - - # modify atmos slice of dataset - atmos = atmos_dataset.map( - interp_hybrid_to_pressure, - args=[ - SP, - self.level_info.a_model - / 100000, # IFS a_coeffs are in Pa units. geocat computes pressure as hya * p0 + hyb * psfc - self.level_info.b_model, - ], # fix this patch if we ever use coeffs that arent IFS - **interp_kwargs, - ) - - dataset = xr.merge([atmos, dataset[self.conf["data"]["surface_variables"]]]) - return ( - dataset # original dataset with interpolated atmos vars and plev coordinate - ) - - -if __name__ == "__main__": - from os.path import join - import yaml - - test_dir = ( - "/glade/work/dkimpara/repos/global/miles-credit/results/test_files_quarter" - ) - config = join(test_dir, "model.yml") - with open(config) as cf: - conf = yaml.load(cf, Loader=yaml.FullLoader) - - y_pred = torch.load(join(test_dir, "pred.pt")) - # y = torch.load(join(test_dir, "y.pt")) - - converter = dataConverter(conf) - ds = converter.tensor_to_dataset(y_pred, [0]) - - ##### test hybrid to pressure ### - pressure = converter.dataset_to_pressure_levels(ds) - print(pressure) - print(f"nulls: {pressure.U.isnull().sum().values}") diff --git a/credit/diagnostics.py b/credit/diagnostics.py deleted file mode 100644 index 9e28c22d..00000000 --- a/credit/diagnostics.py +++ /dev/null @@ -1,496 +0,0 @@ -import os -from os.path import join, expandvars -import typing as t -import numpy as np -import xarray as xr -import matplotlib.pyplot as plt -import cartopy.crs as ccrs -import torch - -# from weatherbench2.derived_variables import ZonalEnergySpectrum -# WEC limiting spectrum shit cause weather bench2 not installed. - -import logging - -logger = logging.getLogger(__name__) - - -class Diagnostics: - """ - program flow: this class, sets up necessary pipeline, converts to xarray and/or pressure levels, calls diagnostics - """ - - def __init__(self, conf, init_datetime, data_converter): - """data_converter is a dataConverter object""" - self.conf = conf - self.converter = data_converter - self.w_lat, self.w_var = self.get_weights() - - self.diagnostics = [] - self.plev_diagnostics = [] - diag_conf = self.conf["diagnostics"] - if diag_conf["use_spectrum_vis"]: - logger.info("computing spectrum visualizations") - # save directory for spectra plots - plot_save_loc = join( - expandvars(self.conf["save_loc"]), f"forecasts/spectra_{init_datetime}/" - ) - os.makedirs(plot_save_loc, exist_ok=True) - - spectrum_vis = ZonalSpectrumVis( - self.conf, self.w_lat, self.w_var, plot_save_loc - ) - self.diagnostics.append(spectrum_vis) - - if diag_conf["use_KE_diagnostics"]: # only plotting summary spectra for KE - logger.info("computing KE visualizations") - plot_save_loc = join( - expandvars(self.conf["save_loc"]), f"forecasts/ke_{init_datetime}/" - ) - os.makedirs(plot_save_loc, exist_ok=True) - os.makedirs(join(plot_save_loc, "ke_diff"), exist_ok=True) - - ke_vis = KE_Diagnostic(self.conf, self.w_lat, self.w_var, plot_save_loc) - self.plev_diagnostics.append(ke_vis) - - def __call__(self, pred_ds, y_ds, fh): - metric_dict = {} - - for diagnostic in self.diagnostics: # non-plev diagnostics - diagnostic(pred_ds, y_ds, fh) - - # guard clauses for plev computation - if not self.plev_diagnostics: - return {} - if ( - self.conf["diagnostics"]["plev_summary_only"] - and fh not in self.conf["diagnostics"]["summary_plot_fhs"] - ): - logger.info(f"skipping plev diagnostics for fh {fh}") - return {} - - logger.info(f"computing plev diagnostics for fh {fh}") - pred_pressure = self.converter.dataset_to_pressure_levels(pred_ds).compute() - y_pressure = self.converter.dataset_to_pressure_levels(y_ds).compute() - for diagnostic in self.plev_diagnostics: - metric_dict = metric_dict | diagnostic(pred_pressure, y_pressure, fh) - # only return metrics if not doing summarys - return {} if self.conf["diagnostics"]["plev_summary_only"] else metric_dict - - def get_weights(self): - """ - gets variable and latitude weights from latitude file. - """ - w_lat = None - if self.conf["loss"]["use_latitude_weights"]: - lat = xr.open_dataset(self.conf["loss"]["latitude_weights"])["latitude"] - w_lat = np.cos(np.deg2rad(lat)) - self.w_lat = w_lat / w_lat.mean() - - w_var = None - if self.conf["loss"]["use_variable_weights"]: - var_weights = [ - value if isinstance(value, list) else [value] - for value in self.conf["loss"]["variable_weights"].values() - ] - var_weights = [item for sublist in var_weights for item in sublist] - - return w_lat, w_var - - -def calculate_KE(dataset): - wind_squared = dataset.U**2 + dataset.V**2 - return -1 * 0.5 * wind_squared.integrate("plev") - # negative needed because of doing integration 'backwards' could also reverse the plev coord - - -class KE_Diagnostic: - def __init__(self, conf, w_lat, w_var, plot_save_loc): - self.conf = conf - self.w_lat = w_lat - self.plot_save_loc = plot_save_loc - - self.summary_plot_fhs = conf["diagnostics"]["summary_plot_fhs"] - for k, v in conf["diagnostics"]["ke_vis"].items(): - setattr(self, k, v) - - # if self.use_KE_spectrum_vis: - # self.zonal_spectrum_calculator = ZonalEnergySpectrum("KE") - # if self.summary_plot_fhs: - # self.KE_fig, self.KE_axs = plt.subplots(ncols=1, figsize=(5, 5)) - # self.KE_axs = [self.KE_axs] - - def __call__(self, pred_ds, y_ds, fh): - """ - pressure level datasets - """ - pred_ke = calculate_KE(pred_ds).compute() - y_ke = calculate_KE(y_ds).compute() - - if self.use_KE_spectrum_vis: - self.KE_spectrum_vis(pred_ke, y_ke, fh) - - if self.use_KE_difference_vis: - self.KE_difference_vis(pred_ke, y_ke, fh) - - metric_dict = self.avg_KE_metric(pred_ke, y_ke) - return metric_dict - - def avg_KE_metric(self, pred_ke, y_ke): - diff = np.abs(pred_ke - y_ke) - weighted = diff.weighted(self.w_lat).mean() - return {"avg_KE_difference": weighted.values} - - def KE_difference_vis(self, pred_ke, y_ke, fh): - # ke_diff = pred_ke - y_ke - - # Plotting - fig = plt.figure(figsize=(10, 6)) - ax = fig.add_subplot(1, 1, 1, projection=ccrs.EckertIII()) - - # Plot data using colormesh - # divnorm = colors.TwoSlopeNorm(vcenter=0.0) # center cmap at 0 - datetime_str = np.datetime_as_string( - pred_ke.datetime.values[0], unit="h", timezone="UTC" - ) - ax.set_title(f"pred_ke - y_ke | fh={fh} {datetime_str}") - # Add coastlines and gridlines - ax.coastlines() - ax.gridlines() - - # Add colorbar - y_ke_total = y_ke.sum().values - pred_ke_total = pred_ke.sum().values - text = ( - f"Total ERA5 KE: {y_ke_total:.2e}; " - f"Difference: {pred_ke_total - y_ke_total:.2e}; " - f"abs diff: {np.abs(pred_ke - y_ke).sum().values:.2e}" - ) - ax.annotate( - text, - xy=(0.5, 0), - xytext=(0.5, -0.1), - xycoords="axes fraction", - textcoords="axes fraction", - ha="center", - va="top", - ) - - # save figure - filepath = join(self.plot_save_loc, f"ke_diff/ke_diff_{datetime_str}.pdf") - fig.savefig(filepath, format="pdf") - - def get_avg_spectrum_ke(self, da): - ds = xr.Dataset({"KE": da}) - ds_spectrum = self.zonal_spectrum_calculator.compute(ds) - ds_spectrum = interpolate_spectral_frequencies(ds_spectrum, "zonal_wavenumber") - ds_spectrum = ds_spectrum.weighted(self.w_lat).mean(dim="latitude") - return ds_spectrum - - def KE_spectrum_vis(self, pred_ke, y_ke, fh): - # plot on summary plot - - if int(fh) in self.summary_plot_fhs: # plot some fhs onto a single plot - avg_pred_spectrum = self.get_avg_spectrum_ke(pred_ke) - avg_y_spectrum = self.get_avg_spectrum_ke(y_ke) - - fh_idx = self.summary_plot_fhs.index(fh) - self.KE_fig, self.KE_axs = self.plot_avg_spectrum( - avg_pred_spectrum, - avg_y_spectrum, - self.KE_fig, - self.KE_axs, - alpha=1 - fh_idx / len(self.summary_plot_fhs), - label=f"fh={fh}", - ) - # if fh == self.summary_plot_fhs[-1]: - for ax in self.KE_axs: # overwrite every time in case of crash - ax.legend() - self.KE_fig.savefig(join(self.plot_save_loc, f"ke_spectra_summary{fh}")) - logger.info( - f"saved summary plot to {join(self.plot_save_loc, 'ke_spectra_summary')}" - ) - - def plot_avg_spectrum( - self, avg_pred_spectrum, avg_y_spectrum, fig, axs, alpha=1, label=None - ): - # copied from spectrum diagnostic function - for ax in axs: - ax.set_yscale("log") - ax.set_xscale("log") - curr_ax = 0 - avg_pred_spectrum.plot( - x="wavelength", ax=axs[curr_ax], color="r", alpha=alpha, label=label - ) - avg_y_spectrum.plot(x="wavelength", ax=axs[curr_ax], color="0") - - axs[curr_ax].set_title("KE Spectrum") - ticks = axs[curr_ax].get_xticks() # rescale x axis to km - axs[curr_ax].set_xticks(ticks, ticks / 1000) - axs[curr_ax].autoscale_view() - axs[curr_ax].set_xlabel("Wavelength (km)") - axs[curr_ax].set_ylabel("Power") - - return fig, axs - - -class ZonalSpectrumVis: - def __init__(self, conf, w_lat, w_var, plot_save_loc): - """ """ - self.conf = conf - self.w_lat = w_lat - self.plot_save_loc = plot_save_loc - self.summary_plot_fhs = conf["diagnostics"]["summary_plot_fhs"] - - # this replaces unpacking the dictionary (below) - for k, v in conf["diagnostics"]["spectrum_vis"].items(): - setattr(self, k, v) - # vis_conf = conf['diagnostics']['spectrum_vis'] - # self.atmos_variables = vis_conf['atmos_variables'] - # self.atmos_levels= vis_conf['atmos_levels'] - # self.single_level_variables = vis_conf['single_level_variables'] - - self.zonal_spectrum_calculator = None - # self.zonal_spectrum_calculator = ZonalEnergySpectrum( - # self.atmos_variables + self.single_level_variables - # ) - - self.ifs_levels = xr.open_dataset( - "/glade/derecho/scratch/dkimpara/nwp_files/ifs_levels.nc" - ) - # self.figsize = vis_conf['figsize'] - if len(self.figsize) == 0: - self.figsize = self.figsize - else: - num_vars = len(self.atmos_variables) * len(self.atmos_levels) + len( - self.single_level_variables - ) - self.figsize = (5 * num_vars, 5) - - if self.summary_plot_fhs: - self.summary_fig, self.summary_axs = plt.subplots( - ncols=len(self.atmos_variables) * len(self.atmos_levels) - + len(self.single_level_variables), - figsize=self.figsize, - ) - - def __call__(self, pred_ds, y_ds, fh): - """ - pred, y can be normalized or unnormalized tensors. - """ - # compute spectrum and add epsilon to avoid division by zero - epsilon = 1e-7 - pred_spectrum = self.zonal_spectrum_calculator.compute(pred_ds) + epsilon - y_spectrum = self.zonal_spectrum_calculator.compute(y_ds) + epsilon - - # compute average zonal spectrum - avg_pred_spectrum = self.get_avg_spectrum(pred_spectrum) - avg_y_spectrum = self.get_avg_spectrum(y_spectrum) - - # visualize - fig, axs = plt.subplots( - ncols=len(self.atmos_variables) * len(self.atmos_levels) - + len(self.single_level_variables), - figsize=self.figsize, - ) - datetime_str = np.datetime_as_string( - pred_ds.datetime.values[0], unit="h", timezone="UTC" - ) - - fig.suptitle(f"t={datetime_str}, fh={fh}") - fig, axs = self.plot_avg_spectrum(avg_pred_spectrum, avg_y_spectrum, fig, axs) - fig.savefig(join(self.plot_save_loc, f"spectra_{datetime_str}")) - - # plot on summary plot - - if fh in self.summary_plot_fhs: # plot some fhs onto a single plot - fh_idx = self.summary_plot_fhs.index(fh) - self.summary_fig, self.summary_axs = self.plot_avg_spectrum( - avg_pred_spectrum, - avg_y_spectrum, - self.summary_fig, - self.summary_axs, - alpha=1 - fh_idx / len(self.summary_plot_fhs), - label=f"fh={fh}", - ) - for ax in self.summary_axs: # overwrite every time in case of crash - ax.legend() - self.summary_fig.savefig(join(self.plot_save_loc, "spectra_summary")) - logger.info( - f"saved summary plot to {join(self.plot_save_loc, 'spectra_summary')}" - ) - - def plot_avg_spectrum( - self, avg_pred_spectrum, avg_y_spectrum, fig, axs, alpha=1, label=None - ): - for ax in axs: - ax.set_yscale("log") - ax.set_xscale("log") - - curr_ax = 0 - for level in self.atmos_levels: - for variable in self.atmos_variables: - avg_pred_spectrum[variable].sel(level=level).plot( - x="wavelength", ax=axs[curr_ax], color="r", alpha=alpha, label=label - ) - avg_y_spectrum[variable].sel(level=level).plot( - x="wavelength", ax=axs[curr_ax], color="0" - ) - - axs[curr_ax].set_title( - f"{variable} {self.ifs_levels.ref_hPa.sel(level=level).values}" - ) - ticks = axs[curr_ax].get_xticks() # rescale x axis to km - axs[curr_ax].set_xticks(ticks, ticks / 1000) - axs[curr_ax].autoscale_view() - axs[curr_ax].set_xlabel("Wavelength (km)") - axs[curr_ax].set_ylabel("Power") - curr_ax += 1 - - for variable in self.single_level_variables: - avg_pred_spectrum[variable].plot( - x="wavelength", ax=axs[curr_ax], color="r", alpha=alpha, label=label - ) - avg_y_spectrum[variable].plot(x="wavelength", ax=axs[curr_ax], color="0") - axs[curr_ax].set_title(variable) - - ticks = axs[curr_ax].get_xticks() # rescale x axis to km - axs[curr_ax].set_xticks(ticks, ticks / 1000) - axs[curr_ax].autoscale_view() - axs[curr_ax].set_xlabel("Wavelength (km)") - axs[curr_ax].set_ylabel("Power") - curr_ax += 1 - - return fig, axs - - def get_avg_spectrum(self, ds_spectrum): - ds_spectrum = ds_spectrum.sel(level=self.atmos_levels) - ds_spectrum = interpolate_spectral_frequencies(ds_spectrum, "zonal_wavenumber") - ds_spectrum = ds_spectrum.weighted(self.w_lat).mean(dim="latitude") - return ds_spectrum - - -# from weatherbench, slightly modified -def interpolate_spectral_frequencies( - spectrum: xr.DataArray, - wavenumber_dim: str, - frequencies: t.Optional[t.Sequence[float]] = None, - method: str = "linear", - **interp_kwargs: t.Optional[dict[str, t.Any]], -) -> xr.DataArray: - """ - Interpolate frequencies in `spectrum` to common values. - - Args: - spectrum: Data as produced by ZonalEnergySpectrum.compute. - wavenumber_dim: Dimension that indexes wavenumber, e.g. 'zonal_wavenumber' - if `spectrum` is produced by ZonalEnergySpectrum. - frequencies: Optional 1-D sequence of frequencies to interpolate to. By - default, use the most narrow range of frequencies in `spectrum`. - method: Interpolation method passed on to DataArray.interp. - interp_kwargs: Additional kwargs passed on to DataArray.interp. - - Returns: - New DataArray with dimension "frequency" replacing the "wavenumber" dim in `spectrum`. - """ - - if set(spectrum.frequency.dims) != set((wavenumber_dim, "latitude")): - raise ValueError( - f"{spectrum.frequency.dims=} was not a permutation of " - f'("{wavenumber_dim}", "latitude")' - ) - - if frequencies is None: - freq_min = spectrum.frequency.max("latitude").min(wavenumber_dim).data - freq_max = spectrum.frequency.min("latitude").max(wavenumber_dim).data - frequencies = np.linspace( - freq_min, freq_max, num=spectrum.sizes[wavenumber_dim] - ) - frequencies = np.asarray(frequencies) - if frequencies.ndim != 1: - raise ValueError(f"Expected 1-D frequencies, found {frequencies.shape=}") - - def interp_at_one_lat(da: xr.DataArray) -> xr.DataArray: - if ( - len(da.latitude.values.shape) > 0 - ): # latitude weirdly not squeezed out by groupby sometimes - da = da.squeeze("latitude") - da = ( - da.swap_dims({wavenumber_dim: "frequency"}) - .drop_vars(wavenumber_dim) - .interp(frequency=frequencies, method=method, **interp_kwargs) - ) - # Interp didn't deal well with the infinite wavelength, so just reset λ as.. - da["wavelength"] = 1 / da.frequency - da["wavelength"] = da["wavelength"].assign_attrs(units="m") - return da - - return spectrum.groupby("latitude", squeeze=True).apply(interp_at_one_lat) - - -def anomaly_correlation_coefficient(pred, true): - pred = pred.float() - true = true.float() - - B, C, H, W = pred.size() - - # Flatten the spatial dimensions - pred_flat = pred.view(B, C, -1) - true_flat = true.view(B, C, -1) - - # Mean over spatial dimensions - pred_mean = torch.mean(pred_flat, dim=-1, keepdim=True) - true_mean = torch.mean(true_flat, dim=-1, keepdim=True) - - # Anomaly calculation - pred_anomaly = pred_flat - pred_mean - true_anomaly = true_flat - true_mean - - # Covariance matrix - covariance_matrix = torch.bmm(pred_anomaly, true_anomaly.transpose(1, 2)) / ( - H * W - 1 - ) - - # Variance terms - pred_var = torch.bmm(pred_anomaly, pred_anomaly.transpose(1, 2)) / (H * W - 1) - true_var = torch.bmm(true_anomaly, true_anomaly.transpose(1, 2)) / (H * W - 1) - - # Anomaly Correlation Coefficient - acc_numerator = torch.einsum("bii->b", covariance_matrix).sum() - acc_denominator = torch.sqrt( - torch.einsum("bii->b", pred_var).sum() * torch.einsum("bii->b", true_var).sum() - ) - - # Avoid division by zero - epsilon = 1e-8 - acc = acc_numerator / (acc_denominator + epsilon) - - return acc.item() - - -if __name__ == "__main__": - import yaml - from credit.data_conversions import dataConverter - import datetime - - test_dir = ( - "/glade/work/dkimpara/repos/global/miles-credit/results/test_files_quarter" - ) - config = join(test_dir, "model.yml") - with open(config) as cf: - conf = yaml.load(cf, Loader=yaml.FullLoader) - - y_pred = torch.load(join(test_dir, "pred.pt")) - y = torch.load(join(test_dir, "y.pt")) - - data_converter = dataConverter(conf) - diagnostic = Diagnostics(conf, 0, data_converter=dataConverter(conf)) - - time = datetime.datetime.now() - pred_ds = data_converter.tensor_to_dataset(y_pred.float(), [time]) - y_ds = data_converter.tensor_to_dataset(y.float(), [time]) - metrics = diagnostic(pred_ds, y_ds, 1) - metrics = diagnostic(pred_ds, y_ds + 1, 2) - metrics = diagnostic(y_ds, pred_ds, 3) - for k, v in metrics.items(): - print(v) diff --git a/environment_cpu.yml b/environment_cpu.yml index c8b8a618..381bbc9b 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -3,9 +3,9 @@ channels: - pytorch - conda-forge dependencies: - - python=3.11 + - python=3.12 - pip - - numpy<2 + - numpy - pandas - matplotlib - cartopy @@ -15,9 +15,6 @@ dependencies: - pytest - xarray - netcdf4 - - pytorch - - torchvision - - cpuonly - pyyaml - cartopy - dask @@ -25,15 +22,15 @@ dependencies: - dask-jobqueue - zarr - jupyter - - geocat-comp - - xesmf - mpi4py - mpich - myst-parser - pip: + - torch + - torchvision + - pre-commit - einops - echo-opt - - optuna==3.6.0 - bridgescaler - rotary-embedding-torch - segmentation-models-pytorch>=0.3.4 @@ -44,5 +41,4 @@ dependencies: - torch_geometric - sphinx-autoapi - timm - - wandb - . diff --git a/environment_gpu.yml b/environment_gpu.yml index 796dc3b4..abe3dc59 100644 --- a/environment_gpu.yml +++ b/environment_gpu.yml @@ -3,10 +3,10 @@ channels: - conda-forge dependencies: - - python=3.11 + - python=3.12 - pip - jupyter - - numpy<1.24 + - numpy - pandas - pyarrow - pyyaml @@ -15,7 +15,6 @@ dependencies: ## dev and ML ops - pytest - ruff - - wandb - pre_commit - myst-parser @@ -24,7 +23,6 @@ dependencies: - netcdf4 - zarr - xesmf - - geocat-comp - dask - dask-jobqueue - matplotlib diff --git a/pyproject.toml b/pyproject.toml index 97041bf5..dd63d092 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,10 +35,10 @@ dependencies = [ "vector-quantize-pytorch", "haversine", "pvlib", - "geocat-comp", "torch-harmonics", "torch_geometric", - "sphinx", + "pre-commit", + "ruff" ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 958836f8..385d464d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,3 +26,4 @@ pvlib pre-commit torch-harmonics torch_geometric +ruff From 411a64158043b24ae02df3f12d4f564c9fefeb90 Mon Sep 17 00:00:00 2001 From: David John Gagne Date: Thu, 13 Feb 2025 18:11:50 -0700 Subject: [PATCH 06/13] Update to interpolation code to make it closer in line with ECMWF. --- credit/VERSION | 2 +- credit/interp.py | 214 +++++++++++++++++++++++++---------------------- 2 files changed, 113 insertions(+), 103 deletions(-) diff --git a/credit/VERSION b/credit/VERSION index a73a8518..dbaa6bc6 100644 --- a/credit/VERSION +++ b/credit/VERSION @@ -1 +1 @@ -2024.1.0 +2025.1.0 diff --git a/credit/interp.py b/credit/interp.py index 82247190..88fd3687 100644 --- a/credit/interp.py +++ b/credit/interp.py @@ -2,7 +2,7 @@ from numba import njit import xarray as xr from tqdm import tqdm -from .physics_constants import RDGAS, RVGAS, GRAVITY +from .physics_constants import RDGAS, GRAVITY import os @@ -23,8 +23,8 @@ def full_state_pressure_interpolation( level_var: str = "level", model_level_file: str = "../credit/metadata/ERA5_Lev_Info.nc", verbose: int = 1, - a_coord: str = "a_model", - b_coord: str = "b_model", + a_coord: str = "a_half", + b_coord: str = "b_half", temp_level_index: int = -2, ) -> xr.Dataset: """ @@ -47,8 +47,8 @@ def full_state_pressure_interpolation( level_var (str): name of level coordinate model_level_file (str): relative path to file containing model levels. verbose (int): verbosity level. If verbose > 0, print progress. - a_coord (str): Name of A weight in sigma coordinate formula. 'a_model' by default. - b_coord (str): Name of B weight in sigma coordinate formula. 'b_model' by default. + a_coord (str): Name of A weight in sigma coordinate formula. 'a_half' by default. + b_coord (str): Name of B weight in sigma coordinate formula. 'b_half' by default. temp_level_index (int): vertical index of the temperature level used for interpolation below ground. Returns: pressure_ds (xr.Dataset): Dataset containing pressure interpolated variables. @@ -56,8 +56,8 @@ def full_state_pressure_interpolation( path_to_file = os.path.abspath(os.path.dirname(__file__)) model_level_file = os.path.join(path_to_file, model_level_file) with xr.open_dataset(model_level_file) as mod_lev_ds: - model_a = mod_lev_ds[a_coord].loc[state_dataset[level_var]].values - model_b = mod_lev_ds[b_coord].loc[state_dataset[level_var]].values + a_half = mod_lev_ds[a_coord].loc[state_dataset[level_var]].values + b_half = mod_lev_ds[b_coord].loc[state_dataset[level_var]].values pres_dims = (time_var, pres_var, lat_var, lon_var) surface_dims = (time_var, lat_var, lon_var) coords = { @@ -93,14 +93,16 @@ def full_state_pressure_interpolation( if verbose == 0: disable = True for t, time in tqdm(enumerate(state_dataset[time_var]), disable=disable): - pressure_grid = create_pressure_grid(state_dataset[surface_pressure_var][t].values, model_a, model_b) + pressure_grid, half_pressure_grid = create_pressure_grid( + state_dataset[surface_pressure_var][t].values.astype(np.float64), a_half, b_half + ) geopotential_grid = geopotential_from_model_vars( - surface_geopotential, - state_dataset[surface_pressure_var][t].values, - state_dataset[temperature_var][t].values, - state_dataset[q_var][t].values, - model_a, - model_b, + surface_geopotential.astype(np.float64), + state_dataset[surface_pressure_var][t].values.astype(np.float64), + state_dataset[temperature_var][t].values.astype(np.float64), + state_dataset[q_var][t].values.astype(np.float64), + a_half, + b_half, ) for interp_field in interp_fields: if interp_field == temperature_var: @@ -110,6 +112,7 @@ def full_state_pressure_interpolation( pressure_levels, state_dataset[surface_pressure_var][t].values / 100.0, surface_geopotential, + geopotential_grid, ) else: pressure_ds[interp_field + pres_ending][t] = interp_hybrid_to_pressure_levels( @@ -123,12 +126,12 @@ def full_state_pressure_interpolation( pressure_levels, state_dataset[surface_pressure_var][t].values / 100.0, surface_geopotential, + geopotential_grid, state_dataset[temperature_var][t].values, - temp_level_index, ) pressure_ds["mean_sea_level_" + pres_var][t] = mean_sea_level_pressure( state_dataset[surface_pressure_var][t].values, - state_dataset[temperature_var][t, temp_level_index].values, + state_dataset[temperature_var][t].values, pressure_grid[temp_level_index], surface_geopotential, ) @@ -136,41 +139,96 @@ def full_state_pressure_interpolation( @njit -def create_pressure_grid(surface_pressure, model_a, model_b): +def create_pressure_grid(surface_pressure, model_a_half, model_b_half): """ Create a 3D pressure field at model levels from the surface pressure field and the hybrid sigma-pressure coefficients from ECMWF. Conversion is `pressure_3d = a + b * SP`. Args: surface_pressure (np.ndarray): (time, latitude, longitude) or (latitude, longitude) grid in units of Pa. - model_a (np.ndarray): a coefficients at each model level being used in units of Pa. - model_b (np.ndarray): b coefficients at each model level being used (unitness). + model_a_half (np.ndarray): a coefficients at each model level being used in units of Pa. + model_b_half (np.ndarray): b coefficients at each model level being used (unitness). Returns: pressure_3d: 3D pressure field with dimensions of surface_pressure and number of levels from model_a and model_b. """ - assert model_a.size == model_b.size, "Model pressure coefficient arrays do not match." + assert model_a_half.size == model_b_half.size, "Model pressure coefficient arrays do not match." if surface_pressure.ndim == 3: # Generate the 3D pressure field for a time series of surface pressure grids pressure_3d = np.zeros( ( surface_pressure.shape[0], - model_a.shape[0], + model_a_half.shape[0] - 1, + surface_pressure.shape[1], + surface_pressure.shape[2], + ), + dtype=surface_pressure.dtype, + ) + pressure_3d_half = np.zeros( + ( + surface_pressure.shape[0], + model_a_half.shape[0], surface_pressure.shape[1], surface_pressure.shape[2], ), dtype=surface_pressure.dtype, ) - model_a_3d = model_a.reshape(-1, 1, 1) - model_b_3d = model_b.reshape(-1, 1, 1) + model_a_3d = model_a_half.reshape(-1, 1, 1) + model_b_3d = model_b_half.reshape(-1, 1, 1) for i in range(surface_pressure.shape[0]): - pressure_3d[i] = model_a_3d + model_b_3d * surface_pressure[i] + pressure_3d_half = model_a_3d + model_b_3d * surface_pressure[i] + pressure_3d[i] = 0.5 * (pressure_3d_half[:-1] + pressure_3d_half[1:]) else: # Generate the 3D pressure field for a single surface pressure grid. - model_a_3d = model_a.reshape(-1, 1, 1) - model_b_3d = model_b.reshape(-1, 1, 1) - pressure_3d = model_a_3d + model_b_3d * surface_pressure - return pressure_3d + model_a_3d = model_a_half.reshape(-1, 1, 1) + model_b_3d = model_b_half.reshape(-1, 1, 1) + pressure_3d_half = model_a_3d + model_b_3d * surface_pressure + pressure_3d = 0.5 * (pressure_3d_half[:-1] + pressure_3d_half[1:]) + return pressure_3d, pressure_3d_half + + +@njit +def geopotential_from_model_vars(surface_geopotential, surface_pressure, temperature, mixing_ratio, a_half, b_half): + """ + Calculate geopotential from the base state variables. Geopotential height is calculated by adding thicknesses + calculated within each half-model-level to account for variations in temperature and moisture between grid cells. + Note that this function is calculating geopotential in units of (m^2 s^-2) not geopential height. + + To convert geopotential to geopotential height, divide geopotential by g (9.806 m s^-2). + + Geopotential height is defined as the height above mean sea level. To get height above ground level, substract + the surface geoptential height field from the 3D geopotential height field. + + Args: + surface_geopotential (np.ndarray): Surface geopotential in shape (y,x) and units m^2 s^-2. + surface_pressure (np.ndarray): Surface pressure in shape (y, x) and units Pa + temperature (np.ndarray): temperature in shape (levels, y, x) and units K + mixing_ratio (np.ndarray): mixing ratio in shape (levels, y, x) and units kg/kg. + a_half (np.ndarray): a coefficients at each model half level being used in units of Pa. + b_half (np.ndarray): b coefficients at each model half level being used (unitness). + + Returns: + model_geoptential (np.ndarray): geopotential on model levels in shape (levels, y, x) + """ + RDGAS = 287.06 + gamma = 0.609133 # from MetView + model_pressure, half_pressure = create_pressure_grid(surface_pressure, a_half, b_half) + model_geopotential = np.zeros(model_pressure.shape, dtype=surface_pressure.dtype) + half_geopotential = np.zeros(half_pressure.shape, dtype=surface_pressure.dtype) + half_geopotential[-1] = surface_geopotential + virtual_temperature = temperature * (1.0 + gamma * mixing_ratio) + m = model_geopotential.shape[-3] - 1 + for i in range(0, model_geopotential.shape[-3]): + if m == 0: + dlog_p = np.log(half_pressure[m + 1] / 0.1) + alpha = np.ones(half_pressure[m + 1].shape) * np.log(2) + else: + dlog_p = np.log(half_pressure[m + 1] / half_pressure[m]) + alpha = alpha = 1.0 - ((half_pressure[m] / (half_pressure[m + 1] - half_pressure[m])) * dlog_p) + model_geopotential[m] = half_geopotential[m + 1] + RDGAS * virtual_temperature[m] * alpha + half_geopotential[m] = half_geopotential[m + 1] + RDGAS * virtual_temperature[m] * dlog_p + m -= 1 + return model_geopotential, half_geopotential @njit @@ -231,8 +289,9 @@ def interp_geopotential_to_pressure_levels( interp_pressures, surface_pressure, surface_geopotential, + geopotential, temperature_k, - temp_level_index=-2, + temp_height=150, ): """ Interpolate geopotential field from hybrid sigma-pressure vertical coordinates to pressure levels. @@ -245,9 +304,9 @@ def interp_geopotential_to_pressure_levels( interp_pressures (np.ndarray): pressure levels for interpolation in units Pa or hPa. surface_pressure (np.ndarray): pressure at the surface in units Pa or hPa. surface_geopotential (np.ndarray): geopotential at the surface in units m^2/s^2. - temperature_k (np.ndarray): temperature state on model levels in Kelvin. - temp_level_index (int): index of vertical level where temperature is extracted for extrapolation to - surface. + geopotential (np.ndarray): geopotential in units m^2/s^2. + temperaure_k (np.ndarray): temperature in units K. + temp_height (float): height above ground of nearest vertical grid cell. Returns: pressure_var (np.ndarray): 3D field on pressure levels with shape (len(interp_pressures), y, x). """ @@ -258,13 +317,14 @@ def interp_geopotential_to_pressure_levels( dtype=model_var.dtype, ) log_interp_pressures = np.log(interp_pressures) - temperature_lowest_level_k = temperature_k[temp_level_index] for (i, j), v in np.ndenumerate(model_var[0]): pressure_var[:, i, j] = np.interp(log_interp_pressures, np.log(model_pressure[:, i, j]), model_var[:, i, j]) for pl, interp_pressure in enumerate(interp_pressures): if interp_pressure > surface_pressure[i, j]: - temp_surface_k = temperature_lowest_level_k[i, j] + ALPHA * temperature_lowest_level_k[i, j] * ( - surface_pressure[i, j] / model_pressure[temp_level_index, i, j] - 1 + height_agl = (geopotential[:, i, j] - surface_geopotential[i, j]) / GRAVITY + h = np.argmin(np.abs(height_agl - temp_height)) + temp_surface_k = temperature_k[h, i, j] + ALPHA * temperature_k[h, i, j] * ( + surface_pressure[i, j] / model_pressure[h, i, j] - 1 ) ln_p_ps = np.log(interp_pressure / surface_pressure[i, j]) pressure_var[pl, i, j] = surface_geopotential[i, j] - RDGAS * temp_surface_k * ln_p_ps * ( @@ -275,12 +335,7 @@ def interp_geopotential_to_pressure_levels( @njit def interp_temperature_to_pressure_levels( - model_var, - model_pressure, - interp_pressures, - surface_pressure, - surface_geopotential, - temp_level_index=-2, + model_var, model_pressure, interp_pressures, surface_pressure, surface_geopotential, geopotential, temp_height=150 ): """ Interpolate temperature field from hybrid sigma-pressure vertical coordinates to pressure levels. @@ -293,8 +348,7 @@ def interp_temperature_to_pressure_levels( interp_pressures: (np.ndarray): pressure levels for interpolation in units Pa or. surface_pressure (np.ndarray): pressure at the surface in units Pa or hPa. surface_geopotential (np.ndarray): geopotential at the surface in units m^2/s^2. - temp_level_index (int): index of vertical level where temperature is extracted for extrapolation to - surface. + temp_height (float): height above ground of nearest vertical grid cell. Returns: pressure_var (np.ndarray): 3D field on pressure levels with shape (len(interp_pressures), y, x). @@ -310,9 +364,12 @@ def interp_temperature_to_pressure_levels( pressure_var[:, i, j] = np.interp(log_interp_pressures, np.log(model_pressure[:, i, j]), model_var[:, i, j]) for pl, interp_pressure in enumerate(interp_pressures): if interp_pressure > surface_pressure[i, j]: - model_level_temp = model_var[temp_level_index, i, j] - temp_surface_k = model_level_temp + ALPHA * model_level_temp * ( - surface_pressure[i, j] / model_pressure[temp_level_index, i, j] - 1 + # The height above ground of each sigma level varies, especially in complex terrain + # To minimize extrapolation error, pick the level closest to 150 m AGL, which is the ECMWF standard. + height_agl = (geopotential[:, i, j] - surface_geopotential[i, j]) / GRAVITY + h = np.argmin(np.abs(height_agl - temp_height)) + temp_surface_k = model_var[h, i, j] + ALPHA * model_var[h, i, j] * ( + surface_pressure[i, j] / model_pressure[h, i, j] - 1 ) surface_height = surface_geopotential[i, j] / GRAVITY temp_sea_level_k = temp_surface_k + LAPSE_RATE * surface_height @@ -332,60 +389,9 @@ def interp_temperature_to_pressure_levels( return pressure_var -@njit -def geopotential_from_model_vars(surface_geopotential, surface_pressure, temperature, mixing_ratio, model_a, model_b): - """ - Calculate geopotential from the base state variables. Geopotential height is calculated by adding thicknesses - calculated within each half-model-level to account for variations in temperature and moisture between grid cells. - Note that this function is calculating geopotential in units of (m^2 s^-2) not geopential height. - - To convert geopotential to geopotential height, divide geopotential by g (9.806 m s^-2). - - Geopotential height is defined as the height above mean sea level. To get height above ground level, substract - the surface geoptential height field from the 3D geopotential height field. - - Args: - surface_geopotential (np.ndarray): Surface geopotential in shape (y,x) and units m^2 s^-2. - surface_pressure (np.ndarray): Surface pressure in shape (y, x) and units Pa - temperature (np.ndarray): temperature in shape (levels, y, x) and units K - mixing_ratio (np.ndarray): mixing ratio in shape (levels, y, x) and units kg/kg. - model_a (np.ndarray): a coefficients at each model level being used in units of Pa. - model_b (np.ndarray): b coefficients at each model level being used (unitness). - - Returns: - model_geoptential (np.ndarray): geopotential on model levels in shape (levels, y, x) - """ - gamma = RVGAS / RDGAS - 1.0 - half_a = 0.5 * (model_a[:-1] + model_a[1:]) - half_b = 0.5 * (model_b[:-1] + model_b[1:]) - model_pressure = create_pressure_grid(surface_pressure, model_a, model_b) - half_pressure = create_pressure_grid(surface_pressure, half_a, half_b) - model_geopotential = np.zeros(model_pressure.shape, dtype=surface_pressure.dtype) - half_geopotential = np.zeros(half_pressure.shape, dtype=surface_pressure.dtype) - virtual_temperature = temperature * (1.0 + gamma * mixing_ratio) - m = model_geopotential.shape[-3] - 1 - h = half_geopotential.shape[-3] - 1 - model_geopotential[m] = surface_geopotential + RDGAS * virtual_temperature[m] * np.log( - surface_pressure / model_pressure[m] - ) - for i in range(1, model_geopotential.shape[-3]): - half_geopotential[h] = model_geopotential[m] + RDGAS * virtual_temperature[m] * np.log( - model_pressure[m] / half_pressure[h] - ) - m -= 1 - model_geopotential[m] = half_geopotential[h] + RDGAS * virtual_temperature[m] * np.log( - half_pressure[h] / model_pressure[m] - ) - h -= 1 - return model_geopotential - - @njit def mean_sea_level_pressure( - surface_pressure_pa, - temperature_lowest_level_k, - pressure_lowest_level_pa, - surface_geopotential, + surface_pressure_pa, temperature_k, pressure_pa, surface_geopotential, geopotential, temp_height=150.0 ): """ Calculate mean sea level pressure from surface pressure, lowest model level temperature, @@ -399,9 +405,11 @@ def mean_sea_level_pressure( Args: surface_pressure_pa: surface pressure in Pascals - temperature_lowest_level_k: Temperature at the lowest model level in Kelvin. - pressure_lowest_level_pa: Pressure at the lowest model level in Pascals. + temperature_k: Temperature at the lowest model level in Kelvin. + pressure_pa: Pressure at the lowest model level in Pascals. surface_geopotential: Geopotential of the surface in m^2 s^-2. + geopotential: Geopotential at all levels. + temp_height: height of nearest vertical grid cell Returns: mslp: Mean sea level pressure in Pascals. @@ -413,8 +421,10 @@ def mean_sea_level_pressure( if np.abs(surface_geopotential[i, j] / GRAVITY) < 1e-4: mslp[i, j] = surface_pressure_pa[i, j] else: - temp_surface_k = temperature_lowest_level_k[i, j] + ALPHA * temperature_lowest_level_k[i, j] * ( - surface_pressure_pa[i, j] / pressure_lowest_level_pa[i, j] - 1 + height_agl = (geopotential[:, i, j] - surface_geopotential[i, j]) / GRAVITY + h = np.argmin(np.abs(height_agl - temp_height)) + temp_surface_k = temperature_k[h, i, j] + ALPHA * temperature_k[h, i, j] * ( + surface_pressure_pa[i, j] / pressure_pa[h, i, j] - 1 ) temp_sealevel_k = temp_surface_k + LAPSE_RATE * surface_geopotential[i, j] / GRAVITY From e3a99c1a8da0267345bd458299d3e46248ca89b9 Mon Sep 17 00:00:00 2001 From: David John Gagne Date: Tue, 18 Feb 2025 12:01:13 -0700 Subject: [PATCH 07/13] Major fixes to interpolation and model loading. --- credit/interp.py | 43 +++++++++++++++++++++-------------- credit/models/__init__.py | 17 +++++++++----- credit/models/crossformer.py | 44 +++++++++++++----------------------- 3 files changed, 53 insertions(+), 51 deletions(-) diff --git a/credit/interp.py b/credit/interp.py index 88fd3687..376e807f 100644 --- a/credit/interp.py +++ b/credit/interp.py @@ -23,8 +23,10 @@ def full_state_pressure_interpolation( level_var: str = "level", model_level_file: str = "../credit/metadata/ERA5_Lev_Info.nc", verbose: int = 1, - a_coord: str = "a_half", - b_coord: str = "b_half", + a_coord: str = "a_model", + b_coord: str = "b_model", + a_half: str = "a_half", + b_half: str = "b_half", temp_level_index: int = -2, ) -> xr.Dataset: """ @@ -47,8 +49,10 @@ def full_state_pressure_interpolation( level_var (str): name of level coordinate model_level_file (str): relative path to file containing model levels. verbose (int): verbosity level. If verbose > 0, print progress. - a_coord (str): Name of A weight in sigma coordinate formula. 'a_half' by default. - b_coord (str): Name of B weight in sigma coordinate formula. 'b_half' by default. + a_coord (str): Name of A weight in sigma coordinate formula. 'a_model' by default. + b_coord (str): Name of B weight in sigma coordinate formula. 'b_model' by default. + a_half (str): Name of A weight in sigma coordinate formula at half levels. 'a_half' by default. + b_half (str): Name of B weight in sigma coordinate formula at half levels. 'b_half' by default. temp_level_index (int): vertical index of the temperature level used for interpolation below ground. Returns: pressure_ds (xr.Dataset): Dataset containing pressure interpolated variables. @@ -56,8 +60,8 @@ def full_state_pressure_interpolation( path_to_file = os.path.abspath(os.path.dirname(__file__)) model_level_file = os.path.join(path_to_file, model_level_file) with xr.open_dataset(model_level_file) as mod_lev_ds: - a_half = mod_lev_ds[a_coord].loc[state_dataset[level_var]].values - b_half = mod_lev_ds[b_coord].loc[state_dataset[level_var]].values + a_half = mod_lev_ds[a_half].values + b_half = mod_lev_ds[b_half].values pres_dims = (time_var, pres_var, lat_var, lon_var) surface_dims = (time_var, lat_var, lon_var) coords = { @@ -92,23 +96,26 @@ def full_state_pressure_interpolation( disable = False if verbose == 0: disable = True + sub_half_levels = np.concatenate([state_dataset[level_var].values, [138]]) + sub_levels = state_dataset[level_var].values for t, time in tqdm(enumerate(state_dataset[time_var]), disable=disable): pressure_grid, half_pressure_grid = create_pressure_grid( state_dataset[surface_pressure_var][t].values.astype(np.float64), a_half, b_half ) + pressure_sub_grid = pressure_grid[sub_levels - 1] + half_pressure_sub_grid = half_pressure_grid[sub_half_levels - 1] geopotential_grid = geopotential_from_model_vars( surface_geopotential.astype(np.float64), state_dataset[surface_pressure_var][t].values.astype(np.float64), state_dataset[temperature_var][t].values.astype(np.float64), state_dataset[q_var][t].values.astype(np.float64), - a_half, - b_half, + half_pressure_sub_grid, ) for interp_field in interp_fields: if interp_field == temperature_var: pressure_ds[interp_field + pres_ending][t] = interp_temperature_to_pressure_levels( state_dataset[interp_field][t].values, - pressure_grid / 100.0, + pressure_sub_grid / 100.0, pressure_levels, state_dataset[surface_pressure_var][t].values / 100.0, surface_geopotential, @@ -117,12 +124,12 @@ def full_state_pressure_interpolation( else: pressure_ds[interp_field + pres_ending][t] = interp_hybrid_to_pressure_levels( state_dataset[interp_field][t].values, - pressure_grid / 100.0, + pressure_sub_grid / 100.0, pressure_levels, ) pressure_ds[geopotential_var + pres_ending][t] = interp_geopotential_to_pressure_levels( geopotential_grid, - pressure_grid / 100.0, + pressure_sub_grid / 100.0, pressure_levels, state_dataset[surface_pressure_var][t].values / 100.0, surface_geopotential, @@ -132,8 +139,9 @@ def full_state_pressure_interpolation( pressure_ds["mean_sea_level_" + pres_var][t] = mean_sea_level_pressure( state_dataset[surface_pressure_var][t].values, state_dataset[temperature_var][t].values, - pressure_grid[temp_level_index], + pressure_sub_grid, surface_geopotential, + geopotential_grid, ) return pressure_ds @@ -188,7 +196,7 @@ def create_pressure_grid(surface_pressure, model_a_half, model_b_half): @njit -def geopotential_from_model_vars(surface_geopotential, surface_pressure, temperature, mixing_ratio, a_half, b_half): +def geopotential_from_model_vars(surface_geopotential, surface_pressure, temperature, mixing_ratio, half_pressure): """ Calculate geopotential from the base state variables. Geopotential height is calculated by adding thicknesses calculated within each half-model-level to account for variations in temperature and moisture between grid cells. @@ -212,8 +220,9 @@ def geopotential_from_model_vars(surface_geopotential, surface_pressure, tempera """ RDGAS = 287.06 gamma = 0.609133 # from MetView - model_pressure, half_pressure = create_pressure_grid(surface_pressure, a_half, b_half) - model_geopotential = np.zeros(model_pressure.shape, dtype=surface_pressure.dtype) + model_geopotential = np.zeros( + (half_pressure.shape[0] - 1, half_pressure.shape[1], half_pressure.shape[2]), dtype=surface_pressure.dtype + ) half_geopotential = np.zeros(half_pressure.shape, dtype=surface_pressure.dtype) half_geopotential[-1] = surface_geopotential virtual_temperature = temperature * (1.0 + gamma * mixing_ratio) @@ -224,11 +233,11 @@ def geopotential_from_model_vars(surface_geopotential, surface_pressure, tempera alpha = np.ones(half_pressure[m + 1].shape) * np.log(2) else: dlog_p = np.log(half_pressure[m + 1] / half_pressure[m]) - alpha = alpha = 1.0 - ((half_pressure[m] / (half_pressure[m + 1] - half_pressure[m])) * dlog_p) + alpha = 1.0 - ((half_pressure[m] / (half_pressure[m + 1] - half_pressure[m])) * dlog_p) model_geopotential[m] = half_geopotential[m + 1] + RDGAS * virtual_temperature[m] * alpha half_geopotential[m] = half_geopotential[m + 1] + RDGAS * virtual_temperature[m] * dlog_p m -= 1 - return model_geopotential, half_geopotential + return model_geopotential @njit diff --git a/credit/models/__init__.py b/credit/models/__init__.py index c06e8108..bd69a4ef 100644 --- a/credit/models/__init__.py +++ b/credit/models/__init__.py @@ -104,6 +104,7 @@ def load_fsdp_or_checkpoint_policy(conf): def load_model(conf, load_weights=False): conf = copy.deepcopy(conf) + model_conf = conf["model"] if "type" not in model_conf: @@ -120,18 +121,22 @@ def load_model(conf, load_weights=False): logger.info(message) if load_weights: model = model(**model_conf) - save_loc = conf["save_loc"] - ckpt = os.path.join(save_loc, "checkpoint.pt") + save_loc = os.path.expandvars(conf["save_loc"]) + if os.path.isfile(os.path.join(save_loc, "model_checkpoint.pt")): + ckpt = os.path.join(save_loc, "model_checkpoint.pt") + else: + ckpt = os.path.join(save_loc, "checkpoint.pt") if not os.path.isfile(ckpt): - raise ValueError( - "No saved checkpoint exists. You must train a model first. Exiting." - ) + raise ValueError("No saved checkpoint exists. You must train a model first. Exiting.") logging.info(f"Loading a model with pre-trained weights from path {ckpt}") checkpoint = torch.load(ckpt) - model.load_state_dict(checkpoint["model_state_dict"]) + if "model_state_dict" in checkpoint.keys(): + model.load_state_dict(checkpoint["model_state_dict"], strict=False) + else: + model.load_state_dict(checkpoint, strict=False) return model return model(**model_conf) diff --git a/credit/models/crossformer.py b/credit/models/crossformer.py index d7522f6a..ffec1811 100644 --- a/credit/models/crossformer.py +++ b/credit/models/crossformer.py @@ -35,9 +35,7 @@ class CubeEmbedding(nn.Module): patch_size: T, Lat, Lon """ - def __init__( - self, img_size, patch_size, in_chans, embed_dim, norm_layer=nn.LayerNorm - ): + def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=nn.LayerNorm): super().__init__() patches_resolution = [ img_size[0] // patch_size[0], @@ -48,9 +46,7 @@ def __init__( self.img_size = img_size self.patches_resolution = patches_resolution self.embed_dim = embed_dim - self.proj = nn.Conv3d( - in_chans, embed_dim, kernel_size=patch_size, stride=patch_size - ) + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: @@ -77,9 +73,7 @@ def __init__(self, in_chans, out_chans, num_groups, num_residuals=2): blk = [] for i in range(num_residuals): - blk.append( - nn.Conv2d(out_chans, out_chans, kernel_size=3, stride=1, padding=1) - ) + blk.append(nn.Conv2d(out_chans, out_chans, kernel_size=3, stride=1, padding=1)) blk.append(nn.GroupNorm(num_groups, out_chans)) blk.append(nn.SiLU()) @@ -242,9 +236,7 @@ def forward(self, x): # split heads - q, k, v = map( - lambda t: rearrange(t, "b (h d) x y -> b h (x y) d", h=heads), (q, k, v) - ) + q, k, v = map(lambda t: rearrange(t, "b (h d) x y -> b h (x y) d", h=heads), (q, k, v)) q = q * self.scale sim = einsum("b h i d, b h j d -> b h i j", q, k) @@ -422,7 +414,7 @@ def __init__( if padding_conf is None: padding_conf = {"activate": False} self.use_padding = padding_conf["activate"] - + if post_conf is None: post_conf = {"activate": False} self.use_post_block = post_conf["activate"] @@ -449,9 +441,7 @@ def __init__( # dimensions last_dim = dim[-1] - first_dim = ( - input_channels if (patch_height == 1 and patch_width == 1) else dim[0] - ) + first_dim = input_channels if (patch_height == 1 and patch_width == 1) else dim[0] dims = [first_dim, *dim] dim_in_and_out = tuple(zip(dims[:-1], dims[1:])) @@ -506,9 +496,7 @@ def __init__( self.up_block1 = UpBlock(1 * last_dim, last_dim // 2, dim[0]) self.up_block2 = UpBlock(2 * (last_dim // 2), last_dim // 4, dim[0]) self.up_block3 = UpBlock(2 * (last_dim // 4), last_dim // 8, dim[0]) - self.up_block4 = nn.ConvTranspose2d( - 2 * (last_dim // 8), output_channels, kernel_size=4, stride=2, padding=1 - ) + self.up_block4 = nn.ConvTranspose2d(2 * (last_dim // 8), output_channels, kernel_size=4, stride=2, padding=1) if self.use_spectral_norm: logger.info("Adding spectral norm to all conv and linear layers") @@ -516,12 +504,14 @@ def __init__( if self.use_post_block: # freeze base model weights before postblock init - if (post_conf["skebs"].get("activate", False) - and post_conf["skebs"].get("freeze_base_model_weights", False)): - logger.warning("freezing all base model weights due to skebs config") - for param in self.parameters(): - param.requires_grad = False - + if "skebs" in post_conf.keys(): + if post_conf["skebs"].get("activate", False) and post_conf["skebs"].get( + "freeze_base_model_weights", False + ): + logger.warning("freezing all base model weights due to skebs config") + for param in self.parameters(): + param.requires_grad = False + logger.info("using postblock") self.postblock = PostBlock(post_conf) @@ -558,9 +548,7 @@ def forward(self, x): x = self.padding_opt.unpad(x) if self.use_interp: - x = F.interpolate( - x, size=(self.image_height, self.image_width), mode="bilinear" - ) + x = F.interpolate(x, size=(self.image_height, self.image_width), mode="bilinear") x = x.unsqueeze(2) From 110d41f922a98ffec9be35646ee1312777c21480 Mon Sep 17 00:00:00 2001 From: "djgagne@ou.edu" Date: Tue, 18 Feb 2025 14:05:47 -0700 Subject: [PATCH 08/13] Fix to model loading --- applications/rollout_to_netcdf.py | 65 +++++++------------------------ credit/models/base_model.py | 17 ++++---- 2 files changed, 22 insertions(+), 60 deletions(-) diff --git a/applications/rollout_to_netcdf.py b/applications/rollout_to_netcdf.py index 6a8bb8d4..08e7d644 100644 --- a/applications/rollout_to_netcdf.py +++ b/applications/rollout_to_netcdf.py @@ -214,9 +214,7 @@ def predict(rank, world_size, conf, p): model = distributed_model_wrapper(conf, model, device) ckpt = os.path.join(save_loc, "checkpoint.pt") checkpoint = torch.load(ckpt, map_location=device) - load_msg = model.module.load_state_dict( - checkpoint["model_state_dict"], strict=False - ) + load_msg = model.module.load_state_dict(checkpoint["model_state_dict"], strict=False) load_state_dict_error_handler(load_msg) elif conf["predict"]["mode"] == "fsdp": @@ -234,14 +232,11 @@ def predict(rank, world_size, conf, p): meta_data = load_metadata(conf) # Set up metrics and containers - metrics = LatWeightedMetrics(conf, predict_mode=True) + metrics = LatWeightedMetrics(conf, training_mode=False) metrics_results = defaultdict(list) # Set up the diffusion and pole filters - if ( - "use_laplace_filter" in conf["predict"] - and conf["predict"]["use_laplace_filter"] - ): + if "use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]: dpf = Diffusion_and_Pole_Filter( nlat=conf["model"]["image_height"], nlon=conf["model"]["image_width"], @@ -268,11 +263,7 @@ def predict(rank, world_size, conf, p): # combine x and x_surf # input: (batch_num, time, var, level, lat, lon), (batch_num, time, var, lat, lon) # output: (batch_num, var, time, lat, lon), 'x' first and then 'x_surf' - x = ( - concat_and_reshape(batch["x"], batch["x_surf"]) - .to(device) - .float() - ) + x = concat_and_reshape(batch["x"], batch["x_surf"]).to(device).float() else: # no x_surf x = reshape_only(batch["x"]).to(device).float() @@ -284,9 +275,7 @@ def predict(rank, world_size, conf, p): # add forcing and static variables (regardless of fcst hours) if "x_forcing_static" in batch: # (batch_num, time, var, lat, lon) --> (batch_num, var, time, lat, lon) - x_forcing_batch = ( - batch["x_forcing_static"].to(device).permute(0, 2, 1, 3, 4).float() - ) + x_forcing_batch = batch["x_forcing_static"].to(device).permute(0, 2, 1, 3, 4).float() # concat on var dimension x = torch.cat((x, x_forcing_batch), dim=1) @@ -348,29 +337,17 @@ def predict(rank, world_size, conf, p): # y_target with unit y = state_transformer.inverse_transform(y.cpu()) - if ( - "use_laplace_filter" in conf["predict"] - and conf["predict"]["use_laplace_filter"] - ): - y_pred = ( - dpf.diff_lap2d_filt(y_pred.to(device).squeeze()) - .unsqueeze(0) - .unsqueeze(2) - .cpu() - ) + if "use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]: + y_pred = dpf.diff_lap2d_filt(y_pred.to(device).squeeze()).unsqueeze(0).unsqueeze(2).cpu() # Compute metrics - metrics_dict = metrics( - y_pred.float(), y.float(), forecast_datetime=forecast_hour - ) + metrics_dict = metrics(y_pred.float(), y.float(), forecast_datetime=forecast_hour) for k, m in metrics_dict.items(): metrics_results[k].append(m.item()) metrics_results["forecast_hour"].append(forecast_hour) # Save the current forecast hour data in parallel - utc_datetime = init_datetime + timedelta( - hours=lead_time_periods * forecast_hour - ) + utc_datetime = init_datetime + timedelta(hours=lead_time_periods * forecast_hour) # convert the current step result as x-array darray_upper_air, darray_single_level = make_xarray( @@ -424,9 +401,7 @@ def predict(rank, world_size, conf, p): # cut diagnostic vars from y_pred, they are not inputs if "y_diag" in batch: - x = torch.cat( - [x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2 - ) + x = torch.cat([x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2) else: x = torch.cat([x_detach, y_pred.detach()], dim=2) # ============================================================ # @@ -441,16 +416,10 @@ def predict(rank, world_size, conf, p): result.get() # save metrics file - save_location = os.path.join( - os.path.expandvars(conf["save_loc"]), "forecasts", "metrics" - ) - os.makedirs( - save_location, exist_ok=True - ) # should already be made above + save_location = os.path.join(os.path.expandvars(conf["save_loc"]), "forecasts", "metrics") + os.makedirs(save_location, exist_ok=True) # should already be made above df = pd.DataFrame(metrics_results) - df.to_csv( - os.path.join(save_location, f"metrics{init_datetime_str}.csv") - ) + df.to_csv(os.path.join(save_location, f"metrics{init_datetime_str}.csv")) # forecast count = a constant for each run forecast_count += 1 @@ -563,17 +532,13 @@ def predict(rank, world_size, conf, p): # ======================================================== # # handling config args - conf = credit_main_parser( - conf, parse_training=False, parse_predict=True, print_summary=False - ) + conf = credit_main_parser(conf, parse_training=False, parse_predict=True, print_summary=False) predict_data_check(conf, print_summary=False) # ======================================================== # # create a save location for rollout # ---------------------------------------------------- # - assert ( - "save_forecast" in conf["predict"] - ), "Please specify the output dir through conf['predict']['save_forecast']" + assert "save_forecast" in conf["predict"], "Please specify the output dir through conf['predict']['save_forecast']" forecast_save_loc = conf["predict"]["save_forecast"] os.makedirs(forecast_save_loc, exist_ok=True) diff --git a/credit/models/base_model.py b/credit/models/base_model.py index b353185d..ea5748de 100644 --- a/credit/models/base_model.py +++ b/credit/models/base_model.py @@ -61,15 +61,11 @@ def load_model(cls, conf): if os.path.isfile(os.path.join(save_loc, "model_checkpoint.pt")): ckpt = os.path.join(save_loc, "model_checkpoint.pt") - fsdp = True else: ckpt = os.path.join(save_loc, "checkpoint.pt") - fsdp = False if not os.path.isfile(ckpt): - raise ValueError( - "No saved checkpoint exists. You must train a model first. Exiting." - ) + raise ValueError("No saved checkpoint exists. You must train a model first. Exiting.") logging.info(f"Loading a model with pre-trained weights from path {ckpt}") @@ -82,11 +78,12 @@ def load_model(cls, conf): del conf["model"]["type"] model_class = cls(**conf["model"]) - - load_msg = model_class.load_state_dict( - checkpoint if fsdp else checkpoint["model_state_dict"], - strict=False - ) + print(list(checkpoint.keys())) + if "model_state_dict" in checkpoint.keys(): + print("model_state_dict in keys") + load_msg = model_class.load_state_dict(checkpoint, strict=False) + else: + load_msg = model_class.load_state_dict(checkpoint, strict=False) load_state_dict_error_handler(load_msg) return model_class From 0a918d192e36496740083a817b40a0cba574c7aa Mon Sep 17 00:00:00 2001 From: "djgagne@ou.edu" Date: Tue, 18 Feb 2025 14:07:46 -0700 Subject: [PATCH 09/13] removed print statement --- credit/models/base_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/credit/models/base_model.py b/credit/models/base_model.py index ea5748de..6084a3dc 100644 --- a/credit/models/base_model.py +++ b/credit/models/base_model.py @@ -78,9 +78,7 @@ def load_model(cls, conf): del conf["model"]["type"] model_class = cls(**conf["model"]) - print(list(checkpoint.keys())) if "model_state_dict" in checkpoint.keys(): - print("model_state_dict in keys") load_msg = model_class.load_state_dict(checkpoint, strict=False) else: load_msg = model_class.load_state_dict(checkpoint, strict=False) From cc64001d0ee9eb3b6e56e7955afc200ca8150606 Mon Sep 17 00:00:00 2001 From: "djgagne@ou.edu" Date: Tue, 18 Feb 2025 14:25:02 -0700 Subject: [PATCH 10/13] Changed default of use_spectral_norm in parser to True --- credit/parser.py | 478 ++++++++++++----------------------------------- 1 file changed, 118 insertions(+), 360 deletions(-) diff --git a/credit/parser.py b/credit/parser.py index 42b66685..b2ea0702 100644 --- a/credit/parser.py +++ b/credit/parser.py @@ -29,16 +29,10 @@ def remove_string_by_pattern(list_string, pattern): if pattern in single_string: pattern_collection.append(single_string) - return [ - single_string - for single_string in list_string - if single_string not in pattern_collection - ] + return [single_string for single_string in list_string if single_string not in pattern_collection] -def credit_main_parser( - conf, parse_training=True, parse_predict=True, print_summary=False -): +def credit_main_parser(conf, parse_training=True, parse_predict=True, print_summary=False): """ Parses and validates the configuration input for the CREDIT project. @@ -68,14 +62,10 @@ def credit_main_parser( """ - assert ( - "save_loc" in conf - ), "save location of the CREDIT project ('save_loc') is missing from conf" + assert "save_loc" in conf, "save location of the CREDIT project ('save_loc') is missing from conf" assert "data" in conf, "data section ('data') is missing from conf" assert "model" in conf, "model section ('model') is missing from conf" - assert ( - "latitude_weights" in conf["loss"] - ), "lat / lon file ('latitude_weights') is missing from conf['loss']" + assert "latitude_weights" in conf["loss"], "lat / lon file ('latitude_weights') is missing from conf['loss']" if parse_training: assert "trainer" in conf, "trainer section ('trainer') is missing from conf" @@ -90,9 +80,7 @@ def credit_main_parser( # conf['data'] section # must have upper-air variables - assert ( - "variables" in conf["data"] - ), "upper-air variable names ('variables') is missing from conf['data']" + assert "variables" in conf["data"], "upper-air variable names ('variables') is missing from conf['data']" if (conf["data"]["variables"] is None) or (len(conf["data"]["variables"]) == 0): print( @@ -102,9 +90,7 @@ def credit_main_parser( ) raise - assert ( - "save_loc" in conf["data"] - ), "upper-air var save locations ('save_loc') is missing from conf['data']" + assert "save_loc" in conf["data"], "upper-air var save locations ('save_loc') is missing from conf['data']" if conf["data"]["save_loc"] is None: print( @@ -118,9 +104,7 @@ def credit_main_parser( if "levels" in conf["model"]: conf["data"]["levels"] = conf["model"]["levels"] else: - print( - "number of upper-air levels ('levels') is missing from both conf['data'] and conf['model']" - ) + print("number of upper-air levels ('levels') is missing from both conf['data'] and conf['model']") raise # ========================================================================================= # # Check other input / output variable types @@ -236,9 +220,7 @@ def credit_main_parser( varname_counts = Counter(all_varnames) duplicates = [varname for varname, count in varname_counts.items() if count > 1] - assert ( - len(duplicates) == 0 - ), "Duplicated variable names: [{}] found. No duplicates allowed, stop.".format( + assert len(duplicates) == 0, "Duplicated variable names: [{}] found. No duplicates allowed, stop.".format( duplicates ) @@ -254,15 +236,11 @@ def credit_main_parser( conf["data"].setdefault("data_clamp", None) if parse_training: - assert ( - "train_years" in conf["data"] - ), "year range for training ('train_years') is missing from conf['data']" + assert "train_years" in conf["data"], "year range for training ('train_years') is missing from conf['data']" # 'valid_years' is required even for conf['trainer']['skip_validation']: True # 'valid_years' and 'train_years' can overlap - assert ( - "valid_years" in conf["data"] - ), "year range for validation ('valid_years') is missing from conf['data']" + assert "valid_years" in conf["data"], "year range for validation ('valid_years') is missing from conf['data']" assert ( "forecast_len" in conf["data"] @@ -289,21 +267,15 @@ def credit_main_parser( if "total_time_steps" not in conf["data"]: conf["data"]["total_time_steps"] = conf["data"]["forecast_len"] - assert ( - "history_len" in conf["data"] - ), "Number of input time frames ('history_len') is missing from conf['data']" + assert "history_len" in conf["data"], "Number of input time frames ('history_len') is missing from conf['data']" assert ( "lead_time_periods" in conf["data"] ), "Number of forecast hours ('lead_time_periods') is missing from conf['data']" assert "scaler_type" in conf["data"], "'scaler_type' is missing from conf['data']" if conf["data"]["scaler_type"] == "std_new": - assert ( - "mean_path" in conf["data"] - ), "The z-score mean file ('mean_path') is missing from conf['data']" - assert ( - "std_path" in conf["data"] - ), "The z-score std file ('std_path') is missing from conf['data']" + assert "mean_path" in conf["data"], "The z-score mean file ('mean_path') is missing from conf['data']" + assert "std_path" in conf["data"], "The z-score std file ('std_path') is missing from conf['data']" # skip_periods if ("skip_periods" not in conf["data"]) or (conf["data"]["skip_periods"] is None): @@ -319,14 +291,10 @@ def credit_main_parser( # conf['model'] section # spectral norm default to false - conf["model"].setdefault("use_spectral_norm", False) + conf["model"].setdefault("use_spectral_norm", True) - if (conf["model"]["type"] == "fuxi") and ( - conf["model"]["use_spectral_norm"] is False - ): - warnings.warn( - "FuXi may not work with 'use_spectral_norm: False' in fsdp training." - ) + if (conf["model"]["type"] == "fuxi") and (conf["model"]["use_spectral_norm"] is False): + warnings.warn("FuXi may not work with 'use_spectral_norm: False' in fsdp training.") # use interpolation if "interp" not in conf["model"]: @@ -356,12 +324,8 @@ def credit_main_parser( pad_lon = conf["model"]["padding_conf"]["pad_lon"] pad_lat = conf["model"]["padding_conf"]["pad_lat"] - assert all( - p >= 0 for p in pad_lon - ), "padding size for longitude dim must be non-negative." - assert all( - p >= 0 for p in pad_lat - ), "padding size for latitude dim must be non-negative." + assert all(p >= 0 for p in pad_lon), "padding size for longitude dim must be non-negative." + assert all(p >= 0 for p in pad_lat), "padding size for latitude dim must be non-negative." assert conf["model"]["padding_conf"]["mode"] in [ "mirror", @@ -386,17 +350,13 @@ def credit_main_parser( # see if any of the postconfs want to be activated post_conf = conf["model"]["post_conf"] - activate_any = any( - [post_conf[post_module]["activate"] for post_module in post_list] - ) + activate_any = any([post_conf[post_module]["activate"] for post_module in post_list]) if post_conf["activate"] and not activate_any: raise ("post_conf is set activate, but no post modules specified") if conf["model"]["post_conf"]["activate"]: # copy only model configs to post_conf subdictionary - conf["model"]["post_conf"]["model"] = { - k: v for k, v in conf["model"].items() if k != "post_conf" - } + conf["model"]["post_conf"]["model"] = {k: v for k, v in conf["model"].items() if k != "post_conf"} # copy data configs to post_conf (for de-normalize variables) conf["model"]["post_conf"]["data"] = {k: v for k, v in conf["data"].items()} @@ -428,14 +388,12 @@ def credit_main_parser( + conf["data"]["static_variables"] ) - varname_output += ( - conf["data"]["surface_variables"] + conf["data"]["diagnostic_variables"] - ) + varname_output += conf["data"]["surface_variables"] + conf["data"]["diagnostic_variables"] # # debug only conf["model"]["post_conf"]["varname_input"] = varname_input conf["model"]["post_conf"]["varname_output"] = varname_output - + # --------------------------------------------------------------------- # # SKEBS @@ -444,10 +402,7 @@ def credit_main_parser( # --------------------------------------------------------------------- # # tracer fixer - flag_tracer = ( - conf["model"]["post_conf"]["activate"] - and conf["model"]["post_conf"]["tracer_fixer"]["activate"] - ) + flag_tracer = conf["model"]["post_conf"]["activate"] and conf["model"]["post_conf"]["tracer_fixer"]["activate"] if flag_tracer: # when tracer fixer is on, get tensor indices of tracers @@ -476,26 +431,18 @@ def credit_main_parser( # --------------------------------------------------------------------- # # global mass fixer - flag_mass = ( - conf["model"]["post_conf"]["activate"] - and conf["model"]["post_conf"]["global_mass_fixer"]["activate"] - ) + flag_mass = conf["model"]["post_conf"]["activate"] and conf["model"]["post_conf"]["global_mass_fixer"]["activate"] if flag_mass: # when global mass fixer is on, get tensor indices of q, precip, evapor # these variables must be outputs # global mass fixer defaults - conf["model"]["post_conf"]["global_mass_fixer"].setdefault( - "activate_outside_model", False - ) + conf["model"]["post_conf"]["global_mass_fixer"].setdefault("activate_outside_model", False) conf["model"]["post_conf"]["global_mass_fixer"].setdefault("denorm", True) conf["model"]["post_conf"]["global_mass_fixer"].setdefault("simple_demo", False) conf["model"]["post_conf"]["global_mass_fixer"].setdefault("midpoint", False) - conf["model"]["post_conf"]["global_mass_fixer"].setdefault( - "grid_type", "pressure" - ) - + conf["model"]["post_conf"]["global_mass_fixer"].setdefault("grid_type", "pressure") assert ( "fix_level_num" in conf["model"]["post_conf"]["global_mass_fixer"] @@ -508,56 +455,39 @@ def credit_main_parser( if conf["model"]["post_conf"]["global_mass_fixer"]["grid_type"] == "sigma": assert ( - "surface_pressure_name" - in conf["model"]["post_conf"]["global_mass_fixer"] + "surface_pressure_name" in conf["model"]["post_conf"]["global_mass_fixer"] ), "Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates" q_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_mass_fixer"][ - "specific_total_water_name" - ] + if var in conf["model"]["post_conf"]["global_mass_fixer"]["specific_total_water_name"] ] conf["model"]["post_conf"]["global_mass_fixer"]["q_inds"] = q_inds - if conf["model"]["post_conf"]["global_mass_fixer"]["grid_type"] == "sigma": sp_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_mass_fixer"][ - "surface_pressure_name" - ] + if var in conf["model"]["post_conf"]["global_mass_fixer"]["surface_pressure_name"] ] conf["model"]["post_conf"]["global_mass_fixer"]["sp_inds"] = sp_inds[0] # --------------------------------------------------------------------- # # global water fixer - flag_water = ( - conf["model"]["post_conf"]["activate"] - and conf["model"]["post_conf"]["global_water_fixer"]["activate"] - ) + flag_water = conf["model"]["post_conf"]["activate"] and conf["model"]["post_conf"]["global_water_fixer"]["activate"] if flag_water: # when global water fixer is on, get tensor indices of q, precip, evapor # these variables must be outputs # global water fixer defaults - conf["model"]["post_conf"]["global_water_fixer"].setdefault( - "activate_outside_model", False - ) + conf["model"]["post_conf"]["global_water_fixer"].setdefault("activate_outside_model", False) conf["model"]["post_conf"]["global_water_fixer"].setdefault("denorm", True) - conf["model"]["post_conf"]["global_water_fixer"].setdefault( - "simple_demo", False - ) + conf["model"]["post_conf"]["global_water_fixer"].setdefault("simple_demo", False) conf["model"]["post_conf"]["global_water_fixer"].setdefault("midpoint", False) - conf["model"]["post_conf"]["global_water_fixer"].setdefault( - "grid_type", "pressure" - ) + conf["model"]["post_conf"]["global_water_fixer"].setdefault("grid_type", "pressure") if conf["model"]["post_conf"]["global_water_fixer"]["simple_demo"] is False: assert ( @@ -566,53 +496,42 @@ def credit_main_parser( if conf["model"]["post_conf"]["global_water_fixer"]["grid_type"] == "sigma": assert ( - "surface_pressure_name" - in conf["model"]["post_conf"]["global_water_fixer"] + "surface_pressure_name" in conf["model"]["post_conf"]["global_water_fixer"] ), "Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates" q_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_water_fixer"][ - "specific_total_water_name" - ] + if var in conf["model"]["post_conf"]["global_water_fixer"]["specific_total_water_name"] ] precip_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_water_fixer"]["precipitation_name"] + if var in conf["model"]["post_conf"]["global_water_fixer"]["precipitation_name"] ] evapor_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_water_fixer"]["evaporation_name"] + if var in conf["model"]["post_conf"]["global_water_fixer"]["evaporation_name"] ] conf["model"]["post_conf"]["global_water_fixer"]["q_inds"] = q_inds conf["model"]["post_conf"]["global_water_fixer"]["precip_ind"] = precip_inds[0] conf["model"]["post_conf"]["global_water_fixer"]["evapor_ind"] = evapor_inds[0] - if conf["model"]["post_conf"]["global_water_fixer"]["grid_type"] == "sigma": sp_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_water_fixer"][ - "surface_pressure_name" - ] + if var in conf["model"]["post_conf"]["global_water_fixer"]["surface_pressure_name"] ] conf["model"]["post_conf"]["global_water_fixer"]["sp_inds"] = sp_inds[0] # --------------------------------------------------------------------- # # global energy fixer flag_energy = ( - conf["model"]["post_conf"]["activate"] - and conf["model"]["post_conf"]["global_energy_fixer"]["activate"] + conf["model"]["post_conf"]["activate"] and conf["model"]["post_conf"]["global_energy_fixer"]["activate"] ) if flag_energy: @@ -620,45 +539,33 @@ def credit_main_parser( # geopotential at surface is input, others are outputs # global energy fixer defaults - conf["model"]["post_conf"]["global_energy_fixer"].setdefault( - "activate_outside_model", False - ) + conf["model"]["post_conf"]["global_energy_fixer"].setdefault("activate_outside_model", False) conf["model"]["post_conf"]["global_energy_fixer"].setdefault("denorm", True) - conf["model"]["post_conf"]["global_energy_fixer"].setdefault( - "simple_demo", False - ) + conf["model"]["post_conf"]["global_energy_fixer"].setdefault("simple_demo", False) conf["model"]["post_conf"]["global_energy_fixer"].setdefault("midpoint", False) - conf["model"]["post_conf"]["global_energy_fixer"].setdefault( - "grid_type", "pressure" - ) + conf["model"]["post_conf"]["global_energy_fixer"].setdefault("grid_type", "pressure") if conf["model"]["post_conf"]["global_energy_fixer"]["simple_demo"] is False: assert ( - "lon_lat_level_name" - in conf["model"]["post_conf"]["global_energy_fixer"] + "lon_lat_level_name" in conf["model"]["post_conf"]["global_energy_fixer"] ), "Must specifiy var names for lat/lon/level in physics reference file" if conf["model"]["post_conf"]["global_energy_fixer"]["grid_type"] == "sigma": assert ( - "surface_pressure_name" - in conf["model"]["post_conf"]["global_energy_fixer"] + "surface_pressure_name" in conf["model"]["post_conf"]["global_energy_fixer"] ), "Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates" T_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_energy_fixer"]["air_temperature_name"] + if var in conf["model"]["post_conf"]["global_energy_fixer"]["air_temperature_name"] ] q_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_energy_fixer"][ - "specific_total_water_name" - ] + if var in conf["model"]["post_conf"]["global_energy_fixer"]["specific_total_water_name"] ] U_inds = [ @@ -676,28 +583,19 @@ def credit_main_parser( TOA_rad_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_energy_fixer"][ - "TOA_net_radiation_flux_name" - ] + if var in conf["model"]["post_conf"]["global_energy_fixer"]["TOA_net_radiation_flux_name"] ] surf_rad_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_energy_fixer"][ - "surface_net_radiation_flux_name" - ] + if var in conf["model"]["post_conf"]["global_energy_fixer"]["surface_net_radiation_flux_name"] ] surf_flux_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_energy_fixer"][ - "surface_energy_flux_name" - ] + if var in conf["model"]["post_conf"]["global_energy_fixer"]["surface_energy_flux_name"] ] conf["model"]["post_conf"]["global_energy_fixer"]["T_inds"] = T_inds @@ -705,31 +603,22 @@ def credit_main_parser( conf["model"]["post_conf"]["global_energy_fixer"]["U_inds"] = U_inds conf["model"]["post_conf"]["global_energy_fixer"]["V_inds"] = V_inds conf["model"]["post_conf"]["global_energy_fixer"]["TOA_rad_inds"] = TOA_rad_inds - conf["model"]["post_conf"]["global_energy_fixer"]["surf_rad_inds"] = ( - surf_rad_inds - ) - conf["model"]["post_conf"]["global_energy_fixer"]["surf_flux_inds"] = ( - surf_flux_inds - ) + conf["model"]["post_conf"]["global_energy_fixer"]["surf_rad_inds"] = surf_rad_inds + conf["model"]["post_conf"]["global_energy_fixer"]["surf_flux_inds"] = surf_flux_inds if conf["model"]["post_conf"]["global_energy_fixer"]["grid_type"] == "sigma": sp_inds = [ i_var for i_var, var in enumerate(varname_output) - if var - in conf["model"]["post_conf"]["global_energy_fixer"][ - "surface_pressure_name" - ] + if var in conf["model"]["post_conf"]["global_energy_fixer"]["surface_pressure_name"] ] conf["model"]["post_conf"]["global_energy_fixer"]["sp_inds"] = sp_inds[0] - + # --------------------------------------------------------- # # conf['trainer'] section if parse_training: - assert ( - "mode" in conf["trainer"] - ), "Resource type ('mode') is missing from conf['trainer']" + assert "mode" in conf["trainer"], "Resource type ('mode') is missing from conf['trainer']" assert conf["trainer"]["mode"] in [ "fsdp", @@ -737,16 +626,10 @@ def credit_main_parser( "none", ], "conf['trainer']['mode'] accepts fsdp, ddp, and none" - assert ( - "type" in conf["trainer"] - ), "Training strategy ('type') is missing from conf['trainer']" + assert "type" in conf["trainer"], "Training strategy ('type') is missing from conf['trainer']" - assert ( - "load_weights" in conf["trainer"] - ), "must specify 'load_weights' in conf['trainer']" - assert ( - "learning_rate" in conf["trainer"] - ), "must specify 'learning_rate' in conf['trainer']" + assert "load_weights" in conf["trainer"], "must specify 'load_weights' in conf['trainer']" + assert "learning_rate" in conf["trainer"], "must specify 'learning_rate' in conf['trainer']" assert ( "batches_per_epoch" in conf["trainer"] @@ -758,11 +641,11 @@ def credit_main_parser( if "ensemble_size" not in conf["trainer"]: conf["trainer"]["ensemble_size"] = 1 # default value of 1 means deterministic training - + if conf["trainer"]["ensemble_size"] > 1: assert ( conf["loss"]["training_loss"] in ["KCRPS"] - ), f'''{conf["loss"]["training_loss"]} loss incompatible with ensemble training. ensemble_size is {conf["trainer"]["ensemble_size"]}''' + ), f"""{conf["loss"]["training_loss"]} loss incompatible with ensemble training. ensemble_size is {conf["trainer"]["ensemble_size"]}""" if "load_scaler" not in conf["trainer"]: conf["trainer"]["load_scaler"] = False @@ -842,12 +725,8 @@ def credit_main_parser( conf["trainer"]["train_one_epoch"] = False if conf["trainer"]["train_one_epoch"] is False: - assert ( - "start_epoch" in conf["trainer"] - ), "must specify 'start_epoch' in conf['trainer']" - assert ( - "epochs" in conf["trainer"] - ), "must specify 'epochs' in conf['trainer']" + assert "start_epoch" in conf["trainer"], "must specify 'start_epoch' in conf['trainer']" + assert "epochs" in conf["trainer"], "must specify 'epochs' in conf['trainer']" else: conf["trainer"]["epochs"] = 999 if "num_epoch" in conf["trainer"]: @@ -859,9 +738,7 @@ def credit_main_parser( conf["trainer"]["amp"] = False if conf["trainer"]["amp"]: - assert ( - "load_scaler" in conf["trainer"] - ), "must specify 'load_scaler' in conf['trainer'] if AMP is used" + assert "load_scaler" in conf["trainer"], "must specify 'load_scaler' in conf['trainer'] if AMP is used" if "weight_decay" not in conf["trainer"]: conf["trainer"]["weight_decay"] = 0 @@ -888,15 +765,9 @@ def credit_main_parser( # conf['loss'] section if parse_training: - assert ( - "training_loss" in conf["loss"] - ), "Training loss ('training_loss') is missing from conf['loss']" - assert ( - "use_latitude_weights" in conf["loss"] - ), "must specify 'use_latitude_weights' in conf['loss']" - assert ( - "use_variable_weights" in conf["loss"] - ), "must specify 'use_variable_weights' in conf['loss']" + assert "training_loss" in conf["loss"], "Training loss ('training_loss') is missing from conf['loss']" + assert "use_latitude_weights" in conf["loss"], "must specify 'use_latitude_weights' in conf['loss']" + assert "use_variable_weights" in conf["loss"], "must specify 'use_variable_weights' in conf['loss']" if conf["loss"]["use_variable_weights"]: assert ( @@ -915,26 +786,16 @@ def credit_main_parser( varname_covered = list(conf["loss"]["variable_weights"].keys()) for varname in varname_upper_air: - assert ( - varname in varname_covered - ), "missing variable weights for '{}'".format(varname) + assert varname in varname_covered, "missing variable weights for '{}'".format(varname) N_weights = len(conf["loss"]["variable_weights"][varname]) - assert ( - N_weights == N_levels - ), "{} levels were defined, but weights only have {} levels".format( + assert N_weights == N_levels, "{} levels were defined, but weights only have {} levels".format( N_levels, N_weights ) - weights_dict_ordered[varname] = conf["loss"]["variable_weights"][ - varname - ] + weights_dict_ordered[varname] = conf["loss"]["variable_weights"][varname] for varname in varname_surface + varname_diagnostics: - assert ( - varname in varname_covered - ), "missing variable weights for '{}'".format(varname) - weights_dict_ordered[varname] = conf["loss"]["variable_weights"][ - varname - ] + assert varname in varname_covered, "missing variable weights for '{}'".format(varname) + weights_dict_ordered[varname] = conf["loss"]["variable_weights"][varname] conf["loss"]["variable_weights"] = weights_dict_ordered # ----------------------------------------------------------------------------------------- # @@ -946,9 +807,7 @@ def credit_main_parser( conf["loss"]["use_spectral_loss"] = False if conf["loss"]["use_power_loss"] and conf["loss"]["use_spectral_loss"]: - warnings.warn( - "'use_power_loss: True' and 'use_spectral_loss: True' are both applied" - ) + warnings.warn("'use_power_loss: True' and 'use_spectral_loss: True' are both applied") if conf["loss"]["use_power_loss"] or conf["loss"]["use_spectral_loss"]: if "spectral_lambda_reg" not in conf["loss"]: @@ -961,16 +820,12 @@ def credit_main_parser( # conf['parse_predict'] section if parse_predict: - assert ( - "forecasts" in conf["predict"] - ), "Rollout settings ('forecasts') is missing from conf['predict']" + assert "forecasts" in conf["predict"], "Rollout settings ('forecasts') is missing from conf['predict']" assert ( "save_forecast" in conf["predict"] ), "Rollout save location ('save_forecast') is missing from conf['predict']" - conf["predict"]["save_forecast"] = os.path.expandvars( - conf["predict"]["save_forecast"] - ) + conf["predict"]["save_forecast"] = os.path.expandvars(conf["predict"]["save_forecast"]) if "use_laplace_filter" not in conf["predict"]: conf["predict"]["use_laplace_filter"] = False @@ -985,9 +840,7 @@ def credit_main_parser( if "mode" in conf["trainer"]: conf["predict"]["mode"] = conf["trainer"]["mode"] else: - print( - "Resource type ('mode') is missing from both conf['trainer'] and conf['predict']" - ) + print("Resource type ('mode') is missing from both conf['trainer'] and conf['predict']") raise # ==================================================== # @@ -995,11 +848,7 @@ def credit_main_parser( if print_summary: print("Upper-air variables: {}".format(conf["data"]["variables"])) print("Surface variables: {}".format(conf["data"]["surface_variables"])) - print( - "Dynamic forcing variables: {}".format( - conf["data"]["dynamic_forcing_variables"] - ) - ) + print("Dynamic forcing variables: {}".format(conf["data"]["dynamic_forcing_variables"])) print("Diagnostic variables: {}".format(conf["data"]["diagnostic_variables"])) print("Forcing variables: {}".format(conf["data"]["forcing_variables"])) print("Static variables: {}".format(conf["data"]["static_variables"])) @@ -1033,36 +882,24 @@ def training_data_check(conf, print_summary=False): valid_years_range = conf["data"]["valid_years"] # convert year info to str for file name search - train_years = [ - str(year) for year in range(train_years_range[0], train_years_range[1]) - ] - valid_years = [ - str(year) for year in range(valid_years_range[0], valid_years_range[1]) - ] + train_years = [str(year) for year in range(train_years_range[0], train_years_range[1])] + valid_years = [str(year) for year in range(valid_years_range[0], valid_years_range[1])] # -------------------------------------------------- # # check file consistencies ## upper-air files all_ERA_files = sorted(glob(conf["data"]["save_loc"])) - train_ERA_files = [ - file for file in all_ERA_files if any(year in file for year in train_years) - ] - valid_ERA_files = [ - file for file in all_ERA_files if any(year in file for year in valid_years) - ] + train_ERA_files = [file for file in all_ERA_files if any(year in file for year in train_years)] + valid_ERA_files = [file for file in all_ERA_files if any(year in file for year in valid_years)] for i_year, year in enumerate(train_years): - assert ( - year in train_ERA_files[i_year] - ), "[Year {}] is missing from [upper-air files {}]".format( + assert year in train_ERA_files[i_year], "[Year {}] is missing from [upper-air files {}]".format( year, conf["data"]["save_loc"] ) for i_year, year in enumerate(valid_years): - assert ( - year in valid_ERA_files[i_year] - ), "[Year {}] is missing from [upper-air files {}]".format( + assert year in valid_ERA_files[i_year], "[Year {}] is missing from [upper-air files {}]".format( year, conf["data"]["save_loc"] ) @@ -1070,24 +907,16 @@ def training_data_check(conf, print_summary=False): if conf["data"]["flag_surface"]: surface_files = sorted(glob(conf["data"]["save_loc_surface"])) - train_surface_files = [ - file for file in surface_files if any(year in file for year in train_years) - ] - valid_surface_files = [ - file for file in surface_files if any(year in file for year in valid_years) - ] + train_surface_files = [file for file in surface_files if any(year in file for year in train_years)] + valid_surface_files = [file for file in surface_files if any(year in file for year in valid_years)] for i_year, year in enumerate(train_years): - assert ( - year in train_surface_files[i_year] - ), "[Year {}] is missing from [surface files {}]".format( + assert year in train_surface_files[i_year], "[Year {}] is missing from [surface files {}]".format( year, conf["data"]["save_loc_surface"] ) for i_year, year in enumerate(valid_years): - assert ( - year in valid_surface_files[i_year] - ), "[Year {}] is missing from [surface files {}]".format( + assert year in valid_surface_files[i_year], "[Year {}] is missing from [surface files {}]".format( year, conf["data"]["save_loc_surface"] ) @@ -1095,16 +924,8 @@ def training_data_check(conf, print_summary=False): if conf["data"]["flag_dyn_forcing"]: dyn_forcing_files = sorted(glob(conf["data"]["save_loc_dynamic_forcing"])) - train_dyn_forcing_files = [ - file - for file in dyn_forcing_files - if any(year in file for year in train_years) - ] - valid_dyn_forcing_files = [ - file - for file in dyn_forcing_files - if any(year in file for year in valid_years) - ] + train_dyn_forcing_files = [file for file in dyn_forcing_files if any(year in file for year in train_years)] + valid_dyn_forcing_files = [file for file in dyn_forcing_files if any(year in file for year in valid_years)] for i_year, year in enumerate(train_years): assert ( @@ -1124,36 +945,22 @@ def training_data_check(conf, print_summary=False): if conf["data"]["flag_diagnostic"]: diagnostic_files = sorted(glob(conf["data"]["save_loc_diagnostic"])) - train_diagnostic_files = [ - file - for file in diagnostic_files - if any(year in file for year in train_years) - ] - valid_diagnostic_files = [ - file - for file in diagnostic_files - if any(year in file for year in valid_years) - ] + train_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in train_years)] + valid_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in valid_years)] for i_year, year in enumerate(train_years): - assert ( - year in train_diagnostic_files[i_year] - ), "[Year {}] is missing from [diagnostic files {}]".format( + assert year in train_diagnostic_files[i_year], "[Year {}] is missing from [diagnostic files {}]".format( year, conf["data"]["save_loc_diagnostic"] ) for i_year, year in enumerate(valid_years): - assert ( - year in valid_diagnostic_files[i_year] - ), "[Year {}] is missing from [diagnostic files {}]".format( + assert year in valid_diagnostic_files[i_year], "[Year {}] is missing from [diagnostic files {}]".format( year, conf["data"]["save_loc_diagnostic"] ) if print_summary: print("Filename checking passed") - print( - "All input files can cover conf['data']['train_years'] and conf['data']['valid_years']" - ) + print("All input files can cover conf['data']['train_years'] and conf['data']['valid_years']") # --------------------------------------------------------------------------- # # variable checks @@ -1165,13 +972,11 @@ def training_data_check(conf, print_summary=False): assert all( varname in varnames_upper_air for varname in conf["data"]["variables"] - ), "upper-air variables [{}] are not fully covered by conf['data']['save_loc']".format( - conf["data"]["variables"] - ) + ), "upper-air variables [{}] are not fully covered by conf['data']['save_loc']".format(conf["data"]["variables"]) # assign the upper_air vars in yaml if it can pass checks varnames_upper_air = conf["data"]["variables"] - + # collecting all variables that require zscores # deep copy to avoid changing conf['data'] by accident all_vars = copy.deepcopy(conf["data"]["variables"]) @@ -1195,8 +1000,7 @@ def training_data_check(conf, print_summary=False): varnames_dyn_forcing = list(ds_dyn_forcing.keys()) assert all( - varname in varnames_dyn_forcing - for varname in conf["data"]["dynamic_forcing_variables"] + varname in varnames_dyn_forcing for varname in conf["data"]["dynamic_forcing_variables"] ), "Dynamic forcing variables [{}] are not fully covered by conf['data']['save_loc_dynamic_forcing']".format( conf["data"]["dynamic_forcing_variables"] ) @@ -1209,8 +1013,7 @@ def training_data_check(conf, print_summary=False): varnames_diagnostic = list(ds_diagnostic.keys()) assert all( - varname in varnames_diagnostic - for varname in conf["data"]["diagnostic_variables"] + varname in varnames_diagnostic for varname in conf["data"]["diagnostic_variables"] ), "Diagnostic variables [{}] are not fully covered by conf['data']['save_loc_diagnostic']".format( conf["data"]["diagnostic_variables"] ) @@ -1278,9 +1081,7 @@ def training_data_check(conf, print_summary=False): for coord_name in coord_surface: assert ds_upper_air.coords[coord_name].equals( ds_surface.coords[coord_name] - ), "coordinate {} mismatched between upper-air and surface files".format( - coord_name - ) + ), "coordinate {} mismatched between upper-air and surface files".format(coord_name) # dyn forcing files if conf["data"]["flag_dyn_forcing"]: @@ -1292,13 +1093,9 @@ def training_data_check(conf, print_summary=False): ), "Dynamic forcing file coordinate names mismatched with upper-air files" for coord_name in coord_dyn_forcing: - assert ds_upper_air.coords[ - coord_name - ].equals( + assert ds_upper_air.coords[coord_name].equals( ds_dyn_forcing.coords[coord_name] - ), "coordinate {} mismatched between upper-air and dynamic forcing files".format( - coord_name - ) + ), "coordinate {} mismatched between upper-air and dynamic forcing files".format(coord_name) # diagnostic files if conf["data"]["flag_diagnostic"]: @@ -1312,9 +1109,7 @@ def training_data_check(conf, print_summary=False): for coord_name in coord_diagnostic: assert ds_upper_air.coords[coord_name].equals( ds_diagnostic.coords[coord_name] - ), "coordinate {} mismatched between upper-air and diagnostic files".format( - coord_name - ) + ), "coordinate {} mismatched between upper-air and diagnostic files".format(coord_name) # forcing files if conf["data"]["flag_forcing"]: @@ -1328,9 +1123,7 @@ def training_data_check(conf, print_summary=False): for coord_name in coord_forcing: assert ds_upper_air.coords[coord_name].equals( ds_forcing.coords[coord_name] - ), "coordinate {} mismatched between upper-air and forcing files".format( - coord_name - ) + ), "coordinate {} mismatched between upper-air and forcing files".format(coord_name) # ============================================== # # !! assumed subdaily inputs, may need to fix !! # @@ -1349,9 +1142,7 @@ def training_data_check(conf, print_summary=False): for coord_name in coord_static: assert ds_upper_air.coords[coord_name].equals( ds_static.coords[coord_name] - ), "coordinate {} mismatched between upper-air and static files".format( - coord_name - ) + ), "coordinate {} mismatched between upper-air and static files".format(coord_name) # zscore mean file (no time coordinate) coord_mean = list(ds_mean.coords.keys()) @@ -1364,9 +1155,7 @@ def training_data_check(conf, print_summary=False): for coord_name in coord_mean: assert ds_upper_air.coords[coord_name].equals( ds_mean.coords[coord_name] - ), "coordinate {} mismatched between upper-air and mean files".format( - coord_name - ) + ), "coordinate {} mismatched between upper-air and mean files".format(coord_name) # zscore std file (no time coordinate) coord_std = list(ds_std.coords.keys()) @@ -1430,9 +1219,7 @@ def predict_data_check(conf, print_summary=False): # a rough estimate of how manys years of initializations are needed # !!! Can be improved !!! if "duration" in conf["predict"]["forecasts"]: - assert ( - "start_year" in conf["predict"]["forecasts"] - ), "Must specify which year to start predict." + assert "start_year" in conf["predict"]["forecasts"], "Must specify which year to start predict." N_years = conf["predict"]["forecasts"]["duration"] // 365 N_years = N_years + 1 else: @@ -1448,42 +1235,28 @@ def predict_data_check(conf, print_summary=False): ## upper-air files all_ERA_files = sorted(glob(conf["data"]["save_loc"])) - pred_ERA_files = [ - file for file in all_ERA_files if any(year in file for year in pred_years) - ] + pred_ERA_files = [file for file in all_ERA_files if any(year in file for year in pred_years)] if len(pred_years) != len(pred_ERA_files): - warnings.warn( - "Provided initializations in upper air files may not cover all forecasted dates" - ) + warnings.warn("Provided initializations in upper air files may not cover all forecasted dates") ## surface files if conf["data"]["flag_surface"]: surface_files = sorted(glob(conf["data"]["save_loc_surface"])) - pred_surface_files = [ - file for file in surface_files if any(year in file for year in pred_years) - ] + pred_surface_files = [file for file in surface_files if any(year in file for year in pred_years)] if len(pred_years) != len(pred_surface_files): - warnings.warn( - "Provided initializations in surface files may not cover all forecasted dates" - ) + warnings.warn("Provided initializations in surface files may not cover all forecasted dates") ## dynamic forcing files if conf["data"]["flag_dyn_forcing"]: dyn_forcing_files = sorted(glob(conf["data"]["save_loc_dynamic_forcing"])) - pred_dyn_forcing_files = [ - file - for file in dyn_forcing_files - if any(year in file for year in pred_years) - ] + pred_dyn_forcing_files = [file for file in dyn_forcing_files if any(year in file for year in pred_years)] if len(pred_years) != len(pred_dyn_forcing_files): - warnings.warn( - "Provided initializations in surface files may not cover all forecasted dates" - ) + warnings.warn("Provided initializations in surface files may not cover all forecasted dates") if print_summary: print("Filename checking passed") @@ -1499,9 +1272,7 @@ def predict_data_check(conf, print_summary=False): assert all( varname in varnames_upper_air for varname in conf["data"]["variables"] - ), "upper-air variables [{}] are not fully covered by conf['data']['save_loc']".format( - conf["data"]["variables"] - ) + ), "upper-air variables [{}] are not fully covered by conf['data']['save_loc']".format(conf["data"]["variables"]) # collecting all variables that require zscores # deep copy to avoid changing conf['data'] by accident @@ -1526,8 +1297,7 @@ def predict_data_check(conf, print_summary=False): varnames_dyn_forcing = list(ds_dyn_forcing.keys()) assert all( - varname in varnames_dyn_forcing - for varname in conf["data"]["dynamic_forcing_variables"] + varname in varnames_dyn_forcing for varname in conf["data"]["dynamic_forcing_variables"] ), "Dynamic forcing variables [{}] are not fully covered by conf['data']['save_loc_dynamic_forcing']".format( conf["data"]["dynamic_forcing_variables"] ) @@ -1599,9 +1369,7 @@ def predict_data_check(conf, print_summary=False): for coord_name in coord_surface: assert ds_upper_air.coords[coord_name].equals( ds_surface.coords[coord_name] - ), "coordinate {} mismatched between upper-air and surface files".format( - coord_name - ) + ), "coordinate {} mismatched between upper-air and surface files".format(coord_name) # dyn forcing files if conf["data"]["flag_dyn_forcing"]: @@ -1613,13 +1381,9 @@ def predict_data_check(conf, print_summary=False): ), "Dynamic forcing file coordinate names mismatched with upper-air files" for coord_name in coord_dyn_forcing: - assert ds_upper_air.coords[ - coord_name - ].equals( + assert ds_upper_air.coords[coord_name].equals( ds_dyn_forcing.coords[coord_name] - ), "coordinate {} mismatched between upper-air and dynamic forcing files".format( - coord_name - ) + ), "coordinate {} mismatched between upper-air and dynamic forcing files".format(coord_name) # forcing files if conf["data"]["flag_forcing"]: @@ -1633,9 +1397,7 @@ def predict_data_check(conf, print_summary=False): for coord_name in coord_forcing: assert ds_upper_air.coords[coord_name].equals( ds_forcing.coords[coord_name] - ), "coordinate {} mismatched between upper-air and forcing files".format( - coord_name - ) + ), "coordinate {} mismatched between upper-air and forcing files".format(coord_name) # ============================================== # # !! assumed subdaily inputs, may need to fix !! # @@ -1654,9 +1416,7 @@ def predict_data_check(conf, print_summary=False): for coord_name in coord_static: assert ds_upper_air.coords[coord_name].equals( ds_static.coords[coord_name] - ), "coordinate {} mismatched between upper-air and static files".format( - coord_name - ) + ), "coordinate {} mismatched between upper-air and static files".format(coord_name) # zscore mean file (no time coordinate) coord_mean = list(ds_mean.coords.keys()) @@ -1669,9 +1429,7 @@ def predict_data_check(conf, print_summary=False): for coord_name in coord_mean: assert ds_upper_air.coords[coord_name].equals( ds_mean.coords[coord_name] - ), "coordinate {} mismatched between upper-air and mean files".format( - coord_name - ) + ), "coordinate {} mismatched between upper-air and mean files".format(coord_name) # zscore std file (no time coordinate) coord_std = list(ds_std.coords.keys()) From 3c0bacdfb6c2455db9d638c5dd43bdb777633b5b Mon Sep 17 00:00:00 2001 From: David John Gagne Date: Tue, 18 Feb 2025 15:48:12 -0700 Subject: [PATCH 11/13] removed metric calc in rollout_to_netcdf --- applications/rollout_to_netcdf.py | 38 +- credit/data.py | 629 +++--------------------------- 2 files changed, 63 insertions(+), 604 deletions(-) diff --git a/applications/rollout_to_netcdf.py b/applications/rollout_to_netcdf.py index 08e7d644..51b59c9f 100644 --- a/applications/rollout_to_netcdf.py +++ b/applications/rollout_to_netcdf.py @@ -8,14 +8,12 @@ from pathlib import Path from argparse import ArgumentParser import multiprocessing as mp -from collections import defaultdict # ---------- # # Numerics from datetime import datetime, timedelta import xarray as xr import numpy as np -import pandas as pd # ---------- # import torch @@ -35,7 +33,6 @@ from credit.transforms import load_transforms, Normalize_ERA5_and_Forcing from credit.pbs import launch_script, launch_script_mpi from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter -from credit.metrics import LatWeightedMetrics from credit.forecast import load_forecasts from credit.distributed import distributed_model_wrapper, setup from credit.models.checkpoint import load_model_state, load_state_dict_error_handler @@ -222,6 +219,8 @@ def predict(rank, world_size, conf, p): model = distributed_model_wrapper(conf, model, device) # Load model weights (if any), an optimizer, scheduler, and gradient scaler model = load_model_state(conf, model, device) + else: + model = None # ================================================================================ # model.eval() @@ -230,11 +229,6 @@ def predict(rank, world_size, conf, p): latlons = xr.open_dataset(conf["loss"]["latitude_weights"]) meta_data = load_metadata(conf) - - # Set up metrics and containers - metrics = LatWeightedMetrics(conf, training_mode=False) - metrics_results = defaultdict(list) - # Set up the diffusion and pole filters if "use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]: dpf = Diffusion_and_Pole_Filter( @@ -334,18 +328,10 @@ def predict(rank, world_size, conf, p): # y_pred with unit y_pred = state_transformer.inverse_transform(y_pred.cpu()) - # y_target with unit - y = state_transformer.inverse_transform(y.cpu()) if "use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]: y_pred = dpf.diff_lap2d_filt(y_pred.to(device).squeeze()).unsqueeze(0).unsqueeze(2).cpu() - # Compute metrics - metrics_dict = metrics(y_pred.float(), y.float(), forecast_datetime=forecast_hour) - for k, m in metrics_dict.items(): - metrics_results[k].append(m.item()) - metrics_results["forecast_hour"].append(forecast_hour) - # Save the current forecast hour data in parallel utc_datetime = init_datetime + timedelta(hours=lead_time_periods * forecast_hour) @@ -372,13 +358,6 @@ def predict(rank, world_size, conf, p): ) results.append(result) - metrics_results["datetime"].append(utc_datetime) - - print_str = f"Forecast: {forecast_count} " - print_str += f"Date: {utc_datetime.strftime('%Y-%m-%d %H:%M:%S')} " - print_str += f"Hour: {batch['forecast_hour'].item()} " - print_str += f"ACC: {metrics_dict['acc']} " - # Update the input # setup for next iteration, transform to z-space and send to device y_pred = state_transformer.transform_array(y_pred).to(device) @@ -414,13 +393,6 @@ def predict(rank, world_size, conf, p): # Wait for all processes to finish in order for result in results: result.get() - - # save metrics file - save_location = os.path.join(os.path.expandvars(conf["save_loc"]), "forecasts", "metrics") - os.makedirs(save_location, exist_ok=True) # should already be made above - df = pd.DataFrame(metrics_results) - df.to_csv(os.path.join(save_location, f"metrics{init_datetime_str}.csv")) - # forecast count = a constant for each run forecast_count += 1 @@ -592,6 +564,6 @@ def predict(rank, world_size, conf, p): else: # single device inference _ = predict(0, 1, conf, p=p) - # Ensure all processes are finished - p.close() - p.join() + # Ensure all processes are finished + p.close() + p.join() diff --git a/credit/data.py b/credit/data.py index dbc392f2..d06aa540 100644 --- a/credit/data.py +++ b/credit/data.py @@ -62,10 +62,7 @@ def nanoseconds_to_year(nanoseconds_value): """ Given datetime info as nanoseconds, compute which year it belongs to. """ - return ( - np.datetime64(nanoseconds_value, "ns").astype("datetime64[Y]").astype(int) - + 1970 - ) + return np.datetime64(nanoseconds_value, "ns").astype("datetime64[Y]").astype(int) + 1970 def extract_month_day_hour(dates): @@ -73,9 +70,7 @@ def extract_month_day_hour(dates): Given an 1-d array of np.datatime64[ns], extract their mon, day, hr into a zipped list """ months = dates.astype("datetime64[M]").astype(int) % 12 + 1 - days = ( - (dates - dates.astype("datetime64[M]") + 1).astype("timedelta64[D]").astype(int) - ) + days = (dates - dates.astype("datetime64[M]") + 1).astype("timedelta64[D]").astype(int) hours = dates.astype("datetime64[h]").astype(int) % 24 return list(zip(months, days, hours)) @@ -98,9 +93,7 @@ def concat_and_reshape(x1, x2): """ Flattening the "level" coordinate of upper-air variables and concatenate it will surface variables. """ - x1 = x1.view( - x1.shape[0], x1.shape[1], x1.shape[2] * x1.shape[3], x1.shape[4], x1.shape[5] - ) + x1 = x1.view(x1.shape[0], x1.shape[1], x1.shape[2] * x1.shape[3], x1.shape[4], x1.shape[5]) x_concat = torch.cat((x1, x2), dim=2) return x_concat.permute(0, 2, 1, 3, 4) @@ -110,9 +103,7 @@ def reshape_only(x1): Flattening the "level" coordinate of upper-air variables. As in "concat_and_reshape", but no concat """ - x1 = x1.view( - x1.shape[0], x1.shape[1], x1.shape[2] * x1.shape[3], x1.shape[4], x1.shape[5] - ) + x1 = x1.view(x1.shape[0], x1.shape[1], x1.shape[2] * x1.shape[3], x1.shape[4], x1.shape[5]) return x1.permute(0, 2, 1, 3, 4) @@ -424,9 +415,7 @@ def __init__( if self.filename_forcing is not None: # drop variables if they are not in the config ds = get_forward_data(filename_forcing) - ds_forcing = drop_var_from_dataset( - ds, varname_forcing - ).load() # <---- load in static + ds_forcing = drop_var_from_dataset(ds, varname_forcing).load() # <---- load in static self.xarray_forcing = ds_forcing else: @@ -439,9 +428,7 @@ def __init__( if self.filename_static is not None: # drop variables if they are not in the config ds = get_forward_data(filename_static) - ds_static = drop_var_from_dataset( - ds, varname_static - ).load() # <---- load in static + ds_static = drop_var_from_dataset(ds, varname_static).load() # <---- load in static self.xarray_static = ds_static else: @@ -470,9 +457,7 @@ def __getitem__(self, index): ind_start_in_file = index - ind_start # handle out-of-bounds - ind_largest = len(self.all_files[int(ind_file)]["time"]) - ( - self.history_len + self.forecast_len + 1 - ) + ind_largest = len(self.all_files[int(ind_file)]["time"]) - (self.history_len + self.forecast_len + 1) if ind_start_in_file > ind_largest: ind_start_in_file = ind_largest @@ -496,9 +481,7 @@ def __getitem__(self, index): ) # .load() NOT load into memory ## merge upper-air and surface here: - ERA5_subset = ERA5_subset.merge( - surface_subset - ) # <-- lazy merge, ERA5 and surface both not loaded + ERA5_subset = ERA5_subset.merge(surface_subset) # <-- lazy merge, ERA5 and surface both not loaded # ==================================================== # # split ERA5_subset into training inputs and targets @@ -536,17 +519,11 @@ def __getitem__(self, index): # ------------------------------------------------------------------------------- # # matching month, day, hour between forcing and upper air [time] # this approach handles leap year forcing file and non-leap-year upper air file - month_day_forcing = extract_month_day_hour( - np.array(self.xarray_forcing["time"]) - ) - month_day_inputs = extract_month_day_hour( - np.array(historical_ERA5_images["time"]) - ) # <-- upper air + month_day_forcing = extract_month_day_hour(np.array(self.xarray_forcing["time"])) + month_day_inputs = extract_month_day_hour(np.array(historical_ERA5_images["time"])) # <-- upper air # indices to subset ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs) - forcing_subset_input = self.xarray_forcing.isel( - time=ind_forcing - ) # .load() # <-- loadded in init + forcing_subset_input = self.xarray_forcing.isel(time=ind_forcing) # .load() # <-- loadded in init # forcing and upper air have different years but the same mon/day/hour # safely replace forcing time with upper air time forcing_subset_input["time"] = historical_ERA5_images["time"] @@ -560,13 +537,9 @@ def __getitem__(self, index): if self.xarray_static: # expand static var on time dim N_time_dims = len(ERA5_subset["time"]) - static_subset_input = self.xarray_static.expand_dims( - dim={"time": N_time_dims} - ) + static_subset_input = self.xarray_static.expand_dims(dim={"time": N_time_dims}) # assign coords 'time' - static_subset_input = static_subset_input.assign_coords( - {"time": ERA5_subset["time"]} - ) + static_subset_input = static_subset_input.assign_coords({"time": ERA5_subset["time"]}) # slice + load to the GPU static_subset_input = static_subset_input.isel( @@ -585,9 +558,7 @@ def __getitem__(self, index): if self.one_shot is not None: # one_shot is True (on), go straight to the last element - target_ERA5_images = ERA5_subset.isel( - time=slice(-1, None) - ).load() # <-- load into memory + target_ERA5_images = ERA5_subset.isel(time=slice(-1, None)).load() # <-- load into memory ## merge diagnoisc input here: if self.diagnostic_files: @@ -595,9 +566,7 @@ def __getitem__(self, index): time=slice(ind_start_in_file, ind_end_in_file + 1) ) - diagnostic_subset = diagnostic_subset.isel( - time=slice(-1, None) - ).load() # <-- load into memory + diagnostic_subset = diagnostic_subset.isel(time=slice(-1, None)).load() # <-- load into memory target_ERA5_images = target_ERA5_images.merge(diagnostic_subset) @@ -642,9 +611,9 @@ def __getitem__(self, index): # for multi-input cases, use time=-1 ocean SKT for all times if self.history_len > 1: - input_skt[: self.history_len - 1] = input_skt[ - : self.history_len - 1 - ].where(~ocean_mask_bool, input_skt.isel(time=-1)) + input_skt[: self.history_len - 1] = input_skt[: self.history_len - 1].where( + ~ocean_mask_bool, input_skt.isel(time=-1) + ) # for target skt, replace ocean values using time=-1 input SKT target_skt = target_skt.where(~ocean_mask_bool, input_skt.isel(time=-1)) @@ -883,9 +852,7 @@ def __init__( if self.filename_forcing is not None: # drop variables if they are not in the config ds = get_forward_data(filename_forcing) - ds_forcing = drop_var_from_dataset( - ds, varname_forcing - ).load() # <---- load in static + ds_forcing = drop_var_from_dataset(ds, varname_forcing).load() # <---- load in static self.xarray_forcing = ds_forcing else: @@ -898,9 +865,7 @@ def __init__( if self.filename_static is not None: # drop variables if they are not in the config ds = get_forward_data(filename_static) - ds_static = drop_var_from_dataset( - ds, varname_static - ).load() # <---- load in static + ds_static = drop_var_from_dataset(ds, varname_static).load() # <---- load in static self.xarray_static = ds_static else: @@ -929,9 +894,7 @@ def __getitem__(self, index): ind_start_in_file = index - ind_start # handle out-of-bounds - ind_largest = len(self.all_files[int(ind_file)]["time"]) - ( - self.history_len + self.forecast_len + 1 - ) + ind_largest = len(self.all_files[int(ind_file)]["time"]) - (self.history_len + self.forecast_len + 1) if ind_start_in_file > ind_largest: ind_start_in_file = ind_largest @@ -955,9 +918,7 @@ def __getitem__(self, index): ) # .load() NOT load into memory ## merge upper-air and surface here: - ERA5_subset = ERA5_subset.merge( - surface_subset - ) # <-- lazy merge, ERA5 and surface both not loaded + ERA5_subset = ERA5_subset.merge(surface_subset) # <-- lazy merge, ERA5 and surface both not loaded # ==================================================== # # split ERA5_subset into training inputs and targets @@ -995,17 +956,11 @@ def __getitem__(self, index): # ------------------------------------------------------------------------------- # # matching month, day, hour between forcing and upper air [time] # this approach handles leap year forcing file and non-leap-year upper air file - month_day_forcing = extract_month_day_hour( - np.array(self.xarray_forcing["time"]) - ) - month_day_inputs = extract_month_day_hour( - np.array(historical_ERA5_images["time"]) - ) # <-- upper air + month_day_forcing = extract_month_day_hour(np.array(self.xarray_forcing["time"])) + month_day_inputs = extract_month_day_hour(np.array(historical_ERA5_images["time"])) # <-- upper air # indices to subset ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs) - forcing_subset_input = self.xarray_forcing.isel( - time=ind_forcing - ) # .load() # <-- loadded in init + forcing_subset_input = self.xarray_forcing.isel(time=ind_forcing) # .load() # <-- loadded in init # forcing and upper air have different years but the same mon/day/hour # safely replace forcing time with upper air time forcing_subset_input["time"] = historical_ERA5_images["time"] @@ -1019,13 +974,9 @@ def __getitem__(self, index): if self.xarray_static: # expand static var on time dim N_time_dims = len(ERA5_subset["time"]) - static_subset_input = self.xarray_static.expand_dims( - dim={"time": N_time_dims} - ) + static_subset_input = self.xarray_static.expand_dims(dim={"time": N_time_dims}) # assign coords 'time' - static_subset_input = static_subset_input.assign_coords( - {"time": ERA5_subset["time"]} - ) + static_subset_input = static_subset_input.assign_coords({"time": ERA5_subset["time"]}) # slice + load to the GPU static_subset_input = static_subset_input.isel( @@ -1044,9 +995,7 @@ def __getitem__(self, index): if self.one_shot is not None: # one_shot is True (on), go straight to the last element - target_ERA5_images = ERA5_subset.isel( - time=slice(-1, None) - ).load() # <-- load into memory + target_ERA5_images = ERA5_subset.isel(time=slice(-1, None)).load() # <-- load into memory ## merge diagnoisc input here: if self.diagnostic_files: @@ -1054,9 +1003,7 @@ def __getitem__(self, index): time=slice(ind_start_in_file, ind_end_in_file + 1) ) - diagnostic_subset = diagnostic_subset.isel( - time=slice(-1, None) - ).load() # <-- load into memory + diagnostic_subset = diagnostic_subset.isel(time=slice(-1, None)).load() # <-- load into memory target_ERA5_images = target_ERA5_images.merge(diagnostic_subset) @@ -1101,9 +1048,9 @@ def __getitem__(self, index): # for multi-input cases, use time=-1 ocean SKT for all times if self.history_len > 1: - input_skt[: self.history_len - 1] = input_skt[ - : self.history_len - 1 - ].where(~ocean_mask_bool, input_skt.isel(time=-1)) + input_skt[: self.history_len - 1] = input_skt[: self.history_len - 1].where( + ~ocean_mask_bool, input_skt.isel(time=-1) + ) # for target skt, replace ocean values using time=-1 input SKT target_skt = target_skt.where(~ocean_mask_bool, input_skt.isel(time=-1)) @@ -1170,9 +1117,7 @@ def __init__( self.history_len = history_len self.init_datetime = fcst_datetime - self.which_forecast = ( - which_forecast # <-- got from the old roll-out script. Dont know - ) + self.which_forecast = which_forecast # <-- got from the old roll-out script. Dont know # -------------------------------------- # # file names @@ -1198,9 +1143,7 @@ def __init__( for fn in self.filenames: # drop variables if they are not in the config xarray_dataset = get_forward_data(filename=fn) - xarray_dataset = drop_var_from_dataset( - xarray_dataset, self.varname_upper_air - ) + xarray_dataset = drop_var_from_dataset(xarray_dataset, self.varname_upper_air) # collect yearly datasets within a list all_files.append(xarray_dataset) self.all_files = all_files @@ -1225,9 +1168,7 @@ def load_zarr_as_input(self, i_file, i_init_start, i_init_end, mode="input"): # open the zarr file as xr.dataset and subset based on the needed time # sliced_x: the final output, starts with an upper air xr.dataset - sliced_x = self.ds_read_and_subset( - self.filenames[i_file], i_init_start, i_init_end + 1, self.varname_upper_air - ) + sliced_x = self.ds_read_and_subset(self.filenames[i_file], i_init_start, i_init_end + 1, self.varname_upper_air) # surface variables if self.filename_surface is not None: sliced_surface = self.ds_read_and_subset( @@ -1256,22 +1197,16 @@ def load_zarr_as_input(self, i_file, i_init_start, i_init_end, mode="input"): # forcing / static if self.filename_forcing is not None: sliced_forcing = get_forward_data(self.filename_forcing) - sliced_forcing = drop_var_from_dataset( - sliced_forcing, self.varname_forcing - ) + sliced_forcing = drop_var_from_dataset(sliced_forcing, self.varname_forcing) # See also `ERA5_and_Forcing_Dataset` # =============================================================================== # # matching month, day, hour between forcing and upper air [time] # this approach handles leap year forcing file and non-leap-year upper air file - month_day_forcing = extract_month_day_hour( - np.array(sliced_forcing["time"]) - ) + month_day_forcing = extract_month_day_hour(np.array(sliced_forcing["time"])) month_day_inputs = extract_month_day_hour(np.array(sliced_x["time"])) # indices to subset - ind_forcing, _ = find_common_indices( - month_day_forcing, month_day_inputs - ) + ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs) sliced_forcing = sliced_forcing.isel(time=ind_forcing) # forcing and upper air have different years but the same mon/day/hour # safely replace forcing time with upper air time @@ -1283,12 +1218,8 @@ def load_zarr_as_input(self, i_file, i_init_start, i_init_end, mode="input"): if self.filename_static is not None: sliced_static = get_forward_data(self.filename_static) - sliced_static = drop_var_from_dataset( - sliced_static, self.varname_static - ) - sliced_static = sliced_static.expand_dims( - dim={"time": len(sliced_x["time"])} - ) + sliced_static = drop_var_from_dataset(sliced_static, self.varname_static) + sliced_static = sliced_static.expand_dims(dim={"time": len(sliced_x["time"])}) sliced_static["time"] = sliced_x["time"] # merge static to sliced_x sliced_x = sliced_x.merge(sliced_static) @@ -1312,9 +1243,7 @@ def find_start_stop_indices(self, index): # ============================================================================ # # shift hours for history_len > 1, becuase more than one init times are needed # <--- !! it MAY NOT work when self.skip_period != 1 - shifted_hours = ( - self.lead_time_periods * self.skip_periods * (self.history_len - 1) - ) + shifted_hours = self.lead_time_periods * self.skip_periods * (self.history_len - 1) # ============================================================================ # # subtrack shifted_hour form the 1st & last init times # convert to datetime object @@ -1332,17 +1261,13 @@ def find_start_stop_indices(self, index): self.lead_time_periods, ) # convert datetime obj to nanosecondes - init_time_list_dt = [ - np.datetime64(date.strftime("%Y-%m-%d %H:%M:%S")) - for date in self.init_datetime[index] - ] + init_time_list_dt = [np.datetime64(date.strftime("%Y-%m-%d %H:%M:%S")) for date in self.init_datetime[index]] # init_time_list_np: a list of python datetime objects, each is a forecast step # init_time_list_np[0]: the first initialization time # init_time_list_np[t]: the forcasted time of the (t-1)th step; the initialization time of the t-th step self.init_time_list_np = [ - np.datetime64(str(dt_obj) + ".000000000").astype(datetime.datetime) - for dt_obj in init_time_list_dt + np.datetime64(str(dt_obj) + ".000000000").astype(datetime.datetime) for dt_obj in init_time_list_dt ] info = [] @@ -1358,10 +1283,7 @@ def find_start_stop_indices(self, index): if init_year0 == ds_year: N_times = len(ds["time"]) # convert ds['time'] to a list of nanosecondes - ds_time_list = [ - np.datetime64(ds_time.values).astype(datetime.datetime) - for ds_time in ds["time"] - ] + ds_time_list = [np.datetime64(ds_time.values).astype(datetime.datetime) for ds_time in ds["time"]] ds_start_time = ds_time_list[0] ds_end_time = ds_time_list[-1] @@ -1372,9 +1294,7 @@ def find_start_stop_indices(self, index): i_init_start = ds_time_list.index(init_time_start) # for multiple init time inputs (history_len > 1), init_end is different for init_start - init_time_end = init_time_start + hour_to_nanoseconds( - shifted_hours - ) + init_time_end = init_time_start + hour_to_nanoseconds(shifted_hours) # see if init_time_end is alos in this file if ds_start_time <= init_time_end <= ds_end_time: @@ -1414,55 +1334,33 @@ def __iter__(self): output_dict = {} # get all inputs in one xr.Dataset - sliced_x = self.load_zarr_as_input( - i_file, i_init_start, i_init_end, mode="input" - ) + sliced_x = self.load_zarr_as_input(i_file, i_init_start, i_init_end, mode="input") # Check if additional data from the next file is needed - if (len(sliced_x["time"]) < self.history_len) or ( - i_init_end + 1 >= N_times - ): + if (len(sliced_x["time"]) < self.history_len) or (i_init_end + 1 >= N_times): # Load excess data from the next file next_file_idx = self.filenames.index(self.filenames[i_file]) + 1 if next_file_idx >= len(self.filenames): # not enough input data to support this forecast - raise OSError( - "You have reached the end of the available data. Exiting." - ) + raise OSError("You have reached the end of the available data. Exiting.") else: - sliced_y = self.load_zarr_as_input( - i_file, i_init_end, i_init_end, mode="target" - ) + sliced_y = self.load_zarr_as_input(i_file, i_init_end, i_init_end, mode="target") # i_init_start = 0 because we need the beginning of the next file only - sliced_x_next = self.load_zarr_as_input( - next_file_idx, 0, self.history_len, mode="input" - ) - sliced_y_next = self.load_zarr_as_input( - next_file_idx, 0, 1, mode="target" - ) + sliced_x_next = self.load_zarr_as_input(next_file_idx, 0, self.history_len, mode="input") + sliced_y_next = self.load_zarr_as_input(next_file_idx, 0, 1, mode="target") # 1 becuase taregt is one step a time # Concatenate excess data from the next file with the current data - sliced_x_combine = xr.concat( - [sliced_x, sliced_x_next], dim="time" - ) - sliced_y_combine = xr.concat( - [sliced_y, sliced_y_next], dim="time" - ) - - sliced_x = sliced_x_combine.isel( - time=slice(0, self.history_len) - ) - sliced_y = sliced_y_combine.isel( - time=slice(self.history_len, self.history_len + 1) - ) + sliced_x_combine = xr.concat([sliced_x, sliced_x_next], dim="time") + sliced_y_combine = xr.concat([sliced_y, sliced_y_next], dim="time") + + sliced_x = sliced_x_combine.isel(time=slice(0, self.history_len)) + sliced_y = sliced_y_combine.isel(time=slice(self.history_len, self.history_len + 1)) else: - sliced_y = self.load_zarr_as_input( - i_file, i_init_end + 1, i_init_end + 1, mode="target" - ) + sliced_y = self.load_zarr_as_input(i_file, i_init_end + 1, i_init_end + 1, mode="target") sample_x = { "historical_ERA5_images": sliced_x, @@ -1479,421 +1377,10 @@ def __iter__(self): output_dict["forecast_hour"] = k + 1 # Adjust stopping condition output_dict["stop_forecast"] = k == (len(self.init_time_list_np) - 1) - output_dict["datetime"] = sliced_x.time.values.astype( - "datetime64[s]" - ).astype(int)[-1] + output_dict["datetime"] = sliced_x.time.values.astype("datetime64[s]").astype(int)[-1] # return output_dict yield output_dict if output_dict["stop_forecast"]: break - -# # =============================================== # -# # This dataset works for hourly model only -# # it does not support forcing & static here -# # but it pairs to the ToTensor that adds -# # TOA (which has problem) and other static fields -# # =============================================== # -# class ERA5Dataset(torch.utils.data.Dataset): -# def __init__( -# self, -# filenames: list = ( -# "/glade/derecho/scratch/wchapman/STAGING/TOTAL_2012-01-01_2012-12-31_staged.zarr", -# "/glade/derecho/scratch/wchapman/STAGING/TOTAL_2013-01-01_2013-12-31_staged.zarr", -# ), -# history_len: int = 1, -# forecast_len: int = 2, -# transform: Optional[Callable] = None, -# seed=42, -# skip_periods=None, -# one_shot=None, -# max_forecast_len=None, -# ): -# self.history_len = history_len -# self.forecast_len = forecast_len -# self.transform = transform -# self.skip_periods = skip_periods -# self.one_shot = one_shot -# self.total_seq_len = self.history_len + self.forecast_len -# all_fils = [] -# filenames = sorted(filenames) -# for fn in filenames: -# all_fils.append(get_forward_data(filename=fn)) -# self.all_fils = all_fils -# self.data_array = all_fils[0] -# self.rng = np.random.default_rng(seed=seed) -# self.max_forecast_len = max_forecast_len - -# # set data places: -# indo = 0 -# self.meta_data_dict = {} -# for ee, bb in enumerate(self.all_fils): -# self.meta_data_dict[str(ee)] = [ -# len(bb["time"]), -# indo, -# indo + len(bb["time"]), -# ] -# indo += len(bb["time"]) + 1 - -# # set out of bounds indexes... -# OOB = [] -# for kk in self.meta_data_dict.keys(): -# OOB.append(generate_integer_list_around(self.meta_data_dict[kk][2])) -# self.OOB = flatten_list(OOB) - -# def __post_init__(self): -# # Total sequence length of each sample. -# self.total_seq_len = self.history_len + self.forecast_len - -# def __len__(self): -# tlen = 0 -# for bb in self.all_fils: -# tlen += len(bb["time"]) - self.total_seq_len + 1 -# return tlen - -# def __getitem__(self, index): -# # find the result key: -# result_key = find_key_for_number(index, self.meta_data_dict) - -# # get the data selection: -# true_ind = index - self.meta_data_dict[result_key][1] - -# if true_ind > ( -# len(self.all_fils[int(result_key)]["time"]) -# - (self.history_len + self.forecast_len + 1) -# ): -# true_ind = len(self.all_fils[int(result_key)]["time"]) - ( -# self.history_len + self.forecast_len + 1 -# ) - -# datasel = self.all_fils[int(result_key)].isel( -# time=slice(true_ind, true_ind + self.history_len + self.forecast_len + 1) -# ) - -# if (self.skip_periods is not None) and (self.one_shot is None): -# sample = Sample( -# historical_ERA5_images=datasel.isel( -# time=slice(0, self.history_len, self.skip_periods) -# ), -# target_ERA5_images=datasel.isel( -# time=slice( -# self.history_len, len(datasel["time"]), self.skip_periods -# ) -# ), -# datetime_index=datasel.time.values.astype("datetime64[s]").astype(int), -# ) - -# elif (self.skip_periods is not None) and (self.one_shot is not None): -# target_ERA5_images = datasel.isel( -# time=slice(self.history_len, len(datasel["time"]), self.skip_periods) -# ) -# target_ERA5_images = target_ERA5_images.isel(time=slice(0, 1)) - -# sample = Sample( -# historical_ERA5_images=datasel.isel( -# time=slice(0, self.history_len, self.skip_periods) -# ), -# target_ERA5_images=target_ERA5_images, -# datetime_index=datasel.time.values.astype("datetime64[s]").astype(int), -# ) - -# elif self.one_shot is not None: -# historical_data = datasel.isel(time=slice(0, self.history_len)).load() -# target_data = datasel.isel(time=slice(-1, None)).load() -# # Create the Sample object with the loaded data -# sample = Sample( -# historical_ERA5_images=historical_data, -# target_ERA5_images=target_data, -# datetime_index=[ -# int( -# historical_data.time.values[0] -# .astype("datetime64[s]") -# .astype(int) -# ), -# int(target_data.time.values[0].astype("datetime64[s]").astype(int)), -# ], -# ) -# else: -# sample = Sample( -# historical_ERA5_images=datasel.isel(time=slice(0, self.history_len)), -# target_ERA5_images=datasel.isel( -# time=slice(self.history_len, len(datasel["time"])) -# ), -# datetime_index=datasel.time.values.astype("datetime64[s]").astype(int), -# ) - -# if self.transform: -# sample = self.transform(sample) - -# sample["index"] = index - -# return sample - - -# # ================================= # -# # This dataset is old, but not sure -# # if anyone still uses it -# # ================================= # -# class Dataset_BridgeScaler(torch.utils.data.Dataset): -# def __init__( -# self, -# conf, -# conf_dataset, -# transform: Optional[Callable] = None, -# ): -# years_do = list(conf["data"][conf_dataset]) -# self.available_dates = pd.date_range( -# str(years_do[0]), str(years_do[1]), freq="1H" -# ) -# self.data_path = str(conf["data"]["bs_data_path"]) -# self.history_len = int(conf["data"]["history_len"]) -# self.forecast_len = int(conf["data"]["forecast_len"]) -# self.forecast_len = 1 if self.forecast_len == 0 else self.forecast_len -# self.file_format = str(conf["data"]["bs_file_format"]) -# self.transform = transform -# self.skip_periods = conf["data"]["skip_periods"] -# self.one_shot = conf["data"]["one_shot"] -# self.total_seq_len = self.history_len + self.forecast_len -# self.first_date = self.available_dates[0] -# self.last_date = self.available_dates[-1] - -# def __post_init__(self): -# # Total sequence length of each sample. -# self.total_seq_len = self.history_len + self.forecast_len - -# def __len__(self): -# tlen = 0 -# tlen = len(self.available_dates) -# return tlen - -# def evenly_spaced_indlist(self, index, skip_periods, forecast_len, history_len): -# # Initialize the list with the base index -# indlist = [index] - -# # Add forecast indices -# for i in range(1, forecast_len + 1): -# indlist.append(index + i * skip_periods) - -# # Add history indices -# for i in range(1, history_len + 1): -# indlist.append(index - i * skip_periods) - -# # Sort the list to maintain order -# indlist = sorted(indlist) -# return indlist - -# def __getitem__(self, index): -# if (self.skip_periods is None) & (self.one_shot is None): -# date_index = self.available_dates[index] - -# indlist = sorted( -# [index] -# + [index + (i) + 1 for i in range(self.forecast_len)] -# + [index - i - 1 for i in range(self.history_len)] -# ) - -# if np.min(indlist) < 0: -# indlist = list(np.array(indlist) + np.abs(np.min(indlist))) -# index += np.abs(np.min(indlist)) -# if np.max(indlist) >= self.__len__(): -# indlist = list( -# np.array(indlist) - np.abs(np.max(indlist)) + self.__len__() - 1 -# ) -# index -= np.abs(np.max(indlist)) -# date_index = self.available_dates[indlist] -# str_tot_find = f"%Y/%m/%d/{self.file_format}" -# fs = [f"{self.data_path}/{bb.strftime(str_tot_find)}" for bb in date_index] -# if len(fs) < 2: -# raise "Must be greater than one day in the list [x and x+1 minimum]" - -# fe = [1 if os.path.exists(fn) else 0 for fn in fs] -# if np.sum(fe) == len(fs): -# pass -# else: -# raise "weve left the training dataset, check your dataloader logic" - -# DShist = xr.open_mfdataset(fs[1 : self.history_len + 1]).load() -# DSfor = xr.open_mfdataset( -# fs[self.history_len + 1 : self.history_len + 1 + self.forecast_len] -# ).load() - -# sample = Sample( -# historical_ERA5_images=DShist, -# target_ERA5_images=DSfor, -# datetime_index=date_index, -# ) - -# if self.transform: -# sample = self.transform(sample) -# return sample -# if self.one_shot is not None: -# date_index = self.available_dates[index] - -# indlist = sorted( -# [index] -# + [index + (i) + 1 for i in range(self.forecast_len)] -# + [index - i - 1 for i in range(self.history_len)] -# ) -# # indlist.append(index+self.one_shot) - -# if np.min(indlist) < 0: -# indlist = list(np.array(indlist) + np.abs(np.min(indlist))) -# index += np.abs(np.min(indlist)) -# if np.max(indlist) >= self.__len__(): -# indlist = list( -# np.array(indlist) - np.abs(np.max(indlist)) + self.__len__() - 1 -# ) -# index -= np.abs(np.max(indlist)) - -# date_index = self.available_dates[indlist] -# str_tot_find = f"%Y/%m/%d/{self.file_format}" -# fs = [f"{self.data_path}/{bb.strftime(str_tot_find)}" for bb in date_index] - -# if len(fs) < 2: -# raise "Must be greater than one day in the list [x and x+1 minimum]" - -# fe = [1 if os.path.exists(fn) else 0 for fn in fs] -# if np.sum(fe) == len(fs): -# pass -# else: -# raise "weve left the training dataset, check your dataloader logic" - -# DShist = xr.open_mfdataset(fs[: self.history_len]).load() -# DSfor = xr.open_mfdataset(fs[-2]).load() - -# sample = Sample( -# historical_ERA5_images=DShist, -# target_ERA5_images=DSfor, -# datetime_index=date_index, -# ) - -# if self.transform: -# sample = self.transform(sample) -# return sample - -# if (self.skip_periods is not None) and (self.one_shot is None): -# date_index = self.available_dates[index] -# indlist = self.evenly_spaced_indlist( -# index, self.skip_periods, self.forecast_len, self.history_len -# ) - -# if np.min(indlist) < 0: -# indlist = list(np.array(indlist) + np.abs(np.min(indlist))) -# index += np.abs(np.min(indlist)) -# if np.max(indlist) >= self.__len__(): -# indlist = list( -# np.array(indlist) - np.abs(np.max(indlist)) + self.__len__() - 1 -# ) -# index -= np.abs(np.max(indlist)) - -# date_index = self.available_dates[indlist] -# str_tot_find = f"%Y/%m/%d/{self.file_format}" -# fs = [f"{self.data_path}/{bb.strftime(str_tot_find)}" for bb in date_index] - -# if len(fs) < 2: -# raise "Must be greater than one day in the list [x and x+1 minimum]" - -# fe = [1 if os.path.exists(fn) else 0 for fn in fs] -# if np.sum(fe) == len(fs): -# pass -# else: -# raise "weve left the training dataset, check your dataloader logic" - -# DShist = xr.open_mfdataset(fs[: self.history_len]).load() -# DSfor = xr.open_mfdataset( -# fs[self.history_len : self.history_len + self.forecast_len] -# ).load() - -# sample = Sample( -# historical_ERA5_images=DShist, -# target_ERA5_images=DSfor, -# datetime_index=date_index, -# ) - -# if self.transform: -# sample = self.transform(sample) -# return sample - - -# # ================================================================== # -# # Note: DistributedSequentialDataset & DistributedSequentialDataset -# # are legacy; they wrap ERA5Dataset to send data batches to GPUs for -# # (1 class of?) huge sharded models, but otherwise have been -# # superseded by ERA5Dataset. -# # ================================================================== # -# class SequentialDataset(torch.utils.data.Dataset): -# def __init__( -# self, -# filenames, -# history_len=1, -# forecast_len=2, -# skip_periods=1, -# transform=None, -# random_forecast=True, -# ): -# self.dataset = ERA5Dataset( -# filenames=filenames, -# history_len=history_len, -# forecast_len=forecast_len, -# transform=transform, -# ) -# self.meta_data_dict = self.dataset.meta_data_dict -# self.all_fils = self.dataset.all_fils -# self.history_len = history_len -# self.forecast_len = forecast_len -# self.filenames = filenames -# self.transform = transform -# self.skip_periods = skip_periods -# self.random_forecast = random_forecast -# self.iteration_count = 0 -# self.current_epoch = 0 -# self.adjust_forecast = 0 - -# self.index_list = [] -# for i, x in enumerate(self.all_fils): -# times = x["time"].values -# slices = np.arange(0, times.shape[0] - (self.forecast_len + 1)) -# self.index_list += [(i, slice) for slice in slices] - -# def __len__(self): -# return len(self.index_list) - -# def set_params(self, epoch): -# self.current_epoch = epoch -# self.iteration_count = 0 - -# def __getitem__(self, index): -# if self.random_forecast and (self.iteration_count % self.forecast_len == 0): -# # Randomly choose a starting point within a valid range -# max_start = len(self.index_list) - (self.forecast_len + 1) -# self.adjust_forecast = np.random.randint(0, max_start + 1) - -# index = (index + self.adjust_forecast) % self.__len__() -# file_id, slice_idx = self.index_list[index] - -# dataset = xr.open_zarr(self.filenames[file_id], consolidated=True).isel( -# time=slice(slice_idx, slice_idx + self.skip_periods + 1, self.skip_periods) -# ) - -# sample = { -# "x": dataset.isel(time=slice(0, 1, 1)), -# "y": dataset.isel(time=slice(1, 2, 1)), -# } - -# if self.transform: -# sample = self.transform(sample) - -# sample["forecast_hour"] = self.iteration_count -# sample["forecast_datetime"] = dataset.time.values.astype( -# "datetime64[s]" -# ).astype(int) -# sample["stop_forecast"] = False - -# if self.iteration_count == self.forecast_len - 1: -# sample["stop_forecast"] = True - -# # Increment the iteration count -# self.iteration_count += 1 - -# return sample From b7adacaa72034493e62c399ba1bf7dafcdd78d15 Mon Sep 17 00:00:00 2001 From: David John Gagne Date: Tue, 18 Feb 2025 18:23:44 -0700 Subject: [PATCH 12/13] Fixed geopotential to pressure levels --- credit/interp.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/credit/interp.py b/credit/interp.py index 376e807f..6a732906 100644 --- a/credit/interp.py +++ b/credit/interp.py @@ -133,7 +133,6 @@ def full_state_pressure_interpolation( pressure_levels, state_dataset[surface_pressure_var][t].values / 100.0, surface_geopotential, - geopotential_grid, state_dataset[temperature_var][t].values, ) pressure_ds["mean_sea_level_" + pres_var][t] = mean_sea_level_pressure( @@ -293,12 +292,11 @@ def interp_pressure_to_hybrid_levels(pressure_var, pressure_levels, model_pressu @njit def interp_geopotential_to_pressure_levels( - model_var, + geopotential, model_pressure, interp_pressures, surface_pressure, surface_geopotential, - geopotential, temperature_k, temp_height=150, ): @@ -308,12 +306,11 @@ def interp_geopotential_to_pressure_levels( below the surface based on Eq. 15 in Trenberth et al. (1993). Args: - model_var (np.ndarray): 3D field on hybrid sigma-pressure levels with shape (levels, y, x). + geopotential (np.ndarray): geopotential in units m^2/s^2. model_pressure (np.ndarray): 3D pressure field with shape (levels, y, x) in units Pa or hPa interp_pressures (np.ndarray): pressure levels for interpolation in units Pa or hPa. surface_pressure (np.ndarray): pressure at the surface in units Pa or hPa. surface_geopotential (np.ndarray): geopotential at the surface in units m^2/s^2. - geopotential (np.ndarray): geopotential in units m^2/s^2. temperaure_k (np.ndarray): temperature in units K. temp_height (float): height above ground of nearest vertical grid cell. Returns: @@ -322,12 +319,12 @@ def interp_geopotential_to_pressure_levels( LAPSE_RATE = 0.0065 # K / m ALPHA = LAPSE_RATE * RDGAS / GRAVITY pressure_var = np.zeros( - (interp_pressures.shape[0], model_var.shape[1], model_var.shape[2]), - dtype=model_var.dtype, + (interp_pressures.shape[0], geopotential.shape[1], geopotential.shape[2]), + dtype=geopotential.dtype, ) log_interp_pressures = np.log(interp_pressures) - for (i, j), v in np.ndenumerate(model_var[0]): - pressure_var[:, i, j] = np.interp(log_interp_pressures, np.log(model_pressure[:, i, j]), model_var[:, i, j]) + for (i, j), v in np.ndenumerate(geopotential[0]): + pressure_var[:, i, j] = np.interp(log_interp_pressures, np.log(model_pressure[:, i, j]), geopotential[:, i, j]) for pl, interp_pressure in enumerate(interp_pressures): if interp_pressure > surface_pressure[i, j]: height_agl = (geopotential[:, i, j] - surface_geopotential[i, j]) / GRAVITY From 156e435c391376327015d93eaf95fa0e5bac8679 Mon Sep 17 00:00:00 2001 From: David John Gagne Date: Wed, 19 Feb 2025 13:14:11 -0700 Subject: [PATCH 13/13] Added elevation corrections to geopotential extrapolation. --- credit/interp.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/credit/interp.py b/credit/interp.py index 6a732906..4ec1beb8 100644 --- a/credit/interp.py +++ b/credit/interp.py @@ -332,9 +332,23 @@ def interp_geopotential_to_pressure_levels( temp_surface_k = temperature_k[h, i, j] + ALPHA * temperature_k[h, i, j] * ( surface_pressure[i, j] / model_pressure[h, i, j] - 1 ) + surface_height = surface_geopotential[i, j] / GRAVITY + temp_sea_level_k = temp_surface_k + LAPSE_RATE * surface_height + temp_pl = np.minimum(temp_sea_level_k, 298.0) + if surface_height > 2500.0: + gamma = GRAVITY / surface_geopotential[i, j] * np.maximum(temp_pl - temp_surface_k, 0) + + elif 2000.0 <= surface_height <= 2500.0: + t_adjusted = 0.002 * ( + (2500 - surface_height) * temp_sea_level_k + (surface_height - 2000.0) * temp_pl + ) + gamma = GRAVITY / surface_geopotential[i, j] * (t_adjusted - temp_surface_k) + else: + gamma = LAPSE_RATE + a_ln_p = gamma * RDGAS / GRAVITY * np.log(interp_pressure / surface_pressure[i, j]) ln_p_ps = np.log(interp_pressure / surface_pressure[i, j]) pressure_var[pl, i, j] = surface_geopotential[i, j] - RDGAS * temp_surface_k * ln_p_ps * ( - 1 + ALPHA * ln_p_ps / 2.0 + (ALPHA * ln_p_ps) ** 2 / 6.0 + 1 + a_ln_p / 2.0 + a_ln_p**2 / 6.0 ) return pressure_var