diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e27b432..fa78ff60 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,7 @@ name: 02. Main Branch Checks on: + workflow_dispatch: push: branches: [ main ] paths-ignore: @@ -16,6 +17,11 @@ on: - '.gitignore' - '.editorconfig' +permissions: + actions: read + contents: read + security-events: write + jobs: test-full: uses: ./.github/workflows/_test_full.yml diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index eba8e742..113a89dd 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -1,6 +1,7 @@ name: 01. Pull Request Checks on: + workflow_dispatch: pull_request: branches: [ main, develop ] paths-ignore: diff --git a/.github/workflows/schedule.yml b/.github/workflows/schedule.yml index 49c9d67e..0b2a2f2a 100644 --- a/.github/workflows/schedule.yml +++ b/.github/workflows/schedule.yml @@ -4,6 +4,11 @@ on: schedule: - cron: '0 0 * * 0' # Run at 00:00 on Sunday +permissions: + actions: read + contents: read + security-events: write + jobs: security-scan: uses: ./.github/workflows/_codeql.yml diff --git a/docs/en/configuration/configuration.md b/docs/en/configuration/configuration.md index d15e0f9f..5b29e8cc 100644 --- a/docs/en/configuration/configuration.md +++ b/docs/en/configuration/configuration.md @@ -207,6 +207,9 @@ client = ov.AsyncOpenViking(config=config) } ``` +Notes: +- `storage.vectordb.sparse_weight` controls hybrid (dense + sparse) indexing/search. It only takes effect when you use a hybrid index; set it > 0 to enable sparse signals. + ## Related Documentation - [Embedding Configuration](./embedding.md) - Embedding setup diff --git a/docs/zh/configuration/configuration.md b/docs/zh/configuration/configuration.md index 93f5e21f..bef6327b 100644 --- a/docs/zh/configuration/configuration.md +++ b/docs/zh/configuration/configuration.md @@ -294,6 +294,9 @@ client = ov.AsyncOpenViking(config=config) } ``` +说明: +- `storage.vectordb.sparse_weight` 用于混合(dense + sparse)索引/检索的权重,仅在使用 hybrid 索引时生效;设置为 > 0 才会启用 sparse 信号。 + ## 相关文档 - [Embedding 配置](./embedding.md) - Embedding 设置 diff --git a/openviking/agfs_manager.py b/openviking/agfs_manager.py index 6e3d1812..66e96a74 100644 --- a/openviking/agfs_manager.py +++ b/openviking/agfs_manager.py @@ -38,14 +38,19 @@ class AGFSManager: manager.start() # 2. S3 backend + from openviking.utils.config.agfs_config import AGFSConfig, S3Config + config = AGFSConfig( path="./data", port=8080, backend="s3", - s3_bucket="my-bucket", - s3_region="us-east-1", - s3_access_key="your-access-key", - s3_secret_key="your-secret-key", + s3=S3Config( + bucket="my-bucket", + region="us-east-1", + access_key="your-access-key", + secret_key="your-secret-key", + endpoint="https://s3.amazonaws.com" + ), log_level="debug" ) manager = AGFSManager(config=config) @@ -73,13 +78,13 @@ def __init__( self.port = config.port self.log_level = config.log_level self.backend = config.backend - self.s3_bucket = config.s3_bucket - self.s3_region = config.s3_region - self.s3_access_key = config.s3_access_key - self.s3_secret_key = config.s3_secret_key - self.s3_endpoint = config.s3_endpoint - self.s3_prefix = config.s3_prefix - self.s3_use_ssl = config.s3_use_ssl + self.s3_bucket = config.s3.bucket + self.s3_region = config.s3.region + self.s3_access_key = config.s3.access_key + self.s3_secret_key = config.s3.secret_key + self.s3_endpoint = config.s3.endpoint + self.s3_prefix = config.s3.prefix + self.s3_use_ssl = config.s3.use_ssl self.process: Optional[subprocess.Popen] = None self.config_file: Optional[Path] = None diff --git a/openviking/session/session.py b/openviking/session/session.py index 225b07e4..689dbe47 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -5,7 +5,6 @@ Session as Context: Sessions integrated into L0/L1/L2 system. """ -import asyncio import json import re from dataclasses import dataclass, field @@ -14,7 +13,7 @@ from uuid import uuid4 from openviking.message import Message, Part -from openviking.utils import get_logger +from openviking.utils import get_logger, run_async from openviking.utils.config import get_openviking_config if TYPE_CHECKING: @@ -25,20 +24,6 @@ logger = get_logger(__name__) -def _run_async(coro): - """Run async coroutine.""" - try: - loop = asyncio.get_running_loop() - # If already in event loop, use nest_asyncio or return directly - import nest_asyncio - - nest_asyncio.apply() - return loop.run_until_complete(coro) - except RuntimeError: - # No running event loop, use asyncio.run() - return asyncio.run(coro) - - @dataclass class SessionCompression: """Session compression information.""" @@ -109,7 +94,7 @@ def load(self): return try: - content = _run_async( + content = run_async( self._viking_fs.read_file(f"{self._session_uri}/messages.jsonl") ) self._messages = [ @@ -123,7 +108,7 @@ def load(self): # Restore compression_index (scan history directory) try: - history_items = _run_async(self._viking_fs.ls(f"{self._session_uri}/history")) + history_items = run_async(self._viking_fs.ls(f"{self._session_uri}/history")) archives = [ item["name"] for item in history_items if item["name"].startswith("archive_") ] @@ -254,7 +239,7 @@ def commit(self) -> Dict[str, Any]: logger.info( f"Starting memory extraction from {len(messages_to_archive)} archived messages" ) - memories = _run_async( + memories = run_async( self._session_compressor.extract_long_term_memories( messages=messages_to_archive, user=self.user, @@ -298,7 +283,7 @@ def _update_active_counts(self) -> int: for usage in self._usage_records: try: - _run_async( + run_async( storage.update( collection="context", filter={"uri": usage.uri}, @@ -334,7 +319,7 @@ def get_context_for_search( summaries = [] if self.compression.compression_index > 0: try: - history_items = _run_async(self._viking_fs.ls(f"{self._session_uri}/history")) + history_items = run_async(self._viking_fs.ls(f"{self._session_uri}/history")) query_lower = query.lower() # Collect all archives with relevance scores @@ -344,7 +329,7 @@ def get_context_for_search( if name and name.startswith("archive_"): overview_uri = f"{self._session_uri}/history/{name}/.overview.md" try: - overview = _run_async(self._viking_fs.read_file(overview_uri)) + overview = run_async(self._viking_fs.read_file(overview_uri)) # Calculate relevance by keyword matching score = 0 if query_lower in overview.lower(): @@ -397,7 +382,7 @@ def _generate_archive_summary(self, messages: List[Message]) -> str: "compression.structured_summary", {"messages": formatted}, ) - return _run_async(vlm.get_completion_async(prompt)) + return run_async(vlm.get_completion_async(prompt)) except Exception as e: logger.warning(f"LLM summary failed: {e}") @@ -420,15 +405,15 @@ def _write_archive( # Write messages.jsonl lines = [m.to_jsonl() for m in messages] - _run_async( + run_async( viking_fs.write_file( uri=f"{archive_uri}/messages.jsonl", content="\n".join(lines) + "\n", ) ) - _run_async(viking_fs.write_file(uri=f"{archive_uri}/.abstract.md", content=abstract)) - _run_async(viking_fs.write_file(uri=f"{archive_uri}/.overview.md", content=overview)) + run_async(viking_fs.write_file(uri=f"{archive_uri}/.abstract.md", content=abstract)) + run_async(viking_fs.write_file(uri=f"{archive_uri}/.overview.md", content=overview)) logger.debug(f"Written archive: {archive_uri}") @@ -446,7 +431,7 @@ def _write_to_agfs(self, messages: List[Message]) -> None: lines = [m.to_jsonl() for m in messages] content = "\n".join(lines) + "\n" if lines else "" - _run_async( + run_async( viking_fs.write_file( uri=f"{self._session_uri}/messages.jsonl", content=content, @@ -454,13 +439,13 @@ def _write_to_agfs(self, messages: List[Message]) -> None: ) # Update L0/L1 - _run_async( + run_async( viking_fs.write_file( uri=f"{self._session_uri}/.abstract.md", content=abstract, ) ) - _run_async( + run_async( viking_fs.write_file( uri=f"{self._session_uri}/.overview.md", content=overview, @@ -471,7 +456,7 @@ def _append_to_jsonl(self, msg: Message) -> None: """Append to messages.jsonl.""" if not self._viking_fs: return - _run_async( + run_async( self._viking_fs.append_file( f"{self._session_uri}/messages.jsonl", msg.to_jsonl() + "\n", @@ -485,7 +470,7 @@ def _update_message_in_jsonl(self) -> None: lines = [m.to_jsonl() for m in self._messages] content = "\n".join(lines) + "\n" - _run_async( + run_async( self._viking_fs.write_file( f"{self._session_uri}/messages.jsonl", content, @@ -516,7 +501,7 @@ def _save_tool_result( "status": status, "time": {"created": datetime.now().isoformat()}, } - _run_async( + run_async( self._viking_fs.write_file( f"{self._session_uri}/tools/{tool_id}/tool.json", json.dumps(tool_data, ensure_ascii=False), @@ -563,7 +548,7 @@ def _write_relations(self) -> None: viking_fs = self._viking_fs for usage in self._usage_records: try: - _run_async(viking_fs.link(self._session_uri, usage.uri)) + run_async(viking_fs.link(self._session_uri, usage.uri)) logger.debug(f"Created relation: {self._session_uri} -> {usage.uri}") except Exception as e: logger.warning(f"Failed to create relation to {usage.uri}: {e}") diff --git a/openviking/storage/observers/README.md b/openviking/storage/observers/README.md index 43f7c229..071bbacf 100644 --- a/openviking/storage/observers/README.md +++ b/openviking/storage/observers/README.md @@ -38,7 +38,7 @@ Monitors queue system status (Embedding, Semantic, and custom queues). import openviking as ov client = ov.OpenViking(path="./data") -print(client.observers["queue"]) +print(client.observer.queue) # Output: # Queue Pending In Progress Processed Errors Total # Embedding 5 2 100 0 107 @@ -58,7 +58,7 @@ Monitors VikingDB collection status (index count and vector count per collection import openviking as ov client = ov.OpenViking(path="./data") -print(client.observers["vikingdb"]) +print(client.observer.vikingdb) # Output: # Collection Index Count Vector Count Status # context 1 69 OK diff --git a/openviking/storage/observers/async_utils.py b/openviking/storage/observers/async_utils.py deleted file mode 100644 index 5b2e30cd..00000000 --- a/openviking/storage/observers/async_utils.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -""" -Async helper utilities for observers. -""" - -import asyncio -import threading -from typing import Any, Awaitable, Callable - - -def run_coroutine_sync(coro_factory: Callable[[], Awaitable[Any]]) -> Any: - """ - Run a coroutine from sync code. - - If an event loop is already running in this thread, execute the coroutine - in a dedicated thread with its own event loop. - """ - - try: - asyncio.get_running_loop() - except RuntimeError: - return asyncio.run(coro_factory()) - - result: dict = {} - - def runner() -> None: - try: - result["value"] = asyncio.run(coro_factory()) - except Exception as exc: - result["error"] = exc - - thread = threading.Thread(target=runner, daemon=True) - thread.start() - thread.join() - - if "error" in result: - raise result["error"] - - return result.get("value") diff --git a/openviking/storage/observers/queue_observer.py b/openviking/storage/observers/queue_observer.py index b7efa8b8..fec514d4 100644 --- a/openviking/storage/observers/queue_observer.py +++ b/openviking/storage/observers/queue_observer.py @@ -8,10 +8,10 @@ from typing import Dict -from openviking.storage.observers.async_utils import run_coroutine_sync from openviking.storage.observers.base_observer import BaseObserver from openviking.storage.queuefs.named_queue import QueueStatus from openviking.storage.queuefs.queue_manager import QueueManager +from openviking.utils import run_async from openviking.utils.logger import get_logger logger = get_logger(__name__) @@ -32,7 +32,7 @@ async def get_status_table_async(self) -> str: return self._format_status_as_table(statuses) def get_status_table(self) -> str: - return run_coroutine_sync(self.get_status_table_async) + return run_async(self.get_status_table_async()) def __str__(self) -> str: return self.get_status_table() diff --git a/openviking/storage/observers/vikingdb_observer.py b/openviking/storage/observers/vikingdb_observer.py index 5ede61ba..91b7d38f 100644 --- a/openviking/storage/observers/vikingdb_observer.py +++ b/openviking/storage/observers/vikingdb_observer.py @@ -8,9 +8,9 @@ from typing import Dict -from openviking.storage.observers.async_utils import run_coroutine_sync from openviking.storage.observers.base_observer import BaseObserver from openviking.storage.vikingdb_manager import VikingDBManager +from openviking.utils import run_async from openviking.utils.logger import get_logger logger = get_logger(__name__) @@ -39,7 +39,7 @@ async def get_status_table_async(self) -> str: return self._format_status_as_table(statuses) def get_status_table(self) -> str: - return run_coroutine_sync(self.get_status_table_async) + return run_async(self.get_status_table_async()) def __str__(self) -> str: return self.get_status_table() @@ -156,7 +156,7 @@ def has_errors(self) -> bool: try: if not self._vikingdb_manager: return True - run_coroutine_sync(self._vikingdb_manager.health_check) + run_async(self._vikingdb_manager.health_check()) return False except Exception as e: logger.error(f"VikingDB health check failed: {e}") diff --git a/openviking/storage/queuefs/embedding_queue.py b/openviking/storage/queuefs/embedding_queue.py index 6af3d3d0..b6dc771a 100644 --- a/openviking/storage/queuefs/embedding_queue.py +++ b/openviking/storage/queuefs/embedding_queue.py @@ -16,7 +16,7 @@ class EmbeddingQueue(NamedQueue): Supports direct enqueue and dequeue of EmbeddingMsg objects. """ - async def enqueue(self, msg: EmbeddingMsg | None) -> str: + async def enqueue(self, msg: Optional[EmbeddingMsg]) -> str: """Serialize EmbeddingMsg object and store in queue.""" if msg is None: logger.warning("Embedding message is None, skipping enqueuing") @@ -28,17 +28,26 @@ async def dequeue(self) -> Optional[EmbeddingMsg]: data_dict = await super().dequeue() if not data_dict: return None - if "data" in data_dict and isinstance(data_dict["data"], str): - try: - return EmbeddingMsg.from_json(data_dict["data"]) - except Exception as e: - logger.debug(f"[EmbeddingQueue] Failed to parse message data: {e}") - return None + if "data" in data_dict: + if isinstance(data_dict["data"], str): + try: + return EmbeddingMsg.from_json(data_dict["data"]) + except Exception as e: + logger.debug(f"[EmbeddingQueue] Failed to parse message data: {e}") + return None + elif isinstance(data_dict["data"], dict): + try: + return EmbeddingMsg.from_dict(data_dict["data"]) + except Exception as e: + logger.debug( + f"[EmbeddingQueue] Failed to create EmbeddingMsg from data dict: {e}" + ) + return None + # Otherwise try to convert directly from dict try: return EmbeddingMsg.from_dict(data_dict) - except Exception as e: - logger.debug(f"[EmbeddingQueue] Failed to create EmbeddingMsg from dict: {e}") + except Exception: return None async def peek(self) -> Optional[EmbeddingMsg]: @@ -47,11 +56,17 @@ async def peek(self) -> Optional[EmbeddingMsg]: if not data_dict: return None - if "data" in data_dict and isinstance(data_dict["data"], str): - try: - return EmbeddingMsg.from_json(data_dict["data"]) - except Exception: - return None + if "data" in data_dict: + if isinstance(data_dict["data"], str): + try: + return EmbeddingMsg.from_json(data_dict["data"]) + except Exception: + return None + elif isinstance(data_dict["data"], dict): + try: + return EmbeddingMsg.from_dict(data_dict["data"]) + except Exception: + return None try: return EmbeddingMsg.from_dict(data_dict) diff --git a/openviking/storage/vectordb/collection/vikingdb_clients.py b/openviking/storage/vectordb/collection/vikingdb_clients.py new file mode 100644 index 00000000..287d039a --- /dev/null +++ b/openviking/storage/vectordb/collection/vikingdb_clients.py @@ -0,0 +1,100 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +import json +from typing import Any, Dict, Optional + +import requests + +from openviking.utils.logger import default_logger as logger + +# Default request timeout (seconds) +DEFAULT_TIMEOUT = 30 + +# VikingDB API Version +VIKING_DB_VERSION = "2025-06-09" + +# SDK Action to VikingDB API path and method mapping +VIKINGDB_APIS = { + # Collection APIs + "ListVikingdbCollection": ("/api/vikingdb/ListCollection", "POST"), + "CreateVikingdbCollection": ("/api/vikingdb/CreateCollection", "POST"), + "DeleteVikingdbCollection": ("/api/vikingdb/DeleteCollection", "POST"), + "UpdateVikingdbCollection": ("/api/vikingdb/UpdateCollection", "POST"), + "GetVikingdbCollection": ("/api/vikingdb/GetCollection", "POST"), + # Index APIs + "ListVikingdbIndex": ("/api/vikingdb/ListIndex", "POST"), + "CreateVikingdbIndex": ("/api/vikingdb/CreateIndex", "POST"), + "DeleteVikingdbIndex": ("/api/vikingdb/DeleteIndex", "POST"), + "UpdateVikingdbIndex": ("/api/vikingdb/UpdateIndex", "POST"), + "GetVikingdbIndex": ("/api/vikingdb/GetIndex", "POST"), + # ApiKey APIs + "ListVikingdbApiKey": ("/api/vikingdb/list", "POST"), + "CreateVikingdbApiKey": ("/api/vikingdb/create", "POST"), + "DeleteVikingdbApiKey": ("/api/vikingdb/delete", "POST"), + "UpdateVikingdbApiKey": ("/api/vikingdb/update", "POST"), + "ListVikingdbApiKeyResources": ("/api/apikey/resource/list", "POST"), +} + + +class VikingDBClient: + """ + Client for VikingDB private deployment. + Uses custom host and headers for authentication/context. + """ + + def __init__(self, host: str, headers: Optional[Dict[str, str]] = None): + """ + Initialize VikingDB client. + + Args: + host: VikingDB service host (e.g., "http://localhost:8080") + headers: Custom headers for requests + """ + self.host = host.rstrip("/") + self.headers = headers or {} + + if not self.host: + raise ValueError("Host is required for VikingDBClient") + + def do_req( + self, + method: str, + path: str = "/", + req_params: Optional[Dict[str, Any]] = None, + req_body: Optional[Dict[str, Any]] = None, + ) -> requests.Response: + """ + Perform HTTP request to VikingDB service. + + Args: + method: HTTP method (GET, POST, etc.) + path: Request path + req_params: Query parameters + req_body: Request body + + Returns: + requests.Response object + """ + if not path.startswith("/"): + path = "/" + path + + url = f"{self.host}{path}" + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + headers.update(self.headers) + + try: + response = requests.request( + method=method, + url=url, + headers=headers, + params=req_params, + data=json.dumps(req_body) if req_body is not None else None, + timeout=DEFAULT_TIMEOUT, + ) + return response + except Exception as e: + logger.error(f"Request to {url} failed: {e}") + raise e diff --git a/openviking/storage/vectordb/collection/vikingdb_collection.py b/openviking/storage/vectordb/collection/vikingdb_collection.py new file mode 100644 index 00000000..5a877edb --- /dev/null +++ b/openviking/storage/vectordb/collection/vikingdb_collection.py @@ -0,0 +1,396 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +import json +from typing import Any, Dict, List, Optional + +from openviking.storage.vectordb.collection.collection import ICollection +from openviking.storage.vectordb.collection.result import ( + AggregateResult, + DataItem, + FetchDataInCollectionResult, + SearchItemResult, + SearchResult, +) +from openviking.storage.vectordb.collection.vikingdb_clients import ( + VIKINGDB_APIS, + VikingDBClient, +) +from openviking.utils.logger import default_logger as logger + + +class VikingDBCollection(ICollection): + """ + VikingDB collection implementation for private deployment. + """ + + def __init__( + self, + host: str, + headers: Optional[Dict[str, str]] = None, + meta_data: Optional[Dict[str, Any]] = None, + ): + super().__init__() + self.client = VikingDBClient(host, headers) + self.meta_data = meta_data if meta_data is not None else {} + self.project_name = self.meta_data.get("ProjectName", "default") + self.collection_name = self.meta_data.get("CollectionName", "") + + def _console_post(self, data: Dict[str, Any], action: str): + path, method = VIKINGDB_APIS[action] + response = self.client.do_req(method, path=path, req_body=data) + if response.status_code != 200: + logger.error(f"Request to {action} failed: {response.text}") + return {} + try: + result = response.json() + if "Result" in result: + return result["Result"] + return result.get("data", {}) + except json.JSONDecodeError: + return {} + + def _console_get(self, params: Optional[Dict[str, Any]], action: str): + if params is None: + params = {} + path, method = VIKINGDB_APIS[action] + # Console GET actions are actually POSTs in VikingDB API + response = self.client.do_req(method, path=path, req_body=params) + + if response.status_code != 200: + logger.error(f"Request to {action} failed: {response.text}") + return {} + try: + result = response.json() + return result.get("Result", {}) + except json.JSONDecodeError: + return {} + + def _data_post(self, path: str, data: Dict[str, Any]): + response = self.client.do_req("POST", path, req_body=data) + if response.status_code != 200: + logger.error(f"Request to {path} failed: {response.text}") + return {} + try: + result = response.json() + return result.get("result", {}) + except json.JSONDecodeError: + return {} + + def _data_get(self, path: str, params: Dict[str, Any]): + response = self.client.do_req("GET", path, req_params=params) + if response.status_code != 200: + logger.error(f"Request to {path} failed: {response.text}") + return {} + try: + result = response.json() + return result.get("result", {}) + except json.JSONDecodeError: + return {} + + def update(self, fields: Optional[Dict[str, Any]] = None, description: Optional[str] = None): + data = { + "ProjectName": self.project_name, + "CollectionName": self.collection_name, + } + if fields: + data["Fields"] = fields + if description is not None: + data["Description"] = description + + return self._console_post(data, action="UpdateVikingdbCollection") + + def get_meta_data(self): + params = { + "ProjectName": self.project_name, + "CollectionName": self.collection_name, + } + return self._console_get(params, action="GetVikingdbCollection") + + def close(self): + pass + + def drop(self): + raise NotImplementedError("collection should be managed manually") + + def create_index(self, index_name: str, meta_data: Dict[str, Any]): + raise NotImplementedError("index should be pre-created") + + def has_index(self, index_name: str): + indexes = self.list_indexes() + return index_name in indexes if isinstance(indexes, list) else False + + def get_index(self, index_name: str): + return self.get_index_meta_data(index_name) + + def list_indexes(self): + params = { + "ProjectName": self.project_name, + "CollectionName": self.collection_name, + } + return self._console_get(params, action="ListVikingdbIndex") + + def update_index( + self, + index_name: str, + scalar_index: Optional[Dict[str, Any]] = None, + description: Optional[str] = None, + ): + raise NotImplementedError("index should be managed manually") + + def get_index_meta_data(self, index_name: str): + params = { + "ProjectName": self.project_name, + "CollectionName": self.collection_name, + "IndexName": index_name, + } + return self._console_get(params, action="GetVikingdbIndex") + + def drop_index(self, index_name: str): + raise NotImplementedError("index should be managed manually") + + def upsert_data(self, data_list: List[Dict[str, Any]], ttl: int = 0): + path = "/api/vikingdb/data/upsert" + data = { + "project": self.project_name, + "collection_name": self.collection_name, + "data": data_list, + "ttl": ttl, + } + return self._data_post(path, data) + + def fetch_data(self, primary_keys: List[Any]) -> FetchDataInCollectionResult: + path = "/api/vikingdb/data/fetch_in_collection" + data = { + "project": self.project_name, + "collection_name": self.collection_name, + "ids": primary_keys, + } + resp_data = self._data_post(path, data) + return self._parse_fetch_result(resp_data) + + def delete_data(self, primary_keys: List[Any]): + path = "/api/vikingdb/data/delete" + data = { + "project": self.project_name, + "collection_name": self.collection_name, + "ids": primary_keys, + } + return self._data_post(path, data) + + def delete_all_data(self): + path = "/api/vikingdb/data/delete" + data = { + "project": self.project_name, + "collection_name": self.collection_name, + "del_all": True, + } + return self._data_post(path, data) + + def _parse_fetch_result(self, data: Dict[str, Any]) -> FetchDataInCollectionResult: + result = FetchDataInCollectionResult() + if isinstance(data, dict): + if "fetch" in data: + fetch = data.get("fetch", []) + result.items = [ + DataItem( + id=item.get("id"), + fields=item.get("fields"), + ) + for item in fetch + ] + if "ids_not_exist" in data: + result.ids_not_exist = data.get("ids_not_exist", []) + return result + + def _parse_search_result(self, data: Dict[str, Any]) -> SearchResult: + result = SearchResult() + if isinstance(data, dict) and "data" in data: + data_list = data.get("data", []) + result.data = [ + SearchItemResult( + id=item.get("id"), + fields=item.get("fields"), + score=item.get("score"), + ) + for item in data_list + ] + return result + + def search_by_vector( + self, + index_name: str, + dense_vector: Optional[List[float]] = None, + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + sparse_vector: Optional[Dict[str, float]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + path = "/api/vikingdb/data/search/vector" + data = { + "project": self.project_name, + "collection_name": self.collection_name, + "index_name": index_name, + "dense_vector": dense_vector, + "sparse_vector": sparse_vector or {}, + "filter": filters, + "output_fields": output_fields, + "limit": limit, + "offset": offset, + } + resp_data = self._data_post(path, data) + return self._parse_search_result(resp_data) + + def search_by_id( + self, + index_name: str, + id: Any, + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + path = "/api/vikingdb/data/search/id" + data = { + "project": self.project_name, + "collection_name": self.collection_name, + "index_name": index_name, + "id": id, + "filter": filters, + "output_fields": output_fields, + "limit": limit, + "offset": offset, + } + resp_data = self._data_post(path, data) + return self._parse_search_result(resp_data) + + def search_by_multimodal( + self, + index_name: str, + text: Optional[str] = None, + image: Optional[Any] = None, + video: Optional[Any] = None, + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + path = "/api/vikingdb/data/search/multi_modal" + data = { + "project": self.project_name, + "collection_name": self.collection_name, + "index_name": index_name, + "text": text, + "image": image, + "video": video, + "filter": filters, + "output_fields": output_fields, + "limit": limit, + "offset": offset, + } + resp_data = self._data_post(path, data) + return self._parse_search_result(resp_data) + + def search_by_random( + self, + index_name: str, + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + path = "/api/vikingdb/data/search/random" + data = { + "project": self.project_name, + "collection_name": self.collection_name, + "index_name": index_name, + "filter": filters, + "output_fields": output_fields, + "limit": limit, + "offset": offset, + } + resp_data = self._data_post(path, data) + return self._parse_search_result(resp_data) + + def search_by_keywords( + self, + index_name: str, + keywords: Optional[List[str]] = None, + query: Optional[str] = None, + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + path = "/api/vikingdb/data/search/keywords" + data = { + "project": self.project_name, + "collection_name": self.collection_name, + "index_name": index_name, + "keywords": keywords, + "query": query, + "filter": filters, + "output_fields": output_fields, + "limit": limit, + "offset": offset, + } + resp_data = self._data_post(path, data) + return self._parse_search_result(resp_data) + + def search_by_scalar( + self, + index_name: str, + field: str, + order: Optional[str] = "desc", + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + path = "/api/vikingdb/data/search/scalar" + data = { + "project": self.project_name, + "collection_name": self.collection_name, + "index_name": index_name, + "field": field, + "order": order, + "filter": filters, + "output_fields": output_fields, + "limit": limit, + "offset": offset, + } + resp_data = self._data_post(path, data) + return self._parse_search_result(resp_data) + + def aggregate_data( + self, + index_name: str, + op: str = "count", + field: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, + cond: Optional[Dict[str, Any]] = None, + ) -> AggregateResult: + path = "/api/vikingdb/data/aggregate" + data = { + "project": self.project_name, + "collection_name": self.collection_name, + "index_name": index_name, + "agg": { + "op": op, + "field": field, + }, + "filter": filters, + } + resp_data = self._data_post(path, data) + return self._parse_aggregate_result(resp_data, op, field) + + def _parse_aggregate_result( + self, data: Dict[str, Any], op: str, field: Optional[str] + ) -> AggregateResult: + result = AggregateResult(op=op, field=field) + if isinstance(data, dict): + if "agg" in data: + result.agg = data["agg"] + else: + result.agg = data + return result diff --git a/openviking/storage/vectordb/meta/index_meta.py b/openviking/storage/vectordb/meta/index_meta.py index 3c2253e7..a7635e0f 100644 --- a/openviking/storage/vectordb/meta/index_meta.py +++ b/openviking/storage/vectordb/meta/index_meta.py @@ -112,9 +112,6 @@ def _build_inner_meta( vector_index["SearchWithSparseLogitAlpha"] = inner_meta["VectorIndex"].get( "SearchWithSparseLogitAlpha", 0.5 ) - vector_index["IndexWithSparseLogitAlpha"] = inner_meta["VectorIndex"].get( - "IndexWithSparseLogitAlpha", 0.5 - ) if "flat" in inner_meta["VectorIndex"]["IndexType"]: vector_index["IndexType"] = "flat" if "EnableSparse" in inner_meta["VectorIndex"]: @@ -123,10 +120,6 @@ def _build_inner_meta( vector_index["SearchWithSparseLogitAlpha"] = inner_meta["VectorIndex"][ "SearchWithSparseLogitAlpha" ] - if "IndexWithSparseLogitAlpha" in inner_meta["VectorIndex"]: - vector_index["IndexWithSparseLogitAlpha"] = inner_meta["VectorIndex"][ - "IndexWithSparseLogitAlpha" - ] inner_meta["VectorIndex"] = vector_index inner_meta["CollectionName"] = collection_meta.collection_name diff --git a/openviking/storage/vectordb/project/vikingdb_project.py b/openviking/storage/vectordb/project/vikingdb_project.py new file mode 100644 index 00000000..bfb42d54 --- /dev/null +++ b/openviking/storage/vectordb/project/vikingdb_project.py @@ -0,0 +1,162 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional + +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.vikingdb_clients import ( + VIKINGDB_APIS, + VikingDBClient, +) +from openviking.storage.vectordb.collection.vikingdb_collection import VikingDBCollection +from openviking.utils.logger import default_logger as logger + + +def get_or_create_vikingdb_project( + project_name: str = "default", config: Optional[Dict[str, Any]] = None +): + """ + Get or create a VikingDB project for private deployment. + + Args: + project_name: Project name + config: Configuration dict with keys: + - Host: VikingDB service host + - Headers: Custom headers for authentication/context + + Returns: + VikingDBProject instance + """ + if config is None: + raise ValueError("config is required") + + host = config.get("Host") + headers = config.get("Headers") + + if not host: + raise ValueError("config must contain 'Host'") + + return VikingDBProject(host=host, headers=headers, project_name=project_name) + + +class VikingDBProject: + """ + VikingDB project class for private deployment. + Manages multiple VikingDBCollection instances. + """ + + def __init__( + self, host: str, headers: Optional[Dict[str, str]] = None, project_name: str = "default" + ): + """ + Initialize VikingDB project. + + Args: + host: VikingDB service host + headers: Custom headers for requests + project_name: Project name + """ + self.host = host + self.headers = headers + self.project_name = project_name + + logger.info(f"Initialized VikingDB project: {project_name} with host {host}") + + def close(self): + """Close project""" + pass + + def has_collection(self, collection_name: str) -> bool: + """Check if collection exists by calling API""" + client = VikingDBClient(self.host, self.headers) + path, method = VIKINGDB_APIS["GetVikingdbCollection"] + data = {"ProjectName": self.project_name, "CollectionName": collection_name} + response = client.do_req(method, path=path, req_body=data) + return response.status_code == 200 + + def get_collection(self, collection_name: str) -> Optional[Collection]: + """Get collection by name by calling API""" + client = VikingDBClient(self.host, self.headers) + path, method = VIKINGDB_APIS["GetVikingdbCollection"] + data = {"ProjectName": self.project_name, "CollectionName": collection_name} + response = client.do_req(method, path=path, req_body=data) + if response.status_code != 200: + return None + + try: + result = response.json() + meta_data = result.get("Result", {}) + if not meta_data: + return None + vikingdb_collection = VikingDBCollection( + host=self.host, headers=self.headers, meta_data=meta_data + ) + return Collection(vikingdb_collection) + except Exception: + return None + + def _get_collections(self) -> List[str]: + """List all collection names from server""" + client = VikingDBClient(self.host, self.headers) + path, method = VIKINGDB_APIS["ListVikingdbCollection"] + data = {"ProjectName": self.project_name} + response = client.do_req(method, path=path, req_body=data) + if response.status_code != 200: + logger.error(f"List collections failed: {response.text}") + return [] + try: + result = response.json() + colls = result.get("Result", {}).get("Collections", []) + return colls + except Exception: + return [] + + def list_collections(self) -> List[str]: + """List all collection names from server""" + colls = self._get_collections() + return [coll.get("CollectionName") for coll in colls] + + def get_collections(self) -> Dict[str, Collection]: + """Get all collections from server""" + colls = self._get_collections() + return { + c["CollectionName"]: Collection( + VikingDBCollection(host=self.host, headers=self.headers, meta_data=c) + ) + for c in colls + } + + def create_collection(self, collection_name: str, meta_data: Dict[str, Any]) -> Collection: + """collection should be pre-created""" + raise NotImplementedError("collection should be pre-created") + + def get_or_create_collection( + self, collection_name: str, meta_data: Optional[Dict[str, Any]] = None + ) -> Collection: + """ + Get or create collection. + + Args: + collection_name: Collection name + meta_data: Collection metadata (required if not exists) + + Returns: + Collection instance + """ + collection = self.get_collection(collection_name) + if collection: + return collection + + if meta_data is None: + raise ValueError(f"meta_data is required to create collection {collection_name}") + + return self.create_collection(collection_name, meta_data) + + def drop_collection(self, collection_name: str): + """Drop specified collection""" + collection = self.get_collection(collection_name) + if not collection: + logger.warning(f"Collection {collection_name} does not exist") + return + + collection.drop() + logger.info(f"Dropped VikingDB collection: {collection_name}") diff --git a/openviking/storage/viking_vector_index_backend.py b/openviking/storage/viking_vector_index_backend.py index ff656f0b..8f4aec68 100644 --- a/openviking/storage/viking_vector_index_backend.py +++ b/openviking/storage/viking_vector_index_backend.py @@ -77,8 +77,9 @@ def __init__( """ init_cpp_logging() - self.vector_dim = config.vector_dim + self.vector_dim = config.dimension self.distance_metric = config.distance_metric + self.sparse_weight = config.sparse_weight if config.backend == "volcengine": if not ( @@ -108,6 +109,24 @@ def __init__( logger.info( f"VectorDB backend initialized in Volcengine mode: region={volc_config['Region']}" ) + elif config.backend == "vikingdb": + if not config.vikingdb.host: + raise ValueError("VikingDB backend requires a valid host") + # VikingDB private deployment mode + self._mode = config.backend + viking_config = { + "Host": config.vikingdb.host, + "Headers": config.vikingdb.headers, + } + + from openviking.storage.vectordb.project.vikingdb_project import ( + get_or_create_vikingdb_project, + ) + + self.project = get_or_create_vikingdb_project( + project_name=self.DEFAULT_PROJECT_NAME, config=viking_config + ) + logger.info(f"VikingDB backend initialized in private mode: {config.vikingdb.host}") elif config.backend == "http": if not config.url: raise ValueError("HTTP backend requires a valid URL") @@ -242,15 +261,19 @@ async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: scalar_index_fields.append(field_name) # Create default index for the collection + use_sparse = self.sparse_weight > 0.0 index_meta = { "IndexName": self.DEFAULT_INDEX_NAME, "VectorIndex": { - "IndexType": "flat_hybrid", + "IndexType": "flat_hybrid" if use_sparse else "flat", "Distance": distance, "Quant": "int8", }, "ScalarIndex": scalar_index_fields, } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = self.sparse_weight logger.info(f"Creating index with meta: {index_meta}") collection.create_index(self.DEFAULT_INDEX_NAME, index_meta) diff --git a/openviking/sync_client.py b/openviking/sync_client.py index 0709e534..9a5da535 100644 --- a/openviking/sync_client.py +++ b/openviking/sync_client.py @@ -4,13 +4,13 @@ Synchronous OpenViking client implementation. """ -import asyncio from typing import TYPE_CHECKING, Any, Dict, List, Optional if TYPE_CHECKING: from openviking.session import Session from openviking.async_client import AsyncOpenViking +from openviking.utils import run_async class SyncOpenViking: @@ -25,7 +25,7 @@ def __init__(self, **kwargs): def initialize(self) -> None: """Initialize OpenViking storage and indexes.""" - asyncio.run(self._async_client.initialize()) + run_async(self._async_client.initialize()) self._initialized = True def session(self, session_id: Optional[str] = None) -> "Session": @@ -42,7 +42,7 @@ def add_resource( timeout: float = None, ) -> Dict[str, Any]: """Add resource to OpenViking (resources scope only)""" - return asyncio.run( + return run_async( self._async_client.add_resource(path, target, reason, instruction, wait, timeout) ) @@ -53,7 +53,7 @@ def add_skill( timeout: float = None, ) -> Dict[str, Any]: """Add skill to OpenViking.""" - return asyncio.run(self._async_client.add_skill(data, wait=wait, timeout=timeout)) + return run_async(self._async_client.add_skill(data, wait=wait, timeout=timeout)) def search( self, @@ -65,7 +65,7 @@ def search( filter: Optional[Dict] = None, ): """Execute complex retrieval (intent analysis, hierarchical retrieval).""" - return asyncio.run( + return run_async( self._async_client.search(query, target_uri, session, limit, score_threshold, filter) ) @@ -77,19 +77,19 @@ def find( score_threshold: Optional[float] = None, ): """Quick retrieval""" - return asyncio.run(self._async_client.find(query, target_uri, limit, score_threshold)) + return run_async(self._async_client.find(query, target_uri, limit, score_threshold)) def abstract(self, uri: str) -> str: """Read L0 abstract""" - return asyncio.run(self._async_client.abstract(uri)) + return run_async(self._async_client.abstract(uri)) def overview(self, uri: str) -> str: """Read L1 overview""" - return asyncio.run(self._async_client.overview(uri)) + return run_async(self._async_client.overview(uri)) def read(self, uri: str) -> str: """Read file""" - return asyncio.run(self._async_client.read(uri)) + return run_async(self._async_client.read(uri)) def ls(self, uri: str, **kwargs) -> List[Any]: """ @@ -100,65 +100,65 @@ def ls(self, uri: str, **kwargs) -> List[Any]: simple: Return only relative path list (bool, default: False) recursive: List all subdirectories recursively (bool, default: False) """ - return asyncio.run(self._async_client.ls(uri, **kwargs)) + return run_async(self._async_client.ls(uri, **kwargs)) def link(self, from_uri: str, uris: Any, reason: str = "") -> None: """Create relation""" - return asyncio.run(self._async_client.link(from_uri, uris, reason)) + return run_async(self._async_client.link(from_uri, uris, reason)) def unlink(self, from_uri: str, uri: str) -> None: """Delete relation""" - return asyncio.run(self._async_client.unlink(from_uri, uri)) + return run_async(self._async_client.unlink(from_uri, uri)) def export_ovpack(self, uri: str, to: str) -> str: """Export .ovpack file""" - return asyncio.run(self._async_client.export_ovpack(uri, to)) + return run_async(self._async_client.export_ovpack(uri, to)) def import_ovpack( self, file_path: str, target: str, force: bool = False, vectorize: bool = True ) -> str: """Import .ovpack file (triggers vectorization by default)""" - return asyncio.run(self._async_client.import_ovpack(file_path, target, force, vectorize)) + return run_async(self._async_client.import_ovpack(file_path, target, force, vectorize)) def close(self) -> None: """Close OpenViking and release resources.""" - return asyncio.run(self._async_client.close()) + return run_async(self._async_client.close()) def relations(self, uri: str) -> List[Dict[str, Any]]: """Get relations""" - return asyncio.run(self._async_client.relations(uri)) + return run_async(self._async_client.relations(uri)) def rm(self, uri: str, recursive: bool = False) -> None: """Delete resource""" - return asyncio.run(self._async_client.rm(uri, recursive)) + return run_async(self._async_client.rm(uri, recursive)) def wait_processed(self, timeout: float = None) -> None: """Wait for all async operations to complete""" - return asyncio.run(self._async_client.wait_processed(timeout)) + return run_async(self._async_client.wait_processed(timeout)) def grep(self, uri: str, pattern: str, case_insensitive: bool = False) -> Dict: """Content search""" - return asyncio.run(self._async_client.grep(uri, pattern, case_insensitive)) + return run_async(self._async_client.grep(uri, pattern, case_insensitive)) def glob(self, pattern: str, uri: str = "viking://") -> Dict: """File pattern matching""" - return asyncio.run(self._async_client.glob(pattern, uri)) + return run_async(self._async_client.glob(pattern, uri)) def mv(self, from_uri: str, to_uri: str) -> None: """Move resource""" - return asyncio.run(self._async_client.mv(from_uri, to_uri)) + return run_async(self._async_client.mv(from_uri, to_uri)) def tree(self, uri: str) -> Dict: """Get directory tree""" - return asyncio.run(self._async_client.tree(uri)) + return run_async(self._async_client.tree(uri)) def stat(self, uri: str) -> Dict: """Get resource status""" - return asyncio.run(self._async_client.stat(uri)) + return run_async(self._async_client.stat(uri)) def mkdir(self, uri: str) -> None: """Create directory""" - return asyncio.run(self._async_client.mkdir(uri)) + return run_async(self._async_client.mkdir(uri)) def get_status(self): """Get system status. @@ -196,4 +196,4 @@ def _session_compressor(self): @classmethod def reset(cls) -> None: """Reset singleton (for testing).""" - return asyncio.run(AsyncOpenViking.reset()) + return run_async(AsyncOpenViking.reset()) diff --git a/openviking/utils/__init__.py b/openviking/utils/__init__.py index 0bb8fbfe..939b8fd1 100644 --- a/openviking/utils/__init__.py +++ b/openviking/utils/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Utility functions and helpers.""" +from openviking.utils.async_utils import run_async from openviking.utils.llm import StructuredLLM, parse_json_from_response, parse_json_to_model from openviking.utils.logger import default_logger, get_logger from openviking.utils.uri import VikingURI @@ -13,4 +14,5 @@ "StructuredLLM", "parse_json_from_response", "parse_json_to_model", + "run_async", ] diff --git a/openviking/utils/async_utils.py b/openviking/utils/async_utils.py new file mode 100644 index 00000000..8e47034e --- /dev/null +++ b/openviking/utils/async_utils.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +""" +Async helper utilities for running coroutines from sync code. +""" + +import asyncio +from typing import Coroutine, TypeVar + +T = TypeVar("T") + + +def run_async(coro: Coroutine[None, None, T]) -> T: + """ + Run async coroutine from sync code, handling nested event loops. + + This function safely runs a coroutine whether or not there's already + a running event loop (e.g., when called from within an MCP server). + + Args: + coro: The coroutine to run + + Returns: + The result of the coroutine + """ + try: + loop = asyncio.get_running_loop() + # Already in event loop, use nest_asyncio to allow nested calls + import nest_asyncio + + nest_asyncio.apply() + return loop.run_until_complete(coro) + except RuntimeError: + # No running event loop, use asyncio.run() + return asyncio.run(coro) diff --git a/openviking/utils/config/agfs_config.py b/openviking/utils/config/agfs_config.py index 454b2355..f7b963eb 100644 --- a/openviking/utils/config/agfs_config.py +++ b/openviking/utils/config/agfs_config.py @@ -5,63 +5,88 @@ from pydantic import BaseModel, Field, model_validator -class AGFSConfig(BaseModel): - """Configuration for AGFS (Agent Global File System).""" - - path: str = Field(default="./data", description="AGFS data storage path") - - port: int = Field(default=8080, description="AGFS service port") - - log_level: str = Field(default="warn", description="AGFS log level") +class S3Config(BaseModel): + """Configuration for S3 backend.""" - url: Optional[str] = Field( - default="http://localhost:8080", description="AGFS service URL for service mode" - ) + bucket: Optional[str] = Field(default=None, description="S3 bucket name") - backend: str = Field( - default="local", description="AGFS storage backend: 'local' | 's3' | 'memory'" - ) - - timeout: int = Field(default=10, description="AGFS request timeout (seconds)") - - retry_times: int = Field(default=3, description="AGFS retry times on failure") - - # S3 backend configuration - # These settings are used when backend is set to 's3'. - # AGFS will act as a gateway to the specified S3 bucket. - s3_bucket: Optional[str] = Field(default=None, description="S3 bucket name") - - s3_region: Optional[str] = Field( + region: Optional[str] = Field( default=None, description="AWS region where the bucket is located (e.g., us-east-1, cn-beijing)", ) - s3_access_key: Optional[str] = Field( + access_key: Optional[str] = Field( default=None, description="S3 access key ID. If not provided, AGFS may attempt to use environment variables or IAM roles.", ) - s3_secret_key: Optional[str] = Field( + secret_key: Optional[str] = Field( default=None, description="S3 secret access key corresponding to the access key ID.", ) - s3_endpoint: Optional[str] = Field( + endpoint: Optional[str] = Field( default=None, description="Custom S3 endpoint URL. Required for S3-compatible services like MinIO or LocalStack. " "Leave empty for standard AWS S3.", ) - s3_prefix: Optional[str] = Field( + prefix: Optional[str] = Field( default="", description="Optional key prefix for namespace isolation. All objects will be stored under this prefix.", ) - s3_use_ssl: bool = Field( + use_ssl: bool = Field( default=True, description="Enable/Disable SSL (HTTPS) for S3 connections. Set to False for local testing without HTTPS.", ) + def validate_config(self): + """Validate S3 configuration completeness""" + missing = [] + if not self.bucket: + missing.append("bucket") + if not self.endpoint: + missing.append("endpoint") + if not self.region: + missing.append("region") + if not self.access_key: + missing.append("access_key") + if not self.secret_key: + missing.append("secret_key") + + if missing: + raise ValueError(f"S3 backend requires the following fields: {', '.join(missing)}") + + return self + + +class AGFSConfig(BaseModel): + """Configuration for AGFS (Agent Global File System).""" + + path: str = Field(default="./data", description="AGFS data storage path") + + port: int = Field(default=8080, description="AGFS service port") + + log_level: str = Field(default="warn", description="AGFS log level") + + url: Optional[str] = Field( + default="http://localhost:8080", description="AGFS service URL for service mode" + ) + + backend: str = Field( + default="local", description="AGFS storage backend: 'local' | 's3' | 'memory'" + ) + + timeout: int = Field(default=10, description="AGFS request timeout (seconds)") + + retry_times: int = Field(default=3, description="AGFS retry times on failure") + + # S3 backend configuration + # These settings are used when backend is set to 's3'. + # AGFS will act as a gateway to the specified S3 bucket. + s3: S3Config = Field(default_factory=lambda: S3Config(), description="S3 backend configuration") + @model_validator(mode="after") def validate_config(self): """Validate configuration completeness and consistency""" @@ -75,21 +100,7 @@ def validate_config(self): raise ValueError("AGFS local backend requires 'path' to be set") elif self.backend == "s3": - missing = [] - if not self.s3_bucket: - missing.append("s3_bucket") - if not self.s3_endpoint: - missing.append("s3_endpoint") - if not self.s3_region: - missing.append("s3_region") - if not self.s3_access_key: - missing.append("s3_access_key") - if not self.s3_secret_key: - missing.append("s3_secret_key") - - if missing: - raise ValueError( - f"AGFS S3 backend requires the following fields: {', '.join(missing)}" - ) + # Validate S3 configuration + self.s3.validate_config() return self diff --git a/openviking/utils/config/open_viking_config.py b/openviking/utils/config/open_viking_config.py index ada70a23..c851b579 100644 --- a/openviking/utils/config/open_viking_config.py +++ b/openviking/utils/config/open_viking_config.py @@ -292,7 +292,7 @@ def initialize_openviking_config( # Embedded mode: local storage config.storage.agfs.backend = config.storage.agfs.backend or "local" config.storage.agfs.path = path - config.storage.vectordb.backend = "local" + config.storage.vectordb.backend = config.storage.vectordb.backend or "local" config.storage.vectordb.path = path elif vectordb_url and agfs_url: # Service mode: remote services @@ -301,6 +301,10 @@ def initialize_openviking_config( config.storage.vectordb.backend = "http" config.storage.vectordb.url = vectordb_url + # Ensure vector dimension is synced if not set in storage + if config.storage.vectordb.dimension == 0: + config.storage.vectordb.dimension = config.embedding.dimension + # Validate configuration if not is_valid_openviking_config(config): raise ValueError("Invalid OpenViking configuration") diff --git a/openviking/utils/config/vectordb_config.py b/openviking/utils/config/vectordb_config.py index dc32c7e4..8786abca 100644 --- a/openviking/utils/config/vectordb_config.py +++ b/openviking/utils/config/vectordb_config.py @@ -1,6 +1,6 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +from typing import Dict, Optional from pydantic import BaseModel, Field, model_validator @@ -16,6 +16,15 @@ class VolcengineConfig(BaseModel): host: Optional[str] = Field(default=None, description="Volcengine VikingDB host (optional)") +class VikingDBConfig(BaseModel): + """Configuration for VikingDB private deployment.""" + + host: Optional[str] = Field(default=None, description="VikingDB service host") + headers: Optional[Dict[str, str]] = Field( + default_factory=dict, description="Custom headers for requests" + ) + + class VectorDBBackendConfig(BaseModel): """ Configuration for VectorDB backend. @@ -43,22 +52,36 @@ class VectorDBBackendConfig(BaseModel): description="Distance metric for vector similarity search (e.g., 'cosine', 'l2', 'ip')", ) - vector_dim: int = Field( + dimension: int = Field( default=0, description="Dimension of vector embeddings", ) + sparse_weight: float = Field( + default=0.0, + description=( + "Sparse weight for hybrid vector search. " + "When > 0, sparse vectors are used for index build and search." + ), + ) + volcengine: Optional[VolcengineConfig] = Field( default_factory=lambda: VolcengineConfig(), description="Volcengine VikingDB configuration for 'volcengine' type", ) + # VikingDB private deployment mode + vikingdb: Optional[VikingDBConfig] = Field( + default_factory=lambda: VikingDBConfig(), + description="VikingDB private deployment configuration for 'vikingdb' type", + ) + @model_validator(mode="after") def validate_config(self): """Validate configuration completeness and consistency""" - if self.backend not in ["local", "http", "volcengine"]: + if self.backend not in ["local", "http", "volcengine", "vikingdb"]: raise ValueError( - f"Invalid VectorDB backend: '{self.backend}'. Must be one of: 'local', 'http', 'volcengine'" + f"Invalid VectorDB backend: '{self.backend}'. Must be one of: 'local', 'http', 'volcengine', 'vikingdb'" ) if self.backend == "local": @@ -75,4 +98,8 @@ def validate_config(self): if not self.volcengine.region: raise ValueError("VectorDB volcengine backend requires 'region' to be set") + elif self.backend == "vikingdb": + if not self.vikingdb or not self.vikingdb.host: + raise ValueError("VectorDB vikingdb backend requires 'host' to be set") + return self diff --git a/src/index/detail/meta/vector_index_meta.cpp b/src/index/detail/meta/vector_index_meta.cpp index 238b1d11..a182bde0 100644 --- a/src/index/detail/meta/vector_index_meta.cpp +++ b/src/index/detail/meta/vector_index_meta.cpp @@ -48,10 +48,6 @@ int VectorIndexMeta::init_from_json(const JsonValue& json) { search_with_sparse_logit_alpha = json["SearchWithSparseLogitAlpha"].GetFloat(); } - if (json.HasMember("IndexWithSparseLogitAlpha")) { - index_with_sparse_logit_alpha = - json["IndexWithSparseLogitAlpha"].GetFloat(); - } return 0; } @@ -102,9 +98,7 @@ int VectorIndexMeta::save_to_json(JsonPrettyWriter& writer) { writer.Bool(enable_sparse); writer.Key("SearchWithSparseLogitAlpha"); writer.Double(search_with_sparse_logit_alpha); - writer.Key("IndexWithSparseLogitAlpha"); - writer.Double(index_with_sparse_logit_alpha); return 0; } -} // namespace vectordb \ No newline at end of file +} // namespace vectordb diff --git a/src/index/detail/vector/common/bruteforce.h b/src/index/detail/vector/common/bruteforce.h index 7069f555..2a065acf 100644 --- a/src/index/detail/vector/common/bruteforce.h +++ b/src/index/detail/vector/common/bruteforce.h @@ -178,6 +178,11 @@ class BruteforceSearch { std::vector& scores) const { if (!query_data) return; + if (k == 0) { + labels.clear(); + scores.clear(); + return; + } if (current_count_ == 0) return; @@ -204,17 +209,8 @@ class BruteforceSearch { continue; } - float dist = dist_func(encoded_query.data(), ptr, dist_params); - - if (sparse_index_) { - float sparse_dist = 0; - if (query_sparse_view) { - sparse_dist = - sparse_index_->sparse_head_output(*query_sparse_view, i); - } - dist = dist * (1 - meta_->search_with_sparse_logit_alpha) + - sparse_dist * meta_->search_with_sparse_logit_alpha; - } + float dist = compute_score(encoded_query.data(), ptr, query_sparse_view, + i, dist_func, dist_params); uint64_t label; std::memcpy(&label, ptr + vector_byte_size_, sizeof(uint64_t)); @@ -307,6 +303,7 @@ class BruteforceSearch { private: void setup_metric() { + reverse_query_score_ = (meta_->distance_type == "l2"); if (meta_->quantization_type == "int8") { if (meta_->distance_type == "l2") space_ = std::make_unique(meta_->dimension); @@ -357,6 +354,26 @@ class BruteforceSearch { return view; } + float compute_score( + const void* encoded_query, const char* data_ptr, + const std::shared_ptr& query_sparse_view, + size_t idx, MetricFunc dist_func, + void* dist_params) const { + float dense_raw = dist_func(encoded_query, data_ptr, dist_params); + float dense_score = + reverse_query_score_ ? (1.0f - dense_raw) : dense_raw; + if (!sparse_index_ || !query_sparse_view || + meta_->search_with_sparse_logit_alpha <= 0.0f) { + return dense_score; + } + float sparse_raw = + sparse_index_->sparse_head_output(*query_sparse_view, idx); + float sparse_score = + reverse_query_score_ ? (1.0f - sparse_raw) : sparse_raw; + float alpha = meta_->search_with_sparse_logit_alpha; + return dense_score * (1.0f - alpha) + sparse_score * alpha; + } + std::shared_ptr meta_; char* data_buffer_ = nullptr; size_t capacity_ = 0; @@ -370,6 +387,7 @@ class BruteforceSearch { std::unique_ptr> space_; std::unique_ptr quantizer_; std::unique_ptr sparse_index_; + bool reverse_query_score_ = false; }; } // namespace vectordb diff --git a/src/index/detail/vector/common/space_int8.h b/src/index/detail/vector/common/space_int8.h index 9e15acf5..36dda0e9 100644 --- a/src/index/detail/vector/common/space_int8.h +++ b/src/index/detail/vector/common/space_int8.h @@ -153,7 +153,7 @@ static float l2_distance_int8(const void* v1, const void* v2, float real_ip = static_cast(ip) * scale1 * scale2; float dist = norm_sq1 + norm_sq2 - 2.0f * real_ip; - return 1.0f - std::max(0.0f, dist); + return std::max(0.0f, dist); } class InnerProductSpaceInt8 : public VectorSpace { diff --git a/src/index/detail/vector/common/space_l2.h b/src/index/detail/vector/common/space_l2.h index 7b0e50b4..a6bef28d 100644 --- a/src/index/detail/vector/common/space_l2.h +++ b/src/index/detail/vector/common/space_l2.h @@ -18,7 +18,7 @@ static float l2_sqr_ref(const void* v1, const void* v2, const void* params) { float diff = pv1[i] - pv2[i]; res += diff * diff; } - return 1.0f - res; + return res; } #if defined(OV_SIMD_AVX512) @@ -46,7 +46,7 @@ static float l2_sqr_avx512(const void* v1, const void* v2, const void* params) { res += diff * diff; } - return 1.0f - res; + return res; } #endif @@ -84,7 +84,7 @@ static float l2_sqr_avx(const void* v1, const void* v2, const void* params) { res += diff * diff; } - return 1.0f - res; + return res; } #endif @@ -117,7 +117,7 @@ static float l2_sqr_sse(const void* v1, const void* v2, const void* params) { res += diff * diff; } - return 1.0f - res; + return res; } #endif diff --git a/src/index/detail/vector/sparse_retrieval/sparse_data_holder.h b/src/index/detail/vector/sparse_retrieval/sparse_data_holder.h index 382fcd73..03720bc1 100644 --- a/src/index/detail/vector/sparse_retrieval/sparse_data_holder.h +++ b/src/index/detail/vector/sparse_retrieval/sparse_data_holder.h @@ -142,7 +142,7 @@ class SparseDataHolder { float sparse_head_dot_product_logit(const SparseDatapointView& x, const DocID docid) { - return 1.0f - sparse_holder_.sparse_dot_product_reduce(x, docid); + return sparse_holder_.sparse_dot_product_reduce(x, docid); } size_t rows() { diff --git a/tests/client/test_resource_management.py b/tests/client/test_resource_management.py index c82fd8a4..6fac382f 100644 --- a/tests/client/test_resource_management.py +++ b/tests/client/test_resource_management.py @@ -41,9 +41,9 @@ async def test_add_resource_without_wait( ) assert "root_uri" in result - # In async mode, status can be monitored via observers - observers = client.observers - assert "queue" in observers + # In async mode, status can be monitored via observer + observer = client.observer + assert observer.queue is not None async def test_add_resource_with_target( self, client: AsyncOpenViking, sample_markdown_file: Path diff --git a/tests/misc/test_config_validation.py b/tests/misc/test_config_validation.py index 35e50b0b..0e5686e7 100644 --- a/tests/misc/test_config_validation.py +++ b/tests/misc/test_config_validation.py @@ -5,7 +5,7 @@ import sys -from openviking.utils.config.agfs_config import AGFSConfig +from openviking.utils.config.agfs_config import AGFSConfig, S3Config from openviking.utils.config.embedding_config import EmbeddingConfig, EmbeddingModelConfig from openviking.utils.config.vectordb_config import VectorDBBackendConfig from openviking.utils.config.vlm_config import VLMConfig @@ -46,10 +46,13 @@ def test_agfs_validation(): try: config = AGFSConfig( backend="s3", - s3_bucket="my-bucket", - s3_region="us-west-1", - s3_access_key="fake-access-key-for-testing", - s3_secret_key="fake-secret-key-for-testing-12345", + s3=S3Config( + bucket="my-bucket", + region="us-west-1", + access_key="fake-access-key-for-testing", + secret_key="fake-secret-key-for-testing-12345", + endpoint="https://s3.amazonaws.com" + ), ) print(" Pass") except ValueError as e: diff --git a/tests/misc/test_vikingdb_observer.py b/tests/misc/test_vikingdb_observer.py index 0cfbfa16..35318fc5 100644 --- a/tests/misc/test_vikingdb_observer.py +++ b/tests/misc/test_vikingdb_observer.py @@ -24,43 +24,44 @@ async def test_vikingdb_observer(): # Test observer access print("\n1. Test observer access:") - print(f"Available observers: {list(client.observers.keys())}") + print(f"Observer service: {client.observer}") # Test QueueObserver print("\n2. Test QueueObserver:") - queue_observer = client.observers["queue"] - print(f"Type: {type(queue_observer)}") - print(f"Is healthy: {queue_observer.is_healthy()}") - print(f"Has errors: {queue_observer.has_errors()}") + queue_status = client.observer.queue + print(f"Type: {type(queue_status)}") + print(f"Is healthy: {queue_status.is_healthy}") + print(f"Has errors: {queue_status.has_errors}") # Test direct print print("\n3. Test direct print QueueObserver:") - print(queue_observer) + print(queue_status) # Test VikingDBObserver print("\n4. Test VikingDBObserver:") - vikingdb_observer = client.observers["vikingdb"] - print(f"Type: {type(vikingdb_observer)}") - print(f"Is healthy: {vikingdb_observer.is_healthy()}") - print(f"Has errors: {vikingdb_observer.has_errors()}") + vikingdb_status = client.observer.vikingdb + print(f"Type: {type(vikingdb_status)}") + print(f"Is healthy: {vikingdb_status.is_healthy}") + print(f"Has errors: {vikingdb_status.has_errors}") # Test direct print print("\n5. Test direct print VikingDBObserver:") - print(vikingdb_observer) - - # Test get status table - print("\n6. Test get status table:") - status_table = vikingdb_observer.get_status_table() - print(f"Status table type: {type(status_table)}") - print(f"Status table length: {len(status_table)}") - - # Test observer properties - print("\n7. Test observer properties:") - for name, observer in client.observers.items(): + print(vikingdb_status) + + # Test status string + print("\n6. Test status string:") + print(f"Status type: {type(vikingdb_status.status)}") + print(f"Status length: {len(vikingdb_status.status)}") + + # Test system status + print("\n7. Test system status:") + system_status = client.observer.system + print(f"System is_healthy: {system_status.is_healthy}") + for name, component in system_status.components.items(): print(f"\n{name}:") - print(f" is_healthy: {observer.is_healthy()}") - print(f" has_errors: {observer.has_errors()}") - print(f" str(observer): {str(observer)[:100]}...") + print(f" is_healthy: {component.is_healthy}") + print(f" has_errors: {component.has_errors}") + print(f" status: {component.status[:100]}...") print("\n=== All tests completed ===") @@ -80,7 +81,7 @@ def test_sync_client(): """Test sync client""" print("\n=== Test sync client ===") - client = ov.OpenViking(path="./test_data_sync") + client = ov.OpenViking(path="./test_data") try: # Initialize @@ -88,15 +89,15 @@ def test_sync_client(): print("Sync client initialized successfully") # Test observer access - print(f"Available observers: {list(client.observers.keys())}") + print(f"Observer service: {client.observer}") # Test QueueObserver print("\nQueueObserver status:") - print(client.observers["queue"]) + print(client.observer.queue) # Test VikingDBObserver print("\nVikingDBObserver status:") - print(client.observers["vikingdb"]) + print(client.observer.vikingdb) print("\n=== Sync client test completed ===") diff --git a/tests/vectordb/test_recall.py b/tests/vectordb/test_recall.py index cdc5c39e..b3230069 100644 --- a/tests/vectordb/test_recall.py +++ b/tests/vectordb/test_recall.py @@ -236,6 +236,189 @@ def test_ip_recall_topk(self): ) print("✓ IP Recall verified") + def test_search_limit_zero(self): + """Test search with limit=0 returns empty result without error""" + print("\n=== Test: Search limit=0 ===") + + dim = 8 + meta_data = { + "CollectionName": "test_limit_zero", + "Fields": [ + {"FieldName": "id", "FieldType": "int64", "IsPrimaryKey": True}, + {"FieldName": "vector", "FieldType": "vector", "Dim": dim}, + ], + } + + collection = self.register_collection( + get_or_create_local_collection(meta_data=meta_data, path=TEST_DB_PATH) + ) + + data = [{"id": 0, "vector": [0.1] * dim}, {"id": 1, "vector": [0.2] * dim}] + collection.upsert_data(data) + + collection.create_index( + "idx_limit_zero", + { + "IndexName": "idx_limit_zero", + "VectorIndex": {"IndexType": "flat", "Distance": "l2"}, + }, + ) + + result = collection.search_by_vector("idx_limit_zero", dense_vector=[0.1] * dim, limit=0) + + self.assertEqual(len(result.data), 0, "limit=0 should return empty results") + print("✓ limit=0 returns empty results") + + def test_sparse_vector_recall(self): + """Test sparse vector recall in hybrid index""" + print("\n=== Test: Sparse Vector Recall ===") + + dim = 4 + meta_data = { + "CollectionName": "test_sparse_recall", + "Fields": [ + {"FieldName": "id", "FieldType": "int64", "IsPrimaryKey": True}, + {"FieldName": "vector", "FieldType": "vector", "Dim": dim}, + {"FieldName": "sparse_vector", "FieldType": "sparse_vector"}, + ], + } + + collection = self.register_collection( + get_or_create_local_collection(meta_data=meta_data, path=TEST_DB_PATH) + ) + + dense_vec = [0.1] * dim + data = [ + {"id": 0, "vector": dense_vec, "sparse_vector": {"t1": 1.0}}, + {"id": 1, "vector": dense_vec, "sparse_vector": {"t1": 0.5}}, + {"id": 2, "vector": dense_vec, "sparse_vector": {"t2": 1.0}}, + ] + collection.upsert_data(data) + + collection.create_index( + "idx_sparse", + { + "IndexName": "idx_sparse", + "VectorIndex": { + "IndexType": "flat_hybrid", + "Distance": "ip", + "SearchWithSparseLogitAlpha": 1.0, + }, + }, + ) + + result = collection.search_by_vector( + "idx_sparse", + dense_vector=dense_vec, + sparse_vector={"t1": 1.0}, + limit=3, + ) + result_ids = [item.id for item in result.data] + + self.assertEqual(result_ids, [0, 1, 2], "Sparse ranking should match dot product order") + print("✓ Sparse vector recall verified", result) + + def test_sparse_vector_recall_l2(self): + """Test sparse vector recall with L2 distance in hybrid index""" + print("\n=== Test: Sparse Vector Recall (L2) ===") + + dim = 4 + meta_data = { + "CollectionName": "test_sparse_recall_l2", + "Fields": [ + {"FieldName": "id", "FieldType": "int64", "IsPrimaryKey": True}, + {"FieldName": "vector", "FieldType": "vector", "Dim": dim}, + {"FieldName": "sparse_vector", "FieldType": "sparse_vector"}, + ], + } + + collection = self.register_collection( + get_or_create_local_collection(meta_data=meta_data, path=TEST_DB_PATH) + ) + + dense_vec = [0.1] * dim + data = [ + {"id": 0, "vector": dense_vec, "sparse_vector": {"t1": 1.0}}, + {"id": 1, "vector": dense_vec, "sparse_vector": {"t1": 0.5}}, + {"id": 2, "vector": dense_vec, "sparse_vector": {"t2": 1.0}}, + ] + collection.upsert_data(data) + + collection.create_index( + "idx_sparse_l2", + { + "IndexName": "idx_sparse_l2", + "VectorIndex": { + "IndexType": "flat_hybrid", + "Distance": "l2", + "SearchWithSparseLogitAlpha": 1.0, + }, + }, + ) + + result = collection.search_by_vector( + "idx_sparse_l2", + dense_vector=dense_vec, + sparse_vector={"t1": 1.0}, + limit=3, + ) + result_ids = [item.id for item in result.data] + + self.assertEqual(result_ids, [0, 1, 2], "Sparse L2 ranking should favor closest match") + print("✓ Sparse vector recall (L2) verified", result) + + def test_hybrid_dense_sparse_mix(self): + """Test hybrid scoring combines dense and sparse signals""" + print("\n=== Test: Hybrid Dense+Sparse Mix ===") + + dim = 4 + meta_data = { + "CollectionName": "test_hybrid_mix", + "Fields": [ + {"FieldName": "id", "FieldType": "int64", "IsPrimaryKey": True}, + {"FieldName": "vector", "FieldType": "vector", "Dim": dim}, + {"FieldName": "sparse_vector", "FieldType": "sparse_vector"}, + ], + } + + collection = self.register_collection( + get_or_create_local_collection(meta_data=meta_data, path=TEST_DB_PATH) + ) + + data = [ + {"id": 0, "vector": [0.9, 0.0, 0.0, 0.0], "sparse_vector": {"t1": 0.1}}, + {"id": 1, "vector": [0.2, 0.0, 0.0, 0.0], "sparse_vector": {"t1": 1.0}}, + {"id": 2, "vector": [0.1, 0.0, 0.0, 0.0], "sparse_vector": {"t1": 0.8}}, + ] + collection.upsert_data(data) + + collection.create_index( + "idx_hybrid_mix", + { + "IndexName": "idx_hybrid_mix", + "VectorIndex": { + "IndexType": "flat_hybrid", + "Distance": "ip", + "SearchWithSparseLogitAlpha": 0.5, + }, + }, + ) + + result = collection.search_by_vector( + "idx_hybrid_mix", + dense_vector=[1.0, 0.0, 0.0, 0.0], + sparse_vector={"t1": 1.0}, + limit=3, + ) + result_ids = [item.id for item in result.data] + + self.assertEqual( + result_ids, + [1, 0, 2], + "Hybrid ranking should reflect combined dense and sparse scores", + ) + print("✓ Hybrid dense+sparse mix verified") + if __name__ == "__main__": unittest.main() diff --git a/tests/vectordb/test_vikingdb_project.py b/tests/vectordb/test_vikingdb_project.py new file mode 100644 index 00000000..14f5867a --- /dev/null +++ b/tests/vectordb/test_vikingdb_project.py @@ -0,0 +1,96 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +import json +import sys +import unittest + +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.vikingdb_collection import VikingDBCollection +from openviking.storage.vectordb.project.vikingdb_project import get_or_create_vikingdb_project + + +@unittest.skipUnless(sys.platform == "darwin", "Only run on macOS") +class TestVikingDBProject(unittest.TestCase): + """ + Unit tests for VikingDB Project and Collection implementation for private deployment. + """ + + def setUp(self): + self.config = { + "Host": "http://localhost:8080", + "Headers": { + "X-Top-Account-Id": "1", + "X-Top-User-Id": "1000", + "X-Top-IdentityName": "test-user", + "X-Top-Role-Id": "data", + }, + } + self.project_name = "test_project" + meta_data = { + "Fields": [ + {"FieldName": "id", "FieldType": "string", "IsPrimaryKey": True}, + {"FieldName": "vector", "FieldType": "vector", "Dim": 128}, + {"FieldName": "text", "FieldType": "string"}, + ] + } + self.meta_data = meta_data + + def test_create_vikingdb_project(self): + """Test project initialization.""" + project = get_or_create_vikingdb_project(self.project_name, self.config) + self.assertEqual(project.project_name, self.project_name) + self.assertEqual(project.host, self.config["Host"]) + self.assertEqual(project.headers, self.config["Headers"]) + + def test_create_collection(self): + """Test collection creation with custom headers.""" + project = get_or_create_vikingdb_project(self.project_name, self.config) + meta_data = self.meta_data + + collection = project.create_collection("test_coll", meta_data) + + self.assertIsNotNone(collection) + self.assertIn("test_coll", project.list_collections()) + + def test_upsert_data(self): + """Test data upsert with custom headers and path.""" + project = get_or_create_vikingdb_project(self.project_name, self.config) + + # Get existing or create new collection + meta_data = self.meta_data + collection = project.get_or_create_collection("test_coll", meta_data) + + data = [{"id": "1", "vector": [0.1] * 128, "text": "123"}] + res = collection.upsert_data(data) + self.assertIsNone(res) + + def test_fetch_data(self): + """Test data fetching.""" + project = get_or_create_vikingdb_project(self.project_name, self.config) + + collection = project.get_or_create_collection("test_coll", self.meta_data) + + # Upsert some data first to fetch it + data = [{"id": "1", "vector": [0.1] * 128, "text": "hello"}] + collection.upsert_data(data) + + result = collection.fetch_data(["1"]) + + self.assertEqual(len(result.items), 1) + self.assertEqual(result.items[0].id, "1") + self.assertEqual(result.items[0].fields["text"], "hello") + + def test_drop_collection(self): + """Test collection dropping.""" + project = get_or_create_vikingdb_project(self.project_name, self.config) + + collection = project.get_or_create_collection("test_coll", self.meta_data) + if not collection: + assert False, "Collection should exist after creation" + + collection.drop() + collection = project.get_collection("test_coll") + self.assertIsNone(collection) + +if __name__ == "__main__": + unittest.main()