Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
518fd6a
Implement tensor.isin
ndgrigorian Jun 6, 2025
85106f4
factor common utilities for scalar arguments to a new file
ndgrigorian Jun 8, 2025
ce757c8
Make constexpr variables in `isin` static
ndgrigorian Jun 11, 2025
8425cdc
Update implementation of `isin`
ndgrigorian Jun 11, 2025
204025a
Update per review comments
ndgrigorian Jun 16, 2025
cfb5154
Allow x to be a scalar in isin and remove assume_unique
ndgrigorian Jun 16, 2025
a448725
Make comparator static constexpr
ndgrigorian Jun 16, 2025
882a431
add basic tests for isin functionality
ndgrigorian Jun 17, 2025
d11626c
Remove unused import of dpctl in _set_functions.py
ndgrigorian Jun 17, 2025
710a6d0
Add type hints to isin
ndgrigorian Jun 17, 2025
411b976
Add usm_type to test_buf in isin
ndgrigorian Jun 18, 2025
3077bd6
Address review comments for isin tests
ndgrigorian Jun 18, 2025
ed9f376
Add test covering nans and +/- 0 in isin
ndgrigorian Jun 18, 2025
72278b5
Add test for isin with Python scalar args
ndgrigorian Jun 23, 2025
5b877c8
Add test for combinations of dtypes as inputs to isin
ndgrigorian Jun 23, 2025
538176c
Add compute follows data test for isin
ndgrigorian Jun 23, 2025
0cd24f5
Add isin to rendered docs
ndgrigorian Jun 23, 2025
0070331
Test that isin output is C-contiguous when input is strided
ndgrigorian Jun 23, 2025
2ebb3a1
improve formatting of isin docstring
ndgrigorian Jun 26, 2025
26f68d2
move rich_comparisons.hpp into utils
ndgrigorian Sep 11, 2025
365c3b4
factor out compare template param in isin
ndgrigorian Sep 11, 2025
44661ba
add missing includes to rich_comparisons
ndgrigorian Sep 17, 2025
62a9fa4
drop unused includes in isin kernel implementation
ndgrigorian Sep 17, 2025
ff68bca
validate that `invert` kwarg is boolean
ndgrigorian Sep 17, 2025
e28e1aa
add a test for input validation of invert
ndgrigorian Sep 17, 2025
68bc84e
make invert bool type check stricter
ndgrigorian Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Set Functions
.. autosummary::
:toctree: generated

isin
unique_all
unique_counts
unique_inverse
Expand Down
1 change: 1 addition & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ set(_reduction_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp
)
set(_sorting_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
Expand Down
2 changes: 2 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@
)
from ._searchsorted import searchsorted
from ._set_functions import (
isin,
unique_all,
unique_counts,
unique_inverse,
Expand Down Expand Up @@ -394,4 +395,5 @@
"top_k",
"dldevice_to_sycl_device",
"sycl_device_to_dldevice",
"isin",
]
10 changes: 5 additions & 5 deletions dpctl/tensor/_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
_empty_like_pair_orderK,
_empty_like_triple_orderK,
)
from dpctl.tensor._elementwise_common import (
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.tensor._type_utils import _can_cast
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._scalar_utils import (
_get_dtype,
_get_queue_usm_type,
_get_shape,
_validate_dtype,
)
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.tensor._type_utils import _can_cast
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._type_utils import (
_resolve_one_strong_one_weak_types,
_resolve_one_strong_two_weak_types,
Expand Down
89 changes: 6 additions & 83 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers

import numpy as np

import dpctl
import dpctl.memory as dpm
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
from ._scalar_utils import (
_get_dtype,
_get_queue_usm_type,
_get_shape,
_validate_dtype,
)
from ._type_utils import (
WeakBooleanType,
WeakComplexType,
WeakFloatingType,
WeakIntegralType,
_acceptance_fn_default_binary,
_acceptance_fn_default_unary,
_all_data_types,
_find_buf_dtype,
_find_buf_dtype2,
_find_buf_dtype_in_place_op,
_resolve_weak_types,
_to_device_supported_dtype,
)


Expand Down Expand Up @@ -289,78 +284,6 @@ def __call__(self, x, /, *, out=None, order="K"):
return out


def _get_queue_usm_type(o):
"""Return SYCL device where object `o` allocated memory, or None."""
if isinstance(o, dpt.usm_ndarray):
return o.sycl_queue, o.usm_type
elif hasattr(o, "__sycl_usm_array_interface__"):
try:
m = dpm.as_usm_memory(o)
return m.sycl_queue, m.get_usm_type()
except Exception:
return None, None
return None, None


def _get_dtype(o, dev):
if isinstance(o, dpt.usm_ndarray):
return o.dtype
if hasattr(o, "__sycl_usm_array_interface__"):
return dpt.asarray(o).dtype
if _is_buffer(o):
host_dt = np.array(o).dtype
dev_dt = _to_device_supported_dtype(host_dt, dev)
return dev_dt
if hasattr(o, "dtype"):
dev_dt = _to_device_supported_dtype(o.dtype, dev)
return dev_dt
if isinstance(o, bool):
return WeakBooleanType(o)
if isinstance(o, int):
return WeakIntegralType(o)
if isinstance(o, float):
return WeakFloatingType(o)
if isinstance(o, complex):
return WeakComplexType(o)
return np.object_


def _validate_dtype(dt) -> bool:
return isinstance(
dt,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
) or (
isinstance(dt, dpt.dtype)
and dt
in [
dpt.bool,
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
dpt.float16,
dpt.float32,
dpt.float64,
dpt.complex64,
dpt.complex128,
]
)


def _get_shape(o):
if isinstance(o, dpt.usm_ndarray):
return o.shape
if _is_buffer(o):
return memoryview(o).shape
if isinstance(o, numbers.Number):
return tuple()
return getattr(o, "shape", tuple())


class BinaryElementwiseFunc:
"""
Class that implements binary element-wise functions.
Expand Down
111 changes: 111 additions & 0 deletions dpctl/tensor/_scalar_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers

import numpy as np

import dpctl.memory as dpm
import dpctl.tensor as dpt
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer

from ._type_utils import (
WeakBooleanType,
WeakComplexType,
WeakFloatingType,
WeakIntegralType,
_to_device_supported_dtype,
)


def _get_queue_usm_type(o):
"""Return SYCL device where object `o` allocated memory, or None."""
if isinstance(o, dpt.usm_ndarray):
return o.sycl_queue, o.usm_type
elif hasattr(o, "__sycl_usm_array_interface__"):
try:
m = dpm.as_usm_memory(o)
return m.sycl_queue, m.get_usm_type()
except Exception:
return None, None
return None, None


def _get_dtype(o, dev):
if isinstance(o, dpt.usm_ndarray):
return o.dtype
if hasattr(o, "__sycl_usm_array_interface__"):
return dpt.asarray(o).dtype
if _is_buffer(o):
host_dt = np.array(o).dtype
dev_dt = _to_device_supported_dtype(host_dt, dev)
return dev_dt
if hasattr(o, "dtype"):
dev_dt = _to_device_supported_dtype(o.dtype, dev)
return dev_dt
if isinstance(o, bool):
return WeakBooleanType(o)
if isinstance(o, int):
return WeakIntegralType(o)
if isinstance(o, float):
return WeakFloatingType(o)
if isinstance(o, complex):
return WeakComplexType(o)
return np.object_


def _validate_dtype(dt) -> bool:
return isinstance(
dt,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
) or (
isinstance(dt, dpt.dtype)
and dt
in [
dpt.bool,
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
dpt.float16,
dpt.float32,
dpt.float64,
dpt.complex64,
dpt.complex128,
]
)


def _get_shape(o):
if isinstance(o, dpt.usm_ndarray):
return o.shape
if _is_buffer(o):
return memoryview(o).shape
if isinstance(o, numbers.Number):
return tuple()
return getattr(o, "shape", tuple())


__all__ = [
"_get_dtype",
"_get_queue_usm_type",
"_get_shape",
"_validate_dtype",
]
10 changes: 5 additions & 5 deletions dpctl/tensor/_search_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
import dpctl
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._elementwise_common import (
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
from ._scalar_utils import (
_get_dtype,
_get_queue_usm_type,
_get_shape,
_validate_dtype,
)
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
from ._type_utils import (
WeakBooleanType,
WeakComplexType,
Expand Down
Loading
Loading