Skip to content

Commit fdaf4d6

Browse files
Fix issues with writing the shape to the deepIJ config
1 parent 9241a39 commit fdaf4d6

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,10 @@ def _get_deepimagej_macro(name, kwargs, export_folder):
294294
return {"spec": "ij.IJ::runMacroFile", "kwargs": macro}
295295

296296

297-
def _get_deepimagej_config(export_folder, sample_inputs, sample_outputs, pixel_sizes, preprocessing, postprocessing):
298-
assert len(sample_inputs) == len(sample_outputs) == 1, "deepimagej config only valid for single input/output"
297+
def _get_deepimagej_config(
298+
export_folder, test_inputs, test_outputs, input_axes, output_axes, pixel_sizes, preprocessing, postprocessing
299+
):
300+
assert len(test_inputs) == len(test_outputs) == 1, "deepimagej config only valid for single input/output"
299301

300302
if any(preproc is not None for preproc in preprocessing):
301303
assert len(preprocessing) == 1
@@ -319,15 +321,21 @@ def _get_deepimagej_config(export_folder, sample_inputs, sample_outputs, pixel_s
319321
else:
320322
postprocess_ij = [{"spec": None}]
321323

322-
def get_size(path):
323-
assert tifffile is not None, "need tifffile for writing deepimagej config"
324-
with tifffile.TiffFile(export_folder / path) as f:
325-
shape = f.asarray().shape
326-
# add singleton z/c axis
327-
if len(shape) == 2:
328-
shape = (1,) + shape + (1,)
329-
elif len(shape) == 3:
330-
shape = shape[:2] + (1,) + shape[-1:]
324+
def get_size(fname, axes):
325+
shape = np.load(export_folder / fname).shape
326+
assert len(shape) == len(axes)
327+
shape = [sh for sh, ax in zip(shape, axes) if ax != "b"]
328+
axes = [ax for ax in axes if ax != "b"]
329+
# the shape for deepij is always given as xyzc
330+
if len(shape) == 3:
331+
axes_ij = "xyc"
332+
else:
333+
axes_ij = "xyzc"
334+
assert set(axes) == set(axes_ij)
335+
axis_permutation = [axes_ij.index(ax) for ax in axes]
336+
shape = [shape[permut] for permut in axis_permutation]
337+
if len(shape) == 3:
338+
shape = shape[:2] + [1] + shape[-1:]
331339
assert len(shape) == 4
332340
return " x ".join(map(str, shape))
333341

@@ -336,10 +344,10 @@ def get_size(path):
336344

337345
test_info = {
338346
"inputs": [
339-
{"name": in_path, "size": get_size(in_path), "pixel_size": pix_size}
340-
for in_path, pix_size in zip(sample_inputs, pixel_sizes_)
347+
{"name": in_path, "size": get_size(in_path, axes), "pixel_size": pix_size}
348+
for in_path, axes, pix_size in zip(test_inputs, input_axes, pixel_sizes_)
341349
],
342-
"outputs": [{"name": out_path, "type": "image", "size": get_size(out_path)} for out_path in sample_outputs],
350+
"outputs": [{"name": out_path, "type": "image", "size": get_size(out_path, axes)} for out_path, axes in zip(test_outputs, output_axes)],
343351
"memory_peak": None,
344352
"runtime": None,
345353
}
@@ -786,7 +794,7 @@ def build_model(
786794
assert all(os.path.splitext(path)[1] in (".tif", ".tiff") for path in sample_outputs)
787795

788796
ij_config, ij_attachments = _get_deepimagej_config(
789-
root, sample_inputs, sample_outputs, pixel_sizes, preprocessing, postprocessing
797+
root, test_inputs, test_outputs, input_axes, output_axes, pixel_sizes, preprocessing, postprocessing
790798
)
791799

792800
if config is None:

0 commit comments

Comments
 (0)