Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update http_tensorstore #19

Open
wants to merge 22 commits into
base: http_tensorstore
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8aa687c
Force script variable definitions to usable type.
rhoadesScholar Jan 29, 2025
8d7ffb5
Change log level from error to info when opening dataset files
rhoadesScholar Jan 29, 2025
16d7b23
Refactor server_check function to accept script path and dataset as a…
rhoadesScholar Jan 29, 2025
c166d9b
Merge branch 'main' into script_type_fix
rhoadesScholar Jan 29, 2025
fb56154
Add confirmation message for successful server check
rhoadesScholar Jan 29, 2025
8104ef1
Add script server check command to CLI for model validation
rhoadesScholar Jan 29, 2025
ebcaa5f
Revert server_check.py to use hardcoded script path and dataset; add …
rhoadesScholar Jan 29, 2025
36842e5
Fix model_setup04.py and server_check.py for improved readability and…
rhoadesScholar Jan 29, 2025
c90fc31
Merge pull request #15 from janelia-cellmap/main
rhoadesScholar Jan 29, 2025
9bc8706
fix typo
mzouink Jan 30, 2025
6a5891f
Merge pull request #17 from janelia-cellmap/norm
mzouink Jan 30, 2025
489162f
Merge pull request #20 from janelia-cellmap/main
rhoadesScholar Jan 30, 2025
9e9d476
Merge pull request #13 from janelia-cellmap/script_type_fix
mzouink Jan 30, 2025
786f16a
minor fix
mzouink Jan 31, 2025
b4586d1
support multiscale
mzouink Jan 31, 2025
cca3e50
update norm
mzouink Jan 31, 2025
9c75ef7
add needed deps
mzouink Jan 31, 2025
5c9f6e8
fix: :bug: fix issues where non-uniform block_shapes are used
davidackerman Feb 5, 2025
3d145a2
refactor: :recycle: move stuff around
davidackerman Feb 5, 2025
5cc93b9
Merge pull request #24 from janelia-cellmap/nonuniform_chunkshape_fix
davidackerman Feb 5, 2025
70f1fd8
fix: :bug: adjust shape appropriately when input and output resolutio…
davidackerman Feb 6, 2025
df652ef
Merge pull request #25 from janelia-cellmap/input_output_resolution_m…
davidackerman Feb 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cellmap_flow/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import logging
import click

from cellmap_flow.server import CellMapFlowServer
from cellmap_flow.utils.bsub_utils import start_hosts
from cellmap_flow.utils.data import ScriptModelConfig
from cellmap_flow.utils.neuroglancer_utils import generate_neuroglancer_link


Expand Down Expand Up @@ -152,6 +154,26 @@ def bioimage(model_path, data_path, queue, charge_group):
run(command, data_path, queue, charge_group)


@cli.command()
@click.option(
"--script_path",
"-s",
type=str,
help="Path to the Python script containing model specification",
)
@click.option("--dataset", "-d", type=str, help="Path to the dataset")
def script_server_check(script_path, dataset):
model_config = ScriptModelConfig(script_path=script_path)
server = CellMapFlowServer(dataset, model_config)
chunk_x = 2
chunk_y = 2
chunk_z = 2

server._chunk_impl(None, None, chunk_x, chunk_y, chunk_z, None)

print("Server check passed")


def run(
command,
dataset_path,
Expand Down
2 changes: 0 additions & 2 deletions cellmap_flow/image_data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(
self.output_voxel_size = output_voxel_size
else:
self.output_voxel_size = self.voxel_size


@property
def ts(self):
Expand All @@ -51,5 +50,4 @@ def to_ndarray_ts(self, roi=None):
self.swap_axes,
self.custom_fill_value,
)
self.ts = None
return res
11 changes: 8 additions & 3 deletions cellmap_flow/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def predict(read_roi, write_roi, config, **kwargs):
device = kwargs.get("device")
if device is None:
raise ValueError("device must be provided in kwargs")

use_half_prediction = kwargs.get("use_half_prediction", False)

raw_input = idi.to_ndarray_ts(read_roi)
Expand All @@ -54,7 +54,7 @@ def __init__(self, model_config: ModelConfig, use_half_prediction=False):
self.model_config.config, "write_shape"
):
self.context = (
Coordinate(self.model_config.config.read_shape)
Coordinate(self.model_config.config.read_shape)
- Coordinate(self.model_config.config.write_shape)
) / 2

Expand Down Expand Up @@ -115,7 +115,12 @@ def process_chunk_basic(self, idi, roi):

input_roi = output_roi.grow(self.context, self.context)
result = self.model_config.config.predict(
input_roi, output_roi, self.model_config.config, idi=idi, device=self.device, use_half_prediction=self.use_half_prediction
input_roi,
output_roi,
self.model_config.config,
idi=idi,
device=self.device,
use_half_prediction=self.use_half_prediction,
)
write_data = self.model_config.config.normalize_output(result)

Expand Down
4 changes: 2 additions & 2 deletions cellmap_flow/norm/input_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ def __init__(self, min_value=0.0, max_value=255.0):

def normalize(self, data: np.ndarray) -> np.ndarray:
data = data.astype(np.float32)
data.clip(self.min_value, self.max_value)
data = data.clip(self.min_value, self.max_value)
return ((data - self.min_value) / (self.max_value - self.min_value)).astype(
np.float32
)


NormalizationMethods = [f.name() for f in InputNormalizer.__subclasses__()]
NormalizationMethods = [f for f in InputNormalizer.__subclasses__()]


def get_normalization(elms: dict) -> InputNormalizer:
Expand Down
35 changes: 26 additions & 9 deletions cellmap_flow/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,17 @@ def __init__(self, dataset_name: str, model_config: ModelConfig):
"""
Initialize the server and set up routes via decorators.
"""
self.block_shape = [int(x) for x in model_config.config.block_shape]

# this is zyx
self.read_block_shape = [int(x) for x in model_config.config.block_shape]

# this needs to have z and x swapped
self.n5_block_shape = self.read_block_shape.copy()
self.n5_block_shape[0], self.n5_block_shape[2] = (
self.n5_block_shape[2],
self.n5_block_shape[0],
)

self.input_voxel_size = Coordinate(model_config.config.input_voxel_size)
self.output_voxel_size = Coordinate(model_config.config.output_voxel_size)
self.output_channels = model_config.config.output_channels
Expand All @@ -41,22 +51,29 @@ def __init__(self, dataset_name: str, model_config: ModelConfig):
self.idi_raw = ImageDataInterface(
dataset_name, target_resolution=self.input_voxel_size
)
output_shape = (
np.array(self.idi_raw.shape)
* np.array(self.input_voxel_size)
/ np.array(self.output_voxel_size)
)

if ".zarr" in dataset_name:
# Convert from (z, y, x) -> (x, y, z) plus channels
self.vol_shape = np.array(
[*np.array(self.idi_raw.shape)[::-1], self.output_channels]
[
*output_shape[::-1],
self.output_channels,
]
)
self.axis = ["x", "y", "z", "c^"]
else:
# For non-Zarr data
self.vol_shape = np.array(
[*np.array(self.idi_raw.shape), self.output_channels]
)
self.vol_shape = np.array([*output_shape, self.output_channels])
self.axis = ["z", "y", "x", "c^"]

# Chunk encoding for N5
self.chunk_encoder = N5ChunkWrapper(
np.uint8, self.block_shape, compressor=numcodecs.GZip()
np.uint8, self.n5_block_shape, compressor=numcodecs.GZip()
)

# Create and configure Flask
Expand Down Expand Up @@ -269,7 +286,7 @@ def _attributes_impl(self, dataset, scale):
"translate": [0.0, 0.0, 0.0, 0.0],
},
"compression": {"type": "gzip", "useZlib": False, "level": -1},
"blockSize": list(self.block_shape),
"blockSize": list(self.n5_block_shape),
"dataType": "uint8",
"dimensions": self.vol_shape.tolist(),
}
Expand All @@ -292,8 +309,8 @@ def _input_normalize_impl(self, norm_type, min_value, max_value):
)

def _chunk_impl(self, dataset, scale, chunk_x, chunk_y, chunk_z, chunk_c):
corner = self.block_shape[:3] * np.array([chunk_z, chunk_y, chunk_x])
box = np.array([corner, self.block_shape[:3]]) * self.output_voxel_size
corner = self.read_block_shape[:3] * np.array([chunk_z, chunk_y, chunk_x])
box = np.array([corner, self.read_block_shape[:3]]) * self.output_voxel_size
roi = Roi(box[0], box[1])
chunk_data = self.inferencer.process_chunk(self.idi_raw, roi)
return (
Expand Down
67 changes: 49 additions & 18 deletions cellmap_flow/utils/neuroglancer_utils.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,55 @@
import neuroglancer
import itertools
import logging
import os

neuroglancer.set_server_bind_address("0.0.0.0")

logger = logging.getLogger(__name__)

from cellmap_flow.image_data_interface import ImageDataInterface
from cellmap_flow.utils.scale_pyramid import ScalePyramid


# TODO support multiresolution datasets
def get_raw_layer(dataset_path, filetype):
if filetype == "zarr":
def get_raw_layer(dataset_path, filetype, is_multiscale=False):
if filetype == "n5":
axis = ["x", "y", "z"]
else:
axis = ["z", "y", "x"]
image = ImageDataInterface(dataset_path)
return neuroglancer.ImageLayer(
source=neuroglancer.LocalVolume(
data=image.ts,
dimensions=neuroglancer.CoordinateSpace(
names=axis,
units="nm",
scales=image.voxel_size,
),
voxel_offset=image.offset,

layers = []

if is_multiscale:
scales = [
f for f in os.listdir(dataset_path) if f[0] == "s" and f[1:].isdigit()
]
scales.sort(key=lambda x: int(x[1:]))
for scale in scales:
image = ImageDataInterface(f"{os.path.join(dataset_path, scale)}")
layers.append(
neuroglancer.LocalVolume(
data=image.ts,
dimensions=neuroglancer.CoordinateSpace(
names=axis,
units="nm",
scales=image.voxel_size,
),
voxel_offset=image.offset,
)
)
return ScalePyramid(layers)
else:
image = ImageDataInterface(dataset_path)
return neuroglancer.ImageLayer(
source=neuroglancer.LocalVolume(
data=image.ts,
dimensions=neuroglancer.CoordinateSpace(
names=axis,
units="nm",
scales=image.voxel_size,
),
voxel_offset=image.offset,
)
)

Expand All @@ -33,20 +60,24 @@ def generate_neuroglancer_link(dataset_path, inference_dict):

# Add a layer to the viewer
with viewer.txn() as s:
is_multi_scale = False
# if multiscale dataset
# if (
# dataset_path.split("/")[-1].startswith("s")
# and dataset_path.split("/")[-1][1:].isdigit()
# ):
# dataset_path = dataset_path.rsplit("/", 1)[0]
if (
dataset_path.split("/")[-1].startswith("s")
and dataset_path.split("/")[-1][1:].isdigit()
):
dataset_path = dataset_path.rsplit("/", 1)[0]
is_multi_scale = True

if ".zarr" in dataset_path:
filetype = "zarr"
elif ".n5" in dataset_path:
filetype = "n5"
else:
filetype = "precomputed"
if dataset_path.startswith("/"):
s.layers["raw"] = get_raw_layer(dataset_path, filetype)
layer = get_raw_layer(dataset_path, filetype, is_multi_scale)
s.layers.append("raw", layer)
# if "nrs/cellmap" in dataset_path:
# security = "https"
# dataset_path = dataset_path.replace("/nrs/cellmap/", "nrs/")
Expand Down
113 changes: 113 additions & 0 deletions cellmap_flow/utils/scale_pyramid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# copied from https://github.com/funkelab/funlib.show.neuroglancer/blob/master/funlib/show/neuroglancer/scale_pyramid.py

import neuroglancer
import operator
import logging

import numpy as np


logger = logging.getLogger(__name__)


class ScalePyramid(neuroglancer.LocalVolume):
"""A neuroglancer layer that provides volume data on different scales.
Mimics a LocalVolume.

Args:

volume_layers (``list`` of ``LocalVolume``):

One ``LocalVolume`` per provided resolution.
"""

def __init__(self, volume_layers):
volume_layers = volume_layers

super(neuroglancer.LocalVolume, self).__init__()

logger.info("Creating scale pyramid...")

self.min_voxel_size = min(
[tuple(layer.dimensions.scales) for layer in volume_layers]
)
self.max_voxel_size = max(
[tuple(layer.dimensions.scales) for layer in volume_layers]
)

self.dims = len(volume_layers[0].dimensions.scales)
self.volume_layers = {
tuple(
int(x)
for x in map(
operator.truediv, layer.dimensions.scales, self.min_voxel_size
)
): layer
for layer in volume_layers
}

logger.info("min_voxel_size: %s", self.min_voxel_size)
logger.info("scale keys: %s", self.volume_layers.keys())
logger.info(self.info())

@property
def volume_type(self):
return self.volume_layers[(1,) * self.dims].volume_type

@property
def token(self):
return self.volume_layers[(1,) * self.dims].token

def info(self):
reference_layer = self.volume_layers[(1,) * self.dims]
# return reference_layer.info()

reference_info = reference_layer.info()

info = {
"dataType": reference_info["dataType"],
"encoding": reference_info["encoding"],
"generation": reference_info["generation"],
"coordinateSpace": reference_info["coordinateSpace"],
"shape": reference_info["shape"],
"volumeType": reference_info["volumeType"],
"voxelOffset": reference_info["voxelOffset"],
"chunkLayout": reference_info["chunkLayout"],
"downsamplingLayout": reference_info["downsamplingLayout"],
"maxDownsampling": int(
np.prod(np.array(self.max_voxel_size) // np.array(self.min_voxel_size))
),
"maxDownsampledSize": reference_info["maxDownsampledSize"],
"maxDownsamplingScales": reference_info["maxDownsamplingScales"],
}

return info

def get_encoded_subvolume(self, data_format, start, end, scale_key=None):
if scale_key is None:
scale_key = ",".join(("1",) * self.dims)

scale = tuple(int(s) for s in scale_key.split(","))
closest_scale = None
min_diff = np.inf
for volume_scales in self.volume_layers.keys():
scale_diff = np.array(scale) // np.array(volume_scales)
if any(scale_diff < 1):
continue
scale_diff = scale_diff.max()
if scale_diff < min_diff:
min_diff = scale_diff
closest_scale = volume_scales

assert closest_scale is not None
relative_scale = np.array(scale) // np.array(closest_scale)

return self.volume_layers[closest_scale].get_encoded_subvolume(
data_format, start, end, scale_key=",".join(map(str, relative_scale))
)

def get_object_mesh(self, object_id):
return self.volume_layers[(1,) * self.dims].get_object_mesh(object_id)

def invalidate(self):
return self.volume_layers[(1,) * self.dims].invalidate()
Loading