From 3bffdfb7bb958220e0c9298edd9396e64aca24f2 Mon Sep 17 00:00:00 2001 From: ConvolutedDog Date: Sun, 28 Sep 2025 21:15:16 +0800 Subject: [PATCH 1/3] [RFC] Clean up debug code and improve output messages --- codegen/op_impl/codegenR.py | 223 +++------------------------------ op/Op.py | 10 +- policy/ConstructionPolicyRT.py | 8 +- test_op.py | 45 ++++--- test_op_mp.py | 34 ++--- 5 files changed, 70 insertions(+), 250 deletions(-) diff --git a/codegen/op_impl/codegenR.py b/codegen/op_impl/codegenR.py index 46a211c..6d5295b 100644 --- a/codegen/op_impl/codegenR.py +++ b/codegen/op_impl/codegenR.py @@ -16,17 +16,6 @@ ) from utils import LatestTVM, get_axis_names, get_blocks -""" -from tvm.topi.cuda.tensor_intrin import ( - intrin_wmma_load_matrix_A, - intrin_wmma_load_matrix_W, - intrin_wmma_store_matrix, - intrin_wmma_gemm, -) -""" - -DEBUG = False - class CodeGeneratorR: def get_codegen_dict(self, rprog): @@ -58,14 +47,6 @@ def split_axis( if sche == None: sche = self.sche - if DEBUG: - print("self.tiling: ", self.tiling) - if hasattr(self, "iter_vars_map"): - print("self.iter_vars_map: ", self.iter_vars_map) - if hasattr(self, "iter_vars_map_reverse"): - print("self.iter_vars_map_reverse: ", self.iter_vars_map_reverse) - print(axis, type(axis)) - factors = None # If axis is loop_var, it has no attr 'var'. @@ -106,10 +87,6 @@ def split_axis( ) ] - if DEBUG: - print("iter_vars_compute_local:", iter_vars_compute_local) - print("iter_vars_compute:", iter_vars_compute) - self.iter_vars_map = {} self.iter_vars_map_reverse = {} for i in range(len(iter_vars_compute_local)): @@ -141,10 +118,6 @@ def split_axis( block_compute_local = self.sche.get_block(target_stage) loops = self.sche.get_loops(block_compute_local) - if DEBUG: - print("iter_vars_map: ", self.iter_vars_map) - print("iter_vars_map_reverse: ", self.iter_vars_map_reverse) - target_loop = axis if target_loop is None: raise ValueError(f"Unfound axis") @@ -154,8 +127,6 @@ def split_axis( math.ceil(int(axis.dom.extent) / int(np.prod(factors))), ] + factors[:] else: - if DEBUG: - print(self.sche.get(target_loop).extent) new_factors = [ math.ceil( int(self.sche.get(target_loop).extent) / int(np.prod(factors)) @@ -164,8 +135,6 @@ def split_axis( axises = sche.split(target_loop, factors=new_factors) - if DEBUG: - self.sche.show() return axises else: @@ -179,8 +148,6 @@ def split_axis( def update_thread_per_block(self, stage, sche=None, vthread=True): num = 1 - if DEBUG: - print(self.tiling) if LatestTVM: saxis, _ = get_axis_names(stage) # stage is out for name in saxis: @@ -207,8 +174,6 @@ def cooperative_fetch( fused_loops.append(looprv) fused = sch.fuse(*fused_loops) - if DEBUG: - sch.show() if not isinstance(sch.get(fused).extent, int): assert self.bank_size % 4 == 0 and self.bank_size >= 4 @@ -224,8 +189,6 @@ def cooperative_fetch( self.bank_size // 4 ] # right? - if DEBUG: - print(new_factors) fused, ii_n = sch.split(fused, factors=new_factors) if not isinstance(sch.get(fused).extent, int): @@ -294,7 +257,6 @@ def calc_grid(self, reduce_iters, space_iters, vthread=True): thrd_dict["threadIdx.y"], thrd_dict["threadIdx.z"], ] - # print("blck_grid: ", self.blck_grid, "thrd_grid: ", self.thrd_grid) def adjust_format(self, out): if LatestTVM: @@ -311,7 +273,6 @@ def adjust_format(self, out): vthrd = self.tiling[name][1] thrd = self.tiling[name][0] self.tiling[name] = [vthrd, thrd, 1] - # print("Config:", self.tiling) # [Parameters] # schedule: the original TVM schedule of an op @@ -409,8 +370,6 @@ def rewrite_schedule( reg_tile = None if LatestTVM: - if DEBUG: - self.sche.mod.show() if self.need_smem_tiling: read_buffer_index = 0 for input_tensor in input_tensors: @@ -420,8 +379,6 @@ def rewrite_schedule( ) read_buffer_index += 1 smem_tensor.append(shared_tensor) - if DEBUG: - self.sche.show() if self.need_reg_tiling: read_buffer_index = 0 if codegen_input_reg_tiling: @@ -435,8 +392,6 @@ def rewrite_schedule( reg_tile = self.sche.cache_write( block_compute, write_buffer_index=0, storage_scope="local" ) - if DEBUG: - self.sche.show() else: if self.need_smem_tiling: for input_tensor in input_tensors: @@ -453,12 +408,6 @@ def rewrite_schedule( reg_tensor.append(local_tensor) reg_tile = self.sche.cache_write(out, "local") - if DEBUG: - mod = tvm.lower( - self.sche, in_tensors + out_tensors, simple_mode=False - ) - print(mod.script()) - blck_axis = [] vthd_axis = [] thrd_axis = [] @@ -478,13 +427,13 @@ def rewrite_schedule( -1 ] * (self.bank_size // 4) else: - print("shared mem tiling is too small.") + print("Shared mem tiling is too small.") self.tiling[axis.var.name][-1] = ( self.tiling[axis.var.name][-1] * self.tiling[axis.var.name][-3] ) self.tiling[axis.var.name][-3] = 1 - print("updated self.tiling: ", self.tiling) + print("Updated self.tiling: ", self.tiling) bx, vx, tx, tn = self.split_axis( out, @@ -494,13 +443,6 @@ def rewrite_schedule( all_tensors=in_tensors + out_tensors, ) - if DEBUG: - if not LatestTVM: - mod = tvm.lower( - self.sche, in_tensors + out_tensors, simple_mode=False - ) - print(mod.script()) - blck_axis.append(bx) vthd_axis.append(vx) thrd_axis.append(tx) @@ -511,8 +453,6 @@ def rewrite_schedule( ) looprvs = self.sche.get_loops(block_compute) - print([self.sche.get(looprvs[i]).loop_var for i in range(len(looprvs))]) - iter_vars = self.sche.get(block_compute).iter_vars iter_types = [iter_var.iter_type for iter_var in iter_vars] @@ -538,12 +478,12 @@ def rewrite_schedule( self.bank_size // 4 ) else: - print("shared mem tiling is too small.") + print("Shared mem tiling is too small.") self.tiling[loop_var][-1] = ( self.tiling[loop_var][-1] * self.tiling[loop_var][-3] ) self.tiling[loop_var][-3] = 1 - print("updated self.tiling: ", self.tiling) + print("Updated self.tiling: ", self.tiling) bx, vx, tx, tn = self.split_axis( None, @@ -565,15 +505,6 @@ def rewrite_schedule( else: self.sche[out].reorder(*axis_order) - if DEBUG: - if not LatestTVM: - mod = tvm.lower( - self.sche, in_tensors + out_tensors, simple_mode=False - ) - print(mod.script()) - else: - self.sche.show() - if LatestTVM: blck_fused = self.sche.fuse(*blck_axis) thrd_fused = self.sche.fuse(*thrd_axis) @@ -581,15 +512,6 @@ def rewrite_schedule( blck_fused = self.sche[out].fuse(*blck_axis) thrd_fused = self.sche[out].fuse(*thrd_axis) - if DEBUG: - if not LatestTVM: - mod = tvm.lower( - self.sche, in_tensors + out_tensors, simple_mode=False - ) - print(mod.script()) - else: - self.sche.show() - if self.binding["space"][0] is not None: if LatestTVM: self.sche.bind(blck_fused, self.binding["space"][0]) @@ -598,15 +520,6 @@ def rewrite_schedule( blck_fused, te.thread_axis(self.binding["space"][0]) ) - if DEBUG: - if not LatestTVM: - mod = tvm.lower( - self.sche, in_tensors + out_tensors, simple_mode=False - ) - print(mod.script()) - else: - self.sche.show() - if self.binding["space"][1] is not None: if LatestTVM: vthd_map = {0: ".x", 1: ".y", 2: ".z"} @@ -619,15 +532,6 @@ def rewrite_schedule( va, te.thread_axis(self.binding["space"][1]) ) - if DEBUG: - if not LatestTVM: - mod = tvm.lower( - self.sche, in_tensors + out_tensors, simple_mode=False - ) - print(mod.script()) - else: - self.sche.show() - if self.binding["space"][2] is not None: if LatestTVM: self.sche.bind(thrd_fused, self.binding["space"][2]) @@ -636,26 +540,13 @@ def rewrite_schedule( thrd_fused, te.thread_axis(self.binding["space"][2]) ) - if DEBUG: - if not LatestTVM: - mod = tvm.lower( - self.sche, in_tensors + out_tensors, simple_mode=False - ) - print(mod.script()) - else: - self.sche.show() - reduce_axis = [] if reg_tile is not None: if LatestTVM: self.sche.compute_at(target_stage, thrd_fused) - if DEBUG: - self.sche.show() space_axis = [] block_compute = self.sche.get_block(target_stage) - if DEBUG: - print(self.sche.get(block_compute).iter_vars) loops = self.sche.get_loops(block_compute) @@ -702,12 +593,6 @@ def rewrite_schedule( target_stage += "_update" for i in range(len(loops)): - if DEBUG: - print( - "loops: ", - self.sche.get(loops[i]).loop_var, - self.sche.get(loops[i]).kind, - ) if tvm.tir.ForKind.SERIAL == self.sche.get(loops[i]).kind: reduce_axis.append(loops[i]) else: @@ -719,10 +604,6 @@ def rewrite_schedule( ] space_axis.append(loops[i]) - if DEBUG: - print("space_axis: ", space_axis) - print("reduce_axis: ", reduce_axis) - new_reduce_axis = [] for axis in reduce_axis.copy(): @@ -738,21 +619,12 @@ def rewrite_schedule( ) reduce_axis = reduce_axis + res new_reduce_axis = new_reduce_axis + res - if DEBUG: - self.sche.show() axis_order = new_reduce_axis + space_axis - if DEBUG: - print("-" * 100) - for axis in axis_order: - print(axis, end=" ") - print(self.sche.get(axis).loop_var) - print("-" * 100) - - # self.sche.reorder(*axis_order) # ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR - # space_fused = self.sche.fuse(*space_axis) # ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR - # self.sche.unroll(space_fused) # ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR + # self.sche.reorder(*axis_order) # TODO: ERROR + # space_fused = self.sche.fuse(*space_axis) # TODO: ERROR + # self.sche.unroll(space_fused) # TODO: ERROR else: self.sche[reg_tile].compute_at(self.sche[out], thrd_fused) space_axis = [] @@ -762,23 +634,15 @@ def rewrite_schedule( res = self.split_axis(reg_tile, axis) reduce_axis = reduce_axis + res axis_order = reduce_axis + space_axis - # print('axis_order', axis_order) - # print("[Split reduction axis]\n", axis_order) self.sche[reg_tile].reorder(*axis_order) space_fused = self.sche[reg_tile].fuse(*space_axis) self.sche[reg_tile].unroll(space_fused) else: if LatestTVM: - if DEBUG: - self.sche.show() - space_axis = [] block_compute = self.sche.get_block(target_stage) - if DEBUG: - print(self.sche.get(block_compute).iter_vars) loops = self.sche.get_loops(block_compute) - print([self.sche.get(loops[i]).loop_var for i in range(len(loops))]) # Purpose: Optimize reduction operations by hoisting initialization # out of hot loops @@ -821,8 +685,6 @@ def rewrite_schedule( self.sche.decompose_reduction(block_compute, loops[3]) ) target_stage += "_update" - if DEBUG: - self.sche.show() block_compute = self.sche.get_block(target_stage) iter_vars = self.sche.get(block_compute).iter_vars @@ -838,12 +700,6 @@ def rewrite_schedule( else: space_axis.append(loops[i]) else: - if DEBUG: - print( - "loops: ", - self.sche.get(loops[i]).loop_var, - self.sche.get(loops[i]).kind, - ) if tvm.tir.ForKind.SERIAL == self.sche.get(loops[i]).kind: reduce_axis.append(loops[i]) else: @@ -855,21 +711,6 @@ def rewrite_schedule( ] space_axis.append(loops[i]) - if DEBUG: - print( - "space_axis: ", - [ - self.sche.get(space_axis[i]).loop_var - for i in range(len(space_axis)) - ], - ) - print( - "reduce_axis: ", - [ - self.sche.get(reduce_axis[i]).loop_var - for i in range(len(reduce_axis)) - ], - ) new_reduce_axis = [] for axis in reduce_axis.copy(): @@ -884,15 +725,13 @@ def rewrite_schedule( ) reduce_axis = reduce_axis + res new_reduce_axis = new_reduce_axis + res - if DEBUG: - self.sche.show() # axis_order = reduce_axis + space_axis axis_order = new_reduce_axis + space_axis - # self.sche.reorder(*axis_order) # ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR - # space_fused = self.sche.fuse(*space_axis) # ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR - # self.sche.unroll(space_fused) # ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR ERROR + # self.sche.reorder(*axis_order) # TODO: ERROR + # space_fused = self.sche.fuse(*space_axis) # TODO: ERROR + # self.sche.unroll(space_fused) # TODO: ERROR else: for axis in self.sche[out].op.reduce_axis: res = self.split_axis(out, axis) @@ -902,37 +741,22 @@ def rewrite_schedule( self.sche[out].bind(reduce_axis[1], bind_idx) self.sche[out].set_store_predicate(bind_idx.var.equal(0)) - if DEBUG: - if not LatestTVM: - mod = tvm.lower( - self.sche, in_tensors + out_tensors, simple_mode=False - ) - print(mod.script()) - else: - self.sche.show() - if reg_tile is not None: if LatestTVM: for rt in reg_tensor: self.sche.compute_at(rt, new_reduce_axis[-1]) for st in smem_tensor: - if DEBUG: - print(dir(st)) old_axis = [ str(self.sche.get(self.sche.get_loops(st)[i]).loop_var) for i in range(len(self.sche.get_loops(st))) ] self.sche.compute_at(st, new_reduce_axis[0]) - if DEBUG: - self.sche.show() self.cooperative_fetch( st, self.sche, old_axis, shared_fetch_vectorize=shared_fetch_vectorize, ) - if DEBUG: - self.sche.show() else: for rt in reg_tensor: self.sche[rt].compute_at(self.sche[reg_tile], reduce_axis[-1]) @@ -949,23 +773,17 @@ def rewrite_schedule( for rt in reg_tensor: self.sche.compute_at(rt, new_reduce_axis[-1]) for st in smem_tensor: - if DEBUG: - print(dir(st)) old_axis = [ str(self.sche.get(self.sche.get_loops(st)[i]).loop_var) for i in range(len(self.sche.get_loops(st))) ] self.sche.compute_at(st, new_reduce_axis[0]) - if DEBUG: - self.sche.show() self.cooperative_fetch( st, self.sche, old_axis, shared_fetch_vectorize=shared_fetch_vectorize, ) - if DEBUG: - self.sche.show() else: for rt in reg_tensor: self.sche[rt].compute_at(self.sche[out], reduce_axis[-1]) @@ -978,15 +796,6 @@ def rewrite_schedule( shared_fetch_vectorize=shared_fetch_vectorize, ) - if DEBUG: - if not LatestTVM: - mod = tvm.lower( - self.sche, in_tensors + out_tensors, simple_mode=False - ) - print(mod.script()) - else: - self.sche.show() - if LatestTVM: for block_name in old_blocks: if block_name not in op_names: @@ -1052,8 +861,6 @@ def rewrite_schedule_fuse( # align_info = self.get_align_info_fuse(schedule, rprog, smem_bool, reg_bool, target_stage, st_align, bank_size, bank_number) for out in output_tensors: - # print('reduce:', self.sche[out].op.reduce_axis) - # print('space:', self.sche[out].op.axis) self.adjust_format(out) # TVM only allows binding reduce axis if it's the only one if self.binding["reduce"][1] is not None: @@ -1064,12 +871,11 @@ def rewrite_schedule_fuse( reduce_iters = out.op.reduce_axis space_iters = list(set(all_iters) - set(reduce_iters)) self.calc_grid(reduce_iters, space_iters) - # print("Target: {}\nSpace Iters: {}\nReduce Iters: {}\n".format(out, space_iters, reduce_iters)) smem_tensor = [] reg_tensor = [] reg_tile = self.sche.cache_write(out, "local") - # print("[Add cache stage]") + if self.need_smem_tiling: for input_tensor in input_tensors: self.sche[input_tensor].compute_inline() @@ -1101,13 +907,13 @@ def rewrite_schedule_fuse( -1 ] * (self.bank_size // 4) else: - print("shared mem tiling is too small.") + print("Shared mem tiling is too small.") self.tiling[axis.var.name][-1] = ( self.tiling[axis.var.name][-1] * self.tiling[axis.var.name][-3] ) self.tiling[axis.var.name][-3] = 1 - print("updated self.tiling: ", self.tiling) + print("Updated self.tiling: ", self.tiling) bx, vx, tx, tn = self.split_axis(out, axis) # bx, tx, tn = self.split_axis(out, axis) blck_axis.append(bx) @@ -1115,7 +921,7 @@ def rewrite_schedule_fuse( thrd_axis.append(tx) tile_axis.append(tn) axis_order = blck_axis + vthd_axis + thrd_axis + tile_axis - # print("[Split spatial axis]\n", axis_order) + self.sche[out].reorder(*axis_order) blck_fused = self.sche[out].fuse(*blck_axis) thrd_fused = self.sche[out].fuse(*thrd_axis) @@ -1153,7 +959,6 @@ def rewrite_schedule_fuse( self.sche[out].bind(reduce_axis[1], bind_idx) self.sche[out].set_store_predicate(bind_idx.var.equal(0)) - # print("[Cooperative fetching]") if reg_tile is not None: for rt in reg_tensor: self.sche[rt].compute_at(self.sche[reg_tile], reduce_axis[-1]) diff --git a/op/Op.py b/op/Op.py index 30c479a..f98eb7c 100644 --- a/op/Op.py +++ b/op/Op.py @@ -36,11 +36,11 @@ def __init__(self, expr, shape, data_type, use_tc=False) -> None: self.saxis, self.raxis = get_axis_names(self.output_tensors[0]) if LatestTVM: - if len(self.unpad_outs) > 0: - import warnings - warnings.warn( - "The unpad_outs length > 0, please check here." - ) + # if len(self.unpad_outs) > 0: + # import warnings + # warnings.warn( + # "The unpad_outs length > 0, please check here." + # ) fadd_pf = te.create_prim_func(self.input_tensors + self.output_tensors) mod = tvm.IRModule({"main": fadd_pf}) # Create a TIR schedule diff --git a/policy/ConstructionPolicyRT.py b/policy/ConstructionPolicyRT.py index f99ac02..ee53dcf 100644 --- a/policy/ConstructionPolicyRT.py +++ b/policy/ConstructionPolicyRT.py @@ -538,10 +538,14 @@ def emit_config_without_trails(self, topk): if len(self.top_results) == 0: self.top_results = self.border_rprogs[0][: self.TOPK] if len(self.top_results) == 0: - print("failed to find results with padding threshold {}".format(th)) + print( + "Failed to find results with padding threshold {} (padding_threshold_cap)".format( + th + ) + ) else: print( - "found {} results with threshold {}".format( + "Found {} results with threshold {} (padding_threshold_cap)".format( len(self.top_results), th ) ) diff --git a/test_op.py b/test_op.py index 61ee4ab..dba6863 100644 --- a/test_op.py +++ b/test_op.py @@ -67,7 +67,7 @@ ) # Generate result checking code for each kernel. parser.add_argument( - "--gen_check_code", dest="gen_check_code", action="store_true", default=True + "--gen_check_code", dest="gen_check_code", action="store_true", default=False ) parser.add_argument("--code_dir", type=str, default="./tmp_dir") parser.add_argument("--topk", type=int, default=10) @@ -78,7 +78,12 @@ parser.add_argument("--data_type", type=str, default="float32") parser.add_argument("--padding_threshold_cap", type=float, default=1.0) parser.add_argument("--keep_tiny", dest="keep_tiny", action="store_true") - +parser.add_argument( + "--verbose_cuda_code", dest="verbose_cuda_code", action="store_true", default=False +) +parser.add_argument( + "--verbose_irmodule", dest="verbose_irmodule", action="store_true", default=False +) args = parser.parse_args() @@ -316,7 +321,6 @@ def get_tvm_source( out_tensor = out_tensors[0] if args.fuse or args.schedule_fuse: pad = get_pad(rprog, out_tensor) - print("pad: ", pad) expr_out = expr(shape, dtype, False, pad) in_tensors, out_tensors = expr_out[0], expr_out[1] ori_in = [] @@ -403,16 +407,20 @@ def get_tvm_source( codegen_input_reg_tiling=args.codegen_input_reg_tiling, ) if LatestTVM: - print(s.mod) - target = tvm.target.Target("cuda") - mod = tvm.build(s.mod, target=target) + if args.verbose_irmodule: + print(s.mod) + + mod = tvm.build(s.mod, target=target) return mod.imported_modules[0].get_source() else: s.normalize() mod = tvm.lower(s, in_tensors + out_tensors, simple_mode=False) + if args.verbose_irmodule: + print(mod.script()) + func = tvm.build(s, in_tensors + out_tensors, "cuda") return func.imported_modules[0].get_source() @@ -470,7 +478,7 @@ def get_tvm_source( else: rprogs = policy.emit_config_without_trails(args.topk) - print("evaluating top {} configs".format(len(rprogs))) + print("Evaluating top {} configs".format(len(rprogs))) best_idx = -1 best_time = 1e100 idx = 0 @@ -480,8 +488,7 @@ def get_tvm_source( bar_id = 0 dtype = "float16" if args.use_tc else "float32" for rprog in rprogs: - print("id: {}".format(idx)) - print(rprog.Dump()) + print("rProg: ", rprog.Dump()) block_size = rprog.GetParallelism(1) * (32 if args.use_tc else 1) grid_size = rprog.GetParallelism(0) blocks = (block_size, 1, 1) @@ -532,9 +539,10 @@ def get_tvm_source( with open("{}.cu".format(file_name), "w") as ouf: ouf.write(main_source) - print("v" * 40) - print(main_source) - print("^" * 40) + if args.verbose_cuda_code: + print("v" * 40) + print(main_source) + print("^" * 40) os.system( "nvcc {}.cu -lcuda -gencode=arch=compute_{},code=compute_{} -o {}".format( @@ -556,13 +564,13 @@ def get_tvm_source( os.system("rm {}".format(file_name)) os.system("rm {}.cu".format(file_name)) - print("LOG_NAME: {}".format(log_name)) with open(log_name, "r") as f: for line in f.readlines(): print(line, end="") exec_time = get_time_from_nvprof_file(log_name) os.system("rm {}".format(log_name)) + if exec_time < best_time: best_idx = idx best_rprog = rprog @@ -573,7 +581,6 @@ def get_tvm_source( best_grid_size = grid_size idx += 1 - print(idx, bar_id) if idx == eval_bar[bar_id]: cur_time = time.time() eval_results = {} @@ -585,11 +592,11 @@ def get_tvm_source( bar_id += 1 for topx, eval_results in zip(eval_bar, evals): - print("Eval top {} configs ======================".format(topx)) - print("compilation time: {}s".format(eval_results["compilation time"])) - print("best time: {}ms".format(eval_results["best time"])) - print("best config: {}".format(eval_results["best config"])) - print("best idx: {}".format(eval_results["best idx"])) + print("Eval top {} configs".format(topx)) + print("Compilation time: {}s".format(eval_results["compilation time"])) + print("Best time: {}ms".format(eval_results["best time"])) + print("Best config: {}".format(eval_results["best config"])) + print("Best idx: {}".format(eval_results["best idx"])) cu_file_name = "roller_{}_{}.cu".format( args.op, "_".join([str(d) for d in args.shape]) diff --git a/test_op_mp.py b/test_op_mp.py index 03fda25..f584fe1 100644 --- a/test_op_mp.py +++ b/test_op_mp.py @@ -67,7 +67,7 @@ ) # Generate result checking code for each kernel. parser.add_argument( - "--gen_check_code", dest="gen_check_code", action="store_true", default=True + "--gen_check_code", dest="gen_check_code", action="store_true", default=False ) parser.add_argument("--code_dir", type=str, default="./tmp_dir") parser.add_argument("--topk", type=int, default=10) @@ -81,6 +81,12 @@ # If you have several GPUs with the same architecture, you can change # the num_threads to run them in parallel. parser.add_argument("--num_threads", type=int, default=1) +parser.add_argument( + "--verbose_cuda_code", dest="verbose_cuda_code", action="store_true", default=False +) +parser.add_argument( + "--verbose_irmodule", dest="verbose_irmodule", action="store_true", default=False +) args = parser.parse_args() top1_time = 0 @@ -319,7 +325,6 @@ def get_tvm_source( out_tensor = out_tensors[0] if args.fuse or args.schedule_fuse: pad = get_pad(rprog, out_tensor) - print("pad: ", pad) expr_out = expr(shape, dtype, False, pad) in_tensors, out_tensors = expr_out[0], expr_out[1] ori_in = [] @@ -406,8 +411,6 @@ def get_tvm_source( codegen_input_reg_tiling=args.codegen_input_reg_tiling, ) if LatestTVM: - print(s.mod) - target = tvm.target.Target("cuda") mod = tvm.build(s.mod, target=target) @@ -435,6 +438,7 @@ def compile_and_run_kernel( device_id, idx, ): + print("rProg: ", rprog.Dump()) block_size = rprog.GetParallelism(1) * (32 if args.use_tc else 1) grid_size = rprog.GetParallelism(0) blocks = (block_size, 1, 1) @@ -488,9 +492,10 @@ def compile_and_run_kernel( with open("{}.cu".format(file_name), "w") as ouf: ouf.write(main_source) - print("v" * 40) - print(main_source) - print("^" * 40) + if args.verbose_cuda_code: + print("v" * 40) + print(main_source) + print("^" * 40) os.system( "nvcc {}.cu -lcuda -gencode=arch=compute_{},code=compute_{} -o {}".format( @@ -512,7 +517,6 @@ def compile_and_run_kernel( os.system("rm {}".format(file_name)) os.system("rm {}.cu".format(file_name)) - print("LOG_NAME: {}".format(log_name)) with open(log_name, "r") as f: for line in f.readlines(): print(line, end="") @@ -613,7 +617,7 @@ def eval_thread( else: rprogs = policy.emit_config_without_trails(args.topk) - print("evaluating top {} configs".format(len(rprogs))) + print("Evaluating top {} configs".format(len(rprogs))) rprog_idx = alloc_configs_for_subprocess(args.num_threads, len(rprogs)) threads = [] @@ -646,12 +650,12 @@ def eval_thread( eval_time = time.time() - start_time - print("top1 time: {} ms".format(top1_time)) - print("top10 time: {} ms".format(best_time)) - print("best idx: {}".format(best_idx)) - print("best config: {}".format(rprogs[best_idx].Dump())) - print("top1 compile time: {} s".format(emit_time)) - print("top10 compile time: {} s".format(eval_time)) + print("Top1 time: {} ms".format(top1_time)) + print("Top10 time: {} ms".format(best_time)) + print("Best idx: {}".format(best_idx)) + print("Best config: {}".format(rprogs[best_idx].Dump())) + print("Top1 compile time: {} s".format(emit_time)) + print("Top10 compile time: {} s".format(eval_time)) cu_file_name = "roller_{}_{}.cu".format( args.op, "_".join([str(d) for d in args.shape]) From 54892687d02909c9b7916a3f70addd07aa6d9613 Mon Sep 17 00:00:00 2001 From: ConvolutedDog Date: Sun, 28 Sep 2025 21:22:29 +0800 Subject: [PATCH 2/3] Update test_op_mp.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- test_op_mp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_op_mp.py b/test_op_mp.py index f584fe1..07cceb1 100644 --- a/test_op_mp.py +++ b/test_op_mp.py @@ -438,7 +438,7 @@ def compile_and_run_kernel( device_id, idx, ): - print("rProg: ", rprog.Dump()) + print(f"rProg[{idx}]: {rprog.Dump()}") block_size = rprog.GetParallelism(1) * (32 if args.use_tc else 1) grid_size = rprog.GetParallelism(0) blocks = (block_size, 1, 1) From 46bcdfe05100ad4b514fafa8e546061ed1ab2450 Mon Sep 17 00:00:00 2001 From: ConvolutedDog Date: Sun, 28 Sep 2025 21:23:32 +0800 Subject: [PATCH 3/3] Update test_op.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- test_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_op.py b/test_op.py index dba6863..963981f 100644 --- a/test_op.py +++ b/test_op.py @@ -488,7 +488,7 @@ def get_tvm_source( bar_id = 0 dtype = "float16" if args.use_tc else "float32" for rprog in rprogs: - print("rProg: ", rprog.Dump()) + print(f"rProg[{idx}]: {rprog.Dump()}") block_size = rprog.GetParallelism(1) * (32 if args.use_tc else 1) grid_size = rprog.GetParallelism(0) blocks = (block_size, 1, 1)