diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 3dbb3e36..bc5c9dab 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -493,12 +493,7 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_addmm_residual.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_matmul_activation.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} @@ -508,12 +503,7 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_activation.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_matmul_scalar.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} @@ -523,12 +513,47 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_scalar.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + + - name: Run test_matmul_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_matmul_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_reduction.py + + - name: Run test_bmm_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_bmm_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_bmm_reduction.py + + - name: Run test_prologue_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_prologue_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_prologue_fusion.py + + - name: Run test_transformer_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_transformer_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_transformer_fusion.py + - name: Run test_conv_fusion.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} diff --git a/.github/workflows/pull-request_mobile.yml b/.github/workflows/pull-request_mobile.yml index 945bac3b..0043eaf4 100644 --- a/.github/workflows/pull-request_mobile.yml +++ b/.github/workflows/pull-request_mobile.yml @@ -493,12 +493,7 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_addmm_residual.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_matmul_activation.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} @@ -508,12 +503,7 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_activation.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_matmul_scalar.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} @@ -523,12 +513,7 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_scalar.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_conv_fusion.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} @@ -539,6 +524,46 @@ jobs: -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_conv_fusion.py + - name: Run test_matmul_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_matmul_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_reduction.py + + - name: Run test_bmm_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_bmm_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_bmm_reduction.py + + - name: Run test_prologue_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_prologue_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_prologue_fusion.py + + - name: Run test_transformer_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_transformer_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_transformer_fusion.py + test_moe: name: Run test_moe runs-on: self-hosted diff --git a/AsmParser/onnx_utility.py b/AsmParser/onnx_utility.py index d46e8347..4f76ef35 100644 --- a/AsmParser/onnx_utility.py +++ b/AsmParser/onnx_utility.py @@ -66,12 +66,13 @@ def __init__(self, tile_info, inst_list=list(), node_id=0): super().__init__(node_id) self.inst = inst_list self.torchsim_base_addr = tile_info["base_addr"] - self.torchsim_stride_list = tile_info["stride_list"] self.torchsim_tile_size = tile_info["tile_size"] + self.torchsim_tile_stride = tile_info["tile_stride"] self.torchsim_element_size = tile_info["element_size"] self.torchsim_tag_idx_list = tile_info["tag_idx_list"] self.torchsim_tag_stride_list = tile_info["tag_stride_list"] self.torchsim_loop_idx_list = tile_info["loop_idx_list"] + self.torchsim_loop_stride_list = tile_info["loop_stride_list"] self.torchsim_is_async = tile_info["is_async"] self.torchsim_indirect_mode = tile_info["indirect_mode"] diff --git a/AsmParser/tog_generator.py b/AsmParser/tog_generator.py index 1dea2f8d..5f586d99 100644 --- a/AsmParser/tog_generator.py +++ b/AsmParser/tog_generator.py @@ -91,12 +91,13 @@ def _create_node(self, dump_data): elif node_type == self.DMANodeKind: tile_info = {} tile_info["base_addr"] = dump_data["base_address"] - tile_info["stride_list"] = dump_data["stride_list"] tile_info["tile_size"] = dump_data["tile_size"] + tile_info["tile_stride"] = dump_data["tile_stride"] tile_info["element_size"] = dump_data["element_size"] tile_info["tag_idx_list"] = dump_data["tag_idx_list"] tile_info["tag_stride_list"] = dump_data["tag_stride_list"] tile_info["loop_idx_list"] = dump_data["loop_idx_list"] + tile_info["loop_stride_list"] = dump_data["loop_stride_list"] tile_info["is_async"] = dump_data["is_async"] tile_info["indirect_mode"] = dump_data["indirect_mode"] is_write = dump_data["is_write"] diff --git a/PyTorchSimBackend/include/Instruction.h b/PyTorchSimBackend/include/Instruction.h index 84b17d7c..4c14dd81 100644 --- a/PyTorchSimBackend/include/Instruction.h +++ b/PyTorchSimBackend/include/Instruction.h @@ -22,9 +22,10 @@ std::string opcode_to_string(Opcode opcode); class Instruction : public std::enable_shared_from_this { public: Instruction(Opcode opcode, cycle_type compute_cycle, size_t num_parents, addr_type dram_addr, - std::vector tile_size, size_t precision, std::vector &idx_list, - std::vector &stride_list, std::vector tag_idx_list, std::vector tag_stride_list, - std::vector accum_tag_idx_list, std::vector loop_size_list); + std::vector tile_size, std::vector tile_stride, size_t precision, + std::vector tag_idx_list, std::vector tag_stride_list, + std::vector accum_tag_idx_list); + Instruction(Opcode opcode); void finish_instruction(); void add_child(std::shared_ptr child); bool check_ready() { return ready_counter == 0; } @@ -60,10 +61,6 @@ class Instruction : public std::enable_shared_from_this { bool load_indirect_index(const std::string& path, uint64_t*& indirect_index, const std::vector& tile_size); void set_trace_address(std::vector& trace_address) { _trace_address = trace_address; } size_t get_free_sram_size() { return _free_sram_size; } - void adjust_dram_address() { - int offset = std::inner_product(_idx_list.begin(), _idx_list.end(), _stride_list.begin(), 0); - dram_addr += offset * _precision; - } addr_type get_base_dram_address() { return dram_addr; } void set_free_sram_size(size_t sram_size) { _free_sram_size=sram_size; } void* get_owner() { return _owner; } @@ -73,7 +70,6 @@ class Instruction : public std::enable_shared_from_this { int get_compute_type() { return _compute_type; } void set_numa_id(int numa_id) { _numa_id = numa_id; } uint32_t get_numa_id() { return _numa_id; } - std::vector& get_idx_list() { return _idx_list; } std::vector& get_tag_idx_list() { return _tag_idx_list; } std::vector& get_tag_stride_list() { return _tag_stride_list; } std::vector& get_tag_id() { return _tag_key; } @@ -103,6 +99,7 @@ class Instruction : public std::enable_shared_from_this { size_t ready_counter; std::set> child_inst; std::vector tile_size; + std::vector tile_stride; size_t _tile_numel; size_t _nr_waiting_request=0; size_t _precision=0; @@ -110,13 +107,10 @@ class Instruction : public std::enable_shared_from_this { addr_type dram_addr; uint32_t _numa_id = 0; // For DMA instruction int _compute_type = 0; - std::vector _idx_list; - std::vector _stride_list; std::vector _tag_idx_list; std::vector _tag_stride_list; std::vector _tag_key; std::vector _accum_tag_idx_list; - std::vector _loop_size_list; std::vector _trace_address; std::string _addr_name; int _addr_id; diff --git a/PyTorchSimBackend/include/TileGraphParser.h b/PyTorchSimBackend/include/TileGraphParser.h index b5322b76..5b561127 100644 --- a/PyTorchSimBackend/include/TileGraphParser.h +++ b/PyTorchSimBackend/include/TileGraphParser.h @@ -175,17 +175,18 @@ class TileMemoryNode : public TileNode { std::string get_base_addr_name() { return _base_addr_name; } size_t get_precision() { return _element_size; } std::vector get_tile_size() { return _tile_size; } - std::vector& get_stride_list () { return _stride_list; } + std::vector& get_tile_stride() { return _tile_stride; } std::vector& get_tag_idx_list() { return _tag_idx_list; } std::vector& get_tag_stride_list() { return _tag_stride_list; } std::vector& get_loop_idx_list() { return _loop_idx_list; } + std::vector& get_loop_stride_list () { return _loop_stride_list; } bool is_async_node() { return _is_async; } bool is_indirect() { return _is_indirect; } void print_node() override; private: std::vector _tile_size; - std::vector _stride_list; + std::vector _tile_stride; size_t _element_size; bool _is_async; bool _is_indirect; @@ -193,6 +194,7 @@ class TileMemoryNode : public TileNode { std::vector _tag_idx_list; std::vector _tag_stride_list; std::vector _loop_idx_list; + std::vector _loop_stride_list; }; class TileMemoryWaitNode : public TileNode { diff --git a/PyTorchSimBackend/src/Instruction.cc b/PyTorchSimBackend/src/Instruction.cc index b706ca8f..aef9079c 100644 --- a/PyTorchSimBackend/src/Instruction.cc +++ b/PyTorchSimBackend/src/Instruction.cc @@ -11,23 +11,22 @@ std::string opcode_to_string(Opcode opcode) { } Instruction::Instruction(Opcode opcode, cycle_type compute_cycle, size_t num_parents, - addr_type dram_addr, std::vector tile_size, size_t precision, - std::vector& idx_list, std::vector& stride_list, + addr_type dram_addr, std::vector tile_size, std::vector tile_stride, size_t precision, std::vector tag_idx_list, std::vector tag_stride_list, - std::vector accum_tag_idx_list, std::vector loop_size_list) + std::vector accum_tag_idx_list) : opcode(opcode), compute_cycle(compute_cycle), ready_counter(num_parents), dram_addr(dram_addr), - tile_size(tile_size), _precision(precision), _idx_list(idx_list), - _stride_list(stride_list), _tag_idx_list(tag_idx_list), _tag_stride_list(tag_stride_list), - _accum_tag_idx_list(accum_tag_idx_list), _loop_size_list(loop_size_list) { + tile_size(tile_size), tile_stride(tile_stride), _precision(precision), + _tag_idx_list(tag_idx_list), _tag_stride_list(tag_stride_list), + _accum_tag_idx_list(accum_tag_idx_list) { assert(_tag_idx_list.size()==_tag_stride_list.size()); _tile_numel = 1; for (auto dim : tile_size) _tile_numel *= dim; +} - /* Supporting vector */ - if (_stride_list.size() == 1) { - _stride_list.push_back(1); - } +Instruction::Instruction(Opcode opcode) + : opcode(opcode) { + _tile_numel = 1; } void Instruction::finish_instruction() { @@ -73,8 +72,8 @@ std::shared_ptr> Instruction::get_dram_address(addr_type dra while (tile_size.size() < 4) tile_size.insert(tile_size.begin(), 1); - while (_stride_list.size() < 4) - _stride_list.insert(_stride_list.begin(), 0); + while (tile_stride.size() < 4) + tile_stride.insert(tile_stride.begin(), 0); if (_is_indirect_mode) { spdlog::trace("[Indirect Access] Indirect mode, dump_path: {}", _indirect_index_path); load_indirect_index(_indirect_index_path, indirect_index, tile_size); @@ -85,10 +84,10 @@ std::shared_ptr> Instruction::get_dram_address(addr_type dra for (int dim1=0; dim1> TileLoopNode::get_tiles_from_iter(TileGraphPa for (auto& tile_node: _body_node) { if (tile_node->get_type() == TileType::LOAD_NODE) { std::shared_ptr mem_node = std::static_pointer_cast(tile_node); - auto base_addr_name = mem_node->get_base_addr_name(); - int base_addr_id = tog_parser->register_addr_name(base_addr_name); - std::vector& tag_idx_list = mem_node->get_tag_idx_list(); - std::vector& tag_stride_list = mem_node->get_tag_stride_list(); - std::vector skip_idx_list; - std::vector values; - - /* Lookup given name's address */ - addr_type base_addr = tog_parser->lookup(base_addr_name); std::vector iter_list; - std::vector tag_list; - std::vector accum_tag_list; - std::vector loop_size_list; - std::vector outer_loop_idx; - std::vector outer_loop_size; int nr_inner_loop = 0; auto& loop_idx_list = mem_node->get_loop_idx_list(); for (auto loop_idx: loop_idx_list) { - auto iter_value = getLoopIndexValue(iter, loop_idx); + int iter_value = getLoopIndexValue(iter, loop_idx); iter_list.push_back(iter_value); - loop_size_list.push_back(tog_parser->get_loop_size(loop_idx)); if (tog_parser->get_loop_type(loop_idx)==LoopType::INNER_LOOP) nr_inner_loop++; } + + /* Base address setting */ + std::string base_addr_name = mem_node->get_base_addr_name(); + int base_addr_id = tog_parser->register_addr_name(base_addr_name); + addr_type base_addr = tog_parser->lookup(base_addr_name); + addr_type offset = std::inner_product(iter_list.begin(), iter_list.end(), mem_node->get_loop_stride_list().begin(), 0); + + std::vector tag_list; + std::vector accum_tag_list; + std::vector outer_loop_idx; + std::vector outer_loop_size; /* Add accumulation loop info to accum_tag list */ for (auto loop_idx = loop_idx_list.begin(); loop_idx != loop_idx_list.end() - nr_inner_loop; ++loop_idx) { @@ -375,6 +375,10 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa accum_tag_list.push_back(iter_value); } } + /* Default accum tag */ + if (accum_tag_list.empty()) { + accum_tag_list.push_back(0); + } for (auto loop_idx = loop_idx_list.begin(); loop_idx != loop_idx_list.end(); ++loop_idx) { @@ -387,7 +391,7 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa } uint32_t systolic_size = std::stoi(tog_parser->getMetaByName("systolic_size")); - for (auto loop_idx: tag_idx_list) { + for (auto loop_idx: mem_node->get_tag_idx_list()) { if (iter.find(loop_idx) == iter.end()) tag_list.push_back(0); else { @@ -406,25 +410,32 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa int stride_idx = calculateAddress(outer_loop_size, tog_parser->lookupNumaInfo(base_addr_name)); numa_id = total_idx / stride_idx; } + /* Check need to make this memory node */ + std::vector& tag_stride_list = mem_node->get_tag_stride_list(); std::vector key = tog_parser->calc_tag(accum_tag_list, tag_list, tag_stride_list); if (tog_parser->check_memory_tag(base_addr_name, key)) continue; tog_parser->register_memory_tag(base_addr_name, key); printIndexMap("[TOGParser] Load Node " + mem_node->get_base_addr_name() + " Numa_id: " + std::to_string(numa_id), iter); + spdlog::trace("[TOGParser] Load Node {} key = [{}], accum = [{}], tag = [{}], stride = [{}]", mem_node->get_base_addr_name(), + fmt::join(key, ", "), + fmt::join(accum_tag_list, ", "), + fmt::join(tag_list, ", "), + fmt::join(tag_stride_list, ", ")); std::shared_ptr inst = std::make_shared( Opcode::MOVIN, 0, - 0, base_addr, - mem_node->get_tile_size(), mem_node->get_precision(), iter_list, - mem_node->get_stride_list(), tag_list, tag_stride_list, accum_tag_list, loop_size_list + 0, base_addr+offset, + mem_node->get_tile_size(), mem_node->get_tile_stride(), mem_node->get_precision(), + tag_list, tag_stride_list, accum_tag_list ); inst->set_addr_name(base_addr_name, base_addr_id); inst->prepare_tag_key(); inst->set_nr_inner_loop(nr_inner_loop); - inst->adjust_dram_address(); inst->set_is_async(mem_node->is_async_node()); inst->set_numa_id(numa_id); + if (mem_node->is_indirect()) { inst->set_indirect_index_path(tog_parser->get_indirect_path()); tog_parser->inc_indirect_counter(); @@ -439,14 +450,7 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa tile_vec.back()->append_instuction(inst); } else if (tile_node->get_type() == TileType::STORE_NODE) { std::shared_ptr mem_node = std::static_pointer_cast(tile_node); - auto base_addr_name = mem_node->get_base_addr_name(); - int base_addr_id = tog_parser->register_addr_name(base_addr_name); - /* Lookup given name's address */ - addr_type base_addr = tog_parser->lookup(base_addr_name); - std::vector& tag_stride_list = mem_node->get_tag_stride_list(); - std::vector accum_tag_list; std::vector iter_list; - std::vector loop_size_list; std::vector outer_loop_idx; std::vector outer_loop_size; int nr_inner_loop = 0; @@ -454,7 +458,6 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa for (auto loop_idx: loop_idx_list) { auto iter_value = getLoopIndexValue(iter, loop_idx); iter_list.push_back(iter_value); - loop_size_list.push_back(tog_parser->get_loop_size(loop_idx)); if (tog_parser->get_loop_type(loop_idx)==LoopType::INNER_LOOP) nr_inner_loop++; if (tog_parser->get_loop_type(loop_idx)==LoopType::PARALLEL_LOOP) { @@ -465,6 +468,12 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa } } + /* Lookup given name's address */ + std::string base_addr_name = mem_node->get_base_addr_name(); + int base_addr_id = tog_parser->register_addr_name(base_addr_name); + addr_type base_addr = tog_parser->lookup(base_addr_name); + addr_type offset = std::inner_product(iter_list.begin(), iter_list.end(), mem_node->get_loop_stride_list().begin(), 0); + /* Calc numa id */ int numa_id = 0; auto numa_stride_size = tog_parser->lookupNumaInfo(base_addr_name).size(); @@ -477,14 +486,13 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa printIndexMap("[TOGParser] Store Node " + mem_node->get_base_addr_name() + " Numa_id: " + std::to_string(numa_id), iter); std::shared_ptr inst = std::make_shared( Opcode::MOVOUT, 0, - 0, base_addr, - mem_node->get_tile_size(), mem_node->get_precision(), iter_list, - mem_node->get_stride_list(), std::vector(1), tag_stride_list, accum_tag_list, loop_size_list + 0, base_addr+offset, + mem_node->get_tile_size(), mem_node->get_tile_stride(), mem_node->get_precision(), + std::vector(1), mem_node->get_tag_stride_list(), std::vector() ); inst->set_addr_name(base_addr_name, base_addr_id); inst->prepare_tag_key(); inst->set_nr_inner_loop(nr_inner_loop); - inst->adjust_dram_address(); inst->set_is_async(mem_node->is_async_node()); inst->set_numa_id(numa_id); if (mem_node->is_indirect()) { @@ -523,6 +531,10 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa tag_list.push_back(iter_value); } } + /* Default accum tag */ + if (accum_tag_list.empty()) { + accum_tag_list.push_back(0); + } /* Skip accum stride */ for (auto i : tag_stride_list) { @@ -530,11 +542,16 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa new_tag_stride_list.push_back(i); } + spdlog::trace("[TOGParser] Wait Node {}, accum = [{}], tag = [{}], stride = [{}]", wait_node->get_base_addr_name(), + fmt::join(accum_tag_list, ", "), + fmt::join(tag_list, ", "), + fmt::join(new_tag_stride_list, ", ")); + std::shared_ptr inst = std::make_shared( Opcode::BAR, 0, 0, base_addr, - std::vector(), 0, iter_list, - iter_list, tag_list, new_tag_stride_list, accum_tag_list, std::vector() + std::vector(), std::vector(), 0, + tag_list, new_tag_stride_list, accum_tag_list ); inst->set_addr_name(base_addr_name, base_addr_id); inst->prepare_tag_key(); @@ -543,15 +560,14 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa } else if (tile_node->get_type() == TileType::COMPUTE_NODE) { printIndexMap("[TOGParser] Compute Node ", iter); std::shared_ptr compute_node = std::static_pointer_cast(tile_node); - std::vector iter_list; std::vector tag_list = {0}; std::vector tag_stride_list = {1}; std::vector accum_tag_list; std::shared_ptr inst = std::make_shared( Opcode::COMP, compute_node->get_cycle(), 0, 0, - std::vector(), 0, iter_list, iter_list, - tag_list, tag_stride_list, accum_tag_list, std::vector() + std::vector(), std::vector(), 0, + tag_list, tag_stride_list, accum_tag_list ); inst->set_overlapping_cycle(compute_node->get_overlapping_cycle()); inst->set_compute_type(compute_node->get_compute_type()); @@ -620,72 +636,28 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa } else if (tile_node->get_type() == TileType::STONNE_NODE) { printIndexMap("[TOGParser] Stonne Node ", iter); std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); - /* Lookup given name's address */ - std::vector iter_list; - std::vector tag_list; - std::vector tag_stride_list; - std::vector accum_tag_list; - - /* Put dummy computation instruction */ - std::shared_ptr inst = std::make_shared( - Opcode::COMP, 0, - 0, 0, - std::vector(), 0, iter_list, - iter_list, tag_list, tag_stride_list, accum_tag_list, std::vector() - ); + std::shared_ptr inst = std::make_shared(Opcode::COMP); link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); tile_vec.back()->set_custom_data(stonne_node->getDesc()); tile_vec.back()->set_stonne_tile(true); } else if (tile_node->get_type() == TileType::STONNE_TRACE_COMPUTE_NODE) { std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); - /* Lookup given name's address */ - std::vector iter_list; - std::vector tag_list; - std::vector tag_stride_list; - std::vector accum_tag_list; - - std::shared_ptr inst = std::make_shared( - Opcode::COMP, stonne_node->get_cycle(), - 0, 0, - std::vector(), 0, iter_list, - iter_list, tag_list, tag_stride_list, accum_tag_list, std::vector() - ); + std::shared_ptr inst = std::make_shared(Opcode::COMP); + inst->set_compute_cycle(stonne_node->get_cycle()); link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); tile_vec.back()->set_stonne_tile(true); } else if (tile_node->get_type() == TileType::STONNE_TRACE_LOAD_NODE) { std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); - /* Lookup given name's address */ - std::vector iter_list; - std::vector tag_list; - std::vector tag_stride_list; - std::vector accum_tag_list; - - std::shared_ptr inst = std::make_shared( - Opcode::MOVIN, 0, - 0, 0, - std::vector(), 0, iter_list, - iter_list, tag_list, tag_stride_list, accum_tag_list, std::vector() - ); + std::shared_ptr inst = std::make_shared(Opcode::MOVIN); inst->set_trace_address(stonne_node->get_address()); link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); tile_vec.back()->set_stonne_tile(true); } else if (tile_node->get_type() == TileType::STONNE_TRACE_STORE_NODE) { std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); - /* Lookup given name's address */ - std::vector iter_list; - std::vector tag_list; - std::vector tag_stride_list; - std::vector accum_tag_list; - - std::shared_ptr inst = std::make_shared( - Opcode::MOVOUT, 0, - 0, 0, - std::vector(), 0, iter_list, - iter_list, tag_list, tag_stride_list, accum_tag_list, std::vector() - ); + std::shared_ptr inst = std::make_shared(Opcode::MOVOUT); inst->set_trace_address(stonne_node->get_address()); link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); diff --git a/PyTorchSimFrontend/common_diff.py b/PyTorchSimFrontend/common_diff.py deleted file mode 100644 index 6c1c875c..00000000 --- a/PyTorchSimFrontend/common_diff.py +++ /dev/null @@ -1,1031 +0,0 @@ -import contextlib -import dataclasses -import functools -import itertools -import logging -import operator -import re -from collections import namedtuple -from itertools import chain -from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Set, Union - -import sympy -from sympy.printing.printer import Printer - -import torch -import torch.fx -from torch.utils._sympy.value_ranges import ValueRanges - -from .. import metrics -from ..utils import ( - DeferredLineBase, - free_symbol_startswith, - get_sympy_Expr_dtype, - IndentedBuffer, - sympy_dot, - sympy_subs, - unique, -) -from ..virtualized import ops, OpsValue, V - -schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") - - -def data_type_logger(msg): - if schedule_log.isEnabledFor(logging.DEBUG): - schedule_log.debug("Data type propagation: %s", msg) - - -TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype"]) -SizeArg = namedtuple("SizeArg", ["name", "expr"]) - -DeviceCodegen = namedtuple("DeviceCodegen", ["scheduling", "wrapper_codegen"]) -device_codegens: Dict[str, DeviceCodegen] = {} - - -# The code generated by Inductor consists of two main parts: kernel code and wrapper code. -# For any new backend looking to integrate with Inductor, customization of these two main -# parts are necessary to generate its specific code. -# -# Kernel code generation is determined by different Scheduling. Consequently, a new -# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, -# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. -# -# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code -# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen, -# and override specific member functions to create backend-specific Python wrapper code. -# -# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part -# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces -# provide flexibility to the backend. A backend can choose to implement these classes from scratch, -# or reuse them by extending and overriding as necessary. And Inductor provides the registration API, -# register_backend_for_device, to equip a new backend at runtime. -# -# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. -# This backend can be used as a reference: -# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 -def register_backend_for_device( - device: str, device_scheduling: type, device_wrapper_codegen: type -): - device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen) - - -def get_scheduling_for_device(device: str): - return device_codegens[device].scheduling if device in device_codegens else None - - -def get_wrapper_codegen_for_device(device: str): - return ( - device_codegens[device].wrapper_codegen if device in device_codegens else None - ) - - -def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes): - from ..ir import FlexibleLayout - - # added contiguous index prevents reordering - return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] - - -@functools.lru_cache(None) -def boolean_ops(): - return ( - "is_inf", - "is_nan", - "bitwise_xor", - "logical_not", - "signbit", - "le", - "lt", - "ge", - "gt", - "eq", - "ne", - ) - - -DTYPE_TO_COMPUTATION_DTYPE = { - torch.bfloat16: torch.float, - torch.float16: torch.float, - **{ - dtype: dtype - for dtype in [ - torch.bool, - torch.float32, - torch.float64, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - ] - }, -} - - -class DataTypePropagation: - def __init__(self, body) -> None: - self.body = body - self.graphs: Dict[Union[Callable[..., Any], str], Any] = { - "root": body.root_block.graph - } - for k, v in body.subblocks.items(): - self.graphs[k] = v.graph - - def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): - inputs = node.all_input_nodes - input_nodes = [ - n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" - ] - if len(input_nodes) == 0: - return None - - all_input_nodes_propogated = all( - OptimizationContext.key in n.meta - and n.meta[OptimizationContext.key].dtype is not None - for n in input_nodes - ) - if not all_input_nodes_propogated: - return None - - return functools.reduce( - torch.promote_types, - [n.meta[OptimizationContext.key].dtype for n in input_nodes], - ) - - def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node): - sub_graph = self.graphs[node.target] - dtype = self.propagate_graph(sub_graph) - assert dtype - return dtype - - def deduce_node_dtype(self, node: torch.fx.Node): - if node.target in boolean_ops(): - return torch.bool - - if node.op == "placeholder": - return None - - if node.target == "output": - # we can infer output node if it only have 1 arg - if len(node.args) != 1: - return None - - if node.target in ( - "to_dtype", - "index_expr", - ): - return node.args[-1] - - if node.target in ( - "rand", - "randn", - ): - return torch.float - - if node.target in ( - "get_index", - "index_expr", - ): - return torch.int64 - - if node.target in ( - "load", - "store", - "store_reduction", - ): - buf_name = node.args[1] - return V.graph.get_dtype(buf_name) - - if node.target == operator.getitem: - return self.deduce_node_dtype(node.args[0]) - - assert isinstance(node.target, str) - - if node.target == "reduction": - return node.args[1] - - if node.target == "constant": - return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] - - if node.target.startswith("masked_subblock"): - return self.deduce_node_dtype_by_subgraph(node) - - return self.deduce_node_dtype_by_inputs(node) - - def propagate_graph(self, graph: torch.fx.Graph): - assert graph.nodes - graph_dtype = None - # For masked_subblock, we use output's dtype to represent - # the dtype of this subgraph. For other cases, graph_dtype - # might be None - for node in graph.nodes: - if OptimizationContext.key in node.meta: - opt_ctx = node.meta[OptimizationContext.key] - else: - opt_ctx = OptimizationContext() - - opt_ctx.dtype = self.deduce_node_dtype(node) - node.meta[OptimizationContext.key] = opt_ctx - if node.target == "output": - graph_dtype = opt_ctx.dtype - return graph_dtype - - def propagate(self): - self.propagate_graph(self.graphs["root"]) - - @classmethod - def propagate_loopbody(cls, body): - return cls(body).propagate() - - @classmethod - def propagate_scheduler_node(cls, node): - from ..ir import LoopBody - from ..scheduler import SchedulerNode - - assert isinstance(node, SchedulerNode) - assert isinstance(node._body, LoopBody) - DataTypePropagation.propagate_loopbody(node._body) - - -class ExprPrinter(Printer): - @staticmethod - def paren(string): - def all_in_parens(string): - if string[0] != "(" or len(string) < 2: - return False - count = 1 - for i, char in enumerate(string[1:]): - if char == "(": - count += 1 - elif char == ")": - count -= 1 - if count == 0 and i != len(string) - 2: - return False - assert count == 0 - return True - - if ( - isinstance(string, CSEVariable) - or re.match(r"^[a-z0-9_.]+$", string, re.I) - or re.match(r"^\([^)]*\)$", string, re.I) - or string == "" - ): - return string - # don't put extra parens for strings that are already wrapped in parens - if all_in_parens(string): - return string - return f"({string})" - - def _print_Pow(self, expr): - # Pow() confuses triton - base, exp = expr.args - # NB: Remember this is sizevar computation! You don't typically - # expect to have to do floating point computation including exponents - # in sizevar compute. Instead of adding support for floating - # point pow, you should make upstream retranslate the Sympy expression - # into Tensor expressions earlier and do that instead. - if exp == 0.5: - return self._helper_sqrt(base) # type: ignore[attr-defined] - elif exp == -0.5: - return "1/" + self._helper_sqrt(base) # type: ignore[attr-defined] - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - if exp > 0: - return "*".join([self.paren(base)] * exp) - elif exp < 0: - return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - return "1" - - def _print_Unequality(self, expr): - return " != ".join(map(self.paren, map(self._print, expr.args))) - - def _print_Mul(self, expr): - return "*".join(map(self.paren, map(self._print, expr.args))) - - def _print_Add(self, expr): - return " + ".join(map(self.paren, map(self._print, expr.args))) - - def _print_Mod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - def _print_CleanDiv(self, expr): - return self._print_FloorDiv(expr) # type: ignore[attr-defined] - - def _print_GreaterThan(self, expr): - # GreaterThan: >= - # StrictlyGreaterThan: > - # Go figure... - return " >= ".join(map(self.paren, map(self._print, expr.args))) - - -class PythonPrinter(ExprPrinter): - def _print_ModularIndexing(self, expr): - x, div, mod = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - mod = self.paren(self.doprint(mod)) - if div != "1": - x = f"({x} // {div})" - return f"{x} % {mod}" - - def _print_FloorDiv(self, expr): - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - return f"({x} // {div})" - - def _helper_sqrt(self, expr): - return f"math.sqrt({self._print(expr)})" - - def _print_floor(self, expr): - assert len(expr.args) == 1 - return f"math.floor({self._print(expr.args[0])})" - - def _print_ceiling(self, expr): - assert len(expr.args) == 1 - return f"math.ceil({self._print(expr.args[0])})" - - -class OpOverrides: - def __init__(self, parent): - super().__init__() - self._parent = parent - - def __getattr__(self, item): - return getattr(self._parent, item) - - @staticmethod - def identity(value): - # used to trigger cse - return value - - @staticmethod - def constant(value, dtype): - return repr(value) - - @staticmethod - def reciprocal(x): - return ops.div("1", x) - - @staticmethod - def square(x): - return ops.mul(x, x) - - @staticmethod - def bitwise_not(x): - return f"~{ExprPrinter.paren(x)}" - - @staticmethod - def logical_not(a): - return f"{ExprPrinter.paren(a)} == 0" - - @staticmethod - def bitwise_and(x, y): - return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" - - @staticmethod - def bitwise_or(x, y): - return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" - - @staticmethod - def bitwise_xor(x, y): - return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" - - @staticmethod - def bitwise_left_shift(x, y): - return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" - - # TODO(fdrocha): this is currently not being used anywhere, - # pending on moving triton pin past 972b761 - @staticmethod - def bitwise_right_shift(x, y): - return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" - - @staticmethod - def remainder(a, b): - r = ops.mod(a, b) - return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r) - - @staticmethod - def load_seed(name, offset): - return ops.load(name, sympy.Integer(offset)) - - -class DeferredLine(DeferredLineBase): - """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" - - def __init__(self, name, line): - super().__init__(line) - self.name = name - - def __call__(self): - if ( - self.name not in V.graph.removed_buffers - and self.name not in V.graph.inplaced_to_remove - ): - return self.line - return None - - def _new_line(self, line): - return DeferredLine(self.name, line) - - -class BracesBuffer(IndentedBuffer): - def indent(self, offset=1): - @contextlib.contextmanager - def ctx(): - for _ in range(offset): - self.writeline("{") - self._indent += 1 - for _ in range(-offset): - self._indent -= 1 - self.writeline("}") - yield - for _ in range(-offset): - self.writeline("{") - self._indent += 1 - for _ in range(offset): - self._indent -= 1 - self.writeline("}") - - return ctx() - - -class InplacedBuffer(NamedTuple): - inner_name: str - other_names: List[str] - - -class KernelArgs: - @staticmethod - def _lookup(prefix, odict, name): - assert isinstance(name, (str, sympy.Symbol)) - if name not in odict: - odict[name] = f"{prefix}{len(odict)}" - return odict[name] - - def __init__(self, sizevars=None): - self.input_buffers = dict() - self.output_buffers = dict() - self.inplace_buffers = dict() - self.sizevars = sizevars or dict() - - def __repr__(self): - return "KernelArgs({})".format( - ", ".join( - map( - repr, - [ - self.input_buffers, - self.output_buffers, - self.inplace_buffers, - self.sizevars, - ], - ) - ) - ) - - def _buffer_is_marked_removed(self, name): - return isinstance(name, str) and name.startswith("REMOVED") - - def input(self, name): - if V.graph.scheduler: - name = V.graph.scheduler.mutation_real_name.get(name, name) - assert name not in V.graph.removed_buffers, name - if name in self.output_buffers: - return self.output_buffers[name] - if name in self.inplace_buffers: - return self.inplace_buffers[name].inner_name - if name.startswith("seed"): - return self._lookup("seed", self.input_buffers, name) - return self._lookup("in_ptr", self.input_buffers, name) - - def output(self, name): - if V.graph.scheduler: - name = V.graph.scheduler.mutation_real_name.get(name, name) - assert name not in V.graph.removed_buffers, name - if name in self.inplace_buffers: - return self.inplace_buffers[name].inner_name - return self._lookup("out_ptr", self.output_buffers, name) - - def make_inplace(self, input_name, output_name): - assert output_name not in self.inplace_buffers - if input_name in self.inplace_buffers: - buf = self.inplace_buffers[input_name] - buf.other_names.append(output_name) - self.inplace_buffers[output_name] = buf - else: - buf = InplacedBuffer( - f"in_out_ptr{len(unique(self.inplace_buffers.values()))}", - [input_name, output_name], - ) - self.inplace_buffers[input_name] = buf - self.inplace_buffers[output_name] = buf - - def seed_offset(self, name, value): - if value in self.sizevars: - return self.sizevars[value] - if name in self.sizevars.values(): - name = ( - f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" - ) - self.sizevars[value] = name - return name - - def size(self, name): - if str(name) == "seed": - self.sizevars["seed"] = "seed" - return "seed" - return self._lookup("ks", self.sizevars, name) - - def call_names(self): - return chain( - self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() - ) - - def wrap_ptr_arg(self, buf, dtype): - return f"c_void_p({buf}.data_ptr())" - - def wrap_size_arg(self, size): - return f"c_long({size})" - - def cpp_argdefs(self): - from .cpp import DTYPE_TO_CPP, INDEX_TYPE - - # TODO(jansel): replace this with data from scheduler - buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers} - for name, val in V.graph.graph_inputs.items(): - if isinstance(val, sympy.Expr): - buffer_types[name] = get_sympy_Expr_dtype(val) - else: - buffer_types[name] = val.get_dtype() - buffer_types.update( - {name: val.dtype for name, val in V.graph.constants.items()} - ) - - call_args = [] - arg_defs = [] - arg_types = [] - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - outer = inplaced.other_names[-1] - inner = inplaced.inner_name - dtype = buffer_types[outer] - cpp_dtype = DTYPE_TO_CPP[dtype] - arg_defs.append(f"{cpp_dtype}* {inner}") - call_args.append(self.wrap_ptr_arg(outer, dtype)) - arg_types.append(f"{cpp_dtype}*") - for outer, inner in self.input_buffers.items(): - if outer in self.inplace_buffers: - continue - dtype = buffer_types[outer] - cpp_dtype = DTYPE_TO_CPP[dtype] - arg_defs.append(f"const {cpp_dtype}* {inner}") - call_args.append(self.wrap_ptr_arg(outer, dtype)) - arg_types.append(f"const {cpp_dtype}*") - for outer, inner in self.output_buffers.items(): - if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): - continue - dtype = buffer_types[outer] - cpp_dtype = DTYPE_TO_CPP[dtype] - arg_defs.append(f"{cpp_dtype}* {inner}") - call_args.append(self.wrap_ptr_arg(outer, dtype)) - arg_types.append(f"{cpp_dtype}*") - for outer, inner in self.sizevars.items(): - arg_defs.append(f"const {INDEX_TYPE} {inner}") - call_args.append(self.wrap_size_arg(outer)) - arg_types.append(f"const {INDEX_TYPE}") - return arg_defs, call_args, arg_types - - def python_argdefs(self): - arg_defs = [] - call_args = [] - precompile_args: List[Union[TensorArg, SizeArg]] = [] - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - arg_defs.append(inplaced.inner_name) - call_args.append(inplaced.other_names[-1]) - precompile_args.append( - TensorArg( - inplaced.inner_name, - inplaced.other_names[-1], - V.graph.get_dtype(inplaced.other_names[-1]), - ) - ) - for outer, inner in chain( - self.input_buffers.items(), self.output_buffers.items() - ): - if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): - continue - arg_defs.append(inner) - call_args.append(outer) - precompile_args.append(TensorArg(inner, outer, V.graph.get_dtype(outer))) - for outer, inner in self.sizevars.items(): - arg_defs.append(inner) - call_args.append(outer) - precompile_args.append(SizeArg(inner, outer)) - - return arg_defs, call_args, precompile_args - - def aliases(self): - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - for other in inplaced.other_names: - if other in V.graph.inplaced_to_remove: - continue - if other in self.input_buffers: - yield self.input_buffers[other], inplaced.inner_name - if other in self.output_buffers: - yield self.output_buffers[other], inplaced.inner_name - - def is_removed(self, name): - def _is_removed(name, buffers): - return name not in buffers or self._buffer_is_marked_removed(buffers[name]) - - return _is_removed(name, self.output_buffers) and _is_removed( - name, self.inplace_buffers - ) - - # Includes inplace buffers, excludes removed buffers. Essentially, - # after you do a call into this kernel, which buffers actually contain - # updated data? Modeled off of python_argdefs. - def live_output_buffers(self): - live_outs = set() - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - live_outs.add(inplaced.other_names[-1]) - for outer, inner in self.output_buffers.items(): - if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): - continue - live_outs.add(outer) - return live_outs - - -class CSEVariable: - """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. - To do so, the backends can simply overload `Kernel.create_cse_var` - The "CSEVariable.update_on_args" method gives you a hook for annotations - See example of TritonCSEVariable in triton.py - """ - - def __init__(self, name, bounds: ValueRanges): - assert isinstance(bounds, ValueRanges) - self.name = name - self.bounds = bounds - - def __str__(self): - return self.name - - def __hash__(self) -> int: - return hash(self.name) - - def __eq__(self, other) -> bool: - return type(other) == type(self) and other.name == self.name - - def update_on_args(self, name, args, kwargs): - pass - - -class CppWrapperKernelArgs(KernelArgs): - def wrap_ptr_arg(self, buf, dtype): - from .cpp import DTYPE_TO_CPP - - return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" - - def wrap_size_arg(self, size): - return f"{size}" - - -class CSE: - """Common subexpression elimination""" - - def __init__( - self, - prefix="", - suffix="", - name_prefix="tmp", - iter_buffers=None, - store_cache=None, - reduction_cache=None, - varname_map=None, - ): - self.prefix = prefix - self.suffix = suffix - self.cache = {} - self.name_prefix = name_prefix - self.store_cache = store_cache or {} - self.reduction_cache = reduction_cache or {} - self.iter_buffer_ids = iter_buffers or itertools.count() - self.invalidated_stores = set() - self.varname_map = varname_map or {} - - def invalidate(self, keep_vars: Set[str]): - for name, tmp in list(self.store_cache.items()): - if tmp not in keep_vars: - del self.store_cache[name] - self.invalidated_stores.add(name) - self.cache = {k: v for k, v in self.cache.items() if v in keep_vars} - - def clone(self): - # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional - return CSE( - prefix=self.prefix, - suffix=self.suffix, - name_prefix=self.name_prefix, - iter_buffers=self.iter_buffer_ids, - store_cache=self.store_cache, - varname_map=self.varname_map, - ) - - def generate( - self, - buffer: IndentedBuffer, - expr: Union[str, CSEVariable, OpsValue], - *, - bounds: ValueRanges = ValueRanges.unknown(), - write=True, - assignment=True, - ) -> CSEVariable: - if isinstance(expr, OpsValue): - expr = expr.value - - assert isinstance(expr, (str, CSEVariable)), type(expr) - assert write or assignment - if isinstance(expr, CSEVariable): - # If the expressions were always created with all the information, we could - # assert expr.bounds == bounds, but sometimes the expression is created - # with the loose ValueRanges.unknown(), so we need to tighten the bounds - expr.bounds = expr.bounds.tighten(bounds) - return expr - cache_key = expr - var = self.cache.get(cache_key, None) - if not var: - var = self.newvar(bounds) if assignment else None - self.cache[cache_key] = var - if write: - if V.kernel.current_node: - V.kernel.current_node.codegen_originating_info( - buffer, only_once=True - ) - if assignment: - line = f"{self.prefix}{var} = {expr}{self.suffix}" - else: - line = f"{expr}{self.suffix}" - buffer.writeline(line) - else: - var.bounds = var.bounds.tighten(bounds) - - return var - - def newvar(self, bounds: ValueRanges = ValueRanges.unknown()) -> CSEVariable: - var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" - var = V.kernel.create_cse_var(var_name, bounds) - self.varname_map[var_name] = var - return var - - -class CodeGen: - def __init__(self): - super().__init__() - self.exit_stack = contextlib.ExitStack() - - def __enter__(self): - self.exit_stack.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.exit_stack.__exit__(exc_type, exc_val, exc_tb) - - -class Kernel(CodeGen): - newvar_prefix = "" - suffix = "" - overrides = None - load_format = None - store_format = None - - def __init__(self, args=None): - super().__init__() - metrics.generated_kernel_count += 1 - self.args = args or KernelArgs() - self.loads = IndentedBuffer() - self.compute = IndentedBuffer() - self.stores = IndentedBuffer() - self.cse = CSE(self.newvar_prefix, self.suffix) - self.must_keep_buffers = set() - self.store_buffer_names = set() - # set in set_current_node - self.current_node = None - self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges]] = None - - @contextlib.contextmanager - def set_current_node(self, node): - prior = self.current_node - self.current_node = node - self.node_to_bounds = node._body.bounds().get_bounds() - try: - yield - finally: - self.current_node = prior - - @contextlib.contextmanager - def swap_buffers(self, lb, cb=None, sb=None): - if cb is None: - cb = lb - loads = self.loads - compute = self.compute - stores = self.stores - cse = self.cse - self.loads = lb - self.compute = cb - self.stores = sb - self.cse = cse.clone() - try: - yield - finally: - self.loads = loads - self.compute = compute - self.stores = stores - self.cse = cse - - def load(self, name: str, index: sympy.Expr): - raise NotImplementedError() - - def indirect_load(self, name: str, index: sympy.Expr): - """A load the depends on an index we have read""" - prior = self.loads - try: - # put the load in the compute section as it might have deps - self.loads = self.compute - return self.load(name, index) - finally: - self.loads = prior - - def store_reduction(self, name, index, value): - raise NotImplementedError() - - def store(self, name, index, value, mode=None): - raise NotImplementedError() - - def reduction(self, dtype, src_dtype, reduction_type, value): - raise NotImplementedError() - - def bucketize( - self, - values, - offsets_name: str, - offsets_size: sympy.Expr, - indexing_dtype: torch.dtype, - right: bool, - ): - """ - See [Note: Inductor bucketize op] - """ - raise NotImplementedError() - - def __enter__(self): - class CSEProxy: - self.name = "CSEProxy" - - @staticmethod - def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] - def inner(*args, **kwargs): - # TritonTemplateKernel has no current_node - buf_bounds = ValueRanges.unknown() - if hasattr(V.interpreter, "current_node"): - fx_node = V.interpreter.current_node - assert isinstance(self.node_to_bounds, dict) - buf_bounds = self.node_to_bounds.get( - fx_node, ValueRanges.unknown() - ) - - csevar = self.cse.generate( - self.compute, - getattr(parent_handler, name)(*args, **kwargs), # type: ignore[has-type] - bounds=buf_bounds, - ) - csevar.update_on_args(name, args, kwargs) - return csevar - - return inner - - @staticmethod - def indirect_indexing(index_var, size, check=True): - # Skip CSE since this doesn't return an expression - return self.indirect_indexing(index_var, size, check) # type: ignore[attr-defined] - - @staticmethod - def load(name: str, index: sympy.Expr): - if name in self.cse.invalidated_stores: - # A load from an invalidated store requires us to - # keep the actual buffer around - V.kernel.must_keep_buffers.add(name) - if free_symbol_startswith(index, "tmp"): - return self.indirect_load(name, index) - store_cache = self.cse.store_cache - if name in store_cache: - return store_cache[name] - return self.load(name, index) - - @staticmethod - def store(name, index, value, mode=None): - self.store_buffer_names.add(name) - if mode is None: - self.cse.store_cache[name] = value - if self.current_node: - for other_name in self.current_node.get_mutations(): - self.cse.store_cache[other_name] = value - if name not in V.graph.removed_buffers: - return self.store(name, index, value, mode=mode) - - @staticmethod - def store_reduction(name, index, value): - self.store_buffer_names.add(name) - self.cse.store_cache[name] = value - if self.current_node: - for other_name in self.current_node.get_mutations(): - self.cse.store_cache[other_name] = value - - if name not in V.graph.removed_buffers: - return self.store_reduction(name, index, value) - - @staticmethod - def reduction(dtype, src_dtype, reduction_type, value): - return self.reduction(dtype, src_dtype, reduction_type, value) - - @staticmethod - def bucketize( - values, - offsets_name: str, - offsets_size: sympy.Expr, - indexing_dtype: torch.dtype, - right: bool, - ): - """ - [Note: Inductor bucketize op] - - Given values (tensor) and offsets_name (reference to the name of a 1D - tensor), calculate the bucket that each value belongs to. - - e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True - return = [ 0, 1, 1, 1, 1, 3, 3, 4]. - - When right == False, bucket i refers to range (offsets[i], offsets[i+1]]. - When right == True, bucket i refers to range [offsets[i], offsets[i+1]). - - Offsets must be non-decreasing or the result is undefined. - """ - return self.bucketize( - values, offsets_name, offsets_size, indexing_dtype, right - ) - - super().__enter__() - assert self.overrides - parent_handler = self.overrides(V.get_ops_handler()) - self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) - self.exit_stack.enter_context(V.set_kernel_handler(self)) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if V.graph.scheduler: - V.graph.scheduler.remove_kernel_local_buffers() - super().__exit__(exc_type, exc_val, exc_tb) - - def rename_indexing(self, index) -> sympy.Expr: - # adds the necessary kernel args for index expressions - # and renames variables in index expressions to kernel arg names - if isinstance(index, (list, tuple)): - return [self.rename_indexing(x) for x in index] - index = V.graph.sizevars.simplify(index) - sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) - replacements = { - x: self.args.size(x) - for x in sorted_symbols - if x.name.startswith("s") or x.name.startswith("ps") - } - return sympy_subs(index, replacements) - - def create_cse_var(self, *args, **kwargs): - return CSEVariable(*args, **kwargs) - - -@dataclasses.dataclass -class OptimizationContext: - key: ClassVar[str] = "opt_ctx" - - # Load value as mask - is_load_as_mask: bool = False - - dtype: torch.dtype = None - ops_name: str = "" - is_most_inner_loop_irrevelant: bool = False - - # Load uint8 value as float32 - is_load_uint8_as_float: bool = False \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 17fa74d9..8994cffe 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -53,9 +53,23 @@ # For block sparse CONFIG_BLOCK_SPARSE = int(os.environ.get('BLOCK_SPARSE', default=0)) -CONFIG_FORCE_TILE_M = int(os.environ.get("TORCHSIM_FORCE_TIME_M", default=sys.maxsize)) -CONFIG_FORCE_TILE_N = int(os.environ.get("TORCHSIM_FORCE_TIME_N", default=sys.maxsize)) -CONFIG_FORCE_TILE_K = int(os.environ.get("TORCHSIM_FORCE_TIME_K", default=sys.maxsize)) + +# For GEMM tile size +CONFIG_MANUAL_TILE_SIZE = int(os.environ.get('TORCHSIM_MANUAL_TILE_SIZE', default=False)) +CONFIG_TILE_M = int(os.environ.get('TORCHSIM_TILE_M', default=CONFIG_VECTOR_LANE)) +CONFIG_TILE_N = int(os.environ.get('TORCHSIM_TILE_N', default=CONFIG_VECTOR_LANE)) +CONFIG_TILE_K = int(os.environ.get('TORCHSIM_TILE_K', default=CONFIG_VECTOR_LANE)) +CONFIG_GEMM_CHEATSHEET_PATH = os.environ.get('TORCHSIM_GEMM_CHEATSHEET_PATH', + default=f"{CONFIG_TORCHSIM_DIR}/validation/gemm_tpuv3_cheatsheet.json") +CONFIG_SUBTILE = int(os.environ.get('TORCHSIM_SUBTILE', default=True)) +CONFIG_MANUAL_SUBTILE_SIZE = int(os.environ.get('TORCHSIM_MANUAL_SUBTILE_SIZE', default=False)) +CONFIG_SUBTILE_M = int(os.environ.get('TORCHSIM_SUBTILE_M', default=CONFIG_VECTOR_LANE)) +CONFIG_SUBTILE_N = int(os.environ.get('TORCHSIM_SUBTILE_N', default=CONFIG_VECTOR_LANE)) +CONFIG_SUBTILE_K = int(os.environ.get('TORCHSIM_SUBTILE_K', default=CONFIG_VECTOR_LANE)) + +# Advanced fusion options +CONFIG_FUSION_REDUCTION = int(os.environ.get('TORCHSIM_FUSION_REDUCTION', default=True)) +CONFIG_FUSION_PROLOGUE = int(os.environ.get('TORCHSIM_FUSION_PROLOGUE', default=True)) # SRAM Buffer allocation plan def load_plan_from_module(module_path): diff --git a/PyTorchSimFrontend/llvm/llvm_caller_codegen.py b/PyTorchSimFrontend/llvm/llvm_caller_codegen.py index 835d9b80..3690f533 100644 --- a/PyTorchSimFrontend/llvm/llvm_caller_codegen.py +++ b/PyTorchSimFrontend/llvm/llvm_caller_codegen.py @@ -231,6 +231,6 @@ def get_spad_size(self, binary_path): spad_end = int(parts[1], 16) if spad_start is None or spad_end is None: - raise ValueError("Could not find .spad addresses") + return 0 spad_size = spad_end - spad_start return spad_size \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index d6917cad..9a9785e1 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -1,14 +1,14 @@ import os from torch import empty_strided -from typing import List, Optional, cast +from typing import List, Optional +import sympy from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel -from torch._inductor.ir import Buffer from torch._inductor.ir import IRNode -from torch._inductor.ir import ReinterpretView from torch._inductor.codecache import write_atomic import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common BMM_TEMPLATE = r""" // BMM kernel @@ -21,70 +21,91 @@ // TILE_K = {{ TILE_K }} // SUB_TILE_M = {{ SUB_TILE_M }} // SUB_TILE_N = {{ SUB_TILE_N }} -#map0 = affine_map<(d0, d1, d2) -> ({{ X_map }})> -#map1 = affine_map<(d0, d1, d2) -> ({{ W_map }})> -#map2 = affine_map<(d0, d1, d2) -> (d0 * {{ M * N }} + d1 * {{ N }} + d2)> -memref.global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index{% if Bias %} - %c_mvin3 = arith.constant 14 : index{% endif %} - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %axis = arith.constant 2 : index - %X_buffer = memref.get_global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32>{% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} %c0 = arith.constant 0 : index -{{ kernel.def_local_vars() }} - affine.for %b=0 to {{ B }} { - affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0 = 0 to {{ B }} { + affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + {% if Bias -%} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} + {%- else -%} + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + + affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_K], indent_size=10) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[1, SUB_TILE_K, SUB_TILE_N], indent_size=10) }} + linalg.matmul ins(%X_buffer2D, %W_buffer2D : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) + outs(%Y_buffer2D : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) + } { accumulation_loop=true, subtile_loop="k" } + {{kernel.store_output(indent_size=8)}} + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } + } { outer_loop=true } + return +} +""" + +BMM_PROLOGUE_TEMPLATE = r""" +// BMM Prologue kernel +// BATCH = {{ B }} +// M = {{ M }} +// N = {{ N }} +// K = {{ K }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// SUB_TILE_M = {{ SUB_TILE_M }} +// SUB_TILE_N = {{ SUB_TILE_N }} +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + %c0 = arith.constant 0 : index + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0 = 0 to {{ B }} { + affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - - %index2 = affine.apply #map2(%b, %t_m, %t_n) {% if Bias -%} - memref.dma_start %Bias[ - {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer2D[0, 0], %c_mvin3, %tag0[%c0], % - {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} - , %vstride : memref< - {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} - xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1 , {{ TILE_M }}] } + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} {%- else -%} - affine.vector_store %v0, %Y_buffer2D[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%b, %t_m, %t_k) - %index1 = affine.apply #map1(%b, %t_k, %t_n) - memref.dma_start %X[%index0], %X_buffer2D[%c0, %c0], %c_mvin, %tag1[%c0], %axis, %vstride - : memref<{{ B * M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[1, {{ TILE_M }}]} - memref.dma_start %W[%index1], %W_buffer2D[%c0, %c0], %c_mvin2, %tag2[%c0], %axis, %vstride - : memref<{{ B * K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_K }}]} - + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { + {{kernel.load_input(indent_size=10)}} linalg.matmul ins(%X_buffer2D, %W_buffer2D : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) outs(%Y_buffer2D : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true } + } { accumulation_loop=true, subtile_loop="k" } {{kernel.store_output(indent_size=8)}} - } { outer_loop=true } - } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } } { outer_loop=true } return } """ BMM_REDUCTION_TEMPLATE = r""" -// BMM kernel +// BMM Reduction kernel // BATCH = {{ B }} // M = {{ M }} // N = {{ N }} @@ -94,65 +115,39 @@ // TILE_K = {{ TILE_K }} // SUB_TILE_M = {{ SUB_TILE_M }} // SUB_TILE_N = {{ SUB_TILE_N }} -#map0 = affine_map<(d0, d1, d2) -> ({{ X_map }})> -#map1 = affine_map<(d0, d1, d2) -> ({{ W_map }})> -#map2 = affine_map<(d0, d1, d2) -> (d0 * {{ M * N }} + d1 * {{ N }} + d2)> -memref.global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index{% if Bias %} - %c_mvin3 = arith.constant 14 : index{% endif %} - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %axis = arith.constant 2 : index - %X_buffer = memref.get_global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32>{% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} %c0 = arith.constant 0 : index -{{ kernel.def_local_vars() }} - affine.for %b=0 to {{ B }} { - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { - %red_idx = affine.apply affine_map<(d0, d1) -> ({{M}}*d0 + d1)>(%b, %t_n) - {{kernel.reduction_acc()}} affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} {{kernel.reduction_iter_arg()}} { + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0=0 to {{ B }} { + affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { + affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %index2 = affine.apply #map2(%b, %t_m, %t_n) {% if Bias -%} - memref.dma_start %Bias[ - {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer2D[0, 0], %c_mvin3, %tag0[%c0], % - {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} - , %vstride : memref< - {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} - xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1 , {{ TILE_M }}] } + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} // Why not N,M? Currently, dma-fine-grained pass assume M->N order... {%- else -%} - affine.vector_store %v0, %Y_buffer2D[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%b, %t_m, %t_k) - %index1 = affine.apply #map1(%b, %t_k, %t_n) - memref.dma_start %X[%index0], %X_buffer2D[%c0, %c0], %c_mvin, %tag1[%c0], %axis, %vstride - : memref<{{ B * M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[1, {{ TILE_M }}]} - memref.dma_start %W[%index1], %W_buffer2D[%c0, %c0], %c_mvin2, %tag2[%c0], %axis, %vstride - : memref<{{ B * K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_K }}]} - + affine.vector_store %v0, %Y_buffer[0, 0, 0] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_K], indent_size=10) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[1, SUB_TILE_K, SUB_TILE_N], indent_size=10) }} linalg.matmul ins(%X_buffer2D, %W_buffer2D : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) outs(%Y_buffer2D : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true, loop_k=true } + } { accumulation_loop=true, subtile_loop="k" } {{kernel.store_output(indent_size=8)}} - } { outer_loop=true, loop_m=true } + } { outer_loop=true, subtile_loop="m" } {{kernel.reduction_output(indent_size=6)}} - } { outer_loop=true, loop_n=true} + } { outer_loop=true, subtile_loop="n" } } { outer_loop=true } return } @@ -166,12 +161,12 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, **kwargs): if template_buffer_node is not None: self.output_node = template_buffer_node - #if epilogue_nodes is not None and len(epilogue_nodes) > 0: - # self.output_node = cast(Buffer, epilogue_nodes[-1]) + # Extract input arguments info X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] @@ -182,79 +177,150 @@ def render(self, W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]]) if len(X_tensor.size()) > 3: X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]]) + B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2] + W_stride = W_tensor.stride() X_stride = X_tensor.stride() - W_map = " + ".join([f"d{idx}*{s}" for idx, s in enumerate(W_stride)]) - X_map = " + ".join([f"d{idx}*{s}" for idx, s in enumerate(X_stride)]) - B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2] + # Select tile size n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + TOG_latency = M if TILE_M > M else TILE_M kernel.loop_size = [TOG_latency, TILE_N, TILE_K] - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K if TILE_K < kernel.vector_lane else kernel.vector_lane + TILE_K = TILE_K // 2 if prologue_nodes else TILE_K - if n_extra_node==1 and epilogue_nodes[0].is_reduction(): + # Select template code + nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else [] + if nr_reduction_nodes: template = BMM_REDUCTION_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index2", "index2": "index1"} nr_rdim = 1 + elif prologue_nodes: + template = BMM_PROLOGUE_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} + nr_rdim = 0 else: template = BMM_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} nr_rdim = 0 + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 2 + loop_dim = [sympy.Symbol("index0"), sympy.Symbol("index1"), sympy.Symbol("index2"), sympy.Symbol("index3")] + X_tile_size = [1, TILE_M, TILE_K] + X_tile_stride = [0, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("X_buffer") + X_stride = X_tensor.stride() + X_idx = [loop_dim[0]*X_stride[0], loop_dim[1]*X_stride[1], loop_dim[3]*X_stride[2]] # To keep index arguemnt order, we used index_list + + W_tile_size = [1, TILE_K, TILE_N] + W_tile_stride = [0, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("W_buffer") + W_stride = W_tensor.stride() + W_idx = [loop_dim[0]*W_stride[0], loop_dim[3]*W_stride[1], loop_dim[2]*W_stride[2]] + + vlane_split_axis = vlane_split_axis if nr_rdim==0 else 1 + Y_tile_size = [1, TILE_M, TILE_N] if nr_rdim == 0 else [1, TILE_N, TILE_M] + Y_tile_stride=[0, 1, TILE_M] if nr_rdim == 0 else [0, TILE_M, 1] + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("Y_buffer") + Y_stride = Y.get_layout().stride + if nr_rdim == 0: + Y_idx = [loop_dim[0]*Y_stride[0], loop_dim[1]*Y_stride[1], loop_dim[2]*Y_stride[2]] + else: + Y_idx = [loop_dim[0]*Y_stride[0], loop_dim[2]*Y_stride[2], loop_dim[1]*Y_stride[1]] + + # Extract Bias info + if Bias is not None: + Bias_stride = Bias.get_layout().stride + if nr_rdim == 0: + Bias_idx = [loop_dim[0]*Bias_stride[0], loop_dim[1]*Bias_stride[1], loop_dim[2]*Bias_stride[2]] + else: + Bias_idx = [loop_dim[0]*Bias_stride[0], loop_dim[2]*Bias_stride[2], loop_dim[1]*Bias_stride[1]] + else: + Bias_idx = None + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - B=B, - M=M, - N=N, - K=K, - TILE_M=TILE_M, - TILE_N=TILE_N, - TILE_K=TILE_K, + B=B, M=M, N=N, K=K, + TILE_M=TILE_M, TILE_N=TILE_N, TILE_K=TILE_K, SUB_TILE_M=SUB_TILE_M, SUB_TILE_N=SUB_TILE_N, SUB_TILE_K=SUB_TILE_K, DATA_STYPE="f32", - DATA_SIZE=4, - X = X, - W = W, - Y = Y, - Bias = Bias, - Bias_rank = len(Bias.data.get_size()) if Bias is not None else 0, - X_map = X_map, - W_map = W_map, - Y_numel = B * M * N, + X = X, W = W,Y = Y, Bias = Bias, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, input_reorder = self.input_reorder ) - kernel.store_info = dict( + if prologue_nodes: + prologue_output_name = list(prologue_nodes[0].read_writes.writes)[0].name + if prologue_output_name == X.get_name(): + # Input fusion case + prologue_var = "X" + prologue_sram_var = "X_buffer" + prologue_tile_desc = X_tile_desc + prologue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2":"index3"} + is_input_fused = True + else: + # Weight fusion case + prologue_var = "W" + prologue_sram_var = "W_buffer" + prologue_tile_desc = W_tile_desc + prologue_dim_aliasing = {"index0":"index0", "index1":"index3", "index2":"index2"} + is_input_fused = False + + kernel.prologue_info = dict ( + input_dram_var = "X", + input_sram_var = "X_buffer", + input_tile_desc = X_tile_desc, + input_idx = X_idx, + input_subtile_size = [1, TILE_M, TILE_K], # TODO. Curently, Subtiling is not supported for prologue template + input_dim_aliasing = {"index0":"index0", "index1":"index1", "index2":"index3"}, + + weight_dram_var = "W", + weight_sram_var = "W_buffer", + weight_tile_desc = W_tile_desc, + weight_idx = W_idx, + weight_subtile_size = [1, TILE_K, TILE_N], # TODO. Curently, Subtiling is not supported for prologue template + weight_dim_aliasing = {"index0":"index0", "index1":"index3", "index2":"index2"}, + + # Descriptor for fusion + dram_var = prologue_var, + sram_var = prologue_sram_var, + dram_tile_desc = prologue_tile_desc, + dim_aliasing = prologue_dim_aliasing, + is_bmm = True, + is_input_fused = is_input_fused + ) + + kernel.epilogue_info = dict( output_node = self.output_node.name, - dependent_buf = [], sram_var = "Y_buffer", dram_var = "Y", - index_var = "index2", - tag_var = "tag", - vlane_split_axis = 2, - vlane_stride = 1, - mlir_dtype = kernel.render_options['DATA_STYPE'], - dram_shape = f"memref<{kernel.render_options['Y_numel']}x{kernel.render_options['DATA_STYPE']}>", - tile_size = (1, TILE_M, TILE_N), - tile_stride = [1, 1, TILE_M], + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, nr_rdim = nr_rdim, - reduction_idx = "red_idx" + dim_aliasing = epilogue_dim_aliasing ) code = self._template_from_string(template).render(**kernel.render_options) kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) - - self.header = f"float X_spad[{kernel.get_spad_size_per_lane(TILE_M, TILE_K)}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float W_spad[{kernel.get_spad_size_per_lane(TILE_K, TILE_N)}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{kernel.get_spad_size_per_lane(TILE_M, TILE_N)}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{TILE_M * TILE_K}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float W_spad[{TILE_K * TILE_N}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{TILE_M * TILE_N}] __attribute__ ((section(\".spad\")));\n" - return code def codegen_header(self, code, extra_headers): @@ -264,6 +330,6 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) \ No newline at end of file + write_atomic(gem5_write_path, extra_headers[1]) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 9a3c4148..725fec5d 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -4,6 +4,7 @@ import os import math import torch +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from torch._dynamo.utils import dynamo_timed from torch._inductor.codegen import cpp, wrapper, common, memory_planning @@ -845,7 +846,7 @@ def __init__(self, kernel_group, reason=None): self.reduction_prefix = IndentedBuffer() self.reduction_suffix = IndentedBuffer() self.applys = IndentedBuffer() - self.body = IndentedBuffer() + self.masks = IndentedBuffer() self.dma_loads = IndentedBuffer() self.dma_stores = IndentedBuffer() self.indexed_buffer = IndentedBuffer() @@ -859,6 +860,7 @@ def __init__(self, kernel_group, reason=None): self.reduction_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") self.spad_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="spad") self.apply_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="apply") + self.mask_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="mask") self.iterator_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="iter") self.init_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init") self.init_vec_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init_vec") @@ -872,12 +874,13 @@ def __init__(self, kernel_group, reason=None): self.tags = dict() self.dma_read_cache = dict() self.dma_write_cache = dict() + self.spadbuf_counter = 0 self.dma_read_counter = 1 self.dma_write_counter = 1 + self.dma_tag_id = 0 self.affine_yield = {} self.welford_reduce_out = None self.reduce_iterator = {} - self.is_template_kernel = False self.spad_buffer_dict = dict() self.base_vector_initialized = False @@ -919,19 +922,22 @@ def convert_index(self, expr, buffer): index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})") return index - def parse_indices(self, expr, buffer=None) -> common.CSEVariable: + def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> common.CSEVariable: if buffer is None: buffer = self.applys # Constant case - if expr.is_number: + if expr.is_number and len(indirect_dims) == 0: return self.get_const_cse(int(expr)) # Identity case - if len(expr.args) == 0: + if len(expr.args) == 0 and len(indirect_dims) == 0: return expr - args = list(expr.args) + if len(expr.args) == 0: + args = [expr] + else: + args = list(expr.args) # Sort index variable.. ex) (%index1, %index0) args_dict = {term: list(term.free_symbols)[0] for term in args if term.free_symbols} sorted_args = sorted(args_dict.keys(), key=lambda term: str(args_dict[term])) @@ -947,8 +953,48 @@ def parse_indices(self, expr, buffer=None) -> common.CSEVariable: indices.append(str(new_arg)) # Extract index var + indirect_args = [f"%{i}" for i in indirect_dims] expr_str = str(expr) args = ", ".join(map(str, indices)) + map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[{','.join(indirect_dims)}] -> ({expr_str})>") + args = ", ".join([f"%{i}" for i in indices]) + index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[{','.join(indirect_args)}] {comments}") + return index + + def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: + if buffer is None: + buffer = self.applys + zero_var = self.get_const_cse(0) + expr_list = [arg for arg in expr_list] + dim_list = [f"d{i}" for i in range(len(expr_list))] + + if len(expr_list) == 1 and expr_list[0].is_number: + # Constant case + return self.get_const_cse(int(expr_list[0])) + elif len(expr_list) == 1 and expr_list[0].is_symbol: + # Identity case + return expr_list[0] + + indices = [] + new_expr_list = [0] * len(expr_list) + for idx, arg in enumerate(expr_list): + if arg.is_Mul and arg.args[0].is_number: + new_arg = sympy.Symbol(str(self.convert_index(arg.args[1], buffer))) + new_expr_list[idx] = arg.subs(arg.args[1], dim_list[idx]) + indices.append(str(new_arg)) + elif not arg.is_number: + new_arg = sympy.Symbol(str(self.convert_index(arg, buffer))) + new_expr_list[idx] = new_arg.subs(new_arg, dim_list[idx]) + indices.append(str(new_arg)) + else: + const_var = self.get_const_cse(int(arg)) + new_arg = sympy.Symbol(f"{const_var}") + new_expr_list[idx] = arg + indices.append(str(new_arg)) + + # Extract index var + expr_str = str(sum(new_expr_list)) + args = ", ".join(map(str, dim_list)) map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[] -> ({expr_str})>") args = ", ".join([f"%{i}" for i in indices]) index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[]") @@ -958,15 +1004,18 @@ def load(self, name: str, index: sympy.Expr): index = self.rename_indexing(index) index = self.convert_indirect_indexing(index) padding = self.get_padding_type() - dram_var = self.kernel_group.args.input(name) + # Extract dram info + dram_var = self.kernel_group.args.input(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - local_tile_desc, index_var = self.get_dma_info(name, index) + + # Extract sram info + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index) vlane_split_axis = local_tile_desc.vlane_split_axis vlane_stride = local_tile_desc.vlane_stride tile_numel_per_lane = local_tile_desc.get_numel_per_lane() - dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) tile_stride = local_tile_desc.get_tile_stride() @@ -975,33 +1024,18 @@ def load(self, name: str, index: sympy.Expr): compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() # Define scratch pad buffer - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index) + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) # MVIN Encoding + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding={padding}}}" code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride, padding) + dram_shape, tile_shape, attribute) self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # Generate vector load instruction - needs_mask = self.compute_body_loop.size % self.compute_body_loop.step != 0 and len(index.free_symbols) == len(self.ranges) if compute_vec_size > 1: - if needs_mask: - index_shape = f"vector<{self.compute_body_loop.step}xindex>" - mask_shape = f"vector<{compute_vec_size}xi1>" - step_vec = self.cse.generate(self.loads, f"vector.step : {index_shape}") - upper_bound = self.get_const_cse(self.compute_body_loop.size, "index") - gap = self.cse.generate(self.loads, f"arith.subi %{upper_bound}, %{self.compute_idx} : index") - gap_vec = self.cse.generate(self.loads, f"vector.broadcast %{gap} : index to {index_shape}") - mask_var = self.cse.generate(self.loads, f"arith.cmpi ult, %{step_vec}, %{gap_vec} : {index_shape}") - if padding: - pad_val = self.const_cse.generate(self.const_buffer, f"arith.constant 0x{mlir_common.MLIR_INF['-inf'][mlir_dtype]:x} : {mlir_dtype}") - else: - pad_val = self.get_const_cse(0, mlir_dtype) - pad_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{pad_val} : {mlir_dtype} to {vshape}") - line = f"vector.maskedload %{sram_var}[{compute_index_var}], %{mask_var}, %{pad_vec} : {tile_shape}, {mask_shape}, {vshape} into {vshape}" - else: - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" else: operation = "affine.load" line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" @@ -1017,16 +1051,14 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] # Prepare dma instruction - local_tile_desc, index_var = self.get_dma_info(name, index) + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index) vlane_split_axis = local_tile_desc.vlane_split_axis vlane_stride = local_tile_desc.vlane_stride - tile_numel_per_lane = local_tile_desc.get_numel_per_lane() dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) tile_stride = local_tile_desc.get_tile_stride() tile_size = local_tile_desc.get_tile_size() - # Compute vector unit size vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() @@ -1038,7 +1070,7 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): if require_store: # Define scratch pad buffer - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index) + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # Generate vector store instruction store_size, operand_type = self.var_info[value] @@ -1057,8 +1089,9 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): sram_index_var = self.spad_buffer_dict[str(value)][3] # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) + dram_shape, tile_shape, attribute) self.dma_stores.writeline(common.DeferredLine(name, code)) def reduction(self, dtype, src_dtype, reduction_type, value): @@ -1103,6 +1136,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value): else: # Adjust shape and inital value init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") + self.register_var_info(init_vec, [vec_len, type_name]) acc_var = init_vec # Reduction body prepare @@ -1121,6 +1155,9 @@ def reduction(self, dtype, src_dtype, reduction_type, value): self.init_cse.reduction_cache[reduction_key] = init_vec # Reduction body codegen + mask_shape, mask_var = self.get_mask() + if mask_var is not None: + value = ops.where(mask_var, value, init_vec) result = reduction_partial_combine_vec(reduction_type, value, body_iter_arg) self.compute_body_loop.reduction_vars[body_acc] = (reduction_type, body_iter_arg, iterator, reduced_shape) self.compute_body_loop.affine_yield[result] = reduced_shape @@ -1151,7 +1188,9 @@ def store_reduction(self, name, index, value): # Store reduction can't share cached value stored in cse, # since it is not innermost loop body. tmp_cse = self.cse + tmp_apply_cse = self.apply_cse self.cse = self.reduction_cse + self.apply_cse = self.reduction_cse dram_var = self.kernel_group.args.output(name) dtype = V.graph.get_dtype(name) @@ -1159,10 +1198,9 @@ def store_reduction(self, name, index, value): index = self.rename_indexing(index) # Tile is always reuduced in inner loop - local_tile_desc, index_var = self.get_dma_info(name, index, broadcast=False, store_reduction=True, buffer=self.reductions_suffix) + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index, broadcast=False, store_reduction=True, buffer=self.reductions_suffix) vlane_split_axis = local_tile_desc.vlane_split_axis vlane_stride = local_tile_desc.vlane_stride - tile_numel_per_lane = local_tile_desc.get_numel_per_lane() dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) @@ -1172,9 +1210,8 @@ def store_reduction(self, name, index, value): vshape = f"{mlir_dtype}" else: vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index) + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) if self.welford_reduce_out is not None: - # raise NotImplementedError() sum, sqr_sum, _ = self.welford_reduce_out # mean divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.ranges[self.reduction_depth])} : f32") @@ -1205,12 +1242,14 @@ def store_reduction(self, name, index, value): # MVOUT Encoding # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) + dram_shape, tile_shape, attribute) self.reductions_suffix.writeline(common.DeferredLine(name, code)) # Restore origin cse self.cse = tmp_cse + self.apply_cse = tmp_apply_cse def indirect_indexing(self, index_var, size, check=True): return str(index_var) @@ -1306,7 +1345,7 @@ def index_expr(self, index, dtype): self.header.writeline(f"{c_type} {new_name}_spad[{compute_vec_size}] __attribute__ ((section(\".spad\")));") self.gem5_header.writeline(f"{c_type} {new_name}_spad[{compute_vec_size}] __attribute__((aligned(64)));") self.global_vars.writeline(f"memref.global @{new_name}_spad : {tile_shape}") - self.global_vars_dict[new_name] = [] + self.global_vars_dict[new_name] = dict() sram_var = self.spad_cse.generate(self.spad_buffer, f"memref.get_global @{new_name}_spad : {tile_shape}") # Initialize base vector if not self.base_vector_initialized: @@ -1375,6 +1414,7 @@ def codegen_loops(self): code.writelines(self.compute_body_loop.lines()) with contextlib.ExitStack() as stack: stack.enter_context(code.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) + code.splice(self.masks) code.splice(self.loads) code.splice(self.compute) code.splice(self.stores) @@ -1394,7 +1434,7 @@ def make_choices(self, nodes, kernel_name): for vlane_stride in [2, 4, 8]: os.environ['TORCHSIM_VECTOR_LANE_STRIDE'] = str(vlane_stride) previous_tile_size = initial_tile_size - increase_dim = 0 # increase the first dimension + increase_dim = -2 # increase the first dimension while previous_tile_size[increase_dim] * 2 <= previous_ranges[increase_dim] and previous_tile_size[increase_dim] <= 2 ** 13 and prevent_infinite_loop < 10: incrase_dim = -1 # only increase the last dimension prevent_infinite_loop += 1 @@ -1443,7 +1483,7 @@ def get_cycle(choice): if len(choices) == 0: # can't autotune return None - with ThreadPoolExecutor(max_workers=5) as executor: + with ThreadPoolExecutor(max_workers=1) as executor: results = list(executor.map(get_cycle, choices)) max_idx = results.index(min(results)) print(f"[Auto-tune] Optimal tile size: {choices[max_idx][2].tile_desc.get_tile_size()}, vlane_stride: {choices[max_idx][2].tile_desc.vlane_stride}, cycles: {results[max_idx]}") @@ -1476,6 +1516,12 @@ def _prepare_simulator_headers(self, src_code): write_atomic(spike_write_path, self.header.getvalue() + spad_end_symbol + spad_section_end_symbol) write_atomic(gem5_write_path, self.gem5_header.getvalue()) + def get_arg_info(self, name): + arg_info = dict() + arg_info.update(V.graph.graph_inputs) + arg_info.update({i.get_name(): i for i in V.graph.buffers}) + return arg_info[name] + def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffer=None): # Need more argument? """ A tile descriptor exists that is configured on a kernel group @@ -1489,7 +1535,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # TODO. kg_tile_desc = self.kernel_group.tile_desc - buffer_info = self.buffer_types[name] # Note: index could contain symbols that represent dynamic axies # Extract dimension of index(e.g, index0, index1) local_dims = [int(str(i)[5:]) for i in index.free_symbols if "index" in str(i)] @@ -1498,23 +1543,14 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane) local_dims.sort() # Assume that smaller index is placed in the outer loop indirect_dims = [f"{i}" for i in index.free_symbols if "tmp" in str(i)] - indirect_arg_dims = [f"%{i}" for i in index.free_symbols if "tmp" in str(i)] for indirect_dim in indirect_dims: index = index.replace(sympy.Symbol(indirect_dim), 0) # Reduction can have two type of tile size if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)): - # We have to create custom apply map to provide dram stride - # ex) (d0, d1, ... dn, dn+1, dn+2, dk) -> (s0*d0 + s1*d1 + ... dn*0+ dn+1*0 + ... dk*0 + const) - fake_dim = self.get_const_cse(0) - input_expr = ",".join(["d"+str(i) for i in total_dims]) - output_expr = str(index).replace('index', 'd') - input_argument = ",".join(["%index" + str(i) if i in local_dims else f"%{fake_dim}" for i in total_dims]) - map_var = self.map_cse.generate(self.global_vars, f"affine_map<({input_expr})[{','.join(indirect_dims)}] -> ({output_expr})>") - index_var = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({input_argument})[{','.join(indirect_arg_dims)}]") local_dims = total_dims # Brodatcast tile shape - else: - index_var = self.parse_indices(index, buffer=buffer) + + index_var = self.parse_indices(index, buffer=buffer, indirect_dims=indirect_dims) if kg_tile_desc.vlane_split_axis in local_dims: local_vlane_split_axis = local_dims.index(kg_tile_desc.vlane_split_axis) @@ -1523,9 +1559,15 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # Case 0. Tile is 0-D scalar if len(local_dims) == 0: - local_tile_desc.set_tile_size([kg_tile_desc.get_used_vlane() * kg_tile_desc.vlane_stride]) # Force it to use vector instruction. - local_tile_desc.vlane_split_axis = local_vlane_split_axis # last axis - local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + if not store_reduction: + local_tile_desc.set_tile_size([kg_tile_desc.get_used_vlane() * kg_tile_desc.vlane_stride]) # Force it to use vector instruction. + local_tile_desc.vlane_split_axis = local_vlane_split_axis # last axis + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + else: + local_tile_desc.set_tile_size([1]) + local_tile_desc.vlane_split_axis = 0 + local_tile_desc.vlane_stride = 1 + dram_stride = [0] # Edge case # Case 1. Tile is 1-D vector type elif len(local_dims) == 1 and len(local_dims) <= self.reduction_depth: local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(local_dims[0])]) @@ -1534,7 +1576,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # Case 2. Tile is 1-D vector type with reduction elif len(local_dims) == 1 and len(local_dims) == self.reduction_depth + 1: local_tile_desc.set_tile_size([1, kg_tile_desc.get_dim_size(local_dims[0])]) - local_tile_desc.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vlane_split_axis = local_vlane_split_axis + 1 local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride # Case 3. Tile is 2-D tile elif len(local_dims) == 2: @@ -1558,6 +1600,14 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims]) local_tile_desc.vlane_split_axis = local_vlane_split_axis local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + # Case 4. Tile is 4-D tile (e.g., Convolution epilogue) + elif len(local_dims) == 4: + is_reduction = self.reduction_depth < 3 and not store_reduction + if is_reduction: + raise NotImplementedError("Currently not implemented... ;)") + local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims]) + local_tile_desc.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride else: raise NotImplementedError("Currently not implemented... ;)") @@ -1574,34 +1624,62 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.set_tile_size(new_tile_size) local_tile_desc.vlane_split_axis = new_vlane_split_axis - return local_tile_desc, index_var - - def get_dma_code(self, dma_type_name, attribute1, attribute2, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, - tag_name, dram_shape, tile_shape, tile_stride, padding_type=0): - dma_key = (attribute1, attribute2, mlir_dtype) + # Calculate dram stride + dram_stride = [0] * local_tile_desc.get_nr_dim() + if index.is_Symbol: + dim_idx = int(str(index)[5:]) + dram_stride[dim_idx] = 1 + elif index.is_Number: + pass + else: + dram_dict = defaultdict(list) + # Assume that div will have high priority than mod + for arg in index.as_ordered_terms(): + coeff, dim = arg.as_coeff_mul() + if len(dim) == 0: + continue + real_dim = list(dim[0].free_symbols)[0] + dram_dict[str(real_dim)].append(coeff) + # Add missing dims if not added + max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1 + for i in range(max_dim): + target_dim = f"index{i}" + if target_dim not in str(index): + dram_dict[target_dim] = [0] + sorted_keys = sorted(dram_dict.keys()) + dram_stride = sum((dram_dict[key] for key in sorted_keys), []) + + # FIXME. It will be nice to modify node instead of this exception handling... + if len(self.itervars) == 1 and self.reduction_depth == 0: + # In case of reduction loop only case, we will add dummy loop so shift it once + dram_stride = [0] + dram_stride[:-1] + return local_tile_desc, index_var, dram_stride + + def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute): + dma_key = (vlane_split_axis, vlane_stride, mlir_dtype) if dma_type_name == "MVIN" and dma_key in self.dma_read_cache: - dma_type, attribute1, attribute2 = self.dma_read_cache[dma_key] + dma_type, vlane_split_axis, vlane_stride = self.dma_read_cache[dma_key] elif dma_type_name == "MVOUT" and dma_key in self.dma_write_cache: - dma_type, attribute1, attribute2 = self.dma_write_cache[dma_key] + dma_type, vlane_split_axis, vlane_stride = self.dma_write_cache[dma_key] else: - attribute1 = self.get_const_cse(attribute1) - attribute2 = self.get_const_cse(attribute2) + vlane_split_axis = self.get_const_cse(vlane_split_axis) + vlane_stride = self.get_const_cse(vlane_stride) if dma_type_name == "MVIN": dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_read_counter}"]) self.dma_read_counter += 1 - self.dma_read_cache[dma_key] = [dma_type, attribute1, attribute2] + self.dma_read_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] else: dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_write_counter}"]) - # self.dma_write_counter += 1 Is it okay? - self.dma_write_cache[dma_key] = [dma_type, attribute1, attribute2] - tag = self.get_tag_cse(tag_name) + self.dma_write_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] + tag = self.get_tag_cse() zero_cse = self.get_const_cse(0) # Prepare opearnds and attributes dram_operand = f"%{dram_var}[%{dram_index_var}]" sram_operand = f"%{sram_var}[{sram_index_var}]" # Use string tag_var = f"%{tag}[%{zero_cse}]" - dma_attribute = f"%{attribute1}, %{attribute2}" + dma_attribute = f"%{vlane_split_axis}, %{vlane_stride}" sram_shape = tile_shape tag_shape = "memref<1xi32>" @@ -1612,86 +1690,46 @@ def get_dma_code(self, dma_type_name, attribute1, attribute2, mlir_dtype, dram_v src_operand, dst_operand = sram_operand, dram_operand src_shape, dst_shape = sram_shape, dram_shape - code = f"memref.dma_start {src_operand}, {dst_operand}, %{dma_type}, {tag_var}, {dma_attribute} : {src_shape}, {dst_shape}, {tag_shape}" - code = code + f" {{padding={padding_type}, sram_stride={tile_stride}}}" - return code + return f"memref.dma_start {src_operand}, {dst_operand}, %{dma_type}, {tag_var}, {dma_attribute} : {src_shape}, {dst_shape}, {tag_shape} {attribute}" - def adjust_tile_size(self): - if self.read_writes is not None: - read_writes = list(self.read_writes.reads) + list(self.read_writes.writes) - cv_list = [] - for node in read_writes: - if len(node) > 1: - cv_list.append(self.get_constant_vector2(node[1])) - max_element = max(cv_list, key=len) - max_nr_dim = len(max_element) - - sorted_max_element = sorted(max_element, key=lambda x:x[0]) - # Force vector tile size when 3D node is originated from view - if max_nr_dim == 3 and max_nr_dim != len(self.itervars): - self.tile_desc.n_col = min(self.tile_desc.get_tile_size(), sorted_max_element[1][0]) - self.tile_desc.n_row = 1 - return - - # Case 1. vector kernel - if len(self.itervars) == 1: - tile_size = self.tile_desc.get_tile_size() if self.tile_desc.get_tile_size() < self.ranges[0] else self.ranges[0] - min_tile_size_unit = self.vector_lane * self.vlen // (8 * self.precision) # TODO: VCIX widening is not implemented - self.tile_desc.n_col = math.ceil(tile_size / min_tile_size_unit) * min_tile_size_unit # padding - self.tile_desc.n_row = 1 - elif len(self.itervars) == 0: - self.tile_desc.n_col = 1 - self.tile_desc.n_row = 1 - - # Case 2. 2-D tensor (e.g., softmax) - if len(self.itervars) == 2 and self.reduction_depth == len(self.itervars): - # Avoid too much padding - if (self.ranges[0] <= self.vector_lane and self.ranges[0] <= self.tile_desc.n_row): - self.tile_desc.n_row = self.ranges[0] - self.tile_desc.used_vector_lane = self.ranges[0] - - # Case 2. 2-D reduction (e.g., batchnorm) - if len(self.itervars) == 2 and self.reduction_depth == len(self.itervars) - 1: - if (((self.ranges[0] + 1) // 2) <= self.vector_lane and ((self.ranges[0] + 1) // 2) <= self.tile_desc.n_row): - self.tile_desc.n_row = ((self.ranges[0] + 1) // 2) * 2 - self.tile_desc.used_vector_lane = (self.ranges[0] + 1) // 2 - - # Case 2. 3-D tensor kernel without reduction. Access vector granule! - if len(self.itervars) == 3 and self.reduction_depth == len(self.itervars): - self.tile_desc.n_col = self.ranges[-1] - self.tile_desc.n_row = 1 - - # Case 3. N-D tensor kernel with reduction. Not implemented. Need this? - if len(self.itervars) >= 3 and self.reduction_depth < len(self.itervars): - raise NotImplementedError() - - def get_scratchpad_buffer(self, dtype, name, tile_size_per_lane, dram_tile_shape, indices, raw_index, is_template=False, buffer=None): + def allocate_sram_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None, forced_name=None): c_type = mlir_common.DTYPE_TO_C[dtype] + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + tile_numel_per_lane = tile_desc.get_numel_per_lane() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) # Make sure each lane's buffer has at least two element - tile_size = max(tile_size_per_lane, 2) * self.vector_lane + tile_size = max(tile_numel_per_lane, 2) * self.vector_lane if buffer is None: buffer = self.spad_buffer - if name not in self.global_vars_dict: - self.global_vars_dict[name] = list() + if dram_name not in self.global_vars_dict: + self.global_vars_dict[dram_name] = dict() - if str(raw_index) not in self.global_vars_dict[name]: - new_name = f"{name}_{len(self.global_vars_dict[name])}" + if str(raw_index) not in self.global_vars_dict[dram_name]: + new_name = f"buf{self.spadbuf_counter}_spad" if forced_name is None else f"{forced_name}_spad" + self.spadbuf_counter+=1 # Add definition to header - self.header.writeline(f"{c_type} {new_name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") - self.gem5_header.writeline(f"{c_type} {new_name}_spad[{tile_size}] __attribute__((aligned(64)));") - self.global_vars.writeline(f"memref.global @{new_name}_spad : {dram_tile_shape}") - self.global_vars_dict[name].append(str(raw_index)) + self.header.writeline(f"{c_type} {new_name}[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") + self.gem5_header.writeline(f"{c_type} {new_name}[{tile_size}] __attribute__((aligned(64)));") + self.global_vars.writeline(f"memref.global @{new_name} : {tile_shape}") + self.global_vars_dict[dram_name][str(raw_index)] = new_name else: - new_name = f"{name}_{self.global_vars_dict[name].index(str(raw_index))}" - sram_var = self.spad_cse.generate(buffer, f"memref.get_global @{new_name}_spad : {dram_tile_shape}") + new_name = self.global_vars_dict[dram_name][str(raw_index)] + return new_name - zero_cse = self.get_const_cse(0) - sram_dims = len(dram_tile_shape.split("x")) - 1 - sram_index_var = ",".join([f"%{zero_cse}"] * sram_dims) + def get_scratchpad_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None): + if buffer is None: + buffer = self.spad_buffer - return sram_var, indices, sram_index_var + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + new_name = self.allocate_sram_buffer(dtype, dram_name, tile_desc, raw_index, buffer=buffer) + sram_var = self.spad_cse.generate(buffer, f"memref.get_global @{new_name} : {tile_shape}") + + zero_cse = self.get_const_cse(0) + sram_index_var = ",".join([f"%{zero_cse}"] * tile_desc.get_nr_dim()) + return sram_var, sram_index_var def get_const_cse(self, value, dtype="index") -> common.CSEVariable: # Type convert @@ -1704,11 +1742,30 @@ def get_const_cse(self, value, dtype="index") -> common.CSEVariable: self.consts[str(value)+dtype] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") return self.consts[str(value)+dtype] - def get_tag_cse(self, value, shape="memref<1xi32>"): + def get_tag_cse(self, value=None, shape="memref<1xi32>"): + if value is None: + value = self.dma_tag_id + self.dma_tag_id += 1 if value not in self.tags: - self.tags[value] = self.alloc_cse.generate(self.alloc_buffer, f"memref.alloc() : {shape}") + self.tags[value] = self.alloc_cse.generate(self.alloc_buffer, f"memref.alloc() : {shape} // {value}") return self.tags[value] + def get_mask(self): + if self.compute_body_loop.size % self.compute_body_loop.step == 0: + return None, None + compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() + index_shape = f"vector<{self.compute_body_loop.step}xindex>" + mask_shape = f"vector<{compute_vec_size}xi1>" + + upper_bound = self.get_const_cse(self.compute_body_loop.size) + step_vec = self.const_cse.generate(self.const_buffer, f"vector.step : {index_shape}") + + gap = self.mask_cse.generate(self.masks, f"arith.subi %{upper_bound}, %{self.compute_idx} : index") + gap_vec = self.mask_cse.generate(self.masks, f"vector.broadcast %{gap} : index to {index_shape}") + mask_var = self.mask_cse.generate(self.masks, f"arith.cmpi ult, %{step_vec}, %{gap_vec} : {index_shape}") + self.register_var_info(mask_var, [compute_vec_size, "i1"]) + return mask_shape, mask_var + def convert_indirect_indexing(self, index :sympy.Expr): if "tmp" not in str(index): return index diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 8ab94049..9151ac0b 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -179,9 +179,12 @@ def set_info(outer, inner, arg_type): class MLIRMultiDimTile(): def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=None, vec_size=None): + self.name = "" self._tile_size = list(tile_size) + self._tile_stride = None self.tile_axis_order = list(range(len(tile_size))) self.vec_size = vec_size + self.update_tile_stride() # Vector lane mapping config self.vector_lane = vector_lane @@ -190,12 +193,23 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N self.implicit_dim_size = None self.nr_rdim = 0 + def set_name(self, name: str): + self.name = name + def set_tile_size(self, tile_size, tile_axis_order=None): self._tile_size = tile_size if tile_axis_order is None: self.tile_axis_order = list(range(len(tile_size))) else: self.tile_axis_order = tile_axis_order + self.update_tile_stride() + + def set_tile_size_stride(self, tile_size, tile_stride): + self._tile_size = tile_size + self._tile_stride = tile_stride + + def get_name(self) -> str: + return self.name def get_tile_size(self): return self._tile_size @@ -216,7 +230,7 @@ def get_numel_per_lane(self): size *= dim_size return size - def get_tile_stride(self): + def update_tile_stride(self): strides = [1] * len(self._tile_size) init = 1 @@ -228,7 +242,10 @@ def get_tile_stride(self): for _, size, original_indices in sorted_pairs: strides[original_indices] = init init *= size - return strides + self._tile_stride = strides + + def get_tile_stride(self): + return self._tile_stride def get_tile_size_per_lane(self): tile_size_per_lane = list(self._tile_size) @@ -339,6 +356,7 @@ def __init__(self, kernel_group, reason=None): self.buffer_types : dict = None # format: dtype, numel, size, stride self.compute_idx = "compute_idx" self.compute_body_loop = LoopLevel(self.compute_idx, 1) + self.prologue_compute_body_loop = LoopLevel(self.compute_idx, 1) self.recodegen = reason # spad overflow, tile size, vlane stride self.stop_autotune = False @@ -479,8 +497,13 @@ def dummy_tile_size(): tile_size[0] = 2 * vlane_stride * self.vector_lane elif len(tile_size) == 3: tile_size[-1] = self.vector_lane - tile_size[-2] = 2 * self.vector_lane + tile_size[-2] = 4 * self.vector_lane + tile_size[-3] = 2 + elif len(tile_size) == 4: + tile_size[-1] = self.vector_lane + tile_size[-2] = 4 * self.vector_lane tile_size[-3] = 2 + tile_size[-4] = 1 else: raise NotImplementedError("dummy tile size fail!") return tile_size diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py new file mode 100644 index 00000000..8cd57077 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -0,0 +1,346 @@ +import os +import math +from sympy import Symbol, Number +from typing import List, Optional + +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +from torch._inductor.ir import IRNode +from torch._inductor.codecache import write_atomic +import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common +from torch._inductor.codecache import get_hash +from PyTorchSimFrontend import extension_config + +CONV_TEMPLATE = r""" +// Multi Channel Tile Conv2D kernel +// BATCH = {{ BATCH }} +// I_C = {{ I_C }} +// I_H = {{ I_H }} +// I_W = {{ I_W }} +// O_C = {{ O_C }} +// K_H = {{ K_H }} +// K_W = {{ K_W }} +// O_H = {{ O_H }} +// O_W = {{ O_W }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// TILE_I_H={{ TILE_I_H }}, +// TILE_I_W={{ TILE_I_W }}, +// TILE_O_H={{ TILE_O_H }}, +// TILE_O_W={{ TILE_O_W }}, +// TILE_K_H={{ TILE_K_H }}, +// TILE_K_W={{ TILE_K_W }}, +// SUB_TILE_M={{ SUB_TILE_M }}, +// SUB_TILE_N={{ SUB_TILE_N }}, +// SUB_TILE_I_W={{ SUB_TILE_I_W }}, +// SUB_TILE_K_H={{ SUB_TILE_K_H }}, +// SUB_TILE_K_W={{ SUB_TILE_K_W }}, +// PADDING_H = {{ PADDING_H }} +// PADDING_W = {{ PADDING_W }} +// STRIDE_H = {{ STRIDE_H }} +// STRIDE_W = {{ STRIDE_W }} +// DATA_STYPE = {{ DATA_STYPE }} + +#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> +#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(1 * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> +#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> +#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %c0 = arith.constant 0 : index + {{- kernel.def_local_vars(indent_size=2) }} + + affine.for %tile_m = 0 to {{ BATCH }} step {{ TILE_M }} { + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { + // Initialize output + {%- if BIAS %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} + {%- else %} + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) }}xf32> + {%- endif %} + affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { + affine.for %tile_k = 0 to {{ I_C * K_W }} step {{ TILE_K }} { + %index_i_h = affine.apply #map_I_H(%o_h, %k_h) + // Load input matrix + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_M, SUB_TILE_K], indent_size=14) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=14) }} + // Compute body part + affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. + affine.for %tile_k_w = 0 to 1 { + %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + affine.for %tile_o_h = 0 to {{ TILE_O_H }} { + affine.for %tile_o_w = 0 to {{ TILE_O_W }} { + %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) + %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_o_w) + %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { accumulation_loop=true, subtile_loop="k" } + } { accumulation_loop=true } + // Store output matrix + {{kernel.store_output(indent_size=10)}} + } { outer_loop=true } + } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } + return +} +""" + +WRAPPER_TEMPLATE = r""" +def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: + # Padding input + padded_shape = list(X.shape) + padded_shape[2] += 2 * {{ PADDING_H }} + padded_shape[3] += 2 * {{ PADDING_W }} + X_padding = torch.zeros(padded_shape, device=X.device) + X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X + + # Tanspose inputs + {%- for buf, name in kernel.get_conv_inputs().items() %} + {%- if name == "X" %} + {{ name }} = {{ name }}_padding.permute(2, 0, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (I_H, BATCH, I_W, I_C) + {%- elif name == "W" %} + {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) + {%- elif name == "Bias" %} + {{ name }} = {{ name }} + {%- endif %} + {%- endfor %} + + # Launch kernel + {{ KERNEL_NAME }} + {%- if BACKENDSIM_EAGER_MODE %} + yield ({{KERNEL_NAME}}, ) + {%- endif %} +""" + +class MLIRConvMultiTileTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.stride = kwargs["stride"] + self.padding = kwargs["padding"] + self.dilation = kwargs["dilation"] + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [i for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + + "_".join([str(i) for i in self.stride]) \ + + "_" + "_".join([str(i) for i in self.padding]) \ + + "_" + "_".join([str(i) for i in self.dilation]) + self.kernel_args = ['X', 'W', 'Bias', 'Y'] + + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs): + # Extract input arguments info + if template_buffer_node is not None: + self.output_node = template_buffer_node + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + + # Select tile size adn template + conv_template = CONV_TEMPLATE + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_I_H, TILE_O_W, TILE_M, TILE_K] + X_tile_stride = [TILE_O_W*TILE_M*TILE_K, TILE_M*TILE_K, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("index_i_h"), Symbol("o_w"), Symbol("tile_m"), Symbol("tile_k")] + X_idx = [X_dim[0]*(I_W+2*PADDING_W)*BATCH*I_C, X_dim[1]*I_C*STRIDE_W, X_dim[2]*I_C*(I_W+2*PADDING_W), X_dim[3]] + + W_tile_size = [TILE_K_H, 1, TILE_K, TILE_N] + W_tile_stride = [TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , Symbol("c0"), W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [TILE_M, TILE_N, TILE_O_H, TILE_O_W] + Y_tile_stride = [1, TILE_M, TILE_O_W * TILE_M * TILE_N, TILE_M * TILE_N] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_dim = [Symbol("tile_m"), Symbol("tile_n"), Symbol("o_h"), Symbol("o_w")] + Y_idx = [Y_dim[0]*O_C*O_H*O_W, Y_dim[1]*O_H*O_W, Y_dim[2]*O_W, Y_dim[3]] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=X, W=W, Y=Y, BIAS=Bias, + PADDED_INPUT_SIZE=self.get_padded_input_size(X), + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, + O_C=O_C, + K_H=K_H, + K_W=K_W, + O_H=O_H, + O_W=O_W, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + TILE_I_H=TILE_I_H, + TILE_I_W=TILE_I_W, + TILE_O_H=TILE_O_H, + TILE_O_W=TILE_O_W, + TILE_K_H=TILE_K_H, + TILE_K_W=TILE_K_W, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, + SUB_TILE_I_H=SUB_TILE_I_H, + SUB_TILE_I_W=SUB_TILE_I_W, + SUB_TILE_K_H=SUB_TILE_K_H, + SUB_TILE_K_W=SUB_TILE_K_W, + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + DATA_STYPE="f32", + input_reorder=self.input_reorder + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "output_buffer", + dram_var = "Y", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"tile_m", "index1":"tile_n", "index2":"o_h", "index3":"o_w"} + ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} + code = self._template_from_string(conv_template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + return code + + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_K = TILE_K + + TOG_latency = O_W if TILE_M > O_W else TILE_M + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + options = dict( + kernel=self.kernel, + KERNEL_NAME=kernel_name, + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, + PADDING_H=self.padding[0], + PADDING_W=self.padding[1], + VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder + ) + code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" + + def get_arg_attributes(self): + arg_attributes = [] + + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] + + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride + + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + + return arg_attributes + + def codegen_header(self, code, extra_headers): + write_path = extension_codecache.get_write_path(code) + if not os.path.exists(write_path): + os.makedirs(write_path) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, extra_headers[0]) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, extra_headers[1]) + self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py new file mode 100644 index 00000000..6c31776d --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -0,0 +1,342 @@ +import os +import math +from sympy import Symbol, Number +from typing import List, Optional + +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +from torch._inductor.ir import IRNode +from torch._inductor.codecache import write_atomic +import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common +from torch._inductor.codecache import get_hash +from PyTorchSimFrontend import extension_config + +CONV_TEMPLATE = r""" +// Single Batch Conv2D kernel +// BATCH = {{ BATCH }} +// I_C = {{ I_C }} +// I_H = {{ I_H }} +// I_W = {{ I_W }} +// O_C = {{ O_C }} +// K_H = {{ K_H }} +// K_W = {{ K_W }} +// O_H = {{ O_H }} +// O_W = {{ O_W }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// TILE_I_H={{ TILE_I_H }}, +// TILE_I_W={{ TILE_I_W }}, +// TILE_O_H={{ TILE_O_H }}, +// TILE_O_W={{ TILE_O_W }}, +// TILE_K_H={{ TILE_K_H }}, +// TILE_K_W={{ TILE_K_W }}, +// SUB_TILE_M={{ SUB_TILE_M }}, +// SUB_TILE_N={{ SUB_TILE_N }}, +// SUB_TILE_I_W={{ SUB_TILE_I_W }}, +// SUB_TILE_K_H={{ SUB_TILE_K_H }}, +// SUB_TILE_K_W={{ SUB_TILE_K_W }}, +// PADDING_H = {{ PADDING_H }} +// PADDING_W = {{ PADDING_W }} +// STRIDE_H = {{ STRIDE_H }} +// STRIDE_W = {{ STRIDE_W }} +// DATA_STYPE = {{ DATA_STYPE }} + +#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> +#map_I_W = affine_map<(d0, d1) -> (d0 * {{ STRIDE_W }} + d1)> +#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> +#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_I_W, TILE_K) }} + d1)> +#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %c0 = arith.constant 0 : index + {{- kernel.def_local_vars(indent_size=2) }} + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { + // Initialize output + {%- if BIAS %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} + {%- else %} + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + {%- endif %} + affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { + affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { + affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { + %index_i_h = affine.apply #map_I_H(%o_h, %k_h) + %index_i_w = affine.apply #map_I_W(%tile_m, %k_w) + // Load input & weight matrix + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_I_H, SUB_TILE_M, SUB_TILE_K], indent_size=14) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=14) }} + // Compute body part + affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. + affine.for %tile_k_w = 0 to {{ TILE_K_W }} { + %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + affine.for %tile_o_h = 0 to {{ TILE_O_H }} { + affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W + %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) + %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) + %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { accumulation_loop=true, subtile_loop="k" } + } { accumulation_loop=true } + } { accumulation_loop=true } + // Store output matrix + {{kernel.store_output(indent_size=8)}} + } { outer_loop=true, subtile_loop="m" } + } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + return +} +""" + +WRAPPER_TEMPLATE = r""" +def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: + # Padding input + padded_shape = list(X.shape) + padded_shape[2] += 2 * {{ PADDING_H }} + padded_shape[3] += 2 * {{ PADDING_W }} + X_padding = torch.zeros(padded_shape, device=X.device) + X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X + + # Tanspose inputs + {%- for buf, name in kernel.get_conv_inputs().items() %} + {%- if name == "X" %} + {{ name }} = {{ name }}_padding.permute(0, 2, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (BATCH, I_H, I_W, I_C) + {%- elif name == "W" %} + {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) + {%- elif name == "Bias" %} + {{ name }} = {{ name }} + {%- endif %} + {%- endfor %} + + # Launch kernel + {{ KERNEL_NAME }} + {%- if BACKENDSIM_EAGER_MODE %} + yield ({{KERNEL_NAME}}, ) + {%- endif %} +""" + +class MLIRConvSingleBatchTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.stride = kwargs["stride"] + self.padding = kwargs["padding"] + self.dilation = kwargs["dilation"] + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [i for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + + "_".join([str(i) for i in self.stride]) \ + + "_" + "_".join([str(i) for i in self.padding]) \ + + "_" + "_".join([str(i) for i in self.dilation]) + self.kernel_args = ['X', 'W', 'Bias', 'Y'] + + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs): + # Extract input arguments info + if template_buffer_node is not None: + self.output_node = template_buffer_node + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + + # Select tile size adn template + conv_template = CONV_TEMPLATE + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [1, TILE_I_H, TILE_I_W, TILE_K] + X_tile_stride = [TILE_I_H * TILE_I_W * TILE_K , TILE_I_W * TILE_K, 1, TILE_I_W] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("c0"), Symbol("index_i_h"), Symbol("index_i_w"), Symbol("tile_k")] + X_idx = [X_dim[0]*((I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C), X_dim[1]*((I_W+2*PADDING_W)*I_C), X_dim[2]*I_C, X_dim[3]] + + W_tile_size = [TILE_K_H, TILE_K_W, TILE_K, TILE_N] + W_tile_stride = [TILE_K_W * TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , W_dim[1]*I_C*O_C, W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [1, TILE_N, TILE_O_H, TILE_M] + Y_tile_stride = [TILE_O_H * TILE_M * TILE_N, TILE_M, TILE_M * TILE_N, 1] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_idx = [Number(0), Symbol("tile_n")*O_H*O_W, Symbol("o_h")*O_W, Symbol("tile_m")] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=X, W=W, Y=Y, BIAS=Bias, + PADDED_INPUT_SIZE=self.get_padded_input_size(X), + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, + O_C=O_C, + K_H=K_H, + K_W=K_W, + O_H=O_H, + O_W=O_W, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + TILE_I_H=TILE_I_H, + TILE_I_W=TILE_I_W, + TILE_O_H=TILE_O_H, + TILE_O_W=TILE_O_W, + TILE_K_H=TILE_K_H, + TILE_K_W=TILE_K_W, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, + SUB_TILE_I_H=SUB_TILE_I_H, + SUB_TILE_I_W=SUB_TILE_I_W, + SUB_TILE_K_H=SUB_TILE_K_H, + SUB_TILE_K_W=SUB_TILE_K_W, + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + DATA_STYPE="f32", + input_reorder=self.input_reorder + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "output_buffer", + dram_var = "Y", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"c0", "index1":"tile_n", "index2":"o_h", "index3":"tile_m"} + ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} + code = self._template_from_string(conv_template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + return code + + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_I_W if TILE_I_W < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + TOG_latency = O_W if TILE_M > O_W else TILE_M + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + options = dict( + kernel=self.kernel, + KERNEL_NAME=kernel_name, + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, + PADDING_H=self.padding[0], + PADDING_W=self.padding[1], + VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder + ) + code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" + + def get_arg_attributes(self): + arg_attributes = [] + + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] + + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride + + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + + return arg_attributes + + def codegen_header(self, code, extra_headers): + write_path = extension_codecache.get_write_path(code) + if not os.path.exists(write_path): + os.makedirs(write_path) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, extra_headers[0]) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, extra_headers[1]) + self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py new file mode 100644 index 00000000..a4ea0b20 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -0,0 +1,343 @@ +import os +import math +from sympy import Symbol, Number +from typing import List, Optional + +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +from torch._inductor.ir import IRNode +from torch._inductor.codecache import write_atomic +import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common +from torch._inductor.codecache import get_hash +from PyTorchSimFrontend import extension_config + +CONV_TEMPLATE = r""" +// Single Batch Conv2D (Stride != 1) kernel +// BATCH = {{ BATCH }} +// I_C = {{ I_C }} +// I_H = {{ I_H }} +// I_W = {{ I_W }} +// O_C = {{ O_C }} +// K_H = {{ K_H }} +// K_W = {{ K_W }} +// O_H = {{ O_H }} +// O_W = {{ O_W }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// TILE_I_H={{ TILE_I_H }}, +// TILE_I_W={{ TILE_I_W }}, +// TILE_O_H={{ TILE_O_H }}, +// TILE_O_W={{ TILE_O_W }}, +// TILE_K_H={{ TILE_K_H }}, +// TILE_K_W={{ TILE_K_W }}, +// SUB_TILE_M={{ SUB_TILE_M }}, +// SUB_TILE_N={{ SUB_TILE_N }}, +// SUB_TILE_I_W={{ SUB_TILE_I_W }}, +// SUB_TILE_K_H={{ SUB_TILE_K_H }}, +// SUB_TILE_K_W={{ SUB_TILE_K_W }}, +// PADDING_H = {{ PADDING_H }} +// PADDING_W = {{ PADDING_W }} +// STRIDE_H = {{ STRIDE_H }} +// STRIDE_W = {{ STRIDE_W }} +// DATA_STYPE = {{ DATA_STYPE }} + +#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> +#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> +#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M * TILE_K_W, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> +#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> + +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %c0 = arith.constant 0 : index + {{- kernel.def_local_vars(indent_size=2) }} + + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { + // Initialize output + {%- if BIAS %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} + {%- else %} + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + {%- endif %} + affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { + affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { + affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { + %index_i_h = affine.apply #map_I_H(%o_h, %k_h) + // Load input & weight matrix + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_I_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_K], indent_size=14) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=14) }} + // Compute body part + affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. + affine.for %tile_k_w = 0 to {{ TILE_K_W }} { + %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + affine.for %tile_o_h = 0 to {{ TILE_O_H }} { + affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W + %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) + %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) + %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { accumulation_loop=true, subtile_loop="k" } + } { accumulation_loop=true } + } { accumulation_loop=true } + // Store output matrix + {{kernel.store_output(indent_size=8)}} + } { outer_loop=true, subtile_loop="m" } + } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + return +} +""" + +WRAPPER_TEMPLATE = r""" +def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: + # Padding input + padded_shape = list(X.shape) + padded_shape[2] += 2 * {{ PADDING_H }} + padded_shape[3] += 2 * {{ PADDING_W }} + X_padding = torch.zeros(padded_shape, device=X.device) + X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X + + # Tanspose inputs + {%- for buf, name in kernel.get_conv_inputs().items() %} + {%- if name == "X" %} + {{ name }} = {{ name }}_padding.permute(0, 2, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (BATCH, I_H, I_W, I_C) + {%- elif name == "W" %} + {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) + {%- elif name == "Bias" %} + {{ name }} = {{ name }} + {%- endif %} + {%- endfor %} + + # Launch kernel + {{ KERNEL_NAME }} + {%- if BACKENDSIM_EAGER_MODE %} + yield ({{KERNEL_NAME}}, ) + {%- endif %} +""" + +class MLIRConvSingleBatchStridedTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.stride = kwargs["stride"] + self.padding = kwargs["padding"] + self.dilation = kwargs["dilation"] + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [i for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + + "_".join([str(i) for i in self.stride]) \ + + "_" + "_".join([str(i) for i in self.padding]) \ + + "_" + "_".join([str(i) for i in self.dilation]) + self.kernel_args = ['X', 'W', 'Bias', 'Y'] + + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs): + # Extract input arguments info + if template_buffer_node is not None: + self.output_node = template_buffer_node + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + + # Select tile size adn template + conv_template = CONV_TEMPLATE + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_I_H, TILE_K_H, TILE_M, TILE_K] + X_tile_stride = [TILE_K_W*TILE_M*TILE_K, TILE_M*TILE_K, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("index_i_h"), Symbol("k_w"), Symbol("tile_m"), Symbol("tile_k")] + X_idx = [X_dim[0]*((I_W+2*PADDING_W)*I_C), X_dim[1]*I_C, X_dim[2]*(I_C*STRIDE_W), X_dim[3]] + + W_tile_size = [TILE_K_H, TILE_K_W, TILE_K, TILE_N] + W_tile_stride = [TILE_K_W * TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , W_dim[1]*I_C*O_C, W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [1, TILE_N, TILE_O_H, TILE_M] + Y_tile_stride = [TILE_O_H * TILE_M * TILE_N, TILE_M, TILE_M * TILE_N, 1] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_idx = [Number(0), Symbol("tile_n")*O_H*O_W, Symbol("o_h")*O_W, Symbol("tile_m")] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=X, W=W, Y=Y, BIAS=Bias, + PADDED_INPUT_SIZE=self.get_padded_input_size(X), + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, + O_C=O_C, + K_H=K_H, + K_W=K_W, + O_H=O_H, + O_W=O_W, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + TILE_I_H=TILE_I_H, + TILE_I_W=TILE_I_W, + TILE_O_H=TILE_O_H, + TILE_O_W=TILE_O_W, + TILE_K_H=TILE_K_H, + TILE_K_W=TILE_K_W, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, + SUB_TILE_I_H=SUB_TILE_I_H, + SUB_TILE_I_W=SUB_TILE_I_W, + SUB_TILE_K_H=SUB_TILE_K_H, + SUB_TILE_K_W=SUB_TILE_K_W, + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + DATA_STYPE="f32", + input_reorder=self.input_reorder + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "output_buffer", + dram_var = "Y", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"c0", "index1":"tile_n", "index2":"o_h", "index3":"tile_m"} + ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} + code = self._template_from_string(conv_template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + return code + + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + TOG_latency = O_W if TILE_M > O_W else TILE_M + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + options = dict( + kernel=self.kernel, + KERNEL_NAME=kernel_name, + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, + PADDING_H=self.padding[0], + PADDING_W=self.padding[1], + VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder + ) + code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" + + def get_arg_attributes(self): + arg_attributes = [] + + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] + + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride + + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + + return arg_attributes + + def codegen_header(self, code, extra_headers): + write_path = extension_codecache.get_write_path(code) + if not os.path.exists(write_path): + os.makedirs(write_path) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, extra_headers[0]) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, extra_headers[1]) + self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 7a3b4b19..73cf710f 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -1,16 +1,15 @@ import os import math -from sympy import divisors, Range -from typing import List, Optional, cast +from sympy import Symbol, Number +from typing import List, Optional from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel -from torch._inductor.ir import Buffer from torch._inductor.ir import IRNode -from torch._inductor.ir import ReinterpretView from torch._inductor.codecache import write_atomic import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common from torch._inductor.codecache import get_hash from PyTorchSimFrontend import extension_config @@ -43,56 +42,30 @@ // PADDING_W = {{ PADDING_W }} // STRIDE_H = {{ STRIDE_H }} // STRIDE_W = {{ STRIDE_W }} -// DILATION_H = {{ DILATION_H }} -// DILATION_W = {{ DILATION_W }} // DATA_STYPE = {{ DATA_STYPE }} -// DATA_SIZE = {{ DATA_SIZE }} -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ O_W * BATCH * O_C }} + d1 * {{ BATCH * O_C }} + d2 * {{ O_C }} + d3)> // output (O_H, O_W, BATCH, O_C) -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ (I_W + 2 * PADDING_W) * BATCH * I_C }} + d1 * {{ BATCH * I_C }} + d2 * {{ I_C }} + d3)> // input (I_H, I_W, BATCH, I_C) -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ K_W * I_C * O_C }} + d1 * {{ I_C * O_C }} + d2 * {{ O_C }} + d3)> // weight (K_H, K_W, I_C, O_C) #map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> #map_I_W = affine_map<(d0, d1) -> (d0 * {{ STRIDE_W }} + d1)> #offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> #offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_I_W * TILE_M, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> #offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> - -memref.global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index - %c_mvin3 = arith.constant 14 : index - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %input_axis = arith.constant 3 : index - %weight_axis = arith.constant 2 : index - %input_buffer = memref.get_global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %weight_buffer = memref.get_global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %output_buffer = memref.get_global @Y_spad : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32> - %tag3 = memref.alloc() : memref<1xi32> - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) }}xf32> + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - {{- kernel.def_local_vars() }} + {{ kernel.def_local_vars(indent_size=2) }} - affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { - affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { - affine.for %tile_m = 0 to {{ BATCH }} step {{ TILE_M }} { - affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { - %index0 = affine.apply #map0(%o_h, %o_w, %tile_m, %tile_n) + affine.for %tile_m = 0 to {{ BATCH }} step {{ TILE_M }} { + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { // Initialize output {%- if BIAS %} - memref.dma_start %Bias[%tile_n], %output_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag0[%c0], %c0, %vstride - : memref<{{ O_C }}xf32>, memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ TILE_O_H }}, {{ TILE_O_W }}, {{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_O_W * TILE_M * TILE_N }}, {{ TILE_M * TILE_N }}, 1, {{ TILE_M }}]} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} {%- else %} affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) }}xf32> {%- endif %} @@ -101,406 +74,37 @@ affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { %index_i_h = affine.apply #map_I_H(%o_h, %k_h) %index_i_w = affine.apply #map_I_W(%o_w, %k_w) - %index1 = affine.apply #map1(%index_i_h, %index_i_w, %tile_m, %tile_k) // input index - %index2 = affine.apply #map2(%k_h, %k_w, %tile_k, %tile_n) // weight index // Load input matrix - memref.dma_start %X[%index1], %input_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag1[%c0], %input_axis, %vstride - : memref<{{ BATCH * I_C * (I_H + 2 * PADDING_H) * (I_W + 2 * PADDING_W) }}xf32>, memref<{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_I_H }}, {{ SUB_TILE_I_W }}, {{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[{{ TILE_I_W * TILE_M * TILE_K }}, {{ TILE_M * TILE_K }}, 1, {{ TILE_M }}]} - // Load kernel matrix - memref.dma_start %W[%index2], %weight_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag2[%c0], %input_axis, %vstride - : memref<{{ O_C * I_C * K_H * K_W }}xf32>, memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K_H }}, {{ SUB_TILE_K_W }}, {{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_K_W * TILE_K * TILE_N }}, {{ TILE_K * TILE_N }}, 1, {{ TILE_K }}]} + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_M, SUB_TILE_K], indent_size=16) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=16) }} + // Compute body part affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to {{ TILE_K_W }} { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ TILE_O_W }} { %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %tile_i_w = affine.apply #map_I_W(%tile_o_w, %tile_k_w) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_i_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } - } { accumulation_loop=true } + } { accumulation_loop=true, subtile_loop="k" } } { accumulation_loop=true } } { accumulation_loop=true } // Store output matrix {{kernel.store_output(indent_size=10)}} } { outer_loop=true } } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } - return -} -""" - -MULTI_TILE_CONV_TEMPLATE = r""" -// Multi Channel Tile Conv2D kernel -// BATCH = {{ BATCH }} -// I_C = {{ I_C }} -// I_H = {{ I_H }} -// I_W = {{ I_W }} -// O_C = {{ O_C }} -// K_H = {{ K_H }} -// K_W = {{ K_W }} -// O_H = {{ O_H }} -// O_W = {{ O_W }} -// TILE_M = {{ TILE_M }} -// TILE_N = {{ TILE_N }} -// TILE_K = {{ TILE_K }} -// TILE_I_H={{ TILE_I_H }}, -// TILE_I_W={{ TILE_I_W }}, -// TILE_O_H={{ TILE_O_H }}, -// TILE_O_W={{ TILE_O_W }}, -// TILE_K_H={{ TILE_K_H }}, -// TILE_K_W={{ TILE_K_W }}, -// SUB_TILE_M={{ SUB_TILE_M }}, -// SUB_TILE_N={{ SUB_TILE_N }}, -// SUB_TILE_I_W={{ SUB_TILE_I_W }}, -// SUB_TILE_K_H={{ SUB_TILE_K_H }}, -// SUB_TILE_K_W={{ SUB_TILE_K_W }}, -// PADDING_H = {{ PADDING_H }} -// PADDING_W = {{ PADDING_W }} -// STRIDE_H = {{ STRIDE_H }} -// STRIDE_W = {{ STRIDE_W }} -// DILATION_H = {{ DILATION_H }} -// DILATION_W = {{ DILATION_W }} -// DATA_STYPE = {{ DATA_STYPE }} -// DATA_SIZE = {{ DATA_SIZE }} - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ O_W * BATCH * O_C }} + d1 * {{ BATCH * O_C }} + d2 * {{ O_C }} + d3)> // output (O_H, O_W, BATCH, O_C) -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ (I_W + 2 * PADDING_W) * BATCH * I_C }} + d1 * {{ I_C * STRIDE_W }} + d2 * {{ I_C * (I_W + 2 * PADDING_W) }} + d3)> // input (I_H, BATCH, I_W, I_C) -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ K_W * I_C * O_C }} + d1 * {{ I_C * O_C }} + d2 * {{ O_C }} + d3)> // weight (K_H, K_W, I_C, O_C) -#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> -#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(1 * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> -#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> -#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> - -memref.global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K_H }}x{{ 1 }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> -{{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index - %c_mvin3 = arith.constant 14 : index - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %input_axis = arith.constant 3 : index - %weight_axis = arith.constant 2 : index - %input_buffer = memref.get_global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %weight_buffer = memref.get_global @W_spad : memref<{{ TILE_K_H }}x{{ 1 }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %output_buffer = memref.get_global @Y_spad : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32> - %tag3 = memref.alloc() : memref<1xi32> - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) }}xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - {{- kernel.def_local_vars() }} - - affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { - affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { - affine.for %tile_m = 0 to {{ BATCH }} step {{ TILE_M }} { - affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { - %index0 = affine.apply #map0(%o_h, %o_w, %tile_m, %tile_n) - // Initialize output - {%- if BIAS %} - memref.dma_start %Bias[%tile_n], %output_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag0[%c0], %c0, %vstride - : memref<{{ O_C }}xf32>, memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ TILE_O_H }}, {{ TILE_O_W }}, {{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_O_W * TILE_M * TILE_N }}, {{ TILE_M * TILE_N }}, 1, {{ TILE_M }}]} - {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) }}xf32> - {%- endif %} - affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { - affine.for %tile_k = 0 to {{ I_C * K_W }} step {{ TILE_K }} { - %index_i_h = affine.apply #map_I_H(%o_h, %k_h) - %index1 = affine.apply #map1(%index_i_h, %o_w, %tile_m, %tile_k) // input index - %index2 = affine.apply #map2(%k_h, %c0, %tile_k, %tile_n) // weight index - // Load input matrix - memref.dma_start %X[%index1], %input_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag1[%c0], %input_axis, %vstride - : memref<{{ BATCH * I_C * (I_H + 2 * PADDING_H) * (I_W + 2 * PADDING_W) }}xf32>, memref<{{ TILE_I_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_I_H }}, {{ SUB_TILE_I_W }}, {{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[{{ TILE_O_W * TILE_M * TILE_K }}, {{ TILE_M * TILE_K }}, 1, {{ TILE_M }}]} - // Load kernel matrix - memref.dma_start %W[%index2], %weight_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag2[%c0], %input_axis, %vstride - : memref<{{ O_C * I_C * K_H * K_W }}xf32>, memref<{{ TILE_K_H }}x{{ 1 }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K_H }}, {{ SUB_TILE_K_W }}, {{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_K_W * TILE_K * TILE_N }}, {{ TILE_K * TILE_N }}, 1, {{ TILE_K }}]} - affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. - affine.for %tile_k_w = 0 to 1 { - %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ 1 }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - affine.for %tile_o_h = 0 to {{ TILE_O_H }} { - affine.for %tile_o_w = 0 to {{ TILE_O_W }} { - %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) - %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_o_w) - %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<{{ TILE_I_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { accumulation_loop=true } - } { accumulation_loop=true } - // Store output matrix - {{kernel.store_output(indent_size=10)}} - } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } - return -} -""" - -SINGLE_BATCH_CONV_TEMPLATE = r""" -// Single Batch Conv2D kernel -// BATCH = {{ BATCH }} -// I_C = {{ I_C }} -// I_H = {{ I_H }} -// I_W = {{ I_W }} -// O_C = {{ O_C }} -// K_H = {{ K_H }} -// K_W = {{ K_W }} -// O_H = {{ O_H }} -// O_W = {{ O_W }} -// TILE_M = {{ TILE_M }} -// TILE_N = {{ TILE_N }} -// TILE_K = {{ TILE_K }} -// TILE_I_H={{ TILE_I_H }}, -// TILE_I_W={{ TILE_I_W }}, -// TILE_O_H={{ TILE_O_H }}, -// TILE_O_W={{ TILE_O_W }}, -// TILE_K_H={{ TILE_K_H }}, -// TILE_K_W={{ TILE_K_W }}, -// SUB_TILE_M={{ SUB_TILE_M }}, -// SUB_TILE_N={{ SUB_TILE_N }}, -// SUB_TILE_I_W={{ SUB_TILE_I_W }}, -// SUB_TILE_K_H={{ SUB_TILE_K_H }}, -// SUB_TILE_K_W={{ SUB_TILE_K_W }}, -// PADDING_H = {{ PADDING_H }} -// PADDING_W = {{ PADDING_W }} -// STRIDE_H = {{ STRIDE_H }} -// STRIDE_W = {{ STRIDE_W }} -// DILATION_H = {{ DILATION_H }} -// DILATION_W = {{ DILATION_W }} -// DATA_STYPE = {{ DATA_STYPE }} -// DATA_SIZE = {{ DATA_SIZE }} - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ O_W * O_H * O_C }} + d1 * {{ O_W * O_C }} + d2 * {{ O_C }} + d3)> // output (BATCH, O_H, O_W, O_C) -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ (I_W + 2 * PADDING_W) * (I_H + 2 * PADDING_W) * I_C }} + d1 * {{ (I_W + 2 * PADDING_W) * I_C }} + d2 * {{ I_C }} + d3)> // input (BATCH, I_H, I_W, I_C) Stride should be changed if kernel stride > 1 -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ K_W * I_C * O_C }} + d1 * {{ I_C * O_C }} + d2 * {{ O_C }} + d3)> // weight (K_H, K_W, I_C, O_C) -#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> -#map_I_W = affine_map<(d0, d1) -> (d0 * {{ STRIDE_W }} + d1)> -#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> -#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_I_W, TILE_K) }} + d1)> -#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> -memref.global @X_spad : memref<{{ 1 }}x{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> -{{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index - %c_mvin3 = arith.constant 14 : index - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %input_axis = arith.constant 3 : index - %weight_axis = arith.constant 2 : index - %input_buffer = memref.get_global @X_spad : memref<{{ 1 }}x{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_K }}xf32, 1> - %weight_buffer = memref.get_global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %output_buffer = memref.get_global @Y_spad : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32> - %tag3 = memref.alloc() : memref<1xi32> - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - {{- kernel.def_local_vars() }} - affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { - affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { - affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { - affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { - %index0 = affine.apply #map0(%c0, %o_h, %tile_m, %tile_n) - // Initialize output - {%- if BIAS %} - memref.dma_start %Bias[%tile_n], %output_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag0[%c0], %c0, %vstride // not implemented yet - : memref<{{ O_C }}xf32>, memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ 1 }}, {{ TILE_O_H }}, {{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_O_H * TILE_M * TILE_N }}, {{ TILE_M * TILE_N }}, 1, {{ TILE_M }}]} - {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> - {%- endif %} - affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { - affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { - affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { - %index_i_h = affine.apply #map_I_H(%o_h, %k_h) - %index_i_w = affine.apply #map_I_W(%o_w, %k_w) - %index1 = affine.apply #map1(%c0, %index_i_h, %index_i_w, %tile_k) // input index - %index2 = affine.apply #map2(%k_h, %k_w, %tile_k, %tile_n) // weight index - // Load input matrix - memref.dma_start %X[%index1], %input_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag1[%c0], %input_axis, %vstride - : memref<{{ BATCH * I_C * (I_H + 2 * PADDING_H) * (I_W + 2 * PADDING_W) }}xf32>, memref<{{ 1 }}x{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ 1 }}, {{ SUB_TILE_I_H }}, {{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[{{ TILE_I_H * TILE_I_W * TILE_K }}, {{ TILE_I_W * TILE_K }}, 1, {{ TILE_I_W }}]} - // Load kernel matrix - memref.dma_start %W[%index2], %weight_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag2[%c0], %input_axis, %vstride - : memref<{{ O_C * I_C * K_H * K_W }}xf32>, memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K_H }}, {{ SUB_TILE_K_W }}, {{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_K_W * TILE_K * TILE_N }}, {{ TILE_K * TILE_N }}, 1, {{ TILE_K }}]} - affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. - affine.for %tile_k_w = 0 to {{ TILE_K_W }} { - %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - affine.for %tile_o_h = 0 to {{ TILE_O_H }} { - affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W - %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) - %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) - %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<{{ 1 }}x{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { accumulation_loop=true } - } { accumulation_loop=true } - } { accumulation_loop=true } - // Store output matrix - {{kernel.store_output(indent_size=8)}} - } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } - return -} -""" - -SINGLE_BATCH_CONV_STRIDE_TEMPLATE = r""" -// Single Batch Conv2D (Stride != 1) kernel -// BATCH = {{ BATCH }} -// I_C = {{ I_C }} -// I_H = {{ I_H }} -// I_W = {{ I_W }} -// O_C = {{ O_C }} -// K_H = {{ K_H }} -// K_W = {{ K_W }} -// O_H = {{ O_H }} -// O_W = {{ O_W }} -// TILE_M = {{ TILE_M }} -// TILE_N = {{ TILE_N }} -// TILE_K = {{ TILE_K }} -// TILE_I_H={{ TILE_I_H }}, -// TILE_I_W={{ TILE_I_W }}, -// TILE_O_H={{ TILE_O_H }}, -// TILE_O_W={{ TILE_O_W }}, -// TILE_K_H={{ TILE_K_H }}, -// TILE_K_W={{ TILE_K_W }}, -// SUB_TILE_M={{ SUB_TILE_M }}, -// SUB_TILE_N={{ SUB_TILE_N }}, -// SUB_TILE_I_W={{ SUB_TILE_I_W }}, -// SUB_TILE_K_H={{ SUB_TILE_K_H }}, -// SUB_TILE_K_W={{ SUB_TILE_K_W }}, -// PADDING_H = {{ PADDING_H }} -// PADDING_W = {{ PADDING_W }} -// STRIDE_H = {{ STRIDE_H }} -// STRIDE_W = {{ STRIDE_W }} -// DILATION_H = {{ DILATION_H }} -// DILATION_W = {{ DILATION_W }} -// DATA_STYPE = {{ DATA_STYPE }} -// DATA_SIZE = {{ DATA_SIZE }} - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ O_W * O_H * O_C }} + d1 * {{ O_W * O_C }} + d2 * {{ O_C }} + d3)> // output (BATCH, O_H, O_W, O_C) -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ (I_W + 2 * PADDING_W) * I_C }} + d1 * {{ I_C }} + d2 * {{ I_C * STRIDE_W }} + d3)> // input (I_H, (k_w), I_W, I_C) // duplicate for k_w -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ K_W * I_C * O_C }} + d1 * {{ I_C * O_C }} + d2 * {{ O_C }} + d3)> // weight (K_H, K_W, I_C, O_C) -#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> -#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> -#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M * TILE_K_W, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> -#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> - -memref.global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_K_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> -{{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index - %c_mvin3 = arith.constant 14 : index - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %input_axis = arith.constant 3 : index - %weight_axis = arith.constant 2 : index - %input_buffer = memref.get_global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_K_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %weight_buffer = memref.get_global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %output_buffer = memref.get_global @Y_spad : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32> - %tag3 = memref.alloc() : memref<1xi32> - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - {{- kernel.def_local_vars() }} - - affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { - affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { - affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { - %index0 = affine.apply #map0(%c0, %o_h, %tile_m, %tile_n) - // Initialize output - {%- if BIAS %} - memref.dma_start %Bias[%tile_n], %output_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag0[%c0], %c0, %vstride // not implemented yet - : memref<{{ O_C }}xf32>, memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ 1 }}, {{ TILE_O_H }}, {{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_O_H * TILE_M * TILE_N }}, {{ TILE_M * TILE_N }}, 1, {{ TILE_M }}]} - {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> - {%- endif %} - affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { - affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { - affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { - %index_i_h = affine.apply #map_I_H(%o_h, %k_h) - %index1 = affine.apply #map1(%index_i_h, %k_w, %tile_m, %tile_k) // input index - %index2 = affine.apply #map2(%k_h, %k_w, %tile_k, %tile_n) // weight index - // Load input matrix - memref.dma_start %X[%index1], %input_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag1[%c0], %input_axis, %vstride - : memref<{{ BATCH * I_C * (I_H + 2 * PADDING_H) * (I_W + 2 * PADDING_W) }}xf32>, memref<{{ TILE_I_H }}x{{ TILE_K_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_I_H }}, {{ SUB_TILE_K_W }}, {{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[{{ TILE_K_W * TILE_M * TILE_K }}, {{ TILE_M * TILE_K }}, 1, {{ TILE_M }}]} - // Load kernel matrix - memref.dma_start %W[%index2], %weight_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag2[%c0], %input_axis, %vstride - : memref<{{ O_C * I_C * K_H * K_W }}xf32>, memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K_H }}, {{ SUB_TILE_K_W }}, {{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_K_W * TILE_K * TILE_N }}, {{ TILE_K * TILE_N }}, 1, {{ TILE_K }}]} - affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. - affine.for %tile_k_w = 0 to {{ TILE_K_W }} { - %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - affine.for %tile_o_h = 0 to {{ TILE_O_H }} { - affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W - %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) - %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) - %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<{{ TILE_I_H }}x{{ TILE_K_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { accumulation_loop=true } - } { accumulation_loop=true } - } { accumulation_loop=true } - // Store output matrix - {{kernel.store_output(indent_size=8)}} - } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } return } """ @@ -514,40 +118,14 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: X_padding = torch.zeros(padded_shape, device=X.device) X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X - # Holding original output tensor - {%- for buf, name in kernel.get_conv_outputs().items() %} - {{ name }}_t = {{ name }} - {%- endfor %} - # Tanspose inputs {%- for buf, name in kernel.get_conv_inputs().items() %} {%- if name == "X" %} - {%- if MULTI_TILE %} - {{ name }} = {{ name }}_padding.permute(2, 0, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (I_H, BATCH, I_W, I_C) - {%- elif SINGLE_BATCH %} - {{ name }} = {{ name }}_padding.permute(0, 2, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (BATCH, I_H, I_W, I_C) - {%- else %} {{ name }} = {{ name }}_padding.permute(2, 3, 0, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (I_H, I_W, BATCH, I_C) - {%- endif %} {%- elif name == "W" %} {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) {%- elif name == "Bias" %} {{ name }} = {{ name }} - {%- else %} - {%- if SINGLE_BATCH %} - {{ name }} = {{ name }}.permute(0, 2, 3, 1).contiguous() if {{ name }}.dim() == 4 else {{ name }} # (BATCH, O_C, O_H, O_W) -> (BATCH, O_H, O_W, O_C) - {%- else %} - {{ name }} = {{ name }}.permute(2, 3, 0, 1).contiguous() if {{ name }}.dim() == 4 else {{ name }} # (BATCH, O_C, O_H, O_W) -> (O_H, O_W, BATCH, O_C) - {%- endif %} - {%- endif %} - {%- endfor %} - - # Transpose outputs - {%- for buf, name in kernel.get_conv_outputs().items() %} - {%- if SINGLE_BATCH %} - {{ name }} = {{ name }}.permute(0, 2, 3, 1).contiguous() # (BATCH, O_C, O_H, O_W) -> (BATCH, O_H, O_W, O_C) - {%- else %} - {{ name }} = {{ name }}.permute(2, 3, 0, 1).contiguous() # (BATCH, O_C, O_H, O_W) -> (O_H, O_W, BATCH, O_C) {%- endif %} {%- endfor %} @@ -556,15 +134,6 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: {%- if BACKENDSIM_EAGER_MODE %} yield ({{KERNEL_NAME}}, ) {%- endif %} - - # Transpose back outputs - {%- for buf, name in kernel.get_conv_outputs().items() %} - {%- if SINGLE_BATCH %} - {{ name }}_t.copy_({{ name }}.permute(0, 3, 1, 2).contiguous()) # (BATCH, O_H, O_W, O_C) -> (BATCH, O_C, O_H, O_W) - {%- else %} - {{ name }}_t.copy_({{ name }}.permute(2, 3, 0, 1).contiguous()) # (O_H, O_W, BATCH, O_C) -> (BATCH, O_C, O_H, O_W) - {%- endif %} - {%- endfor %} """ class MLIRConvTemplate(MLIRTemplate): @@ -581,21 +150,6 @@ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + "_" + "_".join([str(i) for i in self.dilation]) self.kernel_args = ['X', 'W', 'Bias', 'Y'] - def is_transposed(self, node): - if isinstance(node, ReinterpretView): - if node.layout.stride != node.data.layout.stride: - if node.layout.stride[-2] == node.data.layout.stride[-1] and node.layout.stride[-1] == node.data.layout.stride[-2]: - return True - else: - raise NotImplementedError("If the stride is not equal to the original stride, it should have been transposed.") - return False - - def is_multi_tile(self, I_C): - return I_C < (self.kernel.vector_lane // 8) # 8 is hard-coded for now. This should be changed to a better heuristic. - - def is_single_batch(self, BATCH): - return BATCH == 1 - def get_padded_input_size(self, X): input_padded = list(X.layout.size) input_padded[2] += 2 * self.padding[0] @@ -607,6 +161,7 @@ def render(self, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, **kwargs): + # Extract input arguments info if template_buffer_node is not None: self.output_node = template_buffer_node self.kernel = kernel @@ -617,93 +172,68 @@ def render(self, Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] if epilogue_nodes is not None: - extra_node_rw = { - item.name for epilogue_node in epilogue_nodes - for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes - if item.name != Y.name - } + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 - BATCH = X.layout.size[0] - I_C = X.layout.size[1] - O_C = W.layout.size[0] - K_H = W.layout.size[2] - K_W = W.layout.size[3] + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] - SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 - x_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_I_W * TILE_I_H * TILE_M, TILE_K) - w_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_K_W * TILE_K_H * TILE_K, TILE_N) - y_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) - x_spad_size = TILE_I_W * TILE_I_H * TILE_M * TILE_K - w_spad_size = TILE_K_W * TILE_K_H * TILE_K * TILE_N - y_spad_size = TILE_O_H * TILE_O_W * TILE_M * TILE_N + # Select tile size adn template conv_template = CONV_TEMPLATE - TOG_latency = BATCH if TILE_M > BATCH else TILE_M - if self.is_single_batch(BATCH) and self.stride[0] != 1: - conv_template = SINGLE_BATCH_CONV_STRIDE_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - x_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_K_W * TILE_I_H * TILE_M, TILE_K) - w_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_K_W * TILE_K_H * TILE_K, TILE_N) - y_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) - x_spad_size = TILE_K_W * TILE_I_H * TILE_M * TILE_K - w_spad_size = TILE_K_W * TILE_K_H * TILE_K * TILE_N - y_spad_size = TILE_O_H * TILE_M * TILE_N - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - TOG_latency = O_W if TILE_M > O_W else TILE_M - elif self.is_single_batch(BATCH) and self.stride[0] == 1: - conv_template = SINGLE_BATCH_CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] - SUB_TILE_M = TILE_I_W if TILE_I_W < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - x_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_I_W * TILE_I_H, TILE_K) - w_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_K_W * TILE_K_H * TILE_K, TILE_N) - y_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) - x_spad_size = TILE_I_W * TILE_I_H * TILE_K - w_spad_size = TILE_K_W * TILE_K_H * TILE_K * TILE_N - y_spad_size = TILE_O_H * TILE_M * TILE_N - TOG_latency = O_W if TILE_M > O_W else TILE_M - elif self.is_multi_tile(I_C): - conv_template = MULTI_TILE_CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - SUB_TILE_K = TILE_K - x_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_I_W * TILE_I_H * TILE_M, TILE_K) - w_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_K_H * TILE_K, TILE_N) - y_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) - x_spad_size = TILE_I_W * TILE_I_H * TILE_M * TILE_K - w_spad_size = TILE_K_H * TILE_K * TILE_N - y_spad_size = TILE_O_H * TILE_O_W * TILE_M * TILE_N + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N TOG_latency = 8 if TOG_latency < 8 else TOG_latency kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_I_H, TILE_I_W, TILE_M, TILE_K ] + X_tile_stride = [TILE_I_W*TILE_M*TILE_K, TILE_M*TILE_K, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("index_i_h"), Symbol("index_i_w"), Symbol("tile_m"), Symbol("tile_k")] + X_idx = [X_dim[0]*(I_W+2*PADDING_W)*BATCH*I_C, X_dim[1]*I_C*BATCH, X_dim[2]*I_C, X_dim[3]] + + W_tile_size = [TILE_K_H, TILE_K_W, TILE_K, TILE_N] + W_tile_stride = [TILE_K_W * TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , W_dim[1]*I_C*O_C, W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [TILE_M, TILE_N, TILE_O_H, TILE_O_W] + Y_tile_stride = [1, TILE_M, TILE_O_W * TILE_M * TILE_N, TILE_M * TILE_N] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_dim = [Symbol("tile_m"), Symbol("tile_n"), Symbol("o_h"), Symbol("o_w")] + Y_idx = [Y_dim[0]*O_C*O_H*O_W, Y_dim[1]*O_H*O_W, Y_dim[2]*O_W, Y_dim[3]] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - X=X, - W=W, - BIAS=Bias, - Y=Y, + X=X, W=W, Y=Y, BIAS=Bias, PADDED_INPUT_SIZE=self.get_padded_input_size(X), - BATCH=X.layout.size[0], - I_C=X.layout.size[1], - I_H=X.layout.size[2], - I_W=X.layout.size[3], + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, O_C=O_C, K_H=K_H, K_W=K_W, @@ -725,43 +255,46 @@ def render(self, SUB_TILE_I_W=SUB_TILE_I_W, SUB_TILE_K_H=SUB_TILE_K_H, SUB_TILE_K_W=SUB_TILE_K_W, - PADDING_H=self.padding[0], - PADDING_W=self.padding[1], - STRIDE_H=self.stride[0], - STRIDE_W=self.stride[1], - DILATION_H=self.dilation[0], - DILATION_W=self.dilation[1], + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, DATA_STYPE="f32", - DATA_SIZE=4, input_reorder=self.input_reorder ) - kernel.store_info = dict( + kernel.epilogue_info = dict( output_node = self.output_node.name, - dependent_buf = [], sram_var = "output_buffer", dram_var = "Y", - index_var = "index0", - tag_var = "tag", - vlane_split_axis = 3, - vlane_stride = 1, - mlir_dtype = kernel.render_options['DATA_STYPE'], - dram_shape = f"memref<{BATCH * O_C * O_H * O_W}x{kernel.render_options['DATA_STYPE']}>", - tile_size = (TILE_O_H, TILE_O_W, TILE_M, TILE_N) if conv_template in (CONV_TEMPLATE, MULTI_TILE_CONV_TEMPLATE) else (1, TILE_O_H, TILE_M, TILE_N), - tile_stride = [TILE_O_W * TILE_M * TILE_N, TILE_M * TILE_N, 1, TILE_M] + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"tile_m", "index1":"tile_n", "index2":"o_h", "index3":"o_w"} ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} code = self._template_from_string(conv_template).render(**kernel.render_options) - self.header = f"float X_spad[{x_spad_size_per_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float W_spad[{w_spad_size_per_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{y_spad_size_per_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{x_spad_size}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float W_spad[{w_spad_size}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{y_spad_size}] __attribute__ ((section(\".spad\")));\n" - kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) - return code + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = BATCH if TILE_M > BATCH else TILE_M + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + def outer_func_render(self, kernel_name, input_args): X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node @@ -778,8 +311,6 @@ def outer_func_render(self, kernel_name, input_args): OUTPUT=Y, PADDING_H=self.padding[0], PADDING_W=self.padding[1], - MULTI_TILE=self.is_multi_tile(self.input_shape[1]), - SINGLE_BATCH=self.is_single_batch(self.input_shape[0]), VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, BACKENDSIM_EAGER_MODE=eager_mode, input_reorder=self.input_reorder @@ -813,7 +344,7 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) + write_atomic(gem5_write_path, extra_headers[1]) self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index a6b3423b..f706c2e5 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -1,18 +1,20 @@ import os +import json +from pathlib import Path from torch import empty_strided -from typing import List, Optional, cast +from typing import List, Optional +import sympy from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel -from torch._inductor.ir import Buffer from torch._inductor.ir import IRNode -from torch._inductor.ir import ReinterpretView from torch._inductor.codecache import write_atomic import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.mlir import mlir_common GEMM_TEMPLATE = r""" -// GEMM kernel +// GEMM {% if prologue_nodes -%}prologue fused{%- endif %} {% if epilogue_nodes -%}eilogue fused{%- endif %} kernel // M = {{ M }} // N = {{ N }} // K = {{ K }} @@ -21,59 +23,36 @@ // TILE_K = {{ TILE_K }} // SUB_TILE_M = {{ SUB_TILE_M }} // SUB_TILE_N = {{ SUB_TILE_N }} -#map0 = affine_map<(d0, d1) -> ({{ X_map }})> -#map1 = affine_map<(d0, d1) -> ({{ W_map }})> -#map2 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> -memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index{% if Bias %} - %c_mvin3 = arith.constant 14 : index{% endif %} - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %axis = arith.constant 1 : index - %X_buffer = memref.get_global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32>{% if not Bias %} + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} - %c0 = arith.constant 0 : index -{{ kernel.def_local_vars() }} - - affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { - %index2 = affine.apply #map2(%t_m, %t_n) + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { {%- if Bias %} - memref.dma_start %Bias[ - {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], % - {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} - , %vstride : memref< - {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} - xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} - affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {%- endif %} - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%t_m, %t_k) - %index1 = affine.apply #map1(%t_k, %t_n) - memref.dma_start %X[%index0], %X_buffer[%c0, %c0], %c_mvin, %tag1[%c0], %axis, %vstride - : memref<{{ M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[1, {{ TILE_M }}]} - memref.dma_start %W[%index1], %W_buffer[%c0, %c0], %c_mvin2, %tag2[%c0], %axis, %vstride - : memref<{{ K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_K }}]} - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true } + affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { + {% if prologue_nodes -%} + // prologue nodes + {{kernel.load_input(indent_size=8)}} + {%- else -%} + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_K], indent_size=8) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K, SUB_TILE_N], indent_size=8) }} + {%- endif %} + linalg.matmul ins(%X_buffer, %W_buffer : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }}, {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }}) + outs(%Y_buffer : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}) + } { accumulation_loop=true, subtile_loop="k" } {{kernel.store_output(indent_size=6)}} - } { outer_loop=true } - } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } return } """ @@ -85,7 +64,7 @@ """ GEMM_REDUCTION_TEMPLATE = r""" -// GEMM kernel +// GEMM reduction kernel // M = {{ M }} // N = {{ N }} // K = {{ K }} @@ -94,60 +73,34 @@ // TILE_K = {{ TILE_K }} // SUB_TILE_M = {{ SUB_TILE_M }} // SUB_TILE_N = {{ SUB_TILE_N }} -#map0 = affine_map<(d0, d1) -> ({{ X_map }})> -#map1 = affine_map<(d0, d1) -> ({{ W_map }})> -#map2 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> -memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index{% if Bias %} - %c_mvin3 = arith.constant 14 : index{% endif %} - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %axis = arith.constant 1 : index - %X_buffer = memref.get_global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32>{% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} - %c0 = arith.constant 0 : index -{{ kernel.def_local_vars() }} - - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { - {{kernel.reduction_acc()}} affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} {{kernel.reduction_iter_arg()}} { - %index2 = affine.apply #map2(%t_m, %t_n) + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { + affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { + %Y_bufferT = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> {%- if Bias %} - memref.dma_start %Bias[ - {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], % - {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} - , %vstride : memref< - {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} - xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} - affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {%- endif %} - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%t_m, %t_k) - %index1 = affine.apply #map1(%t_k, %t_n) - memref.dma_start %X[%index0], %X_buffer[%c0, %c0], %c_mvin, %tag1[%c0], %axis, %vstride - : memref<{{ M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[1, {{ TILE_M }}]} - memref.dma_start %W[%index1], %W_buffer[%c0, %c0], %c_mvin2, %tag2[%c0], %axis, %vstride - : memref<{{ K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_K }}]} - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true, loop_k=true } + affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_K], indent_size=8) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K, SUB_TILE_N], indent_size=8) }} + linalg.matmul ins(%X_buffer, %W_buffer : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }}, {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }}) + outs(%Y_bufferT : memref<{{TILE_M}}x{{TILE_N}}x{{DATA_STYPE}}, 1>) + } { accumulation_loop=true, subtile_loop="k" } {{kernel.store_output(indent_size=6)}} - } { outer_loop=true, loop_m=true} + } { outer_loop=true, subtile_loop="m" } {{kernel.reduction_output(indent_size=4)}} - } { outer_loop=true, loop_n=true } + } { outer_loop=true, subtile_loop="n" } return } """ @@ -160,65 +113,95 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, **kwargs): if template_buffer_node is not None: self.output_node = template_buffer_node - # if epilogue_nodes is not None and len(epilogue_nodes) > 0: - # self.output_node = cast(Buffer, epilogue_nodes[-1]) #FIXME: Temperary solution - - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - W_tensor = empty_strided(W.layout.size, W.layout.stride) - X_tensor = empty_strided(X.layout.size, X.layout.stride) + # Extract input arguments info + X, W, Y = self.input_nodes[0], self.input_nodes[1], self.output_node + X_tensor = empty_strided(X.layout.size, X.layout.stride) + W_tensor = empty_strided(W.layout.size, W.layout.stride) if len(W_tensor.size()) > 2 or len(X_tensor.size()) > 2: raise NotImplementedError("Please report this case to us...") - W_stride = W_tensor.stride() - X_stride = X_tensor.stride() - W_map = " + ".join([f"d{idx}*{s}" for idx, s in enumerate(W_stride)]) - X_map = " + ".join([f"d{idx}*{s}" for idx, s in enumerate(X_stride)]) - M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] - n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 - # Caculate extra reads + # Extract fusion info + n_epilogue_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 n_extra_read = set() if epilogue_nodes is not None: - for enode in epilogue_nodes: - n_extra_read.update(enode.node.get_read_names()) - if self.output_node.name in n_extra_read: - n_extra_read.remove(self.output_node.name) + for enode in epilogue_nodes: + n_extra_read.update(enode.node.get_read_names()) + if self.output_node.name in n_extra_read: + n_extra_read.remove(self.output_node.name) - nr_rdim = 0 - if (M == 0) or (N == 0) or (K == 0): - TILE_M, TILE_N, TILE_K = 1, 1, 1 + # Select tile size + M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + + # Select template code + if (M == 0) or (N == 0) or (K == 0): # exception for MoE template = EMPTY_TEMPLATE - elif n_extra_node==1 and epilogue_nodes[0].is_reduction(): - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node, min_tile=True) + nr_rdim = 0 + epilogue_dim_aliasing = {} + elif n_epilogue_node>=1 and epilogue_nodes[0].is_reduction(): template = GEMM_REDUCTION_TEMPLATE + epilogue_dim_aliasing = {"index0":"index1", "index1":"index0"} nr_rdim = 1 else: - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, len(n_extra_read), min_tile=True) template = GEMM_TEMPLATE - TILE_M = min(extension_config.CONFIG_FORCE_TILE_M, TILE_M) - TILE_N = min(extension_config.CONFIG_FORCE_TILE_N, TILE_N) - TILE_K = min(extension_config.CONFIG_FORCE_TILE_K, TILE_K) - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - if (TILE_M == M and TILE_N == N): - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - else: # Avoid Row Conflict of weights - SUB_TILE_N = TILE_N - SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N # FIXME: hardcoded & 126 line has same feature - SUB_TILE_K = TILE_K + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1"} + nr_rdim = 0 + TOG_latency = M if SUB_TILE_M > M else SUB_TILE_M kernel.loop_size =[TOG_latency, SUB_TILE_N, SUB_TILE_K] + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_M, TILE_K] + X_tile_stride = [1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("X_buffer") + X_stride = X.get_layout().stride + X_idx = [sympy.Symbol("index0") * X_stride[0], sympy.Symbol("index2") * X_stride[1]] # To keep index arguemnt order, we used index_list + + W_tile_size = [TILE_K, TILE_N] + W_tile_stride = [1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("W_buffer") + W_stride = W.get_layout().stride + W_idx = [sympy.Symbol("index2") * W_stride[0], sympy.Symbol("index1") * W_stride[1]] + + vlane_split_axis = vlane_split_axis if nr_rdim==0 else 0 + Y_tile_size = [TILE_M, TILE_N] if nr_rdim == 0 else [TILE_N, TILE_M] + Y_tile_stride=[1, TILE_M] if nr_rdim == 0 else [TILE_M, 1] + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("Y_buffer") + Y_stride = Y.get_layout().stride + if nr_rdim == 0: + Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]] + else: + Y_idx = [sympy.Symbol("index1") * Y_stride[1], sympy.Symbol("index0") * Y_stride[0]] + + # Extract Bias info + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + if Bias is not None: + Bias_stride = Bias.get_layout().stride + if nr_rdim == 0: + Bias_idx = [sympy.Symbol("index0") * Bias_stride[0], sympy.Symbol("index1") * Bias_stride[1]] + else: + Bias_idx = [sympy.Symbol("index1") * Bias_stride[1], sympy.Symbol("index0") * Bias_stride[0]] + else: + Bias_idx = None + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - M=M, - N=N, - K=K, + M=M, N=N, K=K, TILE_M=TILE_M, TILE_N=TILE_N, TILE_K=TILE_K, @@ -226,48 +209,120 @@ def render(self, SUB_TILE_N=SUB_TILE_N, SUB_TILE_K=SUB_TILE_K, DATA_STYPE="f32", - DATA_SIZE=4, - X = X, - W = W, - Y = Y, + X = X, W = W, Y = Y, Bias = Bias, - Bias_rank = len(Bias.data.get_size()) if Bias is not None else 0, - X_map = X_map, - W_map = W_map, - Y_numel = M * N, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, epilogue_nodes = epilogue_nodes, + prologue_nodes = prologue_nodes, input_reorder = self.input_reorder ) + if prologue_nodes: + prologue_output_name = list(prologue_nodes[0].read_writes.writes)[0].name + if prologue_output_name == X.get_name(): + # Input fusion case + prologue_var = "X" + prologue_sram_var = "X_buffer" + prologue_tile_desc = X_tile_desc + prologue_dim_aliasing = {"index0":"index0", "index1":"index2"} + is_input_fused = True + else: + # Weight fusion case + prologue_var = "W" + prologue_sram_var = "W_buffer" + prologue_tile_desc = W_tile_desc + prologue_dim_aliasing = {"index0":"index2", "index1":"index1"} + is_input_fused = False + + kernel.prologue_info = dict ( + input_dram_var = "X", + input_sram_var = "X_buffer", + input_tile_desc = X_tile_desc, + input_idx = X_idx, + input_subtile_size = [TILE_M, TILE_K], + input_dim_aliasing = {"index0":"index0", "index1":"index2"}, - kernel.store_info = dict( + weight_dram_var = "W", + weight_sram_var = "W_buffer", + weight_tile_desc = W_tile_desc, + weight_idx = W_idx, + weight_subtile_size = [TILE_K, TILE_N], + weight_dim_aliasing = {"index0":"index2", "index1":"index1"}, + + # Descriptor for fusion + dram_var = prologue_var, + sram_var = prologue_sram_var, + dram_tile_desc = prologue_tile_desc, + dim_aliasing = prologue_dim_aliasing, + is_bmm = False, + is_input_fused = is_input_fused + ) + kernel.epilogue_info = dict( output_node = self.output_node.name, - dependent_buf = [], - sram_var = "Y_buffer", dram_var = "Y", - index_var = "index2", - tag_var = "tag", - vlane_split_axis = 1, - vlane_stride = 1, - mlir_dtype = kernel.render_options['DATA_STYPE'], - dram_shape = f"memref<{kernel.render_options['Y_numel']}x{kernel.render_options['DATA_STYPE']}>", - tile_size = (TILE_M, TILE_N), - tile_stride = [1, TILE_M], + sram_var = "Y_buffer", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, nr_rdim = nr_rdim, - reduction_idx = "t_n" + dim_aliasing = epilogue_dim_aliasing ) code = self._template_from_string(template).render(**kernel.render_options) kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + return code - self.header = f"float X_spad[{kernel.get_spad_size_per_lane(TILE_M, TILE_K)}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float W_spad[{kernel.get_spad_size_per_lane(TILE_K, TILE_N)}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{kernel.get_spad_size_per_lane(TILE_M, TILE_N)}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{TILE_M * TILE_K}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float W_spad[{TILE_K * TILE_N}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{TILE_M * TILE_N}] __attribute__ ((section(\".spad\")));\n" + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): + # Check cheat sheet + cheatsheet_path = extension_config.CONFIG_GEMM_CHEATSHEET_PATH + data = {} + if extension_config.CONFIG_GEMM_CHEATSHEET_PATH is not None: + path = Path(cheatsheet_path) + if path.is_file(): + with path.open("r") as f: + data = json.load(f) - kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + gemm_shape = f"{M}_{K}_{N}" + if gemm_shape in data: + tile_info = data[gemm_shape] + TILE_M = tile_info["TILE_M"] + TILE_N = tile_info["TILE_N"] + TILE_K = tile_info["TILE_K"] + else: # case 2: use gemm_combination_mapping + min_tile = (n_extra_node + n_prologue_node) == 0 + TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, max(len(n_extra_read)-2, 0), n_prologue_node, min_tile=min_tile) + # case 3: use manual tile size + if extension_config.CONFIG_MANUAL_TILE_SIZE: + TILE_M = extension_config.CONFIG_TILE_M + TILE_N = extension_config.CONFIG_TILE_N + TILE_K = extension_config.CONFIG_TILE_K - return code + # Edge case + if (M == 0) or (N == 0) or (K == 0): + TILE_M, TILE_N, TILE_K = 1, 1, 1 + + # Calculate Sub Tile Size for fine-grained DMA + if extension_config.CONFIG_SUBTILE: + # Case 1: adjust selective fine-grained DMA (SFG-DMA) + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane or n_prologue_node) else kernel.vector_lane + if (TILE_M == M and TILE_N == N and TILE_N <= 512): + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + else: # Avoid Row Conflict of weights + SUB_TILE_N = TILE_N + SUB_TILE_K = TILE_K + # Case 2: use manual sub tile size (FG-DMA) + if extension_config.CONFIG_MANUAL_SUBTILE_SIZE: + SUB_TILE_M = extension_config.CONFIG_SUBTILE_M + SUB_TILE_N = extension_config.CONFIG_SUBTILE_N + SUB_TILE_K = extension_config.CONFIG_SUBTILE_K + # Case 3: None Subtile + else: + SUB_TILE_M = TILE_M + SUB_TILE_N = TILE_N + SUB_TILE_K = TILE_K + return TILE_M,TILE_N,TILE_K, SUB_TILE_M,SUB_TILE_N,SUB_TILE_K def codegen_header(self, code, extra_headers): write_path = extension_codecache.get_write_path(code) @@ -276,6 +331,6 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) + write_atomic(gem5_write_path, extra_headers[1]) diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index b1e1ba0e..aa3cf16e 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -11,7 +11,11 @@ from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate from PyTorchSimFrontend.mlir.mlir_conv_template import MLIRConvTemplate +from PyTorchSimFrontend.mlir.mlir_conv_mt_template import MLIRConvMultiTileTemplate +from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate +from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.extension_config import CONFIG_VECTOR_LANE aten = torch.ops.aten aten_spmm = MLIRExternKernelChoice(torch.sparse.mm, "custom_op::sparse_addmm") @@ -96,9 +100,20 @@ def convolution( x.realize() weight.realize() x = ir.ExternKernel.require_channels_last(x) + BATCH = x.layout.size[0] + I_C = x.layout.size[1] weight = ir.ExternKernel.require_channels_last(weight) layout = conv_layout(x, weight, None, **kwargs) - mlir_template = MLIRConvTemplate([x, weight, bias], layout, **kwargs) + + # Select conv kernel + if BATCH == 1 and stride[0] == 1: + mlir_template = MLIRConvSingleBatchTemplate([x, weight, bias], layout, **kwargs) + elif BATCH == 1 and stride[0] != 1: + mlir_template = MLIRConvSingleBatchStridedTemplate([x, weight, bias], layout, **kwargs) + elif I_C < CONFIG_VECTOR_LANE // 8: # 8 is hard-coded for now. This should be changed to a better heuristic. + mlir_template = MLIRConvMultiTileTemplate([x, weight, bias], layout, **kwargs) + else: + mlir_template = MLIRConvTemplate([x, weight, bias], layout, **kwargs) return mlir_template.generate().output_node() def maxpool_layout( diff --git a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py index 6a5aafa0..5395efb2 100644 --- a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py +++ b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py @@ -26,8 +26,8 @@ affine.for %i = 0 to {{ BCH }} step {{ out_tile }} { affine.for %j = 0 to {{ W }} step {{ out_tile }} { %index0 = affine.apply #map0(%i, %j) - memref.dma_start %X[%index0], %X_buffer[%c0, %c0], %c_mvin, %tag[%c0], %axis, %vstride : memref<{{ IN }}xf32>, memref<{{ in_tile }}x{{ in_tile }}xf32, 1>, memref<1xi32> - memref.dma_start %Y_buffer[%c0, %c0], %Y[%index0], %c_mvout, %tag[%c0], %axis, %vstride : memref<{{ out_tile }}x{{ out_tile }}xf32, 1>, memref<{{ OUT }}xf32>, memref<1xi32> + memref.dma_start %X[%index0], %X_buffer[%c0, %c0], %c_mvin, %tag[%c0], %axis, %vstride : memref<{{ IN }}xf32>, memref<{{ in_tile }}x{{ in_tile }}xf32, 1>, memref<1xi32> {dram_stride=[{{W}}, 1]} + memref.dma_start %Y_buffer[%c0, %c0], %Y[%index0], %c_mvout, %tag[%c0], %axis, %vstride : memref<{{ out_tile }}x{{ out_tile }}xf32, 1>, memref<{{ OUT }}xf32>, memref<1xi32> {dram_stride=[{{W}}, 1]} } { outer_loop=true } } { outer_loop=true } return @@ -62,6 +62,7 @@ def render(self, W = Y.get_size()[3] BCH = B * C * H kernel.loop_size = None + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -75,28 +76,12 @@ def render(self, out_tile=out_tile, DATA_STYPE="f32", ) - kernel.store_info = dict( + kernel.epilogue_info = dict( output_node = self.output_node.name, - dependent_buf = [], sram_var = "Y_buffer", dram_var = "Y", - index_var = "index0", - tag_var = "tag", - vlane_split_axis = 1, - vlane_stride = 1, - mlir_dtype = kernel.render_options['DATA_STYPE'], - tile_nr_dim = 2, - dram_shape = f"memref<{kernel.render_options['OUT']}x{kernel.render_options['DATA_STYPE']}>", - tile_shape = f"memref<{out_tile}x{out_tile}x{kernel.render_options['DATA_STYPE']}, 1>", - tile_size = (out_tile, out_tile), - tile_stride = [1, out_tile] ) code = self._template_from_string(TEMPLATE).render(**kernel.render_options) - self.header = f"float X_spad[{in_tile * in_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{out_tile * out_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{in_tile * in_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{out_tile * out_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - kernel.add_loop_info([kernel.render_options["IN"]], [kernel.vector_lane, kernel.vector_lane]) return code @@ -107,6 +92,6 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) \ No newline at end of file + write_atomic(gem5_write_path, extra_headers[1]) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index ec8de5a1..f1c72c44 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -1,6 +1,9 @@ import os import math -from sympy import symbols, sympify +import sympy +from functools import reduce +import operator +from sympy import symbols, sympify, Symbol from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel @@ -8,9 +11,11 @@ from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode, SchedulerNode, BaseSchedulerNode from torch._inductor.utils import IndentedBuffer from torch._inductor.virtualized import V +from torch._inductor.ir import LoopBody +from torch._inductor import dependencies from . import mlir_common -from . import mlir_lowering +from . import mlir_lowering # DO NOT REMOVE THIS LINE, it is used for lowering class MLIRScheduling(BaseScheduling): count = 0 @@ -19,6 +24,7 @@ def __init__(self, scheduler): self.scheduler = scheduler self.scheduler.can_fuse_origin = self.scheduler.can_fuse self.scheduler.can_fuse = self.can_fuse_with_exceptions + #self.scheduler.enter_context = self.enter_context_fixed # FIXME. Monkey patch: For fixing the inductor bug self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() self._ready_to_flush = False self.outer_function = set() @@ -26,21 +32,54 @@ def __init__(self, scheduler): self.max_fusion_size = 5 def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: - if node1.get_device() == node2.get_device(): + # Extract base template node + base_template_node1 = [node for node in node1.get_nodes() if node.is_template()] + base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] + if node1.get_device() != node2.get_device(): + return False + if not (isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance(node2, (SchedulerNode, FusedSchedulerNode))): + return False + + if len(base_template_node1) == 1 and len(base_template_node2) == 0 and extension_config.CONFIG_FUSION_REDUCTION: from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - if (node1.is_template() and len(node1.get_nodes())==1 and \ - (isinstance(node1.node.template, MLIRGemmTemplate) or isinstance(node1.node.template, MLIRBMMTemplate)) and \ - node2.is_reduction() and len(node2.get_nodes())==1): + if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction(): # For matmul/bmm+reduction case - size_match = node1.node.get_size() == node2.node.get_size() + node2.node.get_reduction_size() - if len(node1.node.get_size()) == len(node2.node.get_size()): - size_match = node1.node.get_size() == [dim for dim in node2.node.get_size() if dim!=1] + node2.node.get_reduction_size() - stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.node).split("\n") if "r0" in i][1] + size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) + stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] target_symbol = symbols("r0") # We can't fuse dim=-1 - possible = int(sympify(stride).coeff(target_symbol)) != 1 - return size_match and possible + layout_possible = int(sympify(stride).coeff(target_symbol)) != 1 + # Directed linked? + dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 + dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) + return size_match and layout_possible and dependency_check and dependency_size + + # For prologue fusion case + if extension_config.CONFIG_FUSION_PROLOGUE and len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + target_node = base_template_node2[0].node + if target_node.origin_node is not None and hasattr(target_node.origin_node.target, "_name") and target_node.origin_node.target._name == 'aten::convolution': + return False + if node1.is_reduction(): + return False + if len(node1.read_writes.writes) != 1: + return False + if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME + return False + + # Currently only BMM, MM support prologue fusion + if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + return False + # We don't fuse this edge case... + if base_template_node2[0].group[1][0][0] == 1: + return False + + if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: + node1 = self.revert_group(node1) + return True + return self.scheduler.can_fuse_origin(node1, node2) def _set_flush_status(self, status: bool): @@ -56,54 +95,78 @@ def can_fuse_horizontal(self, node1, node2): _, (vars2, reduce2) = node2.group # Reduction is currently not supported + if node1.is_reduction() and node2.is_reduction() and not node1.is_template() and not node2.is_template(): + return vars1 == vars2 and reduce1 == reduce2 and node1.inverse_users == node2.inverse_users if node1.is_reduction() or node2.is_reduction(): return False # Can't fuse two template node - nr_template = 0 - for node in node1.get_nodes() + node2.get_nodes(): - if node.is_template(): - nr_template += 1 - - if nr_template > 1: + if node1.is_template() and node2.is_template(): return False # Check template node fusion if node1.is_template() or node2.is_template(): # Don't fuse maxpool template code from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate - if node1.is_template() and len(node1.get_nodes())==1 and isinstance(node1.node.template, MLIRMaxPoolTemplate) or \ - node2.is_template() and len(node1.get_nodes())==1 and isinstance(node2.node.template, MLIRMaxPoolTemplate): - return False + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate + template_node1 = next((n for n in node1.get_nodes() if n.is_template()), None) + template_node2 = next((n for n in node2.get_nodes() if n.is_template()), None) - # Different layout is not supported - if node1.get_nodes()[0].node.layout.dtype != node2.get_nodes()[0].node.layout.dtype: + if template_node1 and len(node1.get_nodes()) == 1 and isinstance(template_node1.node.template, MLIRMaxPoolTemplate) or \ + template_node2 and len(node2.get_nodes()) == 1 and isinstance(template_node2.node.template, MLIRMaxPoolTemplate): return False - # Convolution is currently not supported - # if node1.is_template() and node1.get_nodes()[0].node.origin_node is not None and hasattr(node1.get_nodes()[0].node.origin_node.target, "_name") and node1.get_nodes()[0].node.origin_node.target._name == 'aten::convolution': - # return False - - # if node2.is_template() and node2.get_nodes()[0].node.origin_node is not None and hasattr(node2.get_nodes()[0].node.origin_node.target, "_name") and node2.get_nodes()[0].node.origin_node.target._name == 'aten::convolution': - # return False - + # Pointwise check v1_total = math.prod(vars1) if len(vars1) else 0 v2_total = math.prod(vars2) if len(vars2) else 0 if v1_total != v2_total: return False - has_depedency = False - template_node = node1 if node1.is_template() else node2 - act_node = node2 if node1.is_template() else node1 - for write_buf in template_node.read_writes.writes: - has_depedency = has_depedency or (write_buf in act_node.read_writes.reads) - return has_depedency + # Pattern check + template_node, act_node = (template_node1, node2) if template_node1 else (template_node2, node1) + has_depedency = set(act_node.inverse_users) <= set(template_node.get_nodes()) + if not has_depedency: + return False + + # Revert act_node.group : simplify_and_reorder() modified _body, _size, group + if template_node.group != act_node.group: + # We don't fuse this case... + if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: + return False + + if list(template_node.group[1][0]) != list(act_node.get_nodes()[0].node.data.get_size()): + return False + self.revert_group(act_node) + return True # Check elementwise fusion if vars1 == vars2 and reduce1 == reduce2: return True return False + def revert_group(self, act_nodes, args=None, var_ranges=None): + for act_node in act_nodes.get_nodes(): + if args is None or var_ranges is None: + args, var_ranges = dependencies.index_vars_no_squeeze( + act_node.node.data.get_size(), act_node.node.data.get_reduction_size(), prefix="q" + ) + body = LoopBody( + act_node.node.get_store_function(), + (args if act_node.node.get_reduction_type() else args[:1]), + var_ranges, + ) + index_size = [] + reduce_size = [] + for v, s in var_ranges.items(): + if v in args[0]: + index_size.append(s) + else: + reduce_size.append(s) + node_device = act_node.get_device() + ranges = (index_size, reduce_size) + act_node._sizes, act_node._body, act_node.group = (ranges), body, (node_device, self.group_fn(ranges)) + def group_fn(self, sizes): return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) @@ -111,13 +174,34 @@ def codegen_nodes(self, nodes): _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) ).group + + # Note: We assume that ther is at least one loop in the nodes + # But, inductor simplifies the group, there could be no loop + # In that case, we add dummy loop(size=1) to the group + if len(group) == 0: + for idx, node in enumerate(nodes): + if len(node.node.data.get_size()) == 0: + continue + if len(reduction_group) != 0: + sym0, sym1 = sympy.Symbol("q0"), sympy.Symbol("q1") + args = [[sym0] + [sympy.Number(0)] * (len(node.node.data.get_size())-1), [sym1]] + var_ranges = {sym0: sympy.Number(1), sym1: reduction_group[0]} + else: + sym0 = sympy.Symbol("q0") + args = [[sym0] + [sympy.Number(0)] * (len(node.node.data.get_size())-1), []] + var_ranges = {sym0: sympy.Number(1)} + self.revert_group(node, args, var_ranges) + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + ex_kernel = self.target_kernel(kernel_group=self.kernel_group) ex_kernel.kernel_group = self.kernel_group - kernel_name = f"extension_kernel_{MLIRScheduling.count}" + kernel_name_candidate = f"extension_kernel_{MLIRScheduling.count}" MLIRScheduling.count += 1 - src_code = ex_kernel.codegen_nodes(nodes, kernel_name) - self.define_kernel(src_code, kernel_name, ex_kernel.vector_lane, + src_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) + kernel_name = self.define_kernel(src_code, kernel_name_candidate, ex_kernel.vector_lane, ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) ex_kernel.call_kernel(kernel_name) _, args, _, _ = ex_kernel.args.mlir_argdefs() @@ -167,45 +251,109 @@ def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) return kernel_name - def codegen_template_code(self, kernel, render, template_node, epilogue_nodes): + def codegen_template_code(self, kernel, render, template_node, prologue_nodes, epilogue_nodes): with kernel: - for node in [template_node, *epilogue_nodes]: + _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() + for node in [template_node, *prologue_nodes, *epilogue_nodes]: node.mark_run() + # Partial codgen template nodes partial_code = render() - tile_desc = kernel.set_tile_size(kernel.store_info) + + # Swap load/store functions + kernel.load = kernel.load_epilogue + kernel.store = kernel.store_epilogue + kernel.store_reduction = kernel.store_reduction_epilogue + kernel.reduction = kernel.reduction_epilogue + + # Codegen prologue nodes + if prologue_nodes: + # Flush created varaibles, since template fusion doen't share variable + with kernel.prologue_buffer_group.as_local(): + _, (group, reduction_group) = max( + [prologue_nodes[-1]], key=lambda x: int(x.is_reduction()) + ).group + prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True) + kernel.kernel_group.set_tile_info(prologue_tile_desc) + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + for node in prologue_nodes: + # Reuse created spad + read_list = sorted(list(node.read_writes.reads)) + candidate_found = False + # Why? There is a case that memdep.get_size() != data.get_size() + buf_dict = {} + buf_dict.update({val.name : val for val in V.graph.buffers}) + buf_dict.update(V.graph.graph_inputs) + for candidate_read in read_list: + if candidate_read.name in buf_dict and reduce(operator.mul, buf_dict[candidate_read.name].get_size(), 1) == node.node.get_numel(): + prologue_input_arg = candidate_read.name + candidate_found = True + break + assert(candidate_found) + assert(len(node.read_writes.writes)==1) + prologue_output_arg = list(node.read_writes.writes)[0].name + template_buf = self.kernel_group.args.input_buffers[prologue_output_arg] + target_buf = f"{template_buf}_buffer" # FIXME. How to pass spad buffer name? + + # To skip the dma code gen + kernel.buffer_names[prologue_input_arg] = target_buf + kernel.buffer_names[prologue_output_arg] = target_buf + + # Edge delete + kernel.kernel_group.args.input_buffers = { + (arg if buf != template_buf else prologue_input_arg): buf + for arg, buf in kernel.kernel_group.args.input_buffers.items() + } + node.codegen((vars, reduction_vars)) + + # Codegen epilogue nodes + tile_desc = kernel.set_tile_size(kernel.epilogue_info) kernel.kernel_group.set_tile_info(tile_desc) + kernel.call_ranges = None if epilogue_nodes: - _, (group, reduction_group) = max( - epilogue_nodes, key=lambda x: int(x.is_reduction()) - ).group - vars, reduction_vars = kernel.set_ranges(group, reduction_group) - # Flush created varaibles, since template fusion doen't share variable - kernel.cse.cache.clear() - for node in epilogue_nodes: - if template_node.node.name in [dep[0] for dep in list(node.read_writes.reads)]: - kernel.store_info['dependent_buf'].append(node.node.name) - node.codegen((vars, reduction_vars)) + with kernel.epilogue_buffer_group.as_local(): + _, (group, reduction_group) = max( + epilogue_nodes, key=lambda x: int(x.is_reduction()) + ).group + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + for node in epilogue_nodes: + node.codegen((vars, reduction_vars)) + with V.set_kernel_handler(kernel): src_code = ( partial_code if isinstance(partial_code, str) else partial_code.finalize() ) - return src_code + + # For consistency, white space could make wrong write_path + buffer = IndentedBuffer() + buffer.splice(src_code) + return buffer.getvalue() def codegen_template(self, template_node, epilogue_nodes): + # Handle prologue pattern + prologue_nodes = [] + if not template_node.is_template(): + epilogue_nodes = [template_node] + epilogue_nodes + for i, node in enumerate(epilogue_nodes): + if node.is_template(): + template_node = node + prologue_nodes = epilogue_nodes[:i] + epilogue_nodes = epilogue_nodes[i+1:] + break + _, (numel, rnumel) = template_node.group template_buffer = template_node.node - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) + kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() - src_code = self.codegen_template_code(kernel, render, template_node, epilogue_nodes) + src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) wrapper = V.graph.wrapper_code if src_code in wrapper.src_to_kernel: # [CONV] check inner function is already defined kernel_name = wrapper.src_to_kernel[src_code] - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, epilogue_nodes=epilogue_nodes, kernel_name=kernel_name) # update kernel name - src_code = self.codegen_template_code(kernel, render, template_node, epilogue_nodes) + kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_name=kernel_name) # update kernel name + src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) with V.set_kernel_handler(kernel): spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" @@ -226,4 +374,15 @@ def codegen_template(self, template_node, epilogue_nodes): V.graph.wrapper_code.writeline( f"yield ({target_kernel_name}, ({args}))" ) - self._set_flush_status(True) \ No newline at end of file + self._set_flush_status(True) + + def enter_context_fixed(self, node): + def get_order(n): + if n not in self.scheduler.origin_to_index: + self.scheduler.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) + return self.scheduler.origin_to_index[n] + + origins = [(get_order(e), idx, e) for n in node.get_nodes() for idx, e in enumerate(n.node.origins)] + if origins: + _, _, last = max(origins) + V.graph.wrapper_code.enter_context(last) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index a0537201..0455cbf1 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -2,6 +2,7 @@ import itertools import textwrap import re +import os import contextlib import math import sympy @@ -11,7 +12,7 @@ from unittest.mock import patch from torch._inductor.codegen.common import Kernel, KernelTemplate, ChoiceCaller, OpOverrides, CSE, DeferredLine -from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, Pointwise +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, View from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta @@ -22,9 +23,68 @@ from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, reduction_combine_vec, is_welford_reduction from PyTorchSimFrontend.mlir.mlir_scheduling import SchedulerNode +from torch._inductor.codegen import common +from PyTorchSimFrontend.extension_config import CONFIG_TORCHSIM_DIR from . import mlir_common +class IndentedBufferGroup: + def __init__(self, kernel: 'MLIRTemplateKernel', prefix=""): + self.kernel = kernel + self.body = IndentedBuffer() + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.applys = IndentedBuffer() + self.dma_loads = IndentedBuffer() + self.dma_stores = IndentedBuffer() + self.spad_buffer = IndentedBuffer() + self.cse = common.CSE("%", "", name_prefix=f"{prefix}") + self.apply_cse = common.CSE("%", "", name_prefix=f"{prefix}apply") + # Original buffers will be saved later in the 'with' block + self.original_buffers = {} + + def set_buffers(self): + self.kernel.loads = self.loads + self.kernel.compute = self.compute + self.kernel.stores = self.stores + self.kernel.applys = self.applys + self.kernel.dma_loads = self.dma_loads + self.kernel.dma_stores = self.dma_stores + self.kernel.spad_buffer = self.spad_buffer + self.kernel.cse = self.cse + self.kernel.apply_cse = self.apply_cse + + def restore_buffers(self): + self.kernel.loads = self.original_buffers['loads'] + self.kernel.compute = self.original_buffers['compute'] + self.kernel.stores = self.original_buffers['stores'] + self.kernel.applys = self.original_buffers['applys'] + self.kernel.dma_loads = self.original_buffers['dma_loads'] + self.kernel.dma_stores = self.original_buffers['dma_stores'] + self.kernel.spad_buffer = self.original_buffers['spad_buffer'] + self.kernel.cse = self.original_buffers['cse'] + self.kernel.apply_cse = self.original_buffers['apply_cse'] + + @contextlib.contextmanager + def as_local(self): + self.original_buffers = { + 'loads': self.kernel.loads, + 'compute': self.kernel.compute, + 'stores': self.kernel.stores, + 'applys': self.kernel.applys, + 'dma_loads': self.kernel.dma_loads, + 'dma_stores': self.kernel.dma_stores, + 'spad_buffer': self.kernel.spad_buffer, + 'cse': self.kernel.cse, + 'apply_cse': self.kernel.apply_cse, + } + try: + self.set_buffers() + yield self + finally: + self.restore_buffers() + class MLIRTemplateKernel(MLIRKernel, BaseMLIRHardwareInfo): def __init__(self, kernel_name, @@ -40,8 +100,6 @@ def __init__(self, self.call_size = call_size self.named_nodes = {} self.loop_info = {} - self.load_desc = {} - self.store_desc = {} self.outer_func_name = outer_func_name self.outer_func_render = outer_func_render self.kernel_arg_attributes = kernel_arg_attributes @@ -50,19 +108,23 @@ def __init__(self, self.render_options = dict() self.tile_size = [] self.loop_size = None - self.is_template_kernel = True - self.map_cse = CSE("#", self.suffix, name_prefix="template_map") - self.const_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="template_const") - self.alloc_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="template_alloc") + self.map_cse = CSE("#", self.suffix, name_prefix="t_map") + self.const_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="t_const") + self.alloc_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="t_alloc") + self.prologue_buffer_group = IndentedBufferGroup(self, prefix="prologue_") + self.epilogue_buffer_group = IndentedBufferGroup(self, prefix="epilogue_") + self.global_vars = IndentedBuffer() + self.exception_nodes = {} + # Reduction data structure self.reduction_epilogue_suffix = IndentedBuffer() self.reduction_fusion = False - self.reduction_idx = None - - # Overwrite ops - self.load = self.load_epilogue - self.store = self.store_epilogue - self.store_reduction = self.store_reduction_epilogue - self.reduction = self.reduction_epilogue + self.reduction_body_loop = None + self.reduction_buffer_idx = 0 + self.reduction_info = {} + self.reduction_epilogue_result = {} + self.reduction_mean = [] + # Dim info + self.dim_aliasing = {} def add_loop_info(self, mat_size, tile_size): for idx, (loop_size, stride) in enumerate(zip(mat_size, tile_size)): @@ -123,13 +185,12 @@ def gemmini_gemm_mapping(self, M, N, K): return inner_I, inner_J, inner_K - def gemm_combination_mapping(self, M, N, K, n_extra_node=0, pad_k=True, min_tile=False): + def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False): spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 # double buffer max_spad_per_lane = spad_size_per_lane // 2 # double buffer - force_double_buffer = 2 if n_extra_node > 0 else 1 # In fusion case, double buffer should be forced - minimum_n_tile = self.num_cores * force_double_buffer if min_tile else 1 + minimum_n_tile = self.num_cores if min_tile else 1 m_pad_factor = self.vector_lane if M > self.vector_lane else 8 n_pad_factor = self.vector_lane if N > self.vector_lane else 8 k_pad_factor = self.vector_lane if K > self.vector_lane else (8 if pad_k else 1) @@ -145,19 +206,46 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, pad_k=True, min_tile tile_N_range = sympy.divisors(indexJ) if N > self.vector_lane else [1] tile_K_range = sympy.divisors(indexK) if K > self.vector_lane else [1] maximize_i_j = 1 # reuse weight - for k in tile_K_range: + for k in tile_K_range: # store tile candidates for manual mapping + tile_K = k * self.vector_lane if K > self.vector_lane else K_padded + for i in tile_M_range: + tile_M = i * self.vector_lane if M > self.vector_lane else M_padded + for j in tile_N_range: + tile_N = j * self.vector_lane if N > self.vector_lane else N_padded + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) + input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) + output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size: + dir_path = f"{CONFIG_TORCHSIM_DIR}/validation/gemm_candidates" + os.makedirs(dir_path, exist_ok=True) + file_path = f"{dir_path}/gemm_{M}_{K}_{N}.txt" + line_to_write = f"{tile_M} {tile_K} {tile_N}\n" + try: + with open(file_path, "r") as f: + lines = f.readlines() + except FileNotFoundError: + lines = [] + if line_to_write not in lines: + with open(file_path, "a") as f: + f.write(line_to_write) + + for k in tile_K_range: # heuristic search tile_K = k * self.vector_lane if K > self.vector_lane else K_padded for i in tile_M_range: tile_M = i * self.vector_lane if M > self.vector_lane else M_padded for j in tile_N_range: tile_N = j * self.vector_lane if N > self.vector_lane else N_padded - used_spad_size = (tile_M * tile_K + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) - input_size_per_lane = self.get_spad_size_per_lane(tile_M, tile_K) + input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision - n_tile = math.ceil(M / tile_M) * math.ceil(N / tile_N) - if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile: + n_tile = math.ceil(M / max(tile_M, 128)) * math.ceil(N / max(tile_N, 128)) + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size and max_used_spad_size < used_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and max(tile_N, 128) // max(tile_M, 128) < 10: max_used_spad_size = used_spad_size maximize_i_j = tile_M * tile_N mapping = (tile_M, tile_N, tile_K) @@ -309,8 +397,6 @@ def meta_kernel(self): wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') # Dump loop and load/store information wrapper.add_import_once(f"loop_info = {self.loop_info}") - wrapper.add_import_once(f"load_tile_info = {self.load_desc}") - wrapper.add_import_once(f"store_tile_info = {self.store_desc}") wrapper.add_import_once(f"arg_attributes = {arg_attributes}") def call_kernel(self, kernel_name): @@ -321,43 +407,65 @@ def call_kernel(self, kernel_name): kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args, cuda=False) - def codegen_body(self): + def codegen_prologue_body(self): + body = IndentedBuffer() + with self.prologue_buffer_group.as_local(): + body.splice(self.spad_buffer) + body.splice(self.applys) + body.splice(self.dma_loads) + + if (self.loads.getvalue() != '' or self.compute.getvalue() != '' or self.stores.getvalue() != ''): + body.writelines(self.prologue_compute_body_loop.lines()) + compute_body = mlir_common.ParallelLoopBuffer() + with contextlib.ExitStack() as stack: + stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) + compute_body.splice(self.loads) + compute_body.splice(self.compute) + compute_body.splice(self.stores) + body.splice(compute_body) + body.splice(self.dma_stores) + return body + + def codegen_epilogue_body(self): def template_store(): - zero_cse = self.get_const_cse(0) - sram_var = self.store_info["sram_var"] - dram_var = self.store_info["dram_var"] - index_var = self.store_info["index_var"] - tag_var = self.store_info["tag_var"] - mlir_dtype = self.store_info["mlir_dtype"] - dram_shape = self.store_info["dram_shape"] - vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis - vlane_stride = self.kernel_group.tile_desc.get_vlane_stride() - tile_stride = self.store_info["tile_stride"] - tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) - sram_index_var = ",".join([f"%{zero_cse}"] * self.kernel_group.tile_desc.get_nr_dim()) - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - tag_var, dram_shape, tile_shape, tile_stride) + dram_var = self.epilogue_info["dram_var"] + index_list = self.epilogue_info["dram_idx"] + tile_desc = self.epilogue_info["dram_tile_desc"] + code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc) self.cse.generate(self.dma_stores, code, assignment = False) - self.body.splice(self.spad_buffer) - self.body.splice(self.applys) - self.body.splice(self.dma_loads) - self.body.writelines(self.compute_body_loop.lines()) - compute_body = mlir_common.ParallelLoopBuffer() - with contextlib.ExitStack() as stack: - stack.enter_context(compute_body.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) - compute_body.splice(self.loads) - compute_body.splice(self.compute) - if len(self.stores._lines) == 0: - template_store() - compute_body.splice(self.stores) - self.body.splice(compute_body) - self.body.splice(self.dma_stores) - self.body.splice(self.reduction_epilogue_suffix) - - # Clear buffers - self.loads.clear() - self.compute.clear() - self.stores.clear() + + body = IndentedBuffer() + with self.epilogue_buffer_group.as_local(): + # Do dma store first to overlap epilogue nodes + if self.reduction_fusion: + if len(self.stores._lines) == 0: + template_store() + body.splice(self.dma_stores) + self.dma_stores.clear() + body.splice(self.spad_buffer) + body.splice(self.applys) + body.splice(self.dma_loads) + body.writelines(self.compute_body_loop.lines()) + compute_body = mlir_common.ParallelLoopBuffer() + with contextlib.ExitStack() as stack: + stack.enter_context(compute_body.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) + if self.reduction_fusion: + compute_body.writelines(self.reduction_body_loop.lines()) + compute_body.splice(self.masks) + stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) + compute_body.splice(self.loads) + compute_body.splice(self.compute) + else: + compute_body.splice(self.loads) + compute_body.splice(self.compute) + if len(self.stores._lines) == 0: + template_store() + compute_body.splice(self.stores) + if (compute_body.getvalue()): + body.splice(compute_body) + body.splice(self.dma_stores) + body.splice(self.reduction_epilogue_suffix) + return body def def_kernel( self, @@ -394,7 +502,7 @@ def def_kernel( extra_node[node.get_name()] = node.node else: extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.store_info['sram_var'] + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] def hook(): arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) @@ -439,7 +547,7 @@ def def_conv_kernel( self.kernel_group.args.output_buffers[node.get_name()] = name self.store_buffer_names.add(node.get_name()) #TODO: Is this enough not calling store() in mlir_common.py? self.extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.store_info['sram_var'] #TODO: Buffer name fixed + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed def kernel_hook(): arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=self.extra_node) @@ -467,22 +575,42 @@ def get_conv_inputs(self): def get_conv_outputs(self): return {k: v for k, v in self.kernel_group.args.output_buffers.items() if v != 'REMOVED'} - def output_name(self): - # Cannot know the output name from the template, so we need to hook it + def load_input(self, indent_size: int = 0): def hook(): - arg_defs, *_ = self.kernel_group.args.mlir_argdefs() - output = arg_defs[3] #FIXME: Constant index used - pattern = r"%(\w+):" - output = re.search(pattern, output).group(1) - return output - assert "" not in self.render_hooks - self.render_hooks[""] = hook - return "" + code = IndentedBuffer() + prologue_code = self.codegen_prologue_body() + if prologue_code.getvalue(): + input_dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + weight_dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + if (self.prologue_info["is_input_fused"]): + code.splice(input_dma_code) + code.splice(prologue_code) + code.splice(weight_dma_code) + else: + code.splice(weight_dma_code) + code.splice(prologue_code) + code.splice(input_dma_code) + else: + dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + code.splice(dma_code) + dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + code.splice(dma_code) + code = textwrap.indent(code.getvalue(), " "*indent_size).strip() + return code + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + return "" def store_output(self, indent_size: int = 0): def hook(): - self.codegen_body() - return textwrap.indent(self.body.getvalue(), " "*indent_size).strip() + epilogue_code = self.codegen_epilogue_body() + return textwrap.indent(epilogue_code.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks self.render_hooks[""] = hook @@ -497,29 +625,6 @@ def hook(): self.render_hooks[""] = hook return "" - def reduction_iter_arg(self): - def hook(): - if len(self.reduction_vars): - args = ', '.join([f"%{iter.name} = %{init.name}" for (_, iter, init, _) in self.reduction_vars.values()]) - dtype = ', '.join([f"{dtype}" for (_, _, _, dtype) in self.reduction_vars.values()]) - return f"iter_args({args}) -> ({dtype})" - return "" - - assert "" not in self.render_hooks - self.render_hooks[""] = hook - return "" - - def reduction_acc(self): - def hook(): - if len(self.reduction_vars): - acc = ', '.join([f"%{acc.name}" for acc in self.reduction_vars.keys()]) - return f"{acc} =" - return "" - - assert "" not in self.render_hooks - self.render_hooks[""] = hook - return "" - def def_function(self): _, call_args, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: @@ -540,26 +645,75 @@ def hook(): self.render_hooks[key] = hook return key - def def_local_vars(self): + def def_local_vars(self, indent_size=0): key = "" def hook(): code = IndentedBuffer() - code.tabwidth = 2 - code.splice("\n") - with code.indent(): - code.splice(self.const_buffer) - code.splice(self.alloc_buffer) - return code.getvalue() + code.tabwidth = 1 + code.splice(self.const_buffer) + code.splice(self.alloc_buffer) + return textwrap.indent(code.getvalue(), " "*indent_size).strip() assert key not in self.render_hooks self.render_hooks[key] = hook return key + def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, + subtile_size:list=[], async_type=None, indent_size=0): + # Prepare code block + local_code = IndentedBuffer() + with V.set_kernel_handler(self): + index_var = self.parse_index_list(index_list, local_code) + node_layout = self.named_nodes[dram_var].get_layout() + numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() + if dram_var in self.exception_nodes: + numel = self.exception_nodes[dram_var]["numel"] + mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] + dram_shape = f"memref<{numel}x{mlir_dtype}>" + dram_stride = [] + for idx in index_list: + if idx.is_Mul: + dram_stride.append(int(idx.args[0])) + elif idx == sympy.Symbol("c0"): + dram_stride.append(0) + elif not idx.is_Number: + dram_stride.append(1) + else: + dram_stride.append(0) + + sram_var = tile_desc.get_name() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = tile_desc.get_tile_stride() + vlane_split_axis = tile_desc.vlane_split_axis + vlane_stride = tile_desc.vlane_stride + + zero_cse = self.get_const_cse(0, "index") + sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) + + attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] + if subtile_size: + attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") + attribute = " {" + ", ".join(attribute_parts) + "}" + code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, "") + local_code.writeline(code) + local_code.writeline(attribute) + return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + + def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): + # Prepare code block + with V.set_kernel_handler(self): + dtype = self.named_nodes[dram_name].get_layout().dtype + tile_shape = tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[dtype]) + buffer_name = self.allocate_sram_buffer(dtype, dram_name, tile_desc, id, forced_name=dram_name) + code = f"%{tile_desc.name} = memref.get_global @{buffer_name} : {tile_shape}" + return textwrap.indent(code, " "*indent_size).strip() + def render(self, template, kwargs, define_function=None): - # self.render_hooks = {} code = template.render(**kwargs) if define_function is not None: define_function(self) + return PartialRender( code, self.render_hooks, @@ -570,19 +724,19 @@ def get_spad_size_per_lane(self, tile_m, tile_n): return max(size, 2) # vector load/store def load_epilogue(self, name: str, index: sympy.Expr): - load_dim = [] - if not isinstance(V.graph, NullHandler) and name in V.graph.graph_inputs: - load_dim = V.graph.graph_inputs[name].layout.size - index_var = self.store_info['index_var'] if len(load_dim) != 1 else 'tile_n' index = self.rename_indexing(index) dram_var = self.kernel_group.args.input(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis if len(load_dim) != 1 else 0 # FIXME: Fixed split axis for 1d load dim - vlane_stride = self.kernel_group.tile_desc.vlane_stride if len(load_dim) != 1 else 1 # FIXME: Fixed stride for 1d load dim - tile_numel_per_lane = self.kernel_group.tile_desc.get_numel_per_lane() + + # Want to use tile_desc from epilogue_info + index_var = self.parse_indices(index) + dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()] + vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis + vlane_stride = self.kernel_group.tile_desc.vlane_stride tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = self.store_info['tile_stride'] + tile_stride = self.kernel_group.tile_desc.get_tile_stride() # Compute vector unit size vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) @@ -591,14 +745,16 @@ def load_epilogue(self, name: str, index: sympy.Expr): if name not in self.buffer_names: # Allocate sram buffer dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index) - self.buffer_names[name] = sram_var + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) + dram_shape, tile_shape, attribute) self.cse.generate(self.dma_loads, code, assignment = False) + self.buffer_names[name] = sram_var + else: + sram_var = self.buffer_names[name] # Load vector from sram - sram_var = self.buffer_names[name] zero_var = self.get_const_cse(0) if not self.reduction_fusion: compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) @@ -609,56 +765,51 @@ def load_epilogue(self, name: str, index: sympy.Expr): operation = "affine.load" line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" out = self.cse.generate(self.loads, line) + self.register_var_info(out, [compute_vec_size, mlir_dtype]) else: # For reduction case reduce_size = self.reduction_nr_outer_loop vsize = compute_vec_size//reduce_size vshape = f"vector<{vsize}x{mlir_dtype}>" - flatten_tshape = f"vector<{compute_vec_size}x{mlir_dtype}>" - init = self.cse.generate(self.loads, f"arith.constant 0.0 : {mlir_dtype}") - init_vec = self.cse.generate(self.loads, f"vector.broadcast %{init} : {mlir_dtype} to {flatten_tshape}") if compute_vec_size > 1: - out_list = [] - for i in range(reduce_size): - offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0) -> (d0 + {i*(self.reduction_axis_size)})>(%{self.compute_idx})") - compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"]) - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - out = self.cse.generate(self.loads, line) - out_list.append(out) - for idx, partial_out in enumerate(out_list): - init_vec = self.cse.generate(self.loads, f"vector.insert_strided_slice %{partial_out}, %{init_vec} {{offsets=[{vsize*idx}],strides=[1]}} : {vshape} into {flatten_tshape}") - out = init_vec + offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.reduction_axis_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})") + compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"]) + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + out = self.cse.generate(self.loads, line) else: line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" out = self.cse.generate(self.loads, line) - self.register_var_info(out, [compute_vec_size, mlir_dtype]) + self.register_var_info(out, [self.compute_body_loop.step, mlir_dtype]) return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): - index_var = self.store_info['index_var'] + index = self.rename_indexing(index) dram_var = self.kernel_group.args.output(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + + index_var = self.parse_indices(index) + dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()] vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis vlane_stride = self.kernel_group.tile_desc.vlane_stride - tile_numel_per_lane = self.kernel_group.tile_desc.get_numel_per_lane() - - dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = self.store_info['tile_stride'] + tile_stride = self.kernel_group.tile_desc.get_tile_stride() # Compute vector unit size vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() if name not in self.buffer_names: - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index) + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) self.buffer_names[name] = sram_var + store_force = False else: zero_cse = self.get_const_cse(0) sram_dims = len(tile_shape.split("x")) - 1 sram_index_var = ",".join([f"%{zero_cse}"] * sram_dims) + store_force = True sram_var = self.buffer_names[name] zero_var = self.get_const_cse(0) @@ -673,174 +824,213 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): else: operation = "affine.store" line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}" - self.stores.writeline(DeferredLine(name, line)) + line = line if store_force else DeferredLine(name, line) + self.stores.writeline(line) # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) + dram_shape, tile_shape, attribute) self.dma_stores.writeline(DeferredLine(name, code)) def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): argmax_or_argmin = reduction_type in {"argmax", "argmin"} - if argmax_or_argmin or is_welford_reduction(reduction_type): + if argmax_or_argmin: raise NotImplementedError() #TODO: argmin, argmax - - # Prepare reduction loop + if is_welford_reduction(reduction_type): + if reduction_type == "welford_combine": + raise NotImplementedError("welford_combine") + else: + assert reduction_type == "welford_reduce" + type_name = mlir_common.DTYPE_TO_MLIR[dtype] + reduction_key = src_dtype, reduction_type, value + sum = self.reduction_epilogue(dtype, src_dtype, "sum", value) + sqr_sum = self.reduction_epilogue(dtype, src_dtype, "sum", ops.mul(value, value)) + self.welford_reduce_out = (sum, sqr_sum, None) + return sum, sqr_sum, None + + # Check duplicated reductions reduction_key = src_dtype, reduction_type, value - acc = self.reduction_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - iterator = self.iterator_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - init = self.init_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - init_vec = self.init_vec_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - vec_len = self.kernel_group.tile_desc.get_compute_vec_size() - reduced_shape = self.kernel_group.tile_desc.get_mlir_vshape(type_name) + if reduction_key in self.reduction_epilogue_result: + return self.reduction_epilogue_result[reduction_key] - # Set accumulation var - if vec_len == 1: # 1-D vector to scalar - # Edge case for scalar - init_vec = init - else: - # Adjust shape and inital value - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") - acc_var = init_vec - - # Reduction body prepare - body_acc = self.reduction_cse.generate( - self.compute, f"reduction {reduction_key}body_acc", write=False - ) - body_iter_arg = self.iterator_cse.generate( - self.compute, f"reduction {reduction_key}body_iter_arg", write=False - ) - self.register_var_info(body_iter_arg, [vec_len, type_name]) - - self.reduction_vars[acc] = (reduction_type, iterator, acc_var, reduced_shape) - self.affine_yield[body_acc] = reduced_shape - self.reduction_cse.reduction_cache[reduction_key] = acc - self.iterator_cse.reduction_cache[reduction_key] = iterator - self.init_cse.reduction_cache[reduction_key] = init_vec + # Reduction fusion codegen part + vec_size = self.compute_body_loop.step + type_name = mlir_common.DTYPE_TO_MLIR[dtype] + new_tile_size = self.kernel_group.tile_desc.get_tile_size()[:-1] + [vec_size] + new_vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis + new_vlane_stride = self.kernel_group.tile_desc.vlane_stride + local_tile_desc = mlir_common.MLIRMultiDimTile(new_tile_size, self.vector_lane, new_vlane_split_axis, new_vlane_stride, vec_size) + + tile_shape = local_tile_desc.get_mlir_shape(type_name) + vshape = local_tile_desc.get_mlir_vshape(type_name) + + name = f"{reduction_type}_buffer{self.reduction_buffer_idx}" + self.reduction_buffer_idx += 1 + index = "dummy_index" # Not used + sram_var, _ = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index, self.const_buffer) + self.reduction_epilogue_result[reduction_key] = sram_var + + # Load partial result + zero_var_list = [f"%{self.get_const_cse(0)}"] * local_tile_desc.get_nr_dim() + zero_var_list[-2] = f"%{self.reduction_loop_idx}" + compute_index_var = ", ".join(zero_var_list) + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + out = self.cse.generate(self.loads, line) + self.register_var_info(out, [self.compute_body_loop.step, type_name]) # Reduction body codegen - result = reduction_partial_combine_vec(reduction_type, value, body_iter_arg) - self.compute_body_loop.reduction_vars[body_acc] = (reduction_type, body_iter_arg, iterator, reduced_shape) - self.compute_body_loop.affine_yield[result] = reduced_shape - - # Final reduction - reduction_size = self.reduction_nr_outer_loop - if vec_len > reduction_size: - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - if reduction_size == 1: - final_reduced_shape = f"{type_name}" - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(reduction_type, acc, init, axis=0, shape=reduced_shape, reduced_shape=final_reduced_shape)) - else: - final_reduced_shape = f"vector<{reduction_size}x{type_name}>" - init_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{init} : {type_name} to {final_reduced_shape}") - new_vshape= f"vector<{reduction_size}x{vec_len//reduction_size}x{type_name}>" - partial_vshape= f"vector<{vec_len//reduction_size}x{type_name}>" - value = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{acc} : {reduced_shape} to {new_vshape}") - # FIXME. I want to use N-Rank multi-reduciton, but we can't use it. It lowerd to scalar operations now... - for i in range(reduction_size): - partial_value = self.cse.generate(self.reductions_suffix, f"vector.extract %{value}[{i}] : {partial_vshape} from {new_vshape}") - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(reduction_type, partial_value, init, axis=0, shape=partial_vshape, reduced_shape=type_name)) - init_vec = self.cse.generate(self.reductions_suffix, f"vector.insert %{out}, %{init_vec}[{i}] : {type_name} into {final_reduced_shape}") - out = init_vec - acc = out - - # reigster reduction output - var_info = [reduction_size, mlir_common.DTYPE_TO_MLIR[dtype]] - self.register_var_info(acc, var_info) - - # Specail handling for fusion - self.reduction_epilogue_suffix.writeline(f"affine.yield %{body_acc} : {self.affine_yield[body_acc]}") - return acc + init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {vshape}") + self.register_var_info(init_vec, [local_tile_desc.get_compute_vec_size(), type_name]) + mask_shape, mask_var = self.get_mask() + if mask_var is not None: + value = ops.where(mask_var, value, init_vec) + result = reduction_partial_combine_vec(reduction_type, value, out) + + # Store partial result + operation = "affine.vector_store" + line = f"{operation} %{result}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + self.compute.writeline(line) # Need to be placed after partial reduction + self.reduction_info[sram_var] = [reduction_type, local_tile_desc] + return sram_var def store_reduction_epilogue(self, name, index, value): - index = self.reduction_idx - tmp_cse = self.cse - self.cse = self.reduction_cse - + index = self.rename_indexing(index) dram_var = self.kernel_group.args.output(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - index = self.rename_indexing(index) - # Tile is always reuduced in inner loop - numel_per_lane = self.kernel_group.tile_desc.get_numel_per_lane() - reduction_axis_size = self.kernel_group.tile_desc.get_tile_size()[-2] - nr_outer_loop = numel_per_lane // reduction_axis_size - - vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis - 1 + index_var = self.parse_indices(index, self.reductions_suffix, comments="// Store reduction") + dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()][:-1] # Assume that there is only one reduction axis + vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis vlane_stride = self.kernel_group.tile_desc.vlane_stride - tile_numel_per_lane = vlane_stride * nr_outer_loop - dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) - tile_shape = f"memref<{self.kernel_group.tile_desc.get_tile_size()[1]}x{mlir_dtype}, 1>" - tile_stride = [1] - compute_vec_size = self.var_info[value][0] - if compute_vec_size == 1: - vshape = f"{mlir_dtype}" - else: - vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index, - index, buffer=self.const_buffer) + # Create final buffer descriptor + nr_outer_loop = self.reduction_nr_outer_loop + tile_size = self.kernel_group.tile_desc.get_tile_size()[:-1] + final_tile_desc = mlir_common.MLIRMultiDimTile(tile_size, self.vector_lane, vlane_split_axis, vlane_stride*nr_outer_loop*2) + final_tile_shape = final_tile_desc.get_mlir_shape(mlir_dtype) + final_tile_stride = final_tile_desc.get_tile_stride() + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, final_tile_desc, index, buffer=self.const_buffer) + + # Set partial buffer descriptor + partial_tile_desc = self.reduction_info[value][1] + partial_vec_size = partial_tile_desc.get_compute_vec_size() + partial_vshape = partial_tile_desc.get_mlir_vshape(mlir_dtype) + partial_tile_shape = partial_tile_desc.get_mlir_shape(mlir_dtype) + + # Prepare constant + init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(self.reduction_info[value][0], dtype)} : {mlir_dtype}") + partial_zero_var_list = [f"%{self.get_const_cse(0)}"] * partial_tile_desc.get_nr_dim() + final_zero_var_list = [f"%{self.get_const_cse(0)}"] * final_tile_desc.get_nr_dim() + for i in range(self.reduction_body_loop.size): + # Load partial result + body_index_var = self.const_cse.generate(self.const_buffer, f"arith.constant {i} : index") + partial_zero_var_list[-2] = f"%{body_index_var}" + compute_index_var = ",".join(partial_zero_var_list) + + operation = "affine.vector_load" + line = f"{operation} %{value}[{compute_index_var}] : {partial_tile_shape}, {partial_vshape}" + out = self.cse.generate(self.reductions_suffix, line) + operation = "affine.vector_store" + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {mlir_dtype} to {partial_vshape}") + line = f"{operation} %{init_vec}, %{value}[{compute_index_var}] : {partial_tile_shape}, {partial_vshape}" + self.reductions_suffix.writeline(line) + + # 2 step reduction + new_vec_size = 2 + new_vshape = f"vector<{partial_vec_size//new_vec_size}x{new_vec_size}x{mlir_dtype}>" + new_reduced_shape = f"vector<{new_vec_size}x{mlir_dtype}>" + out = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{out} : {partial_vshape} to {new_vshape}") + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {mlir_dtype} to {new_reduced_shape}") + out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(self.reduction_info[value][0], out, init_vec, axis=0, shape=new_vshape, reduced_shape=new_reduced_shape)) + out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}") + + self.compute, self.reductions_suffix = self.reductions_suffix, self.compute + self.register_var_info(out, [new_vec_size, mlir_dtype]) + self.register_var_info(out2, [new_vec_size, mlir_dtype]) + out = reduction_partial_combine_vec(self.reduction_info[value][0], out, out2) + self.compute, self.reductions_suffix = self.reductions_suffix, self.compute + + if self.welford_reduce_out is not None: + # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2 + divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.reduction_axis_size)} : f32") + if self.reduction_axis_size - 1 > 0: + divider2 = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.reduction_axis_size-1)} : f32") + else: + divider2 = divider - if self.welford_reduce_out is not None: - raise NotImplementedError() + if self.buffer_types[name][1] > 1: + divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to {new_reduced_shape}") + else: + divider_vec = divider - # Select src type - if compute_vec_size == 1: - operation = "affine.store" - line = f"{operation} %{value}, %{sram_var}[{sram_index_var}] : {tile_shape}" - else: - operation = "affine.vector_store" - line = f"{operation} %{value}, %{sram_var}[{sram_index_var}] : {tile_shape}, {vshape}" - self.reductions_suffix.writeline(DeferredLine(name, line)) + if self.current_node.node.origin_node: # FIXME: This is a temporary solution + # mean = SUM(X) / N + self.reduction_mean.append(self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}")) + out = self.reduction_mean[i] + else: + # m2 = (E(X^2) - E(X)^2) * N + sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}") + mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{self.reduction_mean[i]}, %{self.reduction_mean[i]} : {new_reduced_shape}") + variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {new_reduced_shape}") + m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {new_reduced_shape}") + out = m2 + + final_zero_var_list[-1] = f"%{body_index_var}" + final_compute_index_var = ",".join(final_zero_var_list) + operation = "affine.vector_store" + line = f"{operation} %{out}, %{sram_var}[{final_compute_index_var}] : {final_tile_shape}, {new_reduced_shape}" + self.reductions_suffix.writeline(DeferredLine(name, line)) # MVOUT Encoding # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={final_tile_stride}, padding=0}}" code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) + dram_shape, final_tile_shape, attribute) self.reductions_suffix.writeline(DeferredLine(name, code)) - # Restore origin cse - self.cse = tmp_cse - - def get_scratchpad_buffer(self, dtype, name, tile_size_per_lane, dram_tile_shape, index_var, raw_index, buffer=None): - return super().get_scratchpad_buffer(dtype, name, tile_size_per_lane, dram_tile_shape, index_var, raw_index, True, buffer=buffer) + def set_tile_size(self, template_fusion_info, prologue=False): + tile_desc = template_fusion_info["dram_tile_desc"] + if "dim_aliasing" in template_fusion_info: + self.dim_aliasing = template_fusion_info["dim_aliasing"] - def set_tile_size(self, template_store_info): - tile_desc = mlir_common.MLIRMultiDimTile(template_store_info['tile_size'], - self.vector_lane, - vlane_split_axis=template_store_info['vlane_split_axis'], - vlane_stride=template_store_info['vlane_stride']) - - if 'nr_rdim' in template_store_info and template_store_info['nr_rdim']==1: + if 'nr_rdim' in template_fusion_info and template_fusion_info['nr_rdim']==1: tile_desc.nr_rdim = 1 numel_per_lane = tile_desc.get_numel_per_lane() - reduction_axis_size = tile_desc.get_tile_size()[-2] + reduction_axis_size = tile_desc.get_tile_size()[-1] nr_outer_loop = (numel_per_lane + reduction_axis_size-1) // reduction_axis_size - tile_desc.vec_size = nr_outer_loop * 2 # Why? Emprically selected, other option failed to functionality... + tile_desc.vec_size = nr_outer_loop * 32 # Why? Emprically selected, other option failed to functionality... self.reduction_fusion = True - self.reduction_axis_size = tile_desc.get_tile_size()[-2] - self.reduction_nr_outer_loop = (numel_per_lane + reduction_axis_size-1) // reduction_axis_size - self.reduction_idx = template_store_info["reduction_idx"] + self.reduction_axis_size = tile_desc.get_tile_size()[-1] + self.reduction_nr_outer_loop = nr_outer_loop + self.reduction_loop_idx = "reduce_loop_idx" self.compute_body_loop.size = reduction_axis_size self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop + self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) else: tile_desc.vec_size=64 - self.compute_body_loop.size = tile_desc.get_numel_per_lane() - self.compute_body_loop.step = tile_desc.get_compute_vec_size() + + if prologue: + self.prologue_compute_body_loop.size = tile_desc.get_numel_per_lane() + self.prologue_compute_body_loop.step = tile_desc.get_compute_vec_size() + else: + self.compute_body_loop.size = tile_desc.get_numel_per_lane() + self.compute_body_loop.step = tile_desc.get_compute_vec_size() return tile_desc + def rename_indexing(self, index) -> sympy.Expr: + for dim_name, dim_aliased_name in self.dim_aliasing.items(): + index = index.subs(sympy.Symbol(dim_name), sympy.Symbol("tmp_"+dim_aliased_name)) + # To avoid this case ({"index0":"index1", "index1":"index0"}) + for dim_aliased_name in self.dim_aliasing.values(): + index = index.subs(sympy.Symbol("tmp_"+dim_aliased_name), sympy.Symbol(dim_aliased_name)) + return index + class MLIRTemplateCaller(CUDATemplateCaller): def __str__(self): return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})" @@ -890,6 +1080,7 @@ def generate(self, **kwargs) -> ChoiceCaller: def make_kernel_render( template_node: TemplateBuffer, + prologue_nodes: Optional[List[IRNode]] = None, epilogue_nodes: Optional[List[IRNode]] = None, kernel_name: str = kernel_hash_name, kernel_group: Optional[mlir_common.MLIRWrapperKenrelGroup] = None @@ -910,7 +1101,8 @@ def make_kernel_render( kwargs = { 'kernel': kernel, 'template_buffer_node': template_node, - 'epilogue_nodes': epilogue_nodes + 'epilogue_nodes': epilogue_nodes, + 'prologue_nodes': prologue_nodes, } render = functools.partial( kernel.render, diff --git a/experiments/BERT.py b/experiments/BERT.py index e111908e..3534505d 100644 --- a/experiments/BERT.py +++ b/experiments/BERT.py @@ -7,7 +7,8 @@ def run_BERT(size, input_seq, config): from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - from tests.test_transformer import DecoderBlock + # from tests.test_transformer import EncoderBlock + from tests.Fusion.test_transformer_fusion import EncoderBlock scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) device = scheduler.execution_engine.module.custom_device() @@ -15,10 +16,10 @@ def run_BERT(size, input_seq, config): embedding_size = {'base': 768, 'large': 1024, 'xlarge': 2048} heads = {'base': 12, 'large': 16, 'xlarge': 32} # hidden/64 https://arxiv.org/pdf/1909.11942 cpu_query = torch.randn(input_seq, hidden_dim[size]) - decoder_block = DecoderBlock(embedding_size[size], heads[size]).eval() + encoder_block = EncoderBlock(embedding_size[size], heads[size]).eval() query = cpu_query.clone().to(device=device) - opt_fn = torch.compile(dynamic=False)(decoder_block.to(device=device)) + opt_fn = torch.compile(dynamic=False)(encoder_block.to(device=device)) SchedulerDNNModel.register_model(f"BERT-{size}", opt_fn) request = Request(f"BERT-{size}", [query], [], request_queue_idx=0) @@ -35,7 +36,7 @@ def run_BERT(size, input_seq, config): import os import sys base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path FIXME: gem5 result is different as directoy name sys.path.append(base_dir) args = argparse.ArgumentParser() diff --git a/experiments/attention.py b/experiments/attention.py index acfed848..e8f89dac 100644 --- a/experiments/attention.py +++ b/experiments/attention.py @@ -10,9 +10,9 @@ def run_attention(size, config): def attention(query, key, value): import math d_k = query.size(-1) - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) - p_attn = scores.softmax(dim=-1) - return torch.matmul(p_attn, value) + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) + p_attn = scores.softmax(dim=-2) + return torch.matmul(value.transpose(-1, -2), p_attn) from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) device = scheduler.execution_engine.module.custom_device() diff --git a/scripts/CompilerOpt_experiment/DMAopt.sh b/scripts/CompilerOpt_experiment/DMAopt.sh new file mode 100644 index 00000000..469cf766 --- /dev/null +++ b/scripts/CompilerOpt_experiment/DMAopt.sh @@ -0,0 +1,28 @@ +#!/bin/bash +export TORCHSIM_CONFIG="/root/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json" + +# None FG DMA +export TORCHSIM_SUBTILE=0 +python experiments/gemm.py --size 128 128 128 +python experiments/gemm.py --size 256 256 256 +python experiments/gemm.py --size 512 512 512 +python experiments/gemm.py --size 1024 1024 1024 +python experiments/gemm.py --size 2048 2048 2048 + +# FG DMA +export TORCHSIM_SUBTILE=1 +export TORCHSIM_MANUAL_SUBTILE_SIZE=1 +python experiments/gemm.py --size 128 128 128 +python experiments/gemm.py --size 256 256 256 +python experiments/gemm.py --size 512 512 512 +python experiments/gemm.py --size 1024 1024 1024 +python experiments/gemm.py --size 2048 2048 2048 + +# SFG DMA +export TORCHSIM_SUBTILE=1 +export TORCHSIM_MANUAL_SUBTILE_SIZE=0 +python experiments/gemm.py --size 128 128 128 +python experiments/gemm.py --size 256 256 256 +python experiments/gemm.py --size 512 512 512 +python experiments/gemm.py --size 1024 1024 1024 +python experiments/gemm.py --size 2048 2048 2048 \ No newline at end of file diff --git a/scripts/chiplet_prep.sh b/scripts/chiplet_prep.sh index 99fc9b30..cddf1a58 100755 --- a/scripts/chiplet_prep.sh +++ b/scripts/chiplet_prep.sh @@ -1,14 +1,13 @@ #!/bin/bash sizes=(256 512 1024 2048) -# 각 size에 대해 처리 for size in "${sizes[@]}"; do echo "Processing size: $size" - # 환경 변수 설정 - export TORCHSIM_FORCE_TIME_M=$((size / 2)) - export TORCHSIM_FORCE_TIME_K=$((size / 2)) - export TORCHSIM_FORCE_TIME_N=$((size / 2)) + # Set environment variables + export TORCHSIM_TILE_M=$((size / 2)) + export TORCHSIM_TILE_K=$((size / 2)) + export TORCHSIM_TILE_N=$((size / 2)) export TORCHSIM_DUMP_PATH=$(pwd)/chiplet_result/$size python3 chiplet_prep.py $size #python3 chiplet_run.py $(pwd)/chiplet_result diff --git a/test_extension_backend.py b/test_extension_backend.py index 10bc9854..f0a9353a 100644 --- a/test_extension_backend.py +++ b/test_extension_backend.py @@ -12,7 +12,7 @@ from tests.test_matmul import test_matmul from tests.test_bmm import test_BMM from tests.test_cnn import test_CNN -from tests.test_transformer import test_DecoderBlock +from tests.test_transformer import test_EncoderBlock from tests.test_resnet import test_resnet from tests.test_mlp import test_mlp, test_mlp_inf from tests.MoE.test_moe import test_moe @@ -46,7 +46,7 @@ #test_matmul(device, 33, 45, 68) #test_BMM(device) #test_CNN(device) - #test_DecoderBlock(device) + #test_EncoderBlock(device) #test_resnet(device) #test_mlp(device) #test_mlp_inf(device, batch_size=64, input_size=256, hidden_size=512, output_size=256, sparsity=0.97) diff --git a/tests/Fusion/test_attention_fusion.py b/tests/Fusion/test_attention_fusion.py index a513b0bb..95bdf165 100644 --- a/tests/Fusion/test_attention_fusion.py +++ b/tests/Fusion/test_attention_fusion.py @@ -47,8 +47,7 @@ def forward(self, query, key, value): x = torch.matmul(value.transpose(-1, -2), p_attn) # 3) "Concat" using a view and apply a final linear. x = ( - x.contiguous() - .view(-1, self.h * self.d_k) + x.view(-1, self.h * self.d_k) ) del query del key diff --git a/tests/Fusion/test_bmm_reduction.py b/tests/Fusion/test_bmm_reduction.py new file mode 100644 index 00000000..42e38095 --- /dev/null +++ b/tests/Fusion/test_bmm_reduction.py @@ -0,0 +1,52 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_bmm_reduce(device, batch=12, size=512): + def bmm(a, b): + result = torch.bmm(a, b.transpose(1,2)) + return result, result.max(dim=1).values + torch.manual_seed(0) + N = size + input = torch.randn(batch, N, 64) + weight = torch.randn(batch, N, 64) + #input = torch.arange(1, N * N + 1, dtype=torch.float32).reshape(N, N).to(dtype=torch.float32) + #weight = torch.eye(N, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + opt_fn = torch.compile(dynamic=False)(bmm) + res = opt_fn(x1, w1) + y = bmm(x2, w2) + test_result("BMM Reduction Fusion activation", res[0], y[0]) + test_result("BMM Reduction Fusion reduction", res[1], y[1]) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + #test_bmm_reduce(device) + test_bmm_reduce(device, 12, 512) + test_bmm_reduce(device, 4, 256) + test_bmm_reduce(device, 6, 768) + test_bmm_reduce(device, 2, 128) diff --git a/tests/Fusion/test_matmul_reduction.py b/tests/Fusion/test_matmul_reduction.py index 9f2cc7f3..31ea1b0d 100644 --- a/tests/Fusion/test_matmul_reduction.py +++ b/tests/Fusion/test_matmul_reduction.py @@ -17,14 +17,34 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("cpu out: ", cpu_out) exit(1) -def test_matmul_reduce(device, size=512): - def matmul_fused(a, b, c): +def test_matmul_reduce(device, M=512, N=512, K=512): + def matmul_fused(a, b): result = torch.matmul(a, b) return result, result.max(dim=-2).values torch.manual_seed(0) + input = torch.randn(M, K) + weight = torch.randn(K, N) + #input = torch.arange(1, M * K + 1, dtype=torch.float32).reshape(M, K).to(dtype=torch.float32) + #weight = torch.eye(K, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1) + y = matmul_fused(x2, w2) + test_result("Matmul Reduction Fusion activation", res[0], y[0]) + test_result("Matmul Reduction Fusion reduction", res[1], y[1]) + +def test_matmul_var_mean(device, size=512): + def matmul_fused(a, b, c): + result = torch.matmul(a, b.T) + var, mean = torch.var_mean(result, dim=-2) + return result, var, mean + torch.manual_seed(0) N = size - input = torch.randn(N, N) - weight = torch.randn(N, N) + input = torch.randn(1024, 768) + weight = torch.randn(512, 768) #input = torch.arange(1, N * N + 1, dtype=torch.float32).reshape(N, N).to(dtype=torch.float32) #weight = torch.eye(N, dtype=torch.float32) x1 = input.to(device=device) @@ -35,8 +55,34 @@ def matmul_fused(a, b, c): opt_fn = torch.compile(dynamic=False)(matmul_fused) res = opt_fn(x1, w1, c) y = matmul_fused(x2, w2, c) - test_result("Matmul Reduction Fusion activation", res[0], y[0]) - test_result("Matmul Reduction Fusion reduction", res[1], y[1]) + test_result("Matmul var_mean Fusion activation", res[0], y[0]) + test_result("Matmul var_mean Fusion reduction", res[1], y[1]) + test_result("Matmul var_mean Fusion reduction", res[2], y[2]) + +def test_matmul_add_var_mean(device, M=768, N=512, K=3072): + def matmul_fused(a, b, c, d): + result = torch.matmul(a, b.T) + c.T + var, mean = torch.var_mean(result + d, dim=-2) + return result, var, mean + torch.manual_seed(0) + input = torch.randn(M, K) + weight = torch.randn(N, K) + bias = torch.zeros(N, M) + residual = torch.randn(M,N) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + r1 = residual.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + r2 = residual.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, b1, r1) + y = matmul_fused(x2, w2, b2, r2) + test_result("Matmul+residual+var_mean Fusion activation", res[0], y[0]) + test_result("Matmul+residual+var_mean Fusion reduction", res[1], y[1]) + test_result("Matmul+residual+var_mean Fusion reduction", res[2], y[2]) if __name__ == "__main__": import os @@ -46,4 +92,6 @@ def matmul_fused(a, b, c): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_matmul_reduce(device) + test_matmul_reduce(device, 3072, 512, 768) + test_matmul_var_mean(device) + test_matmul_add_var_mean(device) diff --git a/tests/Fusion/test_prologue_fusion.py b/tests/Fusion/test_prologue_fusion.py new file mode 100644 index 00000000..797f9e76 --- /dev/null +++ b/tests/Fusion/test_prologue_fusion.py @@ -0,0 +1,97 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_elem_broadcast_fusion(device): + def matmul_fused(a, b, c): + return torch.matmul(c * a, b) + torch.manual_seed(0) + input = torch.randn(128, 128) + weight = torch.randn(128, 128) + c = torch.randn(128, 1, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + c1 = c.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + c2 = c.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, c1) + y = matmul_fused(x2, w2, c2) + test_result("Matmul Scalar Fusion Forward", res, y) + +def test_elem_fusion(device): + def matmul_fused(a, b, c): + return torch.matmul(c * a, b) + torch.manual_seed(0) + input = torch.randn(128, 128) + weight = torch.randn(128, 128) + c = torch.randn(128, 128, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + c1 = c.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + c2 = c.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, c1) + y = matmul_fused(x2, w2, c2) + test_result("Matmul Element-wise Fusion Forward", res, y) + +def test_elem_bmm_weight_fusion(device, batch_size=1, m=512, n=512, k=64): + def bmm(a, b, c, d): + return torch.bmm(a , (d+b)*c) + torch.manual_seed(0) + a = torch.randn(batch_size, m, k).to(device=device) + b = torch.randn(batch_size, 1, n).to(device=device) + c = torch.randn(batch_size, 1, n) + c = c.to(device=device) + d = torch.randn(batch_size, k, n).to(device=device) + opt_fn = torch.compile(dynamic=False)(bmm) + res = opt_fn(a, b, c, d) + out = bmm(a.cpu(), b.cpu(), c.cpu(), d.cpu()) + print(torch.max(torch.abs(res.cpu() - out))) + test_result("BMM Element-wise Fusion Forward", res, out) + +def test_elem_bmm_input_fusion(device, batch_size=1, m=512, n=512, k=64): + def bmm(a, b, c, d): + return torch.bmm((a+b)*c , d) + torch.manual_seed(0) + a = torch.randn(batch_size, m, k).to(device=device) + b = torch.randn(batch_size, 1, k).to(device=device) + c = torch.randn(batch_size, 1, k) + c = c.to(device=device) + d = torch.randn(batch_size, k, n).to(device=device) + opt_fn = torch.compile(dynamic=False)(bmm) + res = opt_fn(a, b, c, d) + out = bmm(a.cpu(), b.cpu(), c.cpu(), d.cpu()) + print(torch.max(torch.abs(res.cpu() - out))) + test_result("BMM Element-wise Fusion Forward", res, out) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_elem_broadcast_fusion(device) + test_elem_fusion(device) + test_elem_bmm_input_fusion(device, batch_size=4, m=512, n=512, k=64) + test_elem_bmm_weight_fusion(device, batch_size=12, m=512, n=512, k=64) \ No newline at end of file diff --git a/tests/Fusion/test_transformer_fusion.py b/tests/Fusion/test_transformer_fusion.py index 15bacb39..0e500b5b 100644 --- a/tests/Fusion/test_transformer_fusion.py +++ b/tests/Fusion/test_transformer_fusion.py @@ -53,9 +53,9 @@ def forward(self, query, key, value): del value return self.linears[-1](x) -class DecoderBlock_origin(torch.nn.Module): +class EncoderBlock_origin(torch.nn.Module): def __init__(self, embed_dim, num_heads): - super(DecoderBlock_origin, self).__init__() + super(EncoderBlock_origin, self).__init__() self.multihead_attn = my_MultiheadAttention_origin(num_heads, embed_dim) self.layer_norm = torch.nn.LayerNorm(embed_dim) self.ffn1 = torch.nn.Linear(embed_dim, embed_dim*4) @@ -111,9 +111,9 @@ def forward(self, x, residual): out = torch.matmul(self.weight, x.transpose(-1, -2)) + self.bias[:, None] # (1, 768, 512) return self.layer_norm(out.transpose(-1, -2) + residual) -class DecoderBlock(torch.nn.Module): +class EncoderBlock(torch.nn.Module): def __init__(self, embed_dim, num_heads): - super(DecoderBlock, self).__init__() + super(EncoderBlock, self).__init__() self.multihead_attn = my_MultiheadAttention(num_heads, embed_dim) self.layer_norm = torch.nn.LayerNorm(embed_dim) self.ffn1 = torch.nn.Linear(embed_dim, embed_dim*4) @@ -130,18 +130,18 @@ def forward(self, x): act_result = self.act(ffn1_result) return self.matmulln2(act_result, result) -def test_DecoderBlock(device, head=12, embed_dim=768, input_seq=512): +def test_EncoderBlock(device, head=12, embed_dim=768, input_seq=512): cpu_query = torch.randn(input_seq, embed_dim) - decoder_block = DecoderBlock(embed_dim, head) - cpu_res = decoder_block(cpu_query) + encoder_block = EncoderBlock(embed_dim, head) + cpu_res = encoder_block(cpu_query) query = cpu_query.clone().to(device=device) - decoder_block.to(device=device) + encoder_block.to(device=device) with torch.no_grad(): - opt_fn = torch.compile(dynamic=False)(decoder_block) + opt_fn = torch.compile(dynamic=False)(encoder_block) res = opt_fn(query) - test_result("Decoder Block Forwrad", res, cpu_res) + test_result("Encoder Block Forwrad", res, cpu_res) def test_Attention(device, head=16, seq=512, d_k=64): def attention(query, key, value): @@ -165,18 +165,18 @@ def attention(query, key, value): def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): MHA = my_MultiheadAttention(num_heads, embed_dim) cpu_query = torch.randn(input_seq, embed_dim) - cpu_res = MHA(cpu_query, cpu_query, cpu_query) - - query = cpu_query.clone().to(device=device) - MHA.to(device=device) - opt_fn = torch.compile(dynamic=False)(MHA) - res = opt_fn(query, query, query) + with torch.no_grad(): + cpu_res = MHA(cpu_query, cpu_query, cpu_query) + query = cpu_query.clone().to(device=device) + MHA.to(device=device) + opt_fn = torch.compile(dynamic=False)(MHA) + res = opt_fn(query, query, query) test_result("MHA Forward", res, cpu_res) -def test_DecoderBlock_validation(head=12, embed_dim=768, input_seq=512): - bert_origin = DecoderBlock_origin(embed_dim, head) - bert = DecoderBlock(embed_dim, head) +def test_EncoderBlock_validation(head=12, embed_dim=768, input_seq=512): + bert_origin = EncoderBlock_origin(embed_dim, head) + bert = EncoderBlock(embed_dim, head) bert.multihead_attn.linears[0].weight = bert_origin.multihead_attn.linears[0].weight bert.multihead_attn.linears[0].bias = bert_origin.multihead_attn.linears[0].bias @@ -196,7 +196,7 @@ def test_DecoderBlock_validation(head=12, embed_dim=768, input_seq=512): origin_res = bert_origin(origin_query) res = bert(query) - test_result("Decoder Block Validation", res, origin_res) + test_result("Encoder Block Validation", res, origin_res) if __name__ == "__main__": import os @@ -206,7 +206,8 @@ def test_DecoderBlock_validation(head=12, embed_dim=768, input_seq=512): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_DecoderBlock(device) - # test_DecoderBlock_validation() + #test_MHA(device) + test_EncoderBlock(device) + # test_EncoderBlock_validation() # test_Attention(device, head=16, seq=512, d_k=64) # test_MHA(device, num_heads=12, embed_dim=768) diff --git a/tests/Mixtral_8x7B/test_attention.py b/tests/Mixtral_8x7B/test_attention.py index cc2adc96..aa1af651 100644 --- a/tests/Mixtral_8x7B/test_attention.py +++ b/tests/Mixtral_8x7B/test_attention.py @@ -2,7 +2,7 @@ import torch import torch._dynamo import torch.utils.cpp_extension -from model import Transformer, TransformerBlock, ModelArgs, Attention, FeedForward, KVCache, precompute_freqs_cis, sample +from model import Transformer, TransformerBlock, ModelArgs, Attention, FeedForward, KVCache, RMSNorm, precompute_freqs_cis, sample def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -139,6 +139,25 @@ def concat_tensors(a, b): test_result("ConcatTensors", res, out) +def test_rmsnorm(device, seq=32): + dim = 512 + eps = 1e-5 + T = seq + rmsnorm = RMSNorm(dim=dim, eps=eps) + rmsnorm = rmsnorm.to(device=device) + + x = torch.randn([1, T, dim], dtype=torch.float32) + cpu_x = copy.deepcopy(x) + x = x.to(device) + + cpu_model = copy.deepcopy(rmsnorm).to("cpu") + opt_fn = torch.compile(dynamic=False)(rmsnorm) + + res = opt_fn(x) + cpu_res = cpu_model(cpu_x) + + test_result("RMSNorm", res, cpu_res) + if __name__ == "__main__": import os import sys @@ -147,7 +166,8 @@ def concat_tensors(a, b): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() + test_rmsnorm(device, seq=1) + test_concat(device, size1=(1, 8, 64, 64), size2=(1,8,1,64), dim=2) test_decode(device, 32, 3) - #test_concat(device, size1=(1, 8, 32, 64), size2=(1,8,1,64), dim=2) #test_attention(device) #test_ffn(device) diff --git a/tests/MoE/test_moe.py b/tests/MoE/test_moe.py index cf2f37f4..c5ab8107 100644 --- a/tests/MoE/test_moe.py +++ b/tests/MoE/test_moe.py @@ -1,12 +1,7 @@ # Owner(s): ["module: inductor"] import os -import shutil import sys -import time -import contextlib -import unittest import copy -import numpy as np import matplotlib.pyplot as plt diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index 9d8b855a..c679b431 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -43,6 +43,15 @@ def custom_conv2d(a, b, bias): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_conv2d(device, batch_size=1, in_channels=128, out_channels=128, input_size=28, kernel_size=3, stride=1, padding=1) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64, kernel_size=7, stride=2, padding=3) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64, kernel_size=7, stride=1, padding=3) + torch._dynamo.config.cache_size_limit = 64 + test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) + test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) diff --git a/tests/test_indirect_access.py b/tests/test_indirect_access.py index 6d16c9d0..b7b20074 100644 --- a/tests/test_indirect_access.py +++ b/tests/test_indirect_access.py @@ -27,7 +27,7 @@ def vectoradd(a, idx, b): opt_fn = torch.compile(dynamic=False)(vectoradd) res = opt_fn(x, idx, y) out = vectoradd(x.cpu(), idx.cpu(), y.cpu()) - test_result("VectorAdd", res, out) + test_result("Indirect VectorAdd", res, out) def test_embedding(device, vocab_size, dim): emb = torch.nn.Embedding(vocab_size, dim) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 44f70b69..6f41468b 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -50,6 +50,45 @@ def custom_matmul(bias, a, b): y = custom_matmul(b2, x2, w2) test_result("Addmm Forward", res, y) +def test_addmm2(device, input_size=128, hidden_size=128, output_size=128): + def custom_matmul(bias, a, b): + return torch.matmul(a, b) #+ bias + torch.manual_seed(0) + input = torch.randn(input_size, hidden_size) + weight = torch.randn(hidden_size, output_size) + bias = torch.randn(input_size, 1, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + opt_fn = torch.compile(dynamic=False)(custom_matmul) + res = opt_fn(b1, x1, w1) + y = custom_matmul(b2, x2, w2) + test_result("Addmm2 Forward", res, y) + +def test_linear(device, input_size=128, hidden_size=128, output_size=128): + def custom_linear(a, b, bias): + linear = torch.nn.Linear(hidden_size, output_size) + linear.weight = torch.nn.Parameter(b) + linear.bias = torch.nn.Parameter(bias) + return linear(a) + torch.manual_seed(0) + input = torch.randn(input_size, hidden_size) + weight = torch.randn(output_size, hidden_size) + bias = torch.randn(output_size) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + opt_fn = torch.compile(dynamic=False)(custom_linear) + res = opt_fn(x1, w1, b1) + y = custom_linear(x2, w2, b2) + test_result("Linear Forward", res, y) + if __name__ == "__main__": import os import sys @@ -62,7 +101,10 @@ def custom_matmul(bias, a, b): test_matmul(device, 128, 128, 128) test_matmul(device, 256, 256, 256) test_matmul(device, 128, 256, 256) - test_matmul(device, 129, 61, 56) + test_matmul(device, 128, 63, 56) test_addmm(device, 128, 256, 512) test_addmm(device, 128, 256, 512, bias_rank=2) test_addmm(device, 129, 61, 56) + test_addmm2(device, 129, 61, 56) + test_addmm(device, 129*4, 61*4, 56*4) + test_addmm2(device, 129*4, 61*4, 56*4) diff --git a/tests/test_pool.py b/tests/test_pool.py index e94df65b..304a5e7c 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -50,6 +50,6 @@ def avgpool(a): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_maxpool(device, b=1, c=8, h=16, w=16) - test_maxpool(device, b=1, c=8, h=112, w=112) - test_avgpool(device) + #test_maxpool(device, b=1, c=8, h=16, w=16) + #test_maxpool(device, b=1, c=8, h=112, w=112) + test_avgpool(device, b=1, c=512, h=7, w=7) diff --git a/tests/test_reduce.py b/tests/test_reduce.py index c1556787..e1a84b7f 100644 --- a/tests/test_reduce.py +++ b/tests/test_reduce.py @@ -50,9 +50,8 @@ def reduce_sum(a, dim, keepdim): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - #test_reduce_sum(device, (29, 47), 1, keepdim=True) - #test_reduce_sum(device, (17, 68), 0, keepdim=True) - #test_reduce_sum(device, (327, 447), 1, keepdim=True) - #test_reduce_sum(device, (327, 447), 0, keepdim=True) - test_reduce_sum2(device, shape) - + test_reduce_sum(device, (29, 47), 1, keepdim=True) + test_reduce_sum(device, (17, 68), 0, keepdim=True) + test_reduce_sum(device, (327, 447), 1, keepdim=True) + test_reduce_sum(device, (327, 447), 0, keepdim=True) + test_reduce_sum2(device, shape) \ No newline at end of file diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 5e96b922..f54ce9be 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -18,13 +18,13 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("cpu out: ", cpu_out) exit(1) -def test_resnet(device): +def test_resnet(device, batch=1): from torchvision.models import resnet - # model = resnet._resnet(resnet.BasicBlock, [1, 1, 0, 0], weights=None, progress=False).eval() with torch.no_grad(): + #model = resnet._resnet(resnet.BasicBlock, [1, 1, 1, 1], weights=None, progress=False).eval() model = resnet18().eval() model.to(device, memory_format=torch.channels_last) - input = torch.randn(1, 3, 224, 224) + input = torch.randn(batch, 3, 224, 224) x1 = input.to(device=device, memory_format=torch.channels_last) x2 = input.cpu().to(memory_format=torch.channels_last) opt_fn = torch.compile(dynamic=False)(model) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index e05fa392..c64093a0 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -2,7 +2,7 @@ import sys import torch from torchvision.models import resnet18 as model1 -from test_transformer import DecoderBlock as model2 +from test_transformer import EncoderBlock as model2 base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') sys.path.append(base_path) diff --git a/tests/test_sparsity.py b/tests/test_sparsity.py index b3945520..3e079f83 100644 --- a/tests/test_sparsity.py +++ b/tests/test_sparsity.py @@ -8,7 +8,7 @@ import torch._dynamo import torch.utils.cpp_extension sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) -from test_transformer import DecoderBlock, test_result +from test_transformer import EncoderBlock, test_result from test_mlp import MLP def apply_random_zero(tensor, zero_prob, block_size=8): @@ -35,30 +35,30 @@ def count_zeros_in_tensor_list(tensor_list): def test_dec_inf(device, sparsity=0.0, block=8): torch.manual_seed(0) - decoder_block = DecoderBlock(768, 12) + encoder_block = EncoderBlock(768, 12) cpu_query = torch.randn(512, 768) query = cpu_query.clone().to(device=device) - cpu_y = decoder_block(cpu_query) + cpu_y = encoder_block(cpu_query) with torch.no_grad(): - decoder_block.multihead_attn.linears[0].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[0].weight, sparsity, block_size=block)) - decoder_block.multihead_attn.linears[1].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[1].weight, sparsity, block_size=block)) - decoder_block.multihead_attn.linears[2].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[2].weight, sparsity, block_size=block)) - decoder_block.multihead_attn.linears[3].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[3].weight, sparsity, block_size=block)) - decoder_block.ffn1.weight.copy_(apply_random_zero(decoder_block.ffn1.weight, sparsity, block_size=block)) - decoder_block.ffn2.weight.copy_(apply_random_zero(decoder_block.ffn2.weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[0].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[0].weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[1].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[1].weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[2].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[2].weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[3].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[3].weight, sparsity, block_size=block)) + encoder_block.ffn1.weight.copy_(apply_random_zero(encoder_block.ffn1.weight, sparsity, block_size=block)) + encoder_block.ffn2.weight.copy_(apply_random_zero(encoder_block.ffn2.weight, sparsity, block_size=block)) count_zeros_in_tensor_list([ - decoder_block.multihead_attn.linears[0].weight, - decoder_block.multihead_attn.linears[1].weight, - decoder_block.multihead_attn.linears[2].weight, - decoder_block.multihead_attn.linears[3].weight, - decoder_block.ffn1.weight, - decoder_block.ffn2.weight + encoder_block.multihead_attn.linears[0].weight, + encoder_block.multihead_attn.linears[1].weight, + encoder_block.multihead_attn.linears[2].weight, + encoder_block.multihead_attn.linears[3].weight, + encoder_block.ffn1.weight, + encoder_block.ffn2.weight ]) - decoder_block.to(device=device) - opt_fn = torch.compile(dynamic=False)(decoder_block) + encoder_block.to(device=device) + opt_fn = torch.compile(dynamic=False)(encoder_block) y = opt_fn(query) test_result("MLP Forward", y, cpu_y) diff --git a/tests/test_spmm_scheduler.py b/tests/test_spmm_scheduler.py index 73bbdbae..1cf0d3b3 100644 --- a/tests/test_spmm_scheduler.py +++ b/tests/test_spmm_scheduler.py @@ -5,7 +5,7 @@ sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request from test_sparse_core import SparseMLP as model1 -from test_transformer import DecoderBlock as model2 +from test_transformer import EncoderBlock as model2 CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') if __name__ == "__main__": diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 83ed5850..4d45707e 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -41,23 +41,21 @@ def forward(self, query, key, value): ] # 2) Apply attention on all the projected vectors in batch. - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k) - p_attn = scores.softmax(dim=-1) - x = torch.matmul(p_attn, value) + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(self.d_k) + p_attn = scores.softmax(dim=-2) + x = torch.matmul(value.transpose(-1, -2), p_attn) # 3) "Concat" using a view and apply a final linear. x = ( - x.transpose(0, 1) - .contiguous() - .view(-1, self.h * self.d_k) + x.view(-1, self.h * self.d_k) ) del query del key del value return self.linears[-1](x) -class DecoderBlock(torch.nn.Module): +class EncoderBlock(torch.nn.Module): def __init__(self, embed_dim, num_heads): - super(DecoderBlock, self).__init__() + super(EncoderBlock, self).__init__() self.multihead_attn = my_MultiheadAttention(num_heads, embed_dim) self.layer_norm = torch.nn.LayerNorm(embed_dim) self.ffn1 = torch.nn.Linear(embed_dim, embed_dim*4) @@ -73,25 +71,25 @@ def forward(self, x): ffn2_result = self.ffn2(act_result) return self.layer_norm(ffn2_result + result) -def test_DecoderBlock(device, head=12, embed_dim=768, input_seq=512): +def test_EncoderBlock(device, head=12, embed_dim=768, input_seq=512): cpu_query = torch.randn(1, input_seq, embed_dim) - decoder_block = DecoderBlock(embed_dim, head) - cpu_res = decoder_block(cpu_query) + encoder_block = EncoderBlock(embed_dim, head) + cpu_res = encoder_block(cpu_query) query = cpu_query.clone().to(device=device) - decoder_block.to(device=device) - opt_fn = torch.compile(dynamic=False)(decoder_block) + encoder_block.to(device=device) + opt_fn = torch.compile(dynamic=False)(encoder_block) res = opt_fn(query) - test_result("Decoder Block Forwrad", res, cpu_res) + test_result("Encoder Block Forwrad", res, cpu_res) def test_Attention(device, head=16, seq=512, d_k=64): def attention(query, key, value): import math d_k = query.size(-1) - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) - p_attn = scores.softmax(dim=-1) - return torch.matmul(p_attn, value), p_attn + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) + p_attn = scores.softmax(dim=-2) + return torch.matmul(value.transpose(-1, -2), p_attn) torch.manual_seed(0) query = torch.randn(head, seq, d_k).to(device=device) @@ -99,9 +97,9 @@ def attention(query, key, value): value = torch.randn(head, seq, d_k).to(device=device) opt_fn = torch.compile(dynamic=False)(attention) - res, p_attn = opt_fn(query, key, value) + res = opt_fn(query, key, value) - cpu_res, cpu_p_attn = attention(query.cpu(), key.cpu(), value.cpu()) + cpu_res = attention(query.cpu(), key.cpu(), value.cpu()) test_result("Attention Forward", res, cpu_res) def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): @@ -124,6 +122,6 @@ def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_DecoderBlock(device) + test_EncoderBlock(device) # test_Attention(device, head=16, seq=512, d_k=64) # test_MHA(device, num_heads=12, embed_dim=768) diff --git a/validation/gemm_tpuv3_cheatsheet.json b/validation/gemm_tpuv3_cheatsheet.json new file mode 100644 index 00000000..76a26e1a --- /dev/null +++ b/validation/gemm_tpuv3_cheatsheet.json @@ -0,0 +1,17 @@ +{ + "512_2048_8192" : { + "TILE_M" : 512, + "TILE_K" : 512, + "TILE_N" : 1024 + }, + "512_2048_2048" : { + "TILE_M" : 512, + "TILE_K" : 512, + "TILE_N" : 1024 + }, + "2048_2048_512" : { + "TILE_M" : 1024, + "TILE_K" : 512, + "TILE_N" : 512 + } +} \ No newline at end of file