From dec794a7db5ebbd4eb936db818906c51e3e59d16 Mon Sep 17 00:00:00 2001 From: Lee Jun Hyuk Date: Thu, 26 Mar 2026 20:26:47 +0900 Subject: [PATCH] feat(model): add quantization support for LLM2Vec text encoder Add KIMODO_QUANTIZE env var to load the Llama-3-8B text encoder with reduced precision via bitsandbytes: KIMODO_QUANTIZE=4bit - NF4 4-bit (~5GB VRAM, down from ~17GB) KIMODO_QUANTIZE=8bit - INT8 8-bit (~9GB VRAM) This makes Kimodo usable on consumer GPUs (8-12GB) while retaining full text-prompt support. The quantized model is pinned to its device to avoid errors from .to() calls on quantized weights. Requires: pip install bitsandbytes accelerate --- kimodo/model/llm2vec/llm2vec.py | 6 ++++ kimodo/model/llm2vec/llm2vec_wrapper.py | 46 ++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/kimodo/model/llm2vec/llm2vec.py b/kimodo/model/llm2vec/llm2vec.py index 6d01f57..8128fe5 100644 --- a/kimodo/model/llm2vec/llm2vec.py +++ b/kimodo/model/llm2vec/llm2vec.py @@ -87,6 +87,12 @@ def __init__( self.max_length = max_length self.doc_max_length = doc_max_length self.config = model.config + self._is_quantized = getattr(model, "is_quantized", False) or hasattr(model, "quantization_method") + + def to(self, *args, **kwargs): + if self._is_quantized: + return self + return super().to(*args, **kwargs) @classmethod def _get_model_class(cls, config_class_name, enable_bidirectional): diff --git a/kimodo/model/llm2vec/llm2vec_wrapper.py b/kimodo/model/llm2vec/llm2vec_wrapper.py index eb33c87..dce2d80 100644 --- a/kimodo/model/llm2vec/llm2vec_wrapper.py +++ b/kimodo/model/llm2vec/llm2vec_wrapper.py @@ -9,6 +9,39 @@ from .llm2vec import LLM2Vec +# KIMODO_QUANTIZE options: +# "4bit" - NF4 4-bit quantization (~5GB VRAM for Llama-3-8B) +# "8bit" - INT8 8-bit quantization (~9GB VRAM for Llama-3-8B) +# unset - no quantization, full precision (~17GB VRAM) +QUANTIZE_PRESETS = { + "4bit": { + "load_in_4bit": True, + "bnb_4bit_compute_dtype": "float16", + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_use_double_quant": True, + }, + "8bit": { + "load_in_8bit": True, + }, +} + + +def _build_quantization_config(): + """Build BitsAndBytes quantization config from KIMODO_QUANTIZE env var.""" + quantize = os.environ.get("KIMODO_QUANTIZE", "").lower() + if not quantize: + return None + if quantize not in QUANTIZE_PRESETS: + available = ", ".join(sorted(QUANTIZE_PRESETS)) + raise ValueError( + f"Unknown KIMODO_QUANTIZE='{quantize}'. Available: {available}" + ) + from transformers import BitsAndBytesConfig + kwargs = QUANTIZE_PRESETS[quantize].copy() + if "bnb_4bit_compute_dtype" in kwargs: + kwargs["bnb_4bit_compute_dtype"] = getattr(torch, kwargs["bnb_4bit_compute_dtype"]) + return BitsAndBytesConfig(**kwargs) + class LLM2VecEncoder: """LLM2Vec text embeddings.""" @@ -29,18 +62,29 @@ def __init__( base_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], base_model_name_or_path) peft_model_name_or_path = os.path.join(os.environ["TEXT_ENCODERS_DIR"], peft_model_name_or_path) + extra_kwargs = {} + quantization_config = _build_quantization_config() + if quantization_config is not None: + extra_kwargs["quantization_config"] = quantization_config + extra_kwargs["device_map"] = "auto" + mode = os.environ.get("KIMODO_QUANTIZE", "").lower() + print(f"[Kimodo] Using {mode} quantization for text encoder to reduce VRAM usage") + self.model = LLM2Vec.from_pretrained( base_model_name_or_path=base_model_name_or_path, peft_model_name_or_path=peft_model_name_or_path, torch_dtype=torch_dtype, cache_dir=cache_dir, + **extra_kwargs, ) self.model.eval() for p in self.model.parameters(): p.requires_grad = False + self._quantized = quantization_config is not None def to(self, device: torch.device): - self.model = self.model.to(device) + if not self._quantized: + self.model = self.model.to(device) return self def eval(self):