diff --git a/.github/workflows/docker-base-image.yml b/.github/workflows/docker-base-image.yml index 708c9efe..82614c8c 100644 --- a/.github/workflows/docker-base-image.yml +++ b/.github/workflows/docker-base-image.yml @@ -2,7 +2,9 @@ name: Docker Base Image CI on: push: - branches: [ "master" ] + branches: [ "base" ] + repository_dispatch: + types: [ build_base ] jobs: build: @@ -27,21 +29,8 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - # Step 3: Check if the Base Image Exists - - name: Check if Base Image Exists - id: check-image - run: | - if docker pull ghcr.io/psal-postech/torchsim_base:latest; then - echo "Base image already exists. Skipping build and push." - echo "exists=true" >> $GITHUB_ENV - else - echo "Base image does not exist. Proceeding with build and push." - echo "exists=false" >> $GITHUB_ENV - fi - # Step 4: Build and Push Docker Image - name: Build and Push Docker Image - if: env.exists == 'false' uses: docker/build-push-action@v4 with: context: . diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 70c58a04..2b420ff8 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -2,9 +2,7 @@ name: Docker Image CI on: push: - branches: [ "master", "develop" ] - pull_request: - branches: [ "master", "develop" ] + branches: [ "master" ] jobs: build: diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml new file mode 100644 index 00000000..003a0d01 --- /dev/null +++ b/.github/workflows/pull-request.yml @@ -0,0 +1,373 @@ +name: PR test CI + +on: + pull_request: + branches: [ "master", "develop" ] + +jobs: + build: + runs-on: self-hosted + + permissions: + contents: read + packages: write + attestations: write + id-token: write + + steps: + # Step 1: Checkout the repository + - name: Checkout Code + uses: actions/checkout@v4 + # Step 2: Log in to GitHub Container Registry (optional) + # If you need to push the built image, authenticate here. + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + # Step 3: Pull the Cached Image + - name: Pull Cached Image & Set environment + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + docker pull ghcr.io/psal-postech/torchsim_base:latest || echo "No cache available" + echo "IMAGE_TAG=torchsim-ci:${GITHUB_SHA}" >> $GITHUB_ENV + echo "GITHUB_SHA=${{github.event.pull_request.head.sha}}" >> $GITHUB_ENV + echo "GITHUB_SHA=${{github.event.pull_request.head.sha}}" + gem5_response_file=/tmp/releases-gem5-latest.json + response=$(curl -sH "Authorization: Bearer ${GIT_ACCESS_TOKEN}" https://api.github.com/repos/PSAL-POSTECH/GEM5/releases/latest > ${gem5_response_file} ) + GEM5_ASSET_ID=$(cat ${gem5_response_file} | jq ".assets[0]."id"") + echo "GEM5_ASSET_ID=$GEM5_ASSET_ID" + echo "GEM5_ASSET_ID=$GEM5_ASSET_ID" >> $GITHUB_ENV + + llvm_response_file=/tmp/releases-gem5-latest.json + response=$(curl -sH "Authorization: Bearer ${GIT_ACCESS_TOKEN}" https://api.github.com/repos/PSAL-POSTECH/llvm-project/releases/latest > ${llvm_response_file} ) + LLVM_ASSET_ID=$(cat ${llvm_response_file} | jq ".assets[0]."id"") + echo "LLVM_ASSET_ID=$LLVM_ASSET_ID" + echo "LLVM_ASSET_ID=$LLVM_ASSET_ID" >> $GITHUB_ENV + + mkdir -p /tmp/torchsim-ci/${GITHUB_SHA} + echo "DUMP_PATH=/tmp/torchsim-ci/${GITHUB_SHA}" + + # Step 4: Build and Push Docker Image + - name: Build and Push Docker Image + uses: docker/build-push-action@v4 + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + with: + context: . + file: ./Dockerfile + push: true + build-args: | + GEM5_ASSET_ID=${{ env.GEM5_ASSET_ID }} + LLVM_ASSET_ID=${{ env.LLVM_ASSET_ID }} + GIT_ACCESS_TOKEN=${{ env.GIT_ACCESS_TOKEN }} + TORCHSIM_SHA=${{ env.GITHUB_SHA }} + tags: ghcr.io/psal-postech/${{ env.IMAGE_TAG}} + + test_add: + name: Run test_add.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_add.py + run: | + echo "Running test_add.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_add.py + + test_relu: + name: Run test_relu.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_relu.py + run: | + echo "Running test_relu.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_relu.py + + test_batchnorm: + name: Run test_batchnorm.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_batchnorm.py + run: | + echo "Running test_batchnorm.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_batchnorm.py + + test_bmm: + name: Run test_bmm.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_bmm.py + run: | + echo "Running test_bmm.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_bmm.py + + test_cnn: + name: Run test_cnn.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_cnn.py + run: | + echo "Running test_cnn.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_cnn.py + + test_conv2d: + name: Run test_conv2d.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_conv2d.py + run: | + echo "Running test_conv2d.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_conv2d.py + + test_matmul: + name: Run test_matmul.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_matmul.py + run: | + echo "Running test_matmul.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_matmul.py + + test_reduce: + name: Run test_reduce.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_reduce.py + run: | + echo "Running test_reduce.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_reduce.py + + test_softmax: + name: Run test_softmax.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_softmax.py + run: | + echo "Running test_softmax.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_softmax.py + + test_transpose2D: + name: Run test_transpose2D.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_transpose2D.py + run: | + echo "Running test_transpose2D.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_transpose2D.py + + test_view3D_2D: + name: Run test_view3D_2D.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_view3D_2D.py + run: | + echo "Running test_view3D_2D.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_view3D_2D.py + + test_layernorm: + name: Run test_layernorm.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_layernorm.py + run: | + echo "Running test_layernorm.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_layernorm.py + + test_mlp: + name: Run test_mlp.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_mlp.py + run: | + echo "Running test_mlp.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_mlp.py + + test_resnet: + name: Run test_resnet.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_resnet.py + run: | + echo "Running test_resnet.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_resnet.py + + test_transformer: + name: Run test_transformer.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_transformer.py + run: | + echo "Running test_transformer.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_transformer.py + + test_transpose3D: + name: Run test_transpose3D.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_transpose3D.py + run: | + echo "Running test_transpose3D.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_transpose3D.py + + test_sparsity: + name: Run test_sparsity.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_sparsity.py + run: | + echo "Running test_sparsity.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_sparsity.py + + test_pool: + name: Run test_pool.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_pool.py + run: | + echo "Running test_pool.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_pool.py + + test_perceptron: + name: Run test_perceptron.py + runs-on: self-hosted + needs: build + steps: + - name: Run test_single_perceptron.py + run: | + echo "Running test_single_perceptron.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_single_perceptron.py + + test_fusion: + name: Run test_fusion + runs-on: self-hosted + needs: build + steps: + - name: Run test_addmm_residual.py + run: | + echo "Running test_addmm_residual.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_addmm_residual.py + - name: Run test_matmul_activation.py + run: | + echo "Running test_matmul_activation.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_activation.py + - name: Run test_matmul_scalar.py + run: | + echo "Running test_matmul_scalar.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_scalar.py + + test_moe: + name: Run test_moe + runs-on: self-hosted + needs: build + steps: + - name: Run test_moe.py + run: | + echo "Running test_moe.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/MoE/test_moe.py + + test_cleanup: + name: Clean test cases + runs-on: self-hosted + needs: [test_add, test_batchnorm, test_bmm, test_cnn, test_conv2d, + test_matmul, test_reduce, test_softmax, + test_transpose2D, test_view3D_2D, test_layernorm, + test_mlp, test_resnet, test_transformer, test_transpose3D, + test_sparsity, test_relu, test_pool, test_perceptron, + test_fusion, test_moe] + steps: + - name: Checkout code + uses: actions/checkout@v3 + - name: Clean test case + run: | + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} chown -R $(id -u):$(id -g) /dump diff --git a/AsmParser/tog_generator.py b/AsmParser/tog_generator.py index 9a39b782..1b5971e2 100644 --- a/AsmParser/tog_generator.py +++ b/AsmParser/tog_generator.py @@ -56,8 +56,9 @@ def decrease_depth_stack(self): def load_file(self, path): self.module = import_module_from_path(self.module_name, path) - self.raw_graph = self.module.graph - self.parse_graph() + if hasattr(self.module, "graph"): + self.raw_graph = self.module.graph + self.parse_graph() def _create_node(self, dump_data): node_id = dump_data["node_id"] @@ -193,17 +194,18 @@ def parse_graph(self): def generate_tile_graph(self, name="tile_graph", cycle_list=list, offset=int, vector_lane=int): node_list = list(self.node_dict.values())[1:] - node_list[0].set_parent([]) - for iter_node in self.node_dict.values(): - if isinstance(iter_node, compute_node): - if cycle_list: - iter_node.torchsim_cycle = cycle_list.pop(0) - else: - print("[TOGGen] Error compute cycle timing is missing...!") - iter_node.torchsim_cycle = 10 - # FIXME. - if iter_node.torchsim_compute_type == 1: - iter_node.torchsim_overlapping_cycle = iter_node.torchsim_cycle - offset + if len(node_list): + node_list[0].set_parent([]) + for iter_node in self.node_dict.values(): + if isinstance(iter_node, compute_node): + if cycle_list: + iter_node.torchsim_cycle = cycle_list.pop(0) + else: + print("[TOGGen] Error compute cycle timing is missing...!") + iter_node.torchsim_cycle = 10 + # FIXME. + if iter_node.torchsim_compute_type == 1: + iter_node.torchsim_overlapping_cycle = iter_node.torchsim_cycle - offset origin_info = "_".join(map(str, self.origins)) onnx_node_list = [node.to_onnx() for node in node_list] # Exclude root node diff --git a/PyTorchSimBackend/src/TileGraphParser.cc b/PyTorchSimBackend/src/TileGraphParser.cc index 4a88f4b9..b9ea2b08 100644 --- a/PyTorchSimBackend/src/TileGraphParser.cc +++ b/PyTorchSimBackend/src/TileGraphParser.cc @@ -23,6 +23,26 @@ uint32_t calculateAddress(const std::vector& loop_size, const std::vec return address; } + +int getLoopIndexValue(const std::map& iter, const std::string& loop_idx) { + // Check if loop_idx starts with "c" + if (!loop_idx.empty() && loop_idx[0] == 'c') { + // Extract substring after 'c' and convert to integer + const char* numberPart = loop_idx.c_str() + 1; // Skip the first character 'c' + int convertedValue = std::atoi(numberPart); + return convertedValue; + } + + // If loop_idx does not start with 'c', check in the map + auto it = iter.find(loop_idx); + if (it != iter.end()) { + return it->second; + } + + // If loop_idx is not found, throw an exception + throw std::runtime_error("Key not found in map and does not start with 'c': " + loop_idx); +} + std::vector calc_output_idx(TileGraphParser* tog_parser, std::map& iter) { // Extract outer loop // Extract inner loop @@ -272,7 +292,7 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa int nr_inner_loop = 0; auto& loop_idx_list = mem_node->get_loop_idx_list(); for (auto loop_idx: loop_idx_list) { - auto iter_value = iter.at(loop_idx); + auto iter_value = getLoopIndexValue(iter, loop_idx); iter_list.push_back(iter_value); loop_size_list.push_back(tog_parser->get_loop_size(loop_idx)); if (tog_parser->get_loop_type(loop_idx)==LoopType::INNER_LOOP) @@ -283,7 +303,7 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa loop_idx != loop_idx_list.end() - nr_inner_loop; ++loop_idx) { // Check loop type and process if (tog_parser->get_loop_type(*loop_idx)==LoopType::ACCUMULATION_LOOP) { - auto iter_value = iter.at(*loop_idx); + auto iter_value = getLoopIndexValue(iter, *loop_idx); tag_list.push_back(iter_value); } } @@ -292,7 +312,7 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa loop_idx != loop_idx_list.end(); ++loop_idx) { if (tog_parser->get_loop_type(*loop_idx)==LoopType::PARALLEL_LOOP) { uint32_t step = (uint32_t)tog_parser->get_loop_step(*loop_idx); - auto iter_value = iter.at(*loop_idx) / step; + auto iter_value = getLoopIndexValue(iter, *loop_idx) / step; outer_loop_idx.push_back(iter_value); outer_loop_size.push_back(tog_parser->get_loop_size(*loop_idx)); } @@ -302,7 +322,7 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa if (iter.find(loop_idx) == iter.end()) tag_list.push_back(0); else { - auto iter_value = iter.at(loop_idx); + auto iter_value = getLoopIndexValue(iter, loop_idx); tag_list.push_back(iter_value); } } @@ -342,14 +362,14 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa int nr_inner_loop = 0; auto& loop_idx_list = mem_node->get_loop_idx_list(); for (auto loop_idx: loop_idx_list) { - auto iter_value = iter.at(loop_idx); + auto iter_value = getLoopIndexValue(iter, loop_idx); iter_list.push_back(iter_value); loop_size_list.push_back(tog_parser->get_loop_size(loop_idx)); if (tog_parser->get_loop_type(loop_idx)==LoopType::INNER_LOOP) nr_inner_loop++; if (tog_parser->get_loop_type(loop_idx)==LoopType::PARALLEL_LOOP) { uint32_t step = (uint32_t) tog_parser->get_loop_step(loop_idx); - auto iter_value = iter.at(loop_idx) / step; + auto iter_value = getLoopIndexValue(iter, loop_idx) / step; outer_loop_idx.push_back(iter_value); outer_loop_size.push_back(tog_parser->get_loop_size(loop_idx)/ step); } @@ -400,7 +420,7 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa if (iter.find(loop_idx) == iter.end()) tag_list.push_back(0); else { - auto iter_value = iter.at(loop_idx) * inner_step; + auto iter_value = getLoopIndexValue(iter, loop_idx) * inner_step; tag_list.push_back(iter_value); } } @@ -618,13 +638,13 @@ TileGraphParser::TileGraphParser(std::string onnx_path, json& attribute_json) { tile->print_node(); } + _tile_graph = std::make_unique(TileGraph(onnx_path, graph_name)); /* Generate subgraph */ if (_loop_nodes.empty()) { - spdlog::error("[TileGraphParser] No loop found..."); - exit(EXIT_FAILURE); + spdlog::warn("[TileGraphParser] Null Kernel \"{}\"", onnx_path); + return; } - _tile_graph = std::make_unique(TileGraph(onnx_path, graph_name)); int last_outer_idx = -1; /* Extract outer loop */ for (int i=0;i<_loop_nodes.size();i++) { diff --git a/PyTorchSimBackend/src/main.cc b/PyTorchSimBackend/src/main.cc index e32b3eed..67f19d6d 100644 --- a/PyTorchSimBackend/src/main.cc +++ b/PyTorchSimBackend/src/main.cc @@ -95,6 +95,8 @@ void interactive_mode(Simulator* simulator) { if (isDryRun) std::cout << "[" << simulator->get_core_cycle() << "] BackendSim> "; } + if (simulator->get_core_cycle()==0) + simulator->until(0); simulator->print_core_stat(); } @@ -148,6 +150,8 @@ int main(int argc, char** argv) { /* launch kernels */ launchKernel(simulator, onnx_path, attribute_path); simulator->run_simulator(); + if (simulator->get_core_cycle()==0) + simulator->until(1); simulator->print_core_stat(); } else if (execution_mode.compare("interactive") == 0) { /* Get onnx_path, attribute from user input, request_time */ diff --git a/PyTorchSimFrontend/llvm/llvm_codegen_backend.py b/PyTorchSimFrontend/llvm/llvm_codegen_backend.py index e8daa889..6951b5bd 100644 --- a/PyTorchSimFrontend/llvm/llvm_codegen_backend.py +++ b/PyTorchSimFrontend/llvm/llvm_codegen_backend.py @@ -212,7 +212,7 @@ def maximum(operand1, operand2, tile_size=4): @staticmethod def relu(x, tile_size=4): - return ops.maximum(x, ops.constant(0.0, torch.int32)) + return ops.maximum(x, ops.constant(0.0, "f32")) SYMPY_TO_LLVM = { sympy.core.mul.Mul: "mul", diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 41a2e0cd..cd99d52e 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -23,8 +23,8 @@ %c_mvin3 = arith.constant 14 : index{% endif %} %c_mvout = arith.constant 3 : index %c_set = arith.constant 2 : index - %c{{ TILE_K * 2 + 0}} = arith.constant {{ TILE_K * 2 + 0}} : index{% if Bias_rank == 1 %} - %c0 = arith.constant 0 : index{% endif %}{% if X_transposed %} + %c{{ TILE_K * 2 + 0}} = arith.constant {{ TILE_K * 2 + 0}} : index + %c0 = arith.constant 0 : index{% if X_transposed %} %x_chunk = arith.constant {{ kernel.vector_lane * 2 + 0 }} : index{% endif %}{% if W_transposed %} %w_chunk = arith.constant {{ TILE_K * 2 + 0 }} : index{% endif %} %M = arith.constant {{ M }} : index @@ -52,16 +52,16 @@ affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { %index0 = affine.apply #map0(%b, %t_m, %t_k) %index1 = affine.apply #map1(%b, %t_k, %t_n) - affine.dma_start %X[%index0], %X_buffer[0, 0], %tag[0], %c_mvin, + affine.dma_start %X[%index0], %X_buffer[%c0, %c0], %tag[0], %c_mvin, {%- if X_transposed -%} %M, %x_chunk {%- else -%} %K, %c_set {%- endif -%} : memref<{{ B * M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ TILE_K }}], async=1{% if X_transposed %}, transpose=1{% endif %} } - affine.dma_start %W[%index1], %W_buffer[0, 0], %tag[0], %c_mvin2, + affine.dma_start %W[%index1], %W_buffer[%c0, %c0], %tag[0], %c_mvin2, {%- if W_transposed -%} %K, %w_chunk {%- else -%} %N, %c_set {%- endif -%} : memref<{{ B * K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ TILE_K }}, {{ kernel.vector_lane }}], async=1{% if W_transposed %}, transpose=1{% endif %} } linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) } { accumulation_loop=true } - affine.dma_start %Y_buffer[0, 0], %Y[%index2], %tag[0], %c_mvout, %N, %c_set : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<{{ B * M * N }}xf32>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ kernel.vector_lane }}], async=1 } + affine.dma_start %Y_buffer[%c0, %c0], %Y[%index2], %tag[0], %c_mvout, %N, %c_set : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<{{ B * M * N }}xf32>, memref<1xi32> { async=1 } } { outer_loop=true } } { outer_loop=true } } { outer_loop=true } diff --git a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py index b6a2d7d7..92f250df 100644 --- a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py +++ b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py @@ -28,7 +28,7 @@ def load_arg(self): if self.is_in_arg(arg_attribute[0]): argv_idx = self.get_argv_idx() if arg_name not in self.load_args else self.load_args[arg_name] self.load_args[arg_name] = argv_idx - self.writeline(f'if(load_arg({arg_name}, sizeof({arg_name}), argv[{argv_idx}]) == -1){self.open_bracket}') + self.writeline(f'if(load_arg(c_{arg_name}, sizeof(c_{arg_name}), argv[{argv_idx}]) == -1){self.open_bracket}') with self.code.indent(): self.writeline(f'return -1{self.ending}') self.writeline(self.closed_bracket) @@ -37,7 +37,7 @@ def dump_arg(self): for arg_name, arg_attribute in self.arg_attributes: if self.is_out_arg(arg_attribute[0]): argv_idx = self.get_argv_idx() if not self.is_inout_arg(arg_attribute[0]) else self.load_args[arg_name] - self.writeline(f'if(dump_arg({arg_name}, sizeof({arg_name}), argv[{argv_idx}]) == -1){self.open_bracket}') + self.writeline(f'if(dump_arg(c_{arg_name}, sizeof(c_{arg_name}), argv[{argv_idx}]) == -1){self.open_bracket}') with self.code.indent(): self.writeline(f'return -1{self.ending}') self.writeline(self.closed_bracket) @@ -53,7 +53,7 @@ def generate_args_define(self): for arg_name, (_, arg_type, arg_size) in self.arg_attributes: if not arg_name in name_set: if self.validation: - self.writeline(f'{DTYPE_TO_C[arg_type]} {arg_name}[{arg_size}]{self.ending}') + self.writeline(f'{DTYPE_TO_C[arg_type]} c_{arg_name}[{arg_size}]{self.ending}') else: if torch.is_floating_point(torch.tensor([], dtype=arg_type)): bits = torch.finfo(arg_type).bits @@ -61,7 +61,7 @@ def generate_args_define(self): bits = 8 else: bits = torch.iinfo(arg_type).bits - self.writeline(f'{DTYPE_TO_C[arg_type]}* {arg_name} = malloc({arg_size * bits // 8}){self.ending}') + self.writeline(f'{DTYPE_TO_C[arg_type]}* c_{arg_name} = malloc({arg_size * bits // 8}){self.ending}') name_set.add(arg_name) self.writeline(self.newline) @@ -77,7 +77,7 @@ def generate_main(self): else: self.generate_args_define() - func_arguments = [f"{arg_name}, {arg_name}, 0, {arg_shape}, 1" if arg_type != torch.bool else f"{arg_name}, {arg_name}, 0, {(arg_shape + 7) // 8}, 1" for arg_name, (_, arg_type, arg_shape) in self.arg_attributes] + func_arguments = [f"c_{arg_name}, c_{arg_name}, 0, {arg_shape}, 1" if arg_type != torch.bool else f"c_{arg_name}, c_{arg_name}, 0, {(arg_shape + 7) // 8}, 1" for arg_name, (_, arg_type, arg_shape) in self.arg_attributes] self.writeline(f"wrapper_{self.kernel_name}({', '.join(func_arguments)}){self.ending}{self.newline}") if self.validation: diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index faa08889..3d65aa53 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -107,53 +107,140 @@ def write_header(self): ) class ExtensionOverrides(common.OpOverrides): + # Binary element wise operations @staticmethod - def add(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.add{dtype[0]} %{operand1}, %{operand2} : {shape}' + def custom_cast(operand, target_type, *args, var_info=None): + dtype = var_info[operand][1] + if dtype == "index": + ret = ops.index_cast(operand, target_type, var_info=var_info) + else: + ret = ops.to_dtype(operand, target_type, var_info=var_info) + return ret, var_info[ret] @staticmethod - def sub(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.sub{dtype[0]} %{operand1}, %{operand2} : {shape}' + def binary_elementwise_common(operand1, operand2, var_info): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + # Tile size check + if op_type1[0] != op_type2[0]: + # Try to broad cast + lhs_tile_size, lhs_dtype = op_type1 + rhs_tile_size, rhs_dtype = op_type2 + if lhs_tile_size > rhs_tile_size: + operand2 = ops.broadcast(operand2, operand1, var_info=var_info) + op_type2 = var_info[operand2] + elif lhs_tile_size < rhs_tile_size: + operand1 = ops.broadcast(operand1, operand2, var_info=var_info) + op_type1 = var_info[operand1] + + # Data type check + if op_type1[1] != op_type2[1]: + if op_type1[1] == "index" or op_type1 == "index": + if op_type1[1] == "index": + operand1 = ops.index_cast(operand1, op_type2[1], var_info) + op_type1 = var_info[operand1] + if op_type2[1] == "index": + operand2 = ops.index_cast(operand2, op_type1[1], var_info) + op_type2 = var_info[operand2] + elif op_type1[1][0] == "i" and op_type2[1][0] == "f": + operand1 = ops.to_dtype(operand1, op_type2[1], var_info) + op_type1 = var_info[operand1] + elif op_type1[1][0] == "f" and op_type2[1][0] == "i": + operand2 = ops.to_dtype(operand2, op_type1[1], var_info) + op_type2 = var_info[operand2] + else: + raise NotImplementedError("Unsupported type converting") + + # Updated var info + tile_size = op_type1[0] + ret_type = op_type1[1] + return tile_size, ret_type, operand1, operand2 @staticmethod - def mul(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.mul{dtype[0]} %{operand1}, %{operand2} : {shape}' + def add(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.add{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def div(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.div{dtype[0]} %{operand1}, %{operand2} : {shape}' + def sub(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.sub{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def truediv(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.div{dtype[0]} %{operand1}, %{operand2} : {shape}' + def mul(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.mul{ret_type[0]}' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def to_dtype(x, dst_type, src_dtype=None, tile_size=16, dtype="f32"): - mlir_dtype = mlir_common.DTYPE_TO_MLIR[dst_type] - src_mlir_dtype = mlir_common.DTYPE_TO_MLIR[src_dtype] + def div(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.divf' + else: + opcode = f'arith.divui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] - dst_bits = 1 if dst_type == torch.bool else torch.finfo(dst_type).bits if dst_type.is_floating_point else torch.iinfo(dst_type).bits - src_bits = 1 if src_dtype == torch.bool else torch.finfo(src_dtype).bits if src_dtype.is_floating_point else torch.iinfo(src_dtype).bits - shape = f"vector<{tile_size}x{mlir_dtype}>" if tile_size > 1 else mlir_dtype + @staticmethod + def truediv(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.divf' + else: + opcode = f'arith.divui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def minimum(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.minimumf' + else: + opcode = f'arith.minimumui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def maximum(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + opcode = f'arith.maximumf' + else: + opcode = f'arith.maximumui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def to_dtype(operand, dst_mlir_dtype, *args, var_info=None): + src_mlir_dtype = var_info[operand][1] + tile_size = var_info[operand][0] + + dst_bits = int(dst_mlir_dtype[1:]) + src_bits = int(src_mlir_dtype[1:]) + shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype src_shape = f"vector<{tile_size}x{src_mlir_dtype}>" if tile_size > 1 else src_mlir_dtype - if dst_type.is_floating_point and not src_dtype.is_floating_point: + if dst_mlir_dtype[0] == "i" and src_mlir_dtype[0] == "f": raise NotImplementedError("floating point to integer conversion") - elif not dst_type.is_floating_point and src_dtype.is_floating_point: + if dst_mlir_dtype[0] == "f" and src_mlir_dtype[0] == "i": raise NotImplementedError("integer to floating point conversion") else: if dst_bits > src_bits: - return f"arith.extui %{x} : {src_shape} to {shape}" + return f"arith.extui %{operand} : {src_shape} to {shape}" elif dst_bits < src_bits: - return f"arith.trunc %{x} : {src_shape} to {shape}" + return f"arith.trunc %{operand} : {src_shape} to {shape}" @staticmethod - def constant(value, src_type, tile_size=16, dtype="f32"): - src_type = mlir_common.DTYPE_TO_MLIR[src_type] + def constant(value, src_type, *args, var_info=None): + if isinstance(src_type, torch.dtype): + src_type = mlir_common.DTYPE_TO_MLIR[src_type] + # if value represented by e notation, convert to float (ex 1e-3 -> 1.0e-3) if "e" in str(value): value = float(value) @@ -161,89 +248,354 @@ def constant(value, src_type, tile_size=16, dtype="f32"): value = format(value, ".20f") if src_type[0] == "i": value = int(value) - return f'arith.constant {value} : {src_type}' + return f'arith.constant {value} : {src_type}', [1, src_type] + # transcendental functions @staticmethod - def exp(operand, tile_size=16, dtype="f32"): + def exp(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.exp %{operand} : {shape}' + return f'math.exp %{operand} : {shape}', [tile_size, dtype] @staticmethod - def maximum(operand1, operand2, tile_size=16, dtype="f32"): + def sqrt(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.maximum{dtype[0]} %{operand1}, %{operand2} : {shape}' + return f'math.sqrt %{operand} : {shape}', [tile_size, dtype] @staticmethod - def sqrt(x, tile_size=16, dtype="f32"): + def rsqrt(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.sqrt %{x} : {shape}' + return f'math.rsqrt %{operand} : {shape}', [tile_size, dtype] @staticmethod - def ne(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else "i1" - return f'arith.cmp{dtype[0]} one, %{operand1}, %{operand2} : {shape}' + def pow(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "f": + operand1, dtype = ops.to_dtype(operand1, "f32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "f": + operand2, dtype = ops.to_dtype(operand2, "f32", var_info=var_info) + var_info[operand2] = dtype + + op_type1 = var_info[operand1] + tile_size = op_type1[0] + dtype = op_type1[1] + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f"math.pow{dtype[0]} %{operand1}, %{operand2} : {shape}", [] @staticmethod - def lt(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else "i1" - return f'arith.cmp{dtype[0]} olt, %{operand1}, %{operand2} : {shape}' + def log(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.log %{operand} : {shape}', [tile_size, dtype] @staticmethod - def gt(operand1, operand2, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else "i1" - return f'arith.cmp{dtype[0]} ogt, %{operand1}, %{operand2} : {shape}' + def reciprocal(operand, *args, var_info): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + + return ops.div(ops.constant(1.0, dtype), operand), [tile_size, dtype] + + # Logical operations @staticmethod - def le(operand1, operand2, tile_size=16, dtype="f32"): + def neg(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.cmp{dtype[0]} ole, %{operand1}, %{operand2} : {shape}' + return f'arith.negf %{operand} : {shape}', [tile_size, dtype] @staticmethod - def relu(x, tile_size=16, dtype=None): - return ops.maximum(x, ops.constant(0.0, torch.float32)) + def eq(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "oeq" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "eq" + else: + raise ValueError(f"Unsupported data type for 'eq' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def sigmoid(x, tile_size=16, dtype=None): - one = ops.constant(1, torch.float32) - return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x)))) + def ne(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "one" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sne" + else: + raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def neg(x, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.neg{dtype[0]} %{x} : {shape}' + def lt(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "olt" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "slt" + else: + raise ValueError(f"Unsupported data type for 'lt' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def where(condition, x, y, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - cond_shape = f"vector<{tile_size}xi1>," if tile_size > 1 else "" - return f"arith.select %{condition}, %{x}, %{y} : {cond_shape} {shape}" + def gt(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "ogt" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sgt" + else: + raise ValueError(f"Unsupported data type for 'gt' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def logical_not(operand, tile_size=16, dtype="f32"): - tile_size=16 - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - result_shape = f"vector<{tile_size}xi1>" if tile_size > 1 else "i1" + def le(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "ole" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sle" + else: + raise ValueError(f"Unsupported data type for 'le' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def ge(operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + if ret_type[0] == "f": + op_type = "arith.cmpf" + attribute = "oge" + elif ret_type[0] == "i": + op_type = "arith.cmpi" + attribute = "sge" + else: + raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + + @staticmethod + def and_(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand2] = dtype + + ret_type = op_type1[1] + tile_size = op_type1[0] + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.andi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def or_(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand2] = dtype + + ret_type = op_type1[1] + tile_size = op_type1[0] + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.ori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def xor(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + + # Type check & auto cast + if op_type1[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand1] = dtype + + # Type check & auto cast + if op_type2[1][0] != "i": + operand1, dtype = ops.to_dtype(operand1, "i32", var_info=var_info) + var_info[operand2] = dtype + + ret_type = op_type1[1] + tile_size = op_type1[0] + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f'arith.xori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + + @staticmethod + def logical_and(operand, *args, var_info=None): + raise NotImplementedError("logical_and") + + @staticmethod + def logical_not(operand, *args, var_info=None): raise NotImplementedError("logical_not") - return f"arith.cmp{dtype[0]} oeq, %{operand}, %zero_vec{tile_size} : {shape} -> {result_shape}" @staticmethod - def rsqrt(x, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.rsqrt %{x} : {shape}' + def logical_or(operand, *args, var_info=None): + raise NotImplementedError("logical_not") @staticmethod - def pow(a, b, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f"math.pow{dtype[0]} %{a}, %{b} : {shape}" + def logical_xor(operand, *args, var_info=None): + raise NotImplementedError("logical_not") @staticmethod - def log(x, tile_size=16, dtype="f32"): - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.log %{x} : {shape}' + def relu(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + ret_type = "f32" + return ops.maximum(operand, ops.constant(0.0, "f32")), [tile_size, ret_type] + + @staticmethod + def sigmoid(operand, *args, var_info=None): + op_type = var_info[operand] + tile_size = op_type[0] + ret_type = "f32" + one = ops.constant(1, "f32") + return ops.truediv(one, ops.add(one, ops.exp(ops.neg(operand)))), [tile_size, ret_type] + + # Special operaitons + @staticmethod + def where(condition, operand1, operand2, *args, var_info=None): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + cond_type = var_info[condition] + if cond_type[0] < tile_size: + condition = ops.broadcast(condition, operand1, var_info=var_info) + elif cond_type[0] > tile_size: + operand1 = ops.broadcast(operand1, condition, var_info=var_info) + operand2 = ops.broadcast(operand2, condition, var_info=var_info) + tile_size, ret_type = var_info[operand1] + + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + cond_shape = f"vector<{tile_size}xi1>," if tile_size > 1 else "" + return f"arith.select %{condition}, %{operand1}, %{operand2} : {cond_shape} {shape}", [tile_size, ret_type] + + + @staticmethod + def masked(mask, body, other, *args, var_info=None, tile_size=16, dtype="f32", ninf_declared=False): + result = body() + val = ops.constant(0.0, "f32") + result = ops.where(mask, result, val) + return result, var_info[result] + + @staticmethod + def _index_expr(operand, *args, var_info=None, **kwargs): + symbols = sorted([str(i) for i in operand.free_symbols]) + renamed_symbols = {symbol: sympy.Symbol(f"d{i}") for i, symbol in enumerate(symbols)} + + renamed_expression = operand.subs(renamed_symbols) + + affine_map_str = "(" + ", ".join([f"d{i}" for i in range(len(symbols))]) + ") -> (" + affine_map_str += sympy.printing.ccode(renamed_expression) + ")" + + map_operands = [f"%{str(symbol)}" for symbol in symbols] + return f"affine.apply affine_map<{affine_map_str}>({', '.join(map_operands)})", [1, "index"] + + @staticmethod + def index_expr(operand, *args, var_info=None, **kwargs): + result = ops._index_expr(operand) + ret_type = [1, "index"] + return result, ret_type + + @staticmethod + def index_cast(operand, target_type, *args, var_info=None, **kwrags): + op_type = var_info[operand] + src_shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else op_type[1] + des_shape = f"vector<{op_type[0]}x{target_type}>" if op_type[0] > 1 else target_type + return f"arith.index_cast %{operand} : {src_shape} to {des_shape}", [op_type[0], target_type] + @staticmethod - def reciprocal(a, tile_size=16, dtype="f32"): - return ops.div(ops.constant(1.0, torch.float32), a) + def broadcast(operand1, operand2, *args, var_info=None): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + src_shape = f"vector<{op_type1[0]}x{op_type1[1]}>" if op_type1[0] > 1 else op_type1[1] + des_shape = f"vector<{op_type2[0]}x{op_type1[1]}>" if op_type2[0] > 1 else op_type1[1] # Use tile size only + expand = f"vector.broadcast %{operand1} : {src_shape} to {des_shape}" + return expand, [op_type2[0], op_type1[1]] RTYPE_TO_MLIR = { "sum": "add", @@ -325,7 +677,7 @@ def __init__(self): self.reduction_suffix = IndentedBuffer() self.body = IndentedBuffer() self.global_vars = IndentedBuffer() - self.global_vars_set = set() + self.global_vars_dict = dict() self.header = IndentedBuffer() self.gem5_header = IndentedBuffer() self.reduction_vars = {} @@ -378,7 +730,7 @@ def find_node_by_name(self, name): if output_node.data.name == name: return output_node - def get_dma_info(self, name, index, dtype, is_store): + def get_dma_info(self, name, index, dtype): current_tile = MLIRTile(self.tile_desc.n_row, self.tile_desc.n_col, self.tile_desc.vector_lane, self.tile_desc.used_vector_lane) cv = self.get_constant_vector(index) cv2 = self.get_constant_vector2(index) @@ -485,7 +837,7 @@ def get_dma_info(self, name, index, dtype, is_store): def parse_indices(self, expr): if len(expr.args) == 0: - return expr, expr + return expr # Extract index var expr_str = str(expr) @@ -509,7 +861,7 @@ def parse_indices(self, expr): map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args}) -> ({expr_str})>") args = ", ".join([f"%{i}" for i in indices]) index = self.cse.generate(self.loads, f"affine.apply #{map_var}({args})") - return index, expr + return index def codegen_nodes(self, nodes, kernel_name): _, (group, reduction_group) = max( @@ -551,10 +903,11 @@ def load_epilogue(self, name: str, index: sympy.Expr): else: mvin3 = 14 self.consts.add(mvin3) + self.consts.add(0) dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, index) self.buffer_names[name] = buffer - line = f"affine.dma_start %{var}[%index2], %{buffer}[0, 0], %tag[0], %c{mvin3}, %N, %c_set : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" + line = f"affine.dma_start %{var}[%index2], %{buffer}[%c0, %c0], %tag[0], %c{mvin3}, %N, %c_set : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" self.cse.generate(self.loads, line, assignment = False) tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane @@ -562,23 +915,27 @@ def load_epilogue(self, name: str, index: sympy.Expr): shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" line = f"{operation} %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" out = self.cse.generate(self.loads, line) - self.tile_info[out] = tile_size_per_lane, dtype + var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] + self.register_var_info(out, var_info) return out def load(self, name: str, index: sympy.Expr): if self.is_template_kernel: return self.load_epilogue(name, index) index = self.rename_indexing(index) - indices, index = self.parse_indices(index) - prefix = "" if index.is_number else "%" + indices = self.parse_indices(index) + prefix = self.newvar_prefix + if index.is_number: + prefix = prefix + "c" + self.consts.add(int(index)) var = self.args.input(name) dtype = V.graph.get_dtype(name) type_name = mlir_common.DTYPE_TO_MLIR[dtype] - stride, chunk, tile_shape, tile_size_per_lane = self.get_dma_info(name, index, dtype, 0) + stride, chunk, tile_shape, tile_size_per_lane = self.get_dma_info(name, index, dtype) dram_tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" # Define scratch pad buffer - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.loads, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.loads, indices, index) # MVIN Encoding dma_key = (stride, chunk, dtype) if dma_key in self.dma_cache: @@ -592,28 +949,37 @@ def load(self, name: str, index: sympy.Expr): self.consts.add(chunk) self.dma_cache[dma_key] = dmaType, stride, chunk self.tags.add(f"{name}_tag") - code = f"affine.dma_start %{var}[{prefix}{indices}], %{buffer}[0, 0], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" + self.consts.add(0) + code = f"affine.dma_start %{var}[{prefix}{indices}], %{buffer}[%c0, %c0], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" self.cse.generate(self.loads, code, assignment = False) # FIXME: assignment = False does not support caching operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load" shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" line = f"{operation} %{buffer}[0, 0] : memref<{dram_tile_shape}x{type_name}, 1>{shape}" out = self.cse.generate(self.loads, line) - self.tile_info[out] = tile_size_per_lane, dtype + var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] + self.register_var_info(out, var_info) return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): - indices, index = self.parse_indices(index) - prefix = "" if index.is_number else "%" + indices = self.parse_indices(index) + prefix = self.newvar_prefix + if index.is_number: + prefix = prefix + "c" + self.consts.add(int(index)) var = self.args.output(name) dtype = V.graph.get_dtype(name) type_name = mlir_common.DTYPE_TO_MLIR[dtype] + chunk_size = self.tile_desc.get_chunk_size() + chunk = chunk_size << 1 | (self.tile_desc.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE) + self.consts.add(chunk) + if name in self.buffer_names: buffer = self.buffer_names[name] else: dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices, index) self.buffer_names[name] = buffer tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane @@ -623,23 +989,27 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): self.cse.generate(self.stores, line, assignment = False) self.tags.add(f"{name}_tag") - code = f"affine.dma_start %{buffer}[0, 0], %{var}[%index2], %tag[0], %c_mvout, %N, %c_set : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag + self.consts.add(0) + code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[%index2], %tag[0], %c_mvout, %N, %c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag self.cse.generate(self.stores, code, assignment = False) def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): if self.is_template_kernel: return self.store_epilogue(name, index, value, args, kwargs) index = self.rename_indexing(index) - indices, index = self.parse_indices(index) - prefix = "" if index.is_number else "%" + indices = self.parse_indices(index) + prefix = self.newvar_prefix + if index.is_number: + prefix = prefix + "c" + self.consts.add(int(index)) var = self.args.output(name) dtype = V.graph.get_dtype(name) type_name = mlir_common.DTYPE_TO_MLIR[dtype] - stride, chunk, tile_shape, tile_size_per_lane = self.get_dma_info(name, index, dtype, 1) + stride, chunk, tile_shape, tile_size_per_lane = self.get_dma_info(name, index, dtype) dram_tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" # Define scratch pad buffer - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.stores, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.stores, indices, index) # MVOUT Encoding dmaType = 3 # MVIN 2, MVIN2 1, MVIN3 14, MVOUT 3 @@ -647,14 +1017,17 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): self.consts.add(stride) self.consts.add(chunk) - store_size = self.tile_info[value][0] + store_size, operand_type = self.var_info[value] operation = "affine.vector_store" if tile_size_per_lane > 1 and store_size > 1 else "affine.store" shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 and store_size > 1 else "" + if type_name != operand_type: + value = ops.custom_cast(value, type_name, var_info=self.var_info) line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{dram_tile_shape}x{type_name}, 1>{shape}" self.cse.generate(self.stores, line, assignment = False) + self.consts.add(0) self.tags.add(f"{name}_tag") - code = f"affine.dma_start %{buffer}[0, 0], %{var}[{prefix}{indices}], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{dram_tile_shape}x{type_name}, 1>, memref<{self.buffer_types[name][1]}x{type_name}>, memref<1xi32>" + code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[{prefix}{indices}], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{dram_tile_shape}x{type_name}, 1>, memref<{self.buffer_types[name][1]}x{type_name}>, memref<1xi32>" self.cse.generate(self.stores, code, assignment = False) def reduction(self, dtype, src_dtype, reduction_type, value): @@ -711,13 +1084,14 @@ def reduction(self, dtype, src_dtype, reduction_type, value): init_vec = init axis = "0, 1" acc_var = init - self.tile_info[acc] = 1, dtype + var_info = [1, mlir_common.DTYPE_TO_MLIR[dtype]] else: reduced_shape = f"vector<{vec_len}x{type_name}>" init_vec = self.cse.generate(self.reduction_prefix, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") axis = "0" acc_var = init_vec - self.tile_info[acc] = vec_len, dtype + var_info = [vec_len, mlir_common.DTYPE_TO_MLIR[dtype]] + self.register_var_info(acc, var_info) else: raise NotImplementedError() @@ -735,14 +1109,16 @@ def store_reduction(self, name, index, value): dtype = V.graph.get_dtype(name) type_name = mlir_common.DTYPE_TO_MLIR[dtype] index = self.rename_indexing(index) - indices, index = self.parse_indices(index) - prefix = "" if index.is_number else "%" - + indices = self.parse_indices(index) + prefix = self.newvar_prefix + if index.is_number: + prefix = prefix + "c" + self.consts.add(int(index)) # Tile is always reuduced in inner loop tile_col = self.tile_desc.n_row tile_row = 1 dram_tile_shape = f"{tile_row}x{tile_col}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, self.reductions_suffix, indices) + buffer, indices = self.get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, self.reductions_suffix, indices, index) if self.welford_reduce_out is not None: # raise NotImplementedError() sum, sqr_sum, _ = self.welford_reduce_out @@ -750,7 +1126,7 @@ def store_reduction(self, name, index, value): # mean divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.ranges[self.reduction_depth])} : f32") if self.buffer_types[name][1] > 1: - divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.tile_info[sum][0]}x{type_name}>") + divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.var_info[sum][0]}x{type_name}>") else: divider_vec = f"f{self.buffer_types[name][1]}" mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sum}, %{divider_vec} : {shape}") @@ -792,7 +1168,8 @@ def store_reduction(self, name, index, value): self.consts.add(chunk) self.tags.add(f"{name}_tag") # Change row, col - code = f"affine.dma_start %{buffer}[0, 0], %{var}[{prefix}{indices}], %{name}_tag[0], %c{dmaType}, %c{mm_stride}, %c{chunk} : memref<{tile_row}x{tile_col}x{type_name}, 1>, memref<{self.buffer_types[name][1]}x{type_name}>, memref<1xi32>" + self.consts.add(0) + code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[{prefix}{indices}], %{name}_tag[0], %c{dmaType}, %c{mm_stride}, %c{chunk} : memref<{tile_row}x{tile_col}x{type_name}, 1>, memref<{self.buffer_types[name][1]}x{type_name}>, memref<1xi32>" self.cse.generate(self.reductions_suffix, code, assignment = False) def codegen_body(self): @@ -805,9 +1182,10 @@ def codegen_body(self): def template_store(options): subtile_size = [self.vector_lane, self.vector_lane] async_flag = 1 - line = f"affine.dma_start %Y_buffer[0, 0], %Y[%index2], %tag[0], %c_mvout, %N, %c_set"\ + self.consts.add(0) + line = f"affine.dma_start %Y_buffer[%c0, %c0], %Y[%index2], %tag[0], %c_mvout, %N, %c_set"\ f": memref<{options['TILE_M']}x{options['TILE_N']}xf32, 1>,"\ - f"memref<{options['M'] * options['N']}xf32>, memref<1xi32> " #FIXME: Using constant index and tag + f"memref<{options['M'] * options['N']}xf32>, memref<1xi32>" #FIXME: Using constant index self.cse.generate(self.stores, line, assignment = False) self.body.splice(self.codegen_init()) self.body.splice(self.loads) @@ -840,6 +1218,8 @@ def codegen_loops(self): loops = [LoopLevel(var, size, idx-len(self.itervars), tile_row=tile_row, tile_col=tile_col) for idx, (var, size) in enumerate(zip(self.itervars, self.ranges))] loops, reductions = [LoopNest(loops[: self.reduction_depth]), LoopNest(loops[self.reduction_depth :])] + if (self.reduction_depth==0): + loops = LoopNest([LoopLevel("dummy", 1, 1, 0)]) reductions.mark_reduction(self.reduction_vars) if len(self.affine_yield) > 0: vars = ', '.join([f"%{name}" for name, _ in self.affine_yield.items()]) @@ -907,6 +1287,10 @@ def _codegen_kernel(self, arg_defs, kernel_name): return code def adjust_tile_size(self): + if self.is_template_kernel: + self.tile_desc.n_row = self.render_options['TILE_M'] + self.tile_desc.n_col = self.render_options['TILE_N'] + return if self.read_writes is not None: read_writes = list(self.read_writes.reads) + list(self.read_writes.writes) cv_list = [] @@ -974,7 +1358,7 @@ def set_ranges(self, lengths, reduction_lengths, read_writes): self.itervars[self.reduction_depth :], ) - def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices): + def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index): c_type = mlir_common.DTYPE_TO_C[dtype] mlir_type = mlir_common.DTYPE_TO_MLIR[dtype] # Make sure each lane's buffer has at least two element @@ -987,13 +1371,17 @@ def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape mapping = self.map_cse.generate(self.global_vars, f"affine_map<({indices}) -> ({indices} floordiv 8)>") indices = self.cse.generate(self.loads, f"affine.apply #{mapping}(%{indices})") # FIXME. Only loads? - if name not in self.global_vars_set: + if name not in self.global_vars_dict: + self.global_vars_dict[name] = set() + + if str(raw_index) not in self.global_vars_dict[name]: + new_name = f"{name}_{len(self.global_vars_dict[name])}" # Add definition to header - self.header.writeline(f"{c_type} {name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") - self.gem5_header.writeline(f"{c_type} {name}_spad[{tile_size}];") - self.global_vars_set.add(name) - self.global_vars.writeline(f"memref.global @{name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") - buffer = self.cse.generate(code_buffer, f"memref.get_global @{name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") + self.header.writeline(f"{c_type} {new_name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") + self.gem5_header.writeline(f"{c_type} {new_name}_spad[{tile_size}];") + self.global_vars.writeline(f"memref.global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") + self.global_vars_dict[name].add(str(raw_index)) + buffer = self.cse.generate(code_buffer, f"memref.get_global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") return buffer, indices def roundup_vectorlane(self, size, amp=1): @@ -1068,6 +1456,7 @@ def can_fuse_vertical(self, node1, node2): return self.can_fuse_horizontal(node1, node2) and not node1.is_reduction() def can_fuse_horizontal(self, node1, node2): + return False _, (vars1, reduce1) = node1.group _, (vars2, reduce2) = node2.group if vars1 == vars2 and reduce1 == reduce2: diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index f76fd0cc..912704b5 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -172,7 +172,7 @@ def __init__(self, args=None): self.tile_col = extension_config.CONFIG_TILE_COL if self.tile_col == -1: self.tile_col = 8 # FIXME: tile_col is not always vector_lane * vlen - self.tile_info = {} + self.var_info = {} def load(self, name: str, index: sympy.Expr): raise NotImplementedError() @@ -193,32 +193,8 @@ def check_dtype_in_args(self, args): dtype = arg return dtype - def expand(self, args, buf_bounds): - cse_args = [arg for arg in args if isinstance(arg, common.CSEVariable)] - if len(cse_args) == 0: - return args, 1, self.check_dtype_in_args(args) - elif len(cse_args) == 1: - if not cse_args[0] in self.tile_info: - return args, 1, self.check_dtype_in_args(cse_args) - info = self.tile_info[cse_args[0]] - return args, info[0], info[1] - lhs_idx = args.index(cse_args[-2]) - rhs_idx = args.index(cse_args[-1]) - if not args[lhs_idx] in self.tile_info or not args[rhs_idx] in self.tile_info: - return args, 1, self.check_dtype_in_args(args) - lhs_tile_size, lhs_dtype = self.tile_info[args[lhs_idx]] - rhs_tile_size, rhs_dtype = self.tile_info[args[rhs_idx]] - lhs_shape = f"vector<{lhs_tile_size}x{DTYPE_TO_MLIR[lhs_dtype]}>" if lhs_tile_size > 1 else DTYPE_TO_MLIR[lhs_dtype] - rhs_shape = f"vector<{rhs_tile_size}x{DTYPE_TO_MLIR[rhs_dtype]}>" if rhs_tile_size > 1 else DTYPE_TO_MLIR[rhs_dtype] - temp = list(args) - if lhs_tile_size > rhs_tile_size: - expand = f"vector.broadcast %{args[rhs_idx]} : {rhs_shape} to {lhs_shape}" - temp[rhs_idx] = self.cse.generate(self.compute, expand, bounds=buf_bounds) - elif lhs_tile_size < rhs_tile_size: - expand = f"vector.broadcast %{args[lhs_idx]} : {lhs_shape} to {rhs_shape}" - temp[lhs_idx] = self.cse.generate(self.compute, expand, bounds=buf_bounds) - args = tuple(temp) - return args, max(lhs_tile_size, rhs_tile_size), lhs_dtype + def register_var_info(self, var, var_info): + self.var_info[var] = var_info def __enter__(self): class CSEProxy: @@ -235,13 +211,13 @@ def inner(*args, **kwargs): buf_bounds = self.node_to_bounds.get( fx_node, ValueRanges.unknown() ) - args, tile_size, dtype = self.expand(args, buf_bounds) + code, ret_info = getattr(parent_handler, name)(*args, var_info=self.var_info) csevar = self.cse.generate( self.compute, - getattr(parent_handler, name)(*args, tile_size=tile_size, dtype=DTYPE_TO_MLIR[dtype], **kwargs), # type: ignore[has-type] + code, bounds=buf_bounds, ) - self.tile_info[csevar] = tile_size, dtype + self.register_var_info(csevar, ret_info) csevar.update_on_args(name, args, kwargs) return csevar diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 56d43a11..3f52a61d 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -27,6 +27,7 @@ %c_mvin3 = arith.constant 14 : index %c_mvout = arith.constant 3 : index %c_set = arith.constant 2 : index + %c0 = arith.constant 0 : index %N = arith.constant {{ N }} : index %K = arith.constant {{ K }} : index @@ -38,16 +39,16 @@ affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { %index2 = affine.apply #map1(%t_m, %t_n) - affine.dma_start %B[%index2], %Y_buffer[0, 0], %tag[0], %c_mvin3, %N, %c_set : memref<{{ M * N }}xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ kernel.vector_lane }}], async=1 } + affine.dma_start %B[%index2], %Y_buffer[%c0, %c0], %tag[0], %c_mvin3, %N, %c_set : memref<{{ M * N }}xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ kernel.vector_lane }}], async=1 } affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { %index0 = affine.apply #map0(%t_m, %t_k) %index1 = affine.apply #map1(%t_k, %t_n) - affine.dma_start %X[%index0], %X_buffer[0, 0], %tag[0], %c_mvin, %K, %c_set : memref<{{ M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ TILE_K }}], async=1 } - affine.dma_start %W[%index1], %W_buffer[0, 0], %tag[0], %c_mvin2, %N, %c_set : memref<{{ K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ TILE_K }}, {{ kernel.vector_lane }}], async=1 } + affine.dma_start %X[%index0], %X_buffer[%c0, %c0], %tag[0], %c_mvin, %K, %c_set : memref<{{ M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ TILE_K }}], async=1 } + affine.dma_start %W[%index1], %W_buffer[%c0, %c0], %tag[0], %c_mvin2, %N, %c_set : memref<{{ K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ TILE_K }}, {{ kernel.vector_lane }}], async=1 } linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) } { accumulation_loop=true } - affine.dma_start %Y_buffer[0, 0], %Y[%index2], %tag[0], %c_mvout, %N, %c_set : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<{{ M * N }}xf32>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ kernel.vector_lane }}], async=1 } + affine.dma_start %Y_buffer[%c0, %c0], %Y[%index2], %tag[0], %c_mvout, %N, %c_set : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<{{ M * N }}xf32>, memref<1xi32> { async=1 } } { outer_loop=true } } { outer_loop=true } return @@ -59,7 +60,7 @@ def {{ FUNC_NAME }}({{ INPUT }}, {{ WEIGHT }}{% if BIAS %}, {{ BIAS }}{% endif %}, {{ OUT }}): {{ INPUT }}_cpu = {{ INPUT }}.cpu() {{ WEIGHT }}_cpu = {{ WEIGHT }}.cpu(){% if BIAS %} - {{ BIAS }}_cpu = {{ BIAS }}.cpu(){% endif %} #FIXME: BIAS is not used in the current implementation + {{ BIAS }}_cpu = {{ BIAS }}.cpu(){% endif %} {{ OUT }}_cpu = {{ OUT }}.cpu() # Torch support NCHW, so we need to transpose for now @@ -110,7 +111,8 @@ def {{ FUNC_NAME }}({{ INPUT }}, {{ WEIGHT }}{% if BIAS %}, {{ BIAS }}{% endif % {% endif %} {{ OUT }}_cpu = {{ OUT }}_cpu.reshape(output_shape) - {{ OUT }}_cpu = {{ OUT }}_cpu.permute(0, 3, 1, 2) + {{ OUT }}_cpu = {{ OUT }}_cpu.permute(0, 3, 1, 2){% if BIAS %} + {{ OUT }}_cpu += {{ BIAS }}_cpu.reshape(-1, 1, 1) #TODO: BIAS should be added in the kernel{% endif %} {{ OUT }}.copy_({{ OUT }}_cpu) """ @@ -173,7 +175,7 @@ def render(self, M = self.gemm_input_shape[2] * self.gemm_input_shape[3] N = self.gemm_weight_shape[0] K = self.gemm_weight_shape[1] - TILE_M, TILE_N, TILE_K = kernel.gemmini_gemm_mapping(M, N, K) + TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K) kernel.tile_size = [TILE_M, TILE_N, TILE_K] kernel.loop_size = [M, N, K] diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 83d43174..954059c0 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -33,8 +33,8 @@ %W_buffer = memref.get_global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> %Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> %tag = memref.alloc() : memref<1xi32>{% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ TILE_M * TILE_N // kernel.vector_lane }}xf32>{% else %} - %c0 = arith.constant 0 : index{% endif %} + %v0 = arith.constant dense<0.0> : vector<{{ TILE_M * TILE_N // kernel.vector_lane }}xf32>{% endif %} + %c0 = arith.constant 0 : index affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { @@ -42,7 +42,7 @@ {% if Bias -%} affine.dma_start %Bias[ {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer[0, 0], %tag[0], %c_mvin3, % + ], %Y_buffer[%c0, %c0], %tag[0], %c_mvin3, % {%- if Bias_rank == 2 -%} N {%- else -%} c0 {%- endif -%} , %c_set : memref< {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} @@ -53,16 +53,16 @@ affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { %index0 = affine.apply #map0(%t_m, %t_k) %index1 = affine.apply #map1(%t_k, %t_n) - affine.dma_start %X[%index0], %X_buffer[0, 0], %tag[0], %c_mvin, + affine.dma_start %X[%index0], %X_buffer[%c0, %c0], %tag[0], %c_mvin, {%- if X_transposed -%} %M, %x_chunk {%- else -%} %K, %x_chunk {%- endif -%} : memref<{{ M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ TILE_K }}], async=1{% if X_transposed %}, transpose=1{% endif %} } - affine.dma_start %W[%index1], %W_buffer[0, 0], %tag[0], %c_mvin2, + affine.dma_start %W[%index1], %W_buffer[%c0, %c0], %tag[0], %c_mvin2, {%- if W_transposed -%} %K, %w_chunk {%- else -%} %N, %w_chunk {%- endif -%} : memref<{{ K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ TILE_K }}, {{ kernel.vector_lane }}], async=1{% if W_transposed %}, transpose=1{% endif %} } linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) } { accumulation_loop=true } - {{kernel.store_output()}} { subtile_size=[{{ kernel.vector_lane }}, {{ kernel.vector_lane }}], async=1 } + {{kernel.store_output()}} } { outer_loop=true } } { outer_loop=true } return @@ -110,7 +110,7 @@ def render(self, TILE_M, TILE_N, TILE_K = 0, 0, 0 template = EMPTY_TEMPLATE else: - TILE_M, TILE_N, TILE_K = kernel.gemmini_gemm_mapping(M, N, K) + TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K) template = GEMM_TEMPLATE kernel.tile_size = [TILE_M, TILE_N, TILE_K] kernel.loop_size =[M, N, K] diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index c1950f82..e7ca37eb 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -144,5 +144,4 @@ def custom_maxpool( lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) -lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) -lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) \ No newline at end of file +lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py index 87c9ce0e..1f93f82a 100644 --- a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py +++ b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py @@ -23,11 +23,12 @@ %X_buffer = memref.get_global @X_spad : memref<{{ in_tile }}x{{ in_tile }}xf32, 1> %Y_buffer = memref.get_global @Y_spad : memref<{{ out_tile }}x{{ out_tile }}xf32, 1> %tag = memref.alloc() : memref<1xi32> + %c0 = arith.constant 0 : index affine.for %i = 0 to {{ BCH }} step {{ out_tile }} { affine.for %j = 0 to {{ W }} step {{ out_tile }} { %index0 = affine.apply #map0(%i, %j) - affine.dma_start %X[%index0], %X_buffer[0, 0], %tag[0], %c_mvin, %dummy, %in_chunk : memref<{{ IN }}xf32>, memref<{{ in_tile }}x{{ in_tile }}xf32, 1>, memref<1xi32> - affine.dma_start %Y_buffer[0, 0], %Y[%index0], %tag[0], %c_mvout, %dummy, %out_chunk : memref<{{ out_tile }}x{{ out_tile }}xf32, 1>, memref<{{ OUT }}xf32>, memref<1xi32> + affine.dma_start %X[%index0], %X_buffer[%c0, %c0], %tag[0], %c_mvin, %dummy, %in_chunk : memref<{{ IN }}xf32>, memref<{{ in_tile }}x{{ in_tile }}xf32, 1>, memref<1xi32> + affine.dma_start %Y_buffer[%c0, %c0], %Y[%index0], %tag[0], %c_mvout, %dummy, %out_chunk : memref<{{ out_tile }}x{{ out_tile }}xf32, 1>, memref<{{ OUT }}xf32>, memref<1xi32> } { outer_loop=true } } { outer_loop=true } return diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 476e73f5..63d00dab 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -20,7 +20,7 @@ from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo -from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel +from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, MLIRTile class MLIRTemplateKernel(MLIRKernel, BaseMLIRHardwareInfo): def __init__(self, @@ -106,6 +106,31 @@ def gemmini_gemm_mapping(self, M, N, K): return inner_I, inner_J, inner_K + def gemm_combination_mapping(self, M, N, K): + spad_size = self.spad_info["spad_size"] * self.vector_lane + max_spad_size = spad_size // 2 # double buffer + M_padded = ((M + self.vector_lane - 1) // self.vector_lane) * self.vector_lane + N_padded = ((N + self.vector_lane - 1) // self.vector_lane) * self.vector_lane + K_padded = ((K + self.vector_lane - 1) // self.vector_lane) * self.vector_lane + + max_used_spad_size = 0 + mapping = (self.vector_lane, self.vector_lane, self.vector_lane) + for tile_M in range(self.vector_lane, M_padded + 1, self.vector_lane): + for tile_N in range(self.vector_lane, N_padded + 1, self.vector_lane): + for tile_K in range(self.vector_lane, K_padded + 1, self.vector_lane): + used_spad_size = (tile_M * tile_K + tile_K * tile_N + tile_M * tile_N) * self.precision + if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size: + max_used_spad_size = used_spad_size + mapping = (tile_M, tile_N, tile_K) + + Outer_M = math.ceil(M_padded / mapping[0]) + Outer_N = math.ceil(N_padded / mapping[1]) + Outer_K = math.ceil(K_padded / mapping[2]) + + # split mapping equally to avoid unnecessary padding + mapping = (M_padded // Outer_M, N_padded // Outer_N, K_padded // Outer_K) + return mapping + def meta_kernel(self): wrapper = V.graph.wrapper_code arg_attributes = self.kernel_arg_attributes @@ -183,7 +208,6 @@ def hook(): return "" def store_output(self): - def hook(): self.codegen_body() return textwrap.indent(self.body.getvalue(), " ").strip() #TODO: First line is not indented diff --git a/tests/MLP/test_mlp.py b/tests/MLP/test_mlp.py index 5927502d..6f6c9444 100644 --- a/tests/MLP/test_mlp.py +++ b/tests/MLP/test_mlp.py @@ -21,133 +21,6 @@ sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) -try: - from PyTorchSimFrontend.mlir.mlir_codegen_backend import ( - MLIRScheduling, - ExtensionWrapperCodegen, - ) -except ImportError: - from PyTorchSimFrontend.mlir.mlir_codegen_backend import ( - MLIRScheduling, - ExtensionWrapperCodegen, - ) - -from torch._C import FileCheck -from torch._inductor import metrics -from torch._inductor.codegen.common import ( - get_scheduling_for_device, - get_wrapper_codegen_for_device, - register_backend_for_device, -) -from torch.testing._internal.common_utils import IS_MACOS -from torch.testing._internal.common_utils import TestCase as TorchTestCase - - -def remove_build_path(): - if sys.platform == "win32": - # Not wiping extensions build folder because Windows - return - default_build_root = torch.utils.cpp_extension.get_default_build_root() - if os.path.exists(default_build_root): - shutil.rmtree(default_build_root, ignore_errors=True) - -class TestCase(TorchTestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._stack = contextlib.ExitStack() - cls._stack.enter_context( - config.patch( - { - "debug": True, - "debug_index_asserts": True, - "cpp.min_chunk_size": 1, - "triton.autotune_pointwise": False, # too slow - "implicit_fallbacks": False, - "generate_intermediate_hooks": True, - } - ) - ) - - @classmethod - def tearDownClass(cls): - cls._stack.close() - super().tearDownClass() - - def setUp(self): - torch._dynamo.reset() - torch._inductor.metrics.reset() - super().setUp() - self._start = time.perf_counter() - - def tearDown(self): - super().tearDown() - torch._dynamo.reset() - if os.environ.get("ERROR_ON_SLOW") == "1": - elapsed = time.perf_counter() - self._start - assert elapsed < 120 - - -class ExtensionBackendTests(TestCase): - module = None - - @classmethod - def setUpClass(cls): - super().setUpClass() - - # Build Extension - remove_build_path() - source_file_path = os.path.dirname(os.path.abspath(__file__)) - source_file = os.path.join( - source_file_path, "PyTorchSimFrontend/extension_device.cpp" - ) - cls.module = torch.utils.cpp_extension.load( - name="extension_device", - sources=[ - str(source_file), - ], - extra_cflags=["-g"], - verbose=True, - ) - - @classmethod - def tearDownClass(cls): - cls._stack.close() - super().tearDownClass() - - remove_build_path() - - def setUp(self): - torch._dynamo.reset() - super().setUp() - - # cpp extensions use relative paths. Those paths are relative to - # this file, so we'll change the working directory temporarily - self.old_working_dir = os.getcwd() - os.chdir(os.path.dirname(os.path.abspath(__file__))) - assert self.module is not None - - def tearDown(self): - super().tearDown() - torch._dynamo.reset() - - # return the working directory (see setUp) - os.chdir(self.old_working_dir) - - def test_open_device_registration(self): - torch.utils.rename_privateuse1_backend("extension_device") - - register_backend_for_device( - "extension_device", MLIRScheduling, ExtensionWrapperCodegen - ) - self.assertTrue( - get_scheduling_for_device("extension_device") == MLIRScheduling - ) - self.assertTrue( - get_wrapper_codegen_for_device("extension_device") - == ExtensionWrapperCodegen - ) - def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): pass_message = f"|{name} Test Passed|" fail_message = f"|{name} Test Failed|" @@ -162,6 +35,7 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) + exit(1) class MLP(nn.Module): def __init__(self, input_size, output_size, hidden_size): diff --git a/tests/MoE/test_moe.py b/tests/MoE/test_moe.py index a7ab630b..ff6dd00b 100644 --- a/tests/MoE/test_moe.py +++ b/tests/MoE/test_moe.py @@ -22,132 +22,6 @@ sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) -try: - from PyTorchSimFrontend.mlir.mlir_codegen_backend import ( - MLIRScheduling, - ExtensionWrapperCodegen, - ) -except ImportError: - from .PyTorchSimFrontend.mlir.mlir_codegen_backend import ( - MLIRScheduling, - ExtensionWrapperCodegen, - ) - -from torch._C import FileCheck -from torch._inductor import metrics -from torch._inductor.codegen.common import ( - get_scheduling_for_device, - get_wrapper_codegen_for_device, - register_backend_for_device, -) -from torch.testing._internal.common_utils import IS_MACOS -from torch.testing._internal.common_utils import TestCase as TorchTestCase - -def remove_build_path(): - if sys.platform == "win32": - # Not wiping extensions build folder because Windows - return - default_build_root = torch.utils.cpp_extension.get_default_build_root() - if os.path.exists(default_build_root): - shutil.rmtree(default_build_root, ignore_errors=True) - -class TestCase(TorchTestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._stack = contextlib.ExitStack() - cls._stack.enter_context( - config.patch( - { - "debug": True, - "debug_index_asserts": True, - "cpp.min_chunk_size": 1, - "triton.autotune_pointwise": False, # too slow - "implicit_fallbacks": False, - "generate_intermediate_hooks": True, - } - ) - ) - - @classmethod - def tearDownClass(cls): - cls._stack.close() - super().tearDownClass() - - def setUp(self): - torch._dynamo.reset() - torch._inductor.metrics.reset() - super().setUp() - self._start = time.perf_counter() - - def tearDown(self): - super().tearDown() - torch._dynamo.reset() - if os.environ.get("ERROR_ON_SLOW") == "1": - elapsed = time.perf_counter() - self._start - assert elapsed < 120 - - -class ExtensionBackendTests(TestCase): - module = None - - @classmethod - def setUpClass(cls): - super().setUpClass() - - # Build Extension - remove_build_path() - source_file_path = os.path.dirname(os.path.abspath(__file__)) - source_file = os.path.join( - source_file_path, "PyTorchSimFrontend/extension_device.cpp" - ) - cls.module = torch.utils.cpp_extension.load( - name="extension_device", - sources=[ - str(source_file), - ], - extra_cflags=["-g"], - verbose=True, - ) - - @classmethod - def tearDownClass(cls): - cls._stack.close() - super().tearDownClass() - - remove_build_path() - - def setUp(self): - torch._dynamo.reset() - super().setUp() - - # cpp extensions use relative paths. Those paths are relative to - # this file, so we'll change the working directory temporarily - self.old_working_dir = os.getcwd() - os.chdir(os.path.dirname(os.path.abspath(__file__))) - assert self.module is not None - - def tearDown(self): - super().tearDown() - torch._dynamo.reset() - - # return the working directory (see setUp) - os.chdir(self.old_working_dir) - - def test_open_device_registration(self): - torch.utils.rename_privateuse1_backend("extension_device") - - register_backend_for_device( - "extension_device", MLIRScheduling, ExtensionWrapperCodegen - ) - self.assertTrue( - get_scheduling_for_device("extension_device") == MLIRScheduling - ) - self.assertTrue( - get_wrapper_codegen_for_device("extension_device") - == ExtensionWrapperCodegen - ) - def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): pass_message = f"|{name} Test Passed|" fail_message = f"|{name} Test Failed|" @@ -162,6 +36,7 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) + exit(1) class SparseDispatcher(object): """Helper for implementing a mixture of experts. diff --git a/tests/test_bmm.py b/tests/test_bmm.py index 1e0eab59..483980d8 100644 --- a/tests/test_bmm.py +++ b/tests/test_bmm.py @@ -13,12 +13,12 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("cpu out: ", cpu_out) exit(1) -def test_BMM(device): +def test_BMM(device, batch_size=1, m=32, n=16, k=64): def bmm(a, b): return torch.bmm(a, b.transpose(1, 2)) torch.manual_seed(0) - a = torch.randn(1, 32, 64).to(device=device) - b = torch.randn(1, 16, 64).to(device=device) + a = torch.randn(batch_size, m, k).to(device=device) + b = torch.randn(batch_size, n, k).to(device=device) opt_fn = torch.compile(dynamic=False)(bmm) res = opt_fn(a, b) out = bmm(a.cpu(), b.cpu()) @@ -33,3 +33,4 @@ def bmm(a, b): module = ExecutionEngine.setup_device() device = module.custom_device() test_BMM(device) + test_BMM(device, 2, 512, 512, 512) diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index 10f47faa..29924156 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -13,20 +13,22 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("cpu out: ", cpu_out) exit(1) -def test_conv2d(device): - def custom_conv2d(a, b): +def test_conv2d(device, batch_size=1, in_channels=8, out_channels=16, input_size=64, kernel_size=3, stride=1, padding=0): + def custom_conv2d(a, b, bias): i_c = a.shape[1] o_c = b.shape[0] - conv2d = torch.nn.Conv2d(i_c, o_c, b.shape[-1], stride=1, padding=0, dilation=1) + conv2d = torch.nn.Conv2d(i_c, o_c, b.shape[-1], stride=stride, padding=padding, dilation=1) conv2d.weight = torch.nn.Parameter(b) + conv2d.bias = torch.nn.Parameter(bias) return conv2d(a) torch.manual_seed(0) - conv_input = torch.randn(1, 8, 64, 64).to(device=device) - conv_kernel = torch.randn(16, 8, 3, 3).to(device=device) + conv_input = torch.randn(batch_size, in_channels, input_size, input_size).to(memory_format=torch.channels_last, device=device) + conv_kernel = torch.randn(out_channels, in_channels, kernel_size, kernel_size).to(memory_format=torch.channels_last, device=device) + conv_bias = torch.randn(out_channels).to(device=device) opt_fn = torch.compile(dynamic=False)(custom_conv2d) - res = opt_fn(conv_input, conv_kernel) - out = custom_conv2d(conv_input.cpu(), conv_kernel.cpu()) - test_result("Conv2d Forward", res, out, rtol=1e-1, atol=1e-1) + res = opt_fn(conv_input, conv_kernel, conv_bias) + out = custom_conv2d(conv_input.cpu(), conv_kernel.cpu(), conv_bias.cpu()) + test_result("Conv2d Forward", res, out, rtol=1e-3, atol=1e-3) print("Max diff > ", torch.max(torch.abs(res.cpu() - out))) if __name__ == "__main__": @@ -37,4 +39,4 @@ def custom_conv2d(a, b): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_conv2d(device) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=128, input_size=28, kernel_size=3, stride=1, padding=1) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 1fe0c674..37f8a583 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -14,11 +14,11 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): exit(1) def test_resnet(device): - from torchvision.models import resnet18 - model = resnet18().eval() - model.to(device) + from torchvision.models import resnet + model = resnet._resnet(resnet.BasicBlock, [1, 1, 0, 0], weights=None, progress=False).eval() + model.to(device, memory_format=torch.channels_last) input = torch.randn(1, 3, 224, 224).to(device=device) - x1 = input.to(device=device) + x1 = input.to(device=device, memory_format=torch.channels_last) opt_fn = torch.compile(dynamic=False)(model) res = opt_fn(x1) print("ResNet18 Simulation Done") diff --git a/tests/test_sparsity.py b/tests/test_sparsity.py index d4f4c273..c72dbb98 100644 --- a/tests/test_sparsity.py +++ b/tests/test_sparsity.py @@ -8,7 +8,8 @@ import torch._dynamo import torch.utils.cpp_extension sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) -from test_extension_backend import DecoderBlock, MLP, test_result +from test_transformer import DecoderBlock, test_result +from test_mlp import MLP def apply_random_zero(tensor, zero_prob, block_size=8): if not 0 <= zero_prob <= 1: