Skip to content

Commit beaa4a3

Browse files
committed
[Frontend] Fix ops conversion
1 parent e54325d commit beaa4a3

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -803,12 +803,11 @@ def where(condition, operand1, operand2, *args, var_info=None, **kwargs):
803803
cond_type = var_info[condition]
804804
operand_type = var_info[operand1]
805805
if cond_type[0] < tile_size:
806-
condition = ops.broadcast(condition, operand_type[0])
806+
condition = ops.broadcast(condition, tile_size)
807807
elif cond_type[0] > tile_size:
808-
operand1 = ops.broadcast(operand1, operand_type[0])
809-
operand2 = ops.broadcast(operand2, operand_type[0])
808+
operand1 = ops.broadcast(operand1, cond_type[0])
809+
operand2 = ops.broadcast(operand2, cond_type[0])
810810
tile_size, ret_type = var_info[operand1]
811-
812811
shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type
813812
cond_shape = f"vector<{tile_size}xi1>," if tile_size > 1 else ""
814813
return f"arith.select %{condition}, %{operand1}, %{operand2} : {cond_shape} {shape}", [tile_size, ret_type]
@@ -1164,10 +1163,6 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs):
11641163
# Todo. If tile_size is not same (i.e., view operation), we can't apply peephole optimization easily
11651164
require_store = self.spad_buffer_dict[str(value)][1] != tile_size
11661165

1167-
if compute_vec_size < self.var_info[value][0]:
1168-
value = self.cse.generate(self.stores, f"vector.extract_strided_slice %{value} {{offsets = [0], sizes = [{compute_vec_size}], strides = [1]}}: vector<{self.var_info[value][0]}x{self.var_info[value][1]}> to {vshape}")
1169-
self.register_var_info(value, [compute_vec_size, mlir_dtype])
1170-
11711166
if require_store:
11721167
# Define scratch pad buffer
11731168
sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index)
@@ -1176,6 +1171,11 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs):
11761171
_, operand_type = self.var_info[value]
11771172
if mlir_dtype != operand_type:
11781173
value = ops.custom_cast(value, mlir_dtype)
1174+
1175+
if compute_vec_size < self.var_info[value][0]:
1176+
value = self.cse.generate(self.stores, f"vector.extract_strided_slice %{value} {{offsets = [0], sizes = [{compute_vec_size}], strides = [1]}}: vector<{self.var_info[value][0]}x{self.var_info[value][1]}> to {vshape}")
1177+
self.register_var_info(value, [compute_vec_size, mlir_dtype])
1178+
11791179
with self.override_buffer_cse(buffer=self.stores):
11801180
ops._store(value, sram_var, compute_index_var, tile_shape, buffer_name=name)
11811181
else:

PyTorchSimFrontend/mlir/mlir_template.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -919,24 +919,23 @@ def store_reduction_epilogue(self, name, index, value):
919919
compute_index_var = ",".join(partial_zero_var_list)
920920

921921
with self.override_buffer_cse(buffer=self.reductions_suffix):
922-
out = ops._load(partial_vec_size, mlir_dtype, sram_var, compute_index_var, partial_tile_shape)
922+
out = ops._load(partial_vec_size, mlir_dtype, value, compute_index_var, partial_tile_shape)
923923
ops._store(init_vec, value, compute_index_var, partial_tile_shape) # Clear the partial buffer to zero
924924

925925
# 2 step reduction
926926
new_vec_size = 2
927-
new_reduced_shape = f"<{new_vec_size}x{mlir_dtype}>"
927+
new_reduced_shape = f"vector<{new_vec_size}x{mlir_dtype}>"
928928
reduction_type = self.reduction_info[value][0]
929-
out = ops.multi_reduction(out, init_vec, partial_vec_size, new_vec_size, reduction_type, partial_vshape, self.reduction_info[value][0], mlir_dtype)
929+
out = ops.multi_reduction(out, init_vec2, partial_vec_size, new_vec_size, partial_vshape, reduction_type, mlir_dtype)
930930

931931
out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}")
932932
self.register_var_info(out2, [new_vec_size, mlir_dtype])
933933

934934
with self.override_buffer_cse(buffer=self.reductions_suffix):
935935
out = reduction_partial_combine_vec(self.reduction_info[value][0], out, out2)
936936

937-
if self.welford_reduce_out is not None:
938-
# NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2
939-
with self.override_buffer_cse(buffer=self.reductions_suffix):
937+
if self.welford_reduce_out is not None:
938+
# NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2
940939
divider = ops.constant(float(self.reduction_axis_size), "f32")
941940
if self.buffer_types[name][1] > 1:
942941
divider_vec = ops.broadcast(divider, new_vec_size)
@@ -955,9 +954,9 @@ def store_reduction_epilogue(self, name, index, value):
955954
m2 = ops.mul(variance, divider_vec)
956955
out = m2
957956

958-
final_zero_var_list[-1] = f"%{body_index_var}"
959-
final_compute_index_var = ",".join(final_zero_var_list)
960-
ops._store(out, sram_var, final_compute_index_var, final_tile_shape, buffer_name=name)
957+
final_zero_var_list[-1] = f"%{body_index_var}"
958+
final_compute_index_var = ",".join(final_zero_var_list)
959+
ops._store(out, sram_var, final_compute_index_var, final_tile_shape, buffer_name=name)
961960

962961
# MVOUT Encoding
963962
# Generate DMA instruction

0 commit comments

Comments
 (0)