From fa1894537821bdbf49ce859085e580d3831b3d5b Mon Sep 17 00:00:00 2001 From: John Bogovic Date: Thu, 6 Mar 2025 16:08:42 -0500 Subject: [PATCH 1/2] fix/test: DistanceTransform array out of bounds exception see #108 --- .../distance/DistanceTransform.java | 4 +- .../distance/DistanceTransformTest.java | 63 +++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java b/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java index 48d9182c4..cec18be9d 100644 --- a/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java +++ b/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java @@ -1187,7 +1187,7 @@ private static < T extends RealType< T >, U extends RealType< U > > void transfo final double sourceAtPosition = source.get( position ).getRealDouble(); double s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim ); - for ( double envelopeValueAtK = envelopeIntersectLocation.get( k ).get(); s <= envelopeValueAtK; envelopeValueAtK = envelopeIntersectLocation.get( k ).get() ) + for ( double envelopeValueAtK = envelopeIntersectLocation.get( k ).get(); s <= envelopeValueAtK && k >= 1; envelopeValueAtK = envelopeIntersectLocation.get( k ).get() ) { --k; envelopeIndexAtK = lowerBoundDistanceIndex.get( k ).get(); @@ -1785,7 +1785,7 @@ private static < T extends RealType< T >, U extends RealType< U >, L extends Int final double sourceAtPosition = source.get( position ).getRealDouble(); double s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim ); - for ( double envelopeValueAtK = envelopeIntersectLocation.get( k ).get(); s <= envelopeValueAtK; envelopeValueAtK = envelopeIntersectLocation.get( k ).get() ) + for ( double envelopeValueAtK = envelopeIntersectLocation.get( k ).get(); s <= envelopeValueAtK && k >= 1; envelopeValueAtK = envelopeIntersectLocation.get( k ).get() ) { --k; envelopeIndexAtK = lowerBoundDistanceIndex.get( k ).get(); diff --git a/src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformTest.java b/src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformTest.java index 9129ad728..79092b3b7 100644 --- a/src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformTest.java +++ b/src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformTest.java @@ -34,6 +34,7 @@ package net.imglib2.algorithm.morphology.distance; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -501,6 +502,68 @@ public void testLabelPropagation() } + @Test + public void testWeights() + { + final double tolerance = 1e-9; + final double M = Double.MAX_VALUE; + + double[] data = new double[] { + 0, M, M, M, + M, 0, M, M + }; + ArrayImg dists = ArrayImgs.doubles(data, 4, 2); + DistanceTransform.transform(dists, new EuclidianDistanceAnisotropic(0.1, 1.0)); + + final double[] expected = new double[] {0.0, 0.1, 0.4, 0.9, 0.1, 0.0, 0.1, 0.4 }; + assertArrayEquals( expected, dists.getAccessType().getCurrentStorageArray(), tolerance ); + } + + @Test + public void testWeightsL1() + { + final double tolerance = 1e-9; + final double M = Double.MAX_VALUE; + + double[] data = new double[] { + 0, M, M, M, + M, 0, M, M + }; + ArrayImg dists = ArrayImgs.doubles(data, 4, 2); + DistanceTransform.transform(dists, DISTANCE_TYPE.L1, 0.1, 1.0 ); + + final double[] expected = new double[] {0.0, 0.1, 0.2, 0.3, 0.1, 0.0, 0.1, 0.2 }; + assertArrayEquals( expected, dists.getAccessType().getCurrentStorageArray(), tolerance ); + } + + @Test + public void testLabelPropagationWeights() + { + final long[] labelData = new long[]{ + 1, 0, 0, 0, + 0, 2, 0, 0 }; + + final long[] expectedYClose = new long[] { + 1, 2, 2, 2, + 1, 2, 2, 2}; + + final long[] expectedXClose = new long[] { + 1, 1, 1, 1, + 2, 2, 2, 2}; + + double rx = 99.0; + double ry = 0.01; + ArrayImg labels = ArrayImgs.longs(Arrays.copyOf(labelData, 8), 4, 2); + DistanceTransform.voronoiDistanceTransform(labels, 0, rx, ry); + assertArrayEquals( expectedYClose,labels.getAccessType().getCurrentStorageArray()); + + rx = 0.01; + ry = 99.0; + ArrayImg labels2 = ArrayImgs.longs(Arrays.copyOf(labelData, 8), 4, 2); + DistanceTransform.voronoiDistanceTransform(labels2, 0, rx, ry); + assertArrayEquals( expectedXClose, labels2.getAccessType().getCurrentStorageArray()); + } + /** * Creates an label and distances images with the requested number of dimensions (ndims), * and places nLabels points with non-zero label. Checks that the propagated labels correctly From e6a50a26ee984aa52abf617b862234565b9a68e1 Mon Sep 17 00:00:00 2001 From: John Bogovic Date: Tue, 18 Mar 2025 15:18:46 -0400 Subject: [PATCH 2/2] perf: speed up transformAlongDimension methods * both for distance and voronoiDistance transforms * use primitive arrays rather than RealComposites * get between 40% and 100% speedup --- .../distance/DistanceTransform.java | 207 +++++++++++++++++- .../distance/DistanceTransformBenchmark.java | 115 ++++++++++ 2 files changed, 317 insertions(+), 5 deletions(-) create mode 100644 src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformBenchmark.java diff --git a/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java b/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java index cec18be9d..bf01f357e 100644 --- a/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java +++ b/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java @@ -355,9 +355,9 @@ public static < T extends RealType< T >, U extends RealType< U >, V extends Real final int nTasks, final double... weights ) throws InterruptedException, ExecutionException { - final boolean isIsotropic = weights.length <= 1; - final double[] w = weights.length == source.numDimensions() ? weights : DoubleStream.generate( () -> weights.length == 0 ? 1.0 : weights[ 0 ] ).limit( source.numDimensions() ).toArray(); + final double[] w = weights.length == source.numDimensions() ? weights + : DoubleStream.generate(() -> weights.length == 0 ? 1.0 : weights[0]).limit(source.numDimensions()).toArray(); switch ( distanceType ) { @@ -1105,6 +1105,19 @@ private static < T extends RealType< T >, U extends RealType< U > > void transfo final RandomAccessibleInterval< U > target, final Distance d, final int dim ) + { + final long size = target.dimension( dim ); + if( size > Integer.MAX_VALUE ) + transformAlongDimensionComposite(source, target, d, dim); + else + transformAlongDimensionPrimitive(source, target, d, dim); + } + + private static < T extends RealType< T >, U extends RealType< U > > void transformAlongDimensionComposite( + final RandomAccessible< T > source, + final RandomAccessibleInterval< U > target, + final Distance d, + final int dim ) { final int lastDim = target.numDimensions() - 1; final long size = target.dimension( dim ); @@ -1129,6 +1142,35 @@ private static < T extends RealType< T >, U extends RealType< U > > void transfo } } + private static < T extends RealType< T >, U extends RealType< U > > void transformAlongDimensionPrimitive( + final RandomAccessible< T > source, + final RandomAccessibleInterval< U > target, + final Distance d, + final int dim ) + { + final int lastDim = target.numDimensions() - 1; + final long size = target.dimension( dim ); + final RealComposite< DoubleType > tmp = Views.collapseReal( createAppropriateOneDimensionalImage( size, new DoubleType() ) ).randomAccess().get(); + + // do not permute if we already work on last dimension + final Cursor< RealComposite< T > > s = Views.flatIterable( Views.collapseReal( dim == lastDim ? Views.interval( source, target ) : Views.permute( Views.interval( source, target ), dim, lastDim ) ) ).cursor(); + final Cursor< RealComposite< U > > t = Views.flatIterable( Views.collapseReal( dim == lastDim ? target : Views.permute( target, dim, lastDim ) ) ).cursor(); + + final long[] lowerBoundDistanceIndex = new long[(int)size]; + final double[] envelopeIntersectLocation = new double[(int)size + 1]; + + while ( s.hasNext() ) + { + final RealComposite< T > sourceComp = s.next(); + final RealComposite< U > targetComp = t.next(); + for ( long i = 0; i < size; ++i ) + { + tmp.get( i ).set( sourceComp.get( i ).getRealDouble() ); + } + transformSingleColumnPrimitive( tmp, targetComp, lowerBoundDistanceIndex, envelopeIntersectLocation, d, dim, size ); + } + } + private static < T extends RealType< T >, U extends RealType< U > > void transformAlongDimensionParallel( final RandomAccessible< T > source, final RandomAccessibleInterval< U > target, @@ -1167,6 +1209,53 @@ private static < T extends RealType< T >, U extends RealType< U > > void transfo invokeAllAndWait( es, tasks ); } + private static < T extends RealType< T >, U extends RealType< U > > void transformSingleColumnPrimitive( + final RealComposite< T > source, + final RealComposite< U > target, + final long[] lowerBoundDistanceIndex, + final double[] envelopeIntersectLocation, + final Distance d, + final int dim, + final long size ) + { + int k = 0; + + lowerBoundDistanceIndex[0] = 0; + envelopeIntersectLocation[0] = Double.NEGATIVE_INFINITY; + envelopeIntersectLocation[1] = Double.POSITIVE_INFINITY; + for ( long position = 1; position < size; ++position ) + { + long envelopeIndexAtK = lowerBoundDistanceIndex[k]; + final double sourceAtPosition = source.get( position ).getRealDouble(); + double s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim ); + + for ( double envelopeValueAtK = envelopeIntersectLocation[k]; s <= envelopeValueAtK && k >= 1; envelopeValueAtK = envelopeIntersectLocation[k] ) + { + --k; + envelopeIndexAtK = lowerBoundDistanceIndex[k]; + s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim ); + } + ++k; + lowerBoundDistanceIndex[k] = position; + envelopeIntersectLocation[k] = s; + envelopeIntersectLocation[k + 1] = Double.POSITIVE_INFINITY; + } + + k = 0; + + for ( long position = 0; position < size; ++position ) + { + while ( envelopeIntersectLocation[ k + 1 ] < position ) + { + ++k; + } + final long envelopeIndexAtK = lowerBoundDistanceIndex[k]; + // copy necessary because of the following line, access to source + // after write to source -> source and target cannot be the same + target.get( position ).setReal( d.evaluate( position, envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), dim ) ); + } + } + private static < T extends RealType< T >, U extends RealType< U > > void transformSingleColumn( final RealComposite< T > source, final RealComposite< U > target, @@ -1212,7 +1301,6 @@ private static < T extends RealType< T >, U extends RealType< U > > void transfo // after write to source -> source and target cannot be the same target.get( position ).setReal( d.evaluate( position, envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), dim ) ); } - } private static < T extends RealType< T >, U extends RealType< U > > void transformL1AlongDimension( @@ -1725,6 +1813,21 @@ private static < T extends RealType< T >, U extends RealType< U >, L extends Int final RandomAccessible< M > labelTarget, final Distance d, final int dim ) + { + final long size = target.dimension( dim ); + if( size > Integer.MAX_VALUE ) + transformAlongDimensionPropagateLabelsComposite(source, target, labelSource, labelTarget, d, dim); + else + transformAlongDimensionPropagateLabelsPrimitive(source, target, labelSource, labelTarget, d, dim); + } + + private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType< L >, M extends IntegerType< M > > void transformAlongDimensionPropagateLabelsComposite( + final RandomAccessible< T > source, + final RandomAccessibleInterval< U > target, + final RandomAccessible< L > labelSource, + final RandomAccessible< M > labelTarget, + final Distance d, + final int dim ) { final int lastDim = target.numDimensions() - 1; final long size = target.dimension( dim ); @@ -1759,11 +1862,11 @@ private static < T extends RealType< T >, U extends RealType< U >, L extends Int tmp.get( i ).set( sourceComp.get( i ).getRealDouble() ); tmpLabel.get( i ).setInteger( labelComp.get( i ).getIntegerLong() ); } - transformSingleColumnPropagateLabels( tmp, targetComp, tmpLabel, labelTargetComp, lowerBoundDistanceIndex, envelopeIntersectLocation, d, dim, size ); + transformSingleColumnPropagateLabelsComposite( tmp, targetComp, tmpLabel, labelTargetComp, lowerBoundDistanceIndex, envelopeIntersectLocation, d, dim, size ); } } - private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType, M extends IntegerType > void transformSingleColumnPropagateLabels( + private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType, M extends IntegerType > void transformSingleColumnPropagateLabelsComposite( final RealComposite< T > source, final RealComposite< U > target, final RealComposite< L > labelsSource, @@ -1813,6 +1916,100 @@ private static < T extends RealType< T >, U extends RealType< U >, L extends Int } + private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType< L >, M extends IntegerType< M > > void transformAlongDimensionPropagateLabelsPrimitive( + final RandomAccessible< T > source, + final RandomAccessibleInterval< U > target, + final RandomAccessible< L > labelSource, + final RandomAccessible< M > labelTarget, + final Distance d, + final int dim ) + { + final int lastDim = target.numDimensions() - 1; + final int size = (int)target.dimension( dim ); + + final Img< DoubleType > tmpImg = createAppropriateOneDimensionalImage( size, new DoubleType() ); + final RealComposite< DoubleType > tmp = Views.collapseReal( tmpImg ).randomAccess().get(); + + final Img< L > tmpLabelImg = Util.getSuitableImgFactory( tmpImg, labelSource.getType() ).create( tmpImg ); + final RealComposite< L > tmpLabel = Views.collapseReal( tmpLabelImg ).randomAccess().get(); + + // do not permute if we already work on last dimension + final Cursor< RealComposite< T > > s = Views.flatIterable( Views.collapseReal( dim == lastDim ? Views.interval( source, target ) : Views.permute( Views.interval( source, target ), dim, lastDim ) ) ).cursor(); + final Cursor< RealComposite< U > > t = Views.flatIterable( Views.collapseReal( dim == lastDim ? target : Views.permute( target, dim, lastDim ) ) ).cursor(); + + final Cursor< RealComposite< L > > ls = Views.flatIterable( + Views.collapseReal( dim == lastDim ? Views.interval( labelSource, target ) : Views.permute( Views.interval( labelSource, target ), dim, lastDim ) ) ).cursor(); + + final Cursor< RealComposite< M > > lt = Views.flatIterable( + Views.collapseReal( dim == lastDim ? Views.interval( labelTarget, target ) : Views.permute( Views.interval( labelTarget, target ), dim, lastDim ) ) ).cursor(); + + final long[] lowerBoundDistanceIndex = new long[size]; + final double[] envelopeIntersectLocation = new double[size+1]; + + while ( s.hasNext() ) + { + final RealComposite< T > sourceComp = s.next(); + final RealComposite< U > targetComp = t.next(); + final RealComposite< L > labelComp = ls.next(); + final RealComposite< M > labelTargetComp = lt.next(); + for ( long i = 0; i < size; ++i ) + { + tmp.get( i ).set( sourceComp.get( i ).getRealDouble() ); + tmpLabel.get( i ).setInteger( labelComp.get( i ).getIntegerLong() ); + } + transformSingleColumnPropagateLabelsPrimitive( tmp, targetComp, tmpLabel, labelTargetComp, lowerBoundDistanceIndex, envelopeIntersectLocation, d, dim, size ); + } + } + + private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType, M extends IntegerType > void transformSingleColumnPropagateLabelsPrimitive( + final RealComposite< T > source, + final RealComposite< U > target, + final RealComposite< L > labelsSource, + final RealComposite< M > labelsResult, + final long[] lowerBoundDistanceIndex, + final double[] envelopeIntersectLocation, + final Distance d, + final int dim, + final long size ) + { + int k = 0; + + lowerBoundDistanceIndex[0] = 0; + envelopeIntersectLocation[0] = Double.NEGATIVE_INFINITY; + envelopeIntersectLocation[1] = Double.POSITIVE_INFINITY; + for ( long position = 1; position < size; ++position ) + { + long envelopeIndexAtK = lowerBoundDistanceIndex[k]; + final double sourceAtPosition = source.get( position ).getRealDouble(); + double s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim ); + + for ( double envelopeValueAtK = envelopeIntersectLocation[k]; s <= envelopeValueAtK && k >= 1; envelopeValueAtK = envelopeIntersectLocation[k] ) + { + --k; + envelopeIndexAtK = lowerBoundDistanceIndex[k]; + s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim ); + } + ++k; + lowerBoundDistanceIndex[k] = position; + envelopeIntersectLocation[k] = s; + envelopeIntersectLocation[k + 1] = Double.POSITIVE_INFINITY; + } + + k = 0; + for ( long position = 0; position < size; ++position ) + { + while ( envelopeIntersectLocation[ k + 1 ] < position ) + { + ++k; + } + final long envelopeIndexAtK = lowerBoundDistanceIndex[k]; + // copy necessary because of the following line, access to source + // after write to source -> source and target cannot be the same + target.get( position ).setReal( d.evaluate( position, envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), dim ) ); + labelsResult.get( position ).setInteger( labelsSource.get( envelopeIndexAtK ).getIntegerLong() ); + } + } + private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType< L >, M extends IntegerType< M > > void transformAlongDimensionPropagateLabelsParallel( final RandomAccessible< T > source, final RandomAccessibleInterval< U > target, diff --git a/src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformBenchmark.java b/src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformBenchmark.java new file mode 100644 index 000000000..476320f03 --- /dev/null +++ b/src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformBenchmark.java @@ -0,0 +1,115 @@ +package net.imglib2.algorithm.morphology.distance; + +import java.util.Arrays; +import java.util.Random; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.TimeValue; + +import net.imglib2.img.array.ArrayImg; +import net.imglib2.img.array.ArrayImgs; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.type.numeric.real.DoubleType; + +@State( Scope.Benchmark ) +public class DistanceTransformBenchmark { + + private int numPoints = 4000; + private long[] dim = new long[] {256, 256, 128}; + DistanceTransform.DISTANCE_TYPE distanceType = DistanceTransform.DISTANCE_TYPE.EUCLIDIAN; + + Random random; + private ArrayImg distanceImg; + private ArrayImg labelImg; + ExecutorService es; + int nThreads; + int numTasks; + + @Setup + public void setup() + { + random = new Random(0); + + final int N = Arrays.stream(dim).mapToInt( i -> (int)i).reduce(1, (x,y) -> x*y); + + final double[] initValues = new double[N]; + Arrays.fill(initValues, Double.MAX_VALUE); + for( int i = 0; i < numPoints; i++ ) { + initValues[random.nextInt(N)] = 0.0; + } + distanceImg = ArrayImgs.doubles(initValues, dim); + + final long[] initLabels = new long[N]; + for( int i = 0; i < numPoints; i++ ) { + initLabels[random.nextInt(N)] = 1 + random.nextInt(numPoints); + } + labelImg = ArrayImgs.longs(initLabels, dim); + + int nThreads = 1; + es = Executors.newFixedThreadPool(nThreads); + numTasks = 2*nThreads; + } + + @Benchmark + @BenchmarkMode( Mode.AverageTime ) + @OutputTimeUnit( TimeUnit.MILLISECONDS ) + public void distanceTransform() + { + if( nThreads == 1 ) { + DistanceTransform.transform(distanceImg, distanceType); + } else { + try { + DistanceTransform.transform(distanceImg, distanceType, es, numTasks); + } catch (InterruptedException e) { + e.printStackTrace(); + } catch (ExecutionException e) { + e.printStackTrace(); + } + } + } + + @Benchmark + @BenchmarkMode( Mode.AverageTime ) + @OutputTimeUnit( TimeUnit.MILLISECONDS ) + public void voronoiDistanceTransform() + { + if( nThreads == 1 ) { + DistanceTransform.voronoiDistanceTransform(labelImg, 0l); + } else { + try { + DistanceTransform.voronoiDistanceTransform(labelImg, 0l, es, numTasks); + } catch (InterruptedException e) { + e.printStackTrace(); + } catch (ExecutionException e) { + e.printStackTrace(); + } + } + } + + public static void main( final String... args ) throws RunnerException + { + final Options opt = new OptionsBuilder() + .include( DistanceTransformBenchmark.class.getSimpleName() ) + .forks( 0 ) + .warmupIterations( 5 ) + .measurementIterations( 25 ) + .measurementTime(TimeValue.seconds( 2 )) + .warmupTime(TimeValue.seconds( 2 )) + .build(); + new Runner( opt ).run(); + } +}