27
27
import io .bioimage .modelrunner .engine .DeepLearningEngineInterface ;
28
28
import io .bioimage .modelrunner .exceptions .LoadModelException ;
29
29
import io .bioimage .modelrunner .exceptions .RunModelException ;
30
- import io .bioimage .modelrunner .numpy .DecodeNumpy ;
31
30
import io .bioimage .modelrunner .pytorch .shm .ShmBuilder ;
32
31
import io .bioimage .modelrunner .pytorch .shm .TensorBuilder ;
33
32
import io .bioimage .modelrunner .pytorch .tensor .ImgLib2Builder ;
37
36
import io .bioimage .modelrunner .tensor .shm .SharedMemoryArray ;
38
37
import io .bioimage .modelrunner .utils .CommonUtils ;
39
38
import net .imglib2 .RandomAccessibleInterval ;
40
- import net .imglib2 .img .array .ArrayImgs ;
41
39
import net .imglib2 .type .NativeType ;
42
40
import net .imglib2 .type .numeric .RealType ;
43
41
import net .imglib2 .util .Cast ;
@@ -246,6 +244,34 @@ private static String getModelName(String modelSource) throws IOException {
246
244
return modelName .substring (0 , ind );
247
245
}
248
246
247
+ @ Override
248
+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >> List <RandomAccessibleInterval <R >> inference (
249
+ List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
250
+ if (interprocessing ) {
251
+ return runInterprocessing (inputs );
252
+ }
253
+ try (NDManager manager = NDManager .newBaseManager ()) {
254
+ // Create the input lists of engine tensors (NDArrays) and their
255
+ // corresponding names
256
+ NDList inputList = new NDList ();
257
+ for (RandomAccessibleInterval <T > tt : inputs ) {
258
+ inputList .add (NDArrayBuilder .build (tt , manager ));
259
+ }
260
+ // Run model
261
+ Predictor <NDList , NDList > predictor = model .newPredictor ();
262
+ NDList outputNDArrays = predictor .predict (inputList );
263
+ // Fill the agnostic output tensors list with data from the inference
264
+ // result
265
+ List <RandomAccessibleInterval <R >> outs = new ArrayList <RandomAccessibleInterval <R >>();
266
+ for (int i = 0 ; i < outputNDArrays .size (); i ++)
267
+ outs .add (ImgLib2Builder .build (outputNDArrays .get (i )));
268
+ return outs ;
269
+ }
270
+ catch (TranslateException e ) {
271
+ throw new RunModelException (Types .stackTrace (e ));
272
+ }
273
+ }
274
+
249
275
/**
250
276
* {@inheritDoc}
251
277
*
@@ -310,21 +336,36 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
310
336
}
311
337
}
312
338
313
- /**
314
- * MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
315
- * to create another process, communicate the model info and tensors to the other
316
- * process and then retrieve the results of the other process
317
- * @param <T>
318
- * ImgLib2 data type of the input tensors
319
- * @param <R>
320
- * ImgLib2 data type of the output tensors, it can be the same as the input tensors' data type
321
- * @param inputTensors
322
- * tensors that are going to be run on the model
323
- * @param outputTensors
324
- * expected results of the model
325
- * @throws RunModelException if there is any issue running the model
326
- */
327
- public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
339
+ protected List <String > inferenceFromShmas (List <String > inputs ) throws IOException , RunModelException {
340
+ try (NDManager manager = NDManager .newBaseManager ()) {
341
+ // Create the input lists of engine tensors (NDArrays) and their
342
+ // corresponding names
343
+ NDList inputList = new NDList ();
344
+ for (String ee : inputs ) {
345
+ Map <String , Object > decoded = Types .decode (ee );
346
+ SharedMemoryArray shma = SharedMemoryArray .read ((String ) decoded .get (MEM_NAME_KEY ));
347
+ NDArray inT = TensorBuilder .build (shma , manager );
348
+ if (PlatformDetection .isWindows ()) shma .close ();
349
+ inputList .add (inT );
350
+ }
351
+ // Run model
352
+ Predictor <NDList , NDList > predictor = model .newPredictor ();
353
+ NDList outputNDArrays = predictor .predict (inputList );
354
+
355
+ shmaNamesList = new ArrayList <String >();
356
+ for (int i = 0 ; i < outputNDArrays .size (); i ++) {
357
+ String name = SharedMemoryArray .createShmName ();
358
+ ShmBuilder .build (outputNDArrays .get (i ), name , false );
359
+ shmaNamesList .add (name );
360
+ }
361
+ return shmaNamesList ;
362
+ }
363
+ catch (TranslateException e ) {
364
+ throw new RunModelException (Types .stackTrace (e ));
365
+ }
366
+ }
367
+
368
+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
328
369
void runInterprocessing (List <Tensor <T >> inputTensors , List <Tensor <R >> outputTensors ) throws RunModelException {
329
370
shmaInputList = new ArrayList <SharedMemoryArray >();
330
371
shmaOutputList = new ArrayList <SharedMemoryArray >();
@@ -335,7 +376,7 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
335
376
args .put ("outputs" , encOuts );
336
377
337
378
try {
338
- Task task = runner .task ("inference " , args );
379
+ Task task = runner .task ("run " , args );
339
380
task .waitFor ();
340
381
if (task .status == TaskStatus .CANCELED )
341
382
throw new RuntimeException ();
@@ -365,6 +406,76 @@ else if (task.status == TaskStatus.CRASHED) {
365
406
}
366
407
closeShmas ();
367
408
}
409
+
410
+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
411
+ List <RandomAccessibleInterval <R >> runInterprocessing (List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
412
+ shmaInputList = new ArrayList <SharedMemoryArray >();
413
+ List <String > encIns = new ArrayList <String >();
414
+ Gson gson = new Gson ();
415
+ for (RandomAccessibleInterval <T > tt : inputs ) {
416
+ SharedMemoryArray shma = SharedMemoryArray .createSHMAFromRAI (tt , false , true );
417
+ shmaInputList .add (shma );
418
+ HashMap <String , Object > map = new HashMap <String , Object >();
419
+ map .put (SHAPE_KEY , tt .dimensionsAsLongArray ());
420
+ map .put (DTYPE_KEY , CommonUtils .getDataTypeFromRAI (tt ));
421
+ map .put (IS_INPUT_KEY , true );
422
+ map .put (MEM_NAME_KEY , shma .getName ());
423
+ encIns .add (gson .toJson (map ));
424
+ }
425
+ LinkedHashMap <String , Object > args = new LinkedHashMap <String , Object >();
426
+ args .put ("inputs" , encIns );
427
+
428
+ try {
429
+ Task task = runner .task ("inference" , args );
430
+ task .waitFor ();
431
+ if (task .status == TaskStatus .CANCELED )
432
+ throw new RuntimeException ();
433
+ else if (task .status == TaskStatus .FAILED )
434
+ throw new RuntimeException (task .error );
435
+ else if (task .status == TaskStatus .CRASHED ) {
436
+ this .runner .close ();
437
+ runner = null ;
438
+ throw new RuntimeException (task .error );
439
+ } else if (task .outputs == null )
440
+ throw new RuntimeException ("No outputs generated" );
441
+ List <String > outputs = (List <String >) task .outputs .get ("encoded" );
442
+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
443
+ for (String out : outputs ) {
444
+ String name = (String ) Types .decode (out ).get (MEM_NAME_KEY );
445
+ SharedMemoryArray shm = SharedMemoryArray .read (name );
446
+ RandomAccessibleInterval <R > rai = shm .getSharedRAI ();
447
+ rais .add (Tensor .createCopyOfRaiInWantedDataType (Cast .unchecked (rai ), Util .getTypeFromInterval (Cast .unchecked (rai ))));
448
+ shm .close ();
449
+ }
450
+ closeShmas ();
451
+ return rais ;
452
+ } catch (Exception e ) {
453
+ closeShmas ();
454
+ if (e instanceof RunModelException )
455
+ throw (RunModelException ) e ;
456
+ throw new RunModelException (Types .stackTrace (e ));
457
+ }
458
+ }
459
+
460
+ private void closeInterprocess () throws RunModelException {
461
+ try {
462
+ Task task = runner .task ("closeTensors" );
463
+ task .waitFor ();
464
+ if (task .status == TaskStatus .CANCELED )
465
+ throw new RuntimeException ();
466
+ else if (task .status == TaskStatus .FAILED )
467
+ throw new RuntimeException (task .error );
468
+ else if (task .status == TaskStatus .CRASHED ) {
469
+ this .runner .close ();
470
+ runner = null ;
471
+ throw new RuntimeException (task .error );
472
+ }
473
+ } catch (Exception e ) {
474
+ if (e instanceof RunModelException )
475
+ throw (RunModelException ) e ;
476
+ throw new RunModelException (Types .stackTrace (e ));
477
+ }
478
+ }
368
479
369
480
/**
370
481
* Create the list a list of output tensors agnostic to the Deep Learning
@@ -391,7 +502,19 @@ void fillOutputTensors(NDList outputNDArrays,
391
502
}
392
503
}
393
504
394
- private void closeShmas () {
505
+ protected void closeFromInterp () {
506
+ if (!PlatformDetection .isWindows ())
507
+ return ;
508
+ this .shmaNamesList .stream ().forEach (nn -> {
509
+ try {
510
+ SharedMemoryArray .read (nn ).close ();
511
+ } catch (IOException e ) {
512
+ e .printStackTrace ();
513
+ }
514
+ });
515
+ }
516
+
517
+ private void closeShmas () throws RunModelException {
395
518
shmaInputList .forEach (shm -> {
396
519
try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
397
520
});
@@ -400,6 +523,7 @@ private void closeShmas() {
400
523
try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
401
524
});
402
525
shmaOutputList = null ;
526
+ closeInterprocess ();
403
527
}
404
528
405
529
0 commit comments