Skip to content

Commit 83f375a

Browse files
committed
use input as out in ND FFT when possible
1 parent 0aca8ab commit 83f375a

File tree

2 files changed

+42
-44
lines changed

2 files changed

+42
-44
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3232
* Improved documentations of `dpnp.ndarray` class and added a page with description of supported constants [#2422](https://github.com/IntelPython/dpnp/pull/2422)
3333
* Updated `dpnp.size` to accept tuple of ints for `axes` argument [#2536](https://github.com/IntelPython/dpnp/pull/2536)
3434
* Replaced `ci` section in `.pre-commit-config.yaml` with a new GitHub workflow with scheduled run to autoupdate the `pre-commit` configuration [#2542](https://github.com/IntelPython/dpnp/pull/2542)
35+
* FFT module is updated to perform in-place FFT in intermediate steps of ND FFT [#2543](https://github.com/IntelPython/dpnp/pull/2543)
3536

3637
### Deprecated
3738

dpnp/fft/dpnp_utils_fft.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def _complex_nd_fft(
117117
out,
118118
forward,
119119
in_place,
120-
c2c,
121120
axes,
122121
batch_fft,
123122
*,
@@ -126,34 +125,38 @@ def _complex_nd_fft(
126125
"""Computes complex-to-complex FFT of the input N-D array."""
127126

128127
len_axes = len(axes)
129-
# OneMKL supports up to 3-dimensional FFT on GPU
130-
# repeated axis in OneMKL FFT is not allowed
128+
# oneMKL supports up to 3-dimensional FFT on GPU
129+
# repeated axis in oneMKL FFT is not allowed
131130
if len_axes > 3 or len(set(axes)) < len_axes:
132131
axes_chunk, shape_chunk = _extract_axes_chunk(
133132
axes, s, chunk_size=3, reversed_axes=reversed_axes
134133
)
134+
135+
# We try to use in-place calculations where possible, which is
136+
# everywhere except when the size changes after the first iteration.
137+
size_changes = [axis for axis, n in zip(axes, s) if a.shape[axis] != n]
138+
139+
# cannot use out in the intermediate steps if size changes
140+
res = None if size_changes else out
141+
135142
for i, (s_chunk, a_chunk) in enumerate(zip(shape_chunk, axes_chunk)):
136143
a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk)
137-
# if out is used in an intermediate step, it will have memory
138-
# overlap with input and cannot be used in the final step (a new
139-
# result array will be created for the final step), so there is no
140-
# benefit in using out in an intermediate step
141-
if i == len(axes_chunk) - 1:
142-
tmp_out = out
143-
else:
144-
tmp_out = None
144+
# if size_changes, out cannot be used in intermediate steps
145+
if size_changes and i == len(axes_chunk) - 1:
146+
res = out
145147

146148
a = _fft(
147149
a,
148150
norm=norm,
149-
out=tmp_out,
151+
out=res,
150152
forward=forward,
151-
# TODO: in-place FFT is only implemented for c2c, see SAT-7154
152-
in_place=in_place and c2c,
153-
c2c=c2c,
153+
in_place=in_place,
154+
c2c=True,
154155
axes=a_chunk,
155156
)
156-
157+
if not size_changes:
158+
# Default output for next iteration.
159+
res = a
157160
return a
158161

159162
a = _truncate_or_pad(a, s, axes)
@@ -165,9 +168,8 @@ def _complex_nd_fft(
165168
norm=norm,
166169
out=out,
167170
forward=forward,
168-
# TODO: in-place FFT is only implemented for c2c, see SAT-7154
169-
in_place=in_place and c2c,
170-
c2c=c2c,
171+
in_place=in_place,
172+
c2c=True,
171173
axes=axes,
172174
batch_fft=batch_fft,
173175
)
@@ -198,7 +200,7 @@ def _compute_result(dsc, a, out, forward, c2c, out_strides):
198200
res_usm = dpnp.get_usm_ndarray(out)
199201
result = out
200202
else:
201-
# Result array that is used in OneMKL must have the exact same
203+
# Result array that is used in oneMKL must have the exact same
202204
# stride as input array
203205

204206
if c2c: # c2c FFT
@@ -277,9 +279,9 @@ def _copy_array(x, complex_input):
277279
dtype = x.dtype
278280
copy_flag = False
279281
if numpy.min(x.strides) < 0:
280-
# negative stride is not allowed in OneMKL FFT
282+
# negative stride is not allowed in oneMKL FFT
281283
# TODO: support for negative strides will be added in the future
282-
# versions of OneMKL, see discussion in MKLD-17597
284+
# versions of oneMKL, see discussion in MKLD-17597
283285
copy_flag = True
284286

285287
if complex_input and not dpnp.issubdtype(dtype, dpnp.complexfloating):
@@ -388,6 +390,9 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
388390

389391
index = 0
390392
fft_1d = isinstance(axes, int)
393+
if not in_place and out is not None:
394+
# if input and output are the same array, use in-place FFT
395+
in_place = dpnp.are_same_logical_tensors(a, out)
391396
if batch_fft:
392397
len_axes = 1 if fft_1d else len(axes)
393398
local_axes = numpy.arange(-len_axes, 0)
@@ -627,9 +632,6 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
627632
_validate_out_keyword(a, out, (n,), (axis,), c2c, c2r, r2c)
628633
# if input array is copied, in-place FFT can be used
629634
a, in_place = _copy_array(a, c2c or c2r)
630-
if not in_place and out is not None:
631-
# if input is also given for out, in-place FFT can be used
632-
in_place = dpnp.are_same_logical_tensors(a, out)
633635

634636
if a.size == 0:
635637
return dpnp.get_result_array(a, out=out, casting="same_kind")
@@ -695,31 +697,30 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
695697
)
696698

697699
if r2c:
698-
# a 1D real-to-complext FFT is performed on the last axis and then
700+
size_changes = [axis for axis, n in zip(axes, s) if a.shape[axis] != n]
701+
# cannot use out in the intermediate steps if size changes
702+
res = None if size_changes else out
703+
704+
# a 1D real-to-complex FFT is performed on the last axis and then
699705
# an N-D complex-to-complex FFT over the remaining axes
700706
a = _truncate_or_pad(a, (s[-1],), (axes[-1],))
701707
a = _fft(
702708
a,
703709
norm=norm,
704-
# if out is used in an intermediate step, it will have memory
705-
# overlap with input and cannot be used in the final step (a new
706-
# result array will be created for the final step), so there is no
707-
# benefit in using out in an intermediate step
708-
out=None,
710+
out=res,
709711
forward=forward,
710-
in_place=in_place and c2c,
711-
c2c=c2c,
712+
in_place=False,
713+
c2c=False,
712714
axes=axes[-1],
713715
batch_fft=a.ndim != 1,
714716
)
715717
return _complex_nd_fft(
716718
a,
717-
s=s,
719+
s=s[:-1],
718720
norm=norm,
719721
out=out,
720722
forward=forward,
721723
in_place=in_place,
722-
c2c=True,
723724
axes=axes[:-1],
724725
batch_fft=a.ndim != len_axes - 1,
725726
)
@@ -729,29 +730,25 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
729730
# last one then a 1D complex-to-real FFT is performed on the last axis
730731
a = _complex_nd_fft(
731732
a,
732-
s=s,
733+
s=s[:-1],
733734
norm=norm,
734735
# out has real dtype and cannot be used in intermediate steps
735736
out=None,
736737
forward=forward,
737738
in_place=in_place,
738-
c2c=True,
739739
axes=axes[:-1],
740740
batch_fft=a.ndim != len_axes - 1,
741741
reversed_axes=False,
742742
)
743743
a = _truncate_or_pad(a, (s[-1],), (axes[-1],))
744-
if c2r:
745-
a = _make_array_hermitian(
746-
a, axes[-1], dpnp.are_same_logical_tensors(a, a_orig)
747-
)
748-
return _fft(
749-
a, norm, out, forward, in_place and c2c, c2c, axes[-1], a.ndim != 1
744+
a = _make_array_hermitian(
745+
a, axes[-1], dpnp.are_same_logical_tensors(a, a_orig)
750746
)
747+
return _fft(a, norm, out, forward, False, False, axes[-1], a.ndim != 1)
751748

752749
# c2c
753750
return _complex_nd_fft(
754-
a, s, norm, out, forward, in_place, c2c, axes, a.ndim != len_axes
751+
a, s, norm, out, forward, in_place, axes, a.ndim != len_axes
755752
)
756753

757754

0 commit comments

Comments
 (0)