Skip to content

Commit 1246da1

Browse files
committed
try saveing more along the way
1 parent 1d90073 commit 1246da1

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
275275
for (String ee : inputs) {
276276
Map<String, Object> decoded = Types.decode(ee);
277277
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
278+
DecodeNumpy.saveNpy("/home/carlos/git/mm_in.npy", Cast.unchecked(shma.getSharedRAI()));
278279
NDArray inT = TensorBuilder.build(shma, manager);
279280
if (PlatformDetection.isWindows()) shma.close();
280281
inputList.add(inT);
@@ -339,7 +340,6 @@ else if (task.status == TaskStatus.CRASHED) {
339340
shmaOutputList.add(shm);
340341
}
341342
RandomAccessibleInterval<?> rai = shm.getSharedRAI();
342-
DecodeNumpy.saveNpy("/home/carlos/git/mm.npy", Cast.unchecked(rai));
343343
outputTensors.get(i).setData(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai))));
344344
}
345345
} catch (Exception e) {

src/main/java/io/bioimage/modelrunner/pytorch/shm/ShmBuilder.java

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
*/
2121
package io.bioimage.modelrunner.pytorch.shm;
2222

23+
import io.bioimage.modelrunner.numpy.DecodeNumpy;
2324
import io.bioimage.modelrunner.system.PlatformDetection;
2425
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
2526
import io.bioimage.modelrunner.utils.CommonUtils;
@@ -33,6 +34,7 @@
3334
import net.imglib2.type.numeric.integer.UnsignedByteType;
3435
import net.imglib2.type.numeric.real.DoubleType;
3536
import net.imglib2.type.numeric.real.FloatType;
37+
import net.imglib2.util.Cast;
3638

3739
/**
3840
* A utility class that converts {@link NDArray}s into {@link SharedMemoryArray}s for
@@ -114,6 +116,7 @@ private static void buildFromTensorFloat(NDArray tensor, String memoryName) thro
114116

115117
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
116118
shma.getDataBufferNoHeader().put(tensor.toByteArray());
119+
DecodeNumpy.saveNpy("/home/carlos/git/mm_out.npy", Cast.unchecked(shma.getSharedRAI()));
117120
if (PlatformDetection.isWindows()) shma.close();
118121
}
119122

0 commit comments

Comments
 (0)