Skip to content

Commit 0cd4d0e

Browse files
committed
Use LoopBuilder and Parallelization classes in Gauss3
This simplifies the code, but also improves the performance, because Gauss3.gauss(sigma, source, target) will no longer create it's own ExecutorService. This also means that the multi threading behaviour of Gauss3 can now conviniently be controled using the Parallelization class. For example: Parallelization.runSingleThreaded(() -> { Gauss3.gauss(sigma, source, target); });
1 parent 368a9f0 commit 0cd4d0e

File tree

8 files changed

+61
-112
lines changed

8 files changed

+61
-112
lines changed

src/main/java/net/imglib2/algorithm/convolution/AbstractMultiThreadedConvolution.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
*
5252
* @author Matthias Arzt
5353
*/
54+
@Deprecated
5455
public abstract class AbstractMultiThreadedConvolution< T > implements Convolution< T >
5556
{
5657

@@ -61,6 +62,7 @@ abstract protected void process( RandomAccessible< ? extends T > source,
6162
ExecutorService executorService,
6263
int numThreads );
6364

65+
@Deprecated
6466
@Override
6567
public void setExecutor( final ExecutorService executor )
6668
{

src/main/java/net/imglib2/algorithm/convolution/Concatenation.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,12 @@ class Concatenation< T > implements Convolution< T >
6464
this.steps = new ArrayList<>( steps );
6565
}
6666

67+
@Deprecated
6768
@Override
68-
public void setExecutor( final ExecutorService executor )
69+
public void setExecutor( ExecutorService executor )
6970
{
70-
steps.forEach( step -> step.setExecutor( executor ) );
71+
for ( Convolution<T> step : steps )
72+
step.setExecutor( executor );
7173
}
7274

7375
@Override

src/main/java/net/imglib2/algorithm/convolution/Convolution.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ public interface Convolution< T >
6767
/**
6868
* Set the {@link ExecutorService} to be used for convolution.
6969
*/
70+
@Deprecated
7071
default void setExecutor( final ExecutorService executor )
7172
{}
7273

src/main/java/net/imglib2/algorithm/convolution/LineConvolution.java

Lines changed: 36 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -33,44 +33,50 @@
3333
*/
3434
package net.imglib2.algorithm.convolution;
3535

36-
import java.util.ArrayList;
37-
import java.util.List;
38-
import java.util.concurrent.Callable;
39-
import java.util.concurrent.ExecutionException;
40-
import java.util.concurrent.ExecutorService;
41-
import java.util.concurrent.Future;
42-
import java.util.function.Consumer;
43-
import java.util.function.Supplier;
44-
4536
import net.imglib2.FinalInterval;
4637
import net.imglib2.Interval;
4738
import net.imglib2.Localizable;
48-
import net.imglib2.Point;
4939
import net.imglib2.RandomAccess;
5040
import net.imglib2.RandomAccessible;
5141
import net.imglib2.RandomAccessibleInterval;
52-
import net.imglib2.util.IntervalIndexer;
42+
import net.imglib2.loops.LoopBuilder;
43+
import net.imglib2.parallel.Parallelization;
44+
import net.imglib2.parallel.TaskExecutor;
45+
import net.imglib2.parallel.TaskExecutors;
46+
import net.imglib2.util.Cast;
5347
import net.imglib2.util.Intervals;
48+
import net.imglib2.util.Localizables;
5449
import net.imglib2.view.Views;
5550

51+
import java.util.concurrent.ExecutorService;
52+
5653
/**
5754
* This class can be used to implement a separable convolution. It applies a
5855
* {@link LineConvolverFactory} on the given images.
5956
*
6057
* @author Matthias Arzt
6158
*/
62-
public class LineConvolution< T > extends AbstractMultiThreadedConvolution< T >
59+
public class LineConvolution< T > implements Convolution<T>
6360
{
6461
private final LineConvolverFactory< ? super T > factory;
6562

6663
private final int direction;
6764

65+
private ExecutorService executor;
66+
6867
public LineConvolution( final LineConvolverFactory< ? super T > factory, final int direction )
6968
{
7069
this.factory = factory;
7170
this.direction = direction;
7271
}
7372

73+
@Deprecated
74+
@Override
75+
public void setExecutor( ExecutorService executor )
76+
{
77+
this.executor = executor;
78+
}
79+
7480
@Override
7581
public Interval requiredSourceInterval( final Interval targetInterval )
7682
{
@@ -84,104 +90,38 @@ public Interval requiredSourceInterval( final Interval targetInterval )
8490
@Override
8591
public T preferredSourceType( final T targetType )
8692
{
87-
return (T) factory.preferredSourceType( targetType );
93+
return Cast.unchecked( factory.preferredSourceType( targetType ) );
8894
}
8995

9096
@Override
91-
protected void process( final RandomAccessible< ? extends T > source, final RandomAccessibleInterval< ? extends T > target, final ExecutorService executorService, final int numThreads )
97+
public void process( RandomAccessible< ? extends T > source, RandomAccessibleInterval< ? extends T > target )
9298
{
9399
final RandomAccessibleInterval< ? extends T > sourceInterval = Views.interval( source, requiredSourceInterval( target ) );
94100
final long[] sourceMin = Intervals.minAsLongArray( sourceInterval );
95101
final long[] targetMin = Intervals.minAsLongArray( target );
96102

97-
final Supplier< Consumer< Localizable > > actionFactory = () -> {
98-
99-
final RandomAccess< ? extends T > in = sourceInterval.randomAccess();
100-
final RandomAccess< ? extends T > out = target.randomAccess();
101-
final Runnable convolver = factory.getConvolver( in, out, direction, target.dimension( direction ) );
102-
103-
return position -> {
104-
in.setPosition( sourceMin );
105-
out.setPosition( targetMin );
106-
in.move( position );
107-
out.move( position );
108-
convolver.run();
109-
};
110-
};
111-
112103
final long[] dim = Intervals.dimensionsAsLongArray( target );
113104
dim[ direction ] = 1;
114105

115-
final int numTasks = numThreads > 1 ? timesFourAvoidOverflow(numThreads) : 1;
116-
LineConvolution.forEachIntervalElementInParallel( executorService, numTasks, new FinalInterval( dim ), actionFactory );
117-
}
106+
RandomAccessibleInterval< Localizable > positions = Localizables.randomAccessibleInterval( new FinalInterval( dim ) );
107+
TaskExecutor taskExecutor = executor == null ? Parallelization.getTaskExecutor() : TaskExecutors.forExecutorService( executor );
108+
LoopBuilder.setImages( positions ).multiThreaded(taskExecutor).forEachChunk(
109+
chunk -> {
118110

119-
private int timesFourAvoidOverflow( int x )
120-
{
121-
return (int) Math.min((long) x * 4, Integer.MAX_VALUE);
122-
}
111+
final RandomAccess< ? extends T > in = sourceInterval.randomAccess();
112+
final RandomAccess< ? extends T > out = target.randomAccess();
113+
final Runnable convolver = factory.getConvolver( in, out, direction, target.dimension( direction ) );
123114

124-
/**
125-
* {@link #forEachIntervalElementInParallel(ExecutorService, int, Interval, Supplier)}
126-
* executes a given action for each position in a given interval. Therefor
127-
* it starts the specified number of tasks. Each tasks calls the action
128-
* factory once, to get an instance of the action that should be executed.
129-
* The action is then called multiple times by the task.
130-
*
131-
* @param service
132-
* {@link ExecutorService} used to create the tasks.
133-
* @param numTasks
134-
* number of tasks to use.
135-
* @param interval
136-
* interval to iterate over.
137-
* @param actionFactory
138-
* factory that returns the action to be executed.
139-
*/
140-
// TODO: move to a better place
141-
public static void forEachIntervalElementInParallel( final ExecutorService service, final int numTasks, final Interval interval,
142-
final Supplier< Consumer< Localizable > > actionFactory )
143-
{
144-
final long[] min = Intervals.minAsLongArray( interval );
145-
final long[] dim = Intervals.dimensionsAsLongArray( interval );
146-
final long size = Intervals.numElements( dim );
147-
final int boundedNumTasks = (int) Math.max( 1, Math.min(size, numTasks ));
148-
final long taskSize = ( size - 1 ) / boundedNumTasks + 1; // taskSize = roundUp(size / boundedNumTasks);
149-
final ArrayList< Callable< Void > > callables = new ArrayList<>();
115+
chunk.forEachPixel( position -> {
116+
in.setPosition( sourceMin );
117+
out.setPosition( targetMin );
118+
in.move( position );
119+
out.move( position );
120+
convolver.run();
121+
} );
150122

151-
for ( int taskNum = 0; taskNum < boundedNumTasks; ++taskNum )
152-
{
153-
final long myStartIndex = taskNum * taskSize;
154-
final long myEndIndex = Math.min( size, myStartIndex + taskSize );
155-
final Callable< Void > r = () -> {
156-
final Consumer< Localizable > action = actionFactory.get();
157-
final long[] position = new long[ dim.length ];
158-
final Localizable localizable = Point.wrap( position );
159-
for ( long index = myStartIndex; index < myEndIndex; ++index )
160-
{
161-
IntervalIndexer.indexToPositionWithOffset( index, dim, min, position );
162-
action.accept( localizable );
123+
return null;
163124
}
164-
return null;
165-
};
166-
callables.add( r );
167-
}
168-
execute( service, callables );
169-
}
170-
171-
private static void execute( final ExecutorService service, final ArrayList< Callable< Void > > callables )
172-
{
173-
try
174-
{
175-
final List< Future< Void > > futures = service.invokeAll( callables );
176-
for ( final Future< Void > future : futures )
177-
future.get();
178-
}
179-
catch ( final InterruptedException | ExecutionException e )
180-
{
181-
final Throwable cause = e.getCause();
182-
if ( cause instanceof RuntimeException )
183-
throw ( RuntimeException ) cause;
184-
throw new RuntimeException( e );
185-
}
125+
);
186126
}
187127
}

src/main/java/net/imglib2/algorithm/convolution/MultiDimensionConvolution.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ public class MultiDimensionConvolution< T > implements Convolution< T >
5353
{
5454
private ExecutorService executor;
5555

56+
@Deprecated
5657
@Override
5758
public void setExecutor( final ExecutorService executor )
5859
{

src/main/java/net/imglib2/algorithm/gauss3/Gauss3.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import net.imglib2.algorithm.convolution.kernel.Kernel1D;
4444
import net.imglib2.algorithm.convolution.kernel.SeparableKernelConvolution;
4545
import net.imglib2.exception.IncompatibleTypeException;
46+
import net.imglib2.parallel.Parallelization;
4647
import net.imglib2.type.numeric.NumericType;
4748
import net.imglib2.type.numeric.RealType;
4849
import net.imglib2.type.numeric.real.DoubleType;
@@ -126,10 +127,9 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
126127
*/
127128
public static < S extends NumericType< S >, T extends NumericType< T > > void gauss( final double[] sigma, final RandomAccessible< S > source, final RandomAccessibleInterval< T > target ) throws IncompatibleTypeException
128129
{
129-
final int numthreads = Runtime.getRuntime().availableProcessors();
130-
final ExecutorService service = Executors.newFixedThreadPool( numthreads );
131-
gauss( sigma, source, target, service );
132-
service.shutdown();
130+
final double[][] halfkernels = halfkernels( sigma );
131+
final Convolution< NumericType< ? > > convolution = SeparableKernelConvolution.convolution( Kernel1D.symmetric( halfkernels ) );
132+
convolution.process( source, target );
133133
}
134134

135135
/**
@@ -201,10 +201,9 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
201201
*/
202202
public static < S extends NumericType< S >, T extends NumericType< T > > void gauss( final double[] sigma, final RandomAccessible< S > source, final RandomAccessibleInterval< T > target, final ExecutorService service ) throws IncompatibleTypeException
203203
{
204-
final double[][] halfkernels = halfkernels( sigma );
205-
final Convolution< NumericType< ? > > convolution = SeparableKernelConvolution.convolution( Kernel1D.symmetric( halfkernels ) );
206-
convolution.setExecutor( service );
207-
convolution.process( source, target );
204+
Parallelization.runWithExecutor( service,
205+
() -> gauss( sigma, source, target )
206+
);
208207
}
209208

210209
public static double[][] halfkernels( final double[] sigma )

src/test/java/net/imglib2/algorithm/convolution/ConcatenationTest.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,8 @@
3737
import net.imglib2.Interval;
3838
import net.imglib2.RandomAccessible;
3939
import net.imglib2.RandomAccessibleInterval;
40-
import net.imglib2.algorithm.convolution.kernel.SeparableKernelConvolution;
4140
import net.imglib2.img.Img;
42-
import net.imglib2.img.ImgFactory;
4341
import net.imglib2.img.array.ArrayImgs;
44-
import net.imglib2.img.cell.CellImgFactory;
4542
import net.imglib2.loops.LoopBuilder;
4643
import net.imglib2.type.numeric.RealType;
4744
import net.imglib2.type.numeric.integer.IntType;
@@ -51,6 +48,7 @@
5148
import net.imglib2.util.Intervals;
5249
import net.imglib2.view.IntervalView;
5350
import net.imglib2.view.Views;
51+
import org.junit.Ignore;
5452
import org.junit.Test;
5553

5654
import static org.junit.Assert.assertArrayEquals;
@@ -88,6 +86,7 @@ public void testDifferences2()
8886
assertArrayEquals( new int[] { -3, 3 }, targetPixels );
8987
}
9088

89+
@Ignore( "takes to long" )
9190
@Test
9291
public void testHugeImage()
9392
{
@@ -96,7 +95,7 @@ public void testHugeImage()
9695
assertTrue( width * height > Integer.MAX_VALUE );
9796
RandomAccessible< UnsignedByteType > source = ConstantUtils.constantRandomAccessible( new UnsignedByteType(), 2 );
9897
RandomAccessibleInterval< UnsignedByteType > target = ConstantUtils.constantRandomAccessibleInterval(
99-
new UnsignedByteType(), 2, new FinalInterval( width, height ) );
98+
new UnsignedByteType(), new FinalInterval( width, height ) );
10099
double[][] kernels = { { 2 }, { 3 } };
101100
try
102101
{

src/test/java/net/imglib2/algorithm/convolution/LineConvolutionTest.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
import net.imglib2.RandomAccess;
3838
import net.imglib2.img.Img;
3939
import net.imglib2.img.array.ArrayImgs;
40+
import net.imglib2.parallel.Parallelization;
41+
import net.imglib2.parallel.TaskExecutors;
4042
import net.imglib2.type.numeric.integer.UnsignedByteType;
4143
import net.imglib2.util.Intervals;
4244
import org.junit.Test;
@@ -85,8 +87,11 @@ public void testNumTasksEqualsIntegerMaxValue() {
8587
byte[] result = new byte[ 1 ];
8688
Img< UnsignedByteType > out = ArrayImgs.unsignedBytes( result, result.length );
8789
Img< UnsignedByteType > in = ArrayImgs.unsignedBytes( new byte[] { 1, 2 }, 2 );
88-
final LineConvolution< UnsignedByteType > convolution = new LineConvolution<>( new ForwardDifferenceConvolverFactory(), 0 );
89-
convolution.process( in, out, Executors.newSingleThreadExecutor(), Integer.MAX_VALUE );
90+
Runnable runnable = () -> {
91+
final LineConvolution< UnsignedByteType > convolution = new LineConvolution<>( new ForwardDifferenceConvolverFactory(), 0 );
92+
convolution.process( in, out );
93+
};
94+
Parallelization.runWithExecutor( TaskExecutors.forExecutorServiceAndNumTasks( Executors.newSingleThreadExecutor(), Integer.MAX_VALUE) , runnable );
9095
assertArrayEquals( new byte[] { 1 }, result );
9196
}
9297

0 commit comments

Comments
 (0)