Skip to content

Commit e4236a2

Browse files
Merge pull request #1132 from IntelPython/clipping-changes
Changes to integer indexing modes
2 parents 2273287 + a1078c7 commit e4236a2

File tree

4 files changed

+93
-41
lines changed

4 files changed

+93
-41
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,18 @@
2626
from ._copy_utils import _extract_impl, _nonzero_impl
2727

2828

29-
def take(x, indices, /, *, axis=None, mode="clip"):
30-
"""take(x, indices, axis=None, mode="clip")
29+
def _get_indexing_mode(name):
30+
modes = {"wrap": 0, "clip": 1}
31+
try:
32+
return modes[name]
33+
except KeyError:
34+
raise ValueError(
35+
"`mode` must be `wrap` or `clip`." "Got `{}`.".format(name)
36+
)
37+
38+
39+
def take(x, indices, /, *, axis=None, mode="wrap"):
40+
"""take(x, indices, axis=None, mode="wrap")
3141
3242
Takes elements from array along a given axis.
3343
@@ -42,15 +52,15 @@ def take(x, indices, /, *, axis=None, mode="clip"):
4252
Default: `None`.
4353
mode:
4454
How out-of-bounds indices will be handled.
45-
"clip" - clamps indices to (-n <= i < n), then wraps
55+
"wrap" - clamps indices to (-n <= i < n), then wraps
4656
negative indices.
47-
"wrap" - wraps both negative and positive indices.
48-
Default: `"clip"`.
57+
"clip" - clips indices to (0 <= i < n)
58+
Default: `"wrap"`.
4959
5060
Returns:
5161
out: usm_ndarray
5262
Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
53-
filled with elements .
63+
filled with elements from x.
5464
"""
5565
if not isinstance(x, dpt.usm_ndarray):
5666
raise TypeError(
@@ -80,11 +90,7 @@ def take(x, indices, /, *, axis=None, mode="clip"):
8090
[x.usm_type, indices.usm_type]
8191
)
8292

83-
modes = {"clip": 0, "wrap": 1}
84-
try:
85-
mode = modes[mode]
86-
except KeyError:
87-
raise ValueError("`mode` must be `clip` or `wrap`.")
93+
mode = _get_indexing_mode(mode)
8894

8995
x_ndim = x.ndim
9096
if axis is None:
@@ -114,8 +120,8 @@ def take(x, indices, /, *, axis=None, mode="clip"):
114120
return res
115121

116122

117-
def put(x, indices, vals, /, *, axis=None, mode="clip"):
118-
"""put(x, indices, vals, axis=None, mode="clip")
123+
def put(x, indices, vals, /, *, axis=None, mode="wrap"):
124+
"""put(x, indices, vals, axis=None, mode="wrap")
119125
120126
Puts values of an array into another array
121127
along a given axis.
@@ -134,10 +140,10 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
134140
Default: `None`.
135141
mode:
136142
How out-of-bounds indices will be handled.
137-
"clip" - clamps indices to (-axis_size <= i < axis_size),
138-
then wraps negative indices.
139-
"wrap" - wraps both negative and positive indices.
140-
Default: `"clip"`.
143+
"wrap" - clamps indices to (-n <= i < n), then wraps
144+
negative indices.
145+
"clip" - clips indices to (0 <= i < n)
146+
Default: `"wrap"`.
141147
"""
142148
if not isinstance(x, dpt.usm_ndarray):
143149
raise TypeError(
@@ -175,11 +181,8 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
175181
if exec_q is None:
176182
raise dpctl.utils.ExecutionPlacementError
177183
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
178-
modes = {"clip": 0, "wrap": 1}
179-
try:
180-
mode = modes[mode]
181-
except KeyError:
182-
raise ValueError("`mode` must be `clip` or `wrap`.")
184+
185+
mode = _get_indexing_mode(mode)
183186

184187
x_ndim = x.ndim
185188
if axis is None:

dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ namespace py = pybind11;
4646
template <typename ProjectorT, typename Ty, typename indT> class take_kernel;
4747
template <typename ProjectorT, typename Ty, typename indT> class put_kernel;
4848

49-
class ClipIndex
49+
class WrapIndex
5050
{
5151
public:
52-
ClipIndex() = default;
52+
WrapIndex() = default;
5353

5454
void operator()(py::ssize_t max_item, py::ssize_t &ind) const
5555
{
@@ -60,16 +60,15 @@ class ClipIndex
6060
}
6161
};
6262

63-
class WrapIndex
63+
class ClipIndex
6464
{
6565
public:
66-
WrapIndex() = default;
66+
ClipIndex() = default;
6767

6868
void operator()(py::ssize_t max_item, py::ssize_t &ind) const
6969
{
7070
max_item = std::max<py::ssize_t>(max_item, 1);
71-
ind = (ind < 0) ? (ind + max_item * ((-ind / max_item) + 1)) % max_item
72-
: ind % max_item;
71+
ind = std::clamp<py::ssize_t>(ind, 0, max_item - 1);
7372
return;
7473
}
7574
};

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
#include "integer_advanced_indexing.hpp"
4141

4242
#define INDEXING_MODES 2
43-
#define CLIP_MODE 0
44-
#define WRAP_MODE 1
43+
#define WRAP_MODE 0
44+
#define CLIP_MODE 1
4545

4646
namespace dpctl
4747
{
@@ -252,8 +252,8 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
252252
throw py::value_error("Axis cannot be negative.");
253253
}
254254

255-
if (mode != 0 && mode != 1) {
256-
throw py::value_error("Mode must be 0 or 1.");
255+
if (mode != 0 && mode != 1 && mode != 2) {
256+
throw py::value_error("Mode must be 0, 1, or 2.");
257257
}
258258

259259
const dpctl::tensor::usm_ndarray ind_rep = ind[0];
@@ -575,8 +575,8 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
575575
throw py::value_error("Axis cannot be negative.");
576576
}
577577

578-
if (mode != 0 && mode != 1) {
579-
throw py::value_error("Mode must be 0 or 1.");
578+
if (mode != 0 && mode != 1 && mode != 2) {
579+
throw py::value_error("Mode must be 0, 1, or 2.");
580580
}
581581

582582
if (!dst.is_writable()) {

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from helper import get_queue_or_skip, skip_if_dtype_not_supported
2121
from numpy.testing import assert_array_equal
2222

23+
import dpctl
2324
import dpctl.tensor as dpt
2425
from dpctl.utils import ExecutionPlacementError
2526

@@ -895,20 +896,21 @@ def test_integer_indexing_modes():
895896
q = get_queue_or_skip()
896897

897898
x = dpt.arange(5, sycl_queue=q)
899+
x_np = dpt.asnumpy(x)
900+
901+
# wrapping negative indices
902+
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)
898903

899-
# wrapping
900-
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)
901904
res = dpt.take(x, ind, mode="wrap")
902-
expected_arr = np.take(dpt.asnumpy(x), dpt.asnumpy(ind), mode="wrap")
905+
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="raise")
903906

904907
assert (dpt.asnumpy(res) == expected_arr).all()
905908

906-
# clipping to -n<=i<n,
907-
# where n is the axis length
908-
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)
909+
# clipping to 0 (disabling negative indices)
910+
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)
909911

910912
res = dpt.take(x, ind, mode="clip")
911-
expected_arr = np.take(dpt.asnumpy(x), dpt.asnumpy(ind), mode="raise")
913+
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="clip")
912914

913915
assert (dpt.asnumpy(res) == expected_arr).all()
914916

@@ -939,6 +941,10 @@ def test_take_arg_validation():
939941
dpt.take(dpt.reshape(x, (2, 2)), ind0, axis=None)
940942
with pytest.raises(ValueError):
941943
dpt.take(x, dpt.reshape(ind0, (2, 2)))
944+
with pytest.raises(ValueError):
945+
dpt.take(x[0], ind0, axis=2)
946+
with pytest.raises(ValueError):
947+
dpt.take(x[:, dpt.newaxis, dpt.newaxis], ind0, axis=None)
942948

943949

944950
def test_put_arg_validation():
@@ -968,6 +974,10 @@ def test_put_arg_validation():
968974
dpt.put(x, ind0, val, mode=0)
969975
with pytest.raises(ValueError):
970976
dpt.put(x, dpt.reshape(ind0, (2, 2)), val)
977+
with pytest.raises(ValueError):
978+
dpt.put(x[0], ind0, val, axis=2)
979+
with pytest.raises(ValueError):
980+
dpt.put(x[:, dpt.newaxis, dpt.newaxis], ind0, val, axis=None)
971981

972982

973983
def test_advanced_indexing_compute_follows_data():
@@ -1269,3 +1279,43 @@ def test_nonzero_large():
12691279

12701280
m = dpt.full((30, 60, 80), True)
12711281
assert m[m].size == m.size
1282+
1283+
1284+
def test_extract_arg_validation():
1285+
get_queue_or_skip()
1286+
with pytest.raises(TypeError):
1287+
dpt.extract(None, None)
1288+
cond = dpt.ones(10, dtype="?")
1289+
with pytest.raises(TypeError):
1290+
dpt.extract(cond, None)
1291+
q1 = dpctl.SyclQueue()
1292+
with pytest.raises(ExecutionPlacementError):
1293+
dpt.extract(cond.to_device(q1), dpt.zeros_like(cond, dtype="u1"))
1294+
with pytest.raises(ValueError):
1295+
dpt.extract(dpt.ones((2, 3), dtype="?"), dpt.ones((3, 2), dtype="i1"))
1296+
1297+
1298+
def test_place_arg_validation():
1299+
get_queue_or_skip()
1300+
with pytest.raises(TypeError):
1301+
dpt.place(None, None, None)
1302+
arr = dpt.zeros(8, dtype="i1")
1303+
with pytest.raises(TypeError):
1304+
dpt.place(arr, None, None)
1305+
cond = dpt.ones(8, dtype="?")
1306+
with pytest.raises(TypeError):
1307+
dpt.place(arr, cond, None)
1308+
vals = dpt.ones_like(arr)
1309+
q1 = dpctl.SyclQueue()
1310+
with pytest.raises(ExecutionPlacementError):
1311+
dpt.place(arr.to_device(q1), cond, vals)
1312+
with pytest.raises(ValueError):
1313+
dpt.place(dpt.reshape(arr, (2, 2, 2)), cond, vals)
1314+
1315+
1316+
def test_nonzero_arg_validation():
1317+
get_queue_or_skip()
1318+
with pytest.raises(TypeError):
1319+
dpt.nonzero(list())
1320+
with pytest.raises(ValueError):
1321+
dpt.nonzero(dpt.asarray(1))

0 commit comments

Comments
 (0)