Skip to content

Commit a1f7f3a

Browse files
committed
[Python][RDF] Add tests
1 parent fd33615 commit a1f7f3a

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

bindings/pyroot/pythonizations/test/numbadeclare.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,5 +633,95 @@ def pass_reference(v):
633633
self.assertTrue(np.array_equal(rvecf, np.array([1.0, 4.0])))
634634

635635

636+
class NumbaDeclareInferred(unittest.TestCase):
637+
"""
638+
Test decorator created with a reconstructed list of arguments using RDF column types,
639+
and a return type inferred from the numba jitted function.
640+
"""
641+
642+
def test_fund_types(self):
643+
"""
644+
Test fundamental types
645+
"""
646+
df = ROOT.RDataFrame(4).Define("x", "rdfentry_")
647+
648+
with self.subTest("function"):
649+
def is_even(x):
650+
return x % 2 == 0
651+
df = df.Define("is_even_x_1", is_even, ["x"])
652+
results = df.Take["bool"]("is_even_x_1").GetValue()[0]
653+
self.assertEqual(results, True)
654+
655+
with self.subTest("lambda"):
656+
df = df.Define("is_even_x_2", lambda x: x % 2 == 0, ["x"])
657+
results = df.Take["bool"]("is_even_x_2").GetValue()[0]
658+
self.assertEqual(results, True)
659+
660+
def test_rvec(self):
661+
"""
662+
Test RVec
663+
"""
664+
df = ROOT.RDataFrame(4).Define("x", "ROOT::VecOps::RVec<int>({1, 2, 3})")
665+
666+
with self.subTest("function"):
667+
def square_rvec(v):
668+
return v*v
669+
df = df.Define("square_rvec_1", square_rvec, ["x"])
670+
results = df.Take["RVec<int>"]("square_rvec_1").GetValue()[0]
671+
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))
672+
673+
with self.subTest("lambda"):
674+
df = df.Define("square_rvec_2", lambda v: v*v, ["x"])
675+
results = df.Take["RVec<int>"]("square_rvec_2").GetValue()[0]
676+
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))
677+
678+
def test_std_vec(self):
679+
"""
680+
Test std::vector
681+
"""
682+
df = ROOT.RDataFrame(4).Define("x", "std::vector<int>({1, 2, 3})")
683+
684+
with self.subTest("function"):
685+
def square_std_vec(v):
686+
return v*v
687+
df = df.Define("square_std_vec_1", square_std_vec, ["x"])
688+
results = df.Take["RVec<int>"]("square_std_vec_1").GetValue()[0]
689+
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))
690+
691+
with self.subTest("lambda"):
692+
df = df.Define("square_std_vec_2", lambda v: v*v, ["x"])
693+
results = df.Take["RVec<int>"]("square_std_vec_2").GetValue()[0]
694+
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))
695+
696+
def test_std_array(self):
697+
"""
698+
Test std::array
699+
"""
700+
df = ROOT.RDataFrame(4).Define("x", "std::array<int, 3>({1, 2, 3})")
701+
702+
with self.subTest("function"):
703+
def square_std_arr(v):
704+
return v*v
705+
df = df.Define("square_std_arr_1", square_std_arr, ["x"])
706+
results = df.Take["RVec<int>"]("square_std_arr_1").GetValue()[0]
707+
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))
708+
709+
with self.subTest("lambda"):
710+
df = df.Define("square_std_arr_2", lambda v: v*v, ["x"])
711+
results = df.Take["RVec<int>"]("square_std_arr_2").GetValue()[0]
712+
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))
713+
714+
def test_missing_signature_raises(self):
715+
"""
716+
Ensure an Exception is raised when return type cannot be inferred
717+
and no explicit signature is provided in the decorator.
718+
"""
719+
def f(x):
720+
return x.M()
721+
722+
with self.assertRaises(Exception):
723+
ROOT.RDataFrame(4).Define("v", "ROOT::Math::PtEtaPhiMVector(1, 2, 3, 4)").Define("m", f, ["v"])
724+
725+
636726
if __name__ == "__main__":
637727
unittest.main()

0 commit comments

Comments
 (0)