Skip to content

Commit 0054da5

Browse files
yaxanUbuntu
andauthored
centml.compile + performance prediction backend (#70)
* Implements centml.compile module for remote compilation or prediction * New centml_prediction_backend for predicting inference time and exporting to prometheus * New class for profiling / stepping through graph module * Adds scripts for prediction data collection and a sample script for running prediction workflow * Adds sample prediction data for A10G and A100 --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-46-81.us-east-2.compute.internal>
1 parent 80cc0b3 commit 0054da5

File tree

12 files changed

+775
-0
lines changed

12 files changed

+775
-0
lines changed

centml/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .compile import compile
2+
3+
__all__ = ["compile"]

centml/compile.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import builtins
2+
from typing import Callable, Dict, Optional, Union
3+
4+
import torch
5+
6+
from centml.compiler.backend import centml_dynamo_backend
7+
from centml.compiler.config import OperationMode, settings
8+
from centml.compiler.prediction.backend import centml_prediction_backend, get_gauge
9+
10+
11+
def compile(
12+
model: Optional[Callable] = None,
13+
*,
14+
fullgraph: builtins.bool = False,
15+
dynamic: Optional[builtins.bool] = None,
16+
mode: Union[str, None] = None,
17+
options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
18+
disable: builtins.bool = False,
19+
) -> Callable:
20+
21+
if settings.CENTML_MODE == OperationMode.REMOTE_COMPILATION:
22+
# Return the remote-compiled model
23+
compiled_model = torch.compile(
24+
model,
25+
backend=centml_dynamo_backend, # Compilation backend
26+
fullgraph=fullgraph,
27+
dynamic=dynamic,
28+
mode=mode,
29+
options=options,
30+
disable=disable,
31+
)
32+
return compiled_model
33+
elif settings.CENTML_MODE == OperationMode.PREDICTION:
34+
# Proceed with prediction workflow
35+
compiled_model = torch.compile(
36+
model,
37+
backend=centml_prediction_backend, # Prediction backend
38+
fullgraph=fullgraph,
39+
dynamic=dynamic,
40+
mode=mode,
41+
options=options,
42+
disable=disable,
43+
)
44+
45+
def centml_wrapper(*args, **kwargs):
46+
out = compiled_model(*args, **kwargs)
47+
# Update the prometheus metrics with final values
48+
gauge = get_gauge()
49+
for gpu in settings.CENTML_PREDICTION_GPUS.split(','):
50+
gauge.set_metric_value(gpu)
51+
52+
return out
53+
54+
return centml_wrapper
55+
else:
56+
raise Exception("Invalid operation mode")

centml/compiler/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ class CompilationStatus(Enum):
99
DONE = "DONE"
1010

1111

12+
class OperationMode(Enum):
13+
PREDICTION = "PREDICTION"
14+
REMOTE_COMPILATION = "REMOTE_COMPILATION"
15+
16+
1217
class Config(BaseSettings):
1318
CENTML_COMPILER_TIMEOUT: int = 10
1419
CENTML_COMPILER_MAX_RETRIES: int = 3
@@ -31,5 +36,10 @@ class Config(BaseSettings):
3136
# If the server response is smaller than this, don't gzip it
3237
CENTML_MINIMUM_GZIP_SIZE: int = 1000
3338

39+
CENTML_MODE: OperationMode = OperationMode.REMOTE_COMPILATION
40+
CENTML_PREDICTION_DATA_FILE: str = 'tests/sample_data.csv'
41+
CENTML_PREDICTION_GPUS: str = "A10G,A100SXM440GB"
42+
CENTML_PROMETHEUS_PORT: int = 8000
43+
3444

3545
settings = Config()
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import List
2+
3+
import torch
4+
from torch._subclasses.fake_tensor import FakeTensorMode
5+
6+
from centml.compiler.config import settings
7+
from centml.compiler.prediction.kdtree import get_tree_db
8+
from centml.compiler.prediction.metric import get_gauge
9+
from centml.compiler.prediction.profiler import Profiler
10+
11+
12+
def centml_prediction_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
13+
profilers = []
14+
tree_db = get_tree_db()
15+
for gpu in settings.CENTML_PREDICTION_GPUS.split(','):
16+
profilers.append(Profiler(gm, gpu, tree_db))
17+
18+
def forward(*args):
19+
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
20+
fake_args = [fake_mode.from_tensor(arg) if isinstance(arg, torch.Tensor) else arg for arg in args]
21+
with fake_mode:
22+
for prof in profilers:
23+
out, t = prof.propagate(*fake_args)
24+
gauge = get_gauge()
25+
gauge.increment(prof.gpu, t)
26+
return out
27+
28+
return forward
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import ast
2+
import csv
3+
import logging
4+
5+
from sklearn.neighbors import KDTree # type: ignore
6+
7+
from centml.compiler.config import settings
8+
9+
_tree_db = None
10+
11+
12+
class KDTreeWithValues:
13+
def __init__(self, points=None, values=None):
14+
self.points = points if points else []
15+
self.values = values if values else []
16+
if self.points:
17+
self.tree = KDTree(self.points)
18+
else:
19+
self.tree = None
20+
21+
def add(self, point, value):
22+
self.points.append(point)
23+
self.values.append(value)
24+
self.tree = KDTree(self.points)
25+
26+
def query(self, point):
27+
if self.tree is None:
28+
return None, None
29+
30+
dist, idx = self.tree.query([point], k=1)
31+
return dist[0][0], self.values[idx[0][0]]
32+
33+
34+
class TreeDB:
35+
def __init__(self, data_csv):
36+
self.db = {}
37+
self._populate_db(data_csv)
38+
39+
def get(self, key, inp):
40+
if key not in self.db:
41+
logging.getLogger(__name__).warning(f"Key {key} not found in database")
42+
return float('-inf')
43+
# TODO: Handle the case of unfound keys better. For now, return -inf to indicate something went wrong.
44+
# Ideally, we shouldn't throw away a whole prediction because of one possibly insignificant node.
45+
46+
_, val = self.db[key].query(inp)
47+
return val
48+
49+
def _add_from_db(self, key, points, values):
50+
self.db[key] = KDTreeWithValues(points, values)
51+
52+
def _populate_db(self, data_csv):
53+
with open(data_csv, newline='') as f:
54+
reader = csv.DictReader(f)
55+
for row in reader:
56+
try:
57+
key = (row['op'], int(row['dim']), row['inp_dtypes'], row['out_dtypes'], row['gpu'])
58+
points = ast.literal_eval(row['points'])
59+
values = ast.literal_eval(row['values'])
60+
self._add_from_db(key, points, values)
61+
except ValueError as e:
62+
logging.getLogger(__name__).exception(f"Error parsing row: {row}\n{e}")
63+
64+
65+
def get_tree_db():
66+
global _tree_db
67+
if _tree_db is None:
68+
_tree_db = TreeDB(settings.CENTML_PREDICTION_DATA_FILE)
69+
return _tree_db
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import time
2+
3+
from prometheus_client import Gauge, start_http_server
4+
5+
from centml.compiler.config import settings
6+
7+
_gauge = None
8+
9+
10+
def get_gauge():
11+
global _gauge
12+
if _gauge is None:
13+
_gauge = GaugeMetric()
14+
return _gauge
15+
16+
17+
class GaugeMetric:
18+
def __init__(self):
19+
start_http_server(settings.CENTML_PROMETHEUS_PORT)
20+
self._gauge = Gauge('execution_time_microseconds', 'Kernel execution times by GPU', ['gpu', 'timestamp'])
21+
self._values = {}
22+
23+
def increment(self, gpu_name, value):
24+
if gpu_name not in self._values:
25+
self._values[gpu_name] = 0
26+
self._values[gpu_name] += value
27+
28+
def set_metric_value(self, gpu_name):
29+
self._gauge.labels(gpu=gpu_name, timestamp=time.time()).set(self._values[gpu_name])
30+
self._values[gpu_name] = 0
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from typing import Dict
2+
3+
import torch
4+
import torch.fx
5+
from torch.fx.node import Node
6+
7+
8+
class Profiler:
9+
def __init__(self, mod, gpu, treeDB, data_collection_mode=False):
10+
self.mod = mod
11+
self.graph = mod.graph
12+
self.modules = dict(self.mod.named_modules())
13+
self.tree_db = treeDB
14+
self.gpu = gpu
15+
self.data_collection_mode = data_collection_mode
16+
17+
def propagate(self, *args):
18+
args_iter = iter(args)
19+
env: Dict[str, Node] = {}
20+
total_time = 0
21+
22+
def load_arg(a):
23+
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
24+
25+
def fetch_attr(target: str):
26+
target_atoms = target.split('.')
27+
attr_itr = self.mod
28+
for i, atom in enumerate(target_atoms):
29+
if not hasattr(attr_itr, atom):
30+
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
31+
attr_itr = getattr(attr_itr, atom)
32+
return attr_itr
33+
34+
def get_flattened_shapes(args):
35+
flattened_shapes = []
36+
dtypes = []
37+
38+
for arg in args:
39+
if isinstance(arg, (tuple, list)):
40+
if len(arg) > 0 and isinstance(arg[0], (tuple, list, torch.Tensor)):
41+
nested_shapes, nested_dtypes = get_flattened_shapes(arg[0])
42+
shape = [len(arg)] + nested_shapes
43+
dtypes.extend(nested_dtypes.split(','))
44+
else:
45+
shape = [len(arg)]
46+
elif isinstance(arg, torch.Tensor):
47+
shape = list(arg.shape)
48+
dtypes.append(str(arg.dtype))
49+
elif isinstance(arg, bool):
50+
shape = [1 if arg is True else 0]
51+
elif isinstance(arg, (int, float)):
52+
shape = [arg]
53+
else:
54+
shape = [1]
55+
flattened_shapes.extend(shape)
56+
57+
if len(flattened_shapes) < 2:
58+
flattened_shapes.extend([1])
59+
60+
input_dtypes = ','.join(dtypes) if dtypes else 'N/A'
61+
62+
return flattened_shapes, input_dtypes
63+
64+
def get_output_dtypes(results):
65+
def find_dtypes(results):
66+
if isinstance(results, torch.Tensor):
67+
return [str(results.dtype)]
68+
if isinstance(results, (list, tuple)):
69+
dtypes = []
70+
for item in results:
71+
dtypes.extend(find_dtypes(item))
72+
return dtypes
73+
return []
74+
75+
types = find_dtypes(results)
76+
77+
if types:
78+
return ','.join(types)
79+
return 'N/A'
80+
81+
def get_time_or_profile(key, inp_shapes, operation, *args, **kwargs):
82+
t = self.tree_db.get(key, inp_shapes)
83+
84+
if self.data_collection_mode and t is None:
85+
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
86+
result = operation(*args, **kwargs)
87+
event_time_total = 0
88+
for event in prof.key_averages():
89+
event_time_total += event.cuda_time_total
90+
t = event_time_total
91+
self.tree_db.add(key, inp_shapes, t)
92+
93+
return t
94+
95+
for node in self.graph.nodes:
96+
result = None
97+
if node.op == 'placeholder':
98+
result = next(args_iter)
99+
elif node.op == 'get_attr':
100+
result = fetch_attr(node.target)
101+
elif node.op == 'call_function':
102+
args = load_arg(node.args)
103+
kwargs = load_arg(node.kwargs)
104+
result = node.target(*args, **kwargs)
105+
106+
inp_shapes, input_dtypes = get_flattened_shapes(args)
107+
output_dtypes = get_output_dtypes(result)
108+
109+
key = (node.target.__name__, len(inp_shapes), input_dtypes, output_dtypes, self.gpu)
110+
111+
t = get_time_or_profile(key, inp_shapes, node.target, *args, **kwargs)
112+
113+
total_time += t
114+
elif node.op == 'call_method':
115+
self_obj, *args = load_arg(node.args)
116+
kwargs = load_arg(node.kwargs)
117+
result = getattr(self_obj, node.target)(*args, **kwargs)
118+
119+
inp_shapes, input_dtypes = get_flattened_shapes(args)
120+
output_dtypes = get_output_dtypes(result)
121+
122+
key = (node.target, len(inp_shapes), input_dtypes, output_dtypes, self.gpu)
123+
124+
t = get_time_or_profile(key, inp_shapes, getattr(self_obj, node.target), *args, **kwargs)
125+
126+
total_time += t
127+
elif node.op == 'call_module':
128+
mod = self.modules[node.target]
129+
args = load_arg(node.args)
130+
kwargs = load_arg(node.kwargs)
131+
result = mod(*args, **kwargs)
132+
133+
inp_shapes, input_dtypes = get_flattened_shapes(args)
134+
135+
param_shapes = [param.shape for name, param in mod.named_parameters()]
136+
param_dtypes = [str(param.dtype) for name, param in mod.named_parameters()]
137+
flattened_params = [dim for shape in param_shapes for dim in shape]
138+
139+
inp_shapes = inp_shapes + flattened_params
140+
input_dtypes = input_dtypes + ',' + ','.join(param_dtypes)
141+
142+
output_dtypes = get_output_dtypes(result)
143+
144+
key = (mod._get_name(), len(inp_shapes), input_dtypes, output_dtypes, self.gpu)
145+
146+
t = get_time_or_profile(key, inp_shapes, mod, *args, **kwargs)
147+
148+
total_time += t
149+
elif node.op == 'output':
150+
args = load_arg(node.args)
151+
return args[0], total_time
152+
153+
env[node.name] = result

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ parameterized>=0.9.0
1111
mypy==1.5.1
1212
types-requests==2.31.0.2
1313
types-tabulate>=0.9.0
14+
prometheus-client>=0.20.0
15+

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ Requests==2.32.2
77
tabulate>=0.9.0
88
pyjwt>=2.8.0
99
cryptography==42.0.8
10+
prometheus-client>=0.20.0
1011
scipy>=1.6.0
12+
scikit-learn>=1.5.1
1113
platform_api_client @ git+https://github.com/CentML/platform_api_python_client.git@main

0 commit comments

Comments
 (0)