@@ -316,6 +316,51 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors)
316316 tt .close ();
317317 }
318318 }
319+
320+ @ Override
321+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >> List <RandomAccessibleInterval <R >> inference (
322+ List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
323+ if (interprocessing ) {
324+ return runInterprocessing (inputs );
325+ }
326+ Session session = model .session ();
327+ Session .Runner runner = session .runner ();
328+ List <org .tensorflow .Tensor <?>> inTensors =
329+ new ArrayList <org .tensorflow .Tensor <?>>();
330+ int c = 0 ;
331+ for (RandomAccessibleInterval <T > tt : inputs ) {
332+ org .tensorflow .Tensor <?> inT = TensorBuilder .build (tt );
333+ inTensors .add (inT );
334+ String inputName = getModelInputName ("input" + c , c ++);
335+ runner .feed (inputName , inT );
336+ }
337+ c = 0 ;
338+ List <String > outputInfo = sig .getOutputsMap ().values ().stream ()
339+ .map (nn -> {
340+ String name = nn .getName ();
341+ if (name .endsWith (":0" ))
342+ return name .substring (0 , name .length () - 2 );
343+ return name ;
344+ }).collect (Collectors .toList ());
345+
346+ for (String name : outputInfo )
347+ runner = runner .fetch (name );
348+ // Run runner
349+ List <org .tensorflow .Tensor <?>> resultPatchTensors = runner .run ();
350+ for (org .tensorflow .Tensor <?> tt : inTensors )
351+ tt .close ();
352+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
353+ for (int i = 0 ; i < resultPatchTensors .size (); i ++) {
354+ try {
355+ rais .add (ImgLib2Builder .build (resultPatchTensors .get (i )));
356+ } catch (IllegalArgumentException ex ) {
357+ for (org .tensorflow .Tensor <?> tt : resultPatchTensors )
358+ tt .close ();
359+ throw new RunModelException (Types .stackTrace (ex ));
360+ }
361+ }
362+ return rais ;
363+ }
319364
320365 protected void runFromShmas (List <String > inputs , List <String > outputs ) throws IOException {
321366 Session session = model .session ();
@@ -354,17 +399,46 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
354399 }
355400 }
356401
357- /**
358- * MEthod only used in MacOS Intel systems that makes all the arangements
359- * to create another process, communicate the model info and tensors to the other
360- * process and then retrieve the results of the other process
361- * @param inputTensors
362- * tensors that are going to be run on the model
363- * @param outputTensors
364- * expected results of the model
365- * @throws RunModelException if there is any issue running the model
366- */
367- public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
402+ protected List <String > inferenceFromShmas (List <String > inputs ) throws IOException , RunModelException {
403+ Session session = model .session ();
404+ Session .Runner runner = session .runner ();
405+ List <org .tensorflow .Tensor <?>> inTensors =
406+ new ArrayList <org .tensorflow .Tensor <?>>();
407+ int c = 0 ;
408+ for (String ee : inputs ) {
409+ Map <String , Object > decoded = Types .decode (ee );
410+ SharedMemoryArray shma = SharedMemoryArray .read ((String ) decoded .get (MEM_NAME_KEY ));
411+ org .tensorflow .Tensor <?> inT = io .bioimage .modelrunner .tensorflow .v1 .shm .TensorBuilder .build (shma );
412+ if (PlatformDetection .isWindows ()) shma .close ();
413+ inTensors .add (inT );
414+ String inputName = getModelInputName ((String ) decoded .get (NAME_KEY ), c ++);
415+ runner .feed (inputName , inT );
416+ }
417+ List <String > outputInfo = sig .getOutputsMap ().values ().stream ()
418+ .map (nn -> {
419+ String name = nn .getName ();
420+ if (name .endsWith (":0" ))
421+ return name .substring (0 , name .length () - 2 );
422+ return name ;
423+ }).collect (Collectors .toList ());
424+
425+ for (String name : outputInfo )
426+ runner = runner .fetch (name );
427+ // Run runner
428+ List <org .tensorflow .Tensor <?>> resultPatchTensors = runner .run ();
429+ for (org .tensorflow .Tensor <?> tt : inTensors )
430+ tt .close ();
431+
432+ shmaNamesList = new ArrayList <String >();
433+ for (int i = 0 ; i < resultPatchTensors .size (); i ++) {
434+ String name = SharedMemoryArray .createShmName ();
435+ ShmBuilder .build (resultPatchTensors .get (i ), name , false );
436+ shmaNamesList .add (name );
437+ }
438+ return shmaNamesList ;
439+ }
440+
441+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
368442 void runInterprocessing (List <Tensor <T >> inputTensors , List <Tensor <R >> outputTensors ) throws RunModelException {
369443 shmaInputList = new ArrayList <SharedMemoryArray >();
370444 shmaOutputList = new ArrayList <SharedMemoryArray >();
@@ -375,7 +449,7 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
375449 args .put ("outputs" , encOuts );
376450
377451 try {
378- Task task = runner .task ("inference " , args );
452+ Task task = runner .task ("run " , args );
379453 task .waitFor ();
380454 if (task .status == TaskStatus .CANCELED )
381455 throw new RuntimeException ();
@@ -406,7 +480,89 @@ else if (task.status == TaskStatus.CRASHED) {
406480 closeShmas ();
407481 }
408482
409- private void closeShmas () {
483+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
484+ List <RandomAccessibleInterval <R >> runInterprocessing (List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
485+ shmaInputList = new ArrayList <SharedMemoryArray >();
486+ List <String > encIns = new ArrayList <String >();
487+ Gson gson = new Gson ();
488+ for (RandomAccessibleInterval <T > tt : inputs ) {
489+ SharedMemoryArray shma = SharedMemoryArray .createSHMAFromRAI (tt , false , true );
490+ shmaInputList .add (shma );
491+ HashMap <String , Object > map = new HashMap <String , Object >();
492+ map .put (SHAPE_KEY , tt .dimensionsAsLongArray ());
493+ map .put (DTYPE_KEY , CommonUtils .getDataTypeFromRAI (tt ));
494+ map .put (IS_INPUT_KEY , true );
495+ map .put (MEM_NAME_KEY , shma .getName ());
496+ encIns .add (gson .toJson (map ));
497+ }
498+ LinkedHashMap <String , Object > args = new LinkedHashMap <String , Object >();
499+ args .put ("inputs" , encIns );
500+
501+ try {
502+ Task task = runner .task ("inference" , args );
503+ task .waitFor ();
504+ if (task .status == TaskStatus .CANCELED )
505+ throw new RuntimeException ();
506+ else if (task .status == TaskStatus .FAILED )
507+ throw new RuntimeException (task .error );
508+ else if (task .status == TaskStatus .CRASHED ) {
509+ this .runner .close ();
510+ runner = null ;
511+ throw new RuntimeException (task .error );
512+ } else if (task .outputs == null )
513+ throw new RuntimeException ("No outputs generated" );
514+ List <String > outputs = (List <String >) task .outputs .get ("encoded" );
515+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
516+ for (String out : outputs ) {
517+ String name = (String ) Types .decode (out ).get (MEM_NAME_KEY );
518+ SharedMemoryArray shm = SharedMemoryArray .read (name );
519+ RandomAccessibleInterval <R > rai = shm .getSharedRAI ();
520+ rais .add (Tensor .createCopyOfRaiInWantedDataType (Cast .unchecked (rai ), Util .getTypeFromInterval (Cast .unchecked (rai ))));
521+ shm .close ();
522+ }
523+ closeShmas ();
524+ return rais ;
525+ } catch (Exception e ) {
526+ closeShmas ();
527+ if (e instanceof RunModelException )
528+ throw (RunModelException ) e ;
529+ throw new RunModelException (Types .stackTrace (e ));
530+ }
531+ }
532+
533+ private void closeInterprocess () throws RunModelException {
534+ try {
535+ Task task = runner .task ("closeTensors" );
536+ task .waitFor ();
537+ if (task .status == TaskStatus .CANCELED )
538+ throw new RuntimeException ();
539+ else if (task .status == TaskStatus .FAILED )
540+ throw new RuntimeException (task .error );
541+ else if (task .status == TaskStatus .CRASHED ) {
542+ this .runner .close ();
543+ runner = null ;
544+ throw new RuntimeException (task .error );
545+ }
546+ } catch (Exception e ) {
547+ if (e instanceof RunModelException )
548+ throw (RunModelException ) e ;
549+ throw new RunModelException (Types .stackTrace (e ));
550+ }
551+ }
552+
553+ protected void closeFromInterp () {
554+ if (!PlatformDetection .isWindows ())
555+ return ;
556+ this .shmaNamesList .stream ().forEach (nn -> {
557+ try {
558+ SharedMemoryArray .read (nn ).close ();
559+ } catch (IOException e ) {
560+ e .printStackTrace ();
561+ }
562+ });
563+ }
564+
565+ private void closeShmas () throws RunModelException {
410566 shmaInputList .forEach (shm -> {
411567 try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
412568 });
@@ -415,6 +571,8 @@ private void closeShmas() {
415571 try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
416572 });
417573 shmaOutputList = null ;
574+ if (interprocessing )
575+ closeInterprocess ();
418576 }
419577
420578
@@ -483,7 +641,7 @@ public static <T extends RealType<T> & NativeType<T>> void fillOutputTensors(
483641 try {
484642 outputTensors .get (i ).setData (ImgLib2Builder .build (outputNDArrays .get (i )));
485643 } catch (IllegalArgumentException ex ) {
486- throw new RunModelException (ex . toString ( ));
644+ throw new RunModelException (Types . stackTrace ( ex ));
487645 }
488646 }
489647 }
0 commit comments