@@ -285,30 +285,12 @@ def _copy_array(x, complex_input):
285285 dtype = map_dtype_to_device (dpnp .float64 , x .sycl_device )
286286
287287 if copy_flag :
288- x = _copy_kernel ( x , dtype )
288+ x = x . astype ( dtype , copy = True )
289289
290290 # if copying is done, FFT can be in-place (copy_flag = in_place flag)
291291 return x , copy_flag
292292
293293
294- def _copy_kernel (x , dtype ):
295- x_copy = dpnp .empty_like (x , dtype = dtype , order = "C" )
296-
297- exec_q = x .sycl_queue
298- _manager = dpu .SequentialOrderManager [exec_q ]
299- dep_evs = _manager .submitted_events
300-
301- ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
302- src = dpnp .get_usm_ndarray (x ),
303- dst = x_copy .get_array (),
304- sycl_queue = exec_q ,
305- depends = dep_evs ,
306- )
307- _manager .add_event_pair (ht_copy_ev , copy_ev )
308-
309- return x_copy
310-
311-
312294def _extract_axes_chunk (a , s , chunk_size = 3 ):
313295 """
314296 Classify the first input into a list of lists with each list containing
@@ -438,9 +420,9 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
438420 return result
439421
440422
441- def _make_array_hermitian (a , n , copy_needed ):
423+ def _make_array_hermitian (a , n , axis , copy_needed ):
442424 """
443- For `dpnp.fft.irfft` , the input array should be Hermitian. If it is not,
425+ For complex-to-real FFT , the input array should be Hermitian. If it is not,
444426 the behavior is undefined. This function makes necessary changes to make
445427 sure the given array is Hermitian.
446428
@@ -449,6 +431,7 @@ def _make_array_hermitian(a, n, copy_needed):
449431 `_truncate_or_pad`, so the array has enough length.
450432 """
451433
434+ a = dpnp .moveaxis (a , axis , 0 )
452435 length_is_even = n % 2 == 0
453436 hermitian = dpnp .all (a [0 ].imag == 0 )
454437 assert n is not None
@@ -463,14 +446,14 @@ def _make_array_hermitian(a, n, copy_needed):
463446
464447 if not hermitian :
465448 if copy_needed :
466- a = _copy_kernel ( a , a .dtype )
449+ a = a . astype ( a .dtype , copy = True )
467450
468451 a [0 ].imag = 0
469452 if length_is_even :
470453 f_ny = n // 2
471454 a [f_ny ].imag = 0
472455
473- return a
456+ return dpnp . moveaxis ( a , 0 , axis )
474457
475458
476459def _scale_result (res , a_shape , norm , forward , index ):
@@ -634,11 +617,12 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
634617
635618 if c2r :
636619 # input array should be Hermitian for c2r FFT
637- a = dpnp .moveaxis (a , axis , 0 )
638620 a = _make_array_hermitian (
639- a , a .shape [0 ], dpnp .are_same_logical_tensors (a , a_orig )
621+ a ,
622+ n = a .shape [0 ],
623+ axis = axis ,
624+ copy_needed = dpnp .are_same_logical_tensors (a , a_orig ),
640625 )
641- a = dpnp .moveaxis (a , 0 , axis )
642626
643627 return _fft (
644628 a ,
@@ -687,11 +671,12 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
687671 if len_axes == 1 :
688672 a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
689673 if c2r :
690- a = dpnp .moveaxis (a , axes [- 1 ], 0 )
691674 a = _make_array_hermitian (
692- a , a .shape [0 ], dpnp .are_same_logical_tensors (a , a_orig )
675+ a ,
676+ n = a .shape [0 ],
677+ axis = axes [- 1 ],
678+ copy_needed = dpnp .are_same_logical_tensors (a , a_orig ),
693679 )
694- a = dpnp .moveaxis (a , 0 , axes [- 1 ])
695680 return _fft (
696681 a , norm , out , forward , in_place and c2c , c2c , axes [- 1 ], a .ndim != 1
697682 )
@@ -743,11 +728,12 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
743728 )
744729 a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
745730 if c2r :
746- a = dpnp .moveaxis (a , axes [- 1 ], 0 )
747731 a = _make_array_hermitian (
748- a , a .shape [0 ], dpnp .are_same_logical_tensors (a , a_orig )
732+ a ,
733+ n = a .shape [0 ],
734+ axis = axes [- 1 ],
735+ copy_needed = dpnp .are_same_logical_tensors (a , a_orig ),
749736 )
750- a = dpnp .moveaxis (a , 0 , axes [- 1 ])
751737 return _fft (
752738 a , norm , out , forward , in_place and c2c , c2c , axes [- 1 ], a .ndim != 1
753739 )
0 commit comments