diff --git a/native/bindings/gemm/fp8xbf16_bf16.cpp b/native/bindings/gemm/fp8xbf16_bf16.cpp index a8c5190..1a2eda1 100644 --- a/native/bindings/gemm/fp8xbf16_bf16.cpp +++ b/native/bindings/gemm/fp8xbf16_bf16.cpp @@ -32,13 +32,59 @@ extern "C" { } void init_gemm_fp8xbf16_bf16(py::module_& m) { + // ============================================================ + // W8A16 GEMM: FP8 weight x BF16 activation -> BF16 output (SM120) + // New name: gemm_w8a16_init_lut, alias: w8a16_gemm_init_lut + // ============================================================ + m.def("gemm_w8a16_init_lut", []() { + cudaError_t err = pygpukit_w8a16_gemm_init_lut(); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_w8a16_init_lut failed: " + std::string(cudaGetErrorString(err))); + } + }, "Initialize FP8->F32 LUT for W8A16 GEMM"); m.def("w8a16_gemm_init_lut", []() { cudaError_t err = pygpukit_w8a16_gemm_init_lut(); if (err != cudaSuccess) { throw std::runtime_error("w8a16_gemm_init_lut failed: " + std::string(cudaGetErrorString(err))); } - }, "Initialize FP8->F32 LUT for W8A16 GEMM"); + }, "[Alias for gemm_w8a16_init_lut] Initialize FP8->F32 LUT for W8A16 GEMM"); + // ============================================================ + // W8A16 GEMM with block-wise scale + // New name: gemm_w8a16_bf16_sm120, alias: w8a16_gemm_sm120 + // ============================================================ + m.def("gemm_w8a16_bf16_sm120", [](const GPUArray& A, const GPUArray& B_fp8, const GPUArray& B_scale, GPUArray& C) { + if (A.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemm_w8a16_bf16_sm120: A and C must be bfloat16"); + } + if (B_fp8.dtype() != DataType::UInt8) { + throw std::runtime_error("gemm_w8a16_bf16_sm120: B_fp8 must be uint8 (FP8 E4M3)"); + } + if (B_scale.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemm_w8a16_bf16_sm120: B_scale must be bfloat16"); + } + if (A.ndim() != 2 || B_fp8.ndim() != 2 || C.ndim() != 2) { + throw std::runtime_error("gemm_w8a16_bf16_sm120: A[M,K], B_fp8[K,N], C[M,N] dimensions required"); + } + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B_fp8.shape()[1]; + int scale_stride_n = (N + 127) / 128; + if (B_fp8.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemm_w8a16_bf16_sm120: K dimension mismatch"); + } + if (C.shape()[0] != static_cast(M) || C.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_w8a16_bf16_sm120: output shape mismatch"); + } + cudaError_t err = pygpukit_w8a16_gemm_sm120( + A.data(), B_fp8.data(), B_scale.data(), C.data(), + M, N, K, scale_stride_n, nullptr); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_w8a16_bf16_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B_fp8"), py::arg("B_scale"), py::arg("C"), + "GEMM W8A16->BF16 for SM120: C[M,N] = A[M,K] @ B_fp8[K,N] (FP8 weight x BF16 activation with block-wise scale)"); + // Alias: w8a16_gemm_sm120 m.def("w8a16_gemm_sm120", [](const GPUArray& A, const GPUArray& B_fp8, const GPUArray& B_scale, GPUArray& C) { if (A.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { throw std::runtime_error("w8a16_gemm_sm120: A and C must be bfloat16"); @@ -52,30 +98,57 @@ void init_gemm_fp8xbf16_bf16(py::module_& m) { if (A.ndim() != 2 || B_fp8.ndim() != 2 || C.ndim() != 2) { throw std::runtime_error("w8a16_gemm_sm120: A[M,K], B_fp8[K,N], C[M,N] dimensions required"); } - int M = A.shape()[0]; int K = A.shape()[1]; int N = B_fp8.shape()[1]; int scale_stride_n = (N + 127) / 128; - if (B_fp8.shape()[0] != static_cast(K)) { throw std::runtime_error("w8a16_gemm_sm120: K dimension mismatch"); } if (C.shape()[0] != static_cast(M) || C.shape()[1] != static_cast(N)) { throw std::runtime_error("w8a16_gemm_sm120: output shape mismatch"); } - cudaError_t err = pygpukit_w8a16_gemm_sm120( A.data(), B_fp8.data(), B_scale.data(), C.data(), - M, N, K, scale_stride_n, nullptr - ); - + M, N, K, scale_stride_n, nullptr); if (err != cudaSuccess) { throw std::runtime_error("w8a16_gemm_sm120 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B_fp8"), py::arg("B_scale"), py::arg("C"), - "W8A16 GEMM: C[M,N] = A[M,K] @ B_fp8[K,N] (FP8 weight x BF16 activation with block-wise scale)"); + "[Alias for gemm_w8a16_bf16_sm120] W8A16 GEMM: C[M,N] = A[M,K] @ B_fp8[K,N]"); + // ============================================================ + // W8A16 CUTLASS variant + // New name: gemm_w8a16_bf16_cutlass_sm120, alias: w8a16_cutlass_sm120 + // ============================================================ + m.def("gemm_w8a16_bf16_cutlass_sm120", [](const GPUArray& A, const GPUArray& B_fp8, GPUArray& D) { + if (A.dtype() != DataType::BFloat16 || D.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemm_w8a16_bf16_cutlass_sm120: A and D must be bfloat16"); + } + if (B_fp8.dtype() != DataType::UInt8) { + throw std::runtime_error("gemm_w8a16_bf16_cutlass_sm120: B_fp8 must be uint8 (FP8 E4M3)"); + } + if (A.ndim() != 2 || B_fp8.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_w8a16_bf16_cutlass_sm120: A[M,K], B_fp8[N,K], D[M,N] required"); + } + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B_fp8.shape()[0]; + if (B_fp8.shape()[1] != static_cast(K)) { + throw std::runtime_error("gemm_w8a16_bf16_cutlass_sm120: K dimension mismatch (B_fp8 should be [N,K])"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_w8a16_bf16_cutlass_sm120: output shape mismatch"); + } + cudaError_t err = pygpukit_w8a16_cutlass_sm120( + A.data(), B_fp8.data(), D.data(), + M, N, K, 1.0f, 0.0f, nullptr); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_w8a16_bf16_cutlass_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B_fp8"), py::arg("D"), + "GEMM W8A16->BF16 (CUTLASS) for SM120: D[M,N] = A[M,K] @ B_fp8[N,K]"); + // Alias: w8a16_cutlass_sm120 m.def("w8a16_cutlass_sm120", [](const GPUArray& A, const GPUArray& B_fp8, GPUArray& D) { if (A.dtype() != DataType::BFloat16 || D.dtype() != DataType::BFloat16) { throw std::runtime_error("w8a16_cutlass_sm120: A and D must be bfloat16"); @@ -86,31 +159,56 @@ void init_gemm_fp8xbf16_bf16(py::module_& m) { if (A.ndim() != 2 || B_fp8.ndim() != 2 || D.ndim() != 2) { throw std::runtime_error("w8a16_cutlass_sm120: A[M,K], B_fp8[N,K], D[M,N] required"); } - int M = A.shape()[0]; int K = A.shape()[1]; int N = B_fp8.shape()[0]; - if (B_fp8.shape()[1] != static_cast(K)) { throw std::runtime_error("w8a16_cutlass_sm120: K dimension mismatch (B_fp8 should be [N,K])"); } if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { throw std::runtime_error("w8a16_cutlass_sm120: output shape mismatch"); } - cudaError_t err = pygpukit_w8a16_cutlass_sm120( A.data(), B_fp8.data(), D.data(), - M, N, K, - 1.0f, 0.0f, - nullptr - ); - + M, N, K, 1.0f, 0.0f, nullptr); if (err != cudaSuccess) { throw std::runtime_error("w8a16_cutlass_sm120 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B_fp8"), py::arg("D"), - "W8A16 GEMM using CUTLASS: D[M,N] = A[M,K] @ B_fp8[N,K] (B transposed for ColumnMajor, quantizes BF16->FP8 internally)"); + "[Alias for gemm_w8a16_bf16_cutlass_sm120] W8A16 GEMM using CUTLASS"); + // ============================================================ + // W8A16 blockwise variant + // New name: gemm_w8a16_bf16_blockwise_sm120, alias: w8a16_blockwise_sm120 + // ============================================================ + m.def("gemm_w8a16_bf16_blockwise_sm120", [](const GPUArray& A, const GPUArray& B_fp8, GPUArray& D) { + if (A.dtype() != DataType::BFloat16 || D.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemm_w8a16_bf16_blockwise_sm120: A and D must be bfloat16"); + } + if (B_fp8.dtype() != DataType::UInt8) { + throw std::runtime_error("gemm_w8a16_bf16_blockwise_sm120: B_fp8 must be uint8 (FP8 E4M3)"); + } + if (A.ndim() != 2 || B_fp8.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_w8a16_bf16_blockwise_sm120: A[M,K], B_fp8[N,K], D[M,N] required"); + } + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B_fp8.shape()[0]; + if (B_fp8.shape()[1] != static_cast(K)) { + throw std::runtime_error("gemm_w8a16_bf16_blockwise_sm120: K dimension mismatch (B_fp8 should be [N,K])"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_w8a16_bf16_blockwise_sm120: output shape mismatch"); + } + cudaError_t err = pygpukit_w8a16_blockwise_sm120( + A.data(), B_fp8.data(), D.data(), + M, N, K, 1.0f, 0.0f, nullptr); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_w8a16_bf16_blockwise_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B_fp8"), py::arg("D"), + "GEMM W8A16->BF16 (blockwise) for SM120: D[M,N] = A[M,K] @ B_fp8[N,K]"); + // Alias: w8a16_blockwise_sm120 m.def("w8a16_blockwise_sm120", [](const GPUArray& A, const GPUArray& B_fp8, GPUArray& D) { if (A.dtype() != DataType::BFloat16 || D.dtype() != DataType::BFloat16) { throw std::runtime_error("w8a16_blockwise_sm120: A and D must be bfloat16"); @@ -121,31 +219,59 @@ void init_gemm_fp8xbf16_bf16(py::module_& m) { if (A.ndim() != 2 || B_fp8.ndim() != 2 || D.ndim() != 2) { throw std::runtime_error("w8a16_blockwise_sm120: A[M,K], B_fp8[N,K], D[M,N] required"); } - int M = A.shape()[0]; int K = A.shape()[1]; int N = B_fp8.shape()[0]; - if (B_fp8.shape()[1] != static_cast(K)) { throw std::runtime_error("w8a16_blockwise_sm120: K dimension mismatch (B_fp8 should be [N,K])"); } if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { throw std::runtime_error("w8a16_blockwise_sm120: output shape mismatch"); } - cudaError_t err = pygpukit_w8a16_blockwise_sm120( A.data(), B_fp8.data(), D.data(), - M, N, K, - 1.0f, 0.0f, - nullptr - ); - + M, N, K, 1.0f, 0.0f, nullptr); if (err != cudaSuccess) { throw std::runtime_error("w8a16_blockwise_sm120 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B_fp8"), py::arg("D"), - "W8A16 GEMM using blockwise: D[M,N] = A[M,K] @ B_fp8[N,K] (same kernel as working fp8_blockwise)"); + "[Alias for gemm_w8a16_bf16_blockwise_sm120] W8A16 GEMM using blockwise"); + // ============================================================ + // W8A16 optimized variant + // New name: gemm_w8a16_bf16_optimized_sm120, alias: w8a16_optimized_sm120 + // ============================================================ + m.def("gemm_w8a16_bf16_optimized_sm120", [](const GPUArray& A, const GPUArray& B_fp8, GPUArray& D) { + if (A.dtype() != DataType::BFloat16 || D.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemm_w8a16_bf16_optimized_sm120: A and D must be bfloat16"); + } + if (B_fp8.dtype() != DataType::UInt8) { + throw std::runtime_error("gemm_w8a16_bf16_optimized_sm120: B_fp8 must be uint8 (FP8 E4M3)"); + } + if (A.ndim() != 2 || B_fp8.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_w8a16_bf16_optimized_sm120: A[M,K], B_fp8[N,K], D[M,N] required"); + } + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B_fp8.shape()[0]; + if (B_fp8.shape()[1] != static_cast(K)) { + throw std::runtime_error("gemm_w8a16_bf16_optimized_sm120: K dimension mismatch (B_fp8 should be [N,K])"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_w8a16_bf16_optimized_sm120: output shape mismatch"); + } + cudaError_t err = pygpukit_gemm_w8a16_optimized_sm120( + A.data(), + reinterpret_cast(B_fp8.data()), + D.data(), + nullptr, nullptr, + M, N, K, 1.0f, 0.0f, nullptr); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_w8a16_bf16_optimized_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B_fp8"), py::arg("D"), + "GEMM W8A16->BF16 (optimized) for SM120: D[M,N] = A[M,K] @ B_fp8[N,K] (~220+ TFLOPS)"); + // Alias: w8a16_optimized_sm120 m.def("w8a16_optimized_sm120", [](const GPUArray& A, const GPUArray& B_fp8, GPUArray& D) { if (A.dtype() != DataType::BFloat16 || D.dtype() != DataType::BFloat16) { throw std::runtime_error("w8a16_optimized_sm120: A and D must be bfloat16"); @@ -156,31 +282,24 @@ void init_gemm_fp8xbf16_bf16(py::module_& m) { if (A.ndim() != 2 || B_fp8.ndim() != 2 || D.ndim() != 2) { throw std::runtime_error("w8a16_optimized_sm120: A[M,K], B_fp8[N,K], D[M,N] required"); } - int M = A.shape()[0]; int K = A.shape()[1]; int N = B_fp8.shape()[0]; - if (B_fp8.shape()[1] != static_cast(K)) { throw std::runtime_error("w8a16_optimized_sm120: K dimension mismatch (B_fp8 should be [N,K])"); } if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { throw std::runtime_error("w8a16_optimized_sm120: output shape mismatch"); } - cudaError_t err = pygpukit_gemm_w8a16_optimized_sm120( A.data(), reinterpret_cast(B_fp8.data()), D.data(), nullptr, nullptr, - M, N, K, - 1.0f, 0.0f, - nullptr - ); - + M, N, K, 1.0f, 0.0f, nullptr); if (err != cudaSuccess) { throw std::runtime_error("w8a16_optimized_sm120 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B_fp8"), py::arg("D"), - "Optimized W8A16 GEMM: D[M,N] = A[M,K] @ B_fp8[N,K] (uses fast FP8xFP8 internally, ~220+ TFLOPS expected)"); + "[Alias for gemm_w8a16_bf16_optimized_sm120] Optimized W8A16 GEMM"); } diff --git a/native/bindings/gemm/fp8xfp8_bf16.cpp b/native/bindings/gemm/fp8xfp8_bf16.cpp index fcfc419..1da94be 100644 --- a/native/bindings/gemm/fp8xfp8_bf16.cpp +++ b/native/bindings/gemm/fp8xfp8_bf16.cpp @@ -32,12 +32,19 @@ extern "C" { } void init_gemm_fp8xfp8_bf16(py::module_& m) { - // SM90 (Hopper) + // ============================================================ + // SM90 (Hopper) - FP8 internally, F32 I/O + // New name: gemm_fp8_f32_sm90_available, alias: fp8_sm90_available + // ============================================================ + m.def("gemm_fp8_f32_sm90_available", []() { + return pygpukit_fp8_sm90_available(); + }, "Check if FP8 GEMM (F32 I/O) is available on SM90 (Hopper)"); m.def("fp8_sm90_available", []() { return pygpukit_fp8_sm90_available(); - }, "Check if FP8 GEMM is available on SM90 (Hopper)"); + }, "[Alias for gemm_fp8_f32_sm90_available] Check if FP8 GEMM is available on SM90 (Hopper)"); - m.def("gemm_fp8_sm90", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + // New name: gemm_fp8_f32_sm90, alias: gemm_fp8_sm90 + m.def("gemm_fp8_f32_sm90", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || D.dtype() != DataType::Float32) { throw std::runtime_error("gemm_fp8_sm90: all inputs must be float32"); } @@ -65,18 +72,52 @@ void init_gemm_fp8xfp8_bf16(py::module_& m) { nullptr ); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_fp8_f32_sm90 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), + "GEMM FP8 (F32 I/O) for SM90: D = A @ B (FP8 quantization internally)"); + // Alias: gemm_fp8_sm90 + m.def("gemm_fp8_sm90", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || D.dtype() != DataType::Float32) { + throw std::runtime_error("gemm_fp8_sm90: all inputs must be float32"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_fp8_sm90: all inputs must be 2D"); + } + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B.shape()[1]; + if (B.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemm_fp8_sm90: A.shape[1] must equal B.shape[0]"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_fp8_sm90: D shape mismatch"); + } + cudaError_t err = pygpukit_gemm_fp8_sm90( + static_cast(A.data()), + static_cast(B.data()), + static_cast(D.data()), + M, N, K, 1.0f, 0.0f, nullptr); if (err != cudaSuccess) { throw std::runtime_error("gemm_fp8_sm90 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B"), py::arg("D"), - "FP8 GEMM for SM90 (Hopper): D = A @ B (with FP8 quantization internally)"); + "[Alias for gemm_fp8_f32_sm90] FP8 GEMM for SM90 (Hopper)"); - // SM100 (Blackwell datacenter) + // ============================================================ + // SM100 (Blackwell datacenter) - FP8 internally, F32 I/O + // New name: gemm_fp8_f32_sm100_available, alias: fp8_sm100_available + // ============================================================ + m.def("gemm_fp8_f32_sm100_available", []() { + return pygpukit_fp8_sm100_available(); + }, "Check if FP8 GEMM (F32 I/O) is available on SM100 (Blackwell datacenter)"); m.def("fp8_sm100_available", []() { return pygpukit_fp8_sm100_available(); - }, "Check if FP8 GEMM is available on SM100 (Blackwell datacenter)"); + }, "[Alias for gemm_fp8_f32_sm100_available] Check if FP8 GEMM is available on SM100 (Blackwell datacenter)"); - m.def("gemm_fp8_sm100", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + // New name: gemm_fp8_f32_sm100, alias: gemm_fp8_sm100 + m.def("gemm_fp8_f32_sm100", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || D.dtype() != DataType::Float32) { throw std::runtime_error("gemm_fp8_sm100: all inputs must be float32"); } @@ -104,18 +145,52 @@ void init_gemm_fp8xfp8_bf16(py::module_& m) { nullptr ); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_fp8_f32_sm100 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), + "GEMM FP8 (F32 I/O) for SM100: D = A @ B (FP8 quantization internally)"); + // Alias: gemm_fp8_sm100 + m.def("gemm_fp8_sm100", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || D.dtype() != DataType::Float32) { + throw std::runtime_error("gemm_fp8_sm100: all inputs must be float32"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_fp8_sm100: all inputs must be 2D"); + } + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B.shape()[1]; + if (B.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemm_fp8_sm100: A.shape[1] must equal B.shape[0]"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_fp8_sm100: D shape mismatch"); + } + cudaError_t err = pygpukit_gemm_fp8_sm100( + static_cast(A.data()), + static_cast(B.data()), + static_cast(D.data()), + M, N, K, 1.0f, 0.0f, nullptr); if (err != cudaSuccess) { throw std::runtime_error("gemm_fp8_sm100 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B"), py::arg("D"), - "FP8 GEMM for SM100 (Blackwell datacenter): D = A @ B (with FP8 quantization internally)"); + "[Alias for gemm_fp8_f32_sm100] FP8 GEMM for SM100 (Blackwell datacenter)"); - // SM120 (Blackwell GeForce) + // ============================================================ + // SM120 (Blackwell GeForce) - FP8 internally, F32 I/O + // New name: gemm_fp8_f32_sm120_available, alias: fp8_sm120_available + // ============================================================ + m.def("gemm_fp8_f32_sm120_available", []() { + return pygpukit_fp8_sm120_available(); + }, "Check if FP8 GEMM (F32 I/O) is available on SM120 (Blackwell GeForce)"); m.def("fp8_sm120_available", []() { return pygpukit_fp8_sm120_available(); - }, "Check if FP8 GEMM is available on SM120 (currently disabled due to CUTLASS bug)"); + }, "[Alias for gemm_fp8_f32_sm120_available] Check if FP8 GEMM is available on SM120"); - m.def("gemm_fp8_sm120", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + // New name: gemm_fp8_f32_sm120, alias: gemm_fp8_sm120 + m.def("gemm_fp8_f32_sm120", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || D.dtype() != DataType::Float32) { throw std::runtime_error("gemm_fp8_sm120: all inputs must be float32"); } @@ -143,9 +218,36 @@ void init_gemm_fp8xfp8_bf16(py::module_& m) { nullptr ); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_fp8_f32_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), + "GEMM FP8 (F32 I/O) for SM120: D = A @ B (FP8 quantization internally)"); + // Alias: gemm_fp8_sm120 + m.def("gemm_fp8_sm120", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || D.dtype() != DataType::Float32) { + throw std::runtime_error("gemm_fp8_sm120: all inputs must be float32"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_fp8_sm120: all inputs must be 2D"); + } + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B.shape()[1]; + if (B.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemm_fp8_sm120: A.shape[1] must equal B.shape[0]"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_fp8_sm120: D shape mismatch"); + } + cudaError_t err = pygpukit_gemm_fp8_sm120( + static_cast(A.data()), + static_cast(B.data()), + static_cast(D.data()), + M, N, K, 1.0f, 0.0f, nullptr); if (err != cudaSuccess) { throw std::runtime_error("gemm_fp8_sm120 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B"), py::arg("D"), - "FP8 GEMM for SM120: D = A @ B (with FP8 quantization internally)"); + "[Alias for gemm_fp8_f32_sm120] FP8 GEMM for SM120"); } diff --git a/native/bindings/gemm/fp8xfp8_fp8.cpp b/native/bindings/gemm/fp8xfp8_fp8.cpp index 0a33ec8..ea95c97 100644 --- a/native/bindings/gemm/fp8xfp8_fp8.cpp +++ b/native/bindings/gemm/fp8xfp8_fp8.cpp @@ -39,9 +39,16 @@ extern "C" { } void init_gemm_fp8xfp8_fp8(py::module_& m) { + // ============================================================ + // Pure FP8 I/O GEMM for SM120 + // New name: gemm_fp8_fp8_sm120_available, alias: fp8_fp8_sm120_available + // ============================================================ + m.def("gemm_fp8_fp8_sm120_available", []() { + return pygpukit_fp8_fp8_sm120_available(); + }, "Check if Pure FP8 I/O GEMM is available on SM120 (Blackwell GeForce)"); m.def("fp8_fp8_sm120_available", []() { return pygpukit_fp8_fp8_sm120_available(); - }, "Check if Pure FP8 I/O GEMM is available on SM120"); + }, "[Alias for gemm_fp8_fp8_sm120_available] Check if Pure FP8 I/O GEMM is available on SM120"); m.def("gemm_fp8_fp8_sm120", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { if (A.dtype() != DataType::UInt8 || B.dtype() != DataType::UInt8 || D.dtype() != DataType::UInt8) { @@ -147,11 +154,20 @@ void init_gemm_fp8xfp8_fp8(py::module_& m) { }, py::arg("A"), py::arg("B"), py::arg("D"), py::arg("scale_A"), py::arg("scale_B"), "Blockwise scaled FP8 I/O GEMM for SM120: D = (A * scale_A) @ (B * scale_B)"); - // Get scale factor sizes for FP8 blockwise GEMM - m.def("fp8_fp8_get_scale_sizes", [](int M, int N, int K) { + // ============================================================ + // Helper: Get scale factor sizes for FP8 blockwise GEMM + // New name: gemm_fp8_fp8_get_scale_sizes, alias: fp8_fp8_get_scale_sizes + // ============================================================ + m.def("gemm_fp8_fp8_get_scale_sizes", [](int M, int N, int K) { size_t sfa_size, sfb_size; pygpukit_fp8_fp8_get_scale_sizes(M, N, K, &sfa_size, &sfb_size); return py::make_tuple(sfa_size, sfb_size); }, py::arg("M"), py::arg("N"), py::arg("K"), "Get scale factor sizes for FP8 blockwise GEMM (returns (sfa_size, sfb_size))"); + m.def("fp8_fp8_get_scale_sizes", [](int M, int N, int K) { + size_t sfa_size, sfb_size; + pygpukit_fp8_fp8_get_scale_sizes(M, N, K, &sfa_size, &sfb_size); + return py::make_tuple(sfa_size, sfb_size); + }, py::arg("M"), py::arg("N"), py::arg("K"), + "[Alias for gemm_fp8_fp8_get_scale_sizes] Get scale factor sizes for FP8 blockwise GEMM"); } diff --git a/native/bindings/gemm/generic.cpp b/native/bindings/gemm/generic.cpp index fb55414..0e454e8 100644 --- a/native/bindings/gemm/generic.cpp +++ b/native/bindings/gemm/generic.cpp @@ -4,28 +4,54 @@ #include "../bindings_common.hpp" void init_gemm_generic(py::module_& m) { - // Basic matmul + // ============================================================ + // Basic matmul (F32 -> F32) + // New name: gemm_f32_f32, alias: matmul + // ============================================================ + m.def("gemm_f32_f32", py::overload_cast(&ops::matmul), + py::arg("a"), py::arg("b"), + "GEMM F32->F32: Matrix multiplication of two GPUArrays"); m.def("matmul", py::overload_cast(&ops::matmul), py::arg("a"), py::arg("b"), - "Matrix multiplication of two GPUArrays"); + "[Alias for gemm_f32_f32] Matrix multiplication of two GPUArrays"); + m.def("gemm_f32_f32_", py::overload_cast(&ops::matmul), + py::arg("a"), py::arg("b"), py::arg("out"), + "GEMM F32->F32: Matrix multiplication with output array"); m.def("matmul_", py::overload_cast(&ops::matmul), py::arg("a"), py::arg("b"), py::arg("out"), - "Matrix multiplication with output array"); + "[Alias for gemm_f32_f32_] Matrix multiplication with output array"); - // TF32 variants + // ============================================================ + // TF32 variants (TF32 compute -> F32 output) + // New name: gemm_tf32_f32, alias: matmul_tf32 + // ============================================================ + m.def("gemm_tf32_f32", py::overload_cast(&ops::matmul), + py::arg("a"), py::arg("b"), py::arg("use_tf32"), + "GEMM TF32->F32: Matrix multiplication with TF32 TensorCore"); m.def("matmul_tf32", py::overload_cast(&ops::matmul), py::arg("a"), py::arg("b"), py::arg("use_tf32"), - "Matrix multiplication with explicit TF32 control"); + "[Alias for gemm_tf32_f32] Matrix multiplication with explicit TF32 control"); + m.def("gemm_tf32_f32_", py::overload_cast(&ops::matmul), + py::arg("a"), py::arg("b"), py::arg("out"), py::arg("use_tf32"), + "GEMM TF32->F32: Matrix multiplication with TF32 TensorCore and output array"); m.def("matmul_tf32_", py::overload_cast(&ops::matmul), py::arg("a"), py::arg("b"), py::arg("out"), py::arg("use_tf32"), - "Matrix multiplication with explicit TF32 control and output array"); + "[Alias for gemm_tf32_f32_] Matrix multiplication with explicit TF32 control and output array"); - // Strided Batched GEMM + // ============================================================ + // Strided Batched GEMM (F32 -> F32) + // New name: gemm_f32_f32_batched, alias: gemm_strided_batched_fp32 + // ============================================================ + m.def("gemm_f32_f32_batched", &ops::batched_matmul_fp32, + py::arg("A"), py::arg("B"), py::arg("C"), + py::arg("M"), py::arg("N"), py::arg("K"), py::arg("batch_count"), + py::arg("strideA"), py::arg("strideB"), py::arg("strideC"), + "GEMM F32->F32 batched: C[b] = A[b] @ B[b] for b in [0, batch_count)"); m.def("gemm_strided_batched_fp32", &ops::batched_matmul_fp32, py::arg("A"), py::arg("B"), py::arg("C"), py::arg("M"), py::arg("N"), py::arg("K"), py::arg("batch_count"), py::arg("strideA"), py::arg("strideB"), py::arg("strideC"), - "Strided batched GEMM: C[b] = A[b] @ B[b] for b in [0, batch_count)"); + "[Alias for gemm_f32_f32_batched] Strided batched GEMM: C[b] = A[b] @ B[b] for b in [0, batch_count)"); } diff --git a/native/bindings/gemm/grouped.cpp b/native/bindings/gemm/grouped.cpp index d2317c3..7606aa8 100644 --- a/native/bindings/gemm/grouped.cpp +++ b/native/bindings/gemm/grouped.cpp @@ -14,6 +14,10 @@ extern "C" { } void init_gemm_grouped(py::module_& m) { + // ============================================================ + // Grouped GEMM for MoE: FP8 weights x BF16 activations -> BF16 output + // Functions already follow convention, just add _sm120 suffix where missing + // ============================================================ m.def("grouped_gemm_init_lut", []() { cudaError_t err = pygpukit_grouped_gemm_init_lut(); if (err != cudaSuccess) { @@ -21,7 +25,8 @@ void init_gemm_grouped(py::module_& m) { } }, "Initialize FP8->BF16 LUT for grouped GEMM"); - m.def("grouped_gemm_fp8_bf16", []( + // New name: grouped_gemm_fp8_bf16_sm120, alias: grouped_gemm_fp8_bf16 + m.def("grouped_gemm_fp8_bf16_sm120", []( const GPUArray& A, const GPUArray& B_stacked, const GPUArray& B_scale, @@ -70,9 +75,56 @@ void init_gemm_grouped(py::module_& m) { M, N, K, nullptr ); + if (err != cudaSuccess) { + throw std::runtime_error("grouped_gemm_fp8_bf16_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B_stacked"), py::arg("B_scale"), py::arg("C"), py::arg("row_expert_ids"), + "Grouped GEMM FP8->BF16 for SM120: C[M,N] = A[M,K] @ B_stacked[experts,N,K] with per-row expert IDs"); + // Alias: grouped_gemm_fp8_bf16 + m.def("grouped_gemm_fp8_bf16", []( + const GPUArray& A, + const GPUArray& B_stacked, + const GPUArray& B_scale, + GPUArray& C, + const GPUArray& row_expert_ids + ) { + if (A.dtype() != DataType::BFloat16) { + throw std::runtime_error("grouped_gemm_fp8_bf16: A must be bfloat16"); + } + if (B_stacked.dtype() != DataType::UInt8) { + throw std::runtime_error("grouped_gemm_fp8_bf16: B_stacked must be uint8 (FP8)"); + } + if (B_scale.dtype() != DataType::BFloat16) { + throw std::runtime_error("grouped_gemm_fp8_bf16: B_scale must be bfloat16"); + } + if (C.dtype() != DataType::BFloat16) { + throw std::runtime_error("grouped_gemm_fp8_bf16: C must be bfloat16"); + } + if (row_expert_ids.dtype() != DataType::Int32) { + throw std::runtime_error("grouped_gemm_fp8_bf16: row_expert_ids must be int32"); + } + if (A.ndim() != 2 || B_stacked.ndim() != 3 || C.ndim() != 2) { + throw std::runtime_error("grouped_gemm_fp8_bf16: invalid dimensions"); + } + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B_stacked.shape()[1]; + if (B_stacked.shape()[2] != static_cast(K)) { + throw std::runtime_error("grouped_gemm_fp8_bf16: K dimension mismatch"); + } + if (C.shape()[0] != static_cast(M) || C.shape()[1] != static_cast(N)) { + throw std::runtime_error("grouped_gemm_fp8_bf16: output shape mismatch"); + } + if (row_expert_ids.ndim() != 1 || row_expert_ids.shape()[0] != static_cast(M)) { + throw std::runtime_error("grouped_gemm_fp8_bf16: row_expert_ids size mismatch"); + } + cudaError_t err = pygpukit_grouped_gemm_fp8_bf16( + A.data(), B_stacked.data(), B_scale.data(), C.data(), + reinterpret_cast(row_expert_ids.data()), + M, N, K, nullptr); if (err != cudaSuccess) { throw std::runtime_error("grouped_gemm_fp8_bf16 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B_stacked"), py::arg("B_scale"), py::arg("C"), py::arg("row_expert_ids"), - "Grouped GEMM for MoE: C[M,N] = A[M,K] @ B_stacked[experts,N,K] with per-row expert IDs"); + "[Alias for grouped_gemm_fp8_bf16_sm120] Grouped GEMM for MoE"); } diff --git a/native/bindings/gemm/int.cpp b/native/bindings/gemm/int.cpp index 4f8b5f7..1eff4bd 100644 --- a/native/bindings/gemm/int.cpp +++ b/native/bindings/gemm/int.cpp @@ -28,14 +28,52 @@ extern "C" { } void init_gemm_int(py::module_& m) { - // Int8 GEMM + // ============================================================ + // Int8 GEMM: Int8 x Int8 -> Int32 (SM120) + // New name: gemm_int8_int32_available, alias: int8_native_gemm_available + // ============================================================ + m.def("gemm_int8_int32_available", []() { + return pygpukit_int8_native_gemm_available(); + }, "Check if Int8 GEMM (Int32 output) is available on SM120"); m.def("int8_native_gemm_available", []() { return pygpukit_int8_native_gemm_available(); - }, "Check if native Int8 GEMM is available (uses dp4a CUDA cores)"); + }, "[Alias for gemm_int8_int32_available] Check if native Int8 GEMM is available"); - m.def("int8_native_gemm_sm120", []( - const GPUArray& A, const GPUArray& B, GPUArray& D - ) { + // New name: gemm_int8_int32_sm120, alias: int8_native_gemm_sm120 + m.def("gemm_int8_int32_sm120", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + if (A.dtype() != DataType::Int8) { + throw std::runtime_error("gemm_int8_int32_sm120: A must be int8"); + } + if (B.dtype() != DataType::Int8) { + throw std::runtime_error("gemm_int8_int32_sm120: B must be int8"); + } + if (D.dtype() != DataType::Int32) { + throw std::runtime_error("gemm_int8_int32_sm120: D must be int32"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_int8_int32_sm120: A[M,K], B[N,K], D[M,N] required"); + } + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B.shape()[0]; + if (B.shape()[1] != static_cast(K)) { + throw std::runtime_error("gemm_int8_int32_sm120: K dimension mismatch"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_int8_int32_sm120: output shape mismatch"); + } + cudaError_t err = pygpukit_gemm_int8_native_sm120( + reinterpret_cast(A.data()), + reinterpret_cast(B.data()), + reinterpret_cast(D.data()), + M, N, K, nullptr); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_int8_int32_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), + "GEMM Int8->Int32 for SM120: D[M,N] = A[M,K] @ B[N,K]^T using dp4a CUDA cores"); + // Alias: int8_native_gemm_sm120 + m.def("int8_native_gemm_sm120", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { if (A.dtype() != DataType::Int8) { throw std::runtime_error("int8_native_gemm_sm120: A must be int8"); } @@ -48,37 +86,76 @@ void init_gemm_int(py::module_& m) { if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { throw std::runtime_error("int8_native_gemm_sm120: A[M,K], B[N,K], D[M,N] required"); } - int M = A.shape()[0]; int K = A.shape()[1]; int N = B.shape()[0]; - if (B.shape()[1] != static_cast(K)) { throw std::runtime_error("int8_native_gemm_sm120: K dimension mismatch"); } if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { throw std::runtime_error("int8_native_gemm_sm120: output shape mismatch"); } - cudaError_t err = pygpukit_gemm_int8_native_sm120( reinterpret_cast(A.data()), reinterpret_cast(B.data()), reinterpret_cast(D.data()), - M, N, K, - nullptr - ); - + M, N, K, nullptr); if (err != cudaSuccess) { throw std::runtime_error("int8_native_gemm_sm120 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B"), py::arg("D"), - "Native Int8 GEMM using dp4a: D[M,N] = A[M,K] @ B[N,K]^T with exact Int32 output"); + "[Alias for gemm_int8_int32_sm120] Native Int8 GEMM using dp4a"); - // Int4 GEMM + // ============================================================ + // Int4 GEMM: Int4 x Int4 -> Int32/Int8 (SM120) + // New name: gemm_int4_int32_available, alias: int4_gemm_available + // ============================================================ + m.def("gemm_int4_int32_available", []() { + return pygpukit_int4_gemm_sm120_available(); + }, "Check if Int4 GEMM (Int32 output) is available on SM120"); m.def("int4_gemm_available", []() { return pygpukit_int4_gemm_sm120_available(); - }, "Check if Int4 GEMM is available (SM120 via Int8/FP8 approximation)"); + }, "[Alias for gemm_int4_int32_available] Check if Int4 GEMM is available"); + // New name: gemm_int4_int32_sm120, alias: int4_gemm_int32_sm120 + m.def("gemm_int4_int32_sm120", []( + const GPUArray& A, const GPUArray& B, GPUArray& D, + float scale_A, float scale_B, float descale_D + ) { + if (A.dtype() != DataType::UInt8) { + throw std::runtime_error("gemm_int4_int32_sm120: A must be uint8 (packed int4)"); + } + if (B.dtype() != DataType::UInt8) { + throw std::runtime_error("gemm_int4_int32_sm120: B must be uint8 (packed int4)"); + } + if (D.dtype() != DataType::Int32) { + throw std::runtime_error("gemm_int4_int32_sm120: D must be int32"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_int4_int32_sm120: A[M,K/2], B[N,K/2], D[M,N] required"); + } + int M = A.shape()[0]; + int K_packed = A.shape()[1]; + int K = K_packed * 2; + int N = B.shape()[0]; + if (B.shape()[1] != static_cast(K_packed)) { + throw std::runtime_error("gemm_int4_int32_sm120: K dimension mismatch"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_int4_int32_sm120: output shape mismatch"); + } + cudaError_t err = pygpukit_gemm_int4_int4_int32_sm120( + reinterpret_cast(A.data()), + reinterpret_cast(B.data()), + reinterpret_cast(D.data()), + M, N, K, scale_A, scale_B, descale_D, nullptr); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_int4_int32_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), + py::arg("scale_A") = 1.0f, py::arg("scale_B") = 1.0f, py::arg("descale_D") = 1.0f, + "GEMM Int4->Int32 for SM120: D[M,N] = A[M,K] @ B[N,K]^T (packed int4 input)"); + // Alias: int4_gemm_int32_sm120 m.def("int4_gemm_int32_sm120", []( const GPUArray& A, const GPUArray& B, GPUArray& D, float scale_A, float scale_B, float descale_D @@ -95,35 +172,67 @@ void init_gemm_int(py::module_& m) { if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { throw std::runtime_error("int4_gemm_int32_sm120: A[M,K/2], B[N,K/2], D[M,N] required"); } - int M = A.shape()[0]; int K_packed = A.shape()[1]; int K = K_packed * 2; int N = B.shape()[0]; - if (B.shape()[1] != static_cast(K_packed)) { throw std::runtime_error("int4_gemm_int32_sm120: K dimension mismatch"); } if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { throw std::runtime_error("int4_gemm_int32_sm120: output shape mismatch"); } - cudaError_t err = pygpukit_gemm_int4_int4_int32_sm120( reinterpret_cast(A.data()), reinterpret_cast(B.data()), reinterpret_cast(D.data()), - M, N, K, - scale_A, scale_B, descale_D, - nullptr - ); - + M, N, K, scale_A, scale_B, descale_D, nullptr); if (err != cudaSuccess) { throw std::runtime_error("int4_gemm_int32_sm120 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B"), py::arg("D"), py::arg("scale_A") = 1.0f, py::arg("scale_B") = 1.0f, py::arg("descale_D") = 1.0f, - "Int4 GEMM via Int8/FP8: D[M,N] = A[M,K] @ B[N,K]^T with Int32 output. Input is packed int4."); + "[Alias for gemm_int4_int32_sm120] Int4 GEMM via Int8/FP8"); + // New name: gemm_int4_int8_sm120, alias: int4_gemm_int8_sm120 + m.def("gemm_int4_int8_sm120", []( + const GPUArray& A, const GPUArray& B, GPUArray& D, + float scale_A, float scale_B, float descale_D + ) { + if (A.dtype() != DataType::UInt8) { + throw std::runtime_error("gemm_int4_int8_sm120: A must be uint8 (packed int4)"); + } + if (B.dtype() != DataType::UInt8) { + throw std::runtime_error("gemm_int4_int8_sm120: B must be uint8 (packed int4)"); + } + if (D.dtype() != DataType::Int8) { + throw std::runtime_error("gemm_int4_int8_sm120: D must be int8"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_int4_int8_sm120: A[M,K/2], B[N,K/2], D[M,N] required"); + } + int M = A.shape()[0]; + int K_packed = A.shape()[1]; + int K = K_packed * 2; + int N = B.shape()[0]; + if (B.shape()[1] != static_cast(K_packed)) { + throw std::runtime_error("gemm_int4_int8_sm120: K dimension mismatch"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_int4_int8_sm120: output shape mismatch"); + } + cudaError_t err = pygpukit_gemm_int4_int4_int8_sm120( + reinterpret_cast(A.data()), + reinterpret_cast(B.data()), + reinterpret_cast(D.data()), + M, N, K, scale_A, scale_B, descale_D, nullptr); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_int4_int8_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), + py::arg("scale_A") = 1.0f, py::arg("scale_B") = 1.0f, py::arg("descale_D") = 1.0f, + "GEMM Int4->Int8 for SM120: D[M,N] = A[M,K] @ B[N,K]^T (packed int4 input)"); + // Alias: int4_gemm_int8_sm120 m.def("int4_gemm_int8_sm120", []( const GPUArray& A, const GPUArray& B, GPUArray& D, float scale_A, float scale_B, float descale_D @@ -140,32 +249,25 @@ void init_gemm_int(py::module_& m) { if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { throw std::runtime_error("int4_gemm_int8_sm120: A[M,K/2], B[N,K/2], D[M,N] required"); } - int M = A.shape()[0]; int K_packed = A.shape()[1]; int K = K_packed * 2; int N = B.shape()[0]; - if (B.shape()[1] != static_cast(K_packed)) { throw std::runtime_error("int4_gemm_int8_sm120: K dimension mismatch"); } if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { throw std::runtime_error("int4_gemm_int8_sm120: output shape mismatch"); } - cudaError_t err = pygpukit_gemm_int4_int4_int8_sm120( reinterpret_cast(A.data()), reinterpret_cast(B.data()), reinterpret_cast(D.data()), - M, N, K, - scale_A, scale_B, descale_D, - nullptr - ); - + M, N, K, scale_A, scale_B, descale_D, nullptr); if (err != cudaSuccess) { throw std::runtime_error("int4_gemm_int8_sm120 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B"), py::arg("D"), py::arg("scale_A") = 1.0f, py::arg("scale_B") = 1.0f, py::arg("descale_D") = 1.0f, - "Int4 GEMM via Int8/FP8: D[M,N] = A[M,K] @ B[N,K]^T with Int8 output. Input is packed int4."); + "[Alias for gemm_int4_int8_sm120] Int4 GEMM via Int8/FP8 with Int8 output"); } diff --git a/native/bindings/gemm/nvf4xbf16_bf16.cpp b/native/bindings/gemm/nvf4xbf16_bf16.cpp index 3040ab7..cd07f2e 100644 --- a/native/bindings/gemm/nvf4xbf16_bf16.cpp +++ b/native/bindings/gemm/nvf4xbf16_bf16.cpp @@ -23,9 +23,16 @@ extern "C" { } void init_gemm_nvf4xbf16_bf16(py::module_& m) { + // ============================================================ + // NVF4 (4-bit) GEMM for SM120 with BF16 I/O + // New name: gemm_nvf4_bf16_sm120_available, alias: nvf4_bf16_sm120_available + // ============================================================ + m.def("gemm_nvf4_bf16_sm120_available", []() { + return pygpukit_nvf4_bf16_sm120_available(); + }, "Check if NVF4 BF16 GEMM is available on SM120 (Blackwell GeForce)"); m.def("nvf4_bf16_sm120_available", []() { return pygpukit_nvf4_bf16_sm120_available(); - }, "Check if NVF4 BF16 GEMM is available on SM120"); + }, "[Alias for gemm_nvf4_bf16_sm120_available] Check if NVF4 BF16 GEMM is available on SM120"); m.def("gemm_nvf4_bf16_sm120", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { if (A.dtype() != DataType::BFloat16 || B.dtype() != DataType::BFloat16 || D.dtype() != DataType::BFloat16) { @@ -61,9 +68,13 @@ void init_gemm_nvf4xbf16_bf16(py::module_& m) { }, py::arg("A"), py::arg("B"), py::arg("D"), "NVF4 (4-bit) GEMM for SM120 with BF16 I/O: D = A @ B (BF16 -> NVF4 quantize -> GEMM -> BF16)"); + // New name: gemm_nvf4_nvf4_sm120_available, alias: nvf4_nvf4_sm120_available + m.def("gemm_nvf4_nvf4_sm120_available", []() { + return pygpukit_nvf4_nvf4_sm120_available(); + }, "Check if pure NVF4 GEMM is available on SM120 (Blackwell GeForce)"); m.def("nvf4_nvf4_sm120_available", []() { return pygpukit_nvf4_nvf4_sm120_available(); - }, "Check if pure NVF4 GEMM is available (SM120+)"); + }, "[Alias for gemm_nvf4_nvf4_sm120_available] Check if pure NVF4 GEMM is available (SM120+)"); m.def("benchmark_gemm_nvf4_sm120", [](GPUArray& D, int M, int N, int K) { if (D.dtype() != DataType::BFloat16) { diff --git a/native/bindings/gemv/fp8xfp8_bf16.cpp b/native/bindings/gemv/fp8xfp8_bf16.cpp index a96bd0e..be06063 100644 --- a/native/bindings/gemv/fp8xfp8_bf16.cpp +++ b/native/bindings/gemv/fp8xfp8_bf16.cpp @@ -20,6 +20,43 @@ namespace gemv { } void init_gemv_fp8xfp8_bf16(py::module_& m) { + // ============================================================ + // FP8 GEMV: FP8 weights x BF16 activations -> BF16 output + // New name: gemv_fp8_bf16_sm120, alias: gemv_fp8_bf16_opt + // ============================================================ + m.def("gemv_fp8_bf16_sm120", [](const GPUArray& A, const GPUArray& B_nk, const GPUArray& B_scale, GPUArray& C) { + if (A.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemv_fp8_bf16_sm120: A and C must be bfloat16"); + } + if (B_nk.dtype() != DataType::UInt8) { + throw std::runtime_error("gemv_fp8_bf16_sm120: B_nk must be uint8 (FP8 E4M3)"); + } + if (B_scale.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemv_fp8_bf16_sm120: B_scale must be bfloat16"); + } + if (A.ndim() != 1 || B_nk.ndim() != 2 || C.ndim() != 1) { + throw std::runtime_error("gemv_fp8_bf16_sm120: A[K], B_nk[N,K], C[N] dimensions required"); + } + int K = A.shape()[0]; + int N = B_nk.shape()[0]; + if (B_nk.shape()[1] != static_cast(K)) { + throw std::runtime_error("gemv_fp8_bf16_sm120: K dimension mismatch"); + } + if (C.shape()[0] != static_cast(N)) { + throw std::runtime_error("gemv_fp8_bf16_sm120: N dimension mismatch"); + } + cudaError_t err = pygpukit::ops::gemv::launch_gemv_fp8_opt( + reinterpret_cast(A.data()), + reinterpret_cast(B_nk.data()), + reinterpret_cast(B_scale.data()), + reinterpret_cast<__nv_bfloat16*>(C.data()), + K, N, nullptr); + if (err != cudaSuccess) { + throw std::runtime_error("gemv_fp8_bf16_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B_nk"), py::arg("B_scale"), py::arg("C"), + "GEMV FP8->BF16 for SM120: C[N] = A[K] @ B_nk[N,K]^T (warp-reduce, smem, vec4)"); + // Alias: gemv_fp8_bf16_opt m.def("gemv_fp8_bf16_opt", [](const GPUArray& A, const GPUArray& B_nk, const GPUArray& B_scale, GPUArray& C) { if (A.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { throw std::runtime_error("gemv_fp8_bf16_opt: A and C must be bfloat16"); @@ -33,31 +70,61 @@ void init_gemv_fp8xfp8_bf16(py::module_& m) { if (A.ndim() != 1 || B_nk.ndim() != 2 || C.ndim() != 1) { throw std::runtime_error("gemv_fp8_bf16_opt: A[K], B_nk[N,K], C[N] dimensions required"); } - int K = A.shape()[0]; int N = B_nk.shape()[0]; - if (B_nk.shape()[1] != static_cast(K)) { throw std::runtime_error("gemv_fp8_bf16_opt: K dimension mismatch"); } if (C.shape()[0] != static_cast(N)) { throw std::runtime_error("gemv_fp8_bf16_opt: N dimension mismatch"); } - cudaError_t err = pygpukit::ops::gemv::launch_gemv_fp8_opt( reinterpret_cast(A.data()), reinterpret_cast(B_nk.data()), reinterpret_cast(B_scale.data()), reinterpret_cast<__nv_bfloat16*>(C.data()), - K, N, nullptr - ); - + K, N, nullptr); if (err != cudaSuccess) { throw std::runtime_error("gemv_fp8_bf16_opt failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B_nk"), py::arg("B_scale"), py::arg("C"), - "Optimized FP8 GEMV: C[N] = A[K] @ B_nk[N,K]^T (warp-reduce, smem, vec4)"); + "[Alias for gemv_fp8_bf16_sm120] Optimized FP8 GEMV"); + // New name: gemv_fp8_bf16_batched_sm120, alias: gemv_fp8_bf16_opt_batched + m.def("gemv_fp8_bf16_batched_sm120", [](const GPUArray& A, const GPUArray& B_nk, const GPUArray& B_scale, GPUArray& C) { + if (A.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemv_fp8_bf16_batched_sm120: A and C must be bfloat16"); + } + if (B_nk.dtype() != DataType::UInt8) { + throw std::runtime_error("gemv_fp8_bf16_batched_sm120: B_nk must be uint8 (FP8 E4M3)"); + } + if (B_scale.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemv_fp8_bf16_batched_sm120: B_scale must be bfloat16"); + } + if (A.ndim() != 2 || B_nk.ndim() != 2 || C.ndim() != 2) { + throw std::runtime_error("gemv_fp8_bf16_batched_sm120: A[M,K], B_nk[N,K], C[M,N] dimensions required"); + } + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B_nk.shape()[0]; + if (B_nk.shape()[1] != static_cast(K)) { + throw std::runtime_error("gemv_fp8_bf16_batched_sm120: K dimension mismatch"); + } + if (C.shape()[0] != static_cast(M) || C.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemv_fp8_bf16_batched_sm120: output shape mismatch"); + } + cudaError_t err = pygpukit::ops::gemv::launch_gemv_fp8_opt_batched( + reinterpret_cast(A.data()), + reinterpret_cast(B_nk.data()), + reinterpret_cast(B_scale.data()), + reinterpret_cast<__nv_bfloat16*>(C.data()), + K, N, M, nullptr); + if (err != cudaSuccess) { + throw std::runtime_error("gemv_fp8_bf16_batched_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B_nk"), py::arg("B_scale"), py::arg("C"), + "GEMV FP8->BF16 batched for SM120: C[M,N] = A[M,K] @ B_nk[N,K]^T"); + // Alias: gemv_fp8_bf16_opt_batched m.def("gemv_fp8_bf16_opt_batched", [](const GPUArray& A, const GPUArray& B_nk, const GPUArray& B_scale, GPUArray& C) { if (A.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { throw std::runtime_error("gemv_fp8_bf16_opt_batched: A and C must be bfloat16"); @@ -71,29 +138,24 @@ void init_gemv_fp8xfp8_bf16(py::module_& m) { if (A.ndim() != 2 || B_nk.ndim() != 2 || C.ndim() != 2) { throw std::runtime_error("gemv_fp8_bf16_opt_batched: A[M,K], B_nk[N,K], C[M,N] dimensions required"); } - int M = A.shape()[0]; int K = A.shape()[1]; int N = B_nk.shape()[0]; - if (B_nk.shape()[1] != static_cast(K)) { throw std::runtime_error("gemv_fp8_bf16_opt_batched: K dimension mismatch"); } if (C.shape()[0] != static_cast(M) || C.shape()[1] != static_cast(N)) { throw std::runtime_error("gemv_fp8_bf16_opt_batched: output shape mismatch"); } - cudaError_t err = pygpukit::ops::gemv::launch_gemv_fp8_opt_batched( reinterpret_cast(A.data()), reinterpret_cast(B_nk.data()), reinterpret_cast(B_scale.data()), reinterpret_cast<__nv_bfloat16*>(C.data()), - K, N, M, nullptr - ); - + K, N, M, nullptr); if (err != cudaSuccess) { throw std::runtime_error("gemv_fp8_bf16_opt_batched failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B_nk"), py::arg("B_scale"), py::arg("C"), - "Optimized batched FP8 GEMV: C[M,N] = A[M,K] @ B_nk[N,K]^T (warp-reduce, smem, vec4)"); + "[Alias for gemv_fp8_bf16_batched_sm120] Optimized batched FP8 GEMV"); } diff --git a/native/bindings/gemv/generic.cpp b/native/bindings/gemv/generic.cpp index fe25ae4..9bc46f2 100644 --- a/native/bindings/gemv/generic.cpp +++ b/native/bindings/gemv/generic.cpp @@ -13,6 +13,36 @@ extern "C" { } void init_gemv_generic(py::module_& m) { + // ============================================================ + // BF16 GEMV: BF16 x BF16 -> BF16 (SM120) + // New name: gemv_bf16_bf16_sm120, alias: gemv_bf16_opt_sm120 + // ============================================================ + m.def("gemv_bf16_bf16_sm120", [](const GPUArray& A, const GPUArray& B_nk, GPUArray& C) { + if (A.dtype() != DataType::BFloat16 || B_nk.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemv_bf16_bf16_sm120: all inputs must be bfloat16"); + } + if (A.ndim() != 1 || B_nk.ndim() != 2 || C.ndim() != 1) { + throw std::runtime_error("gemv_bf16_bf16_sm120: A[K], B_nk[N,K], C[N] dimensions required"); + } + int K = A.shape()[0]; + int N = B_nk.shape()[0]; + if (B_nk.shape()[1] != static_cast(K)) { + throw std::runtime_error("gemv_bf16_bf16_sm120: K dimension mismatch"); + } + if (C.shape()[0] != static_cast(N)) { + throw std::runtime_error("gemv_bf16_bf16_sm120: N dimension mismatch"); + } + cudaError_t err = pygpukit_gemv_bf16_opt_sm120( + reinterpret_cast(A.data()), + reinterpret_cast(B_nk.data()), + reinterpret_cast<__nv_bfloat16*>(C.data()), + K, N, nullptr); + if (err != cudaSuccess) { + throw std::runtime_error("gemv_bf16_bf16_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B_nk"), py::arg("C"), + "GEMV BF16->BF16 for SM120: C[N] = A[K] @ B_nk[N,K]^T (warp-reduce optimized)"); + // Alias: gemv_bf16_opt_sm120 m.def("gemv_bf16_opt_sm120", [](const GPUArray& A, const GPUArray& B_nk, GPUArray& C) { if (A.dtype() != DataType::BFloat16 || B_nk.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { throw std::runtime_error("gemv_bf16_opt_sm120: all inputs must be bfloat16"); @@ -20,31 +50,30 @@ void init_gemv_generic(py::module_& m) { if (A.ndim() != 1 || B_nk.ndim() != 2 || C.ndim() != 1) { throw std::runtime_error("gemv_bf16_opt_sm120: A[K], B_nk[N,K], C[N] dimensions required"); } - int K = A.shape()[0]; int N = B_nk.shape()[0]; - if (B_nk.shape()[1] != static_cast(K)) { throw std::runtime_error("gemv_bf16_opt_sm120: K dimension mismatch"); } if (C.shape()[0] != static_cast(N)) { throw std::runtime_error("gemv_bf16_opt_sm120: N dimension mismatch"); } - cudaError_t err = pygpukit_gemv_bf16_opt_sm120( reinterpret_cast(A.data()), reinterpret_cast(B_nk.data()), reinterpret_cast<__nv_bfloat16*>(C.data()), - K, N, nullptr - ); - + K, N, nullptr); if (err != cudaSuccess) { throw std::runtime_error("gemv_bf16_opt_sm120 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B_nk"), py::arg("C"), - "Optimized BF16 GEMV: C[N] = A[K] @ B_nk[N,K]^T (warp-reduce, B[N,K] layout)"); + "[Alias for gemv_bf16_bf16_sm120] Optimized BF16 GEMV"); + // New name: gemv_bf16_bf16_available, alias: gemv_bf16_opt_available + m.def("gemv_bf16_bf16_available", []() { + return pygpukit_gemv_bf16_opt_sm120_available(); + }, "Check if BF16 GEMV is available (SM80+)"); m.def("gemv_bf16_opt_available", []() { return pygpukit_gemv_bf16_opt_sm120_available(); - }, "Check if optimized BF16 GEMV is available (SM80+)"); + }, "[Alias for gemv_bf16_bf16_available] Check if optimized BF16 GEMV is available"); } diff --git a/native/bindings/gemv/nvf4xbf16_bf16.cpp b/native/bindings/gemv/nvf4xbf16_bf16.cpp index 1957a69..9fe8638 100644 --- a/native/bindings/gemv/nvf4xbf16_bf16.cpp +++ b/native/bindings/gemv/nvf4xbf16_bf16.cpp @@ -22,16 +22,30 @@ extern "C" { } void init_gemv_nvf4xbf16_bf16(py::module_& m) { + // ============================================================ + // NVF4 GEMV: NVF4 weights x BF16 activations -> BF16 output (SM120) + // New name: gemv_nvf4_bf16_sm120_available, alias: gemv_nvf4_available + // ============================================================ + m.def("gemv_nvf4_bf16_sm120_available", []() { + return pygpukit_gemv_nvf4_available(); + }, "Check if NVF4 GEMV is available on SM120 (Blackwell GeForce)"); m.def("gemv_nvf4_available", []() { return pygpukit_gemv_nvf4_available(); - }, "Check if NVF4 GEMV is available (SM120+)"); + }, "[Alias for gemv_nvf4_bf16_sm120_available] Check if NVF4 GEMV is available"); + // New name: gemv_nvf4_get_sizes, alias: nvf4_get_sizes + m.def("gemv_nvf4_get_sizes", [](int K, int N) { + size_t data_size, scale_size; + pygpukit_nvf4_get_sizes(K, N, &data_size, &scale_size); + return py::make_tuple(data_size, scale_size); + }, py::arg("K"), py::arg("N"), + "Get buffer sizes for NVF4 GEMV quantization: returns (data_size, scale_size)"); m.def("nvf4_get_sizes", [](int K, int N) { size_t data_size, scale_size; pygpukit_nvf4_get_sizes(K, N, &data_size, &scale_size); return py::make_tuple(data_size, scale_size); }, py::arg("K"), py::arg("N"), - "Get buffer sizes for NVF4 quantization: returns (data_size, scale_size)"); + "[Alias for gemv_nvf4_get_sizes] Get buffer sizes for NVF4 quantization"); m.def("quantize_bf16_to_nvf4", [](const GPUArray& input, GPUArray& out_data, GPUArray& out_scale) { if (input.dtype() != DataType::BFloat16) { @@ -77,6 +91,25 @@ void init_gemv_nvf4xbf16_bf16(py::module_& m) { }, py::arg("input"), py::arg("out_data"), py::arg("out_scale"), "Quantize BF16 weights to NVF4 format (row-major output [N,K/2]) for pure NVF4/NVF4 GEMV"); + // New name: gemv_nvf4_bf16_sm120, alias: gemv_nvf4_bf16 + m.def("gemv_nvf4_bf16_sm120", [](const GPUArray& A, const GPUArray& B_data, const GPUArray& B_scale, GPUArray& C, float alpha) { + if (A.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemv_nvf4_bf16_sm120: A and C must be bfloat16"); + } + if (A.ndim() != 1) { + throw std::runtime_error("gemv_nvf4_bf16_sm120: A must be 1D [K]"); + } + int K = A.shape()[0]; + int N = C.shape()[0]; + cudaError_t err = pygpukit_gemv_nvf4_bf16( + A.data(), B_data.data(), B_scale.data(), C.data(), + K, N, alpha, nullptr); + if (err != cudaSuccess) { + throw std::runtime_error("gemv_nvf4_bf16_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B_data"), py::arg("B_scale"), py::arg("C"), py::arg("alpha") = 1.0f, + "GEMV NVF4->BF16 for SM120: C[N] = alpha * A[K] @ B[K,N] (NVF4 quantized weights)"); + // Alias: gemv_nvf4_bf16 m.def("gemv_nvf4_bf16", [](const GPUArray& A, const GPUArray& B_data, const GPUArray& B_scale, GPUArray& C, float alpha) { if (A.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { throw std::runtime_error("gemv_nvf4_bf16: A and C must be bfloat16"); @@ -84,18 +117,14 @@ void init_gemv_nvf4xbf16_bf16(py::module_& m) { if (A.ndim() != 1) { throw std::runtime_error("gemv_nvf4_bf16: A must be 1D [K]"); } - int K = A.shape()[0]; int N = C.shape()[0]; - cudaError_t err = pygpukit_gemv_nvf4_bf16( A.data(), B_data.data(), B_scale.data(), C.data(), - K, N, alpha, nullptr - ); - + K, N, alpha, nullptr); if (err != cudaSuccess) { throw std::runtime_error("gemv_nvf4_bf16 failed: " + std::string(cudaGetErrorString(err))); } }, py::arg("A"), py::arg("B_data"), py::arg("B_scale"), py::arg("C"), py::arg("alpha") = 1.0f, - "NVF4 GEMV for SM120: C[N] = alpha * A[K] @ B[K,N] (NVF4 quantized weights)"); + "[Alias for gemv_nvf4_bf16_sm120] NVF4 GEMV for SM120"); } diff --git a/src/pygpukit/ops/basic.py b/src/pygpukit/ops/basic.py index e616e63..e60c9ae 100644 --- a/src/pygpukit/ops/basic.py +++ b/src/pygpukit/ops/basic.py @@ -46,7 +46,7 @@ kv_cache_update_gqa_ptr, ) -# Re-export matmul operations +# Re-export matmul operations (both old and new standardized names) from pygpukit.ops.matmul import ( batched_matmul, fp8_available, @@ -57,14 +57,37 @@ fp8_sm90_available, fp8_sm100_available, fp8_sm120_available, + # New standardized GEMM names + gemm_fp8_available, + gemm_fp8_f32_sm90, + gemm_fp8_f32_sm90_available, + gemm_fp8_f32_sm100, + gemm_fp8_f32_sm100_available, + gemm_fp8_f32_sm120, + gemm_fp8_f32_sm120_available, + gemm_fp8_fp8_blockwise_sm120, + gemm_fp8_fp8_get_scale_sizes, + gemm_fp8_fp8_sm120, + gemm_fp8_fp8_sm120_available, + gemm_nvf4_bf16_sm120, + gemm_nvf4_bf16_sm120_available, + gemm_w8a16_bf16_sm120, + gemm_w8a16_init_lut, # GEMV operations gemv_bf16, + gemv_bf16_bf16_sm120, # New standardized name gemv_fp8_bf16, gemv_fp8_bf16_batched, + gemv_fp8_bf16_batched_sm120, # New standardized name + gemv_fp8_bf16_sm120, # New standardized name gemv_nvf4_available, gemv_nvf4_bf16, + gemv_nvf4_bf16_sm120, # New standardized name + gemv_nvf4_bf16_sm120_available, # New standardized name + gemv_nvf4_get_sizes, # New standardized name # Grouped GEMM for MoE grouped_gemm_fp8_bf16, + grouped_gemm_fp8_bf16_sm120, # New standardized name grouped_gemm_init_lut, linear_bias_gelu, matmul, @@ -201,17 +224,45 @@ "matmul_fp8_fp8_blockwise_sm120", "matmul_fp8_fp8_sm120", "nvf4_bf16_sm120_available", - # GEMV + # GEMV (old names) "gemv_bf16", "gemv_fp8_bf16", "gemv_fp8_bf16_batched", "gemv_nvf4_bf16", "gemv_nvf4_available", - # W8A16 GEMM + # GEMV (new standardized names) + "gemv_bf16_bf16_sm120", + "gemv_fp8_bf16_sm120", + "gemv_fp8_bf16_batched_sm120", + "gemv_nvf4_bf16_sm120", + "gemv_nvf4_bf16_sm120_available", + "gemv_nvf4_get_sizes", + # W8A16 GEMM (old name) "w8a16_gemm_sm120", - # Grouped GEMM for MoE + # W8A16 GEMM (new standardized names) + "gemm_w8a16_bf16_sm120", + "gemm_w8a16_init_lut", + # Grouped GEMM for MoE (old names) "grouped_gemm_fp8_bf16", "grouped_gemm_init_lut", + # Grouped GEMM (new standardized name) + "grouped_gemm_fp8_bf16_sm120", + # New standardized GEMM availability functions + "gemm_fp8_available", + "gemm_fp8_f32_sm90_available", + "gemm_fp8_f32_sm100_available", + "gemm_fp8_f32_sm120_available", + "gemm_fp8_fp8_sm120_available", + "gemm_fp8_fp8_get_scale_sizes", + "gemm_nvf4_bf16_sm120_available", + # New standardized GEMM functions + "gemm_fp8_f32_sm90", + "gemm_fp8_f32_sm100", + "gemm_fp8_f32_sm120", + "gemm_fp8_fp8_sm120", + "gemm_fp8_fp8_blockwise_sm120", + "gemm_nvf4_bf16_sm120", + # Utility functions "fp8_init_lut", "fp8_get_sizes", "nvf4_get_sizes", diff --git a/src/pygpukit/ops/matmul.py b/src/pygpukit/ops/matmul.py index c892e85..791667b 100644 --- a/src/pygpukit/ops/matmul.py +++ b/src/pygpukit/ops/matmul.py @@ -518,11 +518,20 @@ def fp8_available() -> bool: from pygpukit.core.backend import get_native_module native = get_native_module() - return native.fp8_available() + # Check all FP8 backends - return True if any is available + return ( + native.gemm_fp8_f32_sm90_available() + or native.gemm_fp8_f32_sm100_available() + or native.gemm_fp8_f32_sm120_available() + ) else: return False +# Alias for standardized naming +gemm_fp8_available = fp8_available + + def fp8_sm90_available() -> bool: """Check if FP8 GEMM is available on SM90 (Hopper). @@ -535,11 +544,16 @@ def fp8_sm90_available() -> bool: from pygpukit.core.backend import get_native_module native = get_native_module() - return native.fp8_sm90_available() + # Use new standardized name + return native.gemm_fp8_f32_sm90_available() else: return False +# Alias for standardized naming +gemm_fp8_f32_sm90_available = fp8_sm90_available + + def fp8_sm100_available() -> bool: """Check if FP8 GEMM is available on SM100 (Blackwell datacenter). @@ -555,11 +569,16 @@ def fp8_sm100_available() -> bool: from pygpukit.core.backend import get_native_module native = get_native_module() - return native.fp8_sm100_available() + # Use new standardized name + return native.gemm_fp8_f32_sm100_available() else: return False +# Alias for standardized naming +gemm_fp8_f32_sm100_available = fp8_sm100_available + + def fp8_sm120_available() -> bool: """Check if FP8 GEMM is available on SM120 (Blackwell GeForce). @@ -574,11 +593,16 @@ def fp8_sm120_available() -> bool: from pygpukit.core.backend import get_native_module native = get_native_module() - return native.fp8_sm120_available() + # Use new standardized name + return native.gemm_fp8_f32_sm120_available() else: return False +# Alias for standardized naming +gemm_fp8_f32_sm120_available = fp8_sm120_available + + def fp8_fp8_sm120_available() -> bool: """Check if Pure FP8 I/O GEMM is available on SM120 (Blackwell GeForce). @@ -593,11 +617,16 @@ def fp8_fp8_sm120_available() -> bool: from pygpukit.core.backend import get_native_module native = get_native_module() - return native.fp8_fp8_sm120_available() + # Use new standardized name + return native.gemm_fp8_fp8_sm120_available() else: return False +# Alias for standardized naming +gemm_fp8_fp8_sm120_available = fp8_fp8_sm120_available + + def matmul_fp8_fp8_sm120( a: GPUArray, b: GPUArray, @@ -688,12 +717,16 @@ def _matmul_fp8_fp8_sm120_native( else: out_native = out._get_native() - # Call Pure FP8 GEMM + # Call Pure FP8 GEMM (use new standardized name) native.gemm_fp8_fp8_sm120(a_native, b_native, out_native) return out +# Alias for standardized naming +gemm_fp8_fp8_sm120 = matmul_fp8_fp8_sm120 + + def fp8_fp8_get_scale_sizes(M: int, N: int, K: int) -> tuple[int, int]: """Get scale factor sizes for FP8 blockwise GEMM. @@ -720,11 +753,16 @@ def fp8_fp8_get_scale_sizes(M: int, N: int, K: int) -> tuple[int, int]: from pygpukit.core.backend import get_native_module native = get_native_module() - return native.fp8_fp8_get_scale_sizes(M, N, K) + # Use new standardized name + return native.gemm_fp8_fp8_get_scale_sizes(M, N, K) else: return (0, 0) +# Alias for standardized naming +gemm_fp8_fp8_get_scale_sizes = fp8_fp8_get_scale_sizes + + def matmul_fp8_fp8_blockwise_sm120( a: GPUArray, b: GPUArray, @@ -832,6 +870,10 @@ def _matmul_fp8_fp8_blockwise_sm120_native( return out +# Alias for standardized naming +gemm_fp8_fp8_blockwise_sm120 = matmul_fp8_fp8_blockwise_sm120 + + def matmul_fp8_sm100( a: GPUArray, b: GPUArray, @@ -919,12 +961,16 @@ def _matmul_fp8_sm100_native( else: out_native = out._get_native() - # Call FP8 GEMM - native.gemm_fp8_sm100(a_native, b_native, out_native) + # Call FP8 GEMM (use new standardized name) + native.gemm_fp8_f32_sm100(a_native, b_native, out_native) return out +# Alias for standardized naming +gemm_fp8_f32_sm100 = matmul_fp8_sm100 + + def matmul_fp8_sm120( a: GPUArray, b: GPUArray, @@ -1009,12 +1055,16 @@ def _matmul_fp8_sm120_native( else: out_native = out._get_native() - # Call FP8 GEMM - native.gemm_fp8_sm120(a_native, b_native, out_native) + # Call FP8 GEMM (use new standardized name) + native.gemm_fp8_f32_sm120(a_native, b_native, out_native) return out +# Alias for standardized naming +gemm_fp8_f32_sm120 = matmul_fp8_sm120 + + def matmul_fp8_sm90( a: GPUArray, b: GPUArray, @@ -1099,12 +1149,16 @@ def _matmul_fp8_sm90_native( else: out_native = out._get_native() - # Call FP8 GEMM - native.gemm_fp8_sm90(a_native, b_native, out_native) + # Call FP8 GEMM (use new standardized name) + native.gemm_fp8_f32_sm90(a_native, b_native, out_native) return out +# Alias for standardized naming +gemm_fp8_f32_sm90 = matmul_fp8_sm90 + + def nvf4_bf16_sm120_available() -> bool: """Check if NVF4 (4-bit) BF16 GEMM is available on SM120 (Blackwell GeForce). @@ -1120,11 +1174,16 @@ def nvf4_bf16_sm120_available() -> bool: from pygpukit.core.backend import get_native_module native = get_native_module() - return native.nvf4_bf16_sm120_available() + # Use new standardized name + return native.gemm_nvf4_bf16_sm120_available() else: return False +# Alias for standardized naming +gemm_nvf4_bf16_sm120_available = nvf4_bf16_sm120_available + + def matmul_nvf4_bf16_sm120( a: GPUArray, b: GPUArray, @@ -1205,6 +1264,10 @@ def _matmul_nvf4_bf16_sm120_native( return out +# Alias for standardized naming +gemm_nvf4_bf16_sm120 = matmul_nvf4_bf16_sm120 + + # ============================================================================ # GEMV Operations (M=1 special case) # ============================================================================ @@ -1222,11 +1285,16 @@ def gemv_nvf4_available() -> bool: from pygpukit.core.backend import get_native_module native = get_native_module() - return native.gemv_nvf4_available() + # Use new standardized name + return native.gemv_nvf4_bf16_sm120_available() else: return False +# Alias for standardized naming +gemv_nvf4_bf16_sm120_available = gemv_nvf4_available + + def nvf4_get_sizes(K: int, N: int) -> tuple[int, int]: """Get buffer sizes for NVF4-quantized weights. @@ -1249,6 +1317,10 @@ def nvf4_get_sizes(K: int, N: int) -> tuple[int, int]: return data_size, scale_size +# Alias for standardized naming +gemv_nvf4_get_sizes = nvf4_get_sizes + + def quantize_bf16_to_nvf4( input: GPUArray, out_data: GPUArray, @@ -1383,13 +1455,18 @@ def gemv_nvf4_bf16( else: out_native = out._get_native() - native.gemv_nvf4_bf16(a_native, data_native, scale_native, out_native, alpha) + # Use new standardized name + native.gemv_nvf4_bf16_sm120(a_native, data_native, scale_native, out_native, alpha) return out else: raise RuntimeError("NVF4 GEMV requires native backend") +# Alias for standardized naming +gemv_nvf4_bf16_sm120 = gemv_nvf4_bf16 + + def gemv_bf16( a: GPUArray, b: GPUArray, @@ -1456,8 +1533,8 @@ def gemv_bf16( else: out_native = out._get_native() - # Use optimized kernel with B[N,K] layout - native.gemv_bf16_opt_sm120(a_native, b_native, out_native) + # Use optimized kernel with B[N,K] layout (new standardized name) + native.gemv_bf16_bf16_sm120(a_native, b_native, out_native) return out else: @@ -1470,6 +1547,10 @@ def gemv_bf16( return from_numpy(result.astype(np.float16).view(np.uint16).astype(np.uint16)) +# Alias for standardized naming +gemv_bf16_bf16_sm120 = gemv_bf16 + + # Flag to track if FP8 LUT has been initialized _FP8_LUT_INITIALIZED = False @@ -1507,10 +1588,15 @@ def w8a16_gemm_init_lut() -> None: from pygpukit.core.backend import get_native_module native = get_native_module() - native.w8a16_gemm_init_lut() + # Use new standardized name + native.gemm_w8a16_init_lut() _W8A16_GEMM_LUT_INITIALIZED = True +# Alias for standardized naming +gemm_w8a16_init_lut = w8a16_gemm_init_lut + + def gemv_fp8_bf16( a: GPUArray, b_nk: GPUArray, @@ -1583,13 +1669,18 @@ def gemv_fp8_bf16( else: out_native = out._get_native() - native.gemv_fp8_bf16_opt(a_native, b_nk_native, b_scale_native, out_native) + # Use new standardized name + native.gemv_fp8_bf16_sm120(a_native, b_nk_native, b_scale_native, out_native) return out else: raise NotImplementedError("FP8 GEMV requires native GPU backend") +# Alias for standardized naming +gemv_fp8_bf16_sm120 = gemv_fp8_bf16 + + def gemv_fp8_bf16_batched( a: GPUArray, b_nk: GPUArray, @@ -1665,13 +1756,18 @@ def gemv_fp8_bf16_batched( else: out_native = out._get_native() - native.gemv_fp8_bf16_opt_batched(a_native, b_nk_native, b_scale_native, out_native) + # Use new standardized name + native.gemv_fp8_bf16_batched_sm120(a_native, b_nk_native, b_scale_native, out_native) return out else: raise NotImplementedError("FP8 batched GEMV requires native GPU backend") +# Alias for standardized naming +gemv_fp8_bf16_batched_sm120 = gemv_fp8_bf16_batched + + def w8a16_gemm_sm120( a: GPUArray, b_fp8: GPUArray, @@ -1747,13 +1843,18 @@ def w8a16_gemm_sm120( else: out_native = out._get_native() - native.w8a16_gemm_sm120(a_native, b_fp8_native, b_scale_native, out_native) + # Use new standardized name + native.gemm_w8a16_bf16_sm120(a_native, b_fp8_native, b_scale_native, out_native) return out else: raise NotImplementedError("W8A16 GEMM requires native GPU backend with SM120") +# Alias for standardized naming +gemm_w8a16_bf16_sm120 = w8a16_gemm_sm120 + + # Track if grouped GEMM LUT is initialized _grouped_gemm_lut_initialized = False @@ -1868,7 +1969,8 @@ def grouped_gemm_fp8_bf16( else: out_native = out._get_native() - native.grouped_gemm_fp8_bf16( + # Use new standardized name + native.grouped_gemm_fp8_bf16_sm120( a_native, b_stacked_native, b_scale_native, out_native, row_expert_ids_native ) @@ -1877,6 +1979,10 @@ def grouped_gemm_fp8_bf16( raise NotImplementedError("Grouped GEMM requires native GPU backend") +# Alias for standardized naming +grouped_gemm_fp8_bf16_sm120 = grouped_gemm_fp8_bf16 + + def fp8_get_sizes(K: int, N: int) -> tuple[int, int, int]: """Get scale tensor dimensions for FP8 block quantization.