Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,15 @@ def set_ranges(self, lengths, reduction_lengths, read_writes):
self.adjust_tile_size()
return ret

# padding type 0: zero-padding 1: negative-padding(-inf) ...
def get_padding_type(self):
ops = self.current_node.node.origins
if self.current_node.is_reduction():
for op in ops:
if "exp" in op.name: # exponential reduciton case
return 1
return 0

def parse_indices(self, expr):
if len(expr.args) == 0:
return expr
Expand Down Expand Up @@ -670,6 +679,7 @@ def parse_indices(self, expr):
def load(self, name: str, index: sympy.Expr):
index = self.rename_indexing(index)
indices = self.parse_indices(index)
padding = self.get_padding_type()
prefix = self.newvar_prefix
if index.is_number:
prefix = prefix + "c"
Expand All @@ -696,7 +706,7 @@ def load(self, name: str, index: sympy.Expr):
self.dma_cache[dma_key] = dmaType, stride, chunk
self.tags.add(f"{name}_tag")
self.consts.add(0)
code = f"affine.dma_start %{var}[{prefix}{indices}], %{buffer}[%c0, %c0], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>"
code = f"affine.dma_start %{var}[{prefix}{indices}], %{buffer}[%c0, %c0], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32> {{padding = {padding}}}"
self.cse.generate(self.loads, code, assignment = False) # FIXME: assignment = False does not support caching

operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load"
Expand Down Expand Up @@ -777,10 +787,10 @@ def reduction(self, dtype, src_dtype, reduction_type, value):
shape = f"vector<{self.tile_desc.get_tile_size()}x{type_name}>"
reduced_shape = type_name
init = self.cse.generate(self.reduction_prefix, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}")
if len(self.ranges) == 1:
if len(self.ranges) == 1: # 1-D vector to scalar
axis = "0"
acc_var = init
shape = f"vector<{self.tile_desc.get_tile_size_per_lane()}x{type_name}>"
shape = f"vector<{self.tile_desc.get_tile_size()}x{type_name}>" # use single vector lane
elif len(self.ranges) == 2:
vec_len = self.tile_desc.get_rows_per_lane()
flattened_size = f"vector<{self.tile_desc.get_tile_size_per_lane()}x{type_name}>"
Expand Down Expand Up @@ -999,6 +1009,9 @@ def get_dma_info(self, name, index, dtype):
current_tile.tile_per_lane_layout = mlir_common.MLIRTile.TILE_PER_LANE_COL_WISE # Actually it is not needed in vector case
chunk_size = current_tile.get_chunk_size()
mm_stride = current_tile.n_col
if self.is_scalar(name): # scalar to vector broadcasting
mm_stride = 0
current_tile.n_row, current_tile.n_col = current_tile.n_col, current_tile.n_row
# Case 2. Tile is 1-D vector type with reduction
elif len(cv) == 1 and len(cv) == self.reduction_depth + 1:
# Use only one vectorlane to reduce a vector
Expand All @@ -1009,6 +1022,9 @@ def get_dma_info(self, name, index, dtype):
current_tile.used_vector_lane = 1
chunk_size = current_tile.get_chunk_size()
mm_stride = 0 # don't care
tile_size_per_lane = current_tile.get_tile_size_per_lane()
if self.is_scalar(name): # scalar to vector broadcasting
current_tile.n_row, current_tile.n_col = current_tile.n_col, current_tile.n_row
# Case 3. Tile is 2-D tile
elif len(cv) == 2:
is_reduction = self.reduction_depth == 1
Expand Down Expand Up @@ -1094,7 +1110,9 @@ def adjust_tile_size(self):

# Case 1. vector kernel
if len(self.itervars) == 1:
self.tile_desc.n_col = self.tile_desc.get_tile_size()
tile_size = self.tile_desc.get_tile_size() if self.tile_desc.get_tile_size() < self.ranges[0] else self.ranges[0]
min_tile_size_unit = self.vector_lane * self.vlen # TODO: VCIX widening is not implemented
self.tile_desc.n_col = math.ceil(tile_size / min_tile_size_unit) * min_tile_size_unit # padding
self.tile_desc.n_row = 1
elif len(self.itervars) == 0:
self.tile_desc.n_col = 1
Expand Down
3 changes: 3 additions & 0 deletions PyTorchSimFrontend/mlir/mlir_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ def find_node_by_name(self, name):
if output_node.data.name == name:
return output_node

def is_scalar(self, name):
return self.buffer_types[name][1] == 1

def roundup_vectorlane(self, size, amp=1):
return ((size + self.vector_lane - 1) // self.vector_lane) * self.vector_lane * amp

Expand Down
26 changes: 13 additions & 13 deletions tests/MoE/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,15 +420,15 @@ def test_moe(device):
x1 = copy.deepcopy(X).to(device=device)
x2 = copy.deepcopy(X).to("cpu")

# model.train()
model.eval()
model.train()
# model.eval()
model_device = model.to(device=device)
opt_model = torch.compile(model_device, dynamic=False)
y_hat, aux_loss = opt_model(x1)
print("MoE Custom Device Done!")

# model_cpu.train()
model_cpu.eval()
model_cpu.train()
# model_cpu.eval()
cpu_hat, cpu_aux_loss = model_cpu(x2)
test_result("MoE Forward", y_hat, cpu_hat)
test_result("MoE Aux Loss", aux_loss, cpu_aux_loss)
Expand All @@ -453,15 +453,15 @@ def test_moe(device):
total_cpu_loss.backward()
print("MoE Backward Done!")

print("MoE Weight Bias print")
for i in range(num_experts):
print(f"\nExpert {i}")
print(f"FC1 Weight: {model.experts[i].fc1.weight.cpu()}")
print(f"FC1 Bias: {model.experts[i].fc1.bias.cpu()}")
print("\n")
print(f"FC2 Weight: {model.experts[i].fc2.weight.cpu()}")
print(f"FC2 Bias: {model.experts[i].fc2.bias.cpu()}")
print("\n")
# print("MoE Weight Bias print")
# for i in range(num_experts):
# print(f"\nExpert {i}")
# print(f"FC1 Weight: {model.experts[i].fc1.weight.cpu()}")
# print(f"FC1 Bias: {model.experts[i].fc1.bias.cpu()}")
# print("\n")
# print(f"FC2 Weight: {model.experts[i].fc2.weight.cpu()}")
# print(f"FC2 Bias: {model.experts[i].fc2.bias.cpu()}")
# print("\n")

print("MoE Weight Bias Grad")
for i in range(num_experts):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_single_perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ def weight_update(a, b, lr):
b2.requires_grad = True
opt_mlp = torch.compile(dynamic=False)(perceptron)
opt_w = torch.compile(dynamic=False)(weight_update)
opt_loss = torch.compile(dynamic=False)(torch.nn.MSELoss())
loss_fn = torch.nn.MSELoss()
opt_loss = torch.compile(dynamic=False)(loss_fn)
lr = torch.tensor(5e-2).to(device=device) # learning rate
y = opt_mlp(w1, x1, b1)
loss = opt_loss(y, y1)
loss.backward()
cpu_y = perceptron(x2, w2, b2)
cpu_loss = torch.nn.MSELoss()(cpu_y, y2)
cpu_loss = loss_fn(cpu_y, y2)
cpu_loss.backward()
test_result("Perceptron", y, cpu_y)
test_result("Loss", loss, cpu_loss)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,29 @@ def test_softmax(device, size=(128, 128), dim=1):
input = torch.randn(size)
x1 = input.to(device=device)
x2 = input.to("cpu")

# split softmax into 3 steps
# def softmax1(x): # find max
# return x.max(dim=dim, keepdim=True).values
# def softmax2(x, max):
# return (x - max).exp().sum(dim=dim, keepdim=True)
# def softmax3(x, max, sum):
# return (x - max).exp().div(sum)

# opt_fn1 = torch.compile(dynamic=False)(softmax1)
# opt_fn2 = torch.compile(dynamic=False)(softmax2)
# opt_fn3 = torch.compile(dynamic=False)(softmax3)

# max = opt_fn1(x1)
# cpu_max = softmax1(x2)
# test_result("Softmax Max", max, cpu_max)
# sum = opt_fn2(x1, max)
# cpu_sum = softmax2(x2, cpu_max)
# test_result("Softmax Sum", sum, cpu_sum)
# y = opt_fn3(x1, max, sum)
# cpu_y = softmax3(x2, cpu_max, cpu_sum)
# test_result("Softmax", y, cpu_y)

opt_fn = torch.compile(dynamic=False)(torch.nn.functional.softmax)
y = opt_fn(x1, dim=dim)
cpu_y = torch.nn.functional.softmax(x2, dim=dim)
Expand All @@ -33,3 +56,4 @@ def test_softmax(device, size=(128, 128), dim=1):
device = module.custom_device()
test_softmax(device, size=(64, 128))
test_softmax(device, size=(256, 128))
test_softmax(device, size=(1, 16))
Loading