@@ -117,7 +117,6 @@ def _complex_nd_fft(
117
117
out ,
118
118
forward ,
119
119
in_place ,
120
- c2c ,
121
120
axes ,
122
121
batch_fft ,
123
122
* ,
@@ -126,34 +125,38 @@ def _complex_nd_fft(
126
125
"""Computes complex-to-complex FFT of the input N-D array."""
127
126
128
127
len_axes = len (axes )
129
- # OneMKL supports up to 3-dimensional FFT on GPU
130
- # repeated axis in OneMKL FFT is not allowed
128
+ # oneMKL supports up to 3-dimensional FFT on GPU
129
+ # repeated axis in oneMKL FFT is not allowed
131
130
if len_axes > 3 or len (set (axes )) < len_axes :
132
131
axes_chunk , shape_chunk = _extract_axes_chunk (
133
132
axes , s , chunk_size = 3 , reversed_axes = reversed_axes
134
133
)
134
+
135
+ # We try to use in-place calculations where possible, which is
136
+ # everywhere except when the size changes after the first iteration.
137
+ size_changes = [axis for axis , n in zip (axes , s ) if a .shape [axis ] != n ]
138
+
139
+ # cannot use out in the intermediate steps if size changes
140
+ res = None if size_changes else out
141
+
135
142
for i , (s_chunk , a_chunk ) in enumerate (zip (shape_chunk , axes_chunk )):
136
143
a = _truncate_or_pad (a , shape = s_chunk , axes = a_chunk )
137
- # if out is used in an intermediate step, it will have memory
138
- # overlap with input and cannot be used in the final step (a new
139
- # result array will be created for the final step), so there is no
140
- # benefit in using out in an intermediate step
141
- if i == len (axes_chunk ) - 1 :
142
- tmp_out = out
143
- else :
144
- tmp_out = None
144
+ # if size_changes, out cannot be used in intermediate steps
145
+ if size_changes and i == len (axes_chunk ) - 1 :
146
+ res = out
145
147
146
148
a = _fft (
147
149
a ,
148
150
norm = norm ,
149
- out = tmp_out ,
151
+ out = res ,
150
152
forward = forward ,
151
- # TODO: in-place FFT is only implemented for c2c, see SAT-7154
152
- in_place = in_place and c2c ,
153
- c2c = c2c ,
153
+ in_place = in_place ,
154
+ c2c = True ,
154
155
axes = a_chunk ,
155
156
)
156
-
157
+ if not size_changes :
158
+ # Default output for next iteration.
159
+ res = a
157
160
return a
158
161
159
162
a = _truncate_or_pad (a , s , axes )
@@ -165,9 +168,8 @@ def _complex_nd_fft(
165
168
norm = norm ,
166
169
out = out ,
167
170
forward = forward ,
168
- # TODO: in-place FFT is only implemented for c2c, see SAT-7154
169
- in_place = in_place and c2c ,
170
- c2c = c2c ,
171
+ in_place = in_place ,
172
+ c2c = True ,
171
173
axes = axes ,
172
174
batch_fft = batch_fft ,
173
175
)
@@ -198,7 +200,7 @@ def _compute_result(dsc, a, out, forward, c2c, out_strides):
198
200
res_usm = dpnp .get_usm_ndarray (out )
199
201
result = out
200
202
else :
201
- # Result array that is used in OneMKL must have the exact same
203
+ # Result array that is used in oneMKL must have the exact same
202
204
# stride as input array
203
205
204
206
if c2c : # c2c FFT
@@ -277,9 +279,9 @@ def _copy_array(x, complex_input):
277
279
dtype = x .dtype
278
280
copy_flag = False
279
281
if numpy .min (x .strides ) < 0 :
280
- # negative stride is not allowed in OneMKL FFT
282
+ # negative stride is not allowed in oneMKL FFT
281
283
# TODO: support for negative strides will be added in the future
282
- # versions of OneMKL , see discussion in MKLD-17597
284
+ # versions of oneMKL , see discussion in MKLD-17597
283
285
copy_flag = True
284
286
285
287
if complex_input and not dpnp .issubdtype (dtype , dpnp .complexfloating ):
@@ -388,6 +390,9 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
388
390
389
391
index = 0
390
392
fft_1d = isinstance (axes , int )
393
+ if not in_place and out is not None :
394
+ # if input and output are the same array, use in-place FFT
395
+ in_place = dpnp .are_same_logical_tensors (a , out )
391
396
if batch_fft :
392
397
len_axes = 1 if fft_1d else len (axes )
393
398
local_axes = numpy .arange (- len_axes , 0 )
@@ -627,9 +632,6 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
627
632
_validate_out_keyword (a , out , (n ,), (axis ,), c2c , c2r , r2c )
628
633
# if input array is copied, in-place FFT can be used
629
634
a , in_place = _copy_array (a , c2c or c2r )
630
- if not in_place and out is not None :
631
- # if input is also given for out, in-place FFT can be used
632
- in_place = dpnp .are_same_logical_tensors (a , out )
633
635
634
636
if a .size == 0 :
635
637
return dpnp .get_result_array (a , out = out , casting = "same_kind" )
@@ -695,31 +697,30 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
695
697
)
696
698
697
699
if r2c :
698
- # a 1D real-to-complext FFT is performed on the last axis and then
700
+ size_changes = [axis for axis , n in zip (axes , s ) if a .shape [axis ] != n ]
701
+ # cannot use out in the intermediate steps if size changes
702
+ res = None if size_changes else out
703
+
704
+ # a 1D real-to-complex FFT is performed on the last axis and then
699
705
# an N-D complex-to-complex FFT over the remaining axes
700
706
a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
701
707
a = _fft (
702
708
a ,
703
709
norm = norm ,
704
- # if out is used in an intermediate step, it will have memory
705
- # overlap with input and cannot be used in the final step (a new
706
- # result array will be created for the final step), so there is no
707
- # benefit in using out in an intermediate step
708
- out = None ,
710
+ out = res ,
709
711
forward = forward ,
710
- in_place = in_place and c2c ,
711
- c2c = c2c ,
712
+ in_place = False ,
713
+ c2c = False ,
712
714
axes = axes [- 1 ],
713
715
batch_fft = a .ndim != 1 ,
714
716
)
715
717
return _complex_nd_fft (
716
718
a ,
717
- s = s ,
719
+ s = s [: - 1 ] ,
718
720
norm = norm ,
719
721
out = out ,
720
722
forward = forward ,
721
723
in_place = in_place ,
722
- c2c = True ,
723
724
axes = axes [:- 1 ],
724
725
batch_fft = a .ndim != len_axes - 1 ,
725
726
)
@@ -729,29 +730,25 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
729
730
# last one then a 1D complex-to-real FFT is performed on the last axis
730
731
a = _complex_nd_fft (
731
732
a ,
732
- s = s ,
733
+ s = s [: - 1 ] ,
733
734
norm = norm ,
734
735
# out has real dtype and cannot be used in intermediate steps
735
736
out = None ,
736
737
forward = forward ,
737
738
in_place = in_place ,
738
- c2c = True ,
739
739
axes = axes [:- 1 ],
740
740
batch_fft = a .ndim != len_axes - 1 ,
741
741
reversed_axes = False ,
742
742
)
743
743
a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
744
- if c2r :
745
- a = _make_array_hermitian (
746
- a , axes [- 1 ], dpnp .are_same_logical_tensors (a , a_orig )
747
- )
748
- return _fft (
749
- a , norm , out , forward , in_place and c2c , c2c , axes [- 1 ], a .ndim != 1
744
+ a = _make_array_hermitian (
745
+ a , axes [- 1 ], dpnp .are_same_logical_tensors (a , a_orig )
750
746
)
747
+ return _fft (a , norm , out , forward , False , False , axes [- 1 ], a .ndim != 1 )
751
748
752
749
# c2c
753
750
return _complex_nd_fft (
754
- a , s , norm , out , forward , in_place , c2c , axes , a .ndim != len_axes
751
+ a , s , norm , out , forward , in_place , axes , a .ndim != len_axes
755
752
)
756
753
757
754
0 commit comments