diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 99c39322..c63f53bb 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -639,6 +639,15 @@ def set_ranges(self, lengths, reduction_lengths, read_writes): self.adjust_tile_size() return ret + # padding type 0: zero-padding 1: negative-padding(-inf) ... + def get_padding_type(self): + ops = self.current_node.node.origins + if self.current_node.is_reduction(): + for op in ops: + if "exp" in op.name: # exponential reduciton case + return 1 + return 0 + def parse_indices(self, expr): if len(expr.args) == 0: return expr @@ -670,6 +679,7 @@ def parse_indices(self, expr): def load(self, name: str, index: sympy.Expr): index = self.rename_indexing(index) indices = self.parse_indices(index) + padding = self.get_padding_type() prefix = self.newvar_prefix if index.is_number: prefix = prefix + "c" @@ -696,7 +706,7 @@ def load(self, name: str, index: sympy.Expr): self.dma_cache[dma_key] = dmaType, stride, chunk self.tags.add(f"{name}_tag") self.consts.add(0) - code = f"affine.dma_start %{var}[{prefix}{indices}], %{buffer}[%c0, %c0], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" + code = f"affine.dma_start %{var}[{prefix}{indices}], %{buffer}[%c0, %c0], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32> {{padding = {padding}}}" self.cse.generate(self.loads, code, assignment = False) # FIXME: assignment = False does not support caching operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load" @@ -777,10 +787,10 @@ def reduction(self, dtype, src_dtype, reduction_type, value): shape = f"vector<{self.tile_desc.get_tile_size()}x{type_name}>" reduced_shape = type_name init = self.cse.generate(self.reduction_prefix, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - if len(self.ranges) == 1: + if len(self.ranges) == 1: # 1-D vector to scalar axis = "0" acc_var = init - shape = f"vector<{self.tile_desc.get_tile_size_per_lane()}x{type_name}>" + shape = f"vector<{self.tile_desc.get_tile_size()}x{type_name}>" # use single vector lane elif len(self.ranges) == 2: vec_len = self.tile_desc.get_rows_per_lane() flattened_size = f"vector<{self.tile_desc.get_tile_size_per_lane()}x{type_name}>" @@ -999,6 +1009,9 @@ def get_dma_info(self, name, index, dtype): current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_COL_WISE # Actually it is not needed in vector case chunk_size = current_tile.get_chunk_size() mm_stride = current_tile.n_col + if self.is_scalar(name): # scalar to vector broadcasting + mm_stride = 0 + current_tile.n_row, current_tile.n_col = current_tile.n_col, current_tile.n_row # Case 2. Tile is 1-D vector type with reduction elif len(cv) == 1 and len(cv) == self.reduction_depth + 1: # Use only one vectorlane to reduce a vector @@ -1009,6 +1022,9 @@ def get_dma_info(self, name, index, dtype): current_tile.used_vector_lane = 1 chunk_size = current_tile.get_chunk_size() mm_stride = 0 # don't care + tile_size_per_lane = current_tile.get_tile_size_per_lane() + if self.is_scalar(name): # scalar to vector broadcasting + current_tile.n_row, current_tile.n_col = current_tile.n_col, current_tile.n_row # Case 3. Tile is 2-D tile elif len(cv) == 2: is_reduction = self.reduction_depth == 1 @@ -1094,7 +1110,9 @@ def adjust_tile_size(self): # Case 1. vector kernel if len(self.itervars) == 1: - self.tile_desc.n_col = self.tile_desc.get_tile_size() + tile_size = self.tile_desc.get_tile_size() if self.tile_desc.get_tile_size() < self.ranges[0] else self.ranges[0] + min_tile_size_unit = self.vector_lane * self.vlen # TODO: VCIX widening is not implemented + self.tile_desc.n_col = math.ceil(tile_size / min_tile_size_unit) * min_tile_size_unit # padding self.tile_desc.n_row = 1 elif len(self.itervars) == 0: self.tile_desc.n_col = 1 diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index a949cb5d..21612a4c 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -369,6 +369,9 @@ def find_node_by_name(self, name): if output_node.data.name == name: return output_node + def is_scalar(self, name): + return self.buffer_types[name][1] == 1 + def roundup_vectorlane(self, size, amp=1): return ((size + self.vector_lane - 1) // self.vector_lane) * self.vector_lane * amp diff --git a/tests/MoE/test_moe.py b/tests/MoE/test_moe.py index ff6dd00b..d14bf5c6 100644 --- a/tests/MoE/test_moe.py +++ b/tests/MoE/test_moe.py @@ -420,15 +420,15 @@ def test_moe(device): x1 = copy.deepcopy(X).to(device=device) x2 = copy.deepcopy(X).to("cpu") - # model.train() - model.eval() + model.train() + # model.eval() model_device = model.to(device=device) opt_model = torch.compile(model_device, dynamic=False) y_hat, aux_loss = opt_model(x1) print("MoE Custom Device Done!") - # model_cpu.train() - model_cpu.eval() + model_cpu.train() + # model_cpu.eval() cpu_hat, cpu_aux_loss = model_cpu(x2) test_result("MoE Forward", y_hat, cpu_hat) test_result("MoE Aux Loss", aux_loss, cpu_aux_loss) @@ -453,15 +453,15 @@ def test_moe(device): total_cpu_loss.backward() print("MoE Backward Done!") - print("MoE Weight Bias print") - for i in range(num_experts): - print(f"\nExpert {i}") - print(f"FC1 Weight: {model.experts[i].fc1.weight.cpu()}") - print(f"FC1 Bias: {model.experts[i].fc1.bias.cpu()}") - print("\n") - print(f"FC2 Weight: {model.experts[i].fc2.weight.cpu()}") - print(f"FC2 Bias: {model.experts[i].fc2.bias.cpu()}") - print("\n") + # print("MoE Weight Bias print") + # for i in range(num_experts): + # print(f"\nExpert {i}") + # print(f"FC1 Weight: {model.experts[i].fc1.weight.cpu()}") + # print(f"FC1 Bias: {model.experts[i].fc1.bias.cpu()}") + # print("\n") + # print(f"FC2 Weight: {model.experts[i].fc2.weight.cpu()}") + # print(f"FC2 Bias: {model.experts[i].fc2.bias.cpu()}") + # print("\n") print("MoE Weight Bias Grad") for i in range(num_experts): diff --git a/tests/test_single_perceptron.py b/tests/test_single_perceptron.py index 7ab02656..78a6b117 100644 --- a/tests/test_single_perceptron.py +++ b/tests/test_single_perceptron.py @@ -41,13 +41,14 @@ def weight_update(a, b, lr): b2.requires_grad = True opt_mlp = torch.compile(dynamic=False)(perceptron) opt_w = torch.compile(dynamic=False)(weight_update) - opt_loss = torch.compile(dynamic=False)(torch.nn.MSELoss()) + loss_fn = torch.nn.MSELoss() + opt_loss = torch.compile(dynamic=False)(loss_fn) lr = torch.tensor(5e-2).to(device=device) # learning rate y = opt_mlp(w1, x1, b1) loss = opt_loss(y, y1) loss.backward() cpu_y = perceptron(x2, w2, b2) - cpu_loss = torch.nn.MSELoss()(cpu_y, y2) + cpu_loss = loss_fn(cpu_y, y2) cpu_loss.backward() test_result("Perceptron", y, cpu_y) test_result("Loss", loss, cpu_loss) diff --git a/tests/test_softmax.py b/tests/test_softmax.py index ca49953c..d68638f8 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -18,6 +18,29 @@ def test_softmax(device, size=(128, 128), dim=1): input = torch.randn(size) x1 = input.to(device=device) x2 = input.to("cpu") + + # split softmax into 3 steps + # def softmax1(x): # find max + # return x.max(dim=dim, keepdim=True).values + # def softmax2(x, max): + # return (x - max).exp().sum(dim=dim, keepdim=True) + # def softmax3(x, max, sum): + # return (x - max).exp().div(sum) + + # opt_fn1 = torch.compile(dynamic=False)(softmax1) + # opt_fn2 = torch.compile(dynamic=False)(softmax2) + # opt_fn3 = torch.compile(dynamic=False)(softmax3) + + # max = opt_fn1(x1) + # cpu_max = softmax1(x2) + # test_result("Softmax Max", max, cpu_max) + # sum = opt_fn2(x1, max) + # cpu_sum = softmax2(x2, cpu_max) + # test_result("Softmax Sum", sum, cpu_sum) + # y = opt_fn3(x1, max, sum) + # cpu_y = softmax3(x2, cpu_max, cpu_sum) + # test_result("Softmax", y, cpu_y) + opt_fn = torch.compile(dynamic=False)(torch.nn.functional.softmax) y = opt_fn(x1, dim=dim) cpu_y = torch.nn.functional.softmax(x2, dim=dim) @@ -33,3 +56,4 @@ def test_softmax(device, size=(128, 128), dim=1): device = module.custom_device() test_softmax(device, size=(64, 128)) test_softmax(device, size=(256, 128)) + test_softmax(device, size=(1, 16))