From d06d1974951a0f03ecacbab83c29f095be5213f7 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 9 Jan 2025 01:30:50 +0000 Subject: [PATCH 1/2] [Frontend] Make operator return type #96 Currently, we don't tracking the all variables type. Return type could be diverse accroding to operations. Now, opeations should return it return type. So every type of variables(SSA form) are tracking and we can support automatic type casting and broadcasting --- .../llvm/llvm_codegen_backend.py | 2 +- .../mlir/mlir_codegen_backend.py | 524 +++++++++++++++--- PyTorchSimFrontend/mlir/mlir_common.py | 36 +- 3 files changed, 440 insertions(+), 122 deletions(-) diff --git a/PyTorchSimFrontend/llvm/llvm_codegen_backend.py b/PyTorchSimFrontend/llvm/llvm_codegen_backend.py index e8daa889..6951b5bd 100644 --- a/PyTorchSimFrontend/llvm/llvm_codegen_backend.py +++ b/PyTorchSimFrontend/llvm/llvm_codegen_backend.py @@ -212,7 +212,7 @@ def maximum(operand1, operand2, tile_size=4): @staticmethod def relu(x, tile_size=4): - return ops.maximum(x, ops.constant(0.0, torch.int32)) + return ops.maximum(x, ops.constant(0.0, "f32")) SYMPY_TO_LLVM = { sympy.core.mul.Mul: "mul", diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 140ba3fd..0bbda189 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -107,53 +107,131 @@ def write_header(self): ) class ExtensionOverrides(common.OpOverrides): + # Binary element wise operations @staticmethod - def add(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.add{dtype[0]} %{operand1}, %{operand2} : {shape}' + def binary_elementwise_common(operand1, operand2, var_info): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + # Tile size check + if op_type1[0] != op_type2[0]: + # Try to broad cast + lhs_tile_size, lhs_dtype = op_type1 + rhs_tile_size, rhs_dtype = op_type2 + if lhs_tile_size > rhs_tile_size: + operand2 = ops.broadcast(operand2, operand1, var_info=var_info) + op_type2 = var_info[operand2] + elif lhs_tile_size < rhs_tile_size: + operand1 = ops.broadcast(operand1, operand2, var_info=var_info) + op_type1 = var_info[operand1] + + # Data type check + if op_type1[1] != op_type2[1]: + if op_type1[1] == "index" or op_type1 == "index": + if op_type1[1] == "index": + operand1 = ops.index_cast(operand1, op_type2[1], var_info) + op_type1 = var_info[operand1] + if op_type2[1] == "index": + operand2 = ops.index_cast(operand2, op_type1[1], var_info) + op_type2 = var_info[operand2] + elif op_type1[1][0] == "i" and op_type2[1][0] == "f": + operand1 = ops.to_dtype(operand1, op_type2[1], var_info) + op_type1 = var_info[operand1] + elif op_type1[1][0] == "f" and op_type2[1][0] == "i": + operand2 = ops.to_dtype(operand2, op_type1[1], var_info) + op_type2 = var_info[operand2] + else: + raise NotImplementedError("Unsupported type converting") + + # Updated var info + tile_size = op_type1[0] + ret_type = op_type1[1] + return tile_size, ret_type, operand1, operand2 @staticmethod - def sub(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.sub{dtype[0]} %{operand1}, %{operand2} : {shape}' + def add(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.add{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def mul(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.mul{dtype[0]} %{operand1}, %{operand2} : {shape}' + def sub(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.sub{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def div(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.div{dtype[0]} %{operand1}, %{operand2} : {shape}' + def mul(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.mul{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def truediv(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.div{dtype[0]} %{operand1}, %{operand2} : {shape}' + def div(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.divf' + else: + opcode = f'arith.divui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def truediv(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.divf' + else: + opcode = f'arith.divui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def to_dtype(x, dst_type, src_dtype=None, tile_size=16, dtype="f32"): - mlir_dtype = mlir_common.DTYPE_TO_MLIR[dst_type] - src_mlir_dtype = mlir_common.DTYPE_TO_MLIR[src_dtype] + def minimum(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.minimumf' + else: + opcode = f'arith.minimumui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - dst_bits = 1 if dst_type == torch.bool else torch.finfo(dst_type).bits if dst_type.is_floating_point else torch.iinfo(dst_type).bits - src_bits = 1 if src_dtype == torch.bool else torch.finfo(src_dtype).bits if src_dtype.is_floating_point else torch.iinfo(src_dtype).bits - shape = f"vector<{tile_size}x{mlir_dtype}>" if tile_size > 1 else mlir_dtype + @staticmethod + def maximum(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.maximumf' + else: + opcode = f'arith.maximumui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def to_dtype(operand, dst_mlir_dtype, *args, var_info=None): + src_mlir_dtype = var_info[operand][1] + tile_size = var_info[operand][0] + + dst_bits = int(dst_mlir_dtype[1:]) + src_bits = int(src_mlir_dtype[1:]) + shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype src_shape = f"vector<{tile_size}x{src_mlir_dtype}>" if tile_size > 1 else src_mlir_dtype - if dst_type.is_floating_point and not src_dtype.is_floating_point: + if dst_mlir_dtype[0] == "i" and src_mlir_dtype[0] == "f": raise NotImplementedError("floating point to integer conversion") - elif not dst_type.is_floating_point and src_dtype.is_floating_point: + if dst_mlir_dtype[0] == "f" and src_mlir_dtype[0] == "i": raise NotImplementedError("integer to floating point conversion") else: if dst_bits > src_bits: - return f"arith.extui %{x} : {src_shape} to {shape}" + return f"arith.extui %{operand} : {src_shape} to {shape}" elif dst_bits < src_bits: - return f"arith.trunc %{x} : {src_shape} to {shape}" + return f"arith.trunc %{operand} : {src_shape} to {shape}" @staticmethod - def constant(value, src_type, tile_size=16, dtype="f32"): - src_type = mlir_common.DTYPE_TO_MLIR[src_type] + def constant(value, src_type, *args, var_info=None): + if isinstance(src_type, torch.dtype): + src_type = mlir_common.DTYPE_TO_MLIR[src_type] + # if value represented by e notation, convert to float (ex 1e-3 -> 1.0e-3) if "e" in str(value): value = float(value) @@ -161,89 +239,346 @@ def constant(value, src_type, tile_size=16, dtype="f32"): value = format(value, ".20f") if src_type[0] == "i": value = int(value) - return f'arith.constant {value} : {src_type}' + return f'arith.constant {value} : {src_type}', [1, src_type] + # transcendental functions @staticmethod - def exp(operand, tile_size=16, dtype="f32"): + def exp(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.exp %{operand} : {shape}' + return f'math.exp %{operand} : {shape}', [tile_size, dtype] @staticmethod - def maximum(operand1, operand2, tile_size=16, dtype="f32"): + def sqrt(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.maximum{dtype[0]} %{operand1}, %{operand2} : {shape}' + return f'math.sqrt %{operand} : {shape}', [tile_size, dtype] @staticmethod - def sqrt(x, tile_size=16, dtype="f32"): + def rsqrt(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.sqrt %{x} : {shape}' + return f'math.rsqrt %{operand} : {shape}', [tile_size, dtype] @staticmethod - def ne(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else "i1" - return f'arith.cmp{dtype[0]} one, %{operand1}, %{operand2} : {shape}' + def pow(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "f": + operand1, dtype = ops.to_dtype(operand1, "f32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "f": + operand2, dtype = ops.to_dtype(operand2, "f32", var_info=var_info) + var_info[operand2] = dtype + + op_type1 = var_info[operand1] + tile_size = op_type1[0] + dtype = op_type1[1] + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f"math.pow{dtype[0]} %{operand1}, %{operand2} : {shape}", [] @staticmethod - def lt(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else "i1" - return f'arith.cmp{dtype[0]} olt, %{operand1}, %{operand2} : {shape}' + def log(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.log %{operand} : {shape}', [tile_size, dtype] @staticmethod - def gt(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else "i1" - return f'arith.cmp{dtype[0]} ogt, %{operand1}, %{operand2} : {shape}' + def reciprocal(operand, *args, var_info): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + return ops.div(ops.constant(1.0, dtype), operand), [tile_size, dtype] + + # Logical operations @staticmethod - def le(operand1, operand2, tile_size=16, dtype="f32"): + def neg(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.cmp{dtype[0]} ole, %{operand1}, %{operand2} : {shape}' + return f'arith.negf %{operand} : {shape}', [tile_size, dtype] @staticmethod - def relu(x, tile_size=16, dtype=None): - return ops.maximum(x, ops.constant(0.0, torch.float32)) + def eq(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "oeq" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "eq" + else: + raise ValueError(f"Unsupported data type for 'eq' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def sigmoid(x, tile_size=16, dtype=None): - one = ops.constant(1, torch.float32) - return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x)))) + def ne(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "one" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sne" + else: + raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def neg(x, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.neg{dtype[0]} %{x} : {shape}' + def lt(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "olt" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "slt" + else: + raise ValueError(f"Unsupported data type for 'lt' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def where(condition, x, y, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - cond_shape = f"vector<{tile_size}xi1>," if tile_size > 1 else "" - return f"arith.select %{condition}, %{x}, %{y} : {cond_shape} {shape}" + def gt(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "ogt" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sgt" + else: + raise ValueError(f"Unsupported data type for 'gt' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def logical_not(operand, tile_size=16, dtype="f32"): - tile_size=16 - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - result_shape = f"vector<{tile_size}xi1>" if tile_size > 1 else "i1" + def le(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "ole" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sle" + else: + raise ValueError(f"Unsupported data type for 'le' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def ge(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "oge" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sge" + else: + raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def and_(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand2] = dtype + + ret_type = op_type1[1] + tile_size = op_type1[0] + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.andi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def or_(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand2] = dtype + + ret_type = op_type1[1] + tile_size = op_type1[0] + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.ori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def xor(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand2] = dtype + + ret_type = op_type1[1] + tile_size = op_type1[0] + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.xori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + + @staticmethod + def logical_and(operand, *args, var_info=None): + raise NotImplementedError("logical_and") + + @staticmethod + def logical_not(operand, *args, var_info=None): raise NotImplementedError("logical_not") - return f"arith.cmp{dtype[0]} oeq, %{operand}, %zero_vec{tile_size} : {shape} -> {result_shape}" @staticmethod - def rsqrt(x, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.rsqrt %{x} : {shape}' + def logical_or(operand, *args, var_info=None): + raise NotImplementedError("logical_not") @staticmethod - def pow(a, b, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f"math.pow{dtype[0]} %{a}, %{b} : {shape}" + def logical_xor(operand, *args, var_info=None): + raise NotImplementedError("logical_not") @staticmethod - def log(x, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.log %{x} : {shape}' + def relu(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + ret_type = "f32" + return ops.maximum(operand, ops.constant(0.0, "f32")), [tile_size, ret_type] + + @staticmethod + def sigmoid(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + ret_type = "f32" + one = ops.constant(1, "f32") + return ops.truediv(one, ops.add(one, ops.exp(ops.neg(operand)))), [tile_size, ret_type] + + # Special operaitons + @staticmethod + def where(condition, operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + cond_type = var_info[condition] + if cond_type[0] != tile_size: + condition = ops.broadcast(condition, operand1, var_info=var_info) + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + cond_shape = f"vector<{tile_size}xi1>," if tile_size > 1 else "" + return f"arith.select %{condition}, %{operand1}, %{operand2} : {cond_shape} {shape}", [tile_size, ret_type] + @staticmethod - def reciprocal(a, tile_size=16, dtype="f32"): - return ops.div(ops.constant(1.0, torch.float32), a) + def masked(mask, body, other, *args, var_info=None, tile_size=16, dtype="f32", ninf_declared=False): + result = body() + val = ops.constant(0.0, "f32") + result = ops.where(mask, result, val) + return result, var_info[result] + + @staticmethod + def _index_expr(operand, *args, var_info=None, **kwargs): + symbols = sorted(operand.free_symbols) + renamed_symbols = {symbol: sympy.Symbol(f"d{i}") for i, symbol in enumerate(symbols)} + + renamed_expression = operand.subs(renamed_symbols) + + affine_map_str = "(" + ", ".join([f"d{i}" for i in range(len(symbols))]) + ") -> (" + affine_map_str += sympy.printing.ccode(renamed_expression) + ")" + + map_operands = [f"%{str(symbol)}" for symbol in symbols] + return f"affine.apply affine_map<{affine_map_str}>({', '.join(map_operands)})", [1, "index"] + + @staticmethod + def index_expr(operand, *args, var_info=None, **kwargs): + result = ops._index_expr(operand) + ret_type = [1, "index"] + return result, ret_type + + @staticmethod + def index_cast(operand, target_type, *args, var_info=None, **kwrags): + return f"arith.index_cast %{operand} : index to {target_type}", [1, target_type] + + @staticmethod + def broadcast(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + src_shape = f"vector<{op_type1[0]}x{op_type1[1]}>" if op_type1[0] > 1 else op_type1[1] + des_shape = f"vector<{op_type2[0]}x{op_type1[1]}>" if op_type2[0] > 1 else op_type1[1] # Use tile size only + expand = f"vector.broadcast %{operand1} : {src_shape} to {des_shape}" + return expand, op_type2 RTYPE_TO_MLIR = { "sum": "add", @@ -325,7 +660,7 @@ def __init__(self): self.reduction_suffix = IndentedBuffer() self.body = IndentedBuffer() self.global_vars = IndentedBuffer() - self.global_vars_set = set() + self.global_vars_dict = dict() self.header = IndentedBuffer() self.gem5_header = IndentedBuffer() self.reduction_vars = {} @@ -563,7 +898,8 @@ def load_epilogue(self, name: str, index: sympy.Expr): shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" line = f"{operation} %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" out = self.cse.generate(self.loads, line) - self.tile_info[out] = tile_size_per_lane, dtype + var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] + self.register_var_info(out, var_info) return out def load(self, name: str, index: sympy.Expr): @@ -582,7 +918,7 @@ def load(self, name: str, index: sympy.Expr): dram_tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" # Define scratch pad buffer - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.loads, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.loads, indices, index) # MVIN Encoding dma_key = (stride, chunk, dtype) if dma_key in self.dma_cache: @@ -604,7 +940,8 @@ def load(self, name: str, index: sympy.Expr): shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" line = f"{operation} %{buffer}[0, 0] : memref<{dram_tile_shape}x{type_name}, 1>{shape}" out = self.cse.generate(self.loads, line) - self.tile_info[out] = tile_size_per_lane, dtype + var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] + self.register_var_info(out, var_info) return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): @@ -625,7 +962,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): buffer = self.buffer_names[name] else: dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices, index) self.buffer_names[name] = buffer tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane @@ -655,7 +992,7 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): dram_tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" # Define scratch pad buffer - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.stores, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.stores, indices, index) # MVOUT Encoding dmaType = 3 # MVIN 2, MVIN2 1, MVIN3 14, MVOUT 3 @@ -663,7 +1000,7 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): self.consts.add(stride) self.consts.add(chunk) - store_size = self.tile_info[value][0] + store_size = self.var_info[value][0] operation = "affine.vector_store" if tile_size_per_lane > 1 and store_size > 1 else "affine.store" shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 and store_size > 1 else "" @@ -728,13 +1065,14 @@ def reduction(self, dtype, src_dtype, reduction_type, value): init_vec = init axis = "0, 1" acc_var = init - self.tile_info[acc] = 1, dtype + var_info = [1, mlir_common.DTYPE_TO_MLIR[dtype]] else: reduced_shape = f"vector<{vec_len}x{type_name}>" init_vec = self.cse.generate(self.reduction_prefix, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") axis = "0" acc_var = init_vec - self.tile_info[acc] = vec_len, dtype + var_info = [vec_len, mlir_common.DTYPE_TO_MLIR[dtype]] + self.register_var_info(acc, var_info) else: raise NotImplementedError() @@ -761,7 +1099,7 @@ def store_reduction(self, name, index, value): tile_col = self.tile_desc.n_row tile_row = 1 dram_tile_shape = f"{tile_row}x{tile_col}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, self.reductions_suffix, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, self.reductions_suffix, indices, index) if self.welford_reduce_out is not None: # raise NotImplementedError() sum, sqr_sum, _ = self.welford_reduce_out @@ -769,7 +1107,7 @@ def store_reduction(self, name, index, value): # mean divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.ranges[self.reduction_depth])} : f32") if self.buffer_types[name][1] > 1: - divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.tile_info[sum][0]}x{type_name}>") + divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.var_info[sum][0]}x{type_name}>") else: divider_vec = f"f{self.buffer_types[name][1]}" mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sum}, %{divider_vec} : {shape}") @@ -1001,7 +1339,7 @@ def set_ranges(self, lengths, reduction_lengths, read_writes): self.itervars[self.reduction_depth :], ) - def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices): + def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index): c_type = mlir_common.DTYPE_TO_C[dtype] mlir_type = mlir_common.DTYPE_TO_MLIR[dtype] # Make sure each lane's buffer has at least two element @@ -1014,13 +1352,17 @@ def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape mapping = self.map_cse.generate(self.global_vars, f"affine_map<({indices}) -> ({indices} floordiv 8)>") indices = self.cse.generate(self.loads, f"affine.apply #{mapping}(%{indices})") # FIXME. Only loads? - if name not in self.global_vars_set: + if name not in self.global_vars_dict: + self.global_vars_dict[name] = set() + + if str(raw_index) not in self.global_vars_dict[name]: + new_name = f"{name}_{len(self.global_vars_dict[name])}" # Add definition to header - self.header.writeline(f"{c_type} {name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") - self.gem5_header.writeline(f"{c_type} {name}_spad[{tile_size}];") - self.global_vars_set.add(name) - self.global_vars.writeline(f"memref.global @{name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") - buffer = self.cse.generate(code_buffer, f"memref.get_global @{name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") + self.header.writeline(f"{c_type} {new_name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") + self.gem5_header.writeline(f"{c_type} {new_name}_spad[{tile_size}];") + self.global_vars.writeline(f"memref.global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") + self.global_vars_dict[name].add(str(raw_index)) + buffer = self.cse.generate(code_buffer, f"memref.get_global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") return buffer, indices def roundup_vectorlane(self, size, amp=1): diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index f76fd0cc..912704b5 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -172,7 +172,7 @@ def __init__(self, args=None): self.tile_col = extension_config.CONFIG_TILE_COL if self.tile_col == -1: self.tile_col = 8 # FIXME: tile_col is not always vector_lane * vlen - self.tile_info = {} + self.var_info = {} def load(self, name: str, index: sympy.Expr): raise NotImplementedError() @@ -193,32 +193,8 @@ def check_dtype_in_args(self, args): dtype = arg return dtype - def expand(self, args, buf_bounds): - cse_args = [arg for arg in args if isinstance(arg, common.CSEVariable)] - if len(cse_args) == 0: - return args, 1, self.check_dtype_in_args(args) - elif len(cse_args) == 1: - if not cse_args[0] in self.tile_info: - return args, 1, self.check_dtype_in_args(cse_args) - info = self.tile_info[cse_args[0]] - return args, info[0], info[1] - lhs_idx = args.index(cse_args[-2]) - rhs_idx = args.index(cse_args[-1]) - if not args[lhs_idx] in self.tile_info or not args[rhs_idx] in self.tile_info: - return args, 1, self.check_dtype_in_args(args) - lhs_tile_size, lhs_dtype = self.tile_info[args[lhs_idx]] - rhs_tile_size, rhs_dtype = self.tile_info[args[rhs_idx]] - lhs_shape = f"vector<{lhs_tile_size}x{DTYPE_TO_MLIR[lhs_dtype]}>" if lhs_tile_size > 1 else DTYPE_TO_MLIR[lhs_dtype] - rhs_shape = f"vector<{rhs_tile_size}x{DTYPE_TO_MLIR[rhs_dtype]}>" if rhs_tile_size > 1 else DTYPE_TO_MLIR[rhs_dtype] - temp = list(args) - if lhs_tile_size > rhs_tile_size: - expand = f"vector.broadcast %{args[rhs_idx]} : {rhs_shape} to {lhs_shape}" - temp[rhs_idx] = self.cse.generate(self.compute, expand, bounds=buf_bounds) - elif lhs_tile_size < rhs_tile_size: - expand = f"vector.broadcast %{args[lhs_idx]} : {lhs_shape} to {rhs_shape}" - temp[lhs_idx] = self.cse.generate(self.compute, expand, bounds=buf_bounds) - args = tuple(temp) - return args, max(lhs_tile_size, rhs_tile_size), lhs_dtype + def register_var_info(self, var, var_info): + self.var_info[var] = var_info def __enter__(self): class CSEProxy: @@ -235,13 +211,13 @@ def inner(*args, **kwargs): buf_bounds = self.node_to_bounds.get( fx_node, ValueRanges.unknown() ) - args, tile_size, dtype = self.expand(args, buf_bounds) + code, ret_info = getattr(parent_handler, name)(*args, var_info=self.var_info) csevar = self.cse.generate( self.compute, - getattr(parent_handler, name)(*args, tile_size=tile_size, dtype=DTYPE_TO_MLIR[dtype], **kwargs), # type: ignore[has-type] + code, bounds=buf_bounds, ) - self.tile_info[csevar] = tile_size, dtype + self.register_var_info(csevar, ret_info) csevar.update_on_args(name, args, kwargs) return csevar From f90d46d9b8071b65e78e99553bc91d9af1f47740 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 9 Jan 2025 04:44:43 +0000 Subject: [PATCH 2/2] [Frontend] Make custom pooling disabled --- .../mlir/mlir_codegen_backend.py | 29 +++++++++++++++---- PyTorchSimFrontend/mlir/mlir_lowering.py | 3 +- tests/test_resnet.py | 8 ++--- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 0bbda189..3d65aa53 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -108,6 +108,15 @@ def write_header(self): class ExtensionOverrides(common.OpOverrides): # Binary element wise operations + @staticmethod + def custom_cast(operand, target_type, *args, var_info=None): + dtype = var_info[operand][1] + if dtype == "index": + ret = ops.index_cast(operand, target_type, var_info=var_info) + else: + ret = ops.to_dtype(operand, target_type, var_info=var_info) + return ret, var_info[ret] + @staticmethod def binary_elementwise_common(operand1, operand2, var_info): op_type1 = var_info[operand1] @@ -533,8 +542,12 @@ def sigmoid(operand, *args, var_info=None): def where(condition, operand1, operand2, *args, var_info=None): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) cond_type = var_info[condition] - if cond_type[0] != tile_size: + if cond_type[0] < tile_size: condition = ops.broadcast(condition, operand1, var_info=var_info) + elif cond_type[0] > tile_size: + operand1 = ops.broadcast(operand1, condition, var_info=var_info) + operand2 = ops.broadcast(operand2, condition, var_info=var_info) + tile_size, ret_type = var_info[operand1] shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type cond_shape = f"vector<{tile_size}xi1>," if tile_size > 1 else "" @@ -550,7 +563,7 @@ def masked(mask, body, other, *args, var_info=None, tile_size=16, dtype="f32", n @staticmethod def _index_expr(operand, *args, var_info=None, **kwargs): - symbols = sorted(operand.free_symbols) + symbols = sorted([str(i) for i in operand.free_symbols]) renamed_symbols = {symbol: sympy.Symbol(f"d{i}") for i, symbol in enumerate(symbols)} renamed_expression = operand.subs(renamed_symbols) @@ -569,7 +582,11 @@ def index_expr(operand, *args, var_info=None, **kwargs): @staticmethod def index_cast(operand, target_type, *args, var_info=None, **kwrags): - return f"arith.index_cast %{operand} : index to {target_type}", [1, target_type] + op_type = var_info[operand] + src_shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else op_type[1] + des_shape = f"vector<{op_type[0]}x{target_type}>" if op_type[0] > 1 else target_type + return f"arith.index_cast %{operand} : {src_shape} to {des_shape}", [op_type[0], target_type] + @staticmethod def broadcast(operand1, operand2, *args, var_info=None): @@ -578,7 +595,7 @@ def broadcast(operand1, operand2, *args, var_info=None): src_shape = f"vector<{op_type1[0]}x{op_type1[1]}>" if op_type1[0] > 1 else op_type1[1] des_shape = f"vector<{op_type2[0]}x{op_type1[1]}>" if op_type2[0] > 1 else op_type1[1] # Use tile size only expand = f"vector.broadcast %{operand1} : {src_shape} to {des_shape}" - return expand, op_type2 + return expand, [op_type2[0], op_type1[1]] RTYPE_TO_MLIR = { "sum": "add", @@ -1000,9 +1017,11 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): self.consts.add(stride) self.consts.add(chunk) - store_size = self.var_info[value][0] + store_size, operand_type = self.var_info[value] operation = "affine.vector_store" if tile_size_per_lane > 1 and store_size > 1 else "affine.store" shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 and store_size > 1 else "" + if type_name != operand_type: + value = ops.custom_cast(value, type_name, var_info=self.var_info) line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{dram_tile_shape}x{type_name}, 1>{shape}" self.cse.generate(self.stores, line, assignment = False) diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index c1950f82..e7ca37eb 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -144,5 +144,4 @@ def custom_maxpool( lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) -lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) -lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) \ No newline at end of file +lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) \ No newline at end of file diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 1fe0c674..37f8a583 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -14,11 +14,11 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): exit(1) def test_resnet(device): - from torchvision.models import resnet18 - model = resnet18().eval() - model.to(device) + from torchvision.models import resnet + model = resnet._resnet(resnet.BasicBlock, [1, 1, 0, 0], weights=None, progress=False).eval() + model.to(device, memory_format=torch.channels_last) input = torch.randn(1, 3, 224, 224).to(device=device) - x1 = input.to(device=device) + x1 = input.to(device=device, memory_format=torch.channels_last) opt_fn = torch.compile(dynamic=False)(model) res = opt_fn(x1) print("ResNet18 Simulation Done")