Skip to content

Commit aae1f4d

Browse files
authored
Merge pull request #93 from imglib/gauss3-parallelization
Speed up Gauss3.gauss() by using the imglib2 "Parallelization" class
2 parents 377ed3c + 41e55dc commit aae1f4d

File tree

9 files changed

+189
-120
lines changed

9 files changed

+189
-120
lines changed

src/main/java/net/imglib2/algorithm/convolution/AbstractMultiThreadedConvolution.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
*
5252
* @author Matthias Arzt
5353
*/
54+
@Deprecated
5455
public abstract class AbstractMultiThreadedConvolution< T > implements Convolution< T >
5556
{
5657

@@ -61,6 +62,7 @@ abstract protected void process( RandomAccessible< ? extends T > source,
6162
ExecutorService executorService,
6263
int numThreads );
6364

65+
@Deprecated
6466
@Override
6567
public void setExecutor( final ExecutorService executor )
6668
{

src/main/java/net/imglib2/algorithm/convolution/Concatenation.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,12 @@ class Concatenation< T > implements Convolution< T >
6464
this.steps = new ArrayList<>( steps );
6565
}
6666

67+
@Deprecated
6768
@Override
68-
public void setExecutor( final ExecutorService executor )
69+
public void setExecutor( ExecutorService executor )
6970
{
70-
steps.forEach( step -> step.setExecutor( executor ) );
71+
for ( Convolution<T> step : steps )
72+
step.setExecutor( executor );
7173
}
7274

7375
@Override

src/main/java/net/imglib2/algorithm/convolution/Convolution.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ public interface Convolution< T >
6767
/**
6868
* Set the {@link ExecutorService} to be used for convolution.
6969
*/
70+
@Deprecated
7071
default void setExecutor( final ExecutorService executor )
7172
{}
7273

src/main/java/net/imglib2/algorithm/convolution/LineConvolution.java

Lines changed: 36 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -33,44 +33,50 @@
3333
*/
3434
package net.imglib2.algorithm.convolution;
3535

36-
import java.util.ArrayList;
37-
import java.util.List;
38-
import java.util.concurrent.Callable;
39-
import java.util.concurrent.ExecutionException;
40-
import java.util.concurrent.ExecutorService;
41-
import java.util.concurrent.Future;
42-
import java.util.function.Consumer;
43-
import java.util.function.Supplier;
44-
4536
import net.imglib2.FinalInterval;
4637
import net.imglib2.Interval;
4738
import net.imglib2.Localizable;
48-
import net.imglib2.Point;
4939
import net.imglib2.RandomAccess;
5040
import net.imglib2.RandomAccessible;
5141
import net.imglib2.RandomAccessibleInterval;
52-
import net.imglib2.util.IntervalIndexer;
42+
import net.imglib2.loops.LoopBuilder;
43+
import net.imglib2.parallel.Parallelization;
44+
import net.imglib2.parallel.TaskExecutor;
45+
import net.imglib2.parallel.TaskExecutors;
46+
import net.imglib2.util.Cast;
5347
import net.imglib2.util.Intervals;
48+
import net.imglib2.util.Localizables;
5449
import net.imglib2.view.Views;
5550

51+
import java.util.concurrent.ExecutorService;
52+
5653
/**
5754
* This class can be used to implement a separable convolution. It applies a
5855
* {@link LineConvolverFactory} on the given images.
5956
*
6057
* @author Matthias Arzt
6158
*/
62-
public class LineConvolution< T > extends AbstractMultiThreadedConvolution< T >
59+
public class LineConvolution< T > implements Convolution<T>
6360
{
6461
private final LineConvolverFactory< ? super T > factory;
6562

6663
private final int direction;
6764

65+
private ExecutorService executor;
66+
6867
public LineConvolution( final LineConvolverFactory< ? super T > factory, final int direction )
6968
{
7069
this.factory = factory;
7170
this.direction = direction;
7271
}
7372

73+
@Deprecated
74+
@Override
75+
public void setExecutor( ExecutorService executor )
76+
{
77+
this.executor = executor;
78+
}
79+
7480
@Override
7581
public Interval requiredSourceInterval( final Interval targetInterval )
7682
{
@@ -84,104 +90,38 @@ public Interval requiredSourceInterval( final Interval targetInterval )
8490
@Override
8591
public T preferredSourceType( final T targetType )
8692
{
87-
return (T) factory.preferredSourceType( targetType );
93+
return Cast.unchecked( factory.preferredSourceType( targetType ) );
8894
}
8995

9096
@Override
91-
protected void process( final RandomAccessible< ? extends T > source, final RandomAccessibleInterval< ? extends T > target, final ExecutorService executorService, final int numThreads )
97+
public void process( RandomAccessible< ? extends T > source, RandomAccessibleInterval< ? extends T > target )
9298
{
9399
final RandomAccessibleInterval< ? extends T > sourceInterval = Views.interval( source, requiredSourceInterval( target ) );
94100
final long[] sourceMin = Intervals.minAsLongArray( sourceInterval );
95101
final long[] targetMin = Intervals.minAsLongArray( target );
96102

97-
final Supplier< Consumer< Localizable > > actionFactory = () -> {
98-
99-
final RandomAccess< ? extends T > in = sourceInterval.randomAccess();
100-
final RandomAccess< ? extends T > out = target.randomAccess();
101-
final Runnable convolver = factory.getConvolver( in, out, direction, target.dimension( direction ) );
102-
103-
return position -> {
104-
in.setPosition( sourceMin );
105-
out.setPosition( targetMin );
106-
in.move( position );
107-
out.move( position );
108-
convolver.run();
109-
};
110-
};
111-
112103
final long[] dim = Intervals.dimensionsAsLongArray( target );
113104
dim[ direction ] = 1;
114105

115-
final int numTasks = numThreads > 1 ? timesFourAvoidOverflow(numThreads) : 1;
116-
LineConvolution.forEachIntervalElementInParallel( executorService, numTasks, new FinalInterval( dim ), actionFactory );
117-
}
106+
RandomAccessibleInterval< Localizable > positions = Localizables.randomAccessibleInterval( new FinalInterval( dim ) );
107+
TaskExecutor taskExecutor = executor == null ? Parallelization.getTaskExecutor() : TaskExecutors.forExecutorService( executor );
108+
LoopBuilder.setImages( positions ).multiThreaded(taskExecutor).forEachChunk(
109+
chunk -> {
118110

119-
private int timesFourAvoidOverflow( int x )
120-
{
121-
return (int) Math.min((long) x * 4, Integer.MAX_VALUE);
122-
}
111+
final RandomAccess< ? extends T > in = sourceInterval.randomAccess();
112+
final RandomAccess< ? extends T > out = target.randomAccess();
113+
final Runnable convolver = factory.getConvolver( in, out, direction, target.dimension( direction ) );
123114

124-
/**
125-
* {@link #forEachIntervalElementInParallel(ExecutorService, int, Interval, Supplier)}
126-
* executes a given action for each position in a given interval. Therefor
127-
* it starts the specified number of tasks. Each tasks calls the action
128-
* factory once, to get an instance of the action that should be executed.
129-
* The action is then called multiple times by the task.
130-
*
131-
* @param service
132-
* {@link ExecutorService} used to create the tasks.
133-
* @param numTasks
134-
* number of tasks to use.
135-
* @param interval
136-
* interval to iterate over.
137-
* @param actionFactory
138-
* factory that returns the action to be executed.
139-
*/
140-
// TODO: move to a better place
141-
public static void forEachIntervalElementInParallel( final ExecutorService service, final int numTasks, final Interval interval,
142-
final Supplier< Consumer< Localizable > > actionFactory )
143-
{
144-
final long[] min = Intervals.minAsLongArray( interval );
145-
final long[] dim = Intervals.dimensionsAsLongArray( interval );
146-
final long size = Intervals.numElements( dim );
147-
final int boundedNumTasks = (int) Math.max( 1, Math.min(size, numTasks ));
148-
final long taskSize = ( size - 1 ) / boundedNumTasks + 1; // taskSize = roundUp(size / boundedNumTasks);
149-
final ArrayList< Callable< Void > > callables = new ArrayList<>();
115+
chunk.forEachPixel( position -> {
116+
in.setPosition( sourceMin );
117+
out.setPosition( targetMin );
118+
in.move( position );
119+
out.move( position );
120+
convolver.run();
121+
} );
150122

151-
for ( int taskNum = 0; taskNum < boundedNumTasks; ++taskNum )
152-
{
153-
final long myStartIndex = taskNum * taskSize;
154-
final long myEndIndex = Math.min( size, myStartIndex + taskSize );
155-
final Callable< Void > r = () -> {
156-
final Consumer< Localizable > action = actionFactory.get();
157-
final long[] position = new long[ dim.length ];
158-
final Localizable localizable = Point.wrap( position );
159-
for ( long index = myStartIndex; index < myEndIndex; ++index )
160-
{
161-
IntervalIndexer.indexToPositionWithOffset( index, dim, min, position );
162-
action.accept( localizable );
123+
return null;
163124
}
164-
return null;
165-
};
166-
callables.add( r );
167-
}
168-
execute( service, callables );
169-
}
170-
171-
private static void execute( final ExecutorService service, final ArrayList< Callable< Void > > callables )
172-
{
173-
try
174-
{
175-
final List< Future< Void > > futures = service.invokeAll( callables );
176-
for ( final Future< Void > future : futures )
177-
future.get();
178-
}
179-
catch ( final InterruptedException | ExecutionException e )
180-
{
181-
final Throwable cause = e.getCause();
182-
if ( cause instanceof RuntimeException )
183-
throw ( RuntimeException ) cause;
184-
throw new RuntimeException( e );
185-
}
125+
);
186126
}
187127
}

src/main/java/net/imglib2/algorithm/convolution/MultiDimensionConvolution.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ public class MultiDimensionConvolution< T > implements Convolution< T >
5353
{
5454
private ExecutorService executor;
5555

56+
@Deprecated
5657
@Override
5758
public void setExecutor( final ExecutorService executor )
5859
{

src/main/java/net/imglib2/algorithm/gauss3/Gauss3.java

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,18 @@
3434

3535
package net.imglib2.algorithm.gauss3;
3636

37+
import java.util.Arrays;
3738
import java.util.concurrent.ExecutorService;
3839
import java.util.concurrent.Executors;
40+
import java.util.concurrent.ForkJoinPool;
3941

4042
import net.imglib2.RandomAccessible;
4143
import net.imglib2.RandomAccessibleInterval;
4244
import net.imglib2.algorithm.convolution.Convolution;
4345
import net.imglib2.algorithm.convolution.kernel.Kernel1D;
4446
import net.imglib2.algorithm.convolution.kernel.SeparableKernelConvolution;
4547
import net.imglib2.exception.IncompatibleTypeException;
48+
import net.imglib2.parallel.Parallelization;
4649
import net.imglib2.type.numeric.NumericType;
4750
import net.imglib2.type.numeric.RealType;
4851
import net.imglib2.type.numeric.real.DoubleType;
@@ -56,7 +59,7 @@
5659
public final class Gauss3
5760
{
5861
/**
59-
* Apply Gaussian convolution to source and write the result to output.
62+
* Apply Gaussian convolution to source and write the result to target.
6063
* In-place operation (source==target) is supported.
6164
*
6265
* <p>
@@ -66,6 +69,11 @@ public final class Gauss3
6669
* in their own precision. The source type S and target type T are either
6770
* both {@link RealType RealTypes} or both the same type.
6871
*
72+
* <p>
73+
* Computation may be multi-threaded, according to the current
74+
* {@link Parallelization} context. (By default, it will use the
75+
* {@link ForkJoinPool#commonPool() common ForkJoinPool})
76+
*
6977
* @param sigma
7078
* standard deviation of isotropic Gaussian.
7179
* @param source
@@ -93,7 +101,7 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
93101
}
94102

95103
/**
96-
* Apply Gaussian convolution to source and write the result to output.
104+
* Apply Gaussian convolution to source and write the result to target.
97105
* In-place operation (source==target) is supported.
98106
*
99107
* <p>
@@ -104,9 +112,10 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
104112
* both {@link RealType RealTypes} or both the same type.
105113
*
106114
* <p>
107-
* Computation is multi-threaded with as many threads as processors
108-
* available.
109-
*
115+
* Computation may be multi-threaded, according to the current
116+
* {@link Parallelization} context. (By default, it will use the
117+
* {@link ForkJoinPool#commonPool() common ForkJoinPool})
118+
*
110119
* @param sigma
111120
* standard deviation in every dimension.
112121
* @param source
@@ -126,13 +135,27 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
126135
*/
127136
public static < S extends NumericType< S >, T extends NumericType< T > > void gauss( final double[] sigma, final RandomAccessible< S > source, final RandomAccessibleInterval< T > target ) throws IncompatibleTypeException
128137
{
129-
final int numthreads = Runtime.getRuntime().availableProcessors();
130-
final ExecutorService service = Executors.newFixedThreadPool( numthreads );
131-
gauss( sigma, source, target, service );
132-
service.shutdown();
138+
final double[][] halfkernels = halfkernels( sigma );
139+
final Convolution< NumericType< ? > > convolution = SeparableKernelConvolution.convolution( Kernel1D.symmetric( halfkernels ) );
140+
convolution.process( source, target );
133141
}
134142

135143
/**
144+
* @deprecated
145+
* Deprecated. Please use
146+
* {@link Gauss3#gauss(double, RandomAccessible, RandomAccessibleInterval)
147+
* gauss(sigma, source, target)} instead. The number of threads used to
148+
* calculate the Gaussion convolution may by set with the
149+
* {@link Parallelization} context, as show in this example:
150+
* <pre>
151+
* {@code
152+
* Parallelization.runWithNumThreads( numThreads,
153+
* () -> gauss( sigma, source, target )
154+
* );
155+
* }
156+
* </pre>
157+
*
158+
* <p>
136159
* Apply Gaussian convolution to source and write the result to output.
137160
* In-place operation (source==target) is supported.
138161
*
@@ -162,14 +185,30 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
162185
* if source and target type are not compatible (they must be
163186
* either both {@link RealType RealTypes} or the same type).
164187
*/
188+
@Deprecated
165189
public static < S extends NumericType< S >, T extends NumericType< T > > void gauss( final double[] sigma, final RandomAccessible< S > source, final RandomAccessibleInterval< T > target, final int numThreads ) throws IncompatibleTypeException
166190
{
167-
final ExecutorService service = Executors.newFixedThreadPool( numThreads );
168-
gauss( sigma, source, target, service );
169-
service.shutdown();
191+
Parallelization.runWithNumThreads( numThreads,
192+
() -> gauss( sigma, source, target )
193+
);
170194
}
171195

172196
/**
197+
* @deprecated
198+
* Deprecated. Please use
199+
* {@link Gauss3#gauss(double, RandomAccessible, RandomAccessibleInterval)
200+
* gauss(sigma, source, target)} instead. The ExecutorService used to
201+
* calculate the Gaussion convolution may by set with the
202+
* {@link Parallelization} context, as show in this example:
203+
* <pre>
204+
* {@code
205+
* Parallelization.runWithExecutor( executorService,
206+
* () -> gauss( sigma, source, target )
207+
* );
208+
* }
209+
* </pre>
210+
*
211+
* <p>
173212
* Apply Gaussian convolution to source and write the result to output.
174213
* In-place operation (source==target) is supported.
175214
*
@@ -199,12 +238,12 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
199238
* if source and target type are not compatible (they must be
200239
* either both {@link RealType RealTypes} or the same type).
201240
*/
241+
@Deprecated
202242
public static < S extends NumericType< S >, T extends NumericType< T > > void gauss( final double[] sigma, final RandomAccessible< S > source, final RandomAccessibleInterval< T > target, final ExecutorService service ) throws IncompatibleTypeException
203243
{
204-
final double[][] halfkernels = halfkernels( sigma );
205-
final Convolution< NumericType< ? > > convolution = SeparableKernelConvolution.convolution( Kernel1D.symmetric( halfkernels ) );
206-
convolution.setExecutor( service );
207-
convolution.process( source, target );
244+
Parallelization.runWithExecutor( service,
245+
() -> gauss( sigma, source, target )
246+
);
208247
}
209248

210249
public static double[][] halfkernels( final double[] sigma )

0 commit comments

Comments
 (0)