2727import  io .bioimage .modelrunner .engine .DeepLearningEngineInterface ;
2828import  io .bioimage .modelrunner .exceptions .LoadModelException ;
2929import  io .bioimage .modelrunner .exceptions .RunModelException ;
30- import  io .bioimage .modelrunner .numpy .DecodeNumpy ;
3130import  io .bioimage .modelrunner .pytorch .shm .ShmBuilder ;
3231import  io .bioimage .modelrunner .pytorch .shm .TensorBuilder ;
3332import  io .bioimage .modelrunner .pytorch .tensor .ImgLib2Builder ;
3736import  io .bioimage .modelrunner .tensor .shm .SharedMemoryArray ;
3837import  io .bioimage .modelrunner .utils .CommonUtils ;
3938import  net .imglib2 .RandomAccessibleInterval ;
40- import  net .imglib2 .img .array .ArrayImgs ;
4139import  net .imglib2 .type .NativeType ;
4240import  net .imglib2 .type .numeric .RealType ;
4341import  net .imglib2 .util .Cast ;
@@ -246,6 +244,34 @@ private static String getModelName(String modelSource) throws IOException {
246244		return  modelName .substring (0 , ind );
247245	}
248246
247+ 	@ Override 
248+ 	public  <T  extends  RealType <T > & NativeType <T >, R  extends  RealType <R > & NativeType <R >> List <RandomAccessibleInterval <R >> inference (
249+ 			List <RandomAccessibleInterval <T >> inputs ) throws  RunModelException  {
250+ 		if  (interprocessing ) {
251+ 			return  runInterprocessing (inputs );
252+ 		}
253+ 		try  (NDManager  manager  = NDManager .newBaseManager ()) {
254+ 			// Create the input lists of engine tensors (NDArrays) and their 
255+ 			// corresponding names 
256+ 			NDList  inputList  = new  NDList ();
257+ 			for  (RandomAccessibleInterval <T > tt  : inputs ) {
258+ 				inputList .add (NDArrayBuilder .build (tt , manager ));
259+ 			}
260+ 			// Run model 
261+ 			Predictor <NDList , NDList > predictor  = model .newPredictor ();
262+ 			NDList  outputNDArrays  = predictor .predict (inputList );
263+ 			// Fill the agnostic output tensors list with data from the inference 
264+ 			// result 
265+ 			List <RandomAccessibleInterval <R >> outs  = new  ArrayList <RandomAccessibleInterval <R >>();
266+ 			for  (int  i  = 0 ; i  < outputNDArrays .size (); i  ++)
267+ 				outs .add (ImgLib2Builder .build (outputNDArrays .get (i )));
268+ 			return  outs ;
269+ 		}
270+ 		catch  (TranslateException  e ) {
271+ 			throw  new  RunModelException (Types .stackTrace (e ));
272+ 		}
273+ 	}
274+ 
249275	/** 
250276	 * {@inheritDoc} 
251277	 *  
@@ -310,21 +336,36 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
310336		}
311337	}
312338
313- 	/** 
314- 	 * MEthod only used in MacOS Intel and Windows systems that makes all the arrangements 
315- 	 * to create another process, communicate the model info and tensors to the other  
316- 	 * process and then retrieve the results of the other process 
317- 	 * @param <T> 
318- 	 * 	ImgLib2 data type of the input tensors 
319- 	 * @param <R> 
320- 	 * 	ImgLib2 data type of the output tensors, it can be the same as the input tensors' data type 
321- 	 * @param inputTensors 
322- 	 * 	tensors that are going to be run on the model 
323- 	 * @param outputTensors 
324- 	 * 	expected results of the model 
325- 	 * @throws RunModelException if there is any issue running the model 
326- 	 */ 
327- 	public  <T  extends  RealType <T > & NativeType <T >, R  extends  RealType <R > & NativeType <R >>
339+ 	protected  List <String > inferenceFromShmas (List <String > inputs ) throws  IOException , RunModelException  {
340+ 		try  (NDManager  manager  = NDManager .newBaseManager ()) {
341+ 			// Create the input lists of engine tensors (NDArrays) and their 
342+ 			// corresponding names 
343+ 			NDList  inputList  = new  NDList ();
344+ 			for  (String  ee  : inputs ) {
345+ 				Map <String , Object > decoded  = Types .decode (ee );
346+ 				SharedMemoryArray  shma  = SharedMemoryArray .read ((String ) decoded .get (MEM_NAME_KEY ));
347+ 				NDArray  inT  = TensorBuilder .build (shma , manager );
348+ 				if  (PlatformDetection .isWindows ()) shma .close ();
349+ 				inputList .add (inT );
350+ 			}
351+ 			// Run model 
352+ 			Predictor <NDList , NDList > predictor  = model .newPredictor ();
353+ 			NDList  outputNDArrays  = predictor .predict (inputList );
354+ 
355+ 			shmaNamesList  = new  ArrayList <String >();
356+ 			for  (int  i  = 0 ; i  < outputNDArrays .size (); i  ++) {
357+ 				String  name  = SharedMemoryArray .createShmName ();
358+ 				ShmBuilder .build (outputNDArrays .get (i ), name , false );
359+ 				shmaNamesList .add (name );
360+ 			}
361+ 			return  shmaNamesList ;
362+ 		}
363+ 		catch  (TranslateException  e ) {
364+ 			throw  new  RunModelException (Types .stackTrace (e ));
365+ 		}
366+ 	}
367+ 	
368+ 	private  <T  extends  RealType <T > & NativeType <T >, R  extends  RealType <R > & NativeType <R >>
328369	void  runInterprocessing (List <Tensor <T >> inputTensors , List <Tensor <R >> outputTensors ) throws  RunModelException  {
329370		shmaInputList  = new  ArrayList <SharedMemoryArray >();
330371		shmaOutputList  = new  ArrayList <SharedMemoryArray >();
@@ -335,7 +376,7 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
335376		args .put ("outputs" , encOuts );
336377
337378		try  {
338- 			Task  task  = runner .task ("inference " , args );
379+ 			Task  task  = runner .task ("run " , args );
339380			task .waitFor ();
340381			if  (task .status  == TaskStatus .CANCELED )
341382				throw  new  RuntimeException ();
@@ -365,6 +406,76 @@ else if (task.status == TaskStatus.CRASHED) {
365406		}
366407		closeShmas ();
367408	}
409+ 	
410+ 	private  <T  extends  RealType <T > & NativeType <T >, R  extends  RealType <R > & NativeType <R >>
411+ 	List <RandomAccessibleInterval <R >> runInterprocessing (List <RandomAccessibleInterval <T >> inputs ) throws  RunModelException  {
412+ 		shmaInputList  = new  ArrayList <SharedMemoryArray >();
413+ 		List <String > encIns  = new  ArrayList <String >();
414+ 		Gson  gson  = new  Gson ();
415+ 		for  (RandomAccessibleInterval <T > tt  : inputs ) {
416+ 			SharedMemoryArray  shma  = SharedMemoryArray .createSHMAFromRAI (tt , false , true );
417+ 			shmaInputList .add (shma );
418+ 			HashMap <String , Object > map  = new  HashMap <String , Object >();
419+ 			map .put (SHAPE_KEY , tt .dimensionsAsLongArray ());
420+ 			map .put (DTYPE_KEY , CommonUtils .getDataTypeFromRAI (tt ));
421+ 			map .put (IS_INPUT_KEY , true );
422+ 			map .put (MEM_NAME_KEY , shma .getName ());
423+ 			encIns .add (gson .toJson (map ));
424+ 		}
425+ 		LinkedHashMap <String , Object > args  = new  LinkedHashMap <String , Object >();
426+ 		args .put ("inputs" , encIns );
427+ 
428+ 		try  {
429+ 			Task  task  = runner .task ("inference" , args );
430+ 			task .waitFor ();
431+ 			if  (task .status  == TaskStatus .CANCELED )
432+ 				throw  new  RuntimeException ();
433+ 			else  if  (task .status  == TaskStatus .FAILED )
434+ 				throw  new  RuntimeException (task .error );
435+ 			else  if  (task .status  == TaskStatus .CRASHED ) {
436+ 				this .runner .close ();
437+ 				runner  = null ;
438+ 				throw  new  RuntimeException (task .error );
439+ 			} else  if  (task .outputs  == null )
440+ 				throw  new  RuntimeException ("No outputs generated" );
441+ 			List <String > outputs  = (List <String >) task .outputs .get ("encoded" );
442+ 			List <RandomAccessibleInterval <R >> rais  = new  ArrayList <RandomAccessibleInterval <R >>();
443+ 			for  (String  out  : outputs ) {
444+ 	        	String  name  = (String ) Types .decode (out ).get (MEM_NAME_KEY );
445+ 	        	SharedMemoryArray  shm  = SharedMemoryArray .read (name );
446+ 	        	RandomAccessibleInterval <R > rai  = shm .getSharedRAI ();
447+ 	        	rais .add (Tensor .createCopyOfRaiInWantedDataType (Cast .unchecked (rai ), Util .getTypeFromInterval (Cast .unchecked (rai ))));
448+ 	        	shm .close ();
449+ 			}
450+ 			closeShmas ();
451+ 			return  rais ;
452+ 		} catch  (Exception  e ) {
453+ 			closeShmas ();
454+ 			if  (e  instanceof  RunModelException )
455+ 				throw  (RunModelException ) e ;
456+ 			throw  new  RunModelException (Types .stackTrace (e ));
457+ 		}
458+ 	}
459+ 	
460+ 	private  void  closeInterprocess () throws  RunModelException  {
461+ 		try  {
462+ 			Task  task  = runner .task ("closeTensors" );
463+ 			task .waitFor ();
464+ 			if  (task .status  == TaskStatus .CANCELED )
465+ 				throw  new  RuntimeException ();
466+ 			else  if  (task .status  == TaskStatus .FAILED )
467+ 				throw  new  RuntimeException (task .error );
468+ 			else  if  (task .status  == TaskStatus .CRASHED ) {
469+ 				this .runner .close ();
470+ 				runner  = null ;
471+ 				throw  new  RuntimeException (task .error );
472+ 			}
473+ 		} catch  (Exception  e ) {
474+ 			if  (e  instanceof  RunModelException )
475+ 				throw  (RunModelException ) e ;
476+ 			throw  new  RunModelException (Types .stackTrace (e ));
477+ 		}
478+ 	}
368479
369480	/** 
370481	 * Create the list a list of output tensors agnostic to the Deep Learning 
@@ -391,7 +502,19 @@ void fillOutputTensors(NDList outputNDArrays,
391502		}
392503	}
393504
394- 	private  void  closeShmas () {
505+ 	protected  void  closeFromInterp () {
506+ 		if  (!PlatformDetection .isWindows ())
507+ 			return ;
508+ 		this .shmaNamesList .stream ().forEach (nn  -> {
509+ 			try  {
510+ 				SharedMemoryArray .read (nn ).close ();
511+ 			} catch  (IOException  e ) {
512+ 				e .printStackTrace ();
513+ 			}
514+ 		});
515+ 	}
516+ 	
517+ 	private  void  closeShmas () throws  RunModelException  {
395518		shmaInputList .forEach (shm  -> {
396519			try  { shm .close (); } catch  (IOException  e1 ) { e1 .printStackTrace ();}
397520		});
@@ -400,6 +523,7 @@ private void closeShmas() {
400523			try  { shm .close (); } catch  (IOException  e1 ) { e1 .printStackTrace ();}
401524		});
402525		shmaOutputList  = null ;
526+ 		closeInterprocess ();
403527	}
404528
405529
0 commit comments