@@ -181,7 +181,7 @@ def _fft(a, norm, out, forward, in_place, axes=None):
181181 a = dpnp .moveaxis (a , axes , local_axes )
182182 a_shape_orig = a .shape
183183 local_shape = (- 1 ,) + a_shape_orig [- len_axes :]
184- a = dpt .reshape (a . get_array () , local_shape )
184+ a = dpnp .reshape (a , local_shape )
185185 index = 1
186186
187187 a_strides = _standardize_strides_to_nonzero (a .strides , a .shape )
@@ -190,16 +190,15 @@ def _fft(a, norm, out, forward, in_place, axes=None):
190190 res = _scale_result (res , norm , forward , index )
191191
192192 if axes is not None : # batch_fft
193- res = dpt .reshape (res . get_array () , a_shape_orig )
193+ res = dpnp .reshape (res , a_shape_orig )
194194 res = dpnp .moveaxis (res , local_axes , axes )
195195
196196 result = dpnp .get_result_array (res , out = out , casting = "same_kind" )
197197 if out is None and not (
198198 result .flags .c_contiguous or result .flags .f_contiguous
199199 ):
200200 result = dpnp .ascontiguousarray (result )
201- else :
202- dpnp .synchronize_array_data (result )
201+
203202 return result
204203
205204
@@ -214,9 +213,8 @@ def _scale_result(res, norm, forward, index):
214213 elif norm in [None , "backward" ] and not forward :
215214 norm_factor = scale
216215
217- usm_res = res .get_array ()
218- usm_res /= norm_factor
219- return dpnp_array ._create_from_usm_ndarray (usm_res )
216+ res /= norm_factor
217+ return res
220218
221219
222220def _truncate_or_pad (a , shape , axes ):
0 commit comments