@@ -262,23 +262,43 @@ def _iter_fftnd(
262
262
axes = None ,
263
263
out = None ,
264
264
direction = + 1 ,
265
- overwrite_x = False ,
266
- scale_function = lambda n , ind : 1.0 ,
265
+ scale_function = lambda ind : 1.0 ,
267
266
):
268
267
a = np .asarray (a )
269
268
s , axes = _init_nd_shape_and_axes (a , s , axes )
270
- ovwr = overwrite_x
271
- for ii in reversed (range (len (axes ))):
269
+
270
+ # Combine the two, but in reverse, to end with the first axis given.
271
+ axes_and_s = list (zip (axes , s ))[::- 1 ]
272
+ # We try to use in-place calculations where possible, which is
273
+ # everywhere except when the size changes after the first FFT.
274
+ size_changes = [axis for axis , n in axes_and_s [1 :] if a .shape [axis ] != n ]
275
+
276
+ # If there are any size changes, we cannot use out
277
+ res = None if size_changes else out
278
+ for ind , (axis , n ) in enumerate (axes_and_s ):
279
+ if axis in size_changes :
280
+ if axis == size_changes [- 1 ]:
281
+ # Last size change, so any output should now be OK
282
+ # (an error will be raised if not), and if no output is
283
+ # required, we want a freshly allocated array of the right size.
284
+ res = out
285
+ elif res is not None and n < res .shape [axis ]:
286
+ # For an intermediate step where we return fewer elements, we
287
+ # can use a smaller view of the previous array.
288
+ res = res [(slice (None ),) * axis + (slice (n ),)]
289
+ else :
290
+ # If we need more elements, we cannot use res.
291
+ res = None
272
292
a = _c2c_fft1d_impl (
273
293
a ,
274
- n = s [ii ],
275
- axis = axes [ii ],
276
- overwrite_x = ovwr ,
294
+ n = n ,
295
+ axis = axis ,
277
296
direction = direction ,
278
- fsc = scale_function (s [ ii ], ii ),
279
- out = out ,
297
+ fsc = scale_function (ind ),
298
+ out = res ,
280
299
)
281
- ovwr = True
300
+ # Default output for next iteration.
301
+ res = a
282
302
return a
283
303
284
304
@@ -360,7 +380,6 @@ def _c2c_fftnd_impl(
360
380
x ,
361
381
s = None ,
362
382
axes = None ,
363
- overwrite_x = False ,
364
383
direction = + 1 ,
365
384
fsc = 1.0 ,
366
385
out = None ,
@@ -385,7 +404,6 @@ def _c2c_fftnd_impl(
385
404
if _direct :
386
405
return _direct_fftnd (
387
406
x ,
388
- overwrite_x = overwrite_x ,
389
407
direction = direction ,
390
408
fsc = fsc ,
391
409
out = out ,
@@ -403,11 +421,7 @@ def _c2c_fftnd_impl(
403
421
x ,
404
422
axes ,
405
423
_direct_fftnd ,
406
- {
407
- "overwrite_x" : overwrite_x ,
408
- "direction" : direction ,
409
- "fsc" : fsc ,
410
- },
424
+ {"direction" : direction , "fsc" : fsc },
411
425
res ,
412
426
)
413
427
else :
@@ -418,97 +432,122 @@ def _c2c_fftnd_impl(
418
432
axes = axes ,
419
433
out = out ,
420
434
direction = direction ,
421
- overwrite_x = overwrite_x ,
422
- scale_function = lambda n , i : fsc if i == 0 else 1.0 ,
435
+ scale_function = lambda i : fsc if i == 0 else 1.0 ,
423
436
)
424
437
425
438
426
439
def _r2c_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
427
440
a = np .asarray (x )
428
441
no_trim = (s is None ) and (axes is None )
429
442
s , axes = _cook_nd_args (a , s , axes )
443
+ axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
430
444
la = axes [- 1 ]
445
+
431
446
# trim array, so that rfft avoids doing unnecessary computations
432
447
if not no_trim :
433
448
a = _trim_array (a , s , axes )
449
+
450
+ # last axis is not included since we calculate r2c FFT separately
451
+ # and not in the loop
452
+ axes_and_s = list (zip (axes , s ))[- 2 ::- 1 ]
453
+ size_changes = [axis for axis , n in axes_and_s if a .shape [axis ] != n ]
454
+ res = None if size_changes else out
455
+
434
456
# r2c along last axis
435
- a = _r2c_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = out )
457
+ a = _r2c_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = res )
458
+ res = a
436
459
if len (s ) > 1 :
437
- if not no_trim :
438
- ss = list (s )
439
- ss [- 1 ] = a .shape [la ]
440
- a = _pad_array (a , tuple (ss ), axes )
460
+
441
461
len_axes = len (axes )
442
462
if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
463
+ if not no_trim :
464
+ ss = list (s )
465
+ ss [- 1 ] = a .shape [la ]
466
+ a = _pad_array (a , tuple (ss ), axes )
443
467
# a series of ND c2c FFTs along last axis
444
468
ss , aa = _remove_axis (s , axes , - 1 )
445
- ind = [
446
- slice (None , None , 1 ),
447
- ] * len (s )
469
+ ind = [slice (None , None , 1 )] * len (s )
448
470
for ii in range (a .shape [la ]):
449
471
ind [la ] = ii
450
472
tind = tuple (ind )
451
473
a_inp = a [tind ]
452
- res = out [tind ] if out is not None else None
453
- a_res = _c2c_fftnd_impl (
454
- a_inp , s = ss , axes = aa , overwrite_x = True , direction = 1 , out = res
455
- )
456
- if a_res is not a_inp :
457
- a [tind ] = a_res # copy in place
474
+ res = out [tind ] if out is not None else a_inp
475
+ _ = _c2c_fftnd_impl (a_inp , s = ss , axes = aa , direction = 1 , out = res )
476
+ if out is not None :
477
+ a = out
458
478
else :
479
+ # another size_changes check is needed if there are repeated axes
480
+ # of last axis, since since FFT changes the shape along last axis
481
+ size_changes = [
482
+ axis for axis , n in axes_and_s if a .shape [axis ] != n
483
+ ]
484
+
459
485
# a series of 1D c2c FFTs along all axes except last
460
- for ii in range (len (axes ) - 2 , - 1 , - 1 ):
461
- a = _c2c_fft1d_impl (a , s [ii ], axes [ii ], overwrite_x = True )
486
+ for axis , n in axes_and_s :
487
+ if axis in size_changes :
488
+ if axis == size_changes [- 1 ]:
489
+ res = out
490
+ elif res is not None and n < res .shape [axis ]:
491
+ res = res [(slice (None ),) * axis + (slice (n ),)]
492
+ else :
493
+ res = None
494
+ a = _c2c_fft1d_impl (a , n , axis , out = res )
495
+ res = a
462
496
return a
463
497
464
498
465
499
def _c2r_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
466
500
a = np .asarray (x )
467
501
no_trim = (s is None ) and (axes is None )
468
502
s , axes = _cook_nd_args (a , s , axes , invreal = True )
503
+ axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
469
504
la = axes [- 1 ]
470
505
if not no_trim :
471
506
a = _trim_array (a , s , axes )
472
507
if len (s ) > 1 :
473
- if not no_trim :
474
- a = _pad_array (a , s , axes )
475
- ovr_x = True if _datacopied (a , x ) else False
476
508
len_axes = len (axes )
477
509
if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
510
+ if not no_trim :
511
+ a = _pad_array (a , s , axes )
478
512
# a series of ND c2c FFTs along last axis
479
513
# due to need to write into a, we must copy
480
- if not ovr_x :
481
- a = a .copy ()
482
- ovr_x = True
514
+ a = a if _datacopied (a , x ) else a .copy ()
483
515
if not np .issubdtype (a .dtype , np .complexfloating ):
484
516
# complex output will be copied to input, copy is needed
485
517
if a .dtype == np .float32 :
486
518
a = a .astype (np .complex64 )
487
519
else :
488
520
a = a .astype (np .complex128 )
489
- ovr_x = True
490
521
ss , aa = _remove_axis (s , axes , - 1 )
491
- ind = [
492
- slice (None , None , 1 ),
493
- ] * len (s )
522
+ ind = [slice (None , None , 1 )] * len (s )
494
523
for ii in range (a .shape [la ]):
495
524
ind [la ] = ii
496
525
tind = tuple (ind )
497
526
a_inp = a [tind ]
498
527
# out has real dtype and cannot be used in intermediate steps
499
- a_res = _c2c_fftnd_impl (
500
- a_inp , s = ss , axes = aa , overwrite_x = True , direction = - 1
528
+ # ss and aa are reversed since np.irfftn uses forward order but
529
+ # np.ifftn uses reverse order see numpy-gh-28950
530
+ _ = _c2c_fftnd_impl (
531
+ a_inp , s = ss [::- 1 ], axes = aa [::- 1 ], out = a_inp , direction = - 1
501
532
)
502
- if a_res is not a_inp :
503
- a [tind ] = a_res # copy in place
504
533
else :
505
534
# a series of 1D c2c FFTs along all axes except last
506
- for ii in range (len (axes ) - 1 ):
507
- # out has real dtype and cannot be used in intermediate steps
508
- a = _c2c_fft1d_impl (
509
- a , s [ii ], axes [ii ], overwrite_x = ovr_x , direction = - 1
510
- )
511
- ovr_x = True
535
+ # forward order, see numpy-gh-28950
536
+ axes_and_s = list (zip (axes , s ))[:- 1 ]
537
+ size_changes = [
538
+ axis for axis , n in axes_and_s [1 :] if a .shape [axis ] != n
539
+ ]
540
+ # out has real dtype cannot be used for intermediate steps
541
+ res = None
542
+ for axis , n in axes_and_s :
543
+ if axis in size_changes :
544
+ if res is not None and n < res .shape [axis ]:
545
+ # pylint: disable=unsubscriptable-object
546
+ res = res [(slice (None ),) * axis + (slice (n ),)]
547
+ else :
548
+ res = None
549
+ a = _c2c_fft1d_impl (a , n , axis , out = res , direction = - 1 )
550
+ res = a
512
551
# c2r along last axis
513
552
a = _c2r_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = out )
514
553
return a
0 commit comments