diff --git a/src/alignrl/inference.py b/src/alignrl/inference.py index 6f39c31..f00a938 100644 --- a/src/alignrl/inference.py +++ b/src/alignrl/inference.py @@ -14,6 +14,8 @@ class InferenceConfig(BaseModel): max_tokens: int = 512 top_p: float = 0.9 backend: str = "unsloth" # "unsloth", "vllm", or "mlx" + max_seq_length: int = 2048 + load_in_4bit: bool = True def build_prompt(user_message: str, system: str | None = None) -> list[dict[str, str]]: @@ -48,8 +50,8 @@ def _load_unsloth(self) -> None: model_path = self.config.adapter_path or self.config.model_name self._model, self._tokenizer = FastLanguageModel.from_pretrained( model_name=model_path, - max_seq_length=2048, - load_in_4bit=True, + max_seq_length=self.config.max_seq_length, + load_in_4bit=self.config.load_in_4bit, ) from alignrl.config import ensure_chat_template