@@ -633,5 +633,95 @@ def pass_reference(v):
633
633
self .assertTrue (np .array_equal (rvecf , np .array ([1.0 , 4.0 ])))
634
634
635
635
636
+ class NumbaDeclareInferred (unittest .TestCase ):
637
+ """
638
+ Test decorator created with a reconstructed list of arguments using RDF column types,
639
+ and a return type inferred from the numba jitted function.
640
+ """
641
+
642
+ def test_fund_types (self ):
643
+ """
644
+ Test fundamental types
645
+ """
646
+ df = ROOT .RDataFrame (4 ).Define ("x" , "rdfentry_" )
647
+
648
+ with self .subTest ("function" ):
649
+ def is_even (x ):
650
+ return x % 2 == 0
651
+ df = df .Define ("is_even_x_1" , is_even , ["x" ])
652
+ results = df .Take ["bool" ]("is_even_x_1" ).GetValue ()[0 ]
653
+ self .assertEqual (results , True )
654
+
655
+ with self .subTest ("lambda" ):
656
+ df = df .Define ("is_even_x_2" , lambda x : x % 2 == 0 , ["x" ])
657
+ results = df .Take ["bool" ]("is_even_x_2" ).GetValue ()[0 ]
658
+ self .assertEqual (results , True )
659
+
660
+ def test_rvec (self ):
661
+ """
662
+ Test RVec
663
+ """
664
+ df = ROOT .RDataFrame (4 ).Define ("x" , "ROOT::VecOps::RVec<int>({1, 2, 3})" )
665
+
666
+ with self .subTest ("function" ):
667
+ def square_rvec (v ):
668
+ return v * v
669
+ df = df .Define ("square_rvec_1" , square_rvec , ["x" ])
670
+ results = df .Take ["RVec<int>" ]("square_rvec_1" ).GetValue ()[0 ]
671
+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
672
+
673
+ with self .subTest ("lambda" ):
674
+ df = df .Define ("square_rvec_2" , lambda v : v * v , ["x" ])
675
+ results = df .Take ["RVec<int>" ]("square_rvec_2" ).GetValue ()[0 ]
676
+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
677
+
678
+ def test_std_vec (self ):
679
+ """
680
+ Test std::vector
681
+ """
682
+ df = ROOT .RDataFrame (4 ).Define ("x" , "std::vector<int>({1, 2, 3})" )
683
+
684
+ with self .subTest ("function" ):
685
+ def square_std_vec (v ):
686
+ return v * v
687
+ df = df .Define ("square_std_vec_1" , square_std_vec , ["x" ])
688
+ results = df .Take ["RVec<int>" ]("square_std_vec_1" ).GetValue ()[0 ]
689
+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
690
+
691
+ with self .subTest ("lambda" ):
692
+ df = df .Define ("square_std_vec_2" , lambda v : v * v , ["x" ])
693
+ results = df .Take ["RVec<int>" ]("square_std_vec_2" ).GetValue ()[0 ]
694
+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
695
+
696
+ def test_std_array (self ):
697
+ """
698
+ Test std::array
699
+ """
700
+ df = ROOT .RDataFrame (4 ).Define ("x" , "std::array<int, 3>({1, 2, 3})" )
701
+
702
+ with self .subTest ("function" ):
703
+ def square_std_arr (v ):
704
+ return v * v
705
+ df = df .Define ("square_std_arr_1" , square_std_arr , ["x" ])
706
+ results = df .Take ["RVec<int>" ]("square_std_arr_1" ).GetValue ()[0 ]
707
+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
708
+
709
+ with self .subTest ("lambda" ):
710
+ df = df .Define ("square_std_arr_2" , lambda v : v * v , ["x" ])
711
+ results = df .Take ["RVec<int>" ]("square_std_arr_2" ).GetValue ()[0 ]
712
+ self .assertTrue (np .array_equal (results , np .array ([1 , 4 , 9 ])))
713
+
714
+ def test_missing_signature_raises (self ):
715
+ """
716
+ Ensure an Exception is raised when return type cannot be inferred
717
+ and no explicit signature is provided in the decorator.
718
+ """
719
+ def f (x ):
720
+ return x .M ()
721
+
722
+ with self .assertRaises (Exception ):
723
+ ROOT .RDataFrame (4 ).Define ("v" , "ROOT::Math::PtEtaPhiMVector(1, 2, 3, 4)" ).Define ("m" , f , ["v" ])
724
+
725
+
636
726
if __name__ == "__main__" :
637
727
unittest .main ()
0 commit comments