Skip to content

Commit 9dc5a88

Browse files
committed
move rich_comparisons.hpp into utils
allows these comparisons to be shared outside of sorting more conveniently
1 parent 08ea385 commit 9dc5a88

File tree

6 files changed

+27
-10
lines changed

6 files changed

+27
-10
lines changed

dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp renamed to dpctl/tensor/libtensor/include/utils/rich_comparisons.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace dpctl
3131
{
3232
namespace tensor
3333
{
34-
namespace py_internal
34+
namespace rich_comparisons
3535
{
3636

3737
namespace detail
@@ -129,6 +129,6 @@ template <typename T> struct DescendingSorter<std::complex<T>>
129129
using type = detail::ExtendedComplexFPGreater<std::complex<T>>;
130130
};
131131

132-
} // end of namespace py_internal
132+
} // end of namespace rich_comparisons
133133
} // end of namespace tensor
134134
} // end of namespace dpctl

dpctl/tensor/libtensor/source/sorting/isin.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@
3535
#include "kernels/sorting/isin.hpp"
3636
#include "utils/memory_overlap.hpp"
3737
#include "utils/output_validation.hpp"
38+
#include "utils/rich_comparisons.hpp"
3839
#include "utils/sycl_alloc_utils.hpp"
3940
#include "utils/type_dispatch.hpp"
4041
#include "utils/type_utils.hpp"
4142

42-
#include "rich_comparisons.hpp"
4343
#include "simplify_iteration_space.hpp"
4444

4545
namespace py = pybind11;
@@ -68,7 +68,10 @@ template <typename fnT, typename argTy> struct IsinContigFactory
6868
fnT get() const
6969
{
7070
using dpctl::tensor::kernels::isin_contig_impl;
71+
using dpctl::tensor::rich_comparisons::AscendingSorter;
72+
7173
using Compare = typename AscendingSorter<argTy>::type;
74+
7275
return isin_contig_impl<argTy, Compare>;
7376
}
7477
};
@@ -85,7 +88,10 @@ template <typename fnT, typename argTy> struct IsinStridedFactory
8588
fnT get() const
8689
{
8790
using dpctl::tensor::kernels::isin_strided_impl;
91+
using dpctl::tensor::rich_comparisons::AscendingSorter;
92+
8893
using Compare = typename AscendingSorter<argTy>::type;
94+
8995
return isin_strided_impl<argTy, Compare>;
9096
}
9197
};

dpctl/tensor/libtensor/source/sorting/merge_argsort.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
#include "utils/math_utils.hpp"
3131
#include "utils/memory_overlap.hpp"
3232
#include "utils/output_validation.hpp"
33+
#include "utils/rich_comparisons.hpp"
3334
#include "utils/type_dispatch.hpp"
3435

3536
#include "kernels/sorting/merge_sort.hpp"
3637
#include "kernels/sorting/sort_impl_fn_ptr_t.hpp"
37-
#include "rich_comparisons.hpp"
3838

3939
#include "merge_argsort.hpp"
4040
#include "py_argsort_common.hpp"
@@ -63,6 +63,7 @@ struct AscendingArgSortContigFactory
6363
if constexpr (std::is_same_v<IndexTy, std::int64_t> ||
6464
std::is_same_v<IndexTy, std::int32_t>)
6565
{
66+
using dpctl::tensor::rich_comparisons::AscendingSorter;
6667
using Comp = typename AscendingSorter<argTy>::type;
6768

6869
using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl;
@@ -82,6 +83,7 @@ struct DescendingArgSortContigFactory
8283
if constexpr (std::is_same_v<IndexTy, std::int64_t> ||
8384
std::is_same_v<IndexTy, std::int32_t>)
8485
{
86+
using dpctl::tensor::rich_comparisons::DescendingSorter;
8587
using Comp = typename DescendingSorter<argTy>::type;
8688

8789
using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl;

dpctl/tensor/libtensor/source/sorting/merge_sort.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@
3131
#include "utils/math_utils.hpp"
3232
#include "utils/memory_overlap.hpp"
3333
#include "utils/output_validation.hpp"
34+
#include "utils/rich_comparisons.hpp"
3435
#include "utils/type_dispatch.hpp"
3536

3637
#include "kernels/sorting/merge_sort.hpp"
3738
#include "kernels/sorting/sort_impl_fn_ptr_t.hpp"
3839

3940
#include "merge_sort.hpp"
4041
#include "py_sort_common.hpp"
41-
#include "rich_comparisons.hpp"
4242

4343
namespace td_ns = dpctl::tensor::type_dispatch;
4444

@@ -59,6 +59,7 @@ template <typename fnT, typename argTy> struct AscendingSortContigFactory
5959
{
6060
fnT get()
6161
{
62+
using dpctl::tensor::rich_comparisons::AscendingSorter;
6263
using Comp = typename AscendingSorter<argTy>::type;
6364

6465
using dpctl::tensor::kernels::stable_sort_axis1_contig_impl;
@@ -70,7 +71,9 @@ template <typename fnT, typename argTy> struct DescendingSortContigFactory
7071
{
7172
fnT get()
7273
{
74+
using dpctl::tensor::rich_comparisons::DescendingSorter;
7375
using Comp = typename DescendingSorter<argTy>::type;
76+
7477
using dpctl::tensor::kernels::stable_sort_axis1_contig_impl;
7578
return stable_sort_axis1_contig_impl<argTy, Comp>;
7679
}

dpctl/tensor/libtensor/source/sorting/searchsorted.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@
3535
#include "kernels/sorting/searchsorted.hpp"
3636
#include "utils/memory_overlap.hpp"
3737
#include "utils/output_validation.hpp"
38+
#include "utils/rich_comparisons.hpp"
3839
#include "utils/sycl_alloc_utils.hpp"
3940
#include "utils/type_dispatch.hpp"
4041
#include "utils/type_utils.hpp"
4142

42-
#include "rich_comparisons.hpp"
4343
#include "simplify_iteration_space.hpp"
4444

4545
namespace py = pybind11;
@@ -76,6 +76,7 @@ struct LeftSideSearchSortedContigFactory
7676
{
7777
static constexpr bool left_side_search(true);
7878
using dpctl::tensor::kernels::searchsorted_contig_impl;
79+
using dpctl::tensor::rich_comparisons::AscendingSorter;
7980

8081
using Compare = typename AscendingSorter<argTy>::type;
8182

@@ -99,7 +100,9 @@ struct RightSideSearchSortedContigFactory
99100
std::is_same_v<indTy, std::int64_t>)
100101
{
101102
static constexpr bool right_side_search(false);
103+
102104
using dpctl::tensor::kernels::searchsorted_contig_impl;
105+
using dpctl::tensor::rich_comparisons::AscendingSorter;
103106

104107
using Compare = typename AscendingSorter<argTy>::type;
105108

@@ -132,6 +135,7 @@ struct LeftSideSearchSortedStridedFactory
132135
{
133136
static constexpr bool left_side_search(true);
134137
using dpctl::tensor::kernels::searchsorted_strided_impl;
138+
using dpctl::tensor::rich_comparisons::AscendingSorter;
135139

136140
using Compare = typename AscendingSorter<argTy>::type;
137141

@@ -156,6 +160,7 @@ struct RightSideSearchSortedStridedFactory
156160
{
157161
static constexpr bool right_side_search(false);
158162
using dpctl::tensor::kernels::searchsorted_strided_impl;
163+
using dpctl::tensor::rich_comparisons::AscendingSorter;
159164

160165
using Compare = typename AscendingSorter<argTy>::type;
161166

dpctl/tensor/libtensor/source/sorting/topk.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@
4141
#include "utils/math_utils.hpp"
4242
#include "utils/memory_overlap.hpp"
4343
#include "utils/output_validation.hpp"
44+
#include "utils/rich_comparisons.hpp"
4445
#include "utils/type_dispatch.hpp"
4546
#include "utils/type_utils.hpp"
4647

47-
#include "rich_comparisons.hpp"
4848
#include "topk.hpp"
4949

5050
namespace dpctl
@@ -110,15 +110,16 @@ sycl::event topk_caller(sycl::queue &exec_q,
110110
using dpctl::tensor::kernels::topk_merge_impl;
111111
if (largest) {
112112
using CompTy =
113-
typename dpctl::tensor::py_internal::DescendingSorter<
113+
typename dpctl::tensor::rich_comparisons::DescendingSorter<
114114
argTy>::type;
115115
return topk_merge_impl<argTy, IndexTy, CompTy>(
116116
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
117117
depends);
118118
}
119119
else {
120-
using CompTy = typename dpctl::tensor::py_internal::AscendingSorter<
121-
argTy>::type;
120+
using CompTy =
121+
typename dpctl::tensor::rich_comparisons::AscendingSorter<
122+
argTy>::type;
122123
return topk_merge_impl<argTy, IndexTy, CompTy>(
123124
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
124125
depends);

0 commit comments

Comments
 (0)