34
34
#include " kernels/dpctl_tensor_types.hpp"
35
35
#include " kernels/sorting/search_sorted_detail.hpp"
36
36
#include " utils/offset_utils.hpp"
37
+ #include " utils/rich_comparisons.hpp"
37
38
38
39
namespace dpctl
39
40
{
@@ -47,8 +48,7 @@ using dpctl::tensor::ssize_t;
47
48
template <typename T,
48
49
typename HayIndexerT,
49
50
typename NeedlesIndexerT,
50
- typename OutIndexerT,
51
- typename Compare>
51
+ typename OutIndexerT>
52
52
struct IsinFunctor
53
53
{
54
54
private:
@@ -78,6 +78,8 @@ struct IsinFunctor
78
78
79
79
void operator ()(sycl::id<1 > id) const
80
80
{
81
+ using Compare =
82
+ typename dpctl::tensor::rich_comparisons::AscendingSorter<T>::type;
81
83
static constexpr Compare comp{};
82
84
83
85
const std::size_t i = id[0 ];
@@ -115,7 +117,7 @@ typedef sycl::event (*isin_contig_impl_fp_ptr_t)(
115
117
116
118
template <typename T> class isin_contig_impl_krn ;
117
119
118
- template <typename T, typename Compare >
120
+ template <typename T>
119
121
sycl::event isin_contig_impl (sycl::queue &exec_q,
120
122
const bool invert,
121
123
const std::size_t hay_nelems,
@@ -148,9 +150,9 @@ sycl::event isin_contig_impl(sycl::queue &exec_q,
148
150
static constexpr TrivialIndexerT out_indexer{};
149
151
150
152
const auto fnctr =
151
- IsinFunctor<T, TrivialIndexerT, TrivialIndexerT, TrivialIndexerT,
152
- Compare>( invert, hay_tp, needles_tp, out_tp, hay_nelems,
153
- hay_indexer, needles_indexer, out_indexer);
153
+ IsinFunctor<T, TrivialIndexerT, TrivialIndexerT, TrivialIndexerT>(
154
+ invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer ,
155
+ needles_indexer, out_indexer);
154
156
155
157
cgh.parallel_for <KernelName>(gRange , fnctr);
156
158
});
@@ -176,7 +178,7 @@ typedef sycl::event (*isin_strided_impl_fp_ptr_t)(
176
178
177
179
template <typename T> class isin_strided_impl_krn ;
178
180
179
- template <typename T, typename Compare >
181
+ template <typename T>
180
182
sycl::event isin_strided_impl (
181
183
sycl::queue &exec_q,
182
184
const bool invert,
@@ -224,7 +226,7 @@ sycl::event isin_strided_impl(
224
226
out_strides);
225
227
226
228
const auto fnctr =
227
- IsinFunctor<T, HayIndexerT, NeedlesIndexerT, OutIndexerT, Compare >(
229
+ IsinFunctor<T, HayIndexerT, NeedlesIndexerT, OutIndexerT>(
228
230
invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer,
229
231
needles_indexer, out_indexer);
230
232
using KernelName = class isin_strided_impl_krn <T>;
0 commit comments