From 95238df3119dc140bc90b6b0df26c1b5521c236f Mon Sep 17 00:00:00 2001 From: Peter Thoman Date: Wed, 30 Apr 2025 13:44:04 +0200 Subject: [PATCH] SYCL_KHR_GROUP_INTERFACE prototype --- .../simsycl/extensions/khr_group_interface.hh | 140 +++++++++ test/CMakeLists.txt | 1 + test/extensions/khr_group_interface_tests.cc | 272 ++++++++++++++++++ 3 files changed, 413 insertions(+) create mode 100644 include/simsycl/extensions/khr_group_interface.hh create mode 100644 test/extensions/khr_group_interface_tests.cc diff --git a/include/simsycl/extensions/khr_group_interface.hh b/include/simsycl/extensions/khr_group_interface.hh new file mode 100644 index 0000000..5513bf9 --- /dev/null +++ b/include/simsycl/extensions/khr_group_interface.hh @@ -0,0 +1,140 @@ +#include "sycl/sycl.hpp" // IWYU pragma: keep + +#define SYCL_KHR_GROUP_INTERFACE 1 + +namespace simsycl::sycl::khr { + +template +class member_item { + public: + using id_type = typename ParentGroup::id_type; + using linear_id_type = typename ParentGroup::linear_id_type; + using range_type = typename ParentGroup::range_type; + // using extents_type = /* extents of all 1s with ParentGroup's index type */; // C++23 + using size_type = typename ParentGroup::size_type; + static constexpr int dimensions = ParentGroup::dimensions; + static constexpr memory_scope fence_scope = memory_scope::work_item; + + /* -- common by-value interface members -- */ + + id_type id() const noexcept { return m_parent_group.get_local_id(); } + linear_id_type linear_id() const noexcept { return m_parent_group.get_local_linear_id(); } + + range_type range() const noexcept { return m_parent_group.get_local_range(); } + + // constexpr extents_type extents() const noexcept; // C++23 + // constexpr extents_type::index_type extent(extents_type::rank_type r) const noexcept; // C++23 + + // static constexpr extents_type::rank_type rank() noexcept; // C++23 + // static constexpr extents_type::rank_type rank_dynamic() noexcept; // C++23 + // static constexpr size_t static_extent(rank_type r) noexcept; // C++23 + + constexpr size_type size() const noexcept { return 1; } + + private: + ParentGroup m_parent_group; + member_item(ParentGroup g) noexcept : m_parent_group(g) {} + + linear_id_type get_local_linear_id() const noexcept { return m_parent_group.get_local_linear_id(); } + + template + friend member_item get_member_item(Group g) noexcept; + template + friend bool leader_of(Group g) noexcept; +}; + +template +class work_group { + public: + using id_type = sycl::id; + using linear_id_type = size_t; + using range_type = sycl::range; + // using extents_type = std::dextents; // C++23 + using size_type = size_t; + static constexpr int dimensions = Dimensions; + static constexpr memory_scope fence_scope = memory_scope::work_group; + + work_group(group g) noexcept : m_group(g) {} + + operator group() const noexcept { return m_group; } + + /* -- common by-value interface members -- */ + + id_type id() const noexcept { return m_group.get_group_id(); } + linear_id_type linear_id() const noexcept { return m_group.get_group_linear_id(); } + + range_type range() const noexcept { return m_group.get_group_range(); } + + // extents_type extents() const noexcept; // C++23 + // extents_type::index_type extent(extents_type::rank_type r) const noexcept; // C++23 + + // static constexpr extents_type::rank_type rank() noexcept; // C++23 + // static constexpr extents_type::rank_type rank_dynamic() noexcept; // C++23 + // static constexpr size_t static_extent(rank_type r) noexcept; // C++23 + + size_type size() const noexcept { return m_group.get_local_range().size(); } + + private: + group m_group; + + id_type get_local_id() const noexcept { return m_group.get_local_id(); } + linear_id_type get_local_linear_id() const noexcept { return m_group.get_local_linear_id(); } + range_type get_local_range() const noexcept { return m_group.get_local_range(); } + friend class member_item; + template + friend bool leader_of(Group g) noexcept; +}; + +class sub_group { + public: + using id_type = sycl::id<1>; + using linear_id_type = uint32_t; + using range_type = sycl::range<1>; + // using extents_type = std::dextents; // C++23 + using size_type = uint32_t; + static constexpr int dimensions = 1; + static constexpr memory_scope fence_scope = memory_scope::sub_group; + + sub_group(sycl::sub_group sg) noexcept : m_sub_group(sg) {} + + operator sycl::sub_group() const noexcept { return m_sub_group; } + + /* -- common by-value interface members -- */ + + id_type id() const noexcept { return m_sub_group.get_group_id(); } + linear_id_type linear_id() const noexcept { return m_sub_group.get_group_linear_id(); } + + range_type range() const noexcept { return m_sub_group.get_group_range(); } + + // extents_type extents() const noexcept; // C++23 + // extents_type::index_type extent(extents_type::rank_type r) const noexcept; // C++23 + + // static constexpr extents_type::rank_type rank() noexcept; // C++23 + // static constexpr extents_type::rank_type rank_dynamic() noexcept; // C++23 + // static constexpr size_t static_extent(rank_type r) noexcept; // C++23 + + size_type size() const noexcept { return m_sub_group.get_local_range().size(); } + size_type max_size() const noexcept { return m_sub_group.get_max_local_range().size(); } + + private: + sycl::sub_group m_sub_group; + + id_type get_local_id() const noexcept { return m_sub_group.get_local_id(); } + linear_id_type get_local_linear_id() const noexcept { return m_sub_group.get_local_linear_id(); } + range_type get_local_range() const noexcept { return m_sub_group.get_local_range(); } + friend class member_item; + template + friend bool leader_of(Group g) noexcept; +}; + +template +member_item get_member_item(Group g) noexcept { + return member_item(g); +} + +template +bool leader_of(Group g) noexcept { + return g.get_local_linear_id() == 0; +} + +} // namespace simsycl::sycl::khr diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3b95dd7..457b23c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -22,6 +22,7 @@ add_executable(tests simulation_tests.cc alloc_tests.cc vec_tests.cc + extensions/khr_group_interface_tests.cc ) add_sycl_to_target(TARGET tests SIMSYCL_ALL_WARNINGS) diff --git a/test/extensions/khr_group_interface_tests.cc b/test/extensions/khr_group_interface_tests.cc new file mode 100644 index 0000000..5d5fa2b --- /dev/null +++ b/test/extensions/khr_group_interface_tests.cc @@ -0,0 +1,272 @@ +#include +#include + +#include + +namespace group_interface::tests { + +template +static void test_work_group(sycl::nd_item it) { + sycl::khr::work_group work_group{it.get_group()}; + sycl::group group{it.get_group()}; + + // id + static_assert(std::is_same_v::id_type>); + CHECK(group.get_group_id() == work_group.id()); + + // linear_id + static_assert( + std::is_same_v::linear_id_type>); + CHECK(group.get_group_linear_id() == work_group.linear_id()); + + // range + static_assert(std::is_same_v::range_type>); + CHECK(group.get_group_range() == work_group.range()); + +#if __cplusplus >= 202302L + // extents + static_assert( + std::is_same_v::extents_type>); + { + const sycl::range localRange = group.get_local_range(); + if constexpr(Dimensions == 1) + CHECK(work_group.extents() == std::dextents(localRange[0])); + else if constexpr(Dimensions == 2) + CHECK(work_group.extents() == std::dextents(localRange[0], localRange[1])); + else if constexpr(Dimensions == 3) + CHECK(work_group.extents() + == std::dextents(localRange[0], localRange[1], localRange[2])); + } + + // extent + static_assert(std::is_same_v::extents_type::index_type>); + for(int i = 0; i < Dimensions; i++) CHECK(work_group.extent(i) == work_group.extents().extent(i)); + + // rank + static_assert(std::is_same_v::rank()), + typename sycl::khr::work_group::extents_type::rank_type>); + static_assert(decltype(work_group)::rank() == decltype(work_group.extents())::rank()); + + // rank_dynamic + static_assert(std::is_same_v::rank_dynamic()), + typename sycl::khr::work_group::extents_type::rank_type>); + static_assert(decltype(work_group)::rank_dynamic() == decltype(work_group.extents())::rank_dynamic()); + + // static_extent + static_assert(std::is_same_v::static_extent(0)), std::size_t>); + static_assert(decltype(work_group)::static_extent(0) == decltype(work_group.extents())::static_extent(0)); + if constexpr(Dimensions >= 2) + static_assert(decltype(work_group)::static_extent(1) == decltype(work_group.extents())::static_extent(1)); + if constexpr(Dimensions == 3) + static_assert(decltype(work_group)::static_extent(2) == decltype(work_group.extents())::static_extent(2)); +#endif + + // size + static_assert(std::is_same_v::size_type>); + CHECK(group.get_local_linear_range() == work_group.size()); + + // leader_of + static_assert(std::is_same_v); + CHECK(group.leader() == sycl::khr::leader_of(work_group)); +} + +template +static void test_sub_group(sycl::nd_item it) { + sycl::khr::sub_group sub_group{it.get_sub_group()}; + sycl::sub_group group{it.get_sub_group()}; + + // id + static_assert(std::is_same_v); + CHECK(group.get_group_id() == sub_group.id()); + + // linear_id + static_assert(std::is_same_v); + CHECK(group.get_group_linear_id() == sub_group.linear_id()); + + // range + static_assert(std::is_same_v); + CHECK(group.get_group_range() == sub_group.range()); + +#if __cplusplus >= 202302L + // extents + static_assert(std::is_same_v); + CHECK(sub_group.extents() == std::dextents(group.get_local_linear_range())); + + // extent + static_assert( + std::is_same_v); + CHECK(sub_group.extent(0) == sub_group.extents().extent(0)); + + // rank + static_assert( + std::is_same_v); + static_assert(decltype(sub_group)::rank() == decltype(sub_group.extents())::rank()); + + // rank_dynamic + static_assert(std::is_same_v); + static_assert(decltype(sub_group)::rank_dynamic() == decltype(sub_group.extents())::rank_dynamic()); + + // static_extent + static_assert(std::is_same_v); + static_assert(decltype(sub_group)::static_extent(0) == decltype(sub_group.extents())::static_extent(0)); +#endif + + // size + static_assert(std::is_same_v); + CHECK(group.get_local_range()[0] == sub_group.size()); + + // max_size + static_assert(std::is_same_v); + CHECK(group.get_max_local_range()[0] == sub_group.max_size()); + + // leader_of + static_assert(std::is_same_v); + CHECK(group.leader() == sycl::khr::leader_of(sub_group)); +} + +template +static void test_work_item_group(sycl::nd_item it) { + sycl::group group{it.get_group()}; + sycl::khr::work_group work_group{group}; + sycl::khr::member_item item{sycl::khr::get_member_item(work_group)}; + + // id + static_assert(std::is_same_v>::id_type>); + CHECK(group.get_local_id() == item.id()); + + // linear_id + static_assert(std::is_same_v>::linear_id_type>); + CHECK(group.get_local_linear_id() == item.linear_id()); + + // range + static_assert(std::is_same_v>::range_type>); + CHECK(group.get_local_range() == item.range()); + +#if __cplusplus >= 202302L + // extents + static_assert(std::is_same_v>::extents_type>); + if constexpr(Dimensions == 1) + CHECK(item.extents() == std::extents()); + else if constexpr(Dimensions == 2) + CHECK(item.extents() == std::extents()); + else if constexpr(Dimensions == 3) + CHECK(item.extents() == std::extents()); + + // extent + static_assert(std::is_same_v>::extents_type::index_type>); + for(int i = 0; i < Dimensions; i++) CHECK(item.extent(i) == item.extents().extent(i)); + + // rank + static_assert(std::is_same_v>::rank()), + typename sycl::khr::member_item>::extents_type::rank_type>); + static_assert(decltype(item)::rank() == decltype(item.extents())::rank()); + + // rank_dynamic + static_assert(std::is_same_v>::rank_dynamic()), + typename sycl::khr::member_item>::extents_type::rank_type>); + static_assert(decltype(item)::rank_dynamic() == decltype(item.extents())::rank_dynamic()); + + // static_extent + static_assert(std::is_same_v>::static_extent(0)), + std::size_t>); + static_assert(decltype(item)::static_extent(0) == decltype(item.extents())::static_extent(0)); + if constexpr(Dimensions >= 2) + static_assert(decltype(item)::static_extent(1) == decltype(item.extents())::static_extent(1)); + if constexpr(Dimensions == 3) + static_assert(decltype(item)::static_extent(2) == decltype(item.extents())::static_extent(2)); +#endif + + // size + static_assert(std::is_same_v>::size_type>); + CHECK(1 == item.size()); +} + +template +static void test_work_item_subgroup(sycl::nd_item it) { + sycl::sub_group group{it.get_sub_group()}; + sycl::khr::sub_group sub_group{group}; + sycl::khr::member_item item{sycl::khr::get_member_item(sub_group)}; + + // id + static_assert(std::is_same_v::id_type>); + CHECK(group.get_local_id() == item.id()); + + // linear_id + static_assert(std::is_same_v::linear_id_type>); + CHECK(group.get_local_linear_id() == item.linear_id()); + + // range + static_assert( + std::is_same_v::range_type>); + CHECK(group.get_local_range() == item.range()); + +#if __cplusplus >= 202302L + // extents + static_assert( + std::is_same_v::extents_type>); + CHECK(item.extents() == std::extents()); + + // extent + static_assert(std::is_same_v::extents_type::index_type>); + CHECK(item.extent(0) == item.extents().extent(0)); + + // rank + static_assert(std::is_same_v::rank()), + typename sycl::khr::member_item::extents_type::rank_type>); + static_assert(decltype(item)::rank() == decltype(item.extents())::rank()); + + // rank_dynamic + static_assert(std::is_same_v::rank_dynamic()), + typename sycl::khr::member_item::extents_type::rank_type>); + static_assert(decltype(item)::rank_dynamic() == decltype(item.extents())::rank_dynamic()); + + // static_extent + static_assert(std::is_same_v>::static_extent(0)), + std::size_t>); + static_assert(decltype(item)::static_extent(0) == decltype(item.extents())::static_extent(0)); +#endif + + // size + static_assert( + std::is_same_v::size_type>); + CHECK(1 == item.size()); +} + +template +static void test_group_interface(sycl::nd_range nd_range) { + sycl::queue q{}; + q.submit([&](sycl::handler &cgh) { + cgh.parallel_for(nd_range, [=](sycl::nd_item it) { + test_work_group(it); + test_sub_group(it); + test_work_item_group(it); + test_work_item_subgroup(it); + }); + }); +} + +TEST_CASE("the group interface extension defines the SYCL_KHR_GROUP_INTERFACE " + "macro", + "[khr_group_interface]") { +#ifndef SYCL_KHR_GROUP_INTERFACE + static_assert(false, "SYCL_KHR_GROUP_INTERFACE is not defined"); +#endif +} + +TEST_CASE("khr group interface extension work_group, sub_group, and work_item classes", "[khr_group_interface]") { + test_group_interface(sycl::nd_range<1>(10, 2)); + test_group_interface(sycl::nd_range<2>({6, 6}, {2, 3})); + test_group_interface(sycl::nd_range<3>({2, 4, 6}, {1, 2, 3})); +} + +} // namespace group_interface::tests