Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions native/bindings/ops_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,13 @@ void init_ops_bindings(py::module_& m) {
m.def("matmul_", py::overload_cast<const GPUArray&, const GPUArray&, GPUArray&>(&ops::matmul),
py::arg("a"), py::arg("b"), py::arg("out"),
"Matrix multiplication with output array");

// TF32 variants
m.def("matmul_tf32", py::overload_cast<const GPUArray&, const GPUArray&, bool>(&ops::matmul),
py::arg("a"), py::arg("b"), py::arg("use_tf32"),
"Matrix multiplication with explicit TF32 control");

m.def("matmul_tf32_", py::overload_cast<const GPUArray&, const GPUArray&, GPUArray&, bool>(&ops::matmul),
py::arg("a"), py::arg("b"), py::arg("out"), py::arg("use_tf32"),
"Matrix multiplication with explicit TF32 control and output array");
}
144 changes: 144 additions & 0 deletions native/ops/basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -939,5 +939,149 @@ GPUArray matmul(const GPUArray& a, const GPUArray& b) {
return c;
}

// Internal helper: matmul with explicit TF32 control
static void matmul_impl(const GPUArray& a, const GPUArray& b, GPUArray& c, bool use_tf32_explicit) {
validate_matmul_shapes(a, b, "matmul");
validate_same_dtype(a, b, "matmul");

size_t M = a.shape()[0];
size_t K = a.shape()[1];
size_t N = b.shape()[1];

if (c.shape()[0] != M || c.shape()[1] != N) {
throw std::runtime_error("matmul output shape mismatch");
}

// Check GPU compute capability for TF32 support
int device;
cudaGetDevice(&device);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, device);
int sm_version = prop.major * 10 + prop.minor;

// TF32 only works with float32 and SM >= 80
bool tf32_enabled = use_tf32_explicit &&
(a.dtype() == DataType::Float32) &&
(sm_version >= 80);

if (use_tf32_explicit && !tf32_enabled) {
if (a.dtype() != DataType::Float32) {
throw std::runtime_error("TF32 matmul requires float32 dtype");
}
if (sm_version < 80) {
throw std::runtime_error("TF32 matmul requires SM >= 80 (Ampere or newer)");
}
}

// Use TF32 kernel for explicit request and large matrices
bool use_tf32 = tf32_enabled &&
((M >= OPTIMIZED_MATMUL_THRESHOLD &&
N >= OPTIMIZED_MATMUL_THRESHOLD &&
K >= OPTIMIZED_MATMUL_THRESHOLD) ||
(M == 16 && (N == 8 || N == 16)));

bool use_optimized = !use_tf32 &&
(a.dtype() == DataType::Float32) &&
(M >= OPTIMIZED_MATMUL_THRESHOLD ||
N >= OPTIMIZED_MATMUL_THRESHOLD ||
K >= OPTIMIZED_MATMUL_THRESHOLD);

bool use_tiled = !use_optimized && !use_tf32 &&
(M >= TILED_MATMUL_THRESHOLD ||
N >= TILED_MATMUL_THRESHOLD ||
K >= TILED_MATMUL_THRESHOLD);

if (use_tf32) {
// TF32 TensorCore kernels
if (M == 16 && (N == 8 || N == 16)) {
tf32::launch_single_tile_verified(
static_cast<const float*>(a.data()),
static_cast<const float*>(b.data()),
static_cast<float*>(c.data()),
M, N, K);
} else {
tf32::launch_sgemm_tf32(
static_cast<const float*>(a.data()),
static_cast<const float*>(b.data()),
static_cast<float*>(c.data()),
M, N, K);
}
} else if (use_optimized) {
ampere::launch_sgemm_ampere(
static_cast<const float*>(a.data()),
static_cast<const float*>(b.data()),
static_cast<float*>(c.data()),
M, N, K);
} else if (use_tiled) {
dim3 block_size(TILE_N / THREAD_N, TILE_M / THREAD_M);
dim3 grid_size(
(N + TILE_N - 1) / TILE_N,
(M + TILE_M - 1) / TILE_M
);

switch (a.dtype()) {
case DataType::Float32:
matmul_f32_tiled_kernel<<<grid_size, block_size>>>(
static_cast<const float*>(a.data()),
static_cast<const float*>(b.data()),
static_cast<float*>(c.data()),
M, N, K);
break;
case DataType::Float64:
matmul_f64_tiled_kernel<<<grid_size, block_size>>>(
static_cast<const double*>(a.data()),
static_cast<const double*>(b.data()),
static_cast<double*>(c.data()),
M, N, K);
break;
default:
throw std::runtime_error("matmul only supports float32 and float64");
}
} else {
dim3 block_size(BLOCK_SIZE, BLOCK_SIZE);
dim3 grid_size(
(N + BLOCK_SIZE - 1) / BLOCK_SIZE,
(M + BLOCK_SIZE - 1) / BLOCK_SIZE
);

switch (a.dtype()) {
case DataType::Float32:
matmul_f32_l2opt_kernel<<<grid_size, block_size>>>(
static_cast<const float*>(a.data()),
static_cast<const float*>(b.data()),
static_cast<float*>(c.data()),
M, N, K);
break;
case DataType::Float64:
matmul_f64_l2opt_kernel<<<grid_size, block_size>>>(
static_cast<const double*>(a.data()),
static_cast<const double*>(b.data()),
static_cast<double*>(c.data()),
M, N, K);
break;
default:
throw std::runtime_error("matmul only supports float32 and float64");
}
}

sync_and_check("matmul kernel failed");
}

void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c, bool use_tf32) {
matmul_impl(a, b, c, use_tf32);
}

GPUArray matmul(const GPUArray& a, const GPUArray& b, bool use_tf32) {
validate_matmul_shapes(a, b, "matmul");
validate_same_dtype(a, b, "matmul");

size_t M = a.shape()[0];
size_t N = b.shape()[1];

GPUArray c({M, N}, a.dtype());
matmul_impl(a, b, c, use_tf32);
return c;
}

} // namespace ops
} // namespace pygpukit
7 changes: 7 additions & 0 deletions native/ops/basic.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ void mul(const GPUArray& a, const GPUArray& b, GPUArray& c);
// a: (M, K), b: (K, N), c: (M, N)
void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c);

// Matrix multiplication with explicit TF32 control
// use_tf32: force TF32 TensorCore path (requires SM >= 80 and float32)
void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c, bool use_tf32);

// Convenience functions that return new arrays
GPUArray add(const GPUArray& a, const GPUArray& b);
GPUArray mul(const GPUArray& a, const GPUArray& b);
GPUArray matmul(const GPUArray& a, const GPUArray& b);

// Matmul with explicit TF32 control
GPUArray matmul(const GPUArray& a, const GPUArray& b, bool use_tf32);

} // namespace ops
} // namespace pygpukit
Loading