diff --git a/.github/workflows/pytorchsim_test.yml b/.github/workflows/pytorchsim_test.yml index 25ce0a16..c27df48a 100644 --- a/.github/workflows/pytorchsim_test.yml +++ b/.github/workflows/pytorchsim_test.yml @@ -20,6 +20,13 @@ jobs: name: Run test_add.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_add.py run: | echo "Running test_add.py" @@ -30,10 +37,38 @@ jobs: -e TORCHSIM_SPAD_SIZE="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_add.py + test_transcendental: + name: Run test_transcendental.py + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_add.py + run: | + echo "Running test_transcendental.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + -e TORCHSIM_VECTOR_LANE="${{ inputs.vector_lane }}" \ + -e TORCHSIM_SPAD_SIZE="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/test_transcendental.py + test_activation: name: Run test_activation.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_activation.py run: | echo "Running test_activation.py" @@ -48,6 +83,13 @@ jobs: name: Run test_batchnorm.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_batchnorm.py run: | echo "Running test_batchnorm.py" @@ -62,6 +104,13 @@ jobs: name: Run test_bmm.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_bmm.py run: | echo "Running test_bmm.py" @@ -76,6 +125,13 @@ jobs: name: Run test_cnn.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_cnn.py run: | echo "Running test_cnn.py" @@ -90,6 +146,13 @@ jobs: name: Run test_conv2d.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_conv2d.py run: | echo "Running test_conv2d.py" @@ -104,6 +167,13 @@ jobs: name: Run test_matmul.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_matmul.py run: | echo "Running test_matmul.py" @@ -118,6 +188,13 @@ jobs: name: Run test_reduce.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_reduce.py run: | echo "Running test_reduce.py" @@ -132,6 +209,13 @@ jobs: name: Run test_softmax.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_softmax.py run: | echo "Running test_softmax.py" @@ -146,6 +230,13 @@ jobs: name: Run test_transpose2D.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_transpose2D.py run: | echo "Running test_transpose2D.py" @@ -160,6 +251,13 @@ jobs: name: Run test_view3D_2D.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_view3D_2D.py run: | echo "Running test_view3D_2D.py" @@ -174,6 +272,13 @@ jobs: name: Run test_layernorm.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_layernorm.py run: | echo "Running test_layernorm.py" @@ -188,6 +293,13 @@ jobs: name: Run test_mlp.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_mlp.py run: | echo "Running test_mlp.py" @@ -233,6 +345,13 @@ jobs: name: Run test_transformer.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_transformer.py run: | echo "Running test_transformer.py" @@ -247,6 +366,13 @@ jobs: name: Run test_transpose3D.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_transpose3D.py run: | echo "Running test_transpose3D.py" @@ -261,6 +387,13 @@ jobs: name: Run test_sparsity.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_sparsity.py run: | echo "Running test_sparsity.py" @@ -275,6 +408,13 @@ jobs: name: Run test_pool.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_pool.py run: | echo "Running test_pool.py" @@ -289,6 +429,13 @@ jobs: name: Run test_perceptron.py runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_single_perceptron.py run: | echo "Running test_single_perceptron.py" @@ -303,6 +450,13 @@ jobs: name: Run test_fusion runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_addmm_residual.py run: | echo "Running test_addmm_residual.py" @@ -387,6 +541,13 @@ jobs: name: Run test_moe runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_moe.py run: | echo "Running test_moe.py" @@ -401,6 +562,13 @@ jobs: name: Run test_mistral runs-on: self-hosted steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_mistral.py run: | echo "Running test_mistral.py" @@ -411,16 +579,62 @@ jobs: -e TORCHSIM_SPAD_SIZE="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Mixtral_8x7B/test_attention.py + test_vit: + name: Run test_vit + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_vit.py + run: | + echo "Running test_vit.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + -e TORCHSIM_VECTOR_LANE="${{ inputs.vector_lane }}" \ + -e TORCHSIM_SPAD_SIZE="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/test_vit.py + + test_diffusion: + name: Run test_diffusion + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_diffusion.py + run: | + echo "Running test_diffusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + -e TORCHSIM_VECTOR_LANE="${{ inputs.vector_lane }}" \ + -e TORCHSIM_SPAD_SIZE="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/Diffusion/test_diffusion.py + test_indirect: name: Run test_indirect runs-on: self-hosted - env: - GIT_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_indirect.py run: | echo "Running test_indirect.py" - echo $GIT_ACCESS_TOKEN | docker login ghcr.io -u USERNAME --password-stdin docker run --rm \ -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ @@ -431,13 +645,17 @@ jobs: test_scheduler: name: Run test_scheduler runs-on: self-hosted - env: - GIT_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Run test_scheduler.py run: | echo "Running test_scheduler.py" - echo $GIT_ACCESS_TOKEN | docker login ghcr.io -u USERNAME --password-stdin docker run --rm \ -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ diff --git a/Dockerfile.base b/Dockerfile.base index 2b3d58d6..1ac5e175 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -42,7 +42,7 @@ RUN wget https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download/2 rm *.tar.gz # Install torchsim dependency -RUN apt install ninja-build && pip install onnx matplotlib && pip install --user conan==1.56.0 +RUN apt install ninja-build && pip install onnx matplotlib && pip install --user conan==1.56.0 && pip install "transformers<4.44" && pip install diffusers==0.34.0 ENV RISCV=/workspace/riscv ENV PATH=$RISCV/bin:$PATH diff --git a/PyTorchSimFrontend/extension_device.cpp b/PyTorchSimFrontend/extension_device.cpp index 4d33db08..1a02bfe3 100644 --- a/PyTorchSimFrontend/extension_device.cpp +++ b/PyTorchSimFrontend/extension_device.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -204,19 +205,25 @@ at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous. TORCH_CHECK(self.sizes() == dst.sizes()); - TORCH_CHECK(self.scalar_type() == dst.scalar_type()); - if (self.is_contiguous() && dst.is_contiguous()) { + const bool same_dtype = (self.scalar_type() == dst.scalar_type()); + const bool both_contig = self.is_contiguous() && dst.is_contiguous(); + + // 1) fast path + if (same_dtype && both_contig) { std::memcpy(dst.mutable_data_ptr(), self.data_ptr(), dst.storage().nbytes()); - } else { - // Using cpu tensor to accomplishment stride copy. - at::Tensor cpu_self = unsafe_create_cpu_tensor_from_dummy_tensor(self); - at::Tensor cpu_dst = unsafe_create_cpu_tensor_from_dummy_tensor(dst); - cpu_dst.copy_(cpu_self); + return dst; } + // 2) slow path + at::Tensor cpu_self = unsafe_create_cpu_tensor_from_dummy_tensor(self); + at::Tensor cpu_dst = unsafe_create_cpu_tensor_from_dummy_tensor(dst); + if (!same_dtype) { + cpu_self = cpu_self.to(cpu_dst.scalar_type(), /*non_blocking=*/false, /*copy=*/true); + } + cpu_dst.copy_(cpu_self); return dst; } @@ -230,7 +237,6 @@ at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) { at::Tensor custom_empty_strided(c10::IntArrayRef size, c10::IntArrayRef stride, c10::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { op_counter += 1; - constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); auto dtype = c10::dtype_or_default(dtype_opt); return at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype); @@ -244,7 +250,23 @@ at::Tensor custom_empty(c10::IntArrayRef size, c10::optional dty return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, dtype, optional_memory_format); } -// This macro does the heavy lifting. +at::Tensor& custom_arange_start_out_impl( + const c10::Scalar& start, + const c10::Scalar& end, + const c10::Scalar& step, + at::Tensor& out) { + //const int64_t n = arange_len(start.toDouble(), end.toDouble(), step.toDouble()); + //at::native::resize_output(out, {n}); + return out; +} + +static at::Tensor custom_to_dtype_impl(const at::Tensor& self, + c10::ScalarType dtype, + bool non_blocking, bool copy, + c10::optional memory_format) { + return at::native::to(self, dtype, non_blocking, copy, memory_format); +} + // With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend. // For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key. // Later in this file, we map a custom device to the PrivateUse1 device type, @@ -255,6 +277,7 @@ at::Tensor custom_empty(c10::IntArrayRef size, c10::optional dty // More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/. TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("to.Device", &custom_to_device); + m.impl("to.dtype", &custom_to_dtype_impl); m.impl("fill_.Scalar", &custom_fill__scalar); m.impl("_copy_from", &custom__copy_from); m.impl("_copy_from_and_resize", &custom__copy_from_and_resize); @@ -262,14 +285,19 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("empty.memory_format", &custom_empty); m.impl("as_strided", at::native::as_strided_tensorimpl); m.impl("view", at::native::view); + m.impl("arange.start_out", &custom_arange_start_out_impl); +} + +TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) { + m.impl("to.dtype", &custom_to_dtype_impl); } TORCH_LIBRARY_FRAGMENT(aten, m) { - m.def( - "_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor", - torch::dispatch( - c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor), - {at::Tag::pt2_compliant_tag}); +m.def( + "_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor", + torch::dispatch( + c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor), + {at::Tag::pt2_compliant_tag}); } void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { @@ -307,6 +335,9 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("_foreach_addcdiv_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("_foreach_add_.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("cat.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("_native_multi_head_attention", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("resize_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("exp.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); } // This basic implementation doesn't bother dealing with different device indices diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 9a9785e1..79e03bd5 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -173,9 +173,9 @@ def render(self, W_tensor = empty_strided(W.layout.size, W.layout.stride) X_tensor = empty_strided(X.layout.size, X.layout.stride) - if len(W_tensor.size()) > 3: + if len(W_tensor.size()) > 3 or len(W_tensor.size()) == 2: W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]]) - if len(X_tensor.size()) > 3: + if len(X_tensor.size()) > 3 or len(X_tensor.size()) == 2: X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]]) B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2] @@ -217,6 +217,7 @@ def render(self, X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) X_tile_desc.set_name("X_buffer") + X_tile_desc.offset = X.get_layout().offset X_stride = X_tensor.stride() X_idx = [loop_dim[0]*X_stride[0], loop_dim[1]*X_stride[1], loop_dim[3]*X_stride[2]] # To keep index arguemnt order, we used index_list @@ -225,6 +226,7 @@ def render(self, W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) W_tile_desc.set_name("W_buffer") + W_tile_desc.offset = W.get_layout().offset W_stride = W_tensor.stride() W_idx = [loop_dim[0]*W_stride[0], loop_dim[3]*W_stride[1], loop_dim[2]*W_stride[2]] @@ -241,8 +243,12 @@ def render(self, Y_idx = [loop_dim[0]*Y_stride[0], loop_dim[2]*Y_stride[2], loop_dim[1]*Y_stride[1]] # Extract Bias info + Bias_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Bias_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Bias_tile_desc.set_name("Y_buffer") if Bias is not None: Bias_stride = Bias.get_layout().stride + Bias_tile_desc.offset = Bias.get_layout().offset if nr_rdim == 0: Bias_idx = [loop_dim[0]*Bias_stride[0], loop_dim[1]*Bias_stride[1], loop_dim[2]*Bias_stride[2]] else: diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index ff87c1d3..79677a3d 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -2,7 +2,8 @@ import sympy import re import os -import math +from functools import reduce +from operator import mul import torch from collections import defaultdict from concurrent.futures import ThreadPoolExecutor @@ -15,7 +16,7 @@ is_welford_reduction, sympy_product ) -from torch.utils._sympy.functions import ModularIndexing +from torch.utils._sympy.functions import ModularIndexing, FloorDiv import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend import extension_config @@ -260,10 +261,10 @@ def binary_elementwise_common(operand1, operand2, var_info): operand2 = ops.to_dtype(operand2, op_type1[1], var_info) op_type2 = var_info[operand2] elif op_type1[1][0] == op_type2[1][0]: - if int(op_type1[1][1:]) > int(op_type2[1][1:]): + if mlir_common.MLIR_TO_BIT[op_type1[1]] > mlir_common.MLIR_TO_BIT[op_type2[1]]: operand2 = ops.ext(operand2, op_type1[1]) op_type2 = var_info[operand2] - elif int(op_type1[1][1:]) < int(op_type2[1][1:]): + elif mlir_common.MLIR_TO_BIT[op_type1[1]] < mlir_common.MLIR_TO_BIT[op_type2[1]]: operand1 = ops.ext(operand1, op_type2[1]) op_type1 = var_info[operand1] else: @@ -348,17 +349,21 @@ def maximum(operand1, operand2, *args, var_info=None, **kwargs): @staticmethod def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): src_mlir_dtype = var_info[operand][1] + if src_mlir_dtype == "index": + operand = ops.index_cast(operand, "i64", var_info=var_info) + src_mlir_dtype = var_info[operand][1] + tile_size = var_info[operand][0] if isinstance(dst_mlir_dtype, torch.dtype): dst_mlir_dtype = mlir_common.DTYPE_TO_MLIR[dst_mlir_dtype] - dst_bits = int(dst_mlir_dtype[1:]) - src_bits = int(src_mlir_dtype[1:]) + dst_bits = mlir_common.MLIR_TO_BIT[dst_mlir_dtype] + src_bits = mlir_common.MLIR_TO_BIT[src_mlir_dtype] shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype src_shape = f"vector<{tile_size}x{src_mlir_dtype}>" if tile_size > 1 else src_mlir_dtype if dst_mlir_dtype[0] == "i" and src_mlir_dtype[0] == "f": - return f"arith.fptoui%{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] + return f"arith.fptoui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] if dst_mlir_dtype[0] == "f" and src_mlir_dtype[0] == "i": - return f"arith.uitofp%{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] + return f"arith.uitofp %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] if dst_mlir_dtype[0] == "i": if dst_bits > src_bits: return f"arith.extui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] @@ -428,7 +433,7 @@ def erf(operand, *args, var_info=None, **kwargs): val = ops.constant(0, op_type[1]) var_info[val][0] = 4 operand = ops.broadcast(operand, val) - val = ops.exp(operand) + val = ops.erf(operand) result = ops.extractelement(val, 0) return result, var_info[result] op_type = var_info[operand] @@ -447,7 +452,7 @@ def tanh(operand, *args, var_info=None, **kwargs): val = ops.constant(0, op_type[1]) var_info[val][0] = 4 operand = ops.broadcast(operand, val) - val = ops.exp(operand) + val = ops.tanh(operand) result = ops.extractelement(val, 0) return result, var_info[result] op_type = var_info[operand] @@ -461,6 +466,54 @@ def tanh(operand, *args, var_info=None, **kwargs): shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype return f'math.tanh %{operand} : {shape}', [tile_size, dtype] + @staticmethod + def sin(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + + # Check scalar + op_type = var_info[operand] + if op_type[0] == 1: + val = ops.constant(0, op_type[1]) + var_info[val][0] = 4 + operand = ops.broadcast(operand, val) + val = ops.sin(operand) + result = ops.extractelement(val, 0) + return result, var_info[result] + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.sin %{operand} : {shape}', [tile_size, dtype] + + @staticmethod + def cos(operand, *args, var_info=None, **kwargs): + op_type = var_info[operand] + + # Check scalar + op_type = var_info[operand] + if op_type[0] == 1: + val = ops.constant(0, op_type[1]) + var_info[val][0] = 4 + operand = ops.broadcast(operand, val) + val = ops.cos(operand) + result = ops.extractelement(val, 0) + return result, var_info[result] + op_type = var_info[operand] + tile_size = op_type[0] + dtype = op_type[1] + + # Type check & auto cast + if dtype[0] != "f": + operand, dtype = ops.to_dtype(operand, "f32", var_info=var_info) + var_info[operand] = dtype + shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype + return f'math.cos %{operand} : {shape}', [tile_size, dtype] + @staticmethod def sqrt(operand, *args, var_info=None, **kwargs): op_type = var_info[operand] @@ -954,14 +1007,18 @@ def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> com # Extract index var indirect_args = [f"%{i}" for i in indirect_dims] + if len(indirect_args): + comments = "{indirect_access} " + comments # Add indirect access attribute expr_str = str(expr) + if "//" in expr_str: + expr_str = expr_str.replace("//", " floordiv ") args = ", ".join(map(str, indices)) map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[{','.join(indirect_dims)}] -> ({expr_str})>") args = ", ".join([f"%{i}" for i in indices]) index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[{','.join(indirect_args)}] {comments}") return index - def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: + def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0)) -> common.CSEVariable: if buffer is None: buffer = self.applys zero_var = self.get_const_cse(0) @@ -970,8 +1027,8 @@ def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: if len(expr_list) == 1 and expr_list[0].is_number: # Constant case - return self.get_const_cse(int(expr_list[0])) - elif len(expr_list) == 1 and expr_list[0].is_symbol: + return self.get_const_cse(int(expr_list[0] + offset)) + elif len(expr_list) == 1 and expr_list[0].is_symbol and int(offset) == 0: # Identity case return expr_list[0] @@ -993,7 +1050,7 @@ def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: indices.append(str(new_arg)) # Extract index var - expr_str = str(sum(new_expr_list)) + expr_str = str(sum(new_expr_list) + offset) args = ", ".join(map(str, dim_list)) map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[] -> ({expr_str})>") args = ", ".join([f"%{i}" for i in indices]) @@ -1002,9 +1059,19 @@ def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: def load(self, name: str, index: sympy.Expr): index = self.rename_indexing(index) - index = self.convert_indirect_indexing(index) + index, comptute_depedency = self.convert_indirect_indexing(index) padding = self.get_padding_type() + # In case of special form of indirect access, we need to put load in dma_store buffer + if comptute_depedency: + apply_buffer = self.dma_stores + dma_buffer = self.dma_stores + load_buffer = self.dma_stores + else: + apply_buffer = None + dma_buffer = self.dma_loads + load_buffer = self.loads + # Extract dram info dram_var = self.kernel_group.args.input(name) dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) @@ -1012,7 +1079,7 @@ def load(self, name: str, index: sympy.Expr): mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] # Extract sram info - local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index) + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index, buffer=apply_buffer) vlane_split_axis = local_tile_desc.vlane_split_axis vlane_stride = local_tile_desc.vlane_stride tile_numel_per_lane = local_tile_desc.get_numel_per_lane() @@ -1030,19 +1097,27 @@ def load(self, name: str, index: sympy.Expr): attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding={padding}}}" code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, dram_shape, tile_shape, attribute) - self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching - compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) - # Generate vector load instruction - if compute_vec_size > 1: - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + self.cse.generate(dma_buffer, code, assignment = False) # FIXME: assignment = False does not support caching + + if not comptute_depedency: + compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) + # Generate vector load instruction + if compute_vec_size > 1: + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + else: + operation = "affine.load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" + + out = self.cse.generate(load_buffer, line) + self.register_var_info(out, [compute_vec_size, mlir_dtype]) + self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] + return out else: - operation = "affine.load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" - out = self.cse.generate(self.loads, line) - self.register_var_info(out, [compute_vec_size, mlir_dtype]) - self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] - return out + out = sram_var + self.register_var_info(out, [compute_vec_size, mlir_dtype]) + self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] + return out def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): index = self.rename_indexing(index) @@ -1063,6 +1138,9 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() require_store = True + if compute_vec_size < self.var_info[value][0]: + value = self.cse.generate(self.stores, f"vector.extract_strided_slice %{value} {{offsets = [0], sizes = [{compute_vec_size}], strides = [1]}}: vector<{self.var_info[value][0]}x{self.var_info[value][1]}> to {vshape}") + self.register_var_info(value, [compute_vec_size, mlir_dtype]) if str(value) in self.spad_buffer_dict: # Todo. If tile_size is not same (i.e., view operation), we can't apply peephole optimization easily @@ -1107,63 +1185,61 @@ def reduction(self, dtype, src_dtype, reduction_type, value): reduction_key = src_dtype, reduction_type, value sum = self.reduction(dtype, src_dtype, "sum", value) sqr_sum = self.reduction(dtype, src_dtype, "sum", ops.mul(value, value)) - self.welford_reduce_out = (sum, sqr_sum, None) - return sum, sqr_sum, None + if self.welford_reduce_out is not None: + return self.welford_reduce_out + else: + self.welford_reduce_out = (sum, sqr_sum, None) + return sum, sqr_sum, None # Prepare reduction loop - reduction_key = src_dtype, reduction_type, value - acc = self.reduction_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - iterator = self.iterator_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - init = self.init_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - init_vec = self.init_vec_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) type_name = mlir_common.DTYPE_TO_MLIR[dtype] - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") vec_len = self.kernel_group.tile_desc.get_compute_vec_size() reduced_shape = self.kernel_group.tile_desc.get_mlir_vshape(type_name) - # Set accumulation var - if vec_len == 1: # 1-D vector to scalar - # Edge case for scalar - init_vec = init - else: - # Adjust shape and inital value - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") - self.register_var_info(init_vec, [vec_len, type_name]) - acc_var = init_vec + # Prepare reduction init + init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") + init_vec = init if vec_len == 1 else self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") + self.register_var_info(init_vec, [vec_len, type_name]) + + acc_var_list = [] + iter_var_list = [] + for reduction_depth in range(self.get_nr_rdim()): + # Create reduction key + reduction_key = src_dtype, reduction_type, value, reduction_depth + acc_init_var = init_vec if reduction_depth == 0 else iter_var_list[-1] + + acc = self.reduction_cse.generate(self.loads, f"reduction {reduction_key}", write=False) + iterator = self.iterator_cse.generate(self.loads, f"reduction {reduction_key}", write=False) + acc_var_list.append(acc) + iter_var_list.append(iterator) + + # Register reduction info + self.reduction_vars[acc] = (reduction_type, iterator, acc_init_var, reduced_shape, reduction_depth) + self.reduction_cse.reduction_cache[reduction_key] = acc # Reduction body prepare - body_acc = self.reduction_cse.generate( - self.compute, f"reduction {reduction_key}body_acc", write=False - ) - body_iter_arg = self.iterator_cse.generate( - self.compute, f"reduction {reduction_key}body_iter_arg", write=False - ) + # Note: reduction body is inner most loop body. So it doesn't need reduction depth. + body_key = src_dtype, reduction_type, value + body_acc = self.reduction_cse.generate(self.compute, f"reduction {body_key}body_acc", write=False) + body_iter_arg = self.iterator_cse.generate(self.compute, f"reduction {body_key}body_iter_arg", write=False) self.register_var_info(body_iter_arg, [vec_len, type_name]) - - self.reduction_vars[acc] = (reduction_type, iterator, acc_var, reduced_shape) - self.affine_yield[body_acc] = reduced_shape - self.reduction_cse.reduction_cache[reduction_key] = acc - self.iterator_cse.reduction_cache[reduction_key] = iterator - self.init_cse.reduction_cache[reduction_key] = init_vec + acc_var_list.append(body_acc) # Reduction body codegen - mask_shape, mask_var = self.get_mask() + _, mask_var = self.get_mask() if mask_var is not None: value = ops.where(mask_var, value, init_vec) result = reduction_partial_combine_vec(reduction_type, value, body_iter_arg) - self.compute_body_loop.reduction_vars[body_acc] = (reduction_type, body_iter_arg, iterator, reduced_shape) + self.compute_body_loop.reduction_vars[body_acc] = (reduction_type, body_iter_arg, iter_var_list[-1], reduced_shape) self.compute_body_loop.affine_yield[result] = reduced_shape + # Register affine yield var + for reduction_depth, acc in enumerate(acc_var_list[1:]): + self.affine_yield[acc] = reduced_shape, reduction_depth + # Final reduction - reduction_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_tile_size()[-1] + acc = acc_var_list[0] # Set outermost acc var + reduction_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_reduction_numel() assert(vec_len % reduction_size==0) if vec_len > reduction_size: init = self.const_cse.generate(self.reductions_suffix, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") @@ -1205,7 +1281,7 @@ def store_reduction(self, name, index, value): dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) tile_stride = local_tile_desc.get_tile_stride() - compute_vec_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_tile_size()[-1] + compute_vec_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_reduction_numel() if compute_vec_size == 1: vshape = f"{mlir_dtype}" else: @@ -1214,7 +1290,8 @@ def store_reduction(self, name, index, value): if self.welford_reduce_out is not None: sum, sqr_sum, _ = self.welford_reduce_out # mean - divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.ranges[self.reduction_depth])} : f32") + reduction_numel = reduce(mul, self.ranges[self.reduction_depth:], 1) + divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(reduction_numel)} : f32") if self.buffer_types[name][1] > 1: divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.var_info[sum][0]}x{mlir_dtype}>") else: @@ -1254,13 +1331,17 @@ def store_reduction(self, name, index, value): def indirect_indexing(self, index_var, size, check=True): return str(index_var) - def _index_expr(self, tile_size, renamed_expression, index, base_vector_index): - tile_desc = self.kernel_group.tile_desc - compute_vec_size = tile_desc.get_compute_vec_size() + def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index): + # In case of index expr, dimension size should be divisible by tile size + if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges): + new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges) + self.kernel_group.tile_desc.set_tile_size(new_tile_size) + self.reset("recompile") + raise mlir_common.RecompileSignal(f"Index access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})") - strides = [1] * len(tile_size) - for i in range(len(tile_size) - 2, -1, -1): - strides[i] = strides[i + 1] * tile_size[i + 1] + tile_size = tile_desc.get_tile_size_per_lane() + compute_vec_size = tile_desc.get_compute_vec_size() + strides = tile_desc.get_tile_stride_per_lane() # Create vector index compute_vec = self.cse.generate(self.compute, f"vector.broadcast %{self.compute_idx} : index to vector<{compute_vec_size}xindex>") @@ -1278,7 +1359,25 @@ def _index_expr(self, tile_size, renamed_expression, index, base_vector_index): self.register_var_info(mod_vec, [compute_vec_size, "index"]) dim = ops.modular(ops.div(vector_index, div_vec), mod_vec) if idx == tile_desc.vlane_split_axis: # Need to add vector lane offset - offset = tile_desc.vlane_stride * strides[idx] + offset = tile_desc.vlane_stride #* strides[idx] + outer_sz = tile_size[idx] // tile_desc.vlane_stride + + nr_vector_lane = self.get_const_cse(self.vector_lane, "index") + nr_vector_lane_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{nr_vector_lane} : index to vector<{compute_vec_size}xindex>") + self.register_var_info(nr_vector_lane_vec, [compute_vec_size, "index"]) + + vlane_stride_coeff = self.get_const_cse(tile_desc.vlane_stride, "index") + vlane_outer_coeff = self.get_const_cse(outer_sz, "index") + vlane_stride_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{vlane_stride_coeff} : index to vector<{compute_vec_size}xindex>") + vlane_outer_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{vlane_outer_coeff} : index to vector<{compute_vec_size}xindex>") + self.register_var_info(vlane_stride_vec, [compute_vec_size, "index"]) + self.register_var_info(vlane_outer_vec, [compute_vec_size, "index"]) + stride_dim = ops.modular(dim, vlane_stride_vec) + outer_dim = ops.modular(ops.div(dim, vlane_stride_vec), vlane_outer_vec) + + dim = ops.add(stride_dim, ops.mul(outer_dim, nr_vector_lane_vec)) + + # Prepare vlane offset (vidx) vlane_coeff = self.get_const_cse(0, "i64") vlane_vec_size = 4 vlane_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{vlane_coeff} : i64 to vector<{vlane_vec_size}xi64>") @@ -1286,6 +1385,7 @@ def _index_expr(self, tile_size, renamed_expression, index, base_vector_index): self.register_var_info(vlane_offset, [vlane_vec_size, "i64"]) vlane_offset = ops.index_cast(vlane_offset, "index") self.register_var_info(vlane_offset, [vlane_vec_size, "index"]) + dim = ops.add(dim, vlane_offset) dim_list.append(dim) @@ -1296,6 +1396,7 @@ def _index_expr(self, tile_size, renamed_expression, index, base_vector_index): self.register_var_info(index_vec, [compute_vec_size, "index"]) offset = ops.add(index_vec, dim_list[i]) dim_list[i] = offset + arg_lists = [] for arg in renamed_expression.args: if isinstance(arg, sympy.Integer): @@ -1330,23 +1431,37 @@ def _index_expr(self, tile_size, renamed_expression, index, base_vector_index): return accum def index_expr(self, index, dtype): - tile_desc = self.kernel_group.tile_desc - tile_size = tile_desc.get_tile_size_per_lane() - mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - str_tile_size = [str(dim) for dim in tile_size] + base_tile_desc = self.kernel_group.tile_desc + if len(self.ranges) != self.reduction_depth: + # FIXME. This is a temporary solution to get tile stride of the reduction case + tile_desc = mlir_common.MLIRMultiDimTile( + base_tile_desc.get_tile_size(), + base_tile_desc.vector_lane, + base_tile_desc.vlane_split_axis, + base_tile_desc.vlane_stride, + base_tile_desc.get_compute_vec_size(), + ) + axis_order = list(range(len(tile_desc.get_tile_size()))) + axis_order = axis_order[1:] + axis_order[:1] # Move the first axis to the end + tile_desc.set_tile_size(tile_desc.get_tile_size(), axis_order) + else: + tile_desc = base_tile_desc compute_vec_size = tile_desc.get_compute_vec_size() - tile_shape = f"memref<{compute_vec_size}xi64, 1>" - vshape = f"vector<{compute_vec_size}xi64>" + + + tile_shape = f"memref<{compute_vec_size*self.vector_lane}xindex, 1>" + vshape = f"vector<{compute_vec_size}xindex>" # Create base_vector index var c_type = "uint64_t" new_name = f"index_expr_{compute_vec_size}" if new_name not in self.global_vars_dict: - self.header.writeline(f"{c_type} {new_name}_spad[{compute_vec_size}] __attribute__ ((section(\".spad\")));") + self.header.writeline(f"{c_type} {new_name}_spad[{compute_vec_size*self.vector_lane}] __attribute__ ((section(\".spad\")));") self.gem5_header.writeline(f"{c_type} {new_name}_spad[{compute_vec_size}] __attribute__((aligned(64)));") self.global_vars.writeline(f"memref.global @{new_name}_spad : {tile_shape}") self.global_vars_dict[new_name] = dict() sram_var = self.spad_cse.generate(self.spad_buffer, f"memref.get_global @{new_name}_spad : {tile_shape}") + # Initialize base vector if not self.base_vector_initialized: init_iter = "iter" @@ -1354,20 +1469,17 @@ def index_expr(self, index, dtype): self.spad_buffer.writeline(parallel_map) with self.spad_buffer.indent(): self.spad_buffer.writeline(f"%init_vec = vector.broadcast %{init_iter} : index to vector<2xindex>") - self.spad_buffer.writeline(f"%init_cvt_vec = arith.index_cast %init_vec : vector<2xindex> to vector<2xi64>") - self.spad_buffer.writeline(f"affine.vector_store %init_cvt_vec, %{sram_var}[%{init_iter}] : {tile_shape}, vector<2xi64>") + self.spad_buffer.writeline(f"affine.vector_store %init_vec, %{sram_var}[%{init_iter}] : {tile_shape}, vector<2xindex>") self.spad_buffer.writeline("}") self.base_vector_initialized = True line = f"affine.vector_load %{sram_var}[0] : {tile_shape}, {vshape}" - out = self.cse.generate(self.compute, line) - self.register_var_info(out, [compute_vec_size, "i64"]) - base_vector_index = ops.index_cast(out, "index") + base_vector_index = self.cse.generate(self.compute, line) self.register_var_info(base_vector_index, [compute_vec_size, "index"]) renamed_symbols = {symbol: "d"+str(symbol)[5:] for symbol in index.free_symbols} renamed_expression = index.subs(renamed_symbols) - result = self._index_expr(tile_size, renamed_expression, index, base_vector_index) + result = self._index_expr(tile_desc, renamed_expression, index, base_vector_index) return result def codegen_global_init(self): @@ -1403,10 +1515,11 @@ def codegen_loops(self): with contextlib.ExitStack() as stack: # Add reduction loops if len(reductions.loops): - reduction_lines = reductions.loops[0].lines() - epilogue = reductions.loops[0].epilogue_line() - code.writelines(reduction_lines) - stack.enter_context(code.indent(attribute="{accumulation_loop=true}", suffix=epilogue)) + for reduction_loop in reductions.loops: + reduction_lines = reduction_loop.lines() + epilogue = reduction_loop.epilogue_line() + code.writelines(reduction_lines) + stack.enter_context(code.indent(attribute="{accumulation_loop=true}", suffix=epilogue)) code.splice(self.applys) code.splice(self.indexed_buffer) code.splice(self.dma_loads) @@ -1596,7 +1709,8 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe elif len(local_dims) == 3: is_reduction = self.reduction_depth < 3 and not store_reduction if is_reduction: - local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims], [1, 2, 0]) + axis_order = [1, 2, 0] if self.get_nr_rdim()==1 else [2, 1, 0] + local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims], axis_order) local_tile_desc.vlane_split_axis = local_vlane_split_axis local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride else: @@ -1652,6 +1766,40 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe sorted_keys = sorted(dram_dict.keys()) dram_stride = sum((dram_dict[key] for key in sorted_keys), []) + # Support floordiv pattern + # FIXME. How to integrate implicit dims and floordiv? + # This was introduced to support GroupNorm + if index.has(FloorDiv) and not index.has(ModularIndexing): + dim_divisor = [1] * len(local_dims) + for sub in sympy.preorder_traversal(index): + if isinstance(sub, FloorDiv): + if not str(sub.args[0]).startswith("index"): + continue + dim_idx = int((str(sub.args[0])[5:])) + if int(self.kernel_group.tile_desc.get_tile_size()[dim_idx] % sub.args[1]) != 0: + # In this case, need to recompile + original_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx] + divisor = sub.args[1] + new_size = ((original_size + divisor - 1) // divisor) * divisor + new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) + new_tile_sizes[dim_idx] = new_size + self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) + + # Send recompile signal + self.reset("recompile") + raise mlir_common.RecompileSignal(f"Tile size {self.kernel_group.tile_desc.get_tile_size()[dim_idx]} is not divisible by {sub.args[1]}") + dim_divisor[dim_idx] = sub.args[1] + + # Update dram_stride, just insert 0 next to target dim + offset = 0 + for dim_idx, divisor in enumerate(dim_divisor): + if divisor == 1: + continue + dram_stride.insert(dim_idx+offset+1, 0) + local_tile_desc.apply_divisor(dim_idx+offset, divisor, "pad") + local_tile_desc.apply_divisor(dim_idx+offset, divisor, "split") + offset = offset+1 + # FIXME. It will be nice to modify node instead of this exception handling... if len(self.itervars) == 1 and self.reduction_depth == 0: # In case of reduction loop only case, we will add dummy loop so shift it once @@ -1771,18 +1919,50 @@ def get_mask(self): def convert_indirect_indexing(self, index :sympy.Expr): if "tmp" not in str(index): - return index + return index, None + + # Note: In case of indirect indexing, dimensions should be divisible by tile size + if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges): + new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges) + self.kernel_group.tile_desc.set_tile_size(new_tile_size) + self.reset("recompile") + raise mlir_common.RecompileSignal(f"Indirect access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})") # Process start indirect_dims = [str(dim) for dim in index.free_symbols if "tmp" in str(dim)] indirect_dims.sort() first_dim = indirect_dims[0] spad_vars = dict() - tmp_comp, self.compute = self.compute, self.dma_loads + old_compute, old_dma_lods, old_dma_stores = self.compute, self.dma_loads, self.dma_stores + compute_dependecy = any([target_dim not in self.spad_buffer_dict for target_dim in indirect_dims]) + if compute_dependecy: + self.compute = old_dma_stores + target_dma_buffers = self.dma_stores + else: + self.compute = old_dma_lods + target_dma_buffers = self.dma_loads # Load indirect operands for target_dim in indirect_dims: - sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim] + if target_dim in self.spad_buffer_dict: + sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim] + else: + # FIXME. + var_info = [v for k, v in self.var_info.items() if str(k) == target_dim][0] + dtype = mlir_common.MLIR_TO_DTYPE[var_info[1]] + + local_tile_desc = self.kernel_group.tile_desc + tile_numel_per_lane = local_tile_desc.get_numel_per_lane() + tile_shape = local_tile_desc.get_mlir_shape(var_info[1]) + vshape = f"vector<{var_info[0]}x{var_info[1]}>" + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, target_dim, local_tile_desc, target_dim) + self.spad_buffer_dict[target_dim] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] + + # Store the indirect index variable + opeartion = "affine.vector_store" + compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) + line = f"{opeartion} %{target_dim}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + self.stores.writeline(line) mlir_dtype = vshape.split("x")[1][:-1] vshape = f"vector<{tile_numel_per_lane}x{mlir_dtype}>" # FIXME. Maybe require fine grain compute... if tile_numel_per_lane > 1: @@ -1791,7 +1971,7 @@ def convert_indirect_indexing(self, index :sympy.Expr): else: operation = "affine.load" line = f"{operation} %{sram_var}[{sram_index_var}] : {tile_shape} // For indirect access" - out = self.cse.generate(self.dma_loads, line) + out = self.cse.generate(target_dma_buffers, line) self.register_var_info(out, [tile_numel_per_lane, mlir_dtype]) spad_vars[target_dim] = out @@ -1821,15 +2001,15 @@ def convert_indirect_indexing(self, index :sympy.Expr): else: operation = "affine.store" line = f"{operation} %{spad_vars[first_dim]}, %{sram_var}[{sram_index_var}] : {tile_shape}" - out = self.cse.generate(self.dma_loads, line, assignment=False) + out = self.cse.generate(target_dma_buffers, line, assignment=False) # Conversion mlir_dtype = self.var_info[spad_vars[first_dim]][1] line = f"affine.load %{sram_var}[{sram_index_var}] : {tile_shape}" - out = self.cse.generate(self.dma_loads, line) + out = self.cse.generate(target_dma_buffers, line) if mlir_dtype != "index": line = f"arith.index_cast %{out} : {mlir_dtype} to {'index'}" - out = self.cse.generate(self.dma_loads, line) + out = self.cse.generate(target_dma_buffers, line) self.register_var_info(out, [1, "index", [1]]) - self.compute = tmp_comp - return index + sympy.Symbol(str(out)) + self.compute, self.dma_loads, self.dma_stores = old_compute, old_dma_lods, old_dma_stores + return index + sympy.Symbol(str(out)), compute_dependecy diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 9151ac0b..73996351 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -48,6 +48,17 @@ torch.bfloat16: "bf16", } +MLIR_TO_DTYPE = { + "f32": torch.float32, + "f64": torch.float64, + "f16": torch.float16, + "i64": torch.int64, + "i32": torch.int32, + "i16": torch.int16, + "i8": torch.int8, + "bf16": torch.bfloat16, +} + DTYPE_TO_C = { torch.float32: "float", torch.float64: "double", @@ -61,6 +72,19 @@ torch.bfloat16: "bfloat16", } +MLIR_TO_BIT = { + "i1": 1, + "i8": 8, + "i16": 16, + "i32": 32, + "i64": 64, + "f16": 16, + "f32": 32, + "f64": 64, + "bf16": 16, + "index": 64 +} + DTYPE_LOWP_FP = [ torch.bfloat16, torch.float16, @@ -105,6 +129,14 @@ def ctx(): return ctx() +class RecompileSignal(BaseException): + """ + Exception raised when a recompilation of a kernel or code block is required. + """ + def __init__(self, message="Recompilation requested."): + self.message = message + super().__init__(self.message) + class MLIRKernelArgs(common.KernelArgs): MLIR_ARGS_IN = 0x01 MLIR_ARGS_OUT = 0x02 @@ -193,6 +225,9 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N self.implicit_dim_size = None self.nr_rdim = 0 + # Dram offset + self.offset = sympy.Integer(0) + def set_name(self, name: str): self.name = name @@ -247,6 +282,20 @@ def update_tile_stride(self): def get_tile_stride(self): return self._tile_stride + def get_tile_stride_per_lane(self): + tile_stride = list(self.get_tile_stride()) # original strides + tile_size = list(self.get_tile_size()) # original tile size + split_axis = self.vlane_split_axis + + tile_size_per_lane = self.get_tile_size_per_lane() + coeff = tile_size[split_axis]//tile_size_per_lane[split_axis] + + # Propagate stride according to per-lane tile size + for i in range(len(tile_stride)): + if tile_stride[i] > tile_stride[split_axis]: + tile_stride[i] = tile_stride[i] // coeff + return tile_stride + def get_tile_size_per_lane(self): tile_size_per_lane = list(self._tile_size) if self.vlane_split_axis < 0 or self.vlane_split_axis >= len(tile_size_per_lane): @@ -293,8 +342,8 @@ def get_compute_vec_size(self): if self.vec_size is not None: return self.vec_size if self.nr_rdim: - assert self.nr_rdim==1 - val = self.get_numel_per_lane() // self._tile_size[-1] + assert self.nr_rdim!=0 + val = self.get_numel_per_lane() // self.get_reduction_numel() if self.get_numel_per_lane() >= val * 8: return val*8 elif self.get_numel_per_lane() >= val * 4: @@ -314,6 +363,73 @@ def get_compute_vec_size(self): def div_round_up(size, round_val): return (size + round_val - 1) // round_val + def apply_divisor(self, axis: int, divisor: int, mode: str = "split"): + # Apply divisor to tile size at given axis. + # This method based on axis order. + old_size = self._tile_size[axis] + if divisor == 1: + return + padded = self.div_round_up(old_size, divisor) * divisor + outer = self.div_round_up(old_size, divisor) + inner = divisor + if mode == "pad": + self._tile_size[axis] = padded + self.update_tile_stride() + return + elif mode == "split": + new_sizes = list(self._tile_size) + new_sizes[axis] = outer + new_sizes.insert(axis + 1, inner) + self._tile_size = new_sizes + + # Update tile_axis_order + old_order_val = self.tile_axis_order[axis] + new_order = list(self.tile_axis_order) + new_order.insert(axis + 1, old_order_val + 0.1) + sorted_pairs = sorted( + zip(range(len(new_order)), new_order), + key=lambda x: x[1] + ) + self.tile_axis_order = [idx for idx, _ in sorted_pairs] + self.update_tile_stride() + + if self.vlane_split_axis == axis: + self.vlane_split_axis = axis + elif self.vlane_split_axis > axis: + self.vlane_split_axis += 1 + return + else: + raise ValueError(f"Unknown mode: {mode}. Supported modes are 'pad' and 'split'.") + + def get_reduction_numel(self): + return reduce(mul, self.get_tile_size()[-1*self.nr_rdim:], 1) + + def is_dim_dividable(self, dim_sizes): + if len(dim_sizes) != len(self._tile_size): + raise ValueError("dim_sizes must match the tile size dimensions.") + dim_sizes_cpy = [int(d) for d in dim_sizes] + remain = dim_sizes_cpy[self.vlane_split_axis] % self.vlane_stride + if remain: + dim_sizes_cpy[self.vlane_split_axis] += self.vlane_stride - remain + return all(d % t == 0 for d, t in zip(dim_sizes_cpy, self._tile_size)) + + def adjust_tile_to_divisible(self, dim_sizes): + def _adjust_one(dim_size, tile_size): + for candidate in range(tile_size, 0, -1): + if dim_size % candidate == 0: + return candidate + return 1 + + if len(dim_sizes) != len(self._tile_size): + raise ValueError("dim_sizes must match the tile size dimensions.") + candidate_tile_size = [_adjust_one(d, t) for d, t in zip(dim_sizes, self._tile_size)] + # FIXME. Is this the only solution? + # Round up + remain = candidate_tile_size[self.vlane_split_axis] % self.vlane_stride + if remain: + candidate_tile_size[self.vlane_split_axis] += self.vlane_stride - remain + return candidate_tile_size + class MLIRWrapperKenrelGroup(cpp.KernelGroup): def __init__(self): super().__init__() @@ -376,6 +492,9 @@ def set_ranges(self, lengths, reduction_lengths): self.itervars[self.reduction_depth :], ) + def get_nr_rdim(self): + return len(self.itervars[self.reduction_depth:]) + def load(self, name: str, index: sympy.Expr): raise NotImplementedError() @@ -521,6 +640,8 @@ def dummy_tile_size(): dim = int(self.recodegen.split("_")[-1]) tile_size = self.kernel_group.tile_desc.get_tile_size() # TODO: tile_size[dim] = tile_size[dim] * 2 + elif self.recodegen == "recompile": + return self.kernel_group.tile_desc else: raise NotImplementedError(f"Unknown recodegen reason: {self.recodegen}") @@ -591,26 +712,36 @@ def dummy_tile_size(): return tile_desc def codegen_nodes(self, nodes, kernel_name): - _, (group, reduction_group) = max( - nodes, key=lambda x: int(x.is_reduction()) - ).group - - # Set node range info - vars, reduction_vars = self.set_ranges(group, reduction_group) - tile_desc = self.compute_tile_size(nodes, vars, reduction_vars) - self.compute_body_loop.size = tile_desc.get_numel_per_lane() - self.compute_body_loop.step = tile_desc.get_compute_vec_size() - self.kernel_group.set_tile_info(tile_desc) - - _, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs() - with self as kernel: - for node in nodes: - node.run(vars, reduction_vars) - V.graph.removed_buffers |= self.removed_buffers - # V.graph.inplaced_to_remove |= self.inplaced_to_remove - src_code = self.codegen_kernel(kernel_name=kernel_name) - self.meta_kernel() - return src_code + recompile_try = 0 + max_retry_compile = 5 + while True: + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + + # Set node range info + vars, reduction_vars = self.set_ranges(group, reduction_group) + tile_desc = self.compute_tile_size(nodes, vars, reduction_vars) + self.compute_body_loop.size = tile_desc.get_numel_per_lane() + self.compute_body_loop.step = tile_desc.get_compute_vec_size() + self.kernel_group.set_tile_info(tile_desc) + try: + _, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs() + with self as kernel: + for node in nodes: + node.run(vars, reduction_vars) + except RecompileSignal as e: + recompile_try += 1 + if recompile_try > max_retry_compile: + raise RuntimeError("Failed to compile kernel after multiple attempts.") + # Retry compile nodes + #print(f"Try recompile({recompile_try}/{max_retry_compile}). Reason: {e}") + continue + V.graph.removed_buffers |= self.removed_buffers + # V.graph.inplaced_to_remove |= self.inplaced_to_remove + src_code = self.codegen_kernel(kernel_name=kernel_name) + self.meta_kernel() + return src_code def run_bench(self, nodes, kernel_name, src_code): _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() @@ -863,9 +994,9 @@ def __bool__(self): return bool(self.loops) def mark_reduction(self, reduction_vars, affine_yield=dict()): - for loop in self.loops: - loop.reduction_vars = reduction_vars - loop.affine_yield = affine_yield + for loop_depth, loop in enumerate(self.loops): + loop.reduction_vars = {key: list(val)[:-1] for key, val in reduction_vars.items() if val[-1] == loop_depth} + loop.affine_yield = {key: val[0] for key, val in affine_yield.items() if val[-1] == loop_depth} def mark_parallel(self, par_depth): loops = self.loops diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py index c5ec004c..6dd17576 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -64,7 +64,7 @@ affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { // Initialize output {%- if BIAS %} - {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} {%- else %} affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> {%- endif %} @@ -138,8 +138,8 @@ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): self.padding = kwargs["padding"] self.dilation = kwargs["dilation"] self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [i for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + self.input_shape = [str(i) for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ + "_".join([str(i) for i in self.stride]) \ + "_" + "_".join([str(i) for i in self.padding]) \ + "_" + "_".join([str(i) for i in self.dilation]) @@ -219,6 +219,11 @@ def render(self, # Extract Bias info Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + Bias_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Bias_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Bias_tile_desc.set_name("output_buffer") + if Bias is not None: + Bias_tile_desc.offset = Bias.get_layout().offset kernel.render_options = dict( KERNEL_NAME=self.name, @@ -257,6 +262,7 @@ def render(self, X_tile_desc = X_tile_desc, W_tile_desc = W_tile_desc, Y_tile_desc = Y_tile_desc, + Bias_tile_desc = Bias_tile_desc, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py index 6c31776d..8b1bf7c5 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -63,7 +63,7 @@ affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { // Initialize output {%- if BIAS %} - {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} {%- else %} affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> {%- endif %} @@ -139,8 +139,8 @@ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): self.padding = kwargs["padding"] self.dilation = kwargs["dilation"] self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [i for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + self.input_shape = [str(i) for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ + "_".join([str(i) for i in self.stride]) \ + "_" + "_".join([str(i) for i in self.padding]) \ + "_" + "_".join([str(i) for i in self.dilation]) @@ -218,6 +218,11 @@ def render(self, # Extract Bias info Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + Bias_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Bias_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Bias_tile_desc.set_name("output_buffer") + if Bias is not None: + Bias_tile_desc.offset = Bias.get_layout().offset kernel.render_options = dict( KERNEL_NAME=self.name, @@ -256,6 +261,7 @@ def render(self, X_tile_desc = X_tile_desc, W_tile_desc = W_tile_desc, Y_tile_desc = Y_tile_desc, + Bias_tile_desc = Bias_tile_desc, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py index 74309b30..2284c86c 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -64,7 +64,7 @@ affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { // Initialize output {%- if BIAS %} - {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} {%- else %} affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> {%- endif %} @@ -139,8 +139,8 @@ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): self.padding = kwargs["padding"] self.dilation = kwargs["dilation"] self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [i for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + self.input_shape = [str(i) for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ + "_".join([str(i) for i in self.stride]) \ + "_" + "_".join([str(i) for i in self.padding]) \ + "_" + "_".join([str(i) for i in self.dilation]) @@ -219,6 +219,11 @@ def render(self, # Extract Bias info Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + Bias_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Bias_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Bias_tile_desc.set_name("output_buffer") + if Bias is not None: + Bias_tile_desc.offset = Bias.get_layout().offset kernel.render_options = dict( KERNEL_NAME=self.name, @@ -257,6 +262,7 @@ def render(self, X_tile_desc = X_tile_desc, W_tile_desc = W_tile_desc, Y_tile_desc = Y_tile_desc, + Bias_tile_desc = Bias_tile_desc, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 9cbd6514..890b76b7 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -65,7 +65,7 @@ affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { // Initialize output {%- if BIAS %} - {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} {%- else %} affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> {%- endif %} @@ -143,8 +143,8 @@ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): self.padding = kwargs["padding"] self.dilation = kwargs["dilation"] self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [i for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + self.input_shape = [str(i) for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ + "_".join([str(i) for i in self.stride]) \ + "_" + "_".join([str(i) for i in self.padding]) \ + "_" + "_".join([str(i) for i in self.dilation]) @@ -224,6 +224,11 @@ def render(self, # Extract Bias info Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + Bias_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Bias_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Bias_tile_desc.set_name("output_buffer") + if Bias is not None: + Bias_tile_desc.offset = Bias.get_layout().offset kernel.render_options = dict( KERNEL_NAME=self.name, @@ -262,6 +267,7 @@ def render(self, X_tile_desc = X_tile_desc, W_tile_desc = W_tile_desc, Y_tile_desc = Y_tile_desc, + Bias_tile_desc = Bias_tile_desc, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index f706c2e5..5b339189 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -35,7 +35,7 @@ affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { {%- if Bias %} - {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} affine.vector_store %v0, %Y_buffer[0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {%- endif %} @@ -87,7 +87,7 @@ affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { %Y_bufferT = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> {%- if Bias %} - {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {%- endif %} @@ -164,6 +164,7 @@ def render(self, X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) X_tile_desc.set_name("X_buffer") + X_tile_desc.offset = X.get_layout().offset X_stride = X.get_layout().stride X_idx = [sympy.Symbol("index0") * X_stride[0], sympy.Symbol("index2") * X_stride[1]] # To keep index arguemnt order, we used index_list @@ -172,6 +173,7 @@ def render(self, W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) W_tile_desc.set_name("W_buffer") + W_tile_desc.offset = W.get_layout().offset W_stride = W.get_layout().stride W_idx = [sympy.Symbol("index2") * W_stride[0], sympy.Symbol("index1") * W_stride[1]] @@ -189,8 +191,12 @@ def render(self, # Extract Bias info Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + Bias_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Bias_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Bias_tile_desc.set_name("Y_buffer") if Bias is not None: Bias_stride = Bias.get_layout().stride + Bias_tile_desc.offset = Bias.get_layout().offset if nr_rdim == 0: Bias_idx = [sympy.Symbol("index0") * Bias_stride[0], sympy.Symbol("index1") * Bias_stride[1]] else: @@ -217,6 +223,7 @@ def render(self, X_tile_desc = X_tile_desc, W_tile_desc = W_tile_desc, Y_tile_desc = Y_tile_desc, + Bias_tile_desc = Bias_tile_desc, epilogue_nodes = epilogue_nodes, prologue_nodes = prologue_nodes, input_reorder = self.input_reorder diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index 9aa08754..6508ea86 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -1,7 +1,7 @@ from typing import List, Optional, Sequence import torch -from torch._inductor.lowering import lowerings +from torch._inductor.lowering import lowerings, index_impl from torch._inductor.kernel.mm_common import mm_args # from torch._inductor.select_algorithm import ExternKernelChoice from torch._inductor import ir @@ -175,10 +175,17 @@ def sparse_addmm(*args, **kwargs): ) return aten_spmm.bind((sp_mat1, sp_mat2), layout).output_node() +def custom_unsafe_index(x, indices): + # We can't fuse indirect access + indexed_expression + computation + if isinstance(x, TensorBox): + x.realize() + return index_impl(x, indices, check=False) + lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) +lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) if CONFIG_USE_TIMING_POOLING: lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 786971fe..2bbdb41d 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -46,10 +46,15 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction(): # For matmul/bmm+reduction case size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) - stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] target_symbol = symbols("r0") + try: + stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] + stride = int(sympify(stride).coeff(target_symbol)) + except: + return False + # We can't fuse dim=-1 - layout_possible = int(sympify(stride).coeff(target_symbol)) != 1 + layout_possible = stride != 1 # Directed linked? dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) @@ -104,6 +109,9 @@ def can_fuse_horizontal(self, node1, node2): if node1.is_template() and node2.is_template(): return False + if '_unsafe_index' in node1.get_nodes()[0].node.origins or "_unsafe_index" in node2.get_nodes()[0].node.origins: + return False + # Check template node fusion if node1.is_template() or node2.is_template(): # Don't fuse maxpool template code @@ -277,15 +285,15 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e vars, reduction_vars = kernel.set_ranges(group, reduction_group) for node in prologue_nodes: # Reuse created spad - read_list = sorted(list(node.read_writes.reads)) + read_list = sorted([i.name for i in node.read_writes.reads]) candidate_found = False # Why? There is a case that memdep.get_size() != data.get_size() buf_dict = {} buf_dict.update({val.name : val for val in V.graph.buffers}) buf_dict.update(V.graph.graph_inputs) for candidate_read in read_list: - if candidate_read.name in buf_dict and reduce(operator.mul, buf_dict[candidate_read.name].get_size(), 1) == node.node.get_numel(): - prologue_input_arg = candidate_read.name + if candidate_read in buf_dict and reduce(operator.mul, buf_dict[candidate_read].get_size(), 1) == node.node.get_numel(): + prologue_input_arg = candidate_read candidate_found = True break assert(candidate_found) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 0b2a08f8..820d5c0d 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -663,7 +663,7 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com # Prepare code block local_code = IndentedBuffer() with V.set_kernel_handler(self): - index_var = self.parse_index_list(index_list, local_code) + index_var = self.parse_index_list(index_list, local_code, offset=tile_desc.offset) node_layout = self.named_nodes[dram_var].get_layout() if dram_var in self.exception_nodes: numel = self.exception_nodes[dram_var]["numel"] diff --git a/gem5_script/vpu_config.py b/gem5_script/vpu_config.py index eeeaefab..33d26b5f 100644 --- a/gem5_script/vpu_config.py +++ b/gem5_script/vpu_config.py @@ -15,9 +15,11 @@ class SparseAccelerator(MinorFU): class SpecialFunctionUnit(MinorFU): opClasses = minorMakeOpClassSet([ - "CustomMatMulvexp", - "CustomMatMulverf", - "CustomMatMulvtanh", + "CustomVexp", + "CustomVerf", + "CustomVtanh", + "CustomVsin", + "CustomVcos", ]) opLat = 10 diff --git a/tests/Diffusion/test_diffusion.py b/tests/Diffusion/test_diffusion.py new file mode 100644 index 00000000..03d1b721 --- /dev/null +++ b/tests/Diffusion/test_diffusion.py @@ -0,0 +1,577 @@ +import os +import sys +import math +import argparse +import torch +import torch._dynamo +from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, UNetMidBlock2DCrossAttn +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from diffusers.models.upsampling import Upsample2D +from diffusers.models.resnet import ResnetBlock2D + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + diff = torch.max(torch.abs(out.cpu() - cpu_out)).item() + print(f"Max abs diff: {diff}") + exit(1) + +@torch.no_grad() +def test_unet_conditional( + device, + model_id="runwayml/stable-diffusion-v1-5", + batch=1, + dtype="float32", + rtol=1e-4, + atol=1e-4, + prompt="a cat in a hat", +): + from diffusers import DiffusionPipeline + + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(dtype, torch.float32) + + print(f"Loading pipeline: {model_id} (dtype={torch_dtype})") + pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) + pipe.to("cpu") + + unet = pipe.unet.eval() + in_ch = unet.config.in_channels + latent_sz = getattr(unet.config, "sample_size", 64) + cross_dim = getattr(unet.config, "cross_attention_dim", None) + + g = torch.Generator().manual_seed(0) + latents = torch.randn(batch, in_ch, latent_sz, latent_sz, generator=g, dtype=torch_dtype) + timestep = torch.tensor(999, dtype=torch.float32) + + enc_states_cpu = None + if hasattr(pipe, "tokenizer") and hasattr(pipe, "text_encoder") and cross_dim is not None: + try: + tokens = pipe.tokenizer( + [prompt] * batch, + padding="max_length", + max_length=getattr(pipe.tokenizer, "model_max_length", 77), + truncation=True, + return_tensors="pt", + ) + text_out = pipe.text_encoder(input_ids=tokens.input_ids).last_hidden_state # [B, T, D] + if text_out.shape[-1] != cross_dim: + print(f"Warning: text_encoder dim {text_out.shape[-1]} != cross_attn dim {cross_dim}. Fallback to random.") + raise RuntimeError("cross-dim mismatch") + enc_states_cpu = text_out.to(dtype=torch_dtype) + except Exception as e: + print(f"Text encoder unavailable or mismatch: {e}. Fallback to random encoder states.") + if enc_states_cpu is None: + if cross_dim is None: + enc_states_cpu = None + else: + seq_len = 77 + enc_states_cpu = torch.randn(batch, seq_len, cross_dim, generator=g, dtype=torch_dtype) + + latents_dev = latents.to(device) + timestep_dev = timestep.to(device) + if enc_states_cpu is not None: + enc_states_dev = enc_states_cpu.to(device) + else: + enc_states_dev = None + + print("Compiling UNet with torch.compile(...)") + unet_dev = unet.to(device) + unet_compiled = torch.compile(unet_dev, dynamic=False) + + # Forward (device) + with torch.no_grad(): + if enc_states_dev is None: + out_dev = unet_compiled(latents_dev, timestep_dev).sample + else: + out_dev = unet_compiled(latents_dev, timestep_dev, encoder_hidden_states=enc_states_dev).sample + + unet_cpu = unet.to("cpu") + if enc_states_cpu is None: + out_cpu = unet_cpu(latents.cpu(), timestep).sample + else: + out_cpu = unet_cpu(latents.cpu(), timestep, encoder_hidden_states=enc_states_cpu).sample + + test_result(f"UNet({model_id}) forward", out_dev, out_cpu, rtol=rtol, atol=atol) + print("Max diff >", torch.max(torch.abs(out_dev.cpu() - out_cpu)).item()) + print("UNet Simulation Done") + +def test_unet_mid_block2d_cross_attn( + device, + in_channels=320, + temb_channels=320, + cross_attention_dim=768, + batch=1, + height=32, + width=32, + rtol=1e-4, + atol=1e-4, + num_layers=1, + num_attention_heads=8, + dual_cross_attention=False, +): + print(f"Testing UNetMidBlock2DCrossAttn on device: {device}") + + cpu_block = UNetMidBlock2DCrossAttn( + in_channels=in_channels, + temb_channels=temb_channels, + num_layers=num_layers, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + ).to("cpu").eval() + + g = torch.Generator().manual_seed(0) + hidden_states_cpu = torch.randn(batch, in_channels, height, width, generator=g) + temb_cpu = torch.randn(batch, temb_channels, generator=g) + encoder_hidden_states_cpu = torch.randn(batch, 77, cross_attention_dim, generator=g) + + with torch.no_grad(): + cpu_out = cpu_block( + hidden_states=hidden_states_cpu, + temb=temb_cpu, + encoder_hidden_states=encoder_hidden_states_cpu, + ) + + dev_block = cpu_block.to(device).eval() + dev_block = torch.compile(dev_block, dynamic=False) + + hidden_states_dev = hidden_states_cpu.to(device) + temb_dev = temb_cpu.to(device) + encoder_hidden_states_dev = encoder_hidden_states_cpu.to(device) + + with torch.no_grad(): + dev_out = dev_block( + hidden_states=hidden_states_dev, + temb=temb_dev, + encoder_hidden_states=encoder_hidden_states_dev, + ) + + test_result("UNetMidBlock2DCrossAttn", dev_out, cpu_out, rtol=rtol, atol=atol) + print("Max diff >", torch.max(torch.abs(dev_out.cpu() - cpu_out)).item()) + print("UNetMidBlock2DCrossAttn simulation done.") + +def test_cross_attn_up_block2d( + device, + in_channels=320, + out_channels=320, + prev_output_channel=320, + temb_channels=1280, + cross_attention_dim=768, + batch=1, + height=32, + width=32, + rtol=1e-4, + atol=1e-4, + num_layers=1, + num_attention_heads=8, + dual_cross_attention=False, +): + print(f"Testing CrossAttnUpBlock2D on device: {device}") + + cpu_block = CrossAttnUpBlock2D( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + num_layers=num_layers, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + # add_upsample=add_upsample, + ).to("cpu").eval() + + g = torch.Generator().manual_seed(0) + hidden_states_cpu = torch.randn(batch, in_channels, height, width, generator=g) + temb_cpu = torch.randn(batch, temb_channels, generator=g) + encoder_hidden_states_cpu = torch.randn(batch, 77, cross_attention_dim, generator=g) + + res_hidden_states_tuple_cpu = tuple( + torch.randn(batch, prev_output_channel, height, width, generator=g) for _ in range(num_layers) + ) + + with torch.no_grad(): + cpu_out = cpu_block( + hidden_states=hidden_states_cpu, + res_hidden_states_tuple=res_hidden_states_tuple_cpu, + temb=temb_cpu, + encoder_hidden_states=encoder_hidden_states_cpu, + ) + + dev_block = cpu_block.to(device).eval() + dev_block = torch.compile(dev_block, dynamic=False) + + hidden_states_dev = hidden_states_cpu.to(device) + temb_dev = temb_cpu.to(device) + encoder_hidden_states_dev = encoder_hidden_states_cpu.to(device) + res_hidden_states_tuple_dev = tuple(t.to(device) for t in res_hidden_states_tuple_cpu) + + with torch.no_grad(): + dev_out = dev_block( + hidden_states=hidden_states_dev, + res_hidden_states_tuple=res_hidden_states_tuple_dev, + temb=temb_dev, + encoder_hidden_states=encoder_hidden_states_dev, + ) + + test_result("CrossAttnUpBlock2D", dev_out, cpu_out, rtol=rtol, atol=atol) + print("Max diff >", torch.max(torch.abs(dev_out.cpu() - cpu_out)).item()) + print("CrossAttnUpBlock2D simulation done.") + +def test_unet2d_condition_model( + device, + batch=1, + in_channels=4, + out_channels=4, + sample_size=32, + cross_attention_dim=[768, 768], + seq_len=77, + block_out_channels=(64, 64), + layers_per_block=[1, 1], + attention_head_dim=(8, 8), + rtol=1e-4, + atol=1e-4, + stride=None, +): + down_block_types = ("CrossAttnDownBlock2D", "DownBlock2D") + up_block_types = ("UpBlock2D", "CrossAttnUpBlock2D") + + unet_cpu = UNet2DConditionModel( + sample_size=sample_size, + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ).to("cpu").eval() + + g = torch.Generator().manual_seed(0) + + if stride is not None: + x_cpu = torch.empty_strided([batch, in_channels, sample_size, sample_size], stride).normal_(generator=g) + else: + x_cpu = torch.randn(batch, in_channels, sample_size, sample_size, generator=g) + + t_cpu = torch.randint(low=0, high=1000, size=(batch,), generator=g, dtype=torch.long) + encoder_hidden_states_cpu = torch.randn(batch, seq_len, cross_attention_dim[0], generator=g) + + # CPU result + with torch.no_grad(): + y_cpu = unet_cpu( + sample=x_cpu, + timestep=t_cpu, + encoder_hidden_states=encoder_hidden_states_cpu, + ).sample # UNet2DConditionOutput.sample (Tensor) + + # Device + torch.compile + unet_dev = unet_cpu.to(device).eval() + unet_dev = torch.compile(unet_dev, dynamic=False) + + x_dev = x_cpu.to(device) + t_dev = t_cpu.to(device) + encoder_hidden_states_dev = encoder_hidden_states_cpu.to(device) + + with torch.no_grad(): + y_dev = unet_dev( + sample=x_dev, + timestep=t_dev, + encoder_hidden_states=encoder_hidden_states_dev, + ).sample + + for idx, (cpu, dev) in enumerate(zip(y_cpu, y_dev)): + test_result(f"[{idx}] UNet2DConditionModel", dev.cpu(), cpu, rtol=rtol, atol=atol) + max_diff = torch.max(torch.abs(dev.detach().cpu() - cpu)).item() + print("Max diff >", max_diff) + print("UNet2DConditionModel simulation done.") + +def test_cross_attn_down_block2d( + device, + in_channels=320, + out_channels=320, + temb_channels=1280, + cross_attention_dim=768, + batch=1, + height=32, + width=32, + rtol=1e-4, + atol=1e-4, + num_layers=1, + num_attention_heads=8, + dual_cross_attention=False +): + print(f"Testing CrossAttnDownBlock2D on device: {device}") + + # 1. Initialize the module on CPU + cpu_block = CrossAttnDownBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_layers=num_layers, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention + ).to("cpu").eval() + + # 2. Create synthetic inputs on CPU + g = torch.Generator().manual_seed(0) + hidden_states_cpu = torch.randn(batch, in_channels, height, width, generator=g) + temb_cpu = torch.randn(batch, temb_channels, generator=g) + encoder_hidden_states_cpu = torch.randn(batch, 77, cross_attention_dim, generator=g) + + # 3. Get the output from the CPU module + with torch.no_grad(): + cpu_out, _ = cpu_block( + hidden_states=hidden_states_cpu, + temb=temb_cpu, + encoder_hidden_states=encoder_hidden_states_cpu, + ) + + # 4. Initialize the module on the custom device + device_block = cpu_block.to(device).eval() + device_block = torch.compile(device_block, dynamic=False) + + # 5. Move inputs to the custom device + hidden_states_dev = hidden_states_cpu.to(device) + temb_dev = temb_cpu.to(device) + encoder_hidden_states_dev = encoder_hidden_states_cpu.to(device) + + # 6. Get the output from the custom device module + with torch.no_grad(): + dev_out, _ = device_block( + hidden_states=hidden_states_dev, + temb=temb_dev, + encoder_hidden_states=encoder_hidden_states_dev, + ) + + # 7. Compare the results + test_result("CrossAttnDownBlock2D", dev_out, cpu_out, rtol=rtol, atol=atol) + print("Max diff >", torch.max(torch.abs(dev_out.cpu() - cpu_out)).item()) + print("CrossAttnDownBlock2D simulation done.") + +def test_resnetblock2d( + device, + batch=1, + in_channels=320, + out_channels=320, + height=32, + width=32, + temb_channels=128, + resnet_eps=1e-5, + resnet_groups=32, + dropout=0.0, + resnet_time_scale_shift="default", # e.g., "default" | "scale_shift" + resnet_act_fn="swish", + output_scale_factor=1.0, + resnet_pre_norm=True, + rtol=1e-4, + atol=1e-4, + stride=None, +): + print(f"Testing ResnetBlock2D(down=True) on device: {device}") + + g = torch.Generator().manual_seed(0) + cpu_blk = ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm + ).to("cpu").eval() + + if stride is not None: + x_cpu = torch.empty_strided([batch, in_channels, height, width], stride).normal_() + else: + x_cpu = torch.randn(batch, in_channels, height, width, generator=g) + + temb_cpu = torch.randn(batch, temb_channels, generator=g) + + with torch.no_grad(): + y_cpu = cpu_blk(x_cpu, temb=temb_cpu) + + dev_blk = cpu_blk.to(device).eval() + dev_blk = torch.compile(dev_blk, dynamic=False) + + x_dev = x_cpu.to(device) + temb_dev = temb_cpu.to(device) + + with torch.no_grad(): + y_dev = dev_blk(x_dev, temb=temb_dev) + + try: + test_result("ResnetBlock2D(down=True)", y_dev, y_cpu, rtol=rtol, atol=atol) + except NameError: + # fallback: PyTorch의 기본 엄밀 비교 + torch.testing.assert_close(y_dev.cpu(), y_cpu, rtol=rtol, atol=atol) + print("ResnetBlock2D(down=True) close-check passed.") + + max_diff = torch.max(torch.abs(y_dev.cpu() - y_cpu)).item() + print("Max diff >", max_diff) + print("ResnetBlock2D simulation done.") + +def test_groupnorm( + device, + batch=1, + channels=320, + height=32, + width=32, + num_groups=32, + eps=1e-5, + rtol=1e-4, + atol=1e-4, + stride=None +): + print(f"Testing GroupNorm on device: {device}") + + # 1. Initialize the module on CPU + cpu_norm = torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=channels, + eps=eps, + affine=True + ).to("cpu").eval() + + # 2. Create synthetic inputs on CPU + g = torch.Generator().manual_seed(0) + if stride is not None: + input_cpu = torch.empty_strided([batch, channels, height, width], stride) + input_cpu = input_cpu.normal_() + else: + input_cpu = torch.randn(batch, channels, height, width, generator=g) + + # 3. Get the output from the CPU module + with torch.no_grad(): + cpu_out = cpu_norm(input_cpu) + + # 4. Initialize the module on the custom device + device_norm = torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=channels, + eps=eps, + affine=True + ).to(device).eval() + device_norm = torch.compile(device_norm, dynamic=False) + + # Copy the weights from the CPU module to ensure they are identical + device_norm.weight.data.copy_(cpu_norm.weight.data) + device_norm.bias.data.copy_(cpu_norm.bias.data) + + # 5. Move inputs to the custom device + input_dev = input_cpu.to(device) + + # 6. Get the output from the custom device module + with torch.no_grad(): + dev_out = device_norm(input_dev) + + # 7. Compare the results + test_result("GroupNorm", dev_out, cpu_out, rtol=rtol, atol=atol) + print("Max diff >", torch.max(torch.abs(dev_out.cpu() - cpu_out)).item()) + print("GroupNorm simulation done.") + +def test_upsample2d( + device, + batch=1, + channels=320, + height=32, + width=32, + rtol=1e-4, + atol=1e-4, + use_conv=True, + use_conv_transpose=False, + out_channels=320, + name="conv", + kernel_size=None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + interpolate=True, + stride=None, +): + cpu_block = Upsample2D( + channels=channels, + use_conv=use_conv, + use_conv_transpose=use_conv_transpose, + out_channels=out_channels, + name=name, + kernel_size=kernel_size, + padding=padding, + norm_type=norm_type, + eps=eps, + elementwise_affine=elementwise_affine, + bias=bias, + interpolate=interpolate, + ).to("cpu").eval() + + g = torch.Generator().manual_seed(0) + if stride is not None: + x_cpu = torch.empty_strided([batch, channels, height, width], stride).normal_(generator=g) + else: + x_cpu = torch.randn(batch, channels, height, width, generator=g) + + with torch.no_grad(): + y_cpu = cpu_block(x_cpu) + + dev_block = cpu_block.to(device).eval() + dev_block = torch.compile(dev_block, dynamic=False) + x_dev = x_cpu.to(device) + + with torch.no_grad(): + y_dev = dev_block(x_dev) + + test_result("Upsample2D", y_dev, y_cpu, rtol=rtol, atol=atol) + print("Max diff >", torch.max(torch.abs(y_dev.cpu() - y_cpu)).item()) + print("Upsample2D simulation done.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run UNet (diffusers) test with comparison") + parser.add_argument("--model", type=str, default="runwayml/stable-diffusion-v1-5", + help="Diffusers model id (e.g., Qwen/Qwen-Image)") + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) + parser.add_argument("--rtol", type=float, default=1e-4) + parser.add_argument("--atol", type=float, default=1e-4) + parser.add_argument("--prompt", type=str, default="a cat in a hat") + args = parser.parse_args() + + sys.path.append(os.environ.get("TORCHSIM_DIR", "/workspace/PyTorchSim")) + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + + #test_upsample2d(device) + #test_groupnorm(device) + #test_groupnorm(device, stride=[1, 1, 320*32, 320]) + #test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=320) + #test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=1280) + #test_cross_attn_down_block2d(device) + #test_unet_mid_block2d_cross_attn(device) + #test_cross_attn_up_block2d(device) + test_unet2d_condition_model(device) + #test_unet_conditional( + # device=device, + # model_id=args.model, + # batch=args.batch, + # dtype=args.dtype, + # rtol=args.rtol, + # atol=args.atol, + # prompt=args.prompt, + #) \ No newline at end of file diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index c679b431..21bbfec7 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -55,3 +55,4 @@ def custom_conv2d(a, b, bias): test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=768, input_size=224, kernel_size=16,stride=16, padding=0) diff --git a/tests/test_transcendental.py b/tests/test_transcendental.py new file mode 100644 index 00000000..5f296581 --- /dev/null +++ b/tests/test_transcendental.py @@ -0,0 +1,83 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_tanh(device, size=(128, 128)): + def tanh(a): + return torch.tanh(a) + x = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(tanh) + res = opt_fn(x) + out = tanh(x.cpu()) + test_result("Tanh", res, out) + +def test_exp(device, size=(128, 128)): + def exp(a): + return torch.exp(a) + x = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(exp) + res = opt_fn(x) + out = exp(x.cpu()) + test_result("Exp", res, out) + +def test_erf(device, size=(128, 128)): + def erf(a): + return torch.erf(a) + x = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(erf) + res = opt_fn(x) + out = erf(x.cpu()) + test_result("Erf", res, out) + +def test_sin(device, size=(128, 128)): + def sin(a): + return torch.sin(a) + x = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(sin) + res = opt_fn(x) + out = sin(x.cpu()) + test_result("Sin", res, out) + +def test_cos(device, size=(128, 128)): + def cos(a): + return torch.cos(a) + x = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(cos) + res = opt_fn(x) + out = cos(x.cpu()) + test_result("Cos", res, out) + +if __name__ == "__main__": + import os + import sys + import argparse + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") + parser.add_argument('--shape', type=str, default="(512,768)") + args = parser.parse_args() + shape = tuple(map(int, args.shape.strip('()').split(','))) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_tanh(device) + test_exp(device) + test_erf(device) + test_sin(device) + test_cos(device) \ No newline at end of file diff --git a/tests/test_vit.py b/tests/test_vit.py new file mode 100644 index 00000000..6f587127 --- /dev/null +++ b/tests/test_vit.py @@ -0,0 +1,219 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension +import argparse +from torchvision import models +from torchvision.models.vision_transformer import _vision_transformer, EncoderBlock + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def init_vit_weights(m): + if isinstance(m, torch.nn.Linear): + torch.nn.init.normal_(m.weight, mean=0.0, std=0.02) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + + elif isinstance(m, torch.nn.Conv2d): + torch.nn.init.kaiming_normal_(m.weight, nonlinearity='linear') + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + + elif isinstance(m, torch.nn.LayerNorm): + if m.weight is not None: + torch.nn.init.normal_(m.weight) + if m.bias is not None: + torch.nn.init.normal_(m.bias) + + elif isinstance(m, torch.nn.MultiheadAttention): + # QKV projection + if hasattr(m, 'in_proj_weight'): + torch.nn.init.normal_(m.in_proj_weight, mean=0.0, std=0.02) + if hasattr(m, 'in_proj_bias') and m.in_proj_bias is not None: + torch.nn.init.normal_(m.in_proj_bias) + + # Output projection + if hasattr(m, 'out_proj'): + torch.nn.init.normal_(m.out_proj.weight, mean=0.0, std=0.02) + if m.out_proj.bias is not None: + torch.nn.init.normal_(m.out_proj.bias) + +def test_vit(device, batch=1, shape=(3, 224, 224), num_layers=1, num_heads=12, hidden_dim=768, mlp_dim=3072): + with torch.no_grad(): + #model = models.vit_b_16(models.ViT_B_16_Weights.IMAGENET1K_V1).eval() + model = _vision_transformer( + patch_size=16, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + weights=None, + progress=False + ).eval() + model.apply(init_vit_weights) + + input_tensor = torch.randn(batch, *shape) + x_device = input_tensor.to(device=device) + x_cpu = input_tensor.cpu() + + model.to(device) + opt_model = torch.compile(dynamic=False)(model) + out_device = opt_model(x_device) + + cpu_model = model.cpu() + out_cpu = cpu_model(x_cpu) + + test_result("VisionTransformer inference", out_device, out_cpu) + print("Max diff > ", torch.max(torch.abs(out_device.cpu() - out_cpu))) + print("VisionTransformer Simulation Done") + +def test_multihead_attention(device, batch=1, seq_len=32, hidden_dim=768, num_heads=12): + print(f"Testing MultiheadAttention (batch={batch}, seq_len={seq_len}, dim={hidden_dim}, heads={num_heads})") + + mha = torch.nn.MultiheadAttention( + embed_dim=hidden_dim, + num_heads=num_heads, + batch_first=True, + dropout=0.0, + ).eval() + mha.apply(init_vit_weights) + + x = torch.randn(seq_len, hidden_dim) + query, key, value = x.clone(), x.clone(), x.clone() + + mha_device = mha.to(device) + q1, k1, v1 = query.to(device), key.to(device), value.to(device) + + compiled_mha = torch.compile(mha_device, dynamic=False) + with torch.no_grad(): + out_device, _ = compiled_mha(q1, k1, v1) + + mha_cpu = mha.cpu() + q2, k2, v2 = query.cpu(), key.cpu(), value.cpu() + with torch.no_grad(): + out_cpu, _ = mha_cpu(q2, k2, v2) + + test_result("MultiheadAttention", out_device, out_cpu) + print("Max diff > ", torch.max(torch.abs(out_device.cpu() - out_cpu))) + print("MultiheadAttention Simulation Done") + +def test_encoder_block(device, batch=1, seq_len=16, hidden_dim=768, num_heads=12, mlp_dim=3072, dropout=0.1, attention_dropout=0.1): + with torch.no_grad(): + block = EncoderBlock( + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + ).eval() + + input_tensor = torch.randn(batch, seq_len, hidden_dim) + + x_device = input_tensor.to(device=device) + x_cpu = input_tensor.cpu() + + block.to(device) + opt_block = torch.compile(dynamic=False)(block) + out_device = opt_block(x_device) + + cpu_block = block.cpu() + out_cpu = cpu_block(x_cpu) + + test_result("EncoderBlock", out_device, out_cpu) + print("Max diff >", torch.max(torch.abs(out_device.cpu() - out_cpu))) + print("EncoderBlock Simulation Done") + +class EncoderWrapper(torch.nn.Module): + def __init__(self, encoder_block: torch.nn.Module, hidden_dim=768): + super().__init__() + self.encoder = encoder_block + self.class_token = torch.nn.Parameter(torch.ones(1, 1, hidden_dim)*99) + self.ln = torch.nn.LayerNorm(hidden_dim, eps=1e-6) + + def forward(self, x): + n = x.shape[0] + batch_class_token = self.class_token.expand(n, 1, -1) + x = torch.cat([batch_class_token, x], dim=1) + #return self.ln(x) + return self.encoder(x) + #return torch.var_mean(x, dim=-1) #self.encoder(x) + +def test_encoder_block_with_class_token( + device, + batch=1, + seq_len=16, + hidden_dim=768, + num_heads=12, + mlp_dim=3072, + dropout=0.1, + attention_dropout=0.1, +): + with torch.no_grad(): + block = EncoderBlock( + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + ).eval() + block.apply(init_vit_weights) + + wrapper = EncoderWrapper(block).eval() + + #input_tensor = torch.arange(seq_len).view(1, seq_len, 1).expand(batch, seq_len, hidden_dim).contiguous() + input_tensor = torch.randn(batch, seq_len, hidden_dim) + + x_device = input_tensor.to(device) + wrapper.to(device) + + opt_wrapper = torch.compile(wrapper, dynamic=False) + out_device = opt_wrapper(x_device) + wrapper_cpu = wrapper.cpu() + out_cpu = wrapper_cpu(input_tensor.cpu()) + + test_result("EncoderBlock with class token", out_device, out_cpu) + print("Max diff >", torch.max(torch.abs(out_device.cpu() - out_cpu))) + print("EncoderBlock with class token Simulation Done") + +if __name__ == "__main__": + import os + import sys + parser = argparse.ArgumentParser(description="Run Vision Transformer test with comparison") + parser.add_argument('--batch', type=int, default=1) + parser.add_argument('--shape', type=str, default="(3,224,224)", help="e.g. '(3,224,224)'") + parser.add_argument('--num_layers', type=int, default=1) + parser.add_argument('--num_heads', type=int, default=12) + parser.add_argument('--hidden_dim', type=int, default=768) + parser.add_argument('--mlp_dim', type=int, default=3072) + args = parser.parse_args() + + shape = tuple(map(int, args.shape.strip('()').split(','))) + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + #test_multihead_attention(device) + #test_encoder_block(device, seq_len=197) + #test_encoder_block_with_class_token(device, seq_len=196) + test_vit( + device, + batch=args.batch, + shape=shape, + num_layers=args.num_layers, + num_heads=args.num_heads, + hidden_dim=args.hidden_dim, + mlp_dim=args.mlp_dim + ) \ No newline at end of file