Skip to content

Commit b186e89

Browse files
committed
Fix: Missing from_xarray method in STTransform class
1 parent 0baf32b commit b186e89

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ classifiers = [
2525
dependencies = [
2626
"pydantic-zarr >= 0.7.0",
2727
"s3fs >= 2023.10.0",
28-
"rich >= 13.7.0"
28+
"rich >= 13.7.0",
29+
"xarray >=2022.03.0"
2930
]
3031

3132
[project.urls]

src/cellmap_schemas/multiscale/cosem.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
in the [OME-NGFF](https://ngff.openmicroscopy.org/) specification.
88
"""
99
from __future__ import annotations
10-
from typing import Annotated, Any, Literal, Optional, Sequence, TYPE_CHECKING
10+
from typing import Tuple, List, Annotated, Any, Literal, Optional, Sequence, TYPE_CHECKING
1111

1212
from cellmap_schemas.base import normalize_chunks
1313

@@ -19,6 +19,7 @@
1919
from numcodecs.abc import Codec
2020
from numpy.typing import NDArray
2121
from pydantic_zarr.v2 import GroupSpec, ArraySpec
22+
from xarray import DataArray
2223
from pydantic import BaseModel, Field, model_validator
2324
from cellmap_schemas.multiscale import neuroglancer_n5
2425
from cellmap_schemas.multiscale.neuroglancer_n5 import PixelResolution
@@ -81,6 +82,60 @@ def validate_argument_length(self: Self):
8182
)
8283
return self
8384

85+
@classmethod
86+
def from_xarray(cls, array: DataArray, reverse_axes: bool = False) -> "STTransform":
87+
"""
88+
Generate a spatial transform from a DataArray.
89+
90+
Parameters
91+
----------
92+
93+
array: xarray.DataArray
94+
A DataArray with coordinates that can be expressed as scaling + translation
95+
applied to a regular grid.
96+
reverse_axes: boolean, default=False
97+
If `True`, the order of the `axes` in the spatial transform will
98+
be reversed relative to the order of the dimensions of `array`, and the
99+
`order` field of the resulting STTransform will be set to "F". This is
100+
designed for compatibility with N5 tools.
101+
102+
Returns
103+
-------
104+
105+
STTransform
106+
An instance of STTransform that is consistent with the coordinates defined
107+
on the input DataArray.
108+
"""
109+
110+
orderer = slice(None)
111+
output_order = "C"
112+
if reverse_axes:
113+
orderer = slice(-1, None, -1)
114+
output_order = "F"
115+
116+
axes = [str(d) for d in array.dims[orderer]]
117+
# default unit is m
118+
units = [array.coords[ax].attrs.get("units", "m") for ax in axes]
119+
translate = [float(array.coords[ax][0]) for ax in axes]
120+
scale = []
121+
for ax in axes:
122+
if len(array.coords[ax]) > 1:
123+
scale_estimate = abs(
124+
float(array.coords[ax][1]) - float(array.coords[ax][0])
125+
)
126+
else:
127+
raise ValueError(
128+
f"""
129+
Cannot infer scale parameter along dimension {ax}
130+
with length {len(array.coords[ax])}
131+
"""
132+
)
133+
scale.append(scale_estimate)
134+
135+
return cls(
136+
axes=axes, units=units, translate=translate, scale=scale, order=output_order
137+
)
138+
84139

85140
class ArrayMetadata(BaseModel):
86141
"""

0 commit comments

Comments
 (0)