Skip to content

Commit 00a8395

Browse files
2 parents 36108f3 + 3f6074c commit 00a8395

File tree

2 files changed

+121
-95
lines changed

2 files changed

+121
-95
lines changed

tests/test_grids/test_sgrid.py

Lines changed: 70 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,88 @@
11
import os
22

3-
import fsspec
43
import numpy as np
5-
import pytest
64
import xarray as xr
75

86
import xarray_subset_grid.accessor # noqa: F401
97
from tests.test_utils import get_test_file_dir
8+
from xarray_subset_grid.grids.sgrid import _get_location_info_from_topology
109

1110
# open dataset as zarr object using fsspec reference file system and xarray
1211

1312

1413
test_dir = get_test_file_dir()
1514
sample_sgrid_file = os.path.join(test_dir, 'arakawa_c_test_grid.nc')
1615

17-
@pytest.mark.online
18-
def test_polygon_subset():
19-
'''
20-
This is a basic integration test for the subsetting of a ROMS sgrid dataset using a polygon.
21-
'''
22-
if fsspec is None:
23-
raise ImportError("Must have fsspec installed to run --online tests")
24-
25-
fs = fsspec.filesystem(
26-
"reference",
27-
fo="s3://nextgen-dmac-cloud-ingest/nos/wcofs/nos.wcofs.2ds.best.nc.zarr",
28-
remote_protocol="s3",
29-
remote_options={"anon": True},
30-
target_protocol="s3",
31-
target_options={"anon": True},
32-
)
33-
m = fs.get_mapper("")
34-
35-
ds = xr.open_dataset(
36-
m, engine="zarr", backend_kwargs=dict(consolidated=False), chunks={}
37-
)
38-
39-
polygon = np.array(
40-
[
41-
[-122.38488806417945, 34.98888604471138],
42-
[-122.02425311530737, 33.300351211467074],
43-
[-120.60402628930146, 32.723214427630836],
44-
[-116.63789131284673, 32.54346959375448],
45-
[-116.39346090873218, 33.8541384965596],
46-
[-118.83845767505964, 35.257586401855164],
47-
[-121.34541503969862, 35.50073821008141],
48-
[-122.38488806417945, 34.98888604471138],
49-
]
50-
)
51-
ds_temp = ds.xsg.subset_vars(['temp_sur'])
52-
ds_subset = ds_temp.xsg.subset_polygon(polygon)
16+
def test_grid_topology_location_parse():
17+
ds = xr.open_dataset(sample_sgrid_file, decode_times=False)
18+
node_info = _get_location_info_from_topology(ds['grid'], 'node')
19+
edge1_info = _get_location_info_from_topology(ds['grid'], 'edge1')
20+
edge2_info = _get_location_info_from_topology(ds['grid'], 'edge2')
21+
face_info = _get_location_info_from_topology(ds['grid'], 'face')
5322

54-
#Check that the subset dataset has the correct dimensions given the original padding
55-
assert ds_subset.sizes['eta_rho'] == ds_subset.sizes['eta_psi'] + 1
56-
assert ds_subset.sizes['eta_u'] == ds_subset.sizes['eta_psi'] + 1
57-
assert ds_subset.sizes['eta_v'] == ds_subset.sizes['eta_psi']
58-
assert ds_subset.sizes['xi_rho'] == ds_subset.sizes['xi_psi'] + 1
59-
assert ds_subset.sizes['xi_u'] == ds_subset.sizes['xi_psi']
60-
assert ds_subset.sizes['xi_v'] == ds_subset.sizes['xi_psi'] + 1
23+
assert node_info == {'dims': ['xi_psi', 'eta_psi'],
24+
'coords': ['lon_psi', 'lat_psi'],
25+
'padding': {'xi_psi': 'none', 'eta_psi': 'none'}}
26+
assert edge1_info == {'dims': ['xi_u', 'eta_u'],
27+
'coords': ['lon_u', 'lat_u'],
28+
'padding': {'eta_u': 'both', 'xi_u': 'none'}}
29+
assert edge2_info == {'dims': ['xi_v', 'eta_v'],
30+
'coords': ['lon_v', 'lat_v'],
31+
'padding': {'xi_v': 'both', 'eta_v': 'none'}}
32+
assert face_info == {'dims': ['xi_rho', 'eta_rho'],
33+
'coords': ['lon_rho', 'lat_rho'],
34+
'padding': {'xi_rho': 'both', 'eta_rho': 'both'}}
35+
36+
37+
# def test_polygon_subset():
38+
# '''
39+
# This is a basic integration test for the subsetting of a ROMS sgrid dataset using a polygon.
40+
# '''
41+
# fs = fsspec.filesystem(
42+
# "reference",
43+
# fo="s3://nextgen-dmac-cloud-ingest/nos/wcofs/nos.wcofs.2ds.best.nc.zarr",
44+
# remote_protocol="s3",
45+
# remote_options={"anon": True},
46+
# target_protocol="s3",
47+
# target_options={"anon": True},
48+
# )
49+
# m = fs.get_mapper("")
50+
51+
# ds = xr.open_dataset(
52+
# m, engine="zarr", backend_kwargs=dict(consolidated=False), chunks={}
53+
# )
6154

62-
#Check that the subset rho/psi/u/v positional relationsip makes sense aka psi point is
63-
#'between' it's neighbor rho points
64-
#Note that this needs to be better generalized; it's not trivial to write a test that
65-
#works in all potential cases.
66-
assert (ds_subset['lon_rho'][0,0] < ds_subset['lon_psi'][0,0]
67-
and ds_subset['lon_rho'][0,1] > ds_subset['lon_psi'][0,0])
55+
# polygon = np.array(
56+
# [
57+
# [-122.38488806417945, 34.98888604471138],
58+
# [-122.02425311530737, 33.300351211467074],
59+
# [-120.60402628930146, 32.723214427630836],
60+
# [-116.63789131284673, 32.54346959375448],
61+
# [-116.39346090873218, 33.8541384965596],
62+
# [-118.83845767505964, 35.257586401855164],
63+
# [-121.34541503969862, 35.50073821008141],
64+
# [-122.38488806417945, 34.98888604471138],
65+
# ]
66+
# )
67+
# ds_temp = ds.xsg.subset_vars(['temp_sur'])
68+
# ds_subset = ds_temp.xsg.subset_polygon(polygon)
6869

69-
#ds_subset.temp_sur.isel(ocean_time=0).plot(x="lon_rho", y="lat_rho")
70+
# #Check that the subset dataset has the correct dimensions given the original padding
71+
# assert ds_subset.sizes['eta_rho'] == ds_subset.sizes['eta_psi'] + 1
72+
# assert ds_subset.sizes['eta_u'] == ds_subset.sizes['eta_psi'] + 1
73+
# assert ds_subset.sizes['eta_v'] == ds_subset.sizes['eta_psi']
74+
# assert ds_subset.sizes['xi_rho'] == ds_subset.sizes['xi_psi'] + 1
75+
# assert ds_subset.sizes['xi_u'] == ds_subset.sizes['xi_psi']
76+
# assert ds_subset.sizes['xi_v'] == ds_subset.sizes['xi_psi'] + 1
77+
78+
# #Check that the subset rho/psi/u/v positional relationsip makes sense aka psi point is
79+
# #'between' it's neighbor rho points
80+
# #Note that this needs to be better generalized; it's not trivial to write a test that
81+
# #works in all potential cases.
82+
# assert (ds_subset['lon_rho'][0,0] < ds_subset['lon_psi'][0,0]
83+
# and ds_subset['lon_rho'][0,1] > ds_subset['lon_psi'][0,0])
84+
85+
# #ds_subset.temp_sur.isel(ocean_time=0).plot(x="lon_rho", y="lat_rho")
7086

7187
def test_polygon_subset_2():
7288
ds = xr.open_dataset(sample_sgrid_file, decode_times=False)
@@ -89,3 +105,5 @@ def test_polygon_subset_2():
89105

90106
assert ds_subset.lon_psi.min() <= 6.5 and ds_subset.lon_psi.max() >= 9.5
91107
assert ds_subset.lat_psi.min() <= 37.5 and ds_subset.lat_psi.max() >= 40.5
108+
109+
assert 'u' in ds_subset.variables.keys()

xarray_subset_grid/grids/sgrid.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,20 @@ def compute_polygon_subset_selector(
108108
dims = _get_sgrid_dim_coord_names(grid_topology)
109109
subset_masks: list[tuple[list[str], xr.DataArray]] = []
110110

111-
node_dims = grid_topology.attrs["node_dimensions"].split()
112-
node_coords = grid_topology.attrs["node_coordinates"].split()
111+
node_info = _get_location_info_from_topology(grid_topology, 'node')
112+
node_dims = node_info['dims']
113+
node_coords = node_info['coords']
114+
115+
unique_dims = set(node_dims)
116+
node_vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))]
113117

114118
node_lon: xr.DataArray | None = None
115119
node_lat: xr.DataArray | None = None
116120
for c in node_coords:
117-
if 'lon' in c:
121+
if 'lon' in ds[c].standard_name.lower():
118122
node_lon = ds[c]
119-
elif 'lat' in c:
123+
elif 'lat' in ds[c].standard_name.lower():
120124
node_lat = ds[c]
121-
if node_lon is None or node_lat is None:
122-
raise ValueError(f"Could not find lon and lat for dimension {node_dims}")
123125

124126
node_mask = compute_2d_subset_mask(lat=node_lat, lon=node_lon, polygon=polygon)
125127
msk = np.where(node_mask)
@@ -134,28 +136,27 @@ def compute_polygon_subset_selector(
134136
node_mask[index_bounding_box[0][0]:index_bounding_box[0][1],
135137
index_bounding_box[1][0]:index_bounding_box[1][1]] = True
136138

137-
subset_masks.append(([node_coords[0], node_coords[1]], node_mask))
139+
subset_masks.append((node_vars, node_mask))
140+
138141
for s in ('face', 'edge1', 'edge2'):
139-
dims = grid_topology.attrs.get(f"{s}_dimensions", None)
140-
coords = grid_topology.attrs.get(f"{s}_coordinates", None).split()
142+
info = _get_location_info_from_topology(grid_topology, s)
143+
dims = info['dims']
144+
coords = info['coords']
145+
unique_dims = set(dims)
146+
vars = [k for k in ds.variables if unique_dims.issubset(set(ds[k].dims))]
141147

142148
lon: xr.DataArray | None = None
143-
lat: xr.DataArray | None = None
144149
for c in coords:
145150
if 'lon' in ds[c].standard_name.lower():
146151
lon = ds[c]
147-
elif 'lat' in ds[c].standard_name.lower():
148-
lat = ds[c]
149-
if lon is None or lat is None:
150-
raise ValueError(f"Could not find lon and lat for dimension {dims}")
151-
padding = parse_padding_string(dims)
152+
padding = info['padding']
152153
arranged_padding = [padding[d] for d in lon.dims]
153154
arranged_padding = [0 if p == 'none' or p == 'low' else 1 for p in arranged_padding]
154155
mask = np.zeros(lon.shape, dtype=bool)
155156
mask[index_bounding_box[0][0]:index_bounding_box[0][1] + arranged_padding[0],
156157
index_bounding_box[1][0]:index_bounding_box[1][1] + arranged_padding[1]] = True
157158
xr_mask = xr.DataArray(mask, dims=lon.dims)
158-
subset_masks.append(([coords[0], coords[1]], xr_mask))
159+
subset_masks.append((vars, xr_mask))
159160

160161
return SGridSelector(
161162
name=name or 'selector',
@@ -165,6 +166,40 @@ def compute_polygon_subset_selector(
165166
subset_masks=subset_masks,
166167
)
167168

169+
def _get_location_info_from_topology(grid_topology: xr.DataArray, location) -> dict[str, str]:
170+
'''Get the dimensions and coordinates for a given location from the grid_topology'''
171+
rdict = {}
172+
dim_str = grid_topology.attrs.get(f"{location}_dimensions", None)
173+
coord_str = grid_topology.attrs.get(f"{location}_coordinates", None)
174+
if dim_str is None or coord_str is None:
175+
raise ValueError(f"Could not find {location} dimensions or coordinates")
176+
# Remove padding for now
177+
dims_only = " ".join([v for v in dim_str.split(" ") if "(" not in v and ")" not in v])
178+
if ":" in dims_only:
179+
dims_only = [s.replace(":", "") for s in dims_only.split(" ") if ":" in s]
180+
else:
181+
dims_only = dims_only.split(" ")
182+
183+
padding = dim_str.replace(':', '').split(')')
184+
pdict = {}
185+
if len(padding) == 3: #two padding values
186+
pdict[dims_only[0]] = padding[0].split(' ')[-1]
187+
pdict[dims_only[1]] = padding[1].split(' ')[-1]
188+
elif len(padding) == 2: #one padding value
189+
if padding[-1] == '': #padding is on second dim
190+
pdict[dims_only[1]] = padding[0].split(' ')[-1]
191+
pdict[dims_only[0]] = 'none'
192+
else:
193+
pdict[dims_only[0]] = padding[0].split(' ')[-1]
194+
pdict[dims_only[1]] = 'none'
195+
else:
196+
pdict[dims_only[0]] = 'none'
197+
pdict[dims_only[1]] = 'none'
198+
199+
rdict['dims'] = dims_only
200+
rdict['coords'] = coord_str.split(" ")
201+
rdict['padding'] = pdict
202+
return rdict
168203

169204
def _get_sgrid_dim_coord_names(
170205
grid_topology: xr.DataArray,
@@ -189,30 +224,3 @@ def _get_sgrid_dim_coord_names(
189224
coords.append(v.split(" "))
190225

191226
return list(zip(dims, coords))
192-
193-
def parse_padding_string(dim_string):
194-
'''
195-
Given a grid_topology dimension string, parse the padding for each dimension.
196-
Returns a dict of {dim0name: padding,
197-
dim1name: padding
198-
}
199-
valid values of padding are: 'none', 'low', 'high', 'both'
200-
'''
201-
parsed_string = dim_string.replace('(padding: ', '').replace(')', '').replace(':', '')
202-
split_parsed_string = parsed_string.split(' ')
203-
if len(split_parsed_string) == 6:
204-
return {split_parsed_string[0]:split_parsed_string[2],
205-
split_parsed_string[3]:split_parsed_string[5]}
206-
elif len(split_parsed_string) == 5:
207-
if split_parsed_string[4] in {'none', 'low', 'high', 'both'}:
208-
#2nd dim has padding, and with len 5 that means first does not
209-
split_parsed_string.insert(2, 'none')
210-
else:
211-
split_parsed_string.insert(5, 'none')
212-
return {split_parsed_string[0]:split_parsed_string[2],
213-
split_parsed_string[3]:split_parsed_string[5]}
214-
elif len(split_parsed_string) == 2:
215-
#node dimensions string could look like this: 'node_dimensions: xi_psi eta_psi'
216-
return {split_parsed_string[0]: 'none', split_parsed_string[1]: 'none'}
217-
else:
218-
raise ValueError(f"Padding parsing failure: {dim_string}")

0 commit comments

Comments
 (0)