diff --git a/ftn/tools/ftn_opt.py b/ftn/tools/ftn_opt.py index 931f854..2a90677 100755 --- a/ftn/tools/ftn_opt.py +++ b/ftn/tools/ftn_opt.py @@ -5,8 +5,8 @@ from ftn.transforms.rewrite_fir_to_core import RewriteFIRToCore from ftn.transforms.merge_memref_deref import MergeMemRefDeref +from ftn.transforms.extract_target import ExtractTarget from ftn.transforms.lower_omp_target_data import LowerOmpTargetDataPass -# from ftn.transforms.extract_target import ExtractTarget # from ftn.transforms.isolate_target import IsolateTarget # from psy.extract_stencil import ExtractStencil # from ftn.transforms.tenstorrent.convert_to_tt import ConvertToTT @@ -27,8 +27,8 @@ def register_all_passes(self): super().register_all_passes() self.register_pass("rewrite-fir-to-core", lambda: RewriteFIRToCore) self.register_pass("merge-memref-deref", lambda: MergeMemRefDeref) + self.register_pass("extract-target", lambda: ExtractTarget) self.register_pass("lower-omp-target-data", lambda: LowerOmpTargetDataPass) - # self.register_pass("extract-target", lambda: ExtractTarget) # self.register_pass("isolate-target", lambda: IsolateTarget) # self.register_pass("convert-to-tt", lambda: ConvertToTT) diff --git a/ftn/transforms/extract_target.py b/ftn/transforms/extract_target.py index a84eb8c..8300612 100644 --- a/ftn/transforms/extract_target.py +++ b/ftn/transforms/extract_target.py @@ -1,10 +1,13 @@ from abc import ABC +from ast import Module +from hmac import new from typing import TypeVar, cast -from dataclasses import dataclass +from dataclasses import dataclass, field import itertools from xdsl.utils.hints import isa from xdsl.dialects import memref, scf, omp -from xdsl.ir import Operation, SSAValue, OpResult, Attribute, MLContext, Block, Region +from xdsl.context import Context +from xdsl.ir import Operation, SSAValue, OpResult, Attribute, Block, Region from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter, op_type_rewrite_pattern, @@ -13,10 +16,12 @@ from xdsl.passes import ModulePass from xdsl.dialects import builtin, func, llvm, arith from ftn.util.visitor import Visitor +from xdsl.rewriter import InsertPoint +@dataclass class RewriteTarget(RewritePattern): - def __init__(self): - self.target_ops=[] + module : builtin.ModuleOp + target_ops: list[Operation] = field(default_factory=list) @op_type_rewrite_pattern def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): @@ -31,38 +36,38 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): # Grab bounds and info, then at end the terminator for var in op.map_vars: var_op=var.owner - var_op.parent.detach_op(var_op) - arg_types.append(var_op.var_ptr[0].type) - arg_ssa.append(var_op.var_ptr[0]) - if isa(var_op.var_ptr[0].type, builtin.MemRefType): - memref_type=var_op.var_ptr[0].type - src_memref=var_op.var_ptr[0] + arg_types.append(var_op.var_ptr.type) + arg_ssa.append(var_op.var_ptr) + locations[var_op]=loc_idx + loc_idx+=1 + if isa(var_op.var_ptr.type, builtin.MemRefType): + memref_type=var_op.var_ptr.type + src_memref=var_op.var_ptr if isa(memref_type.element_type, builtin.MemRefType): assert len(memref_type.shape) == 0 - memref_type=var_op.var_ptr[0].type.element_type - memref_loadop=memref.Load.get(src_memref, []) + memref_type=var_op.var_ptr.type.element_type + memref_loadop=memref.LoadOp.get(src_memref, []) src_memref=memref_loadop.results[0] memref_dim_ops.append(memref_loadop) for idx, s in enumerate(memref_type.shape): assert isa(s, builtin.IntAttr) if (s.data == -1): # Need to pass the dimension shape size in explicitly as it is deferred - const_op=arith.Constant.from_int_and_width(idx, builtin.IndexType()) - dim_size=memref.Dim.from_source_and_index(src_memref, const_op) + const_op=arith.ConstantOp.from_int_and_width(idx, builtin.IndexType()) + dim_size=memref.DimOp.from_source_and_index(src_memref, const_op) memref_dim_ops+=[const_op, dim_size] arg_ssa.append(dim_size.results[0]) arg_types.append(dim_size.results[0].type) + loc_idx+=1 - locations[var_op]=loc_idx - loc_idx+=1 if len(var_op.bounds) > 0: bound_op=var_op.bounds[0].owner - bound_op.parent.detach_op(bound_op) - #self.target_ops+=[bound_op, var_op] - arg_types.append(bound_op.lower[0].type) - arg_ssa.append(bound_op.lower[0]) + arg_types.append(bound_op.lower_bound.type) + arg_ssa.append(bound_op.lower_bound) + arg_types.append(bound_op.upper_bound.type) + arg_ssa.append(bound_op.upper_bound) locations[bound_op]=loc_idx - # Add two, as second is the size + # Adding both lower and upper bound loc_idx+=2 else: pass#self.target_ops+=[var_op] @@ -78,7 +83,7 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): bound_op=var_op.bounds[0].owner res_types=[] for res in bound_op.results: res_types.append(res.type) - new_bounds_op=omp.BoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [], [], [], []], + new_bounds_op=omp.MapBoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [new_block.args[locations[bound_op]+1]], [], [], []], properties={"stride_in_bytes": bound_op.stride_in_bytes}, result_types=res_types) @@ -87,8 +92,8 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): res_types=[] for res in var_op.results: res_types.append(res.type) - mapinfo_op=omp.MapInfoOp.build(operands=[[new_block.args[locations[var_op]]], [], map_bounds], - properties={"map_type": var_op.map_type, "var_name": var_op.var_name, "var_type": var_op.var_type}, + mapinfo_op=omp.MapInfoOp.build(operands=[new_block.args[locations[var_op]], [], [], map_bounds], + properties={"map_type": var_op.map_type, "name": var_op.var_name, "var_type": var_op.var_type, "map_capture_type": omp.VariableCaptureKindAttr(omp.VariableCaptureKind.BY_REF)}, result_types=res_types) new_mapinfo_ssa.append(mapinfo_op.results[0]) @@ -97,9 +102,9 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): reg=op.region op.detach_region(reg) - new_omp_target_op=omp.TargetOp.build(operands=[[],[],[], new_mapinfo_ssa], regions=[reg]) + new_omp_target_op=omp.TargetOp.build(operands=[[],[],[],[],[],[],[],[],[], new_mapinfo_ssa, [], []], regions=[reg]) new_block.add_op(new_omp_target_op) - new_block.add_op(func.Return()) + new_block.add_op(func.ReturnOp()) new_fn_type=builtin.FunctionType.from_lists(arg_types, []) @@ -107,11 +112,13 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): body.add_block(new_block) new_func=func.FuncOp("tt_device", new_fn_type, body) + new_func_signature=func.FuncOp.external("tt_device", new_fn_type.inputs.data, new_fn_type.outputs.data) self.target_ops=[new_func] - call_fn=func.Call.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[]) + call_fn=func.CallOp.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[]) op.parent.insert_ops_before(memref_dim_ops+[call_fn], op) + rewriter.insert_op(new_func_signature, InsertPoint.at_start(self.module.body.block)) op.parent.detach_op(op) @@ -122,15 +129,17 @@ class ExtractTarget(ModulePass): """ name = 'extract-target' - def apply(self, ctx: MLContext, module: builtin.ModuleOp): - rw_target= RewriteTarget() + def apply(self, ctx: Context, module: builtin.ModuleOp): + rw_target= RewriteTarget(module) walker = PatternRewriteWalker(GreedyRewritePatternApplier([ rw_target, ]), apply_recursively=False, walk_reverse=True) walker.rewrite_module(module) - containing_mod=builtin.ModuleOp([]) + # NOTE: The region recieving the block must be empty. Otherwise, the single block region rule of + # the module will not be satisfied. + containing_mod=builtin.ModuleOp(Region()) module.regions[0].move_blocks(containing_mod.regions[0]) new_module=builtin.ModuleOp(rw_target.target_ops, {"target": builtin.StringAttr("tt_device")})