@@ -316,6 +316,51 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors)
316
316
tt .close ();
317
317
}
318
318
}
319
+
320
+ @ Override
321
+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >> List <RandomAccessibleInterval <R >> inference (
322
+ List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
323
+ if (interprocessing ) {
324
+ return runInterprocessing (inputs );
325
+ }
326
+ Session session = model .session ();
327
+ Session .Runner runner = session .runner ();
328
+ List <org .tensorflow .Tensor <?>> inTensors =
329
+ new ArrayList <org .tensorflow .Tensor <?>>();
330
+ int c = 0 ;
331
+ for (RandomAccessibleInterval <T > tt : inputs ) {
332
+ org .tensorflow .Tensor <?> inT = TensorBuilder .build (tt );
333
+ inTensors .add (inT );
334
+ String inputName = getModelInputName ("input" + c , c ++);
335
+ runner .feed (inputName , inT );
336
+ }
337
+ c = 0 ;
338
+ List <String > outputInfo = sig .getOutputsMap ().values ().stream ()
339
+ .map (nn -> {
340
+ String name = nn .getName ();
341
+ if (name .endsWith (":0" ))
342
+ return name .substring (0 , name .length () - 2 );
343
+ return name ;
344
+ }).collect (Collectors .toList ());
345
+
346
+ for (String name : outputInfo )
347
+ runner = runner .fetch (name );
348
+ // Run runner
349
+ List <org .tensorflow .Tensor <?>> resultPatchTensors = runner .run ();
350
+ for (org .tensorflow .Tensor <?> tt : inTensors )
351
+ tt .close ();
352
+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
353
+ for (int i = 0 ; i < resultPatchTensors .size (); i ++) {
354
+ try {
355
+ rais .add (ImgLib2Builder .build (resultPatchTensors .get (i )));
356
+ } catch (IllegalArgumentException ex ) {
357
+ for (org .tensorflow .Tensor <?> tt : resultPatchTensors )
358
+ tt .close ();
359
+ throw new RunModelException (Types .stackTrace (ex ));
360
+ }
361
+ }
362
+ return rais ;
363
+ }
319
364
320
365
protected void runFromShmas (List <String > inputs , List <String > outputs ) throws IOException {
321
366
Session session = model .session ();
@@ -354,17 +399,46 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
354
399
}
355
400
}
356
401
357
- /**
358
- * MEthod only used in MacOS Intel systems that makes all the arangements
359
- * to create another process, communicate the model info and tensors to the other
360
- * process and then retrieve the results of the other process
361
- * @param inputTensors
362
- * tensors that are going to be run on the model
363
- * @param outputTensors
364
- * expected results of the model
365
- * @throws RunModelException if there is any issue running the model
366
- */
367
- public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
402
+ protected List <String > inferenceFromShmas (List <String > inputs ) throws IOException , RunModelException {
403
+ Session session = model .session ();
404
+ Session .Runner runner = session .runner ();
405
+ List <org .tensorflow .Tensor <?>> inTensors =
406
+ new ArrayList <org .tensorflow .Tensor <?>>();
407
+ int c = 0 ;
408
+ for (String ee : inputs ) {
409
+ Map <String , Object > decoded = Types .decode (ee );
410
+ SharedMemoryArray shma = SharedMemoryArray .read ((String ) decoded .get (MEM_NAME_KEY ));
411
+ org .tensorflow .Tensor <?> inT = io .bioimage .modelrunner .tensorflow .v1 .shm .TensorBuilder .build (shma );
412
+ if (PlatformDetection .isWindows ()) shma .close ();
413
+ inTensors .add (inT );
414
+ String inputName = getModelInputName ((String ) decoded .get (NAME_KEY ), c ++);
415
+ runner .feed (inputName , inT );
416
+ }
417
+ List <String > outputInfo = sig .getOutputsMap ().values ().stream ()
418
+ .map (nn -> {
419
+ String name = nn .getName ();
420
+ if (name .endsWith (":0" ))
421
+ return name .substring (0 , name .length () - 2 );
422
+ return name ;
423
+ }).collect (Collectors .toList ());
424
+
425
+ for (String name : outputInfo )
426
+ runner = runner .fetch (name );
427
+ // Run runner
428
+ List <org .tensorflow .Tensor <?>> resultPatchTensors = runner .run ();
429
+ for (org .tensorflow .Tensor <?> tt : inTensors )
430
+ tt .close ();
431
+
432
+ shmaNamesList = new ArrayList <String >();
433
+ for (int i = 0 ; i < resultPatchTensors .size (); i ++) {
434
+ String name = SharedMemoryArray .createShmName ();
435
+ ShmBuilder .build (resultPatchTensors .get (i ), name , false );
436
+ shmaNamesList .add (name );
437
+ }
438
+ return shmaNamesList ;
439
+ }
440
+
441
+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
368
442
void runInterprocessing (List <Tensor <T >> inputTensors , List <Tensor <R >> outputTensors ) throws RunModelException {
369
443
shmaInputList = new ArrayList <SharedMemoryArray >();
370
444
shmaOutputList = new ArrayList <SharedMemoryArray >();
@@ -375,7 +449,7 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
375
449
args .put ("outputs" , encOuts );
376
450
377
451
try {
378
- Task task = runner .task ("inference " , args );
452
+ Task task = runner .task ("run " , args );
379
453
task .waitFor ();
380
454
if (task .status == TaskStatus .CANCELED )
381
455
throw new RuntimeException ();
@@ -406,7 +480,89 @@ else if (task.status == TaskStatus.CRASHED) {
406
480
closeShmas ();
407
481
}
408
482
409
- private void closeShmas () {
483
+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
484
+ List <RandomAccessibleInterval <R >> runInterprocessing (List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
485
+ shmaInputList = new ArrayList <SharedMemoryArray >();
486
+ List <String > encIns = new ArrayList <String >();
487
+ Gson gson = new Gson ();
488
+ for (RandomAccessibleInterval <T > tt : inputs ) {
489
+ SharedMemoryArray shma = SharedMemoryArray .createSHMAFromRAI (tt , false , true );
490
+ shmaInputList .add (shma );
491
+ HashMap <String , Object > map = new HashMap <String , Object >();
492
+ map .put (SHAPE_KEY , tt .dimensionsAsLongArray ());
493
+ map .put (DTYPE_KEY , CommonUtils .getDataTypeFromRAI (tt ));
494
+ map .put (IS_INPUT_KEY , true );
495
+ map .put (MEM_NAME_KEY , shma .getName ());
496
+ encIns .add (gson .toJson (map ));
497
+ }
498
+ LinkedHashMap <String , Object > args = new LinkedHashMap <String , Object >();
499
+ args .put ("inputs" , encIns );
500
+
501
+ try {
502
+ Task task = runner .task ("inference" , args );
503
+ task .waitFor ();
504
+ if (task .status == TaskStatus .CANCELED )
505
+ throw new RuntimeException ();
506
+ else if (task .status == TaskStatus .FAILED )
507
+ throw new RuntimeException (task .error );
508
+ else if (task .status == TaskStatus .CRASHED ) {
509
+ this .runner .close ();
510
+ runner = null ;
511
+ throw new RuntimeException (task .error );
512
+ } else if (task .outputs == null )
513
+ throw new RuntimeException ("No outputs generated" );
514
+ List <String > outputs = (List <String >) task .outputs .get ("encoded" );
515
+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
516
+ for (String out : outputs ) {
517
+ String name = (String ) Types .decode (out ).get (MEM_NAME_KEY );
518
+ SharedMemoryArray shm = SharedMemoryArray .read (name );
519
+ RandomAccessibleInterval <R > rai = shm .getSharedRAI ();
520
+ rais .add (Tensor .createCopyOfRaiInWantedDataType (Cast .unchecked (rai ), Util .getTypeFromInterval (Cast .unchecked (rai ))));
521
+ shm .close ();
522
+ }
523
+ closeShmas ();
524
+ return rais ;
525
+ } catch (Exception e ) {
526
+ closeShmas ();
527
+ if (e instanceof RunModelException )
528
+ throw (RunModelException ) e ;
529
+ throw new RunModelException (Types .stackTrace (e ));
530
+ }
531
+ }
532
+
533
+ private void closeInterprocess () throws RunModelException {
534
+ try {
535
+ Task task = runner .task ("closeTensors" );
536
+ task .waitFor ();
537
+ if (task .status == TaskStatus .CANCELED )
538
+ throw new RuntimeException ();
539
+ else if (task .status == TaskStatus .FAILED )
540
+ throw new RuntimeException (task .error );
541
+ else if (task .status == TaskStatus .CRASHED ) {
542
+ this .runner .close ();
543
+ runner = null ;
544
+ throw new RuntimeException (task .error );
545
+ }
546
+ } catch (Exception e ) {
547
+ if (e instanceof RunModelException )
548
+ throw (RunModelException ) e ;
549
+ throw new RunModelException (Types .stackTrace (e ));
550
+ }
551
+ }
552
+
553
+ protected void closeFromInterp () {
554
+ if (!PlatformDetection .isWindows ())
555
+ return ;
556
+ this .shmaNamesList .stream ().forEach (nn -> {
557
+ try {
558
+ SharedMemoryArray .read (nn ).close ();
559
+ } catch (IOException e ) {
560
+ e .printStackTrace ();
561
+ }
562
+ });
563
+ }
564
+
565
+ private void closeShmas () throws RunModelException {
410
566
shmaInputList .forEach (shm -> {
411
567
try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
412
568
});
@@ -415,6 +571,8 @@ private void closeShmas() {
415
571
try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
416
572
});
417
573
shmaOutputList = null ;
574
+ if (interprocessing )
575
+ closeInterprocess ();
418
576
}
419
577
420
578
@@ -483,7 +641,7 @@ public static <T extends RealType<T> & NativeType<T>> void fillOutputTensors(
483
641
try {
484
642
outputTensors .get (i ).setData (ImgLib2Builder .build (outputNDArrays .get (i )));
485
643
} catch (IllegalArgumentException ex ) {
486
- throw new RunModelException (ex . toString ( ));
644
+ throw new RunModelException (Types . stackTrace ( ex ));
487
645
}
488
646
}
489
647
}
0 commit comments