Skip to content

Commit e0e72dc

Browse files
Merge pull request #212 from bioimage-io/fix-tif
Change how the tiff sample data is written
2 parents 03abb7f + 23df15c commit e0e72dc

File tree

1 file changed

+56
-32
lines changed

1 file changed

+56
-32
lines changed

bioimageio/core/build_spec/build_model.py

+56-32
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,10 @@ def _get_deepimagej_macro(name, kwargs, export_folder):
338338
return {"spec": "ij.IJ::runMacroFile", "kwargs": macro}
339339

340340

341-
def _get_deepimagej_config(export_folder, sample_inputs, sample_outputs, pixel_sizes, preprocessing, postprocessing):
342-
assert len(sample_inputs) == len(sample_outputs) == 1, "deepimagej config only valid for single input/output"
341+
def _get_deepimagej_config(
342+
export_folder, test_inputs, test_outputs, input_axes, output_axes, pixel_sizes, preprocessing, postprocessing
343+
):
344+
assert len(test_inputs) == len(test_outputs) == 1, "deepimagej config only valid for single input/output"
343345

344346
if any(preproc is not None for preproc in preprocessing):
345347
assert len(preprocessing) == 1
@@ -363,13 +365,21 @@ def _get_deepimagej_config(export_folder, sample_inputs, sample_outputs, pixel_s
363365
else:
364366
postprocess_ij = [{"spec": None}]
365367

366-
def get_size(path):
367-
assert tifffile is not None, "need tifffile for writing deepimagej config"
368-
with tifffile.TiffFile(export_folder / path) as f:
369-
shape = f.asarray().shape
370-
# add singleton z axis if we have 2d data
368+
def get_size(fname, axes):
369+
shape = np.load(export_folder / fname).shape
370+
assert len(shape) == len(axes)
371+
shape = [sh for sh, ax in zip(shape, axes) if ax != "b"]
372+
axes = [ax for ax in axes if ax != "b"]
373+
# the shape for deepij is always given as xyzc
374+
if len(shape) == 3:
375+
axes_ij = "xyc"
376+
else:
377+
axes_ij = "xyzc"
378+
assert set(axes) == set(axes_ij)
379+
axis_permutation = [axes_ij.index(ax) for ax in axes]
380+
shape = [shape[permut] for permut in axis_permutation]
371381
if len(shape) == 3:
372-
shape = shape[:2] + (1,) + shape[-1:]
382+
shape = shape[:2] + [1] + shape[-1:]
373383
assert len(shape) == 4
374384
return " x ".join(map(str, shape))
375385

@@ -378,10 +388,13 @@ def get_size(path):
378388

379389
test_info = {
380390
"inputs": [
381-
{"name": in_path, "size": get_size(in_path), "pixel_size": pix_size}
382-
for in_path, pix_size in zip(sample_inputs, pixel_sizes_)
391+
{"name": in_path, "size": get_size(in_path, axes), "pixel_size": pix_size}
392+
for in_path, axes, pix_size in zip(test_inputs, input_axes, pixel_sizes_)
393+
],
394+
"outputs": [
395+
{"name": out_path, "type": "image", "size": get_size(out_path, axes)}
396+
for out_path, axes in zip(test_outputs, output_axes)
383397
],
384-
"outputs": [{"name": out_path, "type": "image", "size": get_size(out_path)} for out_path in sample_outputs],
385398
"memory_peak": None,
386399
"runtime": None,
387400
}
@@ -397,36 +410,49 @@ def get_size(path):
397410
return {"deepimagej": config}, [Path(a) for a in attachments]
398411

399412

400-
def _write_sample_data(input_paths, output_paths, input_axes, output_axes, export_folder: Path):
401-
def write_im(path, im, axes):
413+
def _write_sample_data(input_paths, output_paths, input_axes, output_axes, pixel_sizes, export_folder: Path):
414+
def write_im(path, im, axes, pixel_size=None):
402415
assert tifffile is not None, "need tifffile for writing deepimagej config"
403-
assert len(axes) == im.ndim
404-
assert im.ndim in (3, 4)
416+
assert len(axes) == im.ndim, f"{len(axes), {im.ndim}}"
417+
assert im.ndim in (4, 5), f"{im.ndim}"
405418

406-
# deepimagej expects xyzc axis order
407-
if im.ndim == 3:
408-
assert set(axes) == {"x", "y", "c"}
409-
axes_ij = "xyc"
419+
# convert the image to expects (Z)CYX axis order
420+
if im.ndim == 4:
421+
assert set(axes) == {"b", "x", "y", "c"}, f"{axes}"
422+
axes_ij = "cyxb"
410423
else:
411-
assert set(axes) == {"x", "y", "z", "c"}
412-
axes_ij = "xyzc"
424+
assert set(axes) == {"b", "x", "y", "z", "c"}, f"{axes}"
425+
axes_ij = "zcyxb"
413426

414-
axis_permutation = tuple(axes_ij.index(ax) for ax in axes)
427+
axis_permutation = tuple(axes.index(ax) for ax in axes_ij)
415428
im = im.transpose(axis_permutation)
416-
417-
with tifffile.TiffWriter(path) as f:
418-
f.write(im)
429+
# expand to TZCYXS
430+
if len(axes_ij) == 4: # add singleton t and z axis
431+
im = im[None, None]
432+
else: # add singeton z axis
433+
im = im[None]
434+
435+
if pixel_size is None:
436+
resolution = None
437+
else:
438+
spatial_axes = list(set(axes_ij) - set("bc"))
439+
resolution = tuple(1.0 / pixel_size[ax] for ax in axes_ij if ax in spatial_axes)
440+
# does not work for double
441+
if np.dtype(im.dtype) == np.dtype("float64"):
442+
im = im.astype("float32")
443+
tifffile.imsave(path, im, imagej=True, resolution=resolution)
419444

420445
sample_in_paths = []
421446
for i, (in_path, axes) in enumerate(zip(input_paths, input_axes)):
422-
inp = np.load(export_folder / in_path)[0]
447+
inp = np.load(export_folder / in_path)
423448
sample_in_path = export_folder / f"sample_input_{i}.tif"
424-
write_im(sample_in_path, inp, axes)
449+
pixel_size = None if pixel_sizes is None else pixel_sizes[i]
450+
write_im(sample_in_path, inp, axes, pixel_size)
425451
sample_in_paths.append(sample_in_path)
426452

427453
sample_out_paths = []
428454
for i, (out_path, axes) in enumerate(zip(output_paths, output_axes)):
429-
outp = np.load(export_folder / out_path)[0]
455+
outp = np.load(export_folder / out_path)
430456
sample_out_path = export_folder / f"sample_output_{i}.tif"
431457
write_im(sample_out_path, outp, axes)
432458
sample_out_paths.append(sample_out_path)
@@ -797,17 +823,15 @@ def build_model(
797823
# add the deepimagej config if specified
798824
if add_deepimagej_config:
799825
if sample_inputs is None:
800-
input_axes_ij = [inp.axes[1:] for inp in inputs]
801-
output_axes_ij = [out.axes[1:] for out in outputs]
802826
sample_inputs, sample_outputs = _write_sample_data(
803-
test_inputs, test_outputs, input_axes_ij, output_axes_ij, root
827+
test_inputs, test_outputs, input_axes, output_axes, pixel_sizes, root
804828
)
805829
# deepimagej expect tifs as sample data
806830
assert all(os.path.splitext(path)[1] in (".tif", ".tiff") for path in sample_inputs)
807831
assert all(os.path.splitext(path)[1] in (".tif", ".tiff") for path in sample_outputs)
808832

809833
ij_config, ij_attachments = _get_deepimagej_config(
810-
root, sample_inputs, sample_outputs, pixel_sizes, preprocessing, postprocessing
834+
root, test_inputs, test_outputs, input_axes, output_axes, pixel_sizes, preprocessing, postprocessing
811835
)
812836

813837
if config is None:

0 commit comments

Comments
 (0)