@@ -307,6 +307,51 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors)
307
307
tt .close ();
308
308
}
309
309
}
310
+
311
+ @ Override
312
+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >> List <RandomAccessibleInterval <R >> inference (
313
+ List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
314
+ if (interprocessing ) {
315
+ return runInterprocessing (inputs );
316
+ }
317
+ Session session = model .session ();
318
+ Session .Runner runner = session .runner ();
319
+ List <String > inputListNames = new ArrayList <String >();
320
+ List <org .tensorflow .Tensor <?>> inTensors = new ArrayList <org .tensorflow .Tensor <?>>();
321
+ int c = 0 ;
322
+ for (RandomAccessibleInterval <T > tt : inputs ) {
323
+ org .tensorflow .Tensor <?> inT = TensorBuilder .build (tt );
324
+ inTensors .add (inT );
325
+ String inputName = getModelInputName ("input" + c , c ++);
326
+ runner .feed (inputName , inT );
327
+ }
328
+ c = 0 ;
329
+ List <String > outputInfo = sig .getOutputsMap ().values ().stream ()
330
+ .map (nn -> {
331
+ String name = nn .getName ();
332
+ if (name .endsWith (":0" ))
333
+ return name .substring (0 , name .length () - 2 );
334
+ return name ;
335
+ }).collect (Collectors .toList ());
336
+
337
+ for (String name : outputInfo )
338
+ runner = runner .fetch (name );
339
+ // Run runner
340
+ List <org .tensorflow .Tensor <?>> resultPatchTensors = runner .run ();
341
+ for (org .tensorflow .Tensor <?> tt : inTensors )
342
+ tt .close ();
343
+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
344
+ for (int i = 0 ; i < resultPatchTensors .size (); i ++) {
345
+ try {
346
+ rais .add (ImgLib2Builder .build (resultPatchTensors .get (i )));
347
+ } catch (IllegalArgumentException ex ) {
348
+ for (org .tensorflow .Tensor <?> tt : resultPatchTensors )
349
+ tt .close ();
350
+ throw new RunModelException (Types .stackTrace (ex ));
351
+ }
352
+ }
353
+ return rais ;
354
+ }
310
355
311
356
protected void runFromShmas (List <String > inputs , List <String > outputs ) throws IOException {
312
357
Session session = model .session ();
@@ -345,6 +390,45 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
345
390
}
346
391
}
347
392
393
+ protected List <String > inferenceFromShmas (List <String > inputs ) throws IOException , RunModelException {
394
+ Session session = model .session ();
395
+ Session .Runner runner = session .runner ();
396
+ List <org .tensorflow .Tensor <?>> inTensors =
397
+ new ArrayList <org .tensorflow .Tensor <?>>();
398
+ int c = 0 ;
399
+ for (String ee : inputs ) {
400
+ Map <String , Object > decoded = Types .decode (ee );
401
+ SharedMemoryArray shma = SharedMemoryArray .read ((String ) decoded .get (MEM_NAME_KEY ));
402
+ org .tensorflow .Tensor <?> inT = io .bioimage .modelrunner .tensorflow .v2 .api020 .shm .TensorBuilder .build (shma );
403
+ if (PlatformDetection .isWindows ()) shma .close ();
404
+ inTensors .add (inT );
405
+ String inputName = getModelInputName ((String ) decoded .get (NAME_KEY ), c ++);
406
+ runner .feed (inputName , inT );
407
+ }
408
+ List <String > outputInfo = sig .getOutputsMap ().values ().stream ()
409
+ .map (nn -> {
410
+ String name = nn .getName ();
411
+ if (name .endsWith (":0" ))
412
+ return name .substring (0 , name .length () - 2 );
413
+ return name ;
414
+ }).collect (Collectors .toList ());
415
+
416
+ for (String name : outputInfo )
417
+ runner = runner .fetch (name );
418
+ // Run runner
419
+ List <org .tensorflow .Tensor <?>> resultPatchTensors = runner .run ();
420
+ for (org .tensorflow .Tensor <?> tt : inTensors )
421
+ tt .close ();
422
+
423
+ shmaNamesList = new ArrayList <String >();
424
+ for (int i = 0 ; i < resultPatchTensors .size (); i ++) {
425
+ String name = SharedMemoryArray .createShmName ();
426
+ ShmBuilder .build (resultPatchTensors .get (i ), name , false );
427
+ shmaNamesList .add (name );
428
+ }
429
+ return shmaNamesList ;
430
+ }
431
+
348
432
/**
349
433
* MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
350
434
* to create another process, communicate the model info and tensors to the other
@@ -397,7 +481,89 @@ else if (task.status == TaskStatus.CRASHED) {
397
481
closeShmas ();
398
482
}
399
483
400
- private void closeShmas () {
484
+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
485
+ List <RandomAccessibleInterval <R >> runInterprocessing (List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
486
+ shmaInputList = new ArrayList <SharedMemoryArray >();
487
+ List <String > encIns = new ArrayList <String >();
488
+ Gson gson = new Gson ();
489
+ for (RandomAccessibleInterval <T > tt : inputs ) {
490
+ SharedMemoryArray shma = SharedMemoryArray .createSHMAFromRAI (tt , false , true );
491
+ shmaInputList .add (shma );
492
+ HashMap <String , Object > map = new HashMap <String , Object >();
493
+ map .put (SHAPE_KEY , tt .dimensionsAsLongArray ());
494
+ map .put (DTYPE_KEY , CommonUtils .getDataTypeFromRAI (tt ));
495
+ map .put (IS_INPUT_KEY , true );
496
+ map .put (MEM_NAME_KEY , shma .getName ());
497
+ encIns .add (gson .toJson (map ));
498
+ }
499
+ LinkedHashMap <String , Object > args = new LinkedHashMap <String , Object >();
500
+ args .put ("inputs" , encIns );
501
+
502
+ try {
503
+ Task task = runner .task ("inference" , args );
504
+ task .waitFor ();
505
+ if (task .status == TaskStatus .CANCELED )
506
+ throw new RuntimeException ();
507
+ else if (task .status == TaskStatus .FAILED )
508
+ throw new RuntimeException (task .error );
509
+ else if (task .status == TaskStatus .CRASHED ) {
510
+ this .runner .close ();
511
+ runner = null ;
512
+ throw new RuntimeException (task .error );
513
+ } else if (task .outputs == null )
514
+ throw new RuntimeException ("No outputs generated" );
515
+ List <String > outputs = (List <String >) task .outputs .get ("encoded" );
516
+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
517
+ for (String out : outputs ) {
518
+ String name = (String ) Types .decode (out ).get (MEM_NAME_KEY );
519
+ SharedMemoryArray shm = SharedMemoryArray .read (name );
520
+ RandomAccessibleInterval <R > rai = shm .getSharedRAI ();
521
+ rais .add (Tensor .createCopyOfRaiInWantedDataType (Cast .unchecked (rai ), Util .getTypeFromInterval (Cast .unchecked (rai ))));
522
+ shm .close ();
523
+ }
524
+ closeShmas ();
525
+ return rais ;
526
+ } catch (Exception e ) {
527
+ closeShmas ();
528
+ if (e instanceof RunModelException )
529
+ throw (RunModelException ) e ;
530
+ throw new RunModelException (Types .stackTrace (e ));
531
+ }
532
+ }
533
+
534
+ private void closeInterprocess () throws RunModelException {
535
+ try {
536
+ Task task = runner .task ("closeTensors" );
537
+ task .waitFor ();
538
+ if (task .status == TaskStatus .CANCELED )
539
+ throw new RuntimeException ();
540
+ else if (task .status == TaskStatus .FAILED )
541
+ throw new RuntimeException (task .error );
542
+ else if (task .status == TaskStatus .CRASHED ) {
543
+ this .runner .close ();
544
+ runner = null ;
545
+ throw new RuntimeException (task .error );
546
+ }
547
+ } catch (Exception e ) {
548
+ if (e instanceof RunModelException )
549
+ throw (RunModelException ) e ;
550
+ throw new RunModelException (Types .stackTrace (e ));
551
+ }
552
+ }
553
+
554
+ protected void closeFromInterp () {
555
+ if (!PlatformDetection .isWindows ())
556
+ return ;
557
+ this .shmaNamesList .stream ().forEach (nn -> {
558
+ try {
559
+ SharedMemoryArray .read (nn ).close ();
560
+ } catch (IOException e ) {
561
+ e .printStackTrace ();
562
+ }
563
+ });
564
+ }
565
+
566
+ private void closeShmas () throws RunModelException {
401
567
shmaInputList .forEach (shm -> {
402
568
try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
403
569
});
@@ -406,8 +572,9 @@ private void closeShmas() {
406
572
try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
407
573
});
408
574
shmaOutputList = null ;
409
- }
410
-
575
+ if (interprocessing )
576
+ closeInterprocess ();
577
+ }
411
578
412
579
private <T extends RealType <T > & NativeType <T >> List <String > encodeInputs (List <Tensor <T >> inputTensors ) {
413
580
List <String > encodedInputTensors = new ArrayList <String >();
0 commit comments