diff --git a/CLAUDE.md b/CLAUDE.md index 5d8bbac..66e6722 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -73,31 +73,38 @@ native/ops/matmul/ ├── common/ # Shared utilities │ └── aligned_copy_sm120.cuh ├── gemm/ # GEMM kernels (M > 1) -│ └── {input_dtype}/{output_dtype}/{arch}/{compute}_{suffix}.{cu,cuh} +│ └── {w_dtype}_{a_dtype}_{out_dtype}/{arch}/{kernel}.{cu,cuh} ├── gemv/ # GEMV kernels (M = 1) -│ └── {input_dtype}/{output_dtype}/{arch}/{compute}_{suffix}.{cu,cuh} +│ └── {w_dtype}_{a_dtype}_{out_dtype}/{arch}/{kernel}.{cu,cuh} ├── cublaslt.cuh # cuBLASLt wrapper ├── matmul.cu # Main dispatcher └── matmul_cutlass.cu # CUTLASS dispatcher ``` -**Path Convention:** `{gemm|gemv}/{input_dtype}/{output_dtype}/{arch}/{compute}_{suffix}.cu` +**Path Convention:** `{gemm|gemv}/{w{weight}a{act}_{out}}/{arch}/{kernel}.cu` -| Component | Values | Examples | -|-----------|--------|----------| -| `input_dtype` | `f32`, `bf16`, `fp8`, `nvf4` | Input tensor dtype | -| `output_dtype` | `f32`, `bf16`, `fp8` | Output tensor dtype | +| Component | Values | Description | +|-----------|--------|-------------| +| `w_dtype` | `w4`, `w8`, `bf16`, `f32`, `int4`, `int8` | Weight dtype (w=weight) | +| `a_dtype` | `a4`, `a8`, `a16`, `bf16`, `f32`, `int4`, `int8` | Activation dtype (a=act) | +| `out_dtype` | `bf16`, `f32` | Output dtype | | `arch` | `generic`, `sm80`, `sm90`, `sm100`, `sm120` | Target architecture | -| `compute` | `naive`, `wmma`, `mma`, `cutlass` | Compute method | -| `suffix` | `blockwise`, `kernels`, etc. | Variant identifier | + +**Naming Rationale (Issue #122 Option 2):** +- `w8a16_bf16`: FP8 weights, BF16 activations, BF16 output (W8A16 GEMM) +- `w4a16_bf16`: NVF4 weights, BF16 activations, BF16 output (NVF4 GEMV) +- `w8a8_bf16`: FP8 weights, FP8 activations, BF16 output (pure FP8) +- `bf16_bf16`: BF16 weights, BF16 activations (no quantization) +- `f32_f32`: FP32 weights, FP32 activations (baseline) **Examples:** ``` -gemm/bf16/bf16/sm120/bf16_cutlass.cuh # BF16->BF16 GEMM, SM120, CUTLASS -gemm/fp8/f32/sm90/fp8_cutlass.cu # FP8->F32 GEMM, SM90, CUTLASS -gemm/nvf4/bf16/sm120/nvf4_cutlass.cu # NVF4->BF16 GEMM, SM120, CUTLASS -gemv/bf16/bf16/sm120/nvf4.cu # NVF4->BF16 GEMV, SM120 -gemm/f32/f32/generic/tf32_mma.cuh # TF32 GEMM, generic (SM80+) +gemm/bf16_bf16/sm80/bf16_cutlass.cuh # BF16 GEMM, SM80, CUTLASS +gemm/w8a8_f32/sm90/fp8_cutlass.cu # FP8->F32 GEMM, SM90, CUTLASS +gemm/w4a16_bf16/sm120/nvf4_cutlass.cu # NVF4 weights, BF16 act->BF16, SM120 +gemv/w4a16_bf16/sm120/nvf4.cu # NVF4 GEMV, SM120 +gemv/w8a16_bf16/sm120/fp8_opt_kernels.cu # FP8 weight, BF16 act GEMV, SM120 +gemm/f32_f32/generic/tf32_mma.cuh # TF32 GEMM, generic (SM80+) ``` ### Module Separation Policy diff --git a/native/CMakeLists.txt b/native/CMakeLists.txt index 6540f9f..ed0f604 100644 --- a/native/CMakeLists.txt +++ b/native/CMakeLists.txt @@ -153,30 +153,30 @@ pybind11_add_module(${MODULE_NAME} ops/reduction/reduction.cu ops/matmul/matmul.cu ops/matmul/matmul_cutlass.cu - # GEMM kernels - ops/matmul/gemm/f32/f32/generic/f32_ampere.cu - ops/matmul/gemm/fp8/f32/sm90/fp8_cutlass.cu - ops/matmul/gemm/fp8/f32/sm100/fp8_blockwise.cu - ops/matmul/gemm/fp8/bf16/sm120/fp8_blockwise.cu - ops/matmul/gemm/fp8/bf16/sm120/w8a16_gemm.cu - ops/matmul/gemm/fp8/bf16/sm120/w8a16_cutlass.cu - ops/matmul/gemm/fp8/bf16/sm120/grouped_gemm.cu - ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass.cu - ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v2.cu - ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v3.cu - ops/matmul/gemm/int8/int8/sm120/int8_native.cu - ops/matmul/gemm/int4/int4/sm120/int4_via_int8.cu - ops/matmul/gemm/nvf4/bf16/sm120/nvf4_cutlass.cu - ops/matmul/gemm/nvf4/bf16/sm120/nvf4_nvf4_cutlass.cu - # GEMV kernels - ops/matmul/gemv/bf16/bf16/sm120/nvf4.cu - ops/matmul/gemv/bf16/bf16/sm120/nvf4_kernels.cu - ops/matmul/gemv/bf16/bf16/sm120/fp8_opt_kernels.cu - ops/matmul/gemv/bf16/bf16/sm120/bf16_opt.cu - ops/matmul/gemv/fp8/fp8/sm120/fp8_gemv.cu - ops/matmul/gemv/fp8/fp8/sm120/fp8_accurate.cu - ops/matmul/gemv/nvf4/nvf4/sm120/nvf4_gemv.cu - ops/matmul/gemv/int4/int4/sm120/int4_gemv.cu + # GEMM kernels (Issue #122: Reorganized with w{weight}a{act}_{out} naming) + ops/matmul/gemm/f32_f32/generic/f32_ampere.cu + ops/matmul/gemm/w8a8_f32/sm90/fp8_cutlass.cu + ops/matmul/gemm/w8a8_f32/sm100/fp8_blockwise.cu + ops/matmul/gemm/w8a16_bf16/sm120/fp8_blockwise.cu + ops/matmul/gemm/w8a16_bf16/sm120/w8a16_gemm.cu + ops/matmul/gemm/w8a16_bf16/sm120/w8a16_cutlass.cu + ops/matmul/gemm/w8a16_bf16/sm120/grouped_gemm.cu + ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass.cu + ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v2.cu + ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v3.cu + ops/matmul/gemm/int8_int8/sm120/int8_native.cu + ops/matmul/gemm/int4_int4/sm120/int4_via_int8.cu + ops/matmul/gemm/w4a16_bf16/sm120/nvf4_cutlass.cu + ops/matmul/gemm/w4a16_bf16/sm120/nvf4_nvf4_cutlass.cu + # GEMV kernels (Issue #122: Reorganized with w{weight}a{act}_{out} naming) + ops/matmul/gemv/w4a16_bf16/sm120/nvf4.cu + ops/matmul/gemv/w4a16_bf16/sm120/nvf4_kernels.cu + ops/matmul/gemv/w8a16_bf16/sm120/fp8_opt_kernels.cu + ops/matmul/gemv/bf16_bf16/sm120/bf16_opt.cu + ops/matmul/gemv/w8a8_bf16/sm120/fp8_gemv.cu + ops/matmul/gemv/w8a8_bf16/sm120/fp8_accurate.cu + ops/matmul/gemv/w4a4_bf16/sm120/nvf4_gemv.cu + ops/matmul/gemv/int4_int4/sm120/int4_gemv.cu ops/nn/nn.cu ops/quantize/quantize.cu ops/attention/paged_attention.cu diff --git a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_naive.cuh b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_naive.cuh similarity index 98% rename from native/ops/matmul/gemm/bf16/bf16/generic/bf16_naive.cuh rename to native/ops/matmul/gemm/bf16_bf16/generic/bf16_naive.cuh index 7d59bfb..98d0be0 100644 --- a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_naive.cuh +++ b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_naive.cuh @@ -14,7 +14,7 @@ #include #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma.cuh b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma.cuh similarity index 99% rename from native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma.cuh rename to native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma.cuh index 4324778..ab8ecbe 100644 --- a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma.cuh +++ b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma.cuh @@ -14,7 +14,7 @@ #include #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma_generic.cuh b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma_generic.cuh similarity index 99% rename from native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma_generic.cuh rename to native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma_generic.cuh index 98b68fa..bcf89cc 100644 --- a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma_generic.cuh +++ b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma_generic.cuh @@ -12,7 +12,7 @@ #include #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/bf16/bf16/sm100/bf16_cutlass.cuh b/native/ops/matmul/gemm/bf16_bf16/sm100/bf16_cutlass.cuh similarity index 100% rename from native/ops/matmul/gemm/bf16/bf16/sm100/bf16_cutlass.cuh rename to native/ops/matmul/gemm/bf16_bf16/sm100/bf16_cutlass.cuh diff --git a/native/ops/matmul/gemm/bf16/bf16/sm120/bf16_cutlass.cuh b/native/ops/matmul/gemm/bf16_bf16/sm120/bf16_cutlass.cuh similarity index 100% rename from native/ops/matmul/gemm/bf16/bf16/sm120/bf16_cutlass.cuh rename to native/ops/matmul/gemm/bf16_bf16/sm120/bf16_cutlass.cuh diff --git a/native/ops/matmul/gemm/bf16/bf16/sm80/bf16_cutlass.cuh b/native/ops/matmul/gemm/bf16_bf16/sm80/bf16_cutlass.cuh similarity index 100% rename from native/ops/matmul/gemm/bf16/bf16/sm80/bf16_cutlass.cuh rename to native/ops/matmul/gemm/bf16_bf16/sm80/bf16_cutlass.cuh diff --git a/native/ops/matmul/gemm/bf16/bf16/sm90/bf16_cutlass.cuh b/native/ops/matmul/gemm/bf16_bf16/sm90/bf16_cutlass.cuh similarity index 100% rename from native/ops/matmul/gemm/bf16/bf16/sm90/bf16_cutlass.cuh rename to native/ops/matmul/gemm/bf16_bf16/sm90/bf16_cutlass.cuh diff --git a/native/ops/matmul/gemm/f32/f32/generic/f32_ampere.cu b/native/ops/matmul/gemm/f32_f32/generic/f32_ampere.cu similarity index 100% rename from native/ops/matmul/gemm/f32/f32/generic/f32_ampere.cu rename to native/ops/matmul/gemm/f32_f32/generic/f32_ampere.cu diff --git a/native/ops/matmul/gemm/f32/f32/generic/f32_ampere.cuh b/native/ops/matmul/gemm/f32_f32/generic/f32_ampere.cuh similarity index 99% rename from native/ops/matmul/gemm/f32/f32/generic/f32_ampere.cuh rename to native/ops/matmul/gemm/f32_f32/generic/f32_ampere.cuh index 2a6586c..45c9f3a 100644 --- a/native/ops/matmul/gemm/f32/f32/generic/f32_ampere.cuh +++ b/native/ops/matmul/gemm/f32_f32/generic/f32_ampere.cuh @@ -18,7 +18,7 @@ #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/f32/f32/generic/f32_naive.cuh b/native/ops/matmul/gemm/f32_f32/generic/f32_naive.cuh similarity index 99% rename from native/ops/matmul/gemm/f32/f32/generic/f32_naive.cuh rename to native/ops/matmul/gemm/f32_f32/generic/f32_naive.cuh index d8afc27..5065ba5 100644 --- a/native/ops/matmul/gemm/f32/f32/generic/f32_naive.cuh +++ b/native/ops/matmul/gemm/f32_f32/generic/f32_naive.cuh @@ -10,7 +10,7 @@ #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/f32/f32/generic/tf32_mma.cuh b/native/ops/matmul/gemm/f32_f32/generic/tf32_mma.cuh similarity index 99% rename from native/ops/matmul/gemm/f32/f32/generic/tf32_mma.cuh rename to native/ops/matmul/gemm/f32_f32/generic/tf32_mma.cuh index ace60ac..8df7ded 100644 --- a/native/ops/matmul/gemm/f32/f32/generic/tf32_mma.cuh +++ b/native/ops/matmul/gemm/f32_f32/generic/tf32_mma.cuh @@ -11,7 +11,7 @@ #pragma once #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/f32/f32/generic/tf32_wmma.cuh b/native/ops/matmul/gemm/f32_f32/generic/tf32_wmma.cuh similarity index 99% rename from native/ops/matmul/gemm/f32/f32/generic/tf32_wmma.cuh rename to native/ops/matmul/gemm/f32_f32/generic/tf32_wmma.cuh index 15050b3..c505880 100644 --- a/native/ops/matmul/gemm/f32/f32/generic/tf32_wmma.cuh +++ b/native/ops/matmul/gemm/f32_f32/generic/tf32_wmma.cuh @@ -1,7 +1,7 @@ #pragma once #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/int4/int4/sm120/int4_via_int8.cu b/native/ops/matmul/gemm/int4_int4/sm120/int4_via_int8.cu similarity index 99% rename from native/ops/matmul/gemm/int4/int4/sm120/int4_via_int8.cu rename to native/ops/matmul/gemm/int4_int4/sm120/int4_via_int8.cu index 08a5cf1..3ce517c 100644 --- a/native/ops/matmul/gemm/int4/int4/sm120/int4_via_int8.cu +++ b/native/ops/matmul/gemm/int4_int4/sm120/int4_via_int8.cu @@ -37,7 +37,7 @@ #include "cutlass/util/device_memory.h" #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/int8/int8/sm120/int8_native.cu b/native/ops/matmul/gemm/int8_int8/sm120/int8_native.cu similarity index 100% rename from native/ops/matmul/gemm/int8/int8/sm120/int8_native.cu rename to native/ops/matmul/gemm/int8_int8/sm120/int8_native.cu diff --git a/native/ops/matmul/gemm/nvf4/bf16/sm120/nvf4_cutlass.cu b/native/ops/matmul/gemm/w4a16_bf16/sm120/nvf4_cutlass.cu similarity index 100% rename from native/ops/matmul/gemm/nvf4/bf16/sm120/nvf4_cutlass.cu rename to native/ops/matmul/gemm/w4a16_bf16/sm120/nvf4_cutlass.cu diff --git a/native/ops/matmul/gemm/nvf4/bf16/sm120/nvf4_nvf4_cutlass.cu b/native/ops/matmul/gemm/w4a16_bf16/sm120/nvf4_nvf4_cutlass.cu similarity index 100% rename from native/ops/matmul/gemm/nvf4/bf16/sm120/nvf4_nvf4_cutlass.cu rename to native/ops/matmul/gemm/w4a16_bf16/sm120/nvf4_nvf4_cutlass.cu diff --git a/native/ops/matmul/gemm/fp8/bf16/sm120/fp8_blockwise.cu b/native/ops/matmul/gemm/w8a16_bf16/sm120/fp8_blockwise.cu similarity index 99% rename from native/ops/matmul/gemm/fp8/bf16/sm120/fp8_blockwise.cu rename to native/ops/matmul/gemm/w8a16_bf16/sm120/fp8_blockwise.cu index c612aba..84f1cdd 100644 --- a/native/ops/matmul/gemm/fp8/bf16/sm120/fp8_blockwise.cu +++ b/native/ops/matmul/gemm/w8a16_bf16/sm120/fp8_blockwise.cu @@ -48,7 +48,7 @@ // Provides alignment-safe LDSM operations for Issue #2902 workaround // ============================================================================ #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/fp8/bf16/sm120/grouped_gemm.cu b/native/ops/matmul/gemm/w8a16_bf16/sm120/grouped_gemm.cu similarity index 100% rename from native/ops/matmul/gemm/fp8/bf16/sm120/grouped_gemm.cu rename to native/ops/matmul/gemm/w8a16_bf16/sm120/grouped_gemm.cu diff --git a/native/ops/matmul/gemm/fp8/bf16/sm120/w8a16_cutlass.cu b/native/ops/matmul/gemm/w8a16_bf16/sm120/w8a16_cutlass.cu similarity index 99% rename from native/ops/matmul/gemm/fp8/bf16/sm120/w8a16_cutlass.cu rename to native/ops/matmul/gemm/w8a16_bf16/sm120/w8a16_cutlass.cu index 4c9fb5a..5015663 100644 --- a/native/ops/matmul/gemm/fp8/bf16/sm120/w8a16_cutlass.cu +++ b/native/ops/matmul/gemm/w8a16_bf16/sm120/w8a16_cutlass.cu @@ -36,7 +36,7 @@ #include "cutlass/util/device_memory.h" #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/fp8/bf16/sm120/w8a16_gemm.cu b/native/ops/matmul/gemm/w8a16_bf16/sm120/w8a16_gemm.cu similarity index 100% rename from native/ops/matmul/gemm/fp8/bf16/sm120/w8a16_gemm.cu rename to native/ops/matmul/gemm/w8a16_bf16/sm120/w8a16_gemm.cu diff --git a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass.cu b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass.cu similarity index 99% rename from native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass.cu rename to native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass.cu index 360a28e..99e07f6 100644 --- a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass.cu +++ b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass.cu @@ -38,7 +38,7 @@ // Alignment patch for Issue #2902 workaround #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v2.cu b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v2.cu similarity index 99% rename from native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v2.cu rename to native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v2.cu index 261e2e1..06165e0 100644 --- a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v2.cu +++ b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v2.cu @@ -24,7 +24,7 @@ #include "cutlass/util/device_memory.h" #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v3.cu b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v3.cu similarity index 99% rename from native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v3.cu rename to native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v3.cu index bd519f1..2775c0f 100644 --- a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v3.cu +++ b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v3.cu @@ -31,7 +31,7 @@ #include "cutlass/util/device_memory.h" #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/fp8/f32/sm100/fp8_blockwise.cu b/native/ops/matmul/gemm/w8a8_f32/sm100/fp8_blockwise.cu similarity index 100% rename from native/ops/matmul/gemm/fp8/f32/sm100/fp8_blockwise.cu rename to native/ops/matmul/gemm/w8a8_f32/sm100/fp8_blockwise.cu diff --git a/native/ops/matmul/gemm/fp8/f32/sm90/fp8_cutlass.cu b/native/ops/matmul/gemm/w8a8_f32/sm90/fp8_cutlass.cu similarity index 100% rename from native/ops/matmul/gemm/fp8/f32/sm90/fp8_cutlass.cu rename to native/ops/matmul/gemm/w8a8_f32/sm90/fp8_cutlass.cu diff --git a/native/ops/matmul/gemv/bf16/bf16/generic/bf16_cutlass.cuh b/native/ops/matmul/gemv/bf16_bf16/generic/bf16_cutlass.cuh similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/generic/bf16_cutlass.cuh rename to native/ops/matmul/gemv/bf16_bf16/generic/bf16_cutlass.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/bf16_opt.cu b/native/ops/matmul/gemv/bf16_bf16/sm120/bf16_opt.cu similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/bf16_opt.cu rename to native/ops/matmul/gemv/bf16_bf16/sm120/bf16_opt.cu diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/bf16_opt.cuh b/native/ops/matmul/gemv/bf16_bf16/sm120/bf16_opt.cuh similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/bf16_opt.cuh rename to native/ops/matmul/gemv/bf16_bf16/sm120/bf16_opt.cuh diff --git a/native/ops/matmul/gemv/int4/int4/sm120/int4_gemv.cu b/native/ops/matmul/gemv/int4_int4/sm120/int4_gemv.cu similarity index 100% rename from native/ops/matmul/gemv/int4/int4/sm120/int4_gemv.cu rename to native/ops/matmul/gemv/int4_int4/sm120/int4_gemv.cu diff --git a/native/ops/matmul/gemv/int4/int4/sm120/int4_gemv.cuh b/native/ops/matmul/gemv/int4_int4/sm120/int4_gemv.cuh similarity index 100% rename from native/ops/matmul/gemv/int4/int4/sm120/int4_gemv.cuh rename to native/ops/matmul/gemv/int4_int4/sm120/int4_gemv.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/nvf4.cu b/native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4.cu similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/nvf4.cu rename to native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4.cu diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/nvf4.cuh b/native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4.cuh similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/nvf4.cuh rename to native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/nvf4_kernels.cu b/native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4_kernels.cu similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/nvf4_kernels.cu rename to native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4_kernels.cu diff --git a/native/ops/matmul/gemv/nvf4/nvf4/sm120/nvf4_gemv.cu b/native/ops/matmul/gemv/w4a4_bf16/sm120/nvf4_gemv.cu similarity index 100% rename from native/ops/matmul/gemv/nvf4/nvf4/sm120/nvf4_gemv.cu rename to native/ops/matmul/gemv/w4a4_bf16/sm120/nvf4_gemv.cu diff --git a/native/ops/matmul/gemv/nvf4/nvf4/sm120/nvf4_gemv.cuh b/native/ops/matmul/gemv/w4a4_bf16/sm120/nvf4_gemv.cuh similarity index 100% rename from native/ops/matmul/gemv/nvf4/nvf4/sm120/nvf4_gemv.cuh rename to native/ops/matmul/gemv/w4a4_bf16/sm120/nvf4_gemv.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/fp8.cuh b/native/ops/matmul/gemv/w8a16_bf16/sm120/fp8.cuh similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/fp8.cuh rename to native/ops/matmul/gemv/w8a16_bf16/sm120/fp8.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/fp8_opt.cuh b/native/ops/matmul/gemv/w8a16_bf16/sm120/fp8_opt.cuh similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/fp8_opt.cuh rename to native/ops/matmul/gemv/w8a16_bf16/sm120/fp8_opt.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/fp8_opt_kernels.cu b/native/ops/matmul/gemv/w8a16_bf16/sm120/fp8_opt_kernels.cu similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/fp8_opt_kernels.cu rename to native/ops/matmul/gemv/w8a16_bf16/sm120/fp8_opt_kernels.cu diff --git a/native/ops/matmul/gemv/fp8/fp8/sm120/fp8_accurate.cu b/native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_accurate.cu similarity index 100% rename from native/ops/matmul/gemv/fp8/fp8/sm120/fp8_accurate.cu rename to native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_accurate.cu diff --git a/native/ops/matmul/gemv/fp8/fp8/sm120/fp8_accurate.cuh b/native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_accurate.cuh similarity index 100% rename from native/ops/matmul/gemv/fp8/fp8/sm120/fp8_accurate.cuh rename to native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_accurate.cuh diff --git a/native/ops/matmul/gemv/fp8/fp8/sm120/fp8_gemv.cu b/native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_gemv.cu similarity index 100% rename from native/ops/matmul/gemv/fp8/fp8/sm120/fp8_gemv.cu rename to native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_gemv.cu diff --git a/native/ops/matmul/gemv/fp8/fp8/sm120/fp8_gemv.cuh b/native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_gemv.cuh similarity index 100% rename from native/ops/matmul/gemv/fp8/fp8/sm120/fp8_gemv.cuh rename to native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_gemv.cuh diff --git a/native/ops/matmul/matmul.cu b/native/ops/matmul/matmul.cu index 17631ae..077111f 100644 --- a/native/ops/matmul/matmul.cu +++ b/native/ops/matmul/matmul.cu @@ -1,22 +1,22 @@ /** * Matrix multiplication dispatch */ -#include "gemm/f32/f32/generic/f32_naive.cuh" +#include "gemm/f32_f32/generic/f32_naive.cuh" #include "../common/error.cuh" #include "../common/device.cuh" #include "../../core/memory.hpp" #include "../../core/cuda_graph.hpp" #include "../ops.cuh" // For transpose() -// Include existing optimized kernels -#include "gemm/f32/f32/generic/f32_ampere.cuh" -#include "gemm/f32/f32/generic/tf32_wmma.cuh" -#include "gemm/f32/f32/generic/tf32_mma.cuh" -#include "gemm/bf16/bf16/generic/bf16_naive.cuh" -#include "gemm/bf16/bf16/generic/bf16_wmma.cuh" -#include "gemm/bf16/bf16/generic/bf16_wmma_generic.cuh" +// Include existing optimized kernels (Issue #122: Updated paths) +#include "gemm/f32_f32/generic/f32_ampere.cuh" +#include "gemm/f32_f32/generic/tf32_wmma.cuh" +#include "gemm/f32_f32/generic/tf32_mma.cuh" +#include "gemm/bf16_bf16/generic/bf16_naive.cuh" +#include "gemm/bf16_bf16/generic/bf16_wmma.cuh" +#include "gemm/bf16_bf16/generic/bf16_wmma_generic.cuh" #include "cublaslt.cuh" -#include "gemm/bf16/bf16/sm80/bf16_cutlass.cuh" +#include "gemm/bf16_bf16/sm80/bf16_cutlass.cuh" #include #include diff --git a/native/ops/matmul/matmul_cutlass.cu b/native/ops/matmul/matmul_cutlass.cu index 56d7660..e50a52c 100644 --- a/native/ops/matmul/matmul_cutlass.cu +++ b/native/ops/matmul/matmul_cutlass.cu @@ -11,7 +11,7 @@ #if PYGPUKIT_HAS_CUTLASS -#include "gemm/bf16/bf16/sm80/bf16_cutlass.cuh" +#include "gemm/bf16_bf16/sm80/bf16_cutlass.cuh" namespace pygpukit { namespace ops {