Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 61 additions & 38 deletions examples/python/qqmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,48 +33,71 @@ def test_qqmm():
[64, 128, 256, 1024, 1024 * 8], # N
[64, 128, 256, 1024, 1024 * 8], # K
)
layouts = ["TN", "NT", "TT", "NN"]
for group_size, mode, bits in tests:
for M, N, K in product(*shapes):
for dtype in dtypes:
x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype)
w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype)
w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode)
w_dq = mx.dequantize(
w_q,
scales_w,
group_size=group_size,
bits=bits,
mode=mode,
dtype=dtype,
)
y_q = mx.qqmm(
x,
w_q,
scales_w,
group_size=group_size,
bits=bits,
mode=mode,
)
x_q, scales_x = mx.quantize(
x, group_size=group_size, bits=bits, mode=mode
)
x_dq = mx.dequantize(
x_q,
scales_x,
group_size=group_size,
bits=bits,
mode=mode,
dtype=dtype,
)
y_hat = mx.matmul(x_dq, mx.transpose(w_dq))
ulp = ulp_bf16_at(y_hat)
error = (y_q - y_hat).abs()
if not (mx.logical_or(error < 1e-3, error <= ulp).all()):
raise AssertionError(
f"qqmm test failed for shape {(M, N, K)}, "
f"group_size={group_size}, bits={bits}, "
f"mode={mode}, dtype={dtype}"
for layout in layouts:
if layout == "NT":
x_shape = (M, K)
w_shape = (N, K)
elif layout == "TN":
x_shape = (K, M)
w_shape = (K, N)
elif layout == "TT":
x_shape = (K, M)
w_shape = (N, K)
else: # "NN"
x_shape = (M, K)
w_shape = (K, N)

x = mx.random.normal(shape=x_shape, key=k1, dtype=dtype)
w = mx.random.normal(shape=w_shape, key=k2, dtype=dtype)

if layout == "TT":
x = mx.transpose(x)
elif layout == "TN":
w = mx.transpose(w)
x = mx.transpose(x)
elif layout == "NN":
w = mx.transpose(w)

y_q = mx.qqmm(
x,
w,
group_size=group_size,
bits=bits,
mode=mode,
)
w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode)
w_dq = mx.dequantize(
w_q,
scales_w,
group_size=group_size,
bits=bits,
mode=mode,
dtype=dtype,
)
x_q, scales_x = mx.quantize(
x, group_size=group_size, bits=bits, mode=mode
)
x_dq = mx.dequantize(
x_q,
scales_x,
group_size=group_size,
bits=bits,
mode=mode,
dtype=dtype,
)
y_hat = mx.matmul(x_dq, mx.transpose(w_dq))
ulp = ulp_bf16_at(y_hat)
error = (y_q - y_hat).abs()
if not (mx.logical_or(error < 1e-3, error <= ulp).all()):
raise AssertionError(
f"qqmm test failed for shape {(M, N, K)}, "
f"group_size={group_size}, bits={bits}, "
f"mode={mode}, dtype={dtype}, layout={layout}"
)


def test_qqmm_vjp():
Expand Down
127 changes: 127 additions & 0 deletions mlx/backend/cuda/ptx.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#pragma once

#include <cuda.h>
#include <cuda_runtime.h>

namespace mlx::core {

namespace ptx {

#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
defined(__CUDA_ARCH_SPECIFIC__)

__device__ __forceinline__ void mbarrier_init(uint64_t* mbar, uint32_t count) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use the cuda::ptx APIs like cuda::ptx::mbarrier_init API instead? They don't have good documentation and you would have to search https://github.com/NVIDIA/cccl to find out API names though.

uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.init.shared.b64 [%0], %1;"
:
: "r"(mbar_ptr), "r"(count)
: "memory");
}

__device__ __forceinline__ void mbarrier_invalidate(uint64_t* mbar) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.inval.shared.b64 [%0];" : : "r"(mbar_ptr) : "memory");
}

// Arrive at barrier (non-master threads)
__device__ __forceinline__ void mbarrier_arrive(uint64_t* mbar) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.shared.b64 _, [%0];"
:
: "r"(mbar_ptr)
: "memory");
}

// Arrive at barrier and set expected transaction count (master thread)
__device__ __forceinline__ void mbarrier_arrive_expect_tx(
uint64_t* mbar,
uint32_t tx_count) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
: "r"(mbar_ptr), "r"(tx_count)
: "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-asynchronous-copy-completion-mechanisms-mbarrier
__device__ __forceinline__ void mbarrier_wait_parity(
uint64_t* mbar,
uint32_t parity) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile(
"{\n\t"
".reg .pred P;\n\t"
"WAIT_LOOP:\n\t"
"mbarrier.try_wait.parity.shared.b64 P, [%0], %1;\n\t"
"@!P bra WAIT_LOOP;\n\t"
"}\n\t"
:
: "r"(mbar_ptr), "r"(parity)
: "memory");
}

// Async bulk tensor copy: global -> shared (2D)
__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
void* dst_shmem,
const CUtensorMap* tensor_map,
uint32_t tile_x,
uint32_t tile_y,
uint64_t* mbar) {
uint32_t dst_ptr = __cvta_generic_to_shared(dst_shmem);
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);

asm volatile(
"cp.async.bulk.tensor.2d.shared::cluster.global.tile"
".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];"
:
: "r"(dst_ptr), "l"(tensor_map), "r"(tile_x), "r"(tile_y), "r"(mbar_ptr)
: "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global
__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global(
const uint64_t* tensor_map_ptr,
const uint32_t offset_x,
const uint32_t offset_y,
uint64_t* src_shmem) {
uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem);
asm volatile(
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::
"l"(tensor_map_ptr),
"r"(offset_x),
"r"(offset_y),
"r"(src_shmem_ptr)
: "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
template <int N>
__device__ __forceinline__ void cp_async_bulk_wait_group_read() {
if constexpr (N == 0) {
asm volatile("cp.async.bulk.wait_group.read 0;");
} else if constexpr (N == 1) {
asm volatile("cp.async.bulk.wait_group.read 1;");
} else if constexpr (N == 2) {
asm volatile("cp.async.bulk.wait_group.read 2;");
} else if constexpr (N == 4) {
asm volatile("cp.async.bulk.wait_group.read 4;");
}
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;");
}

// Ccreates a memory ordering barrier between generic and async proxies
// details:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#proxies
__device__ __forceinline__ void fence_proxy_async_shared_cta() {
asm volatile("fence.proxy.async.shared::cta;");
}

#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
} // namespace ptx
} // namespace mlx::core
Loading
Loading