Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions PyRDF/backend/Local.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,11 @@ def execute(self, generator):
# those should be in scope while doing
# a 'GetValue' call on them
nodes[i].ResultPtr = values[i]

@staticmethod
def RunGraphs(proxies, concurrent_runs=4):
"""
Trigger the execution of multiple RDataFrame computation graphs. Not
implemented in the Local backend.
"""
raise NotImplementedError
50 changes: 47 additions & 3 deletions PyRDF/backend/Spark.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from __future__ import print_function

import ntpath # Filename from path (should be platform-independent)
import threading

from PyRDF.backend.Dist import Dist
from PyRDF.backend.Utils import Utils
from pyspark import SparkConf, SparkContext
from pyspark import SparkFiles
import ntpath # Filename from path (should be platform-independent)
from pyspark import SparkConf, SparkContext, SparkFiles

try:
import queue
except ImportError:
import Queue as queue


class Spark(Dist):
Expand Down Expand Up @@ -112,6 +119,43 @@ def spark_mapper(current_range):
# Map-Reduce using Spark
return parallel_collection.map(spark_mapper).treeReduce(reducer)

@staticmethod
def RunGraphs(proxies, numthreads=4):
"""
Trigger multiple RDF graphs through multithreading, according to Spark
docs on `job scheduling <https://spark.apache.org/docs/latest/job-scheduling.html#scheduling-within-an-application>`_.

Args:
proxies(iterable): Action proxies that should be triggered. Only
actions belonging to different RDataFrame graphs will be
triggered to avoid useless calls.

numthreads(int, optional): Number of threads to spawn at the same
time. Each thread will submit a separate job to the Spark
cluster through the same SparkContext. Defaults to 4.
"""

# Create queue to store all the action proxies
q = queue.Queue()

for proxy in proxies:
q.put(proxy)

# Function to trigger the computation graph of each proxy in the queue
def trigger_loop(queue_):
while True:
queue_.get().GetValue()
queue_.task_done()

# Create `numthreads` threads that will each submit a Spark job
for _ in range(numthreads):
worker = threading.Thread(
target=trigger_loop, args=(q,), daemon=True)
worker.start()

# Start the execution and wait for all computations to finish
q.join()

def distribute_files(self, includes_list):
"""
Spark supports sending files to the executors via the
Expand Down
58 changes: 55 additions & 3 deletions PyRDF/backend/Utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import ROOT
import os
import logging
import multiprocessing
import os
import threading

try:
import queue
except ImportError:
import Queue as queue

import PyRDF
import ROOT

logger = logging.getLogger(__name__)

Expand All @@ -25,7 +34,7 @@ def extend_include_path(cls, include_path):

# Retrieve ROOT internal list of include paths and add debug statement
root_includepath = ROOT.gInterpreter.GetIncludePath()
logger.debug("ROOT include paths:\n{}".format(root_includepath))
logger.debug("ROOT include paths:\n%s", root_includepath)

@classmethod
def declare_headers(cls, headers_to_include):
Expand Down Expand Up @@ -69,3 +78,46 @@ def declare_shared_libraries(cls, libraries_to_include):
if not os.path.exists(shared_library):
raise IOError("Shared library does not exist!")
raise Exception("ROOT couldn't load the shared library!")

@classmethod
def RunGraphs(cls, proxies, concurrent_runs=4):
"""
Trigger the execution of multiple RDataFrame computation graphs on the
distributed backend in use. If the backend doesn't support multiple job
submissions concurrently, the distributed computation graphs will be
executed sequentially.

Args:
proxies(list): List of action proxies that should be triggered. Only
actions belonging to different RDataFrame graphs will be
triggered to avoid useless calls.

concurrent_runs(int, optional): Number of graphs that will be
executed concurrently in a distributed backend. Defaults to 4.

Example::

# Create 3 different dataframes and book an histogram on each one
histoproxies = [
PyRDF.RDataFrame(100)
.Define("x", "rdfentry_")
.Histo1D(("name", "title", 10, 0, 100), "x")
for _ in range(4)
]

# Execute the 3 computation graphs
PyRDF.backend.Utils.RunGraphs(histoproxies)

# Retrieve all the histograms in one go
histos = [histoproxy.GetValue() for histoproxy in histoproxies]
"""

# Get proxies belonging to distinct computation graphs
uniquegraphs = {proxy.proxied_node.get_head(): proxy
for proxy in proxies}.values()

try:
PyRDF.current_backend.RunGraphs(uniquegraphs, concurrent_runs)
except NotImplementedError:
for proxy in uniquegraphs:
proxy.GetValue()
75 changes: 74 additions & 1 deletion tests/unit/backend/test_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import array
import os
import unittest
import ROOT

import PyRDF
import pyspark
import ROOT
from PyRDF.backend.Utils import Utils


Expand Down Expand Up @@ -156,3 +160,72 @@ def defineIntVariable(name, value):
varvalue = 2
PyRDF.initialize(defineIntVariable, "myInt", varvalue)
self.assertEqual(ROOT.myInt, varvalue)


class RunGraphsTest(unittest.TestCase):
"""Tests for the concurrent submission of distributed jobs in PyRDF"""

def tearDown(self):
"""Clean up the `SparkContext` object that was created."""
pyspark.SparkContext.getOrCreate().stop()

def ttree_write(self, treename, filename, mean, std_dev):
"""Create a TTree and write it to file."""
f = ROOT.TFile(filename, "recreate")
t = ROOT.TTree(treename, "ConcurrentSparkJobsTest")

x = array.array("f", [0])
t.Branch("x", x, "x/F")

r = ROOT.TRandom()
# Fill the branch with a gaussian distribution
for _ in range(10000):
x[0] = r.Gaus(mean, std_dev)
t.Fill()

f.Write()
f.Close()

def test_rungraphs_local(self):
"""Test RunGraphs with Local backend"""
PyRDF.use("local")

counts = [PyRDF.RDataFrame(10).Count() for _ in range(4)]

Utils.RunGraphs(counts)

for count in counts:
self.assertEqual(count.GetValue(), 10)

def test_rungraphs_spark_3histos(self):
"""
Create three datasets to run some simple analysis on, then submit them
concurrently as Spark jobs from different threads.
"""
PyRDF.use("spark")

treenames = ["tree{}".format(i) for i in range(1, 4)]
filenames = ["file{}.root".format(i) for i in range(1, 4)]
means = [10, 20, 30]
std_devs = [1, 2, 3]

for treename, filename, mean, std_dev in zip(
treenames, filenames, means, std_devs):
self.ttree_write(treename, filename, mean, std_dev)

histoproxies = [PyRDF.RDataFrame(treename, filename)
.Histo1D(("x", "x", 100, 0, 50), "x")
for treename, filename in zip(treenames, filenames)]

Utils.RunGraphs(histoproxies)

delta_equal = 0.1

for histo, mean, std_dev in zip(histoproxies, means, std_devs):
self.assertEqual(histo.GetEntries(), 10000)
self.assertAlmostEqual(histo.GetMean(), mean, delta=delta_equal)
self.assertAlmostEqual(
histo.GetStdDev(), std_dev, delta=delta_equal)

for filename in filenames:
os.remove(filename)
9 changes: 9 additions & 0 deletions tests/unit/backend/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,12 @@ def init(value):
df = PyRDF.RDataFrame(1)
s = df.Define("userValue", "getUserValue()").Sum("userValue")
self.assertEqual(s.GetValue(), 123)


class MiscTest(unittest.TestCase):
"""Miscellaneous tests for Local backend."""

def test_rungraphs_notimplemented(self):
"""Check that RunGraphs is not implemented"""
with self.assertRaises(NotImplementedError):
Local.RunGraphs([PyRDF.RDataFrame(10).Count() for _ in range(4)])
6 changes: 5 additions & 1 deletion tests/unit/backend/test_spark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import array
import os
import unittest

import PyRDF
from PyRDF.backend.Spark import Spark
import ROOT
from PyRDF.backend.Local import Local
from PyRDF.backend.Spark import Spark
from pyspark import SparkContext


Expand Down