Skip to content

Commit a8d9474

Browse files
committed
factor out compare template param in isin
1 parent 9dc5a88 commit a8d9474

File tree

2 files changed

+12
-19
lines changed

2 files changed

+12
-19
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/isin.hpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "kernels/dpctl_tensor_types.hpp"
3535
#include "kernels/sorting/search_sorted_detail.hpp"
3636
#include "utils/offset_utils.hpp"
37+
#include "utils/rich_comparisons.hpp"
3738

3839
namespace dpctl
3940
{
@@ -47,8 +48,7 @@ using dpctl::tensor::ssize_t;
4748
template <typename T,
4849
typename HayIndexerT,
4950
typename NeedlesIndexerT,
50-
typename OutIndexerT,
51-
typename Compare>
51+
typename OutIndexerT>
5252
struct IsinFunctor
5353
{
5454
private:
@@ -78,6 +78,8 @@ struct IsinFunctor
7878

7979
void operator()(sycl::id<1> id) const
8080
{
81+
using Compare =
82+
typename dpctl::tensor::rich_comparisons::AscendingSorter<T>::type;
8183
static constexpr Compare comp{};
8284

8385
const std::size_t i = id[0];
@@ -115,7 +117,7 @@ typedef sycl::event (*isin_contig_impl_fp_ptr_t)(
115117

116118
template <typename T> class isin_contig_impl_krn;
117119

118-
template <typename T, typename Compare>
120+
template <typename T>
119121
sycl::event isin_contig_impl(sycl::queue &exec_q,
120122
const bool invert,
121123
const std::size_t hay_nelems,
@@ -148,9 +150,9 @@ sycl::event isin_contig_impl(sycl::queue &exec_q,
148150
static constexpr TrivialIndexerT out_indexer{};
149151

150152
const auto fnctr =
151-
IsinFunctor<T, TrivialIndexerT, TrivialIndexerT, TrivialIndexerT,
152-
Compare>(invert, hay_tp, needles_tp, out_tp, hay_nelems,
153-
hay_indexer, needles_indexer, out_indexer);
153+
IsinFunctor<T, TrivialIndexerT, TrivialIndexerT, TrivialIndexerT>(
154+
invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer,
155+
needles_indexer, out_indexer);
154156

155157
cgh.parallel_for<KernelName>(gRange, fnctr);
156158
});
@@ -176,7 +178,7 @@ typedef sycl::event (*isin_strided_impl_fp_ptr_t)(
176178

177179
template <typename T> class isin_strided_impl_krn;
178180

179-
template <typename T, typename Compare>
181+
template <typename T>
180182
sycl::event isin_strided_impl(
181183
sycl::queue &exec_q,
182184
const bool invert,
@@ -224,7 +226,7 @@ sycl::event isin_strided_impl(
224226
out_strides);
225227

226228
const auto fnctr =
227-
IsinFunctor<T, HayIndexerT, NeedlesIndexerT, OutIndexerT, Compare>(
229+
IsinFunctor<T, HayIndexerT, NeedlesIndexerT, OutIndexerT>(
228230
invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer,
229231
needles_indexer, out_indexer);
230232
using KernelName = class isin_strided_impl_krn<T>;

dpctl/tensor/libtensor/source/sorting/isin.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
#include "kernels/sorting/isin.hpp"
3636
#include "utils/memory_overlap.hpp"
3737
#include "utils/output_validation.hpp"
38-
#include "utils/rich_comparisons.hpp"
3938
#include "utils/sycl_alloc_utils.hpp"
4039
#include "utils/type_dispatch.hpp"
4140
#include "utils/type_utils.hpp"
@@ -68,11 +67,7 @@ template <typename fnT, typename argTy> struct IsinContigFactory
6867
fnT get() const
6968
{
7069
using dpctl::tensor::kernels::isin_contig_impl;
71-
using dpctl::tensor::rich_comparisons::AscendingSorter;
72-
73-
using Compare = typename AscendingSorter<argTy>::type;
74-
75-
return isin_contig_impl<argTy, Compare>;
70+
return isin_contig_impl<argTy>;
7671
}
7772
};
7873

@@ -88,11 +83,7 @@ template <typename fnT, typename argTy> struct IsinStridedFactory
8883
fnT get() const
8984
{
9085
using dpctl::tensor::kernels::isin_strided_impl;
91-
using dpctl::tensor::rich_comparisons::AscendingSorter;
92-
93-
using Compare = typename AscendingSorter<argTy>::type;
94-
95-
return isin_strided_impl<argTy, Compare>;
86+
return isin_strided_impl<argTy>;
9687
}
9788
};
9889

0 commit comments

Comments
 (0)