Skip to content

1 grid class #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ unit-tests:
python -m pytest -vv --cov=. --cov-report=$(COV_REPORT) --doctest-glob="*.md" --doctest-glob="*.rst"

type-check:
python -m mypy . --follow-untyped-imports
python -m mypy .

conda-env-update:
$(CONDA) install -y -c conda-forge conda-merge
Expand Down
7 changes: 7 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,10 @@ channels:
# - package1
# - package2
# DO NOT EDIT ABOVE THIS LINE, ADD DEPENDENCIES BELOW AS SHOWN IN THE EXAMPLE
dependecies:
- cf_xarray
- netcdf4
- numpy
- xarray
- xesmf
- xcdat
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ dependencies = [
"netCDF4",
"numpy",
"xarray",
"xesmf"
"xesmf",
"xcdat @git+https://github.com/xCDAT/xcdat"
]
description = "NEMO Regional Configuration Toolbox"
dynamic = ["version"]
Expand Down Expand Up @@ -65,6 +66,10 @@ branch = true
[tool.mypy]
strict = false

[[tool.mypy.overrides]]
ignore_missing_imports = true
module = ["xesmf.*", "xcdat.*"]

[tool.ruff]
# Black line length is 88, but black does not format comments.
line-length = 110
Expand All @@ -80,7 +85,8 @@ lint.ignore = [
"D413",
"D415",
"D416",
"D417"
"D417",
"F401"
]
lint.select = [
# pyflakes
Expand Down
60 changes: 48 additions & 12 deletions src/pyic/grid.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
import copy as cp
import warnings

import cf_xarray
import numpy as np
import xarray as xr


class GRID:
"""Class that provides methods to handle and regrid gridded datasets for NEMO."""

def open_dataset(self, filename):
def open_dataset(self, filename, convert_to_z, z_kwargs):
"""Open a dataset from a specified filename using xarray.

Args:
filename (str): The path to the dataset file.
convert_to_z (bool)
zkawrgs (dict)

Returns:
xarray.Dataset: The opened dataset.
"""
if convert_to_z:
return self.convert_grid(filename)
return xr.open_dataset(filename) # Use xarray to open the dataset file

def get_dim_varname(self, dimtype):
Expand Down Expand Up @@ -82,14 +87,13 @@ def extract_lonlat(self, lon_name=None, lat_name=None):
lat_da = self.ds[lat_name] # Get the latitude DataArray

# If longitude or latitude is 1D, create a meshgrid for 2D representation
if len(lon_da.shape) == 1:
lon_da = xr.DataArray(
np.meshgrid(lon_da, lon_da), dims=["y", "x"]
) # Create a 2D meshgrid for longitude
if len(lat_da.shape) == 1:
lat_da = xr.DataArray(
np.meshgrid(lat_da, lat_da), dims=["y", "x"]
) # Create a 2D meshgrid for latitude
if lon_da.ndim == 1 and lat_da.ndim == 1:
lon_arr, lat_arr = np.meshgrid(lon_da, lat_da)
lon_da = xr.DataArray(lon_arr, dims=["y", "x"]) # Create a 2D meshgrid for longitude
lat_da = xr.DataArray(lat_arr, dims=["y", "x"]) # Create a 2D meshgrid for latitude

self.ds[lat_name] = lat_da
self.ds[lon_name] = lon_da

return (
lon_da,
Expand All @@ -98,7 +102,7 @@ def extract_lonlat(self, lon_name=None, lat_name=None):
lat_name,
) # Return the longitude and latitude DataArrays and their names

def make_common_coords(self, lon_name, lat_name, time_counter="time_counter"):
def make_common_coords(self, lon_name, lat_name, time_counter="time_counter", convert_to_z_grid=False):
"""Align the grid dataset with common coordinate names for regridding.

Args:
Expand Down Expand Up @@ -129,12 +133,39 @@ def make_common_coords(self, lon_name, lat_name, time_counter="time_counter"):

return ds_grid # Return the modified dataset with common coordinates

def convert_grid(filename, z_kwargs):
"""TODO. Vertical regrid of data.

using xgcm's built in vertical regridder with xCDAT.
"""
# xcdat documentation here https://xcdat.readthedocs.io/en/main-doc-fix/examples/regridding-vertical.html
# xgcm documentation here https://xgcm.readthedocs.io/en/latest/transform.html?highlight=vertical

import xcdat

ds_grid = xr.open_dataset(filename)

if "lev" not in z_kwargs:
raise Exception("Provide z levels to regrid to using z_kwargs = {'lev':[some levels]}.")
if "var" not in z_kwargs:
raise Exception("Provide origin vertical grid variable as z_kwargs = {'var':'so'}.")
if "method" not in z_kwargs:
method = z_kwargs["method"]
else:
method = "linear"

ds_grid = ds_grid.regridder.vertical(z_kwargs["var"], z_kwargs["lev"], tool="xgcm", method=method)

return ds_grid

def __init__(
self,
data_filename=None,
ds_lon_name=None,
ds_lat_name=None,
ds_time_counter="time_counter",
convert_to_z_grid=False,
z_kwargs={},
):
"""Initialize the GRID class with the specified dataset and coordinate names.

Expand All @@ -146,19 +177,24 @@ def __init__(
If None, it will be inferred from common names.
ds_time_counter (str, optional): The name of the time counter variable in the dataset.
If None, it will be inferred from common names.
convert_to_z_grid (bool, optional): whether to convert from a sigma-level grid to
a z-level grid.
z_kwargs (dict, optional): additional details required for vertical regridding
"""
self.data_filename = data_filename # Store the path to the dataset file
self.lon_names = ["glamt", "nav_lon"] # List of potential longitude variable names
self.lat_names = ["gphit", "nav_lat"] # List of potential latitude variable names

# Open the dataset using the provided filename
self.ds = self.open_dataset(self.data_filename)
self.ds = self.open_dataset(self.data_filename, convert_to_z_grid, z_kwargs)

# Extract longitude and latitude DataArrays and their names
self.lon, self.lat, ds_lon_name, ds_lat_name = self.extract_lonlat(ds_lon_name, ds_lat_name)

# Create a common grid with standardized coordinate names
self.common_grid = self.make_common_coords(ds_lon_name, ds_lat_name, ds_time_counter)
self.common_grid = self.make_common_coords(
ds_lon_name, ds_lat_name, ds_time_counter, convert_to_z_grid
)

# Store the names of the longitude and latitude variables for later use
self.coords = {"lon_name": ds_lon_name, "lat_name": ds_lat_name}
Expand Down
Loading