Skip to content

Commit 99dfd9d

Browse files
committed
get rid of direction
1 parent ec4aa09 commit 99dfd9d

File tree

1 file changed

+21
-36
lines changed

1 file changed

+21
-36
lines changed

mkl_fft/_pydfti.pyx

+21-36
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,10 @@ cdef cnp.ndarray _process_arguments(
224224
object x,
225225
object n,
226226
object axis,
227-
object direction,
228227
long *axis_,
229228
long *n_,
230229
int *in_place,
231230
int *xnd,
232-
int *dir_,
233231
int realQ,
234232
):
235233
"""
@@ -239,11 +237,6 @@ cdef cnp.ndarray _process_arguments(
239237
cdef long n_max = 0
240238
cdef cnp.ndarray x_arr "xx_arrayObject"
241239

242-
if direction not in [-1, +1]:
243-
raise ValueError("Direction of FFT should +1 or -1")
244-
else:
245-
dir_[0] = -1 if direction is -1 else +1
246-
247240
# convert x to ndarray, ensure that strides are multiples of itemsize
248241
x_arr = PyArray_CheckFromAny(
249242
x, NULL, 0, 0,
@@ -379,18 +372,18 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
379372
"""
380373
cdef cnp.ndarray x_arr "x_arrayObject"
381374
cdef cnp.ndarray f_arr "f_arrayObject"
382-
cdef int xnd, n_max = 0, in_place, dir_
375+
cdef int xnd, n_max = 0, in_place
383376
cdef long n_, axis_
384377
cdef int x_type, f_type, status = 0
385378
cdef int ALL_HARMONICS = 1
386379
cdef char * c_error_msg = NULL
387380
cdef bytes py_error_msg
388381
cdef DftiCache *_cache
389382

390-
x_arr = _process_arguments(
391-
x, n, axis, direction, &axis_, &n_, &in_place, &xnd, &dir_, 0
392-
)
383+
if direction not in [-1, +1]:
384+
raise ValueError("Direction of FFT should +1 or -1")
393385

386+
x_arr = _process_arguments(x, n, axis, &axis_, &n_, &in_place, &xnd, 0)
394387
x_type = cnp.PyArray_TYPE(x_arr)
395388

396389
if out is not None:
@@ -424,7 +417,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
424417
_cache_capsule, capsule_name
425418
)
426419
if x_type is cnp.NPY_CDOUBLE:
427-
if dir_ < 0:
420+
if direction < 0:
428421
status = cdouble_mkl_ifft1d_in(
429422
x_arr, n_, <int> axis_, fsc, _cache
430423
)
@@ -433,7 +426,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
433426
x_arr, n_, <int> axis_, fsc, _cache
434427
)
435428
elif x_type is cnp.NPY_CFLOAT:
436-
if dir_ < 0:
429+
if direction < 0:
437430
status = cfloat_mkl_ifft1d_in(
438431
x_arr, n_, <int> axis_, fsc, _cache
439432
)
@@ -482,7 +475,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
482475
)
483476
if f_type is cnp.NPY_CDOUBLE:
484477
if x_type is cnp.NPY_DOUBLE:
485-
if dir_ < 0:
478+
if direction < 0:
486479
status = double_cdouble_mkl_ifft1d_out(
487480
x_arr,
488481
n_,
@@ -503,7 +496,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
503496
_cache,
504497
)
505498
elif x_type is cnp.NPY_CDOUBLE:
506-
if dir_ < 0:
499+
if direction < 0:
507500
status = cdouble_cdouble_mkl_ifft1d_out(
508501
x_arr, n_, <int> axis_, f_arr, fsc, _cache
509502
)
@@ -513,7 +506,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
513506
)
514507
else:
515508
if x_type is cnp.NPY_FLOAT:
516-
if dir_ < 0:
509+
if direction < 0:
517510
status = float_cfloat_mkl_ifft1d_out(
518511
x_arr,
519512
n_,
@@ -534,7 +527,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
534527
_cache,
535528
)
536529
elif x_type is cnp.NPY_CFLOAT:
537-
if dir_ < 0:
530+
if direction < 0:
538531
status = cfloat_cfloat_mkl_ifft1d_out(
539532
x_arr, n_, <int> axis_, f_arr, fsc, _cache
540533
)
@@ -566,18 +559,15 @@ def _r2c_fft1d_impl(
566559
"""
567560
cdef cnp.ndarray x_arr "x_arrayObject"
568561
cdef cnp.ndarray f_arr "f_arrayObject"
569-
cdef int xnd, in_place, dir_
562+
cdef int xnd, in_place
570563
cdef long n_, axis_
571564
cdef int x_type, f_type, status, requirement
572565
cdef int HALF_HARMONICS = 0 # give only positive index harmonics
573-
cdef int direction = 1 # dummy, only used for the sake of arg-processing
574566
cdef char * c_error_msg = NULL
575567
cdef bytes py_error_msg
576568
cdef DftiCache *_cache
577569

578-
x_arr = _process_arguments(
579-
x, n, axis, direction, &axis_, &n_, &in_place, &xnd, &dir_, 1
580-
)
570+
x_arr = _process_arguments(x, n, axis, &axis_, &n_, &in_place, &xnd, 1)
581571

582572
x_type = cnp.PyArray_TYPE(x_arr)
583573

@@ -667,20 +657,17 @@ def _c2r_fft1d_impl(
667657
"""
668658
cdef cnp.ndarray x_arr "x_arrayObject"
669659
cdef cnp.ndarray f_arr "f_arrayObject"
670-
cdef int xnd, in_place, dir_, int_n
660+
cdef int xnd, in_place, int_n
671661
cdef long n_, axis_
672662
cdef int x_type, f_type, status
673-
cdef int direction = 1 # dummy, only used for the sake of arg-processing
674663
cdef char * c_error_msg = NULL
675664
cdef bytes py_error_msg
676665
cdef DftiCache *_cache
677666

678667
int_n = _is_integral(n)
679668
# nn gives the number elements along axis of the input that we use
680669
nn = (n // 2 + 1) if int_n and n > 0 else n
681-
x_arr = _process_arguments(
682-
x, nn, axis, direction, &axis_, &n_, &in_place, &xnd, &dir_, 0
683-
)
670+
x_arr = _process_arguments(x, nn, axis, &axis_, &n_, &in_place, &xnd, 0)
684671
n_ = 2*(n_ - 1)
685672
if int_n and (n % 2 == 1):
686673
n_ += 1
@@ -769,12 +756,10 @@ def _direct_fftnd(
769756
cdef int err
770757
cdef cnp.ndarray x_arr "xxnd_arrayObject"
771758
cdef cnp.ndarray f_arr "ffnd_arrayObject"
772-
cdef int dir_, in_place, x_type, f_type
759+
cdef int in_place, x_type, f_type
773760

774761
if direction not in [-1, +1]:
775762
raise ValueError("Direction of FFT should +1 or -1")
776-
else:
777-
dir_ = -1 if direction is -1 else +1
778763

779764
# convert x to ndarray, ensure that strides are multiples of itemsize
780765
x_arr = PyArray_CheckFromAny(
@@ -814,12 +799,12 @@ def _direct_fftnd(
814799

815800
if in_place:
816801
if x_type == cnp.NPY_CDOUBLE:
817-
if dir_ == 1:
802+
if direction == 1:
818803
err = cdouble_cdouble_mkl_fftnd_in(x_arr, fsc)
819804
else:
820805
err = cdouble_cdouble_mkl_ifftnd_in(x_arr, fsc)
821806
elif x_type == cnp.NPY_CFLOAT:
822-
if dir_ == 1:
807+
if direction == 1:
823808
err = cfloat_cfloat_mkl_fftnd_in(x_arr, fsc)
824809
else:
825810
err = cfloat_cfloat_mkl_ifftnd_in(x_arr, fsc)
@@ -846,22 +831,22 @@ def _direct_fftnd(
846831
f_arr = _allocate_result(x_arr, -1, 0, f_type)
847832

848833
if x_type == cnp.NPY_CDOUBLE:
849-
if dir_ == 1:
834+
if direction == 1:
850835
err = cdouble_cdouble_mkl_fftnd_out(x_arr, f_arr, fsc)
851836
else:
852837
err = cdouble_cdouble_mkl_ifftnd_out(x_arr, f_arr, fsc)
853838
elif x_type == cnp.NPY_CFLOAT:
854-
if dir_ == 1:
839+
if direction == 1:
855840
err = cfloat_cfloat_mkl_fftnd_out(x_arr, f_arr, fsc)
856841
else:
857842
err = cfloat_cfloat_mkl_ifftnd_out(x_arr, f_arr, fsc)
858843
elif x_type == cnp.NPY_DOUBLE:
859-
if dir_ == 1:
844+
if direction == 1:
860845
err = double_cdouble_mkl_fftnd_out(x_arr, f_arr, fsc)
861846
else:
862847
err = double_cdouble_mkl_ifftnd_out(x_arr, f_arr, fsc)
863848
elif x_type == cnp.NPY_FLOAT:
864-
if dir_ == 1:
849+
if direction == 1:
865850
err = float_cfloat_mkl_fftnd_out(x_arr, f_arr, fsc)
866851
else:
867852
err = float_cfloat_mkl_ifftnd_out(x_arr, f_arr, fsc)

0 commit comments

Comments
 (0)