Skip to content

Commit 38595d0

Browse files
author
Vahid Tavanashad
committed
address comments
1 parent 9ad7238 commit 38595d0

File tree

2 files changed

+11
-33
lines changed

2 files changed

+11
-33
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ This release achieves 100% compliance with Python Array API specification (revis
2929
* Removed `einsum_call` keyword from `dpnp.einsum_path` signature [#2421](https://github.com/IntelPython/dpnp/pull/2421)
3030
* Changed `"max dimensions"` to `None` in array API capabilities [#2432](https://github.com/IntelPython/dpnp/pull/2432)
3131
* Updated kernel header `i0.hpp` to expose `cyl_bessel_i0` function depending on build target [#2440](https://github.com/IntelPython/dpnp/pull/2440)
32-
* Updated FFT module to make input array Hermitian before calling complex-to-real FFT [#2444](https://github.com/IntelPython/dpnp/pull/2444)
32+
* Updated FFT module to ensure an input array is Hermitian before calling complex-to-real FFT [#2444](https://github.com/IntelPython/dpnp/pull/2444)
3333

3434
### Fixed
3535

dpnp/fft/dpnp_utils_fft.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -285,30 +285,12 @@ def _copy_array(x, complex_input):
285285
dtype = map_dtype_to_device(dpnp.float64, x.sycl_device)
286286

287287
if copy_flag:
288-
x = _copy_kernel(x, dtype)
288+
x = x.astype(dtype, order="C", copy=True)
289289

290290
# if copying is done, FFT can be in-place (copy_flag = in_place flag)
291291
return x, copy_flag
292292

293293

294-
def _copy_kernel(x, dtype):
295-
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")
296-
297-
exec_q = x.sycl_queue
298-
_manager = dpu.SequentialOrderManager[exec_q]
299-
dep_evs = _manager.submitted_events
300-
301-
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
302-
src=dpnp.get_usm_ndarray(x),
303-
dst=x_copy.get_array(),
304-
sycl_queue=exec_q,
305-
depends=dep_evs,
306-
)
307-
_manager.add_event_pair(ht_copy_ev, copy_ev)
308-
309-
return x_copy
310-
311-
312294
def _extract_axes_chunk(a, s, chunk_size=3):
313295
"""
314296
Classify the first input into a list of lists with each list containing
@@ -438,9 +420,9 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
438420
return result
439421

440422

441-
def _make_array_hermitian(a, n, copy_needed):
423+
def _make_array_hermitian(a, axis, copy_needed):
442424
"""
443-
For `dpnp.fft.irfft`, the input array should be Hermitian. If it is not,
425+
For complex-to-real FFT, the input array should be Hermitian. If it is not,
444426
the behavior is undefined. This function makes necessary changes to make
445427
sure the given array is Hermitian.
446428
@@ -449,6 +431,8 @@ def _make_array_hermitian(a, n, copy_needed):
449431
`_truncate_or_pad`, so the array has enough length.
450432
"""
451433

434+
a = dpnp.moveaxis(a, axis, 0)
435+
n = a.shape[0]
452436
length_is_even = n % 2 == 0
453437
hermitian = dpnp.all(a[0].imag == 0)
454438
assert n is not None
@@ -463,14 +447,14 @@ def _make_array_hermitian(a, n, copy_needed):
463447

464448
if not hermitian:
465449
if copy_needed:
466-
a = _copy_kernel(a, a.dtype)
450+
a = a.astype(a.dtype, order="C", copy=True)
467451

468452
a[0].imag = 0
469453
if length_is_even:
470454
f_ny = n // 2
471455
a[f_ny].imag = 0
472456

473-
return a
457+
return dpnp.moveaxis(a, 0, axis)
474458

475459

476460
def _scale_result(res, a_shape, norm, forward, index):
@@ -634,11 +618,9 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
634618

635619
if c2r:
636620
# input array should be Hermitian for c2r FFT
637-
a = dpnp.moveaxis(a, axis, 0)
638621
a = _make_array_hermitian(
639-
a, a.shape[0], dpnp.are_same_logical_tensors(a, a_orig)
622+
a, axis, dpnp.are_same_logical_tensors(a, a_orig)
640623
)
641-
a = dpnp.moveaxis(a, 0, axis)
642624

643625
return _fft(
644626
a,
@@ -687,11 +669,9 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
687669
if len_axes == 1:
688670
a = _truncate_or_pad(a, (s[-1],), (axes[-1],))
689671
if c2r:
690-
a = dpnp.moveaxis(a, axes[-1], 0)
691672
a = _make_array_hermitian(
692-
a, a.shape[0], dpnp.are_same_logical_tensors(a, a_orig)
673+
a, axes[-1], dpnp.are_same_logical_tensors(a, a_orig)
693674
)
694-
a = dpnp.moveaxis(a, 0, axes[-1])
695675
return _fft(
696676
a, norm, out, forward, in_place and c2c, c2c, axes[-1], a.ndim != 1
697677
)
@@ -743,11 +723,9 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
743723
)
744724
a = _truncate_or_pad(a, (s[-1],), (axes[-1],))
745725
if c2r:
746-
a = dpnp.moveaxis(a, axes[-1], 0)
747726
a = _make_array_hermitian(
748-
a, a.shape[0], dpnp.are_same_logical_tensors(a, a_orig)
727+
a, axes[-1], dpnp.are_same_logical_tensors(a, a_orig)
749728
)
750-
a = dpnp.moveaxis(a, 0, axes[-1])
751729
return _fft(
752730
a, norm, out, forward, in_place and c2c, c2c, axes[-1], a.ndim != 1
753731
)

0 commit comments

Comments
 (0)