Skip to content

Commit 3b73c15

Browse files
author
Vahid Tavanashad
committed
address comments
1 parent 342b37f commit 3b73c15

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

dpnp/dpnp_iface.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,6 @@ def get_result_array(a, out=None, casting="safe"):
624624
if out is None:
625625
return a
626626

627-
dpnp.check_supported_arrays_type(out)
628627
if isinstance(out, dpt.usm_ndarray):
629628
out = dpnp_array._create_from_usm_ndarray(out)
630629

dpnp/fft/dpnp_utils_fft.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _fft(a, norm, out, forward, in_place, axes=None):
181181
a = dpnp.moveaxis(a, axes, local_axes)
182182
a_shape_orig = a.shape
183183
local_shape = (-1,) + a_shape_orig[-len_axes:]
184-
a = dpt.reshape(a.get_array(), local_shape)
184+
a = dpnp.reshape(a, local_shape)
185185
index = 1
186186

187187
a_strides = _standardize_strides_to_nonzero(a.strides, a.shape)
@@ -190,16 +190,15 @@ def _fft(a, norm, out, forward, in_place, axes=None):
190190
res = _scale_result(res, norm, forward, index)
191191

192192
if axes is not None: # batch_fft
193-
res = dpt.reshape(res.get_array(), a_shape_orig)
193+
res = dpnp.reshape(res, a_shape_orig)
194194
res = dpnp.moveaxis(res, local_axes, axes)
195195

196196
result = dpnp.get_result_array(res, out=out, casting="same_kind")
197197
if out is None and not (
198198
result.flags.c_contiguous or result.flags.f_contiguous
199199
):
200200
result = dpnp.ascontiguousarray(result)
201-
else:
202-
dpnp.synchronize_array_data(result)
201+
203202
return result
204203

205204

@@ -214,9 +213,8 @@ def _scale_result(res, norm, forward, index):
214213
elif norm in [None, "backward"] and not forward:
215214
norm_factor = scale
216215

217-
usm_res = res.get_array()
218-
usm_res /= norm_factor
219-
return dpnp_array._create_from_usm_ndarray(usm_res)
216+
res /= norm_factor
217+
return res
220218

221219

222220
def _truncate_or_pad(a, shape, axes):

0 commit comments

Comments
 (0)