Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 47 additions & 32 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,17 @@ def _(
return out

if has_avx512bf16():
gemm_4bit_forward_kernel = None
try:
from kernels import get_kernel

gemm_4bit_forward_kernel = get_kernel("kernels-community/quantization_bitsandbytes").gemm_4bit_forward
except Exception as exc: # pragma: no cover - best effort fallback
gemm_4bit_forward_kernel = None
logger.warning(
"Failed to load CPU gemm_4bit kernel: %s. Please make sure you already `pip install kernels` and the kernels >= 0.11.1",
exc,
)

@register_kernel("bitsandbytes::gemv_4bit", "cpu")
def _(
Expand All @@ -239,38 +250,42 @@ def _(
final_out_shape = (*A.shape[:-1], shapeB[0])
A = A.reshape(-1, A.shape[-1])
out_shape = (*A.shape[:-1], shapeB[0])
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
M = A.shape[0]
N = shapeB[0]
K = A.shape[1]
x_strideM = A.stride(0)
out_strideM = out.stride(0)
if quant_type == "fp4":
lib.gemv_4bit_inference_cpu_fp4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)
elif quant_type == "nf4":
lib.gemv_4bit_inference_cpu_nf4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)
quant_type_num = 1 if quant_type == "fp4" else 0
if gemm_4bit_forward_kernel is not None:
out = gemm_4bit_forward_kernel(A, B, absmax, blocksize, quant_type_num)
else:
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
M = A.shape[0]
N = shapeB[0]
K = A.shape[1]
x_strideM = A.stride(0)
out_strideM = out.stride(0)
if quant_type == "fp4":
lib.gemv_4bit_inference_cpu_fp4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)
elif quant_type == "nf4":
lib.gemv_4bit_inference_cpu_nf4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)

if dtype != torch.bfloat16:
out = out.to(dtype)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ classifiers = [
dependencies = [
"torch>=2.3,<3",
"numpy>=1.17",
"packaging>=20.9"
"packaging>=20.9",
]

[project.urls]
Expand Down
1 change: 0 additions & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
quant_type=quant_type,
)
B_q, state = bitsandbytes.functional._convert_weight_packed_for_cpu(B_q, state)
B_q = B_q.t()
absmax = state.absmax
out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)

Expand Down