diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index def87045c..a2c74488b 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -219,6 +219,17 @@ def _( return out if has_avx512bf16(): + gemm_4bit_forward_kernel = None + try: + from kernels import get_kernel + + gemm_4bit_forward_kernel = get_kernel("kernels-community/quantization_bitsandbytes").gemm_4bit_forward + except Exception as exc: # pragma: no cover - best effort fallback + gemm_4bit_forward_kernel = None + logger.warning( + "Failed to load CPU gemm_4bit kernel: %s. Please make sure you already `pip install kernels` and the kernels >= 0.11.1", + exc, + ) @register_kernel("bitsandbytes::gemv_4bit", "cpu") def _( @@ -239,38 +250,42 @@ def _( final_out_shape = (*A.shape[:-1], shapeB[0]) A = A.reshape(-1, A.shape[-1]) out_shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(out_shape, dtype=A.dtype, device=A.device) - M = A.shape[0] - N = shapeB[0] - K = A.shape[1] - x_strideM = A.stride(0) - out_strideM = out.stride(0) - if quant_type == "fp4": - lib.gemv_4bit_inference_cpu_fp4_bf16( - ct.c_int64(M), - ct.c_int64(N), - ct.c_int64(K), - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(out), - ct.c_int64(blocksize), - ct.c_int64(x_strideM), - ct.c_int64(out_strideM), - ) - elif quant_type == "nf4": - lib.gemv_4bit_inference_cpu_nf4_bf16( - ct.c_int64(M), - ct.c_int64(N), - ct.c_int64(K), - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(out), - ct.c_int64(blocksize), - ct.c_int64(x_strideM), - ct.c_int64(out_strideM), - ) + quant_type_num = 1 if quant_type == "fp4" else 0 + if gemm_4bit_forward_kernel is not None: + out = gemm_4bit_forward_kernel(A, B, absmax, blocksize, quant_type_num) + else: + out = torch.empty(out_shape, dtype=A.dtype, device=A.device) + M = A.shape[0] + N = shapeB[0] + K = A.shape[1] + x_strideM = A.stride(0) + out_strideM = out.stride(0) + if quant_type == "fp4": + lib.gemv_4bit_inference_cpu_fp4_bf16( + ct.c_int64(M), + ct.c_int64(N), + ct.c_int64(K), + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(out), + ct.c_int64(blocksize), + ct.c_int64(x_strideM), + ct.c_int64(out_strideM), + ) + elif quant_type == "nf4": + lib.gemv_4bit_inference_cpu_nf4_bf16( + ct.c_int64(M), + ct.c_int64(N), + ct.c_int64(K), + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(out), + ct.c_int64(blocksize), + ct.c_int64(x_strideM), + ct.c_int64(out_strideM), + ) if dtype != torch.bfloat16: out = out.to(dtype) diff --git a/pyproject.toml b/pyproject.toml index 65f9314c5..f448a079e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ classifiers = [ dependencies = [ "torch>=2.3,<3", "numpy>=1.17", - "packaging>=20.9" + "packaging>=20.9", ] [project.urls] diff --git a/tests/test_ops.py b/tests/test_ops.py index 8d9aa5ab2..3218b9215 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -237,7 +237,6 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): quant_type=quant_type, ) B_q, state = bitsandbytes.functional._convert_weight_packed_for_cpu(B_q, state) - B_q = B_q.t() absmax = state.absmax out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)