diff --git a/.github/workflows/docker-base-image.yml b/.github/workflows/docker-base-image.yml index 82614c8c..5f2005b7 100644 --- a/.github/workflows/docker-base-image.yml +++ b/.github/workflows/docker-base-image.yml @@ -27,7 +27,7 @@ jobs: with: registry: ghcr.io username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} # Step 4: Build and Push Docker Image - name: Build and Push Docker Image diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 2b420ff8..ece4a6d5 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -6,7 +6,7 @@ on: jobs: build: - runs-on: self-hosted + runs-on: [self-hosted, Linux] permissions: contents: read @@ -18,14 +18,16 @@ jobs: # Step 1: Checkout the repository - name: Checkout Code uses: actions/checkout@v4 - # Step 2: Log in to GitHub Container Registry (optional) - # If you need to push the built image, authenticate here. + with: + repository: PSAL-POSTECH/PyTorchSim + ref: ${{ env.GITHUB_SHA }} + submodules: recursive - name: Log in to GitHub Container Registry uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} # Step 3: Pull the Cached Image - name: Pull Cached Image & Set environment @@ -53,9 +55,7 @@ jobs: # Step 4: Build and Push Docker Image - name: Build and Push Docker Image - uses: docker/build-push-action@v4 - env: - GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + uses: docker/build-push-action@v6 with: context: . file: ./Dockerfile @@ -63,8 +63,9 @@ jobs: build-args: | GEM5_ASSET_ID=${{ env.GEM5_ASSET_ID }} LLVM_ASSET_ID=${{ env.LLVM_ASSET_ID }} - GIT_ACCESS_TOKEN=${{ env.GIT_ACCESS_TOKEN }} TORCHSIM_SHA=${{ env.GITHUB_SHA }} + secrets: | + GIT_ACCESS_TOKEN=${{ secrets.GIT_ACCESS_TOKEN }} tags: ghcr.io/psal-postech/${{ env.IMAGE_TAG }} test_add: @@ -72,7 +73,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_add.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_add.py" docker run --rm \ @@ -80,25 +89,41 @@ jobs: -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_add.py - test_relu: - name: Run test_relu.py + test_activation: + name: Run test_activation.py runs-on: self-hosted needs: build steps: - - name: Run test_relu.py + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_activation.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | - echo "Running test_relu.py" + echo "Running test_activation.py" docker run --rm \ -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ - ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_relu.py + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_activation.py test_batchnorm: name: Run test_batchnorm.py runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_batchnorm.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_batchnorm.py" docker run --rm \ @@ -111,7 +136,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_bmm.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_bmm.py" docker run --rm \ @@ -124,7 +157,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_cnn.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_cnn.py" docker run --rm \ @@ -137,7 +178,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_conv2d.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_conv2d.py" docker run --rm \ @@ -150,7 +199,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_matmul.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_matmul.py" docker run --rm \ @@ -163,7 +220,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_reduce.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_reduce.py" docker run --rm \ @@ -176,7 +241,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_softmax.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_softmax.py" docker run --rm \ @@ -189,7 +262,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_transpose2D.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_transpose2D.py" docker run --rm \ @@ -202,7 +283,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_view3D_2D.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_view3D_2D.py" docker run --rm \ @@ -215,7 +304,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_layernorm.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_layernorm.py" docker run --rm \ @@ -228,7 +325,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_mlp.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_mlp.py" docker run --rm \ @@ -241,7 +346,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_resnet.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_resnet.py" docker run --rm \ @@ -254,7 +367,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_transformer.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_transformer.py" docker run --rm \ @@ -267,7 +388,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_transpose3D.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_transpose3D.py" docker run --rm \ @@ -280,7 +409,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_sparsity.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_sparsity.py" docker run --rm \ @@ -293,7 +430,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_pool.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_pool.py" docker run --rm \ @@ -306,7 +451,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_single_perceptron.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_single_perceptron.py" docker run --rm \ @@ -319,21 +472,45 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_addmm_residual.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_addmm_residual.py" docker run --rm \ -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_addmm_residual.py + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_matmul_activation.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_matmul_activation.py" docker run --rm \ -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_activation.py + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_matmul_scalar.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_matmul_scalar.py" docker run --rm \ @@ -346,7 +523,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_moe.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_moe.py" docker run --rm \ @@ -354,6 +539,59 @@ jobs: -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/MoE/test_moe.py + test_mistral: + name: Run test_mistral + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_mistral.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_mistral.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Mixtral_8x7B/test_attention.py + + test_indirect: + name: Run test_indirect + runs-on: self-hosted + needs: build + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + steps: + - name: Run test_indirect.py + run: | + echo "Running test_indirect.py" + echo $GIT_ACCESS_TOKEN | docker login ghcr.io -u USERNAME --password-stdin + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_indirect_access.py + + test_scheduler: + name: Run test_scheduler + runs-on: self-hosted + needs: build + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + steps: + - name: Run test_scheduler.py + run: | + echo "Running test_scheduler.py" + echo $GIT_ACCESS_TOKEN | docker login ghcr.io -u USERNAME --password-stdin + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_scheduler.py + test_cleanup: name: Clean test cases runs-on: self-hosted @@ -361,8 +599,9 @@ jobs: test_matmul, test_reduce, test_softmax, test_transpose2D, test_view3D_2D, test_layernorm, test_mlp, test_resnet, test_transformer, test_transpose3D, - test_sparsity, test_relu, test_pool, test_perceptron, - test_fusion, test_moe] + test_sparsity, test_activation, test_pool, test_perceptron, + test_fusion, test_mistral, test_moe, test_indirect, test_scheduler] + steps: - name: Checkout code uses: actions/checkout@v3 @@ -370,4 +609,4 @@ jobs: run: | docker run --rm \ -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} chown -R $(id -u):$(id -g) /dump \ No newline at end of file + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} chown -R $(id -u):$(id -g) /dump \ No newline at end of file diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 003a0d01..ecdbf861 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -6,7 +6,7 @@ on: jobs: build: - runs-on: self-hosted + runs-on: [self-hosted, Linux] permissions: contents: read @@ -18,6 +18,10 @@ jobs: # Step 1: Checkout the repository - name: Checkout Code uses: actions/checkout@v4 + with: + repository: PSAL-POSTECH/PyTorchSim + ref: ${{ github.event.pull_request.head.sha }} + submodules: recursive # Step 2: Log in to GitHub Container Registry (optional) # If you need to push the built image, authenticate here. - name: Log in to GitHub Container Registry @@ -25,7 +29,7 @@ jobs: with: registry: ghcr.io username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} # Step 3: Pull the Cached Image - name: Pull Cached Image & Set environment @@ -53,9 +57,7 @@ jobs: # Step 4: Build and Push Docker Image - name: Build and Push Docker Image - uses: docker/build-push-action@v4 - env: - GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + uses: docker/build-push-action@v6 with: context: . file: ./Dockerfile @@ -63,16 +65,32 @@ jobs: build-args: | GEM5_ASSET_ID=${{ env.GEM5_ASSET_ID }} LLVM_ASSET_ID=${{ env.LLVM_ASSET_ID }} - GIT_ACCESS_TOKEN=${{ env.GIT_ACCESS_TOKEN }} TORCHSIM_SHA=${{ env.GITHUB_SHA }} + secrets: | + GIT_ACCESS_TOKEN=${{ secrets.GIT_ACCESS_TOKEN }} tags: ghcr.io/psal-postech/${{ env.IMAGE_TAG}} test_add: name: Run test_add.py runs-on: self-hosted + + permissions: + contents: read + packages: write + attestations: write + id-token: write needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_add.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_add.py" docker run --rm \ @@ -80,25 +98,41 @@ jobs: -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_add.py - test_relu: - name: Run test_relu.py + test_activation: + name: Run test_activation.py runs-on: self-hosted needs: build steps: - - name: Run test_relu.py + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_activation.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | - echo "Running test_relu.py" + echo "Running test_activation.py" docker run --rm \ -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ - ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_relu.py + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_activation.py test_batchnorm: name: Run test_batchnorm.py runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_batchnorm.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_batchnorm.py" docker run --rm \ @@ -111,7 +145,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_bmm.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_bmm.py" docker run --rm \ @@ -124,7 +166,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_cnn.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_cnn.py" docker run --rm \ @@ -137,7 +187,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_conv2d.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_conv2d.py" docker run --rm \ @@ -150,7 +208,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_matmul.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_matmul.py" docker run --rm \ @@ -163,7 +229,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_reduce.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_reduce.py" docker run --rm \ @@ -176,7 +250,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_softmax.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_softmax.py" docker run --rm \ @@ -189,7 +271,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_transpose2D.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_transpose2D.py" docker run --rm \ @@ -202,7 +292,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_view3D_2D.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_view3D_2D.py" docker run --rm \ @@ -215,7 +313,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_layernorm.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_layernorm.py" docker run --rm \ @@ -228,7 +334,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_mlp.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_mlp.py" docker run --rm \ @@ -241,7 +355,16 @@ jobs: runs-on: self-hosted needs: build steps: - - name: Run test_resnet.py + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + + - name: Run test_resnet18.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_resnet.py" docker run --rm \ @@ -249,12 +372,30 @@ jobs: -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_resnet.py + - name: Run test_resnet50.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_resnet.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_resnet.py --model_type resnet50 + test_transformer: name: Run test_transformer.py runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_transformer.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_transformer.py" docker run --rm \ @@ -267,7 +408,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_transpose3D.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_transpose3D.py" docker run --rm \ @@ -280,7 +429,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_sparsity.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_sparsity.py" docker run --rm \ @@ -293,7 +450,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_pool.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_pool.py" docker run --rm \ @@ -306,7 +471,15 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_single_perceptron.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_single_perceptron.py" docker run --rm \ @@ -319,21 +492,35 @@ jobs: runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_addmm_residual.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_addmm_residual.py" docker run --rm \ -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_addmm_residual.py + - name: Run test_matmul_activation.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_matmul_activation.py" docker run --rm \ -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_activation.py + - name: Run test_matmul_scalar.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_matmul_scalar.py" docker run --rm \ @@ -341,12 +528,70 @@ jobs: -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_scalar.py + - name: Run test_matmul_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_matmul_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_reduction.py + + - name: Run test_bmm_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_bmm_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_bmm_reduction.py + + - name: Run test_prologue_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_prologue_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_prologue_fusion.py + + - name: Run test_transformer_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_transformer_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_transformer_fusion.py + + - name: Run test_conv_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_conv_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_conv_fusion.py + test_moe: name: Run test_moe runs-on: self-hosted needs: build steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} - name: Run test_moe.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} run: | echo "Running test_moe.py" docker run --rm \ @@ -354,6 +599,59 @@ jobs: -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/MoE/test_moe.py + test_mistral: + name: Run test_mistral + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_mistral.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_mistral.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Mixtral_8x7B/test_attention.py + + test_indirect: + name: Run test_indirect + runs-on: self-hosted + needs: build + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + steps: + - name: Run test_indirect.py + run: | + echo "Running test_indirect.py" + echo $GIT_ACCESS_TOKEN | docker login ghcr.io -u USERNAME --password-stdin + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_indirect_access.py + + test_scheduler: + name: Run test_scheduler + runs-on: self-hosted + needs: build + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + steps: + - name: Run test_scheduler.py + run: | + echo "Running test_scheduler.py" + echo $GIT_ACCESS_TOKEN | docker login ghcr.io -u USERNAME --password-stdin + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_scheduler.py + test_cleanup: name: Clean test cases runs-on: self-hosted @@ -361,8 +659,8 @@ jobs: test_matmul, test_reduce, test_softmax, test_transpose2D, test_view3D_2D, test_layernorm, test_mlp, test_resnet, test_transformer, test_transpose3D, - test_sparsity, test_relu, test_pool, test_perceptron, - test_fusion, test_moe] + test_sparsity, test_activation, test_pool, test_perceptron, + test_fusion, test_mistral, test_moe, test_indirect, test_scheduler] steps: - name: Checkout code uses: actions/checkout@v3 @@ -370,4 +668,4 @@ jobs: run: | docker run --rm \ -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} chown -R $(id -u):$(id -g) /dump + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} chown -R $(id -u):$(id -g) /dump \ No newline at end of file diff --git a/.github/workflows/pull-request_mobile.yml b/.github/workflows/pull-request_mobile.yml new file mode 100644 index 00000000..053e3eac --- /dev/null +++ b/.github/workflows/pull-request_mobile.yml @@ -0,0 +1,658 @@ +name: PR test CI for mobile + +on: + pull_request: + branches: [ "master", "develop" ] + +jobs: + build: + runs-on: [self-hosted, Linux] + + permissions: + contents: read + packages: write + attestations: write + id-token: write + + steps: + # Step 1: Checkout the repository + - name: Checkout Code + uses: actions/checkout@v4 + with: + repository: PSAL-POSTECH/PyTorchSim + ref: ${{ env.github.event.pull_request.head.sha }} + submodules: recursive + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + + # Step 3: Pull the Cached Image + - name: Pull Cached Image & Set environment + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + docker pull ghcr.io/psal-postech/torchsim_base:latest || echo "No cache available" + echo "IMAGE_TAG=torchsim-ci:${GITHUB_SHA}" >> $GITHUB_ENV + echo "GITHUB_SHA=${{github.event.pull_request.head.sha}}" >> $GITHUB_ENV + echo "GITHUB_SHA=${{github.event.pull_request.head.sha}}" + gem5_response_file=/tmp/releases-gem5-latest.json + response=$(curl -sH "Authorization: Bearer ${GIT_ACCESS_TOKEN}" https://api.github.com/repos/PSAL-POSTECH/GEM5/releases/latest > ${gem5_response_file} ) + GEM5_ASSET_ID=$(cat ${gem5_response_file} | jq ".assets[0]."id"") + echo "GEM5_ASSET_ID=$GEM5_ASSET_ID" + echo "GEM5_ASSET_ID=$GEM5_ASSET_ID" >> $GITHUB_ENV + + llvm_response_file=/tmp/releases-gem5-latest.json + response=$(curl -sH "Authorization: Bearer ${GIT_ACCESS_TOKEN}" https://api.github.com/repos/PSAL-POSTECH/llvm-project/releases/latest > ${llvm_response_file} ) + LLVM_ASSET_ID=$(cat ${llvm_response_file} | jq ".assets[0]."id"") + echo "LLVM_ASSET_ID=$LLVM_ASSET_ID" + echo "LLVM_ASSET_ID=$LLVM_ASSET_ID" >> $GITHUB_ENV + + mkdir -p /tmp/torchsim-ci/${GITHUB_SHA} + echo "DUMP_PATH=/tmp/torchsim-ci/${GITHUB_SHA}" + + # Step 4: Build and Push Docker Image + - name: Build and Push Docker Image + uses: docker/build-push-action@v6 + with: + context: . + file: ./Dockerfile + push: true + build-args: | + GEM5_ASSET_ID=${{ env.GEM5_ASSET_ID }} + LLVM_ASSET_ID=${{ env.LLVM_ASSET_ID }} + TORCHSIM_SHA=${{ env.GITHUB_SHA }} + secrets: | + GIT_ACCESS_TOKEN=${{ secrets.GIT_ACCESS_TOKEN }} + tags: ghcr.io/psal-postech/${{ env.IMAGE_TAG}} + + test_add: + name: Run test_add.py + runs-on: self-hosted + + permissions: + contents: read + packages: write + attestations: write + id-token: write + needs: build + + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_add.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_add.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_add.py + + test_activation: + name: Run test_activation.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_activation.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_activation.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_activation.py + + test_batchnorm: + name: Run test_batchnorm.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_batchnorm.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_batchnorm.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_batchnorm.py + + test_bmm: + name: Run test_bmm.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_bmm.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_bmm.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_bmm.py + + test_cnn: + name: Run test_cnn.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_cnn.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_cnn.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_cnn.py + + test_conv2d: + name: Run test_conv2d.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_conv2d.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_conv2d.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_conv2d.py + + test_matmul: + name: Run test_matmul.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_matmul.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_matmul.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_matmul.py + + test_reduce: + name: Run test_reduce.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_reduce.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_reduce.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_reduce.py + + test_softmax: + name: Run test_softmax.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_softmax.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_softmax.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_softmax.py + + test_transpose2D: + name: Run test_transpose2D.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_transpose2D.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_transpose2D.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_transpose2D.py + + test_view3D_2D: + name: Run test_view3D_2D.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_view3D_2D.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_view3D_2D.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_view3D_2D.py + + test_layernorm: + name: Run test_layernorm.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_layernorm.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_layernorm.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_layernorm.py + + test_mlp: + name: Run test_mlp.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_mlp.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_mlp.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_mlp.py + + test_resnet: + name: Run test_resnet.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_resnet.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_resnet.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_resnet.py + + test_transformer: + name: Run test_transformer.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_transformer.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_transformer.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_transformer.py + + test_transpose3D: + name: Run test_transpose3D.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_transpose3D.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_transpose3D.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_transpose3D.py + + test_sparsity: + name: Run test_sparsity.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_sparsity.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_sparsity.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_sparsity.py + + test_pool: + name: Run test_pool.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_pool.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_pool.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_pool.py + + test_perceptron: + name: Run test_perceptron.py + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_single_perceptron.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_single_perceptron.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_single_perceptron.py + + test_fusion: + name: Run test_fusion + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_addmm_residual.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_addmm_residual.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_addmm_residual.py + + - name: Run test_matmul_activation.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_matmul_activation.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_activation.py + + - name: Run test_matmul_scalar.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_matmul_scalar.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_scalar.py + + - name: Run test_conv_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_conv_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_conv_fusion.py + + - name: Run test_matmul_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_matmul_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_reduction.py + + - name: Run test_bmm_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_bmm_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_bmm_reduction.py + + - name: Run test_prologue_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_prologue_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_prologue_fusion.py + + - name: Run test_transformer_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_transformer_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_transformer_fusion.py + + test_moe: + name: Run test_moe + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_moe.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_moe.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/MoE/test_moe.py + + test_mistral: + name: Run test_mistral + runs-on: self-hosted + needs: build + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_mistral.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_mistral.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Mixtral_8x7B/test_attention.py + + test_indirect: + name: Run test_indirect + runs-on: self-hosted + needs: build + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + steps: + - name: Run test_indirect.py + run: | + echo "Running test_indirect.py" + echo $GIT_ACCESS_TOKEN | docker login ghcr.io -u USERNAME --password-stdin + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_indirect_access.py + + test_scheduler: + name: Run test_scheduler + runs-on: self-hosted + needs: build + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + steps: + - name: Run test_scheduler.py + run: | + echo "Running test_scheduler.py" + echo $GIT_ACCESS_TOKEN | docker login ghcr.io -u USERNAME --password-stdin + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/test_scheduler.py + + test_cleanup: + name: Clean test cases + runs-on: self-hosted + needs: [test_add, test_batchnorm, test_bmm, test_cnn, test_conv2d, + test_matmul, test_reduce, test_softmax, + test_transpose2D, test_view3D_2D, test_layernorm, + test_mlp, test_resnet, test_transformer, test_transpose3D, + test_sparsity, test_activation, test_pool, test_perceptron, + test_fusion, test_mistral, test_moe, test_indirect, test_scheduler] + steps: + - name: Checkout code + uses: actions/checkout@v3 + - name: Clean test case + run: | + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} chown -R $(id -u):$(id -g) /dump diff --git a/.github/workflows/tag_release.yml b/.github/workflows/tag_release.yml new file mode 100644 index 00000000..258c0e40 --- /dev/null +++ b/.github/workflows/tag_release.yml @@ -0,0 +1,70 @@ +name: Build & Push Docker Image on Tag + +on: + push: + tags: + - 'v*' + +jobs: + build: + runs-on: self-hosted + + permissions: + contents: read + packages: write + id-token: write + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + repository: PSAL-POSTECH/PyTorchSim + ref: ${{ github.ref_name }} + submodules: recursive + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Set Tag Environment + run: | + echo "IMAGE_TAG=torchsim-ci:${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV + echo "GITHUB_SHA=$GITHUB_SHA" >> $GITHUB_ENV + echo "GITHUB_SHA=$GITHUB_SHA" + + - name: Pull Cached Image & Set environment + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + docker pull ghcr.io/psal-postech/torchsim_base:latest || echo "No cache available" + gem5_response_file=/tmp/releases-gem5-latest.json + response=$(curl -sH "Authorization: Bearer ${GIT_ACCESS_TOKEN}" https://api.github.com/repos/PSAL-POSTECH/GEM5/releases/latest > ${gem5_response_file} ) + GEM5_ASSET_ID=$(cat ${gem5_response_file} | jq ".assets[0]."id"") + echo "GEM5_ASSET_ID=$GEM5_ASSET_ID" + echo "GEM5_ASSET_ID=$GEM5_ASSET_ID" >> $GITHUB_ENV + + llvm_response_file=/tmp/releases-gem5-latest.json + response=$(curl -sH "Authorization: Bearer ${GIT_ACCESS_TOKEN}" https://api.github.com/repos/PSAL-POSTECH/llvm-project/releases/latest > ${llvm_response_file} ) + LLVM_ASSET_ID=$(cat ${llvm_response_file} | jq ".assets[0]."id"") + echo "LLVM_ASSET_ID=$LLVM_ASSET_ID" + echo "LLVM_ASSET_ID=$LLVM_ASSET_ID" >> $GITHUB_ENV + + mkdir -p /tmp/torchsim-ci/${GITHUB_SHA} + echo "DUMP_PATH=/tmp/torchsim-ci/${GITHUB_SHA}" + + - name: Build and Push Docker Image + uses: docker/build-push-action@v6 + with: + context: . + file: ./Dockerfile + push: true + build-args: | + GEM5_ASSET_ID=${{ env.GEM5_ASSET_ID }} + LLVM_ASSET_ID=${{ env.LLVM_ASSET_ID }} + TORCHSIM_SHA=${{ env.GITHUB_SHA }} + secrets: | + GIT_ACCESS_TOKEN=${{ secrets.GIT_ACCESS_TOKEN }} + tags: ghcr.io/psal-postech/${{ env.IMAGE_TAG}} \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..88eb2fb8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +PyTorchSimBackend/build/ +.vscode diff --git a/.gitmodules b/.gitmodules index 831a8746..f65e5f2b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,6 @@ [submodule "PyTorchSimBackend/extern/ramulator2"] path = PyTorchSimBackend/extern/ramulator2 url = https://github.com/PSAL-POSTECH/ramulator2 +[submodule "PyTorchSimBackend/extern/stonneCore"] + path = PyTorchSimBackend/extern/stonneCore + url = https://github.com/PSAL-POSTECH/stonne_core.git diff --git a/AsmParser/onnx_utility.py b/AsmParser/onnx_utility.py index d6eab9b2..4f76ef35 100644 --- a/AsmParser/onnx_utility.py +++ b/AsmParser/onnx_utility.py @@ -66,12 +66,15 @@ def __init__(self, tile_info, inst_list=list(), node_id=0): super().__init__(node_id) self.inst = inst_list self.torchsim_base_addr = tile_info["base_addr"] - self.torchsim_stride_list = tile_info["stride_list"] self.torchsim_tile_size = tile_info["tile_size"] + self.torchsim_tile_stride = tile_info["tile_stride"] self.torchsim_element_size = tile_info["element_size"] self.torchsim_tag_idx_list = tile_info["tag_idx_list"] + self.torchsim_tag_stride_list = tile_info["tag_stride_list"] self.torchsim_loop_idx_list = tile_info["loop_idx_list"] + self.torchsim_loop_stride_list = tile_info["loop_stride_list"] self.torchsim_is_async = tile_info["is_async"] + self.torchsim_indirect_mode = tile_info["indirect_mode"] class load_node(memory_node): pass @@ -83,6 +86,8 @@ class memory_wait_node(node): def __init__(self, tile_info, inst_list=list(), node_id=0): super().__init__(node_id) self.torchsim_tag_idx_list = tile_info["tag_idx_list"] + self.torchsim_tag_stride_list = tile_info["tag_stride_list"] + self.torchsim_tag_divider_list = tile_info["tag_divider_list"] self.torchsim_base_addr = tile_info["base_addr"] class compute_node(node): @@ -93,11 +98,79 @@ def __init__(self, inst_list=list(), cycle=0, overlapping_cycle=0, compute_type= self.torchsim_overlapping_cycle = overlapping_cycle self.torchsim_compute_type = compute_type +class stonne_node(node): + def __init__(self, tile_info, node_id=0): + super().__init__(node_id) + self.torchsim_stonne_operation = tile_info.get("stonne_operation", "CONV") + self.torchsim_stonne_layer_name = tile_info.get("stonne_layer_name", "") + self.torchsim_stonne_mem_init = tile_info.get("stonne_mem_init", "") + + # Convolution Parameters + self.torchsim_stonne_R = tile_info.get("stonne_R", 1) + self.torchsim_stonne_S = tile_info.get("stonne_S", 1) + self.torchsim_stonne_C = tile_info.get("stonne_C", 1) + self.torchsim_stonne_K = tile_info.get("stonne_K", 1) + self.torchsim_stonne_G = tile_info.get("stonne_G", 1) + self.torchsim_stonne_N = tile_info.get("stonne_N", 1) + self.torchsim_stonne_X = tile_info.get("stonne_X", 1) + self.torchsim_stonne_Y = tile_info.get("stonne_Y", 1) + self.torchsim_stonne_X_ = tile_info.get("stonne_X_", 1) + self.torchsim_stonne_Y_ = tile_info.get("stonne_Y_", 1) + self.torchsim_stonne_strides = tile_info.get("stonne_strides", 1) + + # Convolution Tile Parameters + self.torchsim_stonne_T_R = tile_info.get("stonne_T_R", 1) + self.torchsim_stonne_T_S = tile_info.get("stonne_T_S", 1) + self.torchsim_stonne_T_C = tile_info.get("stonne_T_C", 1) + self.torchsim_stonne_T_K = tile_info.get("stonne_T_K", 1) + self.torchsim_stonne_T_G = tile_info.get("stonne_T_G", 1) + self.torchsim_stonne_T_N = tile_info.get("stonne_T_N", 1) + self.torchsim_stonne_T_X_ = tile_info.get("stonne_T_X_", 1) + self.torchsim_stonne_T_Y_ = tile_info.get("stonne_T_Y_", 1) + + # GEMM Parameters + self.torchsim_stonne_GEMM_K = tile_info.get("stonne_GEMM_K", 1) + self.torchsim_stonne_GEMM_N = tile_info.get("stonne_GEMM_N", 1) + self.torchsim_stonne_GEMM_M = tile_info.get("stonne_GEMM_M", 1) + self.torchsim_stonne_GEMM_T_K = tile_info.get("stonne_GEMM_T_K", 1) + self.torchsim_stonne_GEMM_T_N = tile_info.get("stonne_GEMM_T_N", 1) + self.torchsim_stonne_GEMM_T_M = tile_info.get("stonne_GEMM_T_M", 1) + + # Memory Addresses + self.torchsim_stonne_matrix_a_dram_address = tile_info.get("stonne_matrix_a_dram_address", 0) + self.torchsim_stonne_matrix_b_dram_address = tile_info.get("stonne_matrix_b_dram_address", 0) + self.torchsim_stonne_matrix_c_dram_address = tile_info.get("stonne_matrix_c_dram_address", 0) + self.torchsim_stonne_mem_matrix_c_file_name = tile_info.get("stonne_mem_matrix_c_file_name", "") + + # Bitmap and CSR Data + self.torchsim_stonne_bitmap_matrix_a_init = tile_info.get("stonne_bitmap_matrix_a_init", "") + self.torchsim_stonne_bitmap_matrix_b_init = tile_info.get("stonne_bitmap_matrix_b_init", "") + self.torchsim_stonne_rowpointer_matrix_a_init = tile_info.get("stonne_rowpointer_matrix_a_init", "") + self.torchsim_stonne_colpointer_matrix_a_init = tile_info.get("stonne_colpointer_matrix_a_init", "") + self.torchsim_stonne_rowpointer_matrix_b_init = tile_info.get("stonne_rowpointer_matrix_b_init", "") + self.torchsim_stonne_colpointer_matrix_b_init = tile_info.get("stonne_colpointer_matrix_b_init", "") + self.torchsim_trace_path = tile_info.get("stonne_trace_path", "") + +class stonne_trace_compute_node(node): + def __init__(self, cycle=0, node_id=0): + super().__init__(node_id) + self.torchsim_trace_compute_cycle = cycle + +class stonne_trace_store_node(node): + def __init__(self, addr_list=list(), node_id=0): + super().__init__(node_id) + self.torchsim_trace_address = addr_list + +class stonne_trace_load_node(node): + def __init__(self, addr_list=list(), node_id=0): + super().__init__(node_id) + self.torchsim_trace_address = addr_list + def connect_nodes(parent, child): child.add_parent(parent) parent.add_child(child) -def dump_onnx_graph(name, node_list, sa_size, origin_info="dummy_tile_graph"): +def dump_onnx_graph(name, node_list, sa_size, origin_info="dummy_tile_graph", stonneGraph=False): graph_def = onnx.helper.make_graph( inputs=[], outputs=[], @@ -109,6 +182,10 @@ def dump_onnx_graph(name, node_list, sa_size, origin_info="dummy_tile_graph"): meta = model_def.metadata_props.add() meta.key = "systolic_size" meta.value = str(sa_size) + + meta = model_def.metadata_props.add() + meta.key = "stonneGraph" + meta.value = str(int(stonneGraph)) onnx.save(model_def, name) if __name__ == "__main__": diff --git a/AsmParser/tog_generator.py b/AsmParser/tog_generator.py index 1b5971e2..5f586d99 100644 --- a/AsmParser/tog_generator.py +++ b/AsmParser/tog_generator.py @@ -6,8 +6,10 @@ if __name__ == "__main__": from onnx_utility import node, loop_index_node, loop_end_node, load_node, store_node, memory_wait_node, compute_node, connect_nodes, dump_onnx_graph + from onnx_utility import stonne_node, stonne_trace_compute_node, stonne_trace_load_node, stonne_trace_store_node else: from AsmParser.onnx_utility import node, loop_index_node, loop_end_node, load_node, store_node, memory_wait_node, compute_node, connect_nodes, dump_onnx_graph + from AsmParser.onnx_utility import stonne_node, stonne_trace_compute_node, stonne_trace_load_node, stonne_trace_store_node def import_module_from_path(module_name, path): @@ -31,7 +33,11 @@ class tog_generator: LoopNodeKind = 2 DMANodeKind = 3 DMAWaitNodeKind = 4 - def __init__(self, origins=None) -> None: + StonneNodeKind = 5 + StonneTraceCompute= 6 + StonneTraceLoad = 7 + StonneTraceStore = 8 + def __init__(self, origins="Unknown") -> None: self.module_name = "tile_operation_graph" self.module = None self.raw_graph = {} @@ -85,12 +91,15 @@ def _create_node(self, dump_data): elif node_type == self.DMANodeKind: tile_info = {} tile_info["base_addr"] = dump_data["base_address"] - tile_info["stride_list"] = dump_data["stride_list"] tile_info["tile_size"] = dump_data["tile_size"] + tile_info["tile_stride"] = dump_data["tile_stride"] tile_info["element_size"] = dump_data["element_size"] tile_info["tag_idx_list"] = dump_data["tag_idx_list"] + tile_info["tag_stride_list"] = dump_data["tag_stride_list"] tile_info["loop_idx_list"] = dump_data["loop_idx_list"] + tile_info["loop_stride_list"] = dump_data["loop_stride_list"] tile_info["is_async"] = dump_data["is_async"] + tile_info["indirect_mode"] = dump_data["indirect_mode"] is_write = dump_data["is_write"] if is_write: new_node = store_node(tile_info, node_id=node_id) @@ -99,8 +108,18 @@ def _create_node(self, dump_data): elif node_type == self.DMAWaitNodeKind: tile_info = {} tile_info["tag_idx_list"] = dump_data["tag_idx_list"] + tile_info["tag_stride_list"] = dump_data["tag_stride_list"] + tile_info["tag_divider_list"] = dump_data["tag_divider_list"] tile_info["base_addr"] = dump_data["base_address"] new_node = memory_wait_node(tile_info, node_id=node_id) + elif node_type == self.StonneNodeKind: + new_node = stonne_node(dump_data, node_id=node_id) + elif node_type == self.StonneTraceCompute: + new_node = stonne_trace_compute_node(dump_data['trace_compute_cycle'], node_id=node_id) + elif node_type == self.StonneTraceLoad: + new_node = stonne_trace_load_node(dump_data['trace_address'], node_id=node_id) + elif node_type == self.StonneTraceStore: + new_node = stonne_trace_store_node(dump_data['trace_address'], node_id=node_id) else: print("Unexpected node_type :", node_type) exit(1) @@ -136,8 +155,6 @@ def create_node(self, dump_data, prev_node): connect_nodes(prev_node[-1].get_parent()[-1], new_node) elif isinstance(prev_node[-1], memory_wait_node) and isinstance(new_node, memory_wait_node): connect_nodes(prev_node[-1].get_parent()[-1], new_node) - elif isinstance(prev_node[-1], store_node) and isinstance(new_node, store_node): - connect_nodes(prev_node[-1].get_parent()[-1], new_node) elif isinstance(prev_node[-1], load_node) and isinstance(new_node, compute_node) or \ isinstance(prev_node[-1], memory_wait_node) and isinstance(new_node, compute_node): for pn in prev_node: @@ -192,7 +209,7 @@ def parse_graph(self): connect_nodes(prev_node, end_node) prev_node = end_node - def generate_tile_graph(self, name="tile_graph", cycle_list=list, offset=int, vector_lane=int): + def generate_tile_graph(self, name="tile_graph", cycle_list=list, x_offset=int, w_offset=int, vector_lane=int, stonneGraph=False): node_list = list(self.node_dict.values())[1:] if len(node_list): node_list[0].set_parent([]) @@ -204,14 +221,16 @@ def generate_tile_graph(self, name="tile_graph", cycle_list=list, offset=int, ve print("[TOGGen] Error compute cycle timing is missing...!") iter_node.torchsim_cycle = 10 # FIXME. - if iter_node.torchsim_compute_type == 1: - iter_node.torchsim_overlapping_cycle = iter_node.torchsim_cycle - offset + if iter_node.torchsim_compute_type > 0: + is_preload = iter_node.torchsim_compute_type == 2 + offset = w_offset if is_preload else x_offset + iter_node.torchsim_overlapping_cycle = max(iter_node.torchsim_cycle - offset, 0) origin_info = "_".join(map(str, self.origins)) onnx_node_list = [node.to_onnx() for node in node_list] # Exclude root node - dump_onnx_graph(name, onnx_node_list, vector_lane, origin_info) + dump_onnx_graph(name, onnx_node_list, vector_lane, origin_info, stonneGraph=stonneGraph) if __name__ == "__main__": t = tog_generator() - t.load_file("/workspace/llvm-project/build/tile_operation_graph.py") - t.parse_graph() \ No newline at end of file + t.load_file("/tmp/torchinductor/tmp/sz6qi7bqkxn/csz6qi7bqkxnam5sxok4l4sppddjkijq5rd55s4qvdutd5ni73fc_tog.py") + t.generate_tile_graph("./tile_graph.onnx", cycle_list=[1,1,1,1,1], x_offset=0, w_offset=0, vector_lane=128) \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 8e149883..44f6fd5e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,3 +1,4 @@ +# syntax=docker/dockerfile:1.4 # Copyright (c) 2020 The Regents of the University of California # All Rights Reserved. # @@ -26,7 +27,6 @@ FROM ghcr.io/psal-postech/torchsim_base:latest # Pass Access Token securely -ARG GIT_ACCESS_TOKEN ARG GEM5_ASSET_ID ARG LLVM_ASSET_ID ARG TORCHSIM_SHA @@ -34,14 +34,18 @@ ENV PATH $PATH:/root/.local/bin ENV LD_LIBRARY_PATH /usr/lib/x86_64-linux-gnu:/opt/conda/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH # Download GEM5 for torchsim -RUN curl -L -H "Accept: application/octet-stream" -H "Authorization: Bearer ${GIT_ACCESS_TOKEN}" https://api.github.com/repos/PSAL-POSTECH/gem5/releases/assets/${GEM5_ASSET_ID} -o /tmp/gem5-release.tar.gz && \ +RUN --mount=type=secret,id=GIT_ACCESS_TOKEN \ + GIT_ACCESS_TOKEN=$(cat /run/secrets/GIT_ACCESS_TOKEN) && \ + curl -L -H "Accept: application/octet-stream" -H "Authorization: Bearer ${GIT_ACCESS_TOKEN}" https://api.github.com/repos/PSAL-POSTECH/gem5/releases/assets/${GEM5_ASSET_ID} -o /tmp/gem5-release.tar.gz && \ mkdir -p /gem5 && \ tar -xzf /tmp/gem5-release.tar.gz -C /gem5 && \ rm /tmp/gem5-release.tar.gz ENV GEM5_PATH /gem5/release/gem5.opt # Download LLVM RISC-V for torchsim -RUN curl -L -H "Accept: application/octet-stream" -H "Authorization: Bearer ${GIT_ACCESS_TOKEN}" https://api.github.com/repos/PSAL-POSTECH/llvm-project/releases/assets/${LLVM_ASSET_ID} -o /tmp/riscv-llvm-release.tar.gz && \ +RUN --mount=type=secret,id=GIT_ACCESS_TOKEN \ + GIT_ACCESS_TOKEN=$(cat /run/secrets/GIT_ACCESS_TOKEN) && \ + curl -L -H "Accept: application/octet-stream" -H "Authorization: Bearer ${GIT_ACCESS_TOKEN}" https://api.github.com/repos/PSAL-POSTECH/llvm-project/releases/assets/${LLVM_ASSET_ID} -o /tmp/riscv-llvm-release.tar.gz && \ tar -xzf /tmp/riscv-llvm-release.tar.gz -C / && \ rm /tmp/riscv-llvm-release.tar.gz @@ -52,18 +56,20 @@ ENV TORCHSIM_DIR /workspace/PyTorchSim ENV LLVM_DIR /riscv-llvm # Install Spike simulator -RUN git clone https://${GIT_ACCESS_TOKEN}@github.com/PSAL-POSTECH/riscv-isa-sim.git --branch TorchSim && cd riscv-isa-sim && mkdir build && cd build && \ - ../configure --prefix=$RISCV && make -j && make install +RUN --mount=type=secret,id=GIT_ACCESS_TOKEN \ + GIT_ACCESS_TOKEN=$(cat /run/secrets/GIT_ACCESS_TOKEN) && \ + git clone https://$GIT_ACCESS_TOKEN@github.com/PSAL-POSTECH/riscv-isa-sim.git --branch TorchSim && cd riscv-isa-sim && mkdir build && cd build && \ + ../configure --prefix=$RISCV && make -j && make install && cd ../../ && rm -rf riscv-isa-sim # Install Proxy kernel RUN git clone https://github.com/riscv-software-src/riscv-pk.git && \ cd riscv-pk && git checkout 4f3debe4d04f56d31089c1c716a27e2d5245e9a1 && mkdir build && cd build && \ ../configure --prefix=$RISCV --host=riscv64-unknown-elf && make -j && make install -# Prepare ONNXim project -RUN git clone https://${GIT_ACCESS_TOKEN}@github.com/PSAL-POSTECH/PyTorchSim.git && cd PyTorchSim && git checkout ${TORCHSIM_SHA} +# Prepare PyTorchSim project +COPY . /workspace/PyTorchSim + RUN cd PyTorchSim/PyTorchSimBackend && \ - git submodule update --recursive --init && \ mkdir -p build && \ cd build && \ conan install .. --build=missing && \ diff --git a/PyTorchSimBackend/CMakeLists.txt b/PyTorchSimBackend/CMakeLists.txt index 3cb296d5..0d36d463 100644 --- a/PyTorchSimBackend/CMakeLists.txt +++ b/PyTorchSimBackend/CMakeLists.txt @@ -27,12 +27,10 @@ message("BINARY DIR ${CMAKE_BINARY_DIR}") add_subdirectory("${PROJECT_SOURCE_DIR}/src") # Add libaray ramulator -add_subdirectory("${PROJECT_SOURCE_DIR}/extern/ramulator_custom") - -# Add libaray ramulator +include_directories("${PROJECT_SOURCE_DIR}/include") add_subdirectory("${PROJECT_SOURCE_DIR}/extern/ramulator2") include_directories("${PROJECT_SOURCE_DIR}/extern/ramulator2/src") -include_directories("${PROJECT_SOURCE_DIR}/extern/ramulator2/resources/ndp_wrappers") +include_directories("${PROJECT_SOURCE_DIR}/extern/ramulator2/resources/wrappers") # Add libaray booksim add_subdirectory("${PROJECT_SOURCE_DIR}/extern/booksim") @@ -42,16 +40,19 @@ add_subdirectory("${PROJECT_SOURCE_DIR}/extern/protobuf/cmake" EXCLUDE_FROM_ALL) set_target_properties(libprotoc PROPERTIES FOLDER "external/protobuf") set_target_properties(protoc PROPERTIES FOLDER "external/protobuf") +# Add libaray stonne core +add_subdirectory("${PROJECT_SOURCE_DIR}/extern/stonneCore") + # Add libaray onnx add_definitions("-DONNX_NAMESPACE=onnx") add_subdirectory("${PROJECT_SOURCE_DIR}/extern/onnx" EXCLUDE_FROM_ALL) set_target_properties(onnx PROPERTIES FOLDER "extern/onnx") set_target_properties(onnx_proto PROPERTIES FOLDER "extern/onnx") +target_include_directories(Simulator PUBLIC ${PROJECT_SOURCE_DIR}/extern/stonneCore/include) +target_include_directories(Simulator PUBLIC ${PROJECT_SOURCE_DIR}/include) +target_include_directories(Simulator PUBLIC ${PROJECT_SOURCE_DIR}/include/scheduler) +target_include_directories(Simulator PUBLIC ${PROJECT_SOURCE_DIR}/src) target_include_directories(Simulator PUBLIC ${ONNX_INCLUDE_DIRS}) -target_link_libraries(Simulator ramulator1 booksim2 ramulator) -target_link_libraries(Simulator ${PROTOBUF_LIB} onnx_proto ${CONAN_LIBS} stdc++fs) - -target_include_directories(Simulator_lib PUBLIC ${ONNX_INCLUDE_DIRS}) -target_link_libraries(Simulator_lib ramulator1 booksim2 ramulator) -target_link_libraries(Simulator_lib ${PROTOBUF_LIB} onnx_proto ${CONAN_LIBS} stdc++fs) \ No newline at end of file +target_link_libraries(Simulator booksim2 ramulator sstStonne) +target_link_libraries(Simulator ${PROTOBUF_LIB} onnx_proto sstStonne ${CONAN_LIBS} stdc++fs) diff --git a/PyTorchSimBackend/configs/booksim2_configs/fly_c2_m32.icnt b/PyTorchSimBackend/configs/booksim2_configs/fly_c2_m32.icnt new file mode 100644 index 00000000..f8874f20 --- /dev/null +++ b/PyTorchSimBackend/configs/booksim2_configs/fly_c2_m32.icnt @@ -0,0 +1,17 @@ +[config] +use_map = 0 +flit_size = 64 +topology = fly +k = 34 +n = 1 +routing_function = dest_tag +subnets = 1 + +vc_buf_size = 64 +input_buffer_size = 256 +ejection_buffer_size = 64 +boundary_buffer_size = 64 +wait_for_tail_credit = 0 +vc_allocator = islip +sw_allocator = islip +alloc_iters = 1 diff --git a/PyTorchSimBackend/configs/booksim2_configs/fly_c32_m4.icnt b/PyTorchSimBackend/configs/booksim2_configs/fly_c32_m4.icnt new file mode 100644 index 00000000..e5765207 --- /dev/null +++ b/PyTorchSimBackend/configs/booksim2_configs/fly_c32_m4.icnt @@ -0,0 +1,18 @@ +[config] +use_map = 0 +flit_size = 32 +topology = fly +k = 36 +n = 1 +routing_function = dest_tag +subnets = 1 + +vc_buf_size = 256 +input_buffer_size = 256 +ejection_buffer_size = 256 +boundary_buffer_size = 256 +wait_for_tail_credit = 0 +vc_allocator = islip +sw_allocator = islip +alloc_iters = 1 +deadlock_warn_timeout = 10000 \ No newline at end of file diff --git a/PyTorchSimBackend/configs/booksim2_configs/fly_c32_m8.icnt b/PyTorchSimBackend/configs/booksim2_configs/fly_c32_m8.icnt new file mode 100644 index 00000000..29e573cb --- /dev/null +++ b/PyTorchSimBackend/configs/booksim2_configs/fly_c32_m8.icnt @@ -0,0 +1,18 @@ +[config] +use_map = 0 +flit_size = 32 +topology = fly +k = 40 +n = 1 +routing_function = dest_tag +subnets = 1 + +vc_buf_size = 256 +input_buffer_size = 256 +ejection_buffer_size = 256 +boundary_buffer_size = 256 +wait_for_tail_credit = 0 +vc_allocator = islip +sw_allocator = islip +alloc_iters = 1 +deadlock_warn_timeout = 10000 \ No newline at end of file diff --git a/PyTorchSimBackend/configs/heterogeneous_c2_simple_noc.json b/PyTorchSimBackend/configs/heterogeneous_c2_simple_noc.json new file mode 100644 index 00000000..8f196e81 --- /dev/null +++ b/PyTorchSimBackend/configs/heterogeneous_c2_simple_noc.json @@ -0,0 +1,34 @@ +{ + "core_type" : ["stonne", "ws_mesh"], + "stonne_config_path" : "/workspace/PyTorchSim/PyTorchSimBackend/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", + "num_cores" : 2, + "core_freq" : 940, + "sram_size" : 65536, + "core_print_interval" : 10000, + "num_stonne_per_core" : 8, + "num_stonne_port" : 64, + "num_systolic_array_per_core" : 2, + + "dram_type" : "ramulator2", + "dram_freq" : 940, + "dram_channels": 16, + "dram_req_size": 32, + "dram_latency" : 10, + "dram_size" : 32, + "dram_nbl" : 2, + "dram_print_interval": 10000, + "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", + + "icnt_type" : "simple", + "icnt_latency" : 7, + "icnt_freq" : 15000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c1_m8.icnt", + + "precision" : 4, + "scheduler" : "simple", + "num_partition" : 2, + "partition": { + "core_0":0, + "core_1":1 + } +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/ramulator2_configs/HBM2_TPUv3.yaml b/PyTorchSimBackend/configs/ramulator2_configs/HBM2_TPUv3.yaml new file mode 100644 index 00000000..e6543d14 --- /dev/null +++ b/PyTorchSimBackend/configs/ramulator2_configs/HBM2_TPUv3.yaml @@ -0,0 +1,25 @@ +Frontend: + impl: GEM5 + +MemorySystem: + impl: GenericDRAM + clock_ratio: 1 + + DRAM: + impl: HBM2 + org: + preset: HBM2_8Gb + channel: 1 + timing: + preset: HBM2_1.8Gbps + + Controller: + impl: Generic + Scheduler: + impl: FRFCFS + RefreshManager: + impl: AllBank + plugins: + + AddrMapper: + impl: RoBaRaCoCh \ No newline at end of file diff --git a/PyTorchSimBackend/configs/stonne_big_c1_simple_noc.json b/PyTorchSimBackend/configs/stonne_big_c1_simple_noc.json new file mode 100644 index 00000000..c7ef15f7 --- /dev/null +++ b/PyTorchSimBackend/configs/stonne_big_c1_simple_noc.json @@ -0,0 +1,32 @@ +{ + "core_type" : ["stonne"], + "stonne_config_path" : "/workspace/PyTorchSim/PyTorchSimBackend/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", + "num_cores" : 1, + "core_freq" : 940, + "sram_size" : 65536, + "core_print_interval" : 10000, + "num_stonne_per_core" : 8, + "num_stonne_port" : 64, + + "dram_type" : "ramulator2", + "dram_freq" : 940, + "dram_channels": 8, + "dram_req_size": 32, + "dram_latency" : 10, + "dram_size" : 32, + "dram_nbl" : 2, + "dram_print_interval": 10000, + "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", + + "icnt_type" : "simple", + "icnt_latency" : 7, + "icnt_freq" : 15000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c1_m8.icnt", + + "precision" : 4, + "scheduler" : "simple", + "num_partition" : 1, + "partition": { + "core_0":0 + } +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/stonne_single_c1_simple_noc.json b/PyTorchSimBackend/configs/stonne_single_c1_simple_noc.json new file mode 100644 index 00000000..2293e197 --- /dev/null +++ b/PyTorchSimBackend/configs/stonne_single_c1_simple_noc.json @@ -0,0 +1,31 @@ +{ + "core_type" : ["stonne"], + "stonne_config_path" : "/workspace/PyTorchSim/PyTorchSimBackend/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", + "num_cores" : 1, + "core_freq" : 700, + "sram_size" : 65536, + "core_print_interval" : 10000, + "num_stonne_per_core" : 1, + "num_stonne_port" : 8, + + "dram_type" : "ramulator2", + "dram_freq" : 700, + "dram_channels": 8, + "dram_req_size": 32, + "dram_latency" : 10, + "dram_nbl" : 2, + "dram_print_interval": 10000, + "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", + + "icnt_type" : "simple", + "icnt_latency" : 7, + "icnt_freq" : 7000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c1_m8.icnt", + + "precision" : 4, + "scheduler" : "simple", + "num_partition" : 1, + "partition": { + "core_0":0 + } +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/stonne_validation_c1_simple_noc.json b/PyTorchSimBackend/configs/stonne_validation_c1_simple_noc.json new file mode 100644 index 00000000..08548638 --- /dev/null +++ b/PyTorchSimBackend/configs/stonne_validation_c1_simple_noc.json @@ -0,0 +1,31 @@ +{ + "core_type" : ["stonne"], + "stonne_config_path" : "/workspace/PyTorchSim/PyTorchSimBackend/extern/stonneCore/tests/sparseflex_op_128mses_128_bw.cfg", + "num_cores" : 1, + "core_freq" : 1000, + "sram_size" : 65536, + "core_print_interval" : 10000, + "num_stonne_per_core" : 1, + "num_stonne_port" : 32, + + "dram_type" : "simple", + "dram_freq" : 1000, + "dram_channels": 1, + "dram_req_size": 32, + "dram_latency" : 100, + "dram_print_interval": 10000, + "l2d_type" : "datacache", + "l2d_config" : "S:128:128:64,32,L:T:m:W:L,A:192:4,32:0,32", + + "icnt_type" : "simple", + "icnt_latency" : 7, + "icnt_freq" : 7000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c1_m8.icnt", + + "precision" : 4, + "scheduler" : "simple", + "num_partition" : 1, + "partition": { + "core_0":0 + } +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c16_simple_noc_tpuv4.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c16_simple_noc_tpuv4.json deleted file mode 100644 index 17fe87bc..00000000 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c16_simple_noc_tpuv4.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "num_cores" : 16, - "core_freq" : 1000, - "sram_size" : 65536, - - "dram_type" : "ramulator", - "dram_freq" : 877, - "dram_channels": 32, - "dram_req_size": 32, - "dram_latency" : 10, - "dram_print_interval" : 10000, - "dram_config_path" : "../configs/ramulator_configs/HBM-config.cfg", - - "icnt_type" : "simple", - "icnt_latency" : 1, - "icnt_freq" : 2000, - "icnt_print_interval" : 10000, - "icnt_config_path" : "../configs/booksim2_configs/fly_c16_m32.icnt", - - "precision" : 2, - "scheduler" : "simple" -} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c1_booksim_tpuv2.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c1_booksim_tpuv2.json index e623730a..5d7b0d35 100644 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c1_booksim_tpuv2.json +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c1_booksim_tpuv2.json @@ -10,7 +10,7 @@ "dram_req_size": 32, "dram_latency" : 10, "dram_size" : 16, - "dram_nbl" : 1, + "dram_nbl" : 2, "dram_print_interval": 10000, "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.json index 12d5ee39..38acafc0 100644 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.json +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv2.json @@ -5,18 +5,18 @@ "core_print_interval" : 10000, "dram_type" : "ramulator2", - "dram_freq" :700, + "dram_freq" : 700, "dram_channels": 32, "dram_req_size": 32, "dram_latency" : 10, "dram_size" : 16, - "dram_nbl" : 1, + "dram_nbl" : 2, "dram_print_interval": 10000, "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", "icnt_type" : "simple", "icnt_latency" : 7, - "icnt_freq" : 7000, + "icnt_freq" : 10000, "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", "precision" : 4, diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json new file mode 100644 index 00000000..7348d5bc --- /dev/null +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json @@ -0,0 +1,29 @@ +{ + "num_cores" : 1, + "core_freq" : 940, + "sram_size" : 65536, + "core_print_interval" : 10000, + "num_systolic_array_per_core" : 2, + + "dram_type" : "ramulator2", + "dram_freq" : 940, + "dram_channels": 16, + "dram_req_size": 32, + "dram_latency" : 10, + "dram_size" : 32, + "dram_nbl" : 2, + "dram_print_interval": 10000, + "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", + + "icnt_type" : "simple", + "icnt_latency" : 7, + "icnt_freq" : 15000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", + + "precision" : 4, + "scheduler" : "simple", + "num_partition" : 1, + "partition": { + "core_0": 0 + } +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.json new file mode 100644 index 00000000..69ec8bd0 --- /dev/null +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3_half.json @@ -0,0 +1,29 @@ +{ + "num_cores" : 1, + "core_freq" : 940, + "sram_size" : 65536, + "core_print_interval" : 10000, + "num_systolic_array_per_core" : 2, + + "dram_type" : "ramulator2", + "dram_freq" : 940, + "dram_channels": 8, + "dram_req_size": 32, + "dram_latency" : 10, + "dram_size" : 32, + "dram_nbl" : 2, + "dram_print_interval": 10000, + "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", + + "icnt_type" : "simple", + "icnt_latency" : 7, + "icnt_freq" : 15000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", + + "precision" : 4, + "scheduler" : "simple", + "num_partition" : 1, + "partition": { + "core_0": 0 + } +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.json index e9df64b3..bff4e224 100644 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.json +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv4.json @@ -1,8 +1,9 @@ { "num_cores" : 1, - "core_freq" : 700, - "sram_size" : 65536, + "core_freq" : 1050, + "sram_size" : 16777216, "core_print_interval" : 10000, + "num_systolic_array_per_core" : 4, "dram_type" : "ramulator2", "dram_freq" :1200, @@ -10,20 +11,21 @@ "dram_req_size": 32, "dram_latency" : 10, "dram_size" : 16, - "dram_nbl" : 1, + "dram_nbl" : 2, "dram_print_interval": 10000, "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", + "l2d_type" : "datacache", + "l2d_config" : "S:128:128:512,32,L:T:m:W:L,A:192:4,32:0,32", "icnt_type" : "simple", - "icnt_latency" : 1, - "icnt_freq" : 8000, + "icnt_latency" : 7, + "icnt_freq" : 19200, "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", "precision" : 4, "scheduler" : "simple", - "num_partition" : 2, + "num_partition" : 1, "partition": { - "core_0":0, - "core_1":0 + "core_0":0 } } \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c2_booksim_tpuv2.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_booksim_tpuv3.json similarity index 51% rename from PyTorchSimBackend/configs/systolic_ws_128x128_c2_booksim_tpuv2.json rename to PyTorchSimBackend/configs/systolic_ws_128x128_c2_booksim_tpuv3.json index 73eb77d1..d51e9c5f 100644 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c2_booksim_tpuv2.json +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_booksim_tpuv3.json @@ -1,25 +1,26 @@ { "num_cores" : 2, - "core_freq" : 700, + "core_freq" : 940, "sram_size" : 65536, "core_print_interval" : 10000, + "num_systolic_array_per_core" : 2, "dram_type" : "ramulator2", - "dram_freq" :700, + "dram_freq" : 940, "dram_channels": 32, "dram_req_size": 32, "dram_latency" : 10, - "dram_size" : 16, - "dram_nbl" : 1, + "dram_size" : 32, + "dram_nbl" : 2, "dram_print_interval": 10000, - "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - + "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", + "icnt_type" : "booksim2", - "icnt_latency" : 1, - "icnt_freq" : 1000, - "icnt_node_per_core" : 16, - "icnt_config_path" : "../configs/booksim2_configs/fly_c32_m32.icnt", - + "icnt_latency" : 7, + "icnt_freq" : 28000, + "icnt_node_per_core" : 1, + "icnt_config_path" : "../configs/booksim2_configs/fly_c2_m32.icnt", + "precision" : 4, "scheduler" : "simple", "num_partition" : 2, diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv2.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json similarity index 70% rename from PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv2.json rename to PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json index f22cf1a7..b2661894 100644 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv2.json +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json @@ -1,25 +1,27 @@ { "num_cores" : 2, - "core_freq" : 700, + "core_freq" : 940, "sram_size" : 65536, "core_print_interval" : 10000, + "num_systolic_array_per_core" : 2, "dram_type" : "ramulator2", - "dram_freq" :700, + "dram_freq" : 940, "dram_channels": 32, "dram_req_size": 32, "dram_latency" : 10, - "dram_size" : 16, - "dram_nbl" : 1, + "dram_size" : 32, + "dram_nbl" : 2, "dram_print_interval": 10000, "dram_num_partitions" : 2, - "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", + "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", "icnt_type" : "booksim2", "icnt_latency" : 1, "icnt_freq" : 1000, "icnt_node_per_core" : 16, "icnt_config_path" : "../configs/booksim2_configs/chiplet_32_32_2.icnt", + "icnt_print_interval" : 10000, "precision" : 4, "scheduler" : "simple", diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv2_xnuma.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json similarity index 74% rename from PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv2_xnuma.json rename to PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json index 9f8922b4..922ede5b 100644 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv2_xnuma.json +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json @@ -1,19 +1,20 @@ { "num_cores" : 2, - "core_freq" : 700, + "core_freq" : 940, "sram_size" : 65536, "core_print_interval" : 10000, + "num_systolic_array_per_core" : 2, "dram_type" : "ramulator2", - "dram_freq" :700, + "dram_freq" : 940, "dram_channels": 32, "dram_req_size": 32, "dram_latency" : 10, - "dram_size" : 16, - "dram_nbl" : 1, + "dram_size" : 32, + "dram_nbl" : 2, "dram_print_interval": 10000, "dram_num_partitions" : 1, - "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", + "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", "icnt_type" : "booksim2", "icnt_latency" : 1, diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json index 8c6c07dc..034542fe 100644 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json @@ -10,13 +10,13 @@ "dram_req_size": 32, "dram_latency" : 10, "dram_size" : 16, - "dram_nbl" : 1, + "dram_nbl" : 2, "dram_print_interval": 10000, "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", "icnt_type" : "simple", "icnt_latency" : 7, - "icnt_freq" : 7000, + "icnt_freq" : 20000, "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", "precision" : 4, diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c4_simple_noc_tpuv4.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.json similarity index 55% rename from PyTorchSimBackend/configs/systolic_ws_128x128_c4_simple_noc_tpuv4.json rename to PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.json index 9606281d..82f42c00 100644 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c4_simple_noc_tpuv4.json +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv3.json @@ -1,21 +1,23 @@ { "num_cores" : 2, - "core_freq" : 700, + "core_freq" : 940, "sram_size" : 65536, "core_print_interval" : 10000, + "num_systolic_array_per_core" : 2, - "dram_type" : "ramulator", - "dram_freq" : 700, + "dram_type" : "ramulator2", + "dram_freq" : 940, "dram_channels": 32, "dram_req_size": 32, "dram_latency" : 10, - "dram_print_interval" : 10000, - "dram_config_path" : "../configs/ramulator_configs/HBM-config.cfg", + "dram_size" : 32, + "dram_nbl" : 2, + "dram_print_interval": 10000, + "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", "icnt_type" : "simple", - "icnt_latency" : 1, - "icnt_freq" : 2000, - "icnt_print_interval" : 0, + "icnt_latency" : 7, + "icnt_freq" : 28000, "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", "precision" : 4, diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c4_simple_noc_tpuv4_partition.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json similarity index 52% rename from PyTorchSimBackend/configs/systolic_ws_128x128_c4_simple_noc_tpuv4_partition.json rename to PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json index f705506a..132a52e6 100644 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c4_simple_noc_tpuv4_partition.json +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json @@ -1,24 +1,26 @@ { "num_cores" : 2, - "core_freq" : 1000, + "core_freq" : 940, "sram_size" : 65536, "core_print_interval" : 10000, + "num_systolic_array_per_core" : 2, - "dram_type" : "ramulator", - "dram_freq" : 877, + "dram_type" : "ramulator2", + "dram_freq" : 940, "dram_channels": 32, "dram_req_size": 32, "dram_latency" : 10, - "dram_print_interval" : 10000, - "dram_config_path" : "../configs/ramulator_configs/HBM-config.cfg", + "dram_size" : 32, + "dram_nbl" : 2, + "dram_print_interval": 10000, + "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", "icnt_type" : "simple", - "icnt_latency" : 1, - "icnt_freq" : 2000, - "icnt_print_interval" : 0, + "icnt_latency" : 7, + "icnt_freq" : 28000, "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", - "precision" : 2, + "precision" : 4, "scheduler" : "simple", "num_partition" : 2, "partition": { diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json new file mode 100644 index 00000000..4b4df4e6 --- /dev/null +++ b/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json @@ -0,0 +1,32 @@ +{ + "num_cores" : 2, + "core_freq" : 1050, + "sram_size" : 32768, + "core_print_interval" : 10000, + "num_systolic_array_per_core" : 4, + + "dram_type" : "ramulator2", + "dram_freq" :1200, + "dram_channels": 32, + "dram_req_size": 32, + "dram_latency" : 10, + "dram_size" : 32, + "dram_nbl" : 2, + "dram_print_interval": 10000, + "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", + "l2d_type" : "datacache", + "l2d_config" : "S:64:128:512,32,L:T:m:W:L,A:192:4,32:0,32", + + "icnt_type" : "simple", + "icnt_latency" : 7, + "icnt_freq" : 38400, + "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", + + "precision" : 4, + "scheduler" : "simple", + "num_partition" : 1, + "partition": { + "core_0":0, + "core_1":0 + } +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c32_simple_noc_tpuv4.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c32_simple_noc_tpuv4.json deleted file mode 100644 index 80814c42..00000000 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c32_simple_noc_tpuv4.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "num_cores" : 32, - "core_freq" : 1000, - "sram_size" : 65536, - - "dram_type" : "ramulator", - "dram_freq" : 877, - "dram_channels": 32, - "dram_req_size": 32, - "dram_latency" : 10, - "dram_print_interval" : 10000, - "dram_config_path" : "../configs/ramulator_configs/HBM-config.cfg", - - "icnt_type" : "simple", - "icnt_latency" : 1, - "icnt_freq" : 2000, - "icnt_print_interval" : 10000, - "icnt_config_path" : "../configs/booksim2_configs/fly_c32_m32.icnt", - - "precision" : 2, - "scheduler" : "simple" -} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c4_booksim_tpuv2.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c4_booksim_tpuv2.json deleted file mode 100644 index e387650c..00000000 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c4_booksim_tpuv2.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "num_cores" : 4, - "core_freq" : 700, - "sram_size" : 65536, - "core_print_interval" : 10000, - - "dram_type" : "ramulator2", - "dram_freq" :700, - "dram_channels": 64, - "dram_req_size": 32, - "dram_latency" : 10, - "dram_size" : 16, - "dram_nbl" : 1, - "dram_print_interval": 10000, - "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - - "icnt_type" : "booksim2", - "icnt_latency" : 1, - "icnt_freq" : 1000, - "icnt_node_per_core" : 16, - "icnt_config_path" : "booksim2_configs/fly_128.icnt", - - "precision" : 4, - "scheduler" : "simple" -} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c4_chiplet_map_tpuv2.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c4_chiplet_map_tpuv2.json deleted file mode 100644 index ec9493de..00000000 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c4_chiplet_map_tpuv2.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "num_cores" : 4, - "core_freq" : 700, - "sram_size" : 65536, - "core_print_interval" : 10000, - - "dram_type" : "ramulator2", - "dram_freq" :700, - "dram_channels": 64, - "dram_req_size": 32, - "dram_latency" : 10, - "dram_size" : 16, - "dram_nbl" : 1, - "dram_print_interval": 10000, - "dram_num_partitions" : 4, - "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - - "icnt_type" : "booksim2", - "icnt_latency" : 1, - "icnt_freq" : 1000, - "icnt_node_per_core" : 16, - "icnt_config_path" : "../configs/booksim2_configs/chiplet_16_16_4.icnt", - - "precision" : 4, - "scheduler" : "simple", - "num_partition" : 0 -} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c4_chiplet_tpuv2.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c4_chiplet_tpuv2.json deleted file mode 100644 index d06c505d..00000000 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c4_chiplet_tpuv2.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "num_cores" : 4, - "core_freq" : 700, - "sram_size" : 65536, - "core_print_interval" : 10000, - - "dram_type" : "ramulator2", - "dram_freq" :700, - "dram_channels": 64, - "dram_req_size": 32, - "dram_latency" : 10, - "dram_size" : 16, - "dram_nbl" : 1, - "dram_print_interval": 10000, - "dram_config_path" : "../configs/ramulator2_configs/HBM2.yaml", - - "icnt_type" : "booksim2", - "icnt_latency" : 1, - "icnt_freq" : 1000, - "icnt_node_per_core" : 16, - "icnt_config_path" : "../configs/booksim2_configs/chiplet_16_16_4.icnt", - - "precision" : 4, - "scheduler" : "simple", - "num_partition" : 0 -} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_128x128_c8_simple_noc_tpuv4.json b/PyTorchSimBackend/configs/systolic_ws_128x128_c8_simple_noc_tpuv4.json deleted file mode 100644 index 496531a5..00000000 --- a/PyTorchSimBackend/configs/systolic_ws_128x128_c8_simple_noc_tpuv4.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "num_cores" : 8, - "core_freq" : 1000, - "sram_size" : 65536, - - "dram_type" : "ramulator", - "dram_freq" : 877, - "dram_channels": 32, - "dram_req_size": 32, - "dram_latency" : 10, - "dram_print_interval" : 10000, - "dram_config_path" : "../configs/ramulator_configs/HBM-config.cfg", - - "icnt_type" : "simple", - "icnt_latency" : 1, - "icnt_freq" : 2000, - "icnt_print_interval" : 10000, - "icnt_config_path" : "../configs/booksim2_configs/fly_c8_m32.icnt", - - "precision" : 2, - "scheduler" : "simple" -} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_8x8_c1_12G_simple_noc.json b/PyTorchSimBackend/configs/systolic_ws_8x8_c1_12G_simple_noc.json new file mode 100644 index 00000000..e9a64f2e --- /dev/null +++ b/PyTorchSimBackend/configs/systolic_ws_8x8_c1_12G_simple_noc.json @@ -0,0 +1,24 @@ +{ + "num_cores" : 1, + "core_freq" : 1000, + "sram_size" : 256, + "core_print_interval" : 100000, + + "dram_type" : "ramulator2", + "dram_freq" :800, + "dram_channels": 1, + "dram_req_size": 64, + "dram_latency" : 10, + "dram_size" : 16, + "dram_nbl" : 4, + "dram_print_interval": 100000, + "dram_config_path" : "../configs/ramulator2_configs/DDR4.yaml", + + "icnt_type" : "simple", + "icnt_latency" : 1, + "icnt_freq" : 1000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c1_m1.icnt", + + "precision" : 4, + "scheduler" : "simple" +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_8x8_c1_24G_simple_noc.json b/PyTorchSimBackend/configs/systolic_ws_8x8_c1_24G_simple_noc.json new file mode 100644 index 00000000..37e18b35 --- /dev/null +++ b/PyTorchSimBackend/configs/systolic_ws_8x8_c1_24G_simple_noc.json @@ -0,0 +1,24 @@ +{ + "num_cores" : 1, + "core_freq" : 1000, + "sram_size" : 256, + "core_print_interval" : 100000, + + "dram_type" : "ramulator2", + "dram_freq" :800, + "dram_channels": 2, + "dram_req_size": 64, + "dram_latency" : 10, + "dram_size" : 16, + "dram_nbl" : 4, + "dram_print_interval": 100000, + "dram_config_path" : "../configs/ramulator2_configs/DDR4.yaml", + + "icnt_type" : "simple", + "icnt_latency" : 1, + "icnt_freq" : 8000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c1_m1.icnt", + + "precision" : 4, + "scheduler" : "simple" +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_8x8_c1_48G_simple_noc.json b/PyTorchSimBackend/configs/systolic_ws_8x8_c1_48G_simple_noc.json new file mode 100644 index 00000000..49225d77 --- /dev/null +++ b/PyTorchSimBackend/configs/systolic_ws_8x8_c1_48G_simple_noc.json @@ -0,0 +1,24 @@ +{ + "num_cores" : 1, + "core_freq" : 1000, + "sram_size" : 256, + "core_print_interval" : 100000, + + "dram_type" : "ramulator2", + "dram_freq" :800, + "dram_channels": 4, + "dram_req_size": 64, + "dram_latency" : 10, + "dram_size" : 16, + "dram_nbl" : 4, + "dram_print_interval": 100000, + "dram_config_path" : "../configs/ramulator2_configs/DDR4.yaml", + + "icnt_type" : "simple", + "icnt_latency" : 1, + "icnt_freq" : 8000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c1_m1.icnt", + + "precision" : 4, + "scheduler" : "simple" +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_8x8_c2_12G_simple_noc.json b/PyTorchSimBackend/configs/systolic_ws_8x8_c2_12G_simple_noc.json new file mode 100644 index 00000000..f76fec32 --- /dev/null +++ b/PyTorchSimBackend/configs/systolic_ws_8x8_c2_12G_simple_noc.json @@ -0,0 +1,25 @@ +{ + "core_type" : ["ws_mesh","ws_mesh"], + "num_cores" : 2, + "core_freq" : 1000, + "sram_size" : 256, + "core_print_interval" : 100000, + + "dram_type" : "ramulator2", + "dram_freq" :800, + "dram_channels": 1, + "dram_req_size": 64, + "dram_latency" : 10, + "dram_size" : 16, + "dram_nbl" : 4, + "dram_print_interval": 100000, + "dram_config_path" : "../configs/ramulator2_configs/DDR4.yaml", + + "icnt_type" : "simple", + "icnt_latency" : 1, + "icnt_freq" : 8000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c2_m4.icnt", + + "precision" : 4, + "scheduler" : "simple" +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_8x8_c2_24G_simple_noc.json b/PyTorchSimBackend/configs/systolic_ws_8x8_c2_24G_simple_noc.json new file mode 100644 index 00000000..7571b830 --- /dev/null +++ b/PyTorchSimBackend/configs/systolic_ws_8x8_c2_24G_simple_noc.json @@ -0,0 +1,24 @@ +{ + "num_cores" : 2, + "core_freq" : 1000, + "sram_size" : 256, + "core_print_interval" : 100000, + + "dram_type" : "ramulator2", + "dram_freq" :800, + "dram_channels": 2, + "dram_req_size": 64, + "dram_latency" : 10, + "dram_size" : 16, + "dram_nbl" : 4, + "dram_print_interval": 100000, + "dram_config_path" : "../configs/ramulator2_configs/DDR4.yaml", + + "icnt_type" : "simple", + "icnt_latency" : 1, + "icnt_freq" : 8000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c2_m8.icnt", + + "precision" : 4, + "scheduler" : "simple" +} \ No newline at end of file diff --git a/PyTorchSimBackend/configs/systolic_ws_8x8_c2_48G_simple_noc.json b/PyTorchSimBackend/configs/systolic_ws_8x8_c2_48G_simple_noc.json new file mode 100644 index 00000000..be163336 --- /dev/null +++ b/PyTorchSimBackend/configs/systolic_ws_8x8_c2_48G_simple_noc.json @@ -0,0 +1,24 @@ +{ + "num_cores" : 2, + "core_freq" : 1000, + "sram_size" : 256, + "core_print_interval" : 100000, + + "dram_type" : "ramulator2", + "dram_freq" :800, + "dram_channels": 4, + "dram_req_size": 64, + "dram_latency" : 10, + "dram_size" : 16, + "dram_nbl" : 4, + "dram_print_interval": 100000, + "dram_config_path" : "../configs/ramulator2_configs/DDR4.yaml", + + "icnt_type" : "simple", + "icnt_latency" : 1, + "icnt_freq" : 8000, + "icnt_config_path" : "../configs/booksim2_configs/fly_c1_m1.icnt", + + "precision" : 4, + "scheduler" : "simple" +} \ No newline at end of file diff --git a/PyTorchSimBackend/extern/ramulator2 b/PyTorchSimBackend/extern/ramulator2 index 10f8dcaa..748cd709 160000 --- a/PyTorchSimBackend/extern/ramulator2 +++ b/PyTorchSimBackend/extern/ramulator2 @@ -1 +1 @@ -Subproject commit 10f8dcaab94b5696988a92461500bd3212f82e7d +Subproject commit 748cd7099778d7196326aeb6384da92efb0c34c9 diff --git a/PyTorchSimBackend/extern/stonneCore b/PyTorchSimBackend/extern/stonneCore new file mode 160000 index 00000000..97804185 --- /dev/null +++ b/PyTorchSimBackend/extern/stonneCore @@ -0,0 +1 @@ +Subproject commit 97804185f00e98e56f74638e4282b9aecab8cfce diff --git a/PyTorchSimBackend/include/Cache.h b/PyTorchSimBackend/include/Cache.h new file mode 100644 index 00000000..1d5927ac --- /dev/null +++ b/PyTorchSimBackend/include/Cache.h @@ -0,0 +1,471 @@ +#ifndef CACHE_H_ +#define CACHE_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Cache_defs.h" +#include "Cache_stats.h" +#include "DelayQueue.h" +#include "Memfetch.h" + +class CacheConfig { + public: + CacheConfig() {} + void init(std::string config); + bool disabled() const { return m_disabled; } + uint32_t get_line_size() const { return m_line_size; } + uint32_t get_atom_size() const { return m_atom_size; } + uint32_t get_num_lines() const { return m_nset * m_assoc; } + uint32_t get_num_assoc() const { return m_assoc; } + uint32_t get_max_assoc() const { return m_origin_assoc; } + uint32_t get_max_sets() const { return m_origin_nset; } + uint32_t get_num_sets() const { return m_nset; } + void set_sets(uint32_t sets) { m_nset = sets; } + void set_assoc (uint32_t assoc) { m_assoc = assoc; } + uint32_t get_mshr_entries() const { return m_mshr_entries; } + uint32_t get_mshr_max_merge() const { return m_mshr_max_merge; } + uint32_t get_miss_queue_size() const { return m_miss_queue_size; } + uint32_t get_sector_size() { return m_sector_size; } + uint32_t get_set_index(uint64_t addr) const; + uint64_t get_tag(uint64_t addr) const; + uint64_t get_block_addr(uint64_t addr) const; + uint64_t get_mshr_addr(uint64_t addr) const; + CacheType get_cache_type() const { return m_cache_type; } + EvictPolicy get_evict_policy() const { return m_evict_policy; } + WritePolicy get_write_policy() const { return m_write_policy; } + WriteAllocatePolicy get_write_alloc_policy() const { + return m_write_alloc_policy; + } + AllocationPolicy get_alloc_policy() const { return m_alloc_policy; } + MshrConfig get_mshr_config() const { return m_mshr_type; } + uint32_t get_nset() const { return m_nset; } + uint32_t get_total_size_in_kb() const { + return (m_line_size * m_nset * m_assoc) / 1024; + } + uint32_t get_origin_size () const { + return m_line_size * m_origin_assoc * m_origin_nset; + } + uint32_t get_data_port_width() const { return m_data_port_width; } + protected: + bool m_valid = false; + bool m_disabled = false; + uint32_t m_origin_nset = 0; + uint32_t m_line_size = 0; + uint32_t m_line_size_log2 = 0; + uint32_t m_nset = 0; + uint32_t m_nset_log2 = 0; + uint32_t m_assoc = 0; + uint32_t m_origin_assoc = 0; + uint32_t m_atom_size = 0; + uint32_t m_sector_size = 0; + uint32_t m_sector_size_log2 = 0; + uint32_t m_mshr_entries = 0; + uint32_t m_mshr_max_merge = 0; + uint32_t m_miss_queue_size = 0; + uint32_t m_result_fifo_entries = 0; + uint32_t m_data_port_width = 0; + CacheType m_cache_type; + EvictPolicy m_evict_policy; + WritePolicy m_write_policy; + WriteAllocatePolicy m_write_alloc_policy; + AllocationPolicy m_alloc_policy; + MshrConfig m_mshr_type; + SetIndexFunction m_set_index_function; + + uint32_t hash_function(uint64_t addr) const; +}; + +class CacheBlock { + public: + virtual void allocate(uint64_t tage, uint64_t block_addr, uint32_t time, + SectorMask sector_mask) = 0; + virtual void fill(uint32_t time, SectorMask sector_mask) = 0; + virtual bool match_tag(uint64_t tag) { return m_tag == tag; } + virtual uint64_t get_block_addr() { return m_block_addr; } + virtual bool is_valid_line() = 0; + virtual bool is_invalid_line() = 0; + virtual bool is_reserved_line() = 0; + virtual bool is_modified_line() = 0; + virtual SectorMask get_dirty_mask() = 0; + virtual CacheBlockState get_status(SectorMask mask) = 0; + virtual void set_status(CacheBlockState status, SectorMask mask) = 0; + virtual bool is_readable(SectorMask mask) = 0; + virtual uint64_t get_last_access_time() = 0; + virtual uint64_t get_alloc_time() = 0; + virtual void set_ignore_on_fill(bool ignore, SectorMask sector_mask) = 0; + virtual void set_modified_on_fill(bool modified, SectorMask sector_mask) = 0; + virtual void set_readable(bool readable, SectorMask sector_mask) = 0; + virtual void set_last_access_time(uint64_t time, SectorMask sector_mask) = 0; + virtual uint32_t get_modified_size() = 0; + + protected: + uint64_t m_tag; + uint64_t m_block_addr; +}; + +class LineCacheBlock : public CacheBlock { + public: + LineCacheBlock(uint32_t sector_size) : m_sector_size(sector_size) {}; + virtual void allocate(uint64_t tag, uint64_t block_addr, uint32_t time, + SectorMask sector_mask) override; + virtual void fill(uint32_t time, SectorMask sector_mask) override; + virtual bool is_valid_line() override { return m_status == VALID; } + virtual bool is_invalid_line() override { return m_status == INVALID; } + virtual bool is_reserved_line() override { return m_status == RESERVED; } + virtual bool is_modified_line() override { return m_status == MODIFIED; } + virtual SectorMask get_dirty_mask() override; + virtual CacheBlockState get_status(SectorMask mask) override { + return m_status; + } + virtual void set_status(CacheBlockState status, SectorMask mask) override { + m_status = status; + } + virtual bool is_readable(SectorMask mask) override { return m_readable; } + virtual uint64_t get_last_access_time() override { + return m_last_access_time; + } + virtual uint64_t get_alloc_time() override { return m_alloc_time; } + virtual void set_ignore_on_fill(bool ignore, + SectorMask sector_mask) override { + m_ignore_on_fill_status = ignore; + } + virtual void set_modified_on_fill(bool modified, + SectorMask sector_mask) override { + m_set_modified_on_fill = modified; + } + virtual void set_readable(bool readable, SectorMask sector_mask) override { + m_readable = readable; + } + virtual void set_last_access_time(uint64_t time, + SectorMask sector_mask) override { + m_last_access_time = time; + } + virtual uint32_t get_modified_size() override { + return SECTOR_CHUNCK_SIZE * m_sector_size; + } + + protected: + uint64_t m_alloc_time = 0; + uint32_t m_sector_size = 0; + uint64_t m_last_access_time = 0; + uint64_t m_fill_time = 0; + CacheBlockState m_status = INVALID; + bool m_ignore_on_fill_status = false; + bool m_set_modified_on_fill = false; + bool m_readable = true; +}; + +class SectorCacheBlock : public CacheBlock { + public: + SectorCacheBlock(uint32_t sector_size) : m_sector_size(sector_size) {}; + virtual void allocate(uint64_t tag, uint64_t block_addr, uint32_t time, + SectorMask sector_mask) override; + virtual void allocate_sector(uint32_t time, SectorMask sector_mask); + virtual void fill(uint32_t time, SectorMask sector_mask) override; + virtual bool is_valid_line() override; + virtual bool is_invalid_line() override; + virtual bool is_reserved_line() override; + virtual bool is_modified_line() override; + virtual SectorMask get_dirty_mask() override; + virtual CacheBlockState get_status(SectorMask mask) override; + virtual void set_status(CacheBlockState status, SectorMask mask) override; + virtual bool is_readable(SectorMask mask) override; + virtual uint64_t get_last_access_time() override; + virtual uint64_t get_alloc_time() override; + virtual void set_ignore_on_fill(bool ignore, SectorMask sector_mask) override; + virtual void set_modified_on_fill(bool modified, + SectorMask sector_mask) override; + virtual void set_readable(bool readable, SectorMask sector_mask) override; + virtual void set_last_access_time(uint64_t time, + SectorMask sector_mask) override; + virtual uint32_t get_modified_size() override; + + private: + uint32_t m_sector_alloc_time[SECTOR_CHUNCK_SIZE] = {0}; + uint32_t m_sector_fill_time[SECTOR_CHUNCK_SIZE] = {0}; + uint32_t m_sector_last_access_time[SECTOR_CHUNCK_SIZE] = {0}; + uint32_t m_sector_size = 0; + uint32_t m_line_alloc_time = 0; + uint32_t m_line_fill_time = 0; + uint32_t m_line_last_access_time = 0; + CacheBlockState m_status[SECTOR_CHUNCK_SIZE] = {INVALID}; + bool m_ignore_on_fill_status[SECTOR_CHUNCK_SIZE] = {false}; + bool m_set_modified_on_fill_status[SECTOR_CHUNCK_SIZE] = {false}; + bool m_readable[SECTOR_CHUNCK_SIZE] = {true}; + + void init(); + + uint32_t get_sector_index(SectorMask sector_mask) { + assert(sector_mask.count() == 1); + for (unsigned i = 0; i < SECTOR_CHUNCK_SIZE; ++i) { + if (sector_mask.to_ulong() & (1 << i)) return i; + } + assert(false); + return 0; + } +}; + +class TagArray { + public: + TagArray(CacheConfig &config, int core_id, int type_id); + ~TagArray(); + CacheRequestStatus probe(uint64_t addr, uint32_t &idx, mem_fetch *mf, + bool probe_mode = false) const; + CacheRequestStatus probe(uint64_t addr, uint32_t &idx, SectorMask mask, + mem_fetch *mf = NULL, bool probe_mode = false) const; + CacheRequestStatus access(uint64_t addr, uint32_t time, uint32_t &idx, + mem_fetch *mf); + CacheRequestStatus access(uint64_t addr, uint32_t time, uint32_t &idx, + mem_fetch *mf, bool &wb, + EvictedBlockInfo &evicted_block); + void fill(uint64_t addr, uint32_t time, mem_fetch *mf); + void fill(uint32_t idx, uint32_t time, mem_fetch *mf); + void fill(uint64_t addr, uint32_t time, SectorMask mask); + uint32_t size() const { return m_config.get_num_lines(); } + CacheBlock *get_block(uint32_t idx) const { return m_lines[idx]; } + void invalidate(); + + protected: + CacheConfig &m_config; + CacheBlock **m_lines; /* N banks x M sets x assoc lines in total */ + uint32_t m_core_id; + uint32_t m_type_id; + uint32_t m_access; + uint32_t m_miss; + uint32_t m_pending_hit; + uint32_t m_res_fail; + uint32_t m_sector_miss; + bool is_used; + void init(int core_id, int type_id); +}; + +class MshrTable { + public: + MshrTable(uint32_t num_entries, uint32_t max_merged) + : m_num_entries(num_entries), m_max_merged(max_merged) {} + bool probe(uint64_t block_addr) const; + bool full(uint64_t block_addr) const; + void add(uint64_t block_addr, mem_fetch *mf); + bool busy() const { return false; } + void mark_ready(uint64_t block_addr, bool &has_atomic); + bool access_ready() const { return !m_current_response.empty(); } + mem_fetch *pop_next_access(); + mem_fetch *top_next_access(); + bool is_read_after_write_pending(uint64_t block_addr); + void print(FILE *fp) const; + + private: + const unsigned m_num_entries; + const unsigned m_max_merged; + + struct MshrEntry { + std::deque m_list; + bool m_has_atomic = false; + }; + std::map m_table; + std::map m_line_table; + bool m_current_response_ready; + std::deque m_current_response; +}; + +class Cache { + public: + Cache(std::string name, CacheConfig &config, int core_id, int type_id, + std::queue *to_mem_queue); + ~Cache() { + delete m_tag_array; + delete m_mshrs; + } + virtual CacheRequestStatus access(uint64_t addr, uint32_t time, mem_fetch *mf, + std::deque &event) = 0; + virtual void cycle(); + virtual void fill(mem_fetch *mf, uint32_t time); + virtual bool waiting_for_fill(mem_fetch *mf); + virtual bool access_ready() { return m_mshrs->access_ready(); } + virtual mem_fetch *pop_next_access() { return m_mshrs->pop_next_access(); } + virtual mem_fetch *top_next_access() { return m_mshrs->top_next_access(); } + virtual void invalidate() { m_tag_array->invalidate(); } + + virtual bool data_port_free() { + return m_bandwidth_management.data_port_free(); + } + virtual bool fill_port_free() { + return m_bandwidth_management.fill_port_free(); + } + // virtual bool miss_queue_size(bool from_ndp); + virtual void force_tag_access(uint64_t addr, uint32_t time, SectorMask mask) { + m_tag_array->fill(addr, time, mask); + } + virtual CacheStats get_stats() const { return m_stats; } + virtual void print_cache_stats() {} + + protected: + uint32_t m_id; + std::string m_name; + CacheConfig &m_config; + TagArray *m_tag_array; + MshrTable *m_mshrs; + std::deque m_miss_queue; + std::queue *m_to_mem_queue; + CacheStats m_stats; + struct ExtraMfFields { + bool m_valid = false; + uint64_t m_block_addr; + uint64_t m_addr; + uint32_t m_cache_index; + uint32_t m_data_size; + uint32_t pending_read; + }; + class BandwidthManagement { + public: + BandwidthManagement(CacheConfig &config) : m_config(config) {} + void use_data_port(mem_fetch *mf, CacheRequestStatus outcome, + const std::deque &events); + void use_fill_port(mem_fetch *mf); + void replenish_port_bandwidth(); + bool data_port_free() const; + bool fill_port_free() const; + + protected: + const CacheConfig &m_config; + int m_data_port_occupied_cycles = 0; + int m_fill_port_occupied_cycles = 0; + }; + + std::map m_extra_mf_fields; + BandwidthManagement m_bandwidth_management; + + protected: + /// Checks whether this request can be handled on this cycle. num_miss equals + /// max # of misses to be handled on this cycle + bool miss_queue_full(uint32_t num_misses) { + return (m_miss_queue.size() + num_misses) > m_config.get_miss_queue_size(); + ; + } + // Read miss handler without write back + void send_read_request(uint64_t addr, uint64_t block_addr, + uint32_t cache_index, mem_fetch *mf, uint32_t time, + bool &do_miss, std::deque &events, + bool read_only, bool wa); + // Read miss handler. Check MSHR hit or avaiable + void send_read_request(uint64_t addr, uint64_t block_addr, + uint32_t cache_index, mem_fetch *mf, uint32_t time, + bool &do_miss, bool &wb, EvictedBlockInfo &eviced, + std::deque &events, bool read_only, + bool wa); +}; + +class ReadOnlyCache : public Cache { + public: + ReadOnlyCache(std::string name, CacheConfig &config, int core_id, int type_id, + std::queue *to_mem_queue) + : Cache(name, config, core_id, type_id, to_mem_queue) {} + + virtual CacheRequestStatus access(uint64_t addr, uint32_t time, mem_fetch *mf, + std::deque &event) override; +}; + +class DataCache : public Cache { + public: + DataCache(std::string name, CacheConfig &config, int core_id, int type_id, + std::queue *to_mem_queue, bool is_l1 = false) + : Cache(name, config, core_id, type_id, to_mem_queue) { + init(); + m_write_alloc_type = L2_CACHE_WA; + m_write_back_type = L2_CACHE_WB; + } + virtual void init(); + virtual void print_cache_stats(); + virtual CacheRequestStatus access(uint64_t addr, uint32_t time, mem_fetch *mf, + std::deque &event) override; + protected: + mem_access_type m_write_alloc_type; + mem_access_type m_write_back_type; + CacheRequestStatus process_tag_probe(bool wr, CacheRequestStatus status, + uint64_t addr, uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &events); + // Functions for data cache access + /// Sends write request to lower level memory (write or writeback) + void send_write_request(mem_fetch *mf, CacheEvent request, uint32_t time, + std::deque &events); + void write_back(EvictedBlockInfo &evicted, uint32_t time, std::deque &events); + + CacheRequestStatus (DataCache::*m_wr_hit)(uint64_t addr, uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); + CacheRequestStatus (DataCache::*m_wr_miss)(uint64_t addr, + uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); + CacheRequestStatus (DataCache::*m_rd_hit)(uint64_t addr, uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); + CacheRequestStatus (DataCache::*m_rd_miss)(uint64_t addr, + uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); + + // Function pointers for different cache access + // Write hit + CacheRequestStatus wr_hit_wb( + uint64_t addr, uint32_t cache_index, mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); // write hit with write back + CacheRequestStatus wr_hit_wt( + uint64_t addr, uint32_t cache_index, mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); // write hit with write through + CacheRequestStatus wr_hit_we( + uint64_t addr, uint32_t cache_index, mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); // write hit with write evict + CacheRequestStatus wr_hit_global_we_local_wb( + uint64_t addr, uint32_t cache_index, mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); // write hit with write evict for global and + // write back for local + // Write miss + CacheRequestStatus wr_miss_wa_naive( + uint64_t addr, uint32_t cache_index, mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); // write allocate send write and read requsts + CacheRequestStatus wr_miss_wa_lazy_fetch_on_read( + uint64_t addr, uint32_t cache_index, mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); // write allocate with read-fetch-only + CacheRequestStatus wr_miss_wa_write_validate( + uint64_t addr, uint32_t cache_index, mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus + status); // write-allocate that writes with no read fetch + CacheRequestStatus wr_miss_no_wa( + uint64_t addr, uint32_t cache_index, mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); // no write allocate + + // Read hit + CacheRequestStatus rd_hit_base(uint64_t addr, uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); // read hit base + + // Read miss + CacheRequestStatus rd_miss_base(uint64_t addr, uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &event, + CacheRequestStatus status); // read miss base +}; + +#endif \ No newline at end of file diff --git a/PyTorchSimBackend/include/Cache_defs.h b/PyTorchSimBackend/include/Cache_defs.h new file mode 100644 index 00000000..af5035fc --- /dev/null +++ b/PyTorchSimBackend/include/Cache_defs.h @@ -0,0 +1,147 @@ +#ifndef CACHE_DEFS_H +#define CACHE_DEFS_H +#include +#include +#include +#include +#include + +const int SECTOR_CHUNCK_SIZE = 4; +typedef std::bitset SectorMask; +enum CacheBlockState { + INVALID, // Initial state + RESERVED, // Reserved state (alloc()) + VALID, // Filled state (fill) + MODIFIED // Filled with modified data (fill) +}; +enum CacheRequestStatus { + HIT, + HIT_RESERVED, + MISS, + RESERVATION_FAIL, + SECTOR_MISS, + MSHR_HIT, + NUM_CACHE_REQUEST_STATUS +}; +static const char *cache_request_status_str[] = { + "HIT", "HIT_RESERVED", "MISS", "RESERVATION_FAIL", + "SECTOR_MISS", "MSHR_HIT"}; + +enum CacheReservationFailReason { + LINE_ALLOC_FAIL, + MISS_QUEUE_FULL, + MSHR_ENTRY_FAIL, + MSHR_MERGE_ENTRY_FAIL, + MSHR_RW_PENDING, + NUM_CACHE_RESERVATION_FAIL_REASON +}; +static const char *cache_reservation_fail_reason_str[] = { + "LINE_ALLOCATE_FAIL", "MISS_QUEUE_FULL", "MSHR_ENTRY_FAIL", + "MSHR_MERGE_ENTRY_FAIL", "MSHR_RW_PENDING"}; + +enum CacheEventType { + WRITE_BACK_REQUEST_SENT, + READ_REQUEST_SENT, + WRITE_REQUEST_SENT, + WRITE_ALLOCATE_SENT +}; + +struct EvictedBlockInfo { + uint64_t m_block_addr = 0; + uint32_t m_modified_size = 0; + SectorMask m_dirty_mask; + void set_info(uint64_t block_addr, uint32_t modified_size, + SectorMask dirty_mask) { + m_block_addr = block_addr; + m_modified_size = modified_size; + m_dirty_mask = dirty_mask; + } +}; +struct CacheEvent { + CacheEvent() {} + CacheEvent(CacheEventType cache_event_type) + : m_cache_event_type(cache_event_type) {} + CacheEvent(CacheEventType cache_event_type, EvictedBlockInfo evicted_block) + : m_cache_event_type(cache_event_type), m_evicted_block(evicted_block) {} + CacheEventType m_cache_event_type; + EvictedBlockInfo m_evicted_block; // only valid for WRITE_BACK_REQUEST_SENT + static bool was_event_sent(const std::deque &events, + CacheEventType event_type, + CacheEvent &found_event) { + for (auto &event : events) { + if (event.m_cache_event_type == event_type) { + found_event = event; + return true; + } + } + return false; + } + static bool was_write_sent(const std::deque &events) { + CacheEvent event; + return was_event_sent(events, WRITE_REQUEST_SENT, event); + } + static bool was_read_sent(const std::deque &events) { + CacheEvent event; + return was_event_sent(events, READ_REQUEST_SENT, event); + } + static bool was_writeback_sent(const std::deque &events, + CacheEvent event) { + return was_event_sent(events, WRITE_BACK_REQUEST_SENT, event); + } + static bool was_write_allocate_sent(const std::deque &events) { + CacheEvent event; + return was_event_sent(events, WRITE_ALLOCATE_SENT, event); + } +}; + +enum WritePolicy { + READ_ONLY, + WRITE_BACK, + WRITE_THROUGH, + WRITE_EVICT, + LOCAL_WB_GLOBAL_WT +}; +static std::map WritePolicyMap = {{'R', READ_ONLY}, + {'B', WRITE_BACK}, + {'T', WRITE_THROUGH}, + {'E', WRITE_EVICT}, + {'L', LOCAL_WB_GLOBAL_WT}}; + +enum AllocationPolicy { ON_MISS, ON_FILL, STREAMING }; +static std::map AllocationPolicyMap = { + {'m', ON_MISS}, {'f', ON_FILL}, {'s', STREAMING}}; + +enum WriteAllocatePolicy { + NO_WRITE_ALLOCATE, + WRITE_ALLOCATE, + FETCH_ON_WRITE, + LAZY_FETCH_ON_READ +}; +static std::map WriteAllocatePolicyMap = { + {'N', NO_WRITE_ALLOCATE}, + {'W', WRITE_ALLOCATE}, + {'F', FETCH_ON_WRITE}, + {'L', LAZY_FETCH_ON_READ}}; + +enum CacheType { NORMAL, SECTOR }; +static std::map CacheTypeMap = {{'N', NORMAL}, {'S', SECTOR}}; + +enum EvictPolicy { LRU, FIFO }; +static std::map EvictPolicyMap = {{'L', LRU}, {'F', FIFO}}; + +enum MshrConfig { ASSOC, SECTOR_ASSOC }; +static std::map MshrConfigMap = {{'A', ASSOC}, + {'S', SECTOR_ASSOC}}; + +enum SetIndexFunction { + LINEAR_SET_FUNCTION, + BITWISE_XORING_FUNCTION, + HASH_IPOLY_FUNCTION, + CUSTOM_SET_FUNCTION +}; +static std::map SetIndexFunctionMap = { + {'L', LINEAR_SET_FUNCTION}, + {'X', BITWISE_XORING_FUNCTION}, + {'P', HASH_IPOLY_FUNCTION}, + {'C', CUSTOM_SET_FUNCTION}}; +#endif \ No newline at end of file diff --git a/PyTorchSimBackend/include/Cache_stats.h b/PyTorchSimBackend/include/Cache_stats.h new file mode 100644 index 00000000..1bf92d8a --- /dev/null +++ b/PyTorchSimBackend/include/Cache_stats.h @@ -0,0 +1,51 @@ +#ifndef CACHE_STATS_H +#define CACHE_STATS_H +#include +#include +#include + +#include "Cache_defs.h" + +class CacheStats { + public: + CacheStats(); + void clear(); + void inc_stats(int access_type, int accss_outcome); + void inc_fail_stats(int access_type, int fail_outcome); + CacheRequestStatus select_stats_status(CacheRequestStatus probe, + CacheRequestStatus access) const; + uint64_t &operator()(int access_type, int access_outcome, bool fail_outcome); + uint64_t operator()(int access_type, int access_outcome, + bool fail_outcome) const; + CacheStats operator+(const CacheStats &other); + CacheStats &operator+=(const CacheStats &other); + void aggregate_stats(); + uint64_t get_hit() const; + uint64_t get_read_hit() const; + uint64_t get_write_hit() const; + uint64_t get_miss() const; + uint64_t get_read_miss() const; + uint64_t get_write_miss() const; + uint64_t get_accesses() const; + uint64_t get_interval_hit(); + uint64_t get_interval_miss(); + void print_stats(FILE *out, const char *cache_name = "CacheStats") const; + void print_fail_stats(FILE *out, const char *cache_name = "CacheStats") const; + void print_energy_stats(FILE *out, + const char *cache_name = "CacheStats") const; + + private: + bool check_valid(int type, int status) const; + bool check_fail_valid(int type, int fail) const; + + std::vector> m_stats; + std::vector> m_fail_stats; + + uint64_t m_cache_port_available_cycles; + uint64_t m_cache_data_port_busy_cycles; + uint64_t m_cache_fill_port_busy_cycles; + + uint64_t m_prev_hit; + uint64_t m_prev_miss; +}; +#endif \ No newline at end of file diff --git a/PyTorchSimBackend/src/Common.h b/PyTorchSimBackend/include/Common.h similarity index 100% rename from PyTorchSimBackend/src/Common.h rename to PyTorchSimBackend/include/Common.h diff --git a/PyTorchSimBackend/include/Core.h b/PyTorchSimBackend/include/Core.h new file mode 100644 index 00000000..a3d55fa2 --- /dev/null +++ b/PyTorchSimBackend/include/Core.h @@ -0,0 +1,99 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +#include "Dram.h" +#include "Tile.h" +#include "SimulationConfig.h" +#include "TMA.h" + +class Core { + public: + Core(uint32_t id, SimulationConfig config); + ~Core()=default; + virtual bool running(); + virtual bool can_issue(const std::shared_ptr& op); + virtual void issue(std::shared_ptr tile); + virtual std::shared_ptr pop_finished_tile(); + virtual void cycle(); + virtual void print_stats(); + virtual void print_current_stats(); + virtual void finish_instruction(std::shared_ptr& inst); + virtual bool has_memory_request(); + virtual void pop_memory_request(); + virtual mem_fetch* top_memory_request() { return _request_queue.front(); } + virtual void push_memory_response(mem_fetch* response); + void check_tag() { _tma.check_table(); } + void inc_numa_hit() { _stat_numa_hit++; } + void inc_numa_miss() { _stat_numa_miss++; } + + std::queue>& get_compute_pipeline(int compute_type); + enum { + VECTOR_UNIT, + MATMUL, + PRELOAD, + NR_COMPUTE_UNIT + }; + + protected: + void dma_cycle(); + void compute_cycle(); + void vu_cycle(); + void sa_cycle(); + bool can_issue_compute(std::shared_ptr& inst); + void update_stats(); + + /* Core id & config file */ + const uint32_t _id; + const SimulationConfig _config; + size_t _sram_size; + size_t _used_sram_size; + uint32_t _num_systolic_array_per_core; + uint32_t _systolic_array_rr = 0; + + /* TMA Unit */ + TMA _tma; + + /* cycle */ + cycle_type _core_cycle; + cycle_type _stat_tot_vu_compute_cycle = 0; + std::vector _stat_tot_sa_compute_cycle; + cycle_type _stat_tot_tma_cycle = 0; + cycle_type _stat_tot_tma_idle_cycle = 0; + cycle_type _stat_tot_vu_compute_idle_cycle = 0; + std::vector _stat_tot_sa_compute_idle_cycle; + std::vector _stat_inst_count; + std::vector _stat_tot_skipped_inst; + uint64_t _stat_tot_mem_response = 0; + uint64_t _stat_gemm_inst = 0; + uint64_t _stat_skip_dma = 0; + uint64_t _stat_numa_hit = 0; + uint64_t _stat_numa_miss = 0; + + cycle_type _stat_vu_compute_cycle = 0; + std::vector _stat_sa_compute_cycle; + cycle_type _stat_tma_cycle = 0; + cycle_type _stat_tma_idle_cycle = 0; + cycle_type _stat_vu_compute_idle_cycle = 0; + std::vector _stat_sa_compute_idle_cycle; + uint64_t _stat_mem_response = 0; + + std::vector> _tiles; + std::queue> _finished_tiles; + + std::queue> _vu_compute_pipeline; + std::vector>> _sa_compute_pipeline; + std::queue> _ld_inst_queue; + std::queue> _st_inst_queue; + + std::unordered_map> _dma_waiting_queue; + std::vector> _dma_finished_queue; + /* Interconnect queue */ + std::queue _request_queue; + std::queue _response_queue; + uint32_t _waiting_write_reqs; +}; \ No newline at end of file diff --git a/PyTorchSimBackend/include/DelayQueue.h b/PyTorchSimBackend/include/DelayQueue.h new file mode 100644 index 00000000..67d08a23 --- /dev/null +++ b/PyTorchSimBackend/include/DelayQueue.h @@ -0,0 +1,45 @@ +#ifndef PairDelayQueue_H +#define PairDelayQueue_H +#include +#include +#include +#include + +template +class DelayQueue { + public: + DelayQueue() {} + DelayQueue(std::string name, bool only_latency, int max_size) + : m_only_latency(only_latency), + m_name(name), + m_interval(0), + m_cycle(0), + m_max_size(max_size), + m_issued(false), + m_size(0) {} + DelayQueue(std::string name) : DelayQueue(name, false, -1) {} + void push(T data, int delay); + void push(T data, int delay, int interval); + void pop(); + T top(); + int size() { return m_size; } + bool arrived(); + bool queue_empty(); + bool full(); + void cycle(); + + private: + struct QueueEntry { + T data; + uint64_t finish_cycle = 0; + }; + std::string m_name; + int m_interval; + uint64_t m_cycle; + int m_size; + int m_max_size; + bool m_issued; + bool m_only_latency; + std::queue m_queue; +}; +#endif \ No newline at end of file diff --git a/PyTorchSimBackend/include/Dram.h b/PyTorchSimBackend/include/Dram.h new file mode 100644 index 00000000..5e51b96d --- /dev/null +++ b/PyTorchSimBackend/include/Dram.h @@ -0,0 +1,91 @@ +#ifndef DRAM_H +#define DRAM_H +#include +#include +#include +#include + +#include "Common.h" +#include "TMA.h" +#include "ramulator2.hh" +#include "Hashing.h" +#include "Cache.h" +#include "DelayQueue.h" +#include "L2Cache.h" + +class Dram { + public: + Dram(SimulationConfig config, cycle_type* core_cycle); + virtual ~Dram() = default; + virtual bool running() = 0; + virtual void cycle() = 0; + virtual void cache_cycle() = 0; + virtual bool is_full(uint32_t cid, mem_fetch* request) = 0; + virtual void push(uint32_t cid, mem_fetch* request) = 0; + virtual bool is_empty(uint32_t cid) = 0; + virtual mem_fetch* top(uint32_t cid) = 0; + virtual void pop(uint32_t cid) = 0; + uint32_t get_channel_id(mem_fetch* request); + virtual void print_stat() {} + virtual void print_cache_stats() {}; + uint32_t get_channels_per_partition() { return _n_ch_per_partition; } + protected: + SimulationConfig _config; + CacheConfig _m_cache_config; + uint32_t _n_ch; + uint32_t _n_bl; + uint32_t _n_partitions; + uint32_t _n_ch_per_partition; + uint32_t _req_size; + cycle_type _cycles; + cycle_type* _core_cycles; + std::vector> m_cache_latency_queue; + std::vector> m_from_crossbar_queue; + std::vector> m_to_crossbar_queue; + std::vector> m_to_mem_queue; + std::vector _m_caches; +}; + +class DramRamulator2 : public Dram { + public: + DramRamulator2(SimulationConfig config, cycle_type *core_cycle); + + virtual bool running() override; + virtual void cycle() override; + virtual void cache_cycle() override; + virtual bool is_full(uint32_t cid, mem_fetch* request) override; + virtual void push(uint32_t cid, mem_fetch* request) override; + virtual bool is_empty(uint32_t cid) override; + virtual mem_fetch* top(uint32_t cid) override; + virtual void pop(uint32_t cid) override; + virtual void print_stat() override; + void print_cache_stats() override; + + private: + std::vector> _mem; + int _tx_ch_log2; + int _tx_log2; +}; + +class SimpleDRAM: public Dram { + public: + SimpleDRAM(SimulationConfig config, cycle_type *core_cycle); + + virtual bool running() override; + virtual void cycle() override; + virtual void cache_cycle() override; + virtual bool is_full(uint32_t cid, mem_fetch* request) override; + virtual void push(uint32_t cid, mem_fetch* request) override; + virtual bool is_empty(uint32_t cid) override; + virtual mem_fetch* top(uint32_t cid) override; + virtual void pop(uint32_t cid) override; + virtual void print_stat() override; + void print_cache_stats() override; + private: + int _latency = 1; + int _tx_ch_log2; + int _tx_log2; + std::vector>> _mem; +}; + +#endif \ No newline at end of file diff --git a/PyTorchSimBackend/include/Hashing.h b/PyTorchSimBackend/include/Hashing.h new file mode 100644 index 00000000..da03de04 --- /dev/null +++ b/PyTorchSimBackend/include/Hashing.h @@ -0,0 +1,24 @@ +// author: Mahmoud Khairy, (Purdue Univ) +// email: abdallm@purdue.edu + +#include +#include +#include + +#ifndef HASHING_H +#define HASHING_H + +typedef unsigned long long new_addr_type; + +unsigned ipoly_hash_function(new_addr_type higher_bits, unsigned index, + unsigned bank_set_num); + +unsigned bitwise_hash_function(new_addr_type higher_bits, unsigned index, + unsigned bank_set_num); + +unsigned PAE_hash_function(new_addr_type higher_bits, unsigned index, + unsigned bank_set_num); + +unsigned mini_hash_function(new_addr_type higher_bits, unsigned index, + unsigned bank_set_num); +#endif \ No newline at end of file diff --git a/PyTorchSimBackend/src/Instruction.h b/PyTorchSimBackend/include/Instruction.h similarity index 56% rename from PyTorchSimBackend/src/Instruction.h rename to PyTorchSimBackend/include/Instruction.h index ad469b9c..4c14dd81 100644 --- a/PyTorchSimBackend/src/Instruction.h +++ b/PyTorchSimBackend/include/Instruction.h @@ -1,8 +1,9 @@ #pragma once - +#include #include #include #include +#include #include #include @@ -11,18 +12,20 @@ #include #include -enum class Opcode { MOVIN, MOVOUT, COMP, BAR}; +enum class Opcode { MOVIN, MOVOUT, COMP, BAR, COUNT}; typedef uint64_t addr_type; typedef uint64_t cycle_type; std::string opcode_to_string(Opcode opcode); -class Instruction { +class Instruction : public std::enable_shared_from_this { public: Instruction(Opcode opcode, cycle_type compute_cycle, size_t num_parents, addr_type dram_addr, - std::vector tile_size, size_t precision, std::vector &idx_list, - std::vector &stride_list, std::vector tag_idx_list, std::vector loop_size_list); + std::vector tile_size, std::vector tile_stride, size_t precision, + std::vector tag_idx_list, std::vector tag_stride_list, + std::vector accum_tag_idx_list); + Instruction(Opcode opcode); void finish_instruction(); void add_child(std::shared_ptr child); bool check_ready() { return ready_counter == 0; } @@ -30,11 +33,16 @@ class Instruction { bool is_dma_read() { return opcode == Opcode::MOVIN; } bool is_dma_write() { return opcode == Opcode::MOVOUT; } bool is_async_dma() { return _is_async_dma; } + bool is_indirect_mode() { return _is_indirect_mode; } + std::string get_indirect_index_path() { return _indirect_index_path; } bool is_ready() { return ready_counter == 0; } void inc_ready_counter() { ready_counter++; } void dec_ready_counter() { assert(ready_counter!=0); ready_counter--; + if (!ready_counter && _owner_ready_queue_ref != nullptr) { + _owner_ready_queue_ref->push_back(shared_from_this()); + } } size_t get_tile_numel() { return _tile_numel; } size_t get_precision() { return _precision; } @@ -46,36 +54,35 @@ class Instruction { cycle_type get_overlapping_cycle() { return overlapping_cycle; } cycle_type get_compute_cycle() { return compute_cycle; } void set_compute_cycle(cycle_type cycle) { compute_cycle = cycle; } + void set_indirect_index_path(std::string indirect_path) { _is_indirect_mode=true; _indirect_index_path=indirect_path; } void print(); - std::set get_dram_address(addr_type dram_req_size) { - std::set address_set; - for (int row=0; row> get_dram_address(addr_type dram_req_size); + std::vector get_trace_address() { return _trace_address; } + bool load_indirect_index(const std::string& path, uint64_t*& indirect_index, const std::vector& tile_size); + void set_trace_address(std::vector& trace_address) { _trace_address = trace_address; } size_t get_free_sram_size() { return _free_sram_size; } - void adjust_dram_address() { - int offset = std::inner_product(_idx_list.begin(), _idx_list.end(), _stride_list.begin(), 0); - dram_addr += offset * _precision; - } + addr_type get_base_dram_address() { return dram_addr; } void set_free_sram_size(size_t sram_size) { _free_sram_size=sram_size; } void* get_owner() { return _owner; } void set_owner(void *owner) { _owner = owner;} + void set_owner_ready_queue(std::list>* q) { _owner_ready_queue_ref = q; } void set_compute_type(int type) { _compute_type = type; } int get_compute_type() { return _compute_type; } void set_numa_id(int numa_id) { _numa_id = numa_id; } uint32_t get_numa_id() { return _numa_id; } - std::vector& get_idx_list() { return _idx_list; } std::vector& get_tag_idx_list() { return _tag_idx_list; } - void set_addr_name(std::string name) { _addr_name = name; } + std::vector& get_tag_stride_list() { return _tag_stride_list; } + std::vector& get_tag_id() { return _tag_key; } + void set_addr_name(std::string name, int id) { _addr_name = name; _addr_id = id; } std::string get_addr_name() { return _addr_name; } + int get_addr_id() { return _addr_id; } void set_nr_inner_loop(int nr) { _nr_inner_loop = nr; } int get_nr_inner_loop() { return _nr_inner_loop; } void set_is_async(bool is_async) { _is_async_dma = is_async; } + void prepare_tag_key(); + bool is_sparse_inst() { return _is_sparse_inst; } + void set_sparse_state(bool state) { _is_sparse_inst = state; } + std::set>& get_child_inst() { return child_inst; } cycle_type start_cycle; cycle_type finish_cycle; @@ -84,13 +91,15 @@ class Instruction { bool finished=false; int subgraph_id; private: - void *_owner; + void *_owner = nullptr; + std::list>* _owner_ready_queue_ref = nullptr; Opcode opcode; cycle_type compute_cycle; cycle_type overlapping_cycle; size_t ready_counter; std::set> child_inst; std::vector tile_size; + std::vector tile_stride; size_t _tile_numel; size_t _nr_waiting_request=0; size_t _precision=0; @@ -98,11 +107,16 @@ class Instruction { addr_type dram_addr; uint32_t _numa_id = 0; // For DMA instruction int _compute_type = 0; - std::vector _idx_list; - std::vector _stride_list; std::vector _tag_idx_list; - std::vector _loop_size_list; + std::vector _tag_stride_list; + std::vector _tag_key; + std::vector _accum_tag_idx_list; + std::vector _trace_address; std::string _addr_name; + int _addr_id; int _nr_inner_loop = 0; bool _is_async_dma=false; + bool _is_indirect_mode=false; + bool _is_sparse_inst=false; + std::string _indirect_index_path=""; }; \ No newline at end of file diff --git a/PyTorchSimBackend/src/Interconnect.h b/PyTorchSimBackend/include/Interconnect.h similarity index 68% rename from PyTorchSimBackend/src/Interconnect.h rename to PyTorchSimBackend/include/Interconnect.h index a47b8c6a..8467b7aa 100644 --- a/PyTorchSimBackend/src/Interconnect.h +++ b/PyTorchSimBackend/include/Interconnect.h @@ -12,10 +12,10 @@ class Interconnect { virtual ~Interconnect() = default; virtual bool running() = 0; virtual void cycle() = 0; - virtual void push(uint32_t src, uint32_t dest, MemoryAccess* request) = 0; - virtual bool is_full(uint32_t src, MemoryAccess* request) = 0; + virtual void push(uint32_t src, uint32_t dest, mem_fetch* request) = 0; + virtual bool is_full(uint32_t src, mem_fetch* request) = 0; virtual bool is_empty(uint32_t nid) = 0; - virtual MemoryAccess* top(uint32_t nid) = 0; + virtual mem_fetch* top(uint32_t nid) = 0; virtual void pop(uint32_t nid) = 0; virtual void print_stats() = 0; @@ -32,10 +32,10 @@ class SimpleInterconnect : public Interconnect { virtual bool running() override; virtual void cycle() override; virtual void push(uint32_t src, uint32_t dest, - MemoryAccess* request) override; - virtual bool is_full(uint32_t src, MemoryAccess* request) override; + mem_fetch* request) override; + virtual bool is_full(uint32_t src, mem_fetch* request) override; virtual bool is_empty(uint32_t nid) override; - virtual MemoryAccess* top(uint32_t nid) override; + virtual mem_fetch* top(uint32_t nid) override; virtual void pop(uint32_t nid) override; virtual void print_stats() override {} @@ -48,11 +48,11 @@ class SimpleInterconnect : public Interconnect { struct Entity { cycle_type finish_cycle; uint32_t dest; - MemoryAccess* access; + mem_fetch* access; }; std::vector> _in_buffers; - std::vector> _out_buffers; + std::vector> _out_buffers; std::vector _busy_node; }; @@ -62,10 +62,10 @@ class Booksim2Interconnect : public Interconnect { virtual bool running() override; virtual void cycle() override; virtual void push(uint32_t src, uint32_t dest, - MemoryAccess* request) override; - virtual bool is_full(uint32_t src, MemoryAccess* request) override; + mem_fetch* request) override; + virtual bool is_full(uint32_t src, mem_fetch* request) override; virtual bool is_empty(uint32_t nid) override; - virtual MemoryAccess* top(uint32_t nid) override; + virtual mem_fetch* top(uint32_t nid) override; virtual void pop(uint32_t nid) override; virtual void print_stats() override; @@ -74,7 +74,7 @@ class Booksim2Interconnect : public Interconnect { std::string _config_path; std::unique_ptr _booksim; - booksim2::Interconnect::Type get_booksim_type(MemoryAccess* access); - uint32_t get_packet_size(MemoryAccess* access); + booksim2::Interconnect::Type get_booksim_type(mem_fetch* access); + uint32_t get_packet_size(mem_fetch* access); }; #endif \ No newline at end of file diff --git a/PyTorchSimBackend/include/IntervalTree.h b/PyTorchSimBackend/include/IntervalTree.h new file mode 100644 index 00000000..ddc2b915 --- /dev/null +++ b/PyTorchSimBackend/include/IntervalTree.h @@ -0,0 +1,344 @@ +#ifndef __INTERVAL_TREE_H +#define __INTERVAL_TREE_H + +#include +#include +#include +#include +#include +#include + +#ifdef USE_INTERVAL_TREE_NAMESPACE +namespace interval_tree { +#endif +template +class Interval { +public: + Scalar start; + Scalar stop; + Value value; + Interval(const Scalar& s, const Scalar& e, const Value& v) + : start(std::min(s, e)) + , stop(std::max(s, e)) + , value(v) + {} +}; + +template +Value intervalStart(const Interval& i) { + return i.start; +} + +template +Value intervalStop(const Interval& i) { + return i.stop; +} + +template +std::ostream& operator<<(std::ostream& out, const Interval& i) { + out << "Interval(" << i.start << ", " << i.stop << "): " << i.value; + return out; +} + +template +class IntervalTree { +public: + typedef Interval interval; + typedef std::vector interval_vector; + + + struct IntervalStartCmp { + bool operator()(const interval& a, const interval& b) { + return a.start < b.start; + } + }; + + struct IntervalStopCmp { + bool operator()(const interval& a, const interval& b) { + return a.stop < b.stop; + } + }; + + IntervalTree() + : left(nullptr) + , right(nullptr) + , center(0) + {} + + ~IntervalTree() = default; + + std::unique_ptr clone() const { + return std::unique_ptr(new IntervalTree(*this)); + } + + IntervalTree(const IntervalTree& other) + : intervals(other.intervals), + left(other.left ? other.left->clone() : nullptr), + right(other.right ? other.right->clone() : nullptr), + center(other.center) + {} + + IntervalTree& operator=(IntervalTree&&) = default; + IntervalTree(IntervalTree&&) = default; + + IntervalTree& operator=(const IntervalTree& other) { + center = other.center; + intervals = other.intervals; + left = other.left ? other.left->clone() : nullptr; + right = other.right ? other.right->clone() : nullptr; + return *this; + } + + IntervalTree( + interval_vector&& ivals, + std::size_t depth = 16, + std::size_t minbucket = 64, + std::size_t maxbucket = 512, + Scalar leftextent = 0, + Scalar rightextent = 0) + : left(nullptr) + , right(nullptr) + { + --depth; + const auto minmaxStop = std::minmax_element(ivals.begin(), ivals.end(), + IntervalStopCmp()); + const auto minmaxStart = std::minmax_element(ivals.begin(), ivals.end(), + IntervalStartCmp()); + if (!ivals.empty()) { + center = (minmaxStart.first->start + minmaxStop.second->stop) / 2; + } + if (leftextent == 0 && rightextent == 0) { + // sort intervals by start + std::sort(ivals.begin(), ivals.end(), IntervalStartCmp()); + } else { + assert(std::is_sorted(ivals.begin(), ivals.end(), IntervalStartCmp())); + } + if (depth == 0 || (ivals.size() < minbucket && ivals.size() < maxbucket)) { + std::sort(ivals.begin(), ivals.end(), IntervalStartCmp()); + intervals = std::move(ivals); + assert(is_valid().first); + return; + } else { + Scalar leftp = 0; + Scalar rightp = 0; + + if (leftextent || rightextent) { + leftp = leftextent; + rightp = rightextent; + } else { + leftp = ivals.front().start; + rightp = std::max_element(ivals.begin(), ivals.end(), + IntervalStopCmp())->stop; + } + + interval_vector lefts; + interval_vector rights; + + for (typename interval_vector::const_iterator i = ivals.begin(); + i != ivals.end(); ++i) { + const interval& interval = *i; + if (interval.stop < center) { + lefts.push_back(interval); + } else if (interval.start > center) { + rights.push_back(interval); + } else { + assert(interval.start <= center); + assert(center <= interval.stop); + intervals.push_back(interval); + } + } + + if (!lefts.empty()) { + left.reset(new IntervalTree(std::move(lefts), + depth, minbucket, maxbucket, + leftp, center)); + } + if (!rights.empty()) { + right.reset(new IntervalTree(std::move(rights), + depth, minbucket, maxbucket, + center, rightp)); + } + } + assert(is_valid().first); + } + + // Call f on all intervals near the range [start, stop]: + template + void visit_near(const Scalar& start, const Scalar& stop, UnaryFunction f) const { + if (!intervals.empty() && ! (stop < intervals.front().start)) { + for (auto & i : intervals) { + f(i); + } + } + if (left && start <= center) { + left->visit_near(start, stop, f); + } + if (right && stop >= center) { + right->visit_near(start, stop, f); + } + } + + // Call f on all intervals crossing pos + template + void visit_overlapping(const Scalar& pos, UnaryFunction f) const { + visit_overlapping(pos, pos, f); + } + + // Call f on all intervals overlapping [start, stop] + template + void visit_overlapping(const Scalar& start, const Scalar& stop, UnaryFunction f) const { + auto filterF = [&](const interval& interval) { + if (interval.stop >= start && interval.start <= stop) { + // Only apply f if overlapping + f(interval); + } + }; + visit_near(start, stop, filterF); + } + + // Call f on all intervals contained within [start, stop] + template + void visit_contained(const Scalar& start, const Scalar& stop, UnaryFunction f) const { + auto filterF = [&](const interval& interval) { + if (start <= interval.start && interval.stop <= stop) { + f(interval); + } + }; + visit_near(start, stop, filterF); + } + + interval_vector findOverlapping(const Scalar& start, const Scalar& stop) const { + interval_vector result; + visit_overlapping(start, stop, + [&](const interval& interval) { + result.emplace_back(interval); + }); + return result; + } + + interval_vector findContained(const Scalar& start, const Scalar& stop) const { + interval_vector result; + visit_contained(start, stop, + [&](const interval& interval) { + result.push_back(interval); + }); + return result; + } + bool empty() const { + if (left && !left->empty()) { + return false; + } + if (!intervals.empty()) { + return false; + } + if (right && !right->empty()) { + return false; + } + return true; + } + + template + void visit_all(UnaryFunction f) const { + if (left) { + left->visit_all(f); + } + std::for_each(intervals.begin(), intervals.end(), f); + if (right) { + right->visit_all(f); + } + } + + std::pair extentBruitForce() const { + struct Extent { + std::pair x = {std::numeric_limits::max(), + std::numeric_limits::min() }; + void operator()(const interval & interval) { + x.first = std::min(x.first, interval.start); + x.second = std::max(x.second, interval.stop); + } + }; + Extent extent; + + visit_all([&](const interval & interval) { extent(interval); }); + return extent.x; + } + + // Check all constraints. + // If first is false, second is invalid. + std::pair> is_valid() const { + const auto minmaxStop = std::minmax_element(intervals.begin(), intervals.end(), + IntervalStopCmp()); + const auto minmaxStart = std::minmax_element(intervals.begin(), intervals.end(), + IntervalStartCmp()); + + std::pair> result = {true, { std::numeric_limits::max(), + std::numeric_limits::min() }}; + if (!intervals.empty()) { + result.second.first = std::min(result.second.first, minmaxStart.first->start); + result.second.second = std::min(result.second.second, minmaxStop.second->stop); + } + if (left) { + auto valid = left->is_valid(); + result.first &= valid.first; + result.second.first = std::min(result.second.first, valid.second.first); + result.second.second = std::min(result.second.second, valid.second.second); + if (!result.first) { return result; } + if (valid.second.second >= center) { + result.first = false; + return result; + } + } + if (right) { + auto valid = right->is_valid(); + result.first &= valid.first; + result.second.first = std::min(result.second.first, valid.second.first); + result.second.second = std::min(result.second.second, valid.second.second); + if (!result.first) { return result; } + if (valid.second.first <= center) { + result.first = false; + return result; + } + } + if (!std::is_sorted(intervals.begin(), intervals.end(), IntervalStartCmp())) { + result.first = false; + } + return result; + } + + friend std::ostream& operator<<(std::ostream& os, const IntervalTree& itree) { + return writeOut(os, itree); + } + + friend std::ostream& writeOut(std::ostream& os, const IntervalTree& itree, + std::size_t depth = 0) { + auto pad = [&]() { for (std::size_t i = 0; i != depth; ++i) { os << ' '; } }; + pad(); os << "center: " << itree.center << '\n'; + for (const interval & inter : itree.intervals) { + pad(); os << inter << '\n'; + } + if (itree.left) { + pad(); os << "left:\n"; + writeOut(os, *itree.left, depth + 1); + } else { + pad(); os << "left: nullptr\n"; + } + if (itree.right) { + pad(); os << "right:\n"; + writeOut(os, *itree.right, depth + 1); + } else { + pad(); os << "right: nullptr\n"; + } + return os; + } + +private: + interval_vector intervals; + std::unique_ptr left; + std::unique_ptr right; + Scalar center; +}; +#ifdef USE_INTERVAL_TREE_NAMESPACE +} +#endif + +#endif \ No newline at end of file diff --git a/PyTorchSimBackend/include/L2Cache.h b/PyTorchSimBackend/include/L2Cache.h new file mode 100644 index 00000000..e822e6be --- /dev/null +++ b/PyTorchSimBackend/include/L2Cache.h @@ -0,0 +1,55 @@ +#include +#include +#include "Memfetch.h" +#include "Cache.h" +#include "Instruction.h" +#include "IntervalTree.h" + +class L2CacheBase { +public: + L2CacheBase(std::string name, CacheConfig &cache_config, uint32_t id, cycle_type *core_cycle, + uint32_t l2d_hit_latency, std::queue *to_xbar_queue, + std::queue *from_xbar_queue) : + l_name(name), l_cache_config(cache_config), l_id(id), l_core_cycle(core_cycle), + l2d_hit_latency(l2d_hit_latency), + l_to_xbar_queue(to_xbar_queue), l_from_xbar_queue(from_xbar_queue) {} + virtual void cycle()=0; + // Push memory response from DRAM + virtual bool push(mem_fetch* req)=0; + // Pop memory request from Cache + void pop() { l_to_mem_queue.pop(); } + mem_fetch* top() { return l_to_mem_queue.empty() ? NULL : l_to_mem_queue.front(); } + virtual void print_stats() {}; + +protected: + cycle_type *l_core_cycle; // Core cycle + std::string l_name; // L2 name + CacheConfig l_cache_config; // L2 cache config + uint32_t l_id; // L2 partition id + uint32_t l2d_hit_latency; + std::queue *l_to_xbar_queue; + std::queue *l_from_xbar_queue; + std::queue l_to_mem_queue; + DelayQueue l_from_cache_queue; + std::unique_ptr l_cache; +}; + +class NoL2Cache : public L2CacheBase { +public: + NoL2Cache(std::string name, CacheConfig &cache_config, uint32_t id, cycle_type *core_cycle, + std::queue *to_xbar_queue, std::queue *from_xbar_queue) : + L2CacheBase(name, cache_config, id, core_cycle, 0, to_xbar_queue, from_xbar_queue) {} + void cycle() override; + bool push(mem_fetch* req) override; // Push memory response from DRAM +}; + +class L2DataCache : public L2CacheBase { +public: + typedef IntervalTree CachePlan; + L2DataCache(std::string name, CacheConfig &cache_config, uint32_t id, cycle_type *core_cycle, + uint32_t l2d_hit_latency, std::queue *to_xbar_queue, + std::queue *from_xbar_queue); + void cycle() override; + bool push(mem_fetch* req) override; // Push memory response from DRAM + virtual void print_stats() override; +}; \ No newline at end of file diff --git a/PyTorchSimBackend/include/Memfetch.h b/PyTorchSimBackend/include/Memfetch.h new file mode 100644 index 00000000..8934d5c7 --- /dev/null +++ b/PyTorchSimBackend/include/Memfetch.h @@ -0,0 +1,100 @@ +#ifndef MEM_FETCH_H +#define MEM_FETCH_H +#include +#include "Cache_defs.h" + +typedef unsigned long long new_addr_type; + +enum mem_access_type { + GLOBAL_ACC_R, + GLOBAL_ACC_W, + L2_CACHE_WA, /* Data L2 cache write alloc */ + L2_CACHE_WB, /* Data L2 cache write back */ + NUM_MEM_ACCESS_TYPE +}; + +static const char* mem_access_type_str[] = { + "GLOBAL_ACC_R", "GLOBAL_ACC_W", + "L2_CACHE_WA", "L2_CACHE_WB"}; +enum mf_type { READ_REQUEST = 0, WRITE_REQUEST, READ_REPLY, WRITE_ACK }; + +static unsigned long long unique_uid = 0; + +class mem_fetch { + public: + mem_fetch(new_addr_type addr, mem_access_type acc_type, mf_type type, + unsigned data_size, unsigned numa_id=-1, + void* custom_data=NULL) : + m_addr(addr), m_mem_access_type(acc_type), + m_type(type), m_data_size(data_size), + m_numa_id(numa_id), m_custom_data(custom_data) { + m_request_id = unique_uid++; + } + mem_fetch(std::deque mfs); // for wrapping multiple mfs into one + /* Src & Des */ + void set_core_id(int core_id) {m_core_id = core_id;} + int get_core_id() { return m_core_id; } + void set_channel(unsigned channel) { m_channel = channel; } + unsigned get_channel() { return m_channel; } + void set_numa_id(unsigned numa_id) { m_numa_id=numa_id; } + unsigned get_numa_id() { return m_numa_id; } + /* Data & size */ + void set_data(void* data) { m_data = data; } + void* get_data() { return m_data; } + void set_data_size(unsigned size) { m_data_size = size; } + unsigned get_data_size() { return m_data_size; } + new_addr_type get_addr() { return m_addr; } + void set_addr(new_addr_type addr) { m_addr = addr; } + /* Mem info */ + mem_access_type get_access_type() { return m_mem_access_type; } + mf_type get_type() { return m_type; } + void set_type(mf_type type) { m_type = type; } + bool is_write() { return m_type == mf_type::WRITE_REQUEST || m_type == mf_type::WRITE_ACK; } + void set_request_id(unsigned id) { m_request_id = id; } + unsigned get_request_id() { return m_request_id; } + void set_access_sector_mask(uint32_t line_size, uint32_t sector_size) { m_sector_mask.set((m_addr % line_size) / sector_size); } + void set_access_sector_mask(SectorMask mask) { m_sector_mask = mask; } + SectorMask get_access_sector_mask() { return m_sector_mask; } + void set_dirty_mask(SectorMask dirty_mask) { m_dirty_mask = dirty_mask; } + SectorMask get_dirty_mask() { return m_dirty_mask; } + mem_fetch* get_original_mf() { return m_original_mf; } + bool is_atomic() { return false; } + bool is_request() { return m_type == mf_type::READ_REQUEST || m_type == mf_type::WRITE_REQUEST; } + void set_cacheable(bool cacheable) { m_cacheable = cacheable; } + bool is_cacheable() { return m_cacheable; } + void set_reply() { + if (m_type == mf_type::READ_REQUEST) + m_type = mf_type::READ_REPLY; + else if(m_type == mf_type::WRITE_REQUEST) + m_type = mf_type::WRITE_ACK; + else + spdlog::error("Unexpected mf_type in the set_reply"); + } + void set_custom_data(void* custom_data) { m_custom_data = custom_data; } + void* get_custom_data() { return m_custom_data; } + /* Stat */ + void set_start_cycle(uint64_t start_cycle) { m_start_cycle = start_cycle; } + uint64_t get_start_cycle() { return m_start_cycle; } + + std::string current_state = "NONE"; + uint64_t request_cycle; + uint64_t response_cycle; + private: + uint64_t m_request_id; + unsigned m_data_size; + new_addr_type m_addr; + void* m_data = NULL; + mem_access_type m_mem_access_type; + mf_type m_type; + unsigned m_core_id; + unsigned m_channel; + unsigned m_numa_id; + SectorMask m_sector_mask; + SectorMask m_dirty_mask; + mem_fetch* m_original_mf; + void* m_custom_data = NULL; + uint64_t m_start_cycle = 0ULL; + bool m_cacheable = true; +}; + +#endif \ No newline at end of file diff --git a/PyTorchSimBackend/src/Model.h b/PyTorchSimBackend/include/Model.h similarity index 100% rename from PyTorchSimBackend/src/Model.h rename to PyTorchSimBackend/include/Model.h diff --git a/PyTorchSimBackend/src/SimulationConfig.h b/PyTorchSimBackend/include/SimulationConfig.h similarity index 68% rename from PyTorchSimBackend/src/SimulationConfig.h rename to PyTorchSimBackend/include/SimulationConfig.h index 29296cf7..8f011d00 100644 --- a/PyTorchSimBackend/src/SimulationConfig.h +++ b/PyTorchSimBackend/include/SimulationConfig.h @@ -5,16 +5,25 @@ using json = nlohmann::json; +enum class CoreType { WS_MESH, STONNE }; + enum class DramType { SIMPLE, RAMULATOR1, RAMULATOR2 }; enum class IcntType { SIMPLE, BOOKSIM2 }; +enum class L2CacheType { NOCACHE, DATACACHE }; + struct SimulationConfig { /* Core config */ + std::vector core_type; + std::string stonne_config_path; uint32_t num_cores; uint32_t core_freq; uint32_t sram_size; uint32_t core_print_interval = 0; + uint32_t num_systolic_array_per_core = 1; + uint32_t num_stonne_per_core = 1; + uint32_t num_stonne_port = 1; /* DRAM config */ DramType dram_type; @@ -23,9 +32,15 @@ struct SimulationConfig { uint32_t dram_channels; uint32_t dram_req_size; uint32_t dram_latency; + uint32_t dram_nbl = 1; uint32_t dram_print_interval; std::string dram_config_path; + /* L2 Cache config */ + L2CacheType l2d_type = L2CacheType::NOCACHE; + std::string l2d_config_str; + uint32_t l2d_hit_latency = 1; + /* ICNT config */ IcntType icnt_type; uint32_t icnt_node_per_core = 1; @@ -50,6 +65,6 @@ struct SimulationConfig { } float max_dram_bandwidth() { - return dram_freq * dram_channels * dram_req_size / 1000; // GB/s + return dram_freq * dram_channels * dram_req_size * 2 / dram_nbl / 1000; // GB/s } }; \ No newline at end of file diff --git a/PyTorchSimBackend/src/Simulator.h b/PyTorchSimBackend/include/Simulator.h similarity index 95% rename from PyTorchSimBackend/src/Simulator.h rename to PyTorchSimBackend/include/Simulator.h index 907808a6..4d9defd1 100644 --- a/PyTorchSimBackend/src/Simulator.h +++ b/PyTorchSimBackend/include/Simulator.h @@ -5,6 +5,7 @@ #include #include "Common.h" #include "Core.h" +#include "SparseCore.h" #include "Dram.h" #include "Interconnect.h" #include "scheduler/Scheduler.h" @@ -31,16 +32,17 @@ class Simulator { int get_partition_id(int core_id) { return _config.partiton_map[core_id]; } std::unique_ptr& get_partition_scheduler(int core_id) { return _partition_scheduler.at(get_partition_id(core_id)); } void print_core_stat(); - private: void cycle(); + private: void core_cycle(); void dram_cycle(); void icnt_cycle(); bool running(); void set_cycle_mask(); - uint32_t get_dest_node(MemoryAccess *access); + uint32_t get_dest_node(mem_fetch *access); SimulationConfig _config; uint32_t _n_cores; + uint32_t _n_sp_cores; uint32_t _noc_node_per_core; uint32_t _n_memories; uint32_t _memory_req_size; diff --git a/PyTorchSimBackend/include/SparseCore.h b/PyTorchSimBackend/include/SparseCore.h new file mode 100644 index 00000000..9188b21d --- /dev/null +++ b/PyTorchSimBackend/include/SparseCore.h @@ -0,0 +1,91 @@ +#include +#include +#include "Core.h" +#include "sstStonne.h" +#include "SimpleMem.h" +#include "Config.h" + +class TraceNode { +private: + int node_id; + int node_type; + std::string node_name; + std::set address_set; + int compute_cycle; + +public: + enum TraceType {StonneTraceCompute=6, StonneTraceLoad=7, StonneTraceStore=8}; + TraceNode(int id, std::string name, int type, int cycle = 0) + : node_id(id), node_name(name), node_type(type), compute_cycle(cycle) {} + void setAddress(std::set addr_set) { address_set = addr_set; } + friend std::ostream& operator<<(std::ostream& os, const TraceNode& node) { + os << " " << node.node_id << ": {\n" + << " \"node_id\": " << node.node_id << ",\n" + << " \"node_name\": " << std::quoted(node.node_name) << ",\n" + << " \"node_type\": " << node.node_type << ",\n" + << " \"parents\": [0],\n" + << " \"trace_address\": ["; + + bool first = true; + for (uint64_t addr : node.address_set) { + if (!first) + os << ", "; + os << addr; + first = false; + } + + os << "],\n" + << " \"trace_compute_cycle\": " << node.compute_cycle << "\n" + << " }"; + return os; + } +}; + +class SparseCore : public Core { +public: + SparseCore(uint32_t id, SimulationConfig config); + ~SparseCore(); + bool running() override; + bool can_issue(const std::shared_ptr& op) override; + void issue(std::shared_ptr tile) override; + void cycle() override; + void subCoreCycle(uint32_t subcore_id); + void stonneCycle(SST_STONNE::sstStonne *&stonneCore, uint32_t stonne_core_id, bool &retFlag); + bool has_memory_request(); + void pop_memory_request(); + mem_fetch* top_memory_request() { return _request_queue.front(); } + void push_memory_response(mem_fetch* response) override; + void print_stats() override; + void print_current_stats() override; + std::shared_ptr pop_finished_tile() override; + void finish_instruction(std::shared_ptr& inst) override; + void dumpTrace(int stonne_core_id, const std::string& path); + bool isTraceMode(int stonne_core_id) { return traceMode.at(stonne_core_id); } + void setTraceMode(int stonne_core_id, bool mode) { traceMode.at(stonne_core_id) = mode; } + void checkStatus(uint32_t subcore_id); + void registerMemfetch(const std::tuple& key, std::function callback); + int allocTrafficID() { int id = traffic_id; traffic_id++; return 0; } + uint32_t num_ms = 1; + uint32_t r_port_nr = 1; + uint32_t w_port_nr = 1; + uint32_t nr_cores = 1; +private: + uint32_t rr_idx = 0; + std::vector coreBusy; + std::vector traceCoreStatus; + std::vector traceCoreCycle; + std::vector traceMode; + std::vector> traceNodeList; + std::vector> traceLoadTraffic; // To trace dma traffic + std::vector> traceStoreTraffic; // To trace dma traffic + std::vector>> percore_tiles; + std::vector stonneCores; + /* Interconnect queue */ + std::queue _request_queue; + std::queue _response_queue; + std::map, mem_fetch*> request_merge_table; + std::vector percore_stat; + std::vector percore_total_stat; + int traffic_id=0; +}; + diff --git a/PyTorchSimBackend/src/TMA.h b/PyTorchSimBackend/include/TMA.h similarity index 51% rename from PyTorchSimBackend/src/TMA.h rename to PyTorchSimBackend/include/TMA.h index decd3c60..f8355470 100644 --- a/PyTorchSimBackend/src/TMA.h +++ b/PyTorchSimBackend/include/TMA.h @@ -8,20 +8,7 @@ #include "Instruction.h" #include "SimulationConfig.h" #include "Tile.h" - -typedef struct { - uint32_t id; - addr_type dram_address; - uint64_t size; - bool write; - bool request; - uint32_t core_id; - uint32_t numa_id=0; - Instruction* owner_instruction; - cycle_type start_cycle; - cycle_type dram_enter_cycle; - cycle_type dram_finish_cycle; -} MemoryAccess; +#include "Memfetch.h" struct VectorCompare { bool operator()(const std::vector& a, const std::vector& b) const { @@ -36,21 +23,61 @@ class TMA { void issue_tile(std::shared_ptr inst); bool is_finished() { return _finished; } bool empty() { return _current_inst==nullptr; } - void register_tag(int subgraph_id, const std::pair>& key) { + void register_tag(int subgraph_id, std::vector& key) { if (tag_table.find(subgraph_id) == tag_table.end()) { - tag_table[subgraph_id] = std::map>, bool>(); - waiters[subgraph_id] = std::map>, std::vector>>(); + tag_table[subgraph_id] = std::map, uint32_t>(); + waiters[subgraph_id] = std::map, std::vector>>(); } - tag_table[subgraph_id][key] = false; + tag_table[subgraph_id][key] = 0; waiters[subgraph_id][key] = std::vector>(); } - void set_tag_finish(int subgraph_id, const std::pair>& key) { + void set_tag_finish(int subgraph_id, std::vector& key) { + if (tag_table.find(subgraph_id) == tag_table.end()) { + throw std::runtime_error("Subgraph does not exist in tag_table"); + } + tag_table[subgraph_id][key] = 1; + } + + void set_tag_sparse(int subgraph_id, std::vector& key) { + if (tag_table.find(subgraph_id) == tag_table.end()) { + throw std::runtime_error("Subgraph does not exist in tag_table"); + } + tag_table[subgraph_id][key] = -1; + } + + void mark_tag_used(int subgraph_id, std::vector& key) { if (tag_table.find(subgraph_id) == tag_table.end()) { throw std::runtime_error("Subgraph does not exist in tag_table"); + } else if (!tag_table[subgraph_id][key]) { + throw std::runtime_error("Tag is not ready but freed"); } - tag_table[subgraph_id][key] = true; + tag_table[subgraph_id][key] += 1; + } + + void check_table() { + for (const auto& entry: tag_table) { + auto subgraph_id = entry.first; + for (const auto& tag_entry: tag_table[subgraph_id]) { + const std::vector& tag_key = tag_entry.first; + uint32_t value = tag_entry.second; + if (value == 1) { + spdlog::warn("[Tag Table][{}] Unused tag found: (key={}, val={})", + subgraph_id, fmt::format("[{}]", fmt::join(tag_key, ", ")), value); + } + } + } + } + + bool tag_key_exist(int subgraph_id, std::vector& key) { + auto subgraph_it = tag_table.find(subgraph_id); + if (subgraph_it == tag_table.end()) + return false; + + auto& key_map = subgraph_it->second; + auto key_it = key_map.find(key); + return key_it != key_map.end(); } - bool get_tag_finish(int subgraph_id, const std::pair>& key) { + uint32_t get_tag_finish(int subgraph_id, std::vector& key) { auto subgraph_it = tag_table.find(subgraph_id); auto& key_map = subgraph_it->second; auto key_it = key_map.find(key); @@ -67,7 +94,7 @@ class TMA { tag_table.erase(subgraph_id); waiters.erase(subgraph_id); } - void register_tag_waiter(int subgraph_id, const std::pair>& key, std::shared_ptr inst) { + void register_tag_waiter(int subgraph_id, std::vector& key, std::shared_ptr inst) { auto subgraph_it = tag_table.find(subgraph_id); auto& key_map = subgraph_it->second; auto key_it = key_map.find(key); @@ -76,7 +103,7 @@ class TMA { } waiters[subgraph_id][key].push_back(inst); } - std::vector>& get_tag_waiter(int subgraph_id, const std::pair>& key) { + std::vector>& get_tag_waiter(int subgraph_id, std::vector& key) { auto subgraph_it = tag_table.find(subgraph_id); auto& key_map = subgraph_it->second; auto key_it = key_map.find(key); @@ -87,11 +114,13 @@ class TMA { } std::shared_ptr& get_current_inst() { return _current_inst; } - std::vector get_memory_access(); + std::shared_ptr> get_memory_access(); uint32_t generate_mem_access_id(); + const uint32_t get_max_dim() { return _max_dim; } protected: uint32_t _id; + const uint32_t _max_dim = 4; std::shared_ptr _current_inst; uint32_t _dram_req_size; uint32_t _tile_size_x=0; @@ -99,7 +128,7 @@ class TMA { size_t _tile_idx_stride=1; uint32_t _tile_idx; bool _finished=true; - std::map>, bool>> tag_table; - std::map>, std::vector>>> waiters; + std::map, uint32_t>> tag_table; + std::map, std::vector>>> waiters; }; #endif \ No newline at end of file diff --git a/PyTorchSimBackend/src/Tile.h b/PyTorchSimBackend/include/Tile.h similarity index 65% rename from PyTorchSimBackend/src/Tile.h rename to PyTorchSimBackend/include/Tile.h index c62329c3..d867a037 100644 --- a/PyTorchSimBackend/src/Tile.h +++ b/PyTorchSimBackend/include/Tile.h @@ -3,11 +3,12 @@ #include #include +#include #include "Instruction.h" class TileSubGraph; -class Tile { +class Tile : public std::enable_shared_from_this { public: enum class Status { INITIALIZED, @@ -17,8 +18,8 @@ class Tile { }; Tile(Status status); - TileSubGraph* get_owner() { return _onwer_graph; } - void set_ownwer(TileSubGraph* graph) { _onwer_graph = graph; } + std::shared_ptr get_owner() { return _onwer_graph; } + void set_owner(std::shared_ptr graph) { _onwer_graph = graph; } Status get_status() { return _status; } void set_status(Status status) { _status=status; } size_t get_ready_counter() { return _ready_counter; } @@ -33,6 +34,8 @@ class Tile { void finish_tile(); bool is_ready() { return _ready_counter==0; } std::deque>& get_instructions() { return _instructions; } + void enqueue_ready(const std::shared_ptr& inst) { _ready_queue.push_back(inst); } + std::list>& get_ready_instructions() { return _ready_queue; } void print(); size_t nr_insts() { return _nr_insts; } size_t nr_finshed_insts() { return _nr_finished_insts; } @@ -40,16 +43,23 @@ class Tile { _nr_finished_insts++; }; bool all_insts_finshed() { return _nr_insts == _nr_finished_insts; } + void* get_custom_data() { return _custom_data; } + void set_custom_data(void* custom_data ) { _custom_data = custom_data; } + void set_stonne_tile(bool stonne_tile) { _stonne_tile = stonne_tile; } + bool is_stonne_tile() { return _stonne_tile; } protected: - TileSubGraph* _onwer_graph; + std::shared_ptr _onwer_graph; Status _status = Status::EMPTY; size_t _required_sram_size=0; size_t _ready_counter=0; size_t _nr_insts = 0; size_t _nr_finished_insts = 0; std::deque> _instructions; + std::list> _ready_queue; std::vector> _child_tiles; + void *_custom_data=NULL; + bool _stonne_tile=false; }; #endif \ No newline at end of file diff --git a/PyTorchSimBackend/src/TileGraph.h b/PyTorchSimBackend/include/TileGraph.h similarity index 87% rename from PyTorchSimBackend/src/TileGraph.h rename to PyTorchSimBackend/include/TileGraph.h index d17ee4b3..990c107d 100644 --- a/PyTorchSimBackend/src/TileGraph.h +++ b/PyTorchSimBackend/include/TileGraph.h @@ -5,6 +5,7 @@ #include #include #include "Tile.h" +#include "IntervalTree.h" class TileSubGraph { public: @@ -17,6 +18,8 @@ class TileSubGraph { int get_id() { return _id; } void set_core_id(int core_id) { _core_id = core_id; } int get_core_id() { return _core_id; } + void init_cache_plan(std::shared_ptr> plan) { _cache_plan = plan; } + bool is_cacheable(unsigned long long start, unsigned long long end) { return _cache_plan->findOverlapping(start, end).size() != 0; } struct CompareReadyTile { bool operator()(const std::shared_ptr& a, const std::shared_ptr& b) const { return a->get_required_sram_size() > b->get_required_sram_size(); @@ -29,6 +32,7 @@ class TileSubGraph { int _id; int _core_id = -1; static int _next_id; + std::shared_ptr> _cache_plan; }; class TileGraph { @@ -63,6 +67,10 @@ class TileGraph { std::string get_name() { return _name; } void set_arrival_time(cycle_type arrival_time) { _arrival_time = arrival_time; } cycle_type get_arrival_time() { return _arrival_time; } + void init_cache_plan(IntervalTree::interval_vector it) { + _cache_plan = std::make_shared>(std::move(it)); + } + bool StonneGraph = false; class Iterator { public: @@ -127,6 +135,7 @@ class TileGraph { std::vector> _subgraph_vec; std::vector> _finished_subgraph_vec; std::map>> _cpu_graph_map; + std::shared_ptr> _cache_plan; cycle_type _arrival_time; static std::shared_ptr null_tile; }; \ No newline at end of file diff --git a/PyTorchSimBackend/include/TileGraphParser.h b/PyTorchSimBackend/include/TileGraphParser.h new file mode 100644 index 00000000..5b561127 --- /dev/null +++ b/PyTorchSimBackend/include/TileGraphParser.h @@ -0,0 +1,401 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include "TileGraph.h" +#include "Instruction.h" +#include "sstStonne.h" +#include "IntervalTree.h" +#include "onnx/defs/schema.h" +#include "onnx/onnx-operators_pb.h" +#include "onnx/onnx_pb.h" + +using json = nlohmann::json; + +enum class TileType{ + LOOP_INDEX_NODE, + LOOP_END_NODE, + LOAD_NODE, + STORE_NODE, + COMPUTE_NODE, + MEMORY_WAIT_NODE, + STONNE_NODE, + STONNE_TRACE_COMPUTE_NODE, + STONNE_TRACE_LOAD_NODE, + STONNE_TRACE_STORE_NODE +}; + +enum class LoopType { + NORMAL_LOOP, + PARALLEL_LOOP, + ACCUMULATION_LOOP, + INNER_LOOP +}; + +bool loadConfig(const std::string& config_path, json& config_json); + +class TileNode { + public: + TileNode(onnx::NodeProto& node); + static TileType get_tile_type(std::string type); + void add_child(std::shared_ptr child) { _child.push_back(std::move(child)); } + std::vector>& get_child() { return _child; } + void add_parent(std::shared_ptr parent) { _parent.push_back(std::move(parent)); } + std::vector>& get_parent() { return _parent; } + std::vector& get_child_name() { return _child_name; } + std::vector& get_parent_name() { return _parent_name; } + TileType get_type() { return _type; } + std::shared_ptr get_owner_loop() { return _owner_loop; } + std::string get_name() { return _name; } + void set_owner_loop(std::shared_ptr owner) { _owner_loop=std::move(owner); } + virtual void print_node(); + void set_depth(int depth) { _depth=depth; } + int get_depth() { return _depth; } + + private: + std::vector> _parent; + std::vector> _child; + std::vector _parent_name; + std::vector _child_name; + std::shared_ptr _owner_loop; + std::string _name; + int _depth; + TileType _type; +}; + +class TileGraphParser { + public: + TileGraphParser(std::string onnx_path, std::string attribute_path); + std::shared_ptr get_top_loop(); + std::unique_ptr& get_tile_graph() { return _tile_graph; } + addr_type lookup(std::string key); + void register_loop(std::shared_ptr); + void increase_loop_top() { _loop_stack_pointer++; } + void decrease_loop_top() { _loop_stack_pointer--; } + int get_loop_size(std::string key) { return std::get<0>(_loop_size_map[key]); } + int get_loop_step(std::string key) { return std::get<1>(_loop_size_map[key]); } + LoopType get_loop_type(std::string key) { return std::get<2>(_loop_size_map[key]); } + const std::map> & get_loop_map() { return _loop_size_map; } + const std::vector &lookupNumaInfo(std::string key); + int getCoreIdFromJson(const json& attribute_json, int subgraph_id); + std::string getMetaByName(std::string key) { return _tog_meta[key]; } + const json& get_attribute_file() { return _attribute_json; } + std::vector calc_tag(std::vector& accum_tag, std::vector& tag_idx, std::vector& tag_stride); + void register_memory_tag(std::string name, std::vector& tag_key); + bool check_memory_tag(std::string name, std::vector& tag_key); + void clear_tag_table() { _tag_table.clear(); } + std::string get_indirect_path() { + namespace fs = std::filesystem; + fs::path original(_attribute_path); + fs::path base_folder = original.parent_path().parent_path(); + fs::path new_path = base_folder / "indirect_access" / (std::string("indirect_index") + std::to_string(indirect_counter) + ".raw"); + return new_path.string(); + } + std::string get_sparse_tile_meta_path() { + namespace fs = std::filesystem; + fs::path original(_attribute_path); + fs::path base_folder = original.parent_path().parent_path(); + fs::path new_path = base_folder / "dma_access" / (std::string("sparse_tile.raw")); + return new_path.string(); + } + void load_sparse_meta_data() { + /* Prepare runtime attribute */ + std::string sparse_meta_path = get_sparse_tile_meta_path(); + std::ifstream file(sparse_meta_path, std::ios::binary); + if (file) { + file.seekg(0, std::ios::end); + std::streamsize size = file.tellg(); + file.seekg(0, std::ios::beg); + size_t count = size / sizeof(int64_t); + for (size_t i = 0; i < count; ++i) { + int64_t val; + file.read(reinterpret_cast(&val), sizeof(int64_t)); + sparse_tile_set.insert(val); + } + } + } + void inc_indirect_counter() { indirect_counter++; } + uint64_t get_dma_counter() { return dma_counter; } + void inc_dma_counter() { dma_counter++; } + bool is_sparse_tile(uint64_t idx) { return sparse_tile_set.find(idx) != sparse_tile_set.end(); } + int register_addr_name(const std::string& addr_name) { + if (_addr_name_map.find(addr_name) == _addr_name_map.end()) + _addr_name_map[addr_name] = _addr_name_map.size(); + return _addr_name_map[addr_name]; + } + int get_addr_name_id(const std::string& addr_name) { return _addr_name_map[addr_name]; } + + private: + void register_tile(std::shared_ptr tile_node); + void _tile_generate() {} + void _base_addr_update() {} + void _tile_index_generate() {} + int _loop_stack_pointer = 0; + + json _attribute_json; + std::string _tog_path; + std::string _attribute_path; + uint64_t indirect_counter = 0; + uint64_t dma_counter = 0; + std::set sparse_tile_set; + std::map> _output_map; + std::vector>> _loop_nodes; + std::vector> _tile_vec; + std::unique_ptr _tile_graph; + std::map _arg_to_address; + std::map> _arg_numa_stride; + std::vector> _cache_plan; + std::map> _loop_size_map; + std::map _tog_meta; + std::map>, uint32_t> _tag_table; + std::unordered_map _addr_name_map; +}; + +class TileComputeNode : public TileNode { + public: + TileComputeNode(onnx::NodeProto& node); + uint32_t get_cycle() { return _cycle; } + uint32_t get_overlapping_cycle() { return _overlapping_cycle; } + int get_compute_type() { return _compute_type; } + void print_node(); + + private: + std::map> tile_map; + uint32_t _cycle; + uint32_t _overlapping_cycle = 0; + int _compute_type; +}; + +class TileMemoryNode : public TileNode { + public: + TileMemoryNode(onnx::NodeProto& node); + std::string get_base_addr_name() { return _base_addr_name; } + size_t get_precision() { return _element_size; } + std::vector get_tile_size() { return _tile_size; } + std::vector& get_tile_stride() { return _tile_stride; } + std::vector& get_tag_idx_list() { return _tag_idx_list; } + std::vector& get_tag_stride_list() { return _tag_stride_list; } + std::vector& get_loop_idx_list() { return _loop_idx_list; } + std::vector& get_loop_stride_list () { return _loop_stride_list; } + bool is_async_node() { return _is_async; } + bool is_indirect() { return _is_indirect; } + void print_node() override; + + private: + std::vector _tile_size; + std::vector _tile_stride; + size_t _element_size; + bool _is_async; + bool _is_indirect; + std::string _base_addr_name; + std::vector _tag_idx_list; + std::vector _tag_stride_list; + std::vector _loop_idx_list; + std::vector _loop_stride_list; +}; + +class TileMemoryWaitNode : public TileNode { + public: + TileMemoryWaitNode(onnx::NodeProto& node); + std::string get_base_addr_name() { return _base_addr_name; } + std::vector& get_tag_idx_list() { return _tag_idx_list; } + std::vector& get_tag_stride_list() { return _tag_stride_list; } + std::vector& get_tag_divider_list() { return _tag_divider_list; } + void print_node() override; + + private: + std::vector _tag_idx_list; + std::vector _tag_stride_list; + std::vector _tag_divider_list; + std::string _base_addr_name; +}; + +class TileLoopNode : public TileNode { + public: + TileLoopNode(onnx::NodeProto& node); + void add_body(std::shared_ptr body) { _body_node.push_back(body); } + std::vector> get_tiles_from_iter(TileGraphParser*, std::map&); + std::string get_idx_name() { return _tile_index_name; } + uint64_t get_start() { return _start; } + uint64_t get_stride() { return _stride; } + uint64_t get_end() { return _end; } + LoopType get_loop_type() { return _loop_type; } + void print_node() override; + private: + std::string _tile_index_name; + uint64_t _stride; + uint64_t _start; + uint64_t _end; + LoopType _loop_type; + std::vector> _body_node; +}; + +class TileLoopEndNode : public TileNode { + public: + TileLoopEndNode(onnx::NodeProto& node) : TileNode(node) {} +}; + +class TileStonneNode : public TileNode { + public: + TileStonneNode(onnx::NodeProto& node) : TileNode(node) { + for (auto attribute : node.attribute()) { + if (attribute.name() == "torchsim_stonne_operation") { + std::string op_type = attribute.s(); + if (op_type == "CONV") { + desc.operation = Layer_t::CONV; + } else if (op_type == "GEMM") { + desc.operation = Layer_t::GEMM; + } else if (op_type == "POOL") { + desc.operation = Layer_t::POOL; + } else if (op_type == "FC") { + desc.operation = Layer_t::FC; + } else if (op_type == "SPARSE_DENSE") { + desc.operation = Layer_t::SPARSE_DENSE; + } else if (op_type == "bitmapSpMSpM") { + desc.operation = Layer_t::bitmapSpMSpM; + } else if (op_type == "csrSpMM") { + desc.operation = Layer_t::csrSpMM; + } else if (op_type == "outerProductGEMM") { + desc.operation = Layer_t::outerProductGEMM; + } else if (op_type == "gustavsonsGEMM") { + desc.operation = Layer_t::gustavsonsGEMM; + } else { + spdlog::error("[TileStonneNode] Unknown operation type: {}", op_type); + throw std::runtime_error("Invalid operation type in TileStonneNode"); + } + } else if (attribute.name() == "torchsim_stonne_layer_name") { + desc.layer_name = attribute.s(); + } else if (attribute.name() == "torchsim_stonne_mem_init") { + desc.mem_init = attribute.s(); + } else if (attribute.name() == "torchsim_stonne_R") { + desc.R = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_S") { + desc.S = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_C") { + desc.C = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_K") { + desc.K = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_G") { + desc.G = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_N") { + desc.N = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_X") { + desc.X = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_Y") { + desc.Y = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_X_") { + desc.X_ = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_Y_") { + desc.Y_ = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_strides") { + desc.strides = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_T_R") { + desc.T_R = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_T_S") { + desc.T_S = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_T_C") { + desc.T_C = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_T_K") { + desc.T_K = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_T_G") { + desc.T_G = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_T_N") { + desc.T_N = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_T_X_") { + desc.T_X_ = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_T_Y_") { + desc.T_Y_ = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_GEMM_K") { + desc.GEMM_K = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_GEMM_N") { + desc.GEMM_N = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_GEMM_M") { + desc.GEMM_M = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_GEMM_T_K") { + desc.GEMM_T_K = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_GEMM_T_N") { + desc.GEMM_T_N = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_GEMM_T_M") { + desc.GEMM_T_M = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_matrix_a_dram_address") { + desc.matrix_a_dram_address = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_matrix_b_dram_address") { + desc.matrix_b_dram_address = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_matrix_c_dram_address") { + desc.matrix_c_dram_address = attribute.i(); + } else if (attribute.name() == "torchsim_stonne_mem_matrix_c_file_name") { + desc.mem_matrix_c_file_name = attribute.s(); + } else if (attribute.name() == "torchsim_stonne_bitmap_matrix_a_init") { + desc.bitmap_matrix_a_init = attribute.s(); + } else if (attribute.name() == "torchsim_stonne_bitmap_matrix_b_init") { + desc.bitmap_matrix_b_init = attribute.s(); + } else if (attribute.name() == "torchsim_stonne_rowpointer_matrix_a_init") { + desc.rowpointer_matrix_a_init = attribute.s(); + } else if (attribute.name() == "torchsim_stonne_colpointer_matrix_a_init") { + desc.colpointer_matrix_a_init = attribute.s(); + } else if (attribute.name() == "torchsim_stonne_rowpointer_matrix_b_init") { + desc.rowpointer_matrix_b_init = attribute.s(); + } else if (attribute.name() == "torchsim_stonne_colpointer_matrix_b_init") { + desc.colpointer_matrix_b_init = attribute.s(); + } else if (attribute.name() == "torchsim_bitmap_matrix_a_init") { + desc.bitmap_matrix_a_init = attribute.s(); + } else if (attribute.name() == "torchsim_bitmap_matrix_b_init") { + desc.bitmap_matrix_b_init = attribute.s(); + } else if (attribute.name() == "torchsim_mem_matrix_c_file_name") { + desc.mem_matrix_c_file_name = attribute.s(); + } else if (attribute.name() == "torchsim_trace_path") { + desc.trace_path = attribute.s(); + } else { + spdlog::warn("[TileStonneNode] Unrecognized attribute: {}", attribute.name()); + } + } + } + SST_STONNE::StonneOpDesc* getDesc() { return &desc; } + void print_node() override; + private: + SST_STONNE::StonneOpDesc desc; +}; + +class TileStonneTraceComputeNode : public TileNode { + public: + TileStonneTraceComputeNode(onnx::NodeProto& node) : TileNode(node) { + for (auto attribute : node.attribute()) { + if (attribute.name() == "torchsim_trace_compute_cycle") { + _cycle = attribute.i(); + } + } + } + uint32_t get_cycle() { return _cycle; } + void print_node(); + + private: + uint64_t _cycle; +}; + +class TileStonneTraceMemoryNode : public TileNode { + public: + TileStonneTraceMemoryNode(onnx::NodeProto& node) : TileNode(node) { + for (auto attribute : node.attribute()) { + if (attribute.name() == "torchsim_trace_address") { + trace_address.assign(attribute.ints().begin(), attribute.ints().end()); + } + } + } + std::vector& get_address() { return trace_address; } + void print_node(); + + private: + std::vector trace_address; +}; +class TileStonneTraceLoadNode : public TileStonneTraceMemoryNode { + public: + using TileStonneTraceMemoryNode::TileStonneTraceMemoryNode; +}; + +class TileStonneTraceStoreNode : public TileStonneTraceMemoryNode { + public: + using TileStonneTraceMemoryNode::TileStonneTraceMemoryNode; +}; \ No newline at end of file diff --git a/PyTorchSimBackend/src/scheduler/Scheduler.h b/PyTorchSimBackend/include/scheduler/Scheduler.h similarity index 88% rename from PyTorchSimBackend/src/scheduler/Scheduler.h rename to PyTorchSimBackend/include/scheduler/Scheduler.h index 1ceb9f4d..39ab7576 100644 --- a/PyTorchSimBackend/src/scheduler/Scheduler.h +++ b/PyTorchSimBackend/include/scheduler/Scheduler.h @@ -1,8 +1,9 @@ #pragma once #include -#include "../Tile.h" -#include "../Common.h" -#include "../TileGraph.h" +#include "Tile.h" +#include "Common.h" +#include "TileGraph.h" +#include "SimulationConfig.h" class Scheduler { public: @@ -12,7 +13,7 @@ class Scheduler { /* For other schedulers */ virtual std::shared_ptr get_tile(int core_id=0, int slot_id=0); - virtual const std::shared_ptr peek_tile(int core_id=0, int slot_id=0); + virtual const std::shared_ptr peek_tile(int core_id=0, int slot_id=0, CoreType ctype=CoreType::WS_MESH); virtual bool empty(); virtual bool empty(int core_id); virtual void refresh_status(); diff --git a/PyTorchSimBackend/src/CMakeLists.txt b/PyTorchSimBackend/src/CMakeLists.txt index 0b7d93b3..65cd4dd4 100644 --- a/PyTorchSimBackend/src/CMakeLists.txt +++ b/PyTorchSimBackend/src/CMakeLists.txt @@ -12,4 +12,3 @@ file(GLOB_RECURSE SRC_FILES # build add_executable(${LIB_NAME} ${SRC_FILES}) -add_library(${LIB_NAME}_lib ${SRC_FILES}) diff --git a/PyTorchSimBackend/src/Cache.cc b/PyTorchSimBackend/src/Cache.cc new file mode 100644 index 00000000..8346fae8 --- /dev/null +++ b/PyTorchSimBackend/src/Cache.cc @@ -0,0 +1,961 @@ +#include "Cache.h" +#include "Hashing.h" + +unsigned int LOGB2(unsigned int v) { + unsigned int shift; + unsigned int r; + r = 0; + shift = ((v & 0xFFFF0000) != 0) << 4; + v >>= shift; + r |= shift; + shift = ((v & 0xFF00) != 0) << 3; + v >>= shift; + r |= shift; + shift = ((v & 0xF0) != 0) << 2; + v >>= shift; + r |= shift; + shift = ((v & 0xC) != 0) << 1; + v >>= shift; + r |= shift; + shift = ((v & 0x2) != 0) << 0; + v >>= shift; + r |= shift; + return r; +} + +void CacheConfig::init(std::string config) { + assert(config.size() > 0); + char cache_type, evict_policy, write_policy, alloc_policy, write_alloc_policy, + sif; + char mshr_type; + // sif : sector index function + int ntok = + sscanf(config.c_str(), "%c:%u:%u:%u,%u,%c:%c:%c:%c:%c,%c:%u:%u,%u:%u,%u", + &cache_type, &m_nset, &m_line_size, &m_assoc, &m_sector_size, &evict_policy, + &write_policy, &alloc_policy, &write_alloc_policy, &sif, + &mshr_type, &m_mshr_entries, &m_mshr_max_merge, &m_miss_queue_size, + &m_result_fifo_entries, &m_data_port_width); + assert(ntok >= 12); + m_valid = true; + m_cache_type = CacheTypeMap[cache_type]; + m_evict_policy = EvictPolicyMap[evict_policy]; + m_write_policy = WritePolicyMap[write_policy]; + m_alloc_policy = AllocationPolicyMap[alloc_policy]; + m_write_alloc_policy = WriteAllocatePolicyMap[write_alloc_policy]; + m_set_index_function = SetIndexFunctionMap[sif]; + m_mshr_type = MshrConfigMap[mshr_type]; + m_line_size_log2 = LOGB2(m_line_size); + m_nset_log2 = LOGB2(m_nset); + m_atom_size = m_cache_type == SECTOR ? m_sector_size : m_line_size; + m_sector_size_log2 = LOGB2(m_sector_size); + m_origin_assoc = m_assoc; + m_origin_nset = m_nset; +} + +uint32_t CacheConfig::get_set_index(uint64_t addr) const { + return hash_function(addr); +} + +uint64_t CacheConfig::get_tag(uint64_t addr) const { + return addr & ~(uint64_t)(m_line_size - 1); +} + +uint64_t CacheConfig::get_block_addr(uint64_t addr) const { + return addr & ~(uint64_t)(m_line_size - 1); +} + +uint64_t CacheConfig::get_mshr_addr(uint64_t addr) const { + return addr & ~(uint64_t)(m_atom_size - 1); +} + +uint32_t CacheConfig::hash_function(uint64_t addr) const { + uint32_t set_index = 0; + switch (m_set_index_function) { + case LINEAR_SET_FUNCTION: + set_index = (addr >> m_line_size_log2) & (m_nset - 1); + break; + case BITWISE_XORING_FUNCTION: { + uint64_t higher_bits = addr > (m_line_size_log2 + m_nset_log2); + uint32_t index = (addr >> m_line_size_log2) & (m_nset - 1); + set_index = bitwise_hash_function(higher_bits, index, m_nset); + } break; + case HASH_IPOLY_FUNCTION: { + uint64_t higher_bits = addr > (m_line_size_log2 + m_nset_log2); + uint32_t index = (addr >> m_line_size_log2) & (m_nset - 1); + set_index = ipoly_hash_function(higher_bits, index, m_nset); + } break; + case CUSTOM_SET_FUNCTION: + break; + default: + assert(0); + } + return set_index; +} + +/* Normal Cache Block */ +void LineCacheBlock::allocate(uint64_t tag, uint64_t block_addr, uint32_t time, + SectorMask mask) { + m_tag = tag; + m_block_addr = block_addr; + m_alloc_time = time; + m_last_access_time = time; + m_fill_time = 0; + m_status = RESERVED; + m_ignore_on_fill_status = false; + m_set_modified_on_fill = false; +} + +void LineCacheBlock::fill(uint32_t time, SectorMask) { + m_fill_time = time; + m_status = m_set_modified_on_fill ? MODIFIED : VALID; +} + +SectorMask LineCacheBlock::get_dirty_mask() { + SectorMask dirty_mask; + dirty_mask.reset(); + if (m_status == MODIFIED) + dirty_mask.set(); + return dirty_mask; +} + +/* Sector Cache Block */ +void SectorCacheBlock::allocate(uint64_t tag, uint64_t block_addr, + uint32_t time, SectorMask sector_mask) { + // Allocate line + init(); + m_tag = tag; + m_block_addr = block_addr; + uint32_t sidx = get_sector_index(sector_mask); + m_sector_alloc_time[sidx] = time; + m_sector_last_access_time[sidx] = time; + m_line_alloc_time = time; + m_line_last_access_time = time; +} + +void SectorCacheBlock::allocate_sector(uint32_t time, SectorMask sector_mask) { + assert(is_valid_line()); + uint32_t sidx = get_sector_index(sector_mask); + m_sector_alloc_time[sidx] = time; + m_sector_last_access_time[sidx] = time; + m_line_last_access_time = time; + m_set_modified_on_fill_status[sidx] = m_status[sidx] == MODIFIED ? true : false; + m_status[sidx] = RESERVED; + m_ignore_on_fill_status[sidx] = false; + m_readable[sidx] = true; +} + +void SectorCacheBlock::fill(uint32_t time, SectorMask sector_mask) { + uint32_t sidx = get_sector_index(sector_mask); + m_status[sidx] = m_set_modified_on_fill_status[sidx] ? MODIFIED : VALID; + m_sector_fill_time[sidx] = time; + m_line_fill_time = time; +} + +bool SectorCacheBlock::is_valid_line() { return !(is_invalid_line()); } + +bool SectorCacheBlock::is_invalid_line() { + // all the sectors should be invalid + for (unsigned i = 0; i < SECTOR_CHUNCK_SIZE; ++i) { + if (m_status[i] != INVALID) return false; + } + return true; +} + +bool SectorCacheBlock::is_reserved_line() { + // all the sectors should be invalid + for (unsigned i = 0; i < SECTOR_CHUNCK_SIZE; ++i) { + if (m_status[i] == RESERVED) return true; + } + return false; +} + +bool SectorCacheBlock::is_modified_line() { + for (unsigned i = 0; i < SECTOR_CHUNCK_SIZE; ++i) { + if (m_status[i] == MODIFIED) return true; + } + return false; +} + +SectorMask SectorCacheBlock::get_dirty_mask() { + SectorMask dirty_mask; + dirty_mask.reset(); + for (unsigned i = 0; i < SECTOR_CHUNCK_SIZE; ++i) { + if (m_status[i] == MODIFIED) dirty_mask.set(i); + } + return dirty_mask; +} + +void SectorCacheBlock::init() { + for (int i = 0; i < SECTOR_CHUNCK_SIZE; i++) { + m_sector_alloc_time[i] = 0; + m_sector_fill_time[i] = 0; + m_sector_last_access_time[i] = 0; + m_status[i] = INVALID; + m_ignore_on_fill_status[i] = false; + m_set_modified_on_fill_status[i] = false; + m_readable[i] = true; + } + m_line_alloc_time = 0; + m_line_fill_time = 0; + m_line_last_access_time = 0; +} + +CacheBlockState SectorCacheBlock::get_status(SectorMask mask) { + uint32_t sidx = get_sector_index(mask); + return m_status[sidx]; +} + +void SectorCacheBlock::set_status(CacheBlockState status, SectorMask mask) { + uint32_t sidx = get_sector_index(mask); + m_status[sidx] = status; +} + +bool SectorCacheBlock::is_readable(SectorMask mask) { + uint32_t sidx = get_sector_index(mask); + return m_readable[sidx]; +} + +uint64_t SectorCacheBlock::get_last_access_time() { + return m_line_last_access_time; +} + +uint64_t SectorCacheBlock::get_alloc_time() { return m_line_alloc_time; } + +void SectorCacheBlock::set_ignore_on_fill(bool ignore, SectorMask mask) { + uint32_t sidx = get_sector_index(mask); + m_ignore_on_fill_status[sidx] = ignore; +} + +void SectorCacheBlock::set_modified_on_fill(bool modified, SectorMask mask) { + uint32_t sidx = get_sector_index(mask); + m_set_modified_on_fill_status[sidx] = modified; +} + +void SectorCacheBlock::set_readable(bool readable, SectorMask mask) { + uint32_t sidx = get_sector_index(mask); + m_readable[sidx] = readable; +} + +void SectorCacheBlock::set_last_access_time(uint64_t time, + SectorMask sector_mask) { + m_line_last_access_time = time; + uint32_t sidx = get_sector_index(sector_mask); + m_sector_last_access_time[sidx] = time; +} + +uint32_t SectorCacheBlock::get_modified_size() { + uint32_t modified_size = 0; + for (unsigned i = 0; i < SECTOR_CHUNCK_SIZE; ++i) { + if (m_status[i] == MODIFIED) modified_size++; + } + return modified_size * m_sector_size; +} + +/*Tag Array*/ +TagArray::TagArray(CacheConfig &config, int core_id, int type_id) + : m_config(config) { + uint32_t cache_lines_num = config.get_num_lines(); + m_lines = new CacheBlock *[cache_lines_num]; + for (uint32_t i = 0; i < cache_lines_num; ++i) { + if (config.get_cache_type() == SECTOR) + m_lines[i] = new SectorCacheBlock(config.get_sector_size()); + else if (config.get_cache_type() == NORMAL) + m_lines[i] = new LineCacheBlock(config.get_sector_size()); + else + assert(0); + } + init(core_id, type_id); +} + +TagArray::~TagArray() { + uint32_t cache_lines_num = m_config.get_num_lines(); + for (uint32_t i = 0; i < cache_lines_num; ++i) { + delete m_lines[i]; + } + delete[] m_lines; +} + +CacheRequestStatus TagArray::probe(uint64_t addr, uint32_t &idx, mem_fetch *mf, + bool probe_mode) const { + SectorMask sector_mask = mf->get_access_sector_mask(); + return probe(addr, idx, sector_mask, mf, probe_mode); +} + +CacheRequestStatus TagArray::probe(uint64_t addr, uint32_t &idx, + SectorMask mask, mem_fetch *mf, + bool probe_mode) const { + int set_index = m_config.get_set_index(addr); + uint64_t tag = m_config.get_tag(addr); + uint32_t valid_line = (uint32_t)-1; + uint32_t invalid_line = (uint32_t)-1; + uint64_t valid_timestamp = (uint64_t)-1; + bool all_reserved = true; + for (uint32_t way = 0; way < m_config.get_num_assoc(); way++) { + uint32_t index = set_index * m_config.get_num_assoc() + way; + CacheBlock *line = m_lines[index]; + + // Handle tag matched case + if (line->match_tag(tag)) { + idx = index; + if (line->get_status(mask) == RESERVED) { + return HIT_RESERVED; + } else if (line->get_status(mask) == VALID || + (line->get_status(mask) == MODIFIED && + line->is_readable(mask))) { + return HIT; + } else if ((line->get_status(mask) == MODIFIED && + !line->is_readable(mask)) || + (line->is_valid_line() && line->get_status(mask) == INVALID)) { + return SECTOR_MISS; + } else { + assert(line->get_status(mask) == INVALID); + } + } else if (!line->is_reserved_line()) { + all_reserved = false; + if (line->is_invalid_line()) { + invalid_line = index; + continue; + } + + // Choose cacheline for eviction + if (m_config.get_evict_policy() == LRU) { + if (line->get_last_access_time() < valid_timestamp) { + valid_timestamp = line->get_last_access_time(); + valid_line = index; + } + } else if (m_config.get_evict_policy() == FIFO) { + if (line->get_alloc_time() < valid_timestamp) { + valid_timestamp = line->get_alloc_time(); + valid_line = index; + } + } + } + } + + // All target cachelines are reserved + if (all_reserved) { + assert(m_config.get_alloc_policy() == ON_MISS); + return RESERVATION_FAIL; + } + + if (invalid_line != (uint32_t)-1) { + idx = invalid_line; + } else if (valid_line != (uint32_t)-1) { + idx = valid_line; + } else { + assert(0); + } + return MISS; +} + +CacheRequestStatus TagArray::access(uint64_t addr, uint32_t time, uint32_t &idx, + mem_fetch *mf) { + bool wb = false; + EvictedBlockInfo evicted; + return access(addr, time, idx, mf, wb, evicted); +} + +CacheRequestStatus TagArray::access(uint64_t addr, uint32_t time, uint32_t &idx, + mem_fetch *mf, bool &wb, + EvictedBlockInfo &evicted) { + is_used = true; + m_access++; + SectorMask sector_mask = mf->get_access_sector_mask(); + uint64_t tag = m_config.get_tag(addr); + uint64_t block_addr = m_config.get_block_addr(addr); + CacheRequestStatus status = probe(addr, idx, mf); + switch (status) { + case HIT_RESERVED: + m_pending_hit++; + break; + case HIT: + m_lines[idx]->set_last_access_time(time, sector_mask); + break; + case SECTOR_MISS: + assert(m_config.get_cache_type() == SECTOR); + m_sector_miss++; + if (m_config.get_alloc_policy() == ON_MISS) { + ((SectorCacheBlock *)m_lines[idx])->allocate_sector(time, sector_mask); + } + break; + case MISS: + m_miss++; + if (m_config.get_alloc_policy() == ON_MISS) { + if (m_lines[idx]->is_modified_line()) { + wb = true; + evicted.set_info(m_lines[idx]->get_block_addr(), m_lines[idx]->get_modified_size(), + m_lines[idx]->get_status(sector_mask)); + } + m_lines[idx]->allocate(tag, block_addr, time, sector_mask); + } + break; + case RESERVATION_FAIL: + m_res_fail++; + break; + } + return status; +} + +void TagArray::fill(uint64_t addr, uint32_t time, mem_fetch *mf) { + fill(addr, time, mf->get_access_sector_mask()); +} + +void TagArray::fill(uint32_t index, uint32_t time, mem_fetch *mf) { + assert(m_config.get_alloc_policy() == ON_MISS); + m_lines[index]->fill(time, mf->get_access_sector_mask()); +} + +void TagArray::fill(uint64_t addr, uint32_t time, SectorMask mask) { + uint32_t idx; + CacheRequestStatus status = probe(addr, idx, mask); + if (status == MISS) { + m_lines[idx]->allocate(m_config.get_tag(addr), + m_config.get_block_addr(addr), time, mask); + } else if (status == SECTOR_MISS) { + assert(m_config.get_cache_type() == SECTOR); + ((SectorCacheBlock *)m_lines[idx])->allocate_sector(time, mask); + } + m_lines[idx]->fill(time, mask); +} + +void TagArray::invalidate() { + if (!is_used) return; + for (uint32_t i = 0; i < m_config.get_num_lines(); i++) { + for (uint32_t j = 0; j < SECTOR_CHUNCK_SIZE; j++) { + m_lines[i]->set_status(INVALID, SectorMask().set(j)); + } + } +} + +void TagArray::init(int core_id, int type_id) { + m_core_id = core_id; + m_type_id = type_id; + m_access = 0; + m_miss = 0; + m_pending_hit = 0; + m_res_fail = 0; + m_sector_miss = 0; + is_used = false; +} + +/* MSHR Table */ +bool MshrTable::probe(uint64_t block_addr) const { + return m_table.find(block_addr) != m_table.end(); +} + +bool MshrTable::full(uint64_t block_addr) const { + if (probe(block_addr)) + return m_table.at(block_addr).m_list.size() >= m_max_merged; + else + return m_table.size() >= m_num_entries; +} + +void MshrTable::add(uint64_t block_addr, mem_fetch *mf) { + assert(!full(block_addr)); + m_table[block_addr].m_list.push_back(mf); + if (mf->is_atomic()) { + m_table[block_addr].m_has_atomic = true; + } +} + +void MshrTable::mark_ready(uint64_t block_addr, bool &has_atomic) { + assert(probe(block_addr)); + has_atomic = m_table[block_addr].m_has_atomic; + m_current_response.push_back(block_addr); + } + +mem_fetch *MshrTable::pop_next_access() { + assert(access_ready()); + uint64_t block_addr = m_current_response.front(); + assert(probe(block_addr)); + mem_fetch *mf = m_table[block_addr].m_list.front(); + m_table[block_addr].m_list.pop_front(); + if (m_table[block_addr].m_list.empty()) { + m_table.erase(block_addr); + m_current_response.pop_front(); + } + return mf; +} + +mem_fetch *MshrTable::top_next_access() { + assert(access_ready()); + uint64_t block_addr = m_current_response.front(); + assert(probe(block_addr)); + mem_fetch *mf = m_table[block_addr].m_list.front(); + return mf; +} + +bool MshrTable::is_read_after_write_pending(uint64_t block_addr) { + std::deque list = m_table[block_addr].m_list; + bool write_found = false; + for (auto it = list.begin(); it != list.end(); ++it) { + if ((*it)->is_write()) { + write_found = true; // Pending write + } else if (write_found) { + return true; // Pending read after write + } + } + return false; +} + +void MshrTable::print(FILE *fp) const { + +} + +/* Cache */ +Cache::Cache(std::string name, CacheConfig &config, int core_id, int type_id, + std::queue *to_mem_queue) + : m_config(config), m_bandwidth_management(config) { + m_tag_array = new TagArray(config, core_id, type_id); + m_mshrs = new MshrTable(config.get_mshr_entries(), + config.get_mshr_max_merge()); + m_name = name + std::to_string(core_id); + m_id = core_id; + m_to_mem_queue = to_mem_queue; +} + +void Cache::cycle() { + if (!m_miss_queue.empty()) { + mem_fetch *mf = m_miss_queue.front(); + m_to_mem_queue->push(mf); + m_miss_queue.pop_front(); + } + m_bandwidth_management.replenish_port_bandwidth(); +} + +void Cache::fill(mem_fetch *mf, uint32_t time) { + if (m_config.get_mshr_config() == SECTOR_ASSOC) { + assert(mf->get_original_mf()); + assert(m_extra_mf_fields.find(mf->get_original_mf()) != + m_extra_mf_fields.end()); + m_extra_mf_fields[mf->get_original_mf()].pending_read--; + if (m_extra_mf_fields[mf->get_original_mf()].pending_read > 0) { + delete mf; + return; + } else { + mem_fetch *tmp = mf; + mf = mf->get_original_mf(); + delete tmp; + } + } + assert(m_extra_mf_fields.find(mf) != m_extra_mf_fields.end()); + ExtraMfFields field = m_extra_mf_fields[mf]; + mf->set_data_size(field.m_data_size); + mf->set_addr(field.m_addr); + if (m_config.get_alloc_policy() == ON_MISS) { + m_tag_array->fill(field.m_cache_index, time, mf); + } else if (m_config.get_alloc_policy() == ON_FILL) { + m_tag_array->fill(field.m_block_addr, time, mf); + } + bool has_atomic = false; + m_mshrs->mark_ready(field.m_block_addr, has_atomic); + if (has_atomic) { + assert(m_config.get_alloc_policy() == ON_MISS); + CacheBlock *block = m_tag_array->get_block(field.m_cache_index); + if(!block->is_modified_line()) { + // m_tag_array->inc_dirty(); // TODO + } + block->set_status(MODIFIED, mf->get_access_sector_mask()); + } + m_extra_mf_fields.erase(mf); + m_bandwidth_management.use_fill_port(mf); +} + +bool Cache::waiting_for_fill(mem_fetch *mf) { + return m_extra_mf_fields.find(mf) != m_extra_mf_fields.end(); +} + +void Cache::send_read_request(uint64_t addr, uint64_t block_addr, + uint32_t cache_index, mem_fetch *mf, + uint32_t time, bool &do_miss, + std::deque &events, bool read_only, + bool ws) { + bool wb = false; + EvictedBlockInfo evicted; + send_read_request(addr, block_addr, cache_index, mf, time, do_miss, wb, + evicted, events, read_only, ws); +} + +void Cache::send_read_request(uint64_t addr, uint64_t block_addr, + uint32_t cache_index, mem_fetch *mf, + uint32_t time, bool &do_miss, bool &wb, + EvictedBlockInfo &evicted, + std::deque &events, bool read_only, + bool wa) { + new_addr_type mshr_addr = m_config.get_mshr_addr(addr); + bool mshr_hit = m_mshrs->probe(mshr_addr); + bool mshr_avail = !m_mshrs->full(mshr_addr); + if (mshr_hit && mshr_avail) { + if (read_only) + m_tag_array->access(block_addr, time, cache_index, mf); + else + m_tag_array->access(block_addr, time, cache_index, mf, wb, evicted); + m_mshrs->add(mshr_addr, mf); + m_stats.inc_stats(mf->get_access_type(), MSHR_HIT); + do_miss = true; + } else if (!mshr_hit && mshr_avail && !miss_queue_full(0)) { + if (read_only) + m_tag_array->access(block_addr, time, cache_index, mf); + else + m_tag_array->access(block_addr, time, cache_index, mf, wb, evicted); + m_mshrs->add(mshr_addr, mf); + m_extra_mf_fields[mf] = ExtraMfFields(); + m_extra_mf_fields[mf].m_valid = true; + m_extra_mf_fields[mf].m_block_addr = mshr_addr; + m_extra_mf_fields[mf].m_addr = mf->get_addr(); + m_extra_mf_fields[mf].m_cache_index = cache_index; + m_extra_mf_fields[mf].m_data_size = mf->get_data_size(); + m_extra_mf_fields[mf].pending_read = m_config.get_mshr_config() == SECTOR_ASSOC + ? m_config.get_line_size() / m_config.get_sector_size() + : 0; + mf->set_data_size(m_config.get_atom_size()); + // assert(m_config.get_atom_size() <= PACKET_SIZE); //TODO: for now, it should be true + mf->set_addr(mshr_addr); + m_miss_queue.push_back(mf); + if (!wa) events.push_back(CacheEvent(READ_REQUEST_SENT)); + do_miss = true; + } else if (mshr_hit && !mshr_avail) { + m_stats.inc_fail_stats(mf->get_access_type(), MSHR_MERGE_ENTRY_FAIL); + } else if (!mshr_hit && !mshr_avail) { + m_stats.inc_fail_stats(mf->get_access_type(), MSHR_ENTRY_FAIL); + } +} + +void Cache::BandwidthManagement::use_data_port( + mem_fetch *mf, CacheRequestStatus outcome, + const std::deque &events) { + uint32_t data_size = mf->get_data_size(); + uint32_t port_width = m_config.get_data_port_width(); + uint32_t data_cycles = 0; + CacheEvent event; + switch (outcome) { + case HIT: + data_cycles = data_size / port_width + ((data_size % port_width) ? 1 : 0); + m_data_port_occupied_cycles += data_cycles; + break; + case HIT_RESERVED: + case MISS: + if (CacheEvent::was_writeback_sent(events, event)) { + data_cycles = event.m_evicted_block.m_modified_size / port_width; + m_data_port_occupied_cycles += data_cycles; + } + break; + case SECTOR_MISS: + case RESERVATION_FAIL: + break; + default: + assert(0); + } +} + +void Cache::BandwidthManagement::use_fill_port(mem_fetch *mf) { + unsigned fill_cycles = + m_config.get_atom_size() / m_config.get_data_port_width(); + m_fill_port_occupied_cycles += fill_cycles; +} + +void Cache::BandwidthManagement::replenish_port_bandwidth() { + if (m_data_port_occupied_cycles > 0) { + m_data_port_occupied_cycles--; + } + if (m_fill_port_occupied_cycles > 0) { + m_fill_port_occupied_cycles--; + } +} + +bool Cache::BandwidthManagement::data_port_free() const { + return true; // ignore this feature +} + +bool Cache::BandwidthManagement::fill_port_free() const { + return true; +} + +/* Read-only Cache */ +CacheRequestStatus ReadOnlyCache::access(uint64_t addr, uint32_t time, + mem_fetch *mf, + std::deque &events) { + assert(mf->get_data_size() <= m_config.get_atom_size()); + assert(m_config.get_write_policy() == READ_ONLY); + assert(!mf->is_write()); + uint64_t block_addr = m_config.get_block_addr(addr); + uint32_t cache_index = (uint32_t)-1; + CacheRequestStatus status = + m_tag_array->probe(block_addr, cache_index, mf, true); + CacheRequestStatus cache_status = RESERVATION_FAIL; + if (status == HIT) { + cache_status = m_tag_array->access(block_addr, time, cache_index, mf); + } else if (status != RESERVATION_FAIL) { + if (!miss_queue_full(0)) { + bool do_miss = false; + send_read_request(addr, block_addr, cache_index, mf, time, do_miss, + events, true, false); + if (do_miss) + cache_status = MISS; + else + cache_status = RESERVATION_FAIL; + } else { + cache_status = RESERVATION_FAIL; + m_stats.inc_fail_stats(mf->get_access_type(), MISS_QUEUE_FULL); + } + } else { + m_stats.inc_fail_stats(mf->get_access_type(), LINE_ALLOC_FAIL); + } + m_stats.inc_stats(mf->get_access_type(), + m_stats.select_stats_status(status, cache_status)); + + m_bandwidth_management.use_data_port(mf, cache_status, events); + return cache_status; +} + +/* Data Cache */ +void DataCache::init() { + m_rd_hit = &DataCache::rd_hit_base; + m_rd_miss = &DataCache::rd_miss_base; + switch (m_config.get_write_policy()) { + case READ_ONLY: + assert(0); // Data cache cannot be read only + case WRITE_BACK: + m_wr_hit = &DataCache::wr_hit_wb; + break; + case WRITE_THROUGH: + m_wr_hit = &DataCache::wr_hit_wt; + break; + case WRITE_EVICT: + m_wr_hit = &DataCache::wr_hit_we; + break; + default: + assert(0); + } + switch (m_config.get_write_alloc_policy()) { + case NO_WRITE_ALLOCATE: + m_wr_miss = &DataCache::wr_miss_no_wa; + break; + case WRITE_ALLOCATE: + m_wr_miss = &DataCache::wr_miss_wa_naive; + break; + default: + assert(0); + } +} + +void DataCache::print_cache_stats() { + uint64_t hit = m_stats.get_interval_hit(); + uint64_t miss = m_stats.get_interval_miss(); + if (m_id == 0) { + spdlog::info("NDP {:2}: average Data Cache Hit : {}, Miss : {} , Hit Raito : {:.2f}\%", m_id, + hit, miss, ((float)hit) / (hit + miss) * 100); + } else { + spdlog::debug("NDP {:2}: average Data Cache Hit : {}, Miss : {} , Hit Raito : {:.2f}\%", m_id, + hit, miss, ((float)hit) / (hit + miss) * 100); + } +} + +CacheRequestStatus DataCache::access(uint64_t addr, uint32_t time, + mem_fetch *mf, + std::deque &events) { + bool wr = mf->is_write(); + uint64_t block_addr = m_config.get_block_addr(addr); + uint32_t cache_index = (uint32_t)-1; + CacheRequestStatus probe_status = + m_tag_array->probe(block_addr, cache_index, mf, true); + CacheRequestStatus access_status = + process_tag_probe(wr, probe_status, addr, cache_index, mf, time, events); + m_stats.inc_stats(mf->get_access_type(), + m_stats.select_stats_status(probe_status, access_status)); + return access_status; +} + +CacheRequestStatus DataCache::process_tag_probe(bool wr, + CacheRequestStatus probe_status, + uint64_t addr, + uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &events) { + CacheRequestStatus access_status = probe_status; + if (wr) { // Write + if (probe_status == HIT) { + access_status = + (this->*m_wr_hit)(addr, cache_index, mf, time, events, probe_status); + } else if (probe_status != RESERVATION_FAIL || + (probe_status == RESERVATION_FAIL && + m_config.get_write_alloc_policy() == NO_WRITE_ALLOCATE)) { + access_status = + (this->*m_wr_miss)(addr, cache_index, mf, time, events, probe_status); + } else { + m_stats.inc_fail_stats(mf->get_access_type(), LINE_ALLOC_FAIL); + } + } else { // Read + if (probe_status == HIT) { + access_status = + (this->*m_rd_hit)(addr, cache_index, mf, time, events, probe_status); + } else if (probe_status != RESERVATION_FAIL) { + access_status = + (this->*m_rd_miss)(addr, cache_index, mf, time, events, probe_status); + } else { + m_stats.inc_fail_stats(mf->get_access_type(), LINE_ALLOC_FAIL); + } + } + m_bandwidth_management.use_data_port(mf, access_status, events); + return access_status; +} + +void DataCache::send_write_request(mem_fetch *mf, CacheEvent request, + uint32_t time, + std::deque &events) { + events.push_back(request); + m_miss_queue.push_back(mf); +} + +void DataCache::write_back(EvictedBlockInfo &evicted, uint32_t time, std::deque &events) { + auto packet_size = m_config.get_atom_size(); + for(int i = 0; i < evicted.m_modified_size / packet_size; i++) { + uint64_t evicted_addr = evicted.m_block_addr + i * packet_size; + mem_fetch *wb_mf = + new mem_fetch(evicted_addr, m_write_back_type, WRITE_REQUEST, + packet_size); + wb_mf->set_dirty_mask(evicted.m_dirty_mask); + send_write_request(wb_mf, CacheEvent(WRITE_BACK_REQUEST_SENT, evicted), + time, events); + } +} + +/*** WRITE-hit functions (Set by config file) ***/ +// Write hit: Write back +CacheRequestStatus DataCache::wr_hit_wb(uint64_t addr, uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &events, + CacheRequestStatus status) { + uint64_t block_addr = m_config.get_block_addr(addr); + m_tag_array->access(block_addr, time, cache_index, mf); + CacheBlock *block = m_tag_array->get_block(cache_index); + block->set_status(MODIFIED, mf->get_access_sector_mask()); + return HIT; +} + +// Write hit: Write through +CacheRequestStatus DataCache::wr_hit_wt(uint64_t addr, uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &events, + CacheRequestStatus status) { + if (miss_queue_full(0)) { + m_stats.inc_fail_stats(mf->get_access_type(), MISS_QUEUE_FULL); + return RESERVATION_FAIL; + } + uint64_t block_addr = m_config.get_block_addr(addr); + m_tag_array->access(block_addr, time, cache_index, mf); + CacheBlock *block = m_tag_array->get_block(cache_index); + block->set_status(MODIFIED, mf->get_access_sector_mask()); + + // Generate a write-through + send_write_request(mf, CacheEvent(WRITE_REQUEST_SENT), time, events); + return HIT; +} + +// Write hit: Write evict +CacheRequestStatus DataCache::wr_hit_we(uint64_t addr, uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &events, + CacheRequestStatus status) { + if (miss_queue_full(0)) { + m_stats.inc_fail_stats(mf->get_access_type(), MISS_QUEUE_FULL); + return RESERVATION_FAIL; + } + CacheBlock *block = m_tag_array->get_block(cache_index); + send_write_request(mf, CacheEvent(WRITE_REQUEST_SENT), time, events); + block->set_status(INVALID, mf->get_access_sector_mask()); + return HIT; +} + +/*** WRITE-miss functions (Set by config file) ***/ +// Write miss: Write allocate naive +CacheRequestStatus DataCache::wr_miss_wa_naive(uint64_t addr, + uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &events, + CacheRequestStatus status) { + uint64_t block_addr = m_config.get_block_addr(addr); + uint64_t mshr_addr = m_config.get_mshr_addr(addr); + bool mshr_hit = m_mshrs->probe(mshr_addr); + bool mshr_avail = !m_mshrs->full(mshr_addr); + if (miss_queue_full(2)) { + m_stats.inc_fail_stats(mf->get_access_type(), MISS_QUEUE_FULL); + return RESERVATION_FAIL; + } else if (mshr_hit && !mshr_avail) { + m_stats.inc_fail_stats(mf->get_access_type(), MSHR_MERGE_ENTRY_FAIL); + return RESERVATION_FAIL; + } else if (!mshr_hit && !mshr_avail) { + m_stats.inc_fail_stats(mf->get_access_type(), MSHR_ENTRY_FAIL); + return RESERVATION_FAIL; + } + send_write_request(mf, CacheEvent(WRITE_REQUEST_SENT), time, events); + mem_fetch *new_mf = new mem_fetch( + mf->get_addr(), m_write_alloc_type, READ_REQUEST, m_config.get_atom_size()); + new_mf->set_access_sector_mask(mf->get_access_sector_mask()); + new_mf->set_core_id(mf->get_core_id()); + bool do_miss = false; + bool wb = false; + EvictedBlockInfo evicted; + + // Send read request resulting from write miss + send_read_request(addr, block_addr, cache_index, new_mf, time, do_miss, wb, + evicted, events, false, true); + if (do_miss) { + if (wb && (m_config.get_write_policy() != WRITE_THROUGH)) { + assert(status == MISS); + write_back(evicted, time, events); + } + return MISS; + } + return RESERVATION_FAIL; +} + +// Write miss: Write allocate no write allocate +CacheRequestStatus DataCache::wr_miss_no_wa(uint64_t addr, uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &events, + CacheRequestStatus status) { + if (miss_queue_full(0)) { + m_stats.inc_fail_stats(mf->get_access_type(), MISS_QUEUE_FULL); + return RESERVATION_FAIL; + } + send_write_request(mf, CacheEvent(WRITE_REQUEST_SENT), time, events); + return MISS; +} + +CacheRequestStatus DataCache::rd_hit_base(uint64_t addr, uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &events, + CacheRequestStatus status) { + uint64_t block_addr = m_config.get_block_addr(addr); + m_tag_array->access(block_addr, time, cache_index, mf); + if (mf->is_atomic()) { + CacheBlock *block = m_tag_array->get_block(cache_index); + block->set_status(MODIFIED, mf->get_access_sector_mask()); + } + return HIT; +} + +CacheRequestStatus DataCache::rd_miss_base(uint64_t addr, uint32_t cache_index, + mem_fetch *mf, uint32_t time, + std::deque &events, + CacheRequestStatus status) { + if (miss_queue_full(1)) { + mf->current_state = "MISS_QUEUE_FULL"; + m_stats.inc_fail_stats(mf->get_access_type(), MISS_QUEUE_FULL); + return RESERVATION_FAIL; + } + uint64_t block_addr = m_config.get_block_addr(addr); + bool do_miss = false; + bool wb = false; + EvictedBlockInfo evicted; + send_read_request(addr, block_addr, cache_index, mf, time, do_miss, wb, + evicted, events, false, true); + if (do_miss) { + if (wb && (m_config.get_write_policy() != WRITE_THROUGH)) { + write_back(evicted, time, events); + } + return MISS; + } + return RESERVATION_FAIL; +} \ No newline at end of file diff --git a/PyTorchSimBackend/src/Cache_stats.cc b/PyTorchSimBackend/src/Cache_stats.cc new file mode 100644 index 00000000..eacd0c2f --- /dev/null +++ b/PyTorchSimBackend/src/Cache_stats.cc @@ -0,0 +1,244 @@ +#include "Cache_stats.h" +#include "Memfetch.h" + +CacheStats::CacheStats() { + m_stats.resize(NUM_MEM_ACCESS_TYPE); + m_fail_stats.resize(NUM_MEM_ACCESS_TYPE); + for (int i = 0; i < NUM_MEM_ACCESS_TYPE; i++) { + m_stats[i].resize(NUM_CACHE_REQUEST_STATUS, 0); + m_fail_stats[i].resize(NUM_CACHE_RESERVATION_FAIL_REASON, 0); + } + m_cache_port_available_cycles = 0; + m_cache_data_port_busy_cycles = 0; + m_cache_fill_port_busy_cycles = 0; + + m_prev_hit = 0; + m_prev_miss = 0; +} + +void CacheStats::clear() { + for (int i = 0; i < NUM_MEM_ACCESS_TYPE; i++) { + std::fill(m_stats[i].begin(), m_stats[i].end(), 0); + std::fill(m_fail_stats[i].begin(), m_fail_stats[i].end(), 0); + } + m_cache_port_available_cycles = 0; + m_cache_data_port_busy_cycles = 0; + m_cache_fill_port_busy_cycles = 0; +} + +void CacheStats::inc_stats(int access_type, int access_outcome) { + assert(check_valid(access_type, access_outcome)); + m_stats[access_type][access_outcome]++; +} + +void CacheStats::inc_fail_stats(int access_type, int fail_outcome) { + assert(check_fail_valid(access_type, fail_outcome)); + m_fail_stats[access_type][fail_outcome]++; +} + +CacheRequestStatus CacheStats::select_stats_status( + CacheRequestStatus probe, CacheRequestStatus access) const { + if (probe == HIT_RESERVED && access != RESERVATION_FAIL) + return probe; + else if (probe == SECTOR_MISS && access == MISS) + return probe; + else + return access; +} + +uint64_t &CacheStats::operator()(int access_type, int access_outcome, + bool fail_outcome) { + if (fail_outcome) { + assert(check_fail_valid(access_type, access_outcome)); + return m_fail_stats[access_type][access_outcome]; + } else { + assert(check_valid(access_type, access_outcome)); + return m_stats[access_type][access_outcome]; + } +} + +uint64_t CacheStats::operator()(int access_type, int access_outcome, + bool fail_outcome) const { + if (fail_outcome) { + assert(check_fail_valid(access_type, access_outcome)); + return m_fail_stats[access_type][access_outcome]; + } else { + assert(check_valid(access_type, access_outcome)); + return m_stats[access_type][access_outcome]; + } +} + +CacheStats CacheStats::operator+(const CacheStats &other) { + CacheStats sum; + for (int i = 0; i < NUM_MEM_ACCESS_TYPE; i++) { + for (int j = 0; j < NUM_CACHE_REQUEST_STATUS; j++) { + sum.m_stats[i][j] = m_stats[i][j] + other.m_stats[i][j]; + sum.m_fail_stats[i][j] = m_fail_stats[i][j] + other.m_fail_stats[i][j]; + } + } + sum.m_cache_port_available_cycles = + m_cache_port_available_cycles + other.m_cache_port_available_cycles; + sum.m_cache_data_port_busy_cycles = + m_cache_data_port_busy_cycles + other.m_cache_data_port_busy_cycles; + sum.m_cache_fill_port_busy_cycles = + m_cache_fill_port_busy_cycles + other.m_cache_fill_port_busy_cycles; + return sum; +} + +CacheStats &CacheStats::operator+=(const CacheStats &other) { + for (int i = 0; i < NUM_MEM_ACCESS_TYPE; i++) { + for (int j = 0; j < NUM_CACHE_REQUEST_STATUS; j++) { + m_stats[i][j] += other.m_stats[i][j]; + m_fail_stats[i][j] += other.m_fail_stats[i][j]; + } + } + m_cache_port_available_cycles += other.m_cache_port_available_cycles; + m_cache_data_port_busy_cycles += other.m_cache_data_port_busy_cycles; + m_cache_fill_port_busy_cycles += other.m_cache_fill_port_busy_cycles; + return *this; +} + +uint64_t CacheStats::get_hit() const { + uint64_t hit = 0; + for (int i = 0; i < NUM_MEM_ACCESS_TYPE; i++) { + for (int j = 0; j < NUM_CACHE_REQUEST_STATUS; j++) { + if (j == HIT) hit += m_stats[i][j]; + } + } + return hit; +} + +uint64_t CacheStats::get_read_hit() const { + uint64_t hit = 0; + mem_access_type types[] = {GLOBAL_ACC_R}; + CacheRequestStatus status[] = {HIT, HIT_RESERVED}; + for (int i = 0; i < 1; i++) { + for (int j = 0; j < 2; j++) { + hit += m_stats[types[i]][status[j]]; + } + } + return hit; +} + +uint64_t CacheStats::get_write_hit() const { + uint64_t hit = 0; + mem_access_type types[] = {GLOBAL_ACC_W, L2_CACHE_WA, L2_CACHE_WB}; + CacheRequestStatus status[] = {HIT, HIT_RESERVED}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 2; j++) { + hit += m_stats[types[i]][status[j]]; + } + } + return hit; +} + + +uint64_t CacheStats::get_miss() const { + uint64_t miss = 0; + for (int i = 0; i < NUM_MEM_ACCESS_TYPE; i++) { + for (int j = 0; j < NUM_CACHE_REQUEST_STATUS; j++) { + if (j == MISS || j == SECTOR_MISS) miss += m_stats[i][j]; + } + } + return miss; +} + +uint64_t CacheStats::get_read_miss() const { + uint64_t miss = 0; + mem_access_type types[] = {GLOBAL_ACC_R}; + CacheRequestStatus status[] = {MISS, SECTOR_MISS}; + for (int i = 0; i < 1; i++) { + for (int j = 0; j < 2; j++) { + miss += m_stats[types[i]][status[j]]; + } + } + return miss; +} + +uint64_t CacheStats::get_write_miss() const { + uint64_t miss = 0; + mem_access_type types[] = {GLOBAL_ACC_W, L2_CACHE_WA, L2_CACHE_WB}; + CacheRequestStatus status[] = {MISS, SECTOR_MISS}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 2; j++) { + miss += m_stats[types[i]][status[j]]; + } + } + return miss; +} + +uint64_t CacheStats::get_accesses() const { + uint64_t access = 0; + for (int i = 0; i < NUM_MEM_ACCESS_TYPE; i++) { + for (int j = 0; j < NUM_CACHE_REQUEST_STATUS; j++) { + if(j == HIT || j == MISS || j == SECTOR_MISS || j == HIT_RESERVED) + access += m_stats[i][j]; + } + } + return access; +} + +uint64_t CacheStats::get_interval_hit() { + uint64_t prev_hit = m_prev_hit; + m_prev_hit = get_hit(); + + return m_prev_hit - prev_hit; +} + +uint64_t CacheStats::get_interval_miss() { + uint64_t prev_miss = m_prev_miss; + m_prev_miss = get_miss(); + + return m_prev_miss - prev_miss; +} + +void CacheStats::print_stats(FILE *out, const char *cache_name) const { + uint64_t hit = get_hit(); + uint64_t miss = get_miss(); + fprintf(out, "\tCache Hit : %lu, Cache Miss : %lu, Hit Ratio : %.2f\n", hit, + miss, (float)hit / (get_accesses())); + std::vector total_access; + total_access.resize(NUM_MEM_ACCESS_TYPE, 0); + for (int type = 0; type < NUM_MEM_ACCESS_TYPE; type++) { + for (int status = 0; status < NUM_CACHE_REQUEST_STATUS; status++) { + fprintf(out, "\t%s[%s][%s] = %lu\n", cache_name, + mem_access_type_str[type], cache_request_status_str[status], + m_stats[type][status]); + if (status != RESERVATION_FAIL && status != MSHR_HIT) + total_access[type] += m_stats[type][status]; + } + } + for (int type = 0; type < NUM_MEM_ACCESS_TYPE; type++) { + fprintf(out, "\t%s[%s][TOTAL] = %u\n", cache_name, + mem_access_type_str[type], total_access[type]); + } +} + +void CacheStats::print_fail_stats(FILE *out, const char *cache_name) const { + for (int type = 0; type < NUM_MEM_ACCESS_TYPE; type++) { + for (int status = 0; status < NUM_CACHE_RESERVATION_FAIL_REASON; status++) { + fprintf(out, "\t%s[%s][%s] = %lu\n", cache_name, + mem_access_type_str[type], + cache_reservation_fail_reason_str[status], + m_fail_stats[type][status]); + } + } +} + +void CacheStats ::print_energy_stats(FILE *out, const char *cache_name) const { + fprintf(out, "%s_RH: %lu\n", cache_name, get_read_hit()); + fprintf(out, "%s_RM: %lu\n", cache_name, get_read_miss()); + fprintf(out, "%s_WH: %lu\n", cache_name, get_write_hit()); + fprintf(out, "%s_WM: %lu\n", cache_name, get_write_miss()); +} + +bool CacheStats::check_valid(int access_type, int access_outcome) const { + return (access_type >= 0 && access_type < NUM_MEM_ACCESS_TYPE && + access_outcome >= 0 && access_outcome < NUM_CACHE_REQUEST_STATUS); +} + +bool CacheStats::check_fail_valid(int access_type, int fail_outcome) const { + return (access_type >= 0 && access_type < NUM_MEM_ACCESS_TYPE && + fail_outcome >= 0 && + fail_outcome < NUM_CACHE_RESERVATION_FAIL_REASON); +} diff --git a/PyTorchSimBackend/src/Common.cc b/PyTorchSimBackend/src/Common.cc index 8c437b14..5581f8bd 100644 --- a/PyTorchSimBackend/src/Common.cc +++ b/PyTorchSimBackend/src/Common.cc @@ -19,10 +19,40 @@ SimulationConfig initialize_config(json config) { /* Core configs */ parsed_config.num_cores = config["num_cores"]; + if (config.contains("core_type")) { + std::vector core_types = config["core_type"].get>(); + + if (core_types.size() != parsed_config.num_cores) + throw std::runtime_error("Mismatch between num_cores and core_type list size"); + + for (const auto& core_type : core_types) { + if (core_type == "ws_mesh") { + parsed_config.core_type.push_back(CoreType::WS_MESH); + } else if (core_type == "stonne") { + parsed_config.core_type.push_back(CoreType::STONNE); + } else { + throw std::runtime_error(fmt::format("Not implemented core type: {}", core_type)); + } + } + } else { + /* Used WS as default */ + for (int i=0; i(config, "core_print_interval"); + /* Stonne config */ + if (config.contains("stonne_config_path")) + parsed_config.stonne_config_path = config["stonne_config_path"]; + /* DRAM config */ if ((std::string)config["dram_type"] == "simple") parsed_config.dram_type = DramType::SIMPLE; @@ -43,9 +73,29 @@ SimulationConfig initialize_config(json config) { parsed_config.dram_req_size = config["dram_req_size"]; if (config.contains("dram_print_interval")) parsed_config.dram_print_interval = config["dram_print_interval"]; + if(config.contains("dram_nbl")) + parsed_config.dram_nbl = config["dram_nbl"]; if (config.contains("dram_num_partitions")) parsed_config.dram_num_partitions = config["dram_num_partitions"]; + /* L2D config */ + if (config.contains("l2d_type")) { + if ((std::string)config["l2d_type"] == "nocache") + parsed_config.l2d_type = L2CacheType::NOCACHE; + else if ((std::string)config["l2d_type"] == "datacache") + parsed_config.l2d_type = L2CacheType::DATACACHE; + else + throw std::runtime_error(fmt::format("Not implemented l2 cache type {} ", + (std::string)config["l2d_type"])); + } else { + parsed_config.l2d_type = L2CacheType::NOCACHE; + } + + if (config.contains("l2d_config")) + parsed_config.l2d_config_str = config["l2d_config"]; + if (config.contains("l2d_hit_latency")) + parsed_config.l2d_config_str = config["l2d_hit_latency"]; + /* Icnt config */ if ((std::string)config["icnt_type"] == "simple") parsed_config.icnt_type = IcntType::SIMPLE; @@ -72,14 +122,14 @@ SimulationConfig initialize_config(json config) { std::string core_partition = "core_" + std::to_string(i); uint32_t partition_id = uint32_t(config["partition"][core_partition]); parsed_config.partiton_map[i] = partition_id; - spdlog::info("CPU {}: Partition {}", i, partition_id); + spdlog::info("[Config/Core] CPU {}: Partition {}", i, partition_id); } } else { /* Default: all partition 0 */ for (int i=0; i(Opcode::COUNT), 0); + _stat_tot_skipped_inst.resize(static_cast(Opcode::COUNT), 0); } bool Core::can_issue(const std::shared_ptr& op) { /* Check SRAM is enough to run tile */ - return op->get_required_sram_size() + _used_sram_size <= _sram_size && _tiles.size() < 2; + return _tiles.size() < 4 && !op->is_stonne_tile(); } void Core::issue(std::shared_ptr op) { if (op->get_instructions().size()){ spdlog::trace("[Core {}][{}] New Tile is issued, remain sram: {} Required size: {}, Free size: {}", - _id, _core_cycle, _sram_size-_used_sram_size, op->get_required_sram_size(), op->get_instructions().back()->get_free_sram_size()); + _id, _core_cycle, _sram_size-_used_sram_size, op->get_required_sram_size(), + op->get_instructions().back()->get_free_sram_size()); } else { spdlog::trace("[Core {}][{}] New Tile is issued, remain sram: {} Required size: {}", _id, _core_cycle, _sram_size-_used_sram_size, op->get_required_sram_size()); } //_used_sram_size += op->get_required_sram_size(); + for (const auto& inst : op->get_instructions()) { + if (inst->is_ready()) + op->enqueue_ready(inst); + } _tiles.push_back(std::move(op)); } @@ -37,37 +49,74 @@ std::shared_ptr Core::pop_finished_tile() { return result; } -void Core::compute_cycle() { - for (int i=0; i>& Core::get_compute_pipeline(int compute_type) { + if (compute_type == VECTOR_UNIT) + return _vu_compute_pipeline; + else if (compute_type == MATMUL || compute_type == PRELOAD) { + uint32_t sa_idx = _systolic_array_rr; + _systolic_array_rr = (_systolic_array_rr + 1) % _num_systolic_array_per_core; + return _sa_compute_pipeline.at(sa_idx); + } + else { + spdlog::error("Undefined compute type"); + exit(EXIT_FAILURE); + } +} + +void Core::vu_cycle() { + bool retry = true; + while (retry) { + if (!_vu_compute_pipeline.empty()) { + _stat_vu_compute_cycle++; + if(_vu_compute_pipeline.front()->finish_cycle <= _core_cycle) { + int bubble = _vu_compute_pipeline.front()->bubble_cycle; + _stat_vu_compute_idle_cycle += bubble; + _stat_vu_compute_cycle -= bubble; + finish_instruction(_vu_compute_pipeline.front()); + _vu_compute_pipeline.pop(); + } else { + retry = false; + } + } else { + _stat_vu_compute_idle_cycle++; + retry = false; + } + } +} + +void Core::sa_cycle() { + for (int i=0; i<_num_systolic_array_per_core; i++) { bool retry = true; while (retry) { - if (!target_pipeline.empty()) { - _stat_compute_cycle[i]++; - if(target_pipeline.front()->finish_cycle <= _core_cycle) { - int bubble = target_pipeline.front()->bubble_cycle; - _stat_compute_idle_cycle[i] += bubble; - _stat_compute_cycle[i] -= bubble; - finish_instruction(target_pipeline.front()); - target_pipeline.pop(); + if (!_sa_compute_pipeline.at(i).empty()) { + if(_sa_compute_pipeline.at(i).front()->finish_cycle <= _core_cycle) { + int bubble = _sa_compute_pipeline.at(i).front()->bubble_cycle; + _stat_sa_compute_idle_cycle.at(i) += bubble; + _stat_sa_compute_cycle.at(i) -= bubble; + finish_instruction(_sa_compute_pipeline.at(i).front()); + _sa_compute_pipeline.at(i).pop(); } else { + _stat_sa_compute_cycle.at(i)++; retry = false; } } else { - _stat_compute_idle_cycle[i]++; + _stat_sa_compute_idle_cycle.at(i)++; retry = false; } } } } +void Core::compute_cycle() { + vu_cycle(); + sa_cycle(); +} + void Core::dma_cycle() { /* Check finished dma operation */ - for (int i=0; i<_dma_waiting_queue.size(); i++){ - std::shared_ptr& instruction = _dma_waiting_queue.at(i); - /* Pass not finished instruction */ - if (instruction->get_waiting_request()) - continue; + while(_dma_finished_queue.size()) { + std::shared_ptr& instruction = _dma_finished_queue.at(0); + assert(instruction->get_waiting_request()==0); /* Finish DMA read instruction */ if (instruction->is_dma_read() && !instruction->is_async_dma()) @@ -75,24 +124,22 @@ void Core::dma_cycle() { /* Set tag table of async dma load */ if (instruction->is_dma_read() && instruction->is_async_dma()) { - std::ostringstream oss; - auto key = std::make_pair(instruction->get_addr_name(), instruction->get_tag_idx_list()); + auto& key = instruction->get_tag_id(); assert(!_tma.get_tag_finish(instruction->subgraph_id, key)); _tma.set_tag_finish(instruction->subgraph_id, key); - for (const auto& idx : instruction->get_tag_idx_list()) - oss << idx << ", "; - spdlog::trace("[Core {}][{}] {} ASYNC FINISHED, Used sram: {}, Release sram: {}, subgraph_id: {} addr_name: {} tag_idx_list: {}", + spdlog::trace("[Core {}][{}] {} ASYNC FINISHED, Used sram: {}, Release sram: {}, subgraph_id: {} addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", _id, _core_cycle, opcode_to_string(instruction->get_opcode()), _used_sram_size, instruction->get_free_sram_size(), - instruction->subgraph_id, instruction->get_addr_name(), oss.str()); + instruction->subgraph_id, instruction->get_addr_name(), + fmt::format("[{}]", fmt::join(instruction->get_tag_id(), ", ")), + fmt::format("[{}]", fmt::join(instruction->get_tag_idx_list(), ", ")), + fmt::format("[{}]", fmt::join(instruction->get_tag_stride_list(), ", "))); for (auto & wait_inst : _tma.get_tag_waiter(instruction->subgraph_id, key)) { + _tma.mark_tag_used(instruction->subgraph_id, key); finish_instruction(wait_inst); } } - - /* Erase the instruction in DMA waiting queue */ - _dma_waiting_queue.erase(_dma_waiting_queue.begin() + i); - i--; + _dma_finished_queue.erase(_dma_finished_queue.begin()); } if (_tma.is_finished()) { @@ -104,21 +151,20 @@ void Core::dma_cycle() { finish_instruction(finished_inst); } else if (finished_inst->is_dma_read() && finished_inst->is_async_dma()) { /* Register tag table for async dma load */ - _tma.register_tag(finished_inst->subgraph_id, - std::make_pair(finished_inst->get_addr_name(), finished_inst->get_tag_idx_list())); + _tma.register_tag(finished_inst->subgraph_id, finished_inst->get_tag_id()); finish_instruction(finished_inst); } else if(!finished_inst->is_dma_read()) { spdlog::error("[Core {}][{}] TMA instruction in not valid", _id, _core_cycle); exit(EXIT_FAILURE); } else if (finished_inst->get_opcode() == Opcode::BAR) { - std::ostringstream oss; - for (const auto& idx : finished_inst->get_tag_idx_list()) - oss << idx << ", "; - spdlog::trace("[Core {}][{}] {} FINISHED, addr_name: {} tag_list: {}", _id, _core_cycle, - opcode_to_string(finished_inst->get_opcode()), finished_inst->get_addr_name(), oss.str()); + spdlog::trace("[Core {}][{}] {} FINISHED, addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", _id, _core_cycle, + opcode_to_string(finished_inst->get_opcode()), finished_inst->get_addr_name(), + fmt::format("[{}]", fmt::join(finished_inst->get_tag_id(), ", ")), + fmt::format("[{}]", fmt::join(finished_inst->get_tag_idx_list(), ", ")), + fmt::format("[{}]", fmt::join(finished_inst->get_tag_stride_list(), ", "))); } /*Pass to waiting queue */ - _dma_waiting_queue.push_back(std::move(finished_inst)); + _dma_waiting_queue[finished_inst.get()] = std::move(finished_inst); } /* Issue new DMA operation */ @@ -136,11 +182,10 @@ void Core::dma_cycle() { return; } } - /* Generate MemoryAccess */ - std::vector access_vec = _tma.get_memory_access(); - for (auto access : access_vec) { - access->core_id = _id; - access->start_cycle = _core_cycle; + /* Generate memfetch */ + auto access_vec = _tma.get_memory_access(); + for (auto access : *access_vec) { + access->set_start_cycle(_core_cycle); _request_queue.push(access); } @@ -160,26 +205,52 @@ void Core::cycle() { bool issued = false; for (int i=0; i<_tiles.size() && !issued; i++) { - auto& instructions = _tiles[i]->get_instructions(); - for (int j=0; jget_ready_instructions(); + for (auto it=instructions.begin(); it!=instructions.end();) { + auto& inst = *it; /* Skip instruction is not ready */ - if (!inst->is_ready()) - continue; + //if (!inst->is_ready()) + // continue; switch (inst->get_opcode()) { case Opcode::MOVIN: { - std::ostringstream oss; - for (const auto& idx : inst->get_tag_idx_list()) - oss << idx << ", "; - spdlog::trace("[Core {}][{}] {} ISSUED, free_sram_size: {} addr_name: {} tag_idx_list: {}", _id, _core_cycle, - opcode_to_string(inst->get_opcode()), inst->get_free_sram_size(), - inst->get_addr_name(), oss.str()); - _ld_inst_queue.push(inst); - issued = true; + /* Check another MOVIN with same tag is issued */ + auto& key = inst->get_tag_id(); + if (inst->is_sparse_inst()) { + _tma.register_tag(inst->subgraph_id, key); + _tma.set_tag_sparse(inst->subgraph_id, key); + finish_instruction(inst); + issued = true; + _stat_tot_skipped_inst.at(static_cast(inst->get_opcode()))++; + break; + } else if (inst->is_async_dma() && _tma.tag_key_exist(inst->subgraph_id, key)) { + bool finished = _tma.get_tag_finish(inst->subgraph_id, key); + if (finished) + finish_instruction(inst); + else + _tma.register_tag_waiter(inst->subgraph_id, key, inst); + spdlog::trace("[Core {}][{}] {} SKIPPED, free_sram_size: {} addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", _id, _core_cycle, + opcode_to_string(inst->get_opcode()), inst->get_free_sram_size(), + inst->get_addr_name(), + fmt::format("[{}]", fmt::join(inst->get_tag_id(), ", ")), + fmt::format("[{}]", fmt::join(inst->get_tag_idx_list(), ", ")), + fmt::format("[{}]", fmt::join(inst->get_tag_stride_list(), ", "))); + issued = true; + _stat_tot_skipped_inst.at(static_cast(inst->get_opcode()))++; + break; + } else { + spdlog::trace("[Core {}][{}] {} ISSUED, free_sram_size: {} addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", _id, _core_cycle, + opcode_to_string(inst->get_opcode()), inst->get_free_sram_size(), + inst->get_addr_name(), + fmt::format("[{}]", fmt::join(inst->get_tag_id(), ", ")), + fmt::format("[{}]", fmt::join(inst->get_tag_idx_list(), ", ")), + fmt::format("[{}]", fmt::join(inst->get_tag_stride_list(), ", "))); + _ld_inst_queue.push(inst); + issued = true; + break; + } } - break; case Opcode::MOVOUT: spdlog::trace("[Core {}][{}] {} ISSUED, free_sram_size: {}", _id, _core_cycle, opcode_to_string(inst->get_opcode()), inst->get_free_sram_size()); @@ -188,31 +259,54 @@ void Core::cycle() { break; case Opcode::COMP: { - auto& target_pipeline = _compute_pipeline.at(inst->get_compute_type()); - if (target_pipeline.empty()) + auto& target_pipeline = get_compute_pipeline(inst->get_compute_type()); + if (target_pipeline.empty()) { inst->finish_cycle = _core_cycle + inst->get_compute_cycle(); - else { + inst->bubble_cycle = inst->get_overlapping_cycle(); + } else { int overlapped_cycle = std::min(target_pipeline.back()->finish_cycle - _core_cycle, inst->get_overlapping_cycle()); int bubble_cycle = inst->get_overlapping_cycle() - overlapped_cycle; inst->finish_cycle = target_pipeline.back()->finish_cycle + inst->get_compute_cycle() - overlapped_cycle; inst->bubble_cycle = bubble_cycle; } - spdlog::trace("[Core {}][{}] {}-{} ISSUED, finsh at {}", _id, _core_cycle, - opcode_to_string(inst->get_opcode()), inst->get_compute_type(), inst->finish_cycle); - target_pipeline.push(inst); - issued = true; + if (inst->get_compute_cycle() == 0) { + inst->finish_instruction(); + static_cast(inst->get_owner())->inc_finished_inst(); + _stat_tot_skipped_inst.at(static_cast(inst->get_opcode()))++; + instructions.erase(it); + } else { + spdlog::trace("[Core {}][SA {}][{}] {}-{} ISSUED, finsh at {}", _id, _systolic_array_rr, _core_cycle, + opcode_to_string(inst->get_opcode()), inst->get_compute_type(), inst->finish_cycle); + target_pipeline.push(inst); + issued = true; + if (inst->get_compute_type()) { + _stat_gemm_inst++; + } + } } break; case Opcode::BAR: { - std::ostringstream oss; - auto key = std::make_pair(inst->get_addr_name(), inst->get_tag_idx_list()); - bool finished = _tma.get_tag_finish(inst->subgraph_id, key); - if (finished) { + auto& key = inst->get_tag_id(); + uint32_t finished = _tma.get_tag_finish(inst->subgraph_id, key); + if (finished == -1) { + for (auto child_inst : inst->get_child_inst()) { + if (child_inst->get_opcode() == Opcode::COMP && child_inst->get_compute_type() == MATMUL) { + child_inst->set_compute_cycle(0); + } + } + finish_instruction(inst); + } else if (finished != 0) { + _tma.mark_tag_used(inst->subgraph_id, key); finish_instruction(inst); } else { _tma.register_tag_waiter(inst->subgraph_id, key, inst); } + spdlog::trace("[Core {}][{}] {} ISSUED, addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", _id, _core_cycle, + opcode_to_string(inst->get_opcode()), inst->get_addr_name(), + fmt::format("[{}]", fmt::join(inst->get_tag_id(), ", ")), + fmt::format("[{}]", fmt::join(inst->get_tag_idx_list(), ", ")), + fmt::format("[{}]", fmt::join(inst->get_tag_stride_list(), ", "))); issued = true; } break; @@ -222,10 +316,11 @@ void Core::cycle() { } if (issued) { - auto it = instructions.begin() + j; // Position 2 is the third element + _stat_inst_count.at(static_cast(inst->get_opcode()))++; instructions.erase(it); break; } + it++; } } @@ -262,12 +357,12 @@ void Core::finish_instruction(std::shared_ptr& inst) { _id, _core_cycle, opcode_to_string(inst->get_opcode()), inst->get_compute_type(), _used_sram_size, inst->get_free_sram_size()); } else if (inst->get_opcode() != Opcode::BAR && inst->is_async_dma()){ - std::ostringstream oss; - for (const auto& idx : inst->get_tag_idx_list()) - oss << idx << ", "; - spdlog::trace("[Core {}][{}] {} ASYNC REGISTERED, Used sram: {}, Release sram: {} subgraph_id: {} addr_name: {} tag_idx_list: {}", + spdlog::trace("[Core {}][{}] {} ASYNC REGISTERED, Used sram: {}, Release sram: {} subgraph_id: {} addr_name: {} tag_id: {} tag_idx_list: {} tag_stride_list: {}", _id, _core_cycle, opcode_to_string(inst->get_opcode()), _used_sram_size, - inst->get_free_sram_size(), inst->subgraph_id, inst->get_addr_name(), oss.str()); + inst->get_free_sram_size(), inst->subgraph_id, inst->get_addr_name(), + inst->get_tag_id(), + fmt::format("[{}]", fmt::join(inst->get_tag_idx_list(), ", ")), + fmt::format("[{}]", fmt::join(inst->get_tag_stride_list(), ", "))); } else if ((inst->get_opcode() == Opcode::MOVIN || inst->get_opcode() == Opcode::MOVOUT) && !inst->is_async_dma()) { spdlog::trace("[Core {}][{}] {} FINISHED, free_sram_size: {} addr_name: {}", _id, _core_cycle, opcode_to_string(inst->get_opcode()), inst->get_free_sram_size(), @@ -279,28 +374,40 @@ void Core::finish_instruction(std::shared_ptr& inst) { bool Core::running() { bool running = false; running = running || _tiles.size() > 0; - for (int i=0; i 0; } +bool Core::has_memory_request() { + return !_request_queue.empty(); +} void Core::pop_memory_request() { _request_queue.pop(); } -void Core::push_memory_response(MemoryAccess *response) { - Instruction * owner_inst = response->owner_instruction; - - assert(owner_inst); +void Core::push_memory_response(mem_fetch* response) { + Instruction* owner_inst = static_cast(response->get_custom_data()); assert(owner_inst->get_waiting_request()); owner_inst->dec_waiting_request(); + if (!owner_inst->get_waiting_request()) { + auto it = _dma_waiting_queue.find(owner_inst); + if (it != _dma_waiting_queue.end()) { + std::shared_ptr moved_inst = std::move(it->second); + _dma_finished_queue.push_back(std::move(moved_inst)); + _dma_waiting_queue.erase(it); + } else { + assert(true || "Can't happend...!"); + } + } + _stat_mem_response++; delete response; } @@ -309,47 +416,65 @@ bool Core::can_issue_compute(std::shared_ptr& inst) { } void Core::print_stats() { + std::vector sa_utilization; update_stats(); - spdlog::info( - "Core [{}] : MatMul active cycle {} Vector active cycle {} ", - _id, _stat_tot_compute_cycle[SYSTOLIC_ARRAY], _stat_tot_compute_cycle[VECTOR_UNIT]); - spdlog::info( - "Core [{}] : TMA active cycle {} TMA idle cycle {} Systolic Array idle cycle {} Vector unit idle cycle {}", - _id, _stat_tot_tma_cycle, _stat_tot_tma_idle_cycle, _stat_tot_compute_idle_cycle[SYSTOLIC_ARRAY], _stat_compute_idle_cycle[VECTOR_UNIT]); - spdlog::info("Core [{}] : Systolic Array Utilization(%) {:.2f}, Vector Unit Utilization(%) {:.2f}, Total cycle: {}", - _id, static_cast(_stat_tot_compute_cycle[SYSTOLIC_ARRAY] * 100) / _core_cycle, - static_cast(_stat_tot_compute_cycle[VECTOR_UNIT] * 100) / _core_cycle, _core_cycle); + spdlog::info("===== Instructions count ====="); + for (int i=0; i < static_cast(Opcode::COUNT); i++) { + if (i == static_cast(Opcode::COMP)) + spdlog::info("Core [{}] : {} inst count {} (GEMM: {}, Vector: {}), skipped inst count {}", _id, opcode_to_string(static_cast(i)), _stat_inst_count.at(i), _stat_gemm_inst, _stat_inst_count.at(i) - _stat_gemm_inst, _stat_tot_skipped_inst.at(i)); + else + spdlog::info("Core [{}] : {} inst count {}, skipped inst count {}", _id, opcode_to_string(static_cast(i)), _stat_inst_count.at(i), _stat_tot_skipped_inst.at(i)); + } + spdlog::info("========= Core stat ========="); + for (int i=0; i<_num_systolic_array_per_core; i++) + sa_utilization.push_back(static_cast(_stat_tot_sa_compute_cycle.at(i) * 100) / _core_cycle); + for (int i=0; i<_num_systolic_array_per_core; i++) + spdlog::info("Core [{}] : Systolic array [{}] Utilization(%) {:.2f}, active cycle {}, idle cycle {}", _id, i, sa_utilization.at(i), + _stat_tot_sa_compute_cycle.at(i), _stat_tot_sa_compute_idle_cycle.at(i)); + float dram_bw = _config.dram_req_size * _stat_tot_mem_response * _config.core_freq / (_core_cycle * 1000); // B/cycle + spdlog::info("Core [{}] : TMA active cycle {} TMA idle cycle {} DRAM BW {:.3f} GB/s ({})", _id, _stat_tot_tma_cycle, _stat_tot_tma_idle_cycle, dram_bw, _stat_tot_mem_response); + spdlog::info("Core [{}] : Vector Unit Utilization(%) {:.2f}, active cycle {}, idle_cycle {}", _id, + static_cast(_stat_tot_vu_compute_cycle * 100) / _core_cycle, _stat_tot_vu_compute_cycle, _stat_tot_vu_compute_idle_cycle); + spdlog::info("Core [{}] : Numa hit count : {}, Numa miss count : {}", _id, _stat_numa_hit, _stat_numa_miss); + spdlog::info("Core [{}] : Total cycle {}", _id, _core_cycle); } void Core::print_current_stats() { + std::vector sa_utilization; + for (int i=0; i<_num_systolic_array_per_core; i++) + sa_utilization.push_back(static_cast(_stat_sa_compute_cycle.at(i) * 100) / _config.core_print_interval); + float dram_bw = _config.dram_req_size * _stat_mem_response * _config.core_freq / (_config.core_print_interval * 1000); // B/cycle auto level = spdlog::level::info; if(_id != 0) level = spdlog::level::debug; - spdlog::log(level, - "Core [{}] : MatMul active cycle {} Vector active cycle {} ", - _id, _stat_compute_cycle[SYSTOLIC_ARRAY], _stat_compute_cycle[VECTOR_UNIT]); - spdlog::log(level, - "Core [{}] : TMA active cycle {} TMA idle cycle {} Systolic Array idle cycle {} Vector unit idle cycle {}", - _id, _stat_tma_cycle, _stat_tma_idle_cycle, _stat_compute_idle_cycle[SYSTOLIC_ARRAY], _stat_compute_idle_cycle[VECTOR_UNIT]); - spdlog::log(level, - "Core [{}] : Systolic Array Utilization(%) {:.2f}, Vector Unit Utilization(%) {:.2f}, Total cycle: {}", - _id, static_cast(_stat_compute_cycle[SYSTOLIC_ARRAY] * 100) / _config.core_print_interval, - static_cast(_stat_compute_cycle[VECTOR_UNIT] * 100) / _config.core_print_interval, _core_cycle); + + spdlog::info("========= Core stat ========="); + for (int i=0; i<_num_systolic_array_per_core; i++) + spdlog::info("Core [{}] : Systolic array [{}] Utilization(%) {:.2f}, active cycle {}, idle cycle {}", _id, i, sa_utilization.at(i), + _stat_sa_compute_cycle.at(i), _stat_sa_compute_idle_cycle.at(i)); + spdlog::info("Core [{}] : TMA active cycle {} TMA idle cycle {} DRAM BW {:.3f} GB/s ({})", _id, _stat_tma_cycle, _stat_tma_idle_cycle, dram_bw, _stat_mem_response); + spdlog::info("Core [{}] : Vector Unit Utilization(%) {:.2f}, active cycle {}, idle_cycle {}", _id, + static_cast(_stat_vu_compute_cycle * 100) / _config.core_print_interval, _stat_vu_compute_cycle, _stat_vu_compute_idle_cycle); + spdlog::info("Core [{}] : Total cycle {}", _id, _core_cycle); update_stats(); } void Core::update_stats() { - _stat_tot_compute_cycle[SYSTOLIC_ARRAY] += _stat_compute_cycle[SYSTOLIC_ARRAY]; - _stat_tot_compute_cycle[VECTOR_UNIT] += _stat_compute_cycle[VECTOR_UNIT]; + for (int i=0; i<_num_systolic_array_per_core; i++) { + _stat_tot_sa_compute_cycle.at(i) += _stat_sa_compute_cycle.at(i); + _stat_tot_sa_compute_idle_cycle.at(i) += _stat_sa_compute_idle_cycle.at(i); + _stat_sa_compute_cycle.at(i) = 0; + _stat_sa_compute_idle_cycle.at(i) = 0; + } + + _stat_tot_vu_compute_cycle += _stat_vu_compute_cycle; _stat_tot_tma_cycle += _stat_tma_cycle; _stat_tot_tma_idle_cycle += _stat_tma_idle_cycle; - _stat_tot_compute_idle_cycle[SYSTOLIC_ARRAY] += _stat_compute_idle_cycle[SYSTOLIC_ARRAY]; - _stat_compute_idle_cycle[VECTOR_UNIT] += _stat_compute_idle_cycle[VECTOR_UNIT]; + _stat_tot_mem_response += +_stat_mem_response; - _stat_compute_cycle[SYSTOLIC_ARRAY] = 0; - _stat_compute_cycle[VECTOR_UNIT] = 0; + _stat_vu_compute_cycle = 0; _stat_tma_cycle = 0; _stat_tma_idle_cycle = 0; - _stat_compute_idle_cycle[SYSTOLIC_ARRAY] = 0; - _stat_compute_idle_cycle[VECTOR_UNIT] = 0; + _stat_vu_compute_idle_cycle = 0; + _stat_mem_response = 0; } \ No newline at end of file diff --git a/PyTorchSimBackend/src/Core.h b/PyTorchSimBackend/src/Core.h deleted file mode 100644 index 77f13fec..00000000 --- a/PyTorchSimBackend/src/Core.h +++ /dev/null @@ -1,74 +0,0 @@ -#pragma once -#include - -#include -#include - -#include "Dram.h" -#include "Tile.h" -#include "SimulationConfig.h" -#include "TMA.h" - -class Core { - public: - Core(uint32_t id, SimulationConfig config); - ~Core() = default; - bool running(); - bool can_issue(const std::shared_ptr& op); - void issue(std::shared_ptr tile); - std::shared_ptr pop_finished_tile(); - void cycle(); - void compute_cycle(); - void dma_cycle(); - bool has_memory_request(); - void pop_memory_request(); - MemoryAccess* top_memory_request() { return _request_queue.front(); } - void push_memory_response(MemoryAccess* response); - void print_stats(); - void print_current_stats(); - void finish_instruction(std::shared_ptr& inst); - cycle_type get_compute_cycles() { return _stat_tot_compute_cycle[SYSTOLIC_ARRAY]; } - enum { - VECTOR_UNIT, - SYSTOLIC_ARRAY, - NR_COMPUTE_UNIT - }; - - protected: - bool can_issue_compute(std::shared_ptr& inst); - void update_stats(); - - /* Core id & config file */ - const uint32_t _id; - const SimulationConfig _config; - size_t _sram_size; - size_t _used_sram_size; - - /* TMA Unit */ - TMA _tma; - - /* cycle */ - cycle_type _core_cycle; - cycle_type _stat_tot_compute_cycle[NR_COMPUTE_UNIT] = {0, }; - cycle_type _stat_tot_tma_cycle = 0; - cycle_type _stat_tot_tma_idle_cycle = 0; - cycle_type _stat_tot_compute_idle_cycle[NR_COMPUTE_UNIT] = {0, }; - - cycle_type _stat_compute_cycle[NR_COMPUTE_UNIT] = {0, }; - cycle_type _stat_tma_cycle = 0; - cycle_type _stat_tma_idle_cycle = 0; - cycle_type _stat_compute_idle_cycle[NR_COMPUTE_UNIT] = {0, }; - - std::vector> _tiles; - std::queue> _finished_tiles; - - std::vector>> _compute_pipeline; - std::queue> _ld_inst_queue; - std::queue> _st_inst_queue; - - std::vector> _dma_waiting_queue; - /* Interconnect queue */ - std::queue _request_queue; - std::queue _response_queue; - uint32_t _waiting_write_reqs; -}; \ No newline at end of file diff --git a/PyTorchSimBackend/src/DelayQueue.cc b/PyTorchSimBackend/src/DelayQueue.cc new file mode 100644 index 00000000..fd1463fa --- /dev/null +++ b/PyTorchSimBackend/src/DelayQueue.cc @@ -0,0 +1,55 @@ +#include "DelayQueue.h" +#include "Memfetch.h" + +template +void DelayQueue::push(T data, int delay) { + assert(m_only_latency); + m_size++; + m_queue.push(QueueEntry{data, m_cycle + delay}); +} + +template +void DelayQueue::push(T data, int delay, int interval) { + assert(m_issued == false); + m_size++; + m_queue.push(QueueEntry{data, m_cycle + delay}); + if(!m_only_latency) m_issued = true; + m_interval = interval; +} + +template +void DelayQueue::pop() { + assert(arrived()); + m_queue.pop(); + m_size--; +} + +template +T DelayQueue::top() { + assert(arrived()); + return m_queue.front().data; +} + +template +bool DelayQueue::arrived() { + return !m_queue.empty() && (m_queue.front().finish_cycle <= m_cycle); +} + +template +bool DelayQueue::queue_empty() { + return m_queue.empty(); +} + +template +bool DelayQueue::full() { + return m_issued || (m_max_size > 0 && m_size >= m_max_size); +} + +template +void DelayQueue::cycle() { + if (m_interval > 0) m_interval--; + if (m_interval <= 0) m_issued = false; + m_cycle++; +} + +template class DelayQueue; diff --git a/PyTorchSimBackend/src/Dram.cc b/PyTorchSimBackend/src/Dram.cc index 1d564dc8..e604f73f 100644 --- a/PyTorchSimBackend/src/Dram.cc +++ b/PyTorchSimBackend/src/Dram.cc @@ -1,204 +1,211 @@ #include "Dram.h" -uint32_t Dram::get_channel_id(MemoryAccess* access) { +uint32_t Dram::get_channel_id(mem_fetch* access) { uint32_t channel_id; if (_n_ch_per_partition >= 16) - channel_id = ipoly_hash_function((new_addr_type)access->dram_address/_config.dram_req_size, 0, _n_ch_per_partition); + channel_id = ipoly_hash_function((new_addr_type)access->get_addr()/_req_size, 0, _n_ch_per_partition); else - channel_id = ipoly_hash_function((new_addr_type)access->dram_address/_config.dram_req_size, 0, 16) % _n_ch_per_partition; - - channel_id += ((access->numa_id % _n_partitions)* _n_ch_per_partition); + channel_id = ipoly_hash_function((new_addr_type)access->get_addr()/_req_size, 0, 16) % _n_ch_per_partition; + + channel_id += ((access->get_numa_id() % _n_partitions)* _n_ch_per_partition); return channel_id; } -/* FIXME: Simple DRAM has bugs */ -SimpleDram::SimpleDram(SimulationConfig config) - : _latency(config.dram_latency) { - _cycles = 0; - _config = config; +Dram::Dram(SimulationConfig config, cycle_type* core_cycle) { + _core_cycles = core_cycle; _n_ch = config.dram_channels; + _n_bl = config.dram_nbl; + _req_size = config.dram_req_size; _n_partitions = config.dram_num_partitions; _n_ch_per_partition = _n_ch / _n_partitions; - _waiting_queue.resize(_n_ch); - _response_queue.resize(_n_ch); -} - -bool SimpleDram::running() { return false; } + _config = config; -void SimpleDram::cycle() { - for (uint32_t ch = 0; ch < _n_ch; ch++) { - if (!_waiting_queue[ch].empty() && - _waiting_queue[ch].front().first <= _cycles) { - _response_queue[ch].push(_waiting_queue[ch].front().second); - _waiting_queue[ch].pop(); - } + spdlog::info("[Config/DRAM] DRAM Bandwidth {} GB/s, Freq: {} MHz, Channels: {}, Request_size: {}", config.max_dram_bandwidth(), config.dram_freq, _n_ch, _req_size); + /* Initialize DRAM Channels */ + for (int ch = 0; ch < _n_ch; ch++) { + m_to_crossbar_queue.push_back(std::queue()); + m_from_crossbar_queue.push_back(std::queue()); } - _cycles++; + /* Initialize L2 cache */ + _m_caches.resize(_n_ch); + if (config.l2d_type == L2CacheType::NOCACHE) { + std::string name = "No cache"; + spdlog::info("[Config/L2Cache] No L2 cache"); + for (int ch = 0; ch < _n_ch; ch++) + _m_caches[ch] = new NoL2Cache(name, _m_cache_config, ch, _core_cycles, &m_to_crossbar_queue[ch], &m_from_crossbar_queue[ch]); + } else if (config.l2d_type == L2CacheType::DATACACHE) { + std::string name = "L2 cache"; + _m_cache_config.init(config.l2d_config_str); + spdlog::info("[Config/L2Cache] Total Size: {} KB, Partition Size: {} KB, Set: {}, Assoc: {}, Line Size: {}B Sector Size: {}B", + _m_cache_config.get_total_size_in_kb() * _n_ch, _m_cache_config.get_total_size_in_kb(), + _m_cache_config.get_num_sets(), _m_cache_config.get_num_assoc(), + _m_cache_config.get_line_size(), _m_cache_config.get_sector_size()); + for (int ch = 0; ch < _n_ch; ch++) + _m_caches[ch] = new L2DataCache(name, _m_cache_config, ch, _core_cycles, _config.l2d_hit_latency, &m_to_crossbar_queue[ch], &m_from_crossbar_queue[ch]); + } else { + spdlog::error("[Config/L2D] Invalid L2 cache type...!"); + exit(EXIT_FAILURE); + } } -bool SimpleDram::is_full(uint32_t cid, MemoryAccess* request) { return false; } +DramRamulator2::DramRamulator2(SimulationConfig config, cycle_type* core_cycle) : Dram(config, core_cycle) { + /* Initialize DRAM Channels */ + _mem.resize(_n_ch); + for (int ch = 0; ch < _n_ch; ch++) { + _mem[ch] = std::make_unique( + ch, _n_ch, config.dram_config_path, "Ramulator2", _config.dram_print_interval, _n_bl); + } + _tx_log2 = log2(_req_size); + _tx_ch_log2 = log2(_n_ch_per_partition) + _tx_log2; +} -void SimpleDram::push(uint32_t cid, MemoryAccess* request) { - request->request = false; - std::pair entity; - entity.first = MAX(_cycles + _latency, _last_finish_cycle); - _last_finish_cycle = entity.first; - entity.second = request; - _waiting_queue[cid].push(entity); +bool DramRamulator2::running() { + for (int ch = 0; ch < _n_ch; ch++) { + if (mem_fetch* req = _mem[ch]->return_queue_top()) + return true; + if (mem_fetch* req = _m_caches[ch]->top()) + return true; + } + return false; } -bool SimpleDram::is_empty(uint32_t cid) { return _response_queue[cid].empty(); } +void DramRamulator2::cycle() { + for (int ch = 0; ch < _n_ch; ch++) { + _mem[ch]->cycle(); -MemoryAccess* SimpleDram::top(uint32_t cid) { - assert(!is_empty(cid)); - return _response_queue[cid].front(); -} + // From Cache to DRAM + if (mem_fetch* req = _m_caches[ch]->top()) { + _mem[ch]->push(req); + _m_caches[ch]->pop(); + } -void SimpleDram::pop(uint32_t cid) { - assert(!is_empty(cid)); - _response_queue[cid].pop(); + // From DRAM to Cache + if (mem_fetch* req = _mem[ch]->return_queue_top()) { + if(_m_caches[ch]->push(req)) + _mem[ch]->return_queue_pop(); + } + } } -DramRamulator::DramRamulator(SimulationConfig config) - : _mem(std::make_unique(config.dram_config_path, - config.num_cores, false)) { - _n_ch = config.dram_channels; - _config = config; - _cycles = 0; - _total_processed_requests.resize(_n_ch); - _processed_requests.resize(_n_ch); +void DramRamulator2::cache_cycle() { for (int ch = 0; ch < _n_ch; ch++) { - _total_processed_requests[ch] = 0; - _processed_requests[ch] = 0; + _m_caches[ch]->cycle(); } } -bool DramRamulator::running() { return false; } - -void DramRamulator::cycle() { - _mem->tick(); - _cycles++; - int interval = _config.dram_print_interval? _config.dram_print_interval: INT32_MAX; - int average = 0; - if (_cycles % interval == 0) { - for (int ch = 0; ch < _n_ch; ch++) { - float util = ((float)_processed_requests[ch]) / interval * 100; - _total_processed_requests[ch] += _processed_requests[ch]; - average += _processed_requests[ch]; - _processed_requests[ch] = 0; - } - spdlog::info("Avg DRAM: BW Util {:.2f}%", (float)average / (interval * _n_ch) * 100); - } +bool DramRamulator2::is_full(uint32_t cid, mem_fetch* request) { + return false; //m_from_crossbar_queue[cid].full(); Infinite length } -bool DramRamulator::is_full(uint32_t cid, MemoryAccess* request) { - return !_mem->isAvailable(cid, request->dram_address, request->write); +void DramRamulator2::push(uint32_t cid, mem_fetch* request) { + addr_type target_addr = (request->get_addr() >> _tx_ch_log2) << _tx_log2; + request->set_addr(target_addr); + m_from_crossbar_queue[cid].push(request); } -void DramRamulator::push(uint32_t cid, MemoryAccess* request) { - const addr_type atomic_bytes = _mem->getAtomicBytes(); - const addr_type target_addr = request->dram_address; - // align address - const addr_type start_addr = target_addr - (target_addr % atomic_bytes); - assert(start_addr == target_addr); - assert(request->size == atomic_bytes); - int count = 0; - request->request = false; - _mem->push(cid, target_addr, request->write, request->core_id, request); +bool DramRamulator2::is_empty(uint32_t cid) { + return m_to_crossbar_queue[cid].empty(); } -bool DramRamulator::is_empty(uint32_t cid) { return _mem->isEmpty(cid); } - -MemoryAccess* DramRamulator::top(uint32_t cid) { +mem_fetch* DramRamulator2::top(uint32_t cid) { assert(!is_empty(cid)); - return (MemoryAccess*)_mem->top(cid); + return m_to_crossbar_queue[cid].front(); } -void DramRamulator::pop(uint32_t cid) { +void DramRamulator2::pop(uint32_t cid) { assert(!is_empty(cid)); - _mem->pop(cid); - _processed_requests[cid]++; + m_to_crossbar_queue[cid].pop(); } -void DramRamulator::print_stat() { - uint32_t total_reqs = 0; +void DramRamulator2::print_stat() { for (int ch = 0; ch < _n_ch; ch++) { - _total_processed_requests[ch] += _processed_requests[ch]; - float util = ((float)_total_processed_requests[ch]) / _cycles * 100; - spdlog::info("DRAM CH[{}]: AVG BW Util {:.2f}%", ch, util); - total_reqs += _total_processed_requests[ch]; + _mem[ch]->print(stdout); } - float util = ((float)total_reqs / _n_ch) / _cycles * 100; - spdlog::info("DRAM: AVG BW Util {:.2f}%", util); - _mem->print_stats(); } -DramRamulator2::DramRamulator2(SimulationConfig config) { - _n_ch = config.dram_channels; - _req_size = config.dram_req_size; - _n_partitions = config.dram_num_partitions; - _n_ch_per_partition = _n_ch / _n_partitions; - _config = config; - _mem.resize(_n_ch); +void DramRamulator2::print_cache_stats() { + for (int ch = 0; ch < _n_ch; ch++) { + _m_caches[ch]->print_stats(); + } +} + +SimpleDRAM::SimpleDRAM(SimulationConfig config, cycle_type* core_cycle) : Dram(config, core_cycle) { + /* Initialize DRAM Channels */ + spdlog::info("[SimpleDRAM] DRAM latecny: {}", config.dram_latency); for (int ch = 0; ch < _n_ch; ch++) { - _mem[ch] = std::make_unique( - ch, _n_ch, config.dram_config_path, "Ramulator2", _config.dram_print_interval, 1); + _mem.push_back(std::make_unique>("SimpleDRAM", true, -1)); } + _latency = config.dram_latency; _tx_log2 = log2(_req_size); _tx_ch_log2 = log2(_n_ch_per_partition) + _tx_log2; } -bool DramRamulator2::running() { +bool SimpleDRAM::running() { + for (int ch = 0; ch < _n_ch; ch++) { + if (!_mem[ch]->queue_empty()) + return true; + if (mem_fetch* req = _m_caches[ch]->top()) + return true; + } return false; } -void DramRamulator2::cycle() { +void SimpleDRAM::cycle() { for (int ch = 0; ch < _n_ch; ch++) { _mem[ch]->cycle(); + + // From Cache to DRAM + if (mem_fetch* req = _m_caches[ch]->top()) { + //spdlog::info("[Cache->DRAM] mem_fetch: addr={:#x}", req->get_addr()); + + _mem[ch]->push(req, _latency); + _m_caches[ch]->pop(); + } + + // From DRAM to Cache + if (_mem[ch]->arrived()) { + mem_fetch* req = _mem[ch]->top(); + req->set_reply(); + //spdlog::info("[DRAM->Cache] mem_fetch: addr={:#x}", req->get_addr()); + if(_m_caches[ch]->push(req)) + _mem[ch]->pop(); + } + } +} + +void SimpleDRAM::cache_cycle() { + for (int ch = 0; ch < _n_ch; ch++) { + _m_caches[ch]->cycle(); } } -bool DramRamulator2::is_full(uint32_t cid, MemoryAccess* request) { - return _mem[cid]->full(); +bool SimpleDRAM::is_full(uint32_t cid, mem_fetch* request) { + return false; //m_from_crossbar_queue[cid].full(); Infinite length } -void DramRamulator2::push(uint32_t cid, MemoryAccess* request) { - addr_type atomic_bytes =_config.dram_req_size; - addr_type target_addr = request->dram_address; - // align address - addr_type start_addr = target_addr - (target_addr % atomic_bytes); - assert(start_addr == target_addr); - assert(request->size == atomic_bytes); - target_addr = (target_addr >> _tx_ch_log2) << _tx_log2; - NDPSim::mem_fetch* mf = new NDPSim::mem_fetch(); - mf->addr = target_addr; - mf->size = request->size; - mf->write = request->write; - mf->request = true; - mf->origin_data = request; - _mem[cid]->push(mf); +void SimpleDRAM::push(uint32_t cid, mem_fetch* request) { + m_from_crossbar_queue[cid].push(request); } -bool DramRamulator2::is_empty(uint32_t cid) { - return _mem[cid]->return_queue_top() == NULL; +bool SimpleDRAM::is_empty(uint32_t cid) { + return m_to_crossbar_queue[cid].empty(); } -MemoryAccess* DramRamulator2::top(uint32_t cid) { +mem_fetch* SimpleDRAM::top(uint32_t cid) { assert(!is_empty(cid)); - NDPSim::mem_fetch* mf = _mem[cid]->return_queue_top(); - ((MemoryAccess*)mf->origin_data)->request = false; - return (MemoryAccess*)mf->origin_data; + return m_to_crossbar_queue[cid].front(); } -void DramRamulator2::pop(uint32_t cid) { +void SimpleDRAM::pop(uint32_t cid) { assert(!is_empty(cid)); - NDPSim::mem_fetch* mf = _mem[cid]->return_queue_pop(); - delete mf; + m_to_crossbar_queue[cid].pop(); } -void DramRamulator2::print_stat() { +void SimpleDRAM::print_stat() {} + +void SimpleDRAM::print_cache_stats() { for (int ch = 0; ch < _n_ch; ch++) { - _mem[ch]->print(stdout); + _m_caches[ch]->print_stats(); } } diff --git a/PyTorchSimBackend/src/Dram.h b/PyTorchSimBackend/src/Dram.h deleted file mode 100644 index 112d1783..00000000 --- a/PyTorchSimBackend/src/Dram.h +++ /dev/null @@ -1,97 +0,0 @@ -#ifndef DRAM_H -#define DRAM_H -#include -#include -#include -#include - -#include "Common.h" -#include "TMA.h" -#include "ramulator/Ramulator.hpp" -#include "ramulator2.hh" -#include "Hashing.h" - -class Dram { - public: - virtual ~Dram() = default; - virtual bool running() = 0; - virtual void cycle() = 0; - virtual bool is_full(uint32_t cid, MemoryAccess* request) = 0; - virtual void push(uint32_t cid, MemoryAccess* request) = 0; - virtual bool is_empty(uint32_t cid) = 0; - virtual MemoryAccess* top(uint32_t cid) = 0; - virtual void pop(uint32_t cid) = 0; - uint32_t get_channel_id(MemoryAccess* request); - virtual void print_stat() {} - - protected: - SimulationConfig _config; - uint32_t _n_ch; - uint32_t _n_partitions; - uint32_t _n_ch_per_partition; - cycle_type _cycles; -}; - -class SimpleDram : public Dram { - public: - SimpleDram(SimulationConfig config); - virtual bool running() override; - virtual void cycle() override; - virtual bool is_full(uint32_t cid, MemoryAccess* request) override; - virtual void push(uint32_t cid, MemoryAccess* request) override; - virtual bool is_empty(uint32_t cid) override; - virtual MemoryAccess* top(uint32_t cid) override; - virtual void pop(uint32_t cid) override; - - private: - uint32_t _latency; - double _bandwidth; - - uint64_t _last_finish_cycle; - std::vector>> _waiting_queue; - std::vector> _response_queue; -}; - -class DramRamulator : public Dram { - public: - DramRamulator(SimulationConfig config); - - virtual bool running() override; - virtual void cycle() override; - virtual bool is_full(uint32_t cid, MemoryAccess* request) override; - virtual void push(uint32_t cid, MemoryAccess* request) override; - virtual bool is_empty(uint32_t cid) override; - virtual MemoryAccess* top(uint32_t cid) override; - virtual void pop(uint32_t cid) override; - virtual void print_stat() override; - - private: - std::unique_ptr _mem; - robin_hood::unordered_flat_map _waiting_mem_access; - std::queue _responses; - - std::vector _total_processed_requests; - std::vector _processed_requests; -}; - -class DramRamulator2 : public Dram { - public: - DramRamulator2(SimulationConfig config); - - virtual bool running() override; - virtual void cycle() override; - virtual bool is_full(uint32_t cid, MemoryAccess* request) override; - virtual void push(uint32_t cid, MemoryAccess* request) override; - virtual bool is_empty(uint32_t cid) override; - virtual MemoryAccess* top(uint32_t cid) override; - virtual void pop(uint32_t cid) override; - virtual void print_stat() override; - - private: - std::vector> _mem; - int _tx_ch_log2; - int _tx_log2; - int _req_size; -}; - -#endif \ No newline at end of file diff --git a/PyTorchSimBackend/src/Hashing.cc b/PyTorchSimBackend/src/Hashing.cc index 45482867..868178ae 100644 --- a/PyTorchSimBackend/src/Hashing.cc +++ b/PyTorchSimBackend/src/Hashing.cc @@ -95,3 +95,57 @@ unsigned ipoly_hash_function(new_addr_type higher_bits, unsigned index, return 0; } } + +unsigned bitwise_hash_function(new_addr_type higher_bits, unsigned index, + unsigned bank_set_num) { + return (index) ^ (higher_bits & (bank_set_num - 1)); +} + +unsigned PAE_hash_function(new_addr_type higher_bits, unsigned index, + unsigned bank_set_num) { + // Page Address Entropy + // random selected bits from the page and bank bits + // similar to + // Liu, Yuxi, et al. "Get Out of the Valley: Power-Efficient Address + if (bank_set_num == 32) { + std::bitset<64> a(higher_bits); + std::bitset<5> b(index); + std::bitset<5> new_index(index); + new_index[0] = a[13] ^ a[10] ^ a[9] ^ a[5] ^ a[0] ^ b[3] ^ b[0] ^ b[0]; + new_index[1] = a[12] ^ a[11] ^ a[6] ^ a[1] ^ b[3] ^ b[2] ^ b[1] ^ b[1]; + new_index[2] = a[14] ^ a[9] ^ a[8] ^ a[7] ^ a[2] ^ b[1] ^ b[2]; + new_index[3] = a[11] ^ a[10] ^ a[8] ^ a[3] ^ b[2] ^ b[3] ^ b[3]; + new_index[4] = a[12] ^ a[9] ^ a[8] ^ a[5] ^ a[4] ^ b[1] ^ b[0] ^ b[4]; + + return new_index.to_ulong(); + } else { + assert(0); + return 0; + } +} + +unsigned mini_hash_function(new_addr_type higher_bits, unsigned index, + unsigned bank_set_num) { + if (bank_set_num == 16) { + std::bitset<64> a(higher_bits); + std::bitset<4> b(index); + std::bitset<4> new_index(index); + + new_index[0] = a[0] ^ b[0]; + new_index[1] = a[0] ^ b[1]; + new_index[2] = a[1] ^ b[2]; + new_index[3] = a[1] ^ b[3]; + + + return new_index.to_ulong(); + } else { /* Else incorrect number of channels for the hashing function */ + assert( + "\nmemory_partition_indexing error: The number of " + "channels should be " + "16, 32 or 64 for the hashing IPOLY index function. other banks " + "numbers are not supported. Generate it by yourself! \n" && + 0); + + return 0; + } +} \ No newline at end of file diff --git a/PyTorchSimBackend/src/Hashing.h b/PyTorchSimBackend/src/Hashing.h deleted file mode 100644 index dc134792..00000000 --- a/PyTorchSimBackend/src/Hashing.h +++ /dev/null @@ -1,16 +0,0 @@ -// author: Mahmoud Khairy, (Purdue Univ) -// email: abdallm@purdue.edu - -#include -#include -#include - -#ifndef HASHING_H -#define HASHING_H - -typedef unsigned long long new_addr_type; - -unsigned ipoly_hash_function(new_addr_type higher_bits, unsigned index, - unsigned bank_set_num); - -#endif \ No newline at end of file diff --git a/PyTorchSimBackend/src/Instruction.cc b/PyTorchSimBackend/src/Instruction.cc index 9039f3d2..aef9079c 100644 --- a/PyTorchSimBackend/src/Instruction.cc +++ b/PyTorchSimBackend/src/Instruction.cc @@ -11,19 +11,22 @@ std::string opcode_to_string(Opcode opcode) { } Instruction::Instruction(Opcode opcode, cycle_type compute_cycle, size_t num_parents, - addr_type dram_addr, std::vector tile_size, size_t precision, - std::vector& idx_list, std::vector& stride_list, std::vector tag_idx_list, std::vector loop_size_list) + addr_type dram_addr, std::vector tile_size, std::vector tile_stride, size_t precision, + std::vector tag_idx_list, std::vector tag_stride_list, + std::vector accum_tag_idx_list) : opcode(opcode), compute_cycle(compute_cycle), ready_counter(num_parents), dram_addr(dram_addr), - tile_size(tile_size), _precision(precision), _idx_list(idx_list), - _stride_list(stride_list), _tag_idx_list(tag_idx_list), _loop_size_list(loop_size_list) { + tile_size(tile_size), tile_stride(tile_stride), _precision(precision), + _tag_idx_list(tag_idx_list), _tag_stride_list(tag_stride_list), + _accum_tag_idx_list(accum_tag_idx_list) { + assert(_tag_idx_list.size()==_tag_stride_list.size()); _tile_numel = 1; for (auto dim : tile_size) _tile_numel *= dim; +} - /* Supporting vector */ - if (_stride_list.size() == 1) { - _stride_list.push_back(1); - } +Instruction::Instruction(Opcode opcode) + : opcode(opcode) { + _tile_numel = 1; } void Instruction::finish_instruction() { @@ -46,6 +49,84 @@ void Instruction::dec_waiting_request() { _nr_waiting_request--; } +void Instruction::prepare_tag_key() { + /* Calculate tag key */ + int key_offset = 0; + _tag_key.push_back(_addr_id); + for (int i=0; i<_tag_idx_list.size(); i++) + key_offset += _tag_idx_list.at(i) * _tag_stride_list.at(i); + for (auto accum_dim : _accum_tag_idx_list) + _tag_key.push_back(accum_dim); + _tag_key.push_back(key_offset); +} + void Instruction::print() { spdlog::info("{}", opcode_to_string(opcode)); +} + +std::shared_ptr> Instruction::get_dram_address(addr_type dram_req_size) { + auto address_set = std::make_shared>(); + uint64_t* indirect_index = NULL; + size_t index_count = 0; + /* Set 4D shape*/ + while (tile_size.size() < 4) + tile_size.insert(tile_size.begin(), 1); + + while (tile_stride.size() < 4) + tile_stride.insert(tile_stride.begin(), 0); + if (_is_indirect_mode) { + spdlog::trace("[Indirect Access] Indirect mode, dump_path: {}", _indirect_index_path); + load_indirect_index(_indirect_index_path, indirect_index, tile_size); + } + + /* Iterate tile_size */ + for (int dim0=0; dim0insert(address - (address & dram_req_size-1)); + } + } + } + } + return address_set; +} + +bool Instruction::load_indirect_index(const std::string& path, uint64_t*& indirect_index, const std::vector& tile_size) { + size_t count; + std::ifstream ifs(path, std::ios::binary | std::ios::ate); + if (!ifs) { + spdlog::warn("[Indirect Access] Failed to open index file(\'{}\')", path); + return false; + } + + std::streamsize size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + count = size / sizeof(uint64_t); + + uint64_t expected_count = tile_size[0] * tile_size[1] * tile_size[2] * tile_size[3]; + if (size % sizeof(uint64_t) != 0 || count != expected_count) { + spdlog::warn("[Indirect Access] Invalid file size ({} Bytes) at \'{}\'", size, path); + return false; + } + + indirect_index = new uint64_t[count]; + + if (!ifs.read(reinterpret_cast(indirect_index), size)) { + spdlog::warn("[Indirect Access] Failed to read data from file (\'{}\')", path); + delete[] indirect_index; + indirect_index = NULL; + count = 0; + return false; + } + return true; } \ No newline at end of file diff --git a/PyTorchSimBackend/src/Interconnect.cc b/PyTorchSimBackend/src/Interconnect.cc index dc62e402..8a684ff7 100644 --- a/PyTorchSimBackend/src/Interconnect.cc +++ b/PyTorchSimBackend/src/Interconnect.cc @@ -2,7 +2,6 @@ SimpleInterconnect::SimpleInterconnect(SimulationConfig config) : _latency(config.icnt_latency) { - spdlog::info("Initialize SimpleInterconnect"); _cycles = 0; _config = config; _n_nodes = config.num_cores + config.dram_channels; @@ -40,7 +39,7 @@ void SimpleInterconnect::cycle() { _cycles++; } -void SimpleInterconnect::push(uint32_t src, uint32_t dest, MemoryAccess* request) { +void SimpleInterconnect::push(uint32_t src, uint32_t dest, mem_fetch* request) { SimpleInterconnect::Entity entity; if(_in_buffers[src].empty()) entity.finish_cycle = _cycles + _latency; @@ -51,7 +50,7 @@ void SimpleInterconnect::push(uint32_t src, uint32_t dest, MemoryAccess* request _in_buffers[src].push(entity); } -bool SimpleInterconnect::is_full(uint32_t nid, MemoryAccess* request) { +bool SimpleInterconnect::is_full(uint32_t nid, mem_fetch* request) { //TODO: limit buffersize return false; } @@ -60,7 +59,7 @@ bool SimpleInterconnect::is_empty(uint32_t nid) { return _out_buffers[nid].empty(); } -MemoryAccess* SimpleInterconnect::top(uint32_t nid) { +mem_fetch* SimpleInterconnect::top(uint32_t nid) { assert(!is_empty(nid)); return _out_buffers[nid].front(); } @@ -93,13 +92,13 @@ void Booksim2Interconnect::cycle() { _booksim->run(); } -void Booksim2Interconnect::push(uint32_t src, uint32_t dest, MemoryAccess* request) { +void Booksim2Interconnect::push(uint32_t src, uint32_t dest, mem_fetch* request) { booksim2::Interconnect::Type type = get_booksim_type(request); uint32_t size = get_packet_size(request); _booksim->push(request, 0, 0, size, type, src, dest); } -bool Booksim2Interconnect::is_full(uint32_t nid, MemoryAccess* request) { +bool Booksim2Interconnect::is_full(uint32_t nid, mem_fetch* request) { uint32_t size = get_packet_size(request); return _booksim->is_full(nid, 0, size); } @@ -108,9 +107,9 @@ bool Booksim2Interconnect::is_empty(uint32_t nid) { return _booksim->is_empty(nid, 0); } -MemoryAccess* Booksim2Interconnect::top(uint32_t nid) { +mem_fetch* Booksim2Interconnect::top(uint32_t nid) { assert(!is_empty(nid)); - return (MemoryAccess*) _booksim->top(nid, 0); + return (mem_fetch*) _booksim->top(nid, 0); } void Booksim2Interconnect::pop(uint32_t nid) { @@ -122,44 +121,44 @@ void Booksim2Interconnect::print_stats() { _booksim->print_stats(); } -booksim2::Interconnect::Type Booksim2Interconnect::get_booksim_type(MemoryAccess* access) { +booksim2::Interconnect::Type Booksim2Interconnect::get_booksim_type(mem_fetch* access) { booksim2::Interconnect::Type type; - if(access->write && access->request) { - /* Write request */ - type = booksim2::Interconnect::Type::WRITE; - } - else if(access->write && !access->request) { - /* Write response */ - type = booksim2::Interconnect::Type::WRITE_REPLY; - } - else if(!access->write && access->request){ - /* Read request */ + switch (access->get_type()) + { + case mf_type::READ_REQUEST: type = booksim2::Interconnect::Type::READ; - } - else if(!access->write && !access->request) { - /* Read reply */ + break; + case mf_type::READ_REPLY: type = booksim2::Interconnect::Type::READ_REPLY; + break; + case mf_type::WRITE_REQUEST: + type = booksim2::Interconnect::Type::WRITE; + break; + case mf_type::WRITE_ACK: + type = booksim2::Interconnect::Type::WRITE_REPLY; + break; + default: + spdlog::error("[Interconenct] Unexpected memory type..."); + break; } return type; } -uint32_t Booksim2Interconnect::get_packet_size(MemoryAccess* access) { +uint32_t Booksim2Interconnect::get_packet_size(mem_fetch* access) { uint32_t size; - if(access->write && access->request) { - /* Write request */ - size = access->size; - } - else if(access->write && !access->request) { - /* Write response */ - size = _ctrl_size; - } - else if(!access->write && access->request){ - /* Read request */ + switch (access->get_type()) + { + case mf_type::READ_REQUEST: + case mf_type::WRITE_ACK: size = _ctrl_size; - } - else if(!access->write && !access->request) { - /* Read reply */ - size = access->size; + break; + case mf_type::READ_REPLY: + case mf_type::WRITE_REQUEST: + size = access->get_data_size(); + break; + default: + spdlog::error("[Interconenct] Unexpected memory type..."); + break; } return size; } \ No newline at end of file diff --git a/PyTorchSimBackend/src/L2Cache.cc b/PyTorchSimBackend/src/L2Cache.cc new file mode 100644 index 00000000..14c9b9da --- /dev/null +++ b/PyTorchSimBackend/src/L2Cache.cc @@ -0,0 +1,105 @@ +#include "L2Cache.h" + +bool NoL2Cache::push(mem_fetch* req) { + l_to_xbar_queue->push(req); + return true; +} +void NoL2Cache::cycle() { + if (!l_from_xbar_queue->empty()) { + mem_fetch* req = l_from_xbar_queue->front(); + l_to_mem_queue.push(req); + l_from_xbar_queue->pop(); + } +} + +L2DataCache::L2DataCache(std::string name, CacheConfig &cache_config, uint32_t id, + cycle_type *core_cycle, uint32_t l2d_hit_latency, + std::queue *to_xbar_queue, std::queue *from_xbar_queue) : + L2CacheBase(name, cache_config, id, core_cycle, l2d_hit_latency, to_xbar_queue, from_xbar_queue) { + l_cache = std::make_unique(name, cache_config, id, 0, &l_to_mem_queue); + l_from_cache_queue = DelayQueue(l_name + "_latency_queue", true, 0); +} + +bool L2DataCache::push(mem_fetch* req) { + if (l_cache->waiting_for_fill(req)) { + if (!l_cache->fill_port_free()) + return false; + l_cache->fill(req, *l_core_cycle); + } else { + if (req->get_access_type() == GLOBAL_ACC_R || req->get_access_type() == GLOBAL_ACC_W) + l_to_xbar_queue->push(req); + } + return true; +} + +void L2DataCache::cycle() { + l_from_cache_queue.cycle(); + l_cache->cycle(); + + // Mem to Cache + uint32_t line_size = l_cache_config.get_line_size(); + uint32_t sector_size = l_cache_config.get_sector_size(); + + /* Pass a request to cache */ + if (!l_from_xbar_queue->empty()) { + mem_fetch* req = l_from_xbar_queue->front(); + /* Check cache plan */ + bool is_cacheable = req->is_cacheable(); + + /* Go to l2 cache */ + if (is_cacheable && l_cache->data_port_free()) { + req->set_access_sector_mask(line_size, sector_size); + std::deque events; + CacheRequestStatus status = l_cache->access( + req->get_addr(), *l_core_cycle, req, events); + bool write_sent = CacheEvent::was_write_sent(events); + bool read_sent = CacheEvent::was_read_sent(events); + if (status == HIT) { + if (!write_sent) { + req->set_reply(); + req->current_state = "L2HIT"; + l_from_cache_queue.push(req, l2d_hit_latency); + } + l_from_xbar_queue->pop(); + } else if (status != RESERVATION_FAIL) { + req->current_state = "L2MISS"; + if (req->is_write() && + (l_cache_config.get_write_alloc_policy() == FETCH_ON_WRITE || + l_cache_config.get_write_alloc_policy() == LAZY_FETCH_ON_READ)) { + req->set_reply(); + req->current_state = "L2MISS-WRITE"; + l_from_cache_queue.push(req, l2d_hit_latency); + } + l_from_xbar_queue->pop(); + } else { + // Status Reservation fail, Retry it + assert(!write_sent); + assert(!read_sent); + } + } else if (!is_cacheable) { + l_to_mem_queue.push(req); + l_from_xbar_queue->pop(); + } + } + + if (l_cache->access_ready() && + !l_from_cache_queue.full()) { + mem_fetch* req = l_cache->top_next_access(); + if (req->is_request()) req->set_reply(); + l_from_cache_queue.push(req, l2d_hit_latency); + l_cache->pop_next_access(); + } + + if (l_from_cache_queue.arrived()) { + mem_fetch* req = l_from_cache_queue.top(); + if (req->get_access_type() == GLOBAL_ACC_R || req->get_access_type() == GLOBAL_ACC_W) + l_to_xbar_queue->push(req); + l_from_cache_queue.pop(); + } +} + +void L2DataCache::print_stats() { + if (l_id == 0) { + l_cache->get_stats().print_stats(stdout, l_name.c_str()); + } +} \ No newline at end of file diff --git a/PyTorchSimBackend/src/Simulator.cc b/PyTorchSimBackend/src/Simulator.cc index a296897a..6bc80286 100644 --- a/PyTorchSimBackend/src/Simulator.cc +++ b/PyTorchSimBackend/src/Simulator.cc @@ -3,11 +3,6 @@ Simulator::Simulator(SimulationConfig config) : _config(config), _core_cycles(0) { // Create dram object - spdlog::info("Simulator Configuration:"); - for (int i=0; i(core_index, _config); + } else if(config.core_type[core_index] == CoreType::STONNE) { + spdlog::info("[Config/Core] Core {}: {} MHz, Stonne Core selected", core_index, config.core_freq); + _cores.at(core_index) = std::make_unique(core_index, _config); + } else { + throw std::runtime_error(fmt::format("Not implemented Core type {} ", + (int)config.core_type[core_index])); + } + } + if (config.dram_type == DramType::SIMPLE) { - _dram = std::make_unique(config); - } else if (config.dram_type == DramType::RAMULATOR1) { - std::string ramulator_config = fs::path(onnxim_path) - .append("configs") - .append(config.dram_config_path) - .string(); - config.dram_config_path = ramulator_config; - _dram = std::make_unique(config); + _dram = std::make_unique(config, &_core_cycles); } else if (config.dram_type == DramType::RAMULATOR2) { std::string ramulator_config = fs::path(onnxim_path) .append("configs") .append(config.dram_config_path) .string(); - spdlog::info("Ramulator2 config: {}", ramulator_config); + spdlog::info("[Config/DRAM] Ramulator2 config: {}", ramulator_config); config.dram_config_path = ramulator_config; - _dram = std::make_unique(config); + _dram = std::make_unique(config, &_core_cycles); } else { spdlog::error("[Configuration] Invalid DRAM type...!"); exit(EXIT_FAILURE); } // Create interconnect object + spdlog::info("[Config/Interconnect] Inerconnect freq: {} MHz", config.icnt_freq); if (config.icnt_type == IcntType::SIMPLE) { + spdlog::info("[Config/Interconnect] SimpleInerconnect selected"); _icnt = std::make_unique(config); } else if (config.icnt_type == IcntType::BOOKSIM2) { + spdlog::info("[Config/Interconnect] BookSim2 selected"); _icnt = std::make_unique(config); } else { - spdlog::error("[Configuration] {} Invalid interconnect type...!"); + spdlog::error("[Configuration] Invalid interconnect type...!"); exit(EXIT_FAILURE); } _icnt_interval = config.icnt_print_interval; - // Create core objects - _cores.resize(config.num_cores); - for (int core_index = 0; core_index < _n_cores; core_index++) - _cores[core_index] = std::make_unique(core_index, _config); - // Initialize Scheduler for (int i=0; i(Scheduler(config, &_core_cycles, &_core_time, i))); @@ -72,15 +75,10 @@ void Simulator::run_simulator() { } void Simulator::core_cycle() { - for (int core_id = 0; core_id < _n_cores; core_id++) { - std::shared_ptr finished_tile = _cores[core_id]->pop_finished_tile(); - if (finished_tile->get_status() == Tile::Status::FINISH) { - get_partition_scheduler(core_id)->finish_tile(std::move(finished_tile)); - } - + for (int i=0; i<_max_slot; i++, _slot_id=(_slot_id + 1) % _max_slot) { // Issue new tile to core - for (int i=0; i<_max_slot; i++, _slot_id=(_slot_id + 1) % _max_slot) { - const std::shared_ptr tile = get_partition_scheduler(core_id)->peek_tile(core_id, _slot_id); + for (int core_id = 0; core_id < _n_cores; core_id++) { + const std::shared_ptr tile = get_partition_scheduler(core_id)->peek_tile(core_id, _slot_id, _config.core_type[core_id]); if (tile->get_status() != Tile::Status::EMPTY && _cores[core_id]->can_issue(tile)) { if (tile->get_status() == Tile::Status::INITIALIZED) { _cores[core_id]->issue(std::move(get_partition_scheduler(core_id)->get_tile(core_id, _slot_id))); @@ -91,8 +89,16 @@ void Simulator::core_cycle() { } } } + } + for (int core_id = 0; core_id < _n_cores; core_id++) { + std::shared_ptr finished_tile = _cores[core_id]->pop_finished_tile(); + if (finished_tile->get_status() == Tile::Status::FINISH) { + get_partition_scheduler(core_id)->finish_tile(std::move(finished_tile)); + } _cores[core_id]->cycle(); } + /* L2 cache */ + _dram->cache_cycle(); _core_cycles++; } @@ -108,9 +114,14 @@ void Simulator::icnt_cycle() { // PUHS core to ICNT. memory request int port_id = core_id * _noc_node_per_core + noc_id; if (_cores[core_id]->has_memory_request()) { - MemoryAccess *front = _cores[core_id]->top_memory_request(); - front->core_id = core_id; + mem_fetch *front = _cores[core_id]->top_memory_request(); + front->set_core_id(core_id); if (!_icnt->is_full(port_id, front)) { + //int node_id = _dram->get_channel_id(front) / 16; + //if (core_id == node_id) + // _cores[core_id]->inc_numa_hit(); + //else + // _cores[core_id]->inc_numa_miss(); _icnt->push(port_id , get_dest_node(front), front); _cores[core_id]->pop_memory_request(); _nr_from_core++; @@ -177,17 +188,30 @@ int Simulator::until(cycle_type until_cycle) { // Check if core status has changed if (_core_cycles % 10 == 0) { + int bitmap = 0; for (int i=0; i<_partition_scheduler.size(); i++) { /* Skip this */ if (partition_scheudler_status.at(i)) continue; - if (_partition_scheduler.at(i)->empty()) - return i; + if (_partition_scheduler.at(i)->empty()) { + bitmap |= (1 << i); + } } + if (bitmap) + return bitmap; } } - return -1; + int bitmap = 0; + for (int i=0; i<_partition_scheduler.size(); i++) { + /* Skip this */ + if (partition_scheudler_status.at(i)) + continue; + + if (_partition_scheduler.at(i)->empty()) + bitmap |= (1ULL << i); + } + return bitmap; } void Simulator::cycle() { @@ -206,6 +230,9 @@ void Simulator::cycle() { icnt_cycle(); } spdlog::info("Simulation Finished"); + for (auto &core: _cores) { + core->check_tag(); + } } bool Simulator::running() { @@ -238,19 +265,31 @@ void Simulator::set_cycle_mask() { } } -uint32_t Simulator::get_dest_node(MemoryAccess *access) { - if (access->request) { +uint32_t Simulator::get_dest_node(mem_fetch *access) { + switch (access->get_type()) + { + case mf_type::READ_REQUEST: + case mf_type::WRITE_REQUEST: return _config.num_cores * _noc_node_per_core + _dram->get_channel_id(access); - } else { - return access->core_id * _noc_node_per_core + (_dram->get_channel_id(access) % _noc_node_per_core); + break; + case mf_type::READ_REPLY: + case mf_type::WRITE_ACK: + return access->get_core_id() * _noc_node_per_core + (_dram->get_channel_id(access) % _noc_node_per_core); + break; + default: + spdlog::error("Unexpected memfetc type..."); + return -1; + break; } } void Simulator::print_core_stat() { + _icnt->print_stats(); + _dram->print_stat(); + _dram->print_cache_stats(); for (int core_id = 0; core_id < _n_cores; core_id++) { _cores[core_id]->print_stats(); } - _icnt->print_stats(); - _dram->print_stat(); + spdlog::info("Total execution cycle: {}", _core_cycles); } \ No newline at end of file diff --git a/PyTorchSimBackend/src/SparseCore.cc b/PyTorchSimBackend/src/SparseCore.cc new file mode 100644 index 00000000..64d3da55 --- /dev/null +++ b/PyTorchSimBackend/src/SparseCore.cc @@ -0,0 +1,464 @@ +#include "SparseCore.h" + +SparseCore::SparseCore(uint32_t id, SimulationConfig config) : Core(id, config) { + /* Init stonne cores*/ + nr_cores = config.num_stonne_per_core; + coreBusy.resize(nr_cores); + traceCoreStatus.resize(nr_cores); + traceCoreCycle.resize(nr_cores); + traceNodeList.resize(nr_cores); + traceLoadTraffic.resize(nr_cores); + traceStoreTraffic.resize(nr_cores); + percore_tiles.resize(nr_cores); + stonneCores.resize(nr_cores); + traceMode.resize(nr_cores); + percore_stat.resize(nr_cores); + percore_total_stat.resize(nr_cores); + for (int i=0; iinit(1); + coreBusy.at(i) = false; + traceCoreStatus.at(i) = 0; + traceCoreCycle.at(i) = 0; + percore_tiles.at(i) = std::vector>(); + percore_stat.at(i).reset(); + percore_total_stat.at(i).reset(); + } + + Config stonneConfig = stonneCores.at(0)->getStonneConfig(); + unsigned int core_freq = config.core_freq; // MHz; + num_ms = stonneConfig.m_MSNetworkCfg.ms_size; + r_port_nr = config.num_stonne_port; + w_port_nr = config.num_stonne_port; + + double compute_throughput = static_cast(num_ms) * core_freq / 1e3; // FLOPs/sec + double dn_bandwidth = static_cast(r_port_nr) * config.dram_req_size * core_freq * 1e6 / 8.0 / 1e9; // GB/s + double rn_bandwidth = static_cast(w_port_nr) * config.dram_req_size * core_freq * 1e6 / 8.0 / 1e9; // GB/s + for (int i=0; i tile) { + int32_t selected_core_idx = -1; + for (int i=0; iinit(1); + traceNodeList.at(selected_core_idx).clear(); + + SST_STONNE::StonneOpDesc *opDesc = static_cast(tile->get_custom_data()); + bool is_trace_mode = true; + if (opDesc) { + is_trace_mode = false; + stonneCores.at(selected_core_idx)->setup(*opDesc, 0x1000000 * selected_core_idx); // FIXME. To avoid same address + stonneCores.at(selected_core_idx)->init(1); + } + setTraceMode(selected_core_idx, is_trace_mode); + percore_tiles.at(selected_core_idx).push_back(tile); + coreBusy.at(selected_core_idx) = true; + spdlog::info("[StonneCore {}][{}] issued new tile (trace_mode: {})", _id, selected_core_idx, is_trace_mode); +}; + +bool SparseCore::can_issue(const std::shared_ptr& op) { + bool idle_exist = false; + for (bool flag : coreBusy) { + idle_exist |= !flag; + } + return idle_exist && op->is_stonne_tile(); +} + +void SparseCore::checkStatus(uint32_t subcore_id) { + auto& stonneCore = stonneCores.at(subcore_id); + int new_status = stonneCore->getMCFSMStats(); + int compute_cycle = stonneCore->getMSStats().n_multiplications; + if (traceCoreStatus.at(subcore_id) != new_status) { + spdlog::trace("Stonne Core [{}][{}] status transition {} -> {}, Load/Store: {}/{}, compute_cycle: {}", + _id, _core_cycle, traceCoreStatus.at(subcore_id), new_status, + traceLoadTraffic.at(subcore_id).size(), traceStoreTraffic.at(subcore_id).size(), (compute_cycle - traceCoreCycle.at(subcore_id))/num_ms); + if (traceLoadTraffic.at(subcore_id).size()) { + TraceNode load_node = TraceNode(traceNodeList.at(subcore_id).size()+2, "load", TraceNode::StonneTraceLoad); + load_node.setAddress(traceLoadTraffic.at(subcore_id)); + traceNodeList.at(subcore_id).push_back(load_node); + } + if (_core_cycle - traceCoreCycle.at(subcore_id)) {//((compute_cycle - traceCoreCycle.at(subcore_id))/num_ms) { + TraceNode compute_node = TraceNode(traceNodeList.at(subcore_id).size()+2, "compute", TraceNode::StonneTraceCompute, _core_cycle - traceCoreCycle.at(subcore_id)); + traceNodeList.at(subcore_id).push_back(compute_node); + } + if (traceStoreTraffic.at(subcore_id).size()) { + TraceNode store_node = TraceNode(traceNodeList.at(subcore_id).size()+2, "store", TraceNode::StonneTraceStore); + store_node.setAddress(traceStoreTraffic.at(subcore_id)); + traceNodeList.at(subcore_id).push_back(store_node); + } + + traceCoreStatus.at(subcore_id) = new_status; + traceCoreCycle.at(subcore_id) = _core_cycle; + traceLoadTraffic.at(subcore_id).clear(); + traceStoreTraffic.at(subcore_id).clear(); + } +} + +void SparseCore::subCoreCycle(uint32_t subcore_id) { + if (!traceMode.at(subcore_id)) { + auto& stonneCore = stonneCores.at(subcore_id); + stonneCore->cycle(); + + /* Check FSM status transition */ + checkStatus(subcore_id); + + /* Send Memory Request */ + while (SimpleMem::Request* req = stonneCore->popRequest()) { + uint64_t target_addr = (req->getAddress() / _config.dram_req_size) * _config.dram_req_size; + mem_access_type acc_type; + mf_type type; + + switch(req->getcmd()) { + case SimpleMem::Request::Read: + acc_type = mem_access_type::GLOBAL_ACC_R; + type = mf_type::READ_REQUEST; + traceLoadTraffic.at(subcore_id).insert(target_addr); + break; + case SimpleMem::Request::Write: + acc_type = mem_access_type::GLOBAL_ACC_W; + type = mf_type::WRITE_REQUEST; + traceStoreTraffic.at(subcore_id).insert(target_addr); + break; + default: + spdlog::error("[SparseCore] Invalid request type from core"); + return; + } + req->request_time = _core_cycle; + req->stonneId = subcore_id; + std::tuple key = std::make_tuple(target_addr, acc_type, type, allocTrafficID()); + registerMemfetch(key, [this, req, acc_type, type]() { + spdlog::trace("[SparseCore][{}] Round Trip Cycle: {}, Address: {:#x}, Request Type: {}, DRAM Req Size: {}", \ + _core_cycle, _core_cycle - req->request_time, req->getAddress(), int(req->getcmd()), _config.dram_req_size); + req->setReply(); + stonneCores.at(req->stonneId)->pushResponse(req); + }); + } + + /* Finish stonne core */ + if (coreBusy.at(subcore_id) && stonneCore->isFinished()) { + stonneCore->finish(); + spdlog::info("[SparseCore][{}] Operation finished at {}", _id, _core_cycle); + std::shared_ptr target_tile = percore_tiles.at(subcore_id).front(); + SST_STONNE::StonneOpDesc *opDesc = static_cast(target_tile->get_custom_data()); + if (opDesc->trace_path != "") + dumpTrace(subcore_id, opDesc->trace_path); + + target_tile->set_status(Tile::Status::FINISH); + _finished_tiles.push(target_tile); + percore_tiles.at(subcore_id).erase(percore_tiles.at(subcore_id).begin()); + coreBusy.at(subcore_id) = false; + } + } else { + /* Check finished computation */ + auto& target_pipeline = get_compute_pipeline(0); + if (!target_pipeline.empty()) { + if (target_pipeline.front()->finish_cycle <= _core_cycle) { + finish_instruction(target_pipeline.front()); + target_pipeline.pop(); + } + percore_stat.at(subcore_id).n_multiplications += num_ms; + } + + /* Check finished dma operation */ + bool retry=true; + while (retry) { + retry = false; + for (auto it=_dma_finished_queue.begin();it!=_dma_finished_queue.end();it++) { + std::shared_ptr& instruction = _dma_finished_queue.at(0); + /* Pass not finished instruction */ + if (instruction->get_waiting_request()) + continue; + + /* Finish DMA read instruction */ + if (instruction->is_dma_read()) + finish_instruction(instruction); + /* Erase the instruction in DMA finished queue */ + _dma_finished_queue.erase(_dma_finished_queue.begin()); + retry = true; + break; + } + } + + auto& tile_queue = percore_tiles.at(subcore_id); + if (tile_queue.empty()) + return; + auto& instructions = tile_queue.front()->get_instructions(); + + /* Finish stonne core */ + if (coreBusy.at(subcore_id) && instructions.empty()) { + std::shared_ptr target_tile = percore_tiles.at(subcore_id).front(); + target_tile->set_status(Tile::Status::FINISH); + _finished_tiles.push(target_tile); + percore_tiles.at(subcore_id).erase(percore_tiles.at(subcore_id).begin()); + coreBusy.at(subcore_id) = false; + return; + } + + /* Peek instruction*/ + if (instructions.empty()) + return; + auto& inst = instructions.front(); + if (!inst->is_ready()) + return; + + + bool issued = false; + switch (inst->get_opcode()) { + case Opcode::MOVIN: + { + auto acc_type = mem_access_type::GLOBAL_ACC_R; + auto type = mf_type::READ_REQUEST; + spdlog::trace("[StonneCore {}][{}][{}] {} ISSUED", _id, subcore_id, _core_cycle, + opcode_to_string(inst->get_opcode())); + for (auto addr : inst->get_trace_address()) { + addr = addr - (addr & _config.dram_req_size-1); + inst->inc_waiting_request(); + std::tuple key = std::make_tuple(addr, acc_type, type, allocTrafficID()); + uint64_t current_time = _core_cycle; + registerMemfetch(key, [this, inst, addr, current_time, type]() { + spdlog::trace("[SparseCore][{}] Round Trip Cycle: {}, Address: {:#x}, Request Type: {}, DRAM Req Size: {}", \ + this->_core_cycle, this->_core_cycle - current_time, addr, int(type), _config.dram_req_size); + inst->dec_waiting_request(); + }); + } + issued = true; + _dma_finished_queue.push_back(std::move(inst)); + } + break; + case Opcode::MOVOUT: + { + auto acc_type = mem_access_type::GLOBAL_ACC_W; + auto type = mf_type::WRITE_REQUEST; + spdlog::trace("[StonneCore {}][{}][{}] {} ISSUED", _id, subcore_id, _core_cycle, + opcode_to_string(inst->get_opcode())); + for (auto addr : inst->get_trace_address()) { + addr = addr - (addr & _config.dram_req_size-1); + inst->inc_waiting_request(); + std::tuple key = std::make_tuple(addr, acc_type, type, allocTrafficID()); + uint64_t current_time = _core_cycle; + registerMemfetch(key, [this, inst, addr, current_time, type]() { + spdlog::trace("[SparseCore][{}] Round Trip Cycle: {}, Address: {:#x}, Request Type: {}, DRAM Req Size: {}", \ + this->_core_cycle, this->_core_cycle - current_time, addr, int(type), _config.dram_req_size); + inst->dec_waiting_request(); + }); + } + issued = true; + finish_instruction(inst); + _dma_finished_queue.push_back(std::move(inst)); + } + break; + case Opcode::COMP: + { + auto& target_pipeline = get_compute_pipeline(0); + if (target_pipeline.empty()) + inst->finish_cycle = _core_cycle + inst->get_compute_cycle(); + else + inst->finish_cycle = target_pipeline.back()->finish_cycle + inst->get_compute_cycle(); + spdlog::trace("[Core {}][{}][{}] {} ISSUED, finsh at {}", _id, subcore_id, _core_cycle, + opcode_to_string(inst->get_opcode()), inst->finish_cycle); + target_pipeline.push(inst); + issued = true; + } + break; + default: + spdlog::error("Undefined instruction opcode type"); + exit(EXIT_FAILURE); + } + if (issued) { + instructions.erase(std::find(instructions.begin(), instructions.end(), inst)); + } + } +} + +void SparseCore::cycle() { + _core_cycle++; + /* Handle core cycle*/ + for (uint32_t subcore_id=0; subcore_idget_addr(), int(req_pair.second->get_access_type()), int(req_pair.second->get_type()), + _config.dram_req_size, nr_request); + nr_request++; + break; + } + } + + // Send Memory Response + nr_request = 0; + while (!_response_queue.empty()) { + mem_fetch* resp_wrapper = _response_queue.front(); + auto* callbacks = static_cast>*>(resp_wrapper->get_custom_data()); + for (int i=0; isize(); i++) { + (*callbacks).at(i)(); + } + delete callbacks; + delete resp_wrapper; + _response_queue.pop(); + if (++nr_request > w_port_nr) + break; + } + + /* Check print stat */ + if(_config.core_print_interval && _core_cycle % _config.core_print_interval == 0) + print_current_stats(); +} + +bool SparseCore::has_memory_request() { + return !_request_queue.empty(); +} + +void SparseCore::pop_memory_request() { + _request_queue.pop(); +} + +void SparseCore::push_memory_response(mem_fetch* response) { + _response_queue.push(response); +} + +void SparseCore::print_current_stats() { + spdlog::info("========= Sparse Core stat ========="); + for (size_t i = 0; i < stonneCores.size(); ++i) { + if (!isTraceMode(i)) { + MSwitchStats stats = stonneCores.at(i)->getMSStats(); + stats -= percore_total_stat.at(i); + percore_stat.at(i) = stats; + percore_total_stat.at(i) = stonneCores.at(i)->getMSStats(); + } else { + percore_total_stat.at(i) += percore_stat.at(i); + } + cycle_type nr_mul = percore_stat.at(i).n_multiplications; + percore_stat.at(i).reset(); + spdlog::info("Stonne Core [{}][{}] : nr_multiplications: {}", _id, i, nr_mul); + } + spdlog::info("Stonne Core [{}] : Total cycle {}", _id, _core_cycle); +} + +void SparseCore::print_stats() { + spdlog::info("========= Sparse Core stat ========="); + for (size_t i = 0; i < stonneCores.size(); ++i) { + if (!isTraceMode(i)) { + MSwitchStats stats = stonneCores.at(i)->getMSStats(); + stats -= percore_total_stat.at(i); + percore_stat.at(i) = stats; + percore_total_stat.at(i) = stats; + } else { + percore_total_stat.at(i) += percore_stat.at(i); + } + cycle_type nr_mul = percore_total_stat.at(i).n_multiplications; + spdlog::info("Stonne Core [{}][{}] : nr_multiplications: {}", _id, i, nr_mul); + } + spdlog::info("Stonne Core [{}] : Total cycle {}", _id, _core_cycle); +} + +std::shared_ptr SparseCore::pop_finished_tile() { + std::shared_ptr result = std::make_unique(Tile(Tile::Status::EMPTY)); + if (_finished_tiles.size() > 0) { + result = std::move(_finished_tiles.front()); + _finished_tiles.pop(); + } + return result; +} + +void SparseCore::finish_instruction(std::shared_ptr& inst) { + if (inst->finished) { + spdlog::error("[Core {}][{}] {} FINISHED, inst already finished!!", _id, _core_cycle, + opcode_to_string(inst->get_opcode())); + exit(EXIT_FAILURE); + } + inst->finish_instruction(); + static_cast(inst->get_owner())->inc_finished_inst(); + if (inst->get_opcode() == Opcode::COMP) { + spdlog::info("[StonneCore {}][{}] {} FINISHED", + _id, _core_cycle, opcode_to_string(inst->get_opcode())); + } else if (inst->get_opcode() == Opcode::MOVIN || inst->get_opcode() == Opcode::MOVOUT) { + spdlog::info("[StonneCore {}][{}] {} FINISHED, free_sram_size: {}", _id, _core_cycle, + opcode_to_string(inst->get_opcode()), inst->get_free_sram_size()); + } +} + +void SparseCore::registerMemfetch(const std::tuple& key, std::function callback) { + if (request_merge_table.find(key) == request_merge_table.end()) { + mem_fetch* req_wrapper = new mem_fetch(std::get<0>(key), std::get<1>(key), std::get<2>(key), _config.dram_req_size, -1); + + auto* callbacks = new std::vector>(); + req_wrapper->set_custom_data(static_cast(callbacks)); + request_merge_table[key] = req_wrapper; + } + mem_fetch* req_wrapper = request_merge_table[key]; + auto* callbacks = static_cast>*>(req_wrapper->get_custom_data()); + callbacks->push_back(callback); +} + +void SparseCore::dumpTrace(int stonne_core_id, const std::string& path) { + std::ofstream outFile(path); + if (!outFile) { + spdlog::error("[StonneCore] Failed to make trace dump file to \"{}\"", path); + return; + } + // Static nodes for the graph + outFile << "graph = {\n 0: {\n" + << " \"node_id\": 0,\n" + << " \"node_name\": \"root\",\n" + << " \"node_type\": 0,\n" + << " \"parents\": [],\n" + << " \"children\": [1]\n" + << " },\n" + << " 1: {\n" + << " \"node_id\": 1,\n" + << " \"node_name\": \"loopNode\",\n" + << " \"node_type\": 2,\n" + << " \"parents\": [0],\n" + << " \"children\": [2],\n" + << " \"loop_index\": \"loop_arg000\",\n" + << " \"loop_start\": 0,\n" + << " \"loop_end\": 1,\n" + << " \"loop_step\": 1,\n" + << " \"loop_type\": \"outer_loop\"" + << " },\n"; + + // Output traceNodeList + for (size_t i = 0; i < traceNodeList.at(stonne_core_id).size(); ++i) { + if (i != 0) outFile << ",\n"; + outFile << traceNodeList.at(stonne_core_id)[i]; + } + outFile << "\n}" << std::endl; + spdlog::info("[StonneCore] Success to save trace dump file to \"{}\"", path); +} diff --git a/PyTorchSimBackend/src/TMA.cc b/PyTorchSimBackend/src/TMA.cc index 03d88ce6..7744b0f5 100644 --- a/PyTorchSimBackend/src/TMA.cc +++ b/PyTorchSimBackend/src/TMA.cc @@ -11,32 +11,32 @@ TMA::TMA(uint32_t id, uint32_t dram_req_size) { void TMA::issue_tile(std::shared_ptr inst) { _current_inst = std::move(inst); std::vector& tile_size = _current_inst->get_tile_size(); - if (tile_size.size() != 2) { - spdlog::error("[TMA {}] issued tile is not [y,x] format..", _id); + if (tile_size.size() <= 0 || tile_size.size() > get_max_dim()) { + spdlog::error("[TMA {}] issued tile is not supported format..", _id); exit(EXIT_FAILURE); } _finished = false; } -std::vector TMA::get_memory_access() { - std::set addr_set = _current_inst->get_dram_address(_dram_req_size); - std::vector access_vec; +std::shared_ptr> TMA::get_memory_access() { + auto addr_set = _current_inst->get_dram_address(_dram_req_size); + auto access_vec = std::make_shared>(); Tile* owner = (Tile*)_current_inst->get_owner(); - TileSubGraph* owner_subgraph = owner->get_owner(); - spdlog::trace("[NUMA Trace] Subgraph id: {} , Numa id: {}, Arg: {} is_write: {}", - owner_subgraph->get_core_id(), _current_inst->get_numa_id(), _current_inst->get_addr_name(), _current_inst->is_dma_write()); - for (auto addr: addr_set) { - MemoryAccess* access = new MemoryAccess({ - .id = generate_mem_access_id(), - .dram_address = addr, - .size = _dram_req_size, - .write = _current_inst->is_dma_write(), - .request = true, - .numa_id = _current_inst->get_numa_id(), - .owner_instruction = _current_inst.get() - }); + std::shared_ptr owner_subgraph = owner->get_owner(); + unsigned long long base_daddr = _current_inst->get_base_dram_address(); + // Todo. We use a ternsor level buffer allocation, so we don't need to check all memfetch + bool is_cacheable = owner_subgraph->is_cacheable(base_daddr, base_daddr + _dram_req_size); + spdlog::trace("[SRAM Trace] Core-{}, Address: 0x{:016x}, Is_cacheable: {}", _id, base_daddr, is_cacheable); + spdlog::trace("[NUMA Trace] Core-{}, Subgraph id: {} , Numa id: {}, Arg: {} is_write: {}", + _id, owner_subgraph->get_core_id(), _current_inst->get_numa_id(), _current_inst->get_addr_name(), _current_inst->is_dma_write()); + + for (auto addr: *addr_set) { + mem_access_type acc_type = _current_inst->is_dma_write() ? mem_access_type::GLOBAL_ACC_W : mem_access_type::GLOBAL_ACC_R; + mf_type type = _current_inst->is_dma_write() ? mf_type::WRITE_REQUEST : mf_type::READ_REQUEST; + mem_fetch* access = new mem_fetch(addr, acc_type, type, _dram_req_size, _current_inst->get_numa_id(), static_cast(_current_inst.get())); + access->set_cacheable(is_cacheable); _current_inst->inc_waiting_request(); - access_vec.push_back(access); + access_vec->push_back(access); } _finished = true; return access_vec; diff --git a/PyTorchSimBackend/src/Tile.cc b/PyTorchSimBackend/src/Tile.cc index bb166ca0..2e05cb08 100644 --- a/PyTorchSimBackend/src/Tile.cc +++ b/PyTorchSimBackend/src/Tile.cc @@ -21,6 +21,7 @@ void Tile::append_instuction(std::shared_ptr& inst) { /* Move instructions */ _nr_insts++; inst->set_owner(this); + inst->set_owner_ready_queue(&_ready_queue); _instructions.push_back(inst); } diff --git a/PyTorchSimBackend/src/TileGraph.cc b/PyTorchSimBackend/src/TileGraph.cc index 48d76990..33e995e9 100644 --- a/PyTorchSimBackend/src/TileGraph.cc +++ b/PyTorchSimBackend/src/TileGraph.cc @@ -5,7 +5,6 @@ TileSubGraph::TileSubGraph() : _ready_tile_queue(), _tile_set(), _id(_next_id++) } void TileSubGraph::add_tile(std::shared_ptr tile) { - tile->set_ownwer(this); for (auto& inst : tile->get_instructions()) inst->subgraph_id = _id; if (tile->get_ready_counter() == 0) { @@ -48,6 +47,7 @@ std::shared_ptr TileSubGraph::get_tile() { void TileGraph::append_subgraph(std::shared_ptr subgraph) { + subgraph->init_cache_plan(_cache_plan); _subgraph_vec.push_back(std::move(subgraph)); } @@ -63,7 +63,6 @@ bool TileGraph::is_finished() { if (tile_pair.second != nullptr) finished &= tile_pair.second->is_finished(); } - return finished; } diff --git a/PyTorchSimBackend/src/TileGraphParser.cc b/PyTorchSimBackend/src/TileGraphParser.cc index b9ea2b08..9374dcb5 100644 --- a/PyTorchSimBackend/src/TileGraphParser.cc +++ b/PyTorchSimBackend/src/TileGraphParser.cc @@ -1,5 +1,18 @@ #include "TileGraphParser.h" +bool loadConfig(const std::string& config_path, json& config_json) { + std::ifstream config_file(config_path); + if (config_file.is_open()) { + config_file >> config_json; + config_file.close(); + spdlog::info("[LoadConfig] Success to open \"{}\"", config_path); + return true; + } else { + spdlog::error("[LoadConfig] Failed to open \"{}\"", config_path); + return false; + } +} + void printIndexMap(std::string prefix, const std::map& indexMap) { std::ostringstream oss; for (const auto& [key, value] : indexMap) { @@ -58,6 +71,8 @@ std::vector calc_output_idx(TileGraphParser* tog_parser, std::map(desc.operation)); + spdlog::debug("{} layer_name: {}", spaces, desc.layer_name); + spdlog::debug("{} mem_init: {}", spaces, desc.mem_init); + + // Convolution Parameters + spdlog::debug("{} R: {}, S: {}, C: {}, K: {}, G: {}, N: {}", spaces, desc.R, desc.S, desc.C, desc.K, desc.G, desc.N); + spdlog::debug("{} X: {}, Y: {}, X_: {}, Y_: {}, strides: {}", spaces, desc.X, desc.Y, desc.X_, desc.Y_, desc.strides); + + // Convolution Tile Parameters + spdlog::debug("{} T_R: {}, T_S: {}, T_C: {}, T_K: {}, T_G: {}, T_N: {}", spaces, desc.T_R, desc.T_S, desc.T_C, desc.T_K, desc.T_G, desc.T_N); + spdlog::debug("{} T_X_: {}, T_Y_: {}", spaces, desc.T_X_, desc.T_Y_); + + // GEMM Parameters + spdlog::debug("{} GEMM_K: {}, GEMM_N: {}, GEMM_M: {}", spaces, desc.GEMM_K, desc.GEMM_N, desc.GEMM_M); + spdlog::debug("{} GEMM_T_K: {}, GEMM_T_N: {}, GEMM_T_M: {}", spaces, desc.GEMM_T_K, desc.GEMM_T_N, desc.GEMM_T_M); + + // Memory Addresses + spdlog::debug("{} matrix_a_dram_address: {}", spaces, desc.matrix_a_dram_address); + spdlog::debug("{} matrix_b_dram_address: {}", spaces, desc.matrix_b_dram_address); + spdlog::debug("{} matrix_c_dram_address: {}", spaces, desc.matrix_c_dram_address); + spdlog::debug("{} mem_matrix_c_file_name: {}", spaces, desc.mem_matrix_c_file_name); + + // Bitmap and CSR Data + spdlog::debug("{} bitmap_matrix_a_init: {}", spaces, desc.bitmap_matrix_a_init); + spdlog::debug("{} bitmap_matrix_b_init: {}", spaces, desc.bitmap_matrix_b_init); + spdlog::debug("{} rowpointer_matrix_a_init: {}", spaces, desc.rowpointer_matrix_a_init); + spdlog::debug("{} colpointer_matrix_a_init: {}", spaces, desc.colpointer_matrix_a_init); + spdlog::debug("{} rowpointer_matrix_b_init: {}", spaces, desc.rowpointer_matrix_b_init); + spdlog::debug("{} colpointer_matrix_b_init: {}", spaces, desc.colpointer_matrix_b_init); + spdlog::debug("{} trace_path: {}", spaces, desc.trace_path); +} + void TileMemoryWaitNode::print_node() { TileNode::print_node(); std::string spaces(get_depth(), '\t'); - spdlog::debug("{} tag_list: {}", spaces, fmt::join(_tag_idx_list, ", ")); + spdlog::debug("{} tag_idx_list: {}", spaces, fmt::join(_tag_idx_list, ", ")); + spdlog::debug("{} tag_stride_list: {}", spaces, fmt::join(_tag_stride_list, ", ")); +} + +void TileStonneTraceComputeNode::print_node() { + TileNode::print_node(); + std::string spaces(get_depth(), '\t'); + spdlog::debug("{} ComputeCycle: {}", spaces, _cycle); +} + +void TileStonneTraceMemoryNode::print_node() { + TileNode::print_node(); + std::string spaces(get_depth(), '\t'); + spdlog::debug("{} Address: {}", spaces, fmt::join(trace_address, ", ")); } TileLoopNode::TileLoopNode(onnx::NodeProto& node) : TileNode(node) { @@ -254,59 +346,39 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa for (auto& tile_node: _body_node) { if (tile_node->get_type() == TileType::LOAD_NODE) { std::shared_ptr mem_node = std::static_pointer_cast(tile_node); - auto base_addr_name = mem_node->get_base_addr_name(); - std::vector& tag_idx_list = mem_node->get_tag_idx_list(); - std::vector skip_idx_list; - std::vector values; - bool skip = false; - /* Find axis */ - if (mem_node->is_async_node()) { - for (int i=0;i& pair) { return pair.second; }); - - for (auto axis : skip_idx_list) { - if (values.at(iter.size() - tag_idx_list.size() + axis) != 0) { - skip = true; - break; - } - } - - /* Skip this node */ - if (skip) - continue; - } - - /* Lookup given name's address */ - addr_type base_addr = tog_parser->lookup(base_addr_name); std::vector iter_list; - std::vector tag_list; - std::vector loop_size_list; - std::vector outer_loop_idx; - std::vector outer_loop_size; int nr_inner_loop = 0; auto& loop_idx_list = mem_node->get_loop_idx_list(); for (auto loop_idx: loop_idx_list) { - auto iter_value = getLoopIndexValue(iter, loop_idx); + int iter_value = getLoopIndexValue(iter, loop_idx); iter_list.push_back(iter_value); - loop_size_list.push_back(tog_parser->get_loop_size(loop_idx)); if (tog_parser->get_loop_type(loop_idx)==LoopType::INNER_LOOP) nr_inner_loop++; } - /* Add accumulation loop info to tag list */ + + /* Base address setting */ + std::string base_addr_name = mem_node->get_base_addr_name(); + int base_addr_id = tog_parser->register_addr_name(base_addr_name); + addr_type base_addr = tog_parser->lookup(base_addr_name); + addr_type offset = std::inner_product(iter_list.begin(), iter_list.end(), mem_node->get_loop_stride_list().begin(), 0); + + std::vector tag_list; + std::vector accum_tag_list; + std::vector outer_loop_idx; + std::vector outer_loop_size; + /* Add accumulation loop info to accum_tag list */ for (auto loop_idx = loop_idx_list.begin(); loop_idx != loop_idx_list.end() - nr_inner_loop; ++loop_idx) { // Check loop type and process if (tog_parser->get_loop_type(*loop_idx)==LoopType::ACCUMULATION_LOOP) { auto iter_value = getLoopIndexValue(iter, *loop_idx); - tag_list.push_back(iter_value); + accum_tag_list.push_back(iter_value); } } + /* Default accum tag */ + if (accum_tag_list.empty()) { + accum_tag_list.push_back(0); + } for (auto loop_idx = loop_idx_list.begin(); loop_idx != loop_idx_list.end(); ++loop_idx) { @@ -318,11 +390,14 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa } } + uint32_t systolic_size = std::stoi(tog_parser->getMetaByName("systolic_size")); for (auto loop_idx: mem_node->get_tag_idx_list()) { if (iter.find(loop_idx) == iter.end()) tag_list.push_back(0); else { - auto iter_value = getLoopIndexValue(iter, loop_idx); + uint32_t step = (uint32_t)tog_parser->get_loop_step(loop_idx); + step = step > systolic_size ? systolic_size : step; + auto iter_value = getLoopIndexValue(iter, loop_idx) / step; tag_list.push_back(iter_value); } } @@ -336,27 +411,47 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa numa_id = total_idx / stride_idx; } + /* Check need to make this memory node */ + std::vector& tag_stride_list = mem_node->get_tag_stride_list(); + std::vector key = tog_parser->calc_tag(accum_tag_list, tag_list, tag_stride_list); + if (tog_parser->check_memory_tag(base_addr_name, key)) + continue; + tog_parser->register_memory_tag(base_addr_name, key); + printIndexMap("[TOGParser] Load Node " + mem_node->get_base_addr_name() + " Numa_id: " + std::to_string(numa_id), iter); + spdlog::trace("[TOGParser] Load Node {}({}) key = [{}], accum = [{}], tag = [{}], stride = [{}]", mem_node->get_base_addr_name(), + base_addr_id, + fmt::join(key, ", "), + fmt::join(accum_tag_list, ", "), + fmt::join(tag_list, ", "), + fmt::join(tag_stride_list, ", ")); std::shared_ptr inst = std::make_shared( Opcode::MOVIN, 0, - 0, base_addr, - mem_node->get_tile_size(), mem_node->get_precision(), iter_list, - mem_node->get_stride_list(), tag_list, loop_size_list + 0, base_addr+offset, + mem_node->get_tile_size(), mem_node->get_tile_stride(), mem_node->get_precision(), + tag_list, tag_stride_list, accum_tag_list ); - inst->set_addr_name(base_addr_name); + inst->set_addr_name(base_addr_name, base_addr_id); + inst->prepare_tag_key(); inst->set_nr_inner_loop(nr_inner_loop); - inst->adjust_dram_address(); inst->set_is_async(mem_node->is_async_node()); inst->set_numa_id(numa_id); + + if (mem_node->is_indirect()) { + inst->set_indirect_index_path(tog_parser->get_indirect_path()); + tog_parser->inc_indirect_counter(); + } else { + bool is_sparse_tile = tog_parser->is_sparse_tile(tog_parser->get_dma_counter()); + tog_parser->inc_dma_counter(); + if (is_sparse_tile) { + inst->set_sparse_state(is_sparse_tile); + } + } link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); } else if (tile_node->get_type() == TileType::STORE_NODE) { std::shared_ptr mem_node = std::static_pointer_cast(tile_node); - auto base_addr_name = mem_node->get_base_addr_name(); - /* Lookup given name's address */ - addr_type base_addr = tog_parser->lookup(base_addr_name); std::vector iter_list; - std::vector loop_size_list; std::vector outer_loop_idx; std::vector outer_loop_size; int nr_inner_loop = 0; @@ -364,7 +459,6 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa for (auto loop_idx: loop_idx_list) { auto iter_value = getLoopIndexValue(iter, loop_idx); iter_list.push_back(iter_value); - loop_size_list.push_back(tog_parser->get_loop_size(loop_idx)); if (tog_parser->get_loop_type(loop_idx)==LoopType::INNER_LOOP) nr_inner_loop++; if (tog_parser->get_loop_type(loop_idx)==LoopType::PARALLEL_LOOP) { @@ -375,6 +469,12 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa } } + /* Lookup given name's address */ + std::string base_addr_name = mem_node->get_base_addr_name(); + int base_addr_id = tog_parser->register_addr_name(base_addr_name); + addr_type base_addr = tog_parser->lookup(base_addr_name); + addr_type offset = std::inner_product(iter_list.begin(), iter_list.end(), mem_node->get_loop_stride_list().begin(), 0); + /* Calc numa id */ int numa_id = 0; auto numa_stride_size = tog_parser->lookupNumaInfo(base_addr_name).size(); @@ -387,77 +487,92 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa printIndexMap("[TOGParser] Store Node " + mem_node->get_base_addr_name() + " Numa_id: " + std::to_string(numa_id), iter); std::shared_ptr inst = std::make_shared( Opcode::MOVOUT, 0, - 0, base_addr, - mem_node->get_tile_size(), mem_node->get_precision(), iter_list, - mem_node->get_stride_list(), std::vector(), loop_size_list + 0, base_addr+offset, + mem_node->get_tile_size(), mem_node->get_tile_stride(), mem_node->get_precision(), + std::vector(1), mem_node->get_tag_stride_list(), std::vector() ); - inst->set_addr_name(base_addr_name); + inst->set_addr_name(base_addr_name, base_addr_id); + inst->prepare_tag_key(); inst->set_nr_inner_loop(nr_inner_loop); - inst->adjust_dram_address(); inst->set_is_async(mem_node->is_async_node()); inst->set_numa_id(numa_id); + if (mem_node->is_indirect()) { + inst->set_indirect_index_path(tog_parser->get_indirect_path()); + tog_parser->inc_indirect_counter(); + } link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); } else if (tile_node->get_type() == TileType::MEMORY_WAIT_NODE) { printIndexMap("[TOGParser] DMA Wait Node ", iter); std::shared_ptr wait_node = std::static_pointer_cast(tile_node); auto base_addr_name = wait_node->get_base_addr_name(); + int base_addr_id = tog_parser->register_addr_name(base_addr_name); addr_type base_addr = tog_parser->lookup(base_addr_name); /* Lookup given name's address */ std::vector iter_list; std::vector tag_list; + std::vector& tag_stride_list = wait_node->get_tag_stride_list(); + std::vector& tag_divider_list = wait_node->get_tag_divider_list(); + std::vector new_tag_stride_list; + std::vector accum_tag_list; auto& wait_tag_list = wait_node->get_tag_idx_list(); - int inner_step = std::stoi(tog_parser->getMetaByName("systolic_size")); - /* Add accumulation loop info to tag list */ - for (auto loop_idx = iter.begin(); loop_idx != iter.end(); ++loop_idx) { - /* FIXME. Used heuristic that wait_tag_size has 2d dim */ - if (tog_parser->get_loop_type(loop_idx->first)==LoopType::ACCUMULATION_LOOP && wait_tag_list.size() != 2) { - tag_list.push_back(loop_idx->second); - } - } - for (auto loop_idx: wait_tag_list) { - if (iter.find(loop_idx) == iter.end()) + for (int i=0; iget_loop_type(loop_idx)==LoopType::ACCUMULATION_LOOP) { + auto iter_value = getLoopIndexValue(iter, loop_idx); + accum_tag_list.push_back(iter_value); + } else { + auto iter_value = getLoopIndexValue(iter, loop_idx) / tag_divider_list.at(i); tag_list.push_back(iter_value); } } + /* Default accum tag */ + if (accum_tag_list.empty()) { + accum_tag_list.push_back(0); + } + + /* Skip accum stride */ + for (auto i : tag_stride_list) { + if (i!=-1) + new_tag_stride_list.push_back(i); + } + + spdlog::trace("[TOGParser] Wait Node {}, accum = [{}], tag = [{}], stride = [{}]", wait_node->get_base_addr_name(), + fmt::join(accum_tag_list, ", "), + fmt::join(tag_list, ", "), + fmt::join(new_tag_stride_list, ", ")); std::shared_ptr inst = std::make_shared( Opcode::BAR, 0, 0, base_addr, - std::vector(), 0, iter_list, - iter_list, tag_list, std::vector() + std::vector(), std::vector(), 0, + tag_list, new_tag_stride_list, accum_tag_list ); - inst->set_addr_name(base_addr_name); + inst->set_addr_name(base_addr_name, base_addr_id); + inst->prepare_tag_key(); link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); } else if (tile_node->get_type() == TileType::COMPUTE_NODE) { printIndexMap("[TOGParser] Compute Node ", iter); std::shared_ptr compute_node = std::static_pointer_cast(tile_node); - std::vector iter_list; + std::vector tag_list = {0}; + std::vector tag_stride_list = {1}; + std::vector accum_tag_list; std::shared_ptr inst = std::make_shared( Opcode::COMP, compute_node->get_cycle(), 0, 0, - std::vector(), 0, iter_list, iter_list, - std::vector(), std::vector() + std::vector(), std::vector(), 0, + tag_list, tag_stride_list, accum_tag_list ); inst->set_overlapping_cycle(compute_node->get_overlapping_cycle()); inst->set_compute_type(compute_node->get_compute_type()); - /* Check should we have to skip */ - auto output_idx_list = calc_output_idx(tog_parser, iter); // (M,N,K) order - if (compute_node->get_compute_type() == 1 && output_idx_list.size() == 3) { // FIXME. hardcoded type - bool skip = find_output_idx(tog_parser, output_idx_list); - if (skip) { - inst->set_compute_cycle(0); - inst->set_overlapping_cycle(0); - spdlog::trace("[TOGParser/Sparse] Skip output tile index: {}", fmt::join(output_idx_list, ",")); - } - } - link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); } else if (tile_node->get_type() == TileType::LOOP_INDEX_NODE) { @@ -519,6 +634,35 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa parent->append_child(child); /* Create new tile */ tile_vec.push_back(child); + } else if (tile_node->get_type() == TileType::STONNE_NODE) { + printIndexMap("[TOGParser] Stonne Node ", iter); + std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); + std::shared_ptr inst = std::make_shared(Opcode::COMP); + link_map[tile_node] = inst; + tile_vec.back()->append_instuction(inst); + tile_vec.back()->set_custom_data(stonne_node->getDesc()); + tile_vec.back()->set_stonne_tile(true); + } else if (tile_node->get_type() == TileType::STONNE_TRACE_COMPUTE_NODE) { + std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); + std::shared_ptr inst = std::make_shared(Opcode::COMP); + inst->set_compute_cycle(stonne_node->get_cycle()); + link_map[tile_node] = inst; + tile_vec.back()->append_instuction(inst); + tile_vec.back()->set_stonne_tile(true); + } else if (tile_node->get_type() == TileType::STONNE_TRACE_LOAD_NODE) { + std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); + std::shared_ptr inst = std::make_shared(Opcode::MOVIN); + inst->set_trace_address(stonne_node->get_address()); + link_map[tile_node] = inst; + tile_vec.back()->append_instuction(inst); + tile_vec.back()->set_stonne_tile(true); + } else if (tile_node->get_type() == TileType::STONNE_TRACE_STORE_NODE) { + std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); + std::shared_ptr inst = std::make_shared(Opcode::MOVOUT); + inst->set_trace_address(stonne_node->get_address()); + link_map[tile_node] = inst; + tile_vec.back()->append_instuction(inst); + tile_vec.back()->set_stonne_tile(true); } } @@ -555,20 +699,22 @@ void TileLoopNode::print_node() { spdlog::debug("{} stride: {} ", spaces, _stride); } -TileGraphParser::TileGraphParser(std::string onnx_path, json& attribute_json) { +TileGraphParser::TileGraphParser(std::string onnx_path, std::string attribute_path) { + loadConfig(attribute_path, _attribute_json); + _attribute_path = attribute_path; + /* Note: this parsing algorithm assume that all node are sorted in topological-order */ std::ifstream model_istream(onnx_path); google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream); onnx::ModelProto model_proto; /* Attribute parsing */ - _attribute_json = attribute_json; if (_attribute_json.contains("address_info")) { auto address_info = _attribute_json["address_info"]; for (auto it = address_info.begin(); it != address_info.end(); ++it) { uint64_t value = it.value(); _arg_to_address[it.key()] = value; - spdlog::info("[TOGParser] Address Attribute key: {} address: 0x{:x}", it.key(), value); + spdlog::info("[TOGParser/Attribute] Address Attribute key: {} address: 0x{:x}", it.key(), value); } } if (_attribute_json.contains("address_numa_stride")) { @@ -578,9 +724,22 @@ TileGraphParser::TileGraphParser(std::string onnx_path, json& attribute_json) { for (auto value : value_list) { _arg_numa_stride[it.key()].push_back(value); } - spdlog::info("[TOGParser] Address numa info key: {} numa stride : {}", it.key(), fmt::join(_arg_numa_stride[it.key()], ", ")); + spdlog::info("[TOGParser/Attribute] Address numa info key: {} numa stride : {}", it.key(), fmt::join(_arg_numa_stride[it.key()], ", ")); + } + } + if (_attribute_json.contains("sram_alloc") and _attribute_json.contains("l2d_type") and _attribute_json["l2d_type"] == "datacache") { + auto sram_alloc_list = _attribute_json["sram_alloc"]; + spdlog::info("[TOGParser/Attribute] ================= SRAM Alloc Plan ================"); + for (auto it = sram_alloc_list.begin(); it != sram_alloc_list.end(); ++it) { + auto value_list = it.value(); + unsigned long long start = value_list.at(0); + unsigned long long end = value_list.at(1); + spdlog::info("[TOGParser/Attribute] {:16s}: 0x{:016x} ~ 0x{:016x}", it.key(), start, end); + Interval entry = {start, end, 0}; + _cache_plan.push_back(entry); } } + load_sparse_meta_data(); /* ONNX file parsing */ _tog_path = onnx_path; @@ -630,6 +789,22 @@ TileGraphParser::TileGraphParser(std::string onnx_path, json& attribute_json) { std::shared_ptr tile_node = std::make_shared(node_proto); /* Register output */ register_tile(tile_node); + } else if (type == TileType::STONNE_NODE) { + std::shared_ptr tile_node = std::make_shared(node_proto); + /* Register output */ + register_tile(tile_node); + } else if (type == TileType::STONNE_TRACE_COMPUTE_NODE) { + std::shared_ptr tile_node = std::make_shared(node_proto); + /* Register output */ + register_tile(tile_node); + } else if (type == TileType::STONNE_TRACE_LOAD_NODE) { + std::shared_ptr tile_node = std::make_shared(node_proto); + /* Register output */ + register_tile(tile_node); + } else if (type == TileType::STONNE_TRACE_STORE_NODE) { + std::shared_ptr tile_node = std::make_shared(node_proto); + /* Register output */ + register_tile(tile_node); } } @@ -639,6 +814,10 @@ TileGraphParser::TileGraphParser(std::string onnx_path, json& attribute_json) { } _tile_graph = std::make_unique(TileGraph(onnx_path, graph_name)); + _tile_graph->init_cache_plan(_cache_plan); + if (std::stoi(this->getMetaByName("stonneGraph"))) + _tile_graph->StonneGraph=true; + /* Generate subgraph */ if (_loop_nodes.empty()) { spdlog::warn("[TileGraphParser] Null Kernel \"{}\"", onnx_path); @@ -663,15 +842,17 @@ TileGraphParser::TileGraphParser(std::string onnx_path, json& attribute_json) { /* Iterate outer loop and initialize inner loop */ for (auto iter=_tile_graph->begin(); iter!=_tile_graph->end(); ++iter) { std::shared_ptr subgraph = std::make_shared(); - subgraph->set_core_id(getCoreIdFromJson(attribute_json, subgraph->get_id())); + subgraph->set_core_id(getCoreIdFromJson(_attribute_json, subgraph->get_id())); auto indices = iter.get_indices(); for (auto loop : _loop_nodes.at(last_outer_idx)) { std::shared_ptr outer_loop = std::static_pointer_cast(loop); + this->clear_tag_table(); // Clear tag table for each inner loop std::vector> sub_tiles = outer_loop->get_tiles_from_iter(this, indices); /* insert tiles to subgraph */ for (const auto& sub_tile: sub_tiles){ subgraph->add_tile(sub_tile); + sub_tile->set_owner(subgraph); } } /* insert subgraph to graph */ @@ -720,6 +901,26 @@ void TileGraphParser::register_tile(std::shared_ptr tile_node) { } } +std::vector TileGraphParser::calc_tag(std::vector& accum_tag, std::vector& tag_idx, std::vector& tag_stride) { + int key_offset = 0; + std::vector tag_key; + for (int i=0; i& tag_key) { + assert(_tag_table.find(std::make_pair(name, tag_key))==_tag_table.end()); + _tag_table[std::make_pair(name, tag_key)] = true; +} + +bool TileGraphParser::check_memory_tag(std::string name, std::vector& tag_key) { + return _tag_table.find(std::make_pair(name, tag_key))==_tag_table.end() ? false : true; +} + std::shared_ptr TileGraphParser::get_top_loop() { if (_loop_nodes.empty()) return nullptr; diff --git a/PyTorchSimBackend/src/TileGraphParser.h b/PyTorchSimBackend/src/TileGraphParser.h deleted file mode 100644 index 36ec8091..00000000 --- a/PyTorchSimBackend/src/TileGraphParser.h +++ /dev/null @@ -1,170 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include -#include "TileGraph.h" -#include "Instruction.h" -#include "onnx/defs/schema.h" -#include "onnx/onnx-operators_pb.h" -#include "onnx/onnx_pb.h" - -using json = nlohmann::json; - -enum class TileType{ - LOOP_INDEX_NODE, - LOOP_END_NODE, - LOAD_NODE, - STORE_NODE, - COMPUTE_NODE, - MEMORY_WAIT_NODE -}; - -enum class LoopType { - NORMAL_LOOP, - PARALLEL_LOOP, - ACCUMULATION_LOOP, - INNER_LOOP -}; - -class TileNode { - public: - TileNode(onnx::NodeProto& node); - static TileType get_tile_type(std::string type); - void add_child(std::shared_ptr child) { _child.push_back(std::move(child)); } - std::vector>& get_child() { return _child; } - void add_parent(std::shared_ptr parent) { _parent.push_back(std::move(parent)); } - std::vector>& get_parent() { return _parent; } - std::vector& get_child_name() { return _child_name; } - std::vector& get_parent_name() { return _parent_name; } - TileType get_type() { return _type; } - std::shared_ptr get_owner_loop() { return _owner_loop; } - std::string get_name() { return _name; } - void set_owner_loop(std::shared_ptr owner) { _owner_loop=std::move(owner); } - virtual void print_node(); - void set_depth(int depth) { _depth=depth; } - int get_depth() { return _depth; } - - private: - std::vector> _parent; - std::vector> _child; - std::vector _parent_name; - std::vector _child_name; - std::shared_ptr _owner_loop; - std::string _name; - int _depth; - TileType _type; -}; - -class TileGraphParser { - public: - TileGraphParser(std::string onnx_path, json& attribute_json); - std::shared_ptr get_top_loop(); - std::unique_ptr& get_tile_graph() { return _tile_graph; } - addr_type lookup(std::string key); - void register_loop(std::shared_ptr); - void increase_loop_top() { _loop_stack_pointer++; } - void decrease_loop_top() { _loop_stack_pointer--; } - int get_loop_size(std::string key) { return std::get<0>(_loop_size_map[key]); } - int get_loop_step(std::string key) { return std::get<1>(_loop_size_map[key]); } - LoopType get_loop_type(std::string key) { return std::get<2>(_loop_size_map[key]); } - const std::map> & get_loop_map() { return _loop_size_map; } - const std::vector &lookupNumaInfo(std::string key); - int getCoreIdFromJson(const json& attribute_json, int subgraph_id); - std::string getMetaByName(std::string key) { return _tog_meta[key]; } - const json& get_attribute_file() { return _attribute_json; } - private: - void register_tile(std::shared_ptr tile_node); - void _tile_generate() {} - void _base_addr_update() {} - void _tile_index_generate() {} - int _loop_stack_pointer = 0; - - json _attribute_json; - std::string _tog_path; - std::map> _output_map; - std::vector>> _loop_nodes; - std::vector> _tile_vec; - std::unique_ptr _tile_graph; - std::map _arg_to_address; - std::map> _arg_numa_stride; - std::map> _loop_size_map; - std::map _tog_meta; -}; - -class TileComputeNode : public TileNode { - public: - TileComputeNode(onnx::NodeProto& node); - uint32_t get_cycle() { return _cycle; } - uint32_t get_overlapping_cycle() { return _overlapping_cycle; } - int get_compute_type() { return _compute_type; } - void print_node(); - - private: - std::map> tile_map; - uint32_t _cycle; - uint32_t _overlapping_cycle = 0; - int _compute_type; -}; - -class TileMemoryNode : public TileNode { - public: - TileMemoryNode(onnx::NodeProto& node); - std::string get_base_addr_name() { return _base_addr_name; } - size_t get_precision() { return _element_size; } - std::vector get_tile_size() { return _tile_size; } - std::vector& get_stride_list () { return _stride_list; } - std::vector& get_tag_idx_list() { return _tag_idx_list; } - std::vector& get_loop_idx_list() { return _loop_idx_list; } - bool is_async_node() { return _is_async; } - void print_node() override; - - private: - std::vector _tile_size; - std::vector _stride_list; - size_t _element_size; - bool _is_async; - std::string _base_addr_name; - std::vector _tag_idx_list; - std::vector _loop_idx_list; -}; - -class TileMemoryWaitNode : public TileNode { - public: - TileMemoryWaitNode(onnx::NodeProto& node); - std::string get_base_addr_name() { return _base_addr_name; } - std::vector& get_tag_idx_list() { return _tag_idx_list; } - void print_node() override; - - private: - std::vector _tag_idx_list; - std::string _base_addr_name; -}; - - - -class TileLoopNode : public TileNode { - public: - TileLoopNode(onnx::NodeProto& node); - void add_body(std::shared_ptr body) { _body_node.push_back(body); } - std::vector> get_tiles_from_iter(TileGraphParser*, std::map&); - std::string get_idx_name() { return _tile_index_name; } - uint64_t get_start() { return _start; } - uint64_t get_stride() { return _stride; } - uint64_t get_end() { return _end; } - LoopType get_loop_type() { return _loop_type; } - void print_node() override; - private: - std::string _tile_index_name; - uint64_t _stride; - uint64_t _start; - uint64_t _end; - LoopType _loop_type; - std::vector> _body_node; -}; - -class TileLoopEndNode : public TileNode { - public: - TileLoopEndNode(onnx::NodeProto& node) : TileNode(node) {} -}; diff --git a/PyTorchSimBackend/src/main.cc b/PyTorchSimBackend/src/main.cc index 67f19d6d..ecdd85aa 100644 --- a/PyTorchSimBackend/src/main.cc +++ b/PyTorchSimBackend/src/main.cc @@ -12,24 +12,8 @@ namespace po = boost::program_options; const char* env_value = std::getenv("BACKENDSIM_DRYRUN"); bool isDryRun = (env_value != nullptr && std::string(env_value) == "1"); -bool loadConfig(const std::string& config_path, json& config_json) { - std::ifstream config_file(config_path); - if (config_file.is_open()) { - config_file >> config_json; - config_file.close(); - spdlog::info("[LoadConfig] Success to open \"{}\"", config_path); - return true; - } else { - spdlog::error("[LoadConfig] Failed to open \"{}\"", config_path); - return false; - } -} - void launchKernel(Simulator* simulator, std::string onnx_path, std::string attribute_path, cycle_type request_time=0, int partiton_id=0) { - json attribute_json; - loadConfig(attribute_path, attribute_json); - - auto graph_praser = TileGraphParser(onnx_path, attribute_json); + auto graph_praser = TileGraphParser(onnx_path, attribute_path); std::unique_ptr& tile_graph = graph_praser.get_tile_graph(); tile_graph->set_arrival_time(request_time ? request_time : simulator->get_core_cycle()); spdlog::info("[Scheduler {}] Register graph path: {} operation: {} at {}", partiton_id, onnx_path, tile_graph->get_name(), simulator->get_core_cycle()); @@ -39,7 +23,9 @@ void launchKernel(Simulator* simulator, std::string onnx_path, std::string attri Simulator* create_simulator(std::string config_path) { json config_json; - loadConfig(config_path, config_json); + if(!loadConfig(config_path, config_json)) { + exit(1); + } SimulationConfig config = initialize_config(config_json); auto simulator = new Simulator(config); return simulator; @@ -87,7 +73,7 @@ void interactive_mode(Simulator* simulator) { cycle_type current_cycle = simulator->get_core_cycle(); std::cerr << "Current cycle: " << current_cycle << std::endl; }else if (token == "quit") { - spdlog::info("Exiting BackendSim."); + std::cerr << "Quit" << std::endl; break; } else { spdlog::error("Error: unknown command {} Available commands are: launch, until, quit.", token); @@ -95,6 +81,7 @@ void interactive_mode(Simulator* simulator) { if (isDryRun) std::cout << "[" << simulator->get_core_cycle() << "] BackendSim> "; } + simulator->cycle(); if (simulator->get_core_cycle()==0) simulator->until(0); simulator->print_core_stat(); @@ -157,6 +144,7 @@ int main(int argc, char** argv) { /* Get onnx_path, attribute from user input, request_time */ interactive_mode(simulator); } + delete simulator; /* Simulation time measurement */ auto end = std::chrono::high_resolution_clock::now(); diff --git a/PyTorchSimBackend/src/scheduler/Scheduler.cc b/PyTorchSimBackend/src/scheduler/Scheduler.cc index 5bf2f6ee..bb5d29cf 100644 --- a/PyTorchSimBackend/src/scheduler/Scheduler.cc +++ b/PyTorchSimBackend/src/scheduler/Scheduler.cc @@ -11,10 +11,12 @@ void Scheduler::schedule_graph(std::unique_ptr tile_graph) { refresh_status(); } -const std::shared_ptr Scheduler::peek_tile(int core_id, int slot_id) { +const std::shared_ptr Scheduler::peek_tile(int core_id, int slot_id, CoreType ctype) { if (_tile_graph.empty() || _tile_graph.at(0)->get_arrival_time() > *_core_cycle) return std::make_unique(Tile(Tile::Status::EMPTY)); - return _tile_graph.at(0)->peek_tile(core_id, slot_id); + if ((!_tile_graph.at(0)->StonneGraph && ctype == CoreType::WS_MESH) || (_tile_graph.at(0)->StonneGraph && ctype == CoreType::STONNE)) + return _tile_graph.at(0)->peek_tile(core_id, slot_id); + return std::make_unique(Tile(Tile::Status::EMPTY)); } std::shared_ptr Scheduler::get_tile(int core_id, int slot_id) { diff --git a/PyTorchSimFrontend/common_diff.py b/PyTorchSimFrontend/common_diff.py deleted file mode 100644 index 6c1c875c..00000000 --- a/PyTorchSimFrontend/common_diff.py +++ /dev/null @@ -1,1031 +0,0 @@ -import contextlib -import dataclasses -import functools -import itertools -import logging -import operator -import re -from collections import namedtuple -from itertools import chain -from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Set, Union - -import sympy -from sympy.printing.printer import Printer - -import torch -import torch.fx -from torch.utils._sympy.value_ranges import ValueRanges - -from .. import metrics -from ..utils import ( - DeferredLineBase, - free_symbol_startswith, - get_sympy_Expr_dtype, - IndentedBuffer, - sympy_dot, - sympy_subs, - unique, -) -from ..virtualized import ops, OpsValue, V - -schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") - - -def data_type_logger(msg): - if schedule_log.isEnabledFor(logging.DEBUG): - schedule_log.debug("Data type propagation: %s", msg) - - -TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype"]) -SizeArg = namedtuple("SizeArg", ["name", "expr"]) - -DeviceCodegen = namedtuple("DeviceCodegen", ["scheduling", "wrapper_codegen"]) -device_codegens: Dict[str, DeviceCodegen] = {} - - -# The code generated by Inductor consists of two main parts: kernel code and wrapper code. -# For any new backend looking to integrate with Inductor, customization of these two main -# parts are necessary to generate its specific code. -# -# Kernel code generation is determined by different Scheduling. Consequently, a new -# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, -# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. -# -# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code -# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen, -# and override specific member functions to create backend-specific Python wrapper code. -# -# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part -# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces -# provide flexibility to the backend. A backend can choose to implement these classes from scratch, -# or reuse them by extending and overriding as necessary. And Inductor provides the registration API, -# register_backend_for_device, to equip a new backend at runtime. -# -# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. -# This backend can be used as a reference: -# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 -def register_backend_for_device( - device: str, device_scheduling: type, device_wrapper_codegen: type -): - device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen) - - -def get_scheduling_for_device(device: str): - return device_codegens[device].scheduling if device in device_codegens else None - - -def get_wrapper_codegen_for_device(device: str): - return ( - device_codegens[device].wrapper_codegen if device in device_codegens else None - ) - - -def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes): - from ..ir import FlexibleLayout - - # added contiguous index prevents reordering - return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] - - -@functools.lru_cache(None) -def boolean_ops(): - return ( - "is_inf", - "is_nan", - "bitwise_xor", - "logical_not", - "signbit", - "le", - "lt", - "ge", - "gt", - "eq", - "ne", - ) - - -DTYPE_TO_COMPUTATION_DTYPE = { - torch.bfloat16: torch.float, - torch.float16: torch.float, - **{ - dtype: dtype - for dtype in [ - torch.bool, - torch.float32, - torch.float64, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - ] - }, -} - - -class DataTypePropagation: - def __init__(self, body) -> None: - self.body = body - self.graphs: Dict[Union[Callable[..., Any], str], Any] = { - "root": body.root_block.graph - } - for k, v in body.subblocks.items(): - self.graphs[k] = v.graph - - def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): - inputs = node.all_input_nodes - input_nodes = [ - n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" - ] - if len(input_nodes) == 0: - return None - - all_input_nodes_propogated = all( - OptimizationContext.key in n.meta - and n.meta[OptimizationContext.key].dtype is not None - for n in input_nodes - ) - if not all_input_nodes_propogated: - return None - - return functools.reduce( - torch.promote_types, - [n.meta[OptimizationContext.key].dtype for n in input_nodes], - ) - - def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node): - sub_graph = self.graphs[node.target] - dtype = self.propagate_graph(sub_graph) - assert dtype - return dtype - - def deduce_node_dtype(self, node: torch.fx.Node): - if node.target in boolean_ops(): - return torch.bool - - if node.op == "placeholder": - return None - - if node.target == "output": - # we can infer output node if it only have 1 arg - if len(node.args) != 1: - return None - - if node.target in ( - "to_dtype", - "index_expr", - ): - return node.args[-1] - - if node.target in ( - "rand", - "randn", - ): - return torch.float - - if node.target in ( - "get_index", - "index_expr", - ): - return torch.int64 - - if node.target in ( - "load", - "store", - "store_reduction", - ): - buf_name = node.args[1] - return V.graph.get_dtype(buf_name) - - if node.target == operator.getitem: - return self.deduce_node_dtype(node.args[0]) - - assert isinstance(node.target, str) - - if node.target == "reduction": - return node.args[1] - - if node.target == "constant": - return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] - - if node.target.startswith("masked_subblock"): - return self.deduce_node_dtype_by_subgraph(node) - - return self.deduce_node_dtype_by_inputs(node) - - def propagate_graph(self, graph: torch.fx.Graph): - assert graph.nodes - graph_dtype = None - # For masked_subblock, we use output's dtype to represent - # the dtype of this subgraph. For other cases, graph_dtype - # might be None - for node in graph.nodes: - if OptimizationContext.key in node.meta: - opt_ctx = node.meta[OptimizationContext.key] - else: - opt_ctx = OptimizationContext() - - opt_ctx.dtype = self.deduce_node_dtype(node) - node.meta[OptimizationContext.key] = opt_ctx - if node.target == "output": - graph_dtype = opt_ctx.dtype - return graph_dtype - - def propagate(self): - self.propagate_graph(self.graphs["root"]) - - @classmethod - def propagate_loopbody(cls, body): - return cls(body).propagate() - - @classmethod - def propagate_scheduler_node(cls, node): - from ..ir import LoopBody - from ..scheduler import SchedulerNode - - assert isinstance(node, SchedulerNode) - assert isinstance(node._body, LoopBody) - DataTypePropagation.propagate_loopbody(node._body) - - -class ExprPrinter(Printer): - @staticmethod - def paren(string): - def all_in_parens(string): - if string[0] != "(" or len(string) < 2: - return False - count = 1 - for i, char in enumerate(string[1:]): - if char == "(": - count += 1 - elif char == ")": - count -= 1 - if count == 0 and i != len(string) - 2: - return False - assert count == 0 - return True - - if ( - isinstance(string, CSEVariable) - or re.match(r"^[a-z0-9_.]+$", string, re.I) - or re.match(r"^\([^)]*\)$", string, re.I) - or string == "" - ): - return string - # don't put extra parens for strings that are already wrapped in parens - if all_in_parens(string): - return string - return f"({string})" - - def _print_Pow(self, expr): - # Pow() confuses triton - base, exp = expr.args - # NB: Remember this is sizevar computation! You don't typically - # expect to have to do floating point computation including exponents - # in sizevar compute. Instead of adding support for floating - # point pow, you should make upstream retranslate the Sympy expression - # into Tensor expressions earlier and do that instead. - if exp == 0.5: - return self._helper_sqrt(base) # type: ignore[attr-defined] - elif exp == -0.5: - return "1/" + self._helper_sqrt(base) # type: ignore[attr-defined] - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - if exp > 0: - return "*".join([self.paren(base)] * exp) - elif exp < 0: - return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - return "1" - - def _print_Unequality(self, expr): - return " != ".join(map(self.paren, map(self._print, expr.args))) - - def _print_Mul(self, expr): - return "*".join(map(self.paren, map(self._print, expr.args))) - - def _print_Add(self, expr): - return " + ".join(map(self.paren, map(self._print, expr.args))) - - def _print_Mod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - def _print_CleanDiv(self, expr): - return self._print_FloorDiv(expr) # type: ignore[attr-defined] - - def _print_GreaterThan(self, expr): - # GreaterThan: >= - # StrictlyGreaterThan: > - # Go figure... - return " >= ".join(map(self.paren, map(self._print, expr.args))) - - -class PythonPrinter(ExprPrinter): - def _print_ModularIndexing(self, expr): - x, div, mod = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - mod = self.paren(self.doprint(mod)) - if div != "1": - x = f"({x} // {div})" - return f"{x} % {mod}" - - def _print_FloorDiv(self, expr): - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - return f"({x} // {div})" - - def _helper_sqrt(self, expr): - return f"math.sqrt({self._print(expr)})" - - def _print_floor(self, expr): - assert len(expr.args) == 1 - return f"math.floor({self._print(expr.args[0])})" - - def _print_ceiling(self, expr): - assert len(expr.args) == 1 - return f"math.ceil({self._print(expr.args[0])})" - - -class OpOverrides: - def __init__(self, parent): - super().__init__() - self._parent = parent - - def __getattr__(self, item): - return getattr(self._parent, item) - - @staticmethod - def identity(value): - # used to trigger cse - return value - - @staticmethod - def constant(value, dtype): - return repr(value) - - @staticmethod - def reciprocal(x): - return ops.div("1", x) - - @staticmethod - def square(x): - return ops.mul(x, x) - - @staticmethod - def bitwise_not(x): - return f"~{ExprPrinter.paren(x)}" - - @staticmethod - def logical_not(a): - return f"{ExprPrinter.paren(a)} == 0" - - @staticmethod - def bitwise_and(x, y): - return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" - - @staticmethod - def bitwise_or(x, y): - return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" - - @staticmethod - def bitwise_xor(x, y): - return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" - - @staticmethod - def bitwise_left_shift(x, y): - return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" - - # TODO(fdrocha): this is currently not being used anywhere, - # pending on moving triton pin past 972b761 - @staticmethod - def bitwise_right_shift(x, y): - return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" - - @staticmethod - def remainder(a, b): - r = ops.mod(a, b) - return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r) - - @staticmethod - def load_seed(name, offset): - return ops.load(name, sympy.Integer(offset)) - - -class DeferredLine(DeferredLineBase): - """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" - - def __init__(self, name, line): - super().__init__(line) - self.name = name - - def __call__(self): - if ( - self.name not in V.graph.removed_buffers - and self.name not in V.graph.inplaced_to_remove - ): - return self.line - return None - - def _new_line(self, line): - return DeferredLine(self.name, line) - - -class BracesBuffer(IndentedBuffer): - def indent(self, offset=1): - @contextlib.contextmanager - def ctx(): - for _ in range(offset): - self.writeline("{") - self._indent += 1 - for _ in range(-offset): - self._indent -= 1 - self.writeline("}") - yield - for _ in range(-offset): - self.writeline("{") - self._indent += 1 - for _ in range(offset): - self._indent -= 1 - self.writeline("}") - - return ctx() - - -class InplacedBuffer(NamedTuple): - inner_name: str - other_names: List[str] - - -class KernelArgs: - @staticmethod - def _lookup(prefix, odict, name): - assert isinstance(name, (str, sympy.Symbol)) - if name not in odict: - odict[name] = f"{prefix}{len(odict)}" - return odict[name] - - def __init__(self, sizevars=None): - self.input_buffers = dict() - self.output_buffers = dict() - self.inplace_buffers = dict() - self.sizevars = sizevars or dict() - - def __repr__(self): - return "KernelArgs({})".format( - ", ".join( - map( - repr, - [ - self.input_buffers, - self.output_buffers, - self.inplace_buffers, - self.sizevars, - ], - ) - ) - ) - - def _buffer_is_marked_removed(self, name): - return isinstance(name, str) and name.startswith("REMOVED") - - def input(self, name): - if V.graph.scheduler: - name = V.graph.scheduler.mutation_real_name.get(name, name) - assert name not in V.graph.removed_buffers, name - if name in self.output_buffers: - return self.output_buffers[name] - if name in self.inplace_buffers: - return self.inplace_buffers[name].inner_name - if name.startswith("seed"): - return self._lookup("seed", self.input_buffers, name) - return self._lookup("in_ptr", self.input_buffers, name) - - def output(self, name): - if V.graph.scheduler: - name = V.graph.scheduler.mutation_real_name.get(name, name) - assert name not in V.graph.removed_buffers, name - if name in self.inplace_buffers: - return self.inplace_buffers[name].inner_name - return self._lookup("out_ptr", self.output_buffers, name) - - def make_inplace(self, input_name, output_name): - assert output_name not in self.inplace_buffers - if input_name in self.inplace_buffers: - buf = self.inplace_buffers[input_name] - buf.other_names.append(output_name) - self.inplace_buffers[output_name] = buf - else: - buf = InplacedBuffer( - f"in_out_ptr{len(unique(self.inplace_buffers.values()))}", - [input_name, output_name], - ) - self.inplace_buffers[input_name] = buf - self.inplace_buffers[output_name] = buf - - def seed_offset(self, name, value): - if value in self.sizevars: - return self.sizevars[value] - if name in self.sizevars.values(): - name = ( - f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" - ) - self.sizevars[value] = name - return name - - def size(self, name): - if str(name) == "seed": - self.sizevars["seed"] = "seed" - return "seed" - return self._lookup("ks", self.sizevars, name) - - def call_names(self): - return chain( - self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() - ) - - def wrap_ptr_arg(self, buf, dtype): - return f"c_void_p({buf}.data_ptr())" - - def wrap_size_arg(self, size): - return f"c_long({size})" - - def cpp_argdefs(self): - from .cpp import DTYPE_TO_CPP, INDEX_TYPE - - # TODO(jansel): replace this with data from scheduler - buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers} - for name, val in V.graph.graph_inputs.items(): - if isinstance(val, sympy.Expr): - buffer_types[name] = get_sympy_Expr_dtype(val) - else: - buffer_types[name] = val.get_dtype() - buffer_types.update( - {name: val.dtype for name, val in V.graph.constants.items()} - ) - - call_args = [] - arg_defs = [] - arg_types = [] - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - outer = inplaced.other_names[-1] - inner = inplaced.inner_name - dtype = buffer_types[outer] - cpp_dtype = DTYPE_TO_CPP[dtype] - arg_defs.append(f"{cpp_dtype}* {inner}") - call_args.append(self.wrap_ptr_arg(outer, dtype)) - arg_types.append(f"{cpp_dtype}*") - for outer, inner in self.input_buffers.items(): - if outer in self.inplace_buffers: - continue - dtype = buffer_types[outer] - cpp_dtype = DTYPE_TO_CPP[dtype] - arg_defs.append(f"const {cpp_dtype}* {inner}") - call_args.append(self.wrap_ptr_arg(outer, dtype)) - arg_types.append(f"const {cpp_dtype}*") - for outer, inner in self.output_buffers.items(): - if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): - continue - dtype = buffer_types[outer] - cpp_dtype = DTYPE_TO_CPP[dtype] - arg_defs.append(f"{cpp_dtype}* {inner}") - call_args.append(self.wrap_ptr_arg(outer, dtype)) - arg_types.append(f"{cpp_dtype}*") - for outer, inner in self.sizevars.items(): - arg_defs.append(f"const {INDEX_TYPE} {inner}") - call_args.append(self.wrap_size_arg(outer)) - arg_types.append(f"const {INDEX_TYPE}") - return arg_defs, call_args, arg_types - - def python_argdefs(self): - arg_defs = [] - call_args = [] - precompile_args: List[Union[TensorArg, SizeArg]] = [] - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - arg_defs.append(inplaced.inner_name) - call_args.append(inplaced.other_names[-1]) - precompile_args.append( - TensorArg( - inplaced.inner_name, - inplaced.other_names[-1], - V.graph.get_dtype(inplaced.other_names[-1]), - ) - ) - for outer, inner in chain( - self.input_buffers.items(), self.output_buffers.items() - ): - if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): - continue - arg_defs.append(inner) - call_args.append(outer) - precompile_args.append(TensorArg(inner, outer, V.graph.get_dtype(outer))) - for outer, inner in self.sizevars.items(): - arg_defs.append(inner) - call_args.append(outer) - precompile_args.append(SizeArg(inner, outer)) - - return arg_defs, call_args, precompile_args - - def aliases(self): - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - for other in inplaced.other_names: - if other in V.graph.inplaced_to_remove: - continue - if other in self.input_buffers: - yield self.input_buffers[other], inplaced.inner_name - if other in self.output_buffers: - yield self.output_buffers[other], inplaced.inner_name - - def is_removed(self, name): - def _is_removed(name, buffers): - return name not in buffers or self._buffer_is_marked_removed(buffers[name]) - - return _is_removed(name, self.output_buffers) and _is_removed( - name, self.inplace_buffers - ) - - # Includes inplace buffers, excludes removed buffers. Essentially, - # after you do a call into this kernel, which buffers actually contain - # updated data? Modeled off of python_argdefs. - def live_output_buffers(self): - live_outs = set() - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - live_outs.add(inplaced.other_names[-1]) - for outer, inner in self.output_buffers.items(): - if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): - continue - live_outs.add(outer) - return live_outs - - -class CSEVariable: - """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. - To do so, the backends can simply overload `Kernel.create_cse_var` - The "CSEVariable.update_on_args" method gives you a hook for annotations - See example of TritonCSEVariable in triton.py - """ - - def __init__(self, name, bounds: ValueRanges): - assert isinstance(bounds, ValueRanges) - self.name = name - self.bounds = bounds - - def __str__(self): - return self.name - - def __hash__(self) -> int: - return hash(self.name) - - def __eq__(self, other) -> bool: - return type(other) == type(self) and other.name == self.name - - def update_on_args(self, name, args, kwargs): - pass - - -class CppWrapperKernelArgs(KernelArgs): - def wrap_ptr_arg(self, buf, dtype): - from .cpp import DTYPE_TO_CPP - - return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" - - def wrap_size_arg(self, size): - return f"{size}" - - -class CSE: - """Common subexpression elimination""" - - def __init__( - self, - prefix="", - suffix="", - name_prefix="tmp", - iter_buffers=None, - store_cache=None, - reduction_cache=None, - varname_map=None, - ): - self.prefix = prefix - self.suffix = suffix - self.cache = {} - self.name_prefix = name_prefix - self.store_cache = store_cache or {} - self.reduction_cache = reduction_cache or {} - self.iter_buffer_ids = iter_buffers or itertools.count() - self.invalidated_stores = set() - self.varname_map = varname_map or {} - - def invalidate(self, keep_vars: Set[str]): - for name, tmp in list(self.store_cache.items()): - if tmp not in keep_vars: - del self.store_cache[name] - self.invalidated_stores.add(name) - self.cache = {k: v for k, v in self.cache.items() if v in keep_vars} - - def clone(self): - # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional - return CSE( - prefix=self.prefix, - suffix=self.suffix, - name_prefix=self.name_prefix, - iter_buffers=self.iter_buffer_ids, - store_cache=self.store_cache, - varname_map=self.varname_map, - ) - - def generate( - self, - buffer: IndentedBuffer, - expr: Union[str, CSEVariable, OpsValue], - *, - bounds: ValueRanges = ValueRanges.unknown(), - write=True, - assignment=True, - ) -> CSEVariable: - if isinstance(expr, OpsValue): - expr = expr.value - - assert isinstance(expr, (str, CSEVariable)), type(expr) - assert write or assignment - if isinstance(expr, CSEVariable): - # If the expressions were always created with all the information, we could - # assert expr.bounds == bounds, but sometimes the expression is created - # with the loose ValueRanges.unknown(), so we need to tighten the bounds - expr.bounds = expr.bounds.tighten(bounds) - return expr - cache_key = expr - var = self.cache.get(cache_key, None) - if not var: - var = self.newvar(bounds) if assignment else None - self.cache[cache_key] = var - if write: - if V.kernel.current_node: - V.kernel.current_node.codegen_originating_info( - buffer, only_once=True - ) - if assignment: - line = f"{self.prefix}{var} = {expr}{self.suffix}" - else: - line = f"{expr}{self.suffix}" - buffer.writeline(line) - else: - var.bounds = var.bounds.tighten(bounds) - - return var - - def newvar(self, bounds: ValueRanges = ValueRanges.unknown()) -> CSEVariable: - var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" - var = V.kernel.create_cse_var(var_name, bounds) - self.varname_map[var_name] = var - return var - - -class CodeGen: - def __init__(self): - super().__init__() - self.exit_stack = contextlib.ExitStack() - - def __enter__(self): - self.exit_stack.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.exit_stack.__exit__(exc_type, exc_val, exc_tb) - - -class Kernel(CodeGen): - newvar_prefix = "" - suffix = "" - overrides = None - load_format = None - store_format = None - - def __init__(self, args=None): - super().__init__() - metrics.generated_kernel_count += 1 - self.args = args or KernelArgs() - self.loads = IndentedBuffer() - self.compute = IndentedBuffer() - self.stores = IndentedBuffer() - self.cse = CSE(self.newvar_prefix, self.suffix) - self.must_keep_buffers = set() - self.store_buffer_names = set() - # set in set_current_node - self.current_node = None - self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges]] = None - - @contextlib.contextmanager - def set_current_node(self, node): - prior = self.current_node - self.current_node = node - self.node_to_bounds = node._body.bounds().get_bounds() - try: - yield - finally: - self.current_node = prior - - @contextlib.contextmanager - def swap_buffers(self, lb, cb=None, sb=None): - if cb is None: - cb = lb - loads = self.loads - compute = self.compute - stores = self.stores - cse = self.cse - self.loads = lb - self.compute = cb - self.stores = sb - self.cse = cse.clone() - try: - yield - finally: - self.loads = loads - self.compute = compute - self.stores = stores - self.cse = cse - - def load(self, name: str, index: sympy.Expr): - raise NotImplementedError() - - def indirect_load(self, name: str, index: sympy.Expr): - """A load the depends on an index we have read""" - prior = self.loads - try: - # put the load in the compute section as it might have deps - self.loads = self.compute - return self.load(name, index) - finally: - self.loads = prior - - def store_reduction(self, name, index, value): - raise NotImplementedError() - - def store(self, name, index, value, mode=None): - raise NotImplementedError() - - def reduction(self, dtype, src_dtype, reduction_type, value): - raise NotImplementedError() - - def bucketize( - self, - values, - offsets_name: str, - offsets_size: sympy.Expr, - indexing_dtype: torch.dtype, - right: bool, - ): - """ - See [Note: Inductor bucketize op] - """ - raise NotImplementedError() - - def __enter__(self): - class CSEProxy: - self.name = "CSEProxy" - - @staticmethod - def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] - def inner(*args, **kwargs): - # TritonTemplateKernel has no current_node - buf_bounds = ValueRanges.unknown() - if hasattr(V.interpreter, "current_node"): - fx_node = V.interpreter.current_node - assert isinstance(self.node_to_bounds, dict) - buf_bounds = self.node_to_bounds.get( - fx_node, ValueRanges.unknown() - ) - - csevar = self.cse.generate( - self.compute, - getattr(parent_handler, name)(*args, **kwargs), # type: ignore[has-type] - bounds=buf_bounds, - ) - csevar.update_on_args(name, args, kwargs) - return csevar - - return inner - - @staticmethod - def indirect_indexing(index_var, size, check=True): - # Skip CSE since this doesn't return an expression - return self.indirect_indexing(index_var, size, check) # type: ignore[attr-defined] - - @staticmethod - def load(name: str, index: sympy.Expr): - if name in self.cse.invalidated_stores: - # A load from an invalidated store requires us to - # keep the actual buffer around - V.kernel.must_keep_buffers.add(name) - if free_symbol_startswith(index, "tmp"): - return self.indirect_load(name, index) - store_cache = self.cse.store_cache - if name in store_cache: - return store_cache[name] - return self.load(name, index) - - @staticmethod - def store(name, index, value, mode=None): - self.store_buffer_names.add(name) - if mode is None: - self.cse.store_cache[name] = value - if self.current_node: - for other_name in self.current_node.get_mutations(): - self.cse.store_cache[other_name] = value - if name not in V.graph.removed_buffers: - return self.store(name, index, value, mode=mode) - - @staticmethod - def store_reduction(name, index, value): - self.store_buffer_names.add(name) - self.cse.store_cache[name] = value - if self.current_node: - for other_name in self.current_node.get_mutations(): - self.cse.store_cache[other_name] = value - - if name not in V.graph.removed_buffers: - return self.store_reduction(name, index, value) - - @staticmethod - def reduction(dtype, src_dtype, reduction_type, value): - return self.reduction(dtype, src_dtype, reduction_type, value) - - @staticmethod - def bucketize( - values, - offsets_name: str, - offsets_size: sympy.Expr, - indexing_dtype: torch.dtype, - right: bool, - ): - """ - [Note: Inductor bucketize op] - - Given values (tensor) and offsets_name (reference to the name of a 1D - tensor), calculate the bucket that each value belongs to. - - e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True - return = [ 0, 1, 1, 1, 1, 3, 3, 4]. - - When right == False, bucket i refers to range (offsets[i], offsets[i+1]]. - When right == True, bucket i refers to range [offsets[i], offsets[i+1]). - - Offsets must be non-decreasing or the result is undefined. - """ - return self.bucketize( - values, offsets_name, offsets_size, indexing_dtype, right - ) - - super().__enter__() - assert self.overrides - parent_handler = self.overrides(V.get_ops_handler()) - self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) - self.exit_stack.enter_context(V.set_kernel_handler(self)) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if V.graph.scheduler: - V.graph.scheduler.remove_kernel_local_buffers() - super().__exit__(exc_type, exc_val, exc_tb) - - def rename_indexing(self, index) -> sympy.Expr: - # adds the necessary kernel args for index expressions - # and renames variables in index expressions to kernel arg names - if isinstance(index, (list, tuple)): - return [self.rename_indexing(x) for x in index] - index = V.graph.sizevars.simplify(index) - sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) - replacements = { - x: self.args.size(x) - for x in sorted_symbols - if x.name.startswith("s") or x.name.startswith("ps") - } - return sympy_subs(index, replacements) - - def create_cse_var(self, *args, **kwargs): - return CSEVariable(*args, **kwargs) - - -@dataclasses.dataclass -class OptimizationContext: - key: ClassVar[str] = "opt_ctx" - - # Load value as mask - is_load_as_mask: bool = False - - dtype: torch.dtype = None - ops_name: str = "" - is_most_inner_loop_irrevelant: bool = False - - # Load uint8 value as float32 - is_load_uint8_as_float: bool = False \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index a3aa1e28..83561bd4 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -1,5 +1,3 @@ -import getpass -import tempfile import os import re import shlex @@ -16,7 +14,7 @@ LOCK_TIMEOUT = 600 def hash_prefix(hash_value): - return hash_value[1:5] + return hash_value[1:12] def get_write_path(src_code): return os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "tmp", hash_prefix(get_hash(src_code.strip()))) @@ -31,6 +29,21 @@ def dump_metadata(args, arg_attributes, path): file.write(f'{arg_name}=({arg_attribute[0]}, {arg.dtype}, {arg.shape})\n') return +def parse_stack_sizes(file_path): + meta_path = file_path.split(".")[0]+".meta" + cmd = ["riscv64-unknown-elf-objcopy", "--dump-section", f".stack_sizes={meta_path}", file_path, "/dev/null"] + subprocess.run(cmd, check=True) + + with open(meta_path, 'rb') as f: + stack_sizes_data = list(f.read()) + if len(stack_sizes_data) <= 17: + raise ValueError("Invalid .stack_sizes section size") + + stack_size_bytes = stack_sizes_data[8:-9] + stack_size = int.from_bytes(stack_size_bytes, byteorder='little') + return stack_size + + def llvm_compile_command(input, output): opt_output = f"{input[:-3]}_opt.ll" return [re.sub(r"[ \n]+", " ", @@ -44,18 +57,21 @@ def llvm_compile_command(input, output): """, ).strip()] -def mlir_compile_command(filename, vectorlane_size, tile_size, vlen=256): +def mlir_compile_command(filename, vectorlane_size, vlen=256): return [re.sub(r"[ \n]+", " ", f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding \ - -dma-fine-grained='systolic-array-size={vectorlane_size} tile-size={tile_size[0]},{tile_size[1]},{tile_size[2]}' \ + -dma-fine-grained='systolic-array-size={vectorlane_size}' \ + -global-idx='vlen={vlen}' \ -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ + -test-memref-to-gemmini="vectorlane={vectorlane_size}" \ + -convert-linalg-to-loops \ + -convert-vector-to-scf='full-unroll' \ -lower-affine \ + -finalize-memref-to-llvm \ -lower-vector-multi-reduction \ -convert-vector-to-llvm \ - -test-memref-to-gemmini="vectorlane={vectorlane_size}" \ - -finalize-memref-to-llvm \ -convert-arith-to-llvm \ -convert-math-to-llvm \ -convert-scf-to-cf \ @@ -63,6 +79,7 @@ def mlir_compile_command(filename, vectorlane_size, tile_size, vlen=256): -convert-func-to-llvm \ -convert-index-to-llvm \ -reconcile-unrealized-casts \ + {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ {filename}.mlir -o {filename}_llvm.mlir """, ).strip(), @@ -73,23 +90,30 @@ def mlir_compile_command(filename, vectorlane_size, tile_size, vlen=256): ).strip(), re.sub(r"[ \n]+", " ", f""" - {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc -relocation-model=pic -march=riscv64 -mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b -O2 {filename}.ll -o {filename}.s + {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \ + -relocation-model=pic -march=riscv64 -O3 --stack-size-section \ + -mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b \ + {'--print-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_LLVM_IR else ''} \ + -O2 {filename}.ll -o {filename}.s """, ).strip()] -def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_size, tile_size, vlen=256): +def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_size, vlen=256): return [re.sub(r"[ \n]+", " ", f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding='timing_mode=1' \ - -dma-fine-grained='systolic-array-size={vectorlane_size} tile-size={tile_size[0]},{tile_size[1]},{tile_size[2]}' \ - -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen=256' \ - -test-tile-operation-graph='vectorlane={vectorlane_size}' \ + -dma-fine-grained='systolic-array-size={vectorlane_size}' \ + -global-idx='vlen={vlen}' \ + -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ + -test-tile-operation-graph='vectorlane={vectorlane_size} tls_mode={extension_config.CONFIG_TLS_MODE}' \ + -test-memref-to-gemmini="vectorlane={vectorlane_size} timing=1" \ + -convert-linalg-to-loops \ + -convert-vector-to-scf='full-unroll' \ -lower-affine \ + -finalize-memref-to-llvm \ -lower-vector-multi-reduction \ -convert-vector-to-llvm \ - -test-memref-to-gemmini="vectorlane={vectorlane_size} timing=1" \ - -finalize-memref-to-llvm \ -convert-arith-to-llvm \ -convert-math-to-llvm \ -convert-scf-to-cf \ @@ -97,6 +121,7 @@ def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_si -convert-func-to-llvm \ -convert-index-to-llvm \ -reconcile-unrealized-casts \ + {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ {filename}.mlir -o {sample_filename}_llvm.mlir """, ).strip(), @@ -107,10 +132,18 @@ def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_si ).strip(), re.sub(r"[ \n]+", " ", f""" - {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc -relocation-model=pic -march=riscv64 -mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b -O2 {sample_filename}.ll -o {sample_filename}.s + {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \ + -relocation-model=pic -march=riscv64 -O3 --stack-size-section \ + -mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b \ + {'--print-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_LLVM_IR else ''} \ + -O2 {sample_filename}.ll -o {sample_filename}.s """, ).strip()] +class SpadOverflowError(Exception): + def __init__(self, message="SPAD overflow occurred."): + super().__init__(message) + class MLIRCodeCache: cache = dict() clear = staticmethod(cache.clear) # Todo: Cache @@ -125,14 +158,16 @@ def load(cls, source_code, validation_binary_name="validation_bin", cycle_wrapper_name="cycle_wrapper", cycle_binary_name="cycle_bin", - arg_attributes=[], vectorlane_size=16, tile_size=[], - spad_info=None, origins=None, **kwargs): + arg_attributes=[], vectorlane_size=16, + spad_info=None, origins=None, silent_mode=False, **kwargs): + vlen = kwargs['vlen'] + vlenb = vlen // 8 write_path = get_write_path(source_code) key, input_path = write(source_code, "mlir", specified_dir=write_path) new_input_path = os.path.splitext(input_path)[0] raw_tog_path = new_input_path + "_tog.py" sample_mlir_path = new_input_path + "_sample" - gem5_cmds = mlir_gem5_compile_command(new_input_path, sample_mlir_path, raw_tog_path, vectorlane_size, tile_size) + gem5_cmds = mlir_gem5_compile_command(new_input_path, sample_mlir_path, raw_tog_path, vectorlane_size) from filelock import FileLock lock_dir = get_lock_dir() @@ -144,7 +179,9 @@ def load(cls, source_code, link_option = "" # Generate LLVM kernel calller and binary for validation if extension_config.CONFIG_TORCHSIM_VALIDATION_MODE: - cmds = mlir_compile_command(new_input_path, vectorlane_size, tile_size, vlen=256) + # Use custom malloc to avoid size error + new_link_option = link_option + " -Wl,--wrap=malloc -Wl,--wrap=free" + cmds = mlir_compile_command(new_input_path, vectorlane_size, vlen=vlen) opt_cmd = shlex.split(cmds[0]) translate_cmd = shlex.split(cmds[1]) llc_cmd = shlex.split(cmds[2]) @@ -161,7 +198,16 @@ def load(cls, source_code, val_llvm_caller = MLIRKernelCallerCodeGen(extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, arg_attributes) val_llvm_caller.generate_wrapper_file(write_path, validation_wrapper_name) val_llvm_caller.compile_wih_kernel(write_path, key, validation_wrapper_name, - validation_binary_name, link_option) + validation_binary_name, new_link_option) + target = os.path.join(write_path, validation_binary_name) + stack_size = val_llvm_caller.parse_stack_sizes(f"{write_path}/{key}.s", vlenb=vlenb) + spad_size = val_llvm_caller.get_spad_size(target) + spad_usage = stack_size + spad_size # Spad usage per lane + if extension_config.CONFIG_SPAD_INFO["spad_size"] < spad_usage: + print(f"[Warning] Scratchpad size exceeded: required {spad_usage} bytes, " + f"but only {extension_config.CONFIG_SPAD_INFO['spad_size']} bytes available.") + raise SpadOverflowError() + # Launch tile graph generator gem5_sample_cmd = shlex.split(gem5_cmds[0]) gem5_translate_cmd = shlex.split(gem5_cmds[1]) @@ -180,33 +226,37 @@ def load(cls, source_code, print("Error output:", e.output) assert(0) - if extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: - return key + if extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: + return key - # Generate MLIR kernel calller and binary for cycle calculation - cycle_llvm_caller = MLIRKernelCallerCodeGen(False, arg_attributes, cycle_sim=True) - cycle_llvm_caller.generate_wrapper_file(write_path, cycle_wrapper_name) - cycle_llvm_caller.compile_wih_kernel(write_path, key + "_sample", cycle_wrapper_name, cycle_binary_name, link_option) - array_size = [] - for (arg_name, arg_attribute) in arg_attributes: - array_size.append(str(arg_attribute[2])) - - # Run cyclesim - cyclesim = CycleSimulator() - cycle_list = cyclesim.compile_and_simulate(os.path.join(write_path, cycle_binary_name), " ".join(array_size), vectorlane_size) - - # Create TOG - offset = vectorlane_size - if kwargs['loop_size'] is not None and kwargs['loop_size'][0] < vectorlane_size: - offset = kwargs['loop_size'][0] - tile_graph_generator = tog_generator(origins) - tile_graph_generator.load_file(raw_tog_path) - tile_graph_generator.generate_tile_graph( - os.path.join(write_path, "tile_graph.onnx"), - cycle_list=cycle_list, - offset=offset, # FIXME. - vector_lane=vectorlane_size - ) + # Generate MLIR kernel calller and binary for cycle calculation + cycle_llvm_caller = MLIRKernelCallerCodeGen(False, arg_attributes, cycle_sim=True) + cycle_llvm_caller.generate_wrapper_file(write_path, cycle_wrapper_name) + cycle_llvm_caller.compile_wih_kernel(write_path, key + "_sample", cycle_wrapper_name, cycle_binary_name, link_option) + array_size = [] + for (arg_name, arg_attribute) in arg_attributes: + array_size.append(str(arg_attribute[2])) + + # Run cyclesim + cyclesim = CycleSimulator() + cycle_list = cyclesim.compile_and_simulate(os.path.join(write_path, cycle_binary_name), " ".join(array_size), vectorlane_size, silent_mode=silent_mode) + + # Create TOG + w_offset, x_offset = vectorlane_size, vectorlane_size + if kwargs['loop_size'] is not None and kwargs['loop_size'][-3] < vectorlane_size: + x_offset = kwargs['loop_size'][-3] + if kwargs['loop_size'] is not None and kwargs['loop_size'][-1] < vectorlane_size: + w_offset = kwargs['loop_size'][-1] + w_offset = 0 # max(w_offset - x_offset, 0) + tile_graph_generator = tog_generator(origins) + tile_graph_generator.load_file(raw_tog_path) + tile_graph_generator.generate_tile_graph( + os.path.join(write_path, "tile_graph.onnx"), + cycle_list=cycle_list, + x_offset=x_offset, # FIXME. + w_offset=w_offset, # FIXME. + vector_lane=vectorlane_size + ) return key class LLVMCodeCache: @@ -285,46 +335,84 @@ def __init__(self): self.cycle_wrapper_name = "cycle_wrapper" self.cycle_binary_name = "cycle_binary" - def mlir(self, source_code, arg_attributes=[], vectorlane_size=16, tile_size=[], spad_info=None, origins=None, **kwargs): + def mlir(self, source_code, arg_attributes=[], vectorlane_size=16, tile_size=[], spad_info=None, origins=None, silent_mode=False, **kwargs): def task(): key = MLIRCodeCache.load(source_code, valdiation_wrapper_name=self.validation_binary_name, validation_binary_name=self.validation_binary_name, arg_attributes=arg_attributes, vectorlane_size=vectorlane_size, - tile_size=tile_size, spad_info=spad_info, origins=origins, **kwargs) + tile_size=tile_size, spad_info=spad_info, origins=origins, + silent_mode=silent_mode, **kwargs) return key future = self.submit(task) + if "loop_size" in kwargs: + loop_size = kwargs["loop_size"] + else: + loop_size = [] def dummy_simulator(*args, **kwargs): + validate = kwargs.get('validate', False) # Wait for compilation key = future.result() + from filelock import FileLock + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + # Run simulator pass + result_path = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "tmp", hash_prefix(key)) + # Dump arguments and meta data + dump_metadata(args, arg_attributes, result_path) + runtime_path = FunctionalSimulator.get_runtime_dump_path(result_path) + if extension_config.CONFIG_TORCHSIM_VALIDATION_MODE or validate: + funcsim = FunctionalSimulator(result_path, key) + funcsim.run_spike(args, arg_attributes, + runtime_path, self.validation_binary_name, + vectorlane_size=vectorlane_size, spad_info=spad_info, + cleanup=extension_config.CONFIG_CLEANUP_DUMP_ARGS, silent_mode=silent_mode) + if extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: + return + + onnx_path = os.path.join(result_path, "tile_graph.onnx") + attribute_path = os.path.join(runtime_path, "attribute") + backend_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "PyTorchSimBackend") + backsim = BackendSimulator(backend_path, extension_config.CONFIG_TORCHSIM_BACKEND_CONFIG) + backsim.vectorlane_size = vectorlane_size + attribute_path = backsim.create_attribute_file(attribute_path, args, loop_size=loop_size) + result_path = backsim.simulation(onnx_path, attribute_path, silent_mode=silent_mode) + result = BackendSimulator.get_result_from_file(result_path) + return result - # Run simulator pass + def dryrun_simulator(*args, **kwargs): + autotune = kwargs.get('autotune', False) + key = future.result() + # Run simulator pass result_path = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "tmp", hash_prefix(key)) # Dump arguments and meta data dump_metadata(args, arg_attributes, result_path) - if extension_config.CONFIG_TORCHSIM_VALIDATION_MODE: + runtime_path = FunctionalSimulator.get_runtime_dump_path(result_path) + if extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: + return + + if autotune: + onnx_path = os.path.join(result_path, "tile_graph.onnx") + attribute_path = os.path.join(runtime_path, "attribute") + backend_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "PyTorchSimBackend") + backsim = BackendSimulator(backend_path, extension_config.CONFIG_TORCHSIM_BACKEND_CONFIG) + backsim.vectorlane_size = vectorlane_size + attribute_path = backsim.create_attribute_file(attribute_path, args, loop_size=loop_size) + result_path_2 = backsim.simulation(onnx_path, attribute_path) + result = BackendSimulator.get_result_from_file(result_path_2) + return result_path, runtime_path, result + + # Todo. Support valude dependent mode for graph mode + if False: # extension_config.CONFIG_TORCHSIM_VALIDATION_MODE: funcsim = FunctionalSimulator(result_path, key) funcsim.run_spike(args, arg_attributes, - result_path, self.validation_binary_name, - kwargs['intermediate_op'] if 'intermediate_op' in kwargs else None, + runtime_path, self.validation_binary_name, vectorlane_size=vectorlane_size, spad_info=spad_info, cleanup=extension_config.CONFIG_CLEANUP_DUMP_ARGS) - if extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: - return - - onnx_path = os.path.join(result_path, "tile_graph.onnx") - attribute_path = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "tmp", hash_prefix(key), "attribute") - backend_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "PyTorchSimBackend") - backsim = BackendSimulator(backend_path, extension_config.CONFIG_TORCHSIM_BACKEND_CONFIG) - attribute_path = backsim.create_attribute_file(attribute_path, args, tile_size=tile_size) - result_path = backsim.simulation(onnx_path, attribute_path) - result = BackendSimulator.get_result_from_file(result_path) - return result - - def dryrun_simulator(*args, **kwargs): - key = future.result() + return result_path, runtime_path, None - is_dryrun = extension_config.CONFIG_BACKENDSIM_DRYRUN + is_dryrun = int(os.environ.get('BACKENDSIM_DRYRUN', default=False)) target_simulator = dryrun_simulator if is_dryrun else dummy_simulator target_simulator.arg_attributes = arg_attributes target_simulator.future = future diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 6f5b3e00..59f3818c 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -1,20 +1,21 @@ import os +import sys import tempfile +import importlib # Hardware info config -CONFIG_VECTOR_LANE = 128 +CONFIG_VECTOR_LANE = int(os.environ.get("TORCHSIM_VECTOR_LANE", default=128)) +CONFIG_VECTOR_LANE_STRIDE = int(os.environ.get("TORCHSIM_VECTOR_LANE_STRIDE", default=2)) CONFIG_SPAD_INFO = { "spad_vaddr" : 0xD0000000, - "spad_paddr" : 0xD0000000, - "spad_size" : 128 << 10 + "spad_paddr" : 0x2000000000, + "spad_size" : int(os.environ.get("TORCHSIM_SPAD_SIZE", default=128)) << 10 # Note: spad size per lane } CONFIG_PRECISION = 4 # 32bit CONFIG_NUM_CORES = 1 -CONFIG_VLEN = 32 // CONFIG_PRECISION # 256bits / 32bits = 8 [elements] +CONFIG_VLEN = 256 # 256bits / 32bits = 8 [elements] # Tile size config -CONFIG_TILE_ROW = int(os.environ.get("TORCHSIM_TILE_ROW", default=-1)) -CONFIG_TILE_COL = int(os.environ.get("TORCHSIM_TILE_COL", default=-1)) CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') # DUMP PATH @@ -30,16 +31,69 @@ CONFIG_TORCHSIM_LLVM_PATH = os.environ.get('TORCHSIM_LLVM_PATH', default="/usr/bin") CONFIG_TORCHSIM_CUSTOM_PASS_PATH = os.environ.get('TORCHSIM_CUSTOM_PASS_PATH', default=f"{CONFIG_TORCHSIM_DIR}/GemminiLowerPass/build") +CONFIG_TORCHSIM_DUMP_MLIR_IR = int(os.environ.get("TORCHSIM_DUMP_MLIR_IR", default=False)) +CONFIG_TORCHSIM_DUMP_LLVM_IR = int(os.environ.get("TORCHSIM_DUMP_LLVM_IR", default=False)) # Backendsim config CONFIG_TORCHSIM_BACKEND_CONFIG = os.environ.get('TORCHSIM_CONFIG', - default=f'{CONFIG_TORCHSIM_DIR}/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json') -CONFIG_BACKENDSIM_SPIKE_ONLY = bool(os.environ.get("BACKENDSIM_SPIKE_ONLY", False)) -CONFIG_BACKENDSIM_EAGER_MODE = bool(os.environ.get("BACKENDSIM_EAGER_MODE", default=False)) -CONFIG_BACKENDSIM_DRYRUN = bool(int(os.environ.get('BACKENDSIM_DRYRUN', default=0))) + default=f'{CONFIG_TORCHSIM_DIR}/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') +CONFIG_BACKENDSIM_SPIKE_ONLY = int(os.environ.get("BACKENDSIM_SPIKE_ONLY", False)) +CONFIG_BACKENDSIM_EAGER_MODE = int(os.environ.get("BACKENDSIM_EAGER_MODE", default=False)) +CONFIG_BACKENDSIM_DRYRUN = int(os.environ.get('BACKENDSIM_DRYRUN', default=False)) CONFIG_BACKENDSIM_DEBUG_LEVEL = os.environ.get("BACKENDSIM_DEBUG_LEVEL", "") # GEM5 config CONFIG_GEM5_PATH = os.environ.get('GEM5_PATH', default="/workspace/gem5/build/RISCV/gem5.opt") CONFIG_GEM5_SCRIPT_PATH = os.environ.get('GEM5_SCRIPT_PATH', - default=f"{CONFIG_TORCHSIM_DIR}/gem5_script/script_systolic.py") \ No newline at end of file + default=f"{CONFIG_TORCHSIM_DIR}/gem5_script/script_systolic.py") + +# AUTOTUNE config +CONFIG_AUTOTUNE = int(os.environ.get('AUTOTUNE', default=True)) +CONFIG_MAX_AUTOTUNE_TRY = int(os.environ.get('MAX_AUTOTUNE_TRY', default=10)) + +# For block sparse +CONFIG_BLOCK_SPARSE = int(os.environ.get('BLOCK_SPARSE', default=0)) + +# For GEMM tile size +CONFIG_MANUAL_TILE_SIZE = int(os.environ.get('TORCHSIM_MANUAL_TILE_SIZE', default=False)) +CONFIG_TILE_M = int(os.environ.get('TORCHSIM_TILE_M', default=CONFIG_VECTOR_LANE)) +CONFIG_TILE_N = int(os.environ.get('TORCHSIM_TILE_N', default=CONFIG_VECTOR_LANE)) +CONFIG_TILE_K = int(os.environ.get('TORCHSIM_TILE_K', default=CONFIG_VECTOR_LANE)) +CONFIG_GEMM_CHEATSHEET_PATH = os.environ.get('TORCHSIM_GEMM_CHEATSHEET_PATH', + default=f"{CONFIG_TORCHSIM_DIR}/validation/gemm_tpuv3_cheatsheet.json") +CONFIG_SUBTILE = int(os.environ.get('TORCHSIM_SUBTILE', default=True)) +CONFIG_MANUAL_SUBTILE_SIZE = int(os.environ.get('TORCHSIM_MANUAL_SUBTILE_SIZE', default=False)) +CONFIG_SUBTILE_M = int(os.environ.get('TORCHSIM_SUBTILE_M', default=CONFIG_VECTOR_LANE)) +CONFIG_SUBTILE_N = int(os.environ.get('TORCHSIM_SUBTILE_N', default=CONFIG_VECTOR_LANE)) +CONFIG_SUBTILE_K = int(os.environ.get('TORCHSIM_SUBTILE_K', default=CONFIG_VECTOR_LANE)) + +# Advanced fusion options +CONFIG_FUSION_REDUCTION_EPILOGUE = int(os.environ.get('TORCHSIM_FUSION_REDUCTION_EPILOGUE', default=True)) +CONFIG_FUSION_REDUCTION_REDUCTION = int(os.environ.get('TORCHSIM_FUSION_REDUCTION_REDUCTION', default=True)) +CONFIG_FUSION_PROLOGUE = int(os.environ.get('TORCHSIM_FUSION_PROLOGUE', default=True)) + +# SRAM Buffer allocation plan +def load_plan_from_module(module_path): + if module_path is None: + return None + + try: + spec = importlib.util.spec_from_file_location("plan_module", module_path) + if spec is None: + return None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if hasattr(module, 'plan'): + return module.plan + return None + except Exception as e: + print(f"[Warning] Failed to load SRAM buffer plan from module: {e}") + return None + +CONFIG_SRAM_BUFFER_PLAN_PATH = os.environ.get("SRAM_BUFFER_PLAN_PATH", default=None) +CONFIG_SRAM_BUFFER_PLAN = load_plan_from_module(CONFIG_SRAM_BUFFER_PLAN_PATH) + +# For ILS experiment +CONFIG_TLS_MODE = int(os.environ.get('TORCHSIM_TLS_MODE', default=1)) + +CONFIG_USE_TIMING_POOLING = int(os.environ.get('TORCHSIM_USE_TIMING_POOLING', default=0)) diff --git a/PyTorchSimFrontend/extension_device.cpp b/PyTorchSimFrontend/extension_device.cpp index 6bceb8ae..4d33db08 100644 --- a/PyTorchSimFrontend/extension_device.cpp +++ b/PyTorchSimFrontend/extension_device.cpp @@ -307,8 +307,6 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("_foreach_addcdiv_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("_foreach_add_.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("cat.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - // m.impl("addmm.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); // TODO: only for optimizer test - // m.impl("mm.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); // TODO: only for optimizer test } // This basic implementation doesn't bother dealing with different device indices diff --git a/PyTorchSimFrontend/extension_op.py b/PyTorchSimFrontend/extension_op.py new file mode 100644 index 00000000..22a727c5 --- /dev/null +++ b/PyTorchSimFrontend/extension_op.py @@ -0,0 +1,303 @@ +import os +import subprocess +import math +import struct +from datetime import datetime +import random +import torch +import numpy as np +import hashlib +from torch._inductor.select_algorithm import ExternKernelChoice +from torch._inductor.codecache import get_hash +from AsmParser.tog_generator import tog_generator +from torch._inductor.codecache import write +from PyTorchSimFrontend.extension_codecache import get_write_path +from PyTorchSimFrontend import extension_config +from Simulator.simulator import BackendSimulator, TORCH_TO_NUMPY + +graph_template = { + 0: { + "node_id": 0, + "node_name": "root", + "node_type": 0, + "parents": [], + "children": [1] + }, + 1: { + "node_id": 1, + "node_name": "loopNode", + "node_type": 2, + "parents": [0], + "children": [2], + "loop_index": "loop_arg000", + "loop_start": 0, + "loop_end": 4, # FIXME. this is a trick that generate multiple tile. + "loop_step": 1, + "loop_type": "outer_loop" + }, + 2: { + "node_id": 2, + "node_name": "stonneNode", + "node_type": 5, + "parents": [1], + "children": [], + } +} + +class MLIRExternKernelChoice(ExternKernelChoice): + def call_name(self): + is_dryrun = int(os.environ.get('BACKENDSIM_DRYRUN', default=False)) + if is_dryrun: + return f"yield from sparse_mm_dummy_stonne_outer" + return f"torch.ops.extension_op.{self.name}" + +custom_lib = torch.library.Library("extension_op", "DEF") + +def calculate_sparsity(tensor): + total_elements = tensor.numel() + zero_elements = torch.sum(tensor.cpu() == 0) + sparsity_ratio = zero_elements / total_elements * 100 + return math.ceil(sparsity_ratio.item()) + +def generate_outer_product_matrix(a, b, M, K, N, prefix, dir_path): + # Generating matrix A + data_width = 4 + a_cpu = a.cpu() + b_cpu = b.cpu() + value_pointer = os.path.join(dir_path, f'{prefix}_outerproduct_gemm_mem.ini') + rowA_pointer = os.path.join(dir_path, f'{prefix}_outerproduct_gemm_rowpointerA.in') + colA_pointer = os.path.join(dir_path, f'{prefix}_outerproduct_gemm_colpointerA.in') + rowB_pointer = os.path.join(dir_path, f'{prefix}_outerproduct_gemm_rowpointerB.in') + colB_pointer = os.path.join(dir_path, f'{prefix}_outerproduct_gemm_colpointerB.in') + + with open(value_pointer, "w") as fd, open(rowA_pointer, "w") as rpA, open(colA_pointer, "w") as cpA, open(rowB_pointer, "w") as rpB, open(colB_pointer, "w") as cpB: + #generating matrixA + n_nonzeros=0 + for k in range(K): # col major + initial_values=0 + rpA.write(str(n_nonzeros)+","); # writing the index of A + for m in range(M): + if(a_cpu[m, k]): # value is nonzero + if((m==(M-1)) and (k==(K-1))): + cpA.write(str(m)) + else: + cpA.write(str(m)+","); #writing the row index + initial_values+=1 + value = a_cpu[m, k] + ba = bytearray(struct.pack(">f", value)) # generating list of bytes + my_int = int.from_bytes(ba, "big") + fd.write(str(my_int)) + fd.write(",") + n_nonzeros+=1 + rpA.write(str(n_nonzeros)) + address_matrix_b=n_nonzeros*data_width + #Generating matrix B + n_nonzeros=0 + for k in range(0,K): # Row major + initial_values=0 + rpB.write(str(n_nonzeros)+","); # writing the index of A + for n in range(0,N): + if(b_cpu[k, n]): # value is nonzero + if((k==(K-1)) and (n==(N-1))): + cpB.write(str(n)) + else: + cpB.write(str(n)+","); #writing the row index + + initial_values+=1 + value = b_cpu[k, n] + ba = bytearray(struct.pack(">f", value)) # generating list of bytes + my_int = int.from_bytes(ba, "big") + fd.write(str(my_int)) + fd.write(",") + n_nonzeros+=1 + + rpB.write(str(n_nonzeros)) + fd.write(str(0)) # Adding a final 0 to the memory which will never be used. This is just to avoid having a last comma. + address_matrix_c=address_matrix_b+(n_nonzeros*data_width) + return 0, address_matrix_b, address_matrix_c + +def generate_inner_product_matrix(a, b, M, K, N, file_name, in_file_bitmap_a, in_file_bitmap_b): + data_width = 4 + a_cpu = a.cpu() + b_cpu = b.cpu() + matrixA_size=int(M*K) + matrixB_size=int(N*K) + matrixC_size=int(M*N) + + random.seed(a=0, version=2) + + address_matrix_a = 0 + with open(file_name, "w") as fd, open(in_file_bitmap_a, "w") as fbA, open(in_file_bitmap_b, "w") as fbB: + #generating matrixA + n_nonzeros=0 + for m in range(M): # Row major + for k in range(K): + is_sparse = a_cpu[m,k] + if(torch.isclose(is_sparse, torch.zeros(1), atol=1e-1)): + if((m==(M-1)) and (k==(K-1))): + fbA.write(str(1)) + else: + fbA.write(str(1)+","); #writing a 1 in bitmap + ba = bytearray(struct.pack(">f", is_sparse)) # generating list of bytes + my_int = int.from_bytes(ba, "big") + fd.write(str(my_int)) + fd.write(",") + n_nonzeros+=1 + else: + if((m==(M-1)) and (k==(K-1))): # this is to insert a comma + fbA.write(str(0)) + # note no data element is inserted in this case + else: + # note no data element is inserted in this case + fbA.write(str(0)+",") + + address_matrix_b=n_nonzeros*data_width + #Generating matrix B + n_nonzeros=0 + bitmapB=list(range(0,matrixB_size)) + for n in range(0,N): # Row major + for k in range(0,K): + is_sparse = b_cpu[k,n] + if(torch.isclose(is_sparse, torch.zeros(1), atol=1e-1)): # value is generated + bitmapB[k*N+n]=1 + ba = bytearray(struct.pack(">f", float(is_sparse))) # generating list of bytes + my_int = int.from_bytes(ba, "big") + fd.write(str(my_int)) + fd.write(",") + n_nonzeros+=1 + else: + # no data element is inserted in this case + bitmapB[k*N+n]=0; #writing a 0 + # writing the bitmapB in the appropiate order + for i in range(0, matrixB_size): + fbB.write(str(bitmapB[i])) + if(i < (matrixB_size-1)): + fbB.write(",") + + fd.write(str(0)) # Adding a final 0 to the memory which will never be used. This is just to avoid having a last comma. + address_matrix_c=address_matrix_b+(n_nonzeros*data_width) + print("Offset matrix A: "+str(address_matrix_a)) + print("Offset matrix B: "+str(address_matrix_b)) + print("Offset matrix C: "+str(address_matrix_c)) + return address_matrix_a, matrixA_size, matrixA_size+matrixB_size + +def prepare_outer_product_matrix(a, b, out): + M, K, N = a.shape[0], b.shape[0], b.shape[1] + + prefix = datetime.now().strftime("%m%d%H%M%S%f") + w_sparsity = calculate_sparsity(a) + x_sparsity = calculate_sparsity(b) + print(f"A Sparsity: {w_sparsity}") + print(f"B Sparsity: {x_sparsity}") + assert(x_sparsity >= 0 and x_sparsity < 100) + assert(w_sparsity >= 0 and w_sparsity < 100) + + graph = dict(graph_template) + meta_data = { + # Operation Type + "stonne_operation": "outerProductGEMM", + + # GEMM Parameters + "stonne_GEMM_K": K, + "stonne_GEMM_N": N, + "stonne_GEMM_M": M, + "a_hash" : hashlib.sha512(a.cpu().numpy().tobytes()).hexdigest(), + "b_hash" : hashlib.sha512(b.cpu().numpy().tobytes()).hexdigest(), + } + graph[2].update(meta_data) + + # Create write path + write_path = get_write_path(str(graph)) + os.makedirs(write_path, exist_ok=True) + + # Generating inputs + mem_init = os.path.join(write_path, f'{prefix}_outerproduct_gemm_mem.ini') + a_row_init = os.path.join(write_path, f'{prefix}_outerproduct_gemm_rowpointerA.in') + a_col_init = os.path.join(write_path, f'{prefix}_outerproduct_gemm_colpointerA.in') + b_row_init = os.path.join(write_path, f'{prefix}_outerproduct_gemm_rowpointerB.in') + b_col_init = os.path.join(write_path, f'{prefix}_outerproduct_gemm_colpointerB.in') + c_result = os.path.join(write_path, f'{prefix}_result.out') + trace_path = os.path.join(write_path, "trace.py") + + if not os.path.isfile(trace_path): + dram_a_address, dram_b_address, dram_c_address = generate_outer_product_matrix(a, b, M, K, N, prefix, write_path) + meta_data = { + # Memory Initialization & File Paths + "stonne_mem_init": mem_init, + "stonne_mem_matrix_c_file_name": c_result, + + # Memory Addresses + "stonne_matrix_a_dram_address": dram_a_address, + "stonne_matrix_b_dram_address": dram_b_address, + "stonne_matrix_c_dram_address": dram_c_address, + + # CSR & Bitmap Initialization + "stonne_rowpointer_matrix_a_init": a_row_init, + "stonne_colpointer_matrix_a_init": a_col_init, + "stonne_rowpointer_matrix_b_init": b_row_init, + "stonne_colpointer_matrix_b_init": b_col_init, + "stonne_trace_path": trace_path + } + graph[2].update(meta_data) + + source_code = "graph = " + str(graph) + key, raw_tog_path = write(source_code, "py", specified_dir=write_path) + tile_graph_generator = tog_generator(["flexagon_matmul"]) + tile_graph_generator.load_file(raw_tog_path) + tile_graph_generator.generate_tile_graph( + os.path.join(write_path, "tile_graph.onnx"), + cycle_list=[0], + x_offset=0, + w_offset=0, + vector_lane=0, + stonneGraph=True + ) + onnx_path = os.path.join(write_path, "tile_graph.onnx") + attribute_path = os.path.join(write_path, "attributes") + return onnx_path, attribute_path, c_result + else: # Use trace file to generate onnx graph + tile_graph_generator = tog_generator(["flexagon_matmul"]) + tile_graph_generator.load_file(trace_path) + tile_graph_generator.generate_tile_graph( + os.path.join(write_path, "trace_tile_graph.onnx"), + cycle_list=[0], + x_offset=0, + w_offset=0, + vector_lane=0, + stonneGraph=True + ) + onnx_path = os.path.join(write_path, "trace_tile_graph.onnx") + attribute_path = os.path.join(write_path, "attributes") + return onnx_path, attribute_path, c_result + + + +def sparse_mm_stonne_outer(a, b, out): + onnx_path, attribute_path, c_result_path = prepare_outer_product_matrix(a, b, out) + + backend_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "PyTorchSimBackend") + stonne_config_path = f'{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimBackend/configs/stonne_single_c1_simple_noc.json' + backsim = BackendSimulator(backend_path, stonne_config_path) + result_path = backsim.simulation(onnx_path) + BackendSimulator.get_result_from_file(result_path) + + # Load result data + #with open(c_result_path, 'rb') as f: + # np_array = np.fromfile(f, dtype=TORCH_TO_NUMPY[out.dtype]) + # src_tensor = torch.as_strided(torch.from_numpy(np_array), out.size(), out.stride()) + # out.copy_(src_tensor.to(dtype=out.dtype)) + +def sparse_mm_dummy_stonne_outer(a, b, out): + onnx_path, attribute_path, c_result_path = prepare_outer_product_matrix(a, b, out) + out.copy_(torch.matmul(a.cpu(), b.cpu())) + yield (onnx_path, attribute_path) + + # Load result data + # with open(c_result_path, 'rb') as f: + # np_array = np.fromfile(f, dtype=TORCH_TO_NUMPY[out.dtype]) + # src_tensor = torch.as_strided(torch.from_numpy(np_array), out.size(), out.stride()) + # out.copy_(src_tensor.to(dtype=out.dtype)) + +custom_lib.define("_sparse_mm(Tensor a, Tensor b, Tensor out) -> Tensor") +custom_lib.impl("_sparse_mm", sparse_mm_stonne_outer, "PrivateUse1") +custom_lib.impl("_sparse_mm", sparse_mm_stonne_outer, "AutogradPrivateUse1") \ No newline at end of file diff --git a/PyTorchSimFrontend/llvm/llvm_caller_codegen.py b/PyTorchSimFrontend/llvm/llvm_caller_codegen.py index 4ef03d3f..3690f533 100644 --- a/PyTorchSimFrontend/llvm/llvm_caller_codegen.py +++ b/PyTorchSimFrontend/llvm/llvm_caller_codegen.py @@ -1,6 +1,7 @@ import os import subprocess import shlex +import re from torch._inductor.utils import IndentedBuffer from torch._inductor.codegen import cpp @@ -83,7 +84,7 @@ def generate_args_define(self): self.writeline(self.newline) def generate_load_dump_fn(self): - self.writeline(f'{self.newline}int load_arg(void *arg, int size, const char *path) {self.open_bracket}') + self.writeline(f'{self.newline}int load_arg(void *arg, size_t size, const char *path) {self.open_bracket}') with self.code.indent(): self.writeline(f'int fd = open(path, 0x00000000){self.ending}') self.writeline(f'if (fd == -1) {self.open_bracket}') @@ -99,7 +100,7 @@ def generate_load_dump_fn(self): self.writeline(f'return 0{self.ending}') self.writeline(self.closed_bracket) - self.writeline(f'{self.newline}int dump_arg(void *arg, int size, const char *path) {self.open_bracket}') + self.writeline(f'{self.newline}int dump_arg(void *arg, size_t size, const char *path) {self.open_bracket}') with self.code.indent(): self.writeline(f'int fd = open(path, 0x00000001 | 0x00000040, 0644){self.ending}') self.writeline(f'if (fd == -1) {self.open_bracket}') @@ -174,3 +175,62 @@ def compile_wih_kernel(self, write_path, llvm_name, wrapper_name, binary_name, l print("Command failed with exit code", e.returncode) print("Error output:", e.output) assert(0) + + def parse_stack_sizes(self, file_path, vlenb=256): + with open(file_path, 'r') as f: + stack_sizes_data = f.readlines() + + in_proc = False + stack_base = None + dynamic_expr = None + max_offset = 0 + + for line in stack_sizes_data: + line = line.strip() + if line.startswith(".cfi_startproc"): + in_proc = True + continue + elif line.startswith(".cfi_endproc") and in_proc: + if dynamic_expr: + total_stack = eval(dynamic_expr, {"vlenb": vlenb}) + return total_stack + elif stack_base: + return stack_base + else: + return max_offset + + # Skip outer function + if not in_proc: + continue + + if line.startswith(".cfi_def_cfa_offset"): + stack_base = int(line.split()[-1]) + + if ".cfi_escape" in line and "#" in line: + comment = line.split("#")[-1].strip() + m = re.search(r"sp \+ (\d+)\s*\+\s*(\d+)\s*\*\s*vlenb", comment) + if m: + base, scale = int(m.group(1)), int(m.group(2)) + dynamic_expr = f"{base} + {scale} * vlenb" + + def get_spad_size(self, binary_path): + cmd = ["riscv64-unknown-elf-readelf", "-s", binary_path] + result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + raise RuntimeError(f"Readelf error: {result.stderr}") + + output = result.stdout + spad_start = None + spad_end = None + for line in output.splitlines(): + if '.spad' in line and 'SECTION' in line: + parts = line.split() + spad_start = int(parts[1], 16) + elif 'spad_end' in line: + parts = line.split() + spad_end = int(parts[1], 16) + + if spad_start is None or spad_end is None: + return 0 + spad_size = spad_end - spad_start + return spad_size \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_autotune.py b/PyTorchSimFrontend/mlir/mlir_autotune.py index cea9834b..af101f44 100644 --- a/PyTorchSimFrontend/mlir/mlir_autotune.py +++ b/PyTorchSimFrontend/mlir/mlir_autotune.py @@ -1,8 +1,8 @@ import functools import torch +import dataclasses from torch._inductor.autotune_process import BenchmarkRequest from torch._inductor.autotune_process import TensorMeta -from torch._inductor.codecache import CUDACodeCache from typing import ( Any, @@ -15,8 +15,8 @@ TYPE_CHECKING, Union, ) - -class MLIRBenchmarkRequest(BenchmarkRequest): +@dataclasses.dataclass +class MLIRBenchmarkRequest(): def __init__( self, kernel_name: str, @@ -25,50 +25,41 @@ def __init__( extra_args: Iterable[Any], source_code: str, ): - super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.kernel_name = kernel_name + if isinstance(input_tensor_meta, TensorMeta): + input_tensor_meta = [input_tensor_meta] + self.input_tensor_meta = input_tensor_meta + + if isinstance(output_tensor_meta, TensorMeta): + output_tensor_meta = [output_tensor_meta] + self.output_tensor_meta = output_tensor_meta self.source_code = source_code self.workspace_size: int = 0 self.workspace: Optional[torch.Tensor] = None self.hash_key: str = "" self.source_file: str = "" + self.extra_args = extra_args #self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so") def make_run_fn( - self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + self, input_tensors: torch.Tensor, output_tensors: torch.Tensor ) -> Callable[[], None]: - self.DLL, self.hash_key, self.source_file = CUDACodeCache.load( - self.source_code, "so" - ) + from PyTorchSimFrontend.extension_codecache import CustomAsyncCompile + custom_async_compile = CustomAsyncCompile() + run_method = custom_async_compile.mlir( + self.source_code, vectorlane_size=self.extra_args["vector_lane"], + loop_size=None, spad_info=self.extra_args["spad_info"], + vlen=self.extra_args["vlen"], arg_attributes=self.extra_args["arg_attributes"], + origins="Unknown", silent_mode=True) args = [ - tensor.data_ptr() - for tensor in list(input_tensors) + [output_tensor] + tensor + for tensor in list(input_tensors) + list(output_tensors) ] - - print( - "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, args=%s, self.extra_args=%s", - self.kernel_name, - self.source_file, - self.hash_key, - args, - self.extra_args, - ) - - run_method = getattr(self.DLL, self.kernel_name) - - # Retrieve workspace_size and initialize workspace. - run_method( - *args, # input ptrs and output ptrs - *self.extra_args, - ) - # Generate partial function. return functools.partial( run_method, *args, - *self.extra_args, - None, # null workspace size ptr - None, # set workspace ptr, TODO: update it to a real ptr if workspace_size > 0 ) def __str__(self) -> str: diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index cd99d52e..9a9785e1 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -1,69 +1,153 @@ import os -from typing import List, Optional, cast +from torch import empty_strided +from typing import List, Optional +import sympy from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel -from torch._inductor.ir import Buffer from torch._inductor.ir import IRNode -from torch._inductor.ir import ReinterpretView from torch._inductor.codecache import write_atomic import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common BMM_TEMPLATE = r""" -{% if X_transposed %}#map0 = affine_map<(d0, d1, d2) -> (d0 * {{ K * M }} + d2 * {{ M }} + d1)>{% else %}#map0 = affine_map<(d0, d1, d2) -> (d0 * {{ M * K }} + d1 * {{ K }} + d2)>{% endif %} -{% if W_transposed %}#map1 = affine_map<(d0, d1, d2) -> (d0 * {{ N * K }} + d2 * {{ K }} + d1)>{% else %}#map1 = affine_map<(d0, d1, d2) -> (d0 * {{ K * N }} + d1 * {{ N }} + d2)>{% endif %} -#map2 = affine_map<(d0, d1, d2) -> (d0 * {{ M * N }} + d1 * {{ N }} + d2)> -memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> +// BMM kernel +// BATCH = {{ B }} +// M = {{ M }} +// N = {{ N }} +// K = {{ K }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// SUB_TILE_M = {{ SUB_TILE_M }} +// SUB_TILE_N = {{ SUB_TILE_N }} +{{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index{% if Bias %} - %c_mvin3 = arith.constant 14 : index{% endif %} - %c_mvout = arith.constant 3 : index - %c_set = arith.constant 2 : index - %c{{ TILE_K * 2 + 0}} = arith.constant {{ TILE_K * 2 + 0}} : index - %c0 = arith.constant 0 : index{% if X_transposed %} - %x_chunk = arith.constant {{ kernel.vector_lane * 2 + 0 }} : index{% endif %}{% if W_transposed %} - %w_chunk = arith.constant {{ TILE_K * 2 + 0 }} : index{% endif %} - %M = arith.constant {{ M }} : index - %N = arith.constant {{ N }} : index - %K = arith.constant {{ K }} : index - %X_buffer = memref.get_global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %v0 = arith.constant dense<0.0> : vector<{{ TILE_M * TILE_N // kernel.vector_lane }}xf32> - - affine.for %b=0 to {{ B }} { - affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { - %index2 = affine.apply #map2(%b, %t_m, %t_n){% if Bias %} - affine.dma_start %Bias[ - {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer[0, 0], %tag[0], %c_mvin3, - %{%- if Bias_rank == 2 -%} N {%- else -%} c0 {%- endif -%} - , %c_set : memref< - {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} - xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> - {%- else %} - affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ TILE_M * TILE_N // kernel.vector_lane }}xf32>{% endif %} - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%b, %t_m, %t_k) - %index1 = affine.apply #map1(%b, %t_k, %t_n) - affine.dma_start %X[%index0], %X_buffer[%c0, %c0], %tag[0], %c_mvin, - {%- if X_transposed -%} %M, %x_chunk {%- else -%} %K, %c_set {%- endif -%} - : memref<{{ B * M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ TILE_K }}], async=1{% if X_transposed %}, transpose=1{% endif %} } - affine.dma_start %W[%index1], %W_buffer[%c0, %c0], %tag[0], %c_mvin2, - {%- if W_transposed -%} %K, %w_chunk {%- else -%} %N, %c_set {%- endif -%} - : memref<{{ B * K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ TILE_K }}, {{ kernel.vector_lane }}], async=1{% if W_transposed %}, transpose=1{% endif %} } - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true } - affine.dma_start %Y_buffer[%c0, %c0], %Y[%index2], %tag[0], %c_mvout, %N, %c_set : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<{{ B * M * N }}xf32>, memref<1xi32> { async=1 } - } { outer_loop=true } - } { outer_loop=true } + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + %c0 = arith.constant 0 : index + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0 = 0 to {{ B }} { + affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + {% if Bias -%} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} + {%- else -%} + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + + affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_K], indent_size=10) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[1, SUB_TILE_K, SUB_TILE_N], indent_size=10) }} + linalg.matmul ins(%X_buffer2D, %W_buffer2D : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) + outs(%Y_buffer2D : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) + } { accumulation_loop=true, subtile_loop="k" } + {{kernel.store_output(indent_size=8)}} + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } + } { outer_loop=true } + return +} +""" + +BMM_PROLOGUE_TEMPLATE = r""" +// BMM Prologue kernel +// BATCH = {{ B }} +// M = {{ M }} +// N = {{ N }} +// K = {{ K }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// SUB_TILE_M = {{ SUB_TILE_M }} +// SUB_TILE_N = {{ SUB_TILE_N }} +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + %c0 = arith.constant 0 : index + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0 = 0 to {{ B }} { + affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + {% if Bias -%} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} + {%- else -%} + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { + {{kernel.load_input(indent_size=10)}} + linalg.matmul ins(%X_buffer2D, %W_buffer2D : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) + outs(%Y_buffer2D : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) + } { accumulation_loop=true, subtile_loop="k" } + {{kernel.store_output(indent_size=8)}} + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } + } { outer_loop=true } + return +} +""" + +BMM_REDUCTION_TEMPLATE = r""" +// BMM Reduction kernel +// BATCH = {{ B }} +// M = {{ M }} +// N = {{ N }} +// K = {{ K }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// SUB_TILE_M = {{ SUB_TILE_M }} +// SUB_TILE_N = {{ SUB_TILE_N }} +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + %c0 = arith.constant 0 : index + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0=0 to {{ B }} { + affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { + affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + + {% if Bias -%} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} // Why not N,M? Currently, dma-fine-grained pass assume M->N order... + {%- else -%} + affine.vector_store %v0, %Y_buffer[0, 0, 0] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_K], indent_size=10) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[1, SUB_TILE_K, SUB_TILE_N], indent_size=10) }} + linalg.matmul ins(%X_buffer2D, %W_buffer2D : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) + outs(%Y_buffer2D : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) + } { accumulation_loop=true, subtile_loop="k" } + {{kernel.store_output(indent_size=8)}} + } { outer_loop=true, subtile_loop="m" } + {{kernel.reduction_output(indent_size=6)}} + } { outer_loop=true, subtile_loop="n" } } { outer_loop=true } return } @@ -73,69 +157,170 @@ class MLIRBMMTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) - def is_transposed(self, node): - if isinstance(node, ReinterpretView): - # if node.layout.stride != node.data.layout.stride: - if node.layout.stride[-1] != node.data.layout.stride[-1] or node.layout.stride[-2] != node.data.layout.stride[-2]: - if node.layout.stride[-2] == node.data.layout.stride[-1] and node.layout.stride[-1] == node.data.layout.stride[-2]: - return True - else: - raise NotImplementedError("If the stride is not equal to the original stride, it should have been transposed.") - return False - def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, **kwargs): if template_buffer_node is not None: self.output_node = template_buffer_node - if epilogue_nodes is not None and len(epilogue_nodes) > 0: - self.output_node = cast(Buffer, epilogue_nodes[-1]) + # Extract input arguments info X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - M, N, K = X.get_size()[1], W.get_size()[2], X.get_size()[2] - TILE_M, TILE_N, TILE_K = kernel.gemmini_gemm_mapping(M, N, K) - kernel.tile_size = [TILE_M, TILE_N, TILE_K] - kernel.loop_size = [M, N, K] + W_tensor = empty_strided(W.layout.size, W.layout.stride) + X_tensor = empty_strided(X.layout.size, X.layout.stride) + if len(W_tensor.size()) > 3: + W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]]) + if len(X_tensor.size()) > 3: + X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]]) + B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2] + + W_stride = W_tensor.stride() + X_stride = X_tensor.stride() + + # Select tile size + n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + + TOG_latency = M if TILE_M > M else TILE_M + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + TILE_K = TILE_K // 2 if prologue_nodes else TILE_K + + # Select template code + nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else [] + if nr_reduction_nodes: + template = BMM_REDUCTION_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index2", "index2": "index1"} + nr_rdim = 1 + elif prologue_nodes: + template = BMM_PROLOGUE_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} + nr_rdim = 0 + else: + template = BMM_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} + nr_rdim = 0 - W_transposed = self.is_transposed(W) - X_transposed = self.is_transposed(X) + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 2 + loop_dim = [sympy.Symbol("index0"), sympy.Symbol("index1"), sympy.Symbol("index2"), sympy.Symbol("index3")] + X_tile_size = [1, TILE_M, TILE_K] + X_tile_stride = [0, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("X_buffer") + X_stride = X_tensor.stride() + X_idx = [loop_dim[0]*X_stride[0], loop_dim[1]*X_stride[1], loop_dim[3]*X_stride[2]] # To keep index arguemnt order, we used index_list - options = dict( + W_tile_size = [1, TILE_K, TILE_N] + W_tile_stride = [0, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("W_buffer") + W_stride = W_tensor.stride() + W_idx = [loop_dim[0]*W_stride[0], loop_dim[3]*W_stride[1], loop_dim[2]*W_stride[2]] + + vlane_split_axis = vlane_split_axis if nr_rdim==0 else 1 + Y_tile_size = [1, TILE_M, TILE_N] if nr_rdim == 0 else [1, TILE_N, TILE_M] + Y_tile_stride=[0, 1, TILE_M] if nr_rdim == 0 else [0, TILE_M, 1] + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("Y_buffer") + Y_stride = Y.get_layout().stride + if nr_rdim == 0: + Y_idx = [loop_dim[0]*Y_stride[0], loop_dim[1]*Y_stride[1], loop_dim[2]*Y_stride[2]] + else: + Y_idx = [loop_dim[0]*Y_stride[0], loop_dim[2]*Y_stride[2], loop_dim[1]*Y_stride[1]] + + # Extract Bias info + if Bias is not None: + Bias_stride = Bias.get_layout().stride + if nr_rdim == 0: + Bias_idx = [loop_dim[0]*Bias_stride[0], loop_dim[1]*Bias_stride[1], loop_dim[2]*Bias_stride[2]] + else: + Bias_idx = [loop_dim[0]*Bias_stride[0], loop_dim[2]*Bias_stride[2], loop_dim[1]*Bias_stride[1]] + else: + Bias_idx = None + + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - B=X.get_size()[0], - M=M, - N=N, - K=K, - TILE_M=TILE_M, - TILE_N=TILE_N, - TILE_K=TILE_K, + B=B, M=M, N=N, K=K, + TILE_M=TILE_M, TILE_N=TILE_N, TILE_K=TILE_K, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, DATA_STYPE="f32", - DATA_SIZE=4, - X = X, - W = W, - Y = Y, - Bias = Bias, - Bias_rank = len(Bias.data.get_size()) if Bias is not None else 0, - W_transposed = W_transposed, - X_transposed = X_transposed, + X = X, W = W,Y = Y, Bias = Bias, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, input_reorder = self.input_reorder ) - code = self._template_from_string(BMM_TEMPLATE).render(**options) - kernel.add_loop_info([options["M"], options["N"], options["K"]], [options["TILE_M"], options["TILE_N"], options["TILE_K"]]) - self.header = f"float X_spad[{TILE_M * TILE_K // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float W_spad[{TILE_K * TILE_N // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{TILE_M * TILE_N // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{TILE_M * TILE_K}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float W_spad[{TILE_K * TILE_N}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{TILE_M * TILE_N}] __attribute__ ((section(\".spad\")));\n" + if prologue_nodes: + prologue_output_name = list(prologue_nodes[0].read_writes.writes)[0].name + if prologue_output_name == X.get_name(): + # Input fusion case + prologue_var = "X" + prologue_sram_var = "X_buffer" + prologue_tile_desc = X_tile_desc + prologue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2":"index3"} + is_input_fused = True + else: + # Weight fusion case + prologue_var = "W" + prologue_sram_var = "W_buffer" + prologue_tile_desc = W_tile_desc + prologue_dim_aliasing = {"index0":"index0", "index1":"index3", "index2":"index2"} + is_input_fused = False + + kernel.prologue_info = dict ( + input_dram_var = "X", + input_sram_var = "X_buffer", + input_tile_desc = X_tile_desc, + input_idx = X_idx, + input_subtile_size = [1, TILE_M, TILE_K], # TODO. Curently, Subtiling is not supported for prologue template + input_dim_aliasing = {"index0":"index0", "index1":"index1", "index2":"index3"}, + weight_dram_var = "W", + weight_sram_var = "W_buffer", + weight_tile_desc = W_tile_desc, + weight_idx = W_idx, + weight_subtile_size = [1, TILE_K, TILE_N], # TODO. Curently, Subtiling is not supported for prologue template + weight_dim_aliasing = {"index0":"index0", "index1":"index3", "index2":"index2"}, + + # Descriptor for fusion + dram_var = prologue_var, + sram_var = prologue_sram_var, + dram_tile_desc = prologue_tile_desc, + dim_aliasing = prologue_dim_aliasing, + is_bmm = True, + is_input_fused = is_input_fused + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "Y_buffer", + dram_var = "Y", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + nr_rdim = nr_rdim, + dim_aliasing = epilogue_dim_aliasing + ) + code = self._template_from_string(template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code def codegen_header(self, code, extra_headers): @@ -145,6 +330,6 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) \ No newline at end of file + write_atomic(gem5_write_path, extra_headers[1]) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py index 92f250df..3fff9958 100644 --- a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py +++ b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py @@ -50,10 +50,12 @@ def generate_kernel_declare(self): def generate_args_define(self): name_set = set() - for arg_name, (_, arg_type, arg_size) in self.arg_attributes: + if self.validation: + self.writeline(f'int padding[0x100000]{self.ending}') # FIXME. For pooling operation... Some pooling layer use negative offset + for arg_name, (_, arg_type, arg_size, arg_sizes, arg_stride) in self.arg_attributes: if not arg_name in name_set: if self.validation: - self.writeline(f'{DTYPE_TO_C[arg_type]} c_{arg_name}[{arg_size}]{self.ending}') + self.writeline(f'{DTYPE_TO_C[arg_type]} c_{arg_name}[{arg_size}ULL]{self.ending}') else: if torch.is_floating_point(torch.tensor([], dtype=arg_type)): bits = torch.finfo(arg_type).bits @@ -61,7 +63,7 @@ def generate_args_define(self): bits = 8 else: bits = torch.iinfo(arg_type).bits - self.writeline(f'{DTYPE_TO_C[arg_type]}* c_{arg_name} = malloc({arg_size * bits // 8}){self.ending}') + self.writeline(f'{DTYPE_TO_C[arg_type]}* c_{arg_name} = malloc({arg_size * bits // 8}ULL){self.ending}') name_set.add(arg_name) self.writeline(self.newline) @@ -77,7 +79,7 @@ def generate_main(self): else: self.generate_args_define() - func_arguments = [f"c_{arg_name}, c_{arg_name}, 0, {arg_shape}, 1" if arg_type != torch.bool else f"c_{arg_name}, c_{arg_name}, 0, {(arg_shape + 7) // 8}, 1" for arg_name, (_, arg_type, arg_shape) in self.arg_attributes] + func_arguments = [f"c_{arg_name}, c_{arg_name}, 0, {arg_shape}, 1" for arg_name, (_, arg_type, arg_shape, _, _) in self.arg_attributes] self.writeline(f"wrapper_{self.kernel_name}({', '.join(func_arguments)}){self.ending}{self.newline}") if self.validation: diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 09c5aa34..ff87c1d3 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1,31 +1,26 @@ -import dataclasses import contextlib import sympy -import itertools import re import os import math -from functools import reduce -from operator import mul -from typing import List -from typing import Dict -from collections import OrderedDict import torch -from torch._inductor import dependencies, config -from torch._inductor.codegen import cpp, wrapper, common -from torch._inductor.scheduler import BaseScheduling +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from torch._dynamo.utils import dynamo_timed +from torch._inductor.codegen import cpp, wrapper, common, memory_planning from torch._inductor.virtualized import V, _ops as ops from torch._inductor.codecache import write_atomic, write -from Simulator.simulator import BackendSimulator -from PyTorchSimFrontend import extension_config from torch._inductor.utils import ( IndentedBuffer, is_welford_reduction, + sympy_product ) +from torch.utils._sympy.functions import ModularIndexing import PyTorchSimFrontend.extension_codecache as extension_codecache - +from PyTorchSimFrontend import extension_config from . import mlir_common +from .mlir_common import LoopLevel, LoopNest def reduction_init(reduction_type, dtype): if dtype in cpp.DTYPE_LOWP_FP: @@ -37,28 +32,34 @@ def reduction_init(reduction_type, dtype): if reduction_type == "prod": return float(1) if dtype.is_floating_point else int(1) if reduction_type in {"max", "argmax"}: - return "0.0" + if dtype == torch.float32: + return f"0x{mlir_common.MLIR_INF['-inf']['f32']:x}" + elif dtype == torch.float64: + return f"0x{mlir_common.MLIR_INF['-inf']['f64']:x}" + else: + return "0.0" if reduction_type in {"min", "argmin"}: - return "0.0" + if dtype == torch.float32: + return f"0x{mlir_common.MLIR_INF['inf']['f32']:x}" + elif dtype == torch.float64: + return f"0x{mlir_common.MLIR_INF['inf']['f64']:x}" + else: + return "0.0" if reduction_type in {"welford_reduce"}: return f"0.0" raise AssertionError(reduction_type) -def reduction_combine(reduction_type, start_value, vector_value, tile_size=64): +def reduction_partial_combine_vec(reduction_type, vector_value, init_value): if reduction_type == "sum": - return f"arith.addf %{start_value}, %{vector_value} : vector<{tile_size}xf32>" + return ops.add(vector_value, init_value) if reduction_type == "prod": - return f"arith.mulf %{start_value}, %{vector_value} : vector<{tile_size}xf32>" - if reduction_type == "xor_sum": - raise NotImplementedError() # TODO: implement + return ops.mul(vector_value, init_value) + if reduction_type == "max": + return ops.maximum(vector_value, init_value) + if reduction_type == "min": + return ops.minimum(vector_value, init_value) if reduction_type == "any": - raise NotImplementedError() - if reduction_type in ("min", "max"): - raise NotImplementedError() - if reduction_type == "welford_reduce": - raise NotImplementedError() - if reduction_type == "welford_combine": - raise NotImplementedError() + return ops.logical_and(vector_value, init_value) raise AssertionError(reduction_type) def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape, reduced_shape): @@ -94,6 +95,9 @@ def write_header(self): from torch import device, empty, empty_strided from {extension_codecache.__name__} import CustomAsyncCompile + from PyTorchSimFrontend.extension_config import CONFIG_SRAM_BUFFER_PLAN, CONFIG_BACKENDSIM_EAGER_MODE + from Simulator.simulator import BackendSimulator + from PyTorchSimFrontend.extension_op import sparse_mm_dummy_stonne_outer from torch._inductor.select_algorithm import extern_kernels aten = torch.ops.aten @@ -101,15 +105,120 @@ def write_header(self): assert_size_stride = torch._C._dynamo.guards.assert_size_stride alloc_from_pool = torch.ops.inductor._alloc_from_pool reinterpret_tensor = torch.ops.aten._reinterpret_tensor - async_compile = CustomAsyncCompile() + custom_async_compile = CustomAsyncCompile() os.environ["TORCHSIM_LAST_COMPILED_MODULE"] = __file__ """ ) + self.header.splice( + f""" + def sram_plan_prefix(buffer_name, buffer): + #if CONFIG_SRAM_BUFFER_PLAN is None: + # return + #elif buffer_name not in CONFIG_SRAM_BUFFER_PLAN: + # return + buffer_size = buffer.element_size() * buffer.untyped_storage().size() + start = buffer.data_ptr() + end = start + buffer_size + # print(f'Alloc {{buffer_name}}(0x{{start:x}} ~ 0x{{end:x}})') + BackendSimulator.sram_alloc(buffer_name, [start, end]) + + def sram_plan_postfix(buffer_name, buffer): + #if CONFIG_SRAM_BUFFER_PLAN is None: + # return + #elif buffer_name not in CONFIG_SRAM_BUFFER_PLAN: + # return + buffer_size = buffer.element_size() * buffer.untyped_storage().size() + start = buffer.data_ptr() + end = start + buffer_size + # print(f'Dealloc {{buffer_name}}(0x{{start:x}} ~ 0x{{end:x}})') + BackendSimulator.sram_dealloc(buffer_name, [start, end]) + + def host2device_memcopy(buffer): + pass + + def device2host_memcpy(buffer): + pass + """ + ) + + def write_prefix(self): + self.prefix.splice( + """ + def call(args): + """ + ) + with self.prefix.indent(): + inp_len = len(V.graph.graph_inputs.keys()) + if inp_len != 0: + lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}" + self.prefix.writeline(f"{lhs} = args") + self.prefix.writeline("args.clear()") + + self.codegen_inputs(self.prefix, V.graph.graph_inputs) + self.codegen_input_size_asserts() + self.codegen_sram_plan_prefix() + + def codegen_sram_plan_prefix(self): + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + if sympy_product(buf.get_size()) == 0: + continue + if buf is None: + continue + self.prefix.writeline(f"sram_plan_prefix('{name}', {name})") + + def codegen_sram_plan_postfix(self, outputs): + for name in outputs: + if name is None or name == "None": + continue + self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") + + @dynamo_timed + def generate(self, is_inference): + result = IndentedBuffer() + result.splice(self.header) + with contextlib.ExitStack() as stack: + stack.enter_context(self.wrapper_call.indent()) + self.memory_plan_reuse() + for line in self.lines: + # Add buffer plan hook for dealloc + if isinstance(line, memory_planning.DeallocFromPoolLine): + self.wrapper_call.writeline(f"sram_plan_postfix('{line.node.get_name()}', {line.node.get_name()})") + elif isinstance(line, str) and "del" in line: + name = line.split(" ")[1] + self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") + + if isinstance(line, wrapper.MemoryPlanningLine): + line.codegen(self.wrapper_call) + else: + self.wrapper_call.writeline(line) + # Add buffer plan hook for alloc + if isinstance(line, memory_planning.AllocFromPoolLine) or isinstance(line, wrapper.AllocateLine): + self.wrapper_call.writeline(f"sram_plan_prefix('{line.node.get_name()}', {line.node.get_name()})") + output_refs = self.get_output_refs() + self.codegen_sram_plan_postfix(output_refs) + self.mark_output_type() + self.generate_return(output_refs) + + self.append_precomputed_sizes_to_prefix() + self.finalize_prefix() + result.splice(self.prefix) + + with result.indent(): + result.splice(self.wrapper_call) + + self.generate_end(result) + self.add_benchmark_harness(result) + return result.getvaluewithlinemap() + + def memory_plan(self): + self.lines = memory_planning.MemoryPlanner(self).plan(self.lines) class ExtensionOverrides(common.OpOverrides): # Binary element wise operations @staticmethod - def custom_cast(operand, target_type, *args, var_info=None): + def custom_cast(operand, target_type, *args, var_info=None, **kwargs): dtype = var_info[operand][1] if dtype == "index": ret = ops.index_cast(operand, target_type, var_info=var_info) @@ -119,6 +228,8 @@ def custom_cast(operand, target_type, *args, var_info=None): @staticmethod def binary_elementwise_common(operand1, operand2, var_info): + operand1.bounds = operand1.bounds.unknown() + operand2.bounds = operand2.bounds.unknown() op_type1 = var_info[operand1] op_type2 = var_info[operand2] # Tile size check @@ -148,6 +259,13 @@ def binary_elementwise_common(operand1, operand2, var_info): elif op_type1[1][0] == "f" and op_type2[1][0] == "i": operand2 = ops.to_dtype(operand2, op_type1[1], var_info) op_type2 = var_info[operand2] + elif op_type1[1][0] == op_type2[1][0]: + if int(op_type1[1][1:]) > int(op_type2[1][1:]): + operand2 = ops.ext(operand2, op_type1[1]) + op_type2 = var_info[operand2] + elif int(op_type1[1][1:]) < int(op_type2[1][1:]): + operand1 = ops.ext(operand1, op_type2[1]) + op_type1 = var_info[operand1] else: raise NotImplementedError("Unsupported type converting") @@ -157,28 +275,28 @@ def binary_elementwise_common(operand1, operand2, var_info): return tile_size, ret_type, operand1, operand2 @staticmethod - def add(operand1, operand2, *args, var_info=None): + def add(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.add{ret_type[0]}' return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def sub(operand1, operand2, *args, var_info=None): + def sub(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.sub{ret_type[0]}' return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def mul(operand1, operand2, *args, var_info=None): + def mul(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.mul{ret_type[0]}' return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def div(operand1, operand2, *args, var_info=None): + def div(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type[0] == "f": @@ -188,7 +306,7 @@ def div(operand1, operand2, *args, var_info=None): return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def truediv(operand1, operand2, *args, var_info=None): + def truediv(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type[0] == "f": @@ -198,7 +316,17 @@ def truediv(operand1, operand2, *args, var_info=None): return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def minimum(operand1, operand2, *args, var_info=None): + def modular(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + if ret_type[0] == "f": + raise NotImplementedError("Not support remainder operation for floating point") + else: + opcode = f'arith.remui' + return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + + @staticmethod + def minimum(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type[0] == "f": @@ -208,7 +336,7 @@ def minimum(operand1, operand2, *args, var_info=None): return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def maximum(operand1, operand2, *args, var_info=None): + def maximum(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type[0] == "f": @@ -218,50 +346,123 @@ def maximum(operand1, operand2, *args, var_info=None): return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def to_dtype(operand, dst_mlir_dtype, *args, var_info=None): + def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): src_mlir_dtype = var_info[operand][1] tile_size = var_info[operand][0] - + if isinstance(dst_mlir_dtype, torch.dtype): + dst_mlir_dtype = mlir_common.DTYPE_TO_MLIR[dst_mlir_dtype] dst_bits = int(dst_mlir_dtype[1:]) src_bits = int(src_mlir_dtype[1:]) shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype src_shape = f"vector<{tile_size}x{src_mlir_dtype}>" if tile_size > 1 else src_mlir_dtype if dst_mlir_dtype[0] == "i" and src_mlir_dtype[0] == "f": - raise NotImplementedError("floating point to integer conversion") + return f"arith.fptoui%{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] if dst_mlir_dtype[0] == "f" and src_mlir_dtype[0] == "i": - raise NotImplementedError("integer to floating point conversion") - else: + return f"arith.uitofp%{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] + if dst_mlir_dtype[0] == "i": if dst_bits > src_bits: - return f"arith.extui %{operand} : {src_shape} to {shape}" + return f"arith.extui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] elif dst_bits < src_bits: - return f"arith.trunc %{operand} : {src_shape} to {shape}" + return f"arith.trunc %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] + return f"arith.maximumi %{operand}, %{operand} : {shape}", [tile_size, dst_mlir_dtype] + elif dst_mlir_dtype[0] == "f": + if dst_bits > src_bits: + return f"arith.extf %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] + elif dst_bits < src_bits: + return f"arith.trunf %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] + return f"arith.maximumf %{operand}, %{operand} : {shape}", [tile_size, dst_mlir_dtype] + else: + raise NotImplementedError("Unsupported type for to_dtype ops") @staticmethod - def constant(value, src_type, *args, var_info=None): + def constant(value, src_type, *args, var_info=None, **kwargs): if isinstance(src_type, torch.dtype): src_type = mlir_common.DTYPE_TO_MLIR[src_type] + if "inf" == str(value) or "-inf" == str(value) or "nan" == str(value): + value = f"0x{mlir_common.MLIR_INF[str(value)][src_type]:x}" # if value represented by e notation, convert to float (ex 1e-3 -> 1.0e-3) - if "e" in str(value): - value = float(value) - if src_type[0] == "f": + elif "e" in str(value): + value = format(float(value), ".20f") + elif src_type[0] == "f": value = format(value, ".20f") - if src_type[0] == "i": + elif src_type[0] == "i": value = int(value) return f'arith.constant {value} : {src_type}', [1, src_type] - # transcendental functions @staticmethod - def exp(operand, *args, var_info=None): + def alloc(size, src_type, *args, var_info=None, **kwargs): + return f"memref.alloc() : memref<{size}x{src_type}>", [size, src_type] + + @staticmethod + def extractelement(operand, idx, *args, var_info=None, **kwargs): op_type = var_info[operand] tile_size = op_type[0] dtype = op_type[1] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f"vector.extract %{operand}[{idx}]: {dtype} from {shape}", [1, dtype] + # transcendental functions + @staticmethod + def exp(operand, *args, var_info=None, **kwargs): + # Check scalar + op_type = var_info[operand] + if op_type[0] == 1: + val = ops.constant(0, op_type[1]) + var_info[val][0] = 4 + operand = ops.broadcast(operand, val) + val = ops.exp(operand) + result = ops.extractelement(val, 0) + return result, var_info[result] + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype return f'math.exp %{operand} : {shape}', [tile_size, dtype] @staticmethod - def sqrt(operand, *args, var_info=None): + def erf(operand, *args, var_info=None, **kwargs): + # Check scalar + op_type = var_info[operand] + if op_type[0] == 1: + val = ops.constant(0, op_type[1]) + var_info[val][0] = 4 + operand = ops.broadcast(operand, val) + val = ops.exp(operand) + result = ops.extractelement(val, 0) + return result, var_info[result] + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.erf %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def tanh(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + + # Check scalar + op_type = var_info[operand] + if op_type[0] == 1: + val = ops.constant(0, op_type[1]) + var_info[val][0] = 4 + operand = ops.broadcast(operand, val) + val = ops.exp(operand) + result = ops.extractelement(val, 0) + return result, var_info[result] + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.tanh %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def sqrt(operand, *args, var_info=None, **kwargs): op_type = var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -275,7 +476,7 @@ def sqrt(operand, *args, var_info=None): return f'math.sqrt %{operand} : {shape}', [tile_size, dtype] @staticmethod - def rsqrt(operand, *args, var_info=None): + def rsqrt(operand, *args, var_info=None, **kwargs): op_type = var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -289,29 +490,23 @@ def rsqrt(operand, *args, var_info=None): return f'math.rsqrt %{operand} : {shape}', [tile_size, dtype] @staticmethod - def pow(operand1, operand2, *args, var_info=None): - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] - + def pow(operand1, operand2, *args, var_info=None, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) # Type check & auto cast - if op_type1[1][0] != "f": - operand1, dtype = ops.to_dtype(operand1, "f32", var_info=var_info) - var_info[operand1] = dtype + if ret_type[0] != "f": + operand1, ret_type = ops.to_dtype(operand1, "f32", var_info=var_info) + var_info[operand1] = ret_type # Type check & auto cast - if op_type2[1][0] != "f": - operand2, dtype = ops.to_dtype(operand2, "f32", var_info=var_info) - var_info[operand2] = dtype - - op_type1 = var_info[operand1] - tile_size = op_type1[0] - dtype = op_type1[1] + if ret_type[0] != "f": + operand2, ret_type = ops.to_dtype(operand2, "f32", var_info=var_info) + var_info[operand2] = ret_type - shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f"math.pow{dtype[0]} %{operand1}, %{operand2} : {shape}", [] + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + return f"math.pow{ret_type[0]} %{operand1}, %{operand2} : {shape}", [tile_size, ret_type] @staticmethod - def log(operand, *args, var_info=None): + def log(operand, *args, var_info=None, **kwargs): op_type = var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -325,7 +520,7 @@ def log(operand, *args, var_info=None): return f'math.log %{operand} : {shape}', [tile_size, dtype] @staticmethod - def reciprocal(operand, *args, var_info): + def reciprocal(operand, *args, var_info=None, **kwargs): op_type = var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -337,9 +532,20 @@ def reciprocal(operand, *args, var_info): return ops.div(ops.constant(1.0, dtype), operand), [tile_size, dtype] + @staticmethod + def ext(operand, dtype, *args, var_info=None, **kwargs): + op_type = var_info[operand] + shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else f"{op_type[1]}" + target_type = f"vector<{op_type[0]}x{dtype}>" if op_type[0] > 1 else f"{dtype}" + if op_type[0] == "f": + opcode = f'arith.extf' + else: + opcode = f'arith.extui' + return f'{opcode} %{operand} : {shape} to {target_type}', [op_type[0], dtype] + # Logical operations @staticmethod - def neg(operand, *args, var_info=None): + def neg(operand, *args, var_info=None, **kwargs): op_type = var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -353,7 +559,7 @@ def neg(operand, *args, var_info=None): return f'arith.negf %{operand} : {shape}', [tile_size, dtype] @staticmethod - def eq(operand1, operand2, *args, var_info=None): + def eq(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) if ret_type[0] == "f": op_type = "arith.cmpf" @@ -368,7 +574,7 @@ def eq(operand1, operand2, *args, var_info=None): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def ne(operand1, operand2, *args, var_info=None): + def ne(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) if ret_type[0] == "f": op_type = "arith.cmpf" @@ -383,7 +589,7 @@ def ne(operand1, operand2, *args, var_info=None): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def lt(operand1, operand2, *args, var_info=None): + def lt(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) if ret_type[0] == "f": op_type = "arith.cmpf" @@ -398,7 +604,7 @@ def lt(operand1, operand2, *args, var_info=None): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def gt(operand1, operand2, *args, var_info=None): + def gt(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) if ret_type[0] == "f": op_type = "arith.cmpf" @@ -413,7 +619,7 @@ def gt(operand1, operand2, *args, var_info=None): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def le(operand1, operand2, *args, var_info=None): + def le(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) if ret_type[0] == "f": op_type = "arith.cmpf" @@ -428,7 +634,7 @@ def le(operand1, operand2, *args, var_info=None): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def ge(operand1, operand2, *args, var_info=None): + def ge(operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) if ret_type[0] == "f": op_type = "arith.cmpf" @@ -443,7 +649,7 @@ def ge(operand1, operand2, *args, var_info=None): return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] @staticmethod - def and_(operand1, operand2, *args, var_info=None): + def and_(operand1, operand2, *args, var_info=None, **kwargs): op_type1 = var_info[operand1] op_type2 = var_info[operand2] @@ -464,7 +670,7 @@ def and_(operand1, operand2, *args, var_info=None): return f'arith.andi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def or_(operand1, operand2, *args, var_info=None): + def or_(operand1, operand2, *args, var_info=None, **kwargs): op_type1 = var_info[operand1] op_type2 = var_info[operand2] @@ -485,7 +691,7 @@ def or_(operand1, operand2, *args, var_info=None): return f'arith.ori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] @staticmethod - def xor(operand1, operand2, *args, var_info=None): + def xor(operand1, operand2, *args, var_info=None, **kwargs): op_type1 = var_info[operand1] op_type2 = var_info[operand2] @@ -507,30 +713,50 @@ def xor(operand1, operand2, *args, var_info=None): @staticmethod - def logical_and(operand, *args, var_info=None): - raise NotImplementedError("logical_and") + def logical_and(operand1, operand2, *args, var_info=None, **kwargs): + op_type = var_info[operand1] + # Type check & auto cast + if op_type[1] != "i1": + raise NotImplementedError("Logical operation with not bool data type") + return ExtensionOverrides.and_(operand1, operand2, *args, var_info=var_info, **kwargs) @staticmethod - def logical_not(operand, *args, var_info=None): - raise NotImplementedError("logical_not") + def logical_not(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + + ret_type = op_type[1] + tile_size = op_type[0] + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + const_one = ops.constant(0, ret_type) + const_one = ops.broadcast(const_one, operand, var_info=var_info) + ret = ops.eq(operand,const_one) + return ret, [tile_size, var_info[ret]] @staticmethod - def logical_or(operand, *args, var_info=None): - raise NotImplementedError("logical_not") + def logical_or(operand1, operand2, *args, var_info=None, **kwargs): + op_type = var_info[operand1] + # Type check & auto cast + if op_type[1] != "i1": + raise NotImplementedError("Logical operation with not bool data type") + return ExtensionOverrides.or_(operand1, operand2, *args, var_info=var_info, **kwargs) @staticmethod - def logical_xor(operand, *args, var_info=None): - raise NotImplementedError("logical_not") + def logical_xor(operand1, operand2, *args, var_info=None, **kwargs): + op_type = var_info[operand1] + # Type check & auto cast + if op_type[1] != "i1": + raise NotImplementedError("Logical operation with not bool data type") + return ExtensionOverrides.xor(operand1, operand2, *args, var_info=var_info, **kwargs) @staticmethod - def relu(operand, *args, var_info=None): + def relu(operand, *args, var_info=None, **kwargs): op_type = var_info[operand] tile_size = op_type[0] ret_type = "f32" return ops.maximum(operand, ops.constant(0.0, "f32")), [tile_size, ret_type] @staticmethod - def sigmoid(operand, *args, var_info=None): + def sigmoid(operand, *args, var_info=None, **kwargs): op_type = var_info[operand] tile_size = op_type[0] ret_type = "f32" @@ -539,7 +765,7 @@ def sigmoid(operand, *args, var_info=None): # Special operaitons @staticmethod - def where(condition, operand1, operand2, *args, var_info=None): + def where(condition, operand1, operand2, *args, var_info=None, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) cond_type = var_info[condition] if cond_type[0] < tile_size: @@ -555,31 +781,12 @@ def where(condition, operand1, operand2, *args, var_info=None): @staticmethod - def masked(mask, body, other, *args, var_info=None, tile_size=16, dtype="f32", ninf_declared=False): + def masked(mask, body, other, *args, var_info=None, tile_size=16, dtype="f32", ninf_declared=False, **kwargs): result = body() - val = ops.constant(0.0, "f32") + val = ops.constant(other, dtype, *args, **kwargs) result = ops.where(mask, result, val) return result, var_info[result] - @staticmethod - def _index_expr(operand, *args, var_info=None, **kwargs): - symbols = sorted([str(i) for i in operand.free_symbols]) - renamed_symbols = {symbol: sympy.Symbol(f"d{i}") for i, symbol in enumerate(symbols)} - - renamed_expression = operand.subs(renamed_symbols) - - affine_map_str = "(" + ", ".join([f"d{i}" for i in range(len(symbols))]) + ") -> (" - affine_map_str += sympy.printing.ccode(renamed_expression) + ")" - - map_operands = [f"%{str(symbol)}" for symbol in symbols] - return f"affine.apply affine_map<{affine_map_str}>({', '.join(map_operands)})", [1, "index"] - - @staticmethod - def index_expr(operand, *args, var_info=None, **kwargs): - result = ops._index_expr(operand) - ret_type = [1, "index"] - return result, ret_type - @staticmethod def index_cast(operand, target_type, *args, var_info=None, **kwrags): op_type = var_info[operand] @@ -587,14 +794,32 @@ def index_cast(operand, target_type, *args, var_info=None, **kwrags): des_shape = f"vector<{op_type[0]}x{target_type}>" if op_type[0] > 1 else target_type return f"arith.index_cast %{operand} : {src_shape} to {des_shape}", [op_type[0], target_type] + @staticmethod + def broadcast_unflat(operand1, operand2, *args, var_info=None, **kwargs): + op_type1 = var_info[operand1] + op_type2 = var_info[operand2] + src_shape = f"vector<{op_type1[0]}x{op_type1[1]}>"# if op_type1[0] > 1 else op_type1[1] + des_shape = f"vector<{op_type2[0]//op_type1[0]}x{op_type1[0]}x{op_type1[1]}>"# if op_type2[0] > 1 else op_type1[1] # Use tile size only + + expand = f"vector.broadcast %{operand1} : {src_shape} to {des_shape}" + return expand, [op_type2[0], op_type1[1]] @staticmethod - def broadcast(operand1, operand2, *args, var_info=None): + def broadcast(operand1, operand2, *args, var_info=None, **kwargs): op_type1 = var_info[operand1] op_type2 = var_info[operand2] src_shape = f"vector<{op_type1[0]}x{op_type1[1]}>" if op_type1[0] > 1 else op_type1[1] - des_shape = f"vector<{op_type2[0]}x{op_type1[1]}>" if op_type2[0] > 1 else op_type1[1] # Use tile size only - expand = f"vector.broadcast %{operand1} : {src_shape} to {des_shape}" + des_shape = f"vector<{op_type2[0]}x{op_type1[1]}>" # if op_type2[0] > 1 else op_type1[1] # Use tile size only + + # Special case for length 2 vector. We used this vector to avoid scalar operations... + if op_type1[0] != 1 and op_type2[0] % op_type1[0] == 0: + unflat_operand = ops.broadcast_unflat(operand1, operand2) + unflat_shape = f"vector<{op_type2[0]//op_type1[0]}x{op_type1[0]}x{op_type1[1]}>" + expand = f"vector.shape_cast %{unflat_operand} : {unflat_shape} to {des_shape}" + elif op_type1[0] == 1: + expand = f"vector.broadcast %{operand1} : {src_shape} to {des_shape}" + else: + raise NotImplementedError("Not supporting broadcast type...") return expand, [op_type2[0], op_type1[1]] RTYPE_TO_MLIR = { @@ -606,363 +831,268 @@ def broadcast(operand1, operand2, *args, var_info=None): "MVIN1": 2, "MVIN2": 1, "MVIN3": 14, - "MVOUT": 3, + "MVOUT1": 3, } -class MLIRTile(): - TILE_ROW_WISE = 0 - TILE_COL_WISE = 1 - TILE_PER_LANE_ROW_WISE = 2 - TILE_PER_LANE_COL_WISE = 3 - def __init__(self, n_row, n_col, vector_lane, used_vector_lane=None) -> None: - self.n_row = n_row - self.n_col = n_col - self.vector_lane = vector_lane - if used_vector_lane is None: - self.used_vector_lane = self.vector_lane - else: - self.used_vector_lane = used_vector_lane - self.tile_per_lane_layout = self.TILE_PER_LANE_ROW_WISE # How a given tile per lane is stored - self.tile_layout = self.TILE_ROW_WISE # How a given tile is stored per lane - self.vector_lane_axis = (self.n_col//self.used_vector_lane) > 0 #(0: Col major, 1: Row major) - - def get_tile_size(self): - return self.n_row * self.n_col - - def get_rows_per_lane(self): - if self.n_row % self.used_vector_lane != 0 and self.n_row > 1: - print(f"[Warning] n_row({self.n_row}) % vector_lane({self.used_vector_lane}) != 0") - return self.div_round_up(self.n_row, self.used_vector_lane) - - def get_cols_per_lane(self): - if self.n_col % self.used_vector_lane != 0 and self.n_col > 1: - print(f"[Warning] n_col({self.n_col}) % vector_lane({self.used_vector_lane}) != 0") - return self.div_round_up(self.n_col, self.used_vector_lane) - - def get_tile_size_per_lane(self): - if self.get_tile_size() % self.used_vector_lane != 0: - print(f"[Warning] n_col({self.n_col}) % vector_lane({self.used_vector_lane}) != 0") - return self.div_round_up(self.get_tile_size(), self.used_vector_lane) - - def get_tile_shape(self): - return f"{self.n_row}x{self.n_col}" - - def get_chunk_size(self): - if self.tile_layout == self.TILE_ROW_WISE: - chunk_size = self.get_tile_size_per_lane() - else: - chunk_size = self.get_cols_per_lane() - return chunk_size - - @staticmethod - def div_round_up(size, round_val): - return (size + round_val - 1) // round_val - class MLIRKernel(mlir_common.BaseMLIRKernel): overrides = ExtensionOverrides newvar_prefix = "%" - def __init__(self): - super().__init__(mlir_common.MLIRKernelArgs()) - self.kernel_group = None - self.call_ranges = None - self.ranges = None - self.itervars = None - self.reduction_depth = None + def __init__(self, kernel_group, reason=None): + super().__init__(kernel_group, reason=reason) + self.const_buffer = IndentedBuffer() + self.alloc_buffer = IndentedBuffer() + self.spad_buffer = IndentedBuffer() self.reduction_prefix = IndentedBuffer() self.reduction_suffix = IndentedBuffer() - self.body = IndentedBuffer() + self.applys = IndentedBuffer() + self.masks = IndentedBuffer() + self.dma_loads = IndentedBuffer() + self.dma_stores = IndentedBuffer() + self.indexed_buffer = IndentedBuffer() self.global_vars = IndentedBuffer() - self.global_vars_dict = dict() self.header = IndentedBuffer() self.gem5_header = IndentedBuffer() - self.reduction_vars = {} + self.header.writeline("#include ") + self.header.writeline("#include ") + self.header.writeline("void* __wrap_malloc(size_t size) { return sbrk(size); }") + self.header.writeline("void __wrap_free(void *ptr) { return; }") self.reduction_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") + self.spad_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="spad") + self.apply_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="apply") + self.mask_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="mask") self.iterator_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="iter") self.init_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init") self.init_vec_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init_vec") + self.const_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="const") + self.alloc_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="alloc") + self.indexed_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="indexed_op") self.map_cse = common.CSE("#", self.suffix, name_prefix="map") - self.consts = set() - self.tags = set() - self.tile_desc = MLIRTile(self.tile_row, self.tile_col, self.vector_lane) - self.dma_cache = {} - self.dma_counter = 1 - self.reduction_idx = {} + self.global_vars_dict = dict() + self.reduction_vars = dict() + self.consts = dict() + self.tags = dict() + self.dma_read_cache = dict() + self.dma_write_cache = dict() + self.spadbuf_counter = 0 + self.dma_read_counter = 1 + self.dma_write_counter = 1 + self.dma_tag_id = 0 self.affine_yield = {} self.welford_reduce_out = None self.reduce_iterator = {} - self.is_template_kernel = False + self.spad_buffer_dict = dict() + self.base_vector_initialized = False + + def reset(self, reason): + self.__init__(self.kernel_group, reason=reason) + + # padding type 0: zero-padding 1: negative-padding(-inf) ... + def get_padding_type(self): + ops = self.current_node.node.origins + if self.current_node.is_reduction(): + for op in ops: + if "exp" in op.name: # exponential reduciton case + return 1 + # for op in ops: # TODO: padding has some problem in the case of max_pool + # if "max_pool" in op.args[0].name: + # return 1 + return 0 + + def convert_index(self, expr, buffer): + if len(expr.free_symbols) != 1: + raise NotImplementedError("Not supporting this view operation...!") - def get_constant_vector(self, expr): - constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] - return constant_vector - - def get_constant_vector2(self, expr): - # Case 0. symbol ex) index 0 - # Case 1. inner product form ex) 16 * index0 + 1 * index1 - # Case 2. Complicated form ex) 16 * index0 + 8 * (index//4) + (index % 4) - constant_vector = [] if expr.is_symbol: - constant_vector.append(tuple([1, expr])) - return constant_vector + return expr - for arg in expr.args: - if arg.is_symbol: - constant_vector.append(tuple([1,arg])) - continue - if len(arg.args) == 0: #TODO: check this - continue - if arg.args[0].is_number: - constant_vector.append(arg.args) - else: - constant_vector.append([1, arg]) + expr_str = str(expr) + if isinstance(expr, ModularIndexing): + replace_str = f"({expr.args[0]} floordiv {expr.args[1]}) mod {expr.args[2]}" + expr_str = re.sub(r"ModularIndexing\([^)]*\)", replace_str, expr_str) + elif "//" in expr_str: + expr_str = expr_str.replace("//", " floordiv ") + else: + raise NotImplementedError("What is this case?") - return constant_vector + indices = [expr.args[0]] + args = ", ".join(map(str, indices)) + map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args}) -> ({expr_str})>") + args = ", ".join([f"%{i}" for i in indices]) + index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})") + return index - def find_node_by_name(self, name): - if name in V.graph.graph_inputs: - return V.graph.graph_inputs[name] - else: - for output_node in V.graph.graph_outputs: - if output_node.data.name == name: - return output_node - - def get_dma_info(self, name, index, dtype): - current_tile = MLIRTile(self.tile_desc.n_row, self.tile_desc.n_col, self.tile_desc.vector_lane, self.tile_desc.used_vector_lane) - cv = self.get_constant_vector(index) - cv2 = self.get_constant_vector2(index) - tile_size_per_lane = self.tile_desc.get_tile_size_per_lane() # FIXME. move this - tile_size_per_lane = 2 if tile_size_per_lane==1 else tile_size_per_lane # Avoid scalar operation - - if len(cv) != len(cv2) and len(cv2) == 3: - print("Mismatch! ", cv) - # FIXME. this is really shitty code :( - cv = cv2#[[1 if x[0] == 0 else x[0], x[1]] for x in cv] + def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> common.CSEVariable: + if buffer is None: + buffer = self.applys - # Case 0. Tile is 0-D scalar - if len(cv) == 0: - # Use only one vectorlane to handle scalar data - current_tile.n_row = 1 - current_tile.n_col = 1 - current_tile.tile_layout = MLIRTile.TILE_ROW_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_ROW_WISE - mm_stride, tile_size_per_lane = 1, 1 - chunk_size = current_tile.get_chunk_size() - # Case 1. Tile is 1-D vector type - elif len(cv) == 1 and len(cv) <= self.reduction_depth: - current_tile.n_row = 1 - current_tile.n_col = self.tile_desc.get_tile_size() - current_tile.tile_layout = MLIRTile.TILE_ROW_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_COL_WISE # Actually it is not needed in vector case - chunk_size = current_tile.get_chunk_size() - mm_stride = current_tile.n_col - # Case 2. Tile is 1-D vector type with reduction - elif len(cv) == 1 and len(cv) == self.reduction_depth + 1: - # Use only one vectorlane to reduce a vector - current_tile.tile_layout = MLIRTile.TILE_ROW_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_ROW_WISE - current_tile.n_row = 1 - current_tile.n_col = self.tile_desc.get_tile_size() - current_tile.used_vector_lane = 1 - chunk_size = current_tile.get_chunk_size() - mm_stride = 0 # don't care - # Case 3. Tile is 2-D tile - elif len(cv) == 2: - is_reduction = self.reduction_depth == 1 - if cv[0][0] != 0 and cv[1][0] != 0: - is_transposed = cv[0][0] < cv[1][0] - if is_transposed: - current_tile.n_row = self.tile_desc.n_col - current_tile.n_col = self.tile_desc.n_row - mm_stride = self.ranges[0] - else: - current_tile.n_row = self.tile_desc.n_row - current_tile.n_col = self.tile_desc.n_col - mm_stride = self.ranges[1] - - if is_reduction and is_transposed: - current_tile.tile_layout = MLIRTile.TILE_COL_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_ROW_WISE - chunk_size = current_tile.get_chunk_size() - elif is_reduction and not is_transposed: - current_tile.tile_layout = MLIRTile.TILE_ROW_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_COL_WISE - chunk_size = current_tile.get_chunk_size() - elif not is_reduction and is_transposed: - # Transposed case - current_tile.tile_layout = MLIRTile.TILE_COL_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_COL_WISE - chunk_size = current_tile.get_chunk_size() - else: # not is_reduction and not is_transpose - current_tile.tile_layout = MLIRTile.TILE_COL_WISE if self.tile_desc.vector_lane_axis else MLIRTile.TILE_ROW_WISE - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_ROW_WISE - chunk_size = current_tile.get_chunk_size() - else: - # Broadcast pattern - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_ROW_WISE - mm_stride = 0 - if cv[0][0] == 0: - current_tile.tile_layout = MLIRTile.TILE_COL_WISE if self.tile_desc.vector_lane_axis else MLIRTile.TILE_ROW_WISE - current_tile.n_row = self.tile_desc.n_row - current_tile.n_col = self.tile_desc.n_col - chunk_size = current_tile.get_chunk_size() - else: # cv[1][0] == 0 - current_tile.n_row = self.tile_desc.n_col - current_tile.n_col = self.tile_desc.n_row - chunk_size = current_tile.get_cols_per_lane() - if not is_reduction: - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_COL_WISE - chunk_size = current_tile.n_col if self.tile_desc.vector_lane_axis else chunk_size - elif len(cv) == 3: - current_tile.tile_per_lane_layout = MLIRTile.TILE_PER_LANE_COL_WISE # Actually it is not needed in vector case - mm_stride = cv[-1][0] - # When current_tile.n_col stride is 1, we can access row vector - if mm_stride == 1: - current_tile.n_row = 1 - current_tile.n_col = self.tile_desc.get_tile_size() - # if current_tile.n_col stride is not 1, we have to access in a column vector - else: - current_tile.n_row = self.tile_desc.get_tile_size() - current_tile.n_col = 1 - chunk_size = current_tile.get_tile_size_per_lane() - else: - raise NotImplementedError() + # Constant case + if expr.is_number and len(indirect_dims) == 0: + return self.get_const_cse(int(expr)) - #assert(not (dtype==torch.bool and chunk_size < 8)) - chunk = chunk_size << 1 | (current_tile.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE) - return mm_stride, chunk, [current_tile.n_row, current_tile.n_col], tile_size_per_lane + # Identity case + if len(expr.args) == 0 and len(indirect_dims) == 0: + return expr - def parse_indices(self, expr): if len(expr.args) == 0: - return expr + args = [expr] + else: + args = list(expr.args) + # Sort index variable.. ex) (%index1, %index0) + args_dict = {term: list(term.free_symbols)[0] for term in args if term.free_symbols} + sorted_args = sorted(args_dict.keys(), key=lambda term: str(args_dict[term])) + indices = [] + for arg in sorted_args: + if arg.is_Mul and arg.args[0].is_number: + new_arg = sympy.Symbol(str(self.convert_index(arg.args[1], buffer))) + expr = expr.replace(arg.args[1], new_arg) + indices.append(str(new_arg)) + elif not arg.is_number: + new_arg = sympy.Symbol(str(self.convert_index(arg, buffer))) + expr = expr.replace(arg, new_arg) + indices.append(str(new_arg)) # Extract index var + indirect_args = [f"%{i}" for i in indirect_dims] expr_str = str(expr) - pattern = r'index\d+' - indices = OrderedDict() - for index in re.findall(pattern, expr_str): - indices[index] = None - indices = list(indices.keys()) - args = ", ".join(map(str, indices)) - if "//" in expr_str: - expr_str = expr_str.replace("//", " floordiv ") - pattern = r"ModularIndexing\((.*?)\)" - matches = re.search(pattern, expr_str) - if matches: - mod_args = matches.group(1) - args_list = mod_args.split(", ") - replace_str = f"({args_list[0]} floordiv {args_list[1]}) mod {args_list[2]}" - expr_str = re.sub(r"ModularIndexing\([^)]*\)", replace_str, expr_str) - - map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args}) -> ({expr_str})>") + map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[{','.join(indirect_dims)}] -> ({expr_str})>") args = ", ".join([f"%{i}" for i in indices]) - index = self.cse.generate(self.loads, f"affine.apply #{map_var}({args})") + index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[{','.join(indirect_args)}] {comments}") return index - def codegen_nodes(self, nodes, kernel_name): - _, (group, reduction_group) = max( - nodes, key=lambda x: int(x.is_reduction()) - ).group - - self.set_ranges(group, reduction_group, None) - with self as kernel: - kernel.args = kernel.kernel_group.args - for node in nodes: - vars, reduction_vars = kernel.set_ranges(group, reduction_group, node.read_writes) - kernel.args.tile_row = kernel.tile_desc.n_row - kernel.args.tile_col = kernel.tile_desc.n_col - _, _, _, kernel.buffer_types = kernel.args.mlir_argdefs() - kernel.reduction_idx = {var: i for i, var in enumerate(reduction_vars)} - node.run(vars, reduction_vars) - src_code = self.codegen_kernel(kernel_name=kernel_name) - self.meta_kernel() + def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: + if buffer is None: + buffer = self.applys + zero_var = self.get_const_cse(0) + expr_list = [arg for arg in expr_list] + dim_list = [f"d{i}" for i in range(len(expr_list))] + + if len(expr_list) == 1 and expr_list[0].is_number: + # Constant case + return self.get_const_cse(int(expr_list[0])) + elif len(expr_list) == 1 and expr_list[0].is_symbol: + # Identity case + return expr_list[0] + + indices = [] + new_expr_list = [0] * len(expr_list) + for idx, arg in enumerate(expr_list): + if arg.is_Mul and arg.args[0].is_number: + new_arg = sympy.Symbol(str(self.convert_index(arg.args[1], buffer))) + new_expr_list[idx] = arg.subs(arg.args[1], dim_list[idx]) + indices.append(str(new_arg)) + elif not arg.is_number: + new_arg = sympy.Symbol(str(self.convert_index(arg, buffer))) + new_expr_list[idx] = new_arg.subs(new_arg, dim_list[idx]) + indices.append(str(new_arg)) + else: + const_var = self.get_const_cse(int(arg)) + new_arg = sympy.Symbol(f"{const_var}") + new_expr_list[idx] = arg + indices.append(str(new_arg)) - write_path = extension_codecache.get_write_path(src_code) - if not os.path.exists(write_path): - os.makedirs(write_path) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header.getvalue()) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header.getvalue()) - return src_code + # Extract index var + expr_str = str(sum(new_expr_list)) + args = ", ".join(map(str, dim_list)) + map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[] -> ({expr_str})>") + args = ", ".join([f"%{i}" for i in indices]) + index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[]") + return index def load(self, name: str, index: sympy.Expr): index = self.rename_indexing(index) - indices = self.parse_indices(index) - prefix = self.newvar_prefix - if index.is_number: - prefix = prefix + "c" - self.consts.add(int(index)) - var = self.args.input(name) + index = self.convert_indirect_indexing(index) + padding = self.get_padding_type() + + # Extract dram info + dram_var = self.kernel_group.args.input(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] - stride, chunk, tile_shape, tile_size_per_lane = self.get_dma_info(name, index, dtype) - dram_tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + + # Extract sram info + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index) + vlane_split_axis = local_tile_desc.vlane_split_axis + vlane_stride = local_tile_desc.vlane_stride + tile_numel_per_lane = local_tile_desc.get_numel_per_lane() + tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = local_tile_desc.get_tile_stride() + + # Compute vector unit size + vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) + compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() # Define scratch pad buffer - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.loads, indices, index) + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) + # MVIN Encoding - dma_key = (stride, chunk, dtype) - if dma_key in self.dma_cache: - dmaType, stride, chunk = self.dma_cache[dma_key] + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding={padding}}}" + code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute) + self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching + compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) + # Generate vector load instruction + if compute_vec_size > 1: + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" else: - assert(self.dma_counter < 4) - dmaType = DMA_TYPE[f"MVIN{self.dma_counter}"] - self.dma_counter += 1 - self.consts.add(dmaType) - self.consts.add(stride) - self.consts.add(chunk) - self.dma_cache[dma_key] = dmaType, stride, chunk - self.tags.add(f"{name}_tag") - self.consts.add(0) - code = f"affine.dma_start %{var}[{prefix}{indices}], %{buffer}[%c0, %c0], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" - self.cse.generate(self.loads, code, assignment = False) # FIXME: assignment = False does not support caching - - operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load" - shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" - line = f"{operation} %{buffer}[0, 0] : memref<{dram_tile_shape}x{type_name}, 1>{shape}" + operation = "affine.load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" out = self.cse.generate(self.loads, line) - var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] - self.register_var_info(out, var_info) + self.register_var_info(out, [compute_vec_size, mlir_dtype]) + self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] return out def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): index = self.rename_indexing(index) - indices = self.parse_indices(index) - prefix = self.newvar_prefix - if index.is_number: - prefix = prefix + "c" - self.consts.add(int(index)) - var = self.args.output(name) + dram_var = self.kernel_group.args.output(name) dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] - stride, chunk, tile_shape, tile_size_per_lane = self.get_dma_info(name, index, dtype) - dram_tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" - - # Define scratch pad buffer - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.stores, indices, index) + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + + # Prepare dma instruction + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index) + vlane_split_axis = local_tile_desc.vlane_split_axis + vlane_stride = local_tile_desc.vlane_stride + + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) + tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = local_tile_desc.get_tile_stride() + tile_size = local_tile_desc.get_tile_size() + # Compute vector unit size + vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) + compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() + require_store = True + + if str(value) in self.spad_buffer_dict: + # Todo. If tile_size is not same (i.e., view operation), we can't apply peephole optimization easily + require_store = self.spad_buffer_dict[str(value)][1] != tile_size + + if require_store: + # Define scratch pad buffer + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) + compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) + # Generate vector store instruction + store_size, operand_type = self.var_info[value] + if mlir_dtype != operand_type: + value = ops.custom_cast(value, mlir_dtype, var_info=self.var_info) + + if compute_vec_size > 1 and store_size > 1: + operation = "affine.vector_store" + line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + else: + operation = "affine.store" + line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}" + self.stores.writeline(common.DeferredLine(name, line)) # TODO: Should be changed to self.compute? + else: + sram_var = self.spad_buffer_dict[str(value)][0] + sram_index_var = self.spad_buffer_dict[str(value)][3] - # MVOUT Encoding - dmaType = 3 # MVIN 2, MVIN2 1, MVIN3 14, MVOUT 3 - self.consts.add(dmaType) - self.consts.add(stride) - self.consts.add(chunk) - - store_size, operand_type = self.var_info[value] - operation = "affine.vector_store" if tile_size_per_lane > 1 and store_size > 1 else "affine.store" - shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 and store_size > 1 else "" - if type_name != operand_type: - value = ops.custom_cast(value, type_name, var_info=self.var_info) - - line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{dram_tile_shape}x{type_name}, 1>{shape}" - self.cse.generate(self.stores, line, assignment = False) - self.consts.add(0) - self.tags.add(f"{name}_tag") - code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[{prefix}{indices}], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{dram_tile_shape}x{type_name}, 1>, memref<{self.buffer_types[name][1]}x{type_name}>, memref<1xi32>" - self.cse.generate(self.stores, code, assignment = False) + # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" + code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute) + self.dma_stores.writeline(common.DeferredLine(name, code)) def reduction(self, dtype, src_dtype, reduction_type, value): argmax_or_argmin = reduction_type in {"argmax", "argmin"} @@ -979,537 +1109,727 @@ def reduction(self, dtype, src_dtype, reduction_type, value): sqr_sum = self.reduction(dtype, src_dtype, "sum", ops.mul(value, value)) self.welford_reduce_out = (sum, sqr_sum, None) return sum, sqr_sum, None + + # Prepare reduction loop + reduction_key = src_dtype, reduction_type, value + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + iterator = self.iterator_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + init = self.init_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + init_vec = self.init_vec_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + type_name = mlir_common.DTYPE_TO_MLIR[dtype] + init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") + vec_len = self.kernel_group.tile_desc.get_compute_vec_size() + reduced_shape = self.kernel_group.tile_desc.get_mlir_vshape(type_name) + + # Set accumulation var + if vec_len == 1: # 1-D vector to scalar + # Edge case for scalar + init_vec = init else: - reduction_key = src_dtype, reduction_type, value - acc = self.reduction_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - iterator = self.iterator_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - init = self.init_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - init_vec = self.init_vec_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] - acc_var = init - acc_shape = type_name - shape = f"vector<{self.tile_desc.get_tile_size()}x{type_name}>" - reduced_shape = type_name - init = self.cse.generate(self.reduction_prefix, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - if len(self.ranges) == 1: - axis = "0" - acc_var = init - shape = f"vector<{self.tile_desc.get_tile_size_per_lane()}x{type_name}>" - elif len(self.ranges) == 2: - vec_len = self.tile_desc.get_rows_per_lane() - flattened_size = f"vector<{self.tile_desc.get_tile_size_per_lane()}x{type_name}>" - - # It is column majored per lane tile - expaned_size = f"vector<{self.tile_desc.get_tile_size_per_lane()//vec_len}x{vec_len}x{type_name}>" - value = self.cse.generate(self.compute, f"vector.shape_cast %{value} : {flattened_size} to {expaned_size}") - shape = expaned_size - - # Edge case for scalar - if vec_len == 1: - reduced_shape = f"{type_name}" - init_vec = init - axis = "0, 1" - acc_var = init - var_info = [1, mlir_common.DTYPE_TO_MLIR[dtype]] - else: - reduced_shape = f"vector<{vec_len}x{type_name}>" - init_vec = self.cse.generate(self.reduction_prefix, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") - axis = "0" - acc_var = init_vec - var_info = [vec_len, mlir_common.DTYPE_TO_MLIR[dtype]] - self.register_var_info(acc, var_info) + # Adjust shape and inital value + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") + self.register_var_info(init_vec, [vec_len, type_name]) + acc_var = init_vec + + # Reduction body prepare + body_acc = self.reduction_cse.generate( + self.compute, f"reduction {reduction_key}body_acc", write=False + ) + body_iter_arg = self.iterator_cse.generate( + self.compute, f"reduction {reduction_key}body_iter_arg", write=False + ) + self.register_var_info(body_iter_arg, [vec_len, type_name]) + + self.reduction_vars[acc] = (reduction_type, iterator, acc_var, reduced_shape) + self.affine_yield[body_acc] = reduced_shape + self.reduction_cse.reduction_cache[reduction_key] = acc + self.iterator_cse.reduction_cache[reduction_key] = iterator + self.init_cse.reduction_cache[reduction_key] = init_vec + + # Reduction body codegen + mask_shape, mask_var = self.get_mask() + if mask_var is not None: + value = ops.where(mask_var, value, init_vec) + result = reduction_partial_combine_vec(reduction_type, value, body_iter_arg) + self.compute_body_loop.reduction_vars[body_acc] = (reduction_type, body_iter_arg, iterator, reduced_shape) + self.compute_body_loop.affine_yield[result] = reduced_shape + + # Final reduction + reduction_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_tile_size()[-1] + assert(vec_len % reduction_size==0) + if vec_len > reduction_size: + init = self.const_cse.generate(self.reductions_suffix, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") + if reduction_size == 1: + final_reduced_shape = f"{type_name}" + out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(reduction_type, acc, init, axis=0, shape=reduced_shape, reduced_shape=final_reduced_shape)) else: - raise NotImplementedError() - - self.reduction_vars[acc] = (reduction_type, iterator, acc_var, reduced_shape) - out = self.cse.generate(self.compute, reduction_combine_vec(reduction_type, value, iterator, axis=axis, shape=shape, reduced_shape=reduced_shape)) - self.affine_yield[out] = reduced_shape - - self.reduction_cse.reduction_cache[reduction_key] = acc - self.iterator_cse.reduction_cache[reduction_key] = iterator - self.init_cse.reduction_cache[reduction_key] = init_vec + final_reduced_shape = f"vector<{reduction_size}x{type_name}>" + init_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{init} : {type_name} to {final_reduced_shape}") + new_vshape= f"vector<{vec_len//reduction_size}x{reduction_size}x{type_name}>" + value = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{acc} : {reduced_shape} to {new_vshape}") + out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(reduction_type, value, init_vec, axis=0, shape=new_vshape, reduced_shape=final_reduced_shape)) + acc = out + + # reigster reduction output + var_info = [reduction_size, mlir_common.DTYPE_TO_MLIR[dtype]] + self.register_var_info(acc, var_info) return acc def store_reduction(self, name, index, value): - var = self.args.output(name) + # Note: Change cse temporaily + # Store reduction can't share cached value stored in cse, + # since it is not innermost loop body. + tmp_cse = self.cse + tmp_apply_cse = self.apply_cse + self.cse = self.reduction_cse + self.apply_cse = self.reduction_cse + + dram_var = self.kernel_group.args.output(name) dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] index = self.rename_indexing(index) - indices = self.parse_indices(index) - prefix = self.newvar_prefix - if index.is_number: - prefix = prefix + "c" - self.consts.add(int(index)) + # Tile is always reuduced in inner loop - tile_col = self.tile_desc.n_row - tile_row = 1 - dram_tile_shape = f"{tile_row}x{tile_col}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, self.reductions_suffix, indices, index) + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index, broadcast=False, store_reduction=True, buffer=self.reductions_suffix) + vlane_split_axis = local_tile_desc.vlane_split_axis + vlane_stride = local_tile_desc.vlane_stride + + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) + tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = local_tile_desc.get_tile_stride() + compute_vec_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_tile_size()[-1] + if compute_vec_size == 1: + vshape = f"{mlir_dtype}" + else: + vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) if self.welford_reduce_out is not None: - # raise NotImplementedError() sum, sqr_sum, _ = self.welford_reduce_out - shape = f"vector<{self.tile_desc.get_rows_per_lane()}x{type_name}>" if self.buffer_types[name][1] > 1 else type_name # mean divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.ranges[self.reduction_depth])} : f32") if self.buffer_types[name][1] > 1: - divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.var_info[sum][0]}x{type_name}>") + divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.var_info[sum][0]}x{mlir_dtype}>") else: - divider_vec = f"f{self.buffer_types[name][1]}" - mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sum}, %{divider_vec} : {shape}") + divider_vec = divider + mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sum}, %{divider_vec} : {vshape}") # m2 = (E(X^2) - E(X)^2) * N - sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sqr_sum}, %{divider_vec} : {shape}") - mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{mean}, %{mean} : {shape}") - variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {shape}") - m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {shape}") + sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sqr_sum}, %{divider_vec} : {vshape}") + mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{mean}, %{mean} : {vshape}") + variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {vshape}") + m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {vshape}") if self.current_node.node.origin_node: # FIXME: This is a temporary solution value = mean else: value = m2 - # Select mlir store operaiton - if self.buffer_types[name][1] == 1 or self.tile_desc.get_rows_per_lane() == 1: + # Select src type + if compute_vec_size == 1: operation = "affine.store" - # raise NotImplementedError("Scalar store!") + line = f"{operation} %{value}, %{sram_var}[{sram_index_var}] : {tile_shape}" else: operation = "affine.vector_store" - - # Select src type - if self.tile_desc.get_rows_per_lane() == 1: - shape = "" - else: - shape = f"vector<{self.tile_desc.get_rows_per_lane()}x{type_name}>" - shape = f", {shape}" if self.buffer_types[name][1] > 1 else "" - line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{tile_row}x{tile_col}x{type_name}, 1>{shape}" - self.cse.generate(self.reductions_suffix, line, assignment = False) + line = f"{operation} %{value}, %{sram_var}[{sram_index_var}] : {tile_shape}, {vshape}" + self.reductions_suffix.writeline(common.DeferredLine(name, line)) # MVOUT Encoding - dmaType = 3 # MVIN 2, MVIN2 1, MVIN3 14, MVOUT 3 - mm_stride = tile_col - is_col_major = MLIRTile.TILE_PER_LANE_ROW_WISE - chunk_size = self.tile_desc.get_rows_per_lane() - chunk = chunk_size << 1 | (is_col_major == MLIRTile.TILE_PER_LANE_COL_WISE) - self.consts.add(dmaType) - self.consts.add(mm_stride) - self.consts.add(chunk) - self.tags.add(f"{name}_tag") - # Change row, col - self.consts.add(0) - code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[{prefix}{indices}], %{name}_tag[0], %c{dmaType}, %c{mm_stride}, %c{chunk} : memref<{tile_row}x{tile_col}x{type_name}, 1>, memref<{self.buffer_types[name][1]}x{type_name}>, memref<1xi32>" - self.cse.generate(self.reductions_suffix, code, assignment = False) - - def codegen_body(self): - # if not ( - # self.loads - # or self.stores - # or self.compute - # ): - # return - def template_store(options): - subtile_size = [self.vector_lane, self.vector_lane] - async_flag = 1 - self.consts.add(0) - line = f"affine.dma_start %Y_buffer[%c0, %c0], %Y[%index2], %tag[0], %c_mvout, %N, %c_set"\ - f": memref<{options['TILE_M']}x{options['TILE_N']}xf32, 1>,"\ - f"memref<{options['M'] * options['N']}xf32>, memref<1xi32>" #FIXME: Using constant index - self.cse.generate(self.stores, line, assignment = False) - self.body.splice(self.codegen_init()) - self.body.splice(self.loads) - self.body.splice(self.compute) - if len(self.stores._lines) == 0: - template_store(self.render_options) - self.body.splice(self.stores) - self.loads.clear() - self.compute.clear() - self.stores.clear() - - def codegen_init(self): - code = IndentedBuffer() - tags = sorted(self.tags) - consts = sorted(self.consts) - for tag in tags: - code.writeline(f"%{tag} = memref.alloc() : memref<1xi32>") - for const in consts: - code.writeline(f"%c{const} = arith.constant {const} : index") - return code + # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" + code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute) + self.reductions_suffix.writeline(common.DeferredLine(name, code)) + + # Restore origin cse + self.cse = tmp_cse + self.apply_cse = tmp_apply_cse + + def indirect_indexing(self, index_var, size, check=True): + return str(index_var) + + def _index_expr(self, tile_size, renamed_expression, index, base_vector_index): + tile_desc = self.kernel_group.tile_desc + compute_vec_size = tile_desc.get_compute_vec_size() + + strides = [1] * len(tile_size) + for i in range(len(tile_size) - 2, -1, -1): + strides[i] = strides[i + 1] * tile_size[i + 1] + + # Create vector index + compute_vec = self.cse.generate(self.compute, f"vector.broadcast %{self.compute_idx} : index to vector<{compute_vec_size}xindex>") + self.register_var_info(compute_vec, [compute_vec_size, "index"]) + vector_index = ops.add(base_vector_index, compute_vec) + + # Create tile_dim index + dim_list = [] + for idx in range(len(tile_size)): + div_coeff = self.get_const_cse(strides[idx], "index") + mod_coeff = self.get_const_cse(tile_size[idx], "index") + div_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{div_coeff} : index to vector<{compute_vec_size}xindex>") + mod_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{mod_coeff} : index to vector<{compute_vec_size}xindex>") + self.register_var_info(div_vec, [compute_vec_size, "index"]) + self.register_var_info(mod_vec, [compute_vec_size, "index"]) + dim = ops.modular(ops.div(vector_index, div_vec), mod_vec) + if idx == tile_desc.vlane_split_axis: # Need to add vector lane offset + offset = tile_desc.vlane_stride * strides[idx] + vlane_coeff = self.get_const_cse(0, "i64") + vlane_vec_size = 4 + vlane_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{vlane_coeff} : i64 to vector<{vlane_vec_size}xi64>") + vlane_offset = self.const_cse.generate(self.const_buffer, f"arith.addi %{vlane_vec}, %{vlane_vec} {{ vlane_offset={offset} }} : vector<{vlane_vec_size}xi64> // vlane offset") + self.register_var_info(vlane_offset, [vlane_vec_size, "i64"]) + vlane_offset = ops.index_cast(vlane_offset, "index") + self.register_var_info(vlane_offset, [vlane_vec_size, "index"]) + dim = ops.add(dim, vlane_offset) + dim_list.append(dim) + + indices = [str(i) for i in index.free_symbols] + for idx in indices: + i = int(idx[5:]) + index_vec = self.cse.generate(self.compute, f"vector.broadcast %{idx} : index to vector<{compute_vec_size}xindex>") + self.register_var_info(index_vec, [compute_vec_size, "index"]) + offset = ops.add(index_vec, dim_list[i]) + dim_list[i] = offset + arg_lists = [] + for arg in renamed_expression.args: + if isinstance(arg, sympy.Integer): + offset = self.get_const_cse(int(arg)) + offset_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{offset} : index to vector<{compute_vec_size}xindex>") + self.register_var_info(offset_vec, [compute_vec_size, "index"]) + arg_lists.append(offset_vec) + elif isinstance(arg, sympy.Mul): + if isinstance(arg.args[0], sympy.Integer) and isinstance(arg.args[1], sympy.Symbol): + coeff = self.get_const_cse(int(arg.args[0])) + coeff_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{coeff} : index to vector<{compute_vec_size}xindex>") + self.register_var_info(coeff_vec, [compute_vec_size, "index"]) + result = ops.mul(dim_list[int(str(arg.args[1])[1:])], coeff_vec) + arg_lists.append(result) + elif isinstance(arg.args[1], sympy.Integer) and isinstance(arg.args[0], sympy.Symbol): + coeff = self.get_const_cse(int(arg.args[1])) + coeff_vec = self.cse.generate(self.compute, f"vector.broadcast %{coeff} : index to vector<{compute_vec_size}xindex>") + self.register_var_info(coeff_vec, [compute_vec_size, "index"]) + result = ops.mul(dim_list[int(str(arg.args[0])[1:])], coeff_vec) + arg_lists.append(result) + else: + raise NotImplementedError("Not supporting format") + elif isinstance(arg, sympy.Symbol): + arg_lists.append(dim_list[int(str(arg)[1:])]) + else: + raise NotImplementedError("Not supporting format") + if isinstance(renamed_expression, sympy.Symbol): + arg_lists.append(dim_list[int(str(renamed_expression)[1:])]) + accum = arg_lists[0] + for arg in arg_lists[1:]: + accum = ops.add(accum, arg) + return accum + + def index_expr(self, index, dtype): + tile_desc = self.kernel_group.tile_desc + tile_size = tile_desc.get_tile_size_per_lane() + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + str_tile_size = [str(dim) for dim in tile_size] + compute_vec_size = tile_desc.get_compute_vec_size() + tile_shape = f"memref<{compute_vec_size}xi64, 1>" + vshape = f"vector<{compute_vec_size}xi64>" + + # Create base_vector index var + c_type = "uint64_t" + new_name = f"index_expr_{compute_vec_size}" + if new_name not in self.global_vars_dict: + self.header.writeline(f"{c_type} {new_name}_spad[{compute_vec_size}] __attribute__ ((section(\".spad\")));") + self.gem5_header.writeline(f"{c_type} {new_name}_spad[{compute_vec_size}] __attribute__((aligned(64)));") + self.global_vars.writeline(f"memref.global @{new_name}_spad : {tile_shape}") + self.global_vars_dict[new_name] = dict() + sram_var = self.spad_cse.generate(self.spad_buffer, f"memref.get_global @{new_name}_spad : {tile_shape}") + # Initialize base vector + if not self.base_vector_initialized: + init_iter = "iter" + parallel_map = f"affine.parallel (%{init_iter}) = ({0}) to ({compute_vec_size}) {{ // Base vector initializer" + self.spad_buffer.writeline(parallel_map) + with self.spad_buffer.indent(): + self.spad_buffer.writeline(f"%init_vec = vector.broadcast %{init_iter} : index to vector<2xindex>") + self.spad_buffer.writeline(f"%init_cvt_vec = arith.index_cast %init_vec : vector<2xindex> to vector<2xi64>") + self.spad_buffer.writeline(f"affine.vector_store %init_cvt_vec, %{sram_var}[%{init_iter}] : {tile_shape}, vector<2xi64>") + self.spad_buffer.writeline("}") + self.base_vector_initialized = True + + line = f"affine.vector_load %{sram_var}[0] : {tile_shape}, {vshape}" + out = self.cse.generate(self.compute, line) + self.register_var_info(out, [compute_vec_size, "i64"]) + base_vector_index = ops.index_cast(out, "index") + self.register_var_info(base_vector_index, [compute_vec_size, "index"]) + + renamed_symbols = {symbol: "d"+str(symbol)[5:] for symbol in index.free_symbols} + renamed_expression = index.subs(renamed_symbols) + result = self._index_expr(tile_size, renamed_expression, index, base_vector_index) + return result + + def codegen_global_init(self): + return self.global_vars def codegen_loops(self): code = mlir_common.ParallelLoopBuffer() # Loop body part - tile_row, tile_col = self.tile_desc.n_row, self.tile_desc.n_col - # FIXME. - #if (self.tiling_idx < self.reduction_depth and len(self.reduction_idx) > 0): - # tile_row, tile_col = self.tile_desc.n_col, self.tile_desc.n_row - tile_row = self.tile_desc.get_tile_size() if len(self.itervars) == 1 else tile_row - loops = [LoopLevel(var, size, idx-len(self.itervars), tile_row=tile_row, tile_col=tile_col) for idx, (var, size) in enumerate(zip(self.itervars, self.ranges))] + tile_size = self.kernel_group.tile_desc.get_tile_size() + # Apply paddings + loops = [LoopLevel(var, size, step=step) for idx, (var, size, step) in enumerate(zip(self.itervars, self.ranges, tile_size))] loops, reductions = [LoopNest(loops[: self.reduction_depth]), LoopNest(loops[self.reduction_depth :])] + reductions.mark_reduction(self.reduction_vars, self.affine_yield) + # For non-loop code if (self.reduction_depth==0): - loops = LoopNest([LoopLevel("dummy", 1, 1, 0)]) - reductions.mark_reduction(self.reduction_vars) - if len(self.affine_yield) > 0: - vars = ', '.join([f"%{name}" for name, _ in self.affine_yield.items()]) - reduced_shapes = ', '.join([f"{shape}" for _, shape in self.affine_yield.items()]) - self.stores.writeline(f"affine.yield {vars} : {reduced_shapes}") + loops = LoopNest([LoopLevel("dummy", 1)]) + + if len(reductions.loops) > 1: + NotImplementedError("Not support multiple reduction axis..") + + code.splice(self.const_buffer) + code.splice(self.alloc_buffer) + code.splice(self.spad_buffer) + # Outerloop with contextlib.ExitStack() as stack: for loop in loops.loops: loop_lines = loop.lines() - if loop_lines is None: - return code.writelines(loop_lines) - stack.enter_context(code.indent(outer_loop=True)) - with contextlib.ExitStack() as stack_outer: - code.splice(self.reduction_prefix) + stack.enter_context(code.indent(attribute="{outer_loop=true}")) + # Non-outerloop start + code.splice(self.reduction_prefix) + with contextlib.ExitStack() as stack: + # Add reduction loops + if len(reductions.loops): + reduction_lines = reductions.loops[0].lines() + epilogue = reductions.loops[0].epilogue_line() + code.writelines(reduction_lines) + stack.enter_context(code.indent(attribute="{accumulation_loop=true}", suffix=epilogue)) + code.splice(self.applys) + code.splice(self.indexed_buffer) + code.splice(self.dma_loads) + # Compute body + code.writelines(self.compute_body_loop.lines()) with contextlib.ExitStack() as stack: - for reduction in reductions.loops: - reduction_lines = reduction.lines() - if reduction_lines is None: - return - code.writelines(reduction_lines) - stack.enter_context(code.indent(outer_loop=False)) + stack.enter_context(code.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) + code.splice(self.masks) code.splice(self.loads) code.splice(self.compute) code.splice(self.stores) - code.splice(self.reductions_suffix) + code.splice(self.dma_stores) + code.splice(self.reductions_suffix) + # Non-outerloop end code.writeline(f"return") return code - def codegen_kernel(self, kernel_name): - wrapper = V.graph.wrapper_code - arg_defs, _, _, _ = self.kernel_group.args.mlir_argdefs() - code = self._codegen_kernel(arg_defs, kernel_name) - return code.getvalue() - - def meta_kernel(self): - wrapper = V.graph.wrapper_code - _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() - wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') - wrapper.add_import_once(f'\nfrom PyTorchSimFrontend.extension_codecache import CustomAsyncCompile') - wrapper.add_import_once(f'\ncustom_async_compile = CustomAsyncCompile()') - # Dump loop and load/store information - wrapper.add_import_once(f"arg_attributes = {arg_attributes}") - - - def call_kernel(self, kernel_name): - wrapper = V.graph.wrapper_code - _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() - # generate the code to call this - wrapper.generate_kernel_call(kernel_name, call_args, cuda=False) - - def _codegen_kernel(self, arg_defs, kernel_name): - arg_defs = ",\n".ljust(25).join(arg_defs) - code = common.BracesBuffer() - - code.splice(self.global_vars) - #TODO:. kernel name custom - kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" - code.writeline(f'func.func @{kernel_decl_name}({arg_defs})') - with code.indent(): - for old, new in self.kernel_group.args.aliases(): - code.writeline(f"auto {old} = {new};") - # Loop body part - code.splice(self.codegen_init()) - code.splice(self.codegen_loops()) - return code + def make_choices(self, nodes, kernel_name): + choices = [] + initial_tile_size = self.kernel_group.tile_desc.get_tile_size() + previous_ranges = self.ranges + prevent_infinite_loop = 0 + if len(initial_tile_size) < 2: + return choices # Can't autotune for 1-D tile size + for vlane_stride in [2, 4, 8]: + os.environ['TORCHSIM_VECTOR_LANE_STRIDE'] = str(vlane_stride) + previous_tile_size = initial_tile_size + increase_dim = -2 # increase the first dimension + while previous_tile_size[increase_dim] * 2 <= previous_ranges[increase_dim] and previous_tile_size[increase_dim] <= 2 ** 13 and prevent_infinite_loop < 10: + incrase_dim = -1 # only increase the last dimension + prevent_infinite_loop += 1 + while previous_tile_size[incrase_dim] * 2 <= previous_ranges[incrase_dim] and previous_tile_size[incrase_dim] <= 2 ** 13: + src_code = super().codegen_nodes(nodes, kernel_name) + if self.stop_autotune: + print(f"[Auto-tune] Skipping autotuning due to enough tile size: {self.kernel_group.tile_desc.get_tile_size()}") + break + print(f"[Auto-tune] Trying tile size: {self.kernel_group.tile_desc.get_tile_size()}, vlane_stride: {vlane_stride}") + previous_tile_size = self.kernel_group.tile_desc.get_tile_size() + self._prepare_simulator_headers(src_code) + bench_runner = self.run_bench(nodes, kernel_name, src_code) + choices.append((bench_runner, src_code, self.kernel_group)) + self.reset(f"tile_size_{incrase_dim}") + previous_tile_size[incrase_dim] = initial_tile_size[incrase_dim] + self.kernel_group.tile_desc.set_tile_size(previous_tile_size) + self.reset(f"tile_size_{increase_dim}") + self.reset("vlane_stride") + return choices + + def autotune(self, nodes, kernel_name): + def get_cycle(choice): + bench_runner, src_code, kernel_group = choice + for n_try in range(extension_config.CONFIG_MAX_AUTOTUNE_TRY): # TODO: make simple + try: + # bench_runner = self.run_bench(nodes, kernel_name, src_code) + if int(os.environ.get('BACKENDSIM_DRYRUN', default=False)): + _, _, out = bench_runner(autotune=1) + else: + out = bench_runner(validate=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE) + return out[-1] + except (extension_codecache.SpadOverflowError, RuntimeError) as e: + return float("inf") + #if isinstance(e, RuntimeError) and str(e) != "STACK_OVERFLOW": + # print(f"Benchmark[trial-{n_try}] failed with unexpected error: {e}") + # return float("inf") + #print(f"Benchmark failed due to spad overflow with tile size: {self.kernel_group.tile_desc.get_tile_size()}") + #self.kernel_group = kernel_group # Reset to the original tile desc + #self.reset("spad_overflow") + #src_code = super().codegen_nodes(nodes, kernel_name) + #bench_runner = self.run_bench(nodes, kernel_name, src_code) + #kernel_group = self.kernel_group + #self._prepare_simulator_headers(src_code) + return float("inf") # Exceeded maximum number of autotuning attempts + + choices = self.make_choices(nodes, kernel_name) + + if len(choices) == 0: # can't autotune + return None + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(get_cycle, choices)) + max_idx = results.index(min(results)) + if min(results) == float("inf"): + raise RuntimeError("Failed to find optimal tile size...") + print(f"[Auto-tune] Optimal tile size: {choices[max_idx][2].tile_desc.get_tile_size()}, vlane_stride: {choices[max_idx][2].tile_desc.vlane_stride}, cycles: {results[max_idx]}") + optimal_src_code = choices[max_idx][1] + return optimal_src_code - def adjust_tile_size(self): - if self.read_writes is not None: - read_writes = list(self.read_writes.reads) + list(self.read_writes.writes) - cv_list = [] - for node in read_writes: - if len(node) > 1: - cv_list.append(self.get_constant_vector2(node[1])) - max_element = max(cv_list, key=len) - max_nr_dim = len(max_element) - - sorted_max_element = sorted(max_element, key=lambda x:x[0]) - # Force vector tile size when 3D node is originated from view - if max_nr_dim == 3 and max_nr_dim != len(self.itervars): - self.tile_desc.n_col = min(self.tile_desc.get_tile_size(), sorted_max_element[1][0]) - self.tile_desc.n_row = 1 - return - - # Case 1. vector kernel - if len(self.itervars) == 1: - self.tile_desc.n_col = self.tile_desc.get_tile_size() - self.tile_desc.n_row = 1 - elif len(self.itervars) == 0: - self.tile_desc.n_col = 1 - self.tile_desc.n_row = 1 - - # Case 2. 2-D tensor (e.g., softmax) - if len(self.itervars) == 2 and self.reduction_depth == len(self.itervars): - # Avoid too much padding - if (self.ranges[0] <= self.vector_lane and self.ranges[0] <= self.tile_desc.n_row): - self.tile_desc.n_row = self.ranges[0] - self.tile_desc.used_vector_lane = self.ranges[0] - - # Case 2. 2-D reduction (e.g., batchnorm) - if len(self.itervars) == 2 and self.reduction_depth == len(self.itervars) - 1: - if (((self.ranges[0] + 1) // 2) <= self.vector_lane and ((self.ranges[0] + 1) // 2) <= self.tile_desc.n_row): - self.tile_desc.n_row = ((self.ranges[0] + 1) // 2) * 2 - self.tile_desc.used_vector_lane = (self.ranges[0] + 1) // 2 - - # Case 2. 3-D tensor kernel without reduction. Access vector granule! - if len(self.itervars) == 3 and self.reduction_depth == len(self.itervars): - self.tile_desc.n_col = self.ranges[-1] - self.tile_desc.n_row = 1 - - # Case 3. N-D tensor kernel with reduction. Not implemented. Need this? - if len(self.itervars) >= 3 and self.reduction_depth < len(self.itervars): - raise NotImplementedError() - - def set_ranges(self, lengths, reduction_lengths, read_writes): - self.read_writes = read_writes - if self.call_ranges: - assert self.call_ranges == tuple(lengths) + tuple( - reduction_lengths - ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" - assert self.reduction_depth == len(lengths) + def codegen_nodes(self, nodes, kernel_name): + src_code = super().codegen_nodes(nodes, kernel_name) + self._prepare_simulator_headers(src_code) + if not extension_config.CONFIG_AUTOTUNE or extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: + return src_code else: - self.call_ranges = tuple(lengths) + tuple(reduction_lengths) - self.ranges = [self.rename_indexing(x) for x in self.call_ranges] - self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))] - self.reduction_depth = len(lengths) + optimal_src_code = self.autotune(nodes, kernel_name) + if optimal_src_code: + return optimal_src_code + else: + return src_code + + def _prepare_simulator_headers(self, src_code): + write_path = extension_codecache.get_write_path(src_code) + os.makedirs(write_path, exist_ok=True) - # Adjust time size when it is vector - self.adjust_tile_size() + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - return ( - self.itervars[: self.reduction_depth], - self.itervars[self.reduction_depth :], + spad_end_symbol = "int spad_end[0] __attribute__ ((section(\".spad\")));\n" + spad_section_end_symbol = ( + f"int spad_section_end[0] __attribute__ ((section(\".spad\"), aligned({self.spad_info['spad_size']*self.vector_lane})));" ) + write_atomic(spike_write_path, self.header.getvalue() + spad_end_symbol + spad_section_end_symbol) + write_atomic(gem5_write_path, self.gem5_header.getvalue()) + + def get_arg_info(self, name): + arg_info = dict() + arg_info.update(V.graph.graph_inputs) + arg_info.update({i.get_name(): i for i in V.graph.buffers}) + return arg_info[name] + + def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffer=None): # Need more argument? + """ + A tile descriptor exists that is configured on a kernel group + DMA desc should be adjusted according to buffer. + Therefore, this function shoulde determin DRAM, SRAM stride and + vectorlane mapping policy + """ + # Use loads as default + if buffer is None: + buffer = self.applys if "tmp" not in str(index) else self.dma_loads + + # TODO. + kg_tile_desc = self.kernel_group.tile_desc + # Note: index could contain symbols that represent dynamic axies + # Extract dimension of index(e.g, index0, index1) + local_dims = [int(str(i)[5:]) for i in index.free_symbols if "index" in str(i)] + implicit_local_dims = list(index.args) + total_dims = [int(str(i)[5:]) for i in self.itervars] + local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane) + local_dims.sort() # Assume that smaller index is placed in the outer loop + indirect_dims = [f"{i}" for i in index.free_symbols if "tmp" in str(i)] + for indirect_dim in indirect_dims: + index = index.replace(sympy.Symbol(indirect_dim), 0) + + # Reduction can have two type of tile size + if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)): + local_dims = total_dims # Brodatcast tile shape + + index_var = self.parse_indices(index, buffer=buffer, indirect_dims=indirect_dims) + + if kg_tile_desc.vlane_split_axis in local_dims: + local_vlane_split_axis = local_dims.index(kg_tile_desc.vlane_split_axis) + else: + local_vlane_split_axis = max(len(local_dims) - 1, 0) - def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index): + # Case 0. Tile is 0-D scalar + if len(local_dims) == 0: + if not store_reduction: + local_tile_desc.set_tile_size([kg_tile_desc.get_used_vlane() * kg_tile_desc.vlane_stride]) # Force it to use vector instruction. + local_tile_desc.vlane_split_axis = local_vlane_split_axis # last axis + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + else: + local_tile_desc.set_tile_size([1]) + local_tile_desc.vlane_split_axis = 0 + local_tile_desc.vlane_stride = 1 + dram_stride = [0] # Edge case + # Case 1. Tile is 1-D vector type + elif len(local_dims) == 1 and len(local_dims) <= self.reduction_depth: + local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(local_dims[0])]) + local_tile_desc.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + # Case 2. Tile is 1-D vector type with reduction + elif len(local_dims) == 1 and len(local_dims) == self.reduction_depth + 1: + local_tile_desc.set_tile_size([1, kg_tile_desc.get_dim_size(local_dims[0])]) + local_tile_desc.vlane_split_axis = local_vlane_split_axis + 1 + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + # Case 3. Tile is 2-D tile + elif len(local_dims) == 2: + is_reduction = self.reduction_depth == 1 and not store_reduction + if is_reduction: + local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims], [1, 0]) + local_tile_desc.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + else: + local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims]) + local_tile_desc.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + # Case 3. Tile is 3-D tile + elif len(local_dims) == 3: + is_reduction = self.reduction_depth < 3 and not store_reduction + if is_reduction: + local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims], [1, 2, 0]) + local_tile_desc.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + else: + local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims]) + local_tile_desc.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + # Case 4. Tile is 4-D tile (e.g., Convolution epilogue) + elif len(local_dims) == 4: + is_reduction = self.reduction_depth < 3 and not store_reduction + if is_reduction: + raise NotImplementedError("Currently not implemented... ;)") + local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims]) + local_tile_desc.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + else: + raise NotImplementedError("Currently not implemented... ;)") + + if len(implicit_local_dims)!=0 and len(local_dims) != len(implicit_local_dims) and self.is_modular_indexing(index): + tile_size = local_tile_desc.get_tile_size() + new_tile_size = [] + new_vlane_split_axis = local_tile_desc.vlane_split_axis + implicit_dim_size = list(kg_tile_desc.implicit_dim_size.values()) + for i, target_dim_size in enumerate(implicit_dim_size): + new_tile_size += [1]*(len(target_dim_size)-1) + tile_size[i:i+1] + if local_tile_desc.vlane_split_axis >= i: + new_vlane_split_axis += len(target_dim_size)-1 + # Update + local_tile_desc.set_tile_size(new_tile_size) + local_tile_desc.vlane_split_axis = new_vlane_split_axis + + # Calculate dram stride + dram_stride = [0] * local_tile_desc.get_nr_dim() + if index.is_Symbol: + dim_idx = int(str(index)[5:]) + dram_stride[dim_idx] = 1 + elif index.is_Number: + pass + else: + dram_dict = defaultdict(list) + # Assume that div will have high priority than mod + for arg in index.as_ordered_terms(): + coeff, dim = arg.as_coeff_mul() + if len(dim) == 0: + continue + real_dim = list(dim[0].free_symbols)[0] + dram_dict[str(real_dim)].append(coeff) + # Add missing dims if not added + max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1 + for i in range(max_dim): + target_dim = f"index{i}" + if target_dim not in str(index): + dram_dict[target_dim] = [0] + sorted_keys = sorted(dram_dict.keys()) + dram_stride = sum((dram_dict[key] for key in sorted_keys), []) + + # FIXME. It will be nice to modify node instead of this exception handling... + if len(self.itervars) == 1 and self.reduction_depth == 0: + # In case of reduction loop only case, we will add dummy loop so shift it once + dram_stride = [0] + dram_stride[:-1] + return local_tile_desc, index_var, dram_stride + + def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute): + dma_key = (vlane_split_axis, vlane_stride, mlir_dtype) + if dma_type_name == "MVIN" and dma_key in self.dma_read_cache: + dma_type, vlane_split_axis, vlane_stride = self.dma_read_cache[dma_key] + elif dma_type_name == "MVOUT" and dma_key in self.dma_write_cache: + dma_type, vlane_split_axis, vlane_stride = self.dma_write_cache[dma_key] + else: + vlane_split_axis = self.get_const_cse(vlane_split_axis) + vlane_stride = self.get_const_cse(vlane_stride) + if dma_type_name == "MVIN": + dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_read_counter}"]) + self.dma_read_counter += 1 + self.dma_read_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] + else: + dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_write_counter}"]) + self.dma_write_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] + tag = self.get_tag_cse() + zero_cse = self.get_const_cse(0) + + # Prepare opearnds and attributes + dram_operand = f"%{dram_var}[%{dram_index_var}]" + sram_operand = f"%{sram_var}[{sram_index_var}]" # Use string + tag_var = f"%{tag}[%{zero_cse}]" + dma_attribute = f"%{vlane_split_axis}, %{vlane_stride}" + sram_shape = tile_shape + tag_shape = "memref<1xi32>" + + if dma_type_name == "MVIN": + src_operand, dst_operand = dram_operand, sram_operand + src_shape, dst_shape = dram_shape, sram_shape + else: + src_operand, dst_operand = sram_operand, dram_operand + src_shape, dst_shape = sram_shape, dram_shape + + return f"memref.dma_start {src_operand}, {dst_operand}, %{dma_type}, {tag_var}, {dma_attribute} : {src_shape}, {dst_shape}, {tag_shape} {attribute}" + + def allocate_sram_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None, forced_name=None): c_type = mlir_common.DTYPE_TO_C[dtype] - mlir_type = mlir_common.DTYPE_TO_MLIR[dtype] + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + tile_numel_per_lane = tile_desc.get_numel_per_lane() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) # Make sure each lane's buffer has at least two element - tile_size = max(self.roundup_vectorlane(tile_row * tile_col), self.vector_lane * 2) - if dtype == torch.bool and not self.is_template_kernel: #FIXME: epilogue ReLU does not need this - if self.is_template_kernel: - mapping = f"template_{indices} " - self.map_cse.generate(self.global_vars, f"#{mapping} = affine_map<({indices}) -> ({indices} floordiv 8)>", assignment=False) - else: - mapping = self.map_cse.generate(self.global_vars, f"affine_map<({indices}) -> ({indices} floordiv 8)>") - indices = self.cse.generate(self.loads, f"affine.apply #{mapping}(%{indices})") # FIXME. Only loads? + tile_size = max(tile_numel_per_lane, 2) * self.vector_lane + + if buffer is None: + buffer = self.spad_buffer - if name not in self.global_vars_dict: - self.global_vars_dict[name] = list() + if dram_name not in self.global_vars_dict: + self.global_vars_dict[dram_name] = dict() - if str(raw_index) not in self.global_vars_dict[name]: - new_name = f"{name}_{len(self.global_vars_dict[name])}" + if str(raw_index) not in self.global_vars_dict[dram_name]: + new_name = f"buf{self.spadbuf_counter}_spad" if forced_name is None else f"{forced_name}_spad" + self.spadbuf_counter+=1 # Add definition to header - self.header.writeline(f"{c_type} {new_name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") - self.gem5_header.writeline(f"{c_type} {new_name}_spad[{tile_size}];") - self.global_vars.writeline(f"memref.global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") - self.global_vars_dict[name].append(str(raw_index)) + self.header.writeline(f"{c_type} {new_name}[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") + self.gem5_header.writeline(f"{c_type} {new_name}[{tile_size}] __attribute__((aligned(64)));") + self.global_vars.writeline(f"memref.global @{new_name} : {tile_shape}") + self.global_vars_dict[dram_name][str(raw_index)] = new_name else: - new_name = f"{name}_{self.global_vars_dict[name].index(str(raw_index))}" - buffer = self.cse.generate(code_buffer, f"memref.get_global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") - return buffer, indices - - def roundup_vectorlane(self, size, amp=1): - return ((size + self.vector_lane - 1) // self.vector_lane) * self.vector_lane * amp - -from . import mlir_lowering - -@dataclasses.dataclass -class LoopLevel: - var: sympy.Expr - size: sympy.Expr - idx: int - start: int = 0 - tile_row: int = 4 - tile_col: int = 4 - reduction_vars: Dict[str, str] = None - - def lines(self): - step = 1 - if self.idx == -2: - step = self.tile_row - elif self.idx == -1: - step = self.tile_col - if self.reduction_vars: - acc = ', '.join([f"%{acc.name}" for acc in self.reduction_vars.keys()]) - args = ', '.join([f"%{iter.name} = %{init.name}" for (_, iter, init, _) in self.reduction_vars.values()]) - dtype = ', '.join([f"{dtype}" for (_, _, _, dtype) in self.reduction_vars.values()]) - line = f"{acc} = affine.for %{self.var} = {self.start} to {self.size} step {step} iter_args({args}) -> ({dtype})" - else: - line = f"affine.for %{self.var} = {self.start} to {self.size} step {step}" - - return [line] + new_name = self.global_vars_dict[dram_name][str(raw_index)] + return new_name -@dataclasses.dataclass -class LoopNest: - loops: List[LoopLevel] + def get_scratchpad_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None): + if buffer is None: + buffer = self.spad_buffer - def __bool__(self): - return bool(self.loops) + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + new_name = self.allocate_sram_buffer(dtype, dram_name, tile_desc, raw_index, buffer=buffer) + sram_var = self.spad_cse.generate(buffer, f"memref.get_global @{new_name} : {tile_shape}") - def mark_reduction(self, reduction_vars): - for loop in self.loops: - loop.reduction_vars = reduction_vars + zero_cse = self.get_const_cse(0) + sram_index_var = ",".join([f"%{zero_cse}"] * tile_desc.get_nr_dim()) + return sram_var, sram_index_var - def mark_parallel(self, par_depth): - loops = self.loops - loops[0].parallel = par_depth - for i in range(1, par_depth): - loops[i].collapsed = True - loops[0].simd = loops[par_depth - 1].simd + def get_const_cse(self, value, dtype="index") -> common.CSEVariable: + # Type convert + if dtype[0] == "f": + value = float(value) + else: + value = int(value) -class MLIRWrapperKenrelGroup(cpp.KernelGroup): - def __init__(self): - super().__init__() - self.args = mlir_common.MLIRKernelArgs() - -class MLIRScheduling(BaseScheduling): - count = 0 - target_kernel = MLIRKernel - def __init__(self, scheduler): - self.scheduler = scheduler - self.kernel_group = MLIRWrapperKenrelGroup() - self._ready_to_flush = False - self.outer_function = set() - config.inplace_buffers = False # FIXME. inout kernel makes trouble.. So disabled it! - - def _set_flush_status(self, status: bool): - self._ready_to_flush = status - - def can_fuse_vertical(self, node1, node2): - return False - return self.can_fuse_horizontal(node1, node2) and not node1.is_reduction() - - def can_fuse_horizontal(self, node1, node2): - return False - _, (vars1, reduce1) = node1.group - _, (vars2, reduce2) = node2.group - if vars1 == vars2 and reduce1 == reduce2: - return True - #TODO: Temporary solution determining the fusion condition similar to CPP/OpenMP - v1_total = math.prod(vars1) if len(vars1) else 0 - v2_total = math.prod(vars2) if len(vars2) else 0 - r1_total = math.prod(reduce1) if len(reduce1) else 0 - r2_total = math.prod(reduce2) if len(reduce2) else 0 - if reduce1 == () \ - and v1_total == (v2_total + r2_total): - # and node1.node.layout.size == node2.node.layout.size: #FIXME: Need to check layout too? - return True - return False - - def group_fn(self, sizes): - return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) - - def codegen_nodes(self, nodes): - _, (group, reduction_group) = max( - nodes, key=lambda x: int(x.is_reduction()) - ).group - ex_kernel = self.target_kernel() - ex_kernel.kernel_group = self.kernel_group - - kernel_name = f"extension_kernel_{MLIRScheduling.count}" - MLIRScheduling.count += 1 - src_code = ex_kernel.codegen_nodes(nodes, kernel_name) - self.define_kernel(src_code, kernel_name, ex_kernel.vector_lane, - ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) - ex_kernel.call_kernel(kernel_name) - _, args, _, _ = ex_kernel.args.mlir_argdefs() - args = ", ".join(args) - if (extension_config.CONFIG_BACKENDSIM_EAGER_MODE): - V.graph.wrapper_code.writeline( - f"yield ({kernel_name}, ({args}))" - ) - self._set_flush_status(True) - - def ready_to_flush(self): - return self._ready_to_flush - - def codegen_sync(self): - pass - - def flush(self): - self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) - self.kernel_group = MLIRWrapperKenrelGroup() - self._set_flush_status(False) - - def define_function(self, kernel): - code, function_name = kernel.def_function() - if code is not None and function_name not in self.outer_function: - wrapper = V.graph.wrapper_code - wrapper.header.writeline(code) - self.outer_function.add(function_name) - - def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, tile_size=[1, 1, 1], loop_size=None, origins={}): - wrapper = V.graph.wrapper_code - if src_code in wrapper.src_to_kernel: - kernel_name = wrapper.src_to_kernel[src_code] + if value not in self.consts: + self.consts[str(value)+dtype] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") + return self.consts[str(value)+dtype] + + def get_tag_cse(self, value=None, shape="memref<1xi32>"): + if value is None: + value = self.dma_tag_id + self.dma_tag_id += 1 + if value not in self.tags: + self.tags[value] = self.alloc_cse.generate(self.alloc_buffer, f"memref.alloc() : {shape} // {value}") + return self.tags[value] + + def get_mask(self): + if self.compute_body_loop.size % self.compute_body_loop.step == 0: + return None, None + compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() + index_shape = f"vector<{self.compute_body_loop.step}xindex>" + mask_shape = f"vector<{compute_vec_size}xi1>" + + upper_bound = self.get_const_cse(self.compute_body_loop.size) + step_vec = self.const_cse.generate(self.const_buffer, f"vector.step : {index_shape}") + + gap = self.mask_cse.generate(self.masks, f"arith.subi %{upper_bound}, %{self.compute_idx} : index") + gap_vec = self.mask_cse.generate(self.masks, f"vector.broadcast %{gap} : index to {index_shape}") + mask_var = self.mask_cse.generate(self.masks, f"arith.cmpi ult, %{step_vec}, %{gap_vec} : {index_shape}") + self.register_var_info(mask_var, [compute_vec_size, "i1"]) + return mask_shape, mask_var + + def convert_indirect_indexing(self, index :sympy.Expr): + if "tmp" not in str(index): + return index + + # Process start + indirect_dims = [str(dim) for dim in index.free_symbols if "tmp" in str(dim)] + indirect_dims.sort() + first_dim = indirect_dims[0] + spad_vars = dict() + tmp_comp, self.compute = self.compute, self.dma_loads + + # Load indirect operands + for target_dim in indirect_dims: + sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim] + mlir_dtype = vshape.split("x")[1][:-1] + vshape = f"vector<{tile_numel_per_lane}x{mlir_dtype}>" # FIXME. Maybe require fine grain compute... + if tile_numel_per_lane > 1: + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{sram_index_var}] : {tile_shape}, {vshape} // For indirect access" + else: + operation = "affine.load" + line = f"{operation} %{sram_var}[{sram_index_var}] : {tile_shape} // For indirect access" + out = self.cse.generate(self.dma_loads, line) + self.register_var_info(out, [tile_numel_per_lane, mlir_dtype]) + spad_vars[target_dim] = out + + # Apply stride + for arg in index.args: + if "tmp" not in str(arg): + continue + if arg.is_Mul and arg.args[0].is_number: + coeff_dtype = self.var_info[spad_vars[str(arg.args[1])]][1] + coeff = ops.constant(int(arg.args[0]), coeff_dtype) + spad_vars[str(arg.args[1])] = ops.mul(spad_vars[str(arg.args[1])], coeff) + index = index.replace(arg, 0) + + # Sum + for dim, var in spad_vars.items(): + if dim == first_dim: + continue + spad_vars[first_dim] = ops.add(spad_vars[first_dim], var) + + # Store index var + sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[first_dim] + mlir_dtype = vshape.split("x")[1][:-1] + vshape = f"vector<{tile_numel_per_lane}x{mlir_dtype}>" # FIXME. Maybe require fine grain compute... + if tile_numel_per_lane > 1: + operation = "affine.vector_store" + line = f"{operation} %{spad_vars[first_dim]}, %{sram_var}[{sram_index_var}] : {tile_shape}, {vshape}" else: - wrapper.src_to_kernel[src_code] = kernel_name - - codecache_def = IndentedBuffer() - codecache_def.writeline(f"custom_async_compile.mlir('''{src_code}''', ") - codecache_def.writeline(f"vectorlane_size={vector_lane},") - codecache_def.writeline(f"tile_size={tile_size},") - codecache_def.writeline(f"loop_size={loop_size},") - codecache_def.writeline(f"spad_info={spad_info},") - codecache_def.writeline(f"origins={origins},") - codecache_def.writeline("arg_attributes=arg_attributes)") - wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) - return kernel_name - - def codegen_src_code(self, kernel, render, template_node, epilogue_nodes): - with kernel: - for node in [template_node, *epilogue_nodes]: - node.mark_run() - partial_code = render() - for node in epilogue_nodes: - ranges = node.get_ranges() - node.codegen(kernel.set_ranges(ranges[0], ranges[1], None)) - with V.set_kernel_handler(kernel): - src_code = ( - partial_code - if isinstance(partial_code, str) - else partial_code.finalize() - ) - src_code = kernel.add_extra_global_vars(src_code) - return src_code - - def codegen_template(self, template_node, epilogue_nodes): - _, (numel, rnumel) = template_node.group - template_buffer = template_node.node - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, epilogue_nodes=epilogue_nodes) - _, _, _, kernel.buffer_types = kernel.args.mlir_argdefs() - - src_code = self.codegen_src_code(kernel, render, template_node, epilogue_nodes) - wrapper = V.graph.wrapper_code - - if src_code in wrapper.src_to_kernel: # [CONV] check inner function is already defined - kernel_name = wrapper.src_to_kernel[src_code] - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, epilogue_nodes=epilogue_nodes, kernel_name=kernel_name) # update kernel name - src_code = self.codegen_src_code(kernel, render, template_node, epilogue_nodes) - - with V.set_kernel_handler(kernel): - codegen_header(src_code, (kernel.header.getvalue(), kernel.gem5_header.getvalue())) - # node_schedule = [template_node, *epilogue_nodes] - kernel.meta_kernel() - kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, - kernel.tile_size, kernel.loop_size, origins={str(i) for i in template_node.node.origins}) - self.define_function(kernel) - - kernel.call_kernel(kernel_name) - V.graph.removed_buffers |= kernel.removed_buffers - _, args, _, _ = kernel.args.mlir_argdefs() - args = ", ".join(args) - if (extension_config.CONFIG_BACKENDSIM_EAGER_MODE): - target_kernel_name = kernel_name if kernel.outer_func_name is None else kernel.outer_func_name - V.graph.wrapper_code.writeline( - f"yield ({target_kernel_name}, ({args}))" - ) - self._set_flush_status(True) \ No newline at end of file + operation = "affine.store" + line = f"{operation} %{spad_vars[first_dim]}, %{sram_var}[{sram_index_var}] : {tile_shape}" + out = self.cse.generate(self.dma_loads, line, assignment=False) + + # Conversion + mlir_dtype = self.var_info[spad_vars[first_dim]][1] + line = f"affine.load %{sram_var}[{sram_index_var}] : {tile_shape}" + out = self.cse.generate(self.dma_loads, line) + if mlir_dtype != "index": + line = f"arith.index_cast %{out} : {mlir_dtype} to {'index'}" + out = self.cse.generate(self.dma_loads, line) + self.register_var_info(out, [1, "index", [1]]) + self.compute = tmp_comp + return index + sympy.Symbol(str(out)) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 912704b5..9151ac0b 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -1,8 +1,19 @@ -import os +import dataclasses +import math +from typing import Dict +from typing import List +from collections import defaultdict +from functools import reduce +from operator import mul import torch +from torch._dynamo.testing import rand_strided +from torch._inductor.autotune_process import TensorMeta from torch._inductor.codegen import common +from torch._inductor.codegen import cpp from torch._inductor.virtualized import V from torch._inductor.ir import MultiOutputLayout +from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep +from torch.utils._sympy.functions import ModularIndexing import sympy import contextlib @@ -21,6 +32,7 @@ unique, ) from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") DTYPE_TO_MLIR = { @@ -32,7 +44,7 @@ torch.int16: "i16", torch.int8: "i8", torch.uint8: "i8", - torch.bool: "i1", + torch.bool: "i8", torch.bfloat16: "bf16", } @@ -54,15 +66,31 @@ torch.float16, ] +MLIR_INF = { + "inf" : { + "f32" : 0x7F800000, + "f64" : 0x7FF0000000000000 + }, + "-inf" : { + "f32" : 0xFF800000, + "f64" : 0xFFF0000000000000 + }, + "nan" : { + "f32" : 0x7FC00000, + "f64" : 0x7FF8000000000000 + } +} + class ParallelLoopBuffer(IndentedBuffer): - def indent(self, offset=1, outer_loop=True): + def indent(self, offset=1, attribute="", suffix=""): @contextlib.contextmanager def ctx(): - attribute = "{outer_loop=true}" if outer_loop else "{accumulation_loop=true}" for _ in range(offset): self.writeline("{") self._indent += 1 for _ in range(-offset): + if suffix: + self.writeline(suffix) self._indent -= 1 self.writeline("} " + attribute) yield @@ -70,6 +98,8 @@ def ctx(): self.writeline("{") self._indent += 1 for _ in range(offset): + if suffix: + self.writeline(suffix) self._indent -= 1 self.writeline("} " + attribute) @@ -98,28 +128,34 @@ def is_mlir_arg_out(value): def is_mlir_arg_inout(value): return MLIRKernelArgs.MLIR_ARGS_INOUT & value + @staticmethod + def get_mlir_shape(info): + tensor_type = DTYPE_TO_MLIR[info[0]] + return f"memref<{info[1]}x{tensor_type}>" + def mlir_argdefs(self, extra_node=dict()): buffer_types = {} for x in V.graph.buffers: if not isinstance(x.layout, MultiOutputLayout): # FIXME: MultiOutputLayout should be handled - buffer_types[x.get_name()] = [x.get_dtype(), x.get_numel()] + buffer_types[x.get_name()] = [x.get_dtype(), x.get_numel(), x.get_size(), x.get_stride()] for name, val in V.graph.graph_inputs.items(): if isinstance(val, sympy.Expr): - buffer_types[name] = [get_sympy_Expr_dtype(val), 1] + buffer_types[name] = [get_sympy_Expr_dtype(val), 1, [1], [1]] else: - buffer_types[name] = [val.get_dtype(), val.get_numel()] + buffer_types[name] = [val.get_dtype(), val.get_numel(), val.get_size(), val.get_stride()] buffer_types.update( - {name: val.dtype for name, val in V.graph.constants.items()} + {name: [val.dtype, 1, [1], [1]] for name, val in V.graph.constants.items()} ) buffer_types.update( - {name: [val.get_dtype(), val.get_numel()] for name, val in extra_node.items()} + {name: [val.get_dtype(), val.get_numel(), val.get_size(), val.get_stride()] for name, val in extra_node.items()} ) call_args = [] arg_defs = [] arg_attributes = [] def set_info(outer, inner, arg_type): - arg_defs.append(f"%{inner}: memref<{buffer_types[outer][1]}x{DTYPE_TO_MLIR[buffer_types[outer][0]]}>") + mlir_shape = self.get_mlir_shape(buffer_types[outer]) + arg_defs.append(f"%{inner}: {mlir_shape}") call_args.append(outer) arg_attributes.append([outer] + [[arg_type] + buffer_types[outer]]) @@ -141,18 +177,160 @@ def set_info(outer, inner, arg_type): set_info(outer, inner, self.MLIR_ARGS_VAR) return arg_defs, call_args, arg_attributes, buffer_types +class MLIRMultiDimTile(): + def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=None, vec_size=None): + self.name = "" + self._tile_size = list(tile_size) + self._tile_stride = None + self.tile_axis_order = list(range(len(tile_size))) + self.vec_size = vec_size + self.update_tile_stride() + + # Vector lane mapping config + self.vector_lane = vector_lane + self.vlane_split_axis = vlane_split_axis + self.vlane_stride = vlane_stride + self.implicit_dim_size = None + self.nr_rdim = 0 + + def set_name(self, name: str): + self.name = name + + def set_tile_size(self, tile_size, tile_axis_order=None): + self._tile_size = tile_size + if tile_axis_order is None: + self.tile_axis_order = list(range(len(tile_size))) + else: + self.tile_axis_order = tile_axis_order + self.update_tile_stride() + + def set_tile_size_stride(self, tile_size, tile_stride): + self._tile_size = tile_size + self._tile_stride = tile_stride + + def get_name(self) -> str: + return self.name + + def get_tile_size(self): + return self._tile_size + + def get_numel(self): + """ + Return size of multi-dimensional tile + """ + size = 1 + for dim_size in self._tile_size: + size *= dim_size + return size + + def get_numel_per_lane(self): + tile_size_per_lane = self.get_tile_size_per_lane() + size = 1 + for dim_size in tile_size_per_lane: + size *= dim_size + return size + + def update_tile_stride(self): + strides = [1] * len(self._tile_size) + init = 1 + + original_indices = list(range(len(self.tile_axis_order))) + sorted_pairs = sorted( + zip(self.tile_axis_order, self._tile_size, original_indices), + key=lambda x: x[0], reverse=True + ) + for _, size, original_indices in sorted_pairs: + strides[original_indices] = init + init *= size + self._tile_stride = strides + + def get_tile_stride(self): + return self._tile_stride + + def get_tile_size_per_lane(self): + tile_size_per_lane = list(self._tile_size) + if self.vlane_split_axis < 0 or self.vlane_split_axis >= len(tile_size_per_lane): + raise AssertionError("Not allowed split_axis") + used_vlane = self.get_used_vlane() + tile_size_per_lane[self.vlane_split_axis] = \ + self.div_round_up(tile_size_per_lane[self.vlane_split_axis], used_vlane) + return tile_size_per_lane + + def get_nr_dim(self): + """ + Return number of dimensions + """ + return len(self._tile_size) + + def get_dim_size(self, index): + if isinstance(index, int): + return self._tile_size[index] + elif "index" in str(index): + return self._tile_size[int(str(index)[5:])] + raise NotImplementedError("Unsupported format of index") + + def get_mlir_shape(self, dtype): + str_tile_size = [str(dim) for dim in self._tile_size] + shape = "x".join(str_tile_size) + return f"memref<{shape}x{dtype}, 1>" + + def get_mlir_vshape(self, mlir_dtype): + return f"vector<{self.get_compute_vec_size()}x{mlir_dtype}>" if self.get_compute_vec_size() > 1 else f"{mlir_dtype}" + + def get_used_vlane(self): + """ + Return number of used vector lane + """ + if self.vlane_split_axis < 0 or self.vlane_split_axis >= len(self._tile_size): + raise AssertionError("Not allowed split_axis") + return min(self.div_round_up(self._tile_size[self.vlane_split_axis], self.vlane_stride), self.vector_lane) + + def get_vlane_stride(self): + return self.vlane_stride + + def get_compute_vec_size(self): + # Granule size used in compute loop + if self.vec_size is not None: + return self.vec_size + if self.nr_rdim: + assert self.nr_rdim==1 + val = self.get_numel_per_lane() // self._tile_size[-1] + if self.get_numel_per_lane() >= val * 8: + return val*8 + elif self.get_numel_per_lane() >= val * 4: + return val*4 + elif self.get_numel_per_lane() >= val * 2: + return val*2 + return val + if (self.get_numel_per_lane() // self.vlane_stride) >= 8: + return self.vlane_stride * 8 + if (self.get_numel_per_lane() // self.vlane_stride) >= 4: + return self.vlane_stride * 4 + if (self.get_numel_per_lane() // self.vlane_stride) >= 2: + return self.vlane_stride * 2 + return self.vlane_stride + + @staticmethod + def div_round_up(size, round_val): + return (size + round_val - 1) // round_val + +class MLIRWrapperKenrelGroup(cpp.KernelGroup): + def __init__(self): + super().__init__() + self.args = MLIRKernelArgs() + self.tile_desc : MLIRMultiDimTile = None + + def set_tile_info(self, tile_desc : MLIRMultiDimTile): + self.tile_desc = tile_desc + class BaseMLIRHardwareInfo(): def __init__(self): # Default HW setting - self.vector_lane = 128 - self.spad_info = { - "spad_vaddr" : 0xD0000000, - "spad_paddr" : 0xD0000000, - "spad_size" : 128 << 10 # 128KB per Lane - } - self.precision = 4 # 32bit - self.num_cores = 1 - self.vlen = 32 // self.precision # 256bits / 32bits = 8 [elements] + self.vector_lane = extension_config.CONFIG_VECTOR_LANE + self.spad_info = extension_config.CONFIG_SPAD_INFO + self.precision = extension_config.CONFIG_PRECISION + self.num_cores = extension_config.CONFIG_NUM_CORES + self.vlen = extension_config.CONFIG_VLEN class BaseMLIRKernel(common.Kernel, BaseMLIRHardwareInfo): newvar_prefix = "%" @@ -161,18 +339,42 @@ class BaseMLIRKernel(common.Kernel, BaseMLIRHardwareInfo): load_format = None store_format = None - def __init__(self, args=None): - super().__init__(args) + def __init__(self, kernel_group, reason=None): + super().__init__(kernel_group.args) + self.kernel_group = kernel_group + # Kernel iteration range info + self.call_ranges = None + self.ranges = None + self.reduction_depth = None + self.itervars = None + # Code buffer self.vector_compute = IndentedBuffer() self.reductions_suffix = IndentedBuffer() self.cse = common.CSE(self.newvar_prefix, self.suffix) - self.tile_row = extension_config.CONFIG_TILE_ROW - if self.tile_row == -1: - self.tile_row = self.vlen * self.vector_lane - self.tile_col = extension_config.CONFIG_TILE_COL - if self.tile_col == -1: - self.tile_col = 8 # FIXME: tile_col is not always vector_lane * vlen - self.var_info = {} + # MLIR SSA tracker + self.var_info = {} # MLIR variable info + self.buffer_types : dict = None # format: dtype, numel, size, stride + self.compute_idx = "compute_idx" + self.compute_body_loop = LoopLevel(self.compute_idx, 1) + self.prologue_compute_body_loop = LoopLevel(self.compute_idx, 1) + self.recodegen = reason # spad overflow, tile size, vlane stride + self.stop_autotune = False + + def set_ranges(self, lengths, reduction_lengths): + if self.call_ranges: + assert self.call_ranges == tuple(lengths) + tuple( + reduction_lengths + ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" + assert self.reduction_depth == len(lengths) + else: + self.call_ranges = tuple(lengths) + tuple(reduction_lengths) + self.ranges = [self.rename_indexing(x) for x in self.call_ranges] + self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))] + self.reduction_depth = len(lengths) + return ( + self.itervars[: self.reduction_depth], + self.itervars[self.reduction_depth :], + ) def load(self, name: str, index: sympy.Expr): raise NotImplementedError() @@ -186,16 +388,335 @@ def store(self, name, index, value, mode=None): def reduction(self, dtype, src_dtype, reduction_type, value): raise NotImplementedError() - def check_dtype_in_args(self, args): - dtype = torch.float32 # default dtype - for arg in args: - if arg in list(DTYPE_TO_MLIR.keys()): - dtype = arg - return dtype + def indirect_indexing(self, index_var, size, check): + raise NotImplementedError() + + def codegen_global_init(self): + raise NotImplementedError() + + def codegen_loops(self): + raise NotImplementedError() + + def call_kernel(self, kernel_name): + wrapper = V.graph.wrapper_code + _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() + # generate the code to call this + wrapper.generate_kernel_call(kernel_name, call_args, cuda=False) + + def is_modular_indexing(self, expr): + return "ModularIndexing" in str(expr) + + def compute_tile_size(self, nodes, vars, reduction_vars): + # Handle implict dims. Input operand could have larger dimension space. + implicit_ranges = False + target_operand : MemoryDep = None + implicit_dim_size = defaultdict(list) + for read_operand in nodes[0].read_writes.reads: + read_operand : MemoryDep + if isinstance(read_operand, StarDep) or isinstance(read_operand, WeakDep): # FIXME: WeakDep & StarDep are not supported (MoE case) + continue + read_index = read_operand.index + for arg in read_index.args: + if "ModularIndexing" in str(arg) or "//" in str(arg): + implicit_ranges = True + target_operand = read_operand + break + + if implicit_ranges: + #print("This operation contain implicit dimension space!") + linearized_stride = [1] * len(target_operand.var_names) + for i in range(len(target_operand[3])-2, -1, -1): + linearized_stride[i] = linearized_stride[i+1] * target_operand[3][i+1] + + linearized_index = sympy.Integer(0) + for dim, stride in zip(target_operand[2], linearized_stride): + linearized_index += stride * dim + + new_dim_expression = [] + new_dim_size = [] + for arg in target_operand.index.args: + if len(arg.free_symbols) != 1: + raise NotImplementedError("Not supporting this view operation...!") + + if arg.is_Mul and arg.args[0].is_number: + arg = arg.args[1] + + if isinstance(arg, ModularIndexing): + modular_expr = ModularIndexing(arg.args[0], arg.args[1], arg.args[2]) + elif arg.is_symbol: + modular_expr = ModularIndexing(arg, 1, target_operand.ranges[arg]) + elif "//" in str(arg): + modular_expr = ModularIndexing(arg.args[0], arg.args[1], target_operand.ranges[arg.args[0]]//arg.args[1]) + else: + raise NotImplementedError("What is this case?") + new_dim_expression.append(modular_expr) + new_dim_size.append(modular_expr.args[2]) + implicit_dim_size[int(str(modular_expr.args[0])[1:])].append(int(modular_expr.args[2])) + + # Sanity check + for dim, sub_dims in implicit_dim_size.items(): + sz = reduce(mul, sub_dims, 1) + if sz != target_operand[3][dim]: + raise NotImplementedError("Not supporting type...") + + vlane_split_axis = len(vars) - 1 # Set split_axis as a last normal loop not reduction loop + + # FIXME: Naive decrease tile size + def decrease_tile_size(tile_size, vlane_split_axis): + is_decreased = False + + # Decrease vlane_split_axis when it is too large + if tile_size[vlane_split_axis] > vlane_stride * self.vector_lane: + tile_size[vlane_split_axis] = int(tile_size[vlane_split_axis] // 2) + return tile_size + + for i in range(len(tile_size)): + if i == vlane_split_axis: + continue + if tile_size[i] > 1: + tile_size[i] = int(tile_size[i] // 2) + is_decreased = True + break + + # Decrease vlane_split_axis at the end to maximize the vlane usage + if not is_decreased: + if tile_size[vlane_split_axis] > 1: + tile_size[vlane_split_axis] = int(tile_size[vlane_split_axis] // 2) + return tile_size + + # Dummy tile size + def dummy_tile_size(): + tile_size = [1] * (len(vars) + len(reduction_vars)) + if len(tile_size) == 2: + tile_size[-1] = vlane_stride * self.vector_lane + tile_size[-2] = 2 * self.vector_lane + elif len(tile_size) == 0: # Scalar + tile_size = [1] + self.ranges = [1] + elif len(tile_size) == 1: + tile_size[0] = 2 * vlane_stride * self.vector_lane + elif len(tile_size) == 3: + tile_size[-1] = self.vector_lane + tile_size[-2] = 4 * self.vector_lane + tile_size[-3] = 2 + elif len(tile_size) == 4: + tile_size[-1] = self.vector_lane + tile_size[-2] = 4 * self.vector_lane + tile_size[-3] = 2 + tile_size[-4] = 1 + else: + raise NotImplementedError("dummy tile size fail!") + return tile_size + + vlane_stride = extension_config.CONFIG_VECTOR_LANE_STRIDE + if self.recodegen is None: + tile_size = dummy_tile_size() + else: + if self.recodegen == "spad_overflow": + tile_size = self.kernel_group.tile_desc.get_tile_size() + decrease_tile_size(tile_size, vlane_split_axis) + elif self.recodegen == "vlane_stride": + tile_size = dummy_tile_size() + elif "tile_size" in self.recodegen: + dim = int(self.recodegen.split("_")[-1]) + tile_size = self.kernel_group.tile_desc.get_tile_size() # TODO: + tile_size[dim] = tile_size[dim] * 2 + else: + raise NotImplementedError(f"Unknown recodegen reason: {self.recodegen}") + + # FIXME: Not considering removed buffers + n_buffer = sum( + len(node.read_writes.reads) + len(node.read_writes.writes) + for node in nodes + ) + + spad_overflow = True + # Find proper tile size + while spad_overflow: + # Adjust tile size to avoid too much paddings + for i in range(1, len(tile_size)+1): + target_range = self.ranges[-i] + if implicit_ranges: + target_range = implicit_dim_size[len(tile_size)-i][-1] + + if tile_size[-i] > target_range: + remains = (target_range % vlane_stride) + self.stop_autotune = True + tile_size[-i] = target_range + if remains: + tile_size[-i] += vlane_stride - remains + + # Adjust tile size + for i in range(len(vars)): + if tile_size[i] >= self.vector_lane: # maximize used vector lane + vlane_split_axis = i + used_vlane = min((tile_size[vlane_split_axis] + vlane_stride - 1) // vlane_stride, self.vector_lane) + padded_size = used_vlane * vlane_stride + tile_size[vlane_split_axis] = ((tile_size[vlane_split_axis] + padded_size - 1) // padded_size) * padded_size + + # Check spad overflow + spad_usage_per_vlane = n_buffer * math.prod(tile_size) * self.precision // used_vlane + if spad_usage_per_vlane >= self.spad_info["spad_size"]: + new_tile_size = decrease_tile_size(tile_size.copy(), vlane_split_axis) + if new_tile_size == tile_size: + raise NotImplementedError("Error: Cannot find proper tile size") + tile_size = new_tile_size + spad_overflow = True + self.stop_autotune = True # for auto-tune + continue + else: + spad_overflow = False + + # Maximize the utilizaiotn of vectorlane + if len(reduction_vars): + minimum_stride = max(self.roundup_vectorlane(tile_size[vlane_split_axis]) // self.vector_lane, 2) + vlane_stride = min(minimum_stride, 8) + + # Handle scalar case + if len(self.ranges)==1 and self.ranges[0] == 1: + vlane_stride = 1 + vlane_split_axis = 0 + tile_size[0] = 1 + elif vlane_split_axis == -1: + vlane_split_axis = 0 + vlane_stride = tile_size[0] + + # Select tile info. + # Note: Kernel Group have to share same tile desc for fusion + tile_desc = MLIRMultiDimTile(tile_size, self.vector_lane) + tile_desc.vlane_split_axis = vlane_split_axis + tile_desc.vlane_stride = vlane_stride + tile_desc.implicit_dim_size = implicit_dim_size + tile_desc.nr_rdim = len(reduction_vars) + return tile_desc + + def codegen_nodes(self, nodes, kernel_name): + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + + # Set node range info + vars, reduction_vars = self.set_ranges(group, reduction_group) + tile_desc = self.compute_tile_size(nodes, vars, reduction_vars) + self.compute_body_loop.size = tile_desc.get_numel_per_lane() + self.compute_body_loop.step = tile_desc.get_compute_vec_size() + self.kernel_group.set_tile_info(tile_desc) + + _, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs() + with self as kernel: + for node in nodes: + node.run(vars, reduction_vars) + V.graph.removed_buffers |= self.removed_buffers + # V.graph.inplaced_to_remove |= self.inplaced_to_remove + src_code = self.codegen_kernel(kernel_name=kernel_name) + self.meta_kernel() + return src_code + + def run_bench(self, nodes, kernel_name, src_code): + _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() + input_call_args = tuple(self.args.input_buffers.keys()) + output_call_args = tuple(self.args.output_buffers.keys()) + full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args]) + full_output_nodes = tuple([V.graph.get_buffer(k) for k in output_call_args]) + + bmreq = MLIRBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(full_output_nodes), + extra_args={ + "vector_lane" : self.vector_lane, + "spad_info": self.spad_info, + "vlen" : self.vlen, + "arg_attributes" : arg_attributes + }, + source_code=src_code, + ) + dummy_inputs = [rand_strided(meta.sizes,meta.strides,dtype=meta.dtype, extra_size=meta.offset).to(device=nodes[0].get_device()) for meta in bmreq.input_tensor_meta] + dummy_outputs = [rand_strided(meta.sizes,meta.strides,dtype=meta.dtype, extra_size=meta.offset).to(device=nodes[0].get_device()) for meta in bmreq.output_tensor_meta] + return bmreq.make_run_fn(dummy_inputs, dummy_outputs) + + def codegen_kernel(self, kernel_name): + arg_defs, _, _, _ = self.kernel_group.args.mlir_argdefs() + arg_defs = ",\n".ljust(25).join(arg_defs) + code = common.BracesBuffer() + + #TODO:. kernel name custom + kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" + + code.splice(self.codegen_global_init()) + code.writeline(f'func.func @{kernel_decl_name}({arg_defs})') + with code.indent(): + for old, new in self.kernel_group.args.aliases(): + code.writeline(f"auto {old} = {new};") + # Loop body part + code.splice(self.codegen_loops()) + return code.getvalue() + + def meta_kernel(self): + wrapper = V.graph.wrapper_code + _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() + wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') + # Dump loop and load/store information + wrapper.add_import_once(f"arg_attributes = {arg_attributes}") + return arg_attributes + + def get_constant_vector(self, expr): + constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] + return constant_vector + + def get_constant_vector2(self, expr): + # Case 0. symbol ex) index 0 + # Case 1. inner product form ex) 16 * index0 + 1 * index1 + # Case 2. Complicated form ex) 16 * index0 + 8 * (index//4) + (index % 4) + constant_vector = [] + if expr.is_symbol: + constant_vector.append(tuple([1, expr])) + return constant_vector + + for arg in expr.args: + if arg.is_symbol: + constant_vector.append(tuple([1,arg])) + continue + if len(arg.args) == 0: #TODO: check this + continue + if arg.args[0].is_number: + constant_vector.append(arg.args) + else: + constant_vector.append([1, arg]) + + return constant_vector + + def find_node_by_name(self, name): + if name in V.graph.graph_inputs: + return V.graph.graph_inputs[name] + else: + for output_node in V.graph.graph_outputs: + if output_node.data.name == name: + return output_node + + def is_scalar(self, name): + return self.buffer_types[name][1] == 1 + + def roundup_vectorlane(self, size, amp=1): + return ((size + self.vector_lane - 1) // self.vector_lane) * self.vector_lane * amp def register_var_info(self, var, var_info): self.var_info[var] = var_info + def rename_indexing(self, index) -> sympy.Expr: + # adds the necessary kernel args for index expressions + # and renames variables in index expressions to kernel arg names + if isinstance(index, (list, tuple)): + return [self.rename_indexing(x) for x in index] + index = V.graph.sizevars.simplify(index) + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) + replacements = { + x: self.kernel_group.args.size(x) + for x in sorted_symbols + if x.name.startswith("s") or x.name.startswith("ps") + } + return sympy_subs(index, replacements) + def __enter__(self): class CSEProxy: self.name = "CSEProxy" @@ -203,22 +724,16 @@ class CSEProxy: @staticmethod def __getattr__(name: str) -> Callable[..., common.CSEVariable]: # type: ignore[misc] def inner(*args, **kwargs): - # TritonTemplateKernel has no current_node - buf_bounds = ValueRanges.unknown() - if hasattr(V.interpreter, "current_node"): - fx_node = V.interpreter.current_node - assert isinstance(self.node_to_bounds, dict) - buf_bounds = self.node_to_bounds.get( - fx_node, ValueRanges.unknown() - ) code, ret_info = getattr(parent_handler, name)(*args, var_info=self.var_info) csevar = self.cse.generate( self.compute, code, - bounds=buf_bounds, + bounds=ValueRanges.unknown(), + assignment=(ret_info[0] is not None) ) - self.register_var_info(csevar, ret_info) - csevar.update_on_args(name, args, kwargs) + if ret_info[0] is not None: + self.register_var_info(csevar, ret_info) + csevar.update_on_args(name, args, kwargs) return csevar return inner @@ -226,7 +741,7 @@ def inner(*args, **kwargs): @staticmethod def indirect_indexing(index_var, size, check=True): # Skip CSE since this doesn't return an expression - return sympy_symbol(str(index_var)) # type: ignore[attr-defined] + return self.indirect_indexing(index_var, size, check) @staticmethod def load(name: str, index: sympy.Expr): @@ -239,7 +754,11 @@ def load(name: str, index: sympy.Expr): store_cache = self.cse.store_cache if name in store_cache: return store_cache[name] - return self.load(name, index) + key = name+str(index) + if key not in self.cse.cache: + result = self.load(name, index) + self.cse.cache[key] = result + return self.cse.cache[key] @staticmethod def store(name, index, value, mode=None): @@ -267,6 +786,14 @@ def store_reduction(name, index, value): def reduction(dtype, src_dtype, reduction_type, value): return self.reduction(dtype, src_dtype, reduction_type, value) + @staticmethod + def _index_expr(tile_size, buffer, renamed_expression, index): + return self._index_expr(tile_size, buffer, renamed_expression, index) + + @staticmethod + def index_expr(index, dtype): + return self.index_expr(index, dtype) + @staticmethod def bucketize( values, @@ -300,16 +827,49 @@ def bucketize( self.exit_stack.enter_context(V.set_kernel_handler(self)) return self - def rename_indexing(self, index) -> sympy.Expr: - # adds the necessary kernel args for index expressions - # and renames variables in index expressions to kernel arg names - if isinstance(index, (list, tuple)): - return [self.rename_indexing(x) for x in index] - index = V.graph.sizevars.simplify(index) - sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) - replacements = { - x: self.args.size(x) - for x in sorted_symbols - if x.name.startswith("s") or x.name.startswith("ps") - } - return sympy_subs(index, replacements) + +@dataclasses.dataclass +class LoopLevel: + var: sympy.Expr + size: sympy.Expr + start: int = 0 + step: int = 1 + reduction_vars: Dict[str, str] = dataclasses.field(default_factory=dict) + affine_yield: Dict[str, str] = dataclasses.field(default_factory=dict) + + def lines(self): + if len(self.reduction_vars): + acc = ', '.join([f"%{acc.name}" for acc in self.reduction_vars.keys()]) + args = ', '.join([f"%{iter.name} = %{init.name}" for (_, iter, init, _) in self.reduction_vars.values()]) + dtype = ', '.join([f"{dtype}" for (_, _, _, dtype) in self.reduction_vars.values()]) + line = f"{acc} = affine.for %{self.var} = {self.start} to {self.size} step {self.step} iter_args({args}) -> ({dtype})" + else: + line = f"affine.for %{self.var} = {self.start} to {self.size} step {self.step}" + + return [line] + + def epilogue_line(self): + if len(self.affine_yield): + vars = ', '.join([f"%{name}" for name, _ in self.affine_yield.items()]) + reduced_shapes = ', '.join([f"{shape}" for _, shape in self.affine_yield.items()]) + return f"affine.yield {vars} : {reduced_shapes}" + return "" + +@dataclasses.dataclass +class LoopNest: + loops: List[LoopLevel] + + def __bool__(self): + return bool(self.loops) + + def mark_reduction(self, reduction_vars, affine_yield=dict()): + for loop in self.loops: + loop.reduction_vars = reduction_vars + loop.affine_yield = affine_yield + + def mark_parallel(self, par_depth): + loops = self.loops + loops[0].parallel = par_depth + for i in range(1, par_depth): + loops[i].collapsed = True + loops[0].simd = loops[par_depth - 1].simd \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py new file mode 100644 index 00000000..c5ec004c --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -0,0 +1,346 @@ +import os +import math +from sympy import Symbol, Number +from typing import List, Optional + +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +from torch._inductor.ir import IRNode +from torch._inductor.codecache import write_atomic +import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common +from torch._inductor.codecache import get_hash +from PyTorchSimFrontend import extension_config + +CONV_TEMPLATE = r""" +// Multi Channel Tile Conv2D kernel +// BATCH = {{ BATCH }} +// I_C = {{ I_C }} +// I_H = {{ I_H }} +// I_W = {{ I_W }} +// O_C = {{ O_C }} +// K_H = {{ K_H }} +// K_W = {{ K_W }} +// O_H = {{ O_H }} +// O_W = {{ O_W }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// TILE_I_H={{ TILE_I_H }}, +// TILE_I_W={{ TILE_I_W }}, +// TILE_O_H={{ TILE_O_H }}, +// TILE_O_W={{ TILE_O_W }}, +// TILE_K_H={{ TILE_K_H }}, +// TILE_K_W={{ TILE_K_W }}, +// SUB_TILE_M={{ SUB_TILE_M }}, +// SUB_TILE_N={{ SUB_TILE_N }}, +// SUB_TILE_I_W={{ SUB_TILE_I_W }}, +// SUB_TILE_K_H={{ SUB_TILE_K_H }}, +// SUB_TILE_K_W={{ SUB_TILE_K_W }}, +// PADDING_H = {{ PADDING_H }} +// PADDING_W = {{ PADDING_W }} +// STRIDE_H = {{ STRIDE_H }} +// STRIDE_W = {{ STRIDE_W }} +// DATA_STYPE = {{ DATA_STYPE }} + +#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> +#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(1 * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> +#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> +#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %c0 = arith.constant 0 : index + {{- kernel.def_local_vars(indent_size=2) }} + + affine.for %tile_m = 0 to {{ BATCH }} step {{ TILE_M }} { + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { + // Initialize output + {%- if BIAS %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} + {%- else %} + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + {%- endif %} + affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { + affine.for %tile_k = 0 to {{ I_C * K_W }} step {{ TILE_K }} { + %index_i_h = affine.apply #map_I_H(%o_h, %k_h) + // Load input matrix + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_M, SUB_TILE_K], indent_size=14) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=14) }} + // Compute body part + affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. + affine.for %tile_k_w = 0 to 1 { + %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + affine.for %tile_o_h = 0 to {{ TILE_O_H }} { + affine.for %tile_o_w = 0 to {{ TILE_O_W }} { + %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) + %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_o_w) + %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { accumulation_loop=true, subtile_loop="k" } + } { accumulation_loop=true } + // Store output matrix + {{kernel.store_output(indent_size=10)}} + } { outer_loop=true } + } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } + return +} +""" + +WRAPPER_TEMPLATE = r""" +def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: + # Padding input + padded_shape = list(X.shape) + padded_shape[2] += 2 * {{ PADDING_H }} + padded_shape[3] += 2 * {{ PADDING_W }} + X_padding = torch.zeros(padded_shape, device=X.device) + X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X + + # Tanspose inputs + {%- for buf, name in kernel.get_conv_inputs().items() %} + {%- if name == "X" %} + {{ name }} = {{ name }}_padding.permute(2, 0, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (I_H, BATCH, I_W, I_C) + {%- elif name == "W" %} + {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) + {%- elif name == "Bias" %} + {{ name }} = {{ name }} + {%- endif %} + {%- endfor %} + + # Launch kernel + {{ KERNEL_NAME }} + {%- if BACKENDSIM_EAGER_MODE %} + yield ({{KERNEL_NAME}}, ) + {%- endif %} +""" + +class MLIRConvMultiTileTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.stride = kwargs["stride"] + self.padding = kwargs["padding"] + self.dilation = kwargs["dilation"] + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [i for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + + "_".join([str(i) for i in self.stride]) \ + + "_" + "_".join([str(i) for i in self.padding]) \ + + "_" + "_".join([str(i) for i in self.dilation]) + self.kernel_args = ['X', 'W', 'Bias', 'Y'] + + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs): + # Extract input arguments info + if template_buffer_node is not None: + self.output_node = template_buffer_node + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + + # Select tile size adn template + conv_template = CONV_TEMPLATE + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_I_H, TILE_O_W, TILE_M, TILE_K] + X_tile_stride = [TILE_O_W*TILE_M*TILE_K, TILE_M*TILE_K, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("index_i_h"), Symbol("o_w"), Symbol("tile_m"), Symbol("tile_k")] + X_idx = [X_dim[0]*(I_W+2*PADDING_W)*BATCH*I_C, X_dim[1]*I_C*STRIDE_W, X_dim[2]*I_C*(I_W+2*PADDING_W), X_dim[3]] + + W_tile_size = [TILE_K_H, 1, TILE_K, TILE_N] + W_tile_stride = [TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , Symbol("c0"), W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [TILE_M, TILE_N, TILE_O_H, TILE_O_W] + Y_tile_stride = [1, TILE_M, TILE_O_W * TILE_M * TILE_N, TILE_M * TILE_N] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_dim = [Symbol("tile_m"), Symbol("tile_n"), Symbol("o_h"), Symbol("o_w")] + Y_idx = [Y_dim[0]*O_C*O_H*O_W, Y_dim[1]*O_H*O_W, Y_dim[2]*O_W, Y_dim[3]] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=X, W=W, Y=Y, BIAS=Bias, + PADDED_INPUT_SIZE=self.get_padded_input_size(X), + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, + O_C=O_C, + K_H=K_H, + K_W=K_W, + O_H=O_H, + O_W=O_W, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + TILE_I_H=TILE_I_H, + TILE_I_W=TILE_I_W, + TILE_O_H=TILE_O_H, + TILE_O_W=TILE_O_W, + TILE_K_H=TILE_K_H, + TILE_K_W=TILE_K_W, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, + SUB_TILE_I_H=SUB_TILE_I_H, + SUB_TILE_I_W=SUB_TILE_I_W, + SUB_TILE_K_H=SUB_TILE_K_H, + SUB_TILE_K_W=SUB_TILE_K_W, + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + DATA_STYPE="f32", + input_reorder=self.input_reorder + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "output_buffer", + dram_var = "Y", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"tile_m", "index1":"tile_n", "index2":"o_h", "index3":"o_w"} + ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} + code = self._template_from_string(conv_template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + return code + + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_K = TILE_K + + TOG_latency = O_W if TILE_M > O_W else TILE_M + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + options = dict( + kernel=self.kernel, + KERNEL_NAME=kernel_name, + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, + PADDING_H=self.padding[0], + PADDING_W=self.padding[1], + VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder + ) + code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" + + def get_arg_attributes(self): + arg_attributes = [] + + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] + + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride + + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + + return arg_attributes + + def codegen_header(self, code, extra_headers): + write_path = extension_codecache.get_write_path(code) + if not os.path.exists(write_path): + os.makedirs(write_path) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, extra_headers[0]) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, extra_headers[1]) + self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py new file mode 100644 index 00000000..6c31776d --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -0,0 +1,342 @@ +import os +import math +from sympy import Symbol, Number +from typing import List, Optional + +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +from torch._inductor.ir import IRNode +from torch._inductor.codecache import write_atomic +import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common +from torch._inductor.codecache import get_hash +from PyTorchSimFrontend import extension_config + +CONV_TEMPLATE = r""" +// Single Batch Conv2D kernel +// BATCH = {{ BATCH }} +// I_C = {{ I_C }} +// I_H = {{ I_H }} +// I_W = {{ I_W }} +// O_C = {{ O_C }} +// K_H = {{ K_H }} +// K_W = {{ K_W }} +// O_H = {{ O_H }} +// O_W = {{ O_W }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// TILE_I_H={{ TILE_I_H }}, +// TILE_I_W={{ TILE_I_W }}, +// TILE_O_H={{ TILE_O_H }}, +// TILE_O_W={{ TILE_O_W }}, +// TILE_K_H={{ TILE_K_H }}, +// TILE_K_W={{ TILE_K_W }}, +// SUB_TILE_M={{ SUB_TILE_M }}, +// SUB_TILE_N={{ SUB_TILE_N }}, +// SUB_TILE_I_W={{ SUB_TILE_I_W }}, +// SUB_TILE_K_H={{ SUB_TILE_K_H }}, +// SUB_TILE_K_W={{ SUB_TILE_K_W }}, +// PADDING_H = {{ PADDING_H }} +// PADDING_W = {{ PADDING_W }} +// STRIDE_H = {{ STRIDE_H }} +// STRIDE_W = {{ STRIDE_W }} +// DATA_STYPE = {{ DATA_STYPE }} + +#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> +#map_I_W = affine_map<(d0, d1) -> (d0 * {{ STRIDE_W }} + d1)> +#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> +#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_I_W, TILE_K) }} + d1)> +#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %c0 = arith.constant 0 : index + {{- kernel.def_local_vars(indent_size=2) }} + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { + // Initialize output + {%- if BIAS %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} + {%- else %} + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + {%- endif %} + affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { + affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { + affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { + %index_i_h = affine.apply #map_I_H(%o_h, %k_h) + %index_i_w = affine.apply #map_I_W(%tile_m, %k_w) + // Load input & weight matrix + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_I_H, SUB_TILE_M, SUB_TILE_K], indent_size=14) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=14) }} + // Compute body part + affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. + affine.for %tile_k_w = 0 to {{ TILE_K_W }} { + %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + affine.for %tile_o_h = 0 to {{ TILE_O_H }} { + affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W + %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) + %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) + %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { accumulation_loop=true, subtile_loop="k" } + } { accumulation_loop=true } + } { accumulation_loop=true } + // Store output matrix + {{kernel.store_output(indent_size=8)}} + } { outer_loop=true, subtile_loop="m" } + } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + return +} +""" + +WRAPPER_TEMPLATE = r""" +def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: + # Padding input + padded_shape = list(X.shape) + padded_shape[2] += 2 * {{ PADDING_H }} + padded_shape[3] += 2 * {{ PADDING_W }} + X_padding = torch.zeros(padded_shape, device=X.device) + X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X + + # Tanspose inputs + {%- for buf, name in kernel.get_conv_inputs().items() %} + {%- if name == "X" %} + {{ name }} = {{ name }}_padding.permute(0, 2, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (BATCH, I_H, I_W, I_C) + {%- elif name == "W" %} + {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) + {%- elif name == "Bias" %} + {{ name }} = {{ name }} + {%- endif %} + {%- endfor %} + + # Launch kernel + {{ KERNEL_NAME }} + {%- if BACKENDSIM_EAGER_MODE %} + yield ({{KERNEL_NAME}}, ) + {%- endif %} +""" + +class MLIRConvSingleBatchTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.stride = kwargs["stride"] + self.padding = kwargs["padding"] + self.dilation = kwargs["dilation"] + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [i for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + + "_".join([str(i) for i in self.stride]) \ + + "_" + "_".join([str(i) for i in self.padding]) \ + + "_" + "_".join([str(i) for i in self.dilation]) + self.kernel_args = ['X', 'W', 'Bias', 'Y'] + + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs): + # Extract input arguments info + if template_buffer_node is not None: + self.output_node = template_buffer_node + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + + # Select tile size adn template + conv_template = CONV_TEMPLATE + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [1, TILE_I_H, TILE_I_W, TILE_K] + X_tile_stride = [TILE_I_H * TILE_I_W * TILE_K , TILE_I_W * TILE_K, 1, TILE_I_W] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("c0"), Symbol("index_i_h"), Symbol("index_i_w"), Symbol("tile_k")] + X_idx = [X_dim[0]*((I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C), X_dim[1]*((I_W+2*PADDING_W)*I_C), X_dim[2]*I_C, X_dim[3]] + + W_tile_size = [TILE_K_H, TILE_K_W, TILE_K, TILE_N] + W_tile_stride = [TILE_K_W * TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , W_dim[1]*I_C*O_C, W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [1, TILE_N, TILE_O_H, TILE_M] + Y_tile_stride = [TILE_O_H * TILE_M * TILE_N, TILE_M, TILE_M * TILE_N, 1] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_idx = [Number(0), Symbol("tile_n")*O_H*O_W, Symbol("o_h")*O_W, Symbol("tile_m")] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=X, W=W, Y=Y, BIAS=Bias, + PADDED_INPUT_SIZE=self.get_padded_input_size(X), + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, + O_C=O_C, + K_H=K_H, + K_W=K_W, + O_H=O_H, + O_W=O_W, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + TILE_I_H=TILE_I_H, + TILE_I_W=TILE_I_W, + TILE_O_H=TILE_O_H, + TILE_O_W=TILE_O_W, + TILE_K_H=TILE_K_H, + TILE_K_W=TILE_K_W, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, + SUB_TILE_I_H=SUB_TILE_I_H, + SUB_TILE_I_W=SUB_TILE_I_W, + SUB_TILE_K_H=SUB_TILE_K_H, + SUB_TILE_K_W=SUB_TILE_K_W, + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + DATA_STYPE="f32", + input_reorder=self.input_reorder + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "output_buffer", + dram_var = "Y", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"c0", "index1":"tile_n", "index2":"o_h", "index3":"tile_m"} + ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} + code = self._template_from_string(conv_template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + return code + + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_I_W if TILE_I_W < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + TOG_latency = O_W if TILE_M > O_W else TILE_M + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + options = dict( + kernel=self.kernel, + KERNEL_NAME=kernel_name, + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, + PADDING_H=self.padding[0], + PADDING_W=self.padding[1], + VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder + ) + code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" + + def get_arg_attributes(self): + arg_attributes = [] + + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] + + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride + + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + + return arg_attributes + + def codegen_header(self, code, extra_headers): + write_path = extension_codecache.get_write_path(code) + if not os.path.exists(write_path): + os.makedirs(write_path) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, extra_headers[0]) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, extra_headers[1]) + self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py new file mode 100644 index 00000000..74309b30 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -0,0 +1,343 @@ +import os +import math +from sympy import Symbol, Number +from typing import List, Optional + +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +from torch._inductor.ir import IRNode +from torch._inductor.codecache import write_atomic +import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common +from torch._inductor.codecache import get_hash +from PyTorchSimFrontend import extension_config + +CONV_TEMPLATE = r""" +// Single Batch Conv2D (Stride != 1) kernel +// BATCH = {{ BATCH }} +// I_C = {{ I_C }} +// I_H = {{ I_H }} +// I_W = {{ I_W }} +// O_C = {{ O_C }} +// K_H = {{ K_H }} +// K_W = {{ K_W }} +// O_H = {{ O_H }} +// O_W = {{ O_W }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// TILE_I_H={{ TILE_I_H }}, +// TILE_I_W={{ TILE_I_W }}, +// TILE_O_H={{ TILE_O_H }}, +// TILE_O_W={{ TILE_O_W }}, +// TILE_K_H={{ TILE_K_H }}, +// TILE_K_W={{ TILE_K_W }}, +// SUB_TILE_M={{ SUB_TILE_M }}, +// SUB_TILE_N={{ SUB_TILE_N }}, +// SUB_TILE_I_W={{ SUB_TILE_I_W }}, +// SUB_TILE_K_H={{ SUB_TILE_K_H }}, +// SUB_TILE_K_W={{ SUB_TILE_K_W }}, +// PADDING_H = {{ PADDING_H }} +// PADDING_W = {{ PADDING_W }} +// STRIDE_H = {{ STRIDE_H }} +// STRIDE_W = {{ STRIDE_W }} +// DATA_STYPE = {{ DATA_STYPE }} + +#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> +#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> +#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M * TILE_K_W, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> +#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> + +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %c0 = arith.constant 0 : index + {{- kernel.def_local_vars(indent_size=2) }} + + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { + // Initialize output + {%- if BIAS %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} + {%- else %} + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + {%- endif %} + affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { + affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { + affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { + %index_i_h = affine.apply #map_I_H(%o_h, %k_h) + // Load input & weight matrix + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_I_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_K], indent_size=14) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=14) }} + // Compute body part + affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. + affine.for %tile_k_w = 0 to {{ TILE_K_W }} { + %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + affine.for %tile_o_h = 0 to {{ TILE_O_H }} { + affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W + %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) + %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) + %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { accumulation_loop=true, subtile_loop="k" } + } { accumulation_loop=true } + } { accumulation_loop=true } + // Store output matrix + {{kernel.store_output(indent_size=8)}} + } { outer_loop=true, subtile_loop="m" } + } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + return +} +""" + +WRAPPER_TEMPLATE = r""" +def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: + # Padding input + padded_shape = list(X.shape) + padded_shape[2] += 2 * {{ PADDING_H }} + padded_shape[3] += 2 * {{ PADDING_W }} + X_padding = torch.zeros(padded_shape, device=X.device) + X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X + + # Tanspose inputs + {%- for buf, name in kernel.get_conv_inputs().items() %} + {%- if name == "X" %} + {{ name }} = {{ name }}_padding.permute(0, 2, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (BATCH, I_H, I_W, I_C) + {%- elif name == "W" %} + {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) + {%- elif name == "Bias" %} + {{ name }} = {{ name }} + {%- endif %} + {%- endfor %} + + # Launch kernel + {{ KERNEL_NAME }} + {%- if BACKENDSIM_EAGER_MODE %} + yield ({{KERNEL_NAME}}, ) + {%- endif %} +""" + +class MLIRConvSingleBatchStridedTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.stride = kwargs["stride"] + self.padding = kwargs["padding"] + self.dilation = kwargs["dilation"] + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [i for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + + "_".join([str(i) for i in self.stride]) \ + + "_" + "_".join([str(i) for i in self.padding]) \ + + "_" + "_".join([str(i) for i in self.dilation]) + self.kernel_args = ['X', 'W', 'Bias', 'Y'] + + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs): + # Extract input arguments info + if template_buffer_node is not None: + self.output_node = template_buffer_node + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + + # Select tile size adn template + conv_template = CONV_TEMPLATE + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_I_H, TILE_K_W, TILE_M, TILE_K] + X_tile_stride = [TILE_K_W*TILE_M*TILE_K, TILE_M*TILE_K, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("index_i_h"), Symbol("k_w"), Symbol("tile_m"), Symbol("tile_k")] + X_idx = [X_dim[0]*((I_W+2*PADDING_W)*I_C), X_dim[1]*I_C, X_dim[2]*(I_C*STRIDE_W), X_dim[3]] + + W_tile_size = [TILE_K_H, TILE_K_W, TILE_K, TILE_N] + W_tile_stride = [TILE_K_W * TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , W_dim[1]*I_C*O_C, W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [1, TILE_N, TILE_O_H, TILE_M] + Y_tile_stride = [TILE_O_H * TILE_M * TILE_N, TILE_M, TILE_M * TILE_N, 1] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_idx = [Number(0), Symbol("tile_n")*O_H*O_W, Symbol("o_h")*O_W, Symbol("tile_m")] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=X, W=W, Y=Y, BIAS=Bias, + PADDED_INPUT_SIZE=self.get_padded_input_size(X), + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, + O_C=O_C, + K_H=K_H, + K_W=K_W, + O_H=O_H, + O_W=O_W, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + TILE_I_H=TILE_I_H, + TILE_I_W=TILE_I_W, + TILE_O_H=TILE_O_H, + TILE_O_W=TILE_O_W, + TILE_K_H=TILE_K_H, + TILE_K_W=TILE_K_W, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, + SUB_TILE_I_H=SUB_TILE_I_H, + SUB_TILE_I_W=SUB_TILE_I_W, + SUB_TILE_K_H=SUB_TILE_K_H, + SUB_TILE_K_W=SUB_TILE_K_W, + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + DATA_STYPE="f32", + input_reorder=self.input_reorder + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "output_buffer", + dram_var = "Y", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"c0", "index1":"tile_n", "index2":"o_h", "index3":"tile_m"} + ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} + code = self._template_from_string(conv_template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + return code + + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + TOG_latency = O_W if TILE_M > O_W else TILE_M + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + options = dict( + kernel=self.kernel, + KERNEL_NAME=kernel_name, + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, + PADDING_H=self.padding[0], + PADDING_W=self.padding[1], + VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder + ) + code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" + + def get_arg_attributes(self): + arg_attributes = [] + + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] + + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride + + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + + return arg_attributes + + def codegen_header(self, code, extra_headers): + write_path = extension_codecache.get_write_path(code) + if not os.path.exists(write_path): + os.makedirs(write_path) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, extra_headers[0]) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, extra_headers[1]) + self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 3f52a61d..9cbd6514 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -1,254 +1,339 @@ import os import math -from typing import List, Optional, cast +from sympy import Symbol, Number +from typing import List, Optional from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel -from torch._inductor.ir import Buffer from torch._inductor.ir import IRNode -from torch._inductor.ir import ReinterpretView from torch._inductor.codecache import write_atomic import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common from torch._inductor.codecache import get_hash from PyTorchSimFrontend import extension_config -GEMM_TEMPLATE = r""" -#map0 = affine_map<(d0, d1) -> (d0 * {{ K }} + d1)> -#map1 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> -memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @B_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - -func.func @{{ KERNEL_NAME }}({{ KERNEL_DEF }}) { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index - %c_mvin3 = arith.constant 14 : index - %c_mvout = arith.constant 3 : index - %c_set = arith.constant 2 : index +CONV_TEMPLATE = r""" +// Conv2D kernel +// BATCH = {{ BATCH }} +// I_C = {{ I_C }} +// I_H = {{ I_H }} +// I_W = {{ I_W }} +// O_C = {{ O_C }} +// K_H = {{ K_H }} +// K_W = {{ K_W }} +// O_H = {{ O_H }} +// O_W = {{ O_W }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// TILE_I_H={{ TILE_I_H }}, +// TILE_I_W={{ TILE_I_W }}, +// TILE_O_H={{ TILE_O_H }}, +// TILE_O_W={{ TILE_O_W }}, +// TILE_K_H={{ TILE_K_H }}, +// TILE_K_W={{ TILE_K_W }}, +// SUB_TILE_M={{ SUB_TILE_M }}, +// SUB_TILE_N={{ SUB_TILE_N }}, +// SUB_TILE_I_W={{ SUB_TILE_I_W }}, +// SUB_TILE_K_H={{ SUB_TILE_K_H }}, +// SUB_TILE_K_W={{ SUB_TILE_K_W }}, +// PADDING_H = {{ PADDING_H }} +// PADDING_W = {{ PADDING_W }} +// STRIDE_H = {{ STRIDE_H }} +// STRIDE_W = {{ STRIDE_W }} +// DATA_STYPE = {{ DATA_STYPE }} + +#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> +#map_I_W = affine_map<(d0, d1) -> (d0 * {{ STRIDE_W }} + d1)> +#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> +#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_I_W * TILE_M, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> +#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> %c0 = arith.constant 0 : index - - %N = arith.constant {{ N }} : index - %K = arith.constant {{ K }} : index - %X_buffer = memref.get_global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - - affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { - %index2 = affine.apply #map1(%t_m, %t_n) - affine.dma_start %B[%index2], %Y_buffer[%c0, %c0], %tag[0], %c_mvin3, %N, %c_set : memref<{{ M * N }}xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ kernel.vector_lane }}], async=1 } - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%t_m, %t_k) - %index1 = affine.apply #map1(%t_k, %t_n) - affine.dma_start %X[%index0], %X_buffer[%c0, %c0], %tag[0], %c_mvin, %K, %c_set : memref<{{ M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ TILE_K }}], async=1 } - affine.dma_start %W[%index1], %W_buffer[%c0, %c0], %tag[0], %c_mvin2, %N, %c_set : memref<{{ K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ TILE_K }}, {{ kernel.vector_lane }}], async=1 } - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true } - affine.dma_start %Y_buffer[%c0, %c0], %Y[%index2], %tag[0], %c_mvout, %N, %c_set : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<{{ M * N }}xf32>, memref<1xi32> { async=1 } - } { outer_loop=true } - } { outer_loop=true } + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %tile_m = 0 to {{ BATCH }} step {{ TILE_M }} { + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { + // Initialize output + {%- if BIAS %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} + {%- else %} + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + {%- endif %} + affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { + affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { + affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { + %index_i_h = affine.apply #map_I_H(%o_h, %k_h) + %index_i_w = affine.apply #map_I_W(%o_w, %k_w) + // Load input matrix + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_M, SUB_TILE_K], indent_size=16) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=16) }} + // Compute body part + affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. + affine.for %tile_k_w = 0 to {{ TILE_K_W }} { + %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + affine.for %tile_o_h = 0 to {{ TILE_O_H }} { + affine.for %tile_o_w = 0 to {{ TILE_O_W }} { + %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) + %tile_i_w = affine.apply #map_I_W(%tile_o_w, %tile_k_w) + %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_i_w) + %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { accumulation_loop=true, subtile_loop="k" } + } { accumulation_loop=true } + } { accumulation_loop=true } + // Store output matrix + {{kernel.store_output(indent_size=10)}} + } { outer_loop=true } + } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } return } """ - -CONV2D_FUNC_TEMPLATE = r""" -def {{ FUNC_NAME }}({{ INPUT }}, {{ WEIGHT }}{% if BIAS %}, {{ BIAS }}{% endif %}, {{ OUT }}): - {{ INPUT }}_cpu = {{ INPUT }}.cpu() - {{ WEIGHT }}_cpu = {{ WEIGHT }}.cpu(){% if BIAS %} - {{ BIAS }}_cpu = {{ BIAS }}.cpu(){% endif %} - {{ OUT }}_cpu = {{ OUT }}.cpu() - - # Torch support NCHW, so we need to transpose for now - {{ INPUT }}_cpu = {{ INPUT }}_cpu.permute(0, 2, 3, 1) - {{ WEIGHT }}_cpu = {{ WEIGHT }}_cpu.permute(0, 2, 3, 1) - {{ OUT }}_cpu = {{ OUT }}_cpu.permute(0, 2, 3, 1) - {{ OUT }}_cpu.zero_() - - input_shape = {{ INPUT }}_cpu.shape - weight_shape = {{ WEIGHT }}_cpu.shape - output_shape = {{ OUT }}_cpu.shape - {{ OUT }}_cpu = {{ OUT }}_cpu.reshape(-1, output_shape[3]).contiguous() - - input_pad_shape = (input_shape[0], input_shape[1]+2*{{ PADDING_H }}, input_shape[2]+2*{{ PADDING_W }}, input_shape[3]) - input_pad = torch.zeros(input_pad_shape) - - if {{ PADDING_H }} != 0 and {{ PADDING_W }} != 0: - input_pad[:, {{ PADDING_H }}:-{{ PADDING_H }}, {{ PADDING_W }}:-{{ PADDING_W }}, :] = {{ INPUT }}_cpu - elif {{ PADDING_H }} != 0: - input_pad[:, {{ PADDING_H }}:-{{ PADDING_H }}, :, :] = {{ INPUT }}_cpu - elif {{ PADDING_W }} != 0: - input_pad[:,:, {{ PADDING_W }}:-{{ PADDING_W }}, :] = {{ INPUT }}_cpu - else: - input_pad = {{ INPUT }}_cpu - - {% if VALIDATION_MODE %} - {% endif %} - - for kh in range(weight_shape[1]): - for kw in range(weight_shape[2]): - input_tile = input_pad[:, kh:input_pad_shape[1]-(weight_shape[1]-1)+kh, kw:input_pad_shape[2]-(weight_shape[2]-1)+kw, :] - input_tile = input_tile[:,::{{ STRIDE_H }},::{{ STRIDE_W }}, :] - kernel_tile = {{ WEIGHT }}_cpu[:, kh, kw, :].t() - input_tile = input_tile.reshape(-1, input_pad_shape[3]) - - {% if VALIDATION_MODE %} - if kh == 0 and kw == 0: - {{ KERNEL_NAME }}(input_tile, kernel_tile, {{ OUT }}_cpu, {{ OUT }}_cpu, intermediate_op=0b01) - elif kh == weight_shape[1]-1 and kw == weight_shape[2]-1: - {{ KERNEL_NAME }}(input_tile, kernel_tile, {{ OUT }}_cpu, {{ OUT }}_cpu, intermediate_op=0b10) - else: - {{ KERNEL_NAME }}(input_tile, kernel_tile, {{ OUT }}_cpu, {{ OUT }}_cpu, intermediate_op=0b11) - {% else %} - {{ KERNEL_NAME }}(input_tile, kernel_tile, {{ OUT }}_cpu, {{ OUT }}_cpu) # input, weight, bias, out - {% endif %} - {% if BACKENDSIM_EAGER_MODE %} - yield ({{KERNEL_NAME}}, (input_tile, kernel_tile, {{ OUT }}_cpu, {{ OUT }}_cpu)) - {% endif %} - - {{ OUT }}_cpu = {{ OUT }}_cpu.reshape(output_shape) - {{ OUT }}_cpu = {{ OUT }}_cpu.permute(0, 3, 1, 2){% if BIAS %} - {{ OUT }}_cpu += {{ BIAS }}_cpu.reshape(-1, 1, 1) #TODO: BIAS should be added in the kernel{% endif %} - {{ OUT }}.copy_({{ OUT }}_cpu) +WRAPPER_TEMPLATE = r""" +def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: + # Padding input + padded_shape = list(X.shape) + padded_shape[2] += 2 * {{ PADDING_H }} + padded_shape[3] += 2 * {{ PADDING_W }} + X_padding = torch.zeros(padded_shape, device=X.device) + X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X + + # Tanspose inputs + {%- for buf, name in kernel.get_conv_inputs().items() %} + {%- if name == "X" %} + {{ name }} = {{ name }}_padding.permute(2, 3, 0, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (I_H, I_W, BATCH, I_C) + {%- elif name == "W" %} + {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) + {%- elif name == "Bias" %} + {{ name }} = {{ name }} + {%- endif %} + {%- endfor %} + + # Launch kernel + {{ KERNEL_NAME }} + {%- if BACKENDSIM_EAGER_MODE %} + yield ({{KERNEL_NAME}}, ) + {%- endif %} """ - class MLIRConvTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): super().__init__("kernel", input_nodes, layout, input_reorder) self.stride = kwargs["stride"] self.padding = kwargs["padding"] self.dilation = kwargs["dilation"] - weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.function_name = "Conv2D_" + "_".join(weight_shape)+ "_" \ + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [i for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + "_".join([str(i) for i in self.stride]) \ + "_" + "_".join([str(i) for i in self.padding]) \ + "_" + "_".join([str(i) for i in self.dilation]) - self.gemm_args = ['input', 'weight', 'bias', 'output'] - - self.calculate_gemm_shape() + self.kernel_args = ['X', 'W', 'Bias', 'Y'] - def is_transposed(self, node): - if isinstance(node, ReinterpretView): - if node.layout.stride != node.data.layout.stride: - if node.layout.stride[-2] == node.data.layout.stride[-1] and node.layout.stride[-1] == node.data.layout.stride[-2]: - return True - else: - raise NotImplementedError("If the stride is not equal to the original stride, it should have been transposed.") - return False - - def calculate_gemm_shape(self): - input_shape = self.input_nodes[0].get_size() - weight_shape = self.input_nodes[1].get_size() - gemm_h = int((input_shape[2] + 2*self.padding[0] - (weight_shape[2]-1) - 1) / self.stride[0]) + 1 - gemm_w = int((input_shape[3] + 2*self.padding[1] - (weight_shape[3]-1) - 1) / self.stride[1]) + 1 - - self.gemm_input_shape = [input_shape[0],input_shape[1],gemm_h, gemm_w] - self.gemm_weight_shape = [weight_shape[0],weight_shape[1],1,1] - self.gemm_output_shape = [self.gemm_input_shape[2]*self.gemm_input_shape[3], self.gemm_weight_shape[0]] # Consider Batch size 1 - - def def_kernel(self) ->str: - input_size = self.gemm_input_shape[1]*self.gemm_input_shape[2]*self.gemm_input_shape[3] - weight_size = self.gemm_weight_shape[0]*self.gemm_weight_shape[1] - output_size = self.gemm_output_shape[0]*self.gemm_output_shape[1] - return f"%X: memref<{input_size}xf32>, %W: memref<{weight_size}xf32>, %B: memref<{output_size}xf32>, %Y: memref<{output_size}xf32>" + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, **kwargs): + # Extract input arguments info if template_buffer_node is not None: self.output_node = template_buffer_node - if epilogue_nodes is not None and len(epilogue_nodes) > 0: - self.output_node = cast(Buffer, epilogue_nodes[-1]) - self.function_name += f"_fused_{epilogue_nodes[0].node.origin_node.name}" + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - M = self.gemm_input_shape[2] * self.gemm_input_shape[3] - N = self.gemm_weight_shape[0] - K = self.gemm_weight_shape[1] - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K) - kernel.tile_size = [TILE_M, TILE_N, TILE_K] - kernel.loop_size = [M, N, K] - - W_transposed = self.is_transposed(W) - X_transposed = self.is_transposed(X) - - options = dict( + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + + # Select tile size adn template + conv_template = CONV_TEMPLATE + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_I_H, TILE_I_W, TILE_M, TILE_K ] + X_tile_stride = [TILE_I_W*TILE_M*TILE_K, TILE_M*TILE_K, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("index_i_h"), Symbol("index_i_w"), Symbol("tile_m"), Symbol("tile_k")] + X_idx = [X_dim[0]*(I_W+2*PADDING_W)*BATCH*I_C, X_dim[1]*I_C*BATCH, X_dim[2]*I_C, X_dim[3]] + + W_tile_size = [TILE_K_H, TILE_K_W, TILE_K, TILE_N] + W_tile_stride = [TILE_K_W * TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , W_dim[1]*I_C*O_C, W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [TILE_M, TILE_N, TILE_O_H, TILE_O_W] + Y_tile_stride = [1, TILE_M, TILE_O_W * TILE_M * TILE_N, TILE_M * TILE_N] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_dim = [Symbol("tile_m"), Symbol("tile_n"), Symbol("o_h"), Symbol("o_w")] + Y_idx = [Y_dim[0]*O_C*O_H*O_W, Y_dim[1]*O_H*O_W, Y_dim[2]*O_W, Y_dim[3]] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + + kernel.render_options = dict( KERNEL_NAME=self.name, - KERNEL_DEF=self.def_kernel(), kernel=kernel, - M=M, - N=N, - K=K, + X=X, W=W, Y=Y, BIAS=Bias, + PADDED_INPUT_SIZE=self.get_padded_input_size(X), + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, + O_C=O_C, + K_H=K_H, + K_W=K_W, + O_H=O_H, + O_W=O_W, TILE_M=TILE_M, TILE_N=TILE_N, TILE_K=TILE_K, + TILE_I_H=TILE_I_H, + TILE_I_W=TILE_I_W, + TILE_O_H=TILE_O_H, + TILE_O_W=TILE_O_W, + TILE_K_H=TILE_K_H, + TILE_K_W=TILE_K_W, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, + SUB_TILE_I_H=SUB_TILE_I_H, + SUB_TILE_I_W=SUB_TILE_I_W, + SUB_TILE_K_H=SUB_TILE_K_H, + SUB_TILE_K_W=SUB_TILE_K_W, + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, DATA_STYPE="f32", - DATA_SIZE=4, + input_reorder=self.input_reorder ) - code = self._template_from_string(GEMM_TEMPLATE).render(**options) - - self.header = f"float X_spad[{TILE_M * TILE_K // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float W_spad[{TILE_K * TILE_N // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{TILE_M * TILE_N // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{TILE_M * TILE_K}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float W_spad[{TILE_K * TILE_N}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{TILE_M * TILE_N}] __attribute__ ((section(\".spad\")));\n" - if Bias is not None: - self.header += f"float B_spad[{TILE_M * TILE_N // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float B_spad[{TILE_M * TILE_N}] __attribute__ ((section(\".spad\")));\n" - - kernel.add_loop_info([options["M"], options["N"], options["K"]], [options["TILE_M"], options["TILE_N"], options["TILE_K"]]) - kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=self.input_reorder) + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "output_buffer", + dram_var = "Y", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"tile_m", "index1":"tile_n", "index2":"o_h", "index3":"o_w"} + ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} + code = self._template_from_string(conv_template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = BATCH if TILE_M > BATCH else TILE_M + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) options = dict( + kernel=self.kernel, KERNEL_NAME=kernel_name, - FUNC_NAME=self.function_name, - INPUT=input_args[0], - WEIGHT=input_args[1], - BIAS=input_args[2] if len(input_args) == 4 else None, - OUT=input_args[3] if len(input_args) == 4 else input_args[2], + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, PADDING_H=self.padding[0], PADDING_W=self.padding[1], - STRIDE_H=self.stride[0], - STRIDE_W=self.stride[1], - DILATION_H=self.dilation[0], - DILATION_W=self.dilation[1], VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, - BACKENDSIM_EAGER_MODE=extension_config.CONFIG_BACKENDSIM_EAGER_MODE, - HASH_VALUE=self.hash_value + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder ) - code = self._template_from_string(CONV2D_FUNC_TEMPLATE).render(**options) - return code, self.function_name + code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" def get_arg_attributes(self): arg_attributes = [] - input_shape = self.input_nodes[0].get_size() - weight_shape = self.input_nodes[1].get_size() - gemm_h = int((input_shape[2] + 2*self.padding[0] - (weight_shape[2]-1) - 1) / self.stride[0]) + 1 - gemm_w = int((input_shape[3] + 2*self.padding[1] - (weight_shape[3]-1) - 1) / self.stride[1]) + 1 + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] - gemm_input_shape = [input_shape[0],input_shape[1],gemm_h, gemm_w] - gemm_weight_shape = [weight_shape[0],weight_shape[1],1,1] - gemm_output_shape = [gemm_input_shape[2]*gemm_input_shape[3], gemm_weight_shape[0]] # Consider Batch size 1 + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride - arg_attributes.append([self.gemm_args[0], [MLIRKernelArgs.MLIR_ARGS_IN, self.input_nodes[0].layout.dtype, math.prod(gemm_input_shape)]]) - arg_attributes.append([self.gemm_args[1], [MLIRKernelArgs.MLIR_ARGS_IN, self.input_nodes[1].layout.dtype, math.prod(gemm_weight_shape)]]) - arg_attributes.append([self.gemm_args[2], [MLIRKernelArgs.MLIR_ARGS_IN, self.input_nodes[0].layout.dtype, math.prod(gemm_output_shape)]]) - arg_attributes.append([self.gemm_args[3], [MLIRKernelArgs.MLIR_ARGS_OUT, self.input_nodes[0].layout.dtype, math.prod(gemm_output_shape)]]) + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) return arg_attributes @@ -259,7 +344,7 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) + write_atomic(gem5_write_path, extra_headers[1]) self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 954059c0..f706c2e5 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -1,70 +1,58 @@ import os -from typing import List, Optional, cast +import json +from pathlib import Path +from torch import empty_strided +from typing import List, Optional +import sympy from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel -from torch._inductor.ir import Buffer from torch._inductor.ir import IRNode -from torch._inductor.ir import ReinterpretView from torch._inductor.codecache import write_atomic import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.mlir import mlir_common GEMM_TEMPLATE = r""" -{% if X_transposed %}#map0 = affine_map<(d0, d1) -> (d1 * {{ M }} + d0)>{% else %}#map0 = affine_map<(d0, d1) -> (d0 * {{ K }} + d1)>{% endif %} -{% if W_transposed %}#map1 = affine_map<(d0, d1) -> (d1 * {{ K }} + d0)>{% else %}#map1 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)>{% endif %} -#map2 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> -memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> +// GEMM {% if prologue_nodes -%}prologue fused{%- endif %} {% if epilogue_nodes -%}eilogue fused{%- endif %} kernel +// M = {{ M }} +// N = {{ N }} +// K = {{ K }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// SUB_TILE_M = {{ SUB_TILE_M }} +// SUB_TILE_N = {{ SUB_TILE_N }} {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index{% if Bias %} - %c_mvin3 = arith.constant 14 : index{% endif %} - %c_mvout = arith.constant 3 : index - %c_set = arith.constant 2 : index - %x_chunk = arith.constant {% if X_transposed %} {{ kernel.vector_lane * 2 + 0 }} {% else %} {{ 2 }} {% endif %} : index - %w_chunk = arith.constant {% if W_transposed %} {{ TILE_K * 2 + 0 }} {% else %} {{ 2 }} {% endif %} : index - %M = arith.constant {{ M }} : index - %N = arith.constant {{ N }} : index - %K = arith.constant {{ K }} : index - %X_buffer = memref.get_global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32>{% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ TILE_M * TILE_N // kernel.vector_lane }}xf32>{% endif %} - %c0 = arith.constant 0 : index - - affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { - %index2 = affine.apply #map2(%t_m, %t_n) - {% if Bias -%} - affine.dma_start %Bias[ - {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer[%c0, %c0], %tag[0], %c_mvin3, % - {%- if Bias_rank == 2 -%} N {%- else -%} c0 {%- endif -%} - , %c_set : memref< - {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} - xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ kernel.vector_lane }}], async=1 } - {%- else -%} - affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ TILE_M * TILE_N // kernel.vector_lane }}xf32> + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { + {%- if Bias %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} + {%- else %} + affine.vector_store %v0, %Y_buffer[0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {%- endif %} - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%t_m, %t_k) - %index1 = affine.apply #map1(%t_k, %t_n) - affine.dma_start %X[%index0], %X_buffer[%c0, %c0], %tag[0], %c_mvin, - {%- if X_transposed -%} %M, %x_chunk {%- else -%} %K, %x_chunk {%- endif -%} - : memref<{{ M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ kernel.vector_lane }}, {{ TILE_K }}], async=1{% if X_transposed %}, transpose=1{% endif %} } - affine.dma_start %W[%index1], %W_buffer[%c0, %c0], %tag[0], %c_mvin2, - {%- if W_transposed -%} %K, %w_chunk {%- else -%} %N, %w_chunk {%- endif -%} - : memref<{{ K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ TILE_K }}, {{ kernel.vector_lane }}], async=1{% if W_transposed %}, transpose=1{% endif %} } - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true } - {{kernel.store_output()}} - } { outer_loop=true } - } { outer_loop=true } + affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { + {% if prologue_nodes -%} + // prologue nodes + {{kernel.load_input(indent_size=8)}} + {%- else -%} + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_K], indent_size=8) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K, SUB_TILE_N], indent_size=8) }} + {%- endif %} + linalg.matmul ins(%X_buffer, %W_buffer : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }}, {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }}) + outs(%Y_buffer : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}) + } { accumulation_loop=true, subtile_loop="k" } + {{kernel.store_output(indent_size=6)}} + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } return } """ @@ -75,83 +63,267 @@ } """ +GEMM_REDUCTION_TEMPLATE = r""" +// GEMM reduction kernel +// M = {{ M }} +// N = {{ N }} +// K = {{ K }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// SUB_TILE_M = {{ SUB_TILE_M }} +// SUB_TILE_N = {{ SUB_TILE_N }} +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { + affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { + %Y_bufferT = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + {%- if Bias %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} + {%- else %} + affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {%- endif %} + affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_K], indent_size=8) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K, SUB_TILE_N], indent_size=8) }} + linalg.matmul ins(%X_buffer, %W_buffer : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }}, {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }}) + outs(%Y_bufferT : memref<{{TILE_M}}x{{TILE_N}}x{{DATA_STYPE}}, 1>) + } { accumulation_loop=true, subtile_loop="k" } + {{kernel.store_output(indent_size=6)}} + } { outer_loop=true, subtile_loop="m" } + {{kernel.reduction_output(indent_size=4)}} + } { outer_loop=true, subtile_loop="n" } + return +} +""" + class MLIRGemmTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) - def is_transposed(self, node): - if isinstance(node, ReinterpretView): - if 0 in node.layout.stride: # [MoE] Temporary solution - if node.layout.stride[1] == 0: - return True - if node.layout.stride != node.data.layout.stride: - if node.layout.stride[-2] == node.data.layout.stride[-1] and node.layout.stride[-1] == node.data.layout.stride[-2]: - return True - else: - raise NotImplementedError("If the stride is not equal to the original stride, it should have been transposed.") - return False - def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, **kwargs): if template_buffer_node is not None: self.output_node = template_buffer_node - # if epilogue_nodes is not None and len(epilogue_nodes) > 0: - # self.output_node = cast(Buffer, epilogue_nodes[-1]) #FIXME: Temperary solution - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + # Extract input arguments info + X, W, Y = self.input_nodes[0], self.input_nodes[1], self.output_node + X_tensor = empty_strided(X.layout.size, X.layout.stride) + W_tensor = empty_strided(W.layout.size, W.layout.stride) + if len(W_tensor.size()) > 2 or len(X_tensor.size()) > 2: + raise NotImplementedError("Please report this case to us...") - M, N, K = X.get_size()[0], W.get_size()[1], X.get_size()[1] - if (M == 0) or (N == 0) or (K == 0): - TILE_M, TILE_N, TILE_K = 0, 0, 0 + # Extract fusion info + n_epilogue_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 + n_extra_read = set() + if epilogue_nodes is not None: + for enode in epilogue_nodes: + n_extra_read.update(enode.node.get_read_names()) + if self.output_node.name in n_extra_read: + n_extra_read.remove(self.output_node.name) + + # Select tile size + M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + + # Select template code + if (M == 0) or (N == 0) or (K == 0): # exception for MoE template = EMPTY_TEMPLATE + nr_rdim = 0 + epilogue_dim_aliasing = {} + elif n_epilogue_node>=1 and epilogue_nodes[0].is_reduction(): + template = GEMM_REDUCTION_TEMPLATE + epilogue_dim_aliasing = {"index0":"index1", "index1":"index0"} + nr_rdim = 1 else: - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K) template = GEMM_TEMPLATE - kernel.tile_size = [TILE_M, TILE_N, TILE_K] - kernel.loop_size =[M, N, K] + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1"} + nr_rdim = 0 + + TOG_latency = M if SUB_TILE_M > M else SUB_TILE_M + kernel.loop_size =[TOG_latency, SUB_TILE_N, SUB_TILE_K] - W_transposed = self.is_transposed(W) - X_transposed = self.is_transposed(X) + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_M, TILE_K] + X_tile_stride = [1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("X_buffer") + X_stride = X.get_layout().stride + X_idx = [sympy.Symbol("index0") * X_stride[0], sympy.Symbol("index2") * X_stride[1]] # To keep index arguemnt order, we used index_list + + W_tile_size = [TILE_K, TILE_N] + W_tile_stride = [1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("W_buffer") + W_stride = W.get_layout().stride + W_idx = [sympy.Symbol("index2") * W_stride[0], sympy.Symbol("index1") * W_stride[1]] + + vlane_split_axis = vlane_split_axis if nr_rdim==0 else 0 + Y_tile_size = [TILE_M, TILE_N] if nr_rdim == 0 else [TILE_N, TILE_M] + Y_tile_stride=[1, TILE_M] if nr_rdim == 0 else [TILE_M, 1] + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("Y_buffer") + Y_stride = Y.get_layout().stride + if nr_rdim == 0: + Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]] + else: + Y_idx = [sympy.Symbol("index1") * Y_stride[1], sympy.Symbol("index0") * Y_stride[0]] + + # Extract Bias info + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + if Bias is not None: + Bias_stride = Bias.get_layout().stride + if nr_rdim == 0: + Bias_idx = [sympy.Symbol("index0") * Bias_stride[0], sympy.Symbol("index1") * Bias_stride[1]] + else: + Bias_idx = [sympy.Symbol("index1") * Bias_stride[1], sympy.Symbol("index0") * Bias_stride[0]] + else: + Bias_idx = None kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - M=M, - N=N, - K=K, + M=M, N=N, K=K, TILE_M=TILE_M, TILE_N=TILE_N, TILE_K=TILE_K, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, DATA_STYPE="f32", - DATA_SIZE=4, - X = X, - W = W, - Y = Y, + X = X, W = W, Y = Y, Bias = Bias, - Bias_rank = len(Bias.data.get_size()) if Bias is not None else 0, - W_transposed = W_transposed, - X_transposed = X_transposed, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, epilogue_nodes = epilogue_nodes, + prologue_nodes = prologue_nodes, input_reorder = self.input_reorder ) - code = self._template_from_string(template).render(**kernel.render_options) + if prologue_nodes: + prologue_output_name = list(prologue_nodes[0].read_writes.writes)[0].name + if prologue_output_name == X.get_name(): + # Input fusion case + prologue_var = "X" + prologue_sram_var = "X_buffer" + prologue_tile_desc = X_tile_desc + prologue_dim_aliasing = {"index0":"index0", "index1":"index2"} + is_input_fused = True + else: + # Weight fusion case + prologue_var = "W" + prologue_sram_var = "W_buffer" + prologue_tile_desc = W_tile_desc + prologue_dim_aliasing = {"index0":"index2", "index1":"index1"} + is_input_fused = False - self.header = f"float X_spad[{TILE_M * TILE_K // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float W_spad[{TILE_K * TILE_N // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{TILE_M * TILE_N // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{TILE_M * TILE_K}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float W_spad[{TILE_K * TILE_N}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{TILE_M * TILE_N}] __attribute__ ((section(\".spad\")));\n" + kernel.prologue_info = dict ( + input_dram_var = "X", + input_sram_var = "X_buffer", + input_tile_desc = X_tile_desc, + input_idx = X_idx, + input_subtile_size = [TILE_M, TILE_K], + input_dim_aliasing = {"index0":"index0", "index1":"index2"}, - kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + weight_dram_var = "W", + weight_sram_var = "W_buffer", + weight_tile_desc = W_tile_desc, + weight_idx = W_idx, + weight_subtile_size = [TILE_K, TILE_N], + weight_dim_aliasing = {"index0":"index2", "index1":"index1"}, + # Descriptor for fusion + dram_var = prologue_var, + sram_var = prologue_sram_var, + dram_tile_desc = prologue_tile_desc, + dim_aliasing = prologue_dim_aliasing, + is_bmm = False, + is_input_fused = is_input_fused + ) + kernel.epilogue_info = dict( + output_node = self.output_node.name, + dram_var = "Y", + sram_var = "Y_buffer", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + nr_rdim = nr_rdim, + dim_aliasing = epilogue_dim_aliasing + ) + code = self._template_from_string(template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): + # Check cheat sheet + cheatsheet_path = extension_config.CONFIG_GEMM_CHEATSHEET_PATH + data = {} + if extension_config.CONFIG_GEMM_CHEATSHEET_PATH is not None: + path = Path(cheatsheet_path) + if path.is_file(): + with path.open("r") as f: + data = json.load(f) + + gemm_shape = f"{M}_{K}_{N}" + if gemm_shape in data: + tile_info = data[gemm_shape] + TILE_M = tile_info["TILE_M"] + TILE_N = tile_info["TILE_N"] + TILE_K = tile_info["TILE_K"] + else: # case 2: use gemm_combination_mapping + min_tile = (n_extra_node + n_prologue_node) == 0 + TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, max(len(n_extra_read)-2, 0), n_prologue_node, min_tile=min_tile) + # case 3: use manual tile size + if extension_config.CONFIG_MANUAL_TILE_SIZE: + TILE_M = extension_config.CONFIG_TILE_M + TILE_N = extension_config.CONFIG_TILE_N + TILE_K = extension_config.CONFIG_TILE_K + + # Edge case + if (M == 0) or (N == 0) or (K == 0): + TILE_M, TILE_N, TILE_K = 1, 1, 1 + + # Calculate Sub Tile Size for fine-grained DMA + if extension_config.CONFIG_SUBTILE: + # Case 1: adjust selective fine-grained DMA (SFG-DMA) + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane or n_prologue_node) else kernel.vector_lane + if (TILE_M == M and TILE_N == N and TILE_N <= 512): + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + else: # Avoid Row Conflict of weights + SUB_TILE_N = TILE_N + SUB_TILE_K = TILE_K + # Case 2: use manual sub tile size (FG-DMA) + if extension_config.CONFIG_MANUAL_SUBTILE_SIZE: + SUB_TILE_M = extension_config.CONFIG_SUBTILE_M + SUB_TILE_N = extension_config.CONFIG_SUBTILE_N + SUB_TILE_K = extension_config.CONFIG_SUBTILE_K + # Case 3: None Subtile + else: + SUB_TILE_M = TILE_M + SUB_TILE_N = TILE_N + SUB_TILE_K = TILE_K + return TILE_M,TILE_N,TILE_K, SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + def codegen_header(self, code, extra_headers): write_path = extension_codecache.get_write_path(code) if not os.path.exists(write_path): @@ -159,6 +331,6 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) + write_atomic(gem5_write_path, extra_headers[1]) diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index e7ca37eb..9aa08754 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -3,15 +3,22 @@ import torch from torch._inductor.lowering import lowerings from torch._inductor.kernel.mm_common import mm_args +# from torch._inductor.select_algorithm import ExternKernelChoice from torch._inductor import ir from torch._inductor.virtualized import V from torch._inductor.ir import TensorBox +from PyTorchSimFrontend.extension_op import MLIRExternKernelChoice from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate from PyTorchSimFrontend.mlir.mlir_conv_template import MLIRConvTemplate +from PyTorchSimFrontend.mlir.mlir_conv_mt_template import MLIRConvMultiTileTemplate +from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate +from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.extension_config import CONFIG_VECTOR_LANE, CONFIG_USE_TIMING_POOLING aten = torch.ops.aten +aten_spmm = MLIRExternKernelChoice(torch.sparse.mm, "custom_op::sparse_addmm") def tuned_mm(mat1, mat2, * ,layout=None): m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) @@ -90,8 +97,23 @@ def convolution( "groups": groups, } + x.realize() + weight.realize() + x = ir.ExternKernel.require_channels_last(x) + BATCH = x.layout.size[0] + I_C = x.layout.size[1] + weight = ir.ExternKernel.require_channels_last(weight) layout = conv_layout(x, weight, None, **kwargs) - mlir_template = MLIRConvTemplate([x, weight, bias], layout, **kwargs) + + # Select conv kernel + if BATCH == 1 and stride[0] == 1: + mlir_template = MLIRConvSingleBatchTemplate([x, weight, bias], layout, **kwargs) + elif BATCH == 1 and stride[0] != 1: + mlir_template = MLIRConvSingleBatchStridedTemplate([x, weight, bias], layout, **kwargs) + elif I_C < CONFIG_VECTOR_LANE // 8: # 8 is hard-coded for now. This should be changed to a better heuristic. + mlir_template = MLIRConvMultiTileTemplate([x, weight, bias], layout, **kwargs) + else: + mlir_template = MLIRConvTemplate([x, weight, bias], layout, **kwargs) return mlir_template.generate().output_node() def maxpool_layout( @@ -139,9 +161,24 @@ def custom_maxpool( } layout = maxpool_layout(x, kernel_size, stride, padding, dilation, ceil_mode) mlir_template = MLIRMaxPoolTemplate([x], layout, **kwargs) - return mlir_template.generate().output_node(), x # FIXME: x is dummy IRNode, indices are not used in our case + x.realize() + template_node = mlir_template.generate().output_node() + return template_node, x # FIXME: x is dummy IRNode, indices are not used in our case + +def sparse_addmm(*args, **kwargs): + _, sp_mat1, sp_mat2 = args + mat1_layout = sp_mat1.layout + out_range = args[0].data.data.data.ranges + size = [out_range[i] for i in args[0].data.dims] + layout = ir.FlexibleLayout( + device=mat1_layout.device, dtype=mat1_layout.dtype, size=size # FIXME: Example code for aten op overwrite by externkernel call + ) + return aten_spmm.bind((sp_mat1, sp_mat2), layout).output_node() lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) -lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) \ No newline at end of file +lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) +lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) +if CONFIG_USE_TIMING_POOLING: + lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py index 1f93f82a..6f605d56 100644 --- a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py +++ b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py @@ -8,27 +8,21 @@ from torch._inductor.ir import ReinterpretView from torch._inductor.codecache import write_atomic import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common +import sympy # This template only represents the DMA operations -TEMPLATE = r"""#map0 = affine_map<(d0, d1) -> (d0 * {{ W }} + d1)> -memref.global @X_spad : memref<{{ in_tile }}x{{ in_tile }}xf32, 1> -memref.global @Y_spad : memref<{{ out_tile }}x{{ out_tile }}xf32, 1> +TEMPLATE = r""" +{{kernel.def_global_vars()}} -func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X], outputs=[Y], names_str="X, Y")}} { - %c_mvin = arith.constant 2 : index - %c_mvout = arith.constant 3 : index - %dummy = arith.constant 2 : index - %in_chunk = arith.constant {{ in_tile * 2}} : index - %out_chunk = arith.constant {{ out_tile * 2}} : index - %X_buffer = memref.get_global @X_spad : memref<{{ in_tile }}x{{ in_tile }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<{{ out_tile }}x{{ out_tile }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %c0 = arith.constant 0 : index - affine.for %i = 0 to {{ BCH }} step {{ out_tile }} { - affine.for %j = 0 to {{ W }} step {{ out_tile }} { - %index0 = affine.apply #map0(%i, %j) - affine.dma_start %X[%index0], %X_buffer[%c0, %c0], %tag[0], %c_mvin, %dummy, %in_chunk : memref<{{ IN }}xf32>, memref<{{ in_tile }}x{{ in_tile }}xf32, 1>, memref<1xi32> - affine.dma_start %Y_buffer[%c0, %c0], %Y[%index0], %tag[0], %c_mvout, %dummy, %out_chunk : memref<{{ out_tile }}x{{ out_tile }}xf32, 1>, memref<{{ OUT }}xf32>, memref<1xi32> +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X], outputs=[Y], names_str="X, Y", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {{- kernel.def_local_vars(indent_size=2) }} + affine.for %index0 = 0 to {{ BCH }} step {{ out_tile }} { + affine.for %index1 = 0 to {{ W }} step {{ out_tile }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, indent_size=6) }} + {{ kernel.def_dma_op("MVOUT", "Y", Y_idx, Y_tile_desc, indent_size=6) }} } { outer_loop=true } } { outer_loop=true } return @@ -36,8 +30,8 @@ """ class MLIRMaxPoolTemplate(MLIRTemplate): - def __init__(self, input_nodes, layout, kernel_size, stride, padding, dilation, ceil_mode): - super().__init__("kernel", input_nodes, layout) + def __init__(self, input_nodes, layout, kernel_size, stride, padding, dilation, ceil_mode, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) self.kernel_size = kernel_size self.stride = stride self.padding = padding @@ -62,27 +56,48 @@ def render(self, H = Y.get_size()[2] W = Y.get_size()[3] BCH = B * C * H - kernel.tile_size = [1, 1, 1] # Dummy Tile kernel.loop_size = None - options = { - "KERNEL_NAME" : self.name, - "kernel" : kernel, - "IN" : X.get_numel(), - "OUT" : Y.get_numel(), - "X" : X, - "Y" : Y, - "BCH" : BCH, - "W" : W, - "in_tile" : in_tile, - "out_tile" : out_tile, - } - code = self._template_from_string(TEMPLATE).render(**options) - self.header = f"float X_spad[{in_tile * in_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{out_tile * out_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{in_tile * in_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{out_tile * out_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - kernel.add_loop_info([options["IN"]], [kernel.vector_lane, kernel.vector_lane]) + # Prepare tile descriptors + vlane_stride = 1 # Used dummy value + vlane_split_axis = 1 + X_tile_size = [in_tile, in_tile] + X_tile_stride = [1, in_tile] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("X_buffer") + X_idx = [sympy.Symbol("index0"), sympy.Symbol("index1")*W] # To keep index arguemnt order, we used index_list + + Y_tile_size = [out_tile, out_tile] + Y_tile_stride = [1, out_tile] + Y_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("W_buffer") + Y_idx = [sympy.Symbol("index0"), sympy.Symbol("index1")*W] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=X, + Y=Y, + BCH=BCH, + W=W, + out_tile=out_tile, + X_idx = X_idx, + Y_idx = Y_idx, + X_tile_desc = X_tile_desc, + Y_tile_desc = Y_tile_desc, + input_reorder = self.input_reorder + ) + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "Y_buffer", + dram_var = "Y", + dram_tile_desc = Y_tile_desc, + ) + kernel.exception_nodes["Y"] = {"numel" : Y.get_numel()} + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + kernel.add_loop_info([X.get_numel()], [kernel.vector_lane, kernel.vector_lane]) return code def codegen_header(self, code, extra_headers): @@ -92,6 +107,6 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) \ No newline at end of file + write_atomic(gem5_write_path, extra_headers[1]) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py new file mode 100644 index 00000000..786971fe --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -0,0 +1,388 @@ +import os +import math +import sympy +from functools import reduce +import operator +from sympy import symbols, sympify, Symbol +from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel + +from torch._inductor import config +from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode, SchedulerNode, BaseSchedulerNode +from torch._inductor.utils import IndentedBuffer +from torch._inductor.virtualized import V +from torch._inductor.ir import LoopBody +from torch._inductor import dependencies + +from . import mlir_common +from . import mlir_lowering # DO NOT REMOVE THIS LINE, it is used for lowering + +class MLIRScheduling(BaseScheduling): + count = 0 + target_kernel = MLIRKernel + def __init__(self, scheduler): + self.scheduler = scheduler + self.scheduler.can_fuse_origin = self.scheduler.can_fuse + self.scheduler.can_fuse = self.can_fuse_with_exceptions + #self.scheduler.enter_context = self.enter_context_fixed # FIXME. Monkey patch: For fixing the inductor bug + self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() + self._ready_to_flush = False + self.outer_function = set() + config.inplace_buffers = False # FIXME. inout kernel makes trouble.. So disabled it! + self.max_fusion_size = 5 + + def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: + # Extract base template node + base_template_node1 = [node for node in node1.get_nodes() if node.is_template()] + base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] + if node1.get_device() != node2.get_device(): + return False + if not (isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance(node2, (SchedulerNode, FusedSchedulerNode))): + return False + + if len(base_template_node1) == 1 and len(base_template_node2) == 0 and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction(): + # For matmul/bmm+reduction case + size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) + stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] + target_symbol = symbols("r0") + # We can't fuse dim=-1 + layout_possible = int(sympify(stride).coeff(target_symbol)) != 1 + # Directed linked? + dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 + dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) + return size_match and layout_possible and dependency_check and dependency_size + + # For prologue fusion case + if extension_config.CONFIG_FUSION_PROLOGUE and len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + target_node = base_template_node2[0].node + if target_node.origin_node is not None and hasattr(target_node.origin_node.target, "_name") and target_node.origin_node.target._name == 'aten::convolution': + return False + if node1.is_reduction(): + return False + if len(node1.read_writes.writes) != 1: + return False + if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME + return False + + # Currently only BMM, MM support prologue fusion + if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + return False + # We don't fuse this edge case... + if base_template_node2[0].group[1][0][0] == 1: + return False + + if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: + node1 = self.revert_group(node1) + return True + + return self.scheduler.can_fuse_origin(node1, node2) + + def _set_flush_status(self, status: bool): + self._ready_to_flush = status + + def can_fuse_vertical(self, node1, node2): + return self.can_fuse_horizontal(node1, node2) + + def can_fuse_horizontal(self, node1, node2): + if (len(node1.get_nodes())+ len(node2.get_nodes())) > self.max_fusion_size: + return False + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + # Reduction is currently not supported + if node1.is_reduction() and node2.is_reduction() and not node1.is_template() and not node2.is_template() and extension_config.CONFIG_FUSION_REDUCTION_REDUCTION: + return vars1 == vars2 and reduce1 == reduce2 and node1.inverse_users == node2.inverse_users + if node1.is_reduction() or node2.is_reduction(): + return False + + # Can't fuse two template node + if node1.is_template() and node2.is_template(): + return False + + # Check template node fusion + if node1.is_template() or node2.is_template(): + # Don't fuse maxpool template code + from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate + template_node1 = next((n for n in node1.get_nodes() if n.is_template()), None) + template_node2 = next((n for n in node2.get_nodes() if n.is_template()), None) + + if template_node1 and len(node1.get_nodes()) == 1 and isinstance(template_node1.node.template, MLIRMaxPoolTemplate) or \ + template_node2 and len(node2.get_nodes()) == 1 and isinstance(template_node2.node.template, MLIRMaxPoolTemplate): + return False + + # Pointwise check + v1_total = math.prod(vars1) if len(vars1) else 0 + v2_total = math.prod(vars2) if len(vars2) else 0 + if v1_total != v2_total: + return False + + # Pattern check + template_node, act_node = (template_node1, node2) if template_node1 else (template_node2, node1) + has_depedency = set(act_node.inverse_users) <= set(template_node.get_nodes()) + if not has_depedency: + return False + + # Revert act_node.group : simplify_and_reorder() modified _body, _size, group + if template_node.group != act_node.group: + # We don't fuse this case... + if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: + return False + + if list(template_node.group[1][0]) != list(act_node.get_nodes()[0].node.data.get_size()): + return False + self.revert_group(act_node) + return True + + # Check elementwise fusion + if vars1 == vars2 and reduce1 == reduce2: + return True + return False + + def revert_group(self, act_nodes, args=None, var_ranges=None): + for act_node in act_nodes.get_nodes(): + if args is None or var_ranges is None: + args, var_ranges = dependencies.index_vars_no_squeeze( + act_node.node.data.get_size(), act_node.node.data.get_reduction_size(), prefix="q" + ) + body = LoopBody( + act_node.node.get_store_function(), + (args if act_node.node.get_reduction_type() else args[:1]), + var_ranges, + ) + index_size = [] + reduce_size = [] + for v, s in var_ranges.items(): + if v in args[0]: + index_size.append(s) + else: + reduce_size.append(s) + node_device = act_node.get_device() + ranges = (index_size, reduce_size) + act_node._sizes, act_node._body, act_node.group = (ranges), body, (node_device, self.group_fn(ranges)) + + def group_fn(self, sizes): + return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) + + def codegen_nodes(self, nodes): + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + + # Note: We assume that ther is at least one loop in the nodes + # But, inductor simplifies the group, there could be no loop + # In that case, we add dummy loop(size=1) to the group + if len(group) == 0: + for idx, node in enumerate(nodes): + if len(node.node.data.get_size()) == 0: + continue + if len(reduction_group) != 0: + sym0, sym1 = sympy.Symbol("q0"), sympy.Symbol("q1") + args = [[sym0] + [sympy.Number(0)] * (len(node.node.data.get_size())-1), [sym1]] + var_ranges = {sym0: sympy.Number(1), sym1: reduction_group[0]} + else: + sym0 = sympy.Symbol("q0") + args = [[sym0] + [sympy.Number(0)] * (len(node.node.data.get_size())-1), []] + var_ranges = {sym0: sympy.Number(1)} + self.revert_group(node, args, var_ranges) + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + + ex_kernel = self.target_kernel(kernel_group=self.kernel_group) + ex_kernel.kernel_group = self.kernel_group + + kernel_name_candidate = f"extension_kernel_{MLIRScheduling.count}" + MLIRScheduling.count += 1 + src_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) + kernel_name = self.define_kernel(src_code, kernel_name_candidate, ex_kernel.vector_lane, + ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) + ex_kernel.call_kernel(kernel_name) + _, args, _, _ = ex_kernel.args.mlir_argdefs() + args = ", ".join(args) + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + if (eager_mode): + V.graph.wrapper_code.writeline( + f"yield ({kernel_name}, ({args}))" + ) + self._set_flush_status(True) + + def ready_to_flush(self): + return self._ready_to_flush + + def codegen_sync(self): + pass + + def flush(self): + self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) + self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() + self._set_flush_status(False) + + def define_function(self, kernel): + partial_code, function_name = kernel.def_function() + if partial_code is not None and function_name not in self.outer_function: + with V.set_kernel_handler(kernel): + code = partial_code.finalize() + wrapper = V.graph.wrapper_code + wrapper.header.writeline(code) + self.outer_function.add(function_name) + + def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + wrapper.src_to_kernel[src_code] = kernel_name + + codecache_def = IndentedBuffer() + codecache_def.writeline(f"custom_async_compile.mlir('''{src_code}''', ") + codecache_def.writeline(f"vectorlane_size={vector_lane},") + codecache_def.writeline(f"loop_size={loop_size},") + codecache_def.writeline(f"spad_info={spad_info},") + codecache_def.writeline(f"origins={origins},") + codecache_def.writeline("arg_attributes=arg_attributes,") + codecache_def.writeline(f"vlen={extension_config.CONFIG_VLEN})") + wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) + return kernel_name + + def codegen_template_code(self, kernel, render, template_node, prologue_nodes, epilogue_nodes): + with kernel: + _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() + for node in [template_node, *prologue_nodes, *epilogue_nodes]: + node.mark_run() + # Partial codgen template nodes + partial_code = render() + + # Swap load/store functions + kernel.load = kernel.load_epilogue + kernel.store = kernel.store_epilogue + kernel.store_reduction = kernel.store_reduction_epilogue + kernel.reduction = kernel.reduction_epilogue + + # Codegen prologue nodes + if prologue_nodes: + # Flush created varaibles, since template fusion doen't share variable + with kernel.prologue_buffer_group.as_local(): + _, (group, reduction_group) = max( + [prologue_nodes[-1]], key=lambda x: int(x.is_reduction()) + ).group + prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True) + kernel.kernel_group.set_tile_info(prologue_tile_desc) + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + for node in prologue_nodes: + # Reuse created spad + read_list = sorted(list(node.read_writes.reads)) + candidate_found = False + # Why? There is a case that memdep.get_size() != data.get_size() + buf_dict = {} + buf_dict.update({val.name : val for val in V.graph.buffers}) + buf_dict.update(V.graph.graph_inputs) + for candidate_read in read_list: + if candidate_read.name in buf_dict and reduce(operator.mul, buf_dict[candidate_read.name].get_size(), 1) == node.node.get_numel(): + prologue_input_arg = candidate_read.name + candidate_found = True + break + assert(candidate_found) + assert(len(node.read_writes.writes)==1) + prologue_output_arg = list(node.read_writes.writes)[0].name + template_buf = self.kernel_group.args.input_buffers[prologue_output_arg] + target_buf = f"{template_buf}_buffer" # FIXME. How to pass spad buffer name? + + # To skip the dma code gen + kernel.buffer_names[prologue_input_arg] = target_buf + kernel.buffer_names[prologue_output_arg] = target_buf + + # Edge delete + kernel.kernel_group.args.input_buffers = { + (arg if buf != template_buf else prologue_input_arg): buf + for arg, buf in kernel.kernel_group.args.input_buffers.items() + } + node.codegen((vars, reduction_vars)) + + # Codegen epilogue nodes + tile_desc = kernel.set_tile_size(kernel.epilogue_info) + kernel.kernel_group.set_tile_info(tile_desc) + kernel.call_ranges = None + if epilogue_nodes: + with kernel.epilogue_buffer_group.as_local(): + _, (group, reduction_group) = max( + epilogue_nodes, key=lambda x: int(x.is_reduction()) + ).group + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + for node in epilogue_nodes: + node.codegen((vars, reduction_vars)) + + with V.set_kernel_handler(kernel): + src_code = ( + partial_code + if isinstance(partial_code, str) + else partial_code.finalize() + ) + + # For consistency, white space could make wrong write_path + buffer = IndentedBuffer() + buffer.splice(src_code) + return buffer.getvalue() + + def codegen_template(self, template_node, epilogue_nodes): + # Handle prologue pattern + prologue_nodes = [] + if not template_node.is_template(): + epilogue_nodes = [template_node] + epilogue_nodes + for i, node in enumerate(epilogue_nodes): + if node.is_template(): + template_node = node + prologue_nodes = epilogue_nodes[:i] + epilogue_nodes = epilogue_nodes[i+1:] + break + + _, (numel, rnumel) = template_node.group + template_buffer = template_node.node + kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) + _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() + + src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) + wrapper = V.graph.wrapper_code + + if src_code in wrapper.src_to_kernel: # [CONV] check inner function is already defined + kernel_name = wrapper.src_to_kernel[src_code] + kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_name=kernel_name) # update kernel name + src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) + + with V.set_kernel_handler(kernel): + spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" + spad_section_end_symbol = f"int spad_section_end[0] __attribute__ ((section(\".spad\"), aligned({kernel.spad_info['spad_size']*kernel.vector_lane})));" + codegen_header(src_code, (kernel.header.getvalue()+spad_end_symbol+spad_section_end_symbol, kernel.gem5_header.getvalue())) + kernel.meta_kernel() + kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, + kernel.loop_size, origins={str(i) for i in template_node.node.origins}) + self.define_function(kernel) + + kernel.call_kernel(kernel_name) + V.graph.removed_buffers |= kernel.removed_buffers + _, args, _, _ = self.kernel_group.args.mlir_argdefs() + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + if (eager_mode): + target_kernel_name = kernel_name if kernel.outer_func_name is None else kernel.outer_func_name + f"_{len(args)}" + args = ", ".join(args) + V.graph.wrapper_code.writeline( + f"yield ({target_kernel_name}, ({args}))" + ) + self._set_flush_status(True) + + def enter_context_fixed(self, node): + def get_order(n): + if n not in self.scheduler.origin_to_index: + self.scheduler.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) + return self.scheduler.origin_to_index[n] + + origins = [(get_order(e), idx, e) for n in node.get_nodes() for idx, e in enumerate(n.node.origins)] + if origins: + _, _, last = max(origins) + V.graph.wrapper_code.enter_context(last) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index ec1340e7..0b2a08f8 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -2,59 +2,129 @@ import itertools import textwrap import re +import os +import contextlib import math import sympy +from collections import OrderedDict from typing import List, Optional from unittest.mock import patch -from torch._inductor.codegen.common import KernelTemplate -from torch._inductor.codegen.common import ChoiceCaller -from torch._inductor.codegen.common import Kernel -from torch._inductor.codegen.common import OpOverrides -from torch._inductor.ir import Buffer -from torch._inductor.ir import IRNode -from torch._inductor.ir import TemplateBuffer +from torch._inductor.codegen.common import Kernel, KernelTemplate, ChoiceCaller, OpOverrides, CSE, DeferredLine +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, View from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta -from torch._inductor.virtualized import V +from torch._inductor.virtualized import V, NullHandler, _ops as ops +from torch._inductor.utils import IndentedBuffer from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo -from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, MLIRTile +from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, reduction_combine_vec, is_welford_reduction +from PyTorchSimFrontend.mlir.mlir_scheduling import SchedulerNode +from torch._inductor.codegen import common +from PyTorchSimFrontend.extension_config import CONFIG_TORCHSIM_DIR from . import mlir_common +class IndentedBufferGroup: + def __init__(self, kernel: 'MLIRTemplateKernel', prefix=""): + self.kernel = kernel + self.body = IndentedBuffer() + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.applys = IndentedBuffer() + self.dma_loads = IndentedBuffer() + self.dma_stores = IndentedBuffer() + self.spad_buffer = IndentedBuffer() + self.cse = common.CSE("%", "", name_prefix=f"{prefix}") + self.apply_cse = common.CSE("%", "", name_prefix=f"{prefix}apply") + # Original buffers will be saved later in the 'with' block + self.original_buffers = {} + + def set_buffers(self): + self.kernel.loads = self.loads + self.kernel.compute = self.compute + self.kernel.stores = self.stores + self.kernel.applys = self.applys + self.kernel.dma_loads = self.dma_loads + self.kernel.dma_stores = self.dma_stores + self.kernel.spad_buffer = self.spad_buffer + self.kernel.cse = self.cse + self.kernel.apply_cse = self.apply_cse + + def restore_buffers(self): + self.kernel.loads = self.original_buffers['loads'] + self.kernel.compute = self.original_buffers['compute'] + self.kernel.stores = self.original_buffers['stores'] + self.kernel.applys = self.original_buffers['applys'] + self.kernel.dma_loads = self.original_buffers['dma_loads'] + self.kernel.dma_stores = self.original_buffers['dma_stores'] + self.kernel.spad_buffer = self.original_buffers['spad_buffer'] + self.kernel.cse = self.original_buffers['cse'] + self.kernel.apply_cse = self.original_buffers['apply_cse'] + + @contextlib.contextmanager + def as_local(self): + self.original_buffers = { + 'loads': self.kernel.loads, + 'compute': self.kernel.compute, + 'stores': self.kernel.stores, + 'applys': self.kernel.applys, + 'dma_loads': self.kernel.dma_loads, + 'dma_stores': self.kernel.dma_stores, + 'spad_buffer': self.kernel.spad_buffer, + 'cse': self.kernel.cse, + 'apply_cse': self.kernel.apply_cse, + } + try: + self.set_buffers() + yield self + finally: + self.restore_buffers() + class MLIRTemplateKernel(MLIRKernel, BaseMLIRHardwareInfo): def __init__(self, kernel_name, input_nodes, call_size, + kernel_group = None, outer_func_name=None, outer_func_render=None, kernel_arg_attributes=None) -> None: - super().__init__() + super().__init__(kernel_group if kernel_group is not None else mlir_common.MLIRWrapperKenrelGroup()) self.kernel_name = kernel_name self.input_nodes = input_nodes self.call_size = call_size self.named_nodes = {} self.loop_info = {} - self.load_desc = {} - self.store_desc = {} self.outer_func_name = outer_func_name self.outer_func_render = outer_func_render self.kernel_arg_attributes = kernel_arg_attributes - self.render_hooks = dict() + self.render_hooks = OrderedDict() self.buffer_names = dict() self.render_options = dict() self.tile_size = [] self.loop_size = None - self.is_template_kernel = True - - # Overwrite ops - self.load = self.load_epilogue - self.store = self.store_epilogue + self.map_cse = CSE("#", self.suffix, name_prefix="t_map") + self.const_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="t_const") + self.alloc_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="t_alloc") + self.prologue_buffer_group = IndentedBufferGroup(self, prefix="prologue_") + self.epilogue_buffer_group = IndentedBufferGroup(self, prefix="epilogue_") + self.global_vars = IndentedBuffer() + self.exception_nodes = {} + # Reduction data structure + self.reduction_epilogue_suffix = IndentedBuffer() + self.reduction_fusion = False + self.reduction_body_loop = None + self.reduction_buffer_idx = 0 + self.reduction_info = {} + self.reduction_epilogue_result = {} + self.reduction_mean = [] + # Dim info + self.dim_aliasing = {} def add_loop_info(self, mat_size, tile_size): for idx, (loop_size, stride) in enumerate(zip(mat_size, tile_size)): @@ -115,53 +185,288 @@ def gemmini_gemm_mapping(self, M, N, K): return inner_I, inner_J, inner_K - def gemm_combination_mapping(self, M, N, K): - spad_size = self.spad_info["spad_size"] * self.vector_lane + def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False): + spad_size_per_lane = self.spad_info["spad_size"] + spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 # double buffer - M_padded = ((M + self.vector_lane - 1) // self.vector_lane) * self.vector_lane - N_padded = ((N + self.vector_lane - 1) // self.vector_lane) * self.vector_lane - K_padded = ((K + self.vector_lane - 1) // self.vector_lane) * self.vector_lane + max_spad_per_lane = spad_size_per_lane // 2 # double buffer + minimum_n_tile = self.num_cores if min_tile else 1 + m_pad_factor = self.vector_lane if M > self.vector_lane else 8 + n_pad_factor = self.vector_lane if N > self.vector_lane else 8 + k_pad_factor = self.vector_lane if K > self.vector_lane else (8 if pad_k else 1) + K = max(K, 8) + M_padded = ((M + m_pad_factor - 1) // m_pad_factor) * m_pad_factor + N_padded = ((N + n_pad_factor - 1) // n_pad_factor) * n_pad_factor + K_padded = ((K + k_pad_factor - 1) // k_pad_factor) * k_pad_factor + indexI, indexJ, indexK = (M_padded // self.vector_lane, N_padded // self.vector_lane, K_padded // self.vector_lane) max_used_spad_size = 0 mapping = (self.vector_lane, self.vector_lane, self.vector_lane) - for tile_M in range(self.vector_lane, M_padded + 1, self.vector_lane): - for tile_N in range(self.vector_lane, N_padded + 1, self.vector_lane): - for tile_K in range(self.vector_lane, K_padded + 1, self.vector_lane): - used_spad_size = (tile_M * tile_K + tile_K * tile_N + tile_M * tile_N) * self.precision - if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size: + tile_M_range = sympy.divisors(indexI) if M > self.vector_lane else [1] + tile_N_range = sympy.divisors(indexJ) if N > self.vector_lane else [1] + tile_K_range = sympy.divisors(indexK) if K > self.vector_lane else [1] + maximize_i_j = 1 # reuse weight + for k in tile_K_range: # store tile candidates for manual mapping + tile_K = k * self.vector_lane if K > self.vector_lane else K_padded + for i in tile_M_range: + tile_M = i * self.vector_lane if M > self.vector_lane else M_padded + for j in tile_N_range: + tile_N = j * self.vector_lane if N > self.vector_lane else N_padded + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) + input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) + output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size: + dir_path = f"{CONFIG_TORCHSIM_DIR}/validation/gemm_candidates" + os.makedirs(dir_path, exist_ok=True) + file_path = f"{dir_path}/gemm_{M}_{K}_{N}.txt" + line_to_write = f"{tile_M} {tile_K} {tile_N}\n" + try: + with open(file_path, "r") as f: + lines = f.readlines() + except FileNotFoundError: + lines = [] + if line_to_write not in lines: + with open(file_path, "a") as f: + f.write(line_to_write) + + for k in tile_K_range: # heuristic search + tile_K = k * self.vector_lane if K > self.vector_lane else K_padded + for i in tile_M_range: + tile_M = i * self.vector_lane if M > self.vector_lane else M_padded + for j in tile_N_range: + tile_N = j * self.vector_lane if N > self.vector_lane else N_padded + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) + input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) + output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + n_tile = math.ceil(M / max(tile_M, 128)) * math.ceil(N / max(tile_N, 128)) + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size and max_used_spad_size < used_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and max(tile_N, 128) // max(tile_M, 128) < 10: max_used_spad_size = used_spad_size + maximize_i_j = tile_M * tile_N mapping = (tile_M, tile_N, tile_K) + return mapping - Outer_M = math.ceil(M_padded / mapping[0]) - Outer_N = math.ceil(N_padded / mapping[1]) - Outer_K = math.ceil(K_padded / mapping[2]) + def search_mapping_space(self, mapping, idx, increment, stride, dilation, n_extra_node=0): + if idx == 0 or idx == 1 or idx == 4 or idx == 5 or idx == 6: + raise NotImplementedError("Only O_H and O_W are supported for search_mapping_space") + spad_size_per_lane = self.spad_info["spad_size"] + spad_size = spad_size_per_lane * self.vector_lane + max_spad_size = spad_size // 2 # double buffer + max_spad_per_lane = spad_size_per_lane // 2 # double buffer + + mapping = list(mapping) + mapping[idx] += increment + k_h, k_w, o_h, o_w, M, N, K = mapping + i_h = 1 + (o_h - 1) * stride[0] + (k_h - 1) * dilation[0] + i_w = 1 + (o_w - 1) * stride[1] + (k_w - 1) * dilation[1] + weight_size = k_w * k_h * K * N + input_size = i_w * i_h * M * K + output_size = o_w * o_h * M * N + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + weight_size_per_lane = self.get_spad_size_per_lane(k_w * k_h * K, N) + input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) + output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + if used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane: + mapping = (k_h, k_w, o_h, o_w, M, N, K) + else: + mapping[idx] -= increment + + return mapping - # split mapping equally to avoid unnecessary padding - mapping = (M_padded // Outer_M, N_padded // Outer_N, K_padded // Outer_K) + def pseudo_auto_tune(self, mapping, stride, dilation, O_H, O_W, n_extra_node=0): + # pseudo auto-tune + if mapping[2] == 1 and not (O_H == 1): + mapping = self.search_mapping_space(mapping, 2, 1, stride, dilation, n_extra_node=n_extra_node) + if mapping[3] == 1 and not (O_W == 1): + mapping = self.search_mapping_space(mapping, 3, 1, stride, dilation, n_extra_node=n_extra_node) + return mapping + + def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + spad_size_per_lane = self.spad_info["spad_size"] + spad_size = spad_size_per_lane * self.vector_lane + max_spad_size = spad_size // 2 # double buffer + max_spad_per_lane = spad_size_per_lane // 2 # double buffer + + max_used_spad_size = 0 + M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False) + max_k_h_w = 1 # maximize kernel size + max_o_h_w = 1 # maximize output size + K = min(K, self.vector_lane) + for o_h in sympy.divisors(O_H): + for o_w in sympy.divisors(O_W): + for k_h in sympy.divisors(K_H): + for k_w in sympy.divisors(K_W): + i_h = 1 + (o_h - 1) * stride[0] + (k_h - 1) * dilation[0] + i_w = 1 + (o_w - 1) * stride[1] + (k_w - 1) * dilation[1] + weight_size = k_w * k_h * K * N + input_size = i_w * i_h * M * K + output_size = o_w * o_h * M * N + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + weight_size_per_lane = self.get_spad_size_per_lane(k_w * k_h * K, N) + input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) + output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane and max_k_h_w <= k_h * k_w and max_o_h_w <= o_h * o_w: + max_used_spad_size = used_spad_size + max_k_h_w = k_h * k_w + max_o_h_w = o_h * o_w + mapping = (k_h, k_w, o_h, o_w, M, N, K) + if max_used_spad_size == 0: + raise RuntimeError("Cannot find a valid mapping") + + # FIXME: this should be implemented with auto-tuning + mapping = self.pseudo_auto_tune(mapping, stride, dilation, O_H, O_W, n_extra_node=n_extra_node) + + return mapping + + def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + spad_size_per_lane = self.spad_info["spad_size"] + spad_size = spad_size_per_lane * self.vector_lane + max_spad_size = spad_size // 2 + max_spad_per_lane = spad_size_per_lane // 2 + + max_used_spad_size = 0 + M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False) + max_k_h_w = K_W + for o_h in sympy.divisors(O_H): + for o_w in sympy.divisors(O_W): + for k_h in sympy.divisors(K_H): + i_h = 1 + (o_h - 1) * stride[0] + (k_h - 1) * dilation[0] + i_w = 1 + (o_w - 1) * stride[1] + (K_W - 1) * dilation[1] + weight_size = 1 * k_h * K * N + input_size = i_w * i_h * M * K + output_size = o_w * o_h * M * N + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + weight_size_per_lane = self.get_spad_size_per_lane(1 * k_h * K, N) + input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) + output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane and max_k_h_w <= k_h: + max_used_spad_size = used_spad_size + max_k_h_w = k_h + mapping = (k_h, K_W, o_h, o_w, M, N, K) + if max_used_spad_size == 0: + raise RuntimeError("Cannot find a valid mapping") + return mapping + + def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + spad_size_per_lane = self.spad_info["spad_size"] + spad_size = spad_size_per_lane * self.vector_lane + max_spad_size = spad_size // 2 + max_spad_per_lane = spad_size_per_lane // 2 + + max_used_spad_size = 0 + M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False) + max_k_h_w = 1 + for o_h in sympy.divisors(O_H): + for k_h in sympy.divisors(K_H): + for k_w in sympy.divisors(K_W): + i_h = 1 + (o_h - 1) * stride[0] + (k_h - 1) * dilation[0] + i_w = 1 + (M - 1) * stride[1] + (k_w - 1) * dilation[1] + weight_size = k_w * k_h * K * N + input_size = i_w * i_h * k_w * K + output_size = M * o_h * N + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + weight_size_per_lane = self.get_spad_size_per_lane(k_w * k_h * K, N) + input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * k_w, K) + output_size_per_lane = self.get_spad_size_per_lane(M * o_h * (1 + n_extra_node), N) + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane and max_k_h_w <= k_h * k_w: + max_used_spad_size = used_spad_size + max_k_h_w = k_h * k_w + mapping = (k_h, k_w, o_h, M, M, N, K) + if max_used_spad_size == 0: + raise RuntimeError("Cannot find a valid mapping") return mapping def meta_kernel(self): wrapper = V.graph.wrapper_code - arg_attributes = self.kernel_arg_attributes - if arg_attributes is None: - _, _, arg_attributes, _ = self.args.mlir_argdefs() + kernel_arg_attributes = self.kernel_arg_attributes + _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() + if kernel_arg_attributes is not None: + for name, attr in kernel_arg_attributes: + for idx in range(len(arg_attributes)): + if arg_attributes[idx][0] == name: + arg_attributes[idx][1] = attr wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') - wrapper.add_import_once(f'\nfrom PyTorchSimFrontend.extension_codecache import CustomAsyncCompile') - wrapper.add_import_once(f'\ncustom_async_compile = CustomAsyncCompile()') # Dump loop and load/store information wrapper.add_import_once(f"loop_info = {self.loop_info}") - wrapper.add_import_once(f"load_tile_info = {self.load_desc}") - wrapper.add_import_once(f"store_tile_info = {self.store_desc}") wrapper.add_import_once(f"arg_attributes = {arg_attributes}") def call_kernel(self, kernel_name): wrapper = V.graph.wrapper_code - _, call_args, _, _ = self.args.mlir_argdefs() + _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this wrapper.generate_kernel_call( - kernel_name if self.outer_func_name is None else self.outer_func_name, + kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args, cuda=False) + def codegen_prologue_body(self): + body = IndentedBuffer() + with self.prologue_buffer_group.as_local(): + body.splice(self.spad_buffer) + body.splice(self.applys) + body.splice(self.dma_loads) + + if (self.loads.getvalue() != '' or self.compute.getvalue() != '' or self.stores.getvalue() != ''): + body.writelines(self.prologue_compute_body_loop.lines()) + compute_body = mlir_common.ParallelLoopBuffer() + with contextlib.ExitStack() as stack: + stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) + compute_body.splice(self.loads) + compute_body.splice(self.compute) + compute_body.splice(self.stores) + body.splice(compute_body) + body.splice(self.dma_stores) + return body + + def codegen_epilogue_body(self): + def template_store(): + dram_var = self.epilogue_info["dram_var"] + index_list = self.epilogue_info["dram_idx"] + tile_desc = self.epilogue_info["dram_tile_desc"] + code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc) + self.cse.generate(self.dma_stores, code, assignment = False) + + body = IndentedBuffer() + with self.epilogue_buffer_group.as_local(): + # Do dma store first to overlap epilogue nodes + if self.reduction_fusion: + if len(self.stores._lines) == 0: + template_store() + body.splice(self.dma_stores) + self.dma_stores.clear() + body.splice(self.spad_buffer) + body.splice(self.applys) + body.splice(self.dma_loads) + body.writelines(self.compute_body_loop.lines()) + compute_body = mlir_common.ParallelLoopBuffer() + with contextlib.ExitStack() as stack: + stack.enter_context(compute_body.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) + if self.reduction_fusion: + compute_body.writelines(self.reduction_body_loop.lines()) + compute_body.splice(self.masks) + stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) + compute_body.splice(self.loads) + compute_body.splice(self.compute) + else: + compute_body.splice(self.loads) + compute_body.splice(self.compute) + if len(self.stores._lines) == 0: + template_store() + compute_body.splice(self.stores) + if (compute_body.getvalue()): + body.splice(compute_body) + body.splice(self.dma_stores) + body.splice(self.reduction_epilogue_suffix) + return body + def def_kernel( self, inputs: List[IRNode], @@ -185,134 +490,547 @@ def def_kernel( node = inputs[idx] if node is not None: self.named_nodes[name] = node - self.args.input_buffers[node.get_name()] = name + self.kernel_group.args.input_buffers[node.get_name()] = name extra_node = {} for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): if node is not None: self.named_nodes[name] = node - self.args.output_buffers[node.get_name()] = name + self.kernel_group.args.output_buffers[node.get_name()] = name self.store_buffer_names.add(node.get_name()) #TODO: Is this enough not calling store() in mlir_common.py? - extra_node[node.get_name()] = node - self.buffer_names[node.name] = 'Y_buffer' #TODO: Buffer name fixed + if isinstance(node, SchedulerNode): + extra_node[node.get_name()] = node.node + else: + extra_node[node.get_name()] = node + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] def hook(): - arg_defs, *_ = self.args.mlir_argdefs(extra_node=extra_node) + arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) return f"({', '.join(arg_defs)})" assert "" not in self.render_hooks self.render_hooks[""] = hook return "" - def output_name(self): - # Cannot know the output name from the template, so we need to hook it + # This function is a temporal function for convolution because currently convolution kernel is not considering padding. + # Padding is done by python wrapper so the padded input size is manually applied here. + def def_conv_kernel( + self, + inputs: List[IRNode], + outputs: List[IRNode], + names_str: str = "", + padded_input_size: List[int] = [], + input_reorder: Optional[List[int]] = None, + ) -> str: + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.kernel_group.args.input_buffers[node.get_name()] = name + + self.extra_node = {} + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.kernel_group.args.output_buffers[node.get_name()] = name + self.store_buffer_names.add(node.get_name()) #TODO: Is this enough not calling store() in mlir_common.py? + self.extra_node[node.get_name()] = node + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed + + def kernel_hook(): + arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=self.extra_node) + arg_defs[0] = re.sub(r'(\d+)(?=xf32)', str(padded_input_size), arg_defs[0]) + return f"({', '.join(arg_defs)})" + + assert "" not in self.render_hooks + self.render_hooks[""] = kernel_hook + return "" + + # This function is for convolution wrapper function finalizing. + def def_wrapper(self, only_store_buffer: bool = False, epilogue_buffer: str = False): + def wrapper_hook(): + arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=self.extra_node) + wrapper_arg_defs = [arg.split('%')[1].split(':')[0] for arg in arg_defs] + return f"({', '.join(wrapper_arg_defs)})" + + if "" not in self.render_hooks: + self.render_hooks[""] = wrapper_hook + return "" + + def get_conv_inputs(self): + return self.kernel_group.args.input_buffers + + def get_conv_outputs(self): + return {k: v for k, v in self.kernel_group.args.output_buffers.items() if v != 'REMOVED'} + + def load_input(self, indent_size: int = 0): def hook(): - arg_defs, *_ = self.args.mlir_argdefs() - output = arg_defs[3] #FIXME: Constant index used - pattern = r"%(\w+):" - output = re.search(pattern, output).group(1) - return output - assert "" not in self.render_hooks - self.render_hooks[""] = hook - return "" - - def store_output(self): + code = IndentedBuffer() + prologue_code = self.codegen_prologue_body() + if prologue_code.getvalue(): + input_dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + weight_dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + if (self.prologue_info["is_input_fused"]): + code.splice(input_dma_code) + code.splice(prologue_code) + code.splice(weight_dma_code) + else: + code.splice(weight_dma_code) + code.splice(prologue_code) + code.splice(input_dma_code) + else: + dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + code.splice(dma_code) + dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + code.splice(dma_code) + code = textwrap.indent(code.getvalue(), " "*indent_size).strip() + return code + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + return "" + + def store_output(self, indent_size: int = 0): def hook(): - self.codegen_body() - return textwrap.indent(self.body.getvalue(), " ").strip() #TODO: First line is not indented + epilogue_code = self.codegen_epilogue_body() + return textwrap.indent(epilogue_code.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks self.render_hooks[""] = hook + self.render_hooks.move_to_end("", last=False) # Force order to be triggered first return "" + def reduction_output(self, indent_size: int = 0): + def hook(): + return textwrap.indent(self.reductions_suffix.getvalue(), " "*indent_size).strip() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + def def_function(self): - _, call_args, _ = self.args.python_argdefs() + _, call_args, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: - return self.outer_func_render(input_args=call_args) + partial_code, function_name = self.outer_func_render(input_args=call_args) + return PartialRender( + partial_code, + self.render_hooks, + ), function_name else: return None, None def def_global_vars(self): - return "" + key = "" + def hook(): + return textwrap.indent(self.global_vars.getvalue(), "").strip() - def replace_global_vars(self): - return textwrap.indent(self.global_vars.getvalue(), "").strip() + assert key not in self.render_hooks + self.render_hooks[key] = hook + return key - def add_extra_global_vars(self, code): - key = "" - return code.replace(key, self.replace_global_vars()) + def def_local_vars(self, indent_size=0): + key = "" + def hook(): + code = IndentedBuffer() + code.tabwidth = 1 + code.splice(self.const_buffer) + code.splice(self.alloc_buffer) + return textwrap.indent(code.getvalue(), " "*indent_size).strip() + + assert key not in self.render_hooks + self.render_hooks[key] = hook + return key + + def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, + subtile_size:list=[], async_type=None, indent_size=0): + # Prepare code block + local_code = IndentedBuffer() + with V.set_kernel_handler(self): + index_var = self.parse_index_list(index_list, local_code) + node_layout = self.named_nodes[dram_var].get_layout() + if dram_var in self.exception_nodes: + numel = self.exception_nodes[dram_var]["numel"] + else: + numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() + mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] + dram_shape = f"memref<{numel}x{mlir_dtype}>" + dram_stride = [] + for idx in index_list: + if idx.is_Mul: + dram_stride.append(int(idx.args[0])) + elif idx == sympy.Symbol("c0"): + dram_stride.append(0) + elif not idx.is_Number: + dram_stride.append(1) + else: + dram_stride.append(0) + + sram_var = tile_desc.get_name() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = tile_desc.get_tile_stride() + vlane_split_axis = tile_desc.vlane_split_axis + vlane_stride = tile_desc.vlane_stride + + zero_cse = self.get_const_cse(0, "index") + sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) + + attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] + if subtile_size: + attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") + attribute = " {" + ", ".join(attribute_parts) + "}" + code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, "") + local_code.writeline(code) + local_code.writeline(attribute) + return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + + def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): + # Prepare code block + with V.set_kernel_handler(self): + dtype = self.named_nodes[dram_name].get_layout().dtype + tile_shape = tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[dtype]) + buffer_name = self.allocate_sram_buffer(dtype, dram_name, tile_desc, id, forced_name=dram_name) + code = f"%{tile_desc.name} = memref.get_global @{buffer_name} : {tile_shape}" + return textwrap.indent(code, " "*indent_size).strip() + + def render(self, template, kwargs, define_function=None): + code = template.render(**kwargs) + if define_function is not None: + define_function(self) - def render(self, template, kwargs): - # self.render_hooks = {} return PartialRender( - template.render(**kwargs), + code, self.render_hooks, ) - def adjust_tile_size(self): - self.tile_desc.n_row = self.render_options['TILE_M'] - self.tile_desc.n_col = self.render_options['TILE_N'] - return + def get_spad_size_per_lane(self, tile_m, tile_n): + size = tile_m * ((tile_n + self.vector_lane - 1) // self.vector_lane) + return max(size, 2) # vector load/store def load_epilogue(self, name: str, index: sympy.Expr): index = self.rename_indexing(index) - var = self.args.input(name) + dram_var = self.kernel_group.args.input(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - if name in self.buffer_names: - buffer = self.buffer_names[name] + # Want to use tile_desc from epilogue_info + index_var = self.parse_indices(index) + dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()] + vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis + vlane_stride = self.kernel_group.tile_desc.vlane_stride + tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = self.kernel_group.tile_desc.get_tile_stride() + + # Compute vector unit size + vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) + compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() + + if name not in self.buffer_names: + # Allocate sram buffer + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" + code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute) + self.cse.generate(self.dma_loads, code, assignment = False) + self.buffer_names[name] = sram_var else: - dram_mlir_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) - mvin3 = 14 - self.consts.add(mvin3) - self.consts.add(0) - dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, index) - self.buffer_names[name] = buffer - line = f"affine.dma_start %{var}[%index2], %{buffer}[%c0, %c0], %tag[0], %c{mvin3}, %N, %c_set : {dram_mlir_shape}, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" - self.cse.generate(self.loads, line, assignment = False) - - tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane - operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load" - shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" - line = f"{operation} %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" - out = self.cse.generate(self.loads, line) - var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] - self.register_var_info(out, var_info) + sram_var = self.buffer_names[name] + + # Load vector from sram + zero_var = self.get_const_cse(0) + if not self.reduction_fusion: + compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) + if compute_vec_size > 1: + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + else: + operation = "affine.load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" + out = self.cse.generate(self.loads, line) + self.register_var_info(out, [compute_vec_size, mlir_dtype]) + else: # For reduction case + reduce_size = self.reduction_nr_outer_loop + vsize = compute_vec_size//reduce_size + vshape = f"vector<{vsize}x{mlir_dtype}>" + + if compute_vec_size > 1: + offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.reduction_axis_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})") + compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"]) + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + out = self.cse.generate(self.loads, line) + else: + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" + out = self.cse.generate(self.loads, line) + self.register_var_info(out, [self.compute_body_loop.step, mlir_dtype]) return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): - indices = self.parse_indices(index) - prefix = self.newvar_prefix - if index.is_number: - prefix = prefix + "c" - self.consts.add(int(index)) - var = self.args.output(name) + index = self.rename_indexing(index) + dram_var = self.kernel_group.args.output(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + + index_var = self.parse_indices(index) + dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()] + vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis + vlane_stride = self.kernel_group.tile_desc.vlane_stride + tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = self.kernel_group.tile_desc.get_tile_stride() + + # Compute vector unit size + vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) + compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() + + if name not in self.buffer_names: + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) + self.buffer_names[name] = sram_var + store_force = False + else: + zero_cse = self.get_const_cse(0) + sram_dims = len(tile_shape.split("x")) - 1 + sram_index_var = ",".join([f"%{zero_cse}"] * sram_dims) + store_force = True + sram_var = self.buffer_names[name] + zero_var = self.get_const_cse(0) + + _, operand_type = self.var_info[value] + if mlir_dtype != operand_type: + value = ops.to_dtype(value, mlir_dtype, var_info=self.var_info) + compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) + # Generate vector load instruction + if compute_vec_size > 1: + operation = "affine.vector_store" + line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + else: + operation = "affine.store" + line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}" + line = line if store_force else DeferredLine(name, line) + self.stores.writeline(line) + + # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" + code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute) + self.dma_stores.writeline(DeferredLine(name, code)) + + def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): + argmax_or_argmin = reduction_type in {"argmax", "argmin"} + if argmax_or_argmin: + raise NotImplementedError() #TODO: argmin, argmax + if is_welford_reduction(reduction_type): + if reduction_type == "welford_combine": + raise NotImplementedError("welford_combine") + else: + assert reduction_type == "welford_reduce" + type_name = mlir_common.DTYPE_TO_MLIR[dtype] + reduction_key = src_dtype, reduction_type, value + sum = self.reduction_epilogue(dtype, src_dtype, "sum", value) + sqr_sum = self.reduction_epilogue(dtype, src_dtype, "sum", ops.mul(value, value)) + self.welford_reduce_out = (sum, sqr_sum, None) + return sum, sqr_sum, None + + # Check duplicated reductions + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_epilogue_result: + return self.reduction_epilogue_result[reduction_key] + + # Reduction fusion codegen part + vec_size = self.compute_body_loop.step type_name = mlir_common.DTYPE_TO_MLIR[dtype] + new_tile_size = self.kernel_group.tile_desc.get_tile_size()[:-1] + [vec_size] + new_vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis + new_vlane_stride = self.kernel_group.tile_desc.vlane_stride + local_tile_desc = mlir_common.MLIRMultiDimTile(new_tile_size, self.vector_lane, new_vlane_split_axis, new_vlane_stride, vec_size) + + tile_shape = local_tile_desc.get_mlir_shape(type_name) + vshape = local_tile_desc.get_mlir_vshape(type_name) + + name = f"{reduction_type}_buffer{self.reduction_buffer_idx}" + self.reduction_buffer_idx += 1 + index = "dummy_index" # Not used + sram_var, _ = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index, self.const_buffer) + self.reduction_epilogue_result[reduction_key] = sram_var + + # Load partial result + zero_var_list = [f"%{self.get_const_cse(0)}"] * local_tile_desc.get_nr_dim() + zero_var_list[-2] = f"%{self.reduction_loop_idx}" + compute_index_var = ", ".join(zero_var_list) + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + out = self.cse.generate(self.loads, line) + self.register_var_info(out, [self.compute_body_loop.step, type_name]) + + # Reduction body codegen + init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {vshape}") + self.register_var_info(init_vec, [local_tile_desc.get_compute_vec_size(), type_name]) + mask_shape, mask_var = self.get_mask() + if mask_var is not None: + value = ops.where(mask_var, value, init_vec) + result = reduction_partial_combine_vec(reduction_type, value, out) + + # Store partial result + operation = "affine.vector_store" + line = f"{operation} %{result}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + self.compute.writeline(line) # Need to be placed after partial reduction + self.reduction_info[sram_var] = [reduction_type, local_tile_desc] + return sram_var + + def store_reduction_epilogue(self, name, index, value): + index = self.rename_indexing(index) + dram_var = self.kernel_group.args.output(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) + dtype = V.graph.get_dtype(name) + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + + index_var = self.parse_indices(index, self.reductions_suffix, comments="// Store reduction") + dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()][:-1] # Assume that there is only one reduction axis + vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis + vlane_stride = self.kernel_group.tile_desc.vlane_stride + + # Create final buffer descriptor + nr_outer_loop = self.reduction_nr_outer_loop + tile_size = self.kernel_group.tile_desc.get_tile_size()[:-1] + final_tile_desc = mlir_common.MLIRMultiDimTile(tile_size, self.vector_lane, vlane_split_axis, vlane_stride*nr_outer_loop*2) + final_tile_shape = final_tile_desc.get_mlir_shape(mlir_dtype) + final_tile_stride = final_tile_desc.get_tile_stride() + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, final_tile_desc, index, buffer=self.const_buffer) + + # Set partial buffer descriptor + partial_tile_desc = self.reduction_info[value][1] + partial_vec_size = partial_tile_desc.get_compute_vec_size() + partial_vshape = partial_tile_desc.get_mlir_vshape(mlir_dtype) + partial_tile_shape = partial_tile_desc.get_mlir_shape(mlir_dtype) + + # Prepare constant + init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(self.reduction_info[value][0], dtype)} : {mlir_dtype}") + partial_zero_var_list = [f"%{self.get_const_cse(0)}"] * partial_tile_desc.get_nr_dim() + final_zero_var_list = [f"%{self.get_const_cse(0)}"] * final_tile_desc.get_nr_dim() + for i in range(self.reduction_body_loop.size): + # Load partial result + body_index_var = self.const_cse.generate(self.const_buffer, f"arith.constant {i} : index") + partial_zero_var_list[-2] = f"%{body_index_var}" + compute_index_var = ",".join(partial_zero_var_list) + + operation = "affine.vector_load" + line = f"{operation} %{value}[{compute_index_var}] : {partial_tile_shape}, {partial_vshape}" + out = self.cse.generate(self.reductions_suffix, line) + operation = "affine.vector_store" + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {mlir_dtype} to {partial_vshape}") + line = f"{operation} %{init_vec}, %{value}[{compute_index_var}] : {partial_tile_shape}, {partial_vshape}" + self.reductions_suffix.writeline(line) + + # 2 step reduction + new_vec_size = 2 + new_vshape = f"vector<{partial_vec_size//new_vec_size}x{new_vec_size}x{mlir_dtype}>" + new_reduced_shape = f"vector<{new_vec_size}x{mlir_dtype}>" + out = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{out} : {partial_vshape} to {new_vshape}") + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {mlir_dtype} to {new_reduced_shape}") + out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(self.reduction_info[value][0], out, init_vec, axis=0, shape=new_vshape, reduced_shape=new_reduced_shape)) + out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}") + + self.compute, self.reductions_suffix = self.reductions_suffix, self.compute + self.register_var_info(out, [new_vec_size, mlir_dtype]) + self.register_var_info(out2, [new_vec_size, mlir_dtype]) + out = reduction_partial_combine_vec(self.reduction_info[value][0], out, out2) + self.compute, self.reductions_suffix = self.reductions_suffix, self.compute + + if self.welford_reduce_out is not None: + # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2 + divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.reduction_axis_size)} : f32") + if self.reduction_axis_size - 1 > 0: + divider2 = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.reduction_axis_size-1)} : f32") + else: + divider2 = divider + + if self.buffer_types[name][1] > 1: + divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to {new_reduced_shape}") + else: + divider_vec = divider + + if self.current_node.node.origin_node: # FIXME: This is a temporary solution + # mean = SUM(X) / N + self.reduction_mean.append(self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}")) + out = self.reduction_mean[i] + else: + # m2 = (E(X^2) - E(X)^2) * N + sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}") + mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{self.reduction_mean[i]}, %{self.reduction_mean[i]} : {new_reduced_shape}") + variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {new_reduced_shape}") + m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {new_reduced_shape}") + out = m2 + + final_zero_var_list[-1] = f"%{body_index_var}" + final_compute_index_var = ",".join(final_zero_var_list) + operation = "affine.vector_store" + line = f"{operation} %{out}, %{sram_var}[{final_compute_index_var}] : {final_tile_shape}, {new_reduced_shape}" + self.reductions_suffix.writeline(DeferredLine(name, line)) + + # MVOUT Encoding + # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={final_tile_stride}, padding=0}}" + code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, final_tile_shape, attribute) + self.reductions_suffix.writeline(DeferredLine(name, code)) + + def set_tile_size(self, template_fusion_info, prologue=False): + tile_desc = template_fusion_info["dram_tile_desc"] + if "dim_aliasing" in template_fusion_info: + self.dim_aliasing = template_fusion_info["dim_aliasing"] - chunk_size = self.tile_desc.get_chunk_size() - chunk = chunk_size << 1 | (self.tile_desc.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE) - self.consts.add(chunk) + if 'nr_rdim' in template_fusion_info and template_fusion_info['nr_rdim']==1: + tile_desc.nr_rdim = 1 + numel_per_lane = tile_desc.get_numel_per_lane() + reduction_axis_size = tile_desc.get_tile_size()[-1] + nr_outer_loop = (numel_per_lane + reduction_axis_size-1) // reduction_axis_size + tile_desc.vec_size = nr_outer_loop * 32 # Why? Emprically selected, other option failed to functionality... - if name in self.buffer_names: - buffer = self.buffer_names[name] + self.reduction_fusion = True + self.reduction_axis_size = tile_desc.get_tile_size()[-1] + self.reduction_nr_outer_loop = nr_outer_loop + self.reduction_loop_idx = "reduce_loop_idx" + self.compute_body_loop.size = reduction_axis_size + self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop + self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) else: - dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices, index) - self.buffer_names[name] = buffer + tile_desc.vec_size=64 - tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane - operation = "affine.vector_store" if tile_size_per_lane > 1 else "affine.store" - shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" - line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" - self.cse.generate(self.stores, line, assignment = False) + if prologue: + self.prologue_compute_body_loop.size = tile_desc.get_numel_per_lane() + self.prologue_compute_body_loop.step = tile_desc.get_compute_vec_size() + else: + self.compute_body_loop.size = tile_desc.get_numel_per_lane() + self.compute_body_loop.step = tile_desc.get_compute_vec_size() + return tile_desc - self.tags.add(f"{name}_tag") - self.consts.add(0) - code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[%index2], %tag[0], %c_mvout, %N, %c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag - self.cse.generate(self.stores, code, assignment = False) + def rename_indexing(self, index) -> sympy.Expr: + for dim_name, dim_aliased_name in self.dim_aliasing.items(): + index = index.subs(sympy.Symbol(dim_name), sympy.Symbol("tmp_"+dim_aliased_name)) + # To avoid this case ({"index0":"index1", "index1":"index0"}) + for dim_aliased_name in self.dim_aliasing.values(): + index = index.subs(sympy.Symbol("tmp_"+dim_aliased_name), sympy.Symbol(dim_aliased_name)) + return index class MLIRTemplateCaller(CUDATemplateCaller): def __str__(self): @@ -344,7 +1062,7 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): def generate(self, **kwargs) -> ChoiceCaller: kernel_name = f"mlir_{self.name}" with patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)): - kernel = MLIRTemplateKernel(kernel_name=kernel_name, input_nodes=self.input_nodes, call_size=self.layout.size, + kernel = MLIRTemplateKernel(kernel_name=kernel_name, input_nodes=self.input_nodes, call_size=self.layout.size, kernel_group=None, outer_func_name=self.function_name if hasattr(self, 'function_name') else None, outer_func_render=self.outer_func_render if hasattr(self, 'outer_func_render') else None, kernel_arg_attributes=self.get_arg_attributes() if hasattr(self, 'get_arg_attributes') else None) @@ -363,13 +1081,16 @@ def generate(self, **kwargs) -> ChoiceCaller: def make_kernel_render( template_node: TemplateBuffer, + prologue_nodes: Optional[List[IRNode]] = None, epilogue_nodes: Optional[List[IRNode]] = None, kernel_name: str = kernel_hash_name, + kernel_group: Optional[mlir_common.MLIRWrapperKenrelGroup] = None ): kernel = MLIRTemplateKernel( kernel_name=kernel_name, input_nodes=self.input_nodes, call_size=self.layout.size, + kernel_group=kernel_group, outer_func_name=self.function_name if hasattr(self, 'function_name') else None, outer_func_render=functools.partial( self.outer_func_render, @@ -381,7 +1102,8 @@ def make_kernel_render( kwargs = { 'kernel': kernel, 'template_buffer_node': template_node, - 'epilogue_nodes': epilogue_nodes + 'epilogue_nodes': epilogue_nodes, + 'prologue_nodes': prologue_nodes, } render = functools.partial( kernel.render, diff --git a/Scheduler/scheduler.py b/Scheduler/scheduler.py index 1d8064f9..834698a6 100644 --- a/Scheduler/scheduler.py +++ b/Scheduler/scheduler.py @@ -1,7 +1,6 @@ from typing import List import os -import sys -import json +import numpy as np import torch from pathlib import Path import importlib.util @@ -23,6 +22,19 @@ def import_module_from_path(module_name, path): return module +def poisson_request_generator(lambda_requests, max_msec_time=None): + current_time = 0.0 # msec + + yield 0 + while max_msec_time is None or current_time < max_msec_time: + inter_arrival_time = np.random.exponential(scale=1000 / lambda_requests) + current_time += inter_arrival_time + + if max_msec_time is not None and current_time > max_msec_time: + break + + yield current_time + class Request: """ Each request has model name, it's own id, and requested time. """ request_id = 0 @@ -47,9 +59,6 @@ def allocate_id(self): Request.request_id += 1 return allocated_id - def is_arrived(self, current_time): - return current_time >= self.arrival_time - def set_start(self, start_time): self.state = self.RUNNING self.start_time.append(start_time) @@ -109,9 +118,7 @@ def find_model(self, model_name : str): if model_name in SchedulerDNNModel.MODEL_MAP: return SchedulerDNNModel.MODEL_MAP[model_name] else: - print(f'[Scheduler] Requested model "{model_name}"is not registered...') - return None - + raise KeyError(f'[Scheduler] Requested model "{model_name}" is not registered...') def get_batchable_input(self): batched_input_tensor = [] @@ -179,9 +186,11 @@ def setup_device(): register_backend_for_device, ) from PyTorchSimFrontend.mlir.mlir_codegen_backend import ( - MLIRScheduling, ExtensionWrapperCodegen, ) + from PyTorchSimFrontend.mlir.mlir_scheduling import ( + MLIRScheduling + ) register_backend_for_device( "extension_device", MLIRScheduling, ExtensionWrapperCodegen ) @@ -206,7 +215,10 @@ def get_compiled_model(self, batched_req: List[Request], request_queue_idx): def is_partition_idle(self, partition_idx): return len(self.launch_model_dicts[partition_idx]) == 0 - def is_idle(self): + def is_any_idle(self, skip_list): + return any([self.is_partition_idle(i) and not skip_list[i] for i in range(self.num_partion)]) + + def is_all_idle(self): return all([self.is_partition_idle(i) for i in range(self.num_partion)]) def prepare_model(self, req_model: SchedulerDNNModel): @@ -228,24 +240,25 @@ def finish_model(self, model : SchedulerDNNModel, output : torch.Tensor): self.finish_req_dict[req] = RequestReturn(RequestReturn.FINISHED) def prepare_launch_kernel(self, kernel, inputs): - key = kernel.future.result() - result_path = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "tmp", hash_prefix(key)) + result_path, runtime_path, _ = kernel(*inputs) onnx_path = os.path.join(result_path, "tile_graph.onnx") - attribute_path = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "tmp", hash_prefix(key), "attribute") + attribute_path = os.path.join(runtime_path, "attribute") attribute_path = self.backend_simulator.create_attribute_file(attribute_path, inputs) return onnx_path, attribute_path def launch_kernel(self, current_cycle, partion_idx=0): # Check partition is busy if self.partition_state[partion_idx] != self.PARTITION_IDLE: - return None + return self.partition_state[partion_idx] result = self.select_kernel(partion_idx) if result == self.SELECT_NOTHING: - return None - + return self.SELECT_NOTHING kernel, inputs = result - onnx_path, attribute_path = self.prepare_launch_kernel(kernel, inputs) + if not isinstance(kernel, str): + onnx_path, attribute_path = self.prepare_launch_kernel(kernel, inputs) + else: + onnx_path, attribute_path = kernel, inputs self.partition_state[partion_idx] = self.PARTITION_BUSY return self.backend_simulator.launch(onnx_path, attribute_path, current_cycle, partion_idx) @@ -265,6 +278,10 @@ def select_kernel(self, partition_idx): try: kernel, inputs = next(target_model) + # For extern call + if isinstance(kernel, str): + return kernel, inputs + # For convolution... if not hasattr(kernel, "future"): nested_gen = kernel(*inputs) @@ -327,11 +344,12 @@ def select_kernel(self, partition_idx): return self.SELECT_NOTHING class Scheduler: + FIFO_ENGINE = 0 RR_ENGINE = 1 - def __init__(self, num_request_queue=1, engine_select=FIFO_ENGINE) -> None: - self.current_time = 0 - self.max_batch = 1 + def __init__(self, num_request_queue=1, max_batch=1, engine_select=FIFO_ENGINE, backend_config=extension_config.CONFIG_TORCHSIM_BACKEND_CONFIG) -> None: + self.current_cycle = 0 + self.max_batch = max_batch self.num_request_queue = num_request_queue self.request_queue : List[List[Request]] = [] for i in range(self.num_request_queue): @@ -339,7 +357,7 @@ def __init__(self, num_request_queue=1, engine_select=FIFO_ENGINE) -> None: self.finish_queue : List[Request] = [] backend_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "PyTorchSimBackend") - self.backend_simulator = BackendSimulator(backend_path, extension_config.CONFIG_TORCHSIM_BACKEND_CONFIG) + self.backend_simulator = BackendSimulator(backend_path, backend_config) self.backend_simulator.interactive_simulation() if engine_select == Scheduler.FIFO_ENGINE: self.execution_engine = FIFOExecutionEngine(self.backend_simulator, self.num_request_queue) @@ -350,11 +368,16 @@ def __init__(self, num_request_queue=1, engine_select=FIFO_ENGINE) -> None: exit(1) def add_request(self, request: Request, request_time=-1): - """register model at timestamp time""" - request_time = self.current_time if request_time == -1 else request_time + """register model at timestamp time + request_time : msec + """ + request_time = self.current_time() if request_time == -1 else request_time request.arrival_time = request_time self.request_queue[request.request_queue_idx].append(request) + def request_empty(self, request_queue_idx): + return len(self.request_queue[request_queue_idx])==0 + def select(self, request_queue_idx=0) -> List[Request]: """ Select 1 request from request_queue in FCFS manner. @@ -364,7 +387,8 @@ def select(self, request_queue_idx=0) -> List[Request]: if not self.request_queue[request_queue_idx]: return candidate_req for req in self.request_queue[request_queue_idx]: - if req.is_arrived(self.current_time) and req.state == Request.QUEUED: + + if self.msec_to_cycle(req.arrival_time) <= self.current_cycle and req.state == Request.QUEUED: candidate_req.append(req) # Stop batching @@ -392,7 +416,7 @@ def nearest_next_reqeust_time(self): return nearest_req, nearest_arrival_time def finish_request(self, req : Request): - req.set_finished(self.current_time) + req.set_finished(self.current_time()) # Free resources req.free_memory() @@ -406,17 +430,22 @@ def finish_request(self, req : Request): f"response time: {response_time} tbt_time: {tbt_time}") def per_schedule(self, request_queue_idx): + # Wait partition is idle + if not self.execution_engine.is_partition_idle(request_queue_idx): + return False + request_list = self.select(request_queue_idx) if not request_list: return False + print(f"[Request issue] partition: {request_queue_idx} batch size: {len(request_list)}", flush=True) for req in request_list: - req.set_start(self.current_time) - + req.set_start(self.current_time()) + print(f"[Request-{req.id} issue] partition: {req.request_queue_idx} " + f"arrival_time: {req.arrival_time} start_time: {req.start_time[0]}", flush=True) # Submit batched request self.execution_engine.submit(request_list, request_queue_idx) - print(f"[Request-{req.id} issue] partition: {req.request_queue_idx} " - f"arrival_time: {req.arrival_time} start_time: {req.start_time[0]}") + return True def check_finish_request(self): @@ -434,51 +463,58 @@ def schedule(self): # Try move to next nearest request time next_req, next_time = self.nearest_next_reqeust_time() - if next_req is None and self.execution_engine.is_idle(): + if next_req is None and self.execution_engine.is_all_idle(): # No request remained... return # Need to forward the time until next_arrival_time - if self.execution_engine.is_idle(): - reason = self.backend_simulator.until(next_time) - self.current_time = self.backend_simulator.cycle() + if self.execution_engine.is_all_idle(): + reason = self.backend_simulator.until(self.msec_to_cycle(next_time)) + self.current_cycle = self.backend_simulator.cycle() else: self.run(next_time) return def run(self, until_time): + req_empty_info = [self.request_empty(i) for i in range(self.execution_engine.num_partion)] def execute_cycle(): + launch_ret_info = [] for i in range(self.execution_engine.num_partion): if self.execution_engine.partition_state[i] == ExecutionEngine.PARTITION_IDLE: - ret = self.execution_engine.launch_kernel(self.current_time, i) + ret = self.execution_engine.launch_kernel(self.current_cycle, i) + launch_ret_info.append(ret) self.check_finish_request() # Check if the stop condition is met - if self.execution_engine.is_idle(): - return -1 + if self.execution_engine.is_any_idle(req_empty_info) or self.execution_engine.is_all_idle(): # Ignore empty request queue + return [] # Schedule jobs and update the current time - result = self.backend_simulator.until(until_time) - self.current_time = self.backend_simulator.cycle() + result_list = self.backend_simulator.until(self.msec_to_cycle(until_time)) + self.current_cycle = self.backend_simulator.cycle() - if result != -1: + for core_idx in result_list: # Kernel is finished. So set idle state - self.execution_engine.partition_state[result] = ExecutionEngine.PARTITION_IDLE + self.execution_engine.partition_state[core_idx] = ExecutionEngine.PARTITION_IDLE + + return result_list - return result + if self.current_cycle >= self.msec_to_cycle(until_time): + until_time = -1 if until_time == -1: - while not self.execution_engine.is_idle(): + while not self.execution_engine.is_any_idle(req_empty_info): result = execute_cycle() + req_empty_info = [self.request_empty(i) for i in range(self.execution_engine.num_partion)] # if result is not -1, schedule new request - if result == -1: + if len(result)==0: break else: - while self.current_time <= until_time and not self.execution_engine.is_idle(): + while self.current_cycle <= self.msec_to_cycle(until_time) and not self.execution_engine.is_all_idle(): result = execute_cycle() # if result is not -1, schedule new request - if result == -1: + if len(result)==0: break return @@ -489,12 +525,22 @@ def is_request_queue_empty(self): return result def is_finished(self): - return self.is_request_queue_empty() and self.execution_engine.is_idle() + if self.is_request_queue_empty() and self.execution_engine.is_all_idle(): + self.backend_simulator.wait() + return True + return False + + def current_time(self): + return self.cycle_to_msec(self.current_cycle) def cycle_to_msec(self, cycle): freq = self.backend_simulator.get_core_freq() return cycle / (freq / 1000) def msec_to_cycle(self, msec): + # We treat -1 as special time + if (msec == -1): + return msec + freq = self.backend_simulator.get_core_freq() - return msec * (freq / 1000) \ No newline at end of file + return int(msec * (freq / 1000)) \ No newline at end of file diff --git a/Simulator/simulator.py b/Simulator/simulator.py index c4a778a5..292c5a9b 100644 --- a/Simulator/simulator.py +++ b/Simulator/simulator.py @@ -1,5 +1,6 @@ import os import shlex +import ctypes import subprocess import re import sys @@ -35,8 +36,6 @@ def load_tensor(self, arg, arg_name, arg_attribute, path): # path = os.path.join(dump_path, arg_name, f'{n_call}.raw') with open(path, 'rb') as f: np_array = np.fromfile(f, dtype=TORCH_TO_NUMPY[arg.dtype]) - if (arg.dtype == torch.bool): - np_array = np.unpackbits(np_array) src_tensor = torch.as_strided(torch.from_numpy(np_array), arg.size(), arg.stride()) arg.copy_(src_tensor.to(dtype=arg.dtype)) @@ -50,10 +49,10 @@ def write_arg(self, arg, path, name): if (isinstance(arg, torch.Tensor)): data_path = os.path.join(dump_path, f'{index}.raw') - tensor = arg.cpu() - t_arr = tensor.numpy().flatten() - if (tensor.dtype == torch.bool): - t_arr = np.packbits(t_arr) + tensor = arg.cpu().detach() + buffer_size = tensor.untyped_storage().size() + buffer = (ctypes.c_char * buffer_size).from_address(tensor.data_ptr()) + t_arr = np.frombuffer(buffer, dtype=tensor.numpy().dtype, count=buffer_size // tensor.element_size()) t_arr.tofile(data_path) else: assert(0) @@ -75,49 +74,50 @@ def dump_args(self, args, arg_attributes, load_path, dump_path): return array_size, file_path - def run_spike(self, args, arg_attributes, path, binary, intermediate_op=None, vectorlane_size=4, spad_info=None, cleanup=False): - load_path = self.path - dump_path = self.path + def run_spike(self, args, arg_attributes, runtime_path, binary, vectorlane_size=4, spad_info=None, cleanup=False, silent_mode=False): + load_path = runtime_path + dump_path = runtime_path - target_binary = os.path.join(path, binary) - objdump = f"riscv64-unknown-elf-objdump -d {target_binary} > {os.path.join(path, 'binary.dump')}" + target_binary = os.path.join(self.path, binary) + objdump = f"riscv64-unknown-elf-objdump -d {target_binary} > {os.path.join(self.path, 'binary.dump')}" kernel_start = f"nm {target_binary} | grep 'kernel' | awk 'NR==1 {{print $1}}'" - kernel_end = f"nm {target_binary} | grep 'kernel' | awk 'NR==1 {{print $1}}' | xargs -I {{}} awk '/{{}}/,0' {os.path.join(path, 'binary.dump')} | grep ret | awk 'NR==1 {{print $1}}' | awk '{{gsub(/:$/, \"\"); print}}'" + kernel_end = f"nm {target_binary} | grep 'kernel' | awk 'NR==1 {{print $1}}' | xargs -I {{}} awk '/{{}}/,0' {os.path.join(self.path, 'binary.dump')} | grep ret | awk 'NR==1 {{print $1}}' | awk '{{gsub(/:$/, \"\"); print}}'" subprocess.run(objdump, shell=True) kernel_start_addr = subprocess.run(kernel_start, shell=True, stdout=subprocess.PIPE).stdout.strip().decode('utf-8') kernel_end_addr = subprocess.run(kernel_end, shell=True, stdout=subprocess.PIPE).stdout.strip().decode('utf-8') - if intermediate_op is not None: - os.makedirs(os.path.join(self.path, "intermediate"), exist_ok=True) - if intermediate_op & 0b10: # input comes from intermediate - load_path = os.path.join(self.path, "intermediate") - if intermediate_op & 0b01: # output dumps to intermediate - dump_path = os.path.join(self.path, "intermediate") - for name, attr in arg_attributes: - if attr[0] == 2: - os.makedirs(os.path.join(dump_path, name), exist_ok=True) - _, file_path = self.dump_args(args, arg_attributes, load_path, dump_path) file_path_str = ' '.join(file_path) # Set hardware information - spad_option = f"--scratchpad-base-paddr={spad_info['spad_paddr']} " + \ + spad_option = f"-m0x{0x80000000:x}:0x{100<<30:x},0x{spad_info['spad_paddr']:x}:0x{spad_info['spad_size']*vectorlane_size:x} " + \ + f"--scratchpad-base-paddr={spad_info['spad_paddr']} " + \ f"--scratchpad-base-vaddr={spad_info['spad_vaddr']} " + \ - f"--scratchpad-size={spad_info['spad_size']}" + f"--scratchpad-size={spad_info['spad_size']} " vectorlane_option = f"--vectorlane-size={vectorlane_size}" kernel_address = f"--kernel-addr={kernel_start_addr}:{kernel_end_addr}" - base_addr = f"--base-path={path}" - run = f'spike --isa rv64gcv --varch=vlen:256,elen:64 {vectorlane_option} {spad_option} {kernel_address} {base_addr} /workspace/riscv-pk/build/pk {target_binary} {file_path_str}' - - print("[SpikeSimulator] cmd> ", run) + base_path= f"--base-path={runtime_path}" + os.makedirs(os.path.join(runtime_path, "indirect_access"), exist_ok=True) + os.makedirs(os.path.join(runtime_path, "dma_access"), exist_ok=True) + run = f'spike --isa rv64gcv --varch=vlen:256,elen:64 {vectorlane_option} {spad_option} {kernel_address} {base_path} /workspace/riscv-pk/build/pk {target_binary} {file_path_str}' + if not silent_mode: + print("[SpikeSimulator] cmd> ", run) run_cmd = shlex.split(run) try: - subprocess.check_call(run_cmd) + stdout_setting = subprocess.DEVNULL if silent_mode else None + stderr_setting = subprocess.DEVNULL if silent_mode else None + subprocess.check_call(run_cmd, stdout=stdout_setting, stderr=stderr_setting) except subprocess.CalledProcessError as e: print("[SpikeSimulator] Command failed with exit code", e.returncode) - print("[SpikeSimulator] Error output:", e.output) - assert(0) + error_msg = "" + if e.returncode == 200: + error_msg = "INVALID_SPAD_ACCESS" + elif e.returncode == 201: + error_msg = "STACK_OVERFLOW" + else: + error_msg = "UNKNOWN_ERROR" + raise RuntimeError(f"{error_msg}") for (arg_name, arg_attribute), arg, path in zip(arg_attributes, args, file_path): if LLVMKernelArgs.is_llvm_arg_out(arg_attribute[0]): @@ -128,11 +128,27 @@ def run_spike(self, args, arg_attributes, path, binary, intermediate_op=None, ve if os.path.exists(path): os.remove(path) + @staticmethod + def get_runtime_dump_path(base_path, prefix="runtime", zfill=4): + indices = [ + int(match.group(1)) + for d in os.listdir(base_path) + if (match := re.fullmatch(rf"{prefix}_(\d{{{zfill}}})", d)) + ] + + max_index = max(indices, default=-1) + next_index = max_index + 1 + folder_name = f"{prefix}_{str(next_index).zfill(zfill)}" + full_path = os.path.join(base_path, folder_name) + + os.makedirs(full_path) + return full_path + class CycleSimulator(): def __init__(self) -> None: pass - def compile_and_simulate(self, target_binary, array_size, vectorlane_size): + def compile_and_simulate(self, target_binary, array_size, vectorlane_size, silent_mode=False): def show_progress(): i = 0 while not finished: @@ -143,10 +159,11 @@ def show_progress(): print("") dir_path = os.path.join(os.path.dirname(target_binary), "m5out") - gem5_cmd = [extension_config.CONFIG_GEM5_PATH, "-d", dir_path, extension_config.CONFIG_GEM5_SCRIPT_PATH, "-c", target_binary, "--vlane", str(vectorlane_size)] + gem5_cmd = [extension_config.CONFIG_GEM5_PATH, "-r", "--stdout-file=sto.log", "-d", dir_path, extension_config.CONFIG_GEM5_SCRIPT_PATH, "-c", target_binary, "--vlane", str(vectorlane_size)] try: # Create progress thread - if not extension_config.CONFIG_BACKENDSIM_DRYRUN: + is_dryrun = int(os.environ.get('BACKENDSIM_DRYRUN', default=False)) or silent_mode + if not is_dryrun: print("[Gem5Simulator] cmd> ", " ".join(gem5_cmd)) finished = False progress_thread = threading.Thread(target=show_progress) @@ -173,6 +190,7 @@ def show_progress(): class BackendSimulator(): BACKEND_RESULT_PATH_KEY = "BACKEND_RESULT_PATH" FINISH_STR = "Simulation Finished" + ALLOC_POOL = dict() # For eagermode buffer plan def __init__(self, backend_path, config_path, vectorlane_size=-1) -> None: self.base_dir = backend_path self.config_path = config_path @@ -186,7 +204,7 @@ def get_backend_command(self): cmd = f"{bin} --config {config}" return cmd - def simulation(self, model_path, attribute_path=""): + def simulation(self, model_path, attribute_path="", silent_mode=False): def show_progress(): i = 0 while not finished: @@ -200,21 +218,25 @@ def show_progress(): cmd += f" --log_level {extension_config.CONFIG_BACKENDSIM_DEBUG_LEVEL}" if attribute_path: cmd = f"{cmd} --attributes_list {attribute_path}" - print("[BackendSimulator] cmd> ", cmd) + if not silent_mode: + print("[BackendSimulator] cmd> ", cmd) # Create progress thread - finished = False - progress_thread = threading.Thread(target=show_progress) - progress_thread.start() + if not silent_mode: + finished = False + progress_thread = threading.Thread(target=show_progress) + progress_thread.start() try: result = subprocess.check_output(shlex.split(cmd)) - finished = True - progress_thread.join() + if not silent_mode: + finished = True + progress_thread.join() except subprocess.CalledProcessError as e: - finished = True - progress_thread.join() - print("[BackendSimulator] Command failed with exit code", e.returncode) - print("[BackendSimulator] Error output:", e.output) + if not silent_mode: + finished = True + progress_thread.join() + print("[BackendSimulator] Command failed with exit code", e.returncode) + print("[BackendSimulator] Error output:", e.output) assert 0 result_path = extension_config.CONFIG_BACKEND_RESULT_PATH_KEY if result_path is None: @@ -226,7 +248,7 @@ def show_progress(): result_path = os.path.join(result_path, file_name) with open(result_path, "w") as f: f.write(result.decode()) - print(f'[BackendSimulator] Simulation of "{model_path}" is stored to "{result_path}"') + print(f'[BackendSimulator] Simulation of "{model_path}" is stored to "{result_path}"') return result_path def interactive_simulation(self): @@ -248,13 +270,23 @@ def interactive_simulation(self): def stop(self): if self.process: self.process.terminate() + self.process.wait() self.process = None + print("[BackendSimulator] Simulator stopped.") + + def wait(self): + if self.process: + print("[BackendSimulator] Waiting for simulation to complete...") + self.quit() + self.process.wait() + self.process = None + print("[BackendSimulator] Simulation completed.") def send_command(self, command): if self.process: try: if not extension_config.CONFIG_BACKENDSIM_DRYRUN: - print(command) + print(command, flush=True) self.process.stdin.write(command + '\n') self.process.stdin.flush() ret = self.process.stderr.readline().strip() @@ -281,10 +313,30 @@ def cycle(self): def until(self, until_cycle): command = f"until {until_cycle}" ret = self.send_command(command) - return int(ret.split(" ")[-1]) + bitmap = int(ret.split(" ")[-1]) + indices = [] + for i in range(64): + if (bitmap >> i) & 1: + indices.append(i) + return indices + + def quit(self): + command = "quit" + ret = self.send_command(command) + return + + @classmethod + def sram_alloc(cls, buf_name, addr_range): + cls.ALLOC_POOL[buf_name] = addr_range + + @classmethod + def sram_dealloc(cls, buf_name, addr_range): + if buf_name in cls.ALLOC_POOL: + del cls.ALLOC_POOL[buf_name] def create_attribute_file(self, attribute_path, inputs, **kwargs): address_info = {} + sram_buffer = {} json_content = {} os.makedirs(attribute_path, exist_ok=True) index = str(len(os.listdir(attribute_path))) @@ -294,61 +346,9 @@ def create_attribute_file(self, attribute_path, inputs, **kwargs): address_info[f"arg{idx}"] = tensor.data_ptr() json_content["address_info"] = address_info - if "tile_size" in kwargs and len(kwargs['tile_size'])==3 and kwargs['tile_size'][0] != 1: - # GEMM - import copy - zero_skip = {} - input, weight = inputs[:2] - M, N, K = kwargs['tile_size'] - - padded_input = copy.deepcopy(input.cpu()) - padded_weight = copy.deepcopy(weight.cpu()) - - original_input_shape = input.shape - original_weight_shape = weight.shape - - # Initialize padding for all dimensions - pad_input = [(0, 0)] * input.ndim - pad_weight = [(0, 0)] * weight.ndim - - if input.ndim == 2: - # 2D tensor: (Height, Width) - pad_input[0] = (0, M - original_input_shape[0] if original_input_shape[0] < M else 0) - pad_input[1] = (0, K - original_input_shape[1] if original_input_shape[1] < K else 0) - elif input.ndim == 3: - # 3D tensor: (Depth, Height, Width) - pad_input[1] = (0, M - original_input_shape[1] if original_input_shape[1] < M else 0) - pad_input[2] = (0, K - original_input_shape[2] if original_input_shape[2] < K else 0) - - if weight.ndim == 2: - # 2D tensor: (Height, Width) - pad_weight[0] = (0, K - original_weight_shape[0] if original_weight_shape[0] < K else 0) - pad_weight[1] = (0, N - original_weight_shape[1] if original_weight_shape[1] < N else 0) - elif weight.ndim == 3: - # 3D tensor: (Depth, Height, Width) - pad_weight[1] = (0, K - original_weight_shape[1] if original_weight_shape[1] < K else 0) - pad_weight[2] = (0, N - original_weight_shape[2] if original_weight_shape[2] < N else 0) - - # Apply padding - padded_input = np.pad( - padded_input, - pad_width=pad_input, - mode='constant', - constant_values=0 - ) - - padded_weight = np.pad( - padded_weight, - pad_width=pad_weight, - mode='constant', - constant_values=0 - ) - - #input_zero_pos = self.find_zero_sub_tensors(padded_input) - weight_zero_pos = self.find_zero_sub_tensors(padded_weight) - #zero_skip["arg0"] = input_zero_pos - zero_skip["arg1"] = weight_zero_pos - json_content["zero_skip"] = zero_skip + for buf_name, range in self.ALLOC_POOL.items(): + sram_buffer[buf_name] = range + json_content["sram_alloc"] = sram_buffer with open(attribute_path, "w") as f: json.dump(json_content, f, indent=4) @@ -440,10 +440,13 @@ def get_result_from_file(result_path): if 'DRAM: AVG BW Util' in line: avg_dram_bw = float(re.search(r'AVG BW Util (\d+\.?\d*)%', line).group(1)) + if 'Total execution cycle' in line: + total_cycle = int(re.search(r'Total execution cycle: (\d+)', line).group(1)) + # Parse total simulation time if 'Simulation time' in line: simulation_time = float(re.search(r'Simulation time: (\d+\.?\d*) seconds', line).group(1)) - return core_metrics, dram_channel_bw, avg_dram_bw, simulation_time + return core_metrics, dram_channel_bw, avg_dram_bw, simulation_time, total_cycle if __name__ == "__main__": sim = BackendSimulator("/workspace/PyTorchSim/PyTorchSimBackend", "/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_128x128_c4_simple_noc_tpuv4.json") diff --git a/debug/gem5.sh b/debug/gem5.sh new file mode 100755 index 00000000..b4791775 --- /dev/null +++ b/debug/gem5.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +path="$1" + +# if $TORCHSIM_DIR/debug does not exist, create it +if [ ! -d "$TORCHSIM_DIR/debug" ]; then + mkdir $TORCHSIM_DIR/debug +fi +if [ ! -d "$TORCHSIM_DIR/debug/out" ]; then + mkdir $TORCHSIM_DIR/debug/out +fi + +/workspace/gem5/build/RISCV/gem5.debug \ +--debug-flags=Fetch,Decode,MinorExecute \ +-d m5out $TORCHSIM_DIR/gem5_script/script_systolic.py \ +-c $path/cycle_bin --vlane 128 > $TORCHSIM_DIR/debug/out/gem5_log.txt + +# grep ticks of M5Op +ticks=($(grep "Changing stream on" $TORCHSIM_DIR/debug/out/gem5_log.txt | grep "M5Op" | awk '{print $1}')) + +# trim only cycle number +for i in "${!ticks[@]}"; do + ticks[$i]=${ticks[$i]::-4} +done + +# extract instruction +python $TORCHSIM_DIR/debug/pipeline.py --input $TORCHSIM_DIR/debug/out/gem5_log.txt --min ${ticks[-2]} --max ${ticks[-1]} > $TORCHSIM_DIR/debug/out/gem5_inst.txt \ No newline at end of file diff --git a/debug/pipeline.py b/debug/pipeline.py new file mode 100644 index 00000000..88f6a544 --- /dev/null +++ b/debug/pipeline.py @@ -0,0 +1,257 @@ +import argparse +import os +args = argparse.ArgumentParser() +args.add_argument('--input', type=str, default='input.txt') +args.add_argument('--min', type=int, default=0) +args.add_argument('--max', type=int, default=0) + +parsed_args = args.parse_args() + +filename = parsed_args.input +min_value = parsed_args.min +max_value = parsed_args.max - 1 +#max_addr = '0x10378' # address of the last M5Op + +def filter_exec_in_file(filename, match_string): + result = [] + prev_line = "prev line" + + with open(filename, 'r') as file: + lines = file.readlines() + + cur_stream = 0 + index_dict = dict() + + for line in lines: + parts = line.strip().split(' ') + + if len(parts) >= 3: + try: + num = int(int(parts[0][:-1])/1000) + first_string = parts[1][:-1] + + addr = "temp addr" + + if min_value <= num <= max_value and first_string == match_string: + idx = 0 + +# if parts[2] == "Pushing": # Pushing mem inst: %s pc: addr (inst) +# result.append(f"{num}|execute I: {parts[-2]} {parts[-1]} this is a memory ref. instruction.") + if parts[2] == "Issuing": + if parts[3] == "inst:": # Issuing inst: %s pc: addr (inst) into FU %d + stream = int(parts[-7].split('.')[0].split('/')[-1]) + index = int(parts[-7].split('.')[-1]) + addr = parts[-5] + inst = parts[-4] + +# if addr > max_addr: +# continue + + if (cur_stream != stream): + index_dict.clear() + + if addr in index_dict: + idx = index - index_dict[addr] + else: + index_dict[addr] = index + + result.append(f"{num}|2_execute I: {addr} {inst} {idx}") + + cur_stream = stream + elif parts[3] == "mem": # Issuing mem ref early inst: %s pc: addr (inst) instToWaitFor: %d + stream = int(parts[-6].split('.')[0].split('/')[-1]) + index = int(parts[-6].split('.')[-1]) + addr = parts[-4] + inst = parts[-3] + +# if addr > max_addr: +# continue + + if (cur_stream != stream): + index_dict.clear() + + if addr in index_dict: + idx = index - index_dict[addr] + else: + index_dict[addr] = index + + result.append(f"{num}|1_execute M: {addr} {inst} {idx}") + + cur_stream = stream + else : # Issuing %s to %d # Prev : Trying to issue inst: %s pc: addr (inst) to FU %d + prev_parts = prev_line.strip().split(' ') + + stream = int(prev_parts[-7].split('.')[0].split('/')[-1]) + index = int(prev_parts[-7].split('.')[-1]) + addr = prev_parts[-5] + inst = prev_parts[-4] + +# if addr > max_addr: +# continue + + if (cur_stream != stream): + index_dict.clear() + + if addr in index_dict: + idx = index - index_dict[addr] + else: + index_dict[addr] = index + + result.append(f"{num}|2_execute I: {addr} {inst} {idx}") + + cur_stream = stream +# elif parts[2] == "Discarding": # Discarding inst: %s pc: addr (inst) as its stream state was unexpected, expected: %d +# result.append(f"") + elif parts[2] == "Completed": # Completed inst: %s pc: addr (inst) + stream = int(parts[-4].split('.')[0].split('/')[-1]) + index = int(parts[-4].split('.')[-1]) + addr = parts[-2] + inst = parts[-1] + +# if addr > max_addr: +# continue + + if (cur_stream != stream): + continue + + if addr in index_dict: + idx = index - index_dict[addr] + else: + index_dict[addr] = index + + result.append(f"{num}|0_execute C: {addr} {inst} {idx}") + + cur_stream = stream + else: + prev_line = line + continue + except ValueError: + continue + prev_line = line + + return result + +def filter_decode_in_file(filename, match_string): + result = [] + + with open(filename, 'r') as file: + lines = file.readlines() + + for line in lines: + parts = line.strip().split(' ') + + if len(parts) >= 3: + try: + num = int(int(parts[0][:-1])/1000) + first_string = parts[1][:-1] + + if min_value <= num <= max_value and first_string == match_string: + if parts[2] == "Microop": + inst = parts[-1] + addr = parts[-2] + + if inst == '(vnop)': + continue + +# if addr > max_addr: +# continue + + result.append(f"{num}|3_decode: {parts[-2]} {parts[-1]} {parts[-6][-5]}") + elif parts[2] == "Passing": + addr = parts[7] + +# if addr > max_addr: +# continue + + result.append(f"{num}|3_decode: {parts[7]} {parts[8]}") + else: + continue + except ValueError: + continue + + return result + +def filter_fetch2_in_file(filename, match_string): + result = [] + + with open(filename, 'r') as file: + lines = file.readlines() + + for line in lines: + parts = line.strip().split(' ') + + if len(parts) >= 3: + try: + num = int(int(parts[0][:-1])/1000) + first_string = parts[1][:-1] + + if min_value <= num <= max_value and first_string == match_string: + if parts[2] == "Instruction": # Instruction extracted from line ~ + addr = parts[-2] + +# if addr > max_addr: +# continue + + result.append(f"{num}|4_fetch2: {parts[-2]} {parts[-1]}") + else: + continue + except ValueError: + continue + + return result + +def filter_fetch1_in_file(filename, match_string): + result = [] + + with open(filename, 'r') as file: + lines = file.readlines() + + temp = "start_addr" + + for line in lines: + parts = line.strip().split(' ') + + if len(parts) >= 3: + try: + num = int(int(parts[0][:-1])/1000) + first_string = parts[1][:-1] + + if min_value <= num <= max_value and first_string == match_string: + if parts[2] == "Inserting": + addr = parts[-7] + temp = addr + +# if addr > max_addr: +# continue + + result.append(f"{num}|5_fetch1: {addr} ~ ") + elif parts[2] == "Processing": +# if temp > max_addr: +# continue + + result.append(f"{num}|5_fetch1: {temp} ~ ") + else: + continue + except ValueError: + continue + + return result + + + +filtered_fetch1 = filter_fetch1_in_file(filename, 'system.cpu.fetch1') +filtered_fetch2 = filter_fetch2_in_file(filename, 'system.cpu.fetch2') +filtered_decode = filter_decode_in_file(filename, 'system.cpu.decode') +filtered_exec = filter_exec_in_file(filename, 'system.cpu.execute') + +for line in filtered_exec: + print(line) + +for line in filtered_decode: + print(line) + +for line in filtered_fetch2: + print(line) + +for line in filtered_fetch1: + print(line) diff --git a/experiments/BERT.py b/experiments/BERT.py new file mode 100644 index 00000000..3534505d --- /dev/null +++ b/experiments/BERT.py @@ -0,0 +1,57 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +import argparse +import datetime + +def run_BERT(size, input_seq, config): + from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request + # from tests.test_transformer import EncoderBlock + from tests.Fusion.test_transformer_fusion import EncoderBlock + scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) + device = scheduler.execution_engine.module.custom_device() + + hidden_dim = {'base': 768, 'large': 1024, 'xlarge': 2048} + embedding_size = {'base': 768, 'large': 1024, 'xlarge': 2048} + heads = {'base': 12, 'large': 16, 'xlarge': 32} # hidden/64 https://arxiv.org/pdf/1909.11942 + cpu_query = torch.randn(input_seq, hidden_dim[size]) + encoder_block = EncoderBlock(embedding_size[size], heads[size]).eval() + + query = cpu_query.clone().to(device=device) + opt_fn = torch.compile(dynamic=False)(encoder_block.to(device=device)) + + SchedulerDNNModel.register_model(f"BERT-{size}", opt_fn) + request = Request(f"BERT-{size}", [query], [], request_queue_idx=0) + scheduler.add_request(request, request_time=0) + + # Run scheduler + while not scheduler.is_finished(): + with torch.no_grad(): + scheduler.schedule() + + print(f"BERT-{size} Simulation Done") + +if __name__ == "__main__": + import os + import sys + base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') + config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path FIXME: gem5 result is different as directoy name + sys.path.append(base_dir) + args = argparse.ArgumentParser() + args.add_argument('--size', type=str, default='base') + args.add_argument('--dump_path', type=str, default='results') + args.add_argument('--input_size', type=int, default=512) + args = args.parse_args() + size = args.size + input_seq = args.input_size + result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"BERT_{size}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") + # setting environment variables + os.environ['TORCHSIM_DUMP_PATH'] = result_path + # only timing simulation + os.environ['TORCHSIM_VALIDATION_MODE'] = "0" + if 'BACKENDSIM_SPIKE_ONLY' in os.environ: + del os.environ['BACKENDSIM_SPIKE_ONLY'] + + run_BERT(size, input_seq, config) diff --git a/experiments/artifact/baseline_cycle.csv b/experiments/artifact/baseline_cycle.csv new file mode 100644 index 00000000..afd795e4 --- /dev/null +++ b/experiments/artifact/baseline_cycle.csv @@ -0,0 +1,24 @@ +Workload,TPUv3,mNPUSim,Timeloop,Maestro,SCALE-Sim v3,TOGSim(Ours) +gemm_256x256x256,3101,2426,512,522,1020, +gemm_512x512x512,9490,8584,4096,4162,6128, +gemm_1024x1024x1024,47099,32625,32768,33282,40896, +gemm_2048x2048x2048,317435,3046069,262144,266242,294656, +conv_64x56x56x64x64x3x1x1,1496802,1076160,451584,1683074,610880, +conv_64x28x28x128x128x3x1x1,241935,577408,225792,391042,269952, +conv_64x14x14x256x256x3x1x1,246790,540160,451584,167426,327424, +conv_64x7x7x512x512x3x1x1,247383,1128192,903168,117762,622336, +layernorm_512x768,24895,,,,, +layernorm_2048x768,98234,,,,, +layernorm_8192x768,389863,,,,, +softmax_512x512,12902,,,,, +softmax_2048x2048,169750,,,,, +softmax_8192x8192,2700994,,,,, +attention_12x512x64,110093,101964,30720,49944,48912, +attention_16x512x64,145425,135952,40960,66592,65216, +attention_32x512x64,288250,271904,196608,133184,130432, +resnet18,844524,232699,495520,1518676,206111, +resnet50,1094428,524398,721016,1596151,398447, +bert_base,305077,493756,169984,149820,706886, +bert_large,445317,785205,286720,249676,237056, +bert_xlarge,1283925,3030001,1507328,898700,990016, + diff --git a/experiments/artifact/baseline_latency.csv b/experiments/artifact/baseline_latency.csv new file mode 100644 index 00000000..159b10a3 --- /dev/null +++ b/experiments/artifact/baseline_latency.csv @@ -0,0 +1,11 @@ +Workload,Accel-Sim,mNPUSim,PyTorchSim-SN,PyTorchSim-CA,PyTorchSim-ILS +conv_1x56x56x64x64x3x1x1,38.86346,3.915,,, +conv_1x28x28x128x128x3x1x1,44.58898,2.588,,, +conv_1x14x14x256x256x3x1x1,70.53904,3.162,,, +conv_1x7x7x512x512x3x1x1,78.92694,1.527,,, +gemm_512x512x512,53.5097,5.767,,, +gemm_1024x1024x1024,150.8592,40.946,,, +gemm_2048x2048x2048,951.4306,5157.396,,, +resnet50,1222.504,294.242,,, +bert_large,1436.558,350.84,,, + diff --git a/experiments/artifact/cycle_validation/run_cycle.sh b/experiments/artifact/cycle_validation/run_cycle.sh new file mode 100755 index 00000000..a32cd0a6 --- /dev/null +++ b/experiments/artifact/cycle_validation/run_cycle.sh @@ -0,0 +1,85 @@ +#!/bin/bash +set -e + +export TORCHSIM_CONFIG=$TORCHSIM_DIR/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json +LOG_DIR=$TORCHSIM_DIR/experiments/artifact/logs +mkdir -p $LOG_DIR + +# Matmul +for sz in "256 256 256" "512 512 512" "1024 1024 1024" "2048 2048 2048"; do + name="gemm_${sz// /x}" + echo "" + echo "===================================================" + echo "[*] Running Matmul size=$sz" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/gemm.py --size $sz | tee $LOG_DIR/${name}.log +done + +# Conv +for sz in \ + "1 56 56 64 64 3 1 1" \ + "1 28 28 128 128 3 1 1" \ + "1 14 14 256 256 3 1 1" \ + "1 7 7 512 512 3 1 1" \ + "64 56 56 64 64 3 1 1" \ + "64 28 28 128 128 3 1 1" \ + "64 14 14 256 256 3 1 1" \ + "64 7 7 512 512 3 1 1"; do + name="conv_${sz// /x}" + echo "" + echo "===================================================" + echo "[*] Running Conv size=$sz" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/conv.py --size $sz | tee $LOG_DIR/${name}.log +done + +# Attention +for sz in "12 512 64" "16 512 64" "32 512 64"; do + name="attention_${sz// /x}" + echo "" + echo "===================================================" + echo "[*] Running Attention size=$sz" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/attention.py --size $sz | tee $LOG_DIR/${name}.log +done + +# LayerNorm +for sz in "512 768" "2048 768" "8192 768"; do + name="layernorm_${sz// /x}" + echo "" + echo "===================================================" + echo "[*] Running LayerNorm size=$sz" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/layernorm.py --size $sz | tee $LOG_DIR/${name}.log +done + +# Softmax +for sz in "512 512" "2048 2048" "8192 8192"; do + name="softmax_${sz// /x}" + echo "" + echo "===================================================" + echo "[*] Running Softmax size=$sz" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/softmax.py --size $sz | tee $LOG_DIR/${name}.log +done + +# ResNet +for model in "resnet18" "resnet50"; do + echo "" + echo "===================================================" + echo "[*] Running $model" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/${model}.py | tee $LOG_DIR/${model}.log +done + +# BERT +for model in "base" "large" "xlarge"; do + echo "" + echo "===================================================" + echo "[*] Running BERT size=$model" + echo "===================================================" + python3 $TORCHSIM_DIR/experiments/BERT.py --size $model | tee $LOG_DIR/bert_${model}.log +done + +# Cycle Summary +python3 $TORCHSIM_DIR/experiments/artifact/cycle_validation/summary_cycle.py | tee "$TORCHSIM_DIR/experiments/artifact/cycle_validation/summary_cycle.out" \ No newline at end of file diff --git a/experiments/artifact/cycle_validation/summary_cycle.py b/experiments/artifact/cycle_validation/summary_cycle.py new file mode 100644 index 00000000..529d0161 --- /dev/null +++ b/experiments/artifact/cycle_validation/summary_cycle.py @@ -0,0 +1,143 @@ +import os +import math +import csv +import re +import matplotlib.pyplot as plt +import numpy as np + +TORCHSIM_DIR = os.environ.get("TORCHSIM_DIR", ".") +LOG_DIR = os.path.join(TORCHSIM_DIR, "experiments/artifact/logs") +BASELINE_CSV = os.path.join(TORCHSIM_DIR, "experiments/artifact/baseline_cycle.csv") + +def plot_error_bars(data: dict, filename: str): + colors = { + 'SCALE-Sim v3': 'gold', + 'mNPUSim': 'orange', + 'Timeloop': 'green', + 'Maestro': 'violet', + 'PyTorchSim-SN': 'royalblue', + } + + labels = list(data.keys()) + num_sims = len(colors) + bar_width = 1 + fig, ax = plt.subplots(figsize=(48, 8)) + + grouped_data = {sim: [[], []] for sim in colors} + x_pos = [] + x_offset = 0 + + for key, value in data.items(): + for i, (sim, color) in enumerate(colors.items()): + grouped_data[sim][0].append(value[i]) + grouped_data[sim][1].append(x_offset + bar_width * i) + x_pos.append(x_offset + bar_width * (num_sims // 2)) + x_offset += bar_width * (num_sims + 2) + + for sim, (heights, xpos) in grouped_data.items(): + bars = ax.bar(xpos, heights, width=bar_width, color=colors[sim], label=sim, edgecolor='black') + mae_val = heights[-1] + ax.text( + xpos[-1], + mae_val + 2 if mae_val >= 0 else mae_val - 6, + f'{mae_val:.1f}%', + ha='center', + va='bottom' if mae_val >= 0 else 'top', + fontsize=9, + rotation=90 + ) + + ax.set_xticks(x_pos) + ax.set_xticklabels(labels, rotation=20, ha='right') + ax.set_ylim(-100, 150) + ax.set_yticks(np.arange(-100, 151, 50)) + ax.yaxis.grid(True, linestyle='--', linewidth=0.5, alpha=0.7) + ax.legend() + + plt.savefig(filename) + plt.close() + print(f"Saved plot to {filename}") + +def format_with_error(value, ref, error_list=None): + try: + if value == "" or ref == "" or float(ref) == 0: + return "N/A", 0.0 + val = float(value) + ref = float(ref) + err = ((val - ref) / ref) * 100 + if error_list is not None: + error_list.append(abs(err)) + val_str = f"{int(val):>7}" + err_str = f"{err:+.2f}%" + return f"{val_str} ({err_str:>8})", err + except (ValueError, TypeError): + return "N/A", 0.0 + +def compute_mae(errors): + if not errors: + return "N/A" + abs_errors = [abs(err) for err in errors] + return sum(abs_errors) / len(errors) + +if __name__ == "__main__": + # 1. Generate cycle_map + cycle_map = {} + for file in os.listdir(LOG_DIR): + if file.endswith(".log"): + full_path = os.path.join(LOG_DIR, file) + name = file[:-4] + with open(full_path, errors="ignore") as f: + for line in f: + match = re.search(r"Total execution cycle:\s*([0-9]+)", line) + if match: + cycle_map[name] = int(match.group(1)) + break + + # Error list init + mnpusim_errors = [] + timeloop_errors = [] + maestro_errors = [] + scalesim_errors = [] + togsim_errors = [] + + # Plot data + plot_data ={} + + # Header + print("[*] Summary of Total Execution Cycles with TPUv3-relative (%) Error") + print("=" * 190) + print(f"{'Workload':>30} {'TPUv3':>25} {'mNPUSim':>25} {'Timeloop':>25} {'Maestro':>25} {'SCALE-Sim v3':>25} {'TOGSim(Ours)':>25}") + print("=" * 190) + + with open(BASELINE_CSV, newline="") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + workload = row["Workload"].lstrip('\ufeff') + tpv3 = row["TPUv3"] + + mnpusim, mnpusim_err = format_with_error(row["mNPUSim"], tpv3, mnpusim_errors) + timeloop, timeloop_err = format_with_error(row["Timeloop"], tpv3, timeloop_errors) + maestro, maestro_err = format_with_error(row["Maestro"], tpv3, maestro_errors) + scalesim, scalesim_err = format_with_error(row["SCALE-Sim v3"], tpv3, scalesim_errors) + + togsim_val = cycle_map.get(workload, "") + if "softmax" in workload or "layernorm" in workload: + togsim_str, togsim_err = format_with_error(str(togsim_val), tpv3, []) + else: + togsim_str, togsim_err = format_with_error(str(togsim_val), tpv3, togsim_errors) + plot_data[workload] = [scalesim_err, mnpusim_err, timeloop_err, maestro_err, togsim_err] + print(f"{workload:>30} {tpv3:>25} {mnpusim:>25} {timeloop:>25} {maestro:>25} {scalesim:>25} {togsim_str:>25}") + + # MAE row + mae_mnpusim = compute_mae(mnpusim_errors) + mae_timeloop = compute_mae(timeloop_errors) + mae_maestro = compute_mae(maestro_errors) + mae_scalesim = compute_mae(scalesim_errors) + mae_togsim = compute_mae(togsim_errors) + plot_data["MAE"] = [mae_scalesim, mae_mnpusim, mae_timeloop, mae_maestro, mae_togsim] + print("=" * 190) + print(f"{'[*] Mean Absolute Error(%)':>30} {'0.00%':>25} {mae_mnpusim:>24.2f}% {mae_timeloop:>24.2f}% {mae_maestro:>24.2f}% {mae_scalesim:>24.2f}% {mae_togsim:>24.2f}%") + + # Plot the error bars + path = os.path.join(TORCHSIM_DIR, "experiments/artifact/cycle_validation/cycle_validation.png") + plot_error_bars(plot_data, path) diff --git a/experiments/artifact/speedup/run_speedup.sh b/experiments/artifact/speedup/run_speedup.sh new file mode 100755 index 00000000..7d0c0da2 --- /dev/null +++ b/experiments/artifact/speedup/run_speedup.sh @@ -0,0 +1,102 @@ +#!/bin/bash +LOG_DIR=$TORCHSIM_DIR/experiments/artifact/logs +CONFIG_DIR="$TORCHSIM_DIR/PyTorchSimBackend/configs" +SIMULATOR_BIN="$TORCHSIM_DIR/PyTorchSimBackend/build/bin/Simulator" + +configs=( + "systolic_ws_128x128_c2_simple_noc_tpuv3.json" + "systolic_ws_128x128_c2_booksim_tpuv3.json" +) + +target_list=( + "gemm_512x512x512" + "gemm_1024x1024x1024" + "gemm_2048x2048x2048" + "conv_1x56x56x64x64x3x1x1" + "conv_1x28x28x128x128x3x1x1" + "conv_1x14x14x256x256x3x1x1" + "conv_1x7x7x512x512x3x1x1" + "resnet50" + "bert_large" +) + +TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") +output_dir="$TORCHSIM_DIR/experiments/artifact/speedup/results" +mkdir -p "$output_dir" + +echo "[*] Scanning log files in: $LOG_DIR" +echo "" + +for log_file in "$LOG_DIR"/*.log; do + filename=$(basename "$log_file") + workload="${filename%.log}" + + if [[ ! " ${target_list[@]} " =~ " ${workload} " ]]; then + continue + fi + echo "==> Workload: $workload" + + declare -a ONNX_ATTR_PAIRS=() + + # === Grep launch line === + while IFS= read -r line; do + if [[ "$line" == launch* ]]; then + read -r _ onnx_path attr_path _ <<< "$line" + ONNX_ATTR_PAIRS+=("$onnx_path|$attr_path") + fi + done < "$log_file" + + # Normal configs + for config in "${configs[@]}"; do + output_file="$output_dir/${workload}_${config}.txt" + echo "Running with config=$config" + echo "===== config=$config | model=$workload =====" >> "$output_file" + sum_all_iters=0.0 + iter_count=0 + + # === Run 5 iterations === + for iter in {1..5}; do + echo "[Iter $iter] Running simulation for workload=$workload config=$config" + cmd="" + for pair in "${ONNX_ATTR_PAIRS[@]}"; do + IFS="|" read -r onnx_path attr_path <<< "$pair" + cmd+=" $SIMULATOR_BIN --config $CONFIG_DIR/$config --models_list $onnx_path --attributes_list $attr_path;" + done + + output=$(bash -c "$cmd") + sim_times=$(echo "$output" | grep "Simulation time:" | sed -E 's/.*Simulation time: ([0-9]+\.[0-9]+).*/\1/') + + if [[ -n "$sim_times" ]]; then + sum_per_iter=0.0 + while IFS= read -r sim_time; do + echo "Iteration $iter: simulation_time = $sim_time" >> "$output_file" + sum_per_iter=$(awk -v a="$sum_per_iter" -v b="$sim_time" 'BEGIN {printf "%.6f", a + b}') + done <<< "$sim_times" + + echo "Iteration $iter: total_simulation_time = $sum_per_iter" >> "$output_file" + sum_all_iters=$(awk -v a="$sum_all_iters" -v b="$sum_per_iter" 'BEGIN {printf "%.6f", a + b}') + iter_count=$((iter_count + 1)) + else + echo "Iteration $iter: No simulation time found." >> "$output_file" + fi + done + + # === Final average === + if [[ $iter_count -gt 0 ]]; then + avg=$(awk -v total="$sum_all_iters" -v n="$iter_count" 'BEGIN {printf "%.6f", total / n}') + echo "Average simulation time for $workload with config $config: $avg seconds" + echo "Average simulation time = $avg" >> "$output_file" + else + echo "No valid simulation times found for config $config" + echo "Average simulation time = NA" >> "$output_file" + fi + done +done + +# ILS mode should be run separately +$TORCHSIM_DIR/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh +$TORCHSIM_DIR/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh +$TORCHSIM_DIR/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh +$TORCHSIM_DIR/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh + +python3 $TORCHSIM_DIR/experiments/artifact/speedup/summary_speedup.py | tee "$TORCHSIM_DIR/experiments/artifact/speedup/summary_speedup.log" \ No newline at end of file diff --git a/experiments/artifact/speedup/scripts/ils_parser.sh b/experiments/artifact/speedup/scripts/ils_parser.sh new file mode 100755 index 00000000..913daeea --- /dev/null +++ b/experiments/artifact/speedup/scripts/ils_parser.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +ignore_rest=false +gem5_cmd="" +result_path="" +gem5_time="" +togsim_time="" + +total_gem5=0 +total_togsim=0 + +while IFS= read -r line; do + if [[ "$line" == launch* ]]; then + tile_graph_path=$(echo "$line" | awk '{for (i=1; i<=NF; i++) if ($i ~ /tile_graph\.onnx$/) print $i}') + if [[ -n "$tile_graph_path" ]]; then + dir_path=$(dirname "$tile_graph_path") + sto_log_path="$dir_path/m5out/sto.log" + echo "sto.log path: $sto_log_path" + gem5_time=$(grep "Simulation time:" "$sto_log_path" | \ + sed -E 's/^Simulation time: ([0-9.]+) seconds$/\1/') + echo "GEM5: $gem5_time" + total_gem5=$(awk -v a="$total_gem5" -v b="$gem5_time" 'BEGIN {printf "%.6f", a+b}') + fi + fi + if [[ "$line" == *"Simulation time:"* ]]; then + togsim_time=$(echo "$line" | sed -E 's/.*Simulation time: ([0-9.]+) seconds/\1/') + echo "TOGSim: $togsim_time" + fi +done + +if [[ -n "$total_gem5" && -n "$total_togsim" ]]; then + total_time=$(python3 -c "print(round($total_gem5 + $total_togsim, 6))") + echo "Simulation time: $total_time seconds" +fi \ No newline at end of file diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh b/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh new file mode 100755 index 00000000..66829f02 --- /dev/null +++ b/experiments/artifact/speedup/scripts/run_speed_ils_bert.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +base_dir=$TORCHSIM_DIR/experiments/artifact/speedup +config=( + # "systolic_ws_8x8_c1_simple_noc.json" + "systolic_ws_128x128_c2_simple_noc_tpuv3.json" + #"systolic_ws_128x128_c2_booksim_tpuv3.json" + # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" +) +TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") +SIZE_LIST=( + # "base" + "large" + # "xlarge" +) +seq=512 +output_dir="$base_dir/results" +mkdir -p "$output_dir" + +for i in "${config[@]}"; do + echo "Running with config=$i" + for size in "${SIZE_LIST[@]}"; do + ops="bert_$size" + output_file="$output_dir/ils_${ops}_${i}.txt" + workload="$TORCHSIM_DIR/experiments/BERT.py --size $size --input_size $seq" + echo "===== config=$i | model=$ops =====" >> "$output_file" + sum=0.0 + count=0 + config_path="$TORCHSIM_DIR/PyTorchSimBackend/configs/$i" + + for iter in {1..5}; do + echo "[Iter $iter] Running simulation for workload=ils_$ops config=$config" + output=$(bash -c " + export TORCHSIM_TLS_MODE=0; + export TORCHSIM_VALIDATION_MODE=0; + export TORCHSIM_CONFIG=$config_path; + export AUTOTUNE=0; + printenv; + python3 $workload 2> /dev/null | $TORCHSIM_DIR/experiments/artifact/speedup/scripts/ils_parser.sh + ") + + sim_time=$(echo "$output" | grep "Simulation time:" | tail -n 1 | sed -E 's/.*Simulation time: ([0-9]+\.[0-9]+).*/\1/') + + if [[ -n "$sim_time" ]]; then + echo "Iteration $iter: Simulation time = $sim_time" + echo "Iteration $iter: simulation_time = $sim_time" >> "$output_file" + sum=$(awk -v a="$sum" -v b="$sim_time" 'BEGIN {printf "%.6f", a + b}') + count=$((count + 1)) + else + echo "Iteration $iter: Simulation time not found." + echo "Iteration $iter: simulation_time = NA" >> "$output_file" + fi + done + + if [[ $count -gt 0 ]]; then + avg=$(awk -v total="$sum" -v n="$count" 'BEGIN {printf "%.6f", total / n}') + echo "Average simulation time for $ops with config $i: $avg seconds" + echo "Average simulation time = $avg" >> "$output_file" + else + echo "No valid simulation times found for $ops with config $i" + echo "Average simulation time = NA" >> "$output_file" + fi + echo "" >> "$output_file" + done +done \ No newline at end of file diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh b/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh new file mode 100755 index 00000000..2f9718f1 --- /dev/null +++ b/experiments/artifact/speedup/scripts/run_speed_ils_conv.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +base_dir=$TORCHSIM_DIR/experiments/artifact/speedup +config=( + # "systolic_ws_8x8_c1_simple_noc.json" + "systolic_ws_128x128_c2_simple_noc_tpuv3.json" + #"systolic_ws_128x128_c2_booksim_tpuv3.json" + # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" +) +TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") +SHAPE_LIST=( + #B H W I_C O_C K S P + "1 56 56 64 64 3 1 1" + "1 28 28 128 128 3 1 1" + "1 14 14 256 256 3 1 1" + "1 7 7 512 512 3 1 1" +) +output_dir="$base_dir/results" +mkdir -p "$output_dir" + +for i in "${config[@]}"; do + echo "Running with config=$i" + for shape in "${SHAPE_LIST[@]}"; do + ops="conv_${shape// /x}" + output_file="$output_dir/ils_${ops}_${i}.txt" + workload="$TORCHSIM_DIR/experiments/conv.py --size $shape" + echo "===== config=$i | model=$ops =====" >> "$output_file" + sum=0.0 + count=0 + config_path="$TORCHSIM_DIR/PyTorchSimBackend/configs/$i" + + for iter in {1..5}; do + echo "[Iter $iter] Running simulation for workload=ils_$ops config=$config" + output=$(bash -c " + export TORCHSIM_TLS_MODE=0; + export TORCHSIM_VALIDATION_MODE=0; + export TORCHSIM_CONFIG=$config_path; + export AUTOTUNE=0; + printenv; + python3 $workload 2> /dev/null | $TORCHSIM_DIR/experiments/artifact/speedup/scripts/ils_parser.sh + ") + + sim_time=$(echo "$output" | grep "Simulation time:" | tail -n 1 | sed -E 's/.*Simulation time: ([0-9]+\.[0-9]+).*/\1/') + + if [[ -n "$sim_time" ]]; then + echo "Iteration $iter: Simulation time = $sim_time" + echo "Iteration $iter: simulation_time = $sim_time" >> "$output_file" + sum=$(awk -v a="$sum" -v b="$sim_time" 'BEGIN {printf "%.6f", a + b}') + count=$((count + 1)) + else + echo "Iteration $iter: Simulation time not found." + echo "Iteration $iter: simulation_time = NA" >> "$output_file" + fi + done + + if [[ $count -gt 0 ]]; then + avg=$(awk -v total="$sum" -v n="$count" 'BEGIN {printf "%.6f", total / n}') + echo "Average simulation time for $ops with config $i: $avg seconds" + echo "Average simulation time = $avg" >> "$output_file" + else + echo "No valid simulation times found for $ops with config $i" + echo "Average simulation time = NA" >> "$output_file" + fi + echo "" >> "$output_file" + done +done \ No newline at end of file diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh b/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh new file mode 100755 index 00000000..8ff7e2b6 --- /dev/null +++ b/experiments/artifact/speedup/scripts/run_speed_ils_matmul.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +base_dir=$TORCHSIM_DIR/experiments/artifact/speedup +config=( + # "systolic_ws_8x8_c1_simple_noc.json" + "systolic_ws_128x128_c2_simple_noc_tpuv3.json" + #"systolic_ws_128x128_c2_booksim_tpuv3.json" + # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" +) +TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") +SHAPE_LIST=( + "512 512 512" + "1024 1024 1024" + "2048 2048 2048" +) +output_dir="$base_dir/results" +mkdir -p "$output_dir" + +for i in "${config[@]}"; do + echo "Running with config=$i" + for shape in "${SHAPE_LIST[@]}"; do + ops="gemm_${shape// /x}" + output_file="$output_dir/ils_${ops}_${i}.txt" + workload="$TORCHSIM_DIR/experiments/gemm.py --size $shape" + echo "===== config=$i | model=$ops =====" >> "$output_file" + sum=0.0 + count=0 + config_path="$TORCHSIM_DIR/PyTorchSimBackend/configs/$i" + + for iter in {1..5}; do + echo "[Iter $iter] Running simulation for workload=ils_$ops config=$config" + output=$(bash -c " + export TORCHSIM_TLS_MODE=0; + export TORCHSIM_VALIDATION_MODE=1; + export TORCHSIM_CONFIG=$config_path; + export AUTOTUNE=0; + printenv; + python3 $workload 2> /dev/null | $TORCHSIM_DIR/experiments/artifact/speedup/scripts/ils_parser.sh + ") + + sim_time=$(echo "$output" | grep "Simulation time:" | tail -n 1 | sed -E 's/.*Simulation time: ([0-9]+\.[0-9]+).*/\1/') + + if [[ -n "$sim_time" ]]; then + echo "Iteration $iter: simulation_time = $sim_time" >> "$output_file" + sum=$(awk -v a="$sum" -v b="$sim_time" 'BEGIN {printf "%.6f", a + b}') + count=$((count + 1)) + else + echo "Iteration $iter: Simulation time not found." + echo "Iteration $iter: simulation_time = NA" >> "$output_file" + fi + done + + if [[ $count -gt 0 ]]; then + avg=$(awk -v total="$sum" -v n="$count" 'BEGIN {printf "%.6f", total / n}') + echo "Average simulation time for $ops with config $i: $avg seconds" + echo "Average simulation time = $avg" >> "$output_file" + else + echo "No valid simulation times found for $ops with config $i" + echo "Average simulation time = NA" >> "$output_file" + fi + echo "" >> "$output_file" + done +done \ No newline at end of file diff --git a/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh b/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh new file mode 100755 index 00000000..aa35735c --- /dev/null +++ b/experiments/artifact/speedup/scripts/run_speed_ils_resnet.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +base_dir=$TORCHSIM_DIR/experiments/artifact/speedup +config=( + # "systolic_ws_8x8_c1_simple_noc.json" + "systolic_ws_128x128_c2_simple_noc_tpuv3.json" + #"systolic_ws_128x128_c2_booksim_tpuv3.json" + # "systolic_ws_128x128_c2_simple_noc_tpuv4.json" +) +TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") +SIZE_LIST=( + # "18" + "50" +) +BATCH_LIST=( + "1" + # "8" + # "16" + # "32" + # "64" + # "128" +) +output_dir="$base_dir/results" +mkdir -p "$output_dir" + +for i in "${config[@]}"; do + echo "Running with config=$i" + for size in "${SIZE_LIST[@]}"; do + for batch in "${BATCH_LIST[@]}"; do + ops="resnet$size" + output_file="$output_dir/ils_${ops}_${i}.txt" + workload="$TORCHSIM_DIR/experiments/resnet$size.py --batch $batch" + echo "===== config=$i | model=$ops =====" >> "$output_file" + sum=0.0 + count=0 + config_path="$TORCHSIM_DIR/PyTorchSimBackend/configs/$i" + + for iter in {1..5}; do + echo "[Iter $iter] Running simulation for workload=ils_$ops config=$config" + output=$(bash -c " + export TORCHSIM_TLS_MODE=0; + export TORCHSIM_VALIDATION_MODE=0; + export TORCHSIM_CONFIG=$config_path; + export AUTOTUNE=0; + printenv; + python3 $workload 2> /dev/null | $TORCHSIM_DIR/experiments/artifact/speedup/scripts/ils_parser.sh + ") + + sim_time=$(echo "$output" | grep "Simulation time:" | tail -n 1 | sed -E 's/.*Simulation time: ([0-9]+\.[0-9]+).*/\1/') + + if [[ -n "$sim_time" ]]; then + echo "Iteration $iter: Simulation time = $sim_time" + echo "Iteration $iter: simulation_time = $sim_time" >> "$output_file" + sum=$(awk -v a="$sum" -v b="$sim_time" 'BEGIN {printf "%.6f", a + b}') + count=$((count + 1)) + else + echo "Iteration $iter: Simulation time not found." + echo "Iteration $iter: simulation_time = NA" >> "$output_file" + fi + done + + if [[ $count -gt 0 ]]; then + avg=$(awk -v total="$sum" -v n="$count" 'BEGIN {printf "%.6f", total / n}') + echo "Average simulation time for $ops with config $i: $avg seconds" + echo "Average simulation time = $avg" >> "$output_file" + else + echo "No valid simulation times found for $ops with config $i" + echo "Average simulation time = NA" >> "$output_file" + fi + echo "" >> "$output_file" + done + done +done \ No newline at end of file diff --git a/experiments/artifact/speedup/summary_speedup.py b/experiments/artifact/speedup/summary_speedup.py new file mode 100644 index 00000000..67a741a0 --- /dev/null +++ b/experiments/artifact/speedup/summary_speedup.py @@ -0,0 +1,156 @@ +import os +import csv +import re +import matplotlib.pyplot as plt +import numpy as np + +TORCHSIM_DIR = os.environ.get("TORCHSIM_DIR", ".") +LOG_DIR = os.path.join(TORCHSIM_DIR, "experiments/artifact/speedup/results") +BASELINE_CSV = os.path.join(TORCHSIM_DIR, "experiments/artifact/baseline_latency.csv") + + +def plot_speedup_bars(data: dict, filename: str): + colors = { + 'Accel-Sim': '#A6A6A6', + 'mNPUSim': '#E97132', + 'PyTorchSim(ILS)-SN': '#4EA72E', + 'PyTorchSim-SN': '#0070C0', + 'PyTorchSim-CN': '#A6CAEC', + } + + labels = list(data.keys()) + num_sims = len(colors) + bar_width = 1 + fig, ax = plt.subplots(figsize=(48, 16)) + + grouped_data = {sim: [[], []] for sim in colors} + x_pos = [] + x_offset = 0 + + for key, value in data.items(): + for i, (sim, color) in enumerate(colors.items()): + grouped_data[sim][0].append(value[i]) + grouped_data[sim][1].append(x_offset + bar_width * i) + x_pos.append(x_offset + bar_width * (num_sims // 2)) + x_offset += bar_width * (num_sims + 2) + + for sim, (heights, xpos) in grouped_data.items(): + bars = ax.bar(xpos, heights, width=bar_width, color=colors[sim], label=sim, edgecolor='black') + mae_val = heights[-1] + ax.text( + xpos[-1], + mae_val + 2 if mae_val >= 0 else mae_val - 6, + f'{mae_val:.1f}x', + ha='center', + va='bottom' if mae_val >= 0 else 'top', + fontsize=9, + rotation=90 + ) + + ax.set_xticks(x_pos) + ax.set_xticklabels(labels, rotation=20, ha='right') + ax.set_yscale('log') + ax.set_ylim(0.1, 150) + ax.set_yticks([0.1, 1, 10, 100]) + ax.get_yaxis().set_major_formatter(plt.ScalarFormatter()) + ax.yaxis.grid(True, linestyle='--', linewidth=0.5, alpha=0.7) + ax.legend() + + plt.savefig(filename) + plt.close() + print(f"Saved plot to {filename}") + +def format_with_speedup(value, ref, speedup_list=None): + try: + if value == "" or ref == "" or float(value) == 0: + return "N/A", 0.0 + val = float(value) + ref = float(ref) + spd = ref / val + if speedup_list is not None: + speedup_list.append(spd) + val_str = f"{float(val):>7.3f}" + spd_str = f"{spd:.2f}×" + return f"{val_str} ({spd_str:>7})", spd + except (ValueError, TypeError): + return "N/A", 0.0 + +def compute_geomean(errors): + if not errors: + return "N/A" + filtered = [abs(e) for e in errors if e > 0] + if not filtered: + return "0.00x" + prod = 1.0 + for e in filtered: + prod *= e + geo = prod ** (1.0 / len(filtered)) + return geo + +if __name__ == "__main__": + # 1. Generate cycle_map + average_time_map = {} + for file in os.listdir(LOG_DIR): + if file.endswith(".txt"): + full_path = os.path.join(LOG_DIR, file) + full_name = file[:-4] + name = full_name.split("_systolic", 1)[0] + if "ils" in full_name: + name = name + elif "booksim" in full_name: + name = name +"cn" + elif "simple_noc" in full_name: + name = name +"sn" + else: + raise ValueError(f"Unsupported file name format: {file}") + with open(full_path, errors="ignore") as f: + for line in f: + match = re.search(r"Average simulation time\s*=\s*([0-9]+(?:\.[0-9]+)?)", line) + if match: + average_time_map[name] = float(match.group(1)) + break + + # Speedup list init + accelsim_speedup = [] + mnpusim_speedup = [] + torchsim_ils_sn_speedup = [] + torchsim_sn_speedup = [] + torchsim_cn_speedup = [] + + # Plot data + plot_data ={} + + # Header + print("[*] Summary of Latency (Seconds) and Speedup (vs Accel-Sim)") + print("=" * 165) + print(f"{'Workload':>30} {'Accel-Sim':>25} {'mNPUSim':>25} {'PyTorchSim(ILS)-SN':>25} {'PyTorchSim-SN':>25} {'PyTorchSim-CN':>25}") + print("=" * 165) + + with open(BASELINE_CSV, newline="") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + workload = row["Workload"].lstrip('\ufeff') + accelsim = row["Accel-Sim"] + + mnpusim, mnpusim_spd = format_with_speedup(row["mNPUSim"], accelsim, mnpusim_speedup) + + togsim_ils_sn_val = average_time_map.get("ils_" + workload, "") + togsim_sn_val = average_time_map.get(workload+"sn", "") + togsim_cn_val = average_time_map.get(workload+"cn", "") + torchsim_ils_sn, ils_sn_spd = format_with_speedup(togsim_ils_sn_val, accelsim, torchsim_ils_sn_speedup) + torchsim_sn, sn_spd = format_with_speedup(togsim_sn_val, accelsim, torchsim_sn_speedup) + torchsim_cn, cn_spd = format_with_speedup(togsim_cn_val, accelsim, torchsim_cn_speedup) + plot_data[workload] = [1.0, mnpusim_spd, ils_sn_spd, sn_spd, cn_spd] + print(f"{workload:>30} {accelsim:>25} {mnpusim:>25} {torchsim_ils_sn:>25} {torchsim_sn:>25} {torchsim_cn:>25}") + + # MAE row + geomean_accelsim = 1.0 + geomean_mnpusim = compute_geomean(mnpusim_speedup) + geomean_torchsim_ils_sn = compute_geomean(torchsim_ils_sn_speedup) + geomean_torchsim_sn = compute_geomean(torchsim_sn_speedup) + geomean_torchsim_cn = compute_geomean(torchsim_cn_speedup) + plot_data["Geomean"] = [geomean_accelsim, geomean_mnpusim, geomean_torchsim_ils_sn, geomean_torchsim_sn, geomean_torchsim_cn] + print("=" * 165) + print(f"{'Geomean Speedup':>30} {'1x':>25} {geomean_mnpusim:>24.2f}x {geomean_torchsim_ils_sn:>24.2f}x {geomean_torchsim_sn:>24.2f}x {geomean_torchsim_cn:>24.2f}x") + path = os.path.join(TORCHSIM_DIR, "experiments/artifact/speedup/speedup.png") + plot_speedup_bars(plot_data, path) diff --git a/experiments/attention.py b/experiments/attention.py new file mode 100644 index 00000000..e8f89dac --- /dev/null +++ b/experiments/attention.py @@ -0,0 +1,56 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +import argparse +import datetime + + +def run_attention(size, config): + def attention(query, key, value): + import math + d_k = query.size(-1) + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) + p_attn = scores.softmax(dim=-2) + return torch.matmul(value.transpose(-1, -2), p_attn) + from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request + scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) + device = scheduler.execution_engine.module.custom_device() + query = torch.randn(size).to(device=device) + key = torch.randn(size).to(device=device) + value = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(attention) + + SchedulerDNNModel.register_model("attention", opt_fn) + request = Request("attention", [query, key, value], [], request_queue_idx=0) + scheduler.add_request(request, request_time=0) + + # Run scheduler + while not scheduler.is_finished(): + with torch.no_grad(): + scheduler.schedule() + + print(f"Attention {str(size)} Simulation Done") + +if __name__ == "__main__": + import os + import sys + base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') + config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path + sys.path.append(base_dir) + args = argparse.ArgumentParser() + args.add_argument('--size', nargs='+', type=int, default=[12, 512, 64], help='Tensor Shape') + args.add_argument('--dump_path', type=str, default='results') + args = args.parse_args() + size = args.size + size_str = "x".join([str(i) for i in size]) + result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"attention_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") + # setting environment variables + os.environ['TORCHSIM_DUMP_PATH'] = result_path + # only timing simulation + os.environ['TORCHSIM_VALIDATION_MODE'] = "0" + if 'BACKENDSIM_SPIKE_ONLY' in os.environ: + del os.environ['BACKENDSIM_SPIKE_ONLY'] + + run_attention(size, config) diff --git a/experiments/conv.py b/experiments/conv.py new file mode 100644 index 00000000..e8b97906 --- /dev/null +++ b/experiments/conv.py @@ -0,0 +1,57 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +import argparse +import datetime + + +def run_conv2d(batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding, config): + from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request + def custom_conv2d(a, b, bias): + i_c = a.shape[1] + o_c = b.shape[0] + conv2d = torch.nn.Conv2d(i_c, o_c, b.shape[-1], stride=stride, padding=padding, dilation=1, bias=False) + conv2d.weight = torch.nn.Parameter(b) + # conv2d.bias = torch.nn.Parameter(bias) + return conv2d(a) + scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) + device = scheduler.execution_engine.module.custom_device() + conv_input = torch.randn(batch_size, i_c, i_h, i_w).to(memory_format=torch.channels_last, device=device) + conv_kernel = torch.randn(o_c, i_c, kernel_size, kernel_size).to(memory_format=torch.channels_last, device=device) + conv_bias = torch.randn(o_c).to(device=device) + opt_fn = torch.compile(dynamic=False)(custom_conv2d) + + SchedulerDNNModel.register_model("CONV", opt_fn) + request = Request("CONV", [conv_input, conv_kernel, conv_bias], [], request_queue_idx=0) + scheduler.add_request(request, request_time=0) + + # Run scheduler + while not scheduler.is_finished(): + with torch.no_grad(): + scheduler.schedule() + + print(f"CONV {batch_size}_{i_h}_{i_w}_{i_c}_{o_c}_{kernel_size}_{stride}_{padding} (B_H_W_I_C_O_C_K_S_P) Simulation Done") + +if __name__ == "__main__": + import os + import sys + base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path + sys.path.append(base_dir) + args = argparse.ArgumentParser() + args.add_argument('--size', nargs='+', type=int, default=[8, 28, 28, 128, 128, 3, 1, 1], help='B H W I_C O_C K S P') + args.add_argument('--dump_path', type=str, default='results') + args = args.parse_args() + size = args.size + size_str = "_".join([str(i) for i in size]) + result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"CONV_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") + # setting environment variables + os.environ['TORCHSIM_DUMP_PATH'] = result_path + # only timing simulation + os.environ['TORCHSIM_VALIDATION_MODE'] = "0" + if 'BACKENDSIM_SPIKE_ONLY' in os.environ: + del os.environ['BACKENDSIM_SPIKE_ONLY'] + + run_conv2d(size[0], size[1], size[2], size[3], size[4], size[5], size[6], size[7], config) \ No newline at end of file diff --git a/experiments/gemm.py b/experiments/gemm.py new file mode 100644 index 00000000..a1fdcff6 --- /dev/null +++ b/experiments/gemm.py @@ -0,0 +1,54 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +import argparse +import datetime + + +def run_matmul(input_size, hidden_size, output_size, config): + from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request + def custom_matmul(a, b): + return torch.matmul(a, b) + scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) + device = scheduler.execution_engine.module.custom_device() + torch.manual_seed(0) + input = torch.randn(input_size, hidden_size).to(device=device) + weight = torch.randn(hidden_size, output_size).to(device=device) + opt_fn = torch.compile(dynamic=False)(custom_matmul) + + SchedulerDNNModel.register_model("GEMM", opt_fn) + request = Request("GEMM", [input, weight], [], request_queue_idx=0) + scheduler.add_request(request, request_time=0) + + # Run scheduler + while not scheduler.is_finished(): + scheduler.schedule() + + print(f"GEMM {input_size}x{hidden_size}x{output_size} (MxKxN) Simulation Done") + +if __name__ == "__main__": + import os + import sys + base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path + sys.path.append(base_dir) + args = argparse.ArgumentParser() + args.add_argument('--size', nargs='+', type=int, default=[128, 128, 128], help='M K N') + args.add_argument('--dump_path', type=str, default='results') + args = args.parse_args() + size = args.size + size_str = "x".join([str(i) for i in size]) + result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"GEMM_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") + # setting environment variables + os.environ['TORCHSIM_DUMP_PATH'] = result_path + # only timing simulation + os.environ['TORCHSIM_VALIDATION_MODE'] = "0" + if 'BACKENDSIM_SPIKE_ONLY' in os.environ: + del os.environ['BACKENDSIM_SPIKE_ONLY'] + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + run_matmul(size[0], size[1], size[2], config) diff --git a/experiments/layernorm.py b/experiments/layernorm.py new file mode 100644 index 00000000..f149394e --- /dev/null +++ b/experiments/layernorm.py @@ -0,0 +1,48 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +import argparse +import datetime + + +def run_layernorm(size, config): + from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request + scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) + device = scheduler.execution_engine.module.custom_device() + input = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(torch.nn.LayerNorm(size[-1]).to(device=device)) + + SchedulerDNNModel.register_model("LayerNorm", opt_fn) + request = Request("LayerNorm", [input], [], request_queue_idx=0) + scheduler.add_request(request, request_time=0) + + # Run scheduler + while not scheduler.is_finished(): + scheduler.schedule() + + print(f"LayerNorm {str(size)} Simulation Done") + +if __name__ == "__main__": + import os + import sys + base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path + sys.path.append(base_dir) + args = argparse.ArgumentParser() + args.add_argument('--size', nargs='+', type=int, default=[512, 768], help='Tensor Shape') + args.add_argument('--dump_path', type=str, default='results') + args = args.parse_args() + size = args.size + size_str = "x".join([str(i) for i in size]) + result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"LayerNorm_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") + # setting environment variables + os.environ['TORCHSIM_DUMP_PATH'] = result_path + os.environ['TORCHSIM_FUSION_REDUCTION_REDUCTION'] = "0" + # only timing simulation + os.environ['TORCHSIM_VALIDATION_MODE'] = "0" + if 'BACKENDSIM_SPIKE_ONLY' in os.environ: + del os.environ['BACKENDSIM_SPIKE_ONLY'] + + run_layernorm(size, config) diff --git a/experiments/resnet18.py b/experiments/resnet18.py new file mode 100644 index 00000000..5d9dcf86 --- /dev/null +++ b/experiments/resnet18.py @@ -0,0 +1,49 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +import argparse +import datetime + +def run_resnet(batch, config): + from torchvision.models import resnet18 + from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request + scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) + device = scheduler.execution_engine.module.custom_device() + model = resnet18().eval() + input = torch.randn(batch, 3, 224, 224).to(device=device) + opt_fn = torch.compile(dynamic=False)(model.to(device, memory_format=torch.channels_last)) + + SchedulerDNNModel.register_model("resnet18", opt_fn) + request = Request("resnet18", [input], [], request_queue_idx=0) + scheduler.add_request(request, request_time=0) + + # Run scheduler + while not scheduler.is_finished(): + with torch.no_grad(): + scheduler.schedule() + + print("ResNet18 Simulation Done") + +if __name__ == "__main__": + import os + import sys + base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') + config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path + sys.path.append(base_dir) + args = argparse.ArgumentParser() + args.add_argument('--batch', type=int, default=1) + args.add_argument('--dump_path', type=str, default='results') + args = args.parse_args() + batch = args.batch + result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"resnet18_{batch}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") + # setting environment variables + os.environ['TORCHSIM_DUMP_PATH'] = result_path + os.environ['TORCHSIM_USE_TIMING_POOLING'] = "1" + # only timing simulation + os.environ['TORCHSIM_VALIDATION_MODE'] = "0" + if 'BACKENDSIM_SPIKE_ONLY' in os.environ: + del os.environ['BACKENDSIM_SPIKE_ONLY'] + + run_resnet(batch, config) diff --git a/experiments/resnet50.py b/experiments/resnet50.py new file mode 100644 index 00000000..bd52afc1 --- /dev/null +++ b/experiments/resnet50.py @@ -0,0 +1,49 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +import argparse +import datetime + +def run_resnet(batch, config): + from torchvision.models import resnet50 + from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request + scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) + device = scheduler.execution_engine.module.custom_device() + model = resnet50().eval() + input = torch.randn(batch, 3, 224, 224).to(device=device) + opt_fn = torch.compile(dynamic=False)(model.to(device, memory_format=torch.channels_last)) + + SchedulerDNNModel.register_model("resnet50", opt_fn) + request = Request("resnet50", [input], [], request_queue_idx=0) + scheduler.add_request(request, request_time=0) + + # Run scheduler + while not scheduler.is_finished(): + with torch.no_grad(): + scheduler.schedule() + + print("ResNet50 Simulation Done") + +if __name__ == "__main__": + import os + import sys + base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path + sys.path.append(base_dir) + args = argparse.ArgumentParser() + args.add_argument('--batch', type=int, default=1) + args.add_argument('--dump_path', type=str, default='results') + args = args.parse_args() + batch = args.batch + result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"resnet50_{batch}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") + # setting environment variables + os.environ['TORCHSIM_DUMP_PATH'] = result_path + os.environ['TORCHSIM_USE_TIMING_POOLING'] = "1" + # only timing simulation + os.environ['TORCHSIM_VALIDATION_MODE'] = "0" + if 'BACKENDSIM_SPIKE_ONLY' in os.environ: + del os.environ['BACKENDSIM_SPIKE_ONLY'] + + run_resnet(batch, config) diff --git a/experiments/softmax.py b/experiments/softmax.py new file mode 100644 index 00000000..14d28fee --- /dev/null +++ b/experiments/softmax.py @@ -0,0 +1,47 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +import argparse +import datetime + + +def run_softmax(size, config, dim=1): + from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request + scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) + device = scheduler.execution_engine.module.custom_device() + input = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(torch.nn.Softmax(dim=dim).to(device=device)) + + SchedulerDNNModel.register_model("Softmax", opt_fn) + request = Request("Softmax", [input], [], request_queue_idx=0) + scheduler.add_request(request, request_time=0) + + # Run scheduler + while not scheduler.is_finished(): + scheduler.schedule() + + print(f"Softmax {str(size)} Simulation Done") + +if __name__ == "__main__": + import os + import sys + base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path + sys.path.append(base_dir) + args = argparse.ArgumentParser() + args.add_argument('--size', nargs='+', type=int, default=[512, 512], help='Tensor Shape') + args.add_argument('--dump_path', type=str, default='results') + args = args.parse_args() + size = args.size + size_str = "x".join([str(i) for i in size]) + result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"Softmax_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") + # setting environment variables + os.environ['TORCHSIM_DUMP_PATH'] = result_path + # only timing simulation + os.environ['TORCHSIM_VALIDATION_MODE'] = "0" + if 'BACKENDSIM_SPIKE_ONLY' in os.environ: + del os.environ['BACKENDSIM_SPIKE_ONLY'] + + run_softmax(size, config) diff --git a/gem5_script/script_systolic.py b/gem5_script/script_systolic.py index 89052f27..d5d3a92d 100644 --- a/gem5_script/script_systolic.py +++ b/gem5_script/script_systolic.py @@ -1,540 +1,196 @@ +import time import argparse -import os import sys - +import math import m5 from m5.objects import * -from ctypes import cdll +sys.path.append(os.environ.get('TORCHSIM_DIR')) +from gem5_script.vpu_config import * bin_path = sys.argv[1] parser = argparse.ArgumentParser() -parser.add_argument( - "-c", - "--cmd", - default="", - help="The binary to run in syscall emulation mode.", -) -parser.add_argument( - "-o", - "--options", - default="", - help="""The options to pass to the binary, use - around the entire string""", -) +parser.add_argument("-c", "--cmd", default="", help="The binary to run in syscall emulation mode.") +parser.add_argument("-o", "--options", default="", help="""The options to pass to the binary, use around the entire string""") +parser.add_argument("--cpu", choices=["RiscvAtomicSimpleCPU", "RiscvTimingSimpleCPU", "RiscvMinorCPU", "RiscvDerivO3CPU", + "RiscvMinorCPU", "RiscvCustomCPU", "RiscvMinorV2CPU", "RiscvMinorV4CPU", "RiscvVPU", + "RiscvSparseVPU"], default="RiscvVPU") +parser.add_argument("--mem", choices=["SimpleMemory", "ScratchpadMemory", "DDR3_1600_8x8"], default="ScratchpadMemory") +parser.add_argument("--sparse", type=bool, default=False) +parser.add_argument("--vlane", type=int, default=128) +parser.add_argument("--vlen", type=int, default=256) +args = parser.parse_args() -class MySimpleMemory(SimpleMemory): +class InstMemory(SimpleMemory): latency = "1ns" + bandwidth = "64GB/s" class SpadMemory(SimpleMemory): latency = "1ns" # latency unit is "tick" 1ns = 1000 ticks - bandwidth = "64GB/s" - # TODO: bandwidth = "XXGB/s" what is a proper value? (ref. simple_mem.cc:154) - -class SystolicArray(MinorFU): - unitType = "SystolicArray" - opClasses = minorMakeOpClassSet([ - "CustomMatMul", - "CustomMatMuliVpush", - "CustomMatMulwVpush", - "CustomMatMulvpop", - ]) - opLat = 1 - systolicArrayWidth = 128 - systolicArrayHeight = 128 - -class SparseAccelerator(MinorFU): - unitType = "SparseAccelerator" - opClasses = minorMakeOpClassSet([ - "CustomMatMul", - "CustomMatMuliVpush", - "CustomMatMulwVpush", - "CustomMatMulvpop", - ]) - opLat = 1 - -class SpecialFunctionUnit(MinorFU): - opClasses = minorMakeOpClassSet([ - "CustomMatMulvexp", - ]) - opLat = 10 - -class MinorFPUnit(MinorFU): - opClasses = minorMakeOpClassSet( - [ - "FloatAdd", - "FloatCmp", - "FloatCvt", - "FloatMult", - "FloatMultAcc", - "FloatDiv", - "FloatMisc", - "FloatSqrt" - ] - ) - -class MinorVecAdder(MinorFU): - opClasses = minorMakeOpClassSet( - [ - "SimdAdd", - "SimdFloatAdd", - "SimdFloatAlu", - "SimdFloatCmp", - ] - ) - opLat = 1 - -class MinorVecMultiplier(MinorFU): - opClasses = minorMakeOpClassSet( - [ - "SimdMult", - "SimdFloatMult", - ] - ) - opLat = 3 - -class MinorVecDivider(MinorFU): - opClasses = minorMakeOpClassSet( - [ - "SimdDiv", - "SimdFloatDiv", - ] - ) - opLat = 5 - -class MinorVecMisc(MinorFU): - opClasses = minorMakeOpClassSet( - [ - "SimdUnitStrideLoad", - "SimdUnitStrideStore", - "SimdUnitStrideMaskLoad", - "SimdUnitStrideMaskStore", - "SimdStridedLoad", - "SimdStridedStore", - "SimdIndexedLoad", - "SimdIndexedStore", - "SimdUnitStrideFaultOnlyFirstLoad", - "SimdWholeRegisterLoad", - "SimdWholeRegisterStore", - "SimdAddAcc", - "SimdAlu", - "SimdCmp", - "SimdCvt", - "SimdMultAcc", - "SimdMatMultAcc", - "SimdShift", - "SimdShiftAcc", - "SimdSqrt", - "SimdFloatCvt", - "SimdFloatMisc", - "SimdFloatMultAcc", - "SimdFloatMatMultAcc", - "SimdFloatSqrt", - "SimdReduceAdd", - "SimdReduceAlu", - "SimdReduceCmp", - "SimdFloatReduceAdd", - "SimdFloatReduceCmp", - "SimdAes", - "SimdAesMix", - "SimdSha1Hash", - "SimdSha1Hash2", - "SimdSha256Hash", - "SimdSha256Hash2", - "SimdShaSigma2", - "SimdShaSigma3", - "SimdPredAlu", - "SimdMisc", - - "SimdUnitStrideSegmentedLoad", - "SimdUnitStrideSegmentedStore", - "SimdExt", - "SimdFloatExt", - ] - ) - opLat = 1 - -class MinorVecConfig(MinorFU): - opClasses = minorMakeOpClassSet( - [ - "SimdConfig", - ] - ) - opLat = 1 - -class MinorCustomVecFU(MinorFU): - opClasses = minorMakeOpClassSet( - [ - "SimdUnitStrideLoad", - "SimdUnitStrideStore", - "SimdUnitStrideMaskLoad", - "SimdUnitStrideMaskStore", - "SimdStridedLoad", - "SimdStridedStore", - "SimdIndexedLoad", - "SimdIndexedStore", - "SimdUnitStrideFaultOnlyFirstLoad", - "SimdWholeRegisterLoad", - "SimdWholeRegisterStore", - "SimdAdd", - "SimdAddAcc", - "SimdAlu", - "SimdCmp", - "SimdCvt", - "SimdMisc", - "SimdMult", - "SimdMultAcc", - "SimdMatMultAcc", - "SimdShift", - "SimdShiftAcc", - "SimdDiv", - "SimdSqrt", - "SimdFloatAdd", - "SimdFloatAlu", - "SimdFloatCmp", - "SimdFloatCvt", - "SimdFloatDiv", - "SimdFloatMisc", - "SimdFloatMult", - "SimdFloatMultAcc", - "SimdFloatMatMultAcc", - "SimdFloatSqrt", - "SimdReduceAdd", - "SimdReduceAlu", - "SimdReduceCmp", - "SimdFloatReduceAdd", - "SimdFloatReduceCmp", - "SimdAes", - "SimdAesMix", - "SimdSha1Hash", - "SimdSha1Hash2", - "SimdSha256Hash", - "SimdSha256Hash2", - "SimdShaSigma2", - "SimdShaSigma3", - "SimdPredAlu", - "SimdMisc", - "SimdConfig", - ] - ) - opLat = 1 - -class MinorCustomIntFU(MinorFU): - opClasses = minorMakeOpClassSet(["IntAlu"]) - timings = [MinorFUTiming(description="Int", srcRegsRelativeLats=[2])] - opLat = 1 - -class MinorCustomFUPool(MinorFUPool): - funcUnits = [ - SystolicArray(), # 0 - - MinorVecConfig(), # 1 for vector config - - MinorFPUnit(), - MinorVecMisc(), # 2~5 - MinorVecMisc(), - MinorVecMisc(), - MinorVecMisc(), - - # ALU0 - MinorVecAdder(), # 6 - MinorVecMultiplier(), # 7 - MinorVecDivider(), # 8 - MinorVecAdder(), # 9 - MinorVecMultiplier(), # 10 - MinorVecDivider(), # 11 - MinorVecAdder(), # 12 - MinorVecMultiplier(), # 13 - MinorVecDivider(), # 14 - MinorVecAdder(), # 15 - MinorVecMultiplier(), # 16 - MinorVecDivider(), # 17 - - # ALU1 - MinorVecAdder(), # 18 ~ 29 - MinorVecMultiplier(), - MinorVecDivider(), - MinorVecAdder(), - MinorVecMultiplier(), - MinorVecDivider(), - MinorVecAdder(), - MinorVecMultiplier(), - MinorVecDivider(), - MinorVecAdder(), - MinorVecMultiplier(), - MinorVecDivider(), - - MinorCustomIntFU(), # 30 - MinorCustomIntFU(), - - MinorDefaultIntMulFU(), - MinorDefaultIntDivFU(), - MinorDefaultPredFU(), - MinorDefaultMemFU(), - MinorDefaultMiscFU(), - - SpecialFunctionUnit(), - - # SparseAccelerator(), - # Serializer0(), - # Serializer1(), - # DeSerializer(), - ] - -class MinorCustomSparseFUPool(MinorFUPool): - funcUnits = [ - MinorVecConfig(), # for vector config - - MinorFPUnit(), - MinorVecMisc(), - MinorVecMisc(), - MinorVecMisc(), - MinorVecMisc(), - - # ALU0 - MinorVecAdder(), - MinorVecMultiplier(), - MinorVecDivider(), - MinorVecAdder(), - MinorVecMultiplier(), - MinorVecDivider(), - MinorVecAdder(), - MinorVecMultiplier(), - MinorVecDivider(), - MinorVecAdder(), - MinorVecMultiplier(), - MinorVecDivider(), - - # ALU1 - MinorVecAdder(), - MinorVecMultiplier(), - MinorVecDivider(), - MinorVecAdder(), - MinorVecMultiplier(), - MinorVecDivider(), - MinorVecAdder(), - MinorVecMultiplier(), - MinorVecDivider(), - MinorVecAdder(), - MinorVecMultiplier(), - MinorVecDivider(), - MinorCustomIntFU(), - MinorCustomIntFU(), - - MinorDefaultIntMulFU(), - MinorDefaultIntDivFU(), - MinorDefaultPredFU(), - MinorDefaultMemFU(), - MinorDefaultMiscFU(), - - SparseAccelerator(), - # Serializer0(), - # Serializer1(), - # DeSerializer(), - ] - -class RiscvCustomCPU(RiscvMinorCPU): - fetch2InputBufferSize = 4 - decodeInputWidth = 4 - executeInputWidth = 4 - executeIssueLimit = 8 - executeMemoryIssueLimit = 2 - executeCommitLimit = 8 - executeMemoryCommitLimit = 2 - executeFuncUnits = MinorCustomFUPool() - -class RiscvVPU(RiscvMinorCPU): - fetch2InputBufferSize = 2 - decodeInputBufferSize = 1 - decodeInputWidth = 1 - executeInputWidth = 8 - executeIssueLimit = 8 - executeMemoryIssueLimit = 8 - executeCommitLimit = 8 - executeMemoryCommitLimit = 8 - executeFuncUnits = MinorCustomFUPool() - -class RiscvSparseVPU(RiscvMinorCPU): - fetch2InputBufferSize = 2 - decodeInputBufferSize = 1 - decodeInputWidth = 1 - executeInputWidth = 8 - executeIssueLimit = 8 - executeMemoryIssueLimit = 8 - executeCommitLimit = 8 - executeMemoryCommitLimit = 8 - executeFuncUnits = MinorCustomSparseFUPool() - -class MinorV2FUPool(MinorFUPool): - funcUnits = [ - MinorDefaultIntFU(), - MinorDefaultIntFU(), - MinorDefaultIntMulFU(), - MinorDefaultIntDivFU(), - MinorDefaultFloatSimdFU(), - MinorDefaultPredFU(), - MinorDefaultMemFU(), - MinorDefaultMiscFU(), - # MinorDefaultVecFU(), - # MinorDefaultVecFU(), - ] - -class RiscvMinorV2CPU(RiscvMinorCPU): - executeFuncUnits = MinorV2FUPool() - -class MinorV4FUPool(MinorFUPool): - funcUnits = [ - MinorDefaultIntFU(), - MinorDefaultIntFU(), - MinorDefaultIntMulFU(), - MinorDefaultIntDivFU(), - MinorDefaultFloatSimdFU(), - MinorDefaultPredFU(), - MinorDefaultMemFU(), - MinorDefaultMiscFU(), - # MinorDefaultVecFU(), - # MinorDefaultVecFU(), - # MinorDefaultVecFU(), - # MinorDefaultVecFU(), - ] - -class RiscvMinorV4CPU(RiscvMinorCPU): - executeFuncUnits = MinorV4FUPool() - executeCommitLimit = 4 - executeMemoryCommitLimit = 1 - -class L1Cache(Cache): + def __init__(self, bandwidth="4GB/s"): + super().__init__() + self.bandwidth = bandwidth # Set the bandwidth for this memory bank + +class MultiBankMemorySystem(): + def __init__(self, bus_port, mem_range, num_banks=8, granule_size=4, total_bandwidth="32GB/s"): + self.num_banks = num_banks + self.granule_size = granule_size + + # Calculate interleaving properties + self.intlvBits = int(math.log2(self.num_banks)) # Interleaving bits + self.intlvLowBit = int(math.log2(self.granule_size)) # Granule size low bit + self.intlvHighBit = self.intlvLowBit + self.intlvBits - 1 # High bit for interleaving + self.mem_ctrls = [] + self.bandwidth_per_bank = self.divide_bandwidth(total_bandwidth[:-2], self.num_banks) + + # Create memory controllers for each bank + self.create_memory_banks(bus_port, mem_range) + + def create_memory_banks(self, bus_port, mem_range): + """Create memory banks and interleave them""" + for i in range(self.num_banks): + #print(f"[Spad Bank {i}] Bandwidth {self.bandwidth_per_bank}") + mem = SpadMemory(self.bandwidth_per_bank) # Create a new memory bank + # Define the memory range for each bank (interleaving range) + if self.num_banks!=1: + print("intlvBits:",self.intlvBits, " intlvHighBits: ", self.intlvHighBit) + mem.range = AddrRange( + start=mem_range.start, size=mem_range.size(), + intlvBits=self.intlvBits, intlvMatch=i, intlvHighBit=self.intlvHighBit + ) + else: + mem.range = AddrRange(start=mem_range.start, size=mem_range.size()) + mem.port = bus_port + self.mem_ctrls.append(mem) + + def divide_bandwidth(self, total_bandwidth, num_banks): + total_bandwidth_bytes = self.bandwidth_to_bytes(total_bandwidth) + per_bank_bandwidth_bytes = total_bandwidth_bytes / num_banks + return self.bytes_to_bandwidth(per_bank_bandwidth_bytes) + + def bandwidth_to_bytes(self, bandwidth): + # Extract the value and unit + value, unit = bandwidth[:-2], bandwidth[-2:] + value = float(value) + # Convert based on the unit + if unit == "GB": + return value * 1e9 + elif unit == "MB": + return value * 1e6 + elif unit == "KB": + return value * 1e3 + elif unit == "B": + return value + else: + raise ValueError(f"Unknown bandwidth unit: {unit}") + + def bytes_to_bandwidth(self, bandwidth_bytes): + if bandwidth_bytes >= 1e9: + return f"{bandwidth_bytes / 1e9}GB/s" + elif bandwidth_bytes >= 1e6: + return f"{bandwidth_bytes / 1e6}MB/s" + elif bandwidth_bytes >= 1e3: + return f"{bandwidth_bytes / 1e3}KB/s" + else: + return f"{bandwidth_bytes}B/s" + + def get_ctrls(self): + return self.mem_ctrls + +class L1Cache(NoncoherentCache): """Simple L1 Cache with default values""" - assoc = 8 tag_latency = 1 data_latency = 1 response_latency = 1 mshrs = 16 tgts_per_mshr = 20 - def connectBus(self, bus): - """Connect this cache to a memory-side bus""" self.mem_side = bus.cpu_side_ports def connectCPU(self, cpu): - """Connect this cache's port to a CPU-side port - This must be defined in a subclass""" raise NotImplementedError class L1ICache(L1Cache): - """Simple L1 instruction cache with default values""" - - # Set the default size - size = "8192kB" # is it enough for infinite ICache? + size = "8192kB" + tag_latency = 0 + data_latency = 0 + response_latency = 0 def connectCPU(self, cpu): - """Connect this cache's port to a CPU icache port""" self.cpu_side = cpu.icache_port valid_cpu = { -# "X86AtomicSimpleCPU": X86AtomicSimpleCPU, -# "X86TimingSimpleCPU": X86TimingSimpleCPU, -# "X86DerivO3CPU": X86O3CPU, -# "ArmAtomicSimpleCPU": ArmAtomicSimpleCPU, -# "ArmTimingSimpleCPU": ArmTimingSimpleCPU, -# "ArmMinorCPU": ArmMinorCPU, -# "ArmDerivO3CPU": ArmO3CPU, - "RiscvAtomicSimpleCPU": RiscvAtomicSimpleCPU, - "RiscvTimingSimpleCPU": RiscvTimingSimpleCPU, "RiscvMinorCPU": RiscvMinorCPU, "RiscvDerivO3CPU": RiscvO3CPU, "RiscvMinorCPU": RiscvMinorCPU, - "RiscvCustomCPU": RiscvCustomCPU, - "RiscvMinorV2CPU": RiscvMinorV2CPU, - "RiscvMinorV4CPU": RiscvMinorV4CPU, "RiscvVPU": RiscvVPU, - "RiscvSparseVPU": RiscvSparseVPU, } -valid_mem = {"SimpleMemory": MySimpleMemory, "ScratchpadMemory": SpadMemory, "DDR3_1600_8x8": DDR3_1600_8x8} - -#parser = argparse.ArgumentParser() -#parser.add_argument("binary", type=str) -#parser.add_argument("--cpu", choices=valid_cpu.keys(), default="RiscvTimingSimpleCPU") -parser.add_argument("--cpu", choices=valid_cpu.keys(), default="RiscvVPU") -parser.add_argument("--mem", choices=valid_mem.keys(), default="ScratchpadMemory") -parser.add_argument("--sparse", type=bool, default=False) -parser.add_argument("--vlane", type=int, default=128) - -args = parser.parse_args() - # change systolicArrayWidth and systolicArrayHeight into args.vlane SystolicArray.systolicArrayWidth = args.vlane SystolicArray.systolicArrayHeight = args.vlane - -system = System() - -thispath = os.path.dirname(os.path.realpath(__file__)) binary = args.cmd -#binary = os.path.join( -# thispath, -# "../../../", -# args.binary, -#) -#system.workload = SEWorkload.init_compatible(args.binary) +# Main System Setup +system = System() system.workload = SEWorkload.init_compatible(binary) +# Clock setting system.clk_domain = SrcClockDomain() system.clk_domain.clock = "1GHz" system.clk_domain.voltage_domain = VoltageDomain() -if args.cpu not in ( - "X86AtomicSimpleCPU", - "ArmAtomicSimpleCPU", - "RiscvAtomicSimpleCPU", -): - system.mem_mode = "timing" - -system.mem_ranges = [AddrRange("8192MB")] +fast_clk = SrcClockDomain() +fast_clk.clock = '8GHz' +fast_clk.voltage_domain = VoltageDomain() +system.mem_mode = "timing" +system.cache_line_size = 64 system.cpu = valid_cpu[args.cpu]() +system.cpu.ArchISA.vlen = args.vlen + +# Memory range +granule_sz = 64 +spad_num_bank = 1 +system.mem_ranges = [AddrRange(start=0, size="16GB")] system.membus = SpmXBar( - width = 64, - frontend_latency = 0, - forward_latency = 0, - response_latency = 0) -# system.cpu.icache_port = system.membus.cpu_side_ports + width = granule_sz, + header_latency = 0, + frontend_latency = 0, + forward_latency = 0, + response_latency = 0) +system.membus.clk_domain = fast_clk + +# Instruction cache connection +system.cpu.icache= L1ICache() +system.cpu.icache.connectCPU(system.cpu) +system.cpu.icache.connectBus(system.membus) +#system.cpu.icache.mem_side = inst_mem.port system.cpu.dcache_port = system.membus.cpu_side_ports - -system.cpu.l1i = L1ICache() -system.cpu.l1i.connectCPU(system.cpu) -system.cpu.l1i.connectBus(system.membus) - system.cpu.createInterruptController() -if args.cpu in ("X86AtomicSimpleCPU", "X86TimingSimpleCPU", "X86DerivO3CPU"): - system.cpu.interrupts[0].pio = system.membus.mem_side_ports - system.cpu.interrupts[0].int_master = system.membus.cpu_side_ports - system.cpu.interrupts[0].int_slave = system.membus.mem_side_ports -system.mem_ctrl = valid_mem[args.mem]() -system.mem_ctrl.range = system.mem_ranges[0] -system.mem_ctrl.port = system.membus.mem_side_ports +# Create and connect memory nodes +multi_banked_spm = MultiBankMemorySystem(system.membus.mem_side_ports, system.mem_ranges[0], num_banks=spad_num_bank, granule_size=granule_sz) +system.mem_ctrls = multi_banked_spm.get_ctrls() + system.system_port = system.membus.cpu_side_ports process = Process() -#process.cmd = [args.binary] process.cmd = [binary] + args.options.split() system.cpu.workload = process system.cpu.createThreads() +# Simulation root = Root(full_system=False, system=system) m5.instantiate() - +start_time = time.time() exit_event = m5.simulate() if exit_event.getCause() != "exiting with last active thread context": exit(1) - -# print(f"Exiting @ tick {m5.curTick()} because {exit_event.getCause()}") -print(f"{m5.curTick() / 1000}") -print(f"{m5.curTick()}") - +end_time = time.time() +elapsed_seconds = end_time - start_time +print(f"Simulation time: {elapsed_seconds:.6f} seconds") diff --git a/gem5_script/vpu_config.py b/gem5_script/vpu_config.py new file mode 100644 index 00000000..eeeaefab --- /dev/null +++ b/gem5_script/vpu_config.py @@ -0,0 +1,240 @@ +import m5 +from m5.objects import * + +class SystolicArray(MinorFU): + unitType = "SystolicArray" + opClasses = minorMakeOpClassSet(["CustomMatMul", "CustomMatMuliVpush", "CustomMatMulwVpush", "CustomMatMulvpop"]) + opLat = 1 + systolicArrayWidth = 128 + systolicArrayHeight = 128 + +class SparseAccelerator(MinorFU): + unitType = "SparseAccelerator" + opClasses = minorMakeOpClassSet(["CustomMatMul", "CustomMatMuliVpush", "CustomMatMulwVpush", "CustomMatMulvpop"]) + opLat = 1 + +class SpecialFunctionUnit(MinorFU): + opClasses = minorMakeOpClassSet([ + "CustomMatMulvexp", + "CustomMatMulverf", + "CustomMatMulvtanh", + ]) + opLat = 10 + +class MinorFPUnit(MinorFU): + opClasses = minorMakeOpClassSet( + [ + "FloatAdd", + "FloatCmp", + "FloatCvt", + "FloatMult", + "FloatMultAcc", + "FloatDiv", + "FloatMisc", + "FloatSqrt" + ] + ) + +class MinorVecAdder(MinorFU): + opClasses = minorMakeOpClassSet( + [ + "SimdAdd", + "SimdFloatAdd", + "SimdFloatAlu", + "SimdFloatCmp", + "SimdShift", + "SimdShiftAcc", + "SimdAddAcc", + "SimdAlu", + "SimdCmp", + ] + ) + opLat = 1 + +class MinorVecMultiplier(MinorFU): + opClasses = minorMakeOpClassSet( + [ + "SimdMult", + "SimdFloatMult", + "SimdMultAcc", + "SimdMatMultAcc", + "SimdSqrt", + "SimdFloatMultAcc", + "SimdFloatMatMultAcc", + "SimdFloatSqrt", + ] + ) + opLat = 1 + +class MinorVecDivider(MinorFU): + opClasses = minorMakeOpClassSet( + [ + "SimdDiv", + "SimdFloatDiv", + ] + ) + opLat = 1 + +class MinorVecReduce(MinorFU): + opClasses = minorMakeOpClassSet( + [ + "SimdReduceAdd", + "SimdReduceAlu", + "SimdReduceCmp", + "SimdFloatReduceAdd", + "SimdFloatReduceCmp", + ] + ) + opLat = 1 + +class MinorVecLdStore(MinorFU): + opClasses = minorMakeOpClassSet( + [ + "SimdUnitStrideLoad", + "SimdUnitStrideStore", + "SimdUnitStrideMaskLoad", + "SimdUnitStrideMaskStore", + "SimdStridedLoad", + "SimdStridedStore", + "SimdIndexedLoad", + "SimdIndexedStore", + "SimdUnitStrideFaultOnlyFirstLoad", + "SimdWholeRegisterLoad", + "SimdWholeRegisterStore", + "SimdUnitStrideSegmentedLoad", + "SimdUnitStrideSegmentedStore", + ] + ) + opLat = 1 + +class MinorVecMisc(MinorFU): + opClasses = minorMakeOpClassSet( + [ + "SimdCvt", + "SimdFloatCvt", + "SimdFloatMisc", + "SimdPredAlu", + "SimdMisc", + "SimdExt", + "SimdFloatExt", + "CustomVlaneIdx", + ] + ) + opLat = 1 + +class MinorVecConfig(MinorFU): + opClasses = minorMakeOpClassSet( + [ + "SimdConfig", + ] + ) + opLat = 1 + +class MinorCustomIntFU(MinorDefaultIntFU): + opLat = 1 + +class MinorCustomIntDivFU(MinorDefaultIntDivFU): + opLat = 1 + +class MinorCustomIntMulFU(MinorDefaultIntMulFU): + opLat = 1 + +class MinorCustomPredFU(MinorDefaultPredFU): + opLat = 1 + +class MinorCustomMemFU(MinorDefaultMemFU): + opLat = 1 + +class MinorCustomMiscFU(MinorDefaultMiscFU): + opLat = 1 + +class MinorCustomFUPool(MinorFUPool): + funcUnits = [ + # Scalar unit + MinorFPUnit(), + MinorCustomIntFU(), + MinorCustomIntFU(), + MinorCustomIntMulFU(), + MinorCustomIntDivFU(), + MinorCustomPredFU(), + MinorCustomMemFU(), + MinorCustomMiscFU(), + + # Scalar unit + MinorFPUnit(), + MinorCustomIntFU(), + MinorCustomIntFU(), + MinorCustomIntMulFU(), + MinorCustomIntDivFU(), + MinorCustomPredFU(), + MinorCustomMemFU(), + MinorCustomMiscFU(), + + # Matmul unit + SystolicArray(), # 0 + + # Vector + MinorVecConfig(), # 1 for vector config + MinorVecConfig(), + MinorVecMisc(), + MinorVecMisc(), + MinorVecLdStore(), + MinorVecLdStore(), + + # Vector ALU0 + MinorVecAdder(), # 6 + MinorVecMultiplier(), # 7 + MinorVecDivider(), # 8 + MinorVecReduce(), + + # Vector ALU1 + MinorVecAdder(), # 18 ~ 29 + MinorVecMultiplier(), + MinorVecDivider(), + MinorVecReduce(), + + # Vector + MinorVecConfig(), # 1 for vector config + MinorVecConfig(), + MinorVecMisc(), + MinorVecMisc(), + MinorVecLdStore(), + MinorVecLdStore(), + + # Vector ALU0 + MinorVecAdder(), # 6 + MinorVecMultiplier(), # 7 + MinorVecDivider(), # 8 + MinorVecReduce(), + + # Vector ALU1 + MinorVecAdder(), # 18 ~ 29 + MinorVecMultiplier(), + MinorVecDivider(), + MinorVecReduce(), + + # SFU + SpecialFunctionUnit(), + ] + +class RiscvVPU(RiscvMinorCPU): + fetch1FetchLimit = 8 + decodeInputWidth = 8 + fetch1ToFetch2BackwardDelay = 0 + fetch2InputBufferSize = 8 + decodeInputBufferSize = 8 + decodeInputWidth = 8 + executeInputBufferSize = 128 + executeInputWidth = 12 + executeIssueLimit = 12 + executeCommitLimit = 12 + + # Memory + executeMemoryIssueLimit = 8 + executeMemoryCommitLimit = 8 + executeMaxAccessesInMemory = 8 + executeLSQMaxStoreBufferStoresPerCycle = 8 + executeLSQTransfersQueueSize = 8 + executeLSQStoreBufferSize = 8 + + executeFuncUnits = MinorCustomFUPool() diff --git a/scripts/CompilerOpt_experiment/DMAopt.sh b/scripts/CompilerOpt_experiment/DMAopt.sh new file mode 100644 index 00000000..469cf766 --- /dev/null +++ b/scripts/CompilerOpt_experiment/DMAopt.sh @@ -0,0 +1,28 @@ +#!/bin/bash +export TORCHSIM_CONFIG="/root/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json" + +# None FG DMA +export TORCHSIM_SUBTILE=0 +python experiments/gemm.py --size 128 128 128 +python experiments/gemm.py --size 256 256 256 +python experiments/gemm.py --size 512 512 512 +python experiments/gemm.py --size 1024 1024 1024 +python experiments/gemm.py --size 2048 2048 2048 + +# FG DMA +export TORCHSIM_SUBTILE=1 +export TORCHSIM_MANUAL_SUBTILE_SIZE=1 +python experiments/gemm.py --size 128 128 128 +python experiments/gemm.py --size 256 256 256 +python experiments/gemm.py --size 512 512 512 +python experiments/gemm.py --size 1024 1024 1024 +python experiments/gemm.py --size 2048 2048 2048 + +# SFG DMA +export TORCHSIM_SUBTILE=1 +export TORCHSIM_MANUAL_SUBTILE_SIZE=0 +python experiments/gemm.py --size 128 128 128 +python experiments/gemm.py --size 256 256 256 +python experiments/gemm.py --size 512 512 512 +python experiments/gemm.py --size 1024 1024 1024 +python experiments/gemm.py --size 2048 2048 2048 \ No newline at end of file diff --git a/scripts/ILS_experiment/ils_parser.sh b/scripts/ILS_experiment/ils_parser.sh new file mode 100755 index 00000000..a02d8edb --- /dev/null +++ b/scripts/ILS_experiment/ils_parser.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +ignore_rest=false +gem5_cmd="" +result_path="" +gem5_time="" +togsim_time="" + +total_gem5=0 +total_togsim=0 + +while IFS= read -r line; do + if [[ "$line" == launch* ]]; then + tile_graph_path=$(echo "$line" | awk '{for (i=1; i<=NF; i++) if ($i ~ /tile_graph\.onnx$/) print $i}') + if [[ -n "$tile_graph_path" ]]; then + dir_path=$(dirname "$tile_graph_path") + sto_log_path="$dir_path/m5out/sto.log" + echo "sto.log path: $sto_log_path" + gem5_time=$(grep "Simulation time:" "$sto_log_path" | \ + sed -E 's/^Simulation time: ([0-9.]+) seconds$/\1/') + echo "GEM5: $gem5_time" + total_gem5=$(echo "$total_gem5 + $gem5_time" | bc) + fi + fi + if [[ "$line" == *"Simulation time:"* ]]; then + togsim_time=$(echo "$line" | sed -E 's/.*Simulation time: ([0-9.]+) seconds/\1/') + echo "TOGSim: $togsim_time" + fi +done + +if [[ -n "$total_gem5" && -n "$total_togsim" ]]; then + total_time=$(python3 -c "print(round($total_gem5 + $total_togsim, 6))") + echo "Simulation time: $total_time seconds" +fi \ No newline at end of file diff --git a/scripts/ILS_experiment/test_matmul.py b/scripts/ILS_experiment/test_matmul.py new file mode 100644 index 00000000..09cc407d --- /dev/null +++ b/scripts/ILS_experiment/test_matmul.py @@ -0,0 +1,66 @@ +import torch +import argparse +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_matmul(device, input_size=128, hidden_size=128, output_size=128): + def custom_matmul(a, b): + return torch.matmul(a, b) + torch.manual_seed(0) + input = torch.randn(input_size, hidden_size) + weight = torch.randn(hidden_size, output_size) + x1 = input.to(device=device) + w1 = weight.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + opt_fn = torch.compile(dynamic=False)(custom_matmul) + res = opt_fn(x1, w1) + y = custom_matmul(x2, w2) + test_result("Matmul Forward", res, y) + +def test_addmm(device, input_size=128, hidden_size=128, output_size=128, bias_rank=1): + def custom_matmul(bias, a, b): + return torch.addmm(bias, a, b) + torch.manual_seed(0) + input = torch.randn(input_size, hidden_size) + weight = torch.randn(hidden_size, output_size) + bias = torch.randn(output_size) if bias_rank == 1 else torch.randn(input_size, output_size) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + opt_fn = torch.compile(dynamic=False)(custom_matmul) + res = opt_fn(b1, x1, w1) + y = custom_matmul(b2, x2, w2) + test_result("Addmm Forward", res, y) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + parser = argparse.ArgumentParser(description="Run matmul with given shape") + parser.add_argument('--shape', type=str, default="(512,512,512)") + args = parser.parse_args() + shape = tuple(map(int, args.shape.strip('()').split(','))) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_matmul(device, *shape) diff --git a/scripts/batch_experiment/avg.py b/scripts/batch_experiment/avg.py new file mode 100644 index 00000000..b91287b6 --- /dev/null +++ b/scripts/batch_experiment/avg.py @@ -0,0 +1,22 @@ +import re +import sys + +def parse_log_file(file_path, interval): + with open(file_path, "r") as file: + index = 0 + for line in file: + if index % interval != 0: + index += 1 + continue + index += 1 + print(line.strip()) + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Wrong input") + sys.exit(1) + + log_file = sys.argv[1] + interval = int(sys.argv[2]) + parse_log_file(log_file, interval) + diff --git a/scripts/batch_experiment/batch_time.py b/scripts/batch_experiment/batch_time.py new file mode 100644 index 00000000..9f8778d7 --- /dev/null +++ b/scripts/batch_experiment/batch_time.py @@ -0,0 +1,37 @@ +import re +import sys + +def time_to_milliseconds(timestamp): + match = re.match(r"\[(\d{4}-\d{2}-\d{2}) (\d{2}):(\d{2}):(\d{2})\.(\d{3})\]", timestamp) + if not match: + return None + + _, hh, mm, ss, ms = match.groups() + + total_ms = (int(hh) * 3600 + int(mm) * 60 + int(ss)) * 1000 + int(ms) + return total_ms + +def parse_log_file(file_path): + with open(file_path, "r") as file: + counter = 0 + for line in file: + if "batch size" in line: + print(line.strip()) + counter = 40 + continue + counter -= 1 + if (counter > 0): + time_match = re.search(r"\[(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3})\]", line) + if time_match: + timestamp = time_match.group(0) # "[YYYY-MM-DD HH:MM:SS.sss]" 형식 + time_ms = time_to_milliseconds(timestamp) + print(time_ms) + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Wrong input") + sys.exit(1) + + log_file = sys.argv[1] + parse_log_file(log_file) + diff --git a/scripts/batch_experiment/parse.py b/scripts/batch_experiment/parse.py new file mode 100644 index 00000000..dd3e504f --- /dev/null +++ b/scripts/batch_experiment/parse.py @@ -0,0 +1,35 @@ +import re +import sys + +def time_to_milliseconds(timestamp): + match = re.match(r"\[(\d{4}-\d{2}-\d{2}) (\d{2}):(\d{2}):(\d{2})\.(\d{3})\]", timestamp) + if not match: + return None + + _, hh, mm, ss, ms = match.groups() + + total_ms = (int(hh) * 3600 + int(mm) * 60 + int(ss)) * 1000 + int(ms) + return total_ms + +def parse_log_file(file_path): + with open(file_path, "r") as file: + for line in file: + time_match = re.search(r"\[(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3})\]", line) + # Cycle 값 추출 (예: Total cycle 43858000) + cycle_match = re.search(r"\[0\] : Total cycle (\d+)", line) + + if time_match and cycle_match: + timestamp = time_match.group(0) # "[YYYY-MM-DD HH:MM:SS.sss]" 형식 + cycle = cycle_match.group(1) # Cycle 값 + + time_ms = time_to_milliseconds(timestamp) + print(time_ms, cycle) + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Wrong input") + sys.exit(1) + + log_file = sys.argv[1] + parse_log_file(log_file) + diff --git a/scripts/chiplet.sh b/scripts/chiplet.sh index f0404eb6..3dfba3d9 100755 --- a/scripts/chiplet.sh +++ b/scripts/chiplet.sh @@ -13,19 +13,22 @@ if [ $# -lt 1 ]; then fi GEMM_PATH="$1" +INDEX_NAME="$2" SIMULATOR_PATH="$TORCHSIM_DIR/PyTorchSimBackend/build/bin/Simulator" GEMM_DIR_NAME=$(basename "$GEMM_PATH") echo "GEMM Directory Name: $GEMM_DIR_NAME" CONFIG_LIST=( - "$TORCHSIM_DIR/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json" - "$TORCHSIM_DIR/PyTorchSimBackend/configs/systolic_ws_128x128_c2_booksim_tpuv2.json" - "$TORCHSIM_DIR/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv2.json" - "$TORCHSIM_DIR/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv2_xnuma.json" + "$TORCHSIM_DIR/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv3.json" ) +CONFIG_LIST2=( + "$TORCHSIM_DIR/PyTorchSimBackend/configs/systolic_ws_128x128_c2_booksim_tpuv3.json" + "$TORCHSIM_DIR/PyTorchSimBackend/configs/systolic_ws_128x128_c2_chiplet_tpuv3_xnuma.json" +) +shift shift for ATTRIBUTE in "$@"; do - ATTRIBUTE_FILE="$GEMM_PATH/attribute/$ATTRIBUTE" + ATTRIBUTE_FILE="$GEMM_PATH/runtime_0000/attribute/$ATTRIBUTE" if [ ! -f "$ATTRIBUTE_FILE" ]; then echo "Error: Attribute file '$ATTRIBUTE_FILE' does not exist." exit 1 @@ -33,7 +36,7 @@ for ATTRIBUTE in "$@"; do ATTRIBUTE_FILES+=("$ATTRIBUTE_FILE") done MODELS_LIST="$GEMM_PATH/tile_graph.onnx" -ATTRIBUTE_PATH="$GEMM_PATH/attribute" +ATTRIBUTE_PATH="$GEMM_PATH/runtime_0000/attribute" for CONFIG in "${CONFIG_LIST[@]}"; do CONFIG_NAME=$(basename "$CONFIG" .json) @@ -41,16 +44,27 @@ for CONFIG in "${CONFIG_LIST[@]}"; do for ATTRIBUTE_FILE in "${ATTRIBUTE_FILES[@]}"; do ATTRIBUTE_NAME=$(basename "$ATTRIBUTE_FILE") - RESULTS_DIR="./results/$GEMM_DIR_NAME/$ATTRIBUTE_NAME" + RESULTS_DIR="./chiplet_results$INDEX_NAME/$GEMM_DIR_NAME/$ATTRIBUTE_NAME" mkdir -p "$RESULTS_DIR" OUTPUT_FILE="$RESULTS_DIR/${CONFIG_NAME}_result.txt" # Run Simulator echo "$SIMULATOR_PATH" --config "$CONFIG" --models_list "$MODELS_LIST" --attributes_list "$ATTRIBUTE_PATH/$ATTRIBUTE_NAME" "$SIMULATOR_PATH" --config "$CONFIG" --models_list "$MODELS_LIST" --log_level trace --attributes_list "$ATTRIBUTE_PATH/$ATTRIBUTE_NAME" > "$OUTPUT_FILE" & - - echo "===== Simulation for $CONFIG completed. Results saved to $OUTPUT_FILE =====" + echo "[BackendSimulator] for $CONFIG stored to \"$(pwd)/$OUTPUT_FILE\"" done done +for CONFIG in "${CONFIG_LIST2[@]}"; do + CONFIG_NAME=$(basename "$CONFIG" .json) + ATTRIBUTE_NAME=0 + RESULTS_DIR="./chiplet_results$INDEX_NAME/$GEMM_DIR_NAME/$ATTRIBUTE_NAME" + mkdir -p "$RESULTS_DIR" + OUTPUT_FILE="$RESULTS_DIR/${CONFIG_NAME}_result.txt" + + # Run Simulator + # echo "$SIMULATOR_PATH" --config "$CONFIG" --models_list "$MODELS_LIST" --attributes_list "$ATTRIBUTE_PATH/$ATTRIBUTE_NAME" + "$SIMULATOR_PATH" --config "$CONFIG" --models_list "$MODELS_LIST" --log_level trace --attributes_list "$ATTRIBUTE_PATH/$ATTRIBUTE_NAME" > "$OUTPUT_FILE" & + echo "[BackendSimulator] for $CONFIG stored to \"$(pwd)/$OUTPUT_FILE\"" +done wait \ No newline at end of file diff --git a/scripts/chiplet_prep.py b/scripts/chiplet_prep.py new file mode 100644 index 00000000..168532f1 --- /dev/null +++ b/scripts/chiplet_prep.py @@ -0,0 +1,120 @@ +import os +import json +import shutil +import argparse +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_matmul(device, input_size=128, hidden_size=128, output_size=128): + def custom_matmul(a, b): + return torch.matmul(a, b) + torch.manual_seed(0) + input = torch.randn(input_size, hidden_size) + weight = torch.randn(hidden_size, output_size) + x1 = input.to(device=device) + w1 = weight.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + opt_fn = torch.compile(dynamic=False)(custom_matmul) + res = opt_fn(x1, w1) + y = custom_matmul(x2, w2) + #test_result("Matmul Forward", res, y) + +def modify_file(dump_path, name, address_numa_stride=None, subgraph_map=None): + file_path = os.path.join(dump_path, 'runtime_0000', 'attribute', '0') + if not os.path.exists(file_path): + print(f"File {file_path} does not exist.") + return + with open(file_path, 'r') as f: + data = json.load(f) + # address_numa_stride와 subgraph_map 추가 + if address_numa_stride: + data['address_numa_stride'] = address_numa_stride + if subgraph_map: + data['subgraph_map'] = subgraph_map + + output_path = file_path = os.path.join(dump_path, 'runtime_0000', 'attribute') + os.makedirs(output_path, exist_ok=True) + output_file = os.path.join(output_path, name) + with open(output_file, 'w') as f: + json.dump(data, f, indent=4) + print(f"Modified file saved to {output_file}") + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + parser = argparse.ArgumentParser(description='Process folder argument.') + parser.add_argument('size', type=int, help='Folder value', default=256) + args = parser.parse_args() + + folder = int(args.size) + print("Taget size: ", folder) + folder_path = os.environ.get("TORCHSIM_DUMP_PATH") + print(folder_path) + os.makedirs(folder_path, exist_ok=True) + test_matmul(device, folder, folder, folder) + + pp = os.listdir(folder_path)[0] + dump_path = os.path.join(folder_path, pp) + pp = os.listdir(dump_path)[0] + dump_path = os.path.join(dump_path, pp) + subgraph_map_best = { "0": 0, "1": 0, "2": 1, "3": 1 } + subgraph_map_worst = { "0": 1, "1": 1, "2": 0, "3": 0 } + numa_stride = { "arg0" : [1], "arg1" : [1] , "arg2": [0 , 2] } + + subgraph_map_best1k = { "0": 0, "1": 0, "2": 1, "3": 1 } + subgraph_map_worst1k = { "0": 1, "1": 1, "2": 0, "3": 0 } + numa_stride_1k = { "arg0" : [1], "arg1" : [1] , "arg2": [0 , 2] } + + subgraph_map_best2k = { + "0": 0, + "1": 0, + "2": 0, + "3": 0, + "4": 1, + "5": 1, + "6": 1, + "7": 1 + } + subgraph_map_worst2k = { + "0": 1, + "1": 1, + "2": 1, + "3": 1, + "4": 0, + "5": 0, + "6": 0, + "7": 0 + } + numa_stride_2k = { "arg0" : [2], "arg1" : [1] , "arg2": [0 , 4] } + if args.size == 1024: + modify_file(dump_path, "best", numa_stride_1k, subgraph_map_best1k) + modify_file(dump_path, "worst", numa_stride_1k, subgraph_map_worst1k) + elif args.size == 2048: + modify_file(dump_path, "best", numa_stride_2k, subgraph_map_best2k) + modify_file(dump_path, "worst", numa_stride_2k, subgraph_map_worst2k) + else: + modify_file(dump_path, "best", numa_stride, subgraph_map_best) + modify_file(dump_path, "worst", numa_stride, subgraph_map_worst) + diff --git a/scripts/chiplet_prep.sh b/scripts/chiplet_prep.sh new file mode 100755 index 00000000..cddf1a58 --- /dev/null +++ b/scripts/chiplet_prep.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +sizes=(256 512 1024 2048) +for size in "${sizes[@]}"; do + echo "Processing size: $size" + + # Set environment variables + export TORCHSIM_TILE_M=$((size / 2)) + export TORCHSIM_TILE_K=$((size / 2)) + export TORCHSIM_TILE_N=$((size / 2)) + export TORCHSIM_DUMP_PATH=$(pwd)/chiplet_result/$size + python3 chiplet_prep.py $size + #python3 chiplet_run.py $(pwd)/chiplet_result +done \ No newline at end of file diff --git a/scripts/chiplet_run.py b/scripts/chiplet_run.py new file mode 100644 index 00000000..e53352e6 --- /dev/null +++ b/scripts/chiplet_run.py @@ -0,0 +1,37 @@ +import argparse +from pathlib import Path +import os + +def list_nested_folders(root_path): + root = Path(root_path) + + if not root.exists() or not root.is_dir(): + print(f"[Error] '{root}' is not a valid directory.") + return [] + + folders = set() + for p in root.rglob('*'): + if p.is_dir(): + rel_depth = len(p.relative_to(root).parts) + if rel_depth == 3: + folders.add(p) + + return sorted(folders) + +def main(): + parser = argparse.ArgumentParser(description="List folders up to depth 3 and parse arguments.") + + parser.add_argument("path", type=str, help="Root directory to start scanning") + parser.add_argument("--index", type=int, default=0, help="Index value (default: 0)") + parser.add_argument("--attr", nargs='*', default=["best", "worst"], + help='List of attr (default: ["best", "worst"])') + + args = parser.parse_args() + folders = list_nested_folders(args.path) + for folder in folders: + cmd = f"./chiplet.sh {folder} {args.index} {' '.join(args.attr)}" + print(cmd) + os.system(cmd) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/end2end.sh b/scripts/end2end.sh new file mode 100755 index 00000000..7ca5c93d --- /dev/null +++ b/scripts/end2end.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +# Base directory +BASE_PATH=$1 # Input as the first argument + +# Initialize the total cycle sum +total_sum=0 +total_core=0 +total_vector=0 +# Find all backendsim_result folders +mapfile -t backend_folders < <(find "$BASE_PATH" -type d -name "backendsim_result") + +# Iterate over each backendsim_result folder +for backend_folder in "${backend_folders[@]}"; do + # echo "Processing folder: $backend_folder" + + # Find all files within the backendsim_result folder + mapfile -t files < <(find "$backend_folder" -type f) + + for file in "${files[@]}"; do + # echo "Processing $file" + + # Extract the last line containing "Total cycle" + total_cycle=$(grep "Total cycle" "$file" | tail -n 1 | sed -E 's/.*Total cycle ([0-9]+).*/\1/') + # echo "total_cycle: $total_cycle" + active_cycles=($(grep -o 'active cycle [0-9]*' "$file" | awk '{print $3}')) + num_cycles=${#active_cycles[@]} + if [ "$num_cycles" -ge 3 ]; then + core_cycle=${active_cycles[$((num_cycles-3))]} + else + echo "Error: cannot find core active cycle" + fi + if [[ "$num_cycles" -ge 1 ]]; then + # Extract the last two active cycles + vector_core_cycle=${active_cycles[$((num_cycles-1))]} + else + echo "Error: cannot find vector core active cycle" + fi + echo "file: $file total_cycle: $total_cycle SA core_cycle: $core_cycle vector_core_cycle: $vector_core_cycle" + + if [[ -n "$total_cycle" ]]; then + # Add the total cycle to the total sum + # echo "Adding $total_cycle to total_sum" + total_sum=$((total_sum + total_cycle)) + fi + if [[ -n "$core_cycle" ]]; then + # Add the total cycle to the total sum + # echo "Adding $total_cycle to total_sum" + total_core=$((total_core + core_cycle)) + fi + if [[ -n "$vector_core_cycle" ]]; then + # Add the total cycle to the total sum + # echo "Adding $total_cycle to total_sum" + total_vector=$((total_vector + vector_core_cycle)) + fi + done +done + +# Print the total cycle sum +echo "total end2end cycle: $total_sum" +echo "total core cycle: $total_core" +echo "total vector core cycle: $total_vector" \ No newline at end of file diff --git a/scripts/get_tog_result.sh b/scripts/get_tog_result.sh new file mode 100755 index 00000000..9359e1e5 --- /dev/null +++ b/scripts/get_tog_result.sh @@ -0,0 +1,36 @@ +#!/bin/bash +total_cycles=0 + +# Read through input stream line by line +while IFS= read -r line; do + # Check if the line contains both "[BackendSimulator]" and "stored" + if [[ "$line" == *"[BackendSimulator]"* && "$line" == *"stored"* ]]; then + # Extract the file path from the line + file_path=$(echo "$line" | sed -n 's/.*stored to "\(.*\)"$/\1/p') + + # If the file exists, grep for "Total cycle" and output the last matching line + if [[ -f "$file_path" ]]; then + last_line=$(grep "Total cycle" "$file_path" | tail -n 1) + echo "$last_line ($file_path)" + # Accumulate the cycle value + cycle_value=$(echo "$last_line" | sed -n 's/.*Total cycle \([0-9]\+\)$/\1/p') + total_cycles=$((total_cycles + cycle_value)) + else + echo "File not found: $file_path" + fi + fi + # Check if the line ends with "Test passed|" + if [[ "$line" == *"Test Passed|" ]]; then + echo "$line" + echo "Accumulated Total Cycle: $total_cycles" + total_cycles=0 + fi + if [[ "$line" == *"Test Failed|" ]]; then + echo "$line" + echo "Accumulated Total Cycle: $total_cycles" + total_cycles=0 + fi + if [[ "$line" == *"[log]"* ]]; then + echo "$line" + fi +done \ No newline at end of file diff --git a/scripts/sim_time.sh b/scripts/sim_time.sh new file mode 100755 index 00000000..15c60736 --- /dev/null +++ b/scripts/sim_time.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Base directory +BASE_PATH=$1 # Input as the first argument + +# Initialize total_sum as string for awk processing +total_sum=0.0 + +# Find all backendsim_result folders +mapfile -t backend_folders < <(find "$BASE_PATH" -type d -name "backendsim_result") + +# Iterate over each backendsim_result folder +for backend_folder in "${backend_folders[@]}"; do + mapfile -t files < <(find "$backend_folder" -type f) + + for file in "${files[@]}"; do + sim_time=$(grep "Simulation time:" "$file" | tail -n 1 | sed -E 's/.*Simulation time: ([0-9]+(\.[0-9]+)?).*/\1/') + echo "file: $file total_cycle: $sim_time" + + if [[ -n "$sim_time" ]]; then + total_sum=$(awk -v a="$total_sum" -v b="$sim_time" 'BEGIN {printf "%.6f", a + b}') + fi + done +done + +# Print the total simulation time +echo "simulation time: $total_sum" diff --git a/scripts/sparsity_experiment/parse.py b/scripts/sparsity_experiment/parse.py new file mode 100644 index 00000000..7b15e156 --- /dev/null +++ b/scripts/sparsity_experiment/parse.py @@ -0,0 +1,74 @@ +import argparse +import os +import subprocess + +def get_stored_paths(log_file): + """Extracts stored file paths from the given log file.""" + stored_paths = [] + try: + result = subprocess.run(["grep", "stored", log_file], capture_output=True, text=True) + for line in result.stdout.splitlines(): + parts = line.split(" ") + if "stored" in parts: + index = parts.index("stored") + if index + 1 < len(parts): + stored_paths.append(parts[index + 2].strip('"')) + except Exception as e: + print(f"Error reading stored paths: {e}") + return stored_paths + +def get_last_total_cycle(file_path): + """Extracts the last Total cycle value from the given file.""" + total_cycle = None + try: + result = subprocess.run(["grep", "Total cycle", file_path], capture_output=True, text=True) + lines = result.stdout.splitlines() + if lines: + last_line = lines[-1] + total_cycle = last_line.split()[-1] # Extract the last value + except Exception as e: + print(f"Error reading total cycle from {file_path}: {e}") + return total_cycle + +def main(log_file): + stored_paths = get_stored_paths(log_file) + k = [] + for path in stored_paths: + print(path) + if os.path.exists(path): + total_cycle = get_last_total_cycle(path) + if total_cycle: + k.append(total_cycle) + else: + print(f"{path}: No Total cycle found") + else: + print(f"{path}: File does not exist") + return k + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Extract Total Cycle from stored paths.") + parser.add_argument("log_file", type=str, help="Path to the log file containing stored paths") + args = parser.parse_args() + a_l = [] + b_l = [] + if os.path.exists(args.log_file): + a, b = main(args.log_file + "/0.0") + a_l.append(a) + b_l.append(b) + a, b = main(args.log_file + "/0.2") + a_l.append(a) + b_l.append(b) + a, b = main(args.log_file + "/0.4") + a_l.append(a) + b_l.append(b) + a, b = main(args.log_file + "/0.6") + a_l.append(a) + b_l.append(b) + a, b = main(args.log_file + "/0.8") + a_l.append(a) + b_l.append(b) + print(" ".join(a_l)) + print(" ".join(b_l)) + + else: + print(f"Log file {args.log_file} not found.") diff --git a/scripts/sparsity_experiment/run.sh b/scripts/sparsity_experiment/run.sh new file mode 100755 index 00000000..0b7bc6f5 --- /dev/null +++ b/scripts/sparsity_experiment/run.sh @@ -0,0 +1,53 @@ +export TORCHSIM_DUMP_PATH=$(pwd)/result +export SPIKE_DUMP_SPARSE_TILE=1 +export TORCHSIM_FORCE_TIME_K=8 +export TORCHSIM_FORCE_TIME_M=8 +export TORCHSIM_FORCE_TIME_N=8 + +OUTPUT_DIR="12GB" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_8x8_c1_12G_simple_noc.json" +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT_DIR}/0.6 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 + +OUTPUT_DIR="24GB" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_8x8_c1_24G_simple_noc.json" +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT_DIR}/0.6 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 + +OUTPUT_DIR="48GB" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_8x8_c1_48G_simple_noc.json" +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT_DIR}/0.6 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 + +OUTPUT_DIR="12GB_2core" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_8x8_c2_12G_simple_noc.json" +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT_DIR}/0.6 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 + +OUTPUT_DIR="24GB_2core" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_8x8_c2_24G_simple_noc.json" +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT_DIR}/0.6 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 + +OUTPUT_DIR="48GB_2core" +export TORCHSIM_CONFIG="/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_8x8_c2_48G_simple_noc.json" +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.0 > ${OUTPUT_DIR}/0.0 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.2 > ${OUTPUT_DIR}/0.2 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.4 > ${OUTPUT_DIR}/0.4 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.6 > ${OUTPUT_DIR}/0.6 +python3 /workspace/PyTorchSim/tests/test_sparsity.py --sparsity 0.8 > ${OUTPUT_DIR}/0.8 diff --git a/scripts/stonne_experiment/run.sh b/scripts/stonne_experiment/run.sh new file mode 100755 index 00000000..1825817f --- /dev/null +++ b/scripts/stonne_experiment/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash +export TORCHSIM_FORCE_TIME_M=1024 +export TORCHSIM_FORCE_TIME_K=1024 +export TORCHSIM_FORCE_TIME_N=1024 +python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config stonne_big_c1_simple_noc.json --mode 0 > hetero/big_sparse.log +python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config systolic_ws_128x128_c1_simple_noc_tpuv3_half.json --mode 1 > hetero/big.log +python3 ../../tests/test_hetro.py --M 1024 --N 1024 --K 1024 --sparsity 0.9 --config heterogeneous_c2_simple_noc.json --mode 2 > hetero/hetero.log + +echo "All processes completed!" diff --git a/scripts/stonne_experiment/run_trace.sh b/scripts/stonne_experiment/run_trace.sh new file mode 100755 index 00000000..5a4ff890 --- /dev/null +++ b/scripts/stonne_experiment/run_trace.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +SCRIPT="/workspace/PyTorchSim/tests/test_stonne.py" + +SIZES=(32 64 128) +SPARSITIES=(0.0 0.2 0.4 0.6 0.8) + +for sz in "${SIZES[@]}"; do + for sparsity in "${SPARSITIES[@]}"; do + FILE_PATH=$(python "$SCRIPT" "$sz" "$sparsity" | grep -oP '(?<=stored to ")[^"]+') + TOTAL_CYCLE=$(grep -oP '\[.*?\] \[info\] Stonne Core \[0\] : Total cycle \K\d+' "$FILE_PATH" | tail -n 1) + echo "Stonne $sz $sparsity $TOTAL_CYCLE" + + FILE_PATH=$(python "$SCRIPT" "$sz" "$sparsity" | grep -oP '(?<=stored to ")[^"]+') + TOTAL_CYCLE=$(grep -oP '\[.*?\] \[info\] Stonne Core \[0\] : Total cycle \K\d+' "$FILE_PATH" | tail -n 1) + echo "TOG $sz $sparsity $TOTAL_CYCLE" + done +done \ No newline at end of file diff --git a/scripts/stonne_experiment2/tog_gen.py b/scripts/stonne_experiment2/tog_gen.py new file mode 100644 index 00000000..2f184f4c --- /dev/null +++ b/scripts/stonne_experiment2/tog_gen.py @@ -0,0 +1,85 @@ +import os +import sys +import re +import glob +from collections import defaultdict +sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +from AsmParser.tog_generator import tog_generator +from Simulator.simulator import BackendSimulator +from PyTorchSimFrontend import extension_config + +def extract_simulation_stats(result_path): + with open(result_path, "r") as f: + lines = f.readlines()[-4:] + + nr_multiplications = None + total_cycle = None + sim_time = None + + for line in lines: + if "nr_multiplications" in line: + nr_multiplications = line.strip().split(":")[-1].strip() + elif "Total execution cycle" in line: + total_cycle = line.strip().split(":")[-1].strip() + elif "Simulation time" in line: + sim_time = line.strip().split(":")[-1].replace("seconds", "").strip() + return nr_multiplications, total_cycle, sim_time + +if __name__ == "__main__": + base_dir = "/home/workspace/stonneResult" + trace_mode_paths = [] + perf_mode_paths = [] + for root, dirs, files in os.walk(base_dir): + if "raw_tog.py" in files: + raw_tog_path = os.path.join(root, "raw_tog.py") + tog_path = os.path.join(root, "tile_graph.onnx") + if not os.path.exists(tog_path): + tile_graph_generator = tog_generator([root]) + tile_graph_generator.load_file(raw_tog_path) + tile_graph_generator.generate_tile_graph( + tog_path, + cycle_list=[0], + x_offset=0, + w_offset=0, + vector_lane=0, + stonneGraph=True + ) + print(f"TOG genereted at {tog_path}") + rel_depth = os.path.relpath(root, base_dir).count(os.sep) + if rel_depth == 0: + trace_mode_paths.append(root) + else: + perf_mode_paths.append(root) + cycle_list = {} + simul_list = defaultdict(list) + for path in perf_mode_paths: + parent = os.path.dirname(path) + counter_files = glob.glob(os.path.join(path, "*.counters")) + for counter_file in counter_files: + with open(counter_file, 'r') as f: + first_line = f.readline().strip() + second_line = f.readline().strip() + if first_line.startswith("CYCLES="): + cycle = int(first_line.split("=")[1]) + cycle_list[parent] = cycle + if second_line.startswith("Simulation time="): + match = re.search(r'Simulation time=([0-9.]+)', second_line) + simul_list[parent].append(float(match.group(1))) + + print("\n=== Run TLS simulation ===") + for path in trace_mode_paths: + if "outerPro" in path: + continue + tog_path = os.path.join(path, "tile_graph.onnx") + backend_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "PyTorchSimBackend") + stonne_config_path = f'{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimBackend/configs/stonne_validation_c1_simple_noc.json' + backsim = BackendSimulator(backend_path, stonne_config_path) + result_path = backsim.simulation(tog_path) + nr_multiplications, total_cycle, sim_time = extract_simulation_stats(result_path) + sim_time, total_cycle = float(sim_time), int(total_cycle) + print(f"[TLS] Cycle={total_cycle} Sim time={sim_time} nr_multiplications={nr_multiplications}") + avg_simul = sum(simul_list[path]) / len(simul_list[path]) + print(f"[ILS] Cycle={cycle_list[path]} Sim time= {avg_simul} at {path}") + speedup = avg_simul / sim_time if avg_simul != 0 else float('inf') + error_rate = abs(cycle_list[path] - int(total_cycle)) / total_cycle if total_cycle != 0 else float('inf') + print(f"[EVAL] Speedup={speedup:.3f}x Error rate={error_rate:.4%}") \ No newline at end of file diff --git a/test_extension_backend.py b/test_extension_backend.py index 6c5429c7..f0a9353a 100644 --- a/test_extension_backend.py +++ b/test_extension_backend.py @@ -1,6 +1,6 @@ import torch._dynamo import torch.utils.cpp_extension -from tests.test_add import test_vectoradd +from tests.test_add import test_vectoradd, test_vector_scalar_add from tests.test_reduce import test_reduce_sum from tests.test_transpose2D import test_Transpose2D, test_Transpose2D_2 from tests.test_transpose3D import test_Transpose3D_1, test_Transpose3D_2, test_Transpose3D_3 @@ -12,7 +12,7 @@ from tests.test_matmul import test_matmul from tests.test_bmm import test_BMM from tests.test_cnn import test_CNN -from tests.test_transformer import test_DecoderBlock +from tests.test_transformer import test_EncoderBlock from tests.test_resnet import test_resnet from tests.test_mlp import test_mlp, test_mlp_inf from tests.MoE.test_moe import test_moe @@ -25,30 +25,34 @@ from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_vectoradd(device, (47, 10)) - test_reduce_sum(device, (29, 47), 1, keepdim=True) - test_reduce_sum(device, (17, 68), 0, keepdim=True) - test_Transpose2D(device, [64, 156]) - test_Transpose2D_2(device, [16, 64]) - test_Transpose3D_1(device, [62, 34, 44]) - test_Transpose3D_2(device, [62, 34, 44]) - test_Transpose3D_3(device, [62, 34, 44]) - test_view3D_2D(device) + #test_vectoradd(device, (47, 10)) + #test_vector_scalar_add(device, (10, 10)) + #test_reduce_sum(device, (32, 32), 1, keepdim=True) + #test_reduce_sum(device, (32, 32), 0, keepdim=True) + #test_reduce_sum(device, (512, 512), 1, keepdim=True) + #test_reduce_sum(device, (512, 512), 0, keepdim=True) + #test_Transpose2D(device, [64, 156]) + #test_Transpose2D_2(device, [16, 64]) + #test_Transpose3D_1(device, [62, 34, 256]) + #test_Transpose3D_2(device, [62, 34, 256]) + #test_Transpose3D_3(device, [62, 34, 256]) + #test_view3D_2D(device) test_maxpool(device) - test_avgpool(device) - test_softmax(device, (64, 128), dim=1) - test_BatchNorm(device) - test_LayerNorm(device, (64, 128)) - test_conv2d(device) - test_matmul(device, 33, 45, 68) - test_BMM(device) - test_CNN(device) - test_DecoderBlock(device) - test_resnet(device) - test_mlp(device) - test_mlp_inf(device, batch_size=64, input_size=256, hidden_size=512, output_size=256, sparsity=0.97) + #test_avgpool(device) + #test_softmax(device, (256, 256), dim=1) + #test_BatchNorm(device) + #test_LayerNorm(device, (64, 128)) + #test_conv2d(device) + #test_matmul(device, 33, 45, 68) + #test_BMM(device) + #test_CNN(device) + #test_EncoderBlock(device) + #test_resnet(device) + #test_mlp(device) + #test_mlp_inf(device, batch_size=64, input_size=256, hidden_size=512, output_size=256, sparsity=0.97) # # Fusion Test - test_matmul_scalar(device) - test_matmul_activation(device, batch_size=32, input_size=32, output_size=32, activation_fn="sigmoid") - test_addmm_residual(device) + #test_matmul_scalar(device) + #test_matmul_activation(device, batch_size=32, input_size=32, output_size=32, activation_fn="relu") + #test_matmul_activation(device, batch_size=32, input_size=32, output_size=32, activation_fn="sigmoid") + #test_addmm_residual(device) diff --git a/tests/Fusion/test_addmm_residual.py b/tests/Fusion/test_addmm_residual.py index 10f387e8..a5e05182 100644 --- a/tests/Fusion/test_addmm_residual.py +++ b/tests/Fusion/test_addmm_residual.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) diff --git a/tests/Fusion/test_attention_fusion.py b/tests/Fusion/test_attention_fusion.py new file mode 100644 index 00000000..95bdf165 --- /dev/null +++ b/tests/Fusion/test_attention_fusion.py @@ -0,0 +1,83 @@ +import math +import copy +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def clones(module, N): + "Produce N identical layers." + return torch.nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + +class my_MultiheadAttention(torch.nn.Module): + def __init__(self, h, d_model, dropout=0.1): + super(my_MultiheadAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linear = torch.nn.Linear(d_model, d_model) + self.attn = None + + def forward(self, query, key, value): + # BMM + Max + scores = torch.matmul(key, query.transpose(-2, -1)) + s_max = scores.max(dim=-2, keepdim=True).values + + # Reduce Sum + scores = torch.exp(scores-s_max) + s_sum = scores.sum(dim=-2, keepdim=True) + + # Elementwise + BMM + p_attn = scores/s_sum + x = torch.matmul(value.transpose(-1, -2), p_attn) + # 3) "Concat" using a view and apply a final linear. + x = ( + x.view(-1, self.h * self.d_k) + ) + del query + del key + del value + return self.linear(x) + +def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): + MHA = my_MultiheadAttention(num_heads, embed_dim) + cpu_query = torch.randn(num_heads, input_seq, embed_dim//num_heads) + cpu_key = torch.randn(num_heads, input_seq, embed_dim//num_heads) + cpu_value = torch.randn(num_heads, input_seq, embed_dim//num_heads) + cpu_res = MHA(cpu_query, cpu_key, cpu_value) + + query = cpu_query.clone().to(device=device) + key = cpu_key.clone().to(device=device) + value = cpu_value.clone().to(device=device) + MHA.to(device=device) + opt_fn = torch.compile(dynamic=False)(MHA) + res = opt_fn(query, key, value) + + test_result("MHA Forward", res, cpu_res) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_MHA(device) + # test_Attention(device, head=16, seq=512, d_k=64) + # test_MHA(device, num_heads=12, embed_dim=768) diff --git a/tests/Fusion/test_bmm_reduction.py b/tests/Fusion/test_bmm_reduction.py new file mode 100644 index 00000000..42e38095 --- /dev/null +++ b/tests/Fusion/test_bmm_reduction.py @@ -0,0 +1,52 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_bmm_reduce(device, batch=12, size=512): + def bmm(a, b): + result = torch.bmm(a, b.transpose(1,2)) + return result, result.max(dim=1).values + torch.manual_seed(0) + N = size + input = torch.randn(batch, N, 64) + weight = torch.randn(batch, N, 64) + #input = torch.arange(1, N * N + 1, dtype=torch.float32).reshape(N, N).to(dtype=torch.float32) + #weight = torch.eye(N, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + opt_fn = torch.compile(dynamic=False)(bmm) + res = opt_fn(x1, w1) + y = bmm(x2, w2) + test_result("BMM Reduction Fusion activation", res[0], y[0]) + test_result("BMM Reduction Fusion reduction", res[1], y[1]) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + #test_bmm_reduce(device) + test_bmm_reduce(device, 12, 512) + test_bmm_reduce(device, 4, 256) + test_bmm_reduce(device, 6, 768) + test_bmm_reduce(device, 2, 128) diff --git a/tests/Fusion/test_conv_fusion.py b/tests/Fusion/test_conv_fusion.py new file mode 100644 index 00000000..42210b13 --- /dev/null +++ b/tests/Fusion/test_conv_fusion.py @@ -0,0 +1,124 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + message = f"|{name} Test Passed|" + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + print("Failed") + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + # exit(1) + +def test_conv_residual(device, batch_size=1, in_channels=8, out_channels=16, input_size=64, kernel_size=3, stride=1, padding=0): + def custom_conv2d(a, b, bias, c): + i_c = a.shape[1] + o_c = b.shape[0] + conv2d = torch.nn.Conv2d(i_c, o_c, b.shape[-1], stride=stride, padding=padding, dilation=1, bias=True) + conv2d.weight = torch.nn.Parameter(b) + conv2d.bias = torch.nn.Parameter(bias) + return conv2d(a) + c + torch.manual_seed(0) + conv_input = torch.randn(batch_size, in_channels, input_size, input_size).to(memory_format=torch.channels_last, device=device) + conv_kernel = torch.randn(out_channels, in_channels, kernel_size, kernel_size).to(memory_format=torch.channels_last, device=device) + conv_bias = torch.randn(out_channels).to(device=device) + o_h = (input_size + 2 * padding - kernel_size) // stride + 1 + o_w = (input_size + 2 * padding - kernel_size) // stride + 1 + add_tensor = torch.randn(batch_size, out_channels, o_h, o_w).to(device=device) + opt_fn = torch.compile(dynamic=False)(custom_conv2d) + res = opt_fn(conv_input, conv_kernel, conv_bias, add_tensor) + out = custom_conv2d(conv_input.cpu(), conv_kernel.cpu(), conv_bias.cpu(), add_tensor.cpu()) + test_result("Conv2d Residual Fusion Forward", res, out, rtol=1e-3, atol=1e-3) + print("Max diff > ", torch.max(torch.abs(res.cpu() - out))) + + +def test_conv_scalar(device, batch_size=1, in_channels=8, out_channels=16, input_size=64, kernel_size=3, stride=1, padding=0): + def custom_conv2d(a, b, bias, c): + i_c = a.shape[1] + o_c = b.shape[0] + conv2d = torch.nn.Conv2d(i_c, o_c, b.shape[-1], stride=stride, padding=padding, dilation=1, bias=False) + conv2d.weight = torch.nn.Parameter(b) + # conv2d.bias = torch.nn.Parameter(bias) + return conv2d(a) * c + torch.manual_seed(0) + conv_input = torch.randn(batch_size, in_channels, input_size, input_size).to(memory_format=torch.channels_last, device=device) + conv_kernel = torch.randn(out_channels, in_channels, kernel_size, kernel_size).to(memory_format=torch.channels_last, device=device) + conv_bias = torch.randn(out_channels).to(device=device) + opt_fn = torch.compile(dynamic=False)(custom_conv2d) + res = opt_fn(conv_input, conv_kernel, conv_bias, 2) + out = custom_conv2d(conv_input.cpu(), conv_kernel.cpu(), conv_bias.cpu(), 2) + test_result("Conv2d + Scalar Fusion Forward", res, out, rtol=1e-3, atol=1e-3) + print("Max diff > ", torch.max(torch.abs(res.cpu() - out))) + +def test_conv_relu(device, batch_size=1, in_channels=8, out_channels=16, input_size=64, kernel_size=3, stride=1, padding=0): + def custom_conv2d(a, b, bias): + i_c = a.shape[1] + o_c = b.shape[0] + conv2d = torch.nn.Conv2d(i_c, o_c, b.shape[-1], stride=stride, padding=padding, dilation=1, bias=True) + conv2d.weight = torch.nn.Parameter(b) + conv2d.bias = torch.nn.Parameter(bias) + return torch.nn.functional.relu(conv2d(a)) + torch.manual_seed(0) + conv_input = torch.randn(batch_size, in_channels, input_size, input_size).to(memory_format=torch.channels_last, device=device) + conv_kernel = torch.randn(out_channels, in_channels, kernel_size, kernel_size).to(memory_format=torch.channels_last, device=device) + conv_bias = torch.randn(out_channels).to(device=device) + opt_fn = torch.compile(dynamic=False)(custom_conv2d) + res = opt_fn(conv_input, conv_kernel, conv_bias) + out = custom_conv2d(conv_input.cpu(), conv_kernel.cpu(), conv_bias.cpu()) + test_result("Conv2d + ReLU Fusion Forward", res, out, rtol=1e-3, atol=1e-3) + print("Max diff > ", torch.max(torch.abs(res.cpu() - out))) + +def test_conv_bn_relu(device, batch_size=1, in_channels=8, out_channels=16, input_size=64, kernel_size=3, stride=1, padding=0): + def custom_conv_bn_relu(a, b, bias, c, d, e, f): + i_c = a.shape[1] + o_c = b.shape[0] + conv2d = torch.nn.Conv2d(in_channels, out_channels, b.shape[-1], stride=stride, padding=padding, dilation=1, bias=True).eval() + conv2d.weight = torch.nn.Parameter(b) + conv2d.bias = torch.nn.Parameter(bias) + # return torch.nn.functional.batch_norm(conv2d(a), c, d, weight=e, bias=f) + return torch.nn.functional.relu(torch.nn.functional.batch_norm(conv2d(a), c, d, weight=e, bias=f)) + torch.manual_seed(0) + conv_input = torch.randn(batch_size, in_channels, input_size, input_size).to(memory_format=torch.channels_last, device=device) + conv_kernel = torch.randn(out_channels, in_channels, kernel_size, kernel_size).to(memory_format=torch.channels_last, device=device) + conv_bias = torch.randn(out_channels).to(device=device) + bn_weight = torch.randn(out_channels).to(device=device) + bn_bias = torch.randn(out_channels).to(device=device) + bn_mean = torch.zeros(out_channels).to(device=device) + bn_var = torch.ones(out_channels).to(device=device) + opt_fn = torch.compile(dynamic=False)(custom_conv_bn_relu) + with torch.no_grad(): + res = opt_fn(conv_input, conv_kernel, conv_bias, bn_mean, bn_var, bn_weight, bn_bias) + out = custom_conv_bn_relu(conv_input.cpu(), conv_kernel.cpu(), conv_bias.cpu(), bn_mean.cpu(), bn_var.cpu(), bn_weight.cpu(), bn_bias.cpu()) + test_result("Conv2d + BN + ReLU Fusion Forward", res, out, rtol=1e-3, atol=1e-3) + print("Max diff > ", torch.max(torch.abs(res.cpu() - out))) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + + # Vanila test + test_conv_residual(device, batch_size=3, in_channels=64, out_channels=64, input_size=28, kernel_size=3, stride=1, padding=1) + + # Multi-tile test + test_conv_residual(device, batch_size=1, in_channels=3, out_channels=32, input_size=32, kernel_size=3, stride=1, padding=1) + + # Single batch test + test_conv_residual(device, batch_size=1, in_channels=16, out_channels=16, input_size=64, kernel_size=3, stride=1, padding=1) + + # Scalar + test_conv_scalar(device, batch_size=1, in_channels=16, out_channels=48, input_size=48, kernel_size=3, stride=1, padding=1) + + # Relu + test_conv_relu(device, batch_size=1, in_channels=16, out_channels=16, input_size=64, kernel_size=3, stride=1, padding=1) + + # Conv + BN + ReLU + test_conv_bn_relu(device, batch_size=1, in_channels=8, out_channels=16, input_size=64, kernel_size=3, stride=1, padding=1) \ No newline at end of file diff --git a/tests/Fusion/test_matmul_activation.py b/tests/Fusion/test_matmul_activation.py index fc7960c5..2381bd8c 100644 --- a/tests/Fusion/test_matmul_activation.py +++ b/tests/Fusion/test_matmul_activation.py @@ -4,12 +4,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) diff --git a/tests/Fusion/test_matmul_reduction.py b/tests/Fusion/test_matmul_reduction.py new file mode 100644 index 00000000..31ea1b0d --- /dev/null +++ b/tests/Fusion/test_matmul_reduction.py @@ -0,0 +1,97 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_matmul_reduce(device, M=512, N=512, K=512): + def matmul_fused(a, b): + result = torch.matmul(a, b) + return result, result.max(dim=-2).values + torch.manual_seed(0) + input = torch.randn(M, K) + weight = torch.randn(K, N) + #input = torch.arange(1, M * K + 1, dtype=torch.float32).reshape(M, K).to(dtype=torch.float32) + #weight = torch.eye(K, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1) + y = matmul_fused(x2, w2) + test_result("Matmul Reduction Fusion activation", res[0], y[0]) + test_result("Matmul Reduction Fusion reduction", res[1], y[1]) + +def test_matmul_var_mean(device, size=512): + def matmul_fused(a, b, c): + result = torch.matmul(a, b.T) + var, mean = torch.var_mean(result, dim=-2) + return result, var, mean + torch.manual_seed(0) + N = size + input = torch.randn(1024, 768) + weight = torch.randn(512, 768) + #input = torch.arange(1, N * N + 1, dtype=torch.float32).reshape(N, N).to(dtype=torch.float32) + #weight = torch.eye(N, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + c = 7 + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, c) + y = matmul_fused(x2, w2, c) + test_result("Matmul var_mean Fusion activation", res[0], y[0]) + test_result("Matmul var_mean Fusion reduction", res[1], y[1]) + test_result("Matmul var_mean Fusion reduction", res[2], y[2]) + +def test_matmul_add_var_mean(device, M=768, N=512, K=3072): + def matmul_fused(a, b, c, d): + result = torch.matmul(a, b.T) + c.T + var, mean = torch.var_mean(result + d, dim=-2) + return result, var, mean + torch.manual_seed(0) + input = torch.randn(M, K) + weight = torch.randn(N, K) + bias = torch.zeros(N, M) + residual = torch.randn(M,N) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + r1 = residual.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + r2 = residual.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, b1, r1) + y = matmul_fused(x2, w2, b2, r2) + test_result("Matmul+residual+var_mean Fusion activation", res[0], y[0]) + test_result("Matmul+residual+var_mean Fusion reduction", res[1], y[1]) + test_result("Matmul+residual+var_mean Fusion reduction", res[2], y[2]) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_matmul_reduce(device, 3072, 512, 768) + test_matmul_var_mean(device) + test_matmul_add_var_mean(device) diff --git a/tests/Fusion/test_matmul_scalar.py b/tests/Fusion/test_matmul_scalar.py index b29f37f8..0dcb54f9 100644 --- a/tests/Fusion/test_matmul_scalar.py +++ b/tests/Fusion/test_matmul_scalar.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) diff --git a/tests/Fusion/test_prologue_fusion.py b/tests/Fusion/test_prologue_fusion.py new file mode 100644 index 00000000..797f9e76 --- /dev/null +++ b/tests/Fusion/test_prologue_fusion.py @@ -0,0 +1,97 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_elem_broadcast_fusion(device): + def matmul_fused(a, b, c): + return torch.matmul(c * a, b) + torch.manual_seed(0) + input = torch.randn(128, 128) + weight = torch.randn(128, 128) + c = torch.randn(128, 1, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + c1 = c.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + c2 = c.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, c1) + y = matmul_fused(x2, w2, c2) + test_result("Matmul Scalar Fusion Forward", res, y) + +def test_elem_fusion(device): + def matmul_fused(a, b, c): + return torch.matmul(c * a, b) + torch.manual_seed(0) + input = torch.randn(128, 128) + weight = torch.randn(128, 128) + c = torch.randn(128, 128, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + c1 = c.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + c2 = c.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, c1) + y = matmul_fused(x2, w2, c2) + test_result("Matmul Element-wise Fusion Forward", res, y) + +def test_elem_bmm_weight_fusion(device, batch_size=1, m=512, n=512, k=64): + def bmm(a, b, c, d): + return torch.bmm(a , (d+b)*c) + torch.manual_seed(0) + a = torch.randn(batch_size, m, k).to(device=device) + b = torch.randn(batch_size, 1, n).to(device=device) + c = torch.randn(batch_size, 1, n) + c = c.to(device=device) + d = torch.randn(batch_size, k, n).to(device=device) + opt_fn = torch.compile(dynamic=False)(bmm) + res = opt_fn(a, b, c, d) + out = bmm(a.cpu(), b.cpu(), c.cpu(), d.cpu()) + print(torch.max(torch.abs(res.cpu() - out))) + test_result("BMM Element-wise Fusion Forward", res, out) + +def test_elem_bmm_input_fusion(device, batch_size=1, m=512, n=512, k=64): + def bmm(a, b, c, d): + return torch.bmm((a+b)*c , d) + torch.manual_seed(0) + a = torch.randn(batch_size, m, k).to(device=device) + b = torch.randn(batch_size, 1, k).to(device=device) + c = torch.randn(batch_size, 1, k) + c = c.to(device=device) + d = torch.randn(batch_size, k, n).to(device=device) + opt_fn = torch.compile(dynamic=False)(bmm) + res = opt_fn(a, b, c, d) + out = bmm(a.cpu(), b.cpu(), c.cpu(), d.cpu()) + print(torch.max(torch.abs(res.cpu() - out))) + test_result("BMM Element-wise Fusion Forward", res, out) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_elem_broadcast_fusion(device) + test_elem_fusion(device) + test_elem_bmm_input_fusion(device, batch_size=4, m=512, n=512, k=64) + test_elem_bmm_weight_fusion(device, batch_size=12, m=512, n=512, k=64) \ No newline at end of file diff --git a/tests/Fusion/test_transformer_fusion.py b/tests/Fusion/test_transformer_fusion.py new file mode 100644 index 00000000..0e500b5b --- /dev/null +++ b/tests/Fusion/test_transformer_fusion.py @@ -0,0 +1,213 @@ +import math +import copy +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def clones(module, N): + "Produce N identical layers." + return torch.nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + +class my_MultiheadAttention_origin(torch.nn.Module): + def __init__(self, h, d_model, dropout=0.1): + super(my_MultiheadAttention_origin, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(torch.nn.Linear(d_model, d_model), 4) + self.attn = None + + def forward(self, query, key, value): + # 1) Do all the linear projections in batch from d_model => h x d_k + query, key, value = [ + lin(x).view(-1, self.h, self.d_k).transpose(0, 1) + for lin, x in zip(self.linears, (query, key, value)) + ] + + # 2) Apply attention on all the projected vectors in batch. + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(self.d_k) + p_attn = scores.softmax(dim=-2) + x = torch.matmul(value.transpose(-1, -2), p_attn) + # 3) "Concat" using a view and apply a final linear. + x = ( + x.view(-1, self.h * self.d_k) + ) + del query + del key + del value + return self.linears[-1](x) + +class EncoderBlock_origin(torch.nn.Module): + def __init__(self, embed_dim, num_heads): + super(EncoderBlock_origin, self).__init__() + self.multihead_attn = my_MultiheadAttention_origin(num_heads, embed_dim) + self.layer_norm = torch.nn.LayerNorm(embed_dim) + self.ffn1 = torch.nn.Linear(embed_dim, embed_dim*4) + self.act = torch.nn.ReLU() + self.ffn2 = torch.nn.Linear(embed_dim*4, embed_dim) + + def forward(self, x): + result = self.multihead_attn(x, x, x).reshape(x.shape) + result = self.layer_norm(result+x) + + ffn1_result = self.ffn1(result) + act_result = self.act(ffn1_result) + ffn2_result = self.ffn2(act_result) + return self.layer_norm(ffn2_result + result) + +class my_MultiheadAttention(torch.nn.Module): + def __init__(self, h, d_model, dropout=0.1): + super(my_MultiheadAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(torch.nn.Linear(d_model, d_model), 3) + self.attn = None + + def forward(self, query, key, value): + # 1) Do all the linear projections in batch from d_model => h x d_k + query, key, value = [ + lin(x).view(-1, self.h, self.d_k).transpose(0, 1) + for lin, x in zip(self.linears, (query, key, value)) + ] + + # 2) Apply attention on all the projected vectors in batch. + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(self.d_k) + p_attn = scores.softmax(dim=-2) + x = torch.matmul(value.transpose(-1, -2), p_attn) + # 3) "Concat" using a view and apply a final linear. + x = ( + x.view(-1, self.h * self.d_k) + ) + del query + del key + del value + return x + +class custom_MatmulLayerNorm(torch.nn.Module): + def __init__(self, hidden_size, output_size): # (512, 3072, 768) + super(custom_MatmulLayerNorm, self).__init__() + self.weight = torch.nn.Parameter(torch.randn(output_size, hidden_size)) # (768, 3072) + self.bias = torch.nn.Parameter(torch.randn(output_size)) # (768) + self.layer_norm = torch.nn.LayerNorm(output_size) # 768 + def forward(self, x, residual): + out = torch.matmul(self.weight, x.transpose(-1, -2)) + self.bias[:, None] # (1, 768, 512) + return self.layer_norm(out.transpose(-1, -2) + residual) + +class EncoderBlock(torch.nn.Module): + def __init__(self, embed_dim, num_heads): + super(EncoderBlock, self).__init__() + self.multihead_attn = my_MultiheadAttention(num_heads, embed_dim) + self.layer_norm = torch.nn.LayerNorm(embed_dim) + self.ffn1 = torch.nn.Linear(embed_dim, embed_dim*4) + self.act = torch.nn.ReLU() + self.ffn2 = torch.nn.Linear(embed_dim*4, embed_dim) + self.matmulln1 = custom_MatmulLayerNorm(embed_dim, embed_dim) + self.matmulln2 = custom_MatmulLayerNorm(embed_dim*4, embed_dim) + + def forward(self, x): + result = self.multihead_attn(x, x, x) + result = self.matmulln1(result, x) + + ffn1_result = self.ffn1(result) + act_result = self.act(ffn1_result) + return self.matmulln2(act_result, result) + +def test_EncoderBlock(device, head=12, embed_dim=768, input_seq=512): + cpu_query = torch.randn(input_seq, embed_dim) + encoder_block = EncoderBlock(embed_dim, head) + cpu_res = encoder_block(cpu_query) + + query = cpu_query.clone().to(device=device) + encoder_block.to(device=device) + with torch.no_grad(): + opt_fn = torch.compile(dynamic=False)(encoder_block) + res = opt_fn(query) + + test_result("Encoder Block Forwrad", res, cpu_res) + +def test_Attention(device, head=16, seq=512, d_k=64): + def attention(query, key, value): + import math + d_k = query.size(-1) + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) + p_attn = scores.softmax(dim=-2) + return torch.matmul(value.transpose(-1, -2), p_attn) + + torch.manual_seed(0) + query = torch.randn(head, seq, d_k).to(device=device) + key = torch.randn(head, seq, d_k).to(device=device) + value = torch.randn(head, seq, d_k).to(device=device) + + opt_fn = torch.compile(dynamic=False)(attention) + res = opt_fn(query, key, value) + + cpu_res = attention(query.cpu(), key.cpu(), value.cpu()) + test_result("Attention Forward", res, cpu_res) + +def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): + MHA = my_MultiheadAttention(num_heads, embed_dim) + cpu_query = torch.randn(input_seq, embed_dim) + with torch.no_grad(): + cpu_res = MHA(cpu_query, cpu_query, cpu_query) + query = cpu_query.clone().to(device=device) + MHA.to(device=device) + opt_fn = torch.compile(dynamic=False)(MHA) + res = opt_fn(query, query, query) + + test_result("MHA Forward", res, cpu_res) + +def test_EncoderBlock_validation(head=12, embed_dim=768, input_seq=512): + bert_origin = EncoderBlock_origin(embed_dim, head) + bert = EncoderBlock(embed_dim, head) + + bert.multihead_attn.linears[0].weight = bert_origin.multihead_attn.linears[0].weight + bert.multihead_attn.linears[0].bias = bert_origin.multihead_attn.linears[0].bias + bert.multihead_attn.linears[1].weight = bert_origin.multihead_attn.linears[1].weight + bert.multihead_attn.linears[1].bias = bert_origin.multihead_attn.linears[1].bias + bert.multihead_attn.linears[2].weight = bert_origin.multihead_attn.linears[2].weight + bert.multihead_attn.linears[2].bias = bert_origin.multihead_attn.linears[2].bias + bert.ffn1.weight = bert_origin.ffn1.weight + bert.ffn1.bias = bert_origin.ffn1.bias + bert.matmulln1.weight = torch.nn.Parameter(bert_origin.multihead_attn.linears[-1].weight) + bert.matmulln1.bias = torch.nn.Parameter(bert_origin.multihead_attn.linears[-1].bias) + bert.matmulln2.weight = torch.nn.Parameter(bert_origin.ffn2.weight) + bert.matmulln2.bias = torch.nn.Parameter(bert_origin.ffn2.bias) + + origin_query = torch.randn(input_seq, embed_dim) + query = origin_query.clone() + origin_res = bert_origin(origin_query) + res = bert(query) + + test_result("Encoder Block Validation", res, origin_res) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + #test_MHA(device) + test_EncoderBlock(device) + # test_EncoderBlock_validation() + # test_Attention(device, head=16, seq=512, d_k=64) + # test_MHA(device, num_heads=12, embed_dim=768) diff --git a/tests/Mixtral_8x7B/model.py b/tests/Mixtral_8x7B/model.py new file mode 100644 index 00000000..4c583a0b --- /dev/null +++ b/tests/Mixtral_8x7B/model.py @@ -0,0 +1,241 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional, List + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + num_experts: int = 8 + num_activated_experts: int = 2 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] + assert len(config) == 1, name + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), +} + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, 0, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, k_val, v_val): + self.k_cache = torch.cat([self.k_cache, k_val], dim=2) + self.v_cache = torch.cat([self.v_cache, v_val], dim=2) + + return self.k_cache, self.v_cache + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + self.setup_caches(1, 512) + + def setup_caches(self, max_batch_size, max_seq_length): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + #for b in self.layers: + # b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) + + def forward(self, x: Tensor, mask, freqs_cis: Tensor, input_pos: Optional[Tensor] = None, kv_cache: List[KVCache] = None) -> Tensor: + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask, kv_cache[i]) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.ffn = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, kv_cache: KVCache = None) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, kv_cache) + out = h + self.ffn(self.ffn_norm(h)) + return out + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + + def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, kv_cache: KVCache = None) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + # Todo. + if freqs_cis is not None: + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if kv_cache is not None: + k, v = kv_cache.update(k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.w1 = nn.Parameter(torch.empty(config.intermediate_size, config.dim)) + self.w2 = nn.Parameter(torch.empty(config.dim, config.intermediate_size)) + self.w3 = nn.Parameter(torch.empty(config.intermediate_size, config.dim)) + + def forward(self, x) -> Tensor: + x1 = F.silu(torch.einsum('bti,oi -> bto', x, self.w1)) + x3 = torch.einsum('bti, oi -> bto', x, self.w3) + out = torch.einsum('bto, io -> bti', (x1 * x3), self.w2) + return out + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + # FIXME. This is dummy rotary embedding + return x*freqs_cis + +def apply_rotary_emb2(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.reshape(*x.shape[:-1], -1, 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/tests/Mixtral_8x7B/test_attention.py b/tests/Mixtral_8x7B/test_attention.py new file mode 100644 index 00000000..aa1af651 --- /dev/null +++ b/tests/Mixtral_8x7B/test_attention.py @@ -0,0 +1,173 @@ +import copy +import torch +import torch._dynamo +import torch.utils.cpp_extension +from model import Transformer, TransformerBlock, ModelArgs, Attention, FeedForward, KVCache, RMSNorm, precompute_freqs_cis, sample + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_decode(device, prompt_length, nr_tokens): + # Setup model & model args + args = ModelArgs() + args.n_head = 8 + args.n_local_heads = -1 + args.intermediate_size = None + args.dim = 512 + args.n_layer = 1 + args.__post_init__() + max_batch = 1 + max_seq = 512 + head_dim = args.dim // args.n_head + model = Transformer(args) + model.setup_caches(max_batch, max_seq) + model = model.to(device=device) + + # Prepare inputs + T = prompt_length + prompt = torch.randn([1, T, args.dim] , dtype=torch.float32) + cpu_prompt = copy.deepcopy(prompt) + cpu_model = copy.deepcopy(model).to("cpu") + opt_fn = torch.compile(dynamic=False)(model) + + # Prepare KV cache + kv_caches = [KVCache(max_batch, max_seq, args.n_head, head_dim, torch.float32) for i in range(args.n_layer)] + cpu_kv_caches = copy.deepcopy(kv_caches) + kv_caches = [kv.to(device=device) for kv in kv_caches] + + for i in range(nr_tokens): + input_pos = torch.arange(0, T) + mask = torch.tril(torch.ones(T, T, dtype=torch.bool)) + freqs_cis = precompute_freqs_cis(args.block_size, args.dim // args.n_head, args.rope_base)[input_pos].to(dtype=torch.float32) + prompt = prompt.to(device=device) + cpu_input_pos = copy.deepcopy(input_pos) + input_pos = input_pos.to(device=device) + cpu_mask = copy.deepcopy(mask) + mask = mask.to(device=device) + + freqs_cis = freqs_cis.view(1, T, 1, -1) + cpu_freqs_cis = copy.deepcopy(freqs_cis) + freqs_cis = freqs_cis.to(device=device) + + # Run models + res = opt_fn(prompt, mask, freqs_cis, input_pos, kv_caches) + cpu_res = cpu_model(cpu_prompt, cpu_mask, cpu_freqs_cis, cpu_input_pos, cpu_kv_caches) + new_token = sample(cpu_res.cpu())[0] + print(new_token) + new_token = cpu_model.tok_embeddings(new_token).unsqueeze(1) + cpu_prompt = new_token #torch.cat([cpu_prompt, new_token], dim=1) + prompt = cpu_prompt.clone() + T = 1 + + # Check output token + test_result("Mistral", res, cpu_res) + +def test_attention(device): + args = ModelArgs() + args.n_head = 8 + args.n_local_heads = -1 + args.intermediate_size = None + args.dim = 512 + args.__post_init__() + model = Attention(args) + model = model.to(device=device) + + T = 32 + prompt = torch.randn([1, T, args.dim] , dtype=torch.float32) + input_pos = torch.arange(0, T) + cpu_prompt = copy.deepcopy(prompt) + prompt = prompt.to(device=device) + cpu_input_pos = copy.deepcopy(input_pos) + input_pos = input_pos.to(device=device) + mask = torch.tril(torch.ones(T, T, dtype=torch.bool)) + cpu_mask = copy.deepcopy(mask) + mask = mask.to(device=device) + + cpu_model = copy.deepcopy(model).to("cpu") + opt_fn = torch.compile(dynamic=False)(model) + res = opt_fn(prompt, None, mask, input_pos) + cpu_res = cpu_model(cpu_prompt, None, cpu_mask, cpu_input_pos) + test_result("Attention", res, cpu_res) + +def test_ffn(device): + args = ModelArgs() + args.n_head = 8 + args.n_local_heads = -1 + args.intermediate_size = None + args.dim = 512 + args.__post_init__() + model = FeedForward(args) + model = model.to(device=device) + + T = 32 + prompt = torch.randn([1, T, args.dim] , dtype=torch.float32) + cpu_prompt = copy.deepcopy(prompt) + prompt = prompt.to(device=device) + + cpu_model = copy.deepcopy(model).to("cpu") + opt_fn = torch.compile(dynamic=False)(model) + res = opt_fn(prompt) + cpu_res = cpu_model(cpu_prompt) + test_result("FFN", res, cpu_res) + +def test_concat(device, size1=(1, 8, 32, 64), size2=(1, 8, 1, 64), dim=2): + def concat_tensors(a, b): + return torch.cat((a, b), dim=dim) + + x = torch.randn(size1) + y = torch.randn(size2) + cpu_x = x.clone() + cpu_y = y.clone() + x = x.to(device=device) + y = y.to(device=device) + + opt_fn = torch.compile(dynamic=False)(concat_tensors) + res = opt_fn(x, y) + out = concat_tensors(cpu_x, cpu_y) + + test_result("ConcatTensors", res, out) + +def test_rmsnorm(device, seq=32): + dim = 512 + eps = 1e-5 + T = seq + rmsnorm = RMSNorm(dim=dim, eps=eps) + rmsnorm = rmsnorm.to(device=device) + + x = torch.randn([1, T, dim], dtype=torch.float32) + cpu_x = copy.deepcopy(x) + x = x.to(device) + + cpu_model = copy.deepcopy(rmsnorm).to("cpu") + opt_fn = torch.compile(dynamic=False)(rmsnorm) + + res = opt_fn(x) + cpu_res = cpu_model(cpu_x) + + test_result("RMSNorm", res, cpu_res) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_rmsnorm(device, seq=1) + test_concat(device, size1=(1, 8, 64, 64), size2=(1,8,1,64), dim=2) + test_decode(device, 32, 3) + #test_attention(device) + #test_ffn(device) diff --git a/tests/MoE/test_moe.py b/tests/MoE/test_moe.py index ff6dd00b..c5ab8107 100644 --- a/tests/MoE/test_moe.py +++ b/tests/MoE/test_moe.py @@ -1,12 +1,7 @@ # Owner(s): ["module: inductor"] import os -import shutil import sys -import time -import contextlib -import unittest import copy -import numpy as np import matplotlib.pyplot as plt @@ -341,7 +336,7 @@ def forward(self, x, loss_coef=1e-2): expert_inputs = dispatcher.dispatch(x) gates = dispatcher.expert_to_gates() expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)] - y = dispatcher.combine(expert_outputs, multiply_by_gates=False) + y = dispatcher.combine(expert_outputs, multiply_by_gates=True) return y, loss @torch.compiler.disable(recursive=True) @@ -420,15 +415,15 @@ def test_moe(device): x1 = copy.deepcopy(X).to(device=device) x2 = copy.deepcopy(X).to("cpu") - # model.train() - model.eval() + model.train() + # model.eval() model_device = model.to(device=device) opt_model = torch.compile(model_device, dynamic=False) y_hat, aux_loss = opt_model(x1) print("MoE Custom Device Done!") - # model_cpu.train() - model_cpu.eval() + model_cpu.train() + # model_cpu.eval() cpu_hat, cpu_aux_loss = model_cpu(x2) test_result("MoE Forward", y_hat, cpu_hat) test_result("MoE Aux Loss", aux_loss, cpu_aux_loss) @@ -453,15 +448,15 @@ def test_moe(device): total_cpu_loss.backward() print("MoE Backward Done!") - print("MoE Weight Bias print") - for i in range(num_experts): - print(f"\nExpert {i}") - print(f"FC1 Weight: {model.experts[i].fc1.weight.cpu()}") - print(f"FC1 Bias: {model.experts[i].fc1.bias.cpu()}") - print("\n") - print(f"FC2 Weight: {model.experts[i].fc2.weight.cpu()}") - print(f"FC2 Bias: {model.experts[i].fc2.bias.cpu()}") - print("\n") + # print("MoE Weight Bias print") + # for i in range(num_experts): + # print(f"\nExpert {i}") + # print(f"FC1 Weight: {model.experts[i].fc1.weight.cpu()}") + # print(f"FC1 Bias: {model.experts[i].fc1.bias.cpu()}") + # print("\n") + # print(f"FC2 Weight: {model.experts[i].fc2.weight.cpu()}") + # print(f"FC2 Bias: {model.experts[i].fc2.bias.cpu()}") + # print("\n") print("MoE Weight Bias Grad") for i in range(num_experts): @@ -514,7 +509,7 @@ def weight_update(a, b, lr): # model.eval() model_device = model.to(device=device) opt_model = torch.compile(model_device, dynamic=False) - opt_w = torch.compile()(weight_update, dynamic=False) + # opt_w = torch.compile()(weight_update, dynamic=False) y_hat, aux_loss = opt_model(x1) print("MoE Custom Device Done!") diff --git a/tests/test_activation.py b/tests/test_activation.py new file mode 100644 index 00000000..de3542c3 --- /dev/null +++ b/tests/test_activation.py @@ -0,0 +1,101 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension +import torch.nn.functional as F + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_ReLU(device, size=(128, 128)): + torch.manual_seed(0) + input = torch.randn(size) + x1 = input.to(device=device) + x2 = input.to("cpu") + opt_fn = torch.compile(dynamic=False)(torch.nn.functional.relu) + y = opt_fn(x1) + cpu_y = torch.nn.functional.relu(x2) + test_result("ReLU", y, cpu_y) + +def test_GeLU(device, size=(128, 128), approximate='none'): + torch.manual_seed(0) + input = torch.randn(size) + x1 = input.to(device=device) + x2 = input.to("cpu") + GeLU = torch.nn.GELU(approximate=approximate) + opt_fn = torch.compile(dynamic=False)(GeLU) + y = opt_fn(x1) + cpu_y = GeLU(x2) + test_result("GeLU", y, cpu_y) + +def test_sigmoid(device, size=(128, 128)): + torch.manual_seed(0) + input = torch.randn(size) + x1 = input.to(device=device) + x2 = input.to("cpu") + Sigmoid = torch.nn.Sigmoid() + opt_fn = torch.compile(dynamic=False)(Sigmoid) + y = opt_fn(x1) + cpu_y = Sigmoid(x2) + test_result("Sigmoid", y, cpu_y) + +def test_SiLU(device, size=(128, 128)): + torch.manual_seed(0) + input = torch.randn(size) + x1 = input.to(device=device) + x2 = input.to("cpu") + SiLU = torch.nn.SiLU() + opt_fn = torch.compile(dynamic=False)(SiLU) + y = opt_fn(x1) + cpu_y = SiLU(x2) + test_result("SiLU", y, cpu_y) + +class SwiGLU(torch.nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + +def test_SwiGLU(device, size=(128, 128)): + torch.manual_seed(0) + input = torch.randn(size) + x1 = input.to(device=device) + x2 = input.to("cpu") + SwiGLU_fn = SwiGLU() + opt_fn = torch.compile(dynamic=False)(SwiGLU_fn) + y = opt_fn(x1) + cpu_y = SwiGLU_fn(x2) + test_result("SwiGLU", y, cpu_y) + +if __name__ == "__main__": + import os + import sys + import argparse + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") + parser.add_argument('--shape', type=str, default="(512,768)") + args = parser.parse_args() + shape = tuple(map(int, args.shape.strip('()').split(','))) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_ReLU(device, (47, 10)) + test_ReLU(device, (128, 128)) + test_ReLU(device, (4071, 429)) + test_sigmoid(device, (128, 128)) + test_SiLU(device, (128, 128)) + test_SwiGLU(device, (128, 128)) + test_GeLU(device, (128, 128)) + test_GeLU(device, (128, 128), approximate='tanh') diff --git a/tests/test_add.py b/tests/test_add.py index a3c2be9a..5e1ab15e 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) @@ -23,15 +27,42 @@ def vectoradd(a, b): out = vectoradd(x.cpu(), y.cpu()) test_result("VectorAdd", res, out) +def test_vector_scalar_add(device, size=(128, 128)): + def vectoradd(a, b): + return a + b + x = torch.randn(size).to(device=device) + y = torch.randn([1]).to(device=device) + opt_fn = torch.compile(dynamic=False)(vectoradd) + res = opt_fn(x, y) + out = vectoradd(x.cpu(), y.cpu()) + test_result("VectorScalarAdd", res, out) + +def test_vector_tensor_add(device, size=(128, 128)): + def vectoradd(a, b): + return a + b + x = torch.randn(size).to(device=device) + y = torch.randn(size[-1]).to(device=device) + opt_fn = torch.compile(dynamic=False)(vectoradd) + res = opt_fn(x, y) + out = vectoradd(x.cpu(), y.cpu()) + test_result("VectorTensorAdd", res, out) if __name__ == "__main__": import os import sys + import argparse sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") + parser.add_argument('--shape', type=str, default="(512,768)") + args = parser.parse_args() + shape = tuple(map(int, args.shape.strip('()').split(','))) + from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() + test_vectoradd(device, (1, 1)) test_vectoradd(device, (47, 10)) test_vectoradd(device, (128, 128)) test_vectoradd(device, (4071, 429)) + test_vector_tensor_add(device, (128, 128)) diff --git a/tests/test_batchnorm.py b/tests/test_batchnorm.py index bb8d529f..f7abacf5 100644 --- a/tests/test_batchnorm.py +++ b/tests/test_batchnorm.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) @@ -16,10 +20,12 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): def test_BatchNorm(device, size=(1, 16, 64, 64)): torch.manual_seed(0) model = torch.nn.BatchNorm2d(size[1]).eval() - model.to(device=device) - input = torch.randn(size) - x1 = input.to(device=device) - x2 = input.to("cpu") + model.to(device=device, memory_format=torch.channels_last) + input = torch.empty_strided(size, (size[1]*size[2]*size[3], 1, size[1], size[1]*size[2])) + input.uniform_(-1, 1) + + x1 = input.to(device=device, memory_format=torch.channels_last) + x2 = input.to("cpu", memory_format=torch.channels_last) opt_fn = torch.compile(dynamic=False)(model) y = opt_fn(x1) cpu_model = model.to("cpu") @@ -35,3 +41,6 @@ def test_BatchNorm(device, size=(1, 16, 64, 64)): module = ExecutionEngine.setup_device() device = module.custom_device() test_BatchNorm(device) + test_BatchNorm(device, size=(1,64, 32, 32)) + test_BatchNorm(device, size=(1, 8, 4, 4)) + test_BatchNorm(device, size=(1,256, 32, 32)) diff --git a/tests/test_bmm.py b/tests/test_bmm.py index 483980d8..6d9279aa 100644 --- a/tests/test_bmm.py +++ b/tests/test_bmm.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) @@ -24,6 +28,19 @@ def bmm(a, b): out = bmm(a.cpu(), b.cpu()) test_result("BMM Forward", res, out) +def test_addBMM(device, batch_size=1, m=32, n=16, k=64, bias_rank=1):#TODO: Fusion should be implemented for this test + def bmm(a, b, bias): + return torch.bmm(a, b.transpose(1, 2)) + bias + torch.manual_seed(0) + a = torch.randn(batch_size, m, k).to(device=device) + b = torch.randn(batch_size, n, k).to(device=device) + bias = torch.randn(batch_size, n) if bias_rank == 1 else torch.randn(batch_size, m, n) + bias = bias.to(device=device) + opt_fn = torch.compile(dynamic=False)(bmm) + res = opt_fn(a, b, bias) + out = bmm(a.cpu(), b.cpu(), bias.cpu()) + test_result("BMM Forward", res, out) + if __name__ == "__main__": import os import sys @@ -33,4 +50,9 @@ def bmm(a, b): module = ExecutionEngine.setup_device() device = module.custom_device() test_BMM(device) - test_BMM(device, 2, 512, 512, 512) + test_BMM(device, 2, 256, 128, 256) + test_BMM(device, 2, 128, 256, 256) + test_BMM(device, 2, 256, 256, 128) + test_BMM(device, 4, 256, 256, 256) + test_BMM(device, 12, 512, 512, 64) + test_BMM(device, 16, 512, 512, 64) \ No newline at end of file diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 2d96fe7a..aaad2836 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) @@ -24,9 +28,9 @@ def __init__(self): def forward(self, x): x = self.conv1(x) + x = self.maxpool(x) x = self.norm(x) x = self.conv2(x) - # x = self.maxpool(x) x = torch.nn.functional.relu(x) return x diff --git a/tests/test_compile_overhead.py b/tests/test_compile_overhead.py new file mode 100644 index 00000000..cf0dc1bb --- /dev/null +++ b/tests/test_compile_overhead.py @@ -0,0 +1,45 @@ +import os +import time +import sys +import torch +from torchvision.models import resnet18 as model1 +import argparse +import shutil + +sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request, poisson_request_generator +CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + +if __name__ == "__main__": + target_model1 = model1().eval() + + # Init scheduler + for i in range(1): + timestamp = time.time() # 현재 타임스탬프 (초 단위) + print(f"[{i}] Time Stamp: {timestamp:.6f}") # 소수점 6자리까지 출력 + #try: + # shutil.rmtree("/tmp/torchinductor") + #except FileNotFoundError: + # print("no cache") + scheduler = Scheduler(num_request_queue=1, max_batch=4, engine_select=Scheduler.FIFO_ENGINE, backend_config=f"{CONFIG_TORCHSIM_DIR}/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json") + # Register compiled model + opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last), dynamic=False) + SchedulerDNNModel.register_model("resnet18", opt_model1) + + # Generate time stamp + for request_time in [0]*12: + # Init input data + model_input1 = torch.randn(1, 3, 224, 224) + + # Init request + new_request1 = Request("resnet18", [model_input1], [], request_queue_idx=0) + + # Add request to scheduler + print("[Reqest] Resnet18 request time: ", request_time, flush=True) + scheduler.add_request(new_request1, request_time=request_time) + + # Run scheduler + while not scheduler.is_finished(): + scheduler.schedule() + + print("Done", file=sys.stderr) \ No newline at end of file diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index 29924156..c679b431 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) @@ -17,7 +21,7 @@ def test_conv2d(device, batch_size=1, in_channels=8, out_channels=16, input_size def custom_conv2d(a, b, bias): i_c = a.shape[1] o_c = b.shape[0] - conv2d = torch.nn.Conv2d(i_c, o_c, b.shape[-1], stride=stride, padding=padding, dilation=1) + conv2d = torch.nn.Conv2d(i_c, o_c, b.shape[-1], stride=stride, padding=padding, dilation=1, bias=False) conv2d.weight = torch.nn.Parameter(b) conv2d.bias = torch.nn.Parameter(bias) return conv2d(a) @@ -39,4 +43,15 @@ def custom_conv2d(a, b, bias): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_conv2d(device, batch_size=1, in_channels=128, out_channels=128, input_size=28, kernel_size=3, stride=1, padding=1) + torch._dynamo.config.cache_size_limit = 64 + test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) + test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) diff --git a/tests/test_relu.py b/tests/test_exponent.py similarity index 57% rename from tests/test_relu.py rename to tests/test_exponent.py index 3c3915d7..c95823cb 100644 --- a/tests/test_relu.py +++ b/tests/test_exponent.py @@ -3,25 +3,28 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) -def test_ReLU(device, size=(128, 128)): - torch.manual_seed(0) - input = torch.randn(size) - x1 = input.to(device=device) - x2 = input.to("cpu") - opt_fn = torch.compile(dynamic=False)(torch.nn.functional.relu) - y = opt_fn(x1) - cpu_y = torch.nn.functional.relu(x2) - test_result("ReLU", y, cpu_y) +def test_exponent(device, size=(128, 128)): + def exponent(a): + return a.exp() + x = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(exponent) + res = opt_fn(x) + out = exponent(x.cpu()) + test_result("exponent", res, out) if __name__ == "__main__": import os @@ -31,6 +34,4 @@ def test_ReLU(device, size=(128, 128)): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_ReLU(device, (47, 10)) - test_ReLU(device, (128, 128)) - test_ReLU(device, (4071, 429)) + test_exponent(device, size=(32, 32)) diff --git a/tests/test_hetro.py b/tests/test_hetro.py new file mode 100644 index 00000000..5e36d730 --- /dev/null +++ b/tests/test_hetro.py @@ -0,0 +1,77 @@ +import os +import sys +import torch +import argparse +sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request +from test_stonne import sparse_matmul + +def custom_matmul(a, b): + return torch.matmul(a, b) +torch.manual_seed(0) +CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--M", type=int, default=128, help="Batch size") + parser.add_argument("--N", type=int, default=128, help="Input layer size") + parser.add_argument("--K", type=int, default=128, help="Hidden layer size") + parser.add_argument("--sparsity", type=float, default=0.9, help="Output layer size") + parser.add_argument("--config", type=str, default="stonne_big_c1_simple_noc.json", help="Output layer size") + parser.add_argument("--mode", type=int, default=0, help="Output layer size") + args = parser.parse_args() + + M = args.M + N = args.N + K = args.K + sparsity = args.sparsity + mode = args.mode + config_path = f"{CONFIG_TORCHSIM_DIR}/PyTorchSimBackend/configs/{args.config}" + + print("M: ", M) + print("N: ", N) + print("K: ", K) + print("sparsity: ", sparsity) + + with torch.no_grad(): + # Init scheduler + scheduler = Scheduler(num_request_queue=2, engine_select=Scheduler.FIFO_ENGINE, + backend_config=config_path) + + # Register compiled model + opt_model1 = torch.compile(custom_matmul) + opt_model2 = torch.compile(sparse_matmul) + SchedulerDNNModel.register_model("matmul", opt_model1) + SchedulerDNNModel.register_model("spmm", opt_model2) + + # Init input data + for i in range(1): + dense_input1 = torch.randn(M, K) + dense_input2 = torch.randn(K, N) + + sparse_input1 = torch.randn(128, 128) + sparse_input2 = torch.randn(128, 128) + mask1 = torch.rand(sparse_input1.shape) > sparsity + mask2 = torch.rand(sparse_input2.shape) > sparsity + + sparse_input1 = sparse_input1 * mask1 + sparse_input2 = sparse_input2 * mask2 + + # Init request + if mode == 0: + new_request1 = Request("spmm", [sparse_input1, sparse_input2], [], request_queue_idx=0) + scheduler.add_request(new_request1, request_time=0) + elif mode == 1: + new_request2 = Request("matmul", [dense_input1, dense_input2], [], request_queue_idx=0) + scheduler.add_request(new_request2, request_time=0) + elif mode == 2: + new_request1 = Request("spmm", [sparse_input1, sparse_input2], [], request_queue_idx=0) + new_request2 = Request("matmul", [dense_input1, dense_input2], [], request_queue_idx=1) + + # Add request to scheduler + scheduler.add_request(new_request1, request_time=0) + scheduler.add_request(new_request2, request_time=0) + + # Run scheduler + while not scheduler.is_finished(): + scheduler.schedule() \ No newline at end of file diff --git a/tests/test_indirect_access.py b/tests/test_indirect_access.py new file mode 100644 index 00000000..b7b20074 --- /dev/null +++ b/tests/test_indirect_access.py @@ -0,0 +1,55 @@ +import torch +import copy +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_indirect_vectoradd(device, size=(128, 128)): + def vectoradd(a, idx, b): + return a[idx] + b + x = torch.randn(size, dtype=torch.float32).to(device=device) + idx = torch.randint(0,128, [128]).to(device=device) + y = torch.randn(128, dtype=torch.float32).to(device=device) + opt_fn = torch.compile(dynamic=False)(vectoradd) + res = opt_fn(x, idx, y) + out = vectoradd(x.cpu(), idx.cpu(), y.cpu()) + test_result("Indirect VectorAdd", res, out) + +def test_embedding(device, vocab_size, dim): + emb = torch.nn.Embedding(vocab_size, dim) + cpu_emb = copy.deepcopy(emb) + + prompt = torch.randint(0, 1023, [511], dtype=torch.int) + cpu_prompt = copy.deepcopy(prompt) + prompt = prompt.to(device=device) + + emb.to(device=device) + opt_emb = torch.compile(dynamic=False)(emb) + res = opt_emb(prompt) + cpu_res = cpu_emb(cpu_prompt) + test_result("Embedding", res, cpu_res) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_indirect_vectoradd(device) + #test_embedding(device, 1024, 2048) \ No newline at end of file diff --git a/tests/test_layernorm.py b/tests/test_layernorm.py index 26f5ca17..1cea9d9f 100644 --- a/tests/test_layernorm.py +++ b/tests/test_layernorm.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) @@ -29,10 +33,16 @@ def test_LayerNorm(device, size=(64, 64)): if __name__ == "__main__": import os import sys + import argparse sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") + parser.add_argument('--shape', type=str, help="Shape of the tensor in the format (batch_size, features)", default="(512,768)") + args = parser.parse_args() + shape = tuple(map(int, args.shape.strip('()').split(','))) + from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_LayerNorm(device) - test_LayerNorm(device, (64, 128)) + #test_LayerNorm(device) + test_LayerNorm(device, shape) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 3913df5b..6f41468b 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) @@ -19,7 +23,6 @@ def custom_matmul(a, b): torch.manual_seed(0) input = torch.randn(input_size, hidden_size) weight = torch.randn(hidden_size, output_size) - bias = torch.randn(output_size) x1 = input.to(device=device) w1 = weight.to(device=device) x2 = input.to("cpu") @@ -29,13 +32,13 @@ def custom_matmul(a, b): y = custom_matmul(x2, w2) test_result("Matmul Forward", res, y) -def test_addmm(device, input_size=128, hidden_size=128, output_size=128): +def test_addmm(device, input_size=128, hidden_size=128, output_size=128, bias_rank=1): def custom_matmul(bias, a, b): return torch.addmm(bias, a, b) torch.manual_seed(0) input = torch.randn(input_size, hidden_size) weight = torch.randn(hidden_size, output_size) - bias = torch.randn(output_size) + bias = torch.randn(output_size) if bias_rank == 1 else torch.randn(input_size, output_size) x1 = input.to(device=device) w1 = weight.to(device=device) b1 = bias.to(device=device) @@ -47,6 +50,45 @@ def custom_matmul(bias, a, b): y = custom_matmul(b2, x2, w2) test_result("Addmm Forward", res, y) +def test_addmm2(device, input_size=128, hidden_size=128, output_size=128): + def custom_matmul(bias, a, b): + return torch.matmul(a, b) #+ bias + torch.manual_seed(0) + input = torch.randn(input_size, hidden_size) + weight = torch.randn(hidden_size, output_size) + bias = torch.randn(input_size, 1, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + opt_fn = torch.compile(dynamic=False)(custom_matmul) + res = opt_fn(b1, x1, w1) + y = custom_matmul(b2, x2, w2) + test_result("Addmm2 Forward", res, y) + +def test_linear(device, input_size=128, hidden_size=128, output_size=128): + def custom_linear(a, b, bias): + linear = torch.nn.Linear(hidden_size, output_size) + linear.weight = torch.nn.Parameter(b) + linear.bias = torch.nn.Parameter(bias) + return linear(a) + torch.manual_seed(0) + input = torch.randn(input_size, hidden_size) + weight = torch.randn(output_size, hidden_size) + bias = torch.randn(output_size) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + opt_fn = torch.compile(dynamic=False)(custom_linear) + res = opt_fn(x1, w1, b1) + y = custom_linear(x2, w2, b2) + test_result("Linear Forward", res, y) + if __name__ == "__main__": import os import sys @@ -57,7 +99,12 @@ def custom_matmul(bias, a, b): device = module.custom_device() test_matmul(device, 32, 32, 32) test_matmul(device, 128, 128, 128) - test_matmul(device, 512, 512, 512) - test_matmul(device, 129, 61, 56) - test_addmm(device, 128, 128, 128) + test_matmul(device, 256, 256, 256) + test_matmul(device, 128, 256, 256) + test_matmul(device, 128, 63, 56) + test_addmm(device, 128, 256, 512) + test_addmm(device, 128, 256, 512, bias_rank=2) test_addmm(device, 129, 61, 56) + test_addmm2(device, 129, 61, 56) + test_addmm(device, 129*4, 61*4, 56*4) + test_addmm2(device, 129*4, 61*4, 56*4) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 0582ce74..b8118aa3 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -4,12 +4,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) @@ -79,14 +83,16 @@ def test_mlp_inf(device, batch_size=64, input_size=64, hidden_size=32, output_si def test_optimizer(device): torch.manual_seed(0) model = MLP(input_size=16, hidden_size=16, output_size=16).to(device=device) + model.requires_grad = True cpu_model = copy.deepcopy(model).to("cpu") + opt_model = torch.compile(dynamic=False)(model) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) cpu_optimizer = torch.optim.Adam(cpu_model.parameters(), lr=0.001) opt_step = torch.compile(dynamic=False)(optimizer.step) input = torch.randn(16, 16) x1 = copy.deepcopy(input).to(device=device) x2 = copy.deepcopy(input).to("cpu") - y = model(x1) + y = opt_model(x1) cpu_y = cpu_model(x2) loss = y.sum() cpu_loss = cpu_y.sum() @@ -110,3 +116,4 @@ def test_optimizer(device): test_mlp_inf(device, batch_size=1, input_size=256, hidden_size=512, output_size=256) test_mlp_inf(device, batch_size=8, input_size=256, hidden_size=512, output_size=256) test_mlp_inf(device, batch_size=64, input_size=256, hidden_size=512, output_size=256) + test_optimizer(device) diff --git a/tests/test_pool.py b/tests/test_pool.py index f28becac..304a5e7c 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -3,34 +3,38 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) -def test_maxpool(device): +def test_maxpool(device, b=1, c=64, h=112, w=112): torch.manual_seed(0) model = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1).eval() model.to(device=device) - input = torch.randn(1, 8, 64, 64).to(device=device) + input = torch.randn(b, c, h, w).to(device=device) x1 = input.to(device=device) x2 = input.to("cpu") opt_fn = torch.compile(dynamic=False)(model) res = opt_fn(x1) model.to("cpu") out = model(x2) - # test_result("Maxpool Forward", res, out) # TODO: MaxPool Functionality is not working + test_result("Maxpool Forward", res, out) # TODO: MaxPool Functionality is not working -def test_avgpool(device): +def test_avgpool(device, b=1, c=64, h=112, w=112): def avgpool(a): return torch.nn.AdaptiveAvgPool2d((1, 1))(a) torch.manual_seed(0) - input = torch.randn(1, 16, 64, 64).to(device=device) #FIXME: channel 8 does not work (range padding issue) + input = torch.randn(b, c, h, w).to(device=device) #FIXME: channel 8 does not work (range padding issue) x1 = input.to(device=device) x2 = input.to("cpu") opt_fn = torch.compile(dynamic=False)(avgpool) @@ -46,5 +50,6 @@ def avgpool(a): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_maxpool(device) - test_avgpool(device) + #test_maxpool(device, b=1, c=8, h=16, w=16) + #test_maxpool(device, b=1, c=8, h=112, w=112) + test_avgpool(device, b=1, c=512, h=7, w=7) diff --git a/tests/test_reduce.py b/tests/test_reduce.py index 90a7487e..e1a84b7f 100644 --- a/tests/test_reduce.py +++ b/tests/test_reduce.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) @@ -23,11 +27,26 @@ def reduce_sum(a, b, dim, keepdim): out = reduce_sum(x.cpu(), y.cpu(), dim, keepdim) test_result("ReduceSum", res, out) +def test_reduce_sum2(device, size, dim=-1, keepdim=False): + def reduce_sum(a, dim, keepdim): + return torch.sum(a, axis=dim, keepdim=keepdim) + x = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(reduce_sum) + res = opt_fn(x, dim, keepdim) + out = reduce_sum(x.cpu(), dim, keepdim) + test_result("ReduceMax", res, out) + if __name__ == "__main__": import os import sys + import argparse sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") + parser.add_argument('--shape', type=str, default="(128,768)") + args = parser.parse_args() + shape = tuple(map(int, args.shape.strip('()').split(','))) + from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() @@ -35,4 +54,4 @@ def reduce_sum(a, b, dim, keepdim): test_reduce_sum(device, (17, 68), 0, keepdim=True) test_reduce_sum(device, (327, 447), 1, keepdim=True) test_reduce_sum(device, (327, 447), 0, keepdim=True) - + test_reduce_sum2(device, shape) \ No newline at end of file diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 37f8a583..97c60528 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -1,34 +1,55 @@ +import argparse import torch import torch._dynamo import torch.utils.cpp_extension +from torchvision.models import resnet18, resnet50 def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) -def test_resnet(device): +def test_resnet(device, batch=1, model_type='resnet18'): from torchvision.models import resnet - model = resnet._resnet(resnet.BasicBlock, [1, 1, 0, 0], weights=None, progress=False).eval() - model.to(device, memory_format=torch.channels_last) - input = torch.randn(1, 3, 224, 224).to(device=device) - x1 = input.to(device=device, memory_format=torch.channels_last) - opt_fn = torch.compile(dynamic=False)(model) - res = opt_fn(x1) - print("ResNet18 Simulation Done") + with torch.no_grad(): + #model = resnet._resnet(resnet.BasicBlock, [1, 1, 1, 1], weights=None, progress=False).eval() + if model_type == 'resnet50': + model = resnet50().eval() + elif model_type == 'resnet18': + model = resnet18().eval() + else: + raise ValueError(f"Unsupported model type: {model_type}") + model.to(device, memory_format=torch.channels_last) + input = torch.randn(batch, 3, 224, 224) + x1 = input.to(device=device, memory_format=torch.channels_last) + x2 = input.cpu().to(memory_format=torch.channels_last) + opt_fn = torch.compile(dynamic=False)(model) + res = opt_fn(x1) + cpu_model = model.cpu().to(memory_format=torch.channels_last) + cpu_res = cpu_model(x2) + test_result(f"{model_type} inference", res, cpu_res) + print("Max diff > ", torch.max(torch.abs(res.cpu() - cpu_res))) + print(f"{model_type} Simulation Done") if __name__ == "__main__": import os import sys + args = argparse.ArgumentParser() + args.add_argument('--model_type', type=str, default="resnet18", help='ex) resnet18') + args = args.parse_args() sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_resnet(device) + test_resnet(device, model_type=args.model_type) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 6eb2f0e0..c64093a0 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -2,16 +2,18 @@ import sys import torch from torchvision.models import resnet18 as model1 +from test_transformer import EncoderBlock as model2 -sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.append(base_path) from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request -from test_extension_backend import DecoderBlock as model2 +config = f'{base_path}/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.json' target_model1 = model1().eval() target_model2 = model2(768, 12).eval() # Init scheduler -scheduler = Scheduler(num_request_queue=2, engine_select=Scheduler.FIFO_ENGINE) +scheduler = Scheduler(num_request_queue=2, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) # Register compiled model opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last)) opt_model2 = torch.compile(target_model2.to(device=scheduler.execution_engine.module.custom_device())) @@ -20,15 +22,19 @@ # Init input data model_input1 = torch.randn(1, 3, 224, 224) -model_input2 = torch.randn(512, 768) +model_input2 = torch.randn(128, 768) # Init request new_request1 = Request("resnet18", [model_input1], [], request_queue_idx=0) new_request2 = Request("bert", [model_input2], [], request_queue_idx=1) +new_request3 = Request("resnet18", [model_input1], [], request_queue_idx=0) +new_request4 = Request("bert", [model_input2], [], request_queue_idx=1) # Add request to scheduler scheduler.add_request(new_request1, request_time=0) scheduler.add_request(new_request2, request_time=0) +scheduler.add_request(new_request3, request_time=0) +scheduler.add_request(new_request4, request_time=0) # Run scheduler while not scheduler.is_finished(): diff --git a/tests/test_scheduler_batching.py b/tests/test_scheduler_batching.py new file mode 100644 index 00000000..f3b54159 --- /dev/null +++ b/tests/test_scheduler_batching.py @@ -0,0 +1,41 @@ +import os +import sys +import torch +from torchvision.models import resnet18 as model1 +import argparse + +sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request, poisson_request_generator +CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Poisson Request Generator (ms)") + parser.add_argument("lambda_requests", nargs="?", type=int, help="Average requests per second (λ)", default=2000) + parser.add_argument("max_time", nargs="?", type=int, help="Maximum simulation time in milliseconds", default=30) + + args = parser.parse_args() + target_model1 = model1().eval() + + # Init scheduler + scheduler = Scheduler(num_request_queue=1, max_batch=32, engine_select=Scheduler.FIFO_ENGINE, backend_config=f"{CONFIG_TORCHSIM_DIR}/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json") + # Register compiled model + opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last), dynamic=False) + SchedulerDNNModel.register_model("resnet18", opt_model1) + + # Generate time stamp + for request_time in poisson_request_generator(args.lambda_requests, args.max_time): + # Init input data + model_input1 = torch.randn(1, 3, 224, 224) + + # Init request + new_request1 = Request("resnet18", [model_input1], [], request_queue_idx=0) + + # Add request to scheduler + print("[Reqest] Resnet18 request time: ", request_time, flush=True) + scheduler.add_request(new_request1, request_time=request_time) + + # Run scheduler + while not scheduler.is_finished(): + scheduler.schedule() + + print("Done", file=sys.stderr) \ No newline at end of file diff --git a/tests/test_single_perceptron.py b/tests/test_single_perceptron.py index 7ab02656..c7fdca06 100644 --- a/tests/test_single_perceptron.py +++ b/tests/test_single_perceptron.py @@ -4,12 +4,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) @@ -41,13 +45,14 @@ def weight_update(a, b, lr): b2.requires_grad = True opt_mlp = torch.compile(dynamic=False)(perceptron) opt_w = torch.compile(dynamic=False)(weight_update) - opt_loss = torch.compile(dynamic=False)(torch.nn.MSELoss()) + loss_fn = torch.nn.MSELoss() + opt_loss = torch.compile(dynamic=False)(loss_fn) lr = torch.tensor(5e-2).to(device=device) # learning rate y = opt_mlp(w1, x1, b1) loss = opt_loss(y, y1) loss.backward() cpu_y = perceptron(x2, w2, b2) - cpu_loss = torch.nn.MSELoss()(cpu_y, y2) + cpu_loss = loss_fn(cpu_y, y2) cpu_loss.backward() test_result("Perceptron", y, cpu_y) test_result("Loss", loss, cpu_loss) diff --git a/tests/test_softmax.py b/tests/test_softmax.py index ca49953c..9fba41dd 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -18,6 +18,30 @@ def test_softmax(device, size=(128, 128), dim=1): input = torch.randn(size) x1 = input.to(device=device) x2 = input.to("cpu") + + # split softmax into 3 steps + #def softmax1(x): # find max + # return x.max(dim=dim, keepdim=True).values + #def softmax2(x, max): + # return (x - max).exp().sum(dim=dim, keepdim=True) + #def softmax3(x, max, sum): + # return (x - max).exp().div(sum) + + #opt_fn1 = torch.compile(dynamic=False)(softmax1) + #opt_fn2 = torch.compile(dynamic=False)(softmax2) + #opt_fn3 = torch.compile(dynamic=False)(softmax3) + + #max = opt_fn1(x1) + #cpu_max = softmax1(x2) + #test_result("Softmax Max", max, cpu_max) + #sum = opt_fn2(x1, max) + #cpu_sum = softmax2(x2, cpu_max) + #test_result("Softmax Sum", sum, cpu_sum) + + #y = opt_fn3(x1, max, sum) + #cpu_y = softmax3(x2, cpu_max, cpu_sum) + #test_result("Softmax", y, cpu_y) + opt_fn = torch.compile(dynamic=False)(torch.nn.functional.softmax) y = opt_fn(x1, dim=dim) cpu_y = torch.nn.functional.softmax(x2, dim=dim) @@ -26,10 +50,20 @@ def test_softmax(device, size=(128, 128), dim=1): if __name__ == "__main__": import os import sys + import argparse sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") + parser.add_argument('--shape', type=str, help="Shape of the tensor in the format (batch_size, features)", default="(512,768)") + args = parser.parse_args() + shape = tuple(map(int, args.shape.strip('()').split(','))) + from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() test_softmax(device, size=(64, 128)) + test_softmax(device, size=(64, 128), dim=0) test_softmax(device, size=(256, 128)) + test_softmax(device, size=(256, 128), dim=0) + test_softmax(device, size=(1, 16)) + test_softmax(device, size=(5, 8)) diff --git a/tests/test_sparse_core.py b/tests/test_sparse_core.py new file mode 100644 index 00000000..b2b16818 --- /dev/null +++ b/tests/test_sparse_core.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import torch._dynamo +import torch.utils.cpp_extension +import torch.nn.utils.prune as prune + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +class MLP(nn.Module): + def __init__(self, input_size=16, hidden_size=16, output_size=16, sparsity_fc1=0, sparsity_fc2=0): + super(MLP, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size, bias=False) + self.fc2 = nn.Linear(hidden_size, output_size, bias=False) + + prune.l1_unstructured(self.fc1, name="weight", amount=sparsity_fc1) + prune.l1_unstructured(self.fc2, name="weight", amount=sparsity_fc2) + + prune.remove(self.fc1, "weight") + prune.remove(self.fc2, "weight") + + def forward(self, x): + x = torch.sparse.mm(x, self.fc1.weight) + x = torch.sparse.mm(x, self.fc2.weight) + return x + +class SparseMLP(nn.Module): + def __init__(self, input_size=16, hidden_size=16, output_size=16, sparsity_fc1=0, sparsity_fc2=0, device="cpu"): + super(SparseMLP, self).__init__() + + self.weight1 = torch.empty(input_size, hidden_size, requires_grad=False) + self.weight2 = torch.empty(hidden_size, output_size, requires_grad=False) + + nn.init.xavier_uniform_(self.weight1) + nn.init.xavier_uniform_(self.weight2) + + self._apply_pruning(self.weight1, sparsity_fc1) + self._apply_pruning(self.weight2, sparsity_fc2) + + self.weight1 = self.weight1.to(device=device) + self.weight2 = self.weight2.to(device=device) + + print(f"WEIGHT1 SHAPE > {self.weight1.shape}") # (input_size, hidden_size) + print(f"WEIGHT2 SHAPE > {self.weight2.shape}") # (hidden_size, output_size) + + def _apply_pruning(self, tensor, sparsity): + mask = torch.rand_like(tensor) > sparsity + tensor *= mask + + def forward(self, x): + x = torch.sparse.mm(x, self.weight1) + x = torch.sparse.mm(x, self.weight2) + return x + + +def test_sparse_mlp(device, batch_size=32, input_size=128, hidden_size=128, output_size=128): + torch.manual_seed(0) + # mlp = MLP(input_size, hidden_size, output_size) + mlp = SparseMLP(input_size, hidden_size, output_size, device) + mlp = mlp.to(device=device) + input = torch.randn(batch_size, input_size) + x1 = input.to(device=device) + opt_fn = torch.compile(dynamic=False)(mlp) + res = opt_fn(x1) + + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + from Scheduler.scheduler import ExecutionEngine + + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_sparse_mlp(device, batch_size=8, input_size=16, hidden_size=32, output_size=64) + diff --git a/tests/test_sparsity.py b/tests/test_sparsity.py index c72dbb98..3e079f83 100644 --- a/tests/test_sparsity.py +++ b/tests/test_sparsity.py @@ -8,7 +8,7 @@ import torch._dynamo import torch.utils.cpp_extension sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) -from test_transformer import DecoderBlock, test_result +from test_transformer import EncoderBlock, test_result from test_mlp import MLP def apply_random_zero(tensor, zero_prob, block_size=8): @@ -35,30 +35,30 @@ def count_zeros_in_tensor_list(tensor_list): def test_dec_inf(device, sparsity=0.0, block=8): torch.manual_seed(0) - decoder_block = DecoderBlock(768, 12) + encoder_block = EncoderBlock(768, 12) cpu_query = torch.randn(512, 768) query = cpu_query.clone().to(device=device) - cpu_y = decoder_block(cpu_query) + cpu_y = encoder_block(cpu_query) with torch.no_grad(): - decoder_block.multihead_attn.linears[0].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[0].weight, sparsity, block_size=block)) - decoder_block.multihead_attn.linears[1].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[1].weight, sparsity, block_size=block)) - decoder_block.multihead_attn.linears[2].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[2].weight, sparsity, block_size=block)) - decoder_block.multihead_attn.linears[3].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[3].weight, sparsity, block_size=block)) - decoder_block.ffn1.weight.copy_(apply_random_zero(decoder_block.ffn1.weight, sparsity, block_size=block)) - decoder_block.ffn2.weight.copy_(apply_random_zero(decoder_block.ffn2.weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[0].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[0].weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[1].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[1].weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[2].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[2].weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[3].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[3].weight, sparsity, block_size=block)) + encoder_block.ffn1.weight.copy_(apply_random_zero(encoder_block.ffn1.weight, sparsity, block_size=block)) + encoder_block.ffn2.weight.copy_(apply_random_zero(encoder_block.ffn2.weight, sparsity, block_size=block)) count_zeros_in_tensor_list([ - decoder_block.multihead_attn.linears[0].weight, - decoder_block.multihead_attn.linears[1].weight, - decoder_block.multihead_attn.linears[2].weight, - decoder_block.multihead_attn.linears[3].weight, - decoder_block.ffn1.weight, - decoder_block.ffn2.weight + encoder_block.multihead_attn.linears[0].weight, + encoder_block.multihead_attn.linears[1].weight, + encoder_block.multihead_attn.linears[2].weight, + encoder_block.multihead_attn.linears[3].weight, + encoder_block.ffn1.weight, + encoder_block.ffn2.weight ]) - decoder_block.to(device=device) - opt_fn = torch.compile(dynamic=False)(decoder_block) + encoder_block.to(device=device) + opt_fn = torch.compile(dynamic=False)(encoder_block) y = opt_fn(query) test_result("MLP Forward", y, cpu_y) @@ -83,7 +83,6 @@ def test_mlp_inf(device, batch_size=64, input_size=64, hidden_size=32, output_si test_result("MLP Forward", y, cpu_y) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Count zeros in tensors from command-line arguments.") parser.add_argument( "--sparsity", @@ -102,4 +101,4 @@ def test_mlp_inf(device, batch_size=64, input_size=64, hidden_size=32, output_si device = module.custom_device() #test_dec_inf(device, sparsity=args.sparsity, block=args.block) - test_mlp_inf(device, batch_size=64, input_size=784, hidden_size=512, output_size=256, sparsity=args.sparsity, block=args.block) + test_mlp_inf(device, batch_size=32, input_size=784, hidden_size=512, output_size=256, sparsity=args.sparsity, block=args.block) diff --git a/tests/test_spmm_scheduler.py b/tests/test_spmm_scheduler.py new file mode 100644 index 00000000..1cf0d3b3 --- /dev/null +++ b/tests/test_spmm_scheduler.py @@ -0,0 +1,66 @@ +import os +import sys +import torch +import argparse +sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request +from test_sparse_core import SparseMLP as model1 +from test_transformer import EncoderBlock as model2 +CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--batch_size", type=int, default=128, help="Batch size") + parser.add_argument("--input_size", type=int, default=128, help="Input layer size") + parser.add_argument("--hidden_size", type=int, default=128, help="Hidden layer size") + parser.add_argument("--output_size", type=int, default=128, help="Output layer size") + parser.add_argument("--w1_sparsity", type=float, default=0.5, help="Sparsity of first layer weights (0 to 1)") + parser.add_argument("--w2_sparsity", type=float, default=0.5, help="Sparsity of second layer weights (0 to 1)") + parser.add_argument("--config", type=str) + args = parser.parse_args() + + batch_size = args.batch_size + input_size = args.input_size + hidden_size = args.hidden_size + output_size = args.output_size + w1_sparsity = args.w1_sparsity + w2_sparsity = args.w2_sparsity + config_path = f"{CONFIG_TORCHSIM_DIR}/PyTorchSimBackend/configs/{args.config}" + + print("batch_size: ", batch_size) + print("input_size: ", input_size) + print("hidden_size: ", hidden_size) + print("output_size: ", output_size) + print("w1_sparsity: ", w1_sparsity) + print("w2_sparsity: ", w2_sparsity) + + with torch.no_grad(): + # Init scheduler + scheduler = Scheduler(num_request_queue=2, engine_select=Scheduler.FIFO_ENGINE, + backend_config=config_path) + + target_model1 = model1(input_size, hidden_size, output_size, w1_sparsity, w2_sparsity, scheduler.execution_engine.module.custom_device()).eval() + target_model2 = model2(768, 12).eval() + + # Register compiled model + opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device())) + opt_model2 = torch.compile(target_model2.to(device=scheduler.execution_engine.module.custom_device())) + SchedulerDNNModel.register_model("mlp", opt_model1) + SchedulerDNNModel.register_model("bert", opt_model2) + + # Init input data + model_input1 = torch.randn(batch_size, input_size) + model_input2 = torch.randn(1, 512, 768) + + # Init request + new_request1 = Request("mlp", [model_input1], [], request_queue_idx=0) + #new_request2 = Request("bert", [model_input2], [], request_queue_idx=1) + + + # Add request to scheduler + scheduler.add_request(new_request1, request_time=0) + #scheduler.add_request(new_request2, request_time=0) + + # Run scheduler + while not scheduler.is_finished(): + scheduler.schedule() \ No newline at end of file diff --git a/tests/test_stonne.py b/tests/test_stonne.py new file mode 100644 index 00000000..5e4fe5fb --- /dev/null +++ b/tests/test_stonne.py @@ -0,0 +1,60 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension +import random +import numpy as np +import argparse + +random.seed(0) +np.random.seed(0) +torch.manual_seed(0) + +def apply_pruning(tensor, sparsity): + mask = torch.rand_like(tensor) >= sparsity + tensor *= mask + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def sparse_matmul(a, b): + return torch.sparse.mm(a, b) + +def test_sparse_mm(device, input_size=128, hidden_size=128, output_size=128, sparsity=0.0): + input = torch.randn(input_size, hidden_size) + weight = torch.randn(hidden_size, output_size) + apply_pruning(input, sparsity) + apply_pruning(weight, sparsity) + x1 = input.to(device=device) + w1 = weight.to(device=device) + opt_fn = torch.compile(dynamic=False)(sparse_matmul) + res = opt_fn(x1, w1) + cpu_res = sparse_matmul(input.cpu(), weight.cpu()) + #test_result("spmm", res, cpu_res) + + +if __name__ == "__main__": + import os + import sys + parser = argparse.ArgumentParser(description="stonne test") + parser.add_argument("sz", nargs="?", type=int, help="size", default=64) + parser.add_argument("sparsity", nargs="?", type=float, help="%% of zero", default=0.0) + + args = parser.parse_args() + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_sparse_mm(device, args.sz, args.sz, args.sz, args.sparsity) \ No newline at end of file diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 44ffe5b8..4d45707e 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -5,12 +5,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) @@ -29,43 +33,29 @@ def __init__(self, h, d_model, dropout=0.1): self.linears = clones(torch.nn.Linear(d_model, d_model), 4) self.attn = None - def attention(self, query, key, value): - d_k = query.size(-1) - print(torch.matmul(query, key.transpose(-2, -1))) - - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) - p_attn = scores.softmax(dim=-1) - print(p_attn) - return torch.matmul(p_attn, value), p_attn - def forward(self, query, key, value): # 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = [ - lin(x).view(-1, self.h, self.d_k).transpose(0, 1).contiguous() + lin(x).view(-1, self.h, self.d_k).transpose(0, 1) for lin, x in zip(self.linears, (query, key, value)) ] # 2) Apply attention on all the projected vectors in batch. - # x, self.attn = self.attention(query, key, value) - # d_k = query.size(-1) - - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k) - p_attn = scores.softmax(dim=-1) - x = torch.matmul(p_attn, value) + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(self.d_k) + p_attn = scores.softmax(dim=-2) + x = torch.matmul(value.transpose(-1, -2), p_attn) # 3) "Concat" using a view and apply a final linear. x = ( - x.transpose(0, 1) - .contiguous() - .view(-1, self.h * self.d_k) + x.view(-1, self.h * self.d_k) ) del query del key del value return self.linears[-1](x) -class DecoderBlock(torch.nn.Module): +class EncoderBlock(torch.nn.Module): def __init__(self, embed_dim, num_heads): - super(DecoderBlock, self).__init__() + super(EncoderBlock, self).__init__() self.multihead_attn = my_MultiheadAttention(num_heads, embed_dim) self.layer_norm = torch.nn.LayerNorm(embed_dim) self.ffn1 = torch.nn.Linear(embed_dim, embed_dim*4) @@ -73,7 +63,7 @@ def __init__(self, embed_dim, num_heads): self.ffn2 = torch.nn.Linear(embed_dim*4, embed_dim) def forward(self, x): - result = self.multihead_attn(x, x, x) + result = self.multihead_attn(x, x, x).reshape(x.shape) result = self.layer_norm(result+x) ffn1_result = self.ffn1(result) @@ -81,37 +71,49 @@ def forward(self, x): ffn2_result = self.ffn2(act_result) return self.layer_norm(ffn2_result + result) -def test_DecoderBlock(device): - cpu_query = torch.randn(512, 768) - decoder_block = DecoderBlock(768, 12) - cpu_res = decoder_block(cpu_query) +def test_EncoderBlock(device, head=12, embed_dim=768, input_seq=512): + cpu_query = torch.randn(1, input_seq, embed_dim) + encoder_block = EncoderBlock(embed_dim, head) + cpu_res = encoder_block(cpu_query) query = cpu_query.clone().to(device=device) - decoder_block.to(device=device) - opt_fn = torch.compile(dynamic=False)(decoder_block) + encoder_block.to(device=device) + opt_fn = torch.compile(dynamic=False)(encoder_block) res = opt_fn(query) - test_result("Decoder Block Forwrad", res, cpu_res) + test_result("Encoder Block Forwrad", res, cpu_res) -def test_Attention(device): +def test_Attention(device, head=16, seq=512, d_k=64): def attention(query, key, value): import math d_k = query.size(-1) - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) - p_attn = scores.softmax(dim=-1) - return torch.matmul(p_attn, value), p_attn + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) + p_attn = scores.softmax(dim=-2) + return torch.matmul(value.transpose(-1, -2), p_attn) torch.manual_seed(0) - query = torch.randn(16, 128).to(device=device) - key = torch.randn(16, 128).to(device=device) - value = torch.randn(16, 128).to(device=device) + query = torch.randn(head, seq, d_k).to(device=device) + key = torch.randn(head, seq, d_k).to(device=device) + value = torch.randn(head, seq, d_k).to(device=device) opt_fn = torch.compile(dynamic=False)(attention) - res, p_attn = opt_fn(query, key, value) + res = opt_fn(query, key, value) - cpu_res, cpu_p_attn = attention(query.cpu(), key.cpu(), value.cpu()) + cpu_res = attention(query.cpu(), key.cpu(), value.cpu()) test_result("Attention Forward", res, cpu_res) +def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): + MHA = my_MultiheadAttention(num_heads, embed_dim) + cpu_query = torch.randn(input_seq, embed_dim) + cpu_res = MHA(cpu_query, cpu_query, cpu_query) + + query = cpu_query.clone().to(device=device) + MHA.to(device=device) + opt_fn = torch.compile(dynamic=False)(MHA) + res = opt_fn(query, query, query) + + test_result("MHA Forward", res, cpu_res) + if __name__ == "__main__": import os import sys @@ -120,4 +122,6 @@ def attention(query, key, value): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_DecoderBlock(device) + test_EncoderBlock(device) + # test_Attention(device, head=16, seq=512, d_k=64) + # test_MHA(device, num_heads=12, embed_dim=768) diff --git a/tests/test_transpose2D.py b/tests/test_transpose2D.py index afc17a23..14f16fbb 100644 --- a/tests/test_transpose2D.py +++ b/tests/test_transpose2D.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) diff --git a/tests/test_transpose3D.py b/tests/test_transpose3D.py index d19ea242..937948c4 100644 --- a/tests/test_transpose3D.py +++ b/tests/test_transpose3D.py @@ -3,12 +3,16 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) diff --git a/tests/test_vectorops.py b/tests/test_vectorops.py new file mode 100644 index 00000000..0677b7ae --- /dev/null +++ b/tests/test_vectorops.py @@ -0,0 +1,32 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + + # Target shape + seq_list = [1,128,512,2048,8192] + d_model = 768 + from tests.test_add import test_vectoradd + from tests.test_activation import test_GeLU + from tests.test_reduce import test_reduce_sum2 + from tests.test_layernorm import test_LayerNorm + from tests.test_softmax import test_softmax + func_list = [test_vectoradd, test_GeLU, test_reduce_sum2, test_LayerNorm, test_softmax] + for test_func in func_list: + for seq in seq_list: + if test_func == test_GeLU: + print(f"[log] {test_func.__name__}, seq: {seq}") + test_func(device, size=[seq, d_model*4]) + elif test_func == test_softmax: + print(f"[log] {test_func.__name__}, seq: {seq}") + test_func(device, size=[seq, seq]) + else: + print(f"[log] {test_func.__name__}, seq: {seq}") + test_func(device, size=[seq, d_model]) diff --git a/tests/test_view3D_2D.py b/tests/test_view3D_2D.py index 60575ada..a5a31a85 100644 --- a/tests/test_view3D_2D.py +++ b/tests/test_view3D_2D.py @@ -3,27 +3,42 @@ import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" print("-" * len(message)) print(message) print("-" * len(message)) else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) exit(1) -def test_view3D_2D(device): +def test_view3D_2D(device, size=(16, 8, 16), t_x=0, t_y=1): def view3D_2D(a): - return a.view(16, 128).contiguous() + return a.transpose(t_x, t_y).contiguous().view(-1, size[0] * size[2]) torch.manual_seed(0) - cpu_input = torch.randn(16, 8, 16) + cpu_input = torch.randn(size) input = cpu_input.clone().to(device=device) opt_fn = torch.compile(dynamic=False)(view3D_2D) res = opt_fn(input) out = view3D_2D(cpu_input) test_result("view 3D->2D", res, out) +def test_view2D_3D(device, size=(512, 768), h=12, d_k=64): + def view2D_3D(a): + return a.view(-1, h, d_k).transpose(0, 1).contiguous() + torch.manual_seed(0) + cpu_input = torch.randn(size) + input = cpu_input.clone().to(device=device) + opt_fn = torch.compile(dynamic=False)(view2D_3D) + res = opt_fn(input) + out = view2D_3D(cpu_input) + test_result("view 2D->3D", res, out) + if __name__ == "__main__": import os import sys @@ -33,4 +48,6 @@ def view3D_2D(a): module = ExecutionEngine.setup_device() device = module.custom_device() test_view3D_2D(device) + test_view3D_2D(device, [12, 512, 64]) + test_view2D_3D(device, size=(512, 1024), h=16, d_k=64) diff --git a/validation/gemm_tpuv3_cheatsheet.json b/validation/gemm_tpuv3_cheatsheet.json new file mode 100644 index 00000000..76a26e1a --- /dev/null +++ b/validation/gemm_tpuv3_cheatsheet.json @@ -0,0 +1,17 @@ +{ + "512_2048_8192" : { + "TILE_M" : 512, + "TILE_K" : 512, + "TILE_N" : 1024 + }, + "512_2048_2048" : { + "TILE_M" : 512, + "TILE_K" : 512, + "TILE_N" : 1024 + }, + "2048_2048_512" : { + "TILE_M" : 1024, + "TILE_K" : 512, + "TILE_N" : 512 + } +} \ No newline at end of file