@@ -285,30 +285,12 @@ def _copy_array(x, complex_input):
285285 dtype = map_dtype_to_device (dpnp .float64 , x .sycl_device )
286286
287287 if copy_flag :
288- x = _copy_kernel ( x , dtype )
288+ x = x . astype ( dtype , order = "C" , copy = True )
289289
290290 # if copying is done, FFT can be in-place (copy_flag = in_place flag)
291291 return x , copy_flag
292292
293293
294- def _copy_kernel (x , dtype ):
295- x_copy = dpnp .empty_like (x , dtype = dtype , order = "C" )
296-
297- exec_q = x .sycl_queue
298- _manager = dpu .SequentialOrderManager [exec_q ]
299- dep_evs = _manager .submitted_events
300-
301- ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
302- src = dpnp .get_usm_ndarray (x ),
303- dst = x_copy .get_array (),
304- sycl_queue = exec_q ,
305- depends = dep_evs ,
306- )
307- _manager .add_event_pair (ht_copy_ev , copy_ev )
308-
309- return x_copy
310-
311-
312294def _extract_axes_chunk (a , s , chunk_size = 3 ):
313295 """
314296 Classify the first input into a list of lists with each list containing
@@ -438,9 +420,9 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
438420 return result
439421
440422
441- def _make_array_hermitian (a , n , copy_needed ):
423+ def _make_array_hermitian (a , axis , copy_needed ):
442424 """
443- For `dpnp.fft.irfft` , the input array should be Hermitian. If it is not,
425+ For complex-to-real FFT , the input array should be Hermitian. If it is not,
444426 the behavior is undefined. This function makes necessary changes to make
445427 sure the given array is Hermitian.
446428
@@ -449,6 +431,8 @@ def _make_array_hermitian(a, n, copy_needed):
449431 `_truncate_or_pad`, so the array has enough length.
450432 """
451433
434+ a = dpnp .moveaxis (a , axis , 0 )
435+ n = a .shape [0 ]
452436 length_is_even = n % 2 == 0
453437 hermitian = dpnp .all (a [0 ].imag == 0 )
454438 assert n is not None
@@ -463,14 +447,14 @@ def _make_array_hermitian(a, n, copy_needed):
463447
464448 if not hermitian :
465449 if copy_needed :
466- a = _copy_kernel ( a , a .dtype )
450+ a = a . astype ( a .dtype , order = "C" , copy = True )
467451
468452 a [0 ].imag = 0
469453 if length_is_even :
470454 f_ny = n // 2
471455 a [f_ny ].imag = 0
472456
473- return a
457+ return dpnp . moveaxis ( a , 0 , axis )
474458
475459
476460def _scale_result (res , a_shape , norm , forward , index ):
@@ -634,11 +618,9 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
634618
635619 if c2r :
636620 # input array should be Hermitian for c2r FFT
637- a = dpnp .moveaxis (a , axis , 0 )
638621 a = _make_array_hermitian (
639- a , a . shape [ 0 ] , dpnp .are_same_logical_tensors (a , a_orig )
622+ a , axis , dpnp .are_same_logical_tensors (a , a_orig )
640623 )
641- a = dpnp .moveaxis (a , 0 , axis )
642624
643625 return _fft (
644626 a ,
@@ -687,11 +669,9 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
687669 if len_axes == 1 :
688670 a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
689671 if c2r :
690- a = dpnp .moveaxis (a , axes [- 1 ], 0 )
691672 a = _make_array_hermitian (
692- a , a . shape [ 0 ], dpnp .are_same_logical_tensors (a , a_orig )
673+ a , axes [ - 1 ], dpnp .are_same_logical_tensors (a , a_orig )
693674 )
694- a = dpnp .moveaxis (a , 0 , axes [- 1 ])
695675 return _fft (
696676 a , norm , out , forward , in_place and c2c , c2c , axes [- 1 ], a .ndim != 1
697677 )
@@ -743,11 +723,9 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
743723 )
744724 a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
745725 if c2r :
746- a = dpnp .moveaxis (a , axes [- 1 ], 0 )
747726 a = _make_array_hermitian (
748- a , a . shape [ 0 ], dpnp .are_same_logical_tensors (a , a_orig )
727+ a , axes [ - 1 ], dpnp .are_same_logical_tensors (a , a_orig )
749728 )
750- a = dpnp .moveaxis (a , 0 , axes [- 1 ])
751729 return _fft (
752730 a , norm , out , forward , in_place and c2c , c2c , axes [- 1 ], a .ndim != 1
753731 )
0 commit comments