2121import dpctl .tensor ._tensor_impl as ti
2222import dpctl .utils
2323
24- from ._copy_utils import _extract_impl , _nonzero_impl , _take_multi_index
24+ from ._copy_utils import (
25+ _extract_impl ,
26+ _nonzero_impl ,
27+ _put_multi_index ,
28+ _take_multi_index ,
29+ )
2530from ._numpy_helper import normalize_axis_index
2631
2732
@@ -206,22 +211,18 @@ def put_vec_duplicates(vec, ind, vals):
206211 raise TypeError (
207212 "Expected instance of `dpt.usm_ndarray`, got `{}`." .format (type (x ))
208213 )
209- if isinstance (vals , dpt .usm_ndarray ):
210- queues_ = [x .sycl_queue , vals .sycl_queue ]
211- usm_types_ = [x .usm_type , vals .usm_type ]
212- else :
213- queues_ = [
214- x .sycl_queue ,
215- ]
216- usm_types_ = [
217- x .usm_type ,
218- ]
219214 if not isinstance (indices , dpt .usm_ndarray ):
220215 raise TypeError (
221216 "`indices` expected `dpt.usm_ndarray`, got `{}`." .format (
222217 type (indices )
223218 )
224219 )
220+ if isinstance (vals , dpt .usm_ndarray ):
221+ queues_ = [x .sycl_queue , indices .sycl_queue , vals .sycl_queue ]
222+ usm_types_ = [x .usm_type , indices .usm_type , vals .usm_type ]
223+ else :
224+ queues_ = [x .sycl_queue , indices .sycl_queue ]
225+ usm_types_ = [x .usm_type , indices .usm_type ]
225226 if indices .ndim != 1 :
226227 raise ValueError (
227228 "`indices` expected a 1D array, got `{}`" .format (indices .ndim )
@@ -232,7 +233,6 @@ def put_vec_duplicates(vec, ind, vals):
232233 indices .dtype
233234 )
234235 )
235- queues_ .append (indices .sycl_queue )
236236 usm_types_ .append (indices .usm_type )
237237 exec_q = dpctl .utils .get_execution_queue (queues_ )
238238 if exec_q is None :
@@ -502,3 +502,81 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
502502 for i in range (x_nd )
503503 )
504504 return _take_multi_index (x , _ind , 0 , mode = mode_i )
505+
506+
507+ def put_along_axis (x , indices , vals , / , * , axis = - 1 , mode = "wrap" ):
508+ """
509+ Puts elements into an array at the one-dimensional indices specified by
510+ ``indices`` along a provided ``axis``.
511+
512+ Args:
513+ x (usm_ndarray):
514+ input array. Must be compatible with ``indices``, except for the
515+ axis (dimension) specified by ``axis``.
516+ indices (usm_ndarray):
517+ array indices. Must have the same rank (i.e., number of dimensions)
518+ as ``x``.
519+ vals (usm_ndarray):
520+ Array of values to be put into ``x``.
521+ Must be broadcastable to the shape of ``indices``.
522+ axis: int
523+ axis along which to select values. If ``axis`` is negative, the
524+ function determines the axis along which to select values by
525+ counting from the last dimension. Default: ``-1``.
526+ mode (str, optional):
527+ How out-of-bounds indices will be handled. Possible values
528+ are:
529+
530+ - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
531+ negative indices.
532+ - ``"clip"``: clips indices to (``0 <= i < n``).
533+
534+ Default: ``"wrap"``.
535+
536+ .. note::
537+
538+ If input array ``indices`` contains duplicates, a race condition
539+ occurs, and the value written into corresponding positions in ``x``
540+ may vary from run to run. Preserving sequential semantics in handing
541+ the duplicates to achieve deterministic behavior requires additional
542+ work.
543+
544+ See :func:`dpctl.tensor.put` for an example.
545+ """
546+ if not isinstance (x , dpt .usm_ndarray ):
547+ raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
548+ if not isinstance (indices , dpt .usm_ndarray ):
549+ raise TypeError (
550+ f"Expected dpctl.tensor.usm_ndarray, got { type (indices )} "
551+ )
552+ x_nd = x .ndim
553+ if x_nd != indices .ndim :
554+ raise ValueError (
555+ "Number of dimensions in the first and the second "
556+ "argument arrays must be equal"
557+ )
558+ pp = normalize_axis_index (operator .index (axis ), x_nd )
559+ if isinstance (vals , dpt .usm_ndarray ):
560+ queues_ = [x .sycl_queue , indices .sycl_queue , vals .sycl_queue ]
561+ usm_types_ = [x .usm_type , indices .usm_type , vals .usm_type ]
562+ else :
563+ queues_ = [x .sycl_queue , indices .sycl_queue ]
564+ usm_types_ = [x .usm_type , indices .usm_type ]
565+ exec_q = dpctl .utils .get_execution_queue (queues_ )
566+ if exec_q is None :
567+ raise dpctl .utils .ExecutionPlacementError (
568+ "Execution placement can not be unambiguously inferred "
569+ "from input arguments. "
570+ )
571+ out_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
572+ mode_i = _get_indexing_mode (mode )
573+ indexes_dt = ti .default_device_index_type (exec_q .sycl_device )
574+ _ind = tuple (
575+ (
576+ indices
577+ if i == pp
578+ else _range (x .shape [i ], i , x_nd , exec_q , out_usm_type , indexes_dt )
579+ )
580+ for i in range (x_nd )
581+ )
582+ return _put_multi_index (x , _ind , 0 , vals , mode = mode_i )
0 commit comments