Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions bindings/pyroot/pythonizations/python/ROOT/_numbadeclare.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ def _NumbaDeclareDecorator(input_types, return_type=None, name=None):
"match_pattern": r"(?:ROOT::)?(?:VecOps::)?RVec\w+|(?:ROOT::)?(?:VecOps::)?RVec<[\w\s]+>",
"cpp_name": ["ROOT::RVec", "ROOT::VecOps::RVec"],
},
"std::vector": {
"match_pattern": r"std::vector<[\w\s]+>",
"vector": {
"match_pattern": r"(?:std::)?vector<[\w\s]+>",
"cpp_name": ["std::vector"],
},
"std::array": {
"match_pattern": r"std::array<[\w\s,<>]+>",
"array": {
"match_pattern": r"(?:std::)?array<[\w\s,<>]+>",
"cpp_name": ["std::array"],
},
}
Expand Down Expand Up @@ -233,7 +233,6 @@ def inner(func, input_types=input_types, return_type=return_type, name=name):
"""
Inner decorator without arguments, see outer decorator for documentation
"""

# Jit the given Python callable with numba
try:
nb_return_type, nb_input_types = get_numba_signature(input_types, return_type)
Expand All @@ -255,6 +254,13 @@ def inner(func, input_types=input_types, return_type=return_type, name=name):
"See https://cppyy.readthedocs.io/en/latest/numba.html#numba-support"
)
nbjit = nb.jit(nopython=True, inline="always")(func)
# In this case, the user has to explictly provide the return type, cannot be inferred
if return_type is None:
raise RuntimeError(
"Failed to infer the return type for the provided function. "
"Please specify the signature explicitly in the decorator, e.g.: "
"@ROOT.NumbaDeclare(['double'], 'double')"
)
except: # noqa E722
raise Exception("Failed to jit Python callable {} with numba.jit".format(func))
func.numba_func = nbjit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
function_cache = {}
lambda_function_counter = 0 # Counter to name the lambda functions

def __init__(self, rdf: "RDataFrame") -> None:

Check failure on line 48 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py:48:30: F821 Undefined name `RDataFrame`
self.rdf = rdf
self.col_names: typing.List[str] = rdf.GetColumnNames()
self.func: typing.Callable
Expand All @@ -71,7 +71,7 @@
"""
try:
import numpy as np
except:

Check failure on line 74 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E722)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py:74:9: E722 Do not use bare `except`
raise ImportError("Failed to import numpy during call to determine function signature.")
from ._rdf_conversion_maps import FUNDAMENTAL_PYTHON_TYPES, NUMPY_TO_TREE, TREE_TO_NUMBA

Expand All @@ -81,17 +81,14 @@
t = self.rdf.GetColumnType(x)
if t in TREE_TO_NUMBA: # The column is a fundamental type from tree
return TREE_TO_NUMBA[t]
elif "<" in t: # The column type is a RVec<type>
if ">>" in t: # It is a RVec<RVec<T>>
raise TypeError(
f"Only columns with 'RVec<T>' where T is is a fundamental type are supported, not '{t}'."
)
g = re.match("(.*)<(.*)>", t).groups(0)
if g[1] in TREE_TO_NUMBA:
return "RVec<" + TREE_TO_NUMBA[g[1]] + ">"
# There are data type that leak into here. Not sure from where. But need to implement something here such that this condition is never met.
return "RVec<" + str(g[1]) + ">"

match = re.match(r"([\w:]+)<(.+)>", t)
if match:
container_type, inner_type = match.groups()
container_type = container_type.strip()
inner_type = inner_type.strip()
inner_mapped = TREE_TO_NUMBA.get(inner_type, inner_type)
return f"{container_type}<{inner_mapped}>"
else:
return t
else:
Expand Down Expand Up @@ -173,7 +170,7 @@
value_of_p = func_args[p]
type_of_p = self.find_type(value_of_p)
# Bool(s) in python are represented as True/False but in C++ are true/false. The following if statements are to account for that
if type(value_of_p) == bool:

Check failure on line 173 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E721)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py:173:20: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
if value_of_p:
value_of_p = "true"
else:
Expand Down Expand Up @@ -377,9 +374,9 @@

import cppyy

is_cpp_functor = lambda: isinstance(getattr(func, "__call__", None), cppyy._backend.CPPOverload)

Check failure on line 377 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py:377:5: E731 Do not assign a `lambda` expression, use a `def`

is_std_function = lambda: isinstance(getattr(func, "target_type", None), cppyy._backend.CPPOverload)

Check failure on line 379 in bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py:379:5: E731 Do not assign a `lambda` expression, use a `def`

# handle free functions
if callable(func) and not is_cpp_functor() and not is_std_function():
Expand Down
90 changes: 90 additions & 0 deletions bindings/pyroot/pythonizations/test/numbadeclare.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,5 +633,95 @@ def pass_reference(v):
self.assertTrue(np.array_equal(rvecf, np.array([1.0, 4.0])))


class NumbaDeclareInferred(unittest.TestCase):
"""
Test decorator created with a reconstructed list of arguments using RDF column types,
and a return type inferred from the numba jitted function.
"""

def test_fund_types(self):
"""
Test fundamental types
"""
df = ROOT.RDataFrame(4).Define("x", "rdfentry_")

with self.subTest("function"):
def is_even(x):
return x % 2 == 0
df = df.Define("is_even_x_1", is_even, ["x"])
results = df.Take["bool"]("is_even_x_1").GetValue()[0]
self.assertEqual(results, True)

with self.subTest("lambda"):
df = df.Define("is_even_x_2", lambda x: x % 2 == 0, ["x"])
results = df.Take["bool"]("is_even_x_2").GetValue()[0]
self.assertEqual(results, True)

def test_rvec(self):
"""
Test RVec
"""
df = ROOT.RDataFrame(4).Define("x", "ROOT::VecOps::RVec<int>({1, 2, 3})")

with self.subTest("function"):
def square_rvec(v):
return v*v
df = df.Define("square_rvec_1", square_rvec, ["x"])
results = df.Take["RVec<int>"]("square_rvec_1").GetValue()[0]
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))

with self.subTest("lambda"):
df = df.Define("square_rvec_2", lambda v: v*v, ["x"])
results = df.Take["RVec<int>"]("square_rvec_2").GetValue()[0]
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))

def test_std_vec(self):
"""
Test std::vector
"""
df = ROOT.RDataFrame(4).Define("x", "std::vector<int>({1, 2, 3})")

with self.subTest("function"):
def square_std_vec(v):
return v*v
df = df.Define("square_std_vec_1", square_std_vec, ["x"])
results = df.Take["RVec<int>"]("square_std_vec_1").GetValue()[0]
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))

with self.subTest("lambda"):
df = df.Define("square_std_vec_2", lambda v: v*v, ["x"])
results = df.Take["RVec<int>"]("square_std_vec_2").GetValue()[0]
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))

def test_std_array(self):
"""
Test std::array
"""
df = ROOT.RDataFrame(4).Define("x", "std::array<int, 3>({1, 2, 3})")

with self.subTest("function"):
def square_std_arr(v):
return v*v
df = df.Define("square_std_arr_1", square_std_arr, ["x"])
results = df.Take["RVec<int>"]("square_std_arr_1").GetValue()[0]
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))

with self.subTest("lambda"):
df = df.Define("square_std_arr_2", lambda v: v*v, ["x"])
results = df.Take["RVec<int>"]("square_std_arr_2").GetValue()[0]
self.assertTrue(np.array_equal(results, np.array([1, 4, 9])))

def test_missing_signature_raises(self):
"""
Ensure an Exception is raised when return type cannot be inferred
and no explicit signature is provided in the decorator.
"""
def f(x):
return x.M()

with self.assertRaises(Exception):
ROOT.RDataFrame(4).Define("v", "ROOT::Math::PtEtaPhiMVector(1, 2, 3, 4)").Define("m", f, ["v"])


if __name__ == "__main__":
unittest.main()
Loading