From bd8b77c3844cb048cff9ed2c88d70b3466ae15b4 Mon Sep 17 00:00:00 2001 From: "chenxiaofei.cxf" Date: Thu, 19 Mar 2026 23:22:10 +0800 Subject: [PATCH] feat(rerank): add OpenAI-compatible rerank provider - Add OpenAIRerankClient using standard flat request/response format compatible with DashScope compatible-api and other OpenAI/Cohere-style rerank APIs (no input/output wrappers) - Fix silent data corruption: add index bounds-checking so out-of-bounds or missing index returns None with a warning - Add provider allow-list validation in RerankConfig ('vikingdb'|'openai') - Remove unnecessary getattr() in RerankClient.from_config() - Update ov.conf.example: keep vikingdb (doubao) as primary rerank config, add rerank_openai_example section for DashScope qwen3-rerank - Update docs (en/zh): add OpenAI-compatible provider example alongside existing volcengine example in configuration guide and schema - Add 24 tests covering success, edge cases, and factory dispatch Co-Authored-By: Claude Sonnet 4.6 --- docs/en/guides/01-configuration.md | 24 +- docs/zh/guides/01-configuration.md | 24 +- examples/ov.conf.example | 13 +- openviking_cli/utils/config/rerank_config.py | 31 ++- openviking_cli/utils/rerank.py | 5 + openviking_cli/utils/rerank_openai.py | 122 +++++++++ tests/misc/test_rerank_openai.py | 249 +++++++++++++++++++ 7 files changed, 458 insertions(+), 10 deletions(-) create mode 100644 openviking_cli/utils/rerank_openai.py create mode 100644 tests/misc/test_rerank_openai.py diff --git a/docs/en/guides/01-configuration.md b/docs/en/guides/01-configuration.md index dacb665c..edbbffd4 100644 --- a/docs/en/guides/01-configuration.md +++ b/docs/en/guides/01-configuration.md @@ -410,11 +410,27 @@ Reranking model for search result refinement. } ``` +**OpenAI-compatible provider (e.g. DashScope qwen3-rerank):** + +```json +{ + "rerank": { + "provider": "openai", + "api_key": "your-api-key", + "api_base": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", + "model": "qwen3-rerank", + "threshold": 0.1 + } +} +``` + | Parameter | Type | Description | |-----------|------|-------------| -| `provider` | str | `"volcengine"` | +| `provider` | str | `"volcengine"` or `"openai"` | | `api_key` | str | API key | | `model` | str | Model name | +| `api_base` | str | Endpoint URL (openai provider only) | +| `threshold` | float | Score threshold; results below this are filtered out. Default: `0.1` | If rerank is not configured, search uses vector similarity only. @@ -729,9 +745,11 @@ For details on the lock mechanism, see [Path Locks and Crash Recovery](../concep "stream": false }, "rerank": { - "provider": "volcengine", + "provider": "volcengine|openai", "api_key": "string", - "model": "string" + "model": "string", + "api_base": "string", + "threshold": 0.1 }, "storage": { "workspace": "string", diff --git a/docs/zh/guides/01-configuration.md b/docs/zh/guides/01-configuration.md index bf6a8ea2..87e596c4 100644 --- a/docs/zh/guides/01-configuration.md +++ b/docs/zh/guides/01-configuration.md @@ -383,11 +383,27 @@ AST 提取支持:Python、JavaScript/TypeScript、Rust、Go、Java、C/C++。 } ``` +**OpenAI 兼容提供方(如 DashScope qwen3-rerank):** + +```json +{ + "rerank": { + "provider": "openai", + "api_key": "your-api-key", + "api_base": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", + "model": "qwen3-rerank", + "threshold": 0.1 + } +} +``` + | 参数 | 类型 | 说明 | |------|------|------| -| `provider` | str | `"volcengine"` | +| `provider` | str | `"volcengine"` 或 `"openai"` | | `api_key` | str | API Key | | `model` | str | 模型名称 | +| `api_base` | str | 接口地址(openai 提供方专用) | +| `threshold` | float | 分数阈值,低于此值的结果会被过滤。默认:`0.1` | 如果未配置 Rerank,搜索仅使用向量相似度。 @@ -704,9 +720,11 @@ HTTP 客户端(`SyncHTTPClient` / `AsyncHTTPClient`)和 CLI 工具连接远 "stream": false }, "rerank": { - "provider": "volcengine", + "provider": "volcengine|openai", "api_key": "string", - "model": "string" + "model": "string", + "api_base": "string", + "threshold": 0.1 }, "storage": { "workspace": "string", diff --git a/examples/ov.conf.example b/examples/ov.conf.example index 249c0a13..e2573abc 100644 --- a/examples/ov.conf.example +++ b/examples/ov.conf.example @@ -64,13 +64,22 @@ "thinking": false }, "rerank": { - "ak": null, - "sk": null, + "provider": "vikingdb", + "ak": "{your-ak}", + "sk": "{your-sk}", "host": "api-vikingdb.vikingdb.cn-beijing.volces.com", "model_name": "doubao-seed-rerank", "model_version": "251028", "threshold": 0.1 }, + "rerank_openai_example": { + "_comment": "For OpenAI-compatible rerank providers (e.g. DashScope qwen3-rerank):", + "provider": "openai", + "api_key": "{your-api-key}", + "api_base": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", + "model": "qwen3-rerank", + "threshold": 0.1 + }, "auto_generate_l0": true, "auto_generate_l1": true, "default_search_mode": "thinking", diff --git a/openviking_cli/utils/config/rerank_config.py b/openviking_cli/utils/config/rerank_config.py index 076ef9f6..91c0b023 100644 --- a/openviking_cli/utils/config/rerank_config.py +++ b/openviking_cli/utils/config/rerank_config.py @@ -2,12 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator class RerankConfig(BaseModel): - """Configuration for VikingDB Rerank API.""" + """Configuration for rerank API (VikingDB or OpenAI-compatible providers).""" + provider: str = Field(default="vikingdb", description="Rerank provider: 'vikingdb' or 'openai'") + + # VikingDB fields ak: Optional[str] = Field(default=None, description="VikingDB Access Key") sk: Optional[str] = Field(default=None, description="VikingDB Secret Key") host: str = Field( @@ -15,12 +18,36 @@ class RerankConfig(BaseModel): ) model_name: str = Field(default="doubao-seed-rerank", description="Rerank model name") model_version: str = Field(default="251028", description="Rerank model version") + + # OpenAI-compatible fields + api_key: Optional[str] = Field( + default=None, description="Bearer token for OpenAI-compatible providers" + ) + api_base: Optional[str] = Field(default=None, description="Custom endpoint URL") + model: Optional[str] = Field( + default=None, description="Model name for OpenAI-compatible providers" + ) + threshold: float = Field( default=0.1, description="Relevance threshold (score > threshold is relevant)" ) model_config = {"extra": "forbid"} + @model_validator(mode="after") + def validate_provider_fields(self) -> "RerankConfig": + allowed = ["vikingdb", "openai"] + if self.provider not in allowed: + raise ValueError(f"Rerank provider must be one of {allowed}, got '{self.provider}'") + if self.provider == "openai": + if not self.api_key or not self.api_base: + raise ValueError( + "OpenAI-compatible rerank provider requires 'api_key' and 'api_base'" + ) + return self + def is_available(self) -> bool: """Check if rerank is configured.""" + if self.provider == "openai": + return self.api_key is not None and self.api_base is not None return self.ak is not None and self.sk is not None diff --git a/openviking_cli/utils/rerank.py b/openviking_cli/utils/rerank.py index 10192c97..8ca90602 100644 --- a/openviking_cli/utils/rerank.py +++ b/openviking_cli/utils/rerank.py @@ -156,6 +156,11 @@ def from_config(cls, config) -> Optional["RerankClient"]: if not config or not config.is_available(): return None + if config.provider == "openai": + from openviking_cli.utils.rerank_openai import OpenAIRerankClient + + return OpenAIRerankClient.from_config(config) + return cls( ak=config.ak, sk=config.sk, diff --git a/openviking_cli/utils/rerank_openai.py b/openviking_cli/utils/rerank_openai.py new file mode 100644 index 00000000..28557b53 --- /dev/null +++ b/openviking_cli/utils/rerank_openai.py @@ -0,0 +1,122 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +""" +OpenAI-compatible Rerank API Client. + +Supports third-party rerank services like Alibaba Cloud DashScope (qwen3-rerank) +via api_key + api_base configuration. +""" + +from typing import List, Optional + +import requests + +from openviking_cli.utils.logger import get_logger + +logger = get_logger(__name__) + + +class OpenAIRerankClient: + """ + OpenAI-compatible rerank API client using Bearer token auth. + + Compatible with services like Alibaba Cloud DashScope. + """ + + def __init__(self, api_key: str, api_base: str, model_name: str): + """ + Initialize OpenAI-compatible rerank client. + + Args: + api_key: Bearer token for authentication + api_base: Full endpoint URL for the rerank API + model_name: Model name to use for reranking + """ + self.api_key = api_key + self.api_base = api_base + self.model_name = model_name + + def rerank_batch(self, query: str, documents: List[str]) -> Optional[List[float]]: + """ + Batch rerank documents against a query. + + Args: + query: Query text + documents: List of document texts to rank + + Returns: + List of rerank scores for each document (same order as input), + or None when rerank fails and the caller should fall back + """ + if not documents: + return [] + + req_body = { + "model": self.model_name, + "query": query, + "documents": documents, + } + + try: + response = requests.post( + url=self.api_base, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + json=req_body, + timeout=30, + ) + response.raise_for_status() + result = response.json() + + # Standard OpenAI/Cohere rerank format: results[].{index, relevance_score} + results = result.get("results") + if not results: + logger.warning(f"[OpenAIRerankClient] Unexpected response format: {result}") + return None + + if len(results) != len(documents): + logger.warning( + "[OpenAIRerankClient] Unexpected rerank result length: expected=%s actual=%s", + len(documents), + len(results), + ) + return None + + # Results may not be in original order — sort by index + scores = [0.0] * len(documents) + for item in results: + idx = item.get("index") + if idx is None or not (0 <= idx < len(documents)): + logger.warning( + "[OpenAIRerankClient] Out-of-bounds or missing index in result: %s", item + ) + return None + scores[idx] = item.get("relevance_score", 0.0) + + logger.debug(f"[OpenAIRerankClient] Reranked {len(documents)} documents") + return scores + + except Exception as e: + logger.error(f"[OpenAIRerankClient] Rerank failed: {e}") + return None + + @classmethod + def from_config(cls, config) -> Optional["OpenAIRerankClient"]: + """ + Create OpenAIRerankClient from RerankConfig. + + Args: + config: RerankConfig instance with provider='openai' + + Returns: + OpenAIRerankClient instance or None if config is not available + """ + if not config or not config.is_available(): + return None + return cls( + api_key=config.api_key, + api_base=config.api_base, + model_name=config.model or "qwen3-rerank", + ) diff --git a/tests/misc/test_rerank_openai.py b/tests/misc/test_rerank_openai.py new file mode 100644 index 00000000..dd033b96 --- /dev/null +++ b/tests/misc/test_rerank_openai.py @@ -0,0 +1,249 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for OpenAI-compatible rerank client and factory dispatch.""" + +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from openviking_cli.utils.config.rerank_config import RerankConfig +from openviking_cli.utils.rerank import RerankClient +from openviking_cli.utils.rerank_openai import OpenAIRerankClient + + +class TestOpenAIRerankClient: + def _make_client(self): + return OpenAIRerankClient( + api_key="test-key", + api_base="https://dashscope.aliyuncs.com/api/v1/services/rerank", + model_name="qwen3-rerank", + ) + + def test_rerank_batch_success(self): + client = self._make_client() + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.3}, + {"index": 2, "relevance_score": 0.7}, + ] + } + mock_response.raise_for_status = MagicMock() + + with patch("openviking_cli.utils.rerank_openai.requests.post", return_value=mock_response): + scores = client.rerank_batch("test query", ["doc1", "doc2", "doc3"]) + + assert scores == [0.9, 0.3, 0.7] + + def test_rerank_batch_out_of_order_results(self): + """Results returned out-of-order should be re-ordered by index.""" + client = self._make_client() + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"index": 2, "relevance_score": 0.7}, + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.3}, + ] + } + mock_response.raise_for_status = MagicMock() + + with patch("openviking_cli.utils.rerank_openai.requests.post", return_value=mock_response): + scores = client.rerank_batch("test query", ["doc1", "doc2", "doc3"]) + + assert scores == [0.9, 0.3, 0.7] + + def test_rerank_batch_empty_documents(self): + client = self._make_client() + scores = client.rerank_batch("query", []) + assert scores == [] + + def test_rerank_batch_unexpected_format_returns_none(self): + client = self._make_client() + mock_response = MagicMock() + mock_response.json.return_value = {"unexpected": "format"} + mock_response.raise_for_status = MagicMock() + + with patch("openviking_cli.utils.rerank_openai.requests.post", return_value=mock_response): + result = client.rerank_batch("query", ["doc1"]) + + assert result is None + + def test_rerank_batch_length_mismatch_returns_none(self): + client = self._make_client() + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"index": 0, "relevance_score": 0.9}, + ] + } + mock_response.raise_for_status = MagicMock() + + with patch("openviking_cli.utils.rerank_openai.requests.post", return_value=mock_response): + result = client.rerank_batch("query", ["doc1", "doc2"]) + + assert result is None + + def test_rerank_batch_out_of_bounds_index_returns_none(self): + """An index that is >= len(documents) should return None.""" + client = self._make_client() + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"index": 5, "relevance_score": 0.9}, # only 1 doc + ] + } + mock_response.raise_for_status = MagicMock() + + with patch("openviking_cli.utils.rerank_openai.requests.post", return_value=mock_response): + result = client.rerank_batch("query", ["doc1"]) + + assert result is None + + def test_rerank_batch_missing_index_field_returns_none(self): + """A result item with no 'index' key should return None.""" + client = self._make_client() + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"relevance_score": 0.9}, # missing 'index' + ] + } + mock_response.raise_for_status = MagicMock() + + with patch("openviking_cli.utils.rerank_openai.requests.post", return_value=mock_response): + result = client.rerank_batch("query", ["doc1"]) + + assert result is None + + def test_rerank_batch_http_error_returns_none(self): + client = self._make_client() + + with patch( + "openviking_cli.utils.rerank_openai.requests.post", + side_effect=Exception("connection error"), + ): + result = client.rerank_batch("query", ["doc1"]) + + assert result is None + + def test_rerank_batch_sends_correct_request(self): + client = self._make_client() + mock_response = MagicMock() + mock_response.json.return_value = {"results": [{"index": 0, "relevance_score": 0.8}]} + mock_response.raise_for_status = MagicMock() + + with patch( + "openviking_cli.utils.rerank_openai.requests.post", return_value=mock_response + ) as mock_post: + client.rerank_batch("my query", ["doc1"]) + + call_kwargs = mock_post.call_args + assert call_kwargs.kwargs["url"] == "https://dashscope.aliyuncs.com/api/v1/services/rerank" + assert call_kwargs.kwargs["headers"]["Authorization"] == "Bearer test-key" + body = call_kwargs.kwargs["json"] + assert body["model"] == "qwen3-rerank" + assert body["query"] == "my query" + assert body["documents"] == ["doc1"] + + def test_from_config(self): + config = RerankConfig( + provider="openai", + api_key="my-key", + api_base="https://example.com/rerank", + model="qwen3-rerank", + ) + client = OpenAIRerankClient.from_config(config) + assert isinstance(client, OpenAIRerankClient) + assert client.api_key == "my-key" + assert client.api_base == "https://example.com/rerank" + assert client.model_name == "qwen3-rerank" + + def test_from_config_default_model(self): + config = RerankConfig( + provider="openai", + api_key="my-key", + api_base="https://example.com/rerank", + ) + client = OpenAIRerankClient.from_config(config) + assert client.model_name == "qwen3-rerank" + + def test_from_config_unavailable_returns_none(self): + result = OpenAIRerankClient.from_config(None) + assert result is None + + +class TestRerankClientFactoryDispatch: + def test_factory_dispatches_to_openai_client(self): + config = RerankConfig( + provider="openai", + api_key="test-key", + api_base="https://example.com/rerank", + model="qwen3-rerank", + ) + client = RerankClient.from_config(config) + assert isinstance(client, OpenAIRerankClient) + + def test_factory_dispatches_to_vikingdb_client(self): + config = RerankConfig( + provider="vikingdb", + ak="test-ak", + sk="test-sk", + ) + client = RerankClient.from_config(config) + assert isinstance(client, RerankClient) + assert not isinstance(client, OpenAIRerankClient) + + def test_factory_defaults_to_vikingdb(self): + """Config without provider field defaults to vikingdb.""" + config = RerankConfig(ak="test-ak", sk="test-sk") + client = RerankClient.from_config(config) + assert isinstance(client, RerankClient) + assert not isinstance(client, OpenAIRerankClient) + + def test_factory_returns_none_for_none_config(self): + assert RerankClient.from_config(None) is None + + def test_factory_returns_none_for_unavailable_vikingdb_config(self): + config = RerankConfig() # no ak/sk + assert RerankClient.from_config(config) is None + + def test_factory_returns_none_for_unavailable_openai_config(self): + # This should raise validation error since openai requires api_key + api_base + with pytest.raises(ValidationError): + RerankConfig(provider="openai") + + +class TestRerankConfig: + def test_vikingdb_is_available(self): + config = RerankConfig(ak="ak", sk="sk") + assert config.is_available() is True + + def test_vikingdb_not_available_without_credentials(self): + config = RerankConfig() + assert config.is_available() is False + + def test_openai_is_available(self): + config = RerankConfig( + provider="openai", + api_key="key", + api_base="https://example.com/rerank", + ) + assert config.is_available() is True + + def test_openai_requires_api_key_and_api_base(self): + with pytest.raises(ValidationError): + RerankConfig(provider="openai", api_key="key") + + with pytest.raises(ValidationError): + RerankConfig(provider="openai", api_base="https://example.com/rerank") + + def test_default_provider_is_vikingdb(self): + config = RerankConfig() + assert config.provider == "vikingdb" + + def test_unknown_provider_raises_value_error(self): + with pytest.raises(ValueError, match="provider"): + RerankConfig(provider="cohere", ak="ak", sk="sk")