|
7 | 7 | in the [OME-NGFF](https://ngff.openmicroscopy.org/) specification.
|
8 | 8 | """
|
9 | 9 | from __future__ import annotations
|
10 |
| -from typing import Annotated, Any, Literal, Optional, Sequence, TYPE_CHECKING |
| 10 | +from typing import Tuple, List, Annotated, Any, Literal, Optional, Sequence, TYPE_CHECKING |
11 | 11 |
|
12 | 12 | from cellmap_schemas.base import normalize_chunks
|
13 | 13 |
|
|
19 | 19 | from numcodecs.abc import Codec
|
20 | 20 | from numpy.typing import NDArray
|
21 | 21 | from pydantic_zarr.v2 import GroupSpec, ArraySpec
|
| 22 | +from xarray import DataArray |
22 | 23 | from pydantic import BaseModel, Field, model_validator
|
23 | 24 | from cellmap_schemas.multiscale import neuroglancer_n5
|
24 | 25 | from cellmap_schemas.multiscale.neuroglancer_n5 import PixelResolution
|
@@ -81,6 +82,60 @@ def validate_argument_length(self: Self):
|
81 | 82 | )
|
82 | 83 | return self
|
83 | 84 |
|
| 85 | + @classmethod |
| 86 | + def from_xarray(cls, array: DataArray, reverse_axes: bool = False) -> "STTransform": |
| 87 | + """ |
| 88 | + Generate a spatial transform from a DataArray. |
| 89 | +
|
| 90 | + Parameters |
| 91 | + ---------- |
| 92 | +
|
| 93 | + array: xarray.DataArray |
| 94 | + A DataArray with coordinates that can be expressed as scaling + translation |
| 95 | + applied to a regular grid. |
| 96 | + reverse_axes: boolean, default=False |
| 97 | + If `True`, the order of the `axes` in the spatial transform will |
| 98 | + be reversed relative to the order of the dimensions of `array`, and the |
| 99 | + `order` field of the resulting STTransform will be set to "F". This is |
| 100 | + designed for compatibility with N5 tools. |
| 101 | +
|
| 102 | + Returns |
| 103 | + ------- |
| 104 | +
|
| 105 | + STTransform |
| 106 | + An instance of STTransform that is consistent with the coordinates defined |
| 107 | + on the input DataArray. |
| 108 | + """ |
| 109 | + |
| 110 | + orderer = slice(None) |
| 111 | + output_order = "C" |
| 112 | + if reverse_axes: |
| 113 | + orderer = slice(-1, None, -1) |
| 114 | + output_order = "F" |
| 115 | + |
| 116 | + axes = [str(d) for d in array.dims[orderer]] |
| 117 | + # default unit is m |
| 118 | + units = [array.coords[ax].attrs.get("units", "m") for ax in axes] |
| 119 | + translate = [float(array.coords[ax][0]) for ax in axes] |
| 120 | + scale = [] |
| 121 | + for ax in axes: |
| 122 | + if len(array.coords[ax]) > 1: |
| 123 | + scale_estimate = abs( |
| 124 | + float(array.coords[ax][1]) - float(array.coords[ax][0]) |
| 125 | + ) |
| 126 | + else: |
| 127 | + raise ValueError( |
| 128 | + f""" |
| 129 | + Cannot infer scale parameter along dimension {ax} |
| 130 | + with length {len(array.coords[ax])} |
| 131 | + """ |
| 132 | + ) |
| 133 | + scale.append(scale_estimate) |
| 134 | + |
| 135 | + return cls( |
| 136 | + axes=axes, units=units, translate=translate, scale=scale, order=output_order |
| 137 | + ) |
| 138 | + |
84 | 139 |
|
85 | 140 | class ArrayMetadata(BaseModel):
|
86 | 141 | """
|
|
0 commit comments