@@ -294,8 +294,10 @@ def _get_deepimagej_macro(name, kwargs, export_folder):
294
294
return {"spec" : "ij.IJ::runMacroFile" , "kwargs" : macro }
295
295
296
296
297
- def _get_deepimagej_config (export_folder , sample_inputs , sample_outputs , pixel_sizes , preprocessing , postprocessing ):
298
- assert len (sample_inputs ) == len (sample_outputs ) == 1 , "deepimagej config only valid for single input/output"
297
+ def _get_deepimagej_config (
298
+ export_folder , test_inputs , test_outputs , input_axes , output_axes , pixel_sizes , preprocessing , postprocessing
299
+ ):
300
+ assert len (test_inputs ) == len (test_outputs ) == 1 , "deepimagej config only valid for single input/output"
299
301
300
302
if any (preproc is not None for preproc in preprocessing ):
301
303
assert len (preprocessing ) == 1
@@ -319,15 +321,21 @@ def _get_deepimagej_config(export_folder, sample_inputs, sample_outputs, pixel_s
319
321
else :
320
322
postprocess_ij = [{"spec" : None }]
321
323
322
- def get_size (path ):
323
- assert tifffile is not None , "need tifffile for writing deepimagej config"
324
- with tifffile .TiffFile (export_folder / path ) as f :
325
- shape = f .asarray ().shape
326
- # add singleton z/c axis
327
- if len (shape ) == 2 :
328
- shape = (1 ,) + shape + (1 ,)
329
- elif len (shape ) == 3 :
330
- shape = shape [:2 ] + (1 ,) + shape [- 1 :]
324
+ def get_size (fname , axes ):
325
+ shape = np .load (export_folder / fname ).shape
326
+ assert len (shape ) == len (axes )
327
+ shape = [sh for sh , ax in zip (shape , axes ) if ax != "b" ]
328
+ axes = [ax for ax in axes if ax != "b" ]
329
+ # the shape for deepij is always given as xyzc
330
+ if len (shape ) == 3 :
331
+ axes_ij = "xyc"
332
+ else :
333
+ axes_ij = "xyzc"
334
+ assert set (axes ) == set (axes_ij )
335
+ axis_permutation = [axes_ij .index (ax ) for ax in axes ]
336
+ shape = [shape [permut ] for permut in axis_permutation ]
337
+ if len (shape ) == 3 :
338
+ shape = shape [:2 ] + [1 ] + shape [- 1 :]
331
339
assert len (shape ) == 4
332
340
return " x " .join (map (str , shape ))
333
341
@@ -336,10 +344,10 @@ def get_size(path):
336
344
337
345
test_info = {
338
346
"inputs" : [
339
- {"name" : in_path , "size" : get_size (in_path ), "pixel_size" : pix_size }
340
- for in_path , pix_size in zip (sample_inputs , pixel_sizes_ )
347
+ {"name" : in_path , "size" : get_size (in_path , axes ), "pixel_size" : pix_size }
348
+ for in_path , axes , pix_size in zip (test_inputs , input_axes , pixel_sizes_ )
341
349
],
342
- "outputs" : [{"name" : out_path , "type" : "image" , "size" : get_size (out_path )} for out_path in sample_outputs ],
350
+ "outputs" : [{"name" : out_path , "type" : "image" , "size" : get_size (out_path , axes )} for out_path , axes in zip ( test_outputs , output_axes ) ],
343
351
"memory_peak" : None ,
344
352
"runtime" : None ,
345
353
}
@@ -786,7 +794,7 @@ def build_model(
786
794
assert all (os .path .splitext (path )[1 ] in (".tif" , ".tiff" ) for path in sample_outputs )
787
795
788
796
ij_config , ij_attachments = _get_deepimagej_config (
789
- root , sample_inputs , sample_outputs , pixel_sizes , preprocessing , postprocessing
797
+ root , test_inputs , test_outputs , input_axes , output_axes , pixel_sizes , preprocessing , postprocessing
790
798
)
791
799
792
800
if config is None :
0 commit comments