diff --git a/chandra/model/__init__.py b/chandra/model/__init__.py index 12af5e4..0909aa5 100644 --- a/chandra/model/__init__.py +++ b/chandra/model/__init__.py @@ -29,6 +29,7 @@ def generate( ) bbox_scale = kwargs.pop("bbox_scale", settings.BBOX_SCALE) vllm_api_base = kwargs.pop("vllm_api_base", settings.VLLM_API_BASE) + vllm_api_key = kwargs.pop("vllm_api_key", settings.VLLM_API_KEY) if self.method == "vllm": results = generate_vllm( @@ -36,6 +37,7 @@ def generate( max_output_tokens=max_output_tokens, bbox_scale=bbox_scale, vllm_api_base=vllm_api_base, + vllm_api_key=vllm_api_key, **kwargs, ) else: diff --git a/chandra/model/vllm.py b/chandra/model/vllm.py index 94ad975..e7f4564 100644 --- a/chandra/model/vllm.py +++ b/chandra/model/vllm.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from itertools import repeat from typing import List +import logging from PIL import Image from openai import OpenAI @@ -14,6 +15,9 @@ from chandra.settings import settings +logger = logging.getLogger(__name__) + + def image_to_base64(image: Image.Image) -> str: """Convert PIL Image to base64 string.""" buffered = io.BytesIO() @@ -30,11 +34,16 @@ def generate_vllm( max_failure_retries: int | None = None, bbox_scale: int = settings.BBOX_SCALE, vllm_api_base: str = settings.VLLM_API_BASE, + vllm_api_key: str = settings.VLLM_API_KEY, temperature: float = 0.0, top_p: float = 0.1, ) -> List[GenerationResult]: + if not vllm_api_base.endswith("/v1"): + # this can fail with + # Exception: Unexpected endpoint or method. (POST /chat/completions) + logger.warning(f"vllm_api_base does not end with '/v1': {vllm_api_base!r}") client = OpenAI( - api_key=settings.VLLM_API_KEY, + api_key=vllm_api_key, base_url=vllm_api_base, default_headers=custom_headers, ) @@ -78,6 +87,8 @@ def _generate(item: BatchInputItem, temperature, top_p) -> GenerationResult: temperature=temperature, top_p=top_p, ) + if hasattr(completion, "error"): + raise Exception(completion.error) raw = completion.choices[0].message.content result = GenerationResult( raw=raw, diff --git a/chandra/scripts/cli.py b/chandra/scripts/cli.py index e7d0d83..1b3dd34 100755 --- a/chandra/scripts/cli.py +++ b/chandra/scripts/cli.py @@ -7,6 +7,7 @@ from chandra.input import load_file from chandra.model import InferenceManager from chandra.model.schema import BatchInputItem +from chandra.settings import settings def get_supported_files(input_path: Path) -> List[Path]: @@ -20,6 +21,7 @@ def get_supported_files(input_path: Path) -> List[Path]: ".webp", ".tiff", ".bmp", + ".avif", } if input_path.is_file(): @@ -159,6 +161,18 @@ def save_merged_output( default=None, help="Maximum number of retries for vLLM inference.", ) +@click.option( + "--vllm-api-base", + type=str, + default=settings.VLLM_API_BASE, + help=f"default: {settings.VLLM_API_BASE!r}", +) +@click.option( + "--vllm-api-key", + type=str, + default=settings.VLLM_API_KEY, + help=f"default: {settings.VLLM_API_KEY!r}", +) @click.option( "--include-images/--no-images", default=True, @@ -193,6 +207,8 @@ def main( max_output_tokens: int, max_workers: int, max_retries: int, + vllm_api_base: str, + vllm_api_key: str, include_images: bool, include_headers_footers: bool, save_html: bool, @@ -273,6 +289,10 @@ def main( generate_kwargs["max_workers"] = max_workers if max_retries is not None: generate_kwargs["max_retries"] = max_retries + if vllm_api_base is not None: + generate_kwargs["vllm_api_base"] = vllm_api_base + if vllm_api_key is not None: + generate_kwargs["vllm_api_key"] = vllm_api_key results = model.generate(batch, **generate_kwargs) all_results.extend(results)