diff --git a/PyRDF/CallableGenerator.py b/PyRDF/CallableGenerator.py index 56e7f65..dfc66c8 100644 --- a/PyRDF/CallableGenerator.py +++ b/PyRDF/CallableGenerator.py @@ -33,7 +33,7 @@ def get_action_nodes(self, node_py=None): # current PyRDF node as the head node node_py = self.head_node else: - if node_py.operation.is_action(): + if node_py.operation.is_action() or node_py.operation.is_info(): # Collect all action nodes in order to return them return_nodes.append(node_py) @@ -96,7 +96,7 @@ def mapper(node_cpp, node_py=None): # recursive call parent_node = pyroot_node - if node_py.operation.is_action(): + if node_py.operation.is_action() or node_py.operation.is_info(): # Collect all action nodes in order to return them return_vals.append(pyroot_node) diff --git a/PyRDF/Node.py b/PyRDF/Node.py index 9e0aefe..f7109d5 100644 --- a/PyRDF/Node.py +++ b/PyRDF/Node.py @@ -62,6 +62,7 @@ def __init__(self, get_head, operation, *args): self.children = [] self._new_op_name = "" self.value = None + self.ResultPtr = None self.pyroot_node = None self.has_user_references = True @@ -101,7 +102,29 @@ def __setstate__(self, state): else: self.operation = None - def is_prunable(self): + def is_prunable_action(self): + """ + Checks if an action node can be pruned. Action nodes whose value was + already computed can be pruned. + + Returns: + bool: True if the node is an action and its value was already + computed, False otherwise. + """ + return self.operation and self.operation.is_action() and self.value + + def is_prunable_info(self): + """ + Checks if an info node can be pruned. Info nodes whose ResultPtr was + already assigned can be pruned. + + Returns: + bool: True if the node is an action and its value was already + computed, False otherwise. + """ + return self.operation and self.operation.is_info() and self.ResultPtr + + def is_prunable_node(self): """ Checks whether the current node can be pruned from the computational graph. @@ -113,7 +136,7 @@ def is_prunable(self): if not self.children: # Every pruning condition is written on a separate line if not self.has_user_references or \ - (self.operation and self.operation.is_action() and self.value): + self.is_prunable_action() or self.is_prunable_info(): # ***** Condition 1 ***** # If the node is wrapped by a proxy which is not directly @@ -123,6 +146,11 @@ def is_prunable(self): # If the current node's value was already # computed, it should get pruned only if it's # an Action node. + + # ***** Condition 3 ***** + # If the current node's value was already + # computed, it should get pruned only if it's + # an Info operation. return True return False @@ -144,4 +172,4 @@ def graph_prune(self): children.append(n) self.children = children - return self.is_prunable() + return self.is_prunable_node() diff --git a/PyRDF/Operation.py b/PyRDF/Operation.py index 69068e8..ddedcc7 100644 --- a/PyRDF/Operation.py +++ b/PyRDF/Operation.py @@ -32,7 +32,7 @@ class Operation(object): print(PyRDF.current_backend.supported_operations) """ - Types = Enum("Types", "ACTION TRANSFORMATION INSTANT_ACTION") + Types = Enum("Types", "ACTION TRANSFORMATION INSTANT_ACTION INFO") def __init__(self, name, *args, **kwargs): """ @@ -80,7 +80,11 @@ def _classify_operation(self, name): 'Take': ops.ACTION, 'Graph': ops.ACTION, 'Snapshot': ops.INSTANT_ACTION, - 'Foreach': ops.INSTANT_ACTION + 'Foreach': ops.INSTANT_ACTION, + 'GetColumnNames': ops.INFO, + 'GetDefinedColumnNames': ops.INFO, + 'GetColumnType': ops.INFO, + 'GetFilterNames': ops.INFO, } op_type = operations_dict.get(name) @@ -107,3 +111,13 @@ def is_transformation(self): False otherwise. """ return self.op_type == Operation.Types.TRANSFORMATION + + def is_info(self): + """ + Checks if the current operation is an info operation. + + Returns: + bool: True if the current operation is a transformation, + False otherwise. + """ + return self.op_type == Operation.Types.INFO diff --git a/PyRDF/Proxy.py b/PyRDF/Proxy.py index c63c6bc..67ebb1a 100644 --- a/PyRDF/Proxy.py +++ b/PyRDF/Proxy.py @@ -77,7 +77,7 @@ def GetValue(self): from PyRDF import current_backend if not self.proxied_node.value: # If event-loop not triggered generator = CallableGenerator(self.proxied_node.get_head()) - current_backend.execute(generator) + current_backend.execute(generator, trigger_loop=True) return self.proxied_node.value @@ -123,10 +123,10 @@ def _create_new_op(self, *args, **kwargs): Handles an operation call to the current node and returns the new node built using the operation call. """ + from PyRDF import current_backend # Create a new `Operation` object for the # incoming operation call op = Operation(self.proxied_node._new_op_name, *args, **kwargs) - # Create a new `Node` object to house the operation newNode = Node(operation=op, get_head=self.proxied_node.get_head) @@ -136,5 +136,13 @@ def _create_new_op(self, *args, **kwargs): # Return the appropriate proxy object for the node if op.is_action(): return ActionProxy(newNode) - else: + elif op.is_transformation(): return TransformationProxy(newNode) + else: + try: + generator = CallableGenerator(self.proxied_node.get_head()) + current_backend.execute(generator, trigger_loop=False) + except TypeError as e: + self.proxied_node.children.remove(newNode) + raise e + return newNode.ResultPtr diff --git a/PyRDF/backend/Backend.py b/PyRDF/backend/Backend.py index 69ff786..b3e015b 100644 --- a/PyRDF/backend/Backend.py +++ b/PyRDF/backend/Backend.py @@ -40,6 +40,10 @@ class Backend(ABC): 'Foreach', 'Reduce', 'Aggregate', + 'GetColumnNames', + 'GetDefinedColumnNames', + 'GetColumnType', + 'GetFilterNames', 'Graph' ] @@ -93,7 +97,7 @@ def check_supported(self, operation_name): ) @abstractmethod - def execute(self, generator): + def execute(self, generator, trigger_loop=False): """ Subclasses must define how to run the RDataFrame graph on a given environment. diff --git a/PyRDF/backend/Dist.py b/PyRDF/backend/Dist.py index 99f10a2..decd393 100644 --- a/PyRDF/backend/Dist.py +++ b/PyRDF/backend/Dist.py @@ -353,7 +353,7 @@ def _get_friend_info(self, tree): return FriendInfo(friend_names, friend_file_names) - def execute(self, generator): + def execute(self, generator, trigger_loop=True): """ Executes the current RDataFrame graph in the given distributed environment. @@ -526,7 +526,7 @@ def reducer(values_list1, values_list2): warnings.warn(msg, UserWarning, stacklevel=2) PyRDF.use("local") from .. import current_backend - return current_backend.execute(generator) + return current_backend.execute(generator, trigger_loop=True) # Values produced after Map-Reduce values = self.ProcessAndMerge(mapper, reducer) diff --git a/PyRDF/backend/Local.py b/PyRDF/backend/Local.py index 04d0b58..80cd816 100644 --- a/PyRDF/backend/Local.py +++ b/PyRDF/backend/Local.py @@ -38,7 +38,7 @@ def __init__(self, config={}): if op not in operations_not_supported] self.pyroot_rdf = None - def execute(self, generator): + def execute(self, generator, trigger_loop=False): """ Executes locally the current RDataFrame graph. @@ -52,8 +52,7 @@ def execute(self, generator): # if the RDataFrame has not been created yet or if a new one # is created by the user in the same session - if (not self.pyroot_rdf) or \ - (self.pyroot_rdf is not generator.head_node): + if not self.pyroot_rdf or self.pyroot_rdf is not generator.head_node: self.pyroot_rdf = ROOT.ROOT.RDataFrame(*generator.head_node.args) values = mapper(self.pyroot_rdf) # Execute the mapper function @@ -61,13 +60,13 @@ def execute(self, generator): # Get the action nodes in the same order as values nodes = generator.get_action_nodes() - values[0].GetValue() # Trigger event-loop - - for i in range(len(values)): + for node, value in zip(nodes, values): # Set the obtained values and # 'RResultPtr's of action nodes - nodes[i].value = values[i].GetValue() + if trigger_loop and hasattr(value, 'GetValue'): + # Info actions do not have GetValue + node.value = value.GetValue() # We store the 'RResultPtr's because, # those should be in scope while doing # a 'GetValue' call on them - nodes[i].ResultPtr = values[i] + node.ResultPtr = value diff --git a/tests/integration/local/test_info_operations.py b/tests/integration/local/test_info_operations.py new file mode 100644 index 0000000..a26825b --- /dev/null +++ b/tests/integration/local/test_info_operations.py @@ -0,0 +1,69 @@ +import unittest +import PyRDF +import ROOT + + +class InfoOperationsLocalTest(unittest.TestCase): + """ + Check that Info operations return the expected result rather than a proxy. + """ + + def test_GetColumnNames(self): + """ + GetColumnNames returns ROOT string vector without running the event + loop. + """ + rdf = PyRDF.RDataFrame(1) + d = rdf.Define('a', 'rdfentry_').Define('b', 'a*a') + + column_names = d.GetColumnNames() + expected_columns = ROOT.std.vector('string')() + expected_columns.push_back("a") + expected_columns.push_back("b") + + for column, expected in zip(column_names, expected_columns): + self.assertEqual(column, expected) + + def test_GetColumnType(self): + """ + GetColumnType returns the type of a given column as a string. + """ + rdf = PyRDF.RDataFrame(1) + d = rdf.Define('a', 'rdfentry_').Define('b', 'a*a') + + a_typename = d.GetColumnType('a') + b_typename = d.GetColumnType('b') + expected_type = 'ULong64_t' + + self.assertEqual(a_typename, expected_type) + self.assertEqual(b_typename, expected_type) + + def test_GetDefinedColumnNames(self): + """ + GetDefinedColumnNames returns the names of the defined columns. + """ + rdf = PyRDF.RDataFrame(1) + d = rdf.Define('a', 'rdfentry_').Define('b', 'a*a') + + column_names = d.GetColumnNames() + expected_columns = ROOT.std.vector('string')() + expected_columns.push_back("a") + expected_columns.push_back("b") + + for column, expected in zip(column_names, expected_columns): + self.assertEqual(column, expected) + + def test_GetFilterNames(self): + """ + GetFilterNames returns the names of the filters created. + """ + rdf = PyRDF.RDataFrame(1) + filter_name = 'custom_filter' + d = rdf.Filter('rdfentry_ > 1', filter_name) + + filters = d.GetFilterNames() + expected_filters = ROOT.std.vector('string')() + expected_filters.push_back(filter_name) + + for f, expected in zip(filters, expected_filters): + self.assertEqual(f, expected) diff --git a/tests/unit/test_operation.py b/tests/unit/test_operation.py index e658181..bbadcbf 100644 --- a/tests/unit/test_operation.py +++ b/tests/unit/test_operation.py @@ -20,6 +20,11 @@ def test_transformation(self): op = Operation("Define", "c1") self.assertEqual(op.op_type, Operation.Types.TRANSFORMATION) + def test_info_action(self): + """Info nodes are classified accurately.""" + op = Operation("GetColumnNames") + self.assertEqual(op.op_type, Operation.Types.INFO) + def test_none(self): """Incorrect operations raise an Exception.""" with self.assertRaises(Exception): diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py index edd9e14..c997caf 100644 --- a/tests/unit/test_proxy.py +++ b/tests/unit/test_proxy.py @@ -94,6 +94,7 @@ def test_node_attr_transformation(self): "children", "_new_op_name", "value", + "ResultPtr", "pyroot_node", "has_user_references" ] @@ -145,7 +146,7 @@ class TestBackend(Backend): Test backend to verify the working of 'GetValue' instance method in Proxy. """ - def execute(self, generator): + def execute(self, generator, trigger_loop): """ Test implementation of the execute method for 'TestBackend'. This records the head