Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ def _usm_ndarray_from_suai(obj):
buffer=membuf,
strides=sua_iface.get("strides", None),
)
_data_field = sua_iface["data"]
if isinstance(_data_field, tuple) and len(_data_field) > 1:
ro_field = _data_field[1]
else:
ro_field = False
if ro_field:
ary.flags["W"] = False
return ary


Expand Down
25 changes: 25 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,3 +2368,28 @@ def test_gh_1201():
c = dpt.flip(dpt.empty(a.shape, dtype=a.dtype))
c[:] = a
assert (dpt.asnumpy(c) == a).all()


class ObjWithSyclUsmArrayInterface:
def __init__(self, ary):
self._array_obj = ary

@property
def __sycl_usm_array_interface__(self):
_suai = self._array_obj.__sycl_usm_array_interface__
return _suai


def test_asarray_writable_flag():
try:
a = dpt.empty(8)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")

a.flags["W"] = False
wrapped = ObjWithSyclUsmArrayInterface(a)

b = dpt.asarray(wrapped)

assert not b.flags["W"]
assert b._pointer == a._pointer