33using AbstractFFTs
44using AbstractFFTs: Plan
55using ChainRulesTestUtils
6+ using ChainRulesCore: NoTangent
67
78using LinearAlgebra
89using Random
197198 @test @inferred (f9 (plan_fft (zeros (10 ), 1 ), 10 )) == 1 / 10
198199end
199200
201+ @testset " output size" begin
202+ @testset " complex fft output size" begin
203+ for x in (randn (3 ), randn (3 , 4 ), randn (3 , 4 , 5 ))
204+ N = ndims (x)
205+ y = randn (size (x))
206+ for dims in unique ((1 , 1 : N, N))
207+ P = plan_fft (x, dims)
208+ @test AbstractFFTs. output_size (P) == size (x)
209+ @test AbstractFFTs. output_size (P' ) == size (x)
210+ Pinv = plan_ifft (x)
211+ @test AbstractFFTs. output_size (Pinv) == size (x)
212+ @test AbstractFFTs. output_size (Pinv' ) == size (x)
213+ end
214+ end
215+ end
216+ @testset " real fft output size" begin
217+ for x in (randn (3 ), randn (4 ), randn (3 , 4 ), randn (3 , 4 , 5 )) # test odd and even lengths
218+ N = ndims (x)
219+ for dims in unique ((1 , 1 : N, N))
220+ P = plan_rfft (x, dims)
221+ Px_sz = size (P * x)
222+ @test AbstractFFTs. output_size (P) == Px_sz
223+ @test AbstractFFTs. output_size (P' ) == size (x)
224+ y = randn (Px_sz) .+ randn (Px_sz) * im
225+ Pinv = plan_irfft (y, size (x)[first (dims)], dims)
226+ @test AbstractFFTs. output_size (Pinv) == size (Pinv * y)
227+ @test AbstractFFTs. output_size (Pinv' ) == size (y)
228+ end
229+ end
230+ end
231+ end
232+
233+ @testset " adjoint" begin
234+ @testset " complex fft adjoint" begin
235+ for x in (randn (3 ), randn (3 , 4 ), randn (3 , 4 , 5 ))
236+ N = ndims (x)
237+ y = randn (size (x))
238+ for dims in unique ((1 , 1 : N, N))
239+ P = plan_fft (x, dims)
240+ @test (P' )' * x == P * x # test adjoint of adjoint
241+ @test size (P' ) == AbstractFFTs. output_size (P) # test size of adjoint
242+ @test dot (y, P * x) ≈ dot (P' * y, x) # test validity of adjoint
243+ @test_broken dot (y, P \ x) ≈ dot (P' \ y, x)
244+ Pinv = plan_ifft (y)
245+ @test (Pinv' )' * y == Pinv * y
246+ @test size (Pinv' ) == AbstractFFTs. output_size (Pinv)
247+ @test dot (x, Pinv * y) ≈ dot (Pinv' * x, y)
248+ @test_broken dot (x, Pinv \ y) ≈ dot (Pinv' \ x, y)
249+ end
250+ end
251+ end
252+ @testset " real fft adjoint" begin
253+ for x in (randn (3 ), randn (4 ), randn (3 , 4 ), randn (3 , 4 , 5 )) # test odd and even lengths
254+ N = ndims (x)
255+ for dims in unique ((1 , 1 : N, N))
256+ P = plan_rfft (x, dims)
257+ y_real = randn (size (P * x))
258+ y_imag = randn (size (P * x))
259+ y = y_real .+ y_imag .* im
260+ @test (P' )' * x == P * x
261+ @test size (P' ) == AbstractFFTs. output_size (P)
262+ @test dot (y_real, real .(P * x)) + dot (y_imag, imag .(P * x)) ≈ dot (P' * y, x)
263+ @test_broken dot (y_real, real .(P \ x)) + dot (y_imag, imag .(P \ x)) ≈ dot (P' * y, x)
264+ Pinv = plan_irfft (y, size (x)[first (dims)], dims)
265+ @test (Pinv' )' * y == Pinv * y
266+ @test size (Pinv' ) == AbstractFFTs. output_size (Pinv)
267+ @test dot (x, Pinv * y) ≈ dot (y_real, real .(Pinv' * x)) + dot (y_imag, imag .(Pinv' * x))
268+ @test_broken dot (x, Pinv \ y) ≈ dot (y_real, real .(Pinv' \ x)) + dot (y_imag, imag .(Pinv' \ x))
269+ end
270+ end
271+ end
272+ end
273+
200274@testset " ChainRules" begin
201275 @testset " shift functions" begin
202276 for x in (randn (3 ), randn (3 , 4 ), randn (3 , 4 , 5 ))
@@ -218,20 +292,31 @@ end
218292 end
219293
220294 @testset " fft" begin
221- for x in (randn (3 ), randn (3 , 4 ), randn (3 , 4 , 5 ))
295+ for x in (randn (2 ), randn (2 , 3 ), randn (3 , 4 , 5 ))
222296 N = ndims (x)
223297 complex_x = complex .(x)
224298 for dims in unique ((1 , 1 : N, N))
299+ # fft, ifft, bfft
225300 for f in (fft, ifft, bfft)
226301 test_frule (f, x, dims)
227302 test_rrule (f, x, dims)
228303 test_frule (f, complex_x, dims)
229304 test_rrule (f, complex_x, dims)
230305 end
306+ for pf in (plan_fft, plan_ifft, plan_bfft)
307+ test_frule (* , pf (x, dims) ⊢ NoTangent (), x)
308+ test_rrule (* , pf (x, dims) ⊢ NoTangent (), x)
309+ test_frule (* , pf (complex_x, dims) ⊢ NoTangent (), complex_x)
310+ test_rrule (* , pf (complex_x, dims) ⊢ NoTangent (), complex_x)
311+ end
231312
313+ # rfft
232314 test_frule (rfft, x, dims)
233315 test_rrule (rfft, x, dims)
316+ test_frule (* , plan_rfft (x, dims) ⊢ NoTangent (), x)
317+ test_rrule (* , plan_rfft (x, dims) ⊢ NoTangent (), x)
234318
319+ # irfft, brfft
235320 for f in (irfft, brfft)
236321 for d in (2 * size (x, first (dims)) - 1 , 2 * size (x, first (dims)) - 2 )
237322 test_frule (f, x, d, dims)
240325 test_rrule (f, complex_x, d, dims)
241326 end
242327 end
328+ for pf in (plan_irfft, plan_brfft)
329+ for d in (2 * size (x, first (dims)) - 1 , 2 * size (x, first (dims)) - 2 )
330+ test_frule (* , pf (complex_x, d, dims) ⊢ NoTangent (), complex_x)
331+ test_rrule (* , pf (complex_x, d, dims) ⊢ NoTangent (), complex_x)
332+ end
333+ end
243334 end
244335 end
245336 end
0 commit comments