diff --git a/PyTorchSimFrontend/llvm/llvm_codegen_backend.py b/PyTorchSimFrontend/llvm/llvm_codegen_backend.py index e8daa889..6951b5bd 100644 --- a/PyTorchSimFrontend/llvm/llvm_codegen_backend.py +++ b/PyTorchSimFrontend/llvm/llvm_codegen_backend.py @@ -212,7 +212,7 @@ def maximum(operand1, operand2, tile_size=4): @staticmethod def relu(x, tile_size=4): - return ops.maximum(x, ops.constant(0.0, torch.int32)) + return ops.maximum(x, ops.constant(0.0, "f32")) SYMPY_TO_LLVM = { sympy.core.mul.Mul: "mul", diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 140ba3fd..3d65aa53 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -107,53 +107,140 @@ def write_header(self): ) class ExtensionOverrides(common.OpOverrides): + # Binary element wise operations @staticmethod - def add(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.add{dtype[0]} %{operand1}, %{operand2} : {shape}' + def custom_cast(operand, target_type, *args, var_info=None): + dtype = var_info[operand][1] + if dtype == "index": + ret = ops.index_cast(operand, target_type, var_info=var_info) + else: + ret = ops.to_dtype(operand, target_type, var_info=var_info) + return ret, var_info[ret] @staticmethod - def sub(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.sub{dtype[0]} %{operand1}, %{operand2} : {shape}' + def binary_elementwise_common(operand1, operand2, var_info): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + # Tile size check + if op_type1[0] != op_type2[0]: + # Try to broad cast + lhs_tile_size, lhs_dtype = op_type1 + rhs_tile_size, rhs_dtype = op_type2 + if lhs_tile_size > rhs_tile_size: + operand2 = ops.broadcast(operand2, operand1, var_info=var_info) + op_type2 = var_info[operand2] + elif lhs_tile_size < rhs_tile_size: + operand1 = ops.broadcast(operand1, operand2, var_info=var_info) + op_type1 = var_info[operand1] + + # Data type check + if op_type1[1] != op_type2[1]: + if op_type1[1] == "index" or op_type1 == "index": + if op_type1[1] == "index": + operand1 = ops.index_cast(operand1, op_type2[1], var_info) + op_type1 = var_info[operand1] + if op_type2[1] == "index": + operand2 = ops.index_cast(operand2, op_type1[1], var_info) + op_type2 = var_info[operand2] + elif op_type1[1][0] == "i" and op_type2[1][0] == "f": + operand1 = ops.to_dtype(operand1, op_type2[1], var_info) + op_type1 = var_info[operand1] + elif op_type1[1][0] == "f" and op_type2[1][0] == "i": + operand2 = ops.to_dtype(operand2, op_type1[1], var_info) + op_type2 = var_info[operand2] + else: + raise NotImplementedError("Unsupported type converting") + + # Updated var info + tile_size = op_type1[0] + ret_type = op_type1[1] + return tile_size, ret_type, operand1, operand2 @staticmethod - def mul(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.mul{dtype[0]} %{operand1}, %{operand2} : {shape}' + def add(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.add{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def div(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.div{dtype[0]} %{operand1}, %{operand2} : {shape}' + def sub(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.sub{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def truediv(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.div{dtype[0]} %{operand1}, %{operand2} : {shape}' + def mul(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.mul{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def div(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.divf' + else: + opcode = f'arith.divui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def truediv(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.divf' + else: + opcode = f'arith.divui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def minimum(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.minimumf' + else: + opcode = f'arith.minimumui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def maximum(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.maximumf' + else: + opcode = f'arith.maximumui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def to_dtype(x, dst_type, src_dtype=None, tile_size=16, dtype="f32"): - mlir_dtype = mlir_common.DTYPE_TO_MLIR[dst_type] - src_mlir_dtype = mlir_common.DTYPE_TO_MLIR[src_dtype] + def to_dtype(operand, dst_mlir_dtype, *args, var_info=None): + src_mlir_dtype = var_info[operand][1] + tile_size = var_info[operand][0] - dst_bits = 1 if dst_type == torch.bool else torch.finfo(dst_type).bits if dst_type.is_floating_point else torch.iinfo(dst_type).bits - src_bits = 1 if src_dtype == torch.bool else torch.finfo(src_dtype).bits if src_dtype.is_floating_point else torch.iinfo(src_dtype).bits - shape = f"vector<{tile_size}x{mlir_dtype}>" if tile_size > 1 else mlir_dtype + dst_bits = int(dst_mlir_dtype[1:]) + src_bits = int(src_mlir_dtype[1:]) + shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype src_shape = f"vector<{tile_size}x{src_mlir_dtype}>" if tile_size > 1 else src_mlir_dtype - if dst_type.is_floating_point and not src_dtype.is_floating_point: + if dst_mlir_dtype[0] == "i" and src_mlir_dtype[0] == "f": raise NotImplementedError("floating point to integer conversion") - elif not dst_type.is_floating_point and src_dtype.is_floating_point: + if dst_mlir_dtype[0] == "f" and src_mlir_dtype[0] == "i": raise NotImplementedError("integer to floating point conversion") else: if dst_bits > src_bits: - return f"arith.extui %{x} : {src_shape} to {shape}" + return f"arith.extui %{operand} : {src_shape} to {shape}" elif dst_bits < src_bits: - return f"arith.trunc %{x} : {src_shape} to {shape}" + return f"arith.trunc %{operand} : {src_shape} to {shape}" @staticmethod - def constant(value, src_type, tile_size=16, dtype="f32"): - src_type = mlir_common.DTYPE_TO_MLIR[src_type] + def constant(value, src_type, *args, var_info=None): + if isinstance(src_type, torch.dtype): + src_type = mlir_common.DTYPE_TO_MLIR[src_type] + # if value represented by e notation, convert to float (ex 1e-3 -> 1.0e-3) if "e" in str(value): value = float(value) @@ -161,89 +248,354 @@ def constant(value, src_type, tile_size=16, dtype="f32"): value = format(value, ".20f") if src_type[0] == "i": value = int(value) - return f'arith.constant {value} : {src_type}' + return f'arith.constant {value} : {src_type}', [1, src_type] + # transcendental functions @staticmethod - def exp(operand, tile_size=16, dtype="f32"): + def exp(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.exp %{operand} : {shape}' + return f'math.exp %{operand} : {shape}', [tile_size, dtype] @staticmethod - def maximum(operand1, operand2, tile_size=16, dtype="f32"): + def sqrt(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.maximum{dtype[0]} %{operand1}, %{operand2} : {shape}' + return f'math.sqrt %{operand} : {shape}', [tile_size, dtype] @staticmethod - def sqrt(x, tile_size=16, dtype="f32"): + def rsqrt(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.sqrt %{x} : {shape}' + return f'math.rsqrt %{operand} : {shape}', [tile_size, dtype] @staticmethod - def ne(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else "i1" - return f'arith.cmp{dtype[0]} one, %{operand1}, %{operand2} : {shape}' + def pow(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "f": + operand1, dtype = ops.to_dtype(operand1, "f32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "f": + operand2, dtype = ops.to_dtype(operand2, "f32", var_info=var_info) + var_info[operand2] = dtype + + op_type1 = var_info[operand1] + tile_size = op_type1[0] + dtype = op_type1[1] + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f"math.pow{dtype[0]} %{operand1}, %{operand2} : {shape}", [] @staticmethod - def lt(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else "i1" - return f'arith.cmp{dtype[0]} olt, %{operand1}, %{operand2} : {shape}' + def log(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.log %{operand} : {shape}', [tile_size, dtype] @staticmethod - def gt(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else "i1" - return f'arith.cmp{dtype[0]} ogt, %{operand1}, %{operand2} : {shape}' + def reciprocal(operand, *args, var_info): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + + return ops.div(ops.constant(1.0, dtype), operand), [tile_size, dtype] + # Logical operations @staticmethod - def le(operand1, operand2, tile_size=16, dtype="f32"): + def neg(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.cmp{dtype[0]} ole, %{operand1}, %{operand2} : {shape}' + return f'arith.negf %{operand} : {shape}', [tile_size, dtype] @staticmethod - def relu(x, tile_size=16, dtype=None): - return ops.maximum(x, ops.constant(0.0, torch.float32)) + def eq(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "oeq" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "eq" + else: + raise ValueError(f"Unsupported data type for 'eq' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def sigmoid(x, tile_size=16, dtype=None): - one = ops.constant(1, torch.float32) - return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x)))) + def ne(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "one" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sne" + else: + raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def neg(x, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.neg{dtype[0]} %{x} : {shape}' + def lt(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "olt" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "slt" + else: + raise ValueError(f"Unsupported data type for 'lt' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def where(condition, x, y, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - cond_shape = f"vector<{tile_size}xi1>," if tile_size > 1 else "" - return f"arith.select %{condition}, %{x}, %{y} : {cond_shape} {shape}" + def gt(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "ogt" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sgt" + else: + raise ValueError(f"Unsupported data type for 'gt' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def logical_not(operand, tile_size=16, dtype="f32"): - tile_size=16 - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - result_shape = f"vector<{tile_size}xi1>" if tile_size > 1 else "i1" + def le(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "ole" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sle" + else: + raise ValueError(f"Unsupported data type for 'le' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def ge(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "oge" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sge" + else: + raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def and_(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand2] = dtype + + ret_type = op_type1[1] + tile_size = op_type1[0] + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.andi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def or_(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand2] = dtype + + ret_type = op_type1[1] + tile_size = op_type1[0] + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.ori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def xor(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand2] = dtype + + ret_type = op_type1[1] + tile_size = op_type1[0] + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.xori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + + @staticmethod + def logical_and(operand, *args, var_info=None): + raise NotImplementedError("logical_and") + + @staticmethod + def logical_not(operand, *args, var_info=None): raise NotImplementedError("logical_not") - return f"arith.cmp{dtype[0]} oeq, %{operand}, %zero_vec{tile_size} : {shape} -> {result_shape}" @staticmethod - def rsqrt(x, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.rsqrt %{x} : {shape}' + def logical_or(operand, *args, var_info=None): + raise NotImplementedError("logical_not") @staticmethod - def pow(a, b, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f"math.pow{dtype[0]} %{a}, %{b} : {shape}" + def logical_xor(operand, *args, var_info=None): + raise NotImplementedError("logical_not") @staticmethod - def log(x, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.log %{x} : {shape}' + def relu(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + ret_type = "f32" + return ops.maximum(operand, ops.constant(0.0, "f32")), [tile_size, ret_type] @staticmethod - def reciprocal(a, tile_size=16, dtype="f32"): - return ops.div(ops.constant(1.0, torch.float32), a) + def sigmoid(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + ret_type = "f32" + one = ops.constant(1, "f32") + return ops.truediv(one, ops.add(one, ops.exp(ops.neg(operand)))), [tile_size, ret_type] + + # Special operaitons + @staticmethod + def where(condition, operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + cond_type = var_info[condition] + if cond_type[0] < tile_size: + condition = ops.broadcast(condition, operand1, var_info=var_info) + elif cond_type[0] > tile_size: + operand1 = ops.broadcast(operand1, condition, var_info=var_info) + operand2 = ops.broadcast(operand2, condition, var_info=var_info) + tile_size, ret_type = var_info[operand1] + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + cond_shape = f"vector<{tile_size}xi1>," if tile_size > 1 else "" + return f"arith.select %{condition}, %{operand1}, %{operand2} : {cond_shape} {shape}", [tile_size, ret_type] + + + @staticmethod + def masked(mask, body, other, *args, var_info=None, tile_size=16, dtype="f32", ninf_declared=False): + result = body() + val = ops.constant(0.0, "f32") + result = ops.where(mask, result, val) + return result, var_info[result] + + @staticmethod + def _index_expr(operand, *args, var_info=None, **kwargs): + symbols = sorted([str(i) for i in operand.free_symbols]) + renamed_symbols = {symbol: sympy.Symbol(f"d{i}") for i, symbol in enumerate(symbols)} + + renamed_expression = operand.subs(renamed_symbols) + + affine_map_str = "(" + ", ".join([f"d{i}" for i in range(len(symbols))]) + ") -> (" + affine_map_str += sympy.printing.ccode(renamed_expression) + ")" + + map_operands = [f"%{str(symbol)}" for symbol in symbols] + return f"affine.apply affine_map<{affine_map_str}>({', '.join(map_operands)})", [1, "index"] + + @staticmethod + def index_expr(operand, *args, var_info=None, **kwargs): + result = ops._index_expr(operand) + ret_type = [1, "index"] + return result, ret_type + + @staticmethod + def index_cast(operand, target_type, *args, var_info=None, **kwrags): + op_type = var_info[operand] + src_shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else op_type[1] + des_shape = f"vector<{op_type[0]}x{target_type}>" if op_type[0] > 1 else target_type + return f"arith.index_cast %{operand} : {src_shape} to {des_shape}", [op_type[0], target_type] + + + @staticmethod + def broadcast(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + src_shape = f"vector<{op_type1[0]}x{op_type1[1]}>" if op_type1[0] > 1 else op_type1[1] + des_shape = f"vector<{op_type2[0]}x{op_type1[1]}>" if op_type2[0] > 1 else op_type1[1] # Use tile size only + expand = f"vector.broadcast %{operand1} : {src_shape} to {des_shape}" + return expand, [op_type2[0], op_type1[1]] RTYPE_TO_MLIR = { "sum": "add", @@ -325,7 +677,7 @@ def __init__(self): self.reduction_suffix = IndentedBuffer() self.body = IndentedBuffer() self.global_vars = IndentedBuffer() - self.global_vars_set = set() + self.global_vars_dict = dict() self.header = IndentedBuffer() self.gem5_header = IndentedBuffer() self.reduction_vars = {} @@ -563,7 +915,8 @@ def load_epilogue(self, name: str, index: sympy.Expr): shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" line = f"{operation} %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" out = self.cse.generate(self.loads, line) - self.tile_info[out] = tile_size_per_lane, dtype + var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] + self.register_var_info(out, var_info) return out def load(self, name: str, index: sympy.Expr): @@ -582,7 +935,7 @@ def load(self, name: str, index: sympy.Expr): dram_tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" # Define scratch pad buffer - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.loads, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.loads, indices, index) # MVIN Encoding dma_key = (stride, chunk, dtype) if dma_key in self.dma_cache: @@ -604,7 +957,8 @@ def load(self, name: str, index: sympy.Expr): shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" line = f"{operation} %{buffer}[0, 0] : memref<{dram_tile_shape}x{type_name}, 1>{shape}" out = self.cse.generate(self.loads, line) - self.tile_info[out] = tile_size_per_lane, dtype + var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] + self.register_var_info(out, var_info) return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): @@ -625,7 +979,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): buffer = self.buffer_names[name] else: dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices, index) self.buffer_names[name] = buffer tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane @@ -655,7 +1009,7 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): dram_tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" # Define scratch pad buffer - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.stores, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.stores, indices, index) # MVOUT Encoding dmaType = 3 # MVIN 2, MVIN2 1, MVIN3 14, MVOUT 3 @@ -663,9 +1017,11 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): self.consts.add(stride) self.consts.add(chunk) - store_size = self.tile_info[value][0] + store_size, operand_type = self.var_info[value] operation = "affine.vector_store" if tile_size_per_lane > 1 and store_size > 1 else "affine.store" shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 and store_size > 1 else "" + if type_name != operand_type: + value = ops.custom_cast(value, type_name, var_info=self.var_info) line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{dram_tile_shape}x{type_name}, 1>{shape}" self.cse.generate(self.stores, line, assignment = False) @@ -728,13 +1084,14 @@ def reduction(self, dtype, src_dtype, reduction_type, value): init_vec = init axis = "0, 1" acc_var = init - self.tile_info[acc] = 1, dtype + var_info = [1, mlir_common.DTYPE_TO_MLIR[dtype]] else: reduced_shape = f"vector<{vec_len}x{type_name}>" init_vec = self.cse.generate(self.reduction_prefix, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") axis = "0" acc_var = init_vec - self.tile_info[acc] = vec_len, dtype + var_info = [vec_len, mlir_common.DTYPE_TO_MLIR[dtype]] + self.register_var_info(acc, var_info) else: raise NotImplementedError() @@ -761,7 +1118,7 @@ def store_reduction(self, name, index, value): tile_col = self.tile_desc.n_row tile_row = 1 dram_tile_shape = f"{tile_row}x{tile_col}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, self.reductions_suffix, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, self.reductions_suffix, indices, index) if self.welford_reduce_out is not None: # raise NotImplementedError() sum, sqr_sum, _ = self.welford_reduce_out @@ -769,7 +1126,7 @@ def store_reduction(self, name, index, value): # mean divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.ranges[self.reduction_depth])} : f32") if self.buffer_types[name][1] > 1: - divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.tile_info[sum][0]}x{type_name}>") + divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.var_info[sum][0]}x{type_name}>") else: divider_vec = f"f{self.buffer_types[name][1]}" mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sum}, %{divider_vec} : {shape}") @@ -1001,7 +1358,7 @@ def set_ranges(self, lengths, reduction_lengths, read_writes): self.itervars[self.reduction_depth :], ) - def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices): + def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index): c_type = mlir_common.DTYPE_TO_C[dtype] mlir_type = mlir_common.DTYPE_TO_MLIR[dtype] # Make sure each lane's buffer has at least two element @@ -1014,13 +1371,17 @@ def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape mapping = self.map_cse.generate(self.global_vars, f"affine_map<({indices}) -> ({indices} floordiv 8)>") indices = self.cse.generate(self.loads, f"affine.apply #{mapping}(%{indices})") # FIXME. Only loads? - if name not in self.global_vars_set: + if name not in self.global_vars_dict: + self.global_vars_dict[name] = set() + + if str(raw_index) not in self.global_vars_dict[name]: + new_name = f"{name}_{len(self.global_vars_dict[name])}" # Add definition to header - self.header.writeline(f"{c_type} {name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") - self.gem5_header.writeline(f"{c_type} {name}_spad[{tile_size}];") - self.global_vars_set.add(name) - self.global_vars.writeline(f"memref.global @{name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") - buffer = self.cse.generate(code_buffer, f"memref.get_global @{name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") + self.header.writeline(f"{c_type} {new_name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") + self.gem5_header.writeline(f"{c_type} {new_name}_spad[{tile_size}];") + self.global_vars.writeline(f"memref.global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") + self.global_vars_dict[name].add(str(raw_index)) + buffer = self.cse.generate(code_buffer, f"memref.get_global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") return buffer, indices def roundup_vectorlane(self, size, amp=1): diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index f76fd0cc..912704b5 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -172,7 +172,7 @@ def __init__(self, args=None): self.tile_col = extension_config.CONFIG_TILE_COL if self.tile_col == -1: self.tile_col = 8 # FIXME: tile_col is not always vector_lane * vlen - self.tile_info = {} + self.var_info = {} def load(self, name: str, index: sympy.Expr): raise NotImplementedError() @@ -193,32 +193,8 @@ def check_dtype_in_args(self, args): dtype = arg return dtype - def expand(self, args, buf_bounds): - cse_args = [arg for arg in args if isinstance(arg, common.CSEVariable)] - if len(cse_args) == 0: - return args, 1, self.check_dtype_in_args(args) - elif len(cse_args) == 1: - if not cse_args[0] in self.tile_info: - return args, 1, self.check_dtype_in_args(cse_args) - info = self.tile_info[cse_args[0]] - return args, info[0], info[1] - lhs_idx = args.index(cse_args[-2]) - rhs_idx = args.index(cse_args[-1]) - if not args[lhs_idx] in self.tile_info or not args[rhs_idx] in self.tile_info: - return args, 1, self.check_dtype_in_args(args) - lhs_tile_size, lhs_dtype = self.tile_info[args[lhs_idx]] - rhs_tile_size, rhs_dtype = self.tile_info[args[rhs_idx]] - lhs_shape = f"vector<{lhs_tile_size}x{DTYPE_TO_MLIR[lhs_dtype]}>" if lhs_tile_size > 1 else DTYPE_TO_MLIR[lhs_dtype] - rhs_shape = f"vector<{rhs_tile_size}x{DTYPE_TO_MLIR[rhs_dtype]}>" if rhs_tile_size > 1 else DTYPE_TO_MLIR[rhs_dtype] - temp = list(args) - if lhs_tile_size > rhs_tile_size: - expand = f"vector.broadcast %{args[rhs_idx]} : {rhs_shape} to {lhs_shape}" - temp[rhs_idx] = self.cse.generate(self.compute, expand, bounds=buf_bounds) - elif lhs_tile_size < rhs_tile_size: - expand = f"vector.broadcast %{args[lhs_idx]} : {lhs_shape} to {rhs_shape}" - temp[lhs_idx] = self.cse.generate(self.compute, expand, bounds=buf_bounds) - args = tuple(temp) - return args, max(lhs_tile_size, rhs_tile_size), lhs_dtype + def register_var_info(self, var, var_info): + self.var_info[var] = var_info def __enter__(self): class CSEProxy: @@ -235,13 +211,13 @@ def inner(*args, **kwargs): buf_bounds = self.node_to_bounds.get( fx_node, ValueRanges.unknown() ) - args, tile_size, dtype = self.expand(args, buf_bounds) + code, ret_info = getattr(parent_handler, name)(*args, var_info=self.var_info) csevar = self.cse.generate( self.compute, - getattr(parent_handler, name)(*args, tile_size=tile_size, dtype=DTYPE_TO_MLIR[dtype], **kwargs), # type: ignore[has-type] + code, bounds=buf_bounds, ) - self.tile_info[csevar] = tile_size, dtype + self.register_var_info(csevar, ret_info) csevar.update_on_args(name, args, kwargs) return csevar diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index c1950f82..e7ca37eb 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -144,5 +144,4 @@ def custom_maxpool( lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) -lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) -lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) \ No newline at end of file +lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) \ No newline at end of file diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 1fe0c674..37f8a583 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -14,11 +14,11 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): exit(1) def test_resnet(device): - from torchvision.models import resnet18 - model = resnet18().eval() - model.to(device) + from torchvision.models import resnet + model = resnet._resnet(resnet.BasicBlock, [1, 1, 0, 0], weights=None, progress=False).eval() + model.to(device, memory_format=torch.channels_last) input = torch.randn(1, 3, 224, 224).to(device=device) - x1 = input.to(device=device) + x1 = input.to(device=device, memory_format=torch.channels_last) opt_fn = torch.compile(dynamic=False)(model) res = opt_fn(x1) print("ResNet18 Simulation Done")