2727import io .bioimage .modelrunner .engine .DeepLearningEngineInterface ;
2828import io .bioimage .modelrunner .exceptions .LoadModelException ;
2929import io .bioimage .modelrunner .exceptions .RunModelException ;
30- import io .bioimage .modelrunner .numpy .DecodeNumpy ;
3130import io .bioimage .modelrunner .pytorch .shm .ShmBuilder ;
3231import io .bioimage .modelrunner .pytorch .shm .TensorBuilder ;
3332import io .bioimage .modelrunner .pytorch .tensor .ImgLib2Builder ;
3736import io .bioimage .modelrunner .tensor .shm .SharedMemoryArray ;
3837import io .bioimage .modelrunner .utils .CommonUtils ;
3938import net .imglib2 .RandomAccessibleInterval ;
40- import net .imglib2 .img .array .ArrayImgs ;
4139import net .imglib2 .type .NativeType ;
4240import net .imglib2 .type .numeric .RealType ;
4341import net .imglib2 .util .Cast ;
@@ -246,6 +244,34 @@ private static String getModelName(String modelSource) throws IOException {
246244 return modelName .substring (0 , ind );
247245 }
248246
247+ @ Override
248+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >> List <RandomAccessibleInterval <R >> inference (
249+ List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
250+ if (interprocessing ) {
251+ return runInterprocessing (inputs );
252+ }
253+ try (NDManager manager = NDManager .newBaseManager ()) {
254+ // Create the input lists of engine tensors (NDArrays) and their
255+ // corresponding names
256+ NDList inputList = new NDList ();
257+ for (RandomAccessibleInterval <T > tt : inputs ) {
258+ inputList .add (NDArrayBuilder .build (tt , manager ));
259+ }
260+ // Run model
261+ Predictor <NDList , NDList > predictor = model .newPredictor ();
262+ NDList outputNDArrays = predictor .predict (inputList );
263+ // Fill the agnostic output tensors list with data from the inference
264+ // result
265+ List <RandomAccessibleInterval <R >> outs = new ArrayList <RandomAccessibleInterval <R >>();
266+ for (int i = 0 ; i < outputNDArrays .size (); i ++)
267+ outs .add (ImgLib2Builder .build (outputNDArrays .get (i )));
268+ return outs ;
269+ }
270+ catch (TranslateException e ) {
271+ throw new RunModelException (Types .stackTrace (e ));
272+ }
273+ }
274+
249275 /**
250276 * {@inheritDoc}
251277 *
@@ -310,21 +336,36 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
310336 }
311337 }
312338
313- /**
314- * MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
315- * to create another process, communicate the model info and tensors to the other
316- * process and then retrieve the results of the other process
317- * @param <T>
318- * ImgLib2 data type of the input tensors
319- * @param <R>
320- * ImgLib2 data type of the output tensors, it can be the same as the input tensors' data type
321- * @param inputTensors
322- * tensors that are going to be run on the model
323- * @param outputTensors
324- * expected results of the model
325- * @throws RunModelException if there is any issue running the model
326- */
327- public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
339+ protected List <String > inferenceFromShmas (List <String > inputs ) throws IOException , RunModelException {
340+ try (NDManager manager = NDManager .newBaseManager ()) {
341+ // Create the input lists of engine tensors (NDArrays) and their
342+ // corresponding names
343+ NDList inputList = new NDList ();
344+ for (String ee : inputs ) {
345+ Map <String , Object > decoded = Types .decode (ee );
346+ SharedMemoryArray shma = SharedMemoryArray .read ((String ) decoded .get (MEM_NAME_KEY ));
347+ NDArray inT = TensorBuilder .build (shma , manager );
348+ if (PlatformDetection .isWindows ()) shma .close ();
349+ inputList .add (inT );
350+ }
351+ // Run model
352+ Predictor <NDList , NDList > predictor = model .newPredictor ();
353+ NDList outputNDArrays = predictor .predict (inputList );
354+
355+ shmaNamesList = new ArrayList <String >();
356+ for (int i = 0 ; i < outputNDArrays .size (); i ++) {
357+ String name = SharedMemoryArray .createShmName ();
358+ ShmBuilder .build (outputNDArrays .get (i ), name , false );
359+ shmaNamesList .add (name );
360+ }
361+ return shmaNamesList ;
362+ }
363+ catch (TranslateException e ) {
364+ throw new RunModelException (Types .stackTrace (e ));
365+ }
366+ }
367+
368+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
328369 void runInterprocessing (List <Tensor <T >> inputTensors , List <Tensor <R >> outputTensors ) throws RunModelException {
329370 shmaInputList = new ArrayList <SharedMemoryArray >();
330371 shmaOutputList = new ArrayList <SharedMemoryArray >();
@@ -335,7 +376,7 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
335376 args .put ("outputs" , encOuts );
336377
337378 try {
338- Task task = runner .task ("inference " , args );
379+ Task task = runner .task ("run " , args );
339380 task .waitFor ();
340381 if (task .status == TaskStatus .CANCELED )
341382 throw new RuntimeException ();
@@ -365,6 +406,76 @@ else if (task.status == TaskStatus.CRASHED) {
365406 }
366407 closeShmas ();
367408 }
409+
410+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
411+ List <RandomAccessibleInterval <R >> runInterprocessing (List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
412+ shmaInputList = new ArrayList <SharedMemoryArray >();
413+ List <String > encIns = new ArrayList <String >();
414+ Gson gson = new Gson ();
415+ for (RandomAccessibleInterval <T > tt : inputs ) {
416+ SharedMemoryArray shma = SharedMemoryArray .createSHMAFromRAI (tt , false , true );
417+ shmaInputList .add (shma );
418+ HashMap <String , Object > map = new HashMap <String , Object >();
419+ map .put (SHAPE_KEY , tt .dimensionsAsLongArray ());
420+ map .put (DTYPE_KEY , CommonUtils .getDataTypeFromRAI (tt ));
421+ map .put (IS_INPUT_KEY , true );
422+ map .put (MEM_NAME_KEY , shma .getName ());
423+ encIns .add (gson .toJson (map ));
424+ }
425+ LinkedHashMap <String , Object > args = new LinkedHashMap <String , Object >();
426+ args .put ("inputs" , encIns );
427+
428+ try {
429+ Task task = runner .task ("inference" , args );
430+ task .waitFor ();
431+ if (task .status == TaskStatus .CANCELED )
432+ throw new RuntimeException ();
433+ else if (task .status == TaskStatus .FAILED )
434+ throw new RuntimeException (task .error );
435+ else if (task .status == TaskStatus .CRASHED ) {
436+ this .runner .close ();
437+ runner = null ;
438+ throw new RuntimeException (task .error );
439+ } else if (task .outputs == null )
440+ throw new RuntimeException ("No outputs generated" );
441+ List <String > outputs = (List <String >) task .outputs .get ("encoded" );
442+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
443+ for (String out : outputs ) {
444+ String name = (String ) Types .decode (out ).get (MEM_NAME_KEY );
445+ SharedMemoryArray shm = SharedMemoryArray .read (name );
446+ RandomAccessibleInterval <R > rai = shm .getSharedRAI ();
447+ rais .add (Tensor .createCopyOfRaiInWantedDataType (Cast .unchecked (rai ), Util .getTypeFromInterval (Cast .unchecked (rai ))));
448+ shm .close ();
449+ }
450+ closeShmas ();
451+ return rais ;
452+ } catch (Exception e ) {
453+ closeShmas ();
454+ if (e instanceof RunModelException )
455+ throw (RunModelException ) e ;
456+ throw new RunModelException (Types .stackTrace (e ));
457+ }
458+ }
459+
460+ private void closeInterprocess () throws RunModelException {
461+ try {
462+ Task task = runner .task ("closeTensors" );
463+ task .waitFor ();
464+ if (task .status == TaskStatus .CANCELED )
465+ throw new RuntimeException ();
466+ else if (task .status == TaskStatus .FAILED )
467+ throw new RuntimeException (task .error );
468+ else if (task .status == TaskStatus .CRASHED ) {
469+ this .runner .close ();
470+ runner = null ;
471+ throw new RuntimeException (task .error );
472+ }
473+ } catch (Exception e ) {
474+ if (e instanceof RunModelException )
475+ throw (RunModelException ) e ;
476+ throw new RunModelException (Types .stackTrace (e ));
477+ }
478+ }
368479
369480 /**
370481 * Create the list a list of output tensors agnostic to the Deep Learning
@@ -391,7 +502,19 @@ void fillOutputTensors(NDList outputNDArrays,
391502 }
392503 }
393504
394- private void closeShmas () {
505+ protected void closeFromInterp () {
506+ if (!PlatformDetection .isWindows ())
507+ return ;
508+ this .shmaNamesList .stream ().forEach (nn -> {
509+ try {
510+ SharedMemoryArray .read (nn ).close ();
511+ } catch (IOException e ) {
512+ e .printStackTrace ();
513+ }
514+ });
515+ }
516+
517+ private void closeShmas () throws RunModelException {
395518 shmaInputList .forEach (shm -> {
396519 try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
397520 });
@@ -400,6 +523,7 @@ private void closeShmas() {
400523 try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
401524 });
402525 shmaOutputList = null ;
526+ closeInterprocess ();
403527 }
404528
405529
0 commit comments