Skip to content

Commit d9c2893

Browse files
authored
Merge pull request #78 from imglib/fix-issue-77
Fix issue 77
2 parents b4b2c1a + cfb775a commit d9c2893

File tree

4 files changed

+64
-18
lines changed

4 files changed

+64
-18
lines changed

src/main/java/net/imglib2/algorithm/convolution/AbstractMultiThreadedConvolution.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ final public void process( final RandomAccessible< ? extends T > source, final R
3939
{
4040
if ( executor == null )
4141
{
42-
final int numThreads = suggestNumThreads();
42+
final int numThreads = Runtime.getRuntime().availableProcessors();
4343
final ExecutorService executor = Executors.newFixedThreadPool( numThreads );
4444
try
4545
{
@@ -56,15 +56,13 @@ final public void process( final RandomAccessible< ? extends T > source, final R
5656
}
5757
}
5858

59-
private int getNumThreads( final ExecutorService executor )
59+
static int getNumThreads( final ExecutorService executor )
6060
{
61-
if ( executor instanceof ThreadPoolExecutor )
62-
return ( ( ThreadPoolExecutor ) executor ).getMaximumPoolSize();
63-
return suggestNumThreads();
61+
int maxPoolSize = ( executor instanceof ThreadPoolExecutor ) ?
62+
( ( ThreadPoolExecutor ) executor ).getMaximumPoolSize() :
63+
Integer.MAX_VALUE;
64+
int availableProcessors = Runtime.getRuntime().availableProcessors();
65+
return Math.max(1, Math.min(availableProcessors, maxPoolSize));
6466
}
6567

66-
private int suggestNumThreads()
67-
{
68-
return Runtime.getRuntime().availableProcessors();
69-
}
7068
}

src/main/java/net/imglib2/algorithm/convolution/LineConvolution.java

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@
1616
import net.imglib2.RandomAccess;
1717
import net.imglib2.RandomAccessible;
1818
import net.imglib2.RandomAccessibleInterval;
19-
import net.imglib2.type.Type;
20-
import net.imglib2.type.numeric.RealType;
21-
import net.imglib2.type.numeric.real.DoubleType;
22-
import net.imglib2.type.numeric.real.FloatType;
2319
import net.imglib2.util.IntervalIndexer;
2420
import net.imglib2.util.Intervals;
2521
import net.imglib2.view.Views;
@@ -83,10 +79,15 @@ protected void process( final RandomAccessible< ? extends T > source, final Rand
8379
final long[] dim = Intervals.dimensionsAsLongArray( target );
8480
dim[ direction ] = 1;
8581

86-
final int numTasks = numThreads > 1 ? numThreads * 4 : 1;
82+
final int numTasks = numThreads > 1 ? timesFourAvoidOverflow(numThreads) : 1;
8783
LineConvolution.forEachIntervalElementInParallel( executorService, numTasks, new FinalInterval( dim ), actionFactory );
8884
}
8985

86+
private int timesFourAvoidOverflow( int x )
87+
{
88+
return (int) Math.min((long) x * 4, Integer.MAX_VALUE);
89+
}
90+
9091
/**
9192
* {@link #forEachIntervalElementInParallel(ExecutorService, int, Interval, Supplier)}
9293
* executes a given action for each position in a given interval. Therefor
@@ -110,14 +111,14 @@ public static void forEachIntervalElementInParallel( final ExecutorService servi
110111
final long[] min = Intervals.minAsLongArray( interval );
111112
final long[] dim = Intervals.dimensionsAsLongArray( interval );
112113
final long size = Intervals.numElements( dim );
113-
final long endIndex = size;
114-
final long taskSize = ( size + numTasks - 1 ) / numTasks; // round up
114+
final int boundedNumTasks = (int) Math.max( 1, Math.min(size, numTasks ));
115+
final long taskSize = ( size - 1 ) / boundedNumTasks + 1; // taskSize = roundUp(size / boundedNumTasks);
115116
final ArrayList< Callable< Void > > callables = new ArrayList<>();
116117

117-
for ( int taskNum = 0; taskNum < numTasks; ++taskNum )
118+
for ( int taskNum = 0; taskNum < boundedNumTasks; ++taskNum )
118119
{
119120
final long myStartIndex = taskNum * taskSize;
120-
final long myEndIndex = Math.min( endIndex, myStartIndex + taskSize );
121+
final long myEndIndex = Math.min( size, myStartIndex + taskSize );
121122
final Callable< Void > r = () -> {
122123
final Consumer< Localizable > action = actionFactory.get();
123124
final long[] position = new long[ dim.length ];
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package net.imglib2.algorithm.convolution;
2+
3+
import java.util.concurrent.ExecutorService;
4+
import java.util.concurrent.Executors;
5+
6+
import org.junit.Test;
7+
8+
import static junit.framework.TestCase.assertTrue;
9+
import static org.junit.Assert.assertArrayEquals;
10+
import static org.junit.Assert.assertEquals;
11+
import static org.junit.Assert.assertNotEquals;
12+
13+
/**
14+
* Tests {@link AbstractMultiThreadedConvolution}
15+
*
16+
* @author Matthias Arzt
17+
*/
18+
public class AbstractMultiThreadedConvolutionTest
19+
{
20+
@Test
21+
public void testSuggestNumTasksFixedThreadPool() {
22+
if ( Runtime.getRuntime().availableProcessors() < 3 )
23+
return;
24+
final ExecutorService executor = Executors.newFixedThreadPool( 3 );
25+
int result = AbstractMultiThreadedConvolution.getNumThreads( executor );
26+
assertEquals( 3, result );
27+
}
28+
29+
@Test
30+
public void testSuggestNumTasksCachedThreadPool() {
31+
final ExecutorService executor = Executors.newCachedThreadPool();
32+
int result = AbstractMultiThreadedConvolution.getNumThreads( executor );
33+
assertTrue( Runtime.getRuntime().availableProcessors() >= result );
34+
}
35+
}

src/test/java/net/imglib2/algorithm/convolution/LineConvolutionTest.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import net.imglib2.util.Intervals;
99
import org.junit.Test;
1010

11+
import java.util.concurrent.Executors;
12+
1113
import static org.junit.Assert.assertArrayEquals;
1214
import static org.junit.Assert.assertTrue;
1315

@@ -45,6 +47,16 @@ public void testConvolve()
4547
assertArrayEquals( expected, result );
4648
}
4749

50+
@Test
51+
public void testNumTasksEqualsIntegerMaxValue() {
52+
byte[] result = new byte[ 1 ];
53+
Img< UnsignedByteType > out = ArrayImgs.unsignedBytes( result, result.length );
54+
Img< UnsignedByteType > in = ArrayImgs.unsignedBytes( new byte[] { 1, 2 }, 2 );
55+
final LineConvolution< UnsignedByteType > convolution = new LineConvolution<>( new ForwardDifferenceConvolverFactory(), 0 );
56+
convolution.process( in, out, Executors.newSingleThreadExecutor(), Integer.MAX_VALUE );
57+
assertArrayEquals( new byte[] { 1 }, result );
58+
}
59+
4860
static class ForwardDifferenceConvolverFactory implements LineConvolverFactory< UnsignedByteType >
4961
{
5062

0 commit comments

Comments
 (0)