Skip to content

NetCDF Compatibility and requirements.txt update #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
matplotlib
matplotlib==3.7.1
opencv-python
pint
polygon3
Expand All @@ -8,4 +8,4 @@ scipy
zarr
netCDF4
numpy
numba
numba
147 changes: 75 additions & 72 deletions src/py_eddy_tracker/dataset/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class GridDataset(object):
"filename",
"dimensions",
"indexs",
"nc4file",
"variables_description",
"global_attrs",
"vars",
Expand All @@ -275,6 +276,7 @@ def __init__(
indexs=None,
unset=False,
nan_masking=False,
nc4file=None,
):
"""
:param str filename: Filename to load
Expand All @@ -301,6 +303,7 @@ def __init__(
self.coordinates = x_name, y_name
self.vars = dict()
self.indexs = dict() if indexs is None else indexs
self.nc4file = Dataset(filename, "r") if nc4file is None else nc4file
if centered is None:
logger.warning(
"We assume pixel position of grid is centered for %s", filename
Expand Down Expand Up @@ -344,25 +347,25 @@ def load_general_features(self):
logger.debug(
"Load general feature from %(filename)s", dict(filename=self.filename)
)
with Dataset(self.filename) as h:
# Load generals
self.dimensions = {i: len(v) for i, v in h.dimensions.items()}
self.variables_description = dict()
for i, v in h.variables.items():
args = (i, v.datatype)
kwargs = dict(dimensions=v.dimensions, zlib=True)
if hasattr(v, "_FillValue"):
kwargs["fill_value"] = (v._FillValue,)
attrs = dict()
for attr in v.ncattrs():
if attr in kwargs.keys():
continue
if attr == "_FillValue":
continue
attrs[attr] = getattr(v, attr)
self.variables_description[i] = dict(
args=args, kwargs=kwargs, attrs=attrs, infos=dict()
)
h = self.nc4file
# Load generals
self.dimensions = {i: len(v) for i, v in h.dimensions.items()}
self.variables_description = dict()
for i, v in h.variables.items():
args = (i, v.datatype)
kwargs = dict(dimensions=v.dimensions, zlib=True)
if hasattr(v, "_FillValue"):
kwargs["fill_value"] = (v._FillValue,)
attrs = dict()
for attr in v.ncattrs():
if attr in kwargs.keys():
continue
if attr == "_FillValue":
continue
attrs[attr] = getattr(v, attr)
self.variables_description[i] = dict(
args=args, kwargs=kwargs, attrs=attrs, infos=dict()
)
self.global_attrs = {attr: getattr(h, attr) for attr in h.ncattrs()}

def write(self, filename):
Expand Down Expand Up @@ -407,14 +410,14 @@ def load(self):
Get coordinates and setup coordinates function
"""
x_name, y_name = self.coordinates
with Dataset(self.filename) as h:
self.x_dim = h.variables[x_name].dimensions
self.y_dim = h.variables[y_name].dimensions
h = self.nc4file
self.x_dim = h.variables[x_name].dimensions
self.y_dim = h.variables[y_name].dimensions

sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim]
sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim]
self.vars[x_name] = h.variables[x_name][sl_x]
self.vars[y_name] = h.variables[y_name][sl_y]
sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim]
sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim]
self.vars[x_name] = h.variables[x_name][sl_x]
self.vars[y_name] = h.variables[y_name][sl_y]

self.setup_coordinates()

Expand Down Expand Up @@ -481,10 +484,10 @@ def units(self, varname):
stored_units = self.variables_description[varname]["attrs"].get("units", None)
if stored_units is not None:
return stored_units
with Dataset(self.filename) as h:
var = h.variables[varname]
if hasattr(var, "units"):
return var.units
h = self.nc4file
var = h.variables[varname]
if hasattr(var, "units"):
return var.units

@property
def variables(self):
Expand Down Expand Up @@ -535,24 +538,24 @@ def grid(self, varname, indexs=None):
"Load %(varname)s from %(filename)s",
dict(varname=varname, filename=self.filename),
)
with Dataset(self.filename) as h:
dims = h.variables[varname].dimensions
sl = [
indexs.get(
dim,
self.indexs.get(
dim, slice(None) if dim in coordinates_dims else 0
),
)
for dim in dims
]
self.vars[varname] = h.variables[varname][sl]
if len(self.x_dim) == 1:
i_x = where(array(dims) == self.x_dim)[0][0]
i_y = where(array(dims) == self.y_dim)[0][0]
if i_x > i_y:
self.variables_description[varname]["infos"]["transpose"] = True
self.vars[varname] = self.vars[varname].T
h = self.nc4file
dims = h.variables[varname].dimensions
sl = [
indexs.get(
dim,
self.indexs.get(
dim, slice(None) if dim in coordinates_dims else 0
),
)
for dim in dims
]
self.vars[varname] = h.variables[varname][sl]
if len(self.x_dim) == 1:
i_x = where(array(dims) == self.x_dim)[0][0]
i_y = where(array(dims) == self.y_dim)[0][0]
if i_x > i_y:
self.variables_description[varname]["infos"]["transpose"] = True
self.vars[varname] = self.vars[varname].T
if self.nan_mask:
self.vars[varname] = ma.array(
self.vars[varname],
Expand All @@ -578,20 +581,20 @@ def grid_tiles(self, varname, slice_x, slice_y):
slice_x=slice_x,
),
)
with Dataset(self.filename) as h:
dims = h.variables[varname].dimensions
sl = [
(slice_x if dim in list(self.x_dim) else slice_y)
if dim in coordinates_dims
else 0
for dim in dims
]
data = h.variables[varname][sl]
if len(self.x_dim) == 1:
i_x = where(array(dims) == self.x_dim)[0][0]
i_y = where(array(dims) == self.y_dim)[0][0]
if i_x > i_y:
data = data.T
h = self.nc4file
dims = h.variables[varname].dimensions
sl = [
(slice_x if dim in list(self.x_dim) else slice_y)
if dim in coordinates_dims
else 0
for dim in dims
]
data = h.variables[varname][sl]
if len(self.x_dim) == 1:
i_x = where(array(dims) == self.x_dim)[0][0]
i_y = where(array(dims) == self.y_dim)[0][0]
if i_x > i_y:
data = data.T
if not hasattr(data, "mask"):
data = ma.array(data, mask=zeros(data.shape, dtype="bool"))
return data
Expand Down Expand Up @@ -1086,19 +1089,19 @@ class UnRegularGridDataset(GridDataset):
def load(self):
"""Load variable (data)"""
x_name, y_name = self.coordinates
with Dataset(self.filename) as h:
self.x_dim = h.variables[x_name].dimensions
self.y_dim = h.variables[y_name].dimensions
h = self.nc4file
self.x_dim = h.variables[x_name].dimensions
self.y_dim = h.variables[y_name].dimensions

sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim]
sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim]
self.vars[x_name] = h.variables[x_name][sl_x]
self.vars[y_name] = h.variables[y_name][sl_y]
sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim]
sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim]
self.vars[x_name] = h.variables[x_name][sl_x]
self.vars[y_name] = h.variables[y_name][sl_y]

self.x_c = self.vars[x_name]
self.y_c = self.vars[y_name]
self.x_c = self.vars[x_name]
self.y_c = self.vars[y_name]

self.init_pos_interpolator()
self.init_pos_interpolator()

@property
def bounds(self):
Expand Down