diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 264ebf0bff..e360de2a24 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -756,20 +756,28 @@ def _extract_impl(ary, ary_mask, axis=0): raise TypeError( f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" ) - if not isinstance(ary_mask, dpt.usm_ndarray): - raise TypeError( - f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}" + if isinstance(ary_mask, dpt.usm_ndarray): + dst_usm_type = dpctl.utils.get_coerced_usm_type( + (ary.usm_type, ary_mask.usm_type) ) - dst_usm_type = dpctl.utils.get_coerced_usm_type( - (ary.usm_type, ary_mask.usm_type) - ) - exec_q = dpctl.utils.get_execution_queue( - (ary.sycl_queue, ary_mask.sycl_queue) - ) - if exec_q is None: - raise dpctl.utils.ExecutionPlacementError( - "arrays have different associated queues. " - "Use `y.to_device(x.device)` to migrate." + exec_q = dpctl.utils.get_execution_queue( + (ary.sycl_queue, ary_mask.sycl_queue) + ) + if exec_q is None: + raise dpctl.utils.ExecutionPlacementError( + "arrays have different associated queues. " + "Use `y.to_device(x.device)` to migrate." + ) + elif isinstance(ary_mask, np.ndarray): + dst_usm_type = ary.usm_type + exec_q = ary.sycl_queue + ary_mask = dpt.asarray( + ary_mask, usm_type=dst_usm_type, sycl_queue=exec_q + ) + else: + raise TypeError( + "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got " + f"{type(ary_mask)}" ) ary_nd = ary.ndim pp = normalize_axis_index(operator.index(axis), ary_nd) @@ -837,35 +845,40 @@ def _nonzero_impl(ary): return res -def _validate_indices(inds, queue_list, usm_type_list): +def _get_indices_queue_usm_type(inds, queue, usm_type): """ - Utility for validating indices are usm_ndarray of integral dtype or Python - integers. At least one must be an array. + Utility for validating indices are NumPy ndarray or usm_ndarray of integral + dtype or Python integers. At least one must be an array. For each array, the queue and usm type are appended to `queue_list` and `usm_type_list`, respectively. """ - any_usmarray = False + queues = [queue] + usm_types = [usm_type] + any_array = False for ind in inds: - if isinstance(ind, dpt.usm_ndarray): - any_usmarray = True + if isinstance(ind, (np.ndarray, dpt.usm_ndarray)): + any_array = True if ind.dtype.kind not in "ui": raise IndexError( "arrays used as indices must be of integer (or boolean) " "type" ) - queue_list.append(ind.sycl_queue) - usm_type_list.append(ind.usm_type) + if isinstance(ind, dpt.usm_ndarray): + queues.append(ind.sycl_queue) + usm_types.append(ind.usm_type) elif not isinstance(ind, Integral): raise TypeError( - "all elements of `ind` expected to be usm_ndarrays " - f"or integers, found {type(ind)}" + "all elements of `ind` expected to be usm_ndarrays, " + f"NumPy arrays, or integers, found {type(ind)}" ) - if not any_usmarray: + if not any_array: raise TypeError( - "at least one element of `inds` expected to be a usm_ndarray" + "at least one element of `inds` expected to be an array" ) - return inds + usm_type = dpctl.utils.get_coerced_usm_type(usm_types) + q = dpctl.utils.get_execution_queue(queues) + return q, usm_type def _prepare_indices_arrays(inds, q, usm_type): @@ -922,18 +935,12 @@ def _take_multi_index(ary, inds, p, mode=0): raise ValueError( "Invalid value for mode keyword, only 0 or 1 is supported" ) - queues_ = [ - ary.sycl_queue, - ] - usm_types_ = [ - ary.usm_type, - ] if not isinstance(inds, (list, tuple)): inds = (inds,) - _validate_indices(inds, queues_, usm_types_) - res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) - exec_q = dpctl.utils.get_execution_queue(queues_) + exec_q, res_usm_type = _get_indices_queue_usm_type( + inds, ary.sycl_queue, ary.usm_type + ) if exec_q is None: raise dpctl.utils.ExecutionPlacementError( "Can not automatically determine where to allocate the " @@ -942,8 +949,7 @@ def _take_multi_index(ary, inds, p, mode=0): "be associated with the same queue." ) - if len(inds) > 1: - inds = _prepare_indices_arrays(inds, exec_q, res_usm_type) + inds = _prepare_indices_arrays(inds, exec_q, res_usm_type) ind0 = inds[0] ary_sh = ary.shape @@ -976,16 +982,28 @@ def _place_impl(ary, ary_mask, vals, axis=0): raise TypeError( f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" ) - if not isinstance(ary_mask, dpt.usm_ndarray): - raise TypeError( - f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}" + if isinstance(ary_mask, dpt.usm_ndarray): + exec_q = dpctl.utils.get_execution_queue( + ( + ary.sycl_queue, + ary_mask.sycl_queue, + ) ) - exec_q = dpctl.utils.get_execution_queue( - ( - ary.sycl_queue, - ary_mask.sycl_queue, + if exec_q is None: + raise dpctl.utils.ExecutionPlacementError( + "arrays have different associated queues. " + "Use `y.to_device(x.device)` to migrate." + ) + elif isinstance(ary_mask, np.ndarray): + exec_q = ary.sycl_queue + ary_mask = dpt.asarray( + ary_mask, usm_type=ary.usm_type, sycl_queue=exec_q + ) + else: + raise TypeError( + "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got " + f"{type(ary_mask)}" ) - ) if exec_q is not None: if not isinstance(vals, dpt.usm_ndarray): vals = dpt.asarray(vals, dtype=ary.dtype, sycl_queue=exec_q) @@ -1048,23 +1066,13 @@ def _put_multi_index(ary, inds, p, vals, mode=0): raise ValueError( "Invalid value for mode keyword, only 0 or 1 is supported" ) - if isinstance(vals, dpt.usm_ndarray): - queues_ = [ary.sycl_queue, vals.sycl_queue] - usm_types_ = [ary.usm_type, vals.usm_type] - else: - queues_ = [ - ary.sycl_queue, - ] - usm_types_ = [ - ary.usm_type, - ] if not isinstance(inds, (list, tuple)): inds = (inds,) - _validate_indices(inds, queues_, usm_types_) + exec_q, vals_usm_type = _get_indices_queue_usm_type( + inds, ary.sycl_queue, ary.usm_type + ) - vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) - exec_q = dpctl.utils.get_execution_queue(queues_) if exec_q is not None: if not isinstance(vals, dpt.usm_ndarray): vals = dpt.asarray( @@ -1080,8 +1088,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0): "be associated with the same queue." ) - if len(inds) > 1: - inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type) + inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type) ind0 = inds[0] ary_sh = ary.shape diff --git a/dpctl/tensor/_slicing.pxi b/dpctl/tensor/_slicing.pxi index b94167e7e6..eaf9a855fb 100644 --- a/dpctl/tensor/_slicing.pxi +++ b/dpctl/tensor/_slicing.pxi @@ -17,6 +17,7 @@ import numbers from operator import index from cpython.buffer cimport PyObject_CheckBuffer +from numpy import ndarray cdef bint _is_buffer(object o): @@ -46,7 +47,7 @@ cdef Py_ssize_t _slice_len( cdef bint _is_integral(object x) except *: """Gives True if x is an integral slice spec""" - if isinstance(x, usm_ndarray): + if isinstance(x, (ndarray, usm_ndarray)): if x.ndim > 0: return False if x.dtype.kind not in "ui": @@ -74,7 +75,7 @@ cdef bint _is_integral(object x) except *: cdef bint _is_boolean(object x) except *: """Gives True if x is an integral slice spec""" - if isinstance(x, usm_ndarray): + if isinstance(x, (ndarray, usm_ndarray)): if x.ndim > 0: return False if x.dtype.kind not in "b": @@ -185,7 +186,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): raise IndexError( "Index {0} is out of range for axes 0 with " "size {1}".format(ind, shape[0])) - elif isinstance(ind, usm_ndarray): + elif isinstance(ind, (ndarray, usm_ndarray)): return (shape, strides, offset, (ind,), 0) elif isinstance(ind, tuple): axes_referenced = 0 @@ -216,7 +217,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): axes_referenced += 1 if not array_streak_started and array_streak_interrupted: explicit_index += 1 - elif isinstance(i, usm_ndarray): + elif isinstance(i, (ndarray, usm_ndarray)): if not seen_arrays_yet: seen_arrays_yet = True array_streak_started = True @@ -302,7 +303,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): array_streak = False elif _is_integral(ind_i): if array_streak: - if not isinstance(ind_i, usm_ndarray): + if not isinstance(ind_i, (ndarray, usm_ndarray)): ind_i = index(ind_i) # integer will be converted to an array, # still raise if OOB @@ -337,7 +338,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): "Index {0} is out of range for axes " "{1} with size {2}".format(ind_i, k, shape[k]) ) - elif isinstance(ind_i, usm_ndarray): + elif isinstance(ind_i, (ndarray, usm_ndarray)): if not array_streak: array_streak = True if not advanced_start_pos_set: diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index eeb97461fd..3c6164ea32 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -440,6 +440,28 @@ def test_advanced_slice16(): assert isinstance(y, dpt.usm_ndarray) +def test_integer_indexing_numpy_array(): + q = get_queue_or_skip() + ii = np.asarray([1, 2]) + x = dpt.arange(10, dtype="i4", sycl_queue=q) + y = x[ii] + assert isinstance(y, dpt.usm_ndarray) + assert y.shape == ii.shape + assert dpt.all(dpt.asarray(ii, sycl_queue=q) == y) + + +def test_boolean_indexing_numpy_array(): + q = get_queue_or_skip() + ii = np.asarray( + [False, True, True, False, False, False, False, False, False, False] + ) + x = dpt.arange(10, dtype="i4", sycl_queue=q) + y = x[ii] + assert isinstance(y, dpt.usm_ndarray) + assert y.shape == (2,) + assert dpt.all(x[1:3] == y) + + def test_boolean_indexing_validation(): get_queue_or_skip() x = dpt.zeros(10, dtype="i4")