Skip to content

Commit c619131

Browse files
committed
AbstractMultiThreadedConvolution: shutdown automatically create ExecutorService
1 parent 44d2c88 commit c619131

File tree

3 files changed

+49
-24
lines changed

3 files changed

+49
-24
lines changed
Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,66 @@
11
package net.imglib2.algorithm.convolution;
22

3+
import net.imglib2.RandomAccessible;
4+
import net.imglib2.RandomAccessibleInterval;
5+
36
import java.util.concurrent.ExecutorService;
47
import java.util.concurrent.Executors;
58
import java.util.concurrent.ThreadPoolExecutor;
69

710
/**
811
* Abstract class to help implementing a Convolution, that is multi threaded
912
* using an {@link ExecutorService}. This implements the method
10-
* {@link Convolution#setExecutor(ExecutorService)} and has useful protected
11-
* methods {@link #getExecutor()} and {@link #getNumThreads()}.
13+
* {@link Convolution#setExecutor(ExecutorService)}.
14+
* <p>
15+
* Classes that derive from
16+
* {@link AbstractMultiThreadedConvolution} must override
17+
* {@link AbstractMultiThreadedConvolution#process(RandomAccessible, RandomAccessibleInterval, ExecutorService, int)}
1218
*/
1319
public abstract class AbstractMultiThreadedConvolution< T > implements Convolution< T >
1420
{
15-
private ExecutorService executor = null;
21+
22+
private ExecutorService executor;
23+
24+
abstract protected void process( RandomAccessible< ? extends T > source,
25+
RandomAccessibleInterval< ? extends T > target,
26+
ExecutorService executorService,
27+
int numThreads);
1628

1729
@Override
1830
public void setExecutor( ExecutorService executor )
1931
{
2032
this.executor = executor;
2133
}
2234

23-
protected ExecutorService getExecutor()
35+
@Override
36+
final public void process( RandomAccessible< ? extends T > source, RandomAccessibleInterval< ? extends T > target )
2437
{
25-
if ( executor == null )
26-
executor = Executors.newFixedThreadPool( Runtime.getRuntime().availableProcessors() );
27-
return executor;
38+
if(executor == null) {
39+
int numThreads = suggestNumThreads();
40+
ExecutorService executor = Executors.newFixedThreadPool( numThreads );
41+
try
42+
{
43+
process( source, target, executor, numThreads );
44+
}
45+
finally
46+
{
47+
executor.shutdown();
48+
}
49+
}
50+
else {
51+
process( source, target, executor, getNumThreads( executor ) );
52+
}
2853
}
2954

30-
protected int getNumThreads()
55+
private int getNumThreads(ExecutorService executor)
3156
{
32-
ExecutorService executor = getExecutor();
3357
if ( executor instanceof ThreadPoolExecutor )
3458
return ( ( ThreadPoolExecutor ) executor ).getMaximumPoolSize();
59+
return suggestNumThreads();
60+
}
61+
62+
private int suggestNumThreads()
63+
{
3564
return Runtime.getRuntime().availableProcessors();
3665
}
3766
}

src/main/java/net/imglib2/algorithm/convolution/LineConvolution.java

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,18 @@ public LineConvolution( LineConvolverFactory< ? super T > factory, int direction
5353
return targetType;
5454
}
5555

56-
@Override public void process( RandomAccessible< ? extends T > source, RandomAccessibleInterval< ? extends T > target )
56+
@Override
57+
protected void process( RandomAccessible< ? extends T > source, RandomAccessibleInterval< ? extends T > target, ExecutorService executorService, int numThreads )
5758
{
58-
internConvolve( factory, Views.interval( source, requiredSourceInterval( target ) ), target, direction );
59-
}
60-
61-
private < T > void internConvolve( LineConvolverFactory< ? super T > factory, RandomAccessibleInterval< ? extends T > source, RandomAccessibleInterval< ? extends T > target, int d )
62-
{
63-
final long[] sourceMin = Intervals.minAsLongArray( source );
59+
RandomAccessibleInterval< ? extends T > sourceInterval = Views.interval( source, requiredSourceInterval( target ) );
60+
final long[] sourceMin = Intervals.minAsLongArray( sourceInterval );
6461
final long[] targetMin = Intervals.minAsLongArray( target );
6562

6663
Supplier< Consumer< Localizable > > actionFactory = () -> {
6764

68-
final RandomAccess< ? extends T > in = source.randomAccess();
65+
final RandomAccess< ? extends T > in = sourceInterval.randomAccess();
6966
final RandomAccess< ? extends T > out = target.randomAccess();
70-
final Runnable convolver = factory.getConvolver( in, out, d, target.dimension( d ) );
67+
final Runnable convolver = factory.getConvolver( in, out, direction, target.dimension( direction ) );
7168

7269
return position -> {
7370
in.setPosition( sourceMin );
@@ -79,12 +76,9 @@ private < T > void internConvolve( LineConvolverFactory< ? super T > factory, Ra
7976
};
8077

8178
final long[] dim = Intervals.dimensionsAsLongArray( target );
82-
dim[ d ] = 1;
79+
dim[ direction ] = 1;
8380

84-
// FIXME: is there a better way to determine the number of threads
85-
final int numThreads = getNumThreads();
8681
final int numTasks = numThreads > 1 ? numThreads * 4 : 1;
87-
ExecutorService executorService = getExecutor();
8882
LineConvolution.forEachIntervalElementInParallel( executorService, numTasks, new FinalInterval( dim ), actionFactory );
8983
}
9084

src/test/java/net/imglib2/algorithm/convolution/GaussBenchmark.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import net.imglib2.algorithm.gauss3.SeparableSymmetricConvolution;
99
import net.imglib2.img.Img;
1010
import net.imglib2.img.array.ArrayImgs;
11+
import net.imglib2.type.numeric.NumericType;
12+
import net.imglib2.type.numeric.RealType;
1113
import net.imglib2.type.numeric.real.DoubleType;
1214
import net.imglib2.view.Views;
1315
import org.openjdk.jmh.annotations.Benchmark;
@@ -54,14 +56,14 @@ public void benchmarkSeparableSymmertricConvolution()
5456
@Benchmark
5557
public void benchmarkFastGauss()
5658
{
57-
FastGauss.convolve( sigma, inImage, outImage );
59+
FastGauss.convolution( sigma ).process( inImage, outImage );
5860
}
5961

6062
public static void main( String[] args ) throws RunnerException
6163
{
6264
Options opt = new OptionsBuilder()
6365
.include( GaussBenchmark.class.getSimpleName() )
64-
.forks( 0 )
66+
.forks( 1 )
6567
.warmupIterations( 4 )
6668
.measurementIterations( 8 )
6769
.warmupTime( TimeValue.milliseconds( 100 ) )

0 commit comments

Comments
 (0)