@@ -338,8 +338,10 @@ def _get_deepimagej_macro(name, kwargs, export_folder):
338
338
return {"spec" : "ij.IJ::runMacroFile" , "kwargs" : macro }
339
339
340
340
341
- def _get_deepimagej_config (export_folder , sample_inputs , sample_outputs , pixel_sizes , preprocessing , postprocessing ):
342
- assert len (sample_inputs ) == len (sample_outputs ) == 1 , "deepimagej config only valid for single input/output"
341
+ def _get_deepimagej_config (
342
+ export_folder , test_inputs , test_outputs , input_axes , output_axes , pixel_sizes , preprocessing , postprocessing
343
+ ):
344
+ assert len (test_inputs ) == len (test_outputs ) == 1 , "deepimagej config only valid for single input/output"
343
345
344
346
if any (preproc is not None for preproc in preprocessing ):
345
347
assert len (preprocessing ) == 1
@@ -363,13 +365,21 @@ def _get_deepimagej_config(export_folder, sample_inputs, sample_outputs, pixel_s
363
365
else :
364
366
postprocess_ij = [{"spec" : None }]
365
367
366
- def get_size (path ):
367
- assert tifffile is not None , "need tifffile for writing deepimagej config"
368
- with tifffile .TiffFile (export_folder / path ) as f :
369
- shape = f .asarray ().shape
370
- # add singleton z axis if we have 2d data
368
+ def get_size (fname , axes ):
369
+ shape = np .load (export_folder / fname ).shape
370
+ assert len (shape ) == len (axes )
371
+ shape = [sh for sh , ax in zip (shape , axes ) if ax != "b" ]
372
+ axes = [ax for ax in axes if ax != "b" ]
373
+ # the shape for deepij is always given as xyzc
374
+ if len (shape ) == 3 :
375
+ axes_ij = "xyc"
376
+ else :
377
+ axes_ij = "xyzc"
378
+ assert set (axes ) == set (axes_ij )
379
+ axis_permutation = [axes_ij .index (ax ) for ax in axes ]
380
+ shape = [shape [permut ] for permut in axis_permutation ]
371
381
if len (shape ) == 3 :
372
- shape = shape [:2 ] + ( 1 ,) + shape [- 1 :]
382
+ shape = shape [:2 ] + [ 1 ] + shape [- 1 :]
373
383
assert len (shape ) == 4
374
384
return " x " .join (map (str , shape ))
375
385
@@ -378,10 +388,13 @@ def get_size(path):
378
388
379
389
test_info = {
380
390
"inputs" : [
381
- {"name" : in_path , "size" : get_size (in_path ), "pixel_size" : pix_size }
382
- for in_path , pix_size in zip (sample_inputs , pixel_sizes_ )
391
+ {"name" : in_path , "size" : get_size (in_path , axes ), "pixel_size" : pix_size }
392
+ for in_path , axes , pix_size in zip (test_inputs , input_axes , pixel_sizes_ )
393
+ ],
394
+ "outputs" : [
395
+ {"name" : out_path , "type" : "image" , "size" : get_size (out_path , axes )}
396
+ for out_path , axes in zip (test_outputs , output_axes )
383
397
],
384
- "outputs" : [{"name" : out_path , "type" : "image" , "size" : get_size (out_path )} for out_path in sample_outputs ],
385
398
"memory_peak" : None ,
386
399
"runtime" : None ,
387
400
}
@@ -397,36 +410,49 @@ def get_size(path):
397
410
return {"deepimagej" : config }, [Path (a ) for a in attachments ]
398
411
399
412
400
- def _write_sample_data (input_paths , output_paths , input_axes , output_axes , export_folder : Path ):
401
- def write_im (path , im , axes ):
413
+ def _write_sample_data (input_paths , output_paths , input_axes , output_axes , pixel_sizes , export_folder : Path ):
414
+ def write_im (path , im , axes , pixel_size = None ):
402
415
assert tifffile is not None , "need tifffile for writing deepimagej config"
403
- assert len (axes ) == im .ndim
404
- assert im .ndim in (3 , 4 )
416
+ assert len (axes ) == im .ndim , f" { len ( axes ), { im . ndim } } "
417
+ assert im .ndim in (4 , 5 ), f" { im . ndim } "
405
418
406
- # deepimagej expects xyzc axis order
407
- if im .ndim == 3 :
408
- assert set (axes ) == {"x" , "y" , "c" }
409
- axes_ij = "xyc "
419
+ # convert the image to expects (Z)CYX axis order
420
+ if im .ndim == 4 :
421
+ assert set (axes ) == {"b" , " x" , "y" , "c" }, f" { axes } "
422
+ axes_ij = "cyxb "
410
423
else :
411
- assert set (axes ) == {"x" , "y" , "z" , "c" }
412
- axes_ij = "xyzc "
424
+ assert set (axes ) == {"b" , " x" , "y" , "z" , "c" }, f" { axes } "
425
+ axes_ij = "zcyxb "
413
426
414
- axis_permutation = tuple (axes_ij .index (ax ) for ax in axes )
427
+ axis_permutation = tuple (axes .index (ax ) for ax in axes_ij )
415
428
im = im .transpose (axis_permutation )
416
-
417
- with tifffile .TiffWriter (path ) as f :
418
- f .write (im )
429
+ # expand to TZCYXS
430
+ if len (axes_ij ) == 4 : # add singleton t and z axis
431
+ im = im [None , None ]
432
+ else : # add singeton z axis
433
+ im = im [None ]
434
+
435
+ if pixel_size is None :
436
+ resolution = None
437
+ else :
438
+ spatial_axes = list (set (axes_ij ) - set ("bc" ))
439
+ resolution = tuple (1.0 / pixel_size [ax ] for ax in axes_ij if ax in spatial_axes )
440
+ # does not work for double
441
+ if np .dtype (im .dtype ) == np .dtype ("float64" ):
442
+ im = im .astype ("float32" )
443
+ tifffile .imsave (path , im , imagej = True , resolution = resolution )
419
444
420
445
sample_in_paths = []
421
446
for i , (in_path , axes ) in enumerate (zip (input_paths , input_axes )):
422
- inp = np .load (export_folder / in_path )[ 0 ]
447
+ inp = np .load (export_folder / in_path )
423
448
sample_in_path = export_folder / f"sample_input_{ i } .tif"
424
- write_im (sample_in_path , inp , axes )
449
+ pixel_size = None if pixel_sizes is None else pixel_sizes [i ]
450
+ write_im (sample_in_path , inp , axes , pixel_size )
425
451
sample_in_paths .append (sample_in_path )
426
452
427
453
sample_out_paths = []
428
454
for i , (out_path , axes ) in enumerate (zip (output_paths , output_axes )):
429
- outp = np .load (export_folder / out_path )[ 0 ]
455
+ outp = np .load (export_folder / out_path )
430
456
sample_out_path = export_folder / f"sample_output_{ i } .tif"
431
457
write_im (sample_out_path , outp , axes )
432
458
sample_out_paths .append (sample_out_path )
@@ -797,17 +823,15 @@ def build_model(
797
823
# add the deepimagej config if specified
798
824
if add_deepimagej_config :
799
825
if sample_inputs is None :
800
- input_axes_ij = [inp .axes [1 :] for inp in inputs ]
801
- output_axes_ij = [out .axes [1 :] for out in outputs ]
802
826
sample_inputs , sample_outputs = _write_sample_data (
803
- test_inputs , test_outputs , input_axes_ij , output_axes_ij , root
827
+ test_inputs , test_outputs , input_axes , output_axes , pixel_sizes , root
804
828
)
805
829
# deepimagej expect tifs as sample data
806
830
assert all (os .path .splitext (path )[1 ] in (".tif" , ".tiff" ) for path in sample_inputs )
807
831
assert all (os .path .splitext (path )[1 ] in (".tif" , ".tiff" ) for path in sample_outputs )
808
832
809
833
ij_config , ij_attachments = _get_deepimagej_config (
810
- root , sample_inputs , sample_outputs , pixel_sizes , preprocessing , postprocessing
834
+ root , test_inputs , test_outputs , input_axes , output_axes , pixel_sizes , preprocessing , postprocessing
811
835
)
812
836
813
837
if config is None :
0 commit comments