Skip to content

Commit e6a50a2

Browse files
committed
perf: speed up transformAlongDimension methods
* both for distance and voronoiDistance transforms * use primitive arrays rather than RealComposites * get between 40% and 100% speedup
1 parent fa18945 commit e6a50a2

File tree

2 files changed

+317
-5
lines changed

2 files changed

+317
-5
lines changed

src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java

+202-5
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,9 @@ public static < T extends RealType< T >, U extends RealType< U >, V extends Real
355355
final int nTasks,
356356
final double... weights ) throws InterruptedException, ExecutionException
357357
{
358-
359358
final boolean isIsotropic = weights.length <= 1;
360-
final double[] w = weights.length == source.numDimensions() ? weights : DoubleStream.generate( () -> weights.length == 0 ? 1.0 : weights[ 0 ] ).limit( source.numDimensions() ).toArray();
359+
final double[] w = weights.length == source.numDimensions() ? weights
360+
: DoubleStream.generate(() -> weights.length == 0 ? 1.0 : weights[0]).limit(source.numDimensions()).toArray();
361361

362362
switch ( distanceType )
363363
{
@@ -1105,6 +1105,19 @@ private static < T extends RealType< T >, U extends RealType< U > > void transfo
11051105
final RandomAccessibleInterval< U > target,
11061106
final Distance d,
11071107
final int dim )
1108+
{
1109+
final long size = target.dimension( dim );
1110+
if( size > Integer.MAX_VALUE )
1111+
transformAlongDimensionComposite(source, target, d, dim);
1112+
else
1113+
transformAlongDimensionPrimitive(source, target, d, dim);
1114+
}
1115+
1116+
private static < T extends RealType< T >, U extends RealType< U > > void transformAlongDimensionComposite(
1117+
final RandomAccessible< T > source,
1118+
final RandomAccessibleInterval< U > target,
1119+
final Distance d,
1120+
final int dim )
11081121
{
11091122
final int lastDim = target.numDimensions() - 1;
11101123
final long size = target.dimension( dim );
@@ -1129,6 +1142,35 @@ private static < T extends RealType< T >, U extends RealType< U > > void transfo
11291142
}
11301143
}
11311144

1145+
private static < T extends RealType< T >, U extends RealType< U > > void transformAlongDimensionPrimitive(
1146+
final RandomAccessible< T > source,
1147+
final RandomAccessibleInterval< U > target,
1148+
final Distance d,
1149+
final int dim )
1150+
{
1151+
final int lastDim = target.numDimensions() - 1;
1152+
final long size = target.dimension( dim );
1153+
final RealComposite< DoubleType > tmp = Views.collapseReal( createAppropriateOneDimensionalImage( size, new DoubleType() ) ).randomAccess().get();
1154+
1155+
// do not permute if we already work on last dimension
1156+
final Cursor< RealComposite< T > > s = Views.flatIterable( Views.collapseReal( dim == lastDim ? Views.interval( source, target ) : Views.permute( Views.interval( source, target ), dim, lastDim ) ) ).cursor();
1157+
final Cursor< RealComposite< U > > t = Views.flatIterable( Views.collapseReal( dim == lastDim ? target : Views.permute( target, dim, lastDim ) ) ).cursor();
1158+
1159+
final long[] lowerBoundDistanceIndex = new long[(int)size];
1160+
final double[] envelopeIntersectLocation = new double[(int)size + 1];
1161+
1162+
while ( s.hasNext() )
1163+
{
1164+
final RealComposite< T > sourceComp = s.next();
1165+
final RealComposite< U > targetComp = t.next();
1166+
for ( long i = 0; i < size; ++i )
1167+
{
1168+
tmp.get( i ).set( sourceComp.get( i ).getRealDouble() );
1169+
}
1170+
transformSingleColumnPrimitive( tmp, targetComp, lowerBoundDistanceIndex, envelopeIntersectLocation, d, dim, size );
1171+
}
1172+
}
1173+
11321174
private static < T extends RealType< T >, U extends RealType< U > > void transformAlongDimensionParallel(
11331175
final RandomAccessible< T > source,
11341176
final RandomAccessibleInterval< U > target,
@@ -1167,6 +1209,53 @@ private static < T extends RealType< T >, U extends RealType< U > > void transfo
11671209
invokeAllAndWait( es, tasks );
11681210
}
11691211

1212+
private static < T extends RealType< T >, U extends RealType< U > > void transformSingleColumnPrimitive(
1213+
final RealComposite< T > source,
1214+
final RealComposite< U > target,
1215+
final long[] lowerBoundDistanceIndex,
1216+
final double[] envelopeIntersectLocation,
1217+
final Distance d,
1218+
final int dim,
1219+
final long size )
1220+
{
1221+
int k = 0;
1222+
1223+
lowerBoundDistanceIndex[0] = 0;
1224+
envelopeIntersectLocation[0] = Double.NEGATIVE_INFINITY;
1225+
envelopeIntersectLocation[1] = Double.POSITIVE_INFINITY;
1226+
for ( long position = 1; position < size; ++position )
1227+
{
1228+
long envelopeIndexAtK = lowerBoundDistanceIndex[k];
1229+
final double sourceAtPosition = source.get( position ).getRealDouble();
1230+
double s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim );
1231+
1232+
for ( double envelopeValueAtK = envelopeIntersectLocation[k]; s <= envelopeValueAtK && k >= 1; envelopeValueAtK = envelopeIntersectLocation[k] )
1233+
{
1234+
--k;
1235+
envelopeIndexAtK = lowerBoundDistanceIndex[k];
1236+
s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim );
1237+
}
1238+
++k;
1239+
lowerBoundDistanceIndex[k] = position;
1240+
envelopeIntersectLocation[k] = s;
1241+
envelopeIntersectLocation[k + 1] = Double.POSITIVE_INFINITY;
1242+
}
1243+
1244+
k = 0;
1245+
1246+
for ( long position = 0; position < size; ++position )
1247+
{
1248+
while ( envelopeIntersectLocation[ k + 1 ] < position )
1249+
{
1250+
++k;
1251+
}
1252+
final long envelopeIndexAtK = lowerBoundDistanceIndex[k];
1253+
// copy necessary because of the following line, access to source
1254+
// after write to source -> source and target cannot be the same
1255+
target.get( position ).setReal( d.evaluate( position, envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), dim ) );
1256+
}
1257+
}
1258+
11701259
private static < T extends RealType< T >, U extends RealType< U > > void transformSingleColumn(
11711260
final RealComposite< T > source,
11721261
final RealComposite< U > target,
@@ -1212,7 +1301,6 @@ private static < T extends RealType< T >, U extends RealType< U > > void transfo
12121301
// after write to source -> source and target cannot be the same
12131302
target.get( position ).setReal( d.evaluate( position, envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), dim ) );
12141303
}
1215-
12161304
}
12171305

12181306
private static < T extends RealType< T >, U extends RealType< U > > void transformL1AlongDimension(
@@ -1725,6 +1813,21 @@ private static < T extends RealType< T >, U extends RealType< U >, L extends Int
17251813
final RandomAccessible< M > labelTarget,
17261814
final Distance d,
17271815
final int dim )
1816+
{
1817+
final long size = target.dimension( dim );
1818+
if( size > Integer.MAX_VALUE )
1819+
transformAlongDimensionPropagateLabelsComposite(source, target, labelSource, labelTarget, d, dim);
1820+
else
1821+
transformAlongDimensionPropagateLabelsPrimitive(source, target, labelSource, labelTarget, d, dim);
1822+
}
1823+
1824+
private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType< L >, M extends IntegerType< M > > void transformAlongDimensionPropagateLabelsComposite(
1825+
final RandomAccessible< T > source,
1826+
final RandomAccessibleInterval< U > target,
1827+
final RandomAccessible< L > labelSource,
1828+
final RandomAccessible< M > labelTarget,
1829+
final Distance d,
1830+
final int dim )
17281831
{
17291832
final int lastDim = target.numDimensions() - 1;
17301833
final long size = target.dimension( dim );
@@ -1759,11 +1862,11 @@ private static < T extends RealType< T >, U extends RealType< U >, L extends Int
17591862
tmp.get( i ).set( sourceComp.get( i ).getRealDouble() );
17601863
tmpLabel.get( i ).setInteger( labelComp.get( i ).getIntegerLong() );
17611864
}
1762-
transformSingleColumnPropagateLabels( tmp, targetComp, tmpLabel, labelTargetComp, lowerBoundDistanceIndex, envelopeIntersectLocation, d, dim, size );
1865+
transformSingleColumnPropagateLabelsComposite( tmp, targetComp, tmpLabel, labelTargetComp, lowerBoundDistanceIndex, envelopeIntersectLocation, d, dim, size );
17631866
}
17641867
}
17651868

1766-
private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType<L>, M extends IntegerType<M> > void transformSingleColumnPropagateLabels(
1869+
private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType<L>, M extends IntegerType<M> > void transformSingleColumnPropagateLabelsComposite(
17671870
final RealComposite< T > source,
17681871
final RealComposite< U > target,
17691872
final RealComposite< L > labelsSource,
@@ -1813,6 +1916,100 @@ private static < T extends RealType< T >, U extends RealType< U >, L extends Int
18131916

18141917
}
18151918

1919+
private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType< L >, M extends IntegerType< M > > void transformAlongDimensionPropagateLabelsPrimitive(
1920+
final RandomAccessible< T > source,
1921+
final RandomAccessibleInterval< U > target,
1922+
final RandomAccessible< L > labelSource,
1923+
final RandomAccessible< M > labelTarget,
1924+
final Distance d,
1925+
final int dim )
1926+
{
1927+
final int lastDim = target.numDimensions() - 1;
1928+
final int size = (int)target.dimension( dim );
1929+
1930+
final Img< DoubleType > tmpImg = createAppropriateOneDimensionalImage( size, new DoubleType() );
1931+
final RealComposite< DoubleType > tmp = Views.collapseReal( tmpImg ).randomAccess().get();
1932+
1933+
final Img< L > tmpLabelImg = Util.getSuitableImgFactory( tmpImg, labelSource.getType() ).create( tmpImg );
1934+
final RealComposite< L > tmpLabel = Views.collapseReal( tmpLabelImg ).randomAccess().get();
1935+
1936+
// do not permute if we already work on last dimension
1937+
final Cursor< RealComposite< T > > s = Views.flatIterable( Views.collapseReal( dim == lastDim ? Views.interval( source, target ) : Views.permute( Views.interval( source, target ), dim, lastDim ) ) ).cursor();
1938+
final Cursor< RealComposite< U > > t = Views.flatIterable( Views.collapseReal( dim == lastDim ? target : Views.permute( target, dim, lastDim ) ) ).cursor();
1939+
1940+
final Cursor< RealComposite< L > > ls = Views.flatIterable(
1941+
Views.collapseReal( dim == lastDim ? Views.interval( labelSource, target ) : Views.permute( Views.interval( labelSource, target ), dim, lastDim ) ) ).cursor();
1942+
1943+
final Cursor< RealComposite< M > > lt = Views.flatIterable(
1944+
Views.collapseReal( dim == lastDim ? Views.interval( labelTarget, target ) : Views.permute( Views.interval( labelTarget, target ), dim, lastDim ) ) ).cursor();
1945+
1946+
final long[] lowerBoundDistanceIndex = new long[size];
1947+
final double[] envelopeIntersectLocation = new double[size+1];
1948+
1949+
while ( s.hasNext() )
1950+
{
1951+
final RealComposite< T > sourceComp = s.next();
1952+
final RealComposite< U > targetComp = t.next();
1953+
final RealComposite< L > labelComp = ls.next();
1954+
final RealComposite< M > labelTargetComp = lt.next();
1955+
for ( long i = 0; i < size; ++i )
1956+
{
1957+
tmp.get( i ).set( sourceComp.get( i ).getRealDouble() );
1958+
tmpLabel.get( i ).setInteger( labelComp.get( i ).getIntegerLong() );
1959+
}
1960+
transformSingleColumnPropagateLabelsPrimitive( tmp, targetComp, tmpLabel, labelTargetComp, lowerBoundDistanceIndex, envelopeIntersectLocation, d, dim, size );
1961+
}
1962+
}
1963+
1964+
private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType<L>, M extends IntegerType<M> > void transformSingleColumnPropagateLabelsPrimitive(
1965+
final RealComposite< T > source,
1966+
final RealComposite< U > target,
1967+
final RealComposite< L > labelsSource,
1968+
final RealComposite< M > labelsResult,
1969+
final long[] lowerBoundDistanceIndex,
1970+
final double[] envelopeIntersectLocation,
1971+
final Distance d,
1972+
final int dim,
1973+
final long size )
1974+
{
1975+
int k = 0;
1976+
1977+
lowerBoundDistanceIndex[0] = 0;
1978+
envelopeIntersectLocation[0] = Double.NEGATIVE_INFINITY;
1979+
envelopeIntersectLocation[1] = Double.POSITIVE_INFINITY;
1980+
for ( long position = 1; position < size; ++position )
1981+
{
1982+
long envelopeIndexAtK = lowerBoundDistanceIndex[k];
1983+
final double sourceAtPosition = source.get( position ).getRealDouble();
1984+
double s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim );
1985+
1986+
for ( double envelopeValueAtK = envelopeIntersectLocation[k]; s <= envelopeValueAtK && k >= 1; envelopeValueAtK = envelopeIntersectLocation[k] )
1987+
{
1988+
--k;
1989+
envelopeIndexAtK = lowerBoundDistanceIndex[k];
1990+
s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim );
1991+
}
1992+
++k;
1993+
lowerBoundDistanceIndex[k] = position;
1994+
envelopeIntersectLocation[k] = s;
1995+
envelopeIntersectLocation[k + 1] = Double.POSITIVE_INFINITY;
1996+
}
1997+
1998+
k = 0;
1999+
for ( long position = 0; position < size; ++position )
2000+
{
2001+
while ( envelopeIntersectLocation[ k + 1 ] < position )
2002+
{
2003+
++k;
2004+
}
2005+
final long envelopeIndexAtK = lowerBoundDistanceIndex[k];
2006+
// copy necessary because of the following line, access to source
2007+
// after write to source -> source and target cannot be the same
2008+
target.get( position ).setReal( d.evaluate( position, envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), dim ) );
2009+
labelsResult.get( position ).setInteger( labelsSource.get( envelopeIndexAtK ).getIntegerLong() );
2010+
}
2011+
}
2012+
18162013
private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType< L >, M extends IntegerType< M > > void transformAlongDimensionPropagateLabelsParallel(
18172014
final RandomAccessible< T > source,
18182015
final RandomAccessibleInterval< U > target,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package net.imglib2.algorithm.morphology.distance;
2+
3+
import java.util.Arrays;
4+
import java.util.Random;
5+
import java.util.concurrent.ExecutionException;
6+
import java.util.concurrent.ExecutorService;
7+
import java.util.concurrent.Executors;
8+
import java.util.concurrent.TimeUnit;
9+
10+
import org.openjdk.jmh.annotations.Benchmark;
11+
import org.openjdk.jmh.annotations.BenchmarkMode;
12+
import org.openjdk.jmh.annotations.Mode;
13+
import org.openjdk.jmh.annotations.OutputTimeUnit;
14+
import org.openjdk.jmh.annotations.Scope;
15+
import org.openjdk.jmh.annotations.Setup;
16+
import org.openjdk.jmh.annotations.State;
17+
import org.openjdk.jmh.runner.Runner;
18+
import org.openjdk.jmh.runner.RunnerException;
19+
import org.openjdk.jmh.runner.options.Options;
20+
import org.openjdk.jmh.runner.options.OptionsBuilder;
21+
import org.openjdk.jmh.runner.options.TimeValue;
22+
23+
import net.imglib2.img.array.ArrayImg;
24+
import net.imglib2.img.array.ArrayImgs;
25+
import net.imglib2.type.numeric.integer.LongType;
26+
import net.imglib2.type.numeric.real.DoubleType;
27+
28+
@State( Scope.Benchmark )
29+
public class DistanceTransformBenchmark {
30+
31+
private int numPoints = 4000;
32+
private long[] dim = new long[] {256, 256, 128};
33+
DistanceTransform.DISTANCE_TYPE distanceType = DistanceTransform.DISTANCE_TYPE.EUCLIDIAN;
34+
35+
Random random;
36+
private ArrayImg<DoubleType, ?> distanceImg;
37+
private ArrayImg<LongType, ?> labelImg;
38+
ExecutorService es;
39+
int nThreads;
40+
int numTasks;
41+
42+
@Setup
43+
public void setup()
44+
{
45+
random = new Random(0);
46+
47+
final int N = Arrays.stream(dim).mapToInt( i -> (int)i).reduce(1, (x,y) -> x*y);
48+
49+
final double[] initValues = new double[N];
50+
Arrays.fill(initValues, Double.MAX_VALUE);
51+
for( int i = 0; i < numPoints; i++ ) {
52+
initValues[random.nextInt(N)] = 0.0;
53+
}
54+
distanceImg = ArrayImgs.doubles(initValues, dim);
55+
56+
final long[] initLabels = new long[N];
57+
for( int i = 0; i < numPoints; i++ ) {
58+
initLabels[random.nextInt(N)] = 1 + random.nextInt(numPoints);
59+
}
60+
labelImg = ArrayImgs.longs(initLabels, dim);
61+
62+
int nThreads = 1;
63+
es = Executors.newFixedThreadPool(nThreads);
64+
numTasks = 2*nThreads;
65+
}
66+
67+
@Benchmark
68+
@BenchmarkMode( Mode.AverageTime )
69+
@OutputTimeUnit( TimeUnit.MILLISECONDS )
70+
public void distanceTransform()
71+
{
72+
if( nThreads == 1 ) {
73+
DistanceTransform.transform(distanceImg, distanceType);
74+
} else {
75+
try {
76+
DistanceTransform.transform(distanceImg, distanceType, es, numTasks);
77+
} catch (InterruptedException e) {
78+
e.printStackTrace();
79+
} catch (ExecutionException e) {
80+
e.printStackTrace();
81+
}
82+
}
83+
}
84+
85+
@Benchmark
86+
@BenchmarkMode( Mode.AverageTime )
87+
@OutputTimeUnit( TimeUnit.MILLISECONDS )
88+
public void voronoiDistanceTransform()
89+
{
90+
if( nThreads == 1 ) {
91+
DistanceTransform.voronoiDistanceTransform(labelImg, 0l);
92+
} else {
93+
try {
94+
DistanceTransform.voronoiDistanceTransform(labelImg, 0l, es, numTasks);
95+
} catch (InterruptedException e) {
96+
e.printStackTrace();
97+
} catch (ExecutionException e) {
98+
e.printStackTrace();
99+
}
100+
}
101+
}
102+
103+
public static void main( final String... args ) throws RunnerException
104+
{
105+
final Options opt = new OptionsBuilder()
106+
.include( DistanceTransformBenchmark.class.getSimpleName() )
107+
.forks( 0 )
108+
.warmupIterations( 5 )
109+
.measurementIterations( 25 )
110+
.measurementTime(TimeValue.seconds( 2 ))
111+
.warmupTime(TimeValue.seconds( 2 ))
112+
.build();
113+
new Runner( opt ).run();
114+
}
115+
}

0 commit comments

Comments
 (0)