Skip to content

Commit ec4aa09

Browse files
committed
get rid of overwrite_x kwarg in mkl_fft, instead utilize out kwarg
1 parent 1503f18 commit ec4aa09

File tree

7 files changed

+176
-150
lines changed

7 files changed

+176
-150
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ While using these interfaces is the easiest way to leverage `mk_fft`, one can al
5353

5454
### complex-to-complex (c2c) transforms:
5555

56-
`fft(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0, out=None)` - 1D FFT, similar to `scipy.fft.fft`
56+
`fft(x, n=None, axis=-1, fwd_scale=1.0, out=None)` - 1D FFT, similar to `numpy.fft.fft`
5757

58-
`fft2(x, s=None, axes=(-2, -1), overwrite_x=False, fwd_scale=1.0, out=None)` - 2D FFT, similar to `scipy.fft.fft2`
58+
`fft2(x, s=None, axes=(-2, -1), fwd_scale=1.0, out=None)` - 2D FFT, similar to `numpy.fft.fft2`
5959

60-
`fftn(x, s=None, axes=None, overwrite_x=False, fwd_scale=1.0, out=None)` - ND FFT, similar to `scipy.fft.fftn`
60+
`fftn(x, s=None, axes=None, fwd_scale=1.0, out=None)` - ND FFT, similar to `numpy.fft.fftn`
6161

6262
and similar inverse FFT (`ifft*`) functions.
6363

mkl_fft/_fft_utils.py

+93-55
Original file line numberDiff line numberDiff line change
@@ -262,23 +262,43 @@ def _iter_fftnd(
262262
axes=None,
263263
out=None,
264264
direction=+1,
265-
overwrite_x=False,
266-
scale_function=lambda n, ind: 1.0,
265+
scale_function=lambda ind: 1.0,
267266
):
268267
a = np.asarray(a)
269268
s, axes = _init_nd_shape_and_axes(a, s, axes)
270-
ovwr = overwrite_x
271-
for ii in reversed(range(len(axes))):
269+
270+
# Combine the two, but in reverse, to end with the first axis given.
271+
axes_and_s = list(zip(axes, s))[::-1]
272+
# We try to use in-place calculations where possible, which is
273+
# everywhere except when the size changes after the first FFT.
274+
size_changes = [axis for axis, n in axes_and_s[1:] if a.shape[axis] != n]
275+
276+
# If there are any size changes, we cannot use out
277+
res = None if size_changes else out
278+
for ind, (axis, n) in enumerate(axes_and_s):
279+
if axis in size_changes:
280+
if axis == size_changes[-1]:
281+
# Last size change, so any output should now be OK
282+
# (an error will be raised if not), and if no output is
283+
# required, we want a freshly allocated array of the right size.
284+
res = out
285+
elif res is not None and n < res.shape[axis]:
286+
# For an intermediate step where we return fewer elements, we
287+
# can use a smaller view of the previous array.
288+
res = res[(slice(None),) * axis + (slice(n),)]
289+
else:
290+
# If we need more elements, we cannot use res.
291+
res = None
272292
a = _c2c_fft1d_impl(
273293
a,
274-
n=s[ii],
275-
axis=axes[ii],
276-
overwrite_x=ovwr,
294+
n=n,
295+
axis=axis,
277296
direction=direction,
278-
fsc=scale_function(s[ii], ii),
279-
out=out,
297+
fsc=scale_function(ind),
298+
out=res,
280299
)
281-
ovwr = True
300+
# Default output for next iteration.
301+
res = a
282302
return a
283303

284304

@@ -360,7 +380,6 @@ def _c2c_fftnd_impl(
360380
x,
361381
s=None,
362382
axes=None,
363-
overwrite_x=False,
364383
direction=+1,
365384
fsc=1.0,
366385
out=None,
@@ -385,7 +404,6 @@ def _c2c_fftnd_impl(
385404
if _direct:
386405
return _direct_fftnd(
387406
x,
388-
overwrite_x=overwrite_x,
389407
direction=direction,
390408
fsc=fsc,
391409
out=out,
@@ -403,11 +421,7 @@ def _c2c_fftnd_impl(
403421
x,
404422
axes,
405423
_direct_fftnd,
406-
{
407-
"overwrite_x": overwrite_x,
408-
"direction": direction,
409-
"fsc": fsc,
410-
},
424+
{"direction": direction, "fsc": fsc},
411425
res,
412426
)
413427
else:
@@ -418,97 +432,121 @@ def _c2c_fftnd_impl(
418432
axes=axes,
419433
out=out,
420434
direction=direction,
421-
overwrite_x=overwrite_x,
422-
scale_function=lambda n, i: fsc if i == 0 else 1.0,
435+
scale_function=lambda i: fsc if i == 0 else 1.0,
423436
)
424437

425438

426439
def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
427440
a = np.asarray(x)
428441
no_trim = (s is None) and (axes is None)
429442
s, axes = _cook_nd_args(a, s, axes)
443+
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
430444
la = axes[-1]
445+
431446
# trim array, so that rfft avoids doing unnecessary computations
432447
if not no_trim:
433448
a = _trim_array(a, s, axes)
449+
450+
# last axis is not included since we calculate FT sepaartely and it does not come in loop
451+
axes_and_s = list(zip(axes, s))[-2::-1]
452+
size_changes = [axis for axis, n in axes_and_s if a.shape[axis] != n]
453+
res = None if size_changes else out
454+
434455
# r2c along last axis
435-
a = _r2c_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=out)
456+
a = _r2c_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=res)
457+
res = a
436458
if len(s) > 1:
437-
if not no_trim:
438-
ss = list(s)
439-
ss[-1] = a.shape[la]
440-
a = _pad_array(a, tuple(ss), axes)
459+
441460
len_axes = len(axes)
442461
if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2:
462+
if not no_trim:
463+
ss = list(s)
464+
ss[-1] = a.shape[la]
465+
a = _pad_array(a, tuple(ss), axes)
443466
# a series of ND c2c FFTs along last axis
444467
ss, aa = _remove_axis(s, axes, -1)
445-
ind = [
446-
slice(None, None, 1),
447-
] * len(s)
468+
ind = [slice(None, None, 1)] * len(s)
448469
for ii in range(a.shape[la]):
449470
ind[la] = ii
450471
tind = tuple(ind)
451472
a_inp = a[tind]
452-
res = out[tind] if out is not None else None
453-
a_res = _c2c_fftnd_impl(
454-
a_inp, s=ss, axes=aa, overwrite_x=True, direction=1, out=res
455-
)
456-
if a_res is not a_inp:
457-
a[tind] = a_res # copy in place
473+
res = out[tind] if out is not None else a_inp
474+
_ = _c2c_fftnd_impl(a_inp, s=ss, axes=aa, direction=1, out=res)
475+
if out is not None:
476+
a = out
458477
else:
478+
# another size_changes check is needed if there are repeated axes
479+
# of last axis, since since FFT changes the shape along last axis
480+
size_changes = [
481+
axis for axis, n in axes_and_s if a.shape[axis] != n
482+
]
483+
459484
# a series of 1D c2c FFTs along all axes except last
460-
for ii in range(len(axes) - 2, -1, -1):
461-
a = _c2c_fft1d_impl(a, s[ii], axes[ii], overwrite_x=True)
485+
for axis, n in axes_and_s:
486+
if axis in size_changes:
487+
if axis == size_changes[-1]:
488+
res = out
489+
elif res is not None and n < res.shape[axis]:
490+
res = res[(slice(None),) * axis + (slice(n),)]
491+
else:
492+
res = None
493+
a = _c2c_fft1d_impl(a, n, axis, out=res)
494+
res = a
462495
return a
463496

464497

465498
def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
466499
a = np.asarray(x)
467500
no_trim = (s is None) and (axes is None)
468501
s, axes = _cook_nd_args(a, s, axes, invreal=True)
502+
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
469503
la = axes[-1]
470504
if not no_trim:
471505
a = _trim_array(a, s, axes)
472506
if len(s) > 1:
473-
if not no_trim:
474-
a = _pad_array(a, s, axes)
475-
ovr_x = True if _datacopied(a, x) else False
476507
len_axes = len(axes)
477508
if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2:
509+
if not no_trim:
510+
a = _pad_array(a, s, axes)
478511
# a series of ND c2c FFTs along last axis
479512
# due to need to write into a, we must copy
480-
if not ovr_x:
481-
a = a.copy()
482-
ovr_x = True
513+
a = a if _datacopied(a, x) else a.copy()
483514
if not np.issubdtype(a.dtype, np.complexfloating):
484515
# complex output will be copied to input, copy is needed
485516
if a.dtype == np.float32:
486517
a = a.astype(np.complex64)
487518
else:
488519
a = a.astype(np.complex128)
489-
ovr_x = True
490520
ss, aa = _remove_axis(s, axes, -1)
491-
ind = [
492-
slice(None, None, 1),
493-
] * len(s)
521+
ind = [slice(None, None, 1)] * len(s)
494522
for ii in range(a.shape[la]):
495523
ind[la] = ii
496524
tind = tuple(ind)
497525
a_inp = a[tind]
498526
# out has real dtype and cannot be used in intermediate steps
499-
a_res = _c2c_fftnd_impl(
500-
a_inp, s=ss, axes=aa, overwrite_x=True, direction=-1
527+
# ss and aa are reversed since np.irfftn uses forward order but
528+
# np.ifftn uses reverse order see numpy-gh-28950
529+
_ = _c2c_fftnd_impl(
530+
a_inp, s=ss[::-1], axes=aa[::-1], out=a_inp, direction=-1
501531
)
502-
if a_res is not a_inp:
503-
a[tind] = a_res # copy in place
504532
else:
505533
# a series of 1D c2c FFTs along all axes except last
506-
for ii in range(len(axes) - 1):
507-
# out has real dtype and cannot be used in intermediate steps
508-
a = _c2c_fft1d_impl(
509-
a, s[ii], axes[ii], overwrite_x=ovr_x, direction=-1
510-
)
511-
ovr_x = True
534+
# forward order, see numpy-gh-28950
535+
axes_and_s = list(zip(axes, s))[:-1]
536+
size_changes = [
537+
axis for axis, n in axes_and_s[1:] if a.shape[axis] != n
538+
]
539+
# out has real dtype cannot be used for intermediate steps
540+
res = None
541+
for axis, n in axes_and_s:
542+
if axis in size_changes:
543+
if res is not None and n < res.shape[axis]:
544+
# pylint: disable=unsubscriptable-object
545+
res = res[(slice(None),) * axis + (slice(n),)]
546+
else:
547+
res = None
548+
a = _c2c_fft1d_impl(a, n, axis, out=res, direction=-1)
549+
res = a
512550
# c2r along last axis
513551
a = _c2r_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=out)
514552
return a

mkl_fft/_mkl_fft.py

+12-40
Original file line numberDiff line numberDiff line change
@@ -45,63 +45,35 @@
4545
]
4646

4747

48-
def fft(x, n=None, axis=-1, out=None, overwrite_x=False, fwd_scale=1.0):
48+
def fft(x, n=None, axis=-1, out=None, fwd_scale=1.0):
4949
return _c2c_fft1d_impl(
50-
x,
51-
n=n,
52-
axis=axis,
53-
out=out,
54-
overwrite_x=overwrite_x,
55-
direction=+1,
56-
fsc=fwd_scale,
50+
x, n=n, axis=axis, out=out, direction=+1, fsc=fwd_scale
5751
)
5852

5953

60-
def ifft(x, n=None, axis=-1, out=None, overwrite_x=False, fwd_scale=1.0):
54+
def ifft(x, n=None, axis=-1, out=None, fwd_scale=1.0):
6155
return _c2c_fft1d_impl(
62-
x,
63-
n=n,
64-
axis=axis,
65-
out=out,
66-
overwrite_x=overwrite_x,
67-
direction=-1,
68-
fsc=fwd_scale,
56+
x, n=n, axis=axis, out=out, direction=-1, fsc=fwd_scale
6957
)
7058

7159

72-
def fft2(x, s=None, axes=(-2, -1), out=None, overwrite_x=False, fwd_scale=1.0):
73-
return fftn(
74-
x, s=s, axes=axes, out=out, overwrite_x=overwrite_x, fwd_scale=fwd_scale
75-
)
60+
def fft2(x, s=None, axes=(-2, -1), out=None, fwd_scale=1.0):
61+
return fftn(x, s=s, axes=axes, out=out, fwd_scale=fwd_scale)
7662

7763

78-
def ifft2(x, s=None, axes=(-2, -1), out=None, overwrite_x=False, fwd_scale=1.0):
79-
return ifftn(
80-
x, s=s, axes=axes, out=out, overwrite_x=overwrite_x, fwd_scale=fwd_scale
81-
)
64+
def ifft2(x, s=None, axes=(-2, -1), out=None, fwd_scale=1.0):
65+
return ifftn(x, s=s, axes=axes, out=out, fwd_scale=fwd_scale)
8266

8367

84-
def fftn(x, s=None, axes=None, out=None, overwrite_x=False, fwd_scale=1.0):
68+
def fftn(x, s=None, axes=None, out=None, fwd_scale=1.0):
8569
return _c2c_fftnd_impl(
86-
x,
87-
s=s,
88-
axes=axes,
89-
out=out,
90-
overwrite_x=overwrite_x,
91-
direction=+1,
92-
fsc=fwd_scale,
70+
x, s=s, axes=axes, out=out, direction=+1, fsc=fwd_scale
9371
)
9472

9573

96-
def ifftn(x, s=None, axes=None, out=None, overwrite_x=False, fwd_scale=1.0):
74+
def ifftn(x, s=None, axes=None, out=None, fwd_scale=1.0):
9775
return _c2c_fftnd_impl(
98-
x,
99-
s=s,
100-
axes=axes,
101-
out=out,
102-
overwrite_x=overwrite_x,
103-
direction=-1,
104-
fsc=fwd_scale,
76+
x, s=s, axes=axes, out=out, direction=-1, fsc=fwd_scale
10577
)
10678

10779

0 commit comments

Comments
 (0)