From ef93ce788721a969d886d3bdefaf1ac8960c3f85 Mon Sep 17 00:00:00 2001 From: yrl <2535184404@qq.com> Date: Thu, 4 Dec 2025 17:57:01 +0800 Subject: [PATCH 1/6] incubate_ncu_do_bench_func --- python/triton/_flagtree_tools.py | 321 +++++++++++++++++++++++++++++++ 1 file changed, 321 insertions(+) create mode 100644 python/triton/_flagtree_tools.py diff --git a/python/triton/_flagtree_tools.py b/python/triton/_flagtree_tools.py new file mode 100644 index 000000000..ed51a2428 --- /dev/null +++ b/python/triton/_flagtree_tools.py @@ -0,0 +1,321 @@ +import subprocess +import os +import textwrap +import inspect +import tempfile +import csv +from io import StringIO +from pathlib import Path +import contextlib +import math +import statistics +import triton.runtime as runtime + + +def flagtree_do_bench(fn, warmup=25, rep=100, quantiles=None, return_mode="mean"): + bench = FlagtreeBench(warmup=warmup, rep=rep, quantiles=quantiles, return_mode=return_mode) + bench.do_bench(fn=fn) + return bench._get_index() + + +''' + function _quantile and _summarize_statistics is from .testing. + if used directly, it will lead to circular dependencies +''' + + +def _quantile(a, q): + n = len(a) + a = sorted(a) + + def get_quantile(q): + if not (0 <= q <= 1): + raise ValueError("Quantiles must be in the range [0, 1]") + point = q * (n - 1) + lower = math.floor(point) + upper = math.ceil(point) + t = point - lower + return (1 - t) * a[lower] + t * a[upper] + + return [get_quantile(q) for q in q] + + +def _summarize_statistics(times, quantiles, return_mode): + if quantiles is not None: + ret = _quantile(times, quantiles) + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times + elif return_mode == "min": + return min(times) + elif return_mode == "max": + return max(times) + elif return_mode == "mean": + return statistics.mean(times) + elif return_mode == "median": + return statistics.median(times) + + +''' + IndentedBuffer Referred to + https://github.com/flagos-ai/FlagGems/blob/master/src/flag_gems/utils/code_utils.py::IndentedBuffer +''' + + +class IndentedBuffer: + tabwidth = 4 + + def __init__(self, initial_indent=0): + self._lines = [] + self._indent = initial_indent + + def getvalue(self) -> str: + buf = StringIO() + for line in self._lines: + assert isinstance(line, str) + buf.write(line) + buf.write("\n") + return buf.getvalue() + + def clear(self): + self._lines.clear() + + def __bool__(self): + return bool(self._lines) + + def prefix(self): + return " " * (self._indent * self.tabwidth) + + def newline(self): + self.writeline("\n") + + def writeline(self, line): + if line.strip(): + self._lines.append(f"{self.prefix()}{line}") + else: + self._lines.append("") + + def tpl(self, format_str, **kwargs): + assert isinstance(format_str, str), "format_str must be string of type." + format_str = format_str.format(**kwargs) + lines = format_str.strip().splitlines() + for line in lines: + line = line.replace("\t", " " * self.tabwidth) + self.writeline(line) + + def writelines(self, lines): + for line in lines: + self.writeline(line) + + def writemultiline(self, s): + self.writelines(s.splitlines()) + + def indent(self, offset=1): + + @contextlib.contextmanager + def ctx(): + self._indent += offset + try: + yield + finally: + self._indent -= offset + + return ctx() + + +''' + FlagtreeBench using ncu to measure performance +''' + + +class FlagtreeBench: + + def __init__(self, warmup=100, rep=100, quantiles=None, return_mode="mean", metrics='gpu__time_duration'): + if FlagtreeBench.check_ncu(): + self.metrics = metrics + self.warmup = warmup + self.rep = rep + self.quantiles = quantiles + self.return_mode = return_mode + self.function_paths = [] + self.import_modules = [] + self._get_package_path() + self._create_temp_file() + + staticmethod + + def check_ncu(): + cmd = ["ncu", "--query-metrics"] + try: + subprocess.run(cmd, capture_output=True, check=True) + print("[INFO]: ncu check successfully") + return True + except Exception as err_msg: + print(f"[Hint] The inability to invoke ncu on this machine" + f"might be due to issues such as the absence of ncu, " + f"lack of permissions, or a version that is too low. Specifically {err_msg}") + return False + + @staticmethod + def is_triton_jit_decorated(obj, max_depth=10): + ''' + attrs temporarily adds specialized support to the kernel of flag_gems. + About flag_gems see https://github.com/flagos-ai/FlagGems + ''' + attrs = ['AnonymousLibTunerImpl', 'LibEntry', 'JITFunction'] + if hasattr(obj, '__class__') and obj.__class__.__name__ in attrs: + return True + + def _get_kernels(self, _fn): + import ast + source = inspect.getsource(_fn) + tree = ast.parse(source) + globals_dict = _fn.__globals__ + calls = [] + + class CallVisitor(ast.NodeVisitor): + + def visit_Call(self, node): + if isinstance(node.func, ast.Name): + calls.append({'name': node.func.id}) + elif isinstance(node.func, ast.Attribute): + calls.append({'name': node.func.attr}) + self.generic_visit(node) + + visitor = CallVisitor() + visitor.visit(tree) + jit_funcs = [] + for call in calls: + name = call['name'] + if name not in globals_dict: + continue + entity = globals_dict[call['name']] + if callable(entity): + module = __import__(entity.__module__) + else: + module = entity + _path = module.__file__ + self.function_paths.append(_path) + module_name = _path.split('/')[-1] + module_name = Path(module_name).stem + self.import_modules.append((name, module_name)) + for name in dir(module): + if name.startswith('__'): + continue + obj = getattr(module, name) + if FlagtreeBench.is_triton_jit_decorated(obj=obj): + jit_funcs.append(name) + self.triton_funcs = jit_funcs + + def _get_package_path(self): + self.user_package_path = os.environ.get('BENCH_MODULE_PATH', '') + + def _create_temp_file(self): + self.python_exec = tempfile.NamedTemporaryFile(delete=False, suffix=".py") + self.python_exec.close() + + self.out_csv = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") + self.out_csv.close() + + def _write_script(self, script): + with open(self.python_exec.name, 'w+') as f: + f.write(script) + + def _exec(self): + runtime.driver.active.clear_cache(self.bench_cache) + cmd = [ + "ncu", "--metrics", self.metrics, "--csv", "--log-file", self.out_csv.name, "python3", self.python_exec.name + ] + print(f"[INFO]: ncu running on {self.python_exec.name}") + subprocess.run(cmd, capture_output=True, check=True) + + def _get_index(self): + # indexs = ['avg', 'max', 'min', 'sum'] + _index_package = {} + kernel_name = '' + with open(self.out_csv.name, newline="", encoding="utf-8") as f: + reader = csv.reader(f) + for row in reader: + for jit in self.triton_funcs: + if jit in row: + index_name = row[12].split('.')[-1] + index_val = float(row[14]) / 1e6 + kernel_name = jit + if jit not in _index_package: + _index_package.update({jit: {index_name: index_val}}) + else: + _index_package[jit].update({index_name: index_val}) + return _index_package[kernel_name]['avg'] + + def _gen_import_and_path(self, script_code: IndentedBuffer, path_mode='insert'): + sys_path_action_str = '0, ' + if path_mode == 'insert': + script_code.writeline('import torch') + script_code.writeline('import os') + script_code.writeline('import sys') + else: + sys_path_action_str = '' + if self.user_package_path != '': + script_code.writeline(f"sys.path.{path_mode}({sys_path_action_str}'{self.user_package_path}')") + for path in self.function_paths: + if not os.path.isdir(path): + path = os.path.dirname(path) + script_code.writeline(f"sys.path.{path_mode}({sys_path_action_str}'{path}')") + if path_mode == 'insert': + for module_message in self.import_modules: + fn, module = module_message + script_code.writeline(f"from {module} import {fn}") + + def _generate_script(self, fn): + fn_src_code_string = textwrap.dedent(inspect.getsource(fn)) + script_code = IndentedBuffer() + self._gen_import_and_path(script_code, path_mode='insert') + + script_code.writeline(fn_src_code_string) + script_code.writeline(f'{fn.__name__}()') + script_code.writeline("torch.cuda.synchronize()") + + self._gen_import_and_path(script_code, path_mode='remove') + self.script = script_code.getvalue() + self._write_script(self.script) + + def _pre_operation(self, fn): + ''' + Referred to triton.testing.do_bench + ''' + di = runtime.driver.active.get_device_interface() + fn() + di.synchronize() + cache = runtime.driver.active.get_empty_cache_for_benchmark() + + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + runtime.driver.active.clear_cache(cache) + fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(self.warmup / estimate_ms)) + # n_repeat = max(1, int(self.rep / estimate_ms)) + + self.bench_cache = cache + for _ in range(n_warmup): + fn() + + def do_bench(self, fn): + ''' + Measure the GPU kernel time of fn() using ncu. + Generate a temporary Python file and then run it with 'ncu'. + ''' + self._get_kernels(fn) + self._generate_script(fn=fn) + self._pre_operation(fn=fn) + self._exec() + self.index_set = self._get_index() From 4d7d62785773ae9639c2b0c564974a7ee1aadd0c Mon Sep 17 00:00:00 2001 From: yrl <2535184404@qq.com> Date: Fri, 5 Dec 2025 18:03:06 +0800 Subject: [PATCH 2/6] update --- python/triton/_flagtree_tools.py | 167 ++++++++++++++++--------------- 1 file changed, 85 insertions(+), 82 deletions(-) diff --git a/python/triton/_flagtree_tools.py b/python/triton/_flagtree_tools.py index ed51a2428..905eb174d 100644 --- a/python/triton/_flagtree_tools.py +++ b/python/triton/_flagtree_tools.py @@ -3,19 +3,18 @@ import textwrap import inspect import tempfile -import csv from io import StringIO -from pathlib import Path import contextlib import math +import ast import statistics import triton.runtime as runtime def flagtree_do_bench(fn, warmup=25, rep=100, quantiles=None, return_mode="mean"): - bench = FlagtreeBench(warmup=warmup, rep=rep, quantiles=quantiles, return_mode=return_mode) - bench.do_bench(fn=fn) - return bench._get_index() + bench = FlagtreeBench(current_fn=fn, warmup=warmup, rep=rep, quantiles=quantiles, return_mode=return_mode) + bench.do_bench() + return 0 ''' @@ -132,20 +131,21 @@ def ctx(): class FlagtreeBench: - def __init__(self, warmup=100, rep=100, quantiles=None, return_mode="mean", metrics='gpu__time_duration'): + def __init__(self, current_fn, warmup=100, rep=100, quantiles=None, return_mode="mean", + metrics='gpu__time_duration'): if FlagtreeBench.check_ncu(): + self._current_fn = current_fn self.metrics = metrics self.warmup = warmup self.rep = rep self.quantiles = quantiles self.return_mode = return_mode - self.function_paths = [] - self.import_modules = [] + self.triton_funcs = [] self._get_package_path() self._create_temp_file() + print(self.python_exec.file, self.out_csv.file) - staticmethod - + @staticmethod def check_ncu(): cmd = ["ncu", "--query-metrics"] try: @@ -159,55 +159,71 @@ def check_ncu(): return False @staticmethod - def is_triton_jit_decorated(obj, max_depth=10): + def gather_triton_jit_kernel(mod): ''' attrs temporarily adds specialized support to the kernel of flag_gems. About flag_gems see https://github.com/flagos-ai/FlagGems ''' + if FlagtreeBench.is_from_sitepackages(mod): + return set() + + kernels = set() attrs = ['AnonymousLibTunerImpl', 'LibEntry', 'JITFunction'] - if hasattr(obj, '__class__') and obj.__class__.__name__ in attrs: - return True + for node in dir(mod): + if node.startswith('__'): + continue + obj = getattr(mod, node) + if hasattr(obj, '__class__') and obj.__class__.__name__ in attrs: + kernels.add(node) + return kernels + + @staticmethod + def is_from_sitepackages(mod): + return 'site-packages' in mod.__file__ - def _get_kernels(self, _fn): - import ast + def _get_current_function_used_mod(self, _fn=None): + _fn = _fn or self._current_fn + func_global_dict = _fn.__globals__ source = inspect.getsource(_fn) tree = ast.parse(source) - globals_dict = _fn.__globals__ - calls = [] - - class CallVisitor(ast.NodeVisitor): + modules = set() + calls = set() + deps_path = set() + triton_jit_kernels = set() + + class Visitor(ast.NodeVisitor): + + def visit_Attribute(self, node): + if isinstance(node.value, ast.Name): + mod_name = node.value.id + mod_instance = func_global_dict[mod_name] + if hasattr(mod_instance, '__file__'): + mod_dir_path = os.path.dirname(mod_instance.__file__) + deps_path.add(mod_dir_path) + modules.add(mod_name) + self.generic_visit(node) def visit_Call(self, node): if isinstance(node.func, ast.Name): - calls.append({'name': node.func.id}) + fun_name = node.func.id + func_instance = func_global_dict[fun_name] + mod_instance = __import__(func_instance.__module__) + triton_jit_kernels.update(FlagtreeBench.gather_triton_jit_kernel(mod_instance)) + if hasattr(mod_instance, '__file__'): + mod_dir_path = os.path.dirname(mod_instance.__file__) + deps_path.add(mod_dir_path) + calls.add((fun_name, mod_instance.__name__)) + elif isinstance(node.func, ast.Attribute): - calls.append({'name': node.func.attr}) + fun_name = node.func.attr + if isinstance(node.func.value, ast.Name): + mod = node.func.value.id + mod_instance = func_global_dict[mod] + triton_jit_kernels.update(FlagtreeBench.gather_triton_jit_kernel(mod_instance)) self.generic_visit(node) - visitor = CallVisitor() - visitor.visit(tree) - jit_funcs = [] - for call in calls: - name = call['name'] - if name not in globals_dict: - continue - entity = globals_dict[call['name']] - if callable(entity): - module = __import__(entity.__module__) - else: - module = entity - _path = module.__file__ - self.function_paths.append(_path) - module_name = _path.split('/')[-1] - module_name = Path(module_name).stem - self.import_modules.append((name, module_name)) - for name in dir(module): - if name.startswith('__'): - continue - obj = getattr(module, name) - if FlagtreeBench.is_triton_jit_decorated(obj=obj): - jit_funcs.append(name) - self.triton_funcs = jit_funcs + Visitor().visit(tree) + return (calls, modules, deps_path) def _get_package_path(self): self.user_package_path = os.environ.get('BENCH_MODULE_PATH', '') @@ -225,31 +241,16 @@ def _write_script(self, script): def _exec(self): runtime.driver.active.clear_cache(self.bench_cache) - cmd = [ - "ncu", "--metrics", self.metrics, "--csv", "--log-file", self.out_csv.name, "python3", self.python_exec.name - ] + cmd = ["ncu", "--metrics", self.metrics, "python3", self.python_exec.name] print(f"[INFO]: ncu running on {self.python_exec.name}") - subprocess.run(cmd, capture_output=True, check=True) + result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True, text=True) + self.clean_out = "\n".join(line for line in result.stdout.splitlines() if not line.startswith("==PROF==")) def _get_index(self): - # indexs = ['avg', 'max', 'min', 'sum'] - _index_package = {} - kernel_name = '' - with open(self.out_csv.name, newline="", encoding="utf-8") as f: - reader = csv.reader(f) - for row in reader: - for jit in self.triton_funcs: - if jit in row: - index_name = row[12].split('.')[-1] - index_val = float(row[14]) / 1e6 - kernel_name = jit - if jit not in _index_package: - _index_package.update({jit: {index_name: index_val}}) - else: - _index_package[jit].update({index_name: index_val}) - return _index_package[kernel_name]['avg'] + ... def _gen_import_and_path(self, script_code: IndentedBuffer, path_mode='insert'): + calls, modules, deps_path = self._get_current_function_used_mod() sys_path_action_str = '0, ' if path_mode == 'insert': script_code.writeline('import torch') @@ -259,34 +260,37 @@ def _gen_import_and_path(self, script_code: IndentedBuffer, path_mode='insert'): sys_path_action_str = '' if self.user_package_path != '': script_code.writeline(f"sys.path.{path_mode}({sys_path_action_str}'{self.user_package_path}')") - for path in self.function_paths: + for path in deps_path: if not os.path.isdir(path): path = os.path.dirname(path) script_code.writeline(f"sys.path.{path_mode}({sys_path_action_str}'{path}')") if path_mode == 'insert': - for module_message in self.import_modules: - fn, module = module_message - script_code.writeline(f"from {module} import {fn}") - - def _generate_script(self, fn): - fn_src_code_string = textwrap.dedent(inspect.getsource(fn)) + for mod in modules: + script_code.writeline(f'import {mod}') + for call, mod in calls: + script_code.writeline(f"from {mod} import {call}") + + def _generate_script(self, _fn=None): + _fn = _fn or self._current_fn + fn_src_code_string = textwrap.dedent(inspect.getsource(_fn)) script_code = IndentedBuffer() self._gen_import_and_path(script_code, path_mode='insert') script_code.writeline(fn_src_code_string) - script_code.writeline(f'{fn.__name__}()') + script_code.writeline(f'{_fn.__name__}()') script_code.writeline("torch.cuda.synchronize()") self._gen_import_and_path(script_code, path_mode='remove') self.script = script_code.getvalue() self._write_script(self.script) - def _pre_operation(self, fn): + def _pre_operation(self, _fn=None): ''' Referred to triton.testing.do_bench ''' + _fn = _fn or self._current_fn di = runtime.driver.active.get_device_interface() - fn() + _fn() di.synchronize() cache = runtime.driver.active.get_empty_cache_for_benchmark() @@ -296,26 +300,25 @@ def _pre_operation(self, fn): start_event.record() for _ in range(5): runtime.driver.active.clear_cache(cache) - fn() + _fn() end_event.record() di.synchronize() estimate_ms = start_event.elapsed_time(end_event) / 5 # compute number of warmup and repeat n_warmup = max(1, int(self.warmup / estimate_ms)) - # n_repeat = max(1, int(self.rep / estimate_ms)) self.bench_cache = cache for _ in range(n_warmup): - fn() + _fn() - def do_bench(self, fn): + def do_bench(self): ''' Measure the GPU kernel time of fn() using ncu. Generate a temporary Python file and then run it with 'ncu'. ''' - self._get_kernels(fn) - self._generate_script(fn=fn) - self._pre_operation(fn=fn) + self.used_mods = self._get_current_function_used_mod() + self._generate_script() + self._pre_operation() self._exec() self.index_set = self._get_index() From d0e54fad3b03b4d07177ca2a8209d7e4d2403737 Mon Sep 17 00:00:00 2001 From: yrl <2535184404@qq.com> Date: Mon, 8 Dec 2025 15:51:36 +0800 Subject: [PATCH 3/6] polish and update --- python/triton/_flagtree_tools.py | 95 ++++++++++++++------------------ 1 file changed, 40 insertions(+), 55 deletions(-) diff --git a/python/triton/_flagtree_tools.py b/python/triton/_flagtree_tools.py index 905eb174d..758692b0b 100644 --- a/python/triton/_flagtree_tools.py +++ b/python/triton/_flagtree_tools.py @@ -5,56 +5,16 @@ import tempfile from io import StringIO import contextlib -import math import ast -import statistics +import pandas as pd import triton.runtime as runtime -def flagtree_do_bench(fn, warmup=25, rep=100, quantiles=None, return_mode="mean"): +def flagtree_do_bench(fn, warmup=10, rep=5, quantiles=None, return_mode="mean"): + assert return_mode in ["mean", "min", "max", "sum"] bench = FlagtreeBench(current_fn=fn, warmup=warmup, rep=rep, quantiles=quantiles, return_mode=return_mode) bench.do_bench() - return 0 - - -''' - function _quantile and _summarize_statistics is from .testing. - if used directly, it will lead to circular dependencies -''' - - -def _quantile(a, q): - n = len(a) - a = sorted(a) - - def get_quantile(q): - if not (0 <= q <= 1): - raise ValueError("Quantiles must be in the range [0, 1]") - point = q * (n - 1) - lower = math.floor(point) - upper = math.ceil(point) - t = point - lower - return (1 - t) * a[lower] + t * a[upper] - - return [get_quantile(q) for q in q] - - -def _summarize_statistics(times, quantiles, return_mode): - if quantiles is not None: - ret = _quantile(times, quantiles) - if len(ret) == 1: - ret = ret[0] - return ret - if return_mode == "all": - return times - elif return_mode == "min": - return min(times) - elif return_mode == "max": - return max(times) - elif return_mode == "mean": - return statistics.mean(times) - elif return_mode == "median": - return statistics.median(times) + return bench.results[return_mode] ''' @@ -131,8 +91,7 @@ def ctx(): class FlagtreeBench: - def __init__(self, current_fn, warmup=100, rep=100, quantiles=None, return_mode="mean", - metrics='gpu__time_duration'): + def __init__(self, current_fn, warmup=10, rep=5, quantiles=None, return_mode="mean", metrics='gpu__time_duration'): if FlagtreeBench.check_ncu(): self._current_fn = current_fn self.metrics = metrics @@ -143,7 +102,6 @@ def __init__(self, current_fn, warmup=100, rep=100, quantiles=None, return_mode= self.triton_funcs = [] self._get_package_path() self._create_temp_file() - print(self.python_exec.file, self.out_csv.file) @staticmethod def check_ncu(): @@ -153,9 +111,9 @@ def check_ncu(): print("[INFO]: ncu check successfully") return True except Exception as err_msg: - print(f"[Hint] The inability to invoke ncu on this machine" + print(f"\033[31m[Error] The inability to invoke ncu on this machine" f"might be due to issues such as the absence of ncu, " - f"lack of permissions, or a version that is too low. Specifically {err_msg}") + f"lack of permissions, or a version that is too low. Specifically \n{err_msg}\033[0m") return False @staticmethod @@ -198,7 +156,7 @@ def visit_Attribute(self, node): mod_name = node.value.id mod_instance = func_global_dict[mod_name] if hasattr(mod_instance, '__file__'): - mod_dir_path = os.path.dirname(mod_instance.__file__) + mod_dir_path = os.path.dirname(os.path.dirname(mod_instance.__file__)) deps_path.add(mod_dir_path) modules.add(mod_name) self.generic_visit(node) @@ -241,13 +199,40 @@ def _write_script(self, script): def _exec(self): runtime.driver.active.clear_cache(self.bench_cache) - cmd = ["ncu", "--metrics", self.metrics, "python3", self.python_exec.name] + cmd = [ + "ncu", + "--metrics", + self.metrics, + "--csv", + "--log-file", + self.out_csv.name, + "python3", + self.python_exec.name, + ] print(f"[INFO]: ncu running on {self.python_exec.name}") - result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True, text=True) - self.clean_out = "\n".join(line for line in result.stdout.splitlines() if not line.startswith("==PROF==")) + subprocess.run(cmd, check=True) + self._pure_csv_log() + + def _pure_csv_log(self): + FILTER_PREFIXES = ["==PROF=", "==ERROR=", "==WARNING="] + with open(self.out_csv.name, 'r') as csv_f: + lines = csv_f.readlines() + new_lines = [line for line in lines if not any(line.startswith(prefix) for prefix in FILTER_PREFIXES)] + with open(self.out_csv.name, "w") as csv_f: + csv_f.writelines(new_lines) def _get_index(self): - ... + indexs = ['avg', 'max', 'min', 'sum'] + patterns = "at::|std::" + index_dict = dict.fromkeys(indexs, 0) + df = pd.read_csv(self.out_csv.name) + metric_values = df[~df["Kernel Name"].str.contains(patterns, regex=True)][["Metric Name", "Metric Value"]] + for _, row in metric_values.iterrows(): + metric_name = str(row['Metric Name']).split('.')[-1] + gpu_time = float(row['Metric Value']) / 1e6 + index_dict[metric_name] += gpu_time + index_dict['mean'] = index_dict['avg'] + return index_dict def _gen_import_and_path(self, script_code: IndentedBuffer, path_mode='insert'): calls, modules, deps_path = self._get_current_function_used_mod() @@ -321,4 +306,4 @@ def do_bench(self): self._generate_script() self._pre_operation() self._exec() - self.index_set = self._get_index() + self.results = self._get_index() From f9664818e029afc9d726a00958e0acf7a179c8f0 Mon Sep 17 00:00:00 2001 From: yrl <2535184404@qq.com> Date: Wed, 10 Dec 2025 17:03:36 +0800 Subject: [PATCH 4/6] update --- python/triton/_flagtree_tools.py | 319 ++++++++++++++++++++++--------- 1 file changed, 229 insertions(+), 90 deletions(-) diff --git a/python/triton/_flagtree_tools.py b/python/triton/_flagtree_tools.py index 758692b0b..31accd787 100644 --- a/python/triton/_flagtree_tools.py +++ b/python/triton/_flagtree_tools.py @@ -6,17 +6,71 @@ from io import StringIO import contextlib import ast +import types import pandas as pd +import pickle +from dataclasses import dataclass +from typing import Callable, Any import triton.runtime as runtime +''' + Currently, the use of flagtree_do_bench is restricted, mainly including: + + 1. This method can only be used to test the kernel running time of triton and torch; + + 2. Arg fn, is either a direct single-test wrapper, such as + (1) def test(): + (2). fn = op() + +''' -def flagtree_do_bench(fn, warmup=10, rep=5, quantiles=None, return_mode="mean"): +def flagtree_do_bench(fn, warmup=10, rep=5, quantiles=None, return_mode="mean") -> float: assert return_mode in ["mean", "min", "max", "sum"] bench = FlagtreeBench(current_fn=fn, warmup=warmup, rep=rep, quantiles=quantiles, return_mode=return_mode) bench.do_bench() return bench.results[return_mode] +def check_ncu(): + cmd = ["ncu", "--query-metrics"] + try: + subprocess.run(cmd, capture_output=True, check=True) + print("[INFO]: ncu check successfully") + return True + except Exception as err_msg: + print(f"\033[31m[Error] The inability to invoke ncu on this machine" + f"might be due to issues such as the absence of ncu, " + f"lack of permissions, or a version that is too low. Specifically \n{err_msg}\033[0m") + return False + + +def run_warmup(_fn, warmup): + ''' + Referred to triton.testing.do_bench + ''' + di = runtime.driver.active.get_device_interface() + _fn() + di.synchronize() + cache = runtime.driver.active.get_empty_cache_for_benchmark() + + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + runtime.driver.active.clear_cache(cache) + _fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + + for _ in range(n_warmup): + _fn() + + ''' IndentedBuffer Referred to https://github.com/flagos-ai/FlagGems/blob/master/src/flag_gems/utils/code_utils.py::IndentedBuffer @@ -84,6 +138,148 @@ def ctx(): return ctx() +@dataclass +class FuncAttrs: + _globals: dict = None, + is_lambda: bool = False + is_torch_method: bool = False, + source: str = '' + ast_tree: any = None, + functor_source: str = '' + Argument_serialized: bool = False + Argument_serialized_path: str = '' + + +class CodeGenerator: + + @staticmethod + def gen_load_args_kwargs_method(script_code: IndentedBuffer, Trait: FuncAttrs): + script_code.writeline("def load_args_kwargs(filename):") + with script_code.indent(): + script_code.writeline("import pickle") + script_code.writeline("with open(filename, 'rb') as f:") + with script_code.indent(): + script_code.writeline("data = pickle.load(f)") + script_code.writeline("return data['args'], data['kwargs']") + + script_code.writeline(f"args, kwargs = load_args_kwargs('{Trait.Argument_serialized_path}')") + script_code.writeline(f"{Trait.functor_source}(*args, **kwargs)") + + @staticmethod + def _gen_import_and_path(script_code: IndentedBuffer, unpacked, path_mode='insert'): + sys_path_action_str = '0, ' + if path_mode == 'insert': + script_code.writeline('import torch') + script_code.writeline('import os') + script_code.writeline('import sys') + else: + sys_path_action_str = '' + user_package_path = os.environ.get('BENCH_MODULE_PATH', '') + if user_package_path != '': + script_code.writeline(f"sys.path.{path_mode}({sys_path_action_str}'{user_package_path}')") + + # create extra modules + if not unpacked: + return + else: + calls, modules, deps_path = unpacked + for path in deps_path: + if not os.path.isdir(path): + path = os.path.dirname(path) + script_code.writeline(f"sys.path.{path_mode}({sys_path_action_str}'{path}')") + if path_mode == 'insert': + for mod in modules: + script_code.writeline(f'import {mod}') + for call, mod in calls: + script_code.writeline(f"from {mod} import {call}") + + @staticmethod + def _gen_lambda_source_code(script_code: IndentedBuffer, Trait: FuncAttrs): + if Trait.is_torch_method: + CodeGenerator._gen_torch_using_lambda_code(script_code, Trait) + else: + CodeGenerator._gen_triton_code(script_code) + + @staticmethod + def _gen_torch_using_lambda_code(script_code: IndentedBuffer, Trait: FuncAttrs): + if Trait.Argument_serialized: + CodeGenerator.gen_load_args_kwargs_method(script_code, Trait) + else: + script_code.writeline(f"{Trait.functor_source}()") + + @staticmethod + def _gen_triton_code(): + ... + + +class FunctionExtractor: + + def __init__(self, fn: Callable[..., Any] = None): + if fn.__name__ == "": + self.trait = self.analyse_lambda(fn) + else: + self.trait = self.analyse_general(fn) + + def args_serialization(self, *args, **kwargs): + packed_datas = tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") + with open(packed_datas.name, "wb") as f: + pickle.dump({"args": args, "kwargs": kwargs}, f) + return packed_datas.name + + def analyse_lambda(self, fn): + is_lambda = True + if not hasattr(fn, '__closure__'): + return + + def handlecase_single(closure_package, fn): + ... + + def handlecase_both(closure_package, fn): + ... + + def handlecase_all(closure_package, fn): + source = inspect.getsource(fn) + args, kwargs = closure_package[0].cell_contents, closure_package[1].cell_contents + serialized_path = self.args_serialization(*args, **kwargs) + functor = closure_package[-1].cell_contents + + # Case 1: method_descriptor → Tensor-level method + # e.g., torch.Tensor.addmm + if inspect.ismethoddescriptor(functor) or hasattr(functor, "__objclass__"): + functor_mod_name = functor.__objclass__.__module__.replace("torch._C", "torch") + functor_source = f"{functor_mod_name}.{functor.__name__}" + + # Case 2: bound method (x.addmm) + if inspect.ismethod(functor): + cls = functor.__self__.__class__ + functor_source = f"{cls.__module__}.{cls.__qualname__}.{functor.__name__}" + + if isinstance(functor, (types.BuiltinFunctionType, types.BuiltinMethodType)): + functor_source = f"torch.{functor.__name__}" + + is_torch_method = 'torch' in functor_source + return FuncAttrs(_globals=fn.__globals__, source=inspect.getsource(fn), is_torch_method=is_torch_method, + Argument_serialized=True, Argument_serialized_path=serialized_path, + functor_source=functor_source, ast_tree=ast.parse(textwrap.dedent(source)), + is_lambda=is_lambda) + + casehandlers = [handlecase_single, handlecase_both, handlecase_all] + casehandlers_mapping = {} + for case_idx, handler in enumerate(casehandlers): + casehandlers_mapping[case_idx] = handler + + handle_closure = lambda _type, _closure_package, _fn: casehandlers_mapping[_type](_closure_package, _fn) + closure_package = fn.__closure__ + closure_package_len = len(closure_package) - 1 + + return handle_closure(closure_package_len, closure_package, fn) + + def analyse_general(self, fn): + source = inspect.getsource(fn) + return FuncAttrs(source=source, ast_tree=ast.parse(textwrap.dedent(source)), is_lambda=False, + _globals=fn.__globals__) + + ''' FlagtreeBench using ncu to measure performance ''' @@ -92,7 +288,7 @@ def ctx(): class FlagtreeBench: def __init__(self, current_fn, warmup=10, rep=5, quantiles=None, return_mode="mean", metrics='gpu__time_duration'): - if FlagtreeBench.check_ncu(): + if check_ncu(): self._current_fn = current_fn self.metrics = metrics self.warmup = warmup @@ -100,21 +296,8 @@ def __init__(self, current_fn, warmup=10, rep=5, quantiles=None, return_mode="me self.quantiles = quantiles self.return_mode = return_mode self.triton_funcs = [] - self._get_package_path() self._create_temp_file() - - @staticmethod - def check_ncu(): - cmd = ["ncu", "--query-metrics"] - try: - subprocess.run(cmd, capture_output=True, check=True) - print("[INFO]: ncu check successfully") - return True - except Exception as err_msg: - print(f"\033[31m[Error] The inability to invoke ncu on this machine" - f"might be due to issues such as the absence of ncu, " - f"lack of permissions, or a version that is too low. Specifically \n{err_msg}\033[0m") - return False + print(self.python_exec.name, self.out_csv.name) @staticmethod def gather_triton_jit_kernel(mod): @@ -140,10 +323,10 @@ def is_from_sitepackages(mod): return 'site-packages' in mod.__file__ def _get_current_function_used_mod(self, _fn=None): - _fn = _fn or self._current_fn - func_global_dict = _fn.__globals__ - source = inspect.getsource(_fn) - tree = ast.parse(source) + attrs = self.fn_trait + if attrs.is_lambda: + return None + func_global_dict = attrs._globals modules = set() calls = set() deps_path = set() @@ -164,13 +347,14 @@ def visit_Attribute(self, node): def visit_Call(self, node): if isinstance(node.func, ast.Name): fun_name = node.func.id - func_instance = func_global_dict[fun_name] - mod_instance = __import__(func_instance.__module__) - triton_jit_kernels.update(FlagtreeBench.gather_triton_jit_kernel(mod_instance)) - if hasattr(mod_instance, '__file__'): - mod_dir_path = os.path.dirname(mod_instance.__file__) - deps_path.add(mod_dir_path) - calls.add((fun_name, mod_instance.__name__)) + if fun_name in func_global_dict: + func_instance = func_global_dict[fun_name] + mod_instance = __import__(func_instance.__module__) + triton_jit_kernels.update(FlagtreeBench.gather_triton_jit_kernel(mod_instance)) + if hasattr(mod_instance, '__file__'): + mod_dir_path = os.path.dirname(mod_instance.__file__) + deps_path.add(mod_dir_path) + calls.add((fun_name, mod_instance.__name__)) elif isinstance(node.func, ast.Attribute): fun_name = node.func.attr @@ -180,12 +364,9 @@ def visit_Call(self, node): triton_jit_kernels.update(FlagtreeBench.gather_triton_jit_kernel(mod_instance)) self.generic_visit(node) - Visitor().visit(tree) + Visitor().visit(attrs.ast_tree) return (calls, modules, deps_path) - def _get_package_path(self): - self.user_package_path = os.environ.get('BENCH_MODULE_PATH', '') - def _create_temp_file(self): self.python_exec = tempfile.NamedTemporaryFile(delete=False, suffix=".py") self.python_exec.close() @@ -198,7 +379,6 @@ def _write_script(self, script): f.write(script) def _exec(self): - runtime.driver.active.clear_cache(self.bench_cache) cmd = [ "ncu", "--metrics", @@ -223,10 +403,13 @@ def _pure_csv_log(self): def _get_index(self): indexs = ['avg', 'max', 'min', 'sum'] - patterns = "at::|std::" + patterns = "at::|std::|void" index_dict = dict.fromkeys(indexs, 0) df = pd.read_csv(self.out_csv.name) - metric_values = df[~df["Kernel Name"].str.contains(patterns, regex=True)][["Metric Name", "Metric Value"]] + if self.fn_trait.is_torch_method: + metric_values = df[df["Kernel Name"].str.contains(patterns, regex=True)][["Metric Name", "Metric Value"]] + else: + metric_values = df[~df["Kernel Name"].str.contains(patterns, regex=True)][["Metric Name", "Metric Value"]] for _, row in metric_values.iterrows(): metric_name = str(row['Metric Name']).split('.')[-1] gpu_time = float(row['Metric Value']) / 1e6 @@ -234,76 +417,32 @@ def _get_index(self): index_dict['mean'] = index_dict['avg'] return index_dict - def _gen_import_and_path(self, script_code: IndentedBuffer, path_mode='insert'): - calls, modules, deps_path = self._get_current_function_used_mod() - sys_path_action_str = '0, ' - if path_mode == 'insert': - script_code.writeline('import torch') - script_code.writeline('import os') - script_code.writeline('import sys') - else: - sys_path_action_str = '' - if self.user_package_path != '': - script_code.writeline(f"sys.path.{path_mode}({sys_path_action_str}'{self.user_package_path}')") - for path in deps_path: - if not os.path.isdir(path): - path = os.path.dirname(path) - script_code.writeline(f"sys.path.{path_mode}({sys_path_action_str}'{path}')") - if path_mode == 'insert': - for mod in modules: - script_code.writeline(f'import {mod}') - for call, mod in calls: - script_code.writeline(f"from {mod} import {call}") - def _generate_script(self, _fn=None): _fn = _fn or self._current_fn fn_src_code_string = textwrap.dedent(inspect.getsource(_fn)) script_code = IndentedBuffer() - self._gen_import_and_path(script_code, path_mode='insert') + unpacked = self._get_current_function_used_mod() + CodeGenerator._gen_import_and_path(script_code, unpacked, path_mode='insert') - script_code.writeline(fn_src_code_string) - script_code.writeline(f'{_fn.__name__}()') + if self.fn_trait.is_lambda: + CodeGenerator._gen_lambda_source_code(script_code, self.fn_trait) + else: + script_code.writeline(fn_src_code_string) + script_code.writeline(f'{_fn.__name__}()') script_code.writeline("torch.cuda.synchronize()") - self._gen_import_and_path(script_code, path_mode='remove') + CodeGenerator._gen_import_and_path(script_code, unpacked, path_mode='remove') + self.script = script_code.getvalue() self._write_script(self.script) - def _pre_operation(self, _fn=None): - ''' - Referred to triton.testing.do_bench - ''' - _fn = _fn or self._current_fn - di = runtime.driver.active.get_device_interface() - _fn() - di.synchronize() - cache = runtime.driver.active.get_empty_cache_for_benchmark() - - # Estimate the runtime of the function - start_event = di.Event(enable_timing=True) - end_event = di.Event(enable_timing=True) - start_event.record() - for _ in range(5): - runtime.driver.active.clear_cache(cache) - _fn() - end_event.record() - di.synchronize() - estimate_ms = start_event.elapsed_time(end_event) / 5 - - # compute number of warmup and repeat - n_warmup = max(1, int(self.warmup / estimate_ms)) - - self.bench_cache = cache - for _ in range(n_warmup): - _fn() - - def do_bench(self): + def do_bench(self) -> float: ''' Measure the GPU kernel time of fn() using ncu. Generate a temporary Python file and then run it with 'ncu'. ''' - self.used_mods = self._get_current_function_used_mod() + self.fn_trait = FunctionExtractor(self._current_fn).trait self._generate_script() - self._pre_operation() + run_warmup(self._current_fn, self.warmup) self._exec() self.results = self._get_index() From bf537408b7b2561e79e80c373590332d088f628f Mon Sep 17 00:00:00 2001 From: yrl <2535184404@qq.com> Date: Thu, 11 Dec 2025 10:44:06 +0800 Subject: [PATCH 5/6] update --- python/triton/_flagtree_tools.py | 185 ++++++++++++++++--------------- 1 file changed, 96 insertions(+), 89 deletions(-) diff --git a/python/triton/_flagtree_tools.py b/python/triton/_flagtree_tools.py index 31accd787..d786dc845 100644 --- a/python/triton/_flagtree_tools.py +++ b/python/triton/_flagtree_tools.py @@ -8,6 +8,7 @@ import ast import types import pandas as pd +import torch import pickle from dataclasses import dataclass from typing import Callable, Any @@ -31,6 +32,10 @@ def flagtree_do_bench(fn, warmup=10, rep=5, quantiles=None, return_mode="mean") return bench.results[return_mode] +def get_cuda_impl(op_name): + return torch._C._dispatch_get_registrations_for_dispatch_key("CUDA").get(op_name, None) + + def check_ncu(): cmd = ["ncu", "--query-metrics"] try: @@ -44,7 +49,7 @@ def check_ncu(): return False -def run_warmup(_fn, warmup): +def function_warmup(_fn, warmup): ''' Referred to triton.testing.do_bench ''' @@ -148,10 +153,39 @@ class FuncAttrs: functor_source: str = '' Argument_serialized: bool = False Argument_serialized_path: str = '' + calls: list = None, + modules: list = None, + deps_path: list = None class CodeGenerator: + save_path: str = '' + + def gen_benchmark_python_code(self, _fn: Callable[..., Any] = None, Trait: FuncAttrs = None, save=True): + script_code = IndentedBuffer() + fn_src_code_string = textwrap.dedent(inspect.getsource(_fn)) + script_code = IndentedBuffer() + CodeGenerator._gen_import_and_path(script_code, Trait, path_mode='insert') + + if Trait.is_lambda: + CodeGenerator._gen_lambda_source_code(script_code, Trait) + else: + script_code.writeline(fn_src_code_string) + script_code.writeline(f'{_fn.__name__}()') + script_code.writeline("torch.cuda.synchronize()") + + CodeGenerator._gen_import_and_path(script_code, Trait, path_mode='remove') + self.save_script_code(script_code) + + def save_script_code(self, script_code: IndentedBuffer): + script = script_code.getvalue() + python_temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".py") + python_temp_file.close() + with open(python_temp_file.name, 'w+') as f: + f.write(script) + self.save_path = python_temp_file.name + @staticmethod def gen_load_args_kwargs_method(script_code: IndentedBuffer, Trait: FuncAttrs): script_code.writeline("def load_args_kwargs(filename):") @@ -166,7 +200,7 @@ def gen_load_args_kwargs_method(script_code: IndentedBuffer, Trait: FuncAttrs): script_code.writeline(f"{Trait.functor_source}(*args, **kwargs)") @staticmethod - def _gen_import_and_path(script_code: IndentedBuffer, unpacked, path_mode='insert'): + def _gen_import_and_path(script_code: IndentedBuffer, Trait: FuncAttrs, path_mode='insert'): sys_path_action_str = '0, ' if path_mode == 'insert': script_code.writeline('import torch') @@ -179,10 +213,10 @@ def _gen_import_and_path(script_code: IndentedBuffer, unpacked, path_mode='inser script_code.writeline(f"sys.path.{path_mode}({sys_path_action_str}'{user_package_path}')") # create extra modules - if not unpacked: + if Trait.is_lambda: return else: - calls, modules, deps_path = unpacked + calls, modules, deps_path = Trait.calls, Trait.modules, Trait.deps_path for path in deps_path: if not os.path.isdir(path): path = os.path.dirname(path) @@ -276,57 +310,13 @@ def handlecase_all(closure_package, fn): def analyse_general(self, fn): source = inspect.getsource(fn) - return FuncAttrs(source=source, ast_tree=ast.parse(textwrap.dedent(source)), is_lambda=False, - _globals=fn.__globals__) - - -''' - FlagtreeBench using ncu to measure performance -''' - - -class FlagtreeBench: - - def __init__(self, current_fn, warmup=10, rep=5, quantiles=None, return_mode="mean", metrics='gpu__time_duration'): - if check_ncu(): - self._current_fn = current_fn - self.metrics = metrics - self.warmup = warmup - self.rep = rep - self.quantiles = quantiles - self.return_mode = return_mode - self.triton_funcs = [] - self._create_temp_file() - print(self.python_exec.name, self.out_csv.name) - - @staticmethod - def gather_triton_jit_kernel(mod): - ''' - attrs temporarily adds specialized support to the kernel of flag_gems. - About flag_gems see https://github.com/flagos-ai/FlagGems - ''' - if FlagtreeBench.is_from_sitepackages(mod): - return set() - - kernels = set() - attrs = ['AnonymousLibTunerImpl', 'LibEntry', 'JITFunction'] - for node in dir(mod): - if node.startswith('__'): - continue - obj = getattr(mod, node) - if hasattr(obj, '__class__') and obj.__class__.__name__ in attrs: - kernels.add(node) - return kernels - - @staticmethod - def is_from_sitepackages(mod): - return 'site-packages' in mod.__file__ + ast_tree = ast.parse(textwrap.dedent(source)) + calls, modules, deps_path = self._get_current_function_used_mod(fn, ast_tree) + return FuncAttrs(source=source, ast_tree=ast_tree, is_lambda=False, _globals=fn.__globals__, calls=calls, + modules=modules, deps_path=deps_path) - def _get_current_function_used_mod(self, _fn=None): - attrs = self.fn_trait - if attrs.is_lambda: - return None - func_global_dict = attrs._globals + def _get_current_function_used_mod(self, _fn=None, ast_tree=None): + func_global_dict = _fn.__globals__ modules = set() calls = set() deps_path = set() @@ -350,7 +340,7 @@ def visit_Call(self, node): if fun_name in func_global_dict: func_instance = func_global_dict[fun_name] mod_instance = __import__(func_instance.__module__) - triton_jit_kernels.update(FlagtreeBench.gather_triton_jit_kernel(mod_instance)) + triton_jit_kernels.update(FunctionExtractor.gather_triton_jit_kernel(mod_instance)) if hasattr(mod_instance, '__file__'): mod_dir_path = os.path.dirname(mod_instance.__file__) deps_path.add(mod_dir_path) @@ -361,24 +351,59 @@ def visit_Call(self, node): if isinstance(node.func.value, ast.Name): mod = node.func.value.id mod_instance = func_global_dict[mod] - triton_jit_kernels.update(FlagtreeBench.gather_triton_jit_kernel(mod_instance)) + triton_jit_kernels.update(FunctionExtractor.gather_triton_jit_kernel(mod_instance)) self.generic_visit(node) - Visitor().visit(attrs.ast_tree) + Visitor().visit(ast_tree) return (calls, modules, deps_path) - def _create_temp_file(self): - self.python_exec = tempfile.NamedTemporaryFile(delete=False, suffix=".py") - self.python_exec.close() + @staticmethod + def gather_triton_jit_kernel(mod): + ''' + attrs temporarily adds specialized support to the kernel of flag_gems. + About flag_gems see https://github.com/flagos-ai/FlagGems + ''' + if FunctionExtractor.is_from_sitepackages(mod): + return set() + + kernels = set() + attrs = ['AnonymousLibTunerImpl', 'LibEntry', 'JITFunction'] + for node in dir(mod): + if node.startswith('__'): + continue + obj = getattr(mod, node) + if hasattr(obj, '__class__') and obj.__class__.__name__ in attrs: + kernels.add(node) + return kernels + + @staticmethod + def is_from_sitepackages(mod): + return 'site-packages' in mod.__file__ + + +''' + FlagtreeBench using ncu to measure performance +''' + + +class FlagtreeBench: + + def __init__(self, current_fn, warmup=10, rep=5, quantiles=None, return_mode="mean", metrics='gpu__time_duration'): + if check_ncu(): + self._current_fn = current_fn + self.metrics = metrics + self.warmup = warmup + self.rep = rep + self.quantiles = quantiles + self.return_mode = return_mode + self._create_temp_file() + def _create_temp_file(self): self.out_csv = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") self.out_csv.close() - def _write_script(self, script): - with open(self.python_exec.name, 'w+') as f: - f.write(script) - - def _exec(self): + def run_code_script(self): + path = self.code_instance.save_path cmd = [ "ncu", "--metrics", @@ -387,9 +412,9 @@ def _exec(self): "--log-file", self.out_csv.name, "python3", - self.python_exec.name, + path, ] - print(f"[INFO]: ncu running on {self.python_exec.name}") + print(f"[INFO]: ncu running on {path}") subprocess.run(cmd, check=True) self._pure_csv_log() @@ -417,32 +442,14 @@ def _get_index(self): index_dict['mean'] = index_dict['avg'] return index_dict - def _generate_script(self, _fn=None): - _fn = _fn or self._current_fn - fn_src_code_string = textwrap.dedent(inspect.getsource(_fn)) - script_code = IndentedBuffer() - unpacked = self._get_current_function_used_mod() - CodeGenerator._gen_import_and_path(script_code, unpacked, path_mode='insert') - - if self.fn_trait.is_lambda: - CodeGenerator._gen_lambda_source_code(script_code, self.fn_trait) - else: - script_code.writeline(fn_src_code_string) - script_code.writeline(f'{_fn.__name__}()') - script_code.writeline("torch.cuda.synchronize()") - - CodeGenerator._gen_import_and_path(script_code, unpacked, path_mode='remove') - - self.script = script_code.getvalue() - self._write_script(self.script) - def do_bench(self) -> float: ''' Measure the GPU kernel time of fn() using ncu. Generate a temporary Python file and then run it with 'ncu'. ''' self.fn_trait = FunctionExtractor(self._current_fn).trait - self._generate_script() - run_warmup(self._current_fn, self.warmup) - self._exec() + self.code_instance = CodeGenerator() + self.code_instance.gen_benchmark_python_code(_fn=self._current_fn, Trait=self.fn_trait) + function_warmup(self._current_fn, self.warmup) + self.run_code_script() self.results = self._get_index() From 158199ff40391340aee8c8ac02e92a1f2a5629c6 Mon Sep 17 00:00:00 2001 From: yrl <2535184404@qq.com> Date: Thu, 11 Dec 2025 16:24:11 +0800 Subject: [PATCH 6/6] update --- python/triton/_flagtree_tools.py | 41 +++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/python/triton/_flagtree_tools.py b/python/triton/_flagtree_tools.py index d786dc845..7e4a00d21 100644 --- a/python/triton/_flagtree_tools.py +++ b/python/triton/_flagtree_tools.py @@ -1,5 +1,6 @@ import subprocess import os +import sys import textwrap import inspect import tempfile @@ -156,6 +157,7 @@ class FuncAttrs: calls: list = None, modules: list = None, deps_path: list = None + is_flaggems_functor: bool = False class CodeGenerator: @@ -195,9 +197,14 @@ def gen_load_args_kwargs_method(script_code: IndentedBuffer, Trait: FuncAttrs): with script_code.indent(): script_code.writeline("data = pickle.load(f)") script_code.writeline("return data['args'], data['kwargs']") - script_code.writeline(f"args, kwargs = load_args_kwargs('{Trait.Argument_serialized_path}')") - script_code.writeline(f"{Trait.functor_source}(*args, **kwargs)") + if Trait.is_flaggems_functor: + script_code.writeline("import flag_gems") + script_code.writeline("with flag_gems.use_gems():") + with script_code.indent(): + script_code.writeline(f"{Trait.functor_source}(*args, **kwargs)") + else: + script_code.writeline(f"{Trait.functor_source}(*args, **kwargs)") @staticmethod def _gen_import_and_path(script_code: IndentedBuffer, Trait: FuncAttrs, path_mode='insert'): @@ -295,7 +302,7 @@ def handlecase_all(closure_package, fn): return FuncAttrs(_globals=fn.__globals__, source=inspect.getsource(fn), is_torch_method=is_torch_method, Argument_serialized=True, Argument_serialized_path=serialized_path, functor_source=functor_source, ast_tree=ast.parse(textwrap.dedent(source)), - is_lambda=is_lambda) + is_lambda=is_lambda, is_flaggems_functor=FunctionExtractor.is_flaggems_operator()) casehandlers = [handlecase_single, handlecase_both, handlecase_all] casehandlers_mapping = {} @@ -312,8 +319,16 @@ def analyse_general(self, fn): source = inspect.getsource(fn) ast_tree = ast.parse(textwrap.dedent(source)) calls, modules, deps_path = self._get_current_function_used_mod(fn, ast_tree) - return FuncAttrs(source=source, ast_tree=ast_tree, is_lambda=False, _globals=fn.__globals__, calls=calls, - modules=modules, deps_path=deps_path) + return FuncAttrs( + source=source, + ast_tree=ast_tree, + is_lambda=False, + _globals=fn.__globals__, + calls=calls, + modules=modules, + deps_path=deps_path, + is_flaggems_functor=FunctionExtractor.is_flaggems_operator(), + ) def _get_current_function_used_mod(self, _fn=None, ast_tree=None): func_global_dict = _fn.__globals__ @@ -380,6 +395,15 @@ def gather_triton_jit_kernel(mod): def is_from_sitepackages(mod): return 'site-packages' in mod.__file__ + @staticmethod + def is_flaggems_operator(): + try: + import flag_gems + with flag_gems.use_gems(): + return False + except Exception: + return True + ''' FlagtreeBench using ncu to measure performance @@ -411,11 +435,10 @@ def run_code_script(self): "--csv", "--log-file", self.out_csv.name, - "python3", + sys.executable, path, ] - print(f"[INFO]: ncu running on {path}") - subprocess.run(cmd, check=True) + subprocess.Popen(cmd, text=True).communicate() self._pure_csv_log() def _pure_csv_log(self): @@ -431,7 +454,7 @@ def _get_index(self): patterns = "at::|std::|void" index_dict = dict.fromkeys(indexs, 0) df = pd.read_csv(self.out_csv.name) - if self.fn_trait.is_torch_method: + if self.fn_trait.is_torch_method and not self.fn_trait.is_flaggems_functor: metric_values = df[df["Kernel Name"].str.contains(patterns, regex=True)][["Metric Name", "Metric Value"]] else: metric_values = df[~df["Kernel Name"].str.contains(patterns, regex=True)][["Metric Name", "Metric Value"]]