diff --git a/tests/test_reductions.cxx b/tests/test_reductions.cxx index 12f6a58e..4fc7fa39 100644 --- a/tests/test_reductions.cxx +++ b/tests/test_reductions.cxx @@ -1,5 +1,6 @@ #include +#include #include #include @@ -9,18 +10,44 @@ using namespace gt::placeholders; -TEST(reductions, sum_axis_to_2d) +template +struct reductions_config { - gt::gtensor a({{11., 21., 31.}, {12., 22., 32.}}); + using value_type = T; + using space_type = S; +}; + +template +class Reductions : public testing::Test +{}; + +// FIXME, would be nice to test gt::complex, too, but that doesn't +// compile right now +using reduction_configs = + ::testing::Types, + reductions_config +#ifdef GTENSOR_HAVE_DEVICE + , + reductions_config, + reductions_config +#endif + >; +TYPED_TEST_SUITE(Reductions, reduction_configs); + +TYPED_TEST(Reductions, sum_axis_to_2d) +{ + using T = typename TypeParam::value_type; + using S = typename TypeParam::space_type; + gt::gtensor a({{11., 21., 31.}, {12., 22., 32.}}); GT_DEBUG_VAR(a.shape()); - gt::gtensor asum0(gt::shape(2)); - gt::gtensor asum1(gt::shape(3)); + gt::gtensor asum0(gt::shape(2)); + gt::gtensor asum1(gt::shape(3)); sum_axis_to(asum0, a, 0); - EXPECT_EQ(asum0, (gt::gtensor{63., 66.})); + EXPECT_EQ(asum0, (gt::gtensor{63., 66.})); sum_axis_to(asum1, a, 1); - EXPECT_EQ(asum1, (gt::gtensor{23., 43., 63.})); + EXPECT_EQ(asum1, (gt::gtensor{23., 43., 63.})); } TEST(reductions, sum_axis_to_2d_view) @@ -62,24 +89,6 @@ TEST(reductions, sum_axis_to_3d_view_2d) #ifdef GTENSOR_HAVE_DEVICE -TEST(reductions, device_sum_axis_to_2d) -{ - gt::gtensor_device a({{11., 21., 31.}, {12., 22., 32.}}); - GT_DEBUG_VAR(a.shape()); - - gt::gtensor_device asum0(gt::shape(2)); - gt::gtensor_device asum1(gt::shape(3)); - gt::gtensor h_asum0(gt::shape(2)); - gt::gtensor h_asum1(gt::shape(3)); - - sum_axis_to(asum0, a, 0); - gt::copy(asum0, h_asum0); - EXPECT_EQ(h_asum0, (gt::gtensor{63., 66.})); - sum_axis_to(asum1, a, 1); - gt::copy(asum1, h_asum1); - EXPECT_EQ(h_asum1, (gt::gtensor{23., 43., 63.})); -} - TEST(reductions, device_sum_axis_to_3d_view_2d) { gt::gtensor_device a({{{11., 21., 31.}, {12., 22., 32.}},