Skip to content

Commit 3dd9ecd

Browse files
committed
Chang PartialDerivative to use multi-threaded LoopBuilder
1 parent fe79678 commit 3dd9ecd

File tree

1 file changed

+58
-101
lines changed

1 file changed

+58
-101
lines changed

src/main/java/net/imglib2/algorithm/gradient/PartialDerivative.java

Lines changed: 58 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
* %%
1212
* Redistribution and use in source and binary forms, with or without
1313
* modification, are permitted provided that the following conditions are met:
14-
*
14+
*
1515
* 1. Redistributions of source code must retain the above copyright notice,
1616
* this list of conditions and the following disclaimer.
1717
* 2. Redistributions in binary form must reproduce the above copyright notice,
1818
* this list of conditions and the following disclaimer in the documentation
1919
* and/or other materials provided with the distribution.
20-
*
20+
*
2121
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
2222
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
2323
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
@@ -34,21 +34,16 @@
3434

3535
package net.imglib2.algorithm.gradient;
3636

37-
import java.util.ArrayList;
38-
import java.util.List;
39-
import java.util.concurrent.Callable;
40-
import java.util.concurrent.ExecutionException;
4137
import java.util.concurrent.ExecutorService;
42-
import java.util.concurrent.Future;
4338

44-
import net.imglib2.Cursor;
45-
import net.imglib2.FinalInterval;
4639
import net.imglib2.RandomAccessible;
4740
import net.imglib2.RandomAccessibleInterval;
4841
import net.imglib2.loops.LoopBuilder;
42+
import net.imglib2.parallel.TaskExecutor;
43+
import net.imglib2.parallel.Parallelization;
44+
import net.imglib2.parallel.TaskExecutors;
4945
import net.imglib2.type.numeric.NumericType;
5046
import net.imglib2.util.Intervals;
51-
import net.imglib2.view.IntervalView;
5247
import net.imglib2.view.Views;
5348

5449
/**
@@ -62,105 +57,67 @@ public class PartialDerivative
6257
{
6358
// nice version...
6459
/**
65-
* Compute the partial derivative (central difference approximation) of source
66-
* in a particular dimension:
67-
* {@code d_f( x ) = ( f( x + e ) - f( x - e ) ) / 2},
68-
* where {@code e} is the unit vector along that dimension.
69-
*
70-
* @param source
71-
* source image, has to provide valid data in the interval of the
72-
* gradient image plus a one pixel border in dimension.
73-
* @param gradient
74-
* output image
75-
* @param dimension
76-
* along which dimension the partial derivatives are computed
60+
* @deprecated
61+
* Use {@link #gradientCentralDifference(RandomAccessible, RandomAccessibleInterval, int)}
62+
* instead.
7763
*/
64+
@Deprecated
7865
public static < T extends NumericType< T > > void gradientCentralDifference2( final RandomAccessible< T > source, final RandomAccessibleInterval< T > gradient, final int dimension )
7966
{
80-
final Cursor< T > front = Views.flatIterable( Views.interval( source, Intervals.translate( gradient, 1, dimension ) ) ).cursor();
81-
final Cursor< T > back = Views.flatIterable( Views.interval( source, Intervals.translate( gradient, -1, dimension ) ) ).cursor();
82-
83-
for ( final T t : Views.flatIterable( gradient ) )
84-
{
85-
t.set( front.next() );
86-
t.sub( back.next() );
87-
t.mul( 0.5 );
88-
}
67+
gradientCentralDifference( source, gradient, dimension );
8968
}
9069

9170
// parallel version...
9271
/**
93-
* Compute the partial derivative (central difference approximation) of source
94-
* in a particular dimension:
95-
* {@code d_f( x ) = ( f( x + e ) - f( x - e ) ) / 2},
96-
* where {@code e} is the unit vector along that dimension.
97-
*
98-
* @param source
99-
* source image, has to provide valid data in the interval of the
100-
* gradient image plus a one pixel border in dimension.
101-
* @param gradient
102-
* output image
103-
* @param dimension
104-
* along which dimension the partial derivatives are computed
105-
* @param nTasks
106-
* Number of tasks for gradient computation.
107-
* @param es
108-
* {@link ExecutorService} providing workers for gradient
109-
* computation. Service is managed (created, shutdown) by caller.
72+
* @deprecated
73+
* Use {@link #gradientCentralDifference(RandomAccessible, RandomAccessibleInterval, int)}
74+
* instead.
75+
* <p>
76+
* Read {@link Parallelization} to learn how to run the method multi-threaded. Here is an example:
77+
* <p>
78+
* <pre>
79+
* {@code
80+
* TaskExecutor taskExecutor = TaskExecutors.forExecutorServiceAndNumTasks( executorService, numTasks );
81+
* Parallelization.runWithExecutor( taskExecutor, () -> {
82+
* gradientCentralDerivativeParallel( source, result, dimension );
83+
* } );
84+
* }
85+
* </pre>
11086
*/
87+
@Deprecated
11188
public static < T extends NumericType< T > > void gradientCentralDifferenceParallel(
11289
final RandomAccessible< T > source,
113-
final RandomAccessibleInterval< T > gradient,
90+
final RandomAccessibleInterval< T > result,
11491
final int dimension,
11592
final int nTasks,
116-
final ExecutorService es ) throws InterruptedException, ExecutionException
93+
final ExecutorService es )
11794
{
118-
final int nDim = source.numDimensions();
119-
if ( nDim < 2 )
120-
{
121-
gradientCentralDifference( source, gradient, dimension );
122-
return;
123-
}
124-
125-
long dimensionMax = Long.MIN_VALUE;
126-
int dimensionArgMax = -1;
127-
128-
for ( int d = 0; d < nDim; ++d )
129-
{
130-
final long size = gradient.dimension( d );
131-
if ( d != dimension && size > dimensionMax )
132-
{
133-
dimensionMax = size;
134-
dimensionArgMax = d;
135-
}
136-
}
137-
138-
final long stepSize = Math.max( dimensionMax / nTasks, 1 );
139-
final long stepSizeMinusOne = stepSize - 1;
140-
final long min = gradient.min( dimensionArgMax );
141-
final long max = gradient.max( dimensionArgMax );
142-
143-
final ArrayList< Callable< Void > > tasks = new ArrayList<>();
144-
for ( long currentMin = min, minZeroBase = 0; minZeroBase < dimensionMax; currentMin += stepSize, minZeroBase += stepSize )
145-
{
146-
final long currentMax = Math.min( currentMin + stepSizeMinusOne, max );
147-
final long[] mins = new long[ nDim ];
148-
final long[] maxs = new long[ nDim ];
149-
gradient.min( mins );
150-
gradient.max( maxs );
151-
mins[ dimensionArgMax ] = currentMin;
152-
maxs[ dimensionArgMax ] = currentMax;
153-
final IntervalView< T > currentInterval = Views.interval( gradient, new FinalInterval( mins, maxs ) );
154-
tasks.add( () -> {
155-
gradientCentralDifference( source, currentInterval, dimension );
156-
return null;
157-
} );
158-
}
159-
160-
final List< Future< Void > > futures = es.invokeAll( tasks );
95+
TaskExecutor taskExecutor = TaskExecutors.forExecutorServiceAndNumTasks( es, nTasks );
96+
Parallelization.runWithExecutor( taskExecutor, () -> {
97+
gradientCentralDerivativeParallel( source, result, dimension );
98+
} );
99+
}
161100

162-
for ( final Future< Void > f : futures )
163-
f.get();
101+
/**
102+
* @deprecated
103+
* Use {@link #gradientCentralDifference(RandomAccessible, RandomAccessibleInterval, int)}
104+
* instead.
105+
* <p>
106+
* Read {@link Parallelization} to learn how to run the method multi-threaded. Here is an example:
107+
* <p>
108+
* <pre>
109+
* {@code
110+
* TaskExecutor taskExecutor = TaskExecutors.forExecutorServiceAndNumTasks( executorService, numTasks );
111+
* Parallelization.runWithExecutor( taskExecutor, () -> {
112+
* gradientCentralDerivativeParallel( source, result, dimension );
113+
* } );
114+
* }
115+
* </pre>
116+
*/
117+
private static <T extends NumericType< T >> void gradientCentralDerivativeParallel( RandomAccessible<T> source,
118+
RandomAccessibleInterval<T> result, int dimension )
119+
{
120+
gradientCentralDifference( source, result, dimension );
164121
}
165122

166123
// fast version
@@ -181,10 +138,10 @@ public static < T extends NumericType< T > > void gradientCentralDifferenceParal
181138
public static < T extends NumericType< T > > void gradientCentralDifference( final RandomAccessible< T > source,
182139
final RandomAccessibleInterval< T > result, final int dimension )
183140
{
184-
final RandomAccessibleInterval< T > back = Views.interval( source, Intervals.translate( result, -1, dimension ) );
185-
final RandomAccessibleInterval< T > front = Views.interval( source, Intervals.translate( result, 1, dimension ) );
141+
final RandomAccessibleInterval<T> back = Views.interval( source, Intervals.translate( result, -1, dimension ) );
142+
final RandomAccessibleInterval<T> front = Views.interval( source, Intervals.translate( result, 1, dimension ) );
186143

187-
LoopBuilder.setImages( result, back, front ).forEachPixel( ( r, b, f ) -> {
144+
LoopBuilder.setImages( result, back, front ).multiThreaded().forEachPixel( ( r, b, f ) -> {
188145
r.set( f );
189146
r.sub( b );
190147
r.mul( 0.5 );
@@ -207,7 +164,7 @@ public static < T extends NumericType< T > > void gradientBackwardDifference( fi
207164
final RandomAccessibleInterval< T > back = Views.interval( source, Intervals.translate( result, -1, dimension ) );
208165
final RandomAccessibleInterval< T > front = Views.interval( source, result );
209166

210-
LoopBuilder.setImages( result, back, front ).forEachPixel( ( r, b, f ) -> {
167+
LoopBuilder.setImages( result, back, front ).multiThreaded().forEachPixel( ( r, b, f ) -> {
211168
r.set( f );
212169
r.sub( b );
213170
} );
@@ -217,7 +174,7 @@ public static < T extends NumericType< T > > void gradientBackwardDifference( fi
217174
* Compute the forward difference of source in a particular dimension:
218175
* {@code d_f( x ) = ( f( x + e ) - f( x ) )}
219176
* where {@code e} is the unit vector along that dimension
220-
177+
221178
* @param source source image, has to provide valid data in the interval of
222179
* the gradient image plus a one pixel border in dimension.
223180
* @param result output image
@@ -229,7 +186,7 @@ public static < T extends NumericType< T > > void gradientForwardDifference( fin
229186
final RandomAccessibleInterval< T > back = Views.interval( source, result );
230187
final RandomAccessibleInterval< T > front = Views.interval( source, Intervals.translate( result, 1, dimension ) );
231188

232-
LoopBuilder.setImages( result, back, front ).forEachPixel( ( r, b, f ) -> {
189+
LoopBuilder.setImages( result, back, front ).multiThreaded().forEachPixel( ( r, b, f ) -> {
233190
r.set( f );
234191
r.sub( b );
235192
} );

0 commit comments

Comments
 (0)