Skip to content

Commit

Permalink
Merge pull request #54 from NOAA-CEFI-Portal/develop
Browse files Browse the repository at this point in the history
include MHW processing to main branch
  • Loading branch information
chiaweh2 authored Nov 21, 2024
2 parents 0fb1e90 + 0e5776e commit 66013e7
Show file tree
Hide file tree
Showing 13 changed files with 1,899 additions and 68 deletions.
93 changes: 93 additions & 0 deletions mom6/mom6_module/mom6_detrend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
This is the module to implement the detrending
"""
from typing import Tuple
import xarray as xr

class ForecastDetrend:
"""Detrend class for forecast data"""
def __init__(
self,
da_data : xr.DataArray,
initialization_name : str = 'init',
member_name : str = 'member',
) -> None:
"""
Parameters
----------
da_data : xr.DataArray
The dataarray one want to use to
detrend.
initialization_name : str, optional
initialization dimension name, by default 'init'
member_name : str, optional
ensemble member dimension name, by default 'member'
"""
self.data = da_data
self.init = initialization_name
self.mem = member_name

def polyfit_coef(
self,
deg: int = 1
) -> xr.Dataset:
"""determine the polyfit coefficient based on
lead-time-dependent forecast ensemble mean anomalies
Parameters
----------
deg : int, optional
the order of polynomical fit to use for determining the
fit coefficient, by default 1
Returns
-------
xr.Dataset
coefficient of the polynomical fit
"""

# calculate the ensemble mean of the anomaly
da_ensmean = self.data.mean(dim=self.mem)
# use the ensemble mean anomaly to determine lead time dependent trend
ds_p = da_ensmean.polyfit(dim=self.init, deg=deg, skipna=True).compute()

return ds_p

def detrend_linear(
self,
precompute_coeff : bool = False,
ds_coeff : xr.Dataset = None,
in_place_memory_replace : bool = False
) -> Tuple[xr.DataArray,xr.Dataset]:
"""detrend the original data by using the
degree 1 ployfit coeff
Returns
-------
xr.DataArray
the data with linear trend removed
"""
if precompute_coeff:
ds_p = ds_coeff
else:
# get degree 1 polyfit coeff
ds_p = self.polyfit_coef(deg=1)

# # calculate linear trend based on polyfit coeff
# da_linear_trend = xr.polyval(self.data[self.init], ds_p.polyfit_coefficients)
# # remove the linear trend
# da_detrend = (self.data - da_linear_trend).persist()

if in_place_memory_replace:
self.data = (
self.data-
xr.polyval(self.data[self.init], ds_p.polyfit_coefficients)
).persist()
return self.data, ds_p
else:
da_detrend = (
self.data -
xr.polyval(self.data[self.init], ds_p.polyfit_coefficients)
).persist()
return da_detrend,ds_p
22 changes: 16 additions & 6 deletions mom6/mom6_module/mom6_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def mom6_encoding_attr(
ds_data_ori : xr.Dataset,
ds_data : xr.Dataset,
dataset_name : str,
var_names : List[str] = None
var_names : List[str] = None,
):
"""
This function is designed for creating attribute and netCDF encoding
Expand Down Expand Up @@ -1132,17 +1132,27 @@ def mom6_encoding_attr(
# copy original attrs and encoding for dims
for dim in misc_dims_list:
try:
ds_data[dim].attrs = ds_data_ori[dim].attrs
ds_data[dim].encoding = ds_data_ori[dim].encoding
ds_data[dim].encoding['complevel'] = 2
if ds_data[dim].attrs == {}:
ds_data[dim].attrs = ds_data_ori[dim].attrs
ds_data[dim].encoding = ds_data_ori[dim].encoding
ds_data[dim].encoding['complevel'] = 2
except KeyError:
print(f'no {dim} dimension')

# copy original attrs and encoding for variables
for var_name in var_names:
try:
ds_data[var_name].attrs = ds_data_ori[var_name].attrs
ds_data[var_name].encoding = ds_data_ori[var_name].encoding
if ds_data[var_name].attrs == {}:
ds_data[var_name].attrs = ds_data_ori[var_name].attrs
ds_data[var_name].encoding = ds_data_ori[var_name].encoding
else:
ds_data[var_name].encoding = ds_data_ori[var_name].encoding
new_attrs = list(ds_data[var_name].attrs.keys())
ori_attrs = list(ds_data_ori[var_name].attrs.keys())
for attr in ori_attrs:
if attr not in new_attrs:
ds_data[var_name].attrs[attr] = ds_data_ori[var_name].attrs[attr]

except KeyError:
print(f'new variable name {var_name}')
ds_data[var_name].encoding['complevel'] = 2
Expand Down
71 changes: 57 additions & 14 deletions mom6/mom6_module/mom6_mhw.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mom6.mom6_module.mom6_types import (
TimeGroupByOptions
)
from mom6.mom6_module.mom6_detrend import ForecastDetrend

warnings.simplefilter("ignore")
xr.set_options(keep_attrs=True)
Expand Down Expand Up @@ -73,7 +74,8 @@ def generate_forecast_batch(
climo_end_year : int = 2020,
anom_start_year : int = 1993,
anom_end_year : int = 2020,
quantile_threshold : float = 90.
quantile_threshold : float = 90.,
detrend : bool = False
) -> xr.Dataset:
"""generate the MHW statistics and identify MHW
Expand All @@ -89,6 +91,8 @@ def generate_forecast_batch(
end year of anomaly that need to identify MHW, by default 2020
quantile_threshold : float, optional
quantile value that define the threshold, by default 90.
detrend : bool, optional
flag for whether the MHW is based on detrended ssta or not.
Returns
-------
Expand All @@ -98,27 +102,49 @@ def generate_forecast_batch(

# calculate anomaly based on climatology
class_forecast_climo = ForecastClimatology(self.dataset,self.varname)
dict_anom = class_forecast_climo.generate_anom_batch(
dict_anom_thres = class_forecast_climo.generate_anom_batch(
climo_start_year,
climo_end_year,
climo_start_year, # force the anom start year for threshold be the same as climo period
climo_end_year, # force the anom end year for threshold be the same as climo period
'persist'
)

# detrend or not
if detrend:
class_detrend_thres = ForecastDetrend(dict_anom_thres['anomaly'])
dict_anom_thres['anomaly'], ds_p = class_detrend_thres.detrend_linear(
precompute_coeff=False,
in_place_memory_replace=True
)

# anomaly used for the threshold
ds_anom = xr.Dataset()
ds_anom[f'{self.varname}_anom'] = dict_anom['anomaly']
ds_anom['lon'] = self.dataset['lon']
ds_anom['lat'] = self.dataset['lat']
ds_anom_thres = xr.Dataset()
ds_anom_thres[f'{self.varname}_anom'] = dict_anom_thres['anomaly']
ds_anom_thres['lon'] = self.dataset['lon']
ds_anom_thres['lat'] = self.dataset['lat']

# calculate threshold
class_forecast_quantile = ForecastQuantile(ds_anom,f'{self.varname}_anom')
# if detrend:
# ### in memery result when creating the class
# class_forecast_quantile = ForecastQuantile(
# ds_anom_thres.compute(),
# f'{self.varname}_anom'
# )
# da_threshold = class_forecast_quantile.generate_quantile(
# climo_start_year,
# climo_end_year,
# quantile_threshold,
# dask_obj=False
# )
# else:
class_forecast_quantile = ForecastQuantile(ds_anom_thres,f'{self.varname}_anom')
### in memery result not lazy-loaded (same as climo period)
da_threshold = class_forecast_quantile.generate_quantile(
climo_start_year,
climo_end_year,
quantile_threshold
quantile_threshold,
dask_obj=True
)

# anomaly that need to find MHW
Expand All @@ -129,10 +155,18 @@ def generate_forecast_batch(
anom_end_year,
'persist',
precompute_climo = True,
da_climo = dict_anom['climatology']
da_climo = dict_anom_thres['climatology']
)
da_anom = dict_anom['anomaly']

if detrend:
class_detrend = ForecastDetrend(da_anom)
da_anom,_ = class_detrend.detrend_linear(
precompute_coeff=True,
ds_coeff=ds_p,
in_place_memory_replace=True
)

# calculate average mhw magnitude
da_mhw_mag = da_anom.where(da_anom.groupby(f'{self.init}.{self.tfreq}')>=da_threshold)
da_mhw_mag_ave = da_anom.mean(dim=f'{self.mem}').compute()
Expand All @@ -152,12 +186,21 @@ def generate_forecast_batch(

# output dataset
ds_mhw = xr.Dataset()
if detrend :
ds_mhw['polyfit_coefficients'] = ds_p['polyfit_coefficients']

ds_mhw[f'{self.varname}_threshold{quantile_threshold:02d}'] = da_threshold
ds_mhw[f'{self.varname}_threshold{quantile_threshold:02d}'].attrs['long_name'] = (
f'{self.varname} threshold{quantile_threshold:02d})'
f'{self.varname} threshold{quantile_threshold:02d}'
)
ds_mhw[f'{self.varname}_threshold{quantile_threshold:02d}'].attrs['units'] = 'degC'

ds_mhw[f'{self.varname}_climo'] = dict_anom_thres['climatology']
ds_mhw[f'{self.varname}_climo'].attrs['long_name'] = (
f'{self.varname} climatology'
)
ds_mhw[f'{self.varname}_climo'].attrs['units'] = 'degC'

ds_mhw[f'mhw_prob{quantile_threshold:02d}'] = da_prob
ds_mhw[f'mhw_prob{quantile_threshold:02d}'].attrs['long_name'] = (
f'marine heatwave probability (threshold{quantile_threshold:02d})'
Expand All @@ -170,11 +213,11 @@ def generate_forecast_batch(
)
ds_mhw['ssta_avg'].attrs['units'] = 'degC'

ds_mhw['mhw_mag_indentified_ens'] = da_mhw_mag
ds_mhw['mhw_mag_indentified_ens'].attrs['long_name'] = (
'marine heatwave magnitude in each ensemble'
ds_mhw['ssta'] = da_anom
ds_mhw['ssta'].attrs['long_name'] = (
'anomalous sea surface temperature'
)
ds_mhw['mhw_mag_indentified_ens'].attrs['units'] = 'degC'
ds_mhw['ssta'].attrs['units'] = 'degC'

ds_mhw.attrs['period_of_quantile'] = da_threshold.attrs['period_of_quantile']
ds_mhw.attrs['period_of_climatology'] = da_threshold.attrs['period_of_climatology']
Expand Down
58 changes: 36 additions & 22 deletions mom6/mom6_module/mom6_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,28 +213,6 @@ def generate_anom_batch(
"""Generate the anomaly based on the input
dataset covered period
Parameters
----------
climo_start_year : int, optional
start year to calculation the climatology, by default 1993
climo_end_year : int, optional
end year to calculation the climatology, by default 2020
dask_option : DaskOptions, optional
flag to determine one want the return result
to be 'compute', 'persist' or keep 'lazy' in anomaly, by default 'lazy'
Returns
-------
dict
anomaly: dataarray which represent the anomaly,
climatology: dataarray which represent the climatology
Raises
------
ValueError
when the kwarg anom_start_year & anom_end_year result in
empty array crop
Parameters
----------
climo_start_year : int, optional
Expand Down Expand Up @@ -700,6 +678,42 @@ def generate_anom_batch(
elif dask_option == 'compute':
return {'anomaly':da_anom.compute(),'climatology':da_climo}


class HistoricalQuantile:
"""
Class for calculating the quantile of historical
The method should be able to accomadate the
raw and regridded format
"""
def __init__(
self,
ds_data : xr.Dataset,
var_name : str,
time_name : str = 'time',
time_frequency : TimeGroupByOptions = 'month'
) -> None:
"""
Parameters
----------
ds_data : xr.Dataset
The dataset one want to use to
derived the histiorical run statistics.
var_name : str
The variable name in the dataset
initialization_name : str, optional
initialization dimension name, by default 'init'
member_name : str, optional
ensemble member dimension name, by default 'member'
time_frequency : TimeGroupByOptions, optional
name in time frequency to do the time group, by default 'month'
'year', 'month', 'dayofyear' are the available options.
"""
self.dataset = CoordinateWrangle(ds_data).to_360()
self.varname = var_name
self.timename = time_name
self.tfreq = time_frequency

def generate_quantile(
self,
quantile_start_year : int = 1993,
Expand Down
Loading

0 comments on commit 66013e7

Please sign in to comment.