Skip to content

[DNM] GeoTransformIndex prototype #10055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3050,6 +3050,7 @@ def sel(
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
drop: bool = False,
options: dict[Any, Any] | None = None,
**indexers_kwargs: Any,
) -> Self:
"""Returns a new dataset with each array indexed by tick labels
Expand Down Expand Up @@ -3121,7 +3122,7 @@ def sel(
"""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
query_results = map_index_queries(
self, indexers=indexers, method=method, tolerance=tolerance
self, indexers=indexers, method=method, tolerance=tolerance, options=options
)

if drop:
Expand Down
198 changes: 198 additions & 0 deletions xarray/core/geoindex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from collections.abc import Hashable, Iterable, Mapping
from functools import lru_cache
from typing import Any

import affine
import numpy as np
import pyproj
from affine import Affine

import xarray as xr
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.indexes import CoordinateTransform, CoordinateTransformIndex
from xarray.core.types import Self
from xarray.core.variable import Variable

# https://pyproj4.github.io/pyproj/stable/advanced_examples.html#caching-pyproj-objects
transformer_from_crs = lru_cache(pyproj.Transformer.from_crs)


def with_geoindex(ds: Dataset, dims=("x", "y")) -> Dataset:
"""simple helper function."""

index = GeoTransformIndex.from_grid_mapping(
ds.cf["grid_mapping"], dim_sizes={k: ds.sizes[k] for k in dims}
)
newcoords = xr.Coordinates.from_xindex(index)
# TODO: assign coordinates to (xc, yc)
newds = ds.assign_coords(newcoords) # TODO: got confused with set_xindex
return newds


class Affine2DCoordinateTransform(CoordinateTransform):
"""Affine 2D coordinate transform."""

# Copied from benbovy's https://github.com/pydata/xarray/pull/9543

affine: affine.Affine
xy_dims = tuple[str]

def __init__(
self,
affine: affine.Affine,
coord_names: Iterable[Hashable],
dim_size: Mapping[str, int],
dtype: Any = np.dtype(np.float64),
):
# two dimensions
assert len(coord_names) == 2
assert len(dim_size) == 2

super().__init__(coord_names, dim_size, dtype=dtype)
self.affine = affine

# array dimensions in reverse order (y = rows, x = cols)
self.xy_dims = tuple(self.dims)
self.dims = (self.dims[1], self.dims[0])

def forward(self, dim_positions):
positions = [dim_positions[dim] for dim in self.xy_dims]
x_labels, y_labels = self.affine * tuple(positions)

results = {}
for name, labels in zip(self.coord_names, [x_labels, y_labels], strict=False):
results[name] = labels

return results

def reverse(self, coord_labels):
labels = [coord_labels[name] for name in self.coord_names]
x_positions, y_positions = ~self.affine * tuple(labels)

results = {}
for dim, positions in zip(
self.xy_dims, [x_positions, y_positions], strict=False
):
results[dim] = positions

return results

def equals(self, other):
return self.affine == other.affine and self.dim_size == other.dim_size


class GeoTransformIndex(CoordinateTransformIndex):
def __init__(self, *, transform, grid_mapping: DataArray, coord_names) -> Self:
self.gm_name = grid_mapping.name
self.crs = pyproj.CRS.from_cf(grid_mapping.attrs)
self.xcoord, self.ycoord = coord_names
return super().__init__(transform=transform)

def create_variables(self, variables=None):
res = super().create_variables(variables)
# assign the spatial_ref since the GeoTransform is updated when isel is called.
res[self.gm_name] = Variable(dims=(), data=0, attrs=self.crs.to_cf())
res[self.gm_name].attrs["GeoTransform"] = " ".join(
map(str, self.transform.affine.to_gdal())
)
return res

@classmethod
def from_grid_mapping(
cls,
grid_mapping: Variable,
*,
dim_sizes: dict[str, int],
coord_names=("xc", "yc"),
) -> Self:
geotransform = np.fromstring(
grid_mapping.attrs["GeoTransform"], sep=" "
).tolist()
fwd = Affine.from_gdal(*geotransform[:6])
centers = fwd * fwd.translation(0.5, 0.5)

xtransform = Affine2DCoordinateTransform(
affine=centers,
coord_names=coord_names,
dim_size=dim_sizes,
)
return cls(
transform=xtransform, grid_mapping=grid_mapping, coord_names=coord_names
)

def isel(self, indexers) -> Self:
# FIXME: now this gets called after sel and breaks.
indexers.setdefault("x", slice(None))
indexers.setdefault("y", slice(None))

assert isinstance(indexers["x"], slice)
assert isinstance(indexers["y"], slice)

offsets = tuple((indexers[dim].start or 0) for dim in ("x", "y"))
scales = tuple(indexers[dim].step or 1 for dim in ("x", "y"))
fwd = self.transform.affine
new_transform = fwd * fwd.scale(*scales) * fwd.translation(*offsets)
xtransform = Affine2DCoordinateTransform(
affine=new_transform,
coord_names=self.transform.coord_names,
dim_size=self.transform.dim_size,
)
new_spatial_ref = DataArray(0, attrs=self.crs.to_cf())
new_spatial_ref.attrs["GeoTransform"] = " ".join(
map(str, new_transform.to_gdal())
)
new_index = type(self)(
transform=xtransform,
grid_mapping=new_spatial_ref,
coord_names=self.transform.coord_names,
)
# TODO: we need to return the new spatial_ref, since the GeoTransform may have changed.
return new_index

def sel(self, labels: dict[Any, Any], method="nearest", tolerance=None, crs=None):
"""
This implements a selection API like OGC EDR where queries are always X,Y
interpreted using the provided CRS.
"""
from xarray.core.variable import broadcast_variables

# FIXME: Assume default CRS is self.crs, but we could alternatively assume EPSG:4326
qcrs = self.crs if crs is None else pyproj.CRS.from_user_input(crs)
transformer = transformer_from_crs(
crs_from=qcrs, crs_to=self.crs, always_xy=True
)

X = labels.get(self.xcoord, None)
Y = labels.get(self.ycoord, None)

# TODO: handle, X, Y being None.
assert X is not None
assert Y is not None

transformed_labels = dict(
zip((self.xcoord, self.ycoord), transformer.transform(X, Y), strict=False)
)
# Transformer always returns numpy?
for coord, raw, trans in zip(
labels.keys(), labels.values(), transformed_labels.values(), strict=False
):
if isinstance(raw, Variable):
transformed_labels[coord] = raw.copy(data=trans)
elif isinstance(raw, DataArray):
transformed_labels[coord] = raw.variable.copy(data=trans)
elif isinstance(raw, (np.ndarray | list | tuple)):
transformed_labels[coord] = Variable(dims=(coord,), data=trans)
else:
assert np.isscalar(trans)
transformed_labels[coord] = Variable(dims=(), data=trans)
transformed_labels = dict(
zip(
labels.keys(),
broadcast_variables(*tuple(transformed_labels.values())),
strict=False,
)
)
res = super().sel(labels=transformed_labels, method=method, tolerance=tolerance)
# TODO: create a new spatial_ref here without GeoTransform and return that.
return res
7 changes: 3 additions & 4 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def map_index_queries(
indexers: Mapping[Any, Any],
method=None,
tolerance: int | float | Iterable[int | float] | None = None,
options: dict[Any, Any] | None = None,
**indexers_kwargs: Any,
) -> IndexSelResult:
"""Execute index queries from a DataArray / Dataset and label-based indexers
Expand All @@ -179,9 +180,7 @@ def map_index_queries(
from xarray.core.dataarray import DataArray

# TODO benbovy - flexible indexes: remove when custom index options are available
if method is None and tolerance is None:
options = {}
else:
if options is None:
options = {"method": method, "tolerance": tolerance}

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "map_index_queries")
Expand Down Expand Up @@ -2047,7 +2046,7 @@ def dtype(self) -> np.dtype:

@property
def shape(self) -> tuple[int, ...]:
return tuple(self._transform.dim_size.values())
return tuple(self._transform.dim_size[dim] for dim in self._transform.dims)

def get_duck_array(self) -> np.ndarray:
all_coords = self._transform.generate_coords(dims=self._dims)
Expand Down
Loading