@@ -295,6 +295,51 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors)
295
295
resultPatchTensors .get (i ).close ();
296
296
}
297
297
}
298
+
299
+ @ Override
300
+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >> List <RandomAccessibleInterval <R >> inference (
301
+ List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
302
+ if (interprocessing ) {
303
+ return runInterprocessing (inputs );
304
+ }
305
+ Session session = model .session ();
306
+ Session .Runner runner = session .runner ();
307
+ List <TType > inTensors = new ArrayList <TType >();
308
+ int c = 0 ;
309
+ for (RandomAccessibleInterval <T > tt : inputs ) {
310
+ TType inT = TensorBuilder .build (tt );
311
+ inTensors .add (inT );
312
+ String inputName = getModelInputName ("input" + c , c ++);
313
+ runner .feed (inputName , inT );
314
+ }
315
+ c = 0 ;
316
+ List <String > outputInfo = sig .getOutputsMap ().values ().stream ()
317
+ .map (nn -> {
318
+ String name = nn .getName ();
319
+ if (name .endsWith (":0" ))
320
+ return name .substring (0 , name .length () - 2 );
321
+ return name ;
322
+ }).collect (Collectors .toList ());
323
+
324
+ for (String name : outputInfo )
325
+ runner = runner .fetch (name );
326
+ // Run runner
327
+ Result resultPatchTensors = runner .run ();
328
+ for (TType tt : inTensors )
329
+ tt .close ();
330
+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
331
+ for (int i = 0 ; i < resultPatchTensors .size (); i ++) {
332
+ try {
333
+ rais .add (ImgLib2Builder .build ((TType ) resultPatchTensors .get (i )));
334
+ ((TType ) resultPatchTensors .get (i )).close ();
335
+ } catch (IllegalArgumentException ex ) {
336
+ for (int j = i ; j < resultPatchTensors .size (); j ++)
337
+ ((TType ) resultPatchTensors .get (j )).close ();
338
+ throw new RunModelException (Types .stackTrace (ex ));
339
+ }
340
+ }
341
+ return rais ;
342
+ }
298
343
299
344
protected void runFromShmas (List <String > inputs , List <String > outputs ) throws IOException {
300
345
Session session = model .session ();
@@ -333,6 +378,46 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
333
378
}
334
379
}
335
380
381
+ protected List <String > inferenceFromShmas (List <String > inputs ) throws IOException , RunModelException {
382
+ Session session = model .session ();
383
+ Session .Runner runner = session .runner ();
384
+ List <TType > inTensors = new ArrayList <TType >();
385
+ int c = 0 ;
386
+ for (String ee : inputs ) {
387
+ Map <String , Object > decoded = Types .decode (ee );
388
+ SharedMemoryArray shma = SharedMemoryArray .read ((String ) decoded .get (MEM_NAME_KEY ));
389
+ TType inT = io .bioimage .modelrunner .tensorflow .v2 .api050 .shm .TensorBuilder .build (shma );
390
+ if (PlatformDetection .isWindows ()) shma .close ();
391
+ inTensors .add (inT );
392
+ String inputName = getModelInputName ((String ) decoded .get (NAME_KEY ), c ++);
393
+ runner .feed (inputName , inT );
394
+ }
395
+ List <String > outputInfo = sig .getOutputsMap ().values ().stream ()
396
+ .map (nn -> {
397
+ String name = nn .getName ();
398
+ if (name .endsWith (":0" ))
399
+ return name .substring (0 , name .length () - 2 );
400
+ return name ;
401
+ }).collect (Collectors .toList ());
402
+
403
+ for (String name : outputInfo )
404
+ runner = runner .fetch (name );
405
+ // Run runner
406
+ Result resultPatchTensors = runner .run ();
407
+ for (TType tt : inTensors )
408
+ tt .close ();
409
+
410
+ shmaNamesList = new ArrayList <String >();
411
+ for (int i = 0 ; i < resultPatchTensors .size (); i ++) {
412
+ String name = SharedMemoryArray .createShmName ();
413
+ ShmBuilder .build ((TType ) resultPatchTensors .get (i ), name , false );
414
+ shmaNamesList .add (name );
415
+ }
416
+ for (int i = 0 ; i < resultPatchTensors .size (); i ++)
417
+ resultPatchTensors .get (i ).close ();
418
+ return shmaNamesList ;
419
+ }
420
+
336
421
/**
337
422
* MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
338
423
* to create another process, communicate the model info and tensors to the other
@@ -389,7 +474,89 @@ else if (task.status == TaskStatus.CRASHED) {
389
474
closeShmas ();
390
475
}
391
476
392
- private void closeShmas () {
477
+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
478
+ List <RandomAccessibleInterval <R >> runInterprocessing (List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
479
+ shmaInputList = new ArrayList <SharedMemoryArray >();
480
+ List <String > encIns = new ArrayList <String >();
481
+ Gson gson = new Gson ();
482
+ for (RandomAccessibleInterval <T > tt : inputs ) {
483
+ SharedMemoryArray shma = SharedMemoryArray .createSHMAFromRAI (tt , false , true );
484
+ shmaInputList .add (shma );
485
+ HashMap <String , Object > map = new HashMap <String , Object >();
486
+ map .put (SHAPE_KEY , tt .dimensionsAsLongArray ());
487
+ map .put (DTYPE_KEY , CommonUtils .getDataTypeFromRAI (tt ));
488
+ map .put (IS_INPUT_KEY , true );
489
+ map .put (MEM_NAME_KEY , shma .getName ());
490
+ encIns .add (gson .toJson (map ));
491
+ }
492
+ LinkedHashMap <String , Object > args = new LinkedHashMap <String , Object >();
493
+ args .put ("inputs" , encIns );
494
+
495
+ try {
496
+ Task task = runner .task ("inference" , args );
497
+ task .waitFor ();
498
+ if (task .status == TaskStatus .CANCELED )
499
+ throw new RuntimeException ();
500
+ else if (task .status == TaskStatus .FAILED )
501
+ throw new RuntimeException (task .error );
502
+ else if (task .status == TaskStatus .CRASHED ) {
503
+ this .runner .close ();
504
+ runner = null ;
505
+ throw new RuntimeException (task .error );
506
+ } else if (task .outputs == null )
507
+ throw new RuntimeException ("No outputs generated" );
508
+ List <String > outputs = (List <String >) task .outputs .get ("encoded" );
509
+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
510
+ for (String out : outputs ) {
511
+ String name = (String ) Types .decode (out ).get (MEM_NAME_KEY );
512
+ SharedMemoryArray shm = SharedMemoryArray .read (name );
513
+ RandomAccessibleInterval <R > rai = shm .getSharedRAI ();
514
+ rais .add (Tensor .createCopyOfRaiInWantedDataType (Cast .unchecked (rai ), Util .getTypeFromInterval (Cast .unchecked (rai ))));
515
+ shm .close ();
516
+ }
517
+ closeShmas ();
518
+ return rais ;
519
+ } catch (Exception e ) {
520
+ closeShmas ();
521
+ if (e instanceof RunModelException )
522
+ throw (RunModelException ) e ;
523
+ throw new RunModelException (Types .stackTrace (e ));
524
+ }
525
+ }
526
+
527
+ private void closeInterprocess () throws RunModelException {
528
+ try {
529
+ Task task = runner .task ("closeTensors" );
530
+ task .waitFor ();
531
+ if (task .status == TaskStatus .CANCELED )
532
+ throw new RuntimeException ();
533
+ else if (task .status == TaskStatus .FAILED )
534
+ throw new RuntimeException (task .error );
535
+ else if (task .status == TaskStatus .CRASHED ) {
536
+ this .runner .close ();
537
+ runner = null ;
538
+ throw new RuntimeException (task .error );
539
+ }
540
+ } catch (Exception e ) {
541
+ if (e instanceof RunModelException )
542
+ throw (RunModelException ) e ;
543
+ throw new RunModelException (Types .stackTrace (e ));
544
+ }
545
+ }
546
+
547
+ protected void closeFromInterp () {
548
+ if (!PlatformDetection .isWindows ())
549
+ return ;
550
+ this .shmaNamesList .stream ().forEach (nn -> {
551
+ try {
552
+ SharedMemoryArray .read (nn ).close ();
553
+ } catch (IOException e ) {
554
+ e .printStackTrace ();
555
+ }
556
+ });
557
+ }
558
+
559
+ private void closeShmas () throws RunModelException {
393
560
shmaInputList .forEach (shm -> {
394
561
try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
395
562
});
@@ -398,6 +565,8 @@ private void closeShmas() {
398
565
try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
399
566
});
400
567
shmaOutputList = null ;
568
+ if (interprocessing )
569
+ closeInterprocess ();
401
570
}
402
571
403
572
private <T extends RealType <T > & NativeType <T >> List <String > encodeInputs (List <Tensor <T >> inputTensors ) {
0 commit comments