@@ -237,6 +237,52 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws Run
237
237
inputsVector .close ();
238
238
inputsVector .deallocate ();
239
239
}
240
+
241
+ /**
242
+ * {@inheritDoc}
243
+ *
244
+ * Run a Pytorch model using JavaCpp on the data provided by the {@link Tensor} input list
245
+ * and modifies the output list with the results obtained
246
+ *
247
+ */
248
+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
249
+ List <RandomAccessibleInterval <R >> inference (List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
250
+ if (interprocessing ) {
251
+ return runInterprocessing (inputs );
252
+ }
253
+ IValueVector inputsVector = new IValueVector ();
254
+ for (RandomAccessibleInterval <T > tt : inputs ) {
255
+ inputsVector .put (new IValue (JavaCPPTensorBuilder .build (tt )));
256
+ }
257
+ // Run model
258
+ model .eval ();
259
+ IValue output = model .forward (inputsVector );
260
+ TensorVector outputTensorVector = null ;
261
+ if (output .isTensorList ()) {
262
+ outputTensorVector = output .toTensorVector ();
263
+ } else {
264
+ outputTensorVector = new TensorVector ();
265
+ outputTensorVector .put (output .toTensor ());
266
+ }
267
+ // Fill the agnostic output tensors list with data from the inference result
268
+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
269
+ for (int i = 0 ; i < outputTensorVector .size (); i ++) {
270
+ rais .add (ImgLib2Builder .build (outputTensorVector .get (i )));
271
+ outputTensorVector .get (i ).close ();
272
+ outputTensorVector .get (i ).deallocate ();
273
+ }
274
+ outputTensorVector .close ();
275
+ outputTensorVector .deallocate ();
276
+ output .close ();
277
+ output .deallocate ();
278
+ for (int i = 0 ; i < inputsVector .size (); i ++) {
279
+ inputsVector .get (i ).close ();
280
+ inputsVector .get (i ).deallocate ();
281
+ }
282
+ inputsVector .close ();
283
+ inputsVector .deallocate ();
284
+ return rais ;
285
+ }
240
286
241
287
protected void runFromShmas (List <String > inputs , List <String > outputs ) throws IOException {
242
288
IValueVector inputsVector = new IValueVector ();
@@ -276,17 +322,46 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
276
322
inputsVector .deallocate ();
277
323
}
278
324
279
- /**
280
- * MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
281
- * to create another process, communicate the model info and tensors to the other
282
- * process and then retrieve the results of the other process
283
- * @param inputTensors
284
- * tensors that are going to be run on the model
285
- * @param outputTensors
286
- * expected results of the model
287
- * @throws RunModelException if there is any issue running the model
288
- */
289
- public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
325
+ protected List <String > inferenceFromShmas (List <String > inputs ) throws IOException , RunModelException {
326
+ IValueVector inputsVector = new IValueVector ();
327
+ for (String ee : inputs ) {
328
+ Map <String , Object > decoded = Types .decode (ee );
329
+ SharedMemoryArray shma = SharedMemoryArray .read ((String ) decoded .get (MEM_NAME_KEY ));
330
+ org .bytedeco .pytorch .Tensor inT = TensorBuilder .build (shma );
331
+ inputsVector .put (new IValue (inT ));
332
+ if (PlatformDetection .isWindows ()) shma .close ();
333
+ }
334
+ // Run model
335
+ model .eval ();
336
+ IValue output = model .forward (inputsVector );
337
+ TensorVector outputTensorVector = null ;
338
+ if (output .isTensorList ()) {
339
+ outputTensorVector = output .toTensorVector ();
340
+ } else {
341
+ outputTensorVector = new TensorVector ();
342
+ outputTensorVector .put (output .toTensor ());
343
+ }
344
+
345
+ shmaNamesList = new ArrayList <String >();
346
+ for (int i = 0 ; i < outputTensorVector .size (); i ++) {
347
+ String name = SharedMemoryArray .createShmName ();
348
+ ShmBuilder .build (outputTensorVector .get (i ), name , false );
349
+ shmaNamesList .add (name );
350
+ }
351
+ outputTensorVector .close ();
352
+ outputTensorVector .deallocate ();
353
+ output .close ();
354
+ output .deallocate ();
355
+ for (int i = 0 ; i < inputsVector .size (); i ++) {
356
+ inputsVector .get (i ).close ();
357
+ inputsVector .get (i ).deallocate ();
358
+ }
359
+ inputsVector .close ();
360
+ inputsVector .deallocate ();
361
+ return shmaNamesList ;
362
+ }
363
+
364
+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
290
365
void runInterprocessing (List <Tensor <T >> inputTensors , List <Tensor <R >> outputTensors ) throws RunModelException {
291
366
shmaInputList = new ArrayList <SharedMemoryArray >();
292
367
shmaOutputList = new ArrayList <SharedMemoryArray >();
@@ -297,7 +372,7 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
297
372
args .put ("outputs" , encOuts );
298
373
299
374
try {
300
- Task task = runner .task ("inference " , args );
375
+ Task task = runner .task ("run " , args );
301
376
task .waitFor ();
302
377
if (task .status == TaskStatus .CANCELED )
303
378
throw new RuntimeException ();
@@ -328,7 +403,89 @@ else if (task.status == TaskStatus.CRASHED) {
328
403
closeShmas ();
329
404
}
330
405
331
- private void closeShmas () {
406
+ private <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
407
+ List <RandomAccessibleInterval <R >> runInterprocessing (List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
408
+ shmaInputList = new ArrayList <SharedMemoryArray >();
409
+ List <String > encIns = new ArrayList <String >();
410
+ Gson gson = new Gson ();
411
+ for (RandomAccessibleInterval <T > tt : inputs ) {
412
+ SharedMemoryArray shma = SharedMemoryArray .createSHMAFromRAI (tt , false , true );
413
+ shmaInputList .add (shma );
414
+ HashMap <String , Object > map = new HashMap <String , Object >();
415
+ map .put (SHAPE_KEY , tt .dimensionsAsLongArray ());
416
+ map .put (DTYPE_KEY , CommonUtils .getDataTypeFromRAI (tt ));
417
+ map .put (IS_INPUT_KEY , true );
418
+ map .put (MEM_NAME_KEY , shma .getName ());
419
+ encIns .add (gson .toJson (map ));
420
+ }
421
+ LinkedHashMap <String , Object > args = new LinkedHashMap <String , Object >();
422
+ args .put ("inputs" , encIns );
423
+
424
+ try {
425
+ Task task = runner .task ("inference" , args );
426
+ task .waitFor ();
427
+ if (task .status == TaskStatus .CANCELED )
428
+ throw new RuntimeException ();
429
+ else if (task .status == TaskStatus .FAILED )
430
+ throw new RuntimeException (task .error );
431
+ else if (task .status == TaskStatus .CRASHED ) {
432
+ this .runner .close ();
433
+ runner = null ;
434
+ throw new RuntimeException (task .error );
435
+ } else if (task .outputs == null )
436
+ throw new RuntimeException ("No outputs generated" );
437
+ List <String > outputs = (List <String >) task .outputs .get ("encoded" );
438
+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
439
+ for (String out : outputs ) {
440
+ String name = (String ) Types .decode (out ).get (MEM_NAME_KEY );
441
+ SharedMemoryArray shm = SharedMemoryArray .read (name );
442
+ RandomAccessibleInterval <R > rai = shm .getSharedRAI ();
443
+ rais .add (Tensor .createCopyOfRaiInWantedDataType (Cast .unchecked (rai ), Util .getTypeFromInterval (Cast .unchecked (rai ))));
444
+ shm .close ();
445
+ }
446
+ closeShmas ();
447
+ return rais ;
448
+ } catch (Exception e ) {
449
+ closeShmas ();
450
+ if (e instanceof RunModelException )
451
+ throw (RunModelException ) e ;
452
+ throw new RunModelException (Types .stackTrace (e ));
453
+ }
454
+ }
455
+
456
+ private void closeInterprocess () throws RunModelException {
457
+ try {
458
+ Task task = runner .task ("closeTensors" );
459
+ task .waitFor ();
460
+ if (task .status == TaskStatus .CANCELED )
461
+ throw new RuntimeException ();
462
+ else if (task .status == TaskStatus .FAILED )
463
+ throw new RuntimeException (task .error );
464
+ else if (task .status == TaskStatus .CRASHED ) {
465
+ this .runner .close ();
466
+ runner = null ;
467
+ throw new RuntimeException (task .error );
468
+ }
469
+ } catch (Exception e ) {
470
+ if (e instanceof RunModelException )
471
+ throw (RunModelException ) e ;
472
+ throw new RunModelException (Types .stackTrace (e ));
473
+ }
474
+ }
475
+
476
+ protected void closeFromInterp () {
477
+ if (!PlatformDetection .isWindows ())
478
+ return ;
479
+ this .shmaNamesList .stream ().forEach (nn -> {
480
+ try {
481
+ SharedMemoryArray .read (nn ).close ();
482
+ } catch (IOException e ) {
483
+ e .printStackTrace ();
484
+ }
485
+ });
486
+ }
487
+
488
+ private void closeShmas () throws RunModelException {
332
489
shmaInputList .forEach (shm -> {
333
490
try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
334
491
});
@@ -337,6 +494,8 @@ private void closeShmas() {
337
494
try { shm .close (); } catch (IOException e1 ) { e1 .printStackTrace ();}
338
495
});
339
496
shmaOutputList = null ;
497
+ if (interprocessing )
498
+ closeInterprocess ();
340
499
}
341
500
342
501
/**
0 commit comments