Skip to content

Commit 84742d9

Browse files
Fix tif writing
1 parent 667a991 commit 84742d9

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -413,46 +413,46 @@ def get_size(fname, axes):
413413
def _write_sample_data(input_paths, output_paths, input_axes, output_axes, pixel_sizes, export_folder: Path):
414414
def write_im(path, im, axes, pixel_size=None):
415415
assert tifffile is not None, "need tifffile for writing deepimagej config"
416-
assert len(axes) == im.ndim
417-
assert im.ndim in (3, 4)
416+
assert len(axes) == im.ndim, f"{len(axes), {im.ndim}}"
417+
assert im.ndim in (4, 5)
418418

419419
# convert the image to expects (Z)CYX axis order
420420
if im.ndim == 3:
421-
assert set(axes) == {"x", "y", "c"}
422-
axes_ij = "cyx"
421+
assert set(axes) == {"b", "x", "y", "c"}
422+
axes_ij = "cyxb"
423423
else:
424-
assert set(axes) == {"x", "y", "z", "c"}
425-
axes_ij = "zcyx"
424+
assert set(axes) == {"b", "x", "y", "z", "c"}
425+
axes_ij = "zcyxb"
426426

427-
axis_permutation = tuple(axes_ij.index(ax) for ax in axes)
427+
axis_permutation = tuple(axes.index(ax) for ax in axes_ij)
428428
im = im.transpose(axis_permutation)
429429
# expand to TZCYXS
430-
if len(axes_ij) == 2: # add singleton z axis
431-
im = im[None, None, ..., None]
432-
else:
433-
im = im[None, ..., None]
430+
if len(axes_ij) == 2: # add singleton t and z axis
431+
im = im[None, None]
432+
else: # add singeton z axis
433+
im = im[None]
434434

435435
if pixel_size is None:
436436
resolution = None
437437
else:
438-
spatial_axes = list(set(axes_ij) - set(["c"]))
439-
resolution = tuple(1.0 / pixel_size[ax] for ax in spatial_axes)
438+
spatial_axes = list(set(axes_ij) - set("bc"))
439+
resolution = tuple(1.0 / pixel_size[ax] for ax in axes_ij if ax in spatial_axes)
440440
# does not work for double
441441
if np.dtype(im.dtype) == np.dtype("float64"):
442442
im = im.astype("float32")
443443
tifffile.imsave(path, im, imagej=True, resolution=resolution)
444444

445445
sample_in_paths = []
446446
for i, (in_path, axes) in enumerate(zip(input_paths, input_axes)):
447-
inp = np.load(export_folder / in_path)[0]
447+
inp = np.load(export_folder / in_path)
448448
sample_in_path = export_folder / f"sample_input_{i}.tif"
449449
pixel_size = None if pixel_sizes is None else pixel_sizes[i]
450450
write_im(sample_in_path, inp, axes, pixel_size)
451451
sample_in_paths.append(sample_in_path)
452452

453453
sample_out_paths = []
454454
for i, (out_path, axes) in enumerate(zip(output_paths, output_axes)):
455-
outp = np.load(export_folder / out_path)[0]
455+
outp = np.load(export_folder / out_path)
456456
sample_out_path = export_folder / f"sample_output_{i}.tif"
457457
write_im(sample_out_path, outp, axes)
458458
sample_out_paths.append(sample_out_path)

0 commit comments

Comments
 (0)