You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -34,13 +34,14 @@ returns an operator which performs an FFT on Arrays of type T
34
34
* `shape::Tuple` - size of the array to transform
35
35
* (`shift=true`) - if true, fftshifts are performed
36
36
* (`unitary=true`) - if true, FFT is normalized such that it is unitary
37
+
* (`S = Vector{T}`) - type of temporary vector, change to use on GPU
38
+
* (`kwargs...`) - keyword arguments given to fft plan
37
39
"""
38
-
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, cuda::Bool=false) where D
40
+
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Array{Complex{real(T)}}, kwargs...) where D
39
41
40
-
#tmpVec = cuda ? CuArray{T}(undef,shape) : Array{Complex{real(T)}}(undef, shape)
41
-
tmpVec =Array{Complex{real(T)}}(undef, shape)
42
-
plan =plan_fft!(tmpVec; flags=FFTW.MEASURE)
43
-
iplan =plan_bfft!(tmpVec; flags=FFTW.MEASURE)
42
+
tmpVec =similar(S(undef, 0), shape...)
43
+
plan =plan_fft!(tmpVec; kwargs...)
44
+
iplan =plan_bfft!(tmpVec; kwargs...)
44
45
45
46
if unitary
46
47
facF =T(1.0/sqrt(prod(shape)))
@@ -50,39 +51,25 @@ function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::
50
51
facB =T(1.0)
51
52
end
52
53
53
-
let shape_=shape, plan_=plan, iplan_=iplan, tmpVec_=tmpVec, facF_=facF, facB_=facB
54
+
let shape_=shape, plan_=plan, iplan_=iplan, tmpVec_=tmpVec, facF_=facF, facB_=facB
0 commit comments