From af201496d6117959ff35cd0fe9707ad1a1dd99a1 Mon Sep 17 00:00:00 2001 From: yepper Date: Thu, 12 Mar 2026 09:29:03 +0800 Subject: [PATCH 1/5] feat(resource): implement incremental update with COW pattern Add support for incremental updates using copy-on-write pattern. Key changes include: - Add ResourceLockManager for managing concurrent updates - Introduce EmbeddingTaskTracker to track embedding task completion - Modify TreeBuilder to support temp URIs and skip conflict resolution - Update SemanticDagExecutor to handle incremental updates - Extend memory extractor/compressor to work with temp URIs - Add exists() method to VikingFS for URI existence checks - Update Context to include temp_uri field --- openviking/core/context.py | 4 + openviking/parse/parsers/constants.py | 3 + openviking/parse/parsers/directory.py | 5 + openviking/parse/parsers/upload_utils.py | 1 + openviking/parse/tree_builder.py | 101 +-- openviking/resource/__init__.py | 16 + openviking/resource/resource_lock.py | 414 ++++++++++++ openviking/server/models.py | 1 + openviking/service/core.py | 16 +- openviking/session/compressor.py | 149 ++++- openviking/session/memory_deduplicator.py | 39 +- openviking/session/memory_extractor.py | 88 ++- openviking/session/session.py | 357 +++++++++-- openviking/storage/collection_schemas.py | 8 + openviking/storage/queuefs/__init__.py | 2 + openviking/storage/queuefs/embedding_msg.py | 12 +- .../storage/queuefs/embedding_tracker.py | 181 ++++++ openviking/storage/queuefs/semantic_dag.py | 174 ++++- openviking/storage/queuefs/semantic_msg.py | 8 +- .../storage/queuefs/semantic_processor.py | 596 +++++++++++++----- openviking/storage/viking_fs.py | 15 + openviking/utils/embedding_utils.py | 5 + openviking/utils/resource_processor.py | 7 +- openviking/utils/summarizer.py | 13 +- openviking_cli/exceptions.py | 8 + 25 files changed, 1825 insertions(+), 398 deletions(-) create mode 100644 openviking/resource/__init__.py create mode 100644 openviking/resource/resource_lock.py create mode 100644 openviking/storage/queuefs/embedding_tracker.py diff --git a/openviking/core/context.py b/openviking/core/context.py index 94d47b1f..76308570 100644 --- a/openviking/core/context.py +++ b/openviking/core/context.py @@ -56,6 +56,7 @@ def __init__( self, uri: str, parent_uri: Optional[str] = None, + temp_uri: Optional[str] = None, is_leaf: bool = False, abstract: str = "", context_type: Optional[str] = None, @@ -78,6 +79,7 @@ def __init__( self.id = id or str(uuid4()) self.uri = uri self.parent_uri = parent_uri + self.temp_uri = temp_uri self.is_leaf = is_leaf self.abstract = abstract self.context_type = context_type or self._derive_context_type() @@ -159,6 +161,7 @@ def to_dict(self) -> Dict[str, Any]: "id": self.id, "uri": self.uri, "parent_uri": self.parent_uri, + "temp_uri": self.temp_uri, "is_leaf": self.is_leaf, "abstract": self.abstract, "context_type": self.context_type, @@ -194,6 +197,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "Context": obj = cls( uri=data["uri"], parent_uri=data.get("parent_uri"), + temp_uri=data.get("temp_uri"), is_leaf=data.get("is_leaf", False), abstract=data.get("abstract", ""), context_type=data.get("context_type"), diff --git a/openviking/parse/parsers/constants.py b/openviking/parse/parsers/constants.py index 311e545a..843665e1 100644 --- a/openviking/parse/parsers/constants.py +++ b/openviking/parse/parsers/constants.py @@ -174,6 +174,8 @@ ".graphql", ".gql", ".prisma", + ".thrift", + ".conf" } # Documentation file extensions for file type detection @@ -224,6 +226,7 @@ ".yarnrc", ".env", ".env.example", + ".jsonl", } # Common text encodings to try for encoding detection (in order of likelihood) diff --git a/openviking/parse/parsers/directory.py b/openviking/parse/parsers/directory.py index da27b659..371cf0f9 100644 --- a/openviking/parse/parsers/directory.py +++ b/openviking/parse/parsers/directory.py @@ -123,6 +123,11 @@ async def parse( viking_fs = self._get_viking_fs() temp_uri = self._create_temp_uri() target_uri = f"{temp_uri}/{dir_name}" + logger.info( + f"Scanning directory: {source_path}, " + f"processable files: {len(processable_files)}, " + f"warnings: {warnings}" + ) await viking_fs.mkdir(temp_uri, exist_ok=True) await viking_fs.mkdir(target_uri, exist_ok=True) diff --git a/openviking/parse/parsers/upload_utils.py b/openviking/parse/parsers/upload_utils.py index 50c80ed0..d1870173 100644 --- a/openviking/parse/parsers/upload_utils.py +++ b/openviking/parse/parsers/upload_utils.py @@ -40,6 +40,7 @@ "NEWS", "NOTICE", "TODO", + "BUILD", } diff --git a/openviking/parse/tree_builder.py b/openviking/parse/tree_builder.py index c8409f41..87684ccf 100644 --- a/openviking/parse/tree_builder.py +++ b/openviking/parse/tree_builder.py @@ -138,16 +138,6 @@ async def finalize_from_temp( base_uri = parent_uri or auto_base_uri # 3. Determine candidate_uri if to_uri: - # Exact target URI: must not exist yet - try: - await viking_fs.stat(to_uri, ctx=ctx) - # If we get here, it already exists - raise FileExistsError(f"Target URI already exists: {to_uri}") - except FileExistsError: - raise - except Exception: - # It doesn't exist, good to use - pass candidate_uri = to_uri else: if parent_uri: @@ -160,34 +150,7 @@ async def finalize_from_temp( raise ValueError(f"Parent URI is not a directory: {parent_uri}") candidate_uri = VikingURI(base_uri).join(final_doc_name).uri - final_uri = await self._resolve_unique_uri(candidate_uri, ctx=ctx) - - if final_uri != candidate_uri: - logger.info(f"[TreeBuilder] Resolved name conflict: {candidate_uri} -> {final_uri}") - else: - logger.info(f"[TreeBuilder] Finalizing from temp: {final_uri}") - - # 4. Move directory tree from temp to final location in AGFS - await self._move_temp_to_dest(viking_fs, temp_doc_uri, final_uri, ctx=ctx) - logger.info(f"[TreeBuilder] Moved temp tree: {temp_doc_uri} -> {final_uri}") - - # 5. Cleanup temporary root directory - try: - await viking_fs.delete_temp(temp_uri, ctx=ctx) - logger.info(f"[TreeBuilder] Cleaned up temp root: {temp_uri}") - except Exception as e: - logger.warning(f"[TreeBuilder] Failed to cleanup temp root: {e}") - - # 6. Enqueue to SemanticQueue for async semantic generation - if trigger_semantic: - try: - await self._enqueue_semantic_generation(final_uri, "resource", ctx=ctx) - logger.info(f"[TreeBuilder] Enqueued semantic generation for: {final_uri}") - except Exception as e: - logger.error( - f"[TreeBuilder] Failed to enqueue semantic generation: {e}", exc_info=True - ) - + final_uri = candidate_uri # 7. Return simple BuildingTree (no scanning needed) tree = BuildingTree( source_path=source_path, @@ -196,38 +159,12 @@ async def finalize_from_temp( tree._root_uri = final_uri # Create a minimal Context object for the root so that tree.root is not None - root_context = Context(uri=final_uri) + root_context = Context(uri=final_uri, temp_uri=temp_doc_uri) tree.add_context(root_context) return tree - async def _resolve_unique_uri( - self, uri: str, max_attempts: int = 100, ctx: Optional[RequestContext] = None - ) -> str: - """Return a URI that does not collide with an existing resource. - - If *uri* is free, return it unchanged. Otherwise append ``_1``, - ``_2``, … until a free name is found (like macOS Finder / Windows - Explorer). - """ - viking_fs = get_viking_fs() - - async def _exists(u: str) -> bool: - try: - await viking_fs.stat(u, ctx=ctx) - return True - except Exception: - return False - - if not await _exists(uri): - return uri - - for i in range(1, max_attempts + 1): - candidate = f"{uri}_{i}" - if not await _exists(candidate): - return candidate - - raise FileExistsError(f"Cannot resolve unique name for {uri} after {max_attempts} attempts") + async def _move_temp_to_dest( self, viking_fs, src_uri: str, dst_uri: str, ctx: RequestContext @@ -261,7 +198,7 @@ async def _ensure_parent_dirs(self, uri: str, ctx: RequestContext) -> None: logger.debug(f"Parent dir {parent_uri} may already exist: {e}") async def _enqueue_semantic_generation( - self, uri: str, context_type: str, ctx: RequestContext + self, uri: str, final_uri: str, context_type: str, ctx: RequestContext ) -> None: """ Enqueue a directory for semantic generation. @@ -284,32 +221,6 @@ async def _enqueue_semantic_generation( user_id=ctx.user.user_id, agent_id=ctx.user.agent_id, role=ctx.role.value, + target_uri=final_uri, ) - await semantic_queue.enqueue(msg) - - async def _load_content(self, uri: str, content_type: str) -> str: - """Helper to load content with proper type handling""" - import json - - if content_type == "abstract": - result = await get_viking_fs().abstract(uri) - elif content_type == "overview": - result = await get_viking_fs().overview(uri) - elif content_type == "detail": - result = await get_viking_fs().read_file(uri) - else: - return "" - - # Handle different return types - if isinstance(result, str): - return result - elif isinstance(result, bytes): - return result.decode("utf-8") - elif hasattr(result, "to_dict") and not isinstance(result, list): - # Handle FindResult by converting to dict (skip lists) - return str(result.to_dict()) - elif isinstance(result, list): - # Handle list results - return json.dumps(result) - else: - return str(result) + await semantic_queue.enqueue(msg) \ No newline at end of file diff --git a/openviking/resource/__init__.py b/openviking/resource/__init__.py new file mode 100644 index 00000000..a6f8e73c --- /dev/null +++ b/openviking/resource/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Resource management modules for incremental updates.""" + +from openviking.resource.resource_lock import ( + ResourceLockManager, + ResourceLockConflictError, + ResourceLockError, +) + +__all__ = [ + "ResourceLockManager", + "ResourceLockConflictError", + "ResourceLockError", + "UpdateContext", +] diff --git a/openviking/resource/resource_lock.py b/openviking/resource/resource_lock.py new file mode 100644 index 00000000..2f8f12ad --- /dev/null +++ b/openviking/resource/resource_lock.py @@ -0,0 +1,414 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +""" +Resource-level mutex lock management. + +Implements resource URI-level mutual exclusion to prevent concurrent operations +on the same resource. Uses file-based locks stored in the AGFS filesystem. +""" + +import json +import os +import time +import uuid +from contextlib import contextmanager +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Optional + +from openviking_cli.utils import get_logger +from openviking_cli.utils.uri import VikingURI + +logger = get_logger(__name__) + + +@dataclass +class LockInfo: + """Lock metadata stored in lock file.""" + + lock_id: str + resource_uri: str + operation: str + created_at: float + expires_at: Optional[float] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "lock_id": self.lock_id, + "resource_uri": self.resource_uri, + "operation": self.operation, + "created_at": self.created_at, + "expires_at": self.expires_at, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "LockInfo": + return cls(**data) + + def is_expired(self) -> bool: + if self.expires_at is None: + return False + return time.time() > self.expires_at + + +class ResourceLockError(Exception): + """Base exception for resource lock errors.""" + pass + + +class ResourceLockConflictError(ResourceLockError): + """Raised when attempting to lock a resource that is already locked.""" + + def __init__(self, resource_uri: str, lock_info: Optional[LockInfo] = None): + self.resource_uri = resource_uri + self.lock_info = lock_info + message = f"Resource '{resource_uri}' is locked" + if lock_info: + message += f" by operation '{lock_info.operation}' (lock_id: {lock_info.lock_id})" + super().__init__(message) + + +class ResourceLockManager: + """ + Manages resource-level mutex locks using file-based storage. + + Lock files are stored under `.locks/` directory in the AGFS root. + Each lock file is named after the resource URI (with path separators replaced). + + Features: + - Atomic lock acquisition via file creation + - Lock expiration detection + - Automatic cleanup of expired locks + - Service restart cleanup + """ + + LOCK_DIR = ".locks" + LOCK_FILE_SUFFIX = ".lock" + DEFAULT_TTL = 3600 + AGFS_MOUNT_PATH = "/local" + + def __init__(self, agfs: Any, default_ttl: Optional[int] = None): + """ + Initialize ResourceLockManager. + + Args: + agfs: AGFS client instance + default_ttl: Default lock TTL in seconds (default: 3600) + """ + self._agfs = agfs + self._default_ttl = default_ttl or self.DEFAULT_TTL + self._lock_dir_path = f"{self.AGFS_MOUNT_PATH}/{self.LOCK_DIR}" + + def _get_lock_file_path(self, resource_uri: str) -> str: + """ + Get lock file path for a resource URI. + + Args: + resource_uri: Resource URI (e.g., "viking://default/resources/my-repo") + + Returns: + Lock file path (e.g., "/local/.locks/viking___default___resources___my-repo.lock") + """ + safe_uri = resource_uri.replace("://", "___").replace("/", "___").replace(".", "_") + return f"{self._lock_dir_path}/{safe_uri}{self.LOCK_FILE_SUFFIX}" + + def _ensure_lock_dir(self) -> None: + """Ensure lock directory exists.""" + try: + if not self.exists(self._lock_dir_path): + self._agfs.mkdir(self._lock_dir_path) + logger.info(f"Created lock directory: {self._lock_dir_path}") + except Exception as e: + logger.warning(f"Failed to ensure lock directory: {e}") + + def acquire_lock( + self, + resource_uri: str, + operation: str, + ttl: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> LockInfo: + """ + Acquire a lock on a resource URI. + + Args: + resource_uri: Resource URI to lock + operation: Operation name (e.g., "incremental_update", "full_update") + ttl: Lock TTL in seconds (default: use default_ttl) + metadata: Additional metadata to store with lock + + Returns: + LockInfo for the acquired lock + + Raises: + ResourceLockConflictError: If resource is already locked + """ + self._ensure_lock_dir() + + lock_file = self._get_lock_file_path(resource_uri) + ttl = ttl or self._default_ttl + + current_time = time.time() + lock_info = LockInfo( + lock_id=str(uuid.uuid4()), + resource_uri=resource_uri, + operation=operation, + created_at=current_time, + expires_at=current_time + ttl if ttl > 0 else None, + metadata=metadata or {}, + ) + + try: + if self.exists(lock_file): + existing_lock = self._read_lock(lock_file) + if existing_lock and not existing_lock.is_expired(): + logger.warning( + f"Lock conflict: resource={resource_uri}, " + f"existing_lock_id={existing_lock.lock_id}, " + f"operation={existing_lock.operation}" + ) + raise ResourceLockConflictError(resource_uri, existing_lock) + + logger.info(f"Removing expired lock: {lock_file}") + self._agfs.rm(lock_file) + + self._agfs.write(lock_file, json.dumps(lock_info.to_dict()).encode('utf-8')) + + logger.info( + f"Acquired lock: resource={resource_uri}, " + f"lock_id={lock_info.lock_id}, " + f"operation={operation}, " + f"ttl={ttl}s" + ) + + return lock_info + + except ResourceLockConflictError: + raise + except Exception as e: + logger.error(f"Failed to acquire lock for {resource_uri}: {e}") + raise ResourceLockError(f"Failed to acquire lock: {e}") from e + + def release_lock(self, resource_uri: str, lock_id: Optional[str] = None) -> bool: + """ + Release a lock on a resource URI. + + Args: + resource_uri: Resource URI to unlock + lock_id: Optional lock ID to verify ownership + + Returns: + True if lock was released, False if lock didn't exist + """ + lock_file = self._get_lock_file_path(resource_uri) + + try: + if not self.exists(lock_file): + return False + + if lock_id: + existing_lock = self._read_lock(lock_file) + if existing_lock and existing_lock.lock_id != lock_id: + logger.warning( + f"Lock ID mismatch: expected={lock_id}, " + f"actual={existing_lock.lock_id}" + ) + return False + + self._agfs.rm(lock_file) + logger.info(f"Released lock: resource={resource_uri}, lock_id={lock_id}") + return True + + except Exception as e: + logger.error(f"Failed to release lock for {resource_uri}: {e}") + return False + + def is_locked(self, resource_uri: str) -> bool: + """ + Check if a resource URI is locked. + + Args: + resource_uri: Resource URI to check + + Returns: + True if resource is locked, False otherwise + """ + lock_file = self._get_lock_file_path(resource_uri) + + try: + if not self.exists(lock_file): + return False + + lock_info = self._read_lock(lock_file) + if not lock_info: + return False + + if lock_info.is_expired(): + logger.info(f"Found expired lock: {lock_file}") + self._agfs.rm(lock_file) + return False + + return True + + except Exception as e: + logger.error(f"Failed to check lock for {resource_uri}: {e}") + return False + + def get_lock_info(self, resource_uri: str) -> Optional[LockInfo]: + """ + Get lock information for a resource URI. + + Args: + resource_uri: Resource URI to check + + Returns: + LockInfo if resource is locked, None otherwise + """ + lock_file = self._get_lock_file_path(resource_uri) + + try: + if not self.exists(lock_file): + return None + + lock_info = self._read_lock(lock_file) + if not lock_info: + return None + + if lock_info.is_expired(): + logger.info(f"Found expired lock: {lock_file}") + self._agfs.rm(lock_file) + return None + + return lock_info + + except Exception as e: + logger.error(f"Failed to get lock info for {resource_uri}: {e}") + return None + + def _read_lock(self, lock_file: str) -> Optional[LockInfo]: + """Read lock information from a lock file.""" + try: + data = self._agfs.read(lock_file) + lock_dict = json.loads(data.decode('utf-8')) + return LockInfo.from_dict(lock_dict) + except Exception as e: + logger.error(f"Failed to read lock file {lock_file}: {e}") + return None + + def cleanup_expired_locks(self) -> int: + """ + Clean up all expired locks. + + Returns: + Number of locks cleaned up + """ + cleaned = 0 + + try: + if not self.exists(self._lock_dir_path): + return 0 + + lock_files = self._agfs.ls(self._lock_dir_path) + + for file_info in lock_files: + lock_file = file_info.get("name", "") + if not lock_file.endswith(self.LOCK_FILE_SUFFIX): + continue + + lock_path = f"{self._lock_dir_path}/{lock_file}" + lock_info = self._read_lock(lock_path) + + if not lock_info or lock_info.is_expired(): + self._agfs.rm(lock_path) + cleaned += 1 + logger.info(f"Cleaned up expired lock: {lock_path}") + + if cleaned > 0: + logger.info(f"Cleaned up {cleaned} expired locks") + + return cleaned + + except Exception as e: + logger.error(f"Failed to cleanup expired locks: {e}") + return cleaned + + def cleanup_all_locks(self) -> int: + """ + Clean up all locks (for service restart). + + Returns: + Number of locks cleaned up + """ + cleaned = 0 + + try: + if not self.exists(self._lock_dir_path): + return 0 + + lock_files = self._agfs.ls(self._lock_dir_path) + + for file_info in lock_files: + lock_file = file_info.get("name", "") + if not lock_file.endswith(self.LOCK_FILE_SUFFIX): + continue + + lock_path = f"{self._lock_dir_path}/{lock_file}" + self._agfs.rm(lock_path) + cleaned += 1 + + if cleaned > 0: + logger.info(f"Cleaned up {cleaned} locks on service restart") + + return cleaned + + except Exception as e: + logger.error(f"Failed to cleanup all locks: {e}") + return cleaned + + def exists(self, uri: str) -> bool: + """ + Check if a URI exists using AGFS stat interface. + + Args: + uri: URI to check (e.g., "viking://default/resources/my-repo") + + Returns: + True if URI exists, False otherwise + """ + try: + self._agfs.stat(uri) + return True + except Exception as e: + return False + + @contextmanager + def lock( + self, + resource_uri: str, + operation: str, + ttl: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + """ + Context manager for acquiring and releasing a lock. + + Args: + resource_uri: Resource URI to lock + operation: Operation name + ttl: Lock TTL in seconds + metadata: Additional metadata + + Yields: + LockInfo for the acquired lock + + Raises: + ResourceLockConflictError: If resource is already locked + """ + lock_info = self.acquire_lock(resource_uri, operation, ttl, metadata) + try: + yield lock_info + finally: + self.release_lock(resource_uri, lock_info.lock_id) diff --git a/openviking/server/models.py b/openviking/server/models.py index 4cb7f967..26b6e5ec 100644 --- a/openviking/server/models.py +++ b/openviking/server/models.py @@ -39,6 +39,7 @@ class Response(BaseModel): "INVALID_URI": 400, "NOT_FOUND": 404, "ALREADY_EXISTS": 409, + "CONFLICT": 409, "PERMISSION_DENIED": 403, "UNAUTHENTICATED": 401, "RESOURCE_EXHAUSTED": 429, diff --git a/openviking/service/core.py b/openviking/service/core.py index 5764fed1..9ab1de9a 100644 --- a/openviking/service/core.py +++ b/openviking/service/core.py @@ -11,6 +11,7 @@ from openviking.agfs_manager import AGFSManager from openviking.core.directories import DirectoryInitializer +from openviking.resource.resource_lock import ResourceLockManager from openviking.server.identity import RequestContext, Role from openviking.service.debug_service import DebugService from openviking.service.fs_service import FSService @@ -77,6 +78,7 @@ def __init__( self._session_compressor: Optional[SessionCompressor] = None self._transaction_manager: Optional[TransactionManager] = None self._directory_initializer: Optional[DirectoryInitializer] = None + self._lock_manager: Optional[ResourceLockManager] = None # Sub-services self._fs_service = FSService() @@ -142,6 +144,11 @@ def _init_storage( # Initialize TransactionManager self._transaction_manager = init_transaction_manager(agfs=self._agfs_client) + + # Initialize ResourceLockManager + if self._agfs_client: + self._lock_manager = ResourceLockManager(agfs=self._agfs_client) + logger.info("ResourceLockManager initialized") @property def _agfs(self) -> Any: @@ -254,8 +261,15 @@ async def initialize(self) -> None: user_count, ) + # Clean up all locks on service startup + if self._lock_manager: + cleaned_count = self._lock_manager.cleanup_all_locks() + logger.info(f"Cleaned up {cleaned_count} locks on service startup") + # Initialize processors - self._resource_processor = ResourceProcessor(vikingdb=self._vikingdb_manager) + self._resource_processor = ResourceProcessor( + vikingdb=self._vikingdb_manager, + ) self._skill_processor = SkillProcessor(vikingdb=self._vikingdb_manager) self._session_compressor = SessionCompressor(vikingdb=self._vikingdb_manager) diff --git a/openviking/session/compressor.py b/openviking/session/compressor.py index 3f7eb432..31c75710 100644 --- a/openviking/session/compressor.py +++ b/openviking/session/compressor.py @@ -78,16 +78,62 @@ async def _index_memory(self, memory: Context, ctx: RequestContext) -> bool: await self.extractor._enqueue_semantic_for_parent(memory.uri, ctx) return True + def _convert_to_temp_uri( + self, + target_uri: str, + user_temp_uri: Optional[str], + agent_temp_uri: Optional[str] + ) -> str: + """Convert target URI to temp URI for COW pattern. + + Args: + target_uri: Target URI (e.g., viking://user/... or viking://agent/...) + user_temp_uri: Temp user URI (if available) + agent_temp_uri: Temp agent URI (if available) + + Returns: + Converted temp URI, or original URI if no temp available + """ + if not user_temp_uri and not agent_temp_uri: + return target_uri + + # Convert user URI + if target_uri.startswith("viking://user/") and user_temp_uri: + # viking://user/{user_space}/memories/... -> {user_temp_uri}/memories/... + parts = target_uri.split("/") + if len(parts) >= 5: + # parts[0]="viking:", parts[1]="", parts[2]="user", parts[3]="{user_space}", parts[4:]="memories/..." + rest = "/".join(parts[4:]) + return f"{user_temp_uri}/{rest}" + + # Convert agent URI + if target_uri.startswith("viking://agent/") and agent_temp_uri: + # viking://agent/{agent_space}/memories/... -> {agent_temp_uri}/memories/... + parts = target_uri.split("/") + if len(parts) >= 5: + # parts[0]="viking:", parts[1]="", parts[2]="agent", parts[3]="{agent_space}", parts[4:]="memories/..." + rest = "/".join(parts[4:]) + return f"{agent_temp_uri}/{rest}" + + return target_uri + async def _merge_into_existing( self, candidate: CandidateMemory, target_memory: Context, viking_fs, ctx: RequestContext, + user_temp_uri: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> bool: """Merge candidate content into an existing memory file.""" try: - existing_content = await viking_fs.read_file(target_memory.uri, ctx=ctx) + # Convert target URI to temp URI for COW pattern + temp_uri = self._convert_to_temp_uri( + target_memory.uri, user_temp_uri, agent_temp_uri + ) + + existing_content = await viking_fs.read_file(temp_uri, ctx=ctx) payload = await self.extractor._merge_memory_bundle( existing_abstract=target_memory.abstract, existing_overview=(target_memory.meta or {}).get("overview") or "", @@ -101,34 +147,42 @@ async def _merge_into_existing( if not payload: return False - await viking_fs.write_file(target_memory.uri, payload.content, ctx=ctx) + await viking_fs.write_file(temp_uri, payload.content, ctx=ctx) target_memory.abstract = payload.abstract target_memory.meta = {**(target_memory.meta or {}), "overview": payload.overview} logger.info( - "Merged memory %s with abstract %s", target_memory.uri, target_memory.abstract + "Merged memory %s with abstract %s", temp_uri, target_memory.abstract ) target_memory.set_vectorize(Vectorize(text=payload.content)) - await self._index_memory(target_memory, ctx) + # Note: vectorization will be handled by SemanticQueue after directory switch + # await self._index_memory(target_memory, ctx) return True except Exception as e: logger.error(f"Failed to merge memory {target_memory.uri}: {e}") return False async def _delete_existing_memory( - self, memory: Context, viking_fs, ctx: RequestContext + self, memory: Context, viking_fs, ctx: RequestContext, + user_temp_uri: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> bool: """Hard delete an existing memory file and clean up its vector record.""" try: - await viking_fs.rm(memory.uri, recursive=False, ctx=ctx) + # Convert target URI to temp URI for COW pattern + temp_uri = self._convert_to_temp_uri( + memory.uri, user_temp_uri, agent_temp_uri + ) + + await viking_fs.rm(temp_uri, recursive=False, ctx=ctx) except Exception as e: - logger.error(f"Failed to delete memory file {memory.uri}: {e}") + logger.error(f"Failed to delete memory file {temp_uri}: {e}") return False try: # rm() already syncs vector deletion in most cases; keep this as a safe fallback. - await self.vikingdb.delete_uris(ctx, [memory.uri]) + await self.vikingdb.delete_uris(ctx, [temp_uri]) except Exception as e: - logger.warning(f"Failed to remove vector record for {memory.uri}: {e}") + logger.warning(f"Failed to remove vector record for {temp_uri}: {e}") return True async def extract_long_term_memories( @@ -137,8 +191,24 @@ async def extract_long_term_memories( user: Optional["UserIdentifier"] = None, session_id: Optional[str] = None, ctx: Optional[RequestContext] = None, + user_temp_uri: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> List[Context]: - """Extract long-term memories from messages.""" + """Extract long-term memories from messages. + + Args: + messages: Messages to extract from + user: User identifier + session_id: Session ID + ctx: Request context + user_temp_uri: Temp user URI (for COW pattern). If provided, user memories + will be written to this temp location. + agent_temp_uri: Temp agent URI (for COW pattern). If provided, agent memories + will be written to this temp location. + + Returns: + List of extracted memories + """ if not messages: return [] @@ -164,11 +234,15 @@ async def extract_long_term_memories( for candidate in candidates: # Profile: skip dedup, always merge if candidate.category in ALWAYS_MERGE_CATEGORIES: - memory = await self.extractor.create_memory(candidate, user, session_id, ctx=ctx) + memory = await self.extractor.create_memory( + candidate, user, session_id, ctx=ctx, + user_temp_uri=user_temp_uri, agent_temp_uri=agent_temp_uri + ) if memory: memories.append(memory) stats.created += 1 - await self._index_memory(memory, ctx) + # Note: vectorization will be handled by SemanticQueue after directory switch + # await self._index_memory(memory, ctx) else: stats.skipped += 1 continue @@ -207,11 +281,11 @@ async def extract_long_term_memories( ) if skill_name: memory = await self.extractor._merge_skill_memory( - skill_name, candidate, ctx=ctx + skill_name, candidate, ctx=ctx, agent_temp_uri=agent_temp_uri ) elif tool_name: memory = await self.extractor._merge_tool_memory( - tool_name, candidate, ctx=ctx + tool_name, candidate, ctx=ctx, agent_temp_uri=agent_temp_uri ) else: logger.warning("No tool_name or skill_name found, skipping") @@ -220,7 +294,8 @@ async def extract_long_term_memories( if memory: memories.append(memory) stats.merged += 1 - await self._index_memory(memory, ctx) + # Note: vectorization will be handled by SemanticQueue after directory switch + # await self._index_memory(memory, ctx) continue # Dedup check for other categories @@ -250,7 +325,9 @@ async def extract_long_term_memories( for action in actions: if action.decision == MemoryActionDecision.DELETE: if viking_fs and await self._delete_existing_memory( - action.memory, viking_fs, ctx=ctx + action.memory, viking_fs, ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri ): stats.deleted += 1 else: @@ -258,13 +335,14 @@ async def extract_long_term_memories( elif action.decision == MemoryActionDecision.MERGE: if candidate.category in MERGE_SUPPORTED_CATEGORIES and viking_fs: if await self._merge_into_existing( - candidate, action.memory, viking_fs, ctx=ctx + candidate, action.memory, viking_fs, ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri ): stats.merged += 1 else: stats.skipped += 1 else: - # events/cases don't support MERGE, treat as SKIP stats.skipped += 1 continue @@ -273,24 +351,51 @@ async def extract_long_term_memories( for action in actions: if action.decision == MemoryActionDecision.DELETE: if viking_fs and await self._delete_existing_memory( - action.memory, viking_fs, ctx=ctx + action.memory, viking_fs, ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri ): stats.deleted += 1 else: stats.skipped += 1 - memory = await self.extractor.create_memory(candidate, user, session_id, ctx=ctx) + memory = await self.extractor.create_memory( + candidate, user, session_id, ctx=ctx, + user_temp_uri=user_temp_uri, agent_temp_uri=agent_temp_uri + ) if memory: memories.append(memory) stats.created += 1 - await self._index_memory(memory, ctx) + # Note: vectorization will be handled by SemanticQueue after directory switch + # await self._index_memory(memory, ctx) else: stats.skipped += 1 # Extract URIs used in messages, create relations used_uris = self._extract_used_uris(messages) if used_uris and memories: - await self._create_relations(memories, used_uris, ctx=ctx) + # Convert memory URIs from temp to target for relation creation + target_memories = [] + for memory in memories: + # Create a copy with target URI + target_uri = memory.uri + # If memory.uri is a temp URI, convert it to target URI + if user_temp_uri and memory.uri.startswith(user_temp_uri): + target_uri = memory.uri.replace(user_temp_uri, f"viking://user/{ctx.user.user_space_name()}") + elif agent_temp_uri and memory.uri.startswith(agent_temp_uri): + target_uri = memory.uri.replace(agent_temp_uri, f"viking://agent/{ctx.user.agent_space_name()}") + + # Create a new Context with target URI for relation creation + from openviking_cli.context import Context + target_memory = Context( + uri=target_uri, + context_type=memory.context_type, + abstract=memory.abstract, + meta=memory.meta, + ) + target_memories.append(target_memory) + + await self._create_relations(target_memories, used_uris, ctx=ctx) logger.info( f"Memory extraction: created={stats.created}, " diff --git a/openviking/session/memory_deduplicator.py b/openviking/session/memory_deduplicator.py index c119ecb8..df99fb84 100644 --- a/openviking/session/memory_deduplicator.py +++ b/openviking/session/memory_deduplicator.py @@ -117,8 +117,19 @@ async def deduplicate( async def _find_similar_memories( self, candidate: CandidateMemory, + user_temp_uri: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> List[Context]: - """Find similar existing memories using vector search.""" + """Find similar existing memories using vector search. + + Args: + candidate: Candidate memory + user_temp_uri: Temp user URI (for COW pattern) + agent_temp_uri: Temp agent URI (for COW pattern) + + Returns: + List of similar memories with temp URIs (if temp URIs provided) + """ if not self.embedder: return [] @@ -127,6 +138,7 @@ async def _find_similar_memories( embed_result: EmbedResult = self.embedder.embed(query_text) query_vector = embed_result.dense_vector + # Search target URI (not temp URI) because vectors are stored for target URIs category_uri_prefix = self._category_uri_prefix(candidate.category.value, candidate.user) owner = candidate.user @@ -177,6 +189,31 @@ async def _find_similar_memories( if context: # Keep retrieval score for later destructive-action guardrails. context.meta = {**(context.meta or {}), "_dedup_score": score} + + # Convert target URI to temp URI (for COW pattern) + if user_temp_uri or agent_temp_uri: + original_uri = context.uri + # Convert user URI + if (user_temp_uri and + original_uri.startswith("viking://user/")): + parts = original_uri.split("/") + if len(parts) >= 5: + rest = "/".join(parts[4:]) + context.uri = f"{user_temp_uri}/{rest}" + logger.debug( + f"Converted URI: {original_uri} -> {context.uri}" + ) + # Convert agent URI + elif (agent_temp_uri and + original_uri.startswith("viking://agent/")): + parts = original_uri.split("/") + if len(parts) >= 5: + rest = "/".join(parts[4:]) + context.uri = f"{agent_temp_uri}/{rest}" + logger.debug( + f"Converted URI: {original_uri} -> {context.uri}" + ) + similar.append(context) logger.debug("Dedup similar memories after threshold=%d", len(similar)) return similar diff --git a/openviking/session/memory_extractor.py b/openviking/session/memory_extractor.py index cf9ba87c..5199b6dd 100644 --- a/openviking/session/memory_extractor.py +++ b/openviking/session/memory_extractor.py @@ -391,8 +391,21 @@ async def create_memory( user: str, session_id: str, ctx: RequestContext, + user_temp_uri: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> Optional[Context]: - """Create Context object from candidate and persist to AGFS as .md file.""" + """Create Context object from candidate and persist to AGFS as .md file. + + Args: + candidate: Candidate memory to create + user: User identifier + session_id: Session ID + ctx: Request context + user_temp_uri: Temp user URI (for COW pattern). If provided, user memories + will be written to this temp location. + agent_temp_uri: Temp agent URI (for COW pattern). If provided, agent memories + will be written to this temp location. + """ viking_fs = get_viking_fs() if not viking_fs: logger.warning("VikingFS not available, skipping memory creation") @@ -402,14 +415,22 @@ async def create_memory( # Special handling for profile: append to profile.md if candidate.category == MemoryCategory.PROFILE: - payload = await self._append_to_profile(candidate, viking_fs, ctx=ctx) + payload = await self._append_to_profile( + candidate, viking_fs, ctx=ctx, user_temp_uri=user_temp_uri + ) if not payload: return None user_space = ctx.user.user_space_name() - memory_uri = f"viking://user/{user_space}/memories/profile.md" + # Use temp user URI if provided (for COW pattern) + if user_temp_uri: + memory_uri = f"{user_temp_uri}/memories/profile.md" + parent_uri = f"{user_temp_uri}/memories" + else: + memory_uri = f"viking://user/{user_space}/memories/profile.md" + parent_uri = f"viking://user/{user_space}/memories" memory = Context( uri=memory_uri, - parent_uri=f"viking://user/{user_space}/memories", + parent_uri=parent_uri, is_leaf=True, abstract=payload.abstract, context_type=ContextType.MEMORY.value, @@ -430,9 +451,17 @@ async def create_memory( MemoryCategory.ENTITIES, MemoryCategory.EVENTS, ]: - parent_uri = f"viking://user/{ctx.user.user_space_name()}/{cat_dir}" + # Use temp user URI if provided (for COW pattern) + if user_temp_uri: + parent_uri = f"{user_temp_uri}/{cat_dir}" + else: + parent_uri = f"viking://user/{ctx.user.user_space_name()}/{cat_dir}" else: # CASES, PATTERNS - parent_uri = f"viking://agent/{ctx.user.agent_space_name()}/{cat_dir}" + # Use temp agent URI if provided (for COW pattern) + if agent_temp_uri: + parent_uri = f"{agent_temp_uri}/{cat_dir}" + else: + parent_uri = f"viking://agent/{ctx.user.agent_space_name()}/{cat_dir}" # Generate file URI (store directly as .md file, no directory creation) memory_id = f"mem_{str(uuid4())}" @@ -468,9 +497,14 @@ async def _append_to_profile( candidate: CandidateMemory, viking_fs, ctx: RequestContext, + user_temp_uri: Optional[str] = None, ) -> Optional[MergedMemoryPayload]: """Update user profile - always merge with existing content.""" - uri = f"viking://user/{ctx.user.user_space_name()}/memories/profile.md" + # Use temp user URI if provided (for COW pattern) + if user_temp_uri: + uri = f"{user_temp_uri}/memories/profile.md" + else: + uri = f"viking://user/{ctx.user.user_space_name()}/memories/profile.md" existing = "" try: existing = await viking_fs.read_file(uri, ctx=ctx) or "" @@ -570,7 +604,7 @@ async def _merge_memory_bundle( return None async def _merge_tool_memory( - self, tool_name: str, candidate: CandidateMemory, ctx: "RequestContext" + self, tool_name: str, candidate: CandidateMemory, ctx: "RequestContext", agent_temp_uri: Optional[str] = None ) -> Optional[Context]: """合并 Tool Memory,统计数据用 Python 累加""" if not tool_name or not tool_name.strip(): @@ -578,7 +612,11 @@ async def _merge_tool_memory( return None agent_space = ctx.user.agent_space_name() - uri = f"viking://agent/{agent_space}/memories/tools/{tool_name}.md" + # Use temp agent URI if provided (for COW pattern) + if agent_temp_uri: + uri = f"{agent_temp_uri}/memories/tools/{tool_name}.md" + else: + uri = f"viking://agent/{agent_space}/memories/tools/{tool_name}.md" viking_fs = get_viking_fs() if not viking_fs: @@ -634,7 +672,7 @@ async def _merge_tool_memory( tool_name, merged_stats, new_guidelines, fields=new_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_tool_context(uri, candidate, ctx) + return self._create_tool_context(uri, candidate, ctx, agent_temp_uri=agent_temp_uri) existing_stats = self._parse_tool_statistics(existing) merged_stats = self._merge_tool_statistics(existing_stats, new_stats) @@ -692,7 +730,7 @@ async def _merge_tool_memory( tool_name, merged_stats, merged_guidelines, fields=merged_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_tool_context(uri, candidate, ctx, abstract_override=abstract_override) + return self._create_tool_context(uri, candidate, ctx, abstract_override=abstract_override, agent_temp_uri=agent_temp_uri) async def _enqueue_semantic_for_parent(self, file_uri: str, ctx: "RequestContext") -> None: """Enqueue semantic generation for parent directory.""" @@ -1108,12 +1146,18 @@ def _create_tool_context( candidate: CandidateMemory, ctx: "RequestContext", abstract_override: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> Context: """创建 Tool Memory 的 Context 对象""" agent_space = ctx.user.agent_space_name() + # Use temp agent URI if provided (for COW pattern) + if agent_temp_uri: + parent_uri = f"{agent_temp_uri}/memories/tools" + else: + parent_uri = f"viking://agent/{agent_space}/memories/tools" return Context( uri=uri, - parent_uri=f"viking://agent/{agent_space}/memories/tools", + parent_uri=parent_uri, is_leaf=True, abstract=abstract_override or candidate.abstract, context_type=ContextType.MEMORY.value, @@ -1145,7 +1189,7 @@ def _extract_tool_guidelines(self, content: str) -> str: return content.strip() async def _merge_skill_memory( - self, skill_name: str, candidate: CandidateMemory, ctx: "RequestContext" + self, skill_name: str, candidate: CandidateMemory, ctx: "RequestContext", agent_temp_uri: Optional[str] = None ) -> Optional[Context]: """合并 Skill Memory,统计数据用 Python 累加""" if not skill_name or not skill_name.strip(): @@ -1153,7 +1197,11 @@ async def _merge_skill_memory( return None agent_space = ctx.user.agent_space_name() - uri = f"viking://agent/{agent_space}/memories/skills/{skill_name}.md" + # Use temp agent URI if provided (for COW pattern) + if agent_temp_uri: + uri = f"{agent_temp_uri}/memories/skills/{skill_name}.md" + else: + uri = f"viking://agent/{agent_space}/memories/skills/{skill_name}.md" viking_fs = get_viking_fs() if not viking_fs: @@ -1222,7 +1270,7 @@ async def _merge_skill_memory( skill_name, merged_stats, new_guidelines, fields=new_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_skill_context(uri, candidate, ctx) + return self._create_skill_context(uri, candidate, ctx, agent_temp_uri=agent_temp_uri) existing_stats = self._parse_skill_statistics(existing) merged_stats = self._merge_skill_statistics(existing_stats, new_stats) @@ -1284,7 +1332,7 @@ async def _merge_skill_memory( skill_name, merged_stats, merged_guidelines, fields=merged_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_skill_context(uri, candidate, ctx, abstract_override=abstract_override) + return self._create_skill_context(uri, candidate, ctx, abstract_override=abstract_override, agent_temp_uri=agent_temp_uri) def _compute_skill_statistics_derived(self, stats: dict) -> dict: """计算 Skill 派生统计数据(成功率)""" @@ -1436,12 +1484,18 @@ def _create_skill_context( candidate: CandidateMemory, ctx: "RequestContext", abstract_override: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> Context: """创建 Skill Memory 的 Context 对象""" agent_space = ctx.user.agent_space_name() + # Use temp agent URI if provided (for COW pattern) + if agent_temp_uri: + parent_uri = f"{agent_temp_uri}/memories/skills" + else: + parent_uri = f"viking://agent/{agent_space}/memories/skills" return Context( uri=uri, - parent_uri=f"viking://agent/{agent_space}/memories/skills", + parent_uri=parent_uri, is_leaf=True, abstract=abstract_override or candidate.abstract, context_type=ContextType.MEMORY.value, diff --git a/openviking/session/session.py b/openviking/session/session.py index 243069a1..95f4ada1 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -7,9 +7,10 @@ import json import re +import time from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from uuid import uuid4 from openviking.message import Message, Part @@ -90,6 +91,13 @@ def __init__( self._compression: SessionCompression = SessionCompression() self._stats: SessionStats = SessionStats() self._loaded = False + + # Temp URI management for COW pattern + self._temp_base_uri: Optional[str] = None + self._session_temp_uri: Optional[str] = None + self._user_temp_uri: Optional[str] = None + self._agent_temp_uri: Optional[str] = None + self._temp_created_at: Optional[float] = None logger.info(f"Session created: {self.session_id} for user {self.user}") @@ -294,68 +302,218 @@ def commit(self) -> Dict[str, Any]: logger.info(f"Session {self.session_id} committed") return result + def _create_temp_uris(self) -> Tuple[str, str, str, str]: + """Create temp URIs for session, user and agent directories. + + Temp URI structure matches target URI structure for Semantic DAG recursive processing: + - Session: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/session/{user_space}/{session_id}/ + - User: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/user/{user_space}/ + - Agent: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/agent/{agent_space}/ + + Returns: + (temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri) + """ + temp_base_uri = ( + f"viking://temp/session/" + f"{self.user.user_space_name()}/" + f"{self.session_id}/" + f"commit_{uuid4().hex[:8]}" + ) + + # Match target URI structure for Semantic DAG recursive processing + session_temp_uri = ( + f"{temp_base_uri}/session/" + f"{self.user.user_space_name()}/" + f"{self.session_id}" + ) + user_temp_uri = ( + f"{temp_base_uri}/user/" + f"{self.user.user_space_name()}" + ) + agent_temp_uri = ( + f"{temp_base_uri}/agent/" + f"{self.user.agent_space_name()}" + ) + + self._temp_base_uri = temp_base_uri + self._session_temp_uri = session_temp_uri + self._user_temp_uri = user_temp_uri + self._agent_temp_uri = agent_temp_uri + self._temp_created_at = time.time() + + return temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri + + async def _cleanup_temp_uris(self) -> None: + """Clean up all temp directories after commit.""" + if self._temp_base_uri: + try: + await self._viking_fs.delete_temp(self._temp_base_uri, ctx=self.ctx) + logger.info(f"Cleaned up temp base: {self._temp_base_uri}") + except Exception as e: + logger.warning(f"Failed to cleanup temp {self._temp_base_uri}: {e}") + finally: + self._temp_base_uri = None + self._session_temp_uri = None + self._user_temp_uri = None + self._agent_temp_uri = None + self._temp_created_at = None + async def commit_async(self) -> Dict[str, Any]: - """Async commit session: create archive, extract memories, persist.""" + """Async commit session with Copy-on-Write pattern. + + Process: + 1. Copy: Copy existing session, user and agent directories to temp + 2. Write: Make all changes in temp + 3. Semantic: Trigger semantic processing + 4. Switch: Atomically switch from temp to target (handled by SemanticProcessor) + """ result = { "session_id": self.session_id, "status": "committed", "memories_extracted": 0, "active_count_updated": 0, "archived": False, + "temp_base_uri": None, + "session_temp_uri": None, + "user_temp_uri": None, + "agent_temp_uri": None, + "semantic_msg_id": None, "stats": None, } + if not self._messages: return result - - # 1. Archive current messages - self._compression.compression_index += 1 - messages_to_archive = self._messages.copy() - - summary = await self._generate_archive_summary_async(messages_to_archive) - archive_abstract = self._extract_abstract_from_summary(summary) - archive_overview = summary - - await self._write_archive_async( - index=self._compression.compression_index, - messages=messages_to_archive, - abstract=archive_abstract, - overview=archive_overview, - ) - - self._compression.original_count += len(messages_to_archive) - result["archived"] = True - - self._messages.clear() - logger.info( - f"Archived: {len(messages_to_archive)} messages → history/archive_{self._compression.compression_index:03d}/" - ) - - # 2. Extract long-term memories - if self._session_compressor: + + # ========== Phase 1: Copy ========== + temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri = self._create_temp_uris() + result["temp_base_uri"] = temp_base_uri + result["session_temp_uri"] = session_temp_uri + result["user_temp_uri"] = user_temp_uri + result["agent_temp_uri"] = agent_temp_uri + + try: + # 1.1 Copy existing session to temp + logger.info(f"Copying session {self.session_id} to temp: {session_temp_uri}") + try: + await self._viking_fs.copy_directory( + from_uri=self._session_uri, + to_uri=session_temp_uri, + ctx=self.ctx, + ) + logger.info(f"Session copied to temp: {session_temp_uri}") + except Exception as e: + if "not found" in str(e).lower(): + logger.info(f"Session {self.session_id} not found, creating new temp") + await self._viking_fs.mkdir(session_temp_uri, exist_ok=True, ctx=self.ctx) + else: + raise + + # 1.2 Copy existing user directory to temp + user_uri = f"viking://user/{self.user.user_space_name()}" + logger.info(f"Copying user directory to temp: {user_temp_uri}") + try: + await self._viking_fs.copy_directory( + from_uri=user_uri, + to_uri=user_temp_uri, + ctx=self.ctx, + ) + logger.info(f"User directory copied to temp: {user_temp_uri}") + except Exception as e: + if "not found" in str(e).lower(): + logger.info(f"User directory not found, creating new temp") + await self._viking_fs.mkdir(user_temp_uri, exist_ok=True, ctx=self.ctx) + else: + raise + + # 1.3 Copy existing agent directory to temp + agent_uri = f"viking://agent/{self.user.agent_space_name()}" + logger.info(f"Copying agent directory to temp: {agent_temp_uri}") + try: + await self._viking_fs.copy_directory( + from_uri=agent_uri, + to_uri=agent_temp_uri, + ctx=self.ctx, + ) + logger.info(f"Agent directory copied to temp: {agent_temp_uri}") + except Exception as e: + if "not found" in str(e).lower(): + logger.info(f"Agent directory not found, creating new temp") + await self._viking_fs.mkdir(agent_temp_uri, exist_ok=True, ctx=self.ctx) + else: + raise + + except Exception as e: + logger.error(f"Failed to copy directories to temp: {e}") + await self._cleanup_temp_uris() + raise + + # ========== Phase 2: Write (all changes in temp) ========== + try: + # 2.1 Archive current messages to temp + self._compression.compression_index += 1 + messages_to_archive = self._messages.copy() + + await self._write_archive_to_temp( + temp_uri=session_temp_uri, + index=self._compression.compression_index, + messages=messages_to_archive, + ) + + self._compression.original_count += len(messages_to_archive) + result["archived"] = True + + self._messages.clear() logger.info( - f"Starting memory extraction from {len(messages_to_archive)} archived messages" + f"Archived: {len(messages_to_archive)} messages → " + f"{session_temp_uri}/history/archive_{self._compression.compression_index:03d}/" ) - memories = await self._session_compressor.extract_long_term_memories( - messages=messages_to_archive, - user=self.user, - session_id=self.session_id, - ctx=self.ctx, + + # 2.2 Extract long-term memories (to temp user and agent directories) + if self._session_compressor: + logger.info( + f"Starting memory extraction from {len(messages_to_archive)} archived messages" + ) + memories = await self._session_compressor.extract_long_term_memories( + messages=messages_to_archive, + user=self.user, + session_id=self.session_id, + ctx=self.ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri, + ) + logger.info(f"Extracted {len(memories)} memories to temp directories") + result["memories_extracted"] = len(memories) + self._stats.memories_extracted += len(memories) + + # 2.3 Write current messages to temp + await self._write_messages_to_temp(session_temp_uri, self._messages) + + logger.info(f"Session changes written to temp: {session_temp_uri}") + # 2.5 Update active_count + active_count_updated = await self._update_active_counts_async() + result["active_count_updated"] = active_count_updated + except Exception as e: + logger.error(f"Failed to write changes to temp: {e}") + await self._cleanup_temp_uris() + raise + + # ========== Phase 3: Semantic = Switch =========== + try: + semantic_msg_ids = await self._enqueue_to_semantic_queue( + session_temp_uri=session_temp_uri, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri, ) - logger.info(f"Extracted {len(memories)} memories") - result["memories_extracted"] = len(memories) - self._stats.memories_extracted += len(memories) - - # 3. Write current messages to AGFS - await self._write_to_agfs_async(self._messages) - - # 4. Create relations - await self._write_relations_async() - - # 5. Update active_count - active_count_updated = await self._update_active_counts_async() - result["active_count_updated"] = active_count_updated - - # 6. Update statistics + + logger.info(f"Session, user, agent enqueued to SemanticQueue: {semantic_msg_ids}") + result["semantic_msg_ids"] = semantic_msg_ids + + except Exception as e: + logger.error(f"Failed to enqueue to SemanticQueue: {e}") + await self._cleanup_temp_uris() + raise + + # ========== Update statistics ========== self._stats.compression_count = self._compression.compression_index result["stats"] = { "total_turns": self._stats.total_turns, @@ -365,7 +523,7 @@ async def commit_async(self) -> Dict[str, Any]: } self._stats.total_tokens = 0 - logger.info(f"Session {self.session_id} committed (async)") + logger.info(f"Session {self.session_id} committed (async with COW pattern)") return result def _update_active_counts(self) -> int: @@ -549,6 +707,105 @@ def _write_archive( logger.debug(f"Written archive: {archive_uri}") + async def _write_archive_to_temp( + self, + temp_uri: str, + index: int, + messages: List[Message], + ) -> None: + """Write archive to temp directory. + + Note: .abstract.md and .overview.md will be generated by Semantic DAG. + """ + archive_uri = f"{temp_uri}/history/archive_{index:03d}" + + lines = [m.to_jsonl() for m in messages] + await self._viking_fs.write_file( + uri=f"{archive_uri}/messages.jsonl", + content="\n".join(lines) + "\n", + ctx=self.ctx, + ) + + # Note: .abstract.md and .overview.md will be generated by Semantic DAG + # No need to manually create them here + + logger.debug(f"Written archive to temp: {archive_uri}") + + async def _write_messages_to_temp(self, temp_uri: str, messages: List[Message]) -> None: + """Write current messages to temp directory.""" + lines = [m.to_jsonl() for m in messages] + content = "\n".join(lines) + "\n" if lines else "" + + await self._viking_fs.write_file( + uri=f"{temp_uri}/messages.jsonl", + content=content, + ctx=self.ctx, + ) + + async def _enqueue_to_semantic_queue( + self, + session_temp_uri: str, + user_temp_uri: str, + agent_temp_uri: str, + ) -> List[str]: + """Enqueue session, user, and agent to SemanticQueue for L0/L1 generation. + + The SemanticProcessor will handle: + 1. Generate L0/L1 for session, user and agent directories + 2. Atomically switch temp URIs to target URIs + 3. Create usage relations + 4. Clean up temp URIs + + Returns: + List of message IDs [session_msg_id, user_msg_id, agent_msg_id] + """ + from openviking.storage.queuefs import SemanticMsg, get_queue_manager + + queue_manager = get_queue_manager() + semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC, allow_create=True) + + user_target_uri = f"viking://user/{self.user.user_space_name()}" + agent_target_uri = f"viking://agent/{self.user.agent_space_name()}" + + session_msg = SemanticMsg( + uri=session_temp_uri, + context_type="memory", + target_uri=self._session_uri, + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, + agent_id=self.ctx.user.agent_id, + role=self.ctx.role.value, + recursive=True, + ) + + user_msg = SemanticMsg( + uri=user_temp_uri, + context_type="memory", + target_uri=user_target_uri, + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, + agent_id=self.ctx.user.agent_id, + role=self.ctx.role.value, + recursive=True, + ) + + agent_msg = SemanticMsg( + uri=agent_temp_uri, + context_type="memory", + target_uri=agent_target_uri, + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, + agent_id=self.ctx.user.agent_id, + role=self.ctx.role.value, + recursive=True, + ) + + await semantic_queue.enqueue(session_msg) + await semantic_queue.enqueue(user_msg) + await semantic_queue.enqueue(agent_msg) + + return [session_msg.id, user_msg.id, agent_msg.id] + async def _write_archive_async( self, index: int, diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index 90c61d07..1087f42f 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -263,3 +263,11 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, traceback.print_exc() self.report_error(str(e), data) return None + finally: + if embedding_msg and embedding_msg.semantic_msg_id: + from openviking.storage.queuefs.embedding_tracker import EmbeddingTaskTracker + tracker = EmbeddingTaskTracker.get_instance() + try: + await tracker.decrement(embedding_msg.semantic_msg_id) + except Exception as tracker_err: + logger.warning(f"Failed to decrement embedding tracker: {tracker_err}") diff --git a/openviking/storage/queuefs/__init__.py b/openviking/storage/queuefs/__init__.py index b73a01d7..87514a83 100644 --- a/openviking/storage/queuefs/__init__.py +++ b/openviking/storage/queuefs/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from .embedding_msg import EmbeddingMsg from .embedding_queue import EmbeddingQueue +from .embedding_tracker import EmbeddingTaskTracker from .named_queue import NamedQueue, QueueError, QueueStatus from .queue_manager import QueueManager, get_queue_manager, init_queue_manager from .semantic_dag import SemanticDagExecutor @@ -18,6 +19,7 @@ "QueueError", "EmbeddingQueue", "EmbeddingMsg", + "EmbeddingTaskTracker", "SemanticQueue", "SemanticDagExecutor", "SemanticMsg", diff --git a/openviking/storage/queuefs/embedding_msg.py b/openviking/storage/queuefs/embedding_msg.py index 19b8381e..94e93a2c 100644 --- a/openviking/storage/queuefs/embedding_msg.py +++ b/openviking/storage/queuefs/embedding_msg.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from uuid import uuid4 @@ -10,11 +10,18 @@ class EmbeddingMsg: message: Union[str, List[Dict[str, Any]]] context_data: Dict[str, Any] + semantic_msg_id: Optional[str] = None - def __init__(self, message: Union[str, List[Dict[str, Any]]], context_data: Dict[str, Any]): + def __init__( + self, + message: Union[str, List[Dict[str, Any]]], + context_data: Dict[str, Any], + semantic_msg_id: Optional[str] = None, + ): self.id = str(uuid4()) self.message = message self.context_data = context_data + self.semantic_msg_id = semantic_msg_id def to_dict(self) -> Dict[str, Any]: """Convert embedding message to dictionary format.""" @@ -30,6 +37,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "EmbeddingMsg": obj = EmbeddingMsg( message=data["message"], context_data=data["context_data"], + semantic_msg_id=data.get("semantic_msg_id"), ) obj.id = data.get("id", obj.id) return obj diff --git a/openviking/storage/queuefs/embedding_tracker.py b/openviking/storage/queuefs/embedding_tracker.py new file mode 100644 index 00000000..a78d72b3 --- /dev/null +++ b/openviking/storage/queuefs/embedding_tracker.py @@ -0,0 +1,181 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Embedding Task Tracker for tracking embedding task completion status.""" + +import asyncio +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional + +from openviking_cli.utils.logger import get_logger + +logger = get_logger(__name__) + + +@dataclass +class EmbeddingTaskTracker: + """Track embedding task completion status for each SemanticMsg. + + This tracker maintains a global registry of embedding tasks associated + with each SemanticMsg. When all embedding tasks for a SemanticMsg are + completed, it triggers the registered callback and removes the entry. + """ + + _instance: Optional["EmbeddingTaskTracker"] = None + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) + _tasks: Dict[str, Dict[str, Any]] = field(default_factory=dict) + + @classmethod + def get_instance(cls) -> "EmbeddingTaskTracker": + """Get the singleton instance of EmbeddingTaskTracker.""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + async def register( + self, + semantic_msg_id: str, + total_count: int, + on_complete: Optional[Callable[[], Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Register a SemanticMsg with its total embedding task count. + + Args: + semantic_msg_id: The ID of the SemanticMsg + total_count: Total number of embedding tasks for this SemanticMsg + on_complete: Optional callback when all tasks complete + metadata: Optional metadata to store with the task + """ + if total_count <= 0: + return + + async with self._lock: + self._tasks[semantic_msg_id] = { + "remaining": total_count, + "total": total_count, + "on_complete": on_complete, + "metadata": metadata or {}, + } + logger.info( + f"Registered embedding tracker for SemanticMsg {semantic_msg_id}: " + f"{total_count} tasks" + ) + + async def increment(self, semantic_msg_id: str) -> Optional[int]: + """Increment the remaining task count for a SemanticMsg. + + This method should be called when a new embedding task is added + for an already registered SemanticMsg. + + Args: + semantic_msg_id: The ID of the SemanticMsg + + Returns: + The remaining count after increment, or None if not found + """ + async with self._lock: + if semantic_msg_id not in self._tasks: + return None + + task_info = self._tasks[semantic_msg_id] + task_info["remaining"] += 1 + task_info["total"] += 1 + remaining = task_info["remaining"] + + return remaining + + async def decrement(self, semantic_msg_id: str) -> Optional[int]: + """Decrement the remaining task count for a SemanticMsg. + + This method should be called when an embedding task is completed. + When the count reaches zero, the registered callback is executed + and the entry is removed from the tracker. + + Args: + semantic_msg_id: The ID of the SemanticMsg + + Returns: + The remaining count after decrement, or None if not found + """ + on_complete = None + metadata = None + + async with self._lock: + if semantic_msg_id not in self._tasks: + return None + + task_info = self._tasks[semantic_msg_id] + task_info["remaining"] -= 1 + remaining = task_info["remaining"] + + if remaining <= 0: + on_complete = task_info.get("on_complete") + metadata = task_info.get("metadata", {}) + + del self._tasks[semantic_msg_id] + logger.info( + f"All embedding tasks completed for SemanticMsg {semantic_msg_id}" + ) + + + if on_complete: + try: + result = on_complete() + if asyncio.iscoroutine(result): + await result + except Exception as e: + logger.error( + f"Error in completion callback for {semantic_msg_id}: {e}", + exc_info=True, + ) + return remaining + + async def get_status(self, semantic_msg_id: str) -> Optional[Dict[str, Any]]: + """Get the current status of a SemanticMsg's embedding tasks. + + Args: + semantic_msg_id: The ID of the SemanticMsg + + Returns: + Dict with 'remaining', 'total', 'metadata' or None if not found + """ + async with self._lock: + if semantic_msg_id not in self._tasks: + return None + task_info = self._tasks[semantic_msg_id] + return { + "remaining": task_info["remaining"], + "total": task_info["total"], + "metadata": task_info.get("metadata", {}), + } + + async def remove(self, semantic_msg_id: str) -> bool: + """Remove a SemanticMsg from the tracker. + + Args: + semantic_msg_id: The ID of the SemanticMsg + + Returns: + True if removed, False if not found + """ + async with self._lock: + if semantic_msg_id in self._tasks: + del self._tasks[semantic_msg_id] + return True + return False + + async def get_all_tracked(self) -> Dict[str, Dict[str, Any]]: + """Get all currently tracked SemanticMsgs. + + Returns: + Dict of semantic_msg_id -> task info + """ + async with self._lock: + return { + msg_id: { + "remaining": info["remaining"], + "total": info["total"], + "metadata": info.get("metadata", {}), + } + for msg_id, info in self._tasks.items() + } diff --git a/openviking/storage/queuefs/semantic_dag.py b/openviking/storage/queuefs/semantic_dag.py index 0307521f..bc91c2a7 100644 --- a/openviking/storage/queuefs/semantic_dag.py +++ b/openviking/storage/queuefs/semantic_dag.py @@ -48,11 +48,19 @@ def __init__( context_type: str, max_concurrent_llm: int, ctx: RequestContext, + incremental_update: bool = False, + target_uri: Optional[str] = None, + semantic_msg_id: Optional[str] = None, + recursive: bool = True, ): self._processor = processor self._context_type = context_type self._max_concurrent_llm = max_concurrent_llm self._ctx = ctx + self._incremental_update = incremental_update + self._target_uri = target_uri + self._semantic_msg_id = semantic_msg_id + self._recursive = recursive self._llm_sem = asyncio.Semaphore(max_concurrent_llm) self._viking_fs = get_viking_fs() self._nodes: Dict[str, DirNode] = {} @@ -62,7 +70,7 @@ def __init__( self._stats = DagStats() async def run(self, root_uri: str) -> None: - """Run DAG execution starting from root_uri.""" + """Run DAG execution starting from root_uri.""" self._root_uri = root_uri self._root_done = asyncio.Event() await self._dispatch_dir(root_uri, parent_uri=None) @@ -79,8 +87,10 @@ async def _dispatch_dir(self, dir_uri: str, parent_uri: Optional[str]) -> None: children_dirs, file_paths = await self._list_dir(dir_uri) file_index = {path: idx for idx, path in enumerate(file_paths)} child_index = {path: idx for idx, path in enumerate(children_dirs)} - pending = len(children_dirs) + len(file_paths) - + if self._recursive: + pending = len(children_dirs) + len(file_paths) + else: + pending = len(file_paths) node = DirNode( uri=dir_uri, children_dirs=children_dirs, @@ -108,8 +118,10 @@ async def _dispatch_dir(self, dir_uri: str, parent_uri: Optional[str]) -> None: self._stats.in_progress_nodes += 1 asyncio.create_task(self._file_summary_task(dir_uri, file_path)) - for child_uri in children_dirs: - asyncio.create_task(self._dispatch_dir(child_uri, dir_uri)) + if children_dirs: + if self._recursive: + for child_uri in children_dirs: + asyncio.create_task(self._dispatch_dir(child_uri, dir_uri)) except Exception as e: logger.error(f"Failed to dispatch directory {dir_uri}: {e}", exc_info=True) if parent_uri: @@ -141,13 +153,101 @@ async def _list_dir(self, uri: str) -> tuple[list[str], list[str]]: return children_dirs, file_paths + def _get_target_file_path(self, current_uri: str) -> Optional[str]: + if not self._incremental_update or not self._target_uri or not self._root_uri: + logger.warning(f"Invalid target_uri or root_uri for incremental update: target_uri={self._target_uri}, root_uri={self._root_uri}") + return None + try: + relative_path = current_uri[len(self._root_uri):] + if relative_path.startswith("/"): + relative_path = relative_path[1:] + return f"{self._target_uri}/{relative_path}" if relative_path else self._target_uri + except Exception: + return None + + async def _check_file_content_changed(self, file_path: str) -> bool: + target_path = self._get_target_file_path(file_path) + if not target_path: + return True + try: + current_content = await self._viking_fs.read_file(file_path, ctx=self._ctx) + target_content = await self._viking_fs.read_file(target_path, ctx=self._ctx) + return current_content != target_content + except Exception: + return True + + async def _read_existing_summary(self, file_path: str) -> Optional[Dict[str, str]]: + target_path = self._get_target_file_path(file_path) + if not target_path: + return None + try: + vector_store = self._viking_fs._get_vector_store() + if not vector_store: + return None + records = await vector_store.get_context_by_uri( + account_id=self._ctx.account_id, + uri=target_path, + limit=1, + ) + if records and len(records) > 0: + record = records[0] + summary = record.get("abstract", "") + if summary: + file_name = file_path.split("/")[-1] + return {"name": file_name, "summary": summary} + except Exception: + pass + return None + + async def _check_dir_children_changed(self, dir_uri: str, current_files: List[str], current_dirs: List[str]) -> bool: + target_path = self._get_target_file_path(dir_uri) + if not target_path: + return True + try: + target_dirs, target_files = await self._list_dir(target_path) + current_file_names = {f.split("/")[-1] for f in current_files} + target_file_names = {f.split("/")[-1] for f in target_files} + if current_file_names != target_file_names: + return True + current_dir_names = {d.split("/")[-1] for d in current_dirs} + target_dir_names = {d.split("/")[-1] for d in target_dirs} + if current_dir_names != target_dir_names: + return True + for current_file in current_files: + if await self._check_file_content_changed(current_file): + return True + return False + except Exception: + return True + + async def _read_existing_overview_abstract(self, dir_uri: str) -> tuple[Optional[str], Optional[str]]: + target_path = self._get_target_file_path(dir_uri) + if not target_path: + return None, None + try: + overview = await self._viking_fs.read_file(f"{target_path}/.overview.md", ctx=self._ctx) + abstract = await self._viking_fs.read_file(f"{target_path}/.abstract.md", ctx=self._ctx) + return overview, abstract + except Exception: + return None, None + async def _file_summary_task(self, parent_uri: str, file_path: str) -> None: """Generate file summary and notify parent completion.""" + file_name = file_path.split("/")[-1] + need_vectorize = True try: - summary_dict = await self._processor._generate_single_file_summary( - file_path, llm_sem=self._llm_sem, ctx=self._ctx - ) + summary_dict = None + if self._incremental_update: + content_changed = await self._check_file_content_changed(file_path) + + if not content_changed: + summary_dict = await self._read_existing_summary(file_path) + need_vectorize = False + if summary_dict is None: + summary_dict = await self._processor._generate_single_file_summary( + file_path, llm_sem=self._llm_sem, ctx=self._ctx + ) except Exception as e: logger.warning(f"Failed to generate summary for {file_path}: {e}") summary_dict = {"name": file_name, "summary": ""} @@ -155,21 +255,21 @@ async def _file_summary_task(self, parent_uri: str, file_path: str) -> None: self._stats.done_nodes += 1 self._stats.in_progress_nodes = max(0, self._stats.in_progress_nodes - 1) - await self._on_file_done(parent_uri, file_path, summary_dict) - - # Vectorize file as soon as summary is ready to avoid waiting for overview. try: - asyncio.create_task( - self._processor._vectorize_single_file( - parent_uri=parent_uri, - context_type=self._context_type, - file_path=file_path, - summary_dict=summary_dict, - ctx=self._ctx, + if need_vectorize: + asyncio.create_task( + self._processor._vectorize_single_file( + parent_uri=parent_uri, + context_type=self._context_type, + file_path=file_path, + summary_dict=summary_dict, + ctx=self._ctx, + semantic_msg_id=self._semantic_msg_id, + ) ) - ) except Exception as e: logger.error(f"Failed to schedule vectorization for {file_path}: {e}", exc_info=True) + await self._on_file_done(parent_uri, file_path, summary_dict) async def _on_file_done( self, parent_uri: str, file_path: str, summary_dict: Dict[str, str] @@ -241,17 +341,27 @@ async def _overview_task(self, dir_uri: str) -> None: node = self._nodes.get(dir_uri) if not node: return - - async with node.lock: - file_summaries = self._finalize_file_summaries(node) - children_abstracts = self._finalize_children_abstracts(node) - + need_vectorize = True try: - async with self._llm_sem: - overview = await self._processor._generate_overview( - dir_uri, file_summaries, children_abstracts + overview = None + abstract = None + if self._incremental_update: + children_changed = await self._check_dir_children_changed( + dir_uri, node.file_paths, node.children_dirs ) - abstract = self._processor._extract_abstract_from_overview(overview) + + if not children_changed: + need_vectorize = False + overview, abstract = await self._read_existing_overview_abstract(dir_uri) + if overview is None or abstract is None: + async with node.lock: + file_summaries = self._finalize_file_summaries(node) + children_abstracts = self._finalize_children_abstracts(node) + async with self._llm_sem: + overview = await self._processor._generate_overview( + dir_uri, file_summaries, children_abstracts + ) + abstract = self._processor._extract_abstract_from_overview(overview) try: await self._viking_fs.write_file(f"{dir_uri}/.overview.md", overview, ctx=self._ctx) @@ -260,9 +370,11 @@ async def _overview_task(self, dir_uri: str) -> None: logger.warning(f"Failed to write overview/abstract for {dir_uri}: {e}") try: - await self._processor._vectorize_directory_simple( - dir_uri, self._context_type, abstract, overview, ctx=self._ctx - ) + if need_vectorize: + await self._processor._vectorize_directory( + dir_uri, self._context_type, abstract, overview, ctx=self._ctx, + semantic_msg_id=self._semantic_msg_id, + ) except Exception as e: logger.error(f"Failed to vectorize directory {dir_uri}: {e}", exc_info=True) diff --git a/openviking/storage/queuefs/semantic_msg.py b/openviking/storage/queuefs/semantic_msg.py index 5f7bd730..7517dd3a 100644 --- a/openviking/storage/queuefs/semantic_msg.py +++ b/openviking/storage/queuefs/semantic_msg.py @@ -16,7 +16,7 @@ class SemanticMsg: Attributes: id: Unique identifier (UUID) uri: Directory URI to process - context_type: Type of context (resource, memory, skill) + context_type: Type of context (resource, memory, skill, session) status: Processing status (pending/processing/completed) timestamp: Creation timestamp recursive: Whether to recursively process subdirectories. @@ -27,7 +27,7 @@ class SemanticMsg: id: str # UUID uri: str # Directory URI - context_type: str # resource, memory, skill + context_type: str # resource, memory, skill, session status: str = "pending" # pending/processing/completed timestamp: int = int(datetime.now().timestamp()) recursive: bool = True # Whether to recursively process subdirectories @@ -37,6 +37,7 @@ class SemanticMsg: role: str = "root" # Additional flags skip_vectorization: bool = False + target_uri: str = "" def __init__( self, @@ -48,6 +49,7 @@ def __init__( agent_id: str = "default", role: str = "root", skip_vectorization: bool = False, + target_uri: str = "", ): self.id = str(uuid4()) self.uri = uri @@ -58,6 +60,7 @@ def __init__( self.agent_id = agent_id self.role = role self.skip_vectorization = skip_vectorization + self.target_uri = target_uri def to_dict(self) -> Dict[str, Any]: """Convert object to dictionary.""" @@ -93,6 +96,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "SemanticMsg": agent_id=data.get("agent_id", "default"), role=data.get("role", "root"), skip_vectorization=data.get("skip_vectorization", False), + target_uri=data.get("target_uri", ""), ) if "id" in data and data["id"]: obj.id = data["id"] diff --git a/openviking/storage/queuefs/semantic_processor.py b/openviking/storage/queuefs/semantic_processor.py index 59700783..39a5f4a6 100644 --- a/openviking/storage/queuefs/semantic_processor.py +++ b/openviking/storage/queuefs/semantic_processor.py @@ -3,7 +3,8 @@ """SemanticProcessor: Processes messages from SemanticQueue, generates .abstract.md and .overview.md.""" import asyncio -from typing import Any, Dict, List, Optional, Tuple +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Tuple,Callable,Awaitable from openviking.parse.parsers.constants import ( CODE_EXTENSIONS, @@ -29,9 +30,22 @@ from openviking_cli.utils.config import get_openviking_config from openviking_cli.utils.logger import get_logger + +from .embedding_tracker import EmbeddingTaskTracker + logger = get_logger(__name__) +@dataclass +class DiffResult: + """Directory diff result for sync operations.""" + added_files: List[str] = field(default_factory=list) + deleted_files: List[str] = field(default_factory=list) + updated_files: List[str] = field(default_factory=list) + added_dirs: List[str] = field(default_factory=list) + deleted_dirs: List[str] = field(default_factory=list) + + class SemanticProcessor(DequeueHandlerBase): """ Semantic processor, generates .abstract.md and .overview.md bottom-up. @@ -78,6 +92,61 @@ def _ctx_from_semantic_msg(msg: SemanticMsg) -> RequestContext: role=role, ) + async def _acquire_path_lock( + self, + resource_uri: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> Optional[str]: + """Acquire path lock to prevent concurrent processing of the same path.""" + from openviking.resource.resource_lock import ResourceLockManager, ResourceLockConflictError + + viking_fs = get_viking_fs() + if not hasattr(viking_fs, 'agfs') or not viking_fs.agfs: + logger.warning("Cannot acquire path lock: agfs not available") + return None + + lock_manager = ResourceLockManager(viking_fs.agfs) + + try: + lock_info = lock_manager.acquire_lock( + resource_uri=resource_uri, + operation="path_processing", + metadata=metadata or {}, + ) + logger.info(f"Acquired path lock for {resource_uri}, lock_id={lock_info.lock_id}") + return lock_info.lock_id + except ResourceLockConflictError as e: + logger.warning(f"Path lock conflict for {resource_uri}: {e}") + raise + except Exception as e: + logger.error(f"Failed to acquire path lock for {resource_uri}: {e}") + return None + + async def _release_path_lock( + self, + resource_uri: str, + lock_id: Optional[str], + ) -> bool: + """Release path lock.""" + if not resource_uri or not lock_id: + return False + + from openviking.resource.resource_lock import ResourceLockManager + + viking_fs = get_viking_fs() + if not hasattr(viking_fs, 'agfs') or not viking_fs.agfs: + return False + + lock_manager = ResourceLockManager(viking_fs.agfs) + success = lock_manager.release_lock(resource_uri, lock_id) + + if success: + logger.info(f"Released path lock for {resource_uri}, lock_id={lock_id}") + else: + logger.warning(f"Failed to release path lock for {resource_uri}, lock_id={lock_id}") + + return success + def _detect_file_type(self, file_name: str) -> str: """ Detect file type based on extension using constants from code parser. @@ -103,55 +172,21 @@ def _detect_file_type(self, file_name: str) -> str: # Default to other return FILE_TYPE_OTHER - async def _enqueue_semantic_msg(self, msg: SemanticMsg) -> None: - """Enqueue a SemanticMsg to the semantic queue for processing.""" - from openviking.storage.queuefs import get_queue_manager - - queue_manager = get_queue_manager() - semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC) - # The queue manager returns SemanticQueue but method signature says NamedQueue - # We need to ignore the type error for the enqueue call - await semantic_queue.enqueue(msg) # type: ignore - logger.debug(f"Enqueued semantic message for processing: {msg.uri}") - - async def _collect_directory_info( - self, - uri: str, - result: List[Tuple[str, List[str], List[str]]], - ) -> None: - """Recursively collect directory info, post-order traversal ensures bottom-up order.""" + async def _check_file_content_changed(self, file_path: str, target_file: str) -> bool: + """Check if file content has changed compared to target file.""" viking_fs = get_viking_fs() - try: - entries = await viking_fs.ls(uri, ctx=self._current_ctx) - except Exception as e: - logger.warning(f"Failed to list directory {uri}: {e}") - return - - children_uris = [] - file_paths = [] - - for entry in entries: - name = entry.get("name", "") - if not name or name.startswith(".") or name in [".", ".."]: - continue - - item_uri = VikingURI(uri).join(name).uri - - if entry.get("isDir", False): - # Child directory - children_uris.append(item_uri) - # Recursively collect children - await self._collect_directory_info(item_uri, result) - else: - # File (not starting with .) - file_paths.append(item_uri) - - # Add current directory info - result.append((uri, children_uris, file_paths)) + current_content = await viking_fs.read_file(file_path, ctx=self._current_ctx) + target_content = await viking_fs.read_file(target_file, ctx=self._current_ctx) + return current_content != target_content + except Exception: + return True async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: """Process dequeued SemanticMsg, recursively process all subdirectories.""" + target_lock_id: Optional[str] = None + source_lock_id: Optional[str] = None + try: import json @@ -167,103 +202,360 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, self._current_msg = msg self._current_ctx = self._ctx_from_semantic_msg(msg) logger.info( - f"Processing semantic generation for: {msg.uri} (recursive={msg.recursive})" + f"Processing semantic generation for: {msg})" ) - if msg.recursive: - executor = SemanticDagExecutor( - processor=self, - context_type=msg.context_type, - max_concurrent_llm=self.max_concurrent_llm, - ctx=self._current_ctx, + # Check if target_uri exists, auto-detect incremental update + is_incremental = False + viking_fs = get_viking_fs() + if msg.target_uri: + target_exists = await viking_fs.exists(msg.target_uri, ctx=self._current_ctx) + if target_exists: + is_incremental = True + logger.info(f"Target URI exists, using incremental update: {msg.target_uri}") + + # Acquire target_uri path lock + if msg.target_uri: + target_lock_id = await self._acquire_path_lock( + resource_uri=msg.target_uri, + metadata={"msg_id": msg.id, "uri": msg.uri}, ) - self._dag_executor = executor - await executor.run(msg.uri) - logger.info(f"Completed semantic generation for: {msg.uri}") - self.report_success() - return None - else: - # Non-recursive processing: directly process this directory - children_uris = [] - file_paths = [] - # Collect immediate children info only (no recursion) - viking_fs = get_viking_fs() - try: - entries = await viking_fs.ls(msg.uri, ctx=self._current_ctx) - for entry in entries: - name = entry.get("name", "") - if not name or name.startswith(".") or name in [".", ".."]: - continue - - item_uri = VikingURI(msg.uri).join(name).uri - - if entry.get("isDir", False): - children_uris.append(item_uri) - else: - file_paths.append(item_uri) - except Exception as e: - logger.warning(f"Failed to list directory {msg.uri}: {e}") - - # Process this directory - await self._process_single_directory( - uri=msg.uri, - context_type=msg.context_type, - children_uris=children_uris, - file_paths=file_paths, + # Acquire uri path lock if uri != target_uri + if msg.uri != msg.target_uri and msg.uri: + source_lock_id = await self._acquire_path_lock( + resource_uri=msg.uri, + metadata={"msg_id": msg.id, "target_uri": msg.target_uri}, ) - logger.info(f"Completed semantic generation for: {msg.uri}") - self.report_success() - return None - + tracker = EmbeddingTaskTracker.get_instance() + on_complete = self._create_sync_diff_callback( + root_uri=msg.uri, + target_uri=msg.target_uri, + root_lock_id=source_lock_id, + target_lock_id=target_lock_id, + ) + await tracker.register( + semantic_msg_id=msg.id, + total_count=1, + on_complete=on_complete, + metadata={ + "uri": msg.uri, + "root_lock_id": source_lock_id, + "target_lock_id": target_lock_id, + } + ) + executor = SemanticDagExecutor( + processor=self, + context_type=msg.context_type, + max_concurrent_llm=self.max_concurrent_llm, + ctx=self._current_ctx, + incremental_update=is_incremental, + target_uri=msg.target_uri, + semantic_msg_id=msg.id, + recursive=msg.recursive, + ) + self._dag_executor = executor + await executor.run(msg.uri) + logger.info(f"Completed semantic generation for: {msg.uri}") + self.report_success() + return None + except Exception as e: logger.error(f"Failed to process semantic message: {e}", exc_info=True) self.report_error(str(e), data) return None finally: + tracker = EmbeddingTaskTracker.get_instance() + await tracker.decrement( + semantic_msg_id=msg.id, + ) self._current_msg = None + self._current_ctx = None def get_dag_stats(self) -> Optional["DagStats"]: if not self._dag_executor: return None return self._dag_executor.get_stats() + + def _create_sync_diff_callback( + self, + root_uri: str, + target_uri: str, + root_lock_id: Optional[str] = None, + target_lock_id: Optional[str] = None, + ) -> Callable[[], Awaitable[None]]: + """ + Create a callback function to sync directory differences. + + This callback compares root_uri (new content) with target_uri (old content), + handles added/updated/deleted files, then cleans up root_uri and releases lock. + + Args: + root_uri: Source directory URI (new content) + target_uri: Target directory URI (old content) + root_lock_id: Lock ID for root_uri + target_lock_id: Lock ID for target_uri + + Returns: + Async callback function + """ + + async def sync_diff_callback() -> None: + + try: + viking_fs = get_viking_fs() + + root_tree = await self._collect_tree_info(root_uri) + + target_tree = await self._collect_tree_info(target_uri) + diff = await self._compute_diff(root_tree, target_tree, root_uri, target_uri) + logger.info( + f"[SyncDiff] Diff computed: " + f"added_files={len(diff.added_files)}, " + f"deleted_files={len(diff.deleted_files)}, " + f"updated_files={len(diff.updated_files)}, " + f"added_dirs={len(diff.added_dirs)}, " + f"deleted_dirs={len(diff.deleted_dirs)}" + ) + await self._execute_sync_operations(diff, root_uri, target_uri) + try: + await viking_fs.rm(root_uri, recursive=True, ctx=self._current_ctx) + except Exception as e: + logger.warning(f"[SyncDiff] Failed to delete root directory {root_uri}: {e}") + try: + await self._release_path_lock(root_uri, root_lock_id) + if target_uri != root_uri: + await self._release_path_lock(target_uri, target_lock_id) + except Exception as e: + logger.error( + f"[SyncDiff] Error releasing locks: {e}", + exc_info=True + ) + + except Exception as e: + logger.error( + f"[SyncDiff] Error in sync_diff_callback: " + f"root_uri={root_uri}, target_uri={target_uri}" + f"error={e}", + exc_info=True + ) + + return sync_diff_callback - async def _process_single_directory( + async def _collect_tree_info( self, uri: str, - context_type: str, - children_uris: List[str], - file_paths: List[str], - ) -> None: - """Process single directory, generate .abstract.md and .overview.md.""" + ) -> Dict[str, Tuple[List[str], List[str]]]: + """ + Recursively collect directory tree information. + + Args: + uri: Directory URI + + Returns: + Dictionary: {dir_uri: ([subdir_uris], [file_uris])} + """ viking_fs = get_viking_fs() + result: Dict[str, Tuple[List[str], List[str]]] = {} + total_dirs = 0 + total_files = 0 + + async def collect_recursive(current_uri: str, depth: int = 0) -> None: + nonlocal total_dirs, total_files + indent = " " * depth + try: + entries = await viking_fs.ls(current_uri, ctx=self._current_ctx) + except Exception as e: + logger.warning(f"[SyncDiff]{indent} Failed to list {current_uri}: {e}") + return + + sub_dirs: List[str] = [] + files: List[str] = [] + + for entry in entries: + name = entry.get("name", "") + if not name or name.startswith(".") or name in [".", ".."]: + continue + + item_uri = VikingURI(current_uri).join(name).uri + + if entry.get("isDir", False): + sub_dirs.append(item_uri) + total_dirs += 1 + await collect_recursive(item_uri, depth + 1) + else: + files.append(item_uri) + total_files += 1 + + result[current_uri] = (sub_dirs, files) + + await collect_recursive(uri) + return result + + async def _compute_diff( + self, + root_tree: Dict[str, Tuple[List[str], List[str]]], + target_tree: Dict[str, Tuple[List[str], List[str]]], + root_uri: str, + target_uri: str, + ) -> DiffResult: + """ + Compute differences between two directory trees. - # 1. Collect .abstract.md from subdirectories (already processed earlier) - children_abstracts = await self._collect_children_abstracts(children_uris) + Args: + root_tree: Directory tree from root_uri + target_tree: Directory tree from target_uri + root_uri: Source directory URI + target_uri: Target directory URI - # 2. Concurrently generate summaries for files in directory - file_summaries = await self._generate_file_summaries( - file_paths, context_type=context_type, parent_uri=uri, enqueue_files=True + Returns: + DiffResult with added/deleted/updated files and directories + """ + def get_relative_path(uri: str, base_uri: str) -> str: + if uri.startswith(base_uri): + rel = uri[len(base_uri):] + return rel.lstrip("/") + return uri + + root_files: Set[str] = set() + root_dirs: Set[str] = set() + target_files: Set[str] = set() + target_dirs: Set[str] = set() + + for dir_uri, (sub_dirs, files) in root_tree.items(): + rel_dir = get_relative_path(dir_uri, root_uri) + if rel_dir: + root_dirs.add(rel_dir) + for f in files: + root_files.add(get_relative_path(f, root_uri)) + for d in sub_dirs: + root_dirs.add(get_relative_path(d, root_uri)) + + for dir_uri, (sub_dirs, files) in target_tree.items(): + rel_dir = get_relative_path(dir_uri, target_uri) + if rel_dir: + target_dirs.add(rel_dir) + for f in files: + target_files.add(get_relative_path(f, target_uri)) + for d in sub_dirs: + target_dirs.add(get_relative_path(d, target_uri)) + + added_files_rel = root_files - target_files + deleted_files_rel = target_files - root_files + common_files = root_files & target_files + + added_dirs_rel = root_dirs - target_dirs + deleted_dirs_rel = target_dirs - root_dirs + + updated_files: List[str] = [] + for rel_file in common_files: + root_file = f"{root_uri}/{rel_file}" + target_file = f"{target_uri}/{rel_file}" + try: + if await self._check_file_content_changed(root_file, target_file): + updated_files.append(root_file) + except Exception as e: + logger.warning( + f"[SyncDiff] Failed to compare file content for {rel_file}: {e}, " + f"treating as unchanged" + ) + + added_files = [f"{root_uri}/{f}" for f in added_files_rel] + deleted_files = [f"{target_uri}/{f}" for f in deleted_files_rel] + added_dirs = [f"{root_uri}/{d}" for d in added_dirs_rel] + deleted_dirs = [f"{target_uri}/{d}" for d in deleted_dirs_rel] + + result = DiffResult( + added_files=added_files, + deleted_files=deleted_files, + updated_files=updated_files, + added_dirs=added_dirs, + deleted_dirs=deleted_dirs, ) + + return result - # 3. Generate .overview.md (contains brief description) - overview = await self._generate_overview(uri, file_summaries, children_abstracts) - - # 4. Extract abstract from overview - abstract = self._extract_abstract_from_overview(overview) + async def _execute_sync_operations( + self, + diff: DiffResult, + root_uri: str, + target_uri: str, + ) -> None: + """ + Execute sync operations based on diff result. - # 5. Write files - await viking_fs.write_file(f"{uri}/.overview.md", overview, ctx=self._current_ctx) - await viking_fs.write_file(f"{uri}/.abstract.md", abstract, ctx=self._current_ctx) + Processing order: + 1. Delete files in target that don't exist in root + 2. Move added/updated files from root to target + 3. Delete directories in target that don't exist in root - logger.debug(f"Generated overview and abstract for {uri}") + Args: + diff: DiffResult containing operations to perform + root_uri: Source directory URI + target_uri: Target directory URI + """ + viking_fs = get_viking_fs() + + def map_to_target(root_item_uri: str) -> str: + if root_item_uri.startswith(root_uri): + rel = root_item_uri[len(root_uri):] + return f"{target_uri}{rel}" if rel else target_uri + return root_item_uri + + total_deleted = 0 + total_moved = 0 + total_failed = 0 + + for i, deleted_file in enumerate(diff.deleted_files, 1): + try: + await viking_fs.rm(deleted_file, ctx=self._current_ctx) + total_deleted += 1 + except Exception as e: + total_failed += 1 + logger.warning( + f"[SyncDiff] Failed to delete file [{i}/{len(diff.deleted_files)}]: {deleted_file}, error={e}" + ) - # 6. Vectorize directory - try: - await self._vectorize_directory_simple(uri, context_type, abstract, overview) - except Exception as e: - logger.error(f"Failed to vectorize directory {uri}: {e}", exc_info=True) + for i, updated_file in enumerate(diff.updated_files, 1): + target_file = map_to_target(updated_file) + try: + await viking_fs.rm(target_file, ctx=self._current_ctx) + except Exception as e: + logger.warning( + f"[SyncDiff] Failed to remove old file [{i}/{len(diff.updated_files)}]: {target_file}, error={e}" + ) + + files_to_move = diff.added_files + diff.updated_files + for i, root_file in enumerate(files_to_move, 1): + target_file = map_to_target(root_file) + try: + target_parent = VikingURI(target_file).parent + if target_parent: + try: + await viking_fs.mkdir(target_parent.uri, exist_ok=True, ctx=self._current_ctx) + except Exception as mkdir_error: + logger.debug(f"[SyncDiff] Parent dir creation skipped (may already exist): {mkdir_error}") + await viking_fs.mv(root_file, target_file, ctx=self._current_ctx) + total_moved += 1 + except Exception as e: + total_failed += 1 + logger.warning( + f"[SyncDiff] Failed to move file [{i}/{len(files_to_move)}]: " + f"{root_file} -> {target_file}, error={e}" + ) + + for i, deleted_dir in enumerate( + sorted(diff.deleted_dirs, key=lambda x: x.count("/"), reverse=True), 1 + ): + try: + await viking_fs.rm(deleted_dir, recursive=True, ctx=self._current_ctx) + except Exception as e: + total_failed += 1 + logger.warning( + f"[SyncDiff] Failed to delete directory [{i}/{len(diff.deleted_dirs)}]: " + f"{deleted_dir}, error={e}" + ) async def _collect_children_abstracts(self, children_uris: List[str]) -> List[Dict[str, str]]: """Collect .abstract.md from subdirectories.""" @@ -276,37 +568,6 @@ async def _collect_children_abstracts(self, children_uris: List[str]) -> List[Di results.append({"name": dir_name, "abstract": abstract}) return results - async def _generate_file_summaries( - self, - file_paths: List[str], - context_type: Optional[str] = None, - parent_uri: Optional[str] = None, - enqueue_files: bool = False, - ) -> List[Dict[str, str]]: - """Concurrently generate file summaries.""" - if not file_paths: - return [] - - async def generate_one_summary(file_path: str) -> Dict[str, str]: - summary = await self._generate_single_file_summary(file_path, ctx=self._current_ctx) - if enqueue_files and context_type and parent_uri: - try: - await self._vectorize_single_file( - parent_uri=parent_uri, - context_type=context_type, - file_path=file_path, - summary_dict=summary, - ) - except Exception as e: - logger.error( - f"Failed to vectorize file {file_path}: {e}", - exc_info=True, - ) - return summary - - tasks = [generate_one_summary(fp) for fp in file_paths] - return await asyncio.gather(*tasks) - async def _generate_text_summary( self, file_path: str, @@ -507,13 +768,14 @@ def replace_index(match): logger.error(f"Failed to generate overview for {dir_uri}: {e}", exc_info=True) return f"# {dir_uri.split('/')[-1]}\n\nDirectory overview" - async def _vectorize_directory_simple( + async def _vectorize_directory( self, uri: str, context_type: str, abstract: str, overview: str, ctx: Optional[RequestContext] = None, + semantic_msg_id: Optional[str] = None, ) -> None: """Create directory Context and enqueue to EmbeddingQueue.""" @@ -522,6 +784,13 @@ async def _vectorize_directory_simple( return from openviking.utils.embedding_utils import vectorize_directory_meta + tracker = EmbeddingTaskTracker.get_instance() + await tracker.increment( + semantic_msg_id=semantic_msg_id, + ) + await tracker.increment( + semantic_msg_id=semantic_msg_id + ) active_ctx = ctx or self._current_ctx await vectorize_directory_meta( @@ -530,44 +799,24 @@ async def _vectorize_directory_simple( overview=overview, context_type=context_type, ctx=active_ctx, + semantic_msg_id=semantic_msg_id, ) - async def _vectorize_files( - self, - uri: str, - context_type: str, - file_paths: List[str], - file_summaries: List[Dict[str, str]], - ctx: Optional[RequestContext] = None, - ) -> None: - """Vectorize files in directory.""" - from openviking.storage.queuefs import get_queue_manager - - queue_manager = get_queue_manager() - embedding_queue = queue_manager.get_queue(queue_manager.EMBEDDING) - - for file_path, file_summary_dict in zip(file_paths, file_summaries): - await self._vectorize_single_file( - parent_uri=uri, - context_type=context_type, - file_path=file_path, - summary_dict=file_summary_dict, - embedding_queue=embedding_queue, - ctx=ctx, - ) - async def _vectorize_single_file( self, parent_uri: str, context_type: str, file_path: str, summary_dict: Dict[str, str], - embedding_queue: Optional[Any] = None, ctx: Optional[RequestContext] = None, + semantic_msg_id: Optional[str] = None, ) -> None: """Vectorize a single file using its content or summary.""" from openviking.utils.embedding_utils import vectorize_file - + tracker = EmbeddingTaskTracker.get_instance() + await tracker.increment( + semantic_msg_id=semantic_msg_id, + ) active_ctx = ctx or self._current_ctx await vectorize_file( file_path=file_path, @@ -575,4 +824,5 @@ async def _vectorize_single_file( parent_uri=parent_uri, context_type=context_type, ctx=active_ctx, - ) + semantic_msg_id=semantic_msg_id, + ) \ No newline at end of file diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index bda478ae..6fc7398c 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -332,6 +332,21 @@ async def stat(self, uri: str, ctx: Optional[RequestContext] = None) -> Dict[str path = self._uri_to_path(uri, ctx=ctx) return self.agfs.stat(path) + async def exists(self, uri: str, ctx: Optional[RequestContext] = None) -> bool: + """Check if a URI exists. + + Args: + uri: Viking URI + ctx: Request context + + Returns: + bool: True if the URI exists, False otherwise + """ + try: + await self.stat(uri, ctx=ctx) + return True + except Exception: + return False async def glob( self, pattern: str, diff --git a/openviking/utils/embedding_utils.py b/openviking/utils/embedding_utils.py index 1dffc1c8..907c5dcd 100644 --- a/openviking/utils/embedding_utils.py +++ b/openviking/utils/embedding_utils.py @@ -116,6 +116,7 @@ async def vectorize_directory_meta( overview: str, context_type: str = "resource", ctx: Optional[RequestContext] = None, + semantic_msg_id: Optional[str] = None, ) -> None: """ Vectorize directory metadata (.abstract.md and .overview.md). @@ -147,6 +148,7 @@ async def vectorize_directory_meta( context_abstract.set_vectorize(Vectorize(text=abstract)) msg_abstract = EmbeddingMsgConverter.from_context(context_abstract) if msg_abstract: + msg_abstract.semantic_msg_id = semantic_msg_id await embedding_queue.enqueue(msg_abstract) logger.debug(f"Enqueued directory L0 (abstract) for vectorization: {uri}") @@ -165,6 +167,7 @@ async def vectorize_directory_meta( context_overview.set_vectorize(Vectorize(text=overview)) msg_overview = EmbeddingMsgConverter.from_context(context_overview) if msg_overview: + msg_overview.semantic_msg_id = semantic_msg_id await embedding_queue.enqueue(msg_overview) logger.debug(f"Enqueued directory L1 (overview) for vectorization: {uri}") @@ -175,6 +178,7 @@ async def vectorize_file( parent_uri: str, context_type: str = "resource", ctx: Optional[RequestContext] = None, + semantic_msg_id: Optional[str] = None, ) -> None: """ Vectorize a single file. @@ -246,6 +250,7 @@ async def vectorize_file( if not embedding_msg: return + embedding_msg.semantic_msg_id = semantic_msg_id await embedding_queue.enqueue(embedding_msg) logger.debug(f"Enqueued file for vectorization: {file_path}") diff --git a/openviking/utils/resource_processor.py b/openviking/utils/resource_processor.py index c43ea541..6fd43e17 100644 --- a/openviking/utils/resource_processor.py +++ b/openviking/utils/resource_processor.py @@ -120,7 +120,7 @@ async def process_resource( "source_path": None, } - # ============ Phase 1: Parse source (Parser generates L0/L1 and writes to temp) ============ + # ============ Phase 1: Parse source and writes to temp viking fs ============ try: media_processor = self._get_media_processor() viking_fs = get_viking_fs() @@ -178,6 +178,7 @@ async def process_resource( ) if context_tree and context_tree.root: result["root_uri"] = context_tree.root.uri + result["temp_uri"] = context_tree.root.temp_uri except Exception as e: result["status"] = "error" result["errors"].append(f"Finalize from temp error: {e}") @@ -193,6 +194,7 @@ async def process_resource( # ============ Phase 4: Optional Steps ============ build_index = kwargs.get("build_index", True) + temp_uri_for_summarize = result.get("temp_uri") or parse_result.temp_dir_path if summarize: # Explicit summarization request. # If build_index is ALSO True, we want vectorization. @@ -203,6 +205,7 @@ async def process_resource( resource_uris=[result["root_uri"]], ctx=ctx, skip_vectorization=skip_vec, + temp_uris = [temp_uri_for_summarize] **kwargs, ) except Exception as e: @@ -214,7 +217,7 @@ async def process_resource( # We assume this means "Ingest and Index", which requires summarization. try: await self._get_summarizer().summarize( - resource_uris=[result["root_uri"]], ctx=ctx, skip_vectorization=False, **kwargs + resource_uris=[result["root_uri"]], ctx=ctx, skip_vectorization=False,temp_uris = [temp_uri_for_summarize], **kwargs ) except Exception as e: logger.error(f"Auto-index failed: {e}") diff --git a/openviking/utils/summarizer.py b/openviking/utils/summarizer.py index a7477ba3..ff1f6d25 100644 --- a/openviking/utils/summarizer.py +++ b/openviking/utils/summarizer.py @@ -39,8 +39,16 @@ async def summarize( queue_manager = get_queue_manager() semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC, allow_create=True) + temp_uris = kwargs.get("temp_uris", []) + if temp_uris == []: + temp_uris = resource_uris + if len(temp_uris) != len(resource_uris): + logger.error( + f"temp_uris length ({len(temp_uris)}) must match resource_uris length ({len(resource_uris)})" + ) + return {"status": "error", "message": "temp_uris length must match resource_uris length"} enqueued_count = 0 - for uri in resource_uris: + for uri, temp_uri in zip(resource_uris, temp_uris): # Determine context_type based on URI context_type = "resource" if uri.startswith("viking://memory/"): @@ -49,13 +57,14 @@ async def summarize( context_type = "skill" msg = SemanticMsg( - uri=uri, + uri=temp_uri, context_type=context_type, account_id=ctx.account_id, user_id=ctx.user.user_id, agent_id=ctx.user.agent_id, role=ctx.role.value, skip_vectorization=skip_vectorization, + target_uri=uri, ) await semantic_queue.enqueue(msg) enqueued_count += 1 diff --git a/openviking_cli/exceptions.py b/openviking_cli/exceptions.py index 807d317e..cd432552 100644 --- a/openviking_cli/exceptions.py +++ b/openviking_cli/exceptions.py @@ -70,6 +70,14 @@ def __init__(self, resource: str, resource_type: str = "resource"): ) +class ConflictError(OpenVikingError): + """Resource conflict (e.g., locked by another operation).""" + + def __init__(self, message: str, resource: Optional[str] = None): + details = {"resource": resource} if resource else {} + super().__init__(message, code="CONFLICT", details=details) + + # ============= Authentication Errors ============= From 3e86358c24ba930fd5a2fd16d134a054828d8ed7 Mon Sep 17 00:00:00 2001 From: yepper Date: Thu, 12 Mar 2026 09:57:02 +0800 Subject: [PATCH 2/5] refactor(resource_lock): clean up imports and improve code formatting style: fix code formatting and whitespace issues across multiple files feat(viking_fs): add copy_directory method for recursive directory copying refactor(session): simplify temp URI creation and cleanup logic style(memory_extractor): improve code formatting and line wrapping refactor(semantic_processor): clean up imports and improve sync diff logic style(compressor): fix code formatting and line wrapping refactor(embedding_tracker): clean up code and improve logging style(session): fix code formatting and whitespace issues refactor(resource_lock): improve error handling and code organization --- openviking/parse/parsers/constants.py | 3 +- openviking/parse/tree_builder.py | 4 +- openviking/resource/__init__.py | 2 +- openviking/resource/resource_lock.py | 166 +++++++++--------- openviking/service/core.py | 2 +- openviking/session/compressor.py | 88 ++++++---- openviking/session/memory_deduplicator.py | 22 +-- openviking/session/memory_extractor.py | 22 ++- openviking/session/session.py | 96 +++++----- openviking/storage/collection_schemas.py | 1 + .../storage/queuefs/embedding_tracker.py | 65 ++++--- openviking/storage/queuefs/semantic_dag.py | 22 ++- .../storage/queuefs/semantic_processor.py | 120 ++++++------- openviking/storage/viking_fs.py | 25 +++ openviking/utils/resource_processor.py | 8 +- openviking/utils/summarizer.py | 12 +- 16 files changed, 350 insertions(+), 308 deletions(-) diff --git a/openviking/parse/parsers/constants.py b/openviking/parse/parsers/constants.py index 843665e1..a817a8f2 100644 --- a/openviking/parse/parsers/constants.py +++ b/openviking/parse/parsers/constants.py @@ -174,8 +174,7 @@ ".graphql", ".gql", ".prisma", - ".thrift", - ".conf" + ".conf", } # Documentation file extensions for file type detection diff --git a/openviking/parse/tree_builder.py b/openviking/parse/tree_builder.py index 87684ccf..18bcbf07 100644 --- a/openviking/parse/tree_builder.py +++ b/openviking/parse/tree_builder.py @@ -164,8 +164,6 @@ async def finalize_from_temp( return tree - - async def _move_temp_to_dest( self, viking_fs, src_uri: str, dst_uri: str, ctx: RequestContext ) -> None: @@ -223,4 +221,4 @@ async def _enqueue_semantic_generation( role=ctx.role.value, target_uri=final_uri, ) - await semantic_queue.enqueue(msg) \ No newline at end of file + await semantic_queue.enqueue(msg) diff --git a/openviking/resource/__init__.py b/openviking/resource/__init__.py index a6f8e73c..8cbaa6b9 100644 --- a/openviking/resource/__init__.py +++ b/openviking/resource/__init__.py @@ -3,9 +3,9 @@ """Resource management modules for incremental updates.""" from openviking.resource.resource_lock import ( - ResourceLockManager, ResourceLockConflictError, ResourceLockError, + ResourceLockManager, ) __all__ = [ diff --git a/openviking/resource/resource_lock.py b/openviking/resource/resource_lock.py index 2f8f12ad..c24a4994 100644 --- a/openviking/resource/resource_lock.py +++ b/openviking/resource/resource_lock.py @@ -8,17 +8,13 @@ """ import json -import os import time import uuid from contextlib import contextmanager from dataclasses import dataclass, field -from datetime import datetime -from pathlib import Path from typing import Any, Dict, Optional from openviking_cli.utils import get_logger -from openviking_cli.utils.uri import VikingURI logger = get_logger(__name__) @@ -26,14 +22,14 @@ @dataclass class LockInfo: """Lock metadata stored in lock file.""" - + lock_id: str resource_uri: str operation: str created_at: float expires_at: Optional[float] = None metadata: Dict[str, Any] = field(default_factory=dict) - + def to_dict(self) -> Dict[str, Any]: return { "lock_id": self.lock_id, @@ -43,11 +39,11 @@ def to_dict(self) -> Dict[str, Any]: "expires_at": self.expires_at, "metadata": self.metadata, } - + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "LockInfo": return cls(**data) - + def is_expired(self) -> bool: if self.expires_at is None: return False @@ -56,12 +52,13 @@ def is_expired(self) -> bool: class ResourceLockError(Exception): """Base exception for resource lock errors.""" + pass class ResourceLockConflictError(ResourceLockError): """Raised when attempting to lock a resource that is already locked.""" - + def __init__(self, resource_uri: str, lock_info: Optional[LockInfo] = None): self.resource_uri = resource_uri self.lock_info = lock_info @@ -74,26 +71,26 @@ def __init__(self, resource_uri: str, lock_info: Optional[LockInfo] = None): class ResourceLockManager: """ Manages resource-level mutex locks using file-based storage. - + Lock files are stored under `.locks/` directory in the AGFS root. Each lock file is named after the resource URI (with path separators replaced). - + Features: - Atomic lock acquisition via file creation - Lock expiration detection - Automatic cleanup of expired locks - Service restart cleanup """ - + LOCK_DIR = ".locks" LOCK_FILE_SUFFIX = ".lock" DEFAULT_TTL = 3600 AGFS_MOUNT_PATH = "/local" - + def __init__(self, agfs: Any, default_ttl: Optional[int] = None): """ Initialize ResourceLockManager. - + Args: agfs: AGFS client instance default_ttl: Default lock TTL in seconds (default: 3600) @@ -101,20 +98,20 @@ def __init__(self, agfs: Any, default_ttl: Optional[int] = None): self._agfs = agfs self._default_ttl = default_ttl or self.DEFAULT_TTL self._lock_dir_path = f"{self.AGFS_MOUNT_PATH}/{self.LOCK_DIR}" - + def _get_lock_file_path(self, resource_uri: str) -> str: """ Get lock file path for a resource URI. - + Args: resource_uri: Resource URI (e.g., "viking://default/resources/my-repo") - + Returns: Lock file path (e.g., "/local/.locks/viking___default___resources___my-repo.lock") """ safe_uri = resource_uri.replace("://", "___").replace("/", "___").replace(".", "_") return f"{self._lock_dir_path}/{safe_uri}{self.LOCK_FILE_SUFFIX}" - + def _ensure_lock_dir(self) -> None: """Ensure lock directory exists.""" try: @@ -123,7 +120,7 @@ def _ensure_lock_dir(self) -> None: logger.info(f"Created lock directory: {self._lock_dir_path}") except Exception as e: logger.warning(f"Failed to ensure lock directory: {e}") - + def acquire_lock( self, resource_uri: str, @@ -133,24 +130,24 @@ def acquire_lock( ) -> LockInfo: """ Acquire a lock on a resource URI. - + Args: resource_uri: Resource URI to lock operation: Operation name (e.g., "incremental_update", "full_update") ttl: Lock TTL in seconds (default: use default_ttl) metadata: Additional metadata to store with lock - + Returns: LockInfo for the acquired lock - + Raises: ResourceLockConflictError: If resource is already locked """ self._ensure_lock_dir() - + lock_file = self._get_lock_file_path(resource_uri) ttl = ttl or self._default_ttl - + current_time = time.time() lock_info = LockInfo( lock_id=str(uuid.uuid4()), @@ -160,7 +157,7 @@ def acquire_lock( expires_at=current_time + ttl if ttl > 0 else None, metadata=metadata or {}, ) - + try: if self.exists(lock_file): existing_lock = self._read_lock(lock_file) @@ -171,219 +168,218 @@ def acquire_lock( f"operation={existing_lock.operation}" ) raise ResourceLockConflictError(resource_uri, existing_lock) - + logger.info(f"Removing expired lock: {lock_file}") self._agfs.rm(lock_file) - - self._agfs.write(lock_file, json.dumps(lock_info.to_dict()).encode('utf-8')) - + + self._agfs.write(lock_file, json.dumps(lock_info.to_dict()).encode("utf-8")) + logger.info( f"Acquired lock: resource={resource_uri}, " f"lock_id={lock_info.lock_id}, " f"operation={operation}, " f"ttl={ttl}s" ) - + return lock_info - + except ResourceLockConflictError: raise except Exception as e: logger.error(f"Failed to acquire lock for {resource_uri}: {e}") raise ResourceLockError(f"Failed to acquire lock: {e}") from e - + def release_lock(self, resource_uri: str, lock_id: Optional[str] = None) -> bool: """ Release a lock on a resource URI. - + Args: resource_uri: Resource URI to unlock lock_id: Optional lock ID to verify ownership - + Returns: True if lock was released, False if lock didn't exist """ lock_file = self._get_lock_file_path(resource_uri) - + try: if not self.exists(lock_file): return False - + if lock_id: existing_lock = self._read_lock(lock_file) if existing_lock and existing_lock.lock_id != lock_id: logger.warning( - f"Lock ID mismatch: expected={lock_id}, " - f"actual={existing_lock.lock_id}" + f"Lock ID mismatch: expected={lock_id}, actual={existing_lock.lock_id}" ) return False - + self._agfs.rm(lock_file) logger.info(f"Released lock: resource={resource_uri}, lock_id={lock_id}") return True - + except Exception as e: logger.error(f"Failed to release lock for {resource_uri}: {e}") return False - + def is_locked(self, resource_uri: str) -> bool: """ Check if a resource URI is locked. - + Args: resource_uri: Resource URI to check - + Returns: True if resource is locked, False otherwise """ lock_file = self._get_lock_file_path(resource_uri) - + try: if not self.exists(lock_file): return False - + lock_info = self._read_lock(lock_file) if not lock_info: return False - + if lock_info.is_expired(): logger.info(f"Found expired lock: {lock_file}") self._agfs.rm(lock_file) return False - + return True - + except Exception as e: logger.error(f"Failed to check lock for {resource_uri}: {e}") return False - + def get_lock_info(self, resource_uri: str) -> Optional[LockInfo]: """ Get lock information for a resource URI. - + Args: resource_uri: Resource URI to check - + Returns: LockInfo if resource is locked, None otherwise """ lock_file = self._get_lock_file_path(resource_uri) - + try: if not self.exists(lock_file): return None - + lock_info = self._read_lock(lock_file) if not lock_info: return None - + if lock_info.is_expired(): logger.info(f"Found expired lock: {lock_file}") self._agfs.rm(lock_file) return None - + return lock_info - + except Exception as e: logger.error(f"Failed to get lock info for {resource_uri}: {e}") return None - + def _read_lock(self, lock_file: str) -> Optional[LockInfo]: """Read lock information from a lock file.""" try: data = self._agfs.read(lock_file) - lock_dict = json.loads(data.decode('utf-8')) + lock_dict = json.loads(data.decode("utf-8")) return LockInfo.from_dict(lock_dict) except Exception as e: logger.error(f"Failed to read lock file {lock_file}: {e}") return None - + def cleanup_expired_locks(self) -> int: """ Clean up all expired locks. - + Returns: Number of locks cleaned up """ cleaned = 0 - + try: if not self.exists(self._lock_dir_path): return 0 - + lock_files = self._agfs.ls(self._lock_dir_path) - + for file_info in lock_files: lock_file = file_info.get("name", "") if not lock_file.endswith(self.LOCK_FILE_SUFFIX): continue - + lock_path = f"{self._lock_dir_path}/{lock_file}" lock_info = self._read_lock(lock_path) - + if not lock_info or lock_info.is_expired(): self._agfs.rm(lock_path) cleaned += 1 logger.info(f"Cleaned up expired lock: {lock_path}") - + if cleaned > 0: logger.info(f"Cleaned up {cleaned} expired locks") - + return cleaned - + except Exception as e: logger.error(f"Failed to cleanup expired locks: {e}") return cleaned - + def cleanup_all_locks(self) -> int: """ Clean up all locks (for service restart). - + Returns: Number of locks cleaned up """ cleaned = 0 - + try: if not self.exists(self._lock_dir_path): return 0 - + lock_files = self._agfs.ls(self._lock_dir_path) - + for file_info in lock_files: lock_file = file_info.get("name", "") if not lock_file.endswith(self.LOCK_FILE_SUFFIX): continue - + lock_path = f"{self._lock_dir_path}/{lock_file}" self._agfs.rm(lock_path) cleaned += 1 - + if cleaned > 0: logger.info(f"Cleaned up {cleaned} locks on service restart") - + return cleaned - + except Exception as e: logger.error(f"Failed to cleanup all locks: {e}") return cleaned - + def exists(self, uri: str) -> bool: """ Check if a URI exists using AGFS stat interface. - + Args: uri: URI to check (e.g., "viking://default/resources/my-repo") - + Returns: True if URI exists, False otherwise """ try: self._agfs.stat(uri) return True - except Exception as e: + except Exception: return False - + @contextmanager def lock( self, @@ -394,16 +390,16 @@ def lock( ): """ Context manager for acquiring and releasing a lock. - + Args: resource_uri: Resource URI to lock operation: Operation name ttl: Lock TTL in seconds metadata: Additional metadata - + Yields: LockInfo for the acquired lock - + Raises: ResourceLockConflictError: If resource is already locked """ diff --git a/openviking/service/core.py b/openviking/service/core.py index 9ab1de9a..e8957540 100644 --- a/openviking/service/core.py +++ b/openviking/service/core.py @@ -144,7 +144,7 @@ def _init_storage( # Initialize TransactionManager self._transaction_manager = init_transaction_manager(agfs=self._agfs_client) - + # Initialize ResourceLockManager if self._agfs_client: self._lock_manager = ResourceLockManager(agfs=self._agfs_client) diff --git a/openviking/session/compressor.py b/openviking/session/compressor.py index 31c75710..f38cb296 100644 --- a/openviking/session/compressor.py +++ b/openviking/session/compressor.py @@ -79,24 +79,21 @@ async def _index_memory(self, memory: Context, ctx: RequestContext) -> bool: return True def _convert_to_temp_uri( - self, - target_uri: str, - user_temp_uri: Optional[str], - agent_temp_uri: Optional[str] + self, target_uri: str, user_temp_uri: Optional[str], agent_temp_uri: Optional[str] ) -> str: """Convert target URI to temp URI for COW pattern. - + Args: target_uri: Target URI (e.g., viking://user/... or viking://agent/...) user_temp_uri: Temp user URI (if available) agent_temp_uri: Temp agent URI (if available) - + Returns: Converted temp URI, or original URI if no temp available """ if not user_temp_uri and not agent_temp_uri: return target_uri - + # Convert user URI if target_uri.startswith("viking://user/") and user_temp_uri: # viking://user/{user_space}/memories/... -> {user_temp_uri}/memories/... @@ -105,7 +102,7 @@ def _convert_to_temp_uri( # parts[0]="viking:", parts[1]="", parts[2]="user", parts[3]="{user_space}", parts[4:]="memories/..." rest = "/".join(parts[4:]) return f"{user_temp_uri}/{rest}" - + # Convert agent URI if target_uri.startswith("viking://agent/") and agent_temp_uri: # viking://agent/{agent_space}/memories/... -> {agent_temp_uri}/memories/... @@ -114,7 +111,7 @@ def _convert_to_temp_uri( # parts[0]="viking:", parts[1]="", parts[2]="agent", parts[3]="{agent_space}", parts[4:]="memories/..." rest = "/".join(parts[4:]) return f"{agent_temp_uri}/{rest}" - + return target_uri async def _merge_into_existing( @@ -129,10 +126,8 @@ async def _merge_into_existing( """Merge candidate content into an existing memory file.""" try: # Convert target URI to temp URI for COW pattern - temp_uri = self._convert_to_temp_uri( - target_memory.uri, user_temp_uri, agent_temp_uri - ) - + temp_uri = self._convert_to_temp_uri(target_memory.uri, user_temp_uri, agent_temp_uri) + existing_content = await viking_fs.read_file(temp_uri, ctx=ctx) payload = await self.extractor._merge_memory_bundle( existing_abstract=target_memory.abstract, @@ -150,9 +145,7 @@ async def _merge_into_existing( await viking_fs.write_file(temp_uri, payload.content, ctx=ctx) target_memory.abstract = payload.abstract target_memory.meta = {**(target_memory.meta or {}), "overview": payload.overview} - logger.info( - "Merged memory %s with abstract %s", temp_uri, target_memory.abstract - ) + logger.info("Merged memory %s with abstract %s", temp_uri, target_memory.abstract) target_memory.set_vectorize(Vectorize(text=payload.content)) # Note: vectorization will be handled by SemanticQueue after directory switch # await self._index_memory(target_memory, ctx) @@ -162,17 +155,18 @@ async def _merge_into_existing( return False async def _delete_existing_memory( - self, memory: Context, viking_fs, ctx: RequestContext, + self, + memory: Context, + viking_fs, + ctx: RequestContext, user_temp_uri: Optional[str] = None, agent_temp_uri: Optional[str] = None, ) -> bool: """Hard delete an existing memory file and clean up its vector record.""" try: # Convert target URI to temp URI for COW pattern - temp_uri = self._convert_to_temp_uri( - memory.uri, user_temp_uri, agent_temp_uri - ) - + temp_uri = self._convert_to_temp_uri(memory.uri, user_temp_uri, agent_temp_uri) + await viking_fs.rm(temp_uri, recursive=False, ctx=ctx) except Exception as e: logger.error(f"Failed to delete memory file {temp_uri}: {e}") @@ -195,7 +189,7 @@ async def extract_long_term_memories( agent_temp_uri: Optional[str] = None, ) -> List[Context]: """Extract long-term memories from messages. - + Args: messages: Messages to extract from user: User identifier @@ -205,7 +199,7 @@ async def extract_long_term_memories( will be written to this temp location. agent_temp_uri: Temp agent URI (for COW pattern). If provided, agent memories will be written to this temp location. - + Returns: List of extracted memories """ @@ -235,8 +229,12 @@ async def extract_long_term_memories( # Profile: skip dedup, always merge if candidate.category in ALWAYS_MERGE_CATEGORIES: memory = await self.extractor.create_memory( - candidate, user, session_id, ctx=ctx, - user_temp_uri=user_temp_uri, agent_temp_uri=agent_temp_uri + candidate, + user, + session_id, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri, ) if memory: memories.append(memory) @@ -325,9 +323,11 @@ async def extract_long_term_memories( for action in actions: if action.decision == MemoryActionDecision.DELETE: if viking_fs and await self._delete_existing_memory( - action.memory, viking_fs, ctx=ctx, + action.memory, + viking_fs, + ctx=ctx, user_temp_uri=user_temp_uri, - agent_temp_uri=agent_temp_uri + agent_temp_uri=agent_temp_uri, ): stats.deleted += 1 else: @@ -335,9 +335,12 @@ async def extract_long_term_memories( elif action.decision == MemoryActionDecision.MERGE: if candidate.category in MERGE_SUPPORTED_CATEGORIES and viking_fs: if await self._merge_into_existing( - candidate, action.memory, viking_fs, ctx=ctx, + candidate, + action.memory, + viking_fs, + ctx=ctx, user_temp_uri=user_temp_uri, - agent_temp_uri=agent_temp_uri + agent_temp_uri=agent_temp_uri, ): stats.merged += 1 else: @@ -351,17 +354,23 @@ async def extract_long_term_memories( for action in actions: if action.decision == MemoryActionDecision.DELETE: if viking_fs and await self._delete_existing_memory( - action.memory, viking_fs, ctx=ctx, + action.memory, + viking_fs, + ctx=ctx, user_temp_uri=user_temp_uri, - agent_temp_uri=agent_temp_uri + agent_temp_uri=agent_temp_uri, ): stats.deleted += 1 else: stats.skipped += 1 memory = await self.extractor.create_memory( - candidate, user, session_id, ctx=ctx, - user_temp_uri=user_temp_uri, agent_temp_uri=agent_temp_uri + candidate, + user, + session_id, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri, ) if memory: memories.append(memory) @@ -381,12 +390,15 @@ async def extract_long_term_memories( target_uri = memory.uri # If memory.uri is a temp URI, convert it to target URI if user_temp_uri and memory.uri.startswith(user_temp_uri): - target_uri = memory.uri.replace(user_temp_uri, f"viking://user/{ctx.user.user_space_name()}") + target_uri = memory.uri.replace( + user_temp_uri, f"viking://user/{ctx.user.user_space_name()}" + ) elif agent_temp_uri and memory.uri.startswith(agent_temp_uri): - target_uri = memory.uri.replace(agent_temp_uri, f"viking://agent/{ctx.user.agent_space_name()}") - + target_uri = memory.uri.replace( + agent_temp_uri, f"viking://agent/{ctx.user.agent_space_name()}" + ) + # Create a new Context with target URI for relation creation - from openviking_cli.context import Context target_memory = Context( uri=target_uri, context_type=memory.context_type, @@ -394,7 +406,7 @@ async def extract_long_term_memories( meta=memory.meta, ) target_memories.append(target_memory) - + await self._create_relations(target_memories, used_uris, ctx=ctx) logger.info( diff --git a/openviking/session/memory_deduplicator.py b/openviking/session/memory_deduplicator.py index df99fb84..ee28ff33 100644 --- a/openviking/session/memory_deduplicator.py +++ b/openviking/session/memory_deduplicator.py @@ -121,12 +121,12 @@ async def _find_similar_memories( agent_temp_uri: Optional[str] = None, ) -> List[Context]: """Find similar existing memories using vector search. - + Args: candidate: Candidate memory user_temp_uri: Temp user URI (for COW pattern) agent_temp_uri: Temp agent URI (for COW pattern) - + Returns: List of similar memories with temp URIs (if temp URIs provided) """ @@ -189,31 +189,25 @@ async def _find_similar_memories( if context: # Keep retrieval score for later destructive-action guardrails. context.meta = {**(context.meta or {}), "_dedup_score": score} - + # Convert target URI to temp URI (for COW pattern) if user_temp_uri or agent_temp_uri: original_uri = context.uri # Convert user URI - if (user_temp_uri and - original_uri.startswith("viking://user/")): + if user_temp_uri and original_uri.startswith("viking://user/"): parts = original_uri.split("/") if len(parts) >= 5: rest = "/".join(parts[4:]) context.uri = f"{user_temp_uri}/{rest}" - logger.debug( - f"Converted URI: {original_uri} -> {context.uri}" - ) + logger.debug(f"Converted URI: {original_uri} -> {context.uri}") # Convert agent URI - elif (agent_temp_uri and - original_uri.startswith("viking://agent/")): + elif agent_temp_uri and original_uri.startswith("viking://agent/"): parts = original_uri.split("/") if len(parts) >= 5: rest = "/".join(parts[4:]) context.uri = f"{agent_temp_uri}/{rest}" - logger.debug( - f"Converted URI: {original_uri} -> {context.uri}" - ) - + logger.debug(f"Converted URI: {original_uri} -> {context.uri}") + similar.append(context) logger.debug("Dedup similar memories after threshold=%d", len(similar)) return similar diff --git a/openviking/session/memory_extractor.py b/openviking/session/memory_extractor.py index 5199b6dd..7cfea14a 100644 --- a/openviking/session/memory_extractor.py +++ b/openviking/session/memory_extractor.py @@ -395,7 +395,7 @@ async def create_memory( agent_temp_uri: Optional[str] = None, ) -> Optional[Context]: """Create Context object from candidate and persist to AGFS as .md file. - + Args: candidate: Candidate memory to create user: User identifier @@ -604,7 +604,11 @@ async def _merge_memory_bundle( return None async def _merge_tool_memory( - self, tool_name: str, candidate: CandidateMemory, ctx: "RequestContext", agent_temp_uri: Optional[str] = None + self, + tool_name: str, + candidate: CandidateMemory, + ctx: "RequestContext", + agent_temp_uri: Optional[str] = None, ) -> Optional[Context]: """合并 Tool Memory,统计数据用 Python 累加""" if not tool_name or not tool_name.strip(): @@ -730,7 +734,9 @@ async def _merge_tool_memory( tool_name, merged_stats, merged_guidelines, fields=merged_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_tool_context(uri, candidate, ctx, abstract_override=abstract_override, agent_temp_uri=agent_temp_uri) + return self._create_tool_context( + uri, candidate, ctx, abstract_override=abstract_override, agent_temp_uri=agent_temp_uri + ) async def _enqueue_semantic_for_parent(self, file_uri: str, ctx: "RequestContext") -> None: """Enqueue semantic generation for parent directory.""" @@ -1189,7 +1195,11 @@ def _extract_tool_guidelines(self, content: str) -> str: return content.strip() async def _merge_skill_memory( - self, skill_name: str, candidate: CandidateMemory, ctx: "RequestContext", agent_temp_uri: Optional[str] = None + self, + skill_name: str, + candidate: CandidateMemory, + ctx: "RequestContext", + agent_temp_uri: Optional[str] = None, ) -> Optional[Context]: """合并 Skill Memory,统计数据用 Python 累加""" if not skill_name or not skill_name.strip(): @@ -1332,7 +1342,9 @@ async def _merge_skill_memory( skill_name, merged_stats, merged_guidelines, fields=merged_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_skill_context(uri, candidate, ctx, abstract_override=abstract_override, agent_temp_uri=agent_temp_uri) + return self._create_skill_context( + uri, candidate, ctx, abstract_override=abstract_override, agent_temp_uri=agent_temp_uri + ) def _compute_skill_statistics_derived(self, stats: dict) -> dict: """计算 Skill 派生统计数据(成功率)""" diff --git a/openviking/session/session.py b/openviking/session/session.py index 95f4ada1..330a943b 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -91,7 +91,7 @@ def __init__( self._compression: SessionCompression = SessionCompression() self._stats: SessionStats = SessionStats() self._loaded = False - + # Temp URI management for COW pattern self._temp_base_uri: Optional[str] = None self._session_temp_uri: Optional[str] = None @@ -304,12 +304,12 @@ def commit(self) -> Dict[str, Any]: def _create_temp_uris(self) -> Tuple[str, str, str, str]: """Create temp URIs for session, user and agent directories. - + Temp URI structure matches target URI structure for Semantic DAG recursive processing: - Session: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/session/{user_space}/{session_id}/ - User: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/user/{user_space}/ - Agent: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/agent/{agent_space}/ - + Returns: (temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri) """ @@ -319,30 +319,22 @@ def _create_temp_uris(self) -> Tuple[str, str, str, str]: f"{self.session_id}/" f"commit_{uuid4().hex[:8]}" ) - + # Match target URI structure for Semantic DAG recursive processing session_temp_uri = ( - f"{temp_base_uri}/session/" - f"{self.user.user_space_name()}/" - f"{self.session_id}" - ) - user_temp_uri = ( - f"{temp_base_uri}/user/" - f"{self.user.user_space_name()}" + f"{temp_base_uri}/session/{self.user.user_space_name()}/{self.session_id}" ) - agent_temp_uri = ( - f"{temp_base_uri}/agent/" - f"{self.user.agent_space_name()}" - ) - + user_temp_uri = f"{temp_base_uri}/user/{self.user.user_space_name()}" + agent_temp_uri = f"{temp_base_uri}/agent/{self.user.agent_space_name()}" + self._temp_base_uri = temp_base_uri self._session_temp_uri = session_temp_uri self._user_temp_uri = user_temp_uri self._agent_temp_uri = agent_temp_uri self._temp_created_at = time.time() - + return temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri - + async def _cleanup_temp_uris(self) -> None: """Clean up all temp directories after commit.""" if self._temp_base_uri: @@ -360,7 +352,7 @@ async def _cleanup_temp_uris(self) -> None: async def commit_async(self) -> Dict[str, Any]: """Async commit session with Copy-on-Write pattern. - + Process: 1. Copy: Copy existing session, user and agent directories to temp 2. Write: Make all changes in temp @@ -380,17 +372,17 @@ async def commit_async(self) -> Dict[str, Any]: "semantic_msg_id": None, "stats": None, } - + if not self._messages: return result - + # ========== Phase 1: Copy ========== temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri = self._create_temp_uris() result["temp_base_uri"] = temp_base_uri result["session_temp_uri"] = session_temp_uri result["user_temp_uri"] = user_temp_uri result["agent_temp_uri"] = agent_temp_uri - + try: # 1.1 Copy existing session to temp logger.info(f"Copying session {self.session_id} to temp: {session_temp_uri}") @@ -407,7 +399,7 @@ async def commit_async(self) -> Dict[str, Any]: await self._viking_fs.mkdir(session_temp_uri, exist_ok=True, ctx=self.ctx) else: raise - + # 1.2 Copy existing user directory to temp user_uri = f"viking://user/{self.user.user_space_name()}" logger.info(f"Copying user directory to temp: {user_temp_uri}") @@ -420,11 +412,11 @@ async def commit_async(self) -> Dict[str, Any]: logger.info(f"User directory copied to temp: {user_temp_uri}") except Exception as e: if "not found" in str(e).lower(): - logger.info(f"User directory not found, creating new temp") + logger.info("User directory not found, creating new temp") await self._viking_fs.mkdir(user_temp_uri, exist_ok=True, ctx=self.ctx) else: raise - + # 1.3 Copy existing agent directory to temp agent_uri = f"viking://agent/{self.user.agent_space_name()}" logger.info(f"Copying agent directory to temp: {agent_temp_uri}") @@ -437,37 +429,37 @@ async def commit_async(self) -> Dict[str, Any]: logger.info(f"Agent directory copied to temp: {agent_temp_uri}") except Exception as e: if "not found" in str(e).lower(): - logger.info(f"Agent directory not found, creating new temp") + logger.info("Agent directory not found, creating new temp") await self._viking_fs.mkdir(agent_temp_uri, exist_ok=True, ctx=self.ctx) else: raise - + except Exception as e: logger.error(f"Failed to copy directories to temp: {e}") await self._cleanup_temp_uris() raise - + # ========== Phase 2: Write (all changes in temp) ========== try: # 2.1 Archive current messages to temp self._compression.compression_index += 1 messages_to_archive = self._messages.copy() - + await self._write_archive_to_temp( temp_uri=session_temp_uri, index=self._compression.compression_index, messages=messages_to_archive, ) - + self._compression.original_count += len(messages_to_archive) result["archived"] = True - + self._messages.clear() logger.info( f"Archived: {len(messages_to_archive)} messages → " f"{session_temp_uri}/history/archive_{self._compression.compression_index:03d}/" ) - + # 2.2 Extract long-term memories (to temp user and agent directories) if self._session_compressor: logger.info( @@ -484,10 +476,10 @@ async def commit_async(self) -> Dict[str, Any]: logger.info(f"Extracted {len(memories)} memories to temp directories") result["memories_extracted"] = len(memories) self._stats.memories_extracted += len(memories) - + # 2.3 Write current messages to temp await self._write_messages_to_temp(session_temp_uri, self._messages) - + logger.info(f"Session changes written to temp: {session_temp_uri}") # 2.5 Update active_count active_count_updated = await self._update_active_counts_async() @@ -496,7 +488,7 @@ async def commit_async(self) -> Dict[str, Any]: logger.error(f"Failed to write changes to temp: {e}") await self._cleanup_temp_uris() raise - + # ========== Phase 3: Semantic = Switch =========== try: semantic_msg_ids = await self._enqueue_to_semantic_queue( @@ -504,15 +496,15 @@ async def commit_async(self) -> Dict[str, Any]: user_temp_uri=user_temp_uri, agent_temp_uri=agent_temp_uri, ) - + logger.info(f"Session, user, agent enqueued to SemanticQueue: {semantic_msg_ids}") result["semantic_msg_ids"] = semantic_msg_ids - + except Exception as e: logger.error(f"Failed to enqueue to SemanticQueue: {e}") await self._cleanup_temp_uris() raise - + # ========== Update statistics ========== self._stats.compression_count = self._compression.compression_index result["stats"] = { @@ -714,28 +706,28 @@ async def _write_archive_to_temp( messages: List[Message], ) -> None: """Write archive to temp directory. - + Note: .abstract.md and .overview.md will be generated by Semantic DAG. """ archive_uri = f"{temp_uri}/history/archive_{index:03d}" - + lines = [m.to_jsonl() for m in messages] await self._viking_fs.write_file( uri=f"{archive_uri}/messages.jsonl", content="\n".join(lines) + "\n", ctx=self.ctx, ) - + # Note: .abstract.md and .overview.md will be generated by Semantic DAG # No need to manually create them here - + logger.debug(f"Written archive to temp: {archive_uri}") async def _write_messages_to_temp(self, temp_uri: str, messages: List[Message]) -> None: """Write current messages to temp directory.""" lines = [m.to_jsonl() for m in messages] content = "\n".join(lines) + "\n" if lines else "" - + await self._viking_fs.write_file( uri=f"{temp_uri}/messages.jsonl", content=content, @@ -749,24 +741,24 @@ async def _enqueue_to_semantic_queue( agent_temp_uri: str, ) -> List[str]: """Enqueue session, user, and agent to SemanticQueue for L0/L1 generation. - + The SemanticProcessor will handle: 1. Generate L0/L1 for session, user and agent directories 2. Atomically switch temp URIs to target URIs 3. Create usage relations 4. Clean up temp URIs - + Returns: List of message IDs [session_msg_id, user_msg_id, agent_msg_id] """ from openviking.storage.queuefs import SemanticMsg, get_queue_manager - + queue_manager = get_queue_manager() semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC, allow_create=True) - + user_target_uri = f"viking://user/{self.user.user_space_name()}" agent_target_uri = f"viking://agent/{self.user.agent_space_name()}" - + session_msg = SemanticMsg( uri=session_temp_uri, context_type="memory", @@ -777,7 +769,7 @@ async def _enqueue_to_semantic_queue( role=self.ctx.role.value, recursive=True, ) - + user_msg = SemanticMsg( uri=user_temp_uri, context_type="memory", @@ -788,7 +780,7 @@ async def _enqueue_to_semantic_queue( role=self.ctx.role.value, recursive=True, ) - + agent_msg = SemanticMsg( uri=agent_temp_uri, context_type="memory", @@ -799,11 +791,11 @@ async def _enqueue_to_semantic_queue( role=self.ctx.role.value, recursive=True, ) - + await semantic_queue.enqueue(session_msg) await semantic_queue.enqueue(user_msg) await semantic_queue.enqueue(agent_msg) - + return [session_msg.id, user_msg.id, agent_msg.id] async def _write_archive_async( diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index 1087f42f..817eb02e 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -266,6 +266,7 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, finally: if embedding_msg and embedding_msg.semantic_msg_id: from openviking.storage.queuefs.embedding_tracker import EmbeddingTaskTracker + tracker = EmbeddingTaskTracker.get_instance() try: await tracker.decrement(embedding_msg.semantic_msg_id) diff --git a/openviking/storage/queuefs/embedding_tracker.py b/openviking/storage/queuefs/embedding_tracker.py index a78d72b3..b5317c7c 100644 --- a/openviking/storage/queuefs/embedding_tracker.py +++ b/openviking/storage/queuefs/embedding_tracker.py @@ -14,23 +14,23 @@ @dataclass class EmbeddingTaskTracker: """Track embedding task completion status for each SemanticMsg. - + This tracker maintains a global registry of embedding tasks associated with each SemanticMsg. When all embedding tasks for a SemanticMsg are completed, it triggers the registered callback and removes the entry. """ - + _instance: Optional["EmbeddingTaskTracker"] = None _lock: asyncio.Lock = field(default_factory=asyncio.Lock) _tasks: Dict[str, Dict[str, Any]] = field(default_factory=dict) - + @classmethod def get_instance(cls) -> "EmbeddingTaskTracker": """Get the singleton instance of EmbeddingTaskTracker.""" if cls._instance is None: cls._instance = cls() return cls._instance - + async def register( self, semantic_msg_id: str, @@ -39,7 +39,7 @@ async def register( metadata: Optional[Dict[str, Any]] = None, ) -> None: """Register a SemanticMsg with its total embedding task count. - + Args: semantic_msg_id: The ID of the SemanticMsg total_count: Total number of embedding tasks for this SemanticMsg @@ -48,7 +48,7 @@ async def register( """ if total_count <= 0: return - + async with self._lock: self._tasks[semantic_msg_id] = { "remaining": total_count, @@ -60,64 +60,59 @@ async def register( f"Registered embedding tracker for SemanticMsg {semantic_msg_id}: " f"{total_count} tasks" ) - + async def increment(self, semantic_msg_id: str) -> Optional[int]: """Increment the remaining task count for a SemanticMsg. - + This method should be called when a new embedding task is added for an already registered SemanticMsg. - + Args: semantic_msg_id: The ID of the SemanticMsg - + Returns: The remaining count after increment, or None if not found """ async with self._lock: if semantic_msg_id not in self._tasks: return None - + task_info = self._tasks[semantic_msg_id] task_info["remaining"] += 1 task_info["total"] += 1 remaining = task_info["remaining"] - + return remaining - + async def decrement(self, semantic_msg_id: str) -> Optional[int]: """Decrement the remaining task count for a SemanticMsg. - + This method should be called when an embedding task is completed. When the count reaches zero, the registered callback is executed and the entry is removed from the tracker. - + Args: semantic_msg_id: The ID of the SemanticMsg - + Returns: The remaining count after decrement, or None if not found """ on_complete = None - metadata = None - + async with self._lock: if semantic_msg_id not in self._tasks: return None - + task_info = self._tasks[semantic_msg_id] task_info["remaining"] -= 1 remaining = task_info["remaining"] - + if remaining <= 0: on_complete = task_info.get("on_complete") - metadata = task_info.get("metadata", {}) - + del self._tasks[semantic_msg_id] - logger.info( - f"All embedding tasks completed for SemanticMsg {semantic_msg_id}" - ) - - + logger.info(f"All embedding tasks completed for SemanticMsg {semantic_msg_id}") + if on_complete: try: result = on_complete() @@ -129,13 +124,13 @@ async def decrement(self, semantic_msg_id: str) -> Optional[int]: exc_info=True, ) return remaining - + async def get_status(self, semantic_msg_id: str) -> Optional[Dict[str, Any]]: """Get the current status of a SemanticMsg's embedding tasks. - + Args: semantic_msg_id: The ID of the SemanticMsg - + Returns: Dict with 'remaining', 'total', 'metadata' or None if not found """ @@ -148,13 +143,13 @@ async def get_status(self, semantic_msg_id: str) -> Optional[Dict[str, Any]]: "total": task_info["total"], "metadata": task_info.get("metadata", {}), } - + async def remove(self, semantic_msg_id: str) -> bool: """Remove a SemanticMsg from the tracker. - + Args: semantic_msg_id: The ID of the SemanticMsg - + Returns: True if removed, False if not found """ @@ -163,10 +158,10 @@ async def remove(self, semantic_msg_id: str) -> bool: del self._tasks[semantic_msg_id] return True return False - + async def get_all_tracked(self) -> Dict[str, Dict[str, Any]]: """Get all currently tracked SemanticMsgs. - + Returns: Dict of semantic_msg_id -> task info """ diff --git a/openviking/storage/queuefs/semantic_dag.py b/openviking/storage/queuefs/semantic_dag.py index bc91c2a7..f7c0384a 100644 --- a/openviking/storage/queuefs/semantic_dag.py +++ b/openviking/storage/queuefs/semantic_dag.py @@ -70,7 +70,7 @@ def __init__( self._stats = DagStats() async def run(self, root_uri: str) -> None: - """Run DAG execution starting from root_uri.""" + """Run DAG execution starting from root_uri.""" self._root_uri = root_uri self._root_done = asyncio.Event() await self._dispatch_dir(root_uri, parent_uri=None) @@ -155,10 +155,12 @@ async def _list_dir(self, uri: str) -> tuple[list[str], list[str]]: def _get_target_file_path(self, current_uri: str) -> Optional[str]: if not self._incremental_update or not self._target_uri or not self._root_uri: - logger.warning(f"Invalid target_uri or root_uri for incremental update: target_uri={self._target_uri}, root_uri={self._root_uri}") + logger.warning( + f"Invalid target_uri or root_uri for incremental update: target_uri={self._target_uri}, root_uri={self._root_uri}" + ) return None try: - relative_path = current_uri[len(self._root_uri):] + relative_path = current_uri[len(self._root_uri) :] if relative_path.startswith("/"): relative_path = relative_path[1:] return f"{self._target_uri}/{relative_path}" if relative_path else self._target_uri @@ -199,7 +201,9 @@ async def _read_existing_summary(self, file_path: str) -> Optional[Dict[str, str pass return None - async def _check_dir_children_changed(self, dir_uri: str, current_files: List[str], current_dirs: List[str]) -> bool: + async def _check_dir_children_changed( + self, dir_uri: str, current_files: List[str], current_dirs: List[str] + ) -> bool: target_path = self._get_target_file_path(dir_uri) if not target_path: return True @@ -220,7 +224,9 @@ async def _check_dir_children_changed(self, dir_uri: str, current_files: List[st except Exception: return True - async def _read_existing_overview_abstract(self, dir_uri: str) -> tuple[Optional[str], Optional[str]]: + async def _read_existing_overview_abstract( + self, dir_uri: str + ) -> tuple[Optional[str], Optional[str]]: target_path = self._get_target_file_path(dir_uri) if not target_path: return None, None @@ -372,7 +378,11 @@ async def _overview_task(self, dir_uri: str) -> None: try: if need_vectorize: await self._processor._vectorize_directory( - dir_uri, self._context_type, abstract, overview, ctx=self._ctx, + dir_uri, + self._context_type, + abstract, + overview, + ctx=self._ctx, semantic_msg_id=self._semantic_msg_id, ) except Exception as e: diff --git a/openviking/storage/queuefs/semantic_processor.py b/openviking/storage/queuefs/semantic_processor.py index 39a5f4a6..c8ee7c43 100644 --- a/openviking/storage/queuefs/semantic_processor.py +++ b/openviking/storage/queuefs/semantic_processor.py @@ -4,7 +4,7 @@ import asyncio from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Set, Tuple,Callable,Awaitable +from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple from openviking.parse.parsers.constants import ( CODE_EXTENSIONS, @@ -30,7 +30,6 @@ from openviking_cli.utils.config import get_openviking_config from openviking_cli.utils.logger import get_logger - from .embedding_tracker import EmbeddingTaskTracker logger = get_logger(__name__) @@ -39,6 +38,7 @@ @dataclass class DiffResult: """Directory diff result for sync operations.""" + added_files: List[str] = field(default_factory=list) deleted_files: List[str] = field(default_factory=list) updated_files: List[str] = field(default_factory=list) @@ -98,15 +98,15 @@ async def _acquire_path_lock( metadata: Optional[Dict[str, Any]] = None, ) -> Optional[str]: """Acquire path lock to prevent concurrent processing of the same path.""" - from openviking.resource.resource_lock import ResourceLockManager, ResourceLockConflictError - + from openviking.resource.resource_lock import ResourceLockConflictError, ResourceLockManager + viking_fs = get_viking_fs() - if not hasattr(viking_fs, 'agfs') or not viking_fs.agfs: + if not hasattr(viking_fs, "agfs") or not viking_fs.agfs: logger.warning("Cannot acquire path lock: agfs not available") return None - + lock_manager = ResourceLockManager(viking_fs.agfs) - + try: lock_info = lock_manager.acquire_lock( resource_uri=resource_uri, @@ -130,21 +130,21 @@ async def _release_path_lock( """Release path lock.""" if not resource_uri or not lock_id: return False - + from openviking.resource.resource_lock import ResourceLockManager - + viking_fs = get_viking_fs() - if not hasattr(viking_fs, 'agfs') or not viking_fs.agfs: + if not hasattr(viking_fs, "agfs") or not viking_fs.agfs: return False - + lock_manager = ResourceLockManager(viking_fs.agfs) success = lock_manager.release_lock(resource_uri, lock_id) - + if success: logger.info(f"Released path lock for {resource_uri}, lock_id={lock_id}") else: logger.warning(f"Failed to release path lock for {resource_uri}, lock_id={lock_id}") - + return success def _detect_file_type(self, file_name: str) -> str: @@ -186,7 +186,7 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, """Process dequeued SemanticMsg, recursively process all subdirectories.""" target_lock_id: Optional[str] = None source_lock_id: Optional[str] = None - + try: import json @@ -201,9 +201,7 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, msg = SemanticMsg.from_dict(data) self._current_msg = msg self._current_ctx = self._ctx_from_semantic_msg(msg) - logger.info( - f"Processing semantic generation for: {msg})" - ) + logger.info(f"Processing semantic generation for: {msg})") # Check if target_uri exists, auto-detect incremental update is_incremental = False @@ -243,7 +241,7 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, "uri": msg.uri, "root_lock_id": source_lock_id, "target_lock_id": target_lock_id, - } + }, ) executor = SemanticDagExecutor( processor=self, @@ -260,7 +258,7 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, logger.info(f"Completed semantic generation for: {msg.uri}") self.report_success() return None - + except Exception as e: logger.error(f"Failed to process semantic message: {e}", exc_info=True) self.report_error(str(e), data) @@ -277,7 +275,7 @@ def get_dag_stats(self) -> Optional["DagStats"]: if not self._dag_executor: return None return self._dag_executor.get_stats() - + def _create_sync_diff_callback( self, root_uri: str, @@ -300,14 +298,14 @@ def _create_sync_diff_callback( Returns: Async callback function """ - + async def sync_diff_callback() -> None: - + try: viking_fs = get_viking_fs() - + root_tree = await self._collect_tree_info(root_uri) - + target_tree = await self._collect_tree_info(target_uri) diff = await self._compute_diff(root_tree, target_tree, root_uri, target_uri) logger.info( @@ -328,19 +326,16 @@ async def sync_diff_callback() -> None: if target_uri != root_uri: await self._release_path_lock(target_uri, target_lock_id) except Exception as e: - logger.error( - f"[SyncDiff] Error releasing locks: {e}", - exc_info=True - ) - + logger.error(f"[SyncDiff] Error releasing locks: {e}", exc_info=True) + except Exception as e: logger.error( f"[SyncDiff] Error in sync_diff_callback: " f"root_uri={root_uri}, target_uri={target_uri}" f"error={e}", - exc_info=True + exc_info=True, ) - + return sync_diff_callback async def _collect_tree_info( @@ -360,7 +355,7 @@ async def _collect_tree_info( result: Dict[str, Tuple[List[str], List[str]]] = {} total_dirs = 0 total_files = 0 - + async def collect_recursive(current_uri: str, depth: int = 0) -> None: nonlocal total_dirs, total_files indent = " " * depth @@ -369,17 +364,17 @@ async def collect_recursive(current_uri: str, depth: int = 0) -> None: except Exception as e: logger.warning(f"[SyncDiff]{indent} Failed to list {current_uri}: {e}") return - + sub_dirs: List[str] = [] files: List[str] = [] - + for entry in entries: name = entry.get("name", "") if not name or name.startswith(".") or name in [".", ".."]: continue - + item_uri = VikingURI(current_uri).join(name).uri - + if entry.get("isDir", False): sub_dirs.append(item_uri) total_dirs += 1 @@ -387,9 +382,9 @@ async def collect_recursive(current_uri: str, depth: int = 0) -> None: else: files.append(item_uri) total_files += 1 - + result[current_uri] = (sub_dirs, files) - + await collect_recursive(uri) return result @@ -412,17 +407,18 @@ async def _compute_diff( Returns: DiffResult with added/deleted/updated files and directories """ + def get_relative_path(uri: str, base_uri: str) -> str: if uri.startswith(base_uri): - rel = uri[len(base_uri):] + rel = uri[len(base_uri) :] return rel.lstrip("/") return uri - + root_files: Set[str] = set() root_dirs: Set[str] = set() target_files: Set[str] = set() target_dirs: Set[str] = set() - + for dir_uri, (sub_dirs, files) in root_tree.items(): rel_dir = get_relative_path(dir_uri, root_uri) if rel_dir: @@ -431,7 +427,7 @@ def get_relative_path(uri: str, base_uri: str) -> str: root_files.add(get_relative_path(f, root_uri)) for d in sub_dirs: root_dirs.add(get_relative_path(d, root_uri)) - + for dir_uri, (sub_dirs, files) in target_tree.items(): rel_dir = get_relative_path(dir_uri, target_uri) if rel_dir: @@ -440,14 +436,14 @@ def get_relative_path(uri: str, base_uri: str) -> str: target_files.add(get_relative_path(f, target_uri)) for d in sub_dirs: target_dirs.add(get_relative_path(d, target_uri)) - + added_files_rel = root_files - target_files deleted_files_rel = target_files - root_files common_files = root_files & target_files - + added_dirs_rel = root_dirs - target_dirs deleted_dirs_rel = target_dirs - root_dirs - + updated_files: List[str] = [] for rel_file in common_files: root_file = f"{root_uri}/{rel_file}" @@ -460,12 +456,12 @@ def get_relative_path(uri: str, base_uri: str) -> str: f"[SyncDiff] Failed to compare file content for {rel_file}: {e}, " f"treating as unchanged" ) - + added_files = [f"{root_uri}/{f}" for f in added_files_rel] deleted_files = [f"{target_uri}/{f}" for f in deleted_files_rel] added_dirs = [f"{root_uri}/{d}" for d in added_dirs_rel] deleted_dirs = [f"{target_uri}/{d}" for d in deleted_dirs_rel] - + result = DiffResult( added_files=added_files, deleted_files=deleted_files, @@ -473,7 +469,7 @@ def get_relative_path(uri: str, base_uri: str) -> str: added_dirs=added_dirs, deleted_dirs=deleted_dirs, ) - + return result async def _execute_sync_operations( @@ -496,17 +492,17 @@ async def _execute_sync_operations( target_uri: Target directory URI """ viking_fs = get_viking_fs() - + def map_to_target(root_item_uri: str) -> str: if root_item_uri.startswith(root_uri): - rel = root_item_uri[len(root_uri):] + rel = root_item_uri[len(root_uri) :] return f"{target_uri}{rel}" if rel else target_uri return root_item_uri - + total_deleted = 0 total_moved = 0 total_failed = 0 - + for i, deleted_file in enumerate(diff.deleted_files, 1): try: await viking_fs.rm(deleted_file, ctx=self._current_ctx) @@ -525,7 +521,7 @@ def map_to_target(root_item_uri: str) -> str: logger.warning( f"[SyncDiff] Failed to remove old file [{i}/{len(diff.updated_files)}]: {target_file}, error={e}" ) - + files_to_move = diff.added_files + diff.updated_files for i, root_file in enumerate(files_to_move, 1): target_file = map_to_target(root_file) @@ -533,9 +529,13 @@ def map_to_target(root_item_uri: str) -> str: target_parent = VikingURI(target_file).parent if target_parent: try: - await viking_fs.mkdir(target_parent.uri, exist_ok=True, ctx=self._current_ctx) + await viking_fs.mkdir( + target_parent.uri, exist_ok=True, ctx=self._current_ctx + ) except Exception as mkdir_error: - logger.debug(f"[SyncDiff] Parent dir creation skipped (may already exist): {mkdir_error}") + logger.debug( + f"[SyncDiff] Parent dir creation skipped (may already exist): {mkdir_error}" + ) await viking_fs.mv(root_file, target_file, ctx=self._current_ctx) total_moved += 1 except Exception as e: @@ -544,7 +544,7 @@ def map_to_target(root_item_uri: str) -> str: f"[SyncDiff] Failed to move file [{i}/{len(files_to_move)}]: " f"{root_file} -> {target_file}, error={e}" ) - + for i, deleted_dir in enumerate( sorted(diff.deleted_dirs, key=lambda x: x.count("/"), reverse=True), 1 ): @@ -784,13 +784,12 @@ async def _vectorize_directory( return from openviking.utils.embedding_utils import vectorize_directory_meta + tracker = EmbeddingTaskTracker.get_instance() await tracker.increment( semantic_msg_id=semantic_msg_id, ) - await tracker.increment( - semantic_msg_id=semantic_msg_id - ) + await tracker.increment(semantic_msg_id=semantic_msg_id) active_ctx = ctx or self._current_ctx await vectorize_directory_meta( @@ -813,6 +812,7 @@ async def _vectorize_single_file( ) -> None: """Vectorize a single file using its content or summary.""" from openviking.utils.embedding_utils import vectorize_file + tracker = EmbeddingTaskTracker.get_instance() await tracker.increment( semantic_msg_id=semantic_msg_id, @@ -825,4 +825,4 @@ async def _vectorize_single_file( context_type=context_type, ctx=active_ctx, semantic_msg_id=semantic_msg_id, - ) \ No newline at end of file + ) diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index 6fc7398c..dda70dbd 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from openviking.pyagfs.exceptions import AGFSHTTPError +from openviking.pyagfs.helpers import cp as agfs_cp from openviking.server.identity import RequestContext, Role from openviking.utils.time_utils import format_simplified, get_current_timestamp, parse_iso_datetime from openviking_cli.exceptions import NotFoundError @@ -347,6 +348,7 @@ async def exists(self, uri: str, ctx: Optional[RequestContext] = None) -> bool: return True except Exception: return False + async def glob( self, pattern: str, @@ -1460,6 +1462,29 @@ async def move_file( self.agfs.write(to_path, content) self.agfs.rm(from_path) + async def copy_directory( + self, + from_uri: str, + to_uri: str, + ctx: Optional[RequestContext] = None, + ) -> None: + """Copy directory recursively. + + Args: + from_uri: Source directory URI + to_uri: Destination directory URI + ctx: Request context + """ + self._ensure_access(from_uri, ctx) + self._ensure_access(to_uri, ctx) + + from_path = self._uri_to_path(from_uri, ctx=ctx) + to_path = self._uri_to_path(to_uri, ctx=ctx) + + await self._ensure_parent_dirs(to_path) + + await asyncio.to_thread(agfs_cp, self.agfs, from_path, to_path, recursive=True) + # ========== Temp File Operations (backward compatible) ========== def create_temp_uri(self) -> str: diff --git a/openviking/utils/resource_processor.py b/openviking/utils/resource_processor.py index 6fd43e17..f70c4e54 100644 --- a/openviking/utils/resource_processor.py +++ b/openviking/utils/resource_processor.py @@ -205,7 +205,7 @@ async def process_resource( resource_uris=[result["root_uri"]], ctx=ctx, skip_vectorization=skip_vec, - temp_uris = [temp_uri_for_summarize] + temp_uris=[temp_uri_for_summarize], **kwargs, ) except Exception as e: @@ -217,7 +217,11 @@ async def process_resource( # We assume this means "Ingest and Index", which requires summarization. try: await self._get_summarizer().summarize( - resource_uris=[result["root_uri"]], ctx=ctx, skip_vectorization=False,temp_uris = [temp_uri_for_summarize], **kwargs + resource_uris=[result["root_uri"]], + ctx=ctx, + skip_vectorization=False, + temp_uris=[temp_uri_for_summarize], + **kwargs, ) except Exception as e: logger.error(f"Auto-index failed: {e}") diff --git a/openviking/utils/summarizer.py b/openviking/utils/summarizer.py index ff1f6d25..3d059f6f 100644 --- a/openviking/utils/summarizer.py +++ b/openviking/utils/summarizer.py @@ -6,13 +6,14 @@ Handles summarization and key information extraction. """ -from typing import TYPE_CHECKING, Any, Dict, List, Optional -from openviking_cli.utils import get_logger +from typing import TYPE_CHECKING, Any, Dict, List + from openviking.storage.queuefs import SemanticMsg, get_queue_manager +from openviking_cli.utils import get_logger if TYPE_CHECKING: - from openviking.server.identity import RequestContext from openviking.parse.vlm import VLMProcessor + from openviking.server.identity import RequestContext logger = get_logger(__name__) @@ -46,7 +47,10 @@ async def summarize( logger.error( f"temp_uris length ({len(temp_uris)}) must match resource_uris length ({len(resource_uris)})" ) - return {"status": "error", "message": "temp_uris length must match resource_uris length"} + return { + "status": "error", + "message": "temp_uris length must match resource_uris length", + } enqueued_count = 0 for uri, temp_uri in zip(resource_uris, temp_uris): # Determine context_type based on URI From 66c67cd7c2c6a9bdb51fbd943458d22f5de19a9b Mon Sep 17 00:00:00 2001 From: yepper Date: Thu, 12 Mar 2026 17:57:15 +0800 Subject: [PATCH 3/5] refactor(storage): remove resource lock and improve semantic processing - Remove ResourceLockManager and related lock handling code - Simplify semantic processor by removing path locking mechanism - Improve error handling and logging in sync operations - Add new test files for storage components - Clean up unused imports and update dependencies --- openviking/resource/__init__.py | 16 - openviking/resource/resource_lock.py | 410 --------- openviking/service/core.py | 12 - .../storage/queuefs/embedding_tracker.py | 23 +- openviking/storage/queuefs/semantic_dag.py | 18 +- .../storage/queuefs/semantic_processor.py | 159 +--- tests/unit/session/test_compressor_cow.py | 458 ++++++++++ tests/unit/session/test_deduplicator_uri.py | 301 +++++++ .../session/test_memory_extractor_tools.py | 789 +++++++++++++++++ tests/unit/session/test_session_cow.py | 495 +++++++++++ tests/unit/storage/queuefs/__init__.py | 0 .../storage/queuefs/test_dag_incremental.py | 560 ++++++++++++ .../storage/queuefs/test_embedding_msg.py | 257 ++++++ .../storage/queuefs/test_embedding_tracker.py | 554 ++++++++++++ .../queuefs/test_processor_incremental.py | 832 ++++++++++++++++++ .../unit/storage/queuefs/test_semantic_msg.py | 406 +++++++++ tests/unit/storage/test_viking_fs_new.py | 195 ++++ 17 files changed, 4922 insertions(+), 563 deletions(-) delete mode 100644 openviking/resource/__init__.py delete mode 100644 openviking/resource/resource_lock.py create mode 100644 tests/unit/session/test_compressor_cow.py create mode 100644 tests/unit/session/test_deduplicator_uri.py create mode 100644 tests/unit/session/test_memory_extractor_tools.py create mode 100644 tests/unit/session/test_session_cow.py create mode 100644 tests/unit/storage/queuefs/__init__.py create mode 100644 tests/unit/storage/queuefs/test_dag_incremental.py create mode 100644 tests/unit/storage/queuefs/test_embedding_msg.py create mode 100644 tests/unit/storage/queuefs/test_embedding_tracker.py create mode 100644 tests/unit/storage/queuefs/test_processor_incremental.py create mode 100644 tests/unit/storage/queuefs/test_semantic_msg.py create mode 100644 tests/unit/storage/test_viking_fs_new.py diff --git a/openviking/resource/__init__.py b/openviking/resource/__init__.py deleted file mode 100644 index 8cbaa6b9..00000000 --- a/openviking/resource/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Resource management modules for incremental updates.""" - -from openviking.resource.resource_lock import ( - ResourceLockConflictError, - ResourceLockError, - ResourceLockManager, -) - -__all__ = [ - "ResourceLockManager", - "ResourceLockConflictError", - "ResourceLockError", - "UpdateContext", -] diff --git a/openviking/resource/resource_lock.py b/openviking/resource/resource_lock.py deleted file mode 100644 index c24a4994..00000000 --- a/openviking/resource/resource_lock.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -""" -Resource-level mutex lock management. - -Implements resource URI-level mutual exclusion to prevent concurrent operations -on the same resource. Uses file-based locks stored in the AGFS filesystem. -""" - -import json -import time -import uuid -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Any, Dict, Optional - -from openviking_cli.utils import get_logger - -logger = get_logger(__name__) - - -@dataclass -class LockInfo: - """Lock metadata stored in lock file.""" - - lock_id: str - resource_uri: str - operation: str - created_at: float - expires_at: Optional[float] = None - metadata: Dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> Dict[str, Any]: - return { - "lock_id": self.lock_id, - "resource_uri": self.resource_uri, - "operation": self.operation, - "created_at": self.created_at, - "expires_at": self.expires_at, - "metadata": self.metadata, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "LockInfo": - return cls(**data) - - def is_expired(self) -> bool: - if self.expires_at is None: - return False - return time.time() > self.expires_at - - -class ResourceLockError(Exception): - """Base exception for resource lock errors.""" - - pass - - -class ResourceLockConflictError(ResourceLockError): - """Raised when attempting to lock a resource that is already locked.""" - - def __init__(self, resource_uri: str, lock_info: Optional[LockInfo] = None): - self.resource_uri = resource_uri - self.lock_info = lock_info - message = f"Resource '{resource_uri}' is locked" - if lock_info: - message += f" by operation '{lock_info.operation}' (lock_id: {lock_info.lock_id})" - super().__init__(message) - - -class ResourceLockManager: - """ - Manages resource-level mutex locks using file-based storage. - - Lock files are stored under `.locks/` directory in the AGFS root. - Each lock file is named after the resource URI (with path separators replaced). - - Features: - - Atomic lock acquisition via file creation - - Lock expiration detection - - Automatic cleanup of expired locks - - Service restart cleanup - """ - - LOCK_DIR = ".locks" - LOCK_FILE_SUFFIX = ".lock" - DEFAULT_TTL = 3600 - AGFS_MOUNT_PATH = "/local" - - def __init__(self, agfs: Any, default_ttl: Optional[int] = None): - """ - Initialize ResourceLockManager. - - Args: - agfs: AGFS client instance - default_ttl: Default lock TTL in seconds (default: 3600) - """ - self._agfs = agfs - self._default_ttl = default_ttl or self.DEFAULT_TTL - self._lock_dir_path = f"{self.AGFS_MOUNT_PATH}/{self.LOCK_DIR}" - - def _get_lock_file_path(self, resource_uri: str) -> str: - """ - Get lock file path for a resource URI. - - Args: - resource_uri: Resource URI (e.g., "viking://default/resources/my-repo") - - Returns: - Lock file path (e.g., "/local/.locks/viking___default___resources___my-repo.lock") - """ - safe_uri = resource_uri.replace("://", "___").replace("/", "___").replace(".", "_") - return f"{self._lock_dir_path}/{safe_uri}{self.LOCK_FILE_SUFFIX}" - - def _ensure_lock_dir(self) -> None: - """Ensure lock directory exists.""" - try: - if not self.exists(self._lock_dir_path): - self._agfs.mkdir(self._lock_dir_path) - logger.info(f"Created lock directory: {self._lock_dir_path}") - except Exception as e: - logger.warning(f"Failed to ensure lock directory: {e}") - - def acquire_lock( - self, - resource_uri: str, - operation: str, - ttl: Optional[int] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> LockInfo: - """ - Acquire a lock on a resource URI. - - Args: - resource_uri: Resource URI to lock - operation: Operation name (e.g., "incremental_update", "full_update") - ttl: Lock TTL in seconds (default: use default_ttl) - metadata: Additional metadata to store with lock - - Returns: - LockInfo for the acquired lock - - Raises: - ResourceLockConflictError: If resource is already locked - """ - self._ensure_lock_dir() - - lock_file = self._get_lock_file_path(resource_uri) - ttl = ttl or self._default_ttl - - current_time = time.time() - lock_info = LockInfo( - lock_id=str(uuid.uuid4()), - resource_uri=resource_uri, - operation=operation, - created_at=current_time, - expires_at=current_time + ttl if ttl > 0 else None, - metadata=metadata or {}, - ) - - try: - if self.exists(lock_file): - existing_lock = self._read_lock(lock_file) - if existing_lock and not existing_lock.is_expired(): - logger.warning( - f"Lock conflict: resource={resource_uri}, " - f"existing_lock_id={existing_lock.lock_id}, " - f"operation={existing_lock.operation}" - ) - raise ResourceLockConflictError(resource_uri, existing_lock) - - logger.info(f"Removing expired lock: {lock_file}") - self._agfs.rm(lock_file) - - self._agfs.write(lock_file, json.dumps(lock_info.to_dict()).encode("utf-8")) - - logger.info( - f"Acquired lock: resource={resource_uri}, " - f"lock_id={lock_info.lock_id}, " - f"operation={operation}, " - f"ttl={ttl}s" - ) - - return lock_info - - except ResourceLockConflictError: - raise - except Exception as e: - logger.error(f"Failed to acquire lock for {resource_uri}: {e}") - raise ResourceLockError(f"Failed to acquire lock: {e}") from e - - def release_lock(self, resource_uri: str, lock_id: Optional[str] = None) -> bool: - """ - Release a lock on a resource URI. - - Args: - resource_uri: Resource URI to unlock - lock_id: Optional lock ID to verify ownership - - Returns: - True if lock was released, False if lock didn't exist - """ - lock_file = self._get_lock_file_path(resource_uri) - - try: - if not self.exists(lock_file): - return False - - if lock_id: - existing_lock = self._read_lock(lock_file) - if existing_lock and existing_lock.lock_id != lock_id: - logger.warning( - f"Lock ID mismatch: expected={lock_id}, actual={existing_lock.lock_id}" - ) - return False - - self._agfs.rm(lock_file) - logger.info(f"Released lock: resource={resource_uri}, lock_id={lock_id}") - return True - - except Exception as e: - logger.error(f"Failed to release lock for {resource_uri}: {e}") - return False - - def is_locked(self, resource_uri: str) -> bool: - """ - Check if a resource URI is locked. - - Args: - resource_uri: Resource URI to check - - Returns: - True if resource is locked, False otherwise - """ - lock_file = self._get_lock_file_path(resource_uri) - - try: - if not self.exists(lock_file): - return False - - lock_info = self._read_lock(lock_file) - if not lock_info: - return False - - if lock_info.is_expired(): - logger.info(f"Found expired lock: {lock_file}") - self._agfs.rm(lock_file) - return False - - return True - - except Exception as e: - logger.error(f"Failed to check lock for {resource_uri}: {e}") - return False - - def get_lock_info(self, resource_uri: str) -> Optional[LockInfo]: - """ - Get lock information for a resource URI. - - Args: - resource_uri: Resource URI to check - - Returns: - LockInfo if resource is locked, None otherwise - """ - lock_file = self._get_lock_file_path(resource_uri) - - try: - if not self.exists(lock_file): - return None - - lock_info = self._read_lock(lock_file) - if not lock_info: - return None - - if lock_info.is_expired(): - logger.info(f"Found expired lock: {lock_file}") - self._agfs.rm(lock_file) - return None - - return lock_info - - except Exception as e: - logger.error(f"Failed to get lock info for {resource_uri}: {e}") - return None - - def _read_lock(self, lock_file: str) -> Optional[LockInfo]: - """Read lock information from a lock file.""" - try: - data = self._agfs.read(lock_file) - lock_dict = json.loads(data.decode("utf-8")) - return LockInfo.from_dict(lock_dict) - except Exception as e: - logger.error(f"Failed to read lock file {lock_file}: {e}") - return None - - def cleanup_expired_locks(self) -> int: - """ - Clean up all expired locks. - - Returns: - Number of locks cleaned up - """ - cleaned = 0 - - try: - if not self.exists(self._lock_dir_path): - return 0 - - lock_files = self._agfs.ls(self._lock_dir_path) - - for file_info in lock_files: - lock_file = file_info.get("name", "") - if not lock_file.endswith(self.LOCK_FILE_SUFFIX): - continue - - lock_path = f"{self._lock_dir_path}/{lock_file}" - lock_info = self._read_lock(lock_path) - - if not lock_info or lock_info.is_expired(): - self._agfs.rm(lock_path) - cleaned += 1 - logger.info(f"Cleaned up expired lock: {lock_path}") - - if cleaned > 0: - logger.info(f"Cleaned up {cleaned} expired locks") - - return cleaned - - except Exception as e: - logger.error(f"Failed to cleanup expired locks: {e}") - return cleaned - - def cleanup_all_locks(self) -> int: - """ - Clean up all locks (for service restart). - - Returns: - Number of locks cleaned up - """ - cleaned = 0 - - try: - if not self.exists(self._lock_dir_path): - return 0 - - lock_files = self._agfs.ls(self._lock_dir_path) - - for file_info in lock_files: - lock_file = file_info.get("name", "") - if not lock_file.endswith(self.LOCK_FILE_SUFFIX): - continue - - lock_path = f"{self._lock_dir_path}/{lock_file}" - self._agfs.rm(lock_path) - cleaned += 1 - - if cleaned > 0: - logger.info(f"Cleaned up {cleaned} locks on service restart") - - return cleaned - - except Exception as e: - logger.error(f"Failed to cleanup all locks: {e}") - return cleaned - - def exists(self, uri: str) -> bool: - """ - Check if a URI exists using AGFS stat interface. - - Args: - uri: URI to check (e.g., "viking://default/resources/my-repo") - - Returns: - True if URI exists, False otherwise - """ - try: - self._agfs.stat(uri) - return True - except Exception: - return False - - @contextmanager - def lock( - self, - resource_uri: str, - operation: str, - ttl: Optional[int] = None, - metadata: Optional[Dict[str, Any]] = None, - ): - """ - Context manager for acquiring and releasing a lock. - - Args: - resource_uri: Resource URI to lock - operation: Operation name - ttl: Lock TTL in seconds - metadata: Additional metadata - - Yields: - LockInfo for the acquired lock - - Raises: - ResourceLockConflictError: If resource is already locked - """ - lock_info = self.acquire_lock(resource_uri, operation, ttl, metadata) - try: - yield lock_info - finally: - self.release_lock(resource_uri, lock_info.lock_id) diff --git a/openviking/service/core.py b/openviking/service/core.py index e8957540..a19f3ba3 100644 --- a/openviking/service/core.py +++ b/openviking/service/core.py @@ -11,7 +11,6 @@ from openviking.agfs_manager import AGFSManager from openviking.core.directories import DirectoryInitializer -from openviking.resource.resource_lock import ResourceLockManager from openviking.server.identity import RequestContext, Role from openviking.service.debug_service import DebugService from openviking.service.fs_service import FSService @@ -78,7 +77,6 @@ def __init__( self._session_compressor: Optional[SessionCompressor] = None self._transaction_manager: Optional[TransactionManager] = None self._directory_initializer: Optional[DirectoryInitializer] = None - self._lock_manager: Optional[ResourceLockManager] = None # Sub-services self._fs_service = FSService() @@ -145,11 +143,6 @@ def _init_storage( # Initialize TransactionManager self._transaction_manager = init_transaction_manager(agfs=self._agfs_client) - # Initialize ResourceLockManager - if self._agfs_client: - self._lock_manager = ResourceLockManager(agfs=self._agfs_client) - logger.info("ResourceLockManager initialized") - @property def _agfs(self) -> Any: """Internal access to AGFS client for APIKeyManager.""" @@ -261,11 +254,6 @@ async def initialize(self) -> None: user_count, ) - # Clean up all locks on service startup - if self._lock_manager: - cleaned_count = self._lock_manager.cleanup_all_locks() - logger.info(f"Cleaned up {cleaned_count} locks on service startup") - # Initialize processors self._resource_processor = ResourceProcessor( vikingdb=self._vikingdb_manager, diff --git a/openviking/storage/queuefs/embedding_tracker.py b/openviking/storage/queuefs/embedding_tracker.py index b5317c7c..c9fd6ab6 100644 --- a/openviking/storage/queuefs/embedding_tracker.py +++ b/openviking/storage/queuefs/embedding_tracker.py @@ -46,9 +46,6 @@ async def register( on_complete: Optional callback when all tasks complete metadata: Optional metadata to store with the task """ - if total_count <= 0: - return - async with self._lock: self._tasks[semantic_msg_id] = { "remaining": total_count, @@ -61,6 +58,24 @@ async def register( f"{total_count} tasks" ) + if total_count <= 0 and on_complete: + del self._tasks[semantic_msg_id] + logger.info( + f"No embedding tasks for SemanticMsg {semantic_msg_id}, " + f"triggering on_complete immediately" + ) + + if total_count <= 0 and on_complete: + try: + result = on_complete() + if asyncio.iscoroutine(result): + await result + except Exception as e: + logger.error( + f"Error in completion callback for {semantic_msg_id}: {e}", + exc_info=True, + ) + async def increment(self, semantic_msg_id: str) -> Optional[int]: """Increment the remaining task count for a SemanticMsg. @@ -111,7 +126,7 @@ async def decrement(self, semantic_msg_id: str) -> Optional[int]: on_complete = task_info.get("on_complete") del self._tasks[semantic_msg_id] - logger.info(f"All embedding tasks completed for SemanticMsg {semantic_msg_id}") + logger.info(f"All embedding tasks({task_info['total']}) completed for SemanticMsg {semantic_msg_id}") if on_complete: try: diff --git a/openviking/storage/queuefs/semantic_dag.py b/openviking/storage/queuefs/semantic_dag.py index f7c0384a..397250c5 100644 --- a/openviking/storage/queuefs/semantic_dag.py +++ b/openviking/storage/queuefs/semantic_dag.py @@ -377,16 +377,18 @@ async def _overview_task(self, dir_uri: str) -> None: try: if need_vectorize: - await self._processor._vectorize_directory( - dir_uri, - self._context_type, - abstract, - overview, - ctx=self._ctx, - semantic_msg_id=self._semantic_msg_id, + asyncio.create_task( + self._processor._vectorize_directory( + dir_uri, + self._context_type, + abstract, + overview, + ctx=self._ctx, + semantic_msg_id=self._semantic_msg_id, + ) ) except Exception as e: - logger.error(f"Failed to vectorize directory {dir_uri}: {e}", exc_info=True) + logger.error(f"Failed to schedule vectorization for {dir_uri}: {e}", exc_info=True) except Exception as e: logger.error(f"Failed to generate overview for {dir_uri}: {e}", exc_info=True) diff --git a/openviking/storage/queuefs/semantic_processor.py b/openviking/storage/queuefs/semantic_processor.py index c8ee7c43..f60d37eb 100644 --- a/openviking/storage/queuefs/semantic_processor.py +++ b/openviking/storage/queuefs/semantic_processor.py @@ -92,61 +92,6 @@ def _ctx_from_semantic_msg(msg: SemanticMsg) -> RequestContext: role=role, ) - async def _acquire_path_lock( - self, - resource_uri: str, - metadata: Optional[Dict[str, Any]] = None, - ) -> Optional[str]: - """Acquire path lock to prevent concurrent processing of the same path.""" - from openviking.resource.resource_lock import ResourceLockConflictError, ResourceLockManager - - viking_fs = get_viking_fs() - if not hasattr(viking_fs, "agfs") or not viking_fs.agfs: - logger.warning("Cannot acquire path lock: agfs not available") - return None - - lock_manager = ResourceLockManager(viking_fs.agfs) - - try: - lock_info = lock_manager.acquire_lock( - resource_uri=resource_uri, - operation="path_processing", - metadata=metadata or {}, - ) - logger.info(f"Acquired path lock for {resource_uri}, lock_id={lock_info.lock_id}") - return lock_info.lock_id - except ResourceLockConflictError as e: - logger.warning(f"Path lock conflict for {resource_uri}: {e}") - raise - except Exception as e: - logger.error(f"Failed to acquire path lock for {resource_uri}: {e}") - return None - - async def _release_path_lock( - self, - resource_uri: str, - lock_id: Optional[str], - ) -> bool: - """Release path lock.""" - if not resource_uri or not lock_id: - return False - - from openviking.resource.resource_lock import ResourceLockManager - - viking_fs = get_viking_fs() - if not hasattr(viking_fs, "agfs") or not viking_fs.agfs: - return False - - lock_manager = ResourceLockManager(viking_fs.agfs) - success = lock_manager.release_lock(resource_uri, lock_id) - - if success: - logger.info(f"Released path lock for {resource_uri}, lock_id={lock_id}") - else: - logger.warning(f"Failed to release path lock for {resource_uri}, lock_id={lock_id}") - - return success - def _detect_file_type(self, file_name: str) -> str: """ Detect file type based on extension using constants from code parser. @@ -172,21 +117,21 @@ def _detect_file_type(self, file_name: str) -> str: # Default to other return FILE_TYPE_OTHER - async def _check_file_content_changed(self, file_path: str, target_file: str) -> bool: + async def _check_file_content_changed( + self, file_path: str, target_file: str, ctx: Optional[RequestContext] = None + ) -> bool: """Check if file content has changed compared to target file.""" viking_fs = get_viking_fs() try: - current_content = await viking_fs.read_file(file_path, ctx=self._current_ctx) - target_content = await viking_fs.read_file(target_file, ctx=self._current_ctx) + current_content = await viking_fs.read_file(file_path, ctx=ctx) + target_content = await viking_fs.read_file(target_file, ctx=ctx) return current_content != target_content except Exception: return True async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: """Process dequeued SemanticMsg, recursively process all subdirectories.""" - target_lock_id: Optional[str] = None - source_lock_id: Optional[str] = None - + msg = None try: import json @@ -212,35 +157,19 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, is_incremental = True logger.info(f"Target URI exists, using incremental update: {msg.target_uri}") - # Acquire target_uri path lock - if msg.target_uri: - target_lock_id = await self._acquire_path_lock( - resource_uri=msg.target_uri, - metadata={"msg_id": msg.id, "uri": msg.uri}, - ) - - # Acquire uri path lock if uri != target_uri - if msg.uri != msg.target_uri and msg.uri: - source_lock_id = await self._acquire_path_lock( - resource_uri=msg.uri, - metadata={"msg_id": msg.id, "target_uri": msg.target_uri}, - ) - tracker = EmbeddingTaskTracker.get_instance() on_complete = self._create_sync_diff_callback( root_uri=msg.uri, target_uri=msg.target_uri, - root_lock_id=source_lock_id, - target_lock_id=target_lock_id, + ctx=self._current_ctx, ) + # Register task with tracker, total_count=1 for root URI await tracker.register( semantic_msg_id=msg.id, total_count=1, on_complete=on_complete, metadata={ "uri": msg.uri, - "root_lock_id": source_lock_id, - "target_lock_id": target_lock_id, }, ) executor = SemanticDagExecutor( @@ -264,10 +193,12 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, self.report_error(str(e), data) return None finally: - tracker = EmbeddingTaskTracker.get_instance() - await tracker.decrement( - semantic_msg_id=msg.id, - ) + # Decrement task counter for root URI + if msg is not None: + tracker = EmbeddingTaskTracker.get_instance() + await tracker.decrement( + semantic_msg_id=msg.id, + ) self._current_msg = None self._current_ctx = None @@ -280,20 +211,18 @@ def _create_sync_diff_callback( self, root_uri: str, target_uri: str, - root_lock_id: Optional[str] = None, - target_lock_id: Optional[str] = None, + ctx: RequestContext, ) -> Callable[[], Awaitable[None]]: """ Create a callback function to sync directory differences. This callback compares root_uri (new content) with target_uri (old content), - handles added/updated/deleted files, then cleans up root_uri and releases lock. + handles added/updated/deleted files, then cleans up root_uri. Args: root_uri: Source directory URI (new content) target_uri: Target directory URI (old content) - root_lock_id: Lock ID for root_uri - target_lock_id: Lock ID for target_uri + ctx: Request context (captured at callback creation time) Returns: Async callback function @@ -304,10 +233,10 @@ async def sync_diff_callback() -> None: try: viking_fs = get_viking_fs() - root_tree = await self._collect_tree_info(root_uri) + root_tree = await self._collect_tree_info(root_uri, ctx=ctx) - target_tree = await self._collect_tree_info(target_uri) - diff = await self._compute_diff(root_tree, target_tree, root_uri, target_uri) + target_tree = await self._collect_tree_info(target_uri, ctx=ctx) + diff = await self._compute_diff(root_tree, target_tree, root_uri, target_uri, ctx=ctx) logger.info( f"[SyncDiff] Diff computed: " f"added_files={len(diff.added_files)}, " @@ -316,22 +245,16 @@ async def sync_diff_callback() -> None: f"added_dirs={len(diff.added_dirs)}, " f"deleted_dirs={len(diff.deleted_dirs)}" ) - await self._execute_sync_operations(diff, root_uri, target_uri) + await self._execute_sync_operations(diff, root_uri, target_uri, ctx=ctx) try: - await viking_fs.rm(root_uri, recursive=True, ctx=self._current_ctx) + await viking_fs.rm(root_uri, recursive=True, ctx=ctx) except Exception as e: logger.warning(f"[SyncDiff] Failed to delete root directory {root_uri}: {e}") - try: - await self._release_path_lock(root_uri, root_lock_id) - if target_uri != root_uri: - await self._release_path_lock(target_uri, target_lock_id) - except Exception as e: - logger.error(f"[SyncDiff] Error releasing locks: {e}", exc_info=True) except Exception as e: logger.error( f"[SyncDiff] Error in sync_diff_callback: " - f"root_uri={root_uri}, target_uri={target_uri}" + f"root_uri={root_uri}, target_uri={target_uri} " f"error={e}", exc_info=True, ) @@ -341,12 +264,14 @@ async def sync_diff_callback() -> None: async def _collect_tree_info( self, uri: str, + ctx: Optional[RequestContext] = None, ) -> Dict[str, Tuple[List[str], List[str]]]: """ Recursively collect directory tree information. Args: uri: Directory URI + ctx: Request context Returns: Dictionary: {dir_uri: ([subdir_uris], [file_uris])} @@ -360,7 +285,7 @@ async def collect_recursive(current_uri: str, depth: int = 0) -> None: nonlocal total_dirs, total_files indent = " " * depth try: - entries = await viking_fs.ls(current_uri, ctx=self._current_ctx) + entries = await viking_fs.ls(current_uri, show_all_hidden=True, ctx=ctx) except Exception as e: logger.warning(f"[SyncDiff]{indent} Failed to list {current_uri}: {e}") return @@ -370,7 +295,9 @@ async def collect_recursive(current_uri: str, depth: int = 0) -> None: for entry in entries: name = entry.get("name", "") - if not name or name.startswith(".") or name in [".", ".."]: + if not name or name in [".", ".."]: + continue + if name.startswith(".") and name not in [".abstract.md", ".overview.md"]: continue item_uri = VikingURI(current_uri).join(name).uri @@ -394,6 +321,7 @@ async def _compute_diff( target_tree: Dict[str, Tuple[List[str], List[str]]], root_uri: str, target_uri: str, + ctx: Optional[RequestContext] = None, ) -> DiffResult: """ Compute differences between two directory trees. @@ -403,6 +331,7 @@ async def _compute_diff( target_tree: Directory tree from target_uri root_uri: Source directory URI target_uri: Target directory URI + ctx: Request context Returns: DiffResult with added/deleted/updated files and directories @@ -449,7 +378,7 @@ def get_relative_path(uri: str, base_uri: str) -> str: root_file = f"{root_uri}/{rel_file}" target_file = f"{target_uri}/{rel_file}" try: - if await self._check_file_content_changed(root_file, target_file): + if await self._check_file_content_changed(root_file, target_file, ctx=ctx): updated_files.append(root_file) except Exception as e: logger.warning( @@ -477,6 +406,7 @@ async def _execute_sync_operations( diff: DiffResult, root_uri: str, target_uri: str, + ctx: Optional[RequestContext] = None, ) -> None: """ Execute sync operations based on diff result. @@ -490,6 +420,7 @@ async def _execute_sync_operations( diff: DiffResult containing operations to perform root_uri: Source directory URI target_uri: Target directory URI + ctx: Request context """ viking_fs = get_viking_fs() @@ -505,7 +436,7 @@ def map_to_target(root_item_uri: str) -> str: for i, deleted_file in enumerate(diff.deleted_files, 1): try: - await viking_fs.rm(deleted_file, ctx=self._current_ctx) + await viking_fs.rm(deleted_file, ctx=ctx) total_deleted += 1 except Exception as e: total_failed += 1 @@ -516,7 +447,7 @@ def map_to_target(root_item_uri: str) -> str: for i, updated_file in enumerate(diff.updated_files, 1): target_file = map_to_target(updated_file) try: - await viking_fs.rm(target_file, ctx=self._current_ctx) + await viking_fs.rm(target_file, ctx=ctx) except Exception as e: logger.warning( f"[SyncDiff] Failed to remove old file [{i}/{len(diff.updated_files)}]: {target_file}, error={e}" @@ -530,13 +461,13 @@ def map_to_target(root_item_uri: str) -> str: if target_parent: try: await viking_fs.mkdir( - target_parent.uri, exist_ok=True, ctx=self._current_ctx + target_parent.uri, exist_ok=True, ctx=ctx ) except Exception as mkdir_error: logger.debug( f"[SyncDiff] Parent dir creation skipped (may already exist): {mkdir_error}" ) - await viking_fs.mv(root_file, target_file, ctx=self._current_ctx) + await viking_fs.mv(root_file, target_file, ctx=ctx) total_moved += 1 except Exception as e: total_failed += 1 @@ -549,7 +480,7 @@ def map_to_target(root_item_uri: str) -> str: sorted(diff.deleted_dirs, key=lambda x: x.count("/"), reverse=True), 1 ): try: - await viking_fs.rm(deleted_dir, recursive=True, ctx=self._current_ctx) + await viking_fs.rm(deleted_dir, recursive=True, ctx=ctx) except Exception as e: total_failed += 1 logger.warning( @@ -557,13 +488,15 @@ def map_to_target(root_item_uri: str) -> str: f"{deleted_dir}, error={e}" ) - async def _collect_children_abstracts(self, children_uris: List[str]) -> List[Dict[str, str]]: + async def _collect_children_abstracts( + self, children_uris: List[str], ctx: Optional[RequestContext] = None + ) -> List[Dict[str, str]]: """Collect .abstract.md from subdirectories.""" viking_fs = get_viking_fs() results = [] for child_uri in children_uris: - abstract = await viking_fs.abstract(child_uri, ctx=self._current_ctx) + abstract = await viking_fs.abstract(child_uri, ctx=ctx) dir_name = child_uri.split("/")[-1] results.append({"name": dir_name, "abstract": abstract}) return results @@ -786,9 +719,9 @@ async def _vectorize_directory( from openviking.utils.embedding_utils import vectorize_directory_meta tracker = EmbeddingTaskTracker.get_instance() - await tracker.increment( - semantic_msg_id=semantic_msg_id, - ) + # Increment task for .abstract.md + await tracker.increment(semantic_msg_id=semantic_msg_id) + # Increment task for .overview.md await tracker.increment(semantic_msg_id=semantic_msg_id) active_ctx = ctx or self._current_ctx diff --git a/tests/unit/session/test_compressor_cow.py b/tests/unit/session/test_compressor_cow.py new file mode 100644 index 00000000..cf6efe12 --- /dev/null +++ b/tests/unit/session/test_compressor_cow.py @@ -0,0 +1,458 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.core.context import Context +from openviking.session.compressor import SessionCompressor +from openviking.session.memory_extractor import CandidateMemory, MemoryCategory +from openviking_cli.session.user_id import UserIdentifier + + +def _make_user() -> UserIdentifier: + return UserIdentifier("acc1", "test_user", "test_agent") + + +def _make_candidate(category: MemoryCategory = MemoryCategory.PREFERENCES) -> CandidateMemory: + return CandidateMemory( + category=category, + abstract="User prefers concise summaries", + overview="User asks for concise answers frequently.", + content="The user prefers concise summaries over long explanations.", + source_session="session_test", + user=_make_user(), + language="en", + ) + + +def _make_context(uri: str, abstract: str = "Existing memory") -> Context: + return Context( + uri=uri, + context_type="memory", + abstract=abstract, + meta={"overview": "Existing overview"}, + ) + + +class TestConvertToTempUri: + def test_user_uri_converted_to_temp_uri(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + + result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, None) + + expected = f"{user_temp_uri}/memories/preferences/pref1.md" + assert result == expected + + def test_agent_uri_converted_to_temp_uri(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + agent_space = _make_user().agent_space_name() + target_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" + agent_temp_uri = "viking://agent/temp_agent_456" + + result = compressor._convert_to_temp_uri(target_uri, None, agent_temp_uri) + + expected = f"{agent_temp_uri}/memories/cases/case1.md" + assert result == expected + + def test_no_temp_uri_returns_original_uri(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + + result = compressor._convert_to_temp_uri(target_uri, None, None) + + assert result == target_uri + + def test_mixed_uris_only_convert_matching_type(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + agent_space = _make_user().agent_space_name() + + user_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + agent_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" + + user_temp_uri = "viking://user/temp_user_123" + + result_user = compressor._convert_to_temp_uri(user_uri, user_temp_uri, None) + result_agent = compressor._convert_to_temp_uri(agent_uri, user_temp_uri, None) + + assert result_user == f"{user_temp_uri}/memories/preferences/pref1.md" + assert result_agent == agent_uri + + def test_agent_uri_with_both_temp_uris(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + agent_space = _make_user().agent_space_name() + target_uri = f"viking://agent/{agent_space}/memories/patterns/pattern1.md" + user_temp_uri = "viking://user/temp_user_123" + agent_temp_uri = "viking://agent/temp_agent_456" + + result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, agent_temp_uri) + + expected = f"{agent_temp_uri}/memories/patterns/pattern1.md" + assert result == expected + + def test_user_uri_with_both_temp_uris(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/entities/entity1.md" + user_temp_uri = "viking://user/temp_user_123" + agent_temp_uri = "viking://agent/temp_agent_456" + + result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, agent_temp_uri) + + expected = f"{user_temp_uri}/memories/entities/entity1.md" + assert result == expected + + def test_non_viking_uri_returns_unchanged(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + target_uri = "file:///some/local/path/memory.md" + user_temp_uri = "viking://user/temp_user_123" + + result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, None) + + assert result == target_uri + + def test_short_uri_returns_unchanged(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + target_uri = "viking://user" + user_temp_uri = "viking://user/temp_user_123" + + result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, None) + + assert result == target_uri + + +@pytest.mark.asyncio +class TestMergeIntoExisting: + async def test_merge_into_existing_success(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + temp_uri = f"{user_temp_uri}/memories/preferences/pref1.md" + + candidate = _make_candidate() + target_memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.read_file = AsyncMock(return_value="Existing content") + viking_fs.write_file = AsyncMock() + + mock_payload = MagicMock() + mock_payload.abstract = "Merged abstract" + mock_payload.overview = "Merged overview" + mock_payload.content = "Merged content" + + with patch.object( + compressor.extractor, + "_merge_memory_bundle", + AsyncMock(return_value=mock_payload), + ): + ctx = MagicMock() + result = await compressor._merge_into_existing( + candidate, + target_memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is True + viking_fs.read_file.assert_called_once() + viking_fs.write_file.assert_called_once() + assert target_memory.abstract == "Merged abstract" + assert target_memory.meta.get("overview") == "Merged overview" + + async def test_merge_into_existing_without_temp_uri(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + + candidate = _make_candidate() + target_memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.read_file = AsyncMock(return_value="Existing content") + viking_fs.write_file = AsyncMock() + + mock_payload = MagicMock() + mock_payload.abstract = "Merged abstract" + mock_payload.overview = "Merged overview" + mock_payload.content = "Merged content" + + with patch.object( + compressor.extractor, + "_merge_memory_bundle", + AsyncMock(return_value=mock_payload), + ): + ctx = MagicMock() + result = await compressor._merge_into_existing( + candidate, + target_memory, + viking_fs, + ctx=ctx, + user_temp_uri=None, + agent_temp_uri=None, + ) + + assert result is True + viking_fs.read_file.assert_called_once_with(target_uri, ctx=ctx) + + async def test_merge_into_existing_merge_bundle_returns_none(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + + candidate = _make_candidate() + target_memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.read_file = AsyncMock(return_value="Existing content") + + with patch.object( + compressor.extractor, + "_merge_memory_bundle", + AsyncMock(return_value=None), + ): + ctx = MagicMock() + result = await compressor._merge_into_existing( + candidate, + target_memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is False + viking_fs.write_file.assert_not_called() + + async def test_merge_into_existing_read_file_exception(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + + candidate = _make_candidate() + target_memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.read_file = AsyncMock(side_effect=Exception("Read error")) + + ctx = MagicMock() + result = await compressor._merge_into_existing( + candidate, + target_memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is False + + async def test_merge_into_existing_agent_uri(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + agent_space = _make_user().agent_space_name() + target_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" + agent_temp_uri = "viking://agent/temp_agent_456" + + candidate = _make_candidate(category=MemoryCategory.CASES) + target_memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.read_file = AsyncMock(return_value="Existing content") + viking_fs.write_file = AsyncMock() + + mock_payload = MagicMock() + mock_payload.abstract = "Merged case abstract" + mock_payload.overview = "Merged case overview" + mock_payload.content = "Merged case content" + + with patch.object( + compressor.extractor, + "_merge_memory_bundle", + AsyncMock(return_value=mock_payload), + ): + ctx = MagicMock() + result = await compressor._merge_into_existing( + candidate, + target_memory, + viking_fs, + ctx=ctx, + user_temp_uri=None, + agent_temp_uri=agent_temp_uri, + ) + + assert result is True + expected_temp_uri = f"{agent_temp_uri}/memories/cases/case1.md" + viking_fs.read_file.assert_called_once_with(expected_temp_uri, ctx=ctx) + + +@pytest.mark.asyncio +class TestDeleteExistingMemory: + async def test_delete_existing_memory_success(self): + vikingdb = MagicMock() + vikingdb.delete_uris = AsyncMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + temp_uri = f"{user_temp_uri}/memories/preferences/pref1.md" + + memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.rm = AsyncMock() + + ctx = MagicMock() + result = await compressor._delete_existing_memory( + memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is True + viking_fs.rm.assert_called_once_with(temp_uri, recursive=False, ctx=ctx) + vikingdb.delete_uris.assert_called_once_with(ctx, [temp_uri]) + + async def test_delete_existing_memory_without_temp_uri(self): + vikingdb = MagicMock() + vikingdb.delete_uris = AsyncMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + + memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.rm = AsyncMock() + + ctx = MagicMock() + result = await compressor._delete_existing_memory( + memory, + viking_fs, + ctx=ctx, + user_temp_uri=None, + agent_temp_uri=None, + ) + + assert result is True + viking_fs.rm.assert_called_once_with(target_uri, recursive=False, ctx=ctx) + vikingdb.delete_uris.assert_called_once_with(ctx, [target_uri]) + + async def test_delete_existing_memory_rm_exception(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + + memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.rm = AsyncMock(side_effect=Exception("Delete error")) + + ctx = MagicMock() + result = await compressor._delete_existing_memory( + memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is False + + async def test_delete_existing_memory_vector_delete_exception(self): + vikingdb = MagicMock() + vikingdb.delete_uris = AsyncMock(side_effect=Exception("Vector delete error")) + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + + memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.rm = AsyncMock() + + ctx = MagicMock() + result = await compressor._delete_existing_memory( + memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is True + viking_fs.rm.assert_called_once() + vikingdb.delete_uris.assert_called_once() + + async def test_delete_existing_memory_agent_uri(self): + vikingdb = MagicMock() + vikingdb.delete_uris = AsyncMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + agent_space = _make_user().agent_space_name() + target_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" + agent_temp_uri = "viking://agent/temp_agent_456" + temp_uri = f"{agent_temp_uri}/memories/cases/case1.md" + + memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.rm = AsyncMock() + + ctx = MagicMock() + result = await compressor._delete_existing_memory( + memory, + viking_fs, + ctx=ctx, + user_temp_uri=None, + agent_temp_uri=agent_temp_uri, + ) + + assert result is True + viking_fs.rm.assert_called_once_with(temp_uri, recursive=False, ctx=ctx) + vikingdb.delete_uris.assert_called_once_with(ctx, [temp_uri]) diff --git a/tests/unit/session/test_deduplicator_uri.py b/tests/unit/session/test_deduplicator_uri.py new file mode 100644 index 00000000..900375cd --- /dev/null +++ b/tests/unit/session/test_deduplicator_uri.py @@ -0,0 +1,301 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from openviking.core.context import Context +from openviking.session.memory_deduplicator import MemoryDeduplicator +from openviking.session.memory_extractor import CandidateMemory, MemoryCategory +from openviking_cli.session.user_id import UserIdentifier + + +class _DummyEmbedResult: + def __init__(self, dense_vector): + self.dense_vector = dense_vector + + +class _DummyEmbedder: + def embed(self, _text): + return _DummyEmbedResult([0.1, 0.2, 0.3]) + + +def _make_user() -> UserIdentifier: + return UserIdentifier("acc1", "test_user", "test_agent") + + +def _make_candidate(category: MemoryCategory = MemoryCategory.PREFERENCES) -> CandidateMemory: + return CandidateMemory( + category=category, + abstract="User prefers concise summaries", + overview="User asks for concise answers frequently.", + content="The user prefers concise summaries over long explanations.", + source_session="session_test", + user=_make_user(), + language="en", + ) + + +def _make_existing_user_memory(uri_suffix: str = "existing.md") -> dict: + user_space = _make_user().user_space_name() + return { + "id": f"uri_{uri_suffix}", + "uri": f"viking://user/{user_space}/memories/preferences/{uri_suffix}", + "context_type": "memory", + "level": 2, + "account_id": "acc1", + "owner_space": user_space, + "abstract": "Existing preference memory", + "category": "preferences", + "_score": 0.85, + } + + +def _make_existing_agent_memory(uri_suffix: str = "case1.md") -> dict: + user = _make_user() + agent_space = user.agent_space_name() + return { + "id": f"uri_{uri_suffix}", + "uri": f"viking://agent/{agent_space}/memories/cases/{uri_suffix}", + "context_type": "memory", + "level": 2, + "account_id": "acc1", + "owner_space": agent_space, + "abstract": "Existing case memory", + "category": "cases", + "_score": 0.90, + } + + +@pytest.mark.asyncio +class TestFindSimilarMemoriesURIConversion: + async def test_user_uri_converted_to_temp_uri(self): + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock(return_value=[_make_existing_user_memory("pref1.md")]) + + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate() + + user_temp_uri = "viking://user/temp_user_123" + similar = await dedup._find_similar_memories( + candidate, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert len(similar) == 1 + user_space = _make_user().user_space_name() + original_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + expected_uri = f"{user_temp_uri}/memories/preferences/pref1.md" + assert similar[0].uri == expected_uri + assert similar[0].uri != original_uri + + async def test_agent_uri_converted_to_temp_uri(self): + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock(return_value=[_make_existing_agent_memory("case1.md")]) + + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate(category=MemoryCategory.CASES) + + agent_temp_uri = "viking://agent/temp_agent_456" + similar = await dedup._find_similar_memories( + candidate, + user_temp_uri=None, + agent_temp_uri=agent_temp_uri, + ) + + assert len(similar) == 1 + user = _make_user() + agent_space = user.agent_space_name() + original_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" + expected_uri = f"{agent_temp_uri}/memories/cases/case1.md" + assert similar[0].uri == expected_uri + assert similar[0].uri != original_uri + + async def test_no_conversion_when_no_temp_uri(self): + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock(return_value=[_make_existing_user_memory("pref1.md")]) + + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate() + + similar = await dedup._find_similar_memories( + candidate, + user_temp_uri=None, + agent_temp_uri=None, + ) + + assert len(similar) == 1 + user_space = _make_user().user_space_name() + expected_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + assert similar[0].uri == expected_uri + + async def test_mixed_uris_only_convert_matching_type(self): + user_space = _make_user().user_space_name() + agent_space = _make_user().agent_space_name() + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock( + return_value=[ + _make_existing_user_memory("pref1.md"), + _make_existing_agent_memory("case1.md"), + ] + ) + + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate() + + user_temp_uri = "viking://user/temp_user_123" + similar = await dedup._find_similar_memories( + candidate, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert len(similar) == 2 + uris = {m.uri for m in similar} + assert f"{user_temp_uri}/memories/preferences/pref1.md" in uris + assert f"viking://agent/{agent_space}/memories/cases/case1.md" in uris + + async def test_uri_conversion_preserves_meta_and_score(self): + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock(return_value=[_make_existing_user_memory("pref1.md")]) + + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate() + + user_temp_uri = "viking://user/temp_user_123" + similar = await dedup._find_similar_memories( + candidate, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert len(similar) == 1 + assert similar[0].meta is not None + assert similar[0].meta.get("_dedup_score") == 0.85 + + +class TestExtractFacetKey: + def test_extract_with_chinese_colon(self): + result = MemoryDeduplicator._extract_facet_key("饮食偏好:喜欢吃苹果和草莓") + assert result == "饮食偏好" + + def test_extract_with_english_colon(self): + result = MemoryDeduplicator._extract_facet_key("User preference: dark mode enabled") + assert result == "user preference" + + def test_extract_with_hyphen(self): + result = MemoryDeduplicator._extract_facet_key("Coding style - prefer type hints") + assert result == "coding style" + + def test_extract_with_em_dash(self): + result = MemoryDeduplicator._extract_facet_key("Work schedule — remote on Fridays") + assert result == "work schedule" + + def test_extract_with_no_separator_returns_prefix(self): + result = MemoryDeduplicator._extract_facet_key("This is a long abstract without any separator") + assert len(result) <= 24 + assert result == "this is a long abstract" + + def test_extract_with_empty_string(self): + result = MemoryDeduplicator._extract_facet_key("") + assert result == "" + + def test_extract_with_none(self): + result = MemoryDeduplicator._extract_facet_key(None) + assert result == "" + + def test_extract_normalizes_whitespace(self): + result = MemoryDeduplicator._extract_facet_key(" Multiple spaces : value ") + assert result == "multiple spaces" + + def test_extract_with_short_text_no_separator(self): + result = MemoryDeduplicator._extract_facet_key("Short") + assert result == "short" + + def test_extract_returns_lowercase(self): + result = MemoryDeduplicator._extract_facet_key("FOOD PREFERENCE: pizza") + assert result == "food preference" + + def test_extract_with_separator_at_start(self): + result = MemoryDeduplicator._extract_facet_key(": starts with separator") + assert result == ": starts with" + + def test_extract_with_multiple_separators_uses_first(self): + result = MemoryDeduplicator._extract_facet_key("Topic: Subtopic - Detail") + assert result == "topic" + + +class TestCosineSimilarity: + def test_identical_vectors(self): + vec = [1.0, 2.0, 3.0] + result = MemoryDeduplicator._cosine_similarity(vec, vec) + assert abs(result - 1.0) < 1e-9 + + def test_orthogonal_vectors(self): + vec_a = [1.0, 0.0] + vec_b = [0.0, 1.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert abs(result) < 1e-9 + + def test_opposite_vectors(self): + vec_a = [1.0, 2.0, 3.0] + vec_b = [-1.0, -2.0, -3.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert abs(result + 1.0) < 1e-9 + + def test_different_length_vectors(self): + vec_a = [1.0, 2.0, 3.0] + vec_b = [1.0, 2.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert result == 0.0 + + def test_zero_vector_a(self): + vec_a = [0.0, 0.0, 0.0] + vec_b = [1.0, 2.0, 3.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert result == 0.0 + + def test_zero_vector_b(self): + vec_a = [1.0, 2.0, 3.0] + vec_b = [0.0, 0.0, 0.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert result == 0.0 + + def test_both_zero_vectors(self): + vec_a = [0.0, 0.0, 0.0] + vec_b = [0.0, 0.0, 0.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert result == 0.0 + + def test_partial_similarity(self): + vec_a = [1.0, 0.0, 0.0] + vec_b = [1.0, 1.0, 0.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + expected = 1.0 / (2.0 ** 0.5) + assert abs(result - expected) < 1e-9 + + def test_negative_values(self): + vec_a = [1.0, -2.0, 3.0] + vec_b = [-1.0, 2.0, 3.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert 0 < result < 1 + + def test_single_element_vectors(self): + vec_a = [5.0] + vec_b = [3.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert abs(result - 1.0) < 1e-9 + + def test_large_vectors(self): + vec_a = [float(i) for i in range(100)] + vec_b = [float(i * 2) for i in range(100)] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert abs(result - 1.0) < 1e-6 diff --git a/tests/unit/session/test_memory_extractor_tools.py b/tests/unit/session/test_memory_extractor_tools.py new file mode 100644 index 00000000..cd8b6189 --- /dev/null +++ b/tests/unit/session/test_memory_extractor_tools.py @@ -0,0 +1,789 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.session.memory_extractor import ( + FIELD_MAX_LENGTH, + FIELD_MAX_LENGTHS, + MemoryExtractor, + MemoryCategory, + ToolSkillCandidateMemory, +) + + +@pytest.fixture +def extractor(): + return MemoryExtractor() + + +class TestParseToolStatistics: + def test_parse_chinese_format_full(self, extractor): + content = """ +Tool: test_tool + +总调用次数: 100 +成功率: 85.0%(85 成功,15 失败) +平均耗时: 150.5ms +平均Token: 500 +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 100 + assert stats["success_count"] == 85 + assert stats["fail_count"] == 15 + assert stats["total_time_ms"] == 15050.0 + assert stats["total_tokens"] == 50000 + + def test_parse_chinese_format_with_colon(self, extractor): + content = """ +总调用次数:200 +成功率:90.5%(181 成功,19 失败) +平均耗时:200.0ms +平均Token:800 +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 200 + assert stats["success_count"] == 181 + assert stats["fail_count"] == 19 + + def test_parse_english_format_full(self, extractor): + content = """ +Tool: test_tool + +Based on 50 historical calls: +- Success rate: 80.0% (40 successful, 10 failed) +- Avg time: 1.5s, Avg tokens: 600 +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 50 + assert stats["success_count"] == 40 + assert stats["fail_count"] == 10 + assert stats["total_time_ms"] == 75000.0 + assert stats["total_tokens"] == 30000 + + def test_parse_english_format_ms(self, extractor): + content = """ +Based on 30 historical calls: +- Success rate: 90.0% (27 successful, 3 failed) +- Avg time: 250.5ms, Avg tokens: 400 +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 30 + assert stats["success_count"] == 27 + assert stats["fail_count"] == 3 + assert stats["total_time_ms"] == 7515.0 + assert stats["total_tokens"] == 12000 + + def test_parse_chinese_success_rate_only(self, extractor): + content = """ +总调用次数: 100 +成功率: 75.0% +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 100 + assert stats["success_count"] == 75 + assert stats["fail_count"] == 25 + + def test_parse_english_success_rate_only(self, extractor): + content = """ +Based on 80 historical calls: +- Success rate: 87.5% +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 80 + assert stats["success_count"] == 70 + assert stats["fail_count"] == 10 + + def test_parse_empty_content(self, extractor): + stats = extractor._parse_tool_statistics("") + assert stats["total_calls"] == 0 + assert stats["success_count"] == 0 + assert stats["fail_count"] == 0 + assert stats["total_time_ms"] == 0 + assert stats["total_tokens"] == 0 + + def test_parse_chinese_avg_time_seconds(self, extractor): + content = """ +总调用次数: 10 +平均耗时: 2.5s +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 10 + assert stats["total_time_ms"] == 25000.0 + + def test_parse_no_total_calls_infers_from_success_fail(self, extractor): + content = """ +成功率: 80.0%(40 成功,10 失败) +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 50 + assert stats["success_count"] == 40 + assert stats["fail_count"] == 10 + + +class TestMergeToolStatistics: + def test_merge_basic(self, extractor): + existing = { + "total_calls": 100, + "success_count": 80, + "fail_count": 20, + "total_time_ms": 10000.0, + "total_tokens": 50000, + } + new = { + "total_calls": 50, + "success_count": 45, + "fail_count": 5, + "total_time_ms": 5000.0, + "total_tokens": 25000, + } + merged = extractor._merge_tool_statistics(existing, new) + assert merged["total_calls"] == 150 + assert merged["success_count"] == 125 + assert merged["fail_count"] == 25 + assert merged["total_time_ms"] == 15000.0 + assert merged["total_tokens"] == 75000 + assert abs(merged["avg_time_ms"] - 100.0) < 0.01 + assert abs(merged["avg_tokens"] - 500.0) < 0.01 + assert abs(merged["success_rate"] - 0.8333) < 0.01 + + def test_merge_with_zero_existing(self, extractor): + existing = { + "total_calls": 0, + "success_count": 0, + "fail_count": 0, + "total_time_ms": 0, + "total_tokens": 0, + } + new = { + "total_calls": 10, + "success_count": 8, + "fail_count": 2, + "total_time_ms": 1000.0, + "total_tokens": 5000, + } + merged = extractor._merge_tool_statistics(existing, new) + assert merged["total_calls"] == 10 + assert merged["success_count"] == 8 + assert merged["fail_count"] == 2 + + def test_merge_with_zero_new(self, extractor): + existing = { + "total_calls": 20, + "success_count": 15, + "fail_count": 5, + "total_time_ms": 2000.0, + "total_tokens": 10000, + } + new = { + "total_calls": 0, + "success_count": 0, + "fail_count": 0, + "total_time_ms": 0, + "total_tokens": 0, + } + merged = extractor._merge_tool_statistics(existing, new) + assert merged["total_calls"] == 20 + assert merged["success_count"] == 15 + assert merged["fail_count"] == 5 + + def test_merge_both_zero(self, extractor): + existing = { + "total_calls": 0, + "success_count": 0, + "fail_count": 0, + "total_time_ms": 0, + "total_tokens": 0, + } + new = { + "total_calls": 0, + "success_count": 0, + "fail_count": 0, + "total_time_ms": 0, + "total_tokens": 0, + } + merged = extractor._merge_tool_statistics(existing, new) + assert merged["total_calls"] == 0 + assert merged["avg_time_ms"] == 0 + assert merged["avg_tokens"] == 0 + assert merged["success_rate"] == 0 + + +class TestGenerateToolMemoryContent: + def test_generate_basic(self, extractor): + with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): + stats = { + "total_calls": 100, + "success_count": 85, + "fail_count": 15, + "avg_time_ms": 150.5, + "avg_tokens": 500, + "success_rate": 0.85, + } + guidelines = "Use this tool for testing purposes." + content = extractor._generate_tool_memory_content("test_tool", stats, guidelines) + assert "Tool: test_tool" in content + assert "Based on 100 historical calls:" in content + assert "Success rate: 85.0%" in content + assert "85 successful, 15 failed" in content + assert "Use this tool for testing purposes." in content + + def test_generate_with_fields(self, extractor): + with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): + stats = { + "total_calls": 50, + "success_count": 40, + "fail_count": 10, + "avg_time_ms": 200.0, + "avg_tokens": 600, + "success_rate": 0.8, + } + fields = { + "best_for": "Data processing tasks", + "optimal_params": "batch_size=100", + "common_failures": "Timeout on large inputs", + "recommendation": "Use with small batches", + } + content = extractor._generate_tool_memory_content("test_tool", stats, "", fields=fields) + assert "Best for: Data processing tasks" in content + assert "Optimal params: batch_size=100" in content + assert "Common failures: Timeout on large inputs" in content + assert "Recommendation: Use with small batches" in content + + def test_generate_with_empty_fields(self, extractor): + with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): + stats = { + "total_calls": 10, + "success_count": 10, + "fail_count": 0, + "avg_time_ms": 100.0, + "avg_tokens": 300, + "success_rate": 1.0, + } + content = extractor._generate_tool_memory_content("test_tool", stats, "", fields={}) + assert "Best for: " in content + assert "Optimal params: " in content + + def test_generate_extracts_fields_from_guidelines(self, extractor): + with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): + stats = { + "total_calls": 20, + "success_count": 18, + "fail_count": 2, + "avg_time_ms": 50.0, + "avg_tokens": 200, + "success_rate": 0.9, + } + guidelines = """ +Best for: Quick data validation +Optimal params: strict_mode=true +Common failures: Invalid input format +Recommendation: Always validate input first +""" + content = extractor._generate_tool_memory_content("test_tool", stats, guidelines) + assert "Best for: Quick data validation" in content + assert "Optimal params: strict_mode=true" in content + + +class TestParseSkillStatistics: + def test_parse_chinese_format_full(self, extractor): + content = """ +Skill: test_skill + +总执行次数: 100 +成功率: 90.0%(90 成功,10 失败) +""" + stats = extractor._parse_skill_statistics(content) + assert stats["total_executions"] == 100 + assert stats["success_count"] == 90 + assert stats["fail_count"] == 10 + + def test_parse_chinese_format_with_colon(self, extractor): + content = """ +总执行次数:50 +成功率:80.0%(40 成功,10 失败) +""" + stats = extractor._parse_skill_statistics(content) + assert stats["total_executions"] == 50 + assert stats["success_count"] == 40 + assert stats["fail_count"] == 10 + + def test_parse_english_format_full(self, extractor): + content = """ +Skill: test_skill + +Based on 75 historical executions: +- Success rate: 85.0% (64 successful, 11 failed) +""" + stats = extractor._parse_skill_statistics(content) + assert stats["total_executions"] == 75 + assert stats["success_count"] == 64 + assert stats["fail_count"] == 11 + + def test_parse_english_success_rate_only(self, extractor): + content = """ +Based on 60 historical executions: +- Success rate: 75.0% +""" + stats = extractor._parse_skill_statistics(content) + assert stats["total_executions"] == 60 + assert stats["success_count"] == 45 + assert stats["fail_count"] == 15 + + def test_parse_empty_content(self, extractor): + stats = extractor._parse_skill_statistics("") + assert stats["total_executions"] == 0 + assert stats["success_count"] == 0 + assert stats["fail_count"] == 0 + + def test_parse_no_total_executions_infers_from_success_fail(self, extractor): + content = """ +成功率: 70.0%(35 成功,15 失败) +""" + stats = extractor._parse_skill_statistics(content) + assert stats["total_executions"] == 50 + assert stats["success_count"] == 35 + assert stats["fail_count"] == 15 + + +class TestMergeSkillStatistics: + def test_merge_basic(self, extractor): + existing = { + "total_executions": 100, + "success_count": 90, + "fail_count": 10, + } + new = { + "total_executions": 50, + "success_count": 45, + "fail_count": 5, + } + merged = extractor._merge_skill_statistics(existing, new) + assert merged["total_executions"] == 150 + assert merged["success_count"] == 135 + assert merged["fail_count"] == 15 + assert abs(merged["success_rate"] - 0.9) < 0.01 + + def test_merge_with_zero_existing(self, extractor): + existing = { + "total_executions": 0, + "success_count": 0, + "fail_count": 0, + } + new = { + "total_executions": 20, + "success_count": 18, + "fail_count": 2, + } + merged = extractor._merge_skill_statistics(existing, new) + assert merged["total_executions"] == 20 + assert merged["success_count"] == 18 + assert merged["fail_count"] == 2 + + def test_merge_with_zero_new(self, extractor): + existing = { + "total_executions": 30, + "success_count": 25, + "fail_count": 5, + } + new = { + "total_executions": 0, + "success_count": 0, + "fail_count": 0, + } + merged = extractor._merge_skill_statistics(existing, new) + assert merged["total_executions"] == 30 + assert merged["success_count"] == 25 + assert merged["fail_count"] == 5 + + def test_merge_both_zero(self, extractor): + existing = { + "total_executions": 0, + "success_count": 0, + "fail_count": 0, + } + new = { + "total_executions": 0, + "success_count": 0, + "fail_count": 0, + } + merged = extractor._merge_skill_statistics(existing, new) + assert merged["total_executions"] == 0 + assert merged["success_rate"] == 0 + + +class TestGenerateSkillMemoryContent: + def test_generate_basic(self, extractor): + stats = { + "total_executions": 100, + "success_count": 90, + "fail_count": 10, + "success_rate": 0.9, + } + guidelines = "Use this skill for data processing." + content = extractor._generate_skill_memory_content("test_skill", stats, guidelines) + assert "Skill: test_skill" in content + assert "Based on 100 historical executions:" in content + assert "Success rate: 90.0%" in content + assert "90 successful, 10 failed" in content + assert "Use this skill for data processing." in content + + def test_generate_with_fields(self, extractor): + stats = { + "total_executions": 50, + "success_count": 45, + "fail_count": 5, + "success_rate": 0.9, + } + fields = { + "best_for": "Automated workflows", + "recommended_flow": "Step 1 -> Step 2 -> Step 3", + "key_dependencies": "Database connection", + "common_failures": "Network timeout", + "recommendation": "Use with retry logic", + } + content = extractor._generate_skill_memory_content("test_skill", stats, "", fields=fields) + assert "Best for: Automated workflows" in content + assert "Recommended flow: Step 1 -> Step 2 -> Step 3" in content + assert "Key dependencies: Database connection" in content + assert "Common failures: Network timeout" in content + assert "Recommendation: Use with retry logic" in content + + def test_generate_with_empty_fields(self, extractor): + stats = { + "total_executions": 10, + "success_count": 10, + "fail_count": 0, + "success_rate": 1.0, + } + content = extractor._generate_skill_memory_content("test_skill", stats, "", fields={}) + assert "Best for: " in content + assert "Recommended flow: " in content + + def test_generate_extracts_fields_from_guidelines(self, extractor): + stats = { + "total_executions": 20, + "success_count": 18, + "fail_count": 2, + "success_rate": 0.9, + } + guidelines = """ +Best for: Complex data transformations +Recommended flow: Validate -> Transform -> Store +Key dependencies: S3 bucket access +Common failures: Permission denied +Recommendation: Check permissions first +""" + content = extractor._generate_skill_memory_content("test_skill", stats, guidelines) + assert "Best for: Complex data transformations" in content + assert "Recommended flow: Validate -> Transform -> Store" in content + + +class TestMergeKvField: + @pytest.mark.asyncio + async def test_merge_both_empty(self, extractor): + result = await extractor._merge_kv_field("", "", "best_for") + assert result == "" + + @pytest.mark.asyncio + async def test_merge_existing_empty(self, extractor): + result = await extractor._merge_kv_field("", "new value", "best_for") + assert result == "new value" + + @pytest.mark.asyncio + async def test_merge_new_empty(self, extractor): + result = await extractor._merge_kv_field("existing value", "", "best_for") + assert result == "existing value" + + @pytest.mark.asyncio + async def test_merge_identical_values(self, extractor): + result = await extractor._merge_kv_field("same value", "same value", "best_for") + assert result == "same value" + + @pytest.mark.asyncio + async def test_merge_different_values(self, extractor): + result = await extractor._merge_kv_field("value A", "value B", "best_for") + assert "value A" in result + assert "value B" in result + assert ";" in result + + @pytest.mark.asyncio + async def test_merge_with_semicolon_separator(self, extractor): + result = await extractor._merge_kv_field("item1; item2", "item3", "best_for") + assert "item1" in result + assert "item2" in result + assert "item3" in result + + @pytest.mark.asyncio + async def test_merge_deduplicates(self, extractor): + result = await extractor._merge_kv_field("item1; item2", "item2; item3", "best_for") + assert result.count("item2") == 1 + assert "item1" in result + assert "item3" in result + + @pytest.mark.asyncio + async def test_merge_respects_max_length(self, extractor): + long_value = "x" * 600 + result = await extractor._merge_kv_field(long_value, "new", "best_for") + assert len(result) <= FIELD_MAX_LENGTHS["best_for"] + + @pytest.mark.asyncio + async def test_merge_with_newline_separator(self, extractor): + result = await extractor._merge_kv_field("item1\nitem2", "item3", "best_for") + assert "item1" in result + assert "item2" in result + assert "item3" in result + + +class TestSmartTruncate: + def test_no_truncation_needed(self, extractor): + text = "short text" + result = extractor._smart_truncate(text, 100) + assert result == text + + def test_truncate_at_semicolon(self, extractor): + text = "item1; item2; item3; item4; item5" + result = extractor._smart_truncate(text, 25) + assert len(result) <= 25 + assert result.endswith(";") or result.count(";") >= 1 + + def test_truncate_at_space(self, extractor): + text = "word1 word2 word3 word4 word5" + result = extractor._smart_truncate(text, 20) + assert len(result) <= 20 + + def test_truncate_fallback(self, extractor): + text = "abcdefghijklmnopqrstuvwxyz" + result = extractor._smart_truncate(text, 10) + assert len(result) == 10 + assert result == "abcdefghij" + + def test_truncate_empty_string(self, extractor): + result = extractor._smart_truncate("", 10) + assert result == "" + + def test_truncate_exact_length(self, extractor): + text = "exactly10!" + result = extractor._smart_truncate(text, 10) + assert result == text + + +class TestComputeStatisticsDerived: + def test_compute_with_calls(self, extractor): + stats = { + "total_calls": 100, + "success_count": 80, + "fail_count": 20, + "total_time_ms": 10000.0, + "total_tokens": 50000, + } + result = extractor._compute_statistics_derived(stats) + assert abs(result["avg_time_ms"] - 100.0) < 0.01 + assert abs(result["avg_tokens"] - 500.0) < 0.01 + assert abs(result["success_rate"] - 0.8) < 0.01 + + def test_compute_with_zero_calls(self, extractor): + stats = { + "total_calls": 0, + "success_count": 0, + "fail_count": 0, + "total_time_ms": 0, + "total_tokens": 0, + } + result = extractor._compute_statistics_derived(stats) + assert result["avg_time_ms"] == 0 + assert result["avg_tokens"] == 0 + assert result["success_rate"] == 0 + + def test_compute_preserves_original_values(self, extractor): + stats = { + "total_calls": 50, + "success_count": 40, + "fail_count": 10, + "total_time_ms": 5000.0, + "total_tokens": 25000, + } + result = extractor._compute_statistics_derived(stats) + assert result["total_calls"] == 50 + assert result["success_count"] == 40 + assert result["fail_count"] == 10 + assert result["total_time_ms"] == 5000.0 + assert result["total_tokens"] == 25000 + + +class TestFormatDuration: + def test_format_zero(self, extractor): + result = extractor._format_duration(0) + assert result == "0s" + + def test_format_milliseconds(self, extractor): + result = extractor._format_duration(500) + assert result == "500ms" + + def test_format_seconds(self, extractor): + result = extractor._format_duration(1500) + assert result == "1.5s" + + def test_format_large_seconds(self, extractor): + result = extractor._format_duration(10000) + assert result == "10.0s" + + def test_format_none(self, extractor): + result = extractor._format_duration(None) + assert result == "N/A" + + def test_format_negative(self, extractor): + result = extractor._format_duration(-100) + assert result == "0s" + + def test_format_invalid_type(self, extractor): + result = extractor._format_duration("invalid") + assert result == "N/A" + + def test_format_exactly_one_second(self, extractor): + result = extractor._format_duration(1000) + assert result == "1.0s" + + def test_format_just_under_one_second(self, extractor): + result = extractor._format_duration(999) + assert result == "999ms" + + +class TestExtractContentField: + def test_extract_with_chinese_colon(self, extractor): + content = "Best for:数据处理任务" + result = extractor._extract_content_field(content, ["Best for"]) + assert result == "数据处理任务" + + def test_extract_with_english_colon(self, extractor): + content = "Best for: data processing tasks" + result = extractor._extract_content_field(content, ["Best for"]) + assert result == "data processing tasks" + + def test_extract_with_multiple_keys(self, extractor): + content = "最佳场景: 快速验证" + result = extractor._extract_content_field(content, ["Best for", "最佳场景"]) + assert result == "快速验证" + + def test_extract_not_found(self, extractor): + content = "Some other content" + result = extractor._extract_content_field(content, ["Best for"]) + assert result == "" + + def test_extract_empty_content(self, extractor): + result = extractor._extract_content_field("", ["Best for"]) + assert result == "" + + +class TestCompactBlock: + def test_compact_basic(self, extractor): + text = "Line 1\nLine 2\nLine 3" + result = extractor._compact_block(text) + assert result == "Line 1; Line 2; Line 3" + + def test_compact_with_prefixes(self, extractor): + text = "> Point 1\n- Point 2\n* Point 3" + result = extractor._compact_block(text) + assert "Point 1" in result + assert "Point 2" in result + assert "Point 3" in result + + def test_compact_empty(self, extractor): + result = extractor._compact_block("") + assert result == "" + + def test_compact_whitespace_only(self, extractor): + result = extractor._compact_block(" \n \n ") + assert result == "" + + +class TestExtractToolMemoryContextFieldsFromText: + def test_extract_all_fields(self, extractor): + text = """ +Best for: Data processing +Optimal params: batch_size=100 +Common failures: Timeout +Recommendation: Use small batches +""" + result = extractor._extract_tool_memory_context_fields_from_text(text) + assert result["best_for"] == "Data processing" + assert result["optimal_params"] == "batch_size=100" + assert result["common_failures"] == "Timeout" + assert result["recommendation"] == "Use small batches" + + def test_extract_partial_fields(self, extractor): + text = """ +Best for: Testing +Recommendation: Run in dev mode +""" + result = extractor._extract_tool_memory_context_fields_from_text(text) + assert result["best_for"] == "Testing" + assert result["optimal_params"] == "" + assert result["common_failures"] == "" + assert result["recommendation"] == "Run in dev mode" + + def test_extract_chinese_fields(self, extractor): + text = """ +最佳场景: 数据处理 +最优参数: 批量大小=100 +常见失败: 超时 +推荐: 使用小批量 +""" + result = extractor._extract_tool_memory_context_fields_from_text(text) + assert result["best_for"] == "数据处理" + assert result["optimal_params"] == "批量大小=100" + assert result["common_failures"] == "超时" + assert result["recommendation"] == "使用小批量" + + +class TestExtractSkillMemoryContextFieldsFromText: + def test_extract_all_fields(self, extractor): + text = """ +Best for: Automated workflows +Recommended flow: Step1 -> Step2 -> Step3 +Key dependencies: Database +Common failures: Connection error +Recommendation: Use connection pool +""" + result = extractor._extract_skill_memory_context_fields_from_text(text) + assert result["best_for"] == "Automated workflows" + assert result["recommended_flow"] == "Step1 -> Step2 -> Step3" + assert result["key_dependencies"] == "Database" + assert result["common_failures"] == "Connection error" + assert result["recommendation"] == "Use connection pool" + + def test_extract_chinese_fields(self, extractor): + text = """ +最佳场景: 自动化工作流 +推荐流程: 步骤1 -> 步骤2 -> 步骤3 +关键依赖: 数据库 +常见失败: 连接错误 +推荐: 使用连接池 +""" + result = extractor._extract_skill_memory_context_fields_from_text(text) + assert result["best_for"] == "自动化工作流" + assert result["recommended_flow"] == "步骤1 -> 步骤2 -> 步骤3" + assert result["key_dependencies"] == "数据库" + assert result["common_failures"] == "连接错误" + assert result["recommendation"] == "使用连接池" + + +class TestFormatMs: + def test_format_zero(self, extractor): + result = extractor._format_ms(0) + assert result == "0.000ms" + + def test_format_normal_value(self, extractor): + result = extractor._format_ms(123.456) + assert result == "123.456ms" + + def test_format_very_small_value(self, extractor): + result = extractor._format_ms(0.000123) + assert "ms" in result + assert float(result.replace("ms", "")) > 0 + + def test_format_large_value(self, extractor): + result = extractor._format_ms(9999.999) + assert result == "9999.999ms" diff --git a/tests/unit/session/test_session_cow.py b/tests/unit/session/test_session_cow.py new file mode 100644 index 00000000..26ed353d --- /dev/null +++ b/tests/unit/session/test_session_cow.py @@ -0,0 +1,495 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for Session COW (Copy-on-Write) mode and async commit functionality.""" + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.server.identity import RequestContext, Role +from openviking.session.session import Session +from openviking_cli.session.user_id import UserIdentifier + + +def _make_user() -> UserIdentifier: + return UserIdentifier("test_account", "test_user", "test_agent") + + +def _make_session(viking_fs: MagicMock = None, session_id: str = "test_session_123") -> Session: + user = _make_user() + ctx = RequestContext(user=user, role=Role.ROOT) + fs = viking_fs or MagicMock() + return Session( + viking_fs=fs, + user=user, + ctx=ctx, + session_id=session_id, + ) + + +class TestCreateTempUris: + """Tests for _create_temp_uris() method.""" + + def test_returns_tuple_of_four_uris(self): + session = _make_session() + result = session._create_temp_uris() + + assert isinstance(result, tuple) + assert len(result) == 4 + + def test_temp_base_uri_format(self): + session = _make_session(session_id="sess_abc") + user = session.user + temp_base, _, _, _ = session._create_temp_uris() + + assert temp_base.startswith("viking://temp/session/") + assert f"/{user.user_space_name()}/" in temp_base + assert "/sess_abc/" in temp_base + assert "/commit_" in temp_base + + def test_session_temp_uri_structure(self): + session = _make_session(session_id="sess_abc") + user = session.user + temp_base, session_temp, _, _ = session._create_temp_uris() + + assert session_temp.startswith(temp_base) + assert "/session/" in session_temp + assert f"/{user.user_space_name()}/sess_abc" in session_temp + + def test_user_temp_uri_structure(self): + session = _make_session() + user = session.user + temp_base, _, user_temp, _ = session._create_temp_uris() + + assert user_temp.startswith(temp_base) + assert "/user/" in user_temp + assert user_temp.endswith(f"/user/{user.user_space_name()}") + + def test_agent_temp_uri_structure(self): + session = _make_session() + user = session.user + temp_base, _, _, agent_temp = session._create_temp_uris() + + assert agent_temp.startswith(temp_base) + assert "/agent/" in agent_temp + assert agent_temp.endswith(f"/agent/{user.agent_space_name()}") + + def test_sets_internal_state(self): + session = _make_session() + temp_base, session_temp, user_temp, agent_temp = session._create_temp_uris() + + assert session._temp_base_uri == temp_base + assert session._session_temp_uri == session_temp + assert session._user_temp_uri == user_temp + assert session._agent_temp_uri == agent_temp + assert session._temp_created_at is not None + + def test_temp_created_at_is_recent(self): + session = _make_session() + before = time.time() + session._create_temp_uris() + after = time.time() + + assert before <= session._temp_created_at <= after + + def test_commit_uuid_is_8_chars(self): + session = _make_session() + temp_base, _, _, _ = session._create_temp_uris() + + commit_part = temp_base.split("/commit_")[-1] + assert len(commit_part) == 8 + assert all(c in "0123456789abcdef" for c in commit_part) + + def test_multiple_calls_generate_different_uuids(self): + session = _make_session() + temp_base1, _, _, _ = session._create_temp_uris() + temp_base2, _, _, _ = session._create_temp_uris() + + assert temp_base1 != temp_base2 + + +class TestCleanupTempUris: + """Tests for _cleanup_temp_uris() method.""" + + @pytest.mark.asyncio + async def test_calls_delete_temp_on_viking_fs(self): + viking_fs = MagicMock() + viking_fs.delete_temp = AsyncMock() + session = _make_session(viking_fs=viking_fs) + + session._create_temp_uris() + saved_temp_base = session._temp_base_uri + await session._cleanup_temp_uris() + + viking_fs.delete_temp.assert_called_once() + call_args = viking_fs.delete_temp.call_args + assert call_args[0][0] == saved_temp_base + + @pytest.mark.asyncio + async def test_resets_internal_state(self): + viking_fs = MagicMock() + viking_fs.delete_temp = AsyncMock() + session = _make_session(viking_fs=viking_fs) + + session._create_temp_uris() + await session._cleanup_temp_uris() + + assert session._temp_base_uri is None + assert session._session_temp_uri is None + assert session._user_temp_uri is None + assert session._agent_temp_uri is None + assert session._temp_created_at is None + + @pytest.mark.asyncio + async def test_no_cleanup_when_no_temp_uri(self): + viking_fs = MagicMock() + viking_fs.delete_temp = AsyncMock() + session = _make_session(viking_fs=viking_fs) + + await session._cleanup_temp_uris() + + viking_fs.delete_temp.assert_not_called() + + @pytest.mark.asyncio + async def test_handles_delete_exception(self): + viking_fs = MagicMock() + viking_fs.delete_temp = AsyncMock(side_effect=Exception("Delete failed")) + session = _make_session(viking_fs=viking_fs) + + session._create_temp_uris() + await session._cleanup_temp_uris() + + assert session._temp_base_uri is None + + @pytest.mark.asyncio + async def test_passes_ctx_to_delete_temp(self): + viking_fs = MagicMock() + viking_fs.delete_temp = AsyncMock() + session = _make_session(viking_fs=viking_fs) + + session._create_temp_uris() + await session._cleanup_temp_uris() + + call_kwargs = viking_fs.delete_temp.call_args[1] + assert "ctx" in call_kwargs + assert call_kwargs["ctx"] == session.ctx + + +class TestEnqueueToSemanticQueue: + """Tests for _enqueue_to_semantic_queue() method.""" + + @pytest.mark.asyncio + async def test_returns_list_of_three_msg_ids(self): + session = _make_session() + + mock_queue = MagicMock() + mock_queue.enqueue = AsyncMock() + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + result = await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri="viking://temp/user/test", + agent_temp_uri="viking://temp/agent/test", + ) + + assert isinstance(result, list) + assert len(result) == 3 + + @pytest.mark.asyncio + async def test_session_msg_has_correct_target_uri(self): + session = _make_session(session_id="sess_xyz") + user = session.user + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + session_temp = f"viking://temp/session/{user.user_space_name()}/sess_xyz/commit_abc123/session/{user.user_space_name()}/sess_xyz" + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri=session_temp, + user_temp_uri="viking://temp/user/test", + agent_temp_uri="viking://temp/agent/test", + ) + + session_msg = enqueued_msgs[0] + expected_target = f"viking://session/{user.user_space_name()}/sess_xyz" + assert session_msg.target_uri == expected_target + assert session_msg.uri == session_temp + + @pytest.mark.asyncio + async def test_user_msg_has_correct_target_uri(self): + session = _make_session() + user = session.user + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + user_temp = f"viking://temp/session/{user.user_space_name()}/sess/commit_abc/user/{user.user_space_name()}" + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri=user_temp, + agent_temp_uri="viking://temp/agent/test", + ) + + user_msg = enqueued_msgs[1] + expected_target = f"viking://user/{user.user_space_name()}" + assert user_msg.target_uri == expected_target + assert user_msg.uri == user_temp + + @pytest.mark.asyncio + async def test_agent_msg_has_correct_target_uri(self): + session = _make_session() + user = session.user + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + agent_temp = f"viking://temp/session/{user.user_space_name()}/sess/commit_abc/agent/{user.agent_space_name()}" + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri="viking://temp/user/test", + agent_temp_uri=agent_temp, + ) + + agent_msg = enqueued_msgs[2] + expected_target = f"viking://agent/{user.agent_space_name()}" + assert agent_msg.target_uri == expected_target + assert agent_msg.uri == agent_temp + + @pytest.mark.asyncio + async def test_all_msgs_have_context_type_memory(self): + session = _make_session() + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri="viking://temp/user/test", + agent_temp_uri="viking://temp/agent/test", + ) + + for msg in enqueued_msgs: + assert msg.context_type == "memory" + + @pytest.mark.asyncio + async def test_all_msgs_have_recursive_true(self): + session = _make_session() + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri="viking://temp/user/test", + agent_temp_uri="viking://temp/agent/test", + ) + + for msg in enqueued_msgs: + assert msg.recursive is True + + @pytest.mark.asyncio + async def test_msgs_have_correct_user_context(self): + user = _make_user() + session = _make_session() + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri="viking://temp/user/test", + agent_temp_uri="viking://temp/agent/test", + ) + + for msg in enqueued_msgs: + assert msg.account_id == user.account_id + assert msg.user_id == user.user_id + assert msg.agent_id == user.agent_id + + +class TestTempUriStructureMatchesTarget: + """Tests for temp URI structure matching target URI structure.""" + + def test_session_temp_uri_contains_target_path(self): + session = _make_session(session_id="sess_123") + user = session.user + + temp_base, session_temp, _, _ = session._create_temp_uris() + + target_path = f"/session/{user.user_space_name()}/sess_123" + assert session_temp.endswith(target_path) + + def test_user_temp_uri_contains_target_path(self): + session = _make_session() + user = session.user + + temp_base, _, user_temp, _ = session._create_temp_uris() + + target_path = f"/user/{user.user_space_name()}" + assert user_temp.endswith(target_path) + + def test_agent_temp_uri_contains_target_path(self): + session = _make_session() + user = session.user + + temp_base, _, _, agent_temp = session._create_temp_uris() + + target_path = f"/agent/{user.agent_space_name()}" + assert agent_temp.endswith(target_path) + + def test_all_temp_uris_share_same_base(self): + session = _make_session() + + temp_base, session_temp, user_temp, agent_temp = session._create_temp_uris() + + assert session_temp.startswith(temp_base) + assert user_temp.startswith(temp_base) + assert agent_temp.startswith(temp_base) + + def test_temp_uri_structure_allows_semantic_dag_recursive_processing(self): + session = _make_session(session_id="sess_xyz") + user = session.user + + temp_base, session_temp, user_temp, agent_temp = session._create_temp_uris() + + assert "/session/" in session_temp + assert f"/{user.user_space_name()}/sess_xyz" in session_temp + + assert "/user/" in user_temp + assert f"/{user.user_space_name()}" in user_temp + + assert "/agent/" in agent_temp + assert f"/{user.agent_space_name()}" in agent_temp + + +class TestTempUriWithDifferentUsers: + """Tests for temp URI generation with different user configurations.""" + + def test_different_user_space_names(self): + user1 = UserIdentifier("acc1", "alice", "agent1") + user2 = UserIdentifier("acc2", "bob", "agent2") + + session1 = Session( + viking_fs=MagicMock(), + user=user1, + ctx=RequestContext(user=user1, role=Role.ROOT), + session_id="sess1", + ) + session2 = Session( + viking_fs=MagicMock(), + user=user2, + ctx=RequestContext(user=user2, role=Role.ROOT), + session_id="sess2", + ) + + _, session_temp1, user_temp1, agent_temp1 = session1._create_temp_uris() + _, session_temp2, user_temp2, agent_temp2 = session2._create_temp_uris() + + assert "alice" in session_temp1 + assert "bob" in session_temp2 + assert user_temp1 != user_temp2 + assert agent_temp1 != agent_temp2 + + def test_agent_space_name_is_hashed(self): + user = UserIdentifier("acc", "myuser", "myagent") + session = Session( + viking_fs=MagicMock(), + user=user, + ctx=RequestContext(user=user, role=Role.ROOT), + session_id="sess", + ) + + _, _, _, agent_temp = session._create_temp_uris() + + assert "myagent" not in agent_temp + assert user.agent_space_name() in agent_temp + assert len(user.agent_space_name()) == 12 diff --git a/tests/unit/storage/queuefs/__init__.py b/tests/unit/storage/queuefs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/storage/queuefs/test_dag_incremental.py b/tests/unit/storage/queuefs/test_dag_incremental.py new file mode 100644 index 00000000..f2f0cca9 --- /dev/null +++ b/tests/unit/storage/queuefs/test_dag_incremental.py @@ -0,0 +1,560 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for SemanticDagExecutor incremental update and content change detection.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.server.identity import RequestContext, Role +from openviking.storage.queuefs.semantic_dag import SemanticDagExecutor +from openviking_cli.session.user_id import UserIdentifier + + +@pytest.fixture +def mock_processor(): + """Create a mock SemanticProcessor.""" + processor = MagicMock() + processor._generate_single_file_summary = AsyncMock(return_value={"name": "test.py", "summary": "test summary"}) + processor._generate_overview = AsyncMock(return_value="test overview") + processor._extract_abstract_from_overview = MagicMock(return_value="test abstract") + processor._vectorize_single_file = AsyncMock() + processor._vectorize_directory = AsyncMock() + return processor + + +@pytest.fixture +def mock_viking_fs(): + """Create a mock VikingFS.""" + fs = MagicMock() + fs.ls = AsyncMock(return_value=[]) + fs.read_file = AsyncMock(return_value="") + fs.write_file = AsyncMock() + fs._get_vector_store = MagicMock(return_value=None) + return fs + + +@pytest.fixture +def mock_vector_store(): + """Create a mock VectorStore.""" + store = MagicMock() + store.get_context_by_uri = AsyncMock(return_value=[]) + return store + + +@pytest.fixture +def mock_context(): + """Create a mock RequestContext.""" + user = MagicMock(spec=UserIdentifier) + user.account_id = "test_account" + user.user_id = "test_user" + return RequestContext(user=user, role=Role.USER) + + +@pytest.fixture +def executor(mock_processor, mock_context, mock_viking_fs): + """Create a SemanticDagExecutor instance for testing.""" + with patch("openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs): + executor = SemanticDagExecutor( + processor=mock_processor, + context_type="resource", + max_concurrent_llm=5, + ctx=mock_context, + incremental_update=True, + target_uri="viking://resource/target", + recursive=True, + ) + return executor + + +class TestGetTargetFilePath: + """Tests for _get_target_file_path() method.""" + + def test_returns_none_when_incremental_update_disabled(self, mock_processor, mock_context, mock_viking_fs): + with patch("openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs): + executor = SemanticDagExecutor( + processor=mock_processor, + context_type="resource", + max_concurrent_llm=5, + ctx=mock_context, + incremental_update=False, + target_uri="viking://resource/target", + ) + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root/file.py") + assert result is None + + def test_returns_none_when_target_uri_is_none(self, mock_processor, mock_context, mock_viking_fs): + with patch("openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs): + executor = SemanticDagExecutor( + processor=mock_processor, + context_type="resource", + max_concurrent_llm=5, + ctx=mock_context, + incremental_update=True, + target_uri=None, + ) + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root/file.py") + assert result is None + + def test_returns_none_when_root_uri_is_none(self, executor): + executor._root_uri = None + result = executor._get_target_file_path("viking://resource/root/file.py") + assert result is None + + def test_returns_target_path_for_file_in_root(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root/file.py") + assert result == "viking://resource/target/file.py" + + def test_returns_target_path_for_file_in_subdirectory(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root/subdir/file.py") + assert result == "viking://resource/target/subdir/file.py" + + def test_returns_target_uri_when_current_uri_equals_root(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root") + assert result == "viking://resource/target" + + def test_handles_nested_paths(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root/a/b/c/file.py") + assert result == "viking://resource/target/a/b/c/file.py" + + def test_handles_path_prefix_matching(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/rootdir/file.py") + assert result == "viking://resource/target/dir/file.py" + + def test_returns_none_on_exception(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path(None) + assert result is None + + +class TestCheckFileContentChanged: + """Tests for _check_file_content_changed() method.""" + + @pytest.mark.asyncio + async def test_returns_true_when_target_path_is_none(self, executor): + executor._root_uri = None + result = await executor._check_file_content_changed("viking://resource/root/file.py") + assert result is True + + @pytest.mark.asyncio + async def test_returns_true_when_content_differs(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(side_effect=["current content", "target content"]) + + result = await executor._check_file_content_changed("viking://resource/root/file.py") + assert result is True + + @pytest.mark.asyncio + async def test_returns_false_when_content_identical(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(return_value="same content") + + result = await executor._check_file_content_changed("viking://resource/root/file.py") + assert result is False + + @pytest.mark.asyncio + async def test_returns_true_on_read_exception(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(side_effect=Exception("read error")) + + result = await executor._check_file_content_changed("viking://resource/root/file.py") + assert result is True + + @pytest.mark.asyncio + async def test_calls_read_file_with_correct_paths(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(return_value="content") + + await executor._check_file_content_changed("viking://resource/root/subdir/file.py") + + assert mock_viking_fs.read_file.call_count == 2 + calls = mock_viking_fs.read_file.call_args_list + assert calls[0][0][0] == "viking://resource/root/subdir/file.py" + assert calls[1][0][0] == "viking://resource/target/subdir/file.py" + + +class TestReadExistingSummary: + """Tests for _read_existing_summary() method.""" + + @pytest.mark.asyncio + async def test_returns_none_when_target_path_is_none(self, executor): + executor._root_uri = None + result = await executor._read_existing_summary("viking://resource/root/file.py") + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_when_vector_store_is_none(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=None) + + result = await executor._read_existing_summary("viking://resource/root/file.py") + assert result is None + + @pytest.mark.asyncio + async def test_returns_summary_dict_when_record_exists(self, executor, mock_viking_fs, mock_vector_store): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_vector_store.get_context_by_uri = AsyncMock( + return_value=[{"abstract": "existing summary content"}] + ) + + result = await executor._read_existing_summary("viking://resource/root/subdir/file.py") + + assert result is not None + assert result["name"] == "file.py" + assert result["summary"] == "existing summary content" + + @pytest.mark.asyncio + async def test_returns_none_when_no_records(self, executor, mock_viking_fs, mock_vector_store): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_vector_store.get_context_by_uri = AsyncMock(return_value=[]) + + result = await executor._read_existing_summary("viking://resource/root/file.py") + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_when_abstract_is_empty(self, executor, mock_viking_fs, mock_vector_store): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_vector_store.get_context_by_uri = AsyncMock(return_value=[{"abstract": ""}]) + + result = await executor._read_existing_summary("viking://resource/root/file.py") + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_on_exception(self, executor, mock_viking_fs, mock_vector_store): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_vector_store.get_context_by_uri = AsyncMock(side_effect=Exception("db error")) + + result = await executor._read_existing_summary("viking://resource/root/file.py") + assert result is None + + @pytest.mark.asyncio + async def test_calls_vector_store_with_correct_uri(self, executor, mock_viking_fs, mock_vector_store, mock_context): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_vector_store.get_context_by_uri = AsyncMock(return_value=[{"abstract": "summary"}]) + + await executor._read_existing_summary("viking://resource/root/file.py") + + mock_vector_store.get_context_by_uri.assert_called_once_with( + account_id=mock_context.account_id, + uri="viking://resource/target/file.py", + limit=1, + ) + + +class TestCheckDirChildrenChanged: + """Tests for _check_dir_children_changed() method.""" + + @pytest.mark.asyncio + async def test_returns_true_when_target_path_is_none(self, executor): + executor._root_uri = None + result = await executor._check_dir_children_changed( + "viking://resource/root", ["file1.py"], ["dir1"] + ) + assert result is True + + @pytest.mark.asyncio + async def test_returns_false_when_children_identical(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + current_files = ["viking://resource/root/file1.py", "viking://resource/root/file2.py"] + current_dirs = ["viking://resource/root/dir1", "viking://resource/root/dir2"] + + executor._list_dir = AsyncMock( + side_effect=[ + (["viking://resource/target/dir1", "viking://resource/target/dir2"], ["viking://resource/target/file1.py", "viking://resource/target/file2.py"]), + ([], []), + ] + ) + mock_viking_fs.read_file = AsyncMock(return_value="same content") + + result = await executor._check_dir_children_changed( + "viking://resource/root", current_files, current_dirs + ) + assert result is False + + @pytest.mark.asyncio + async def test_returns_true_when_file_names_differ(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + current_files = ["viking://resource/root/file1.py"] + current_dirs = [] + + mock_viking_fs.ls = AsyncMock( + return_value=[ + {"name": "file2.py", "isDir": False}, + ] + ) + + result = await executor._check_dir_children_changed( + "viking://resource/root", current_files, current_dirs + ) + assert result is True + + @pytest.mark.asyncio + async def test_returns_true_when_dir_names_differ(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + current_files = [] + current_dirs = ["viking://resource/root/dir1"] + + mock_viking_fs.ls = AsyncMock( + return_value=[ + {"name": "dir2", "isDir": True}, + ] + ) + + result = await executor._check_dir_children_changed( + "viking://resource/root", current_files, current_dirs + ) + assert result is True + + @pytest.mark.asyncio + async def test_returns_true_when_file_content_changed(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + current_files = ["viking://resource/root/file1.py"] + current_dirs = [] + + mock_viking_fs.ls = AsyncMock( + side_effect=[ + [{"name": "file1.py", "isDir": False}], + [], + ] + ) + mock_viking_fs.read_file = AsyncMock(side_effect=["old content", "new content"]) + + result = await executor._check_dir_children_changed( + "viking://resource/root", current_files, current_dirs + ) + assert result is True + + @pytest.mark.asyncio + async def test_returns_true_on_exception(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.ls = AsyncMock(side_effect=Exception("ls error")) + + result = await executor._check_dir_children_changed( + "viking://resource/root", ["file1.py"], [] + ) + assert result is True + + @pytest.mark.asyncio + async def test_handles_empty_directories(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.ls = AsyncMock(return_value=[]) + + result = await executor._check_dir_children_changed( + "viking://resource/root", [], [] + ) + assert result is False + + @pytest.mark.asyncio + async def test_ignores_dot_files_in_comparison(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + current_files = ["viking://resource/root/file1.py"] + current_dirs = [] + + executor._list_dir = AsyncMock( + return_value=( + [], + ["viking://resource/target/file1.py"], + ) + ) + + result = await executor._check_dir_children_changed( + "viking://resource/root", current_files, current_dirs + ) + assert result is False + + +class TestReadExistingOverviewAbstract: + """Tests for _read_existing_overview_abstract() method.""" + + @pytest.mark.asyncio + async def test_returns_none_tuple_when_target_path_is_none(self, executor): + executor._root_uri = None + result = await executor._read_existing_overview_abstract("viking://resource/root") + assert result == (None, None) + + @pytest.mark.asyncio + async def test_returns_overview_and_abstract_when_files_exist(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(side_effect=["overview content", "abstract content"]) + + result = await executor._read_existing_overview_abstract("viking://resource/root/dir") + + assert result == ("overview content", "abstract content") + calls = mock_viking_fs.read_file.call_args_list + assert calls[0][0][0] == "viking://resource/target/dir/.overview.md" + assert calls[1][0][0] == "viking://resource/target/dir/.abstract.md" + + @pytest.mark.asyncio + async def test_returns_none_tuple_on_exception(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(side_effect=Exception("read error")) + + result = await executor._read_existing_overview_abstract("viking://resource/root/dir") + assert result == (None, None) + + @pytest.mark.asyncio + async def test_handles_missing_overview_file(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock( + side_effect=[Exception("not found"), "abstract content"] + ) + + result = await executor._read_existing_overview_abstract("viking://resource/root/dir") + assert result == (None, None) + + @pytest.mark.asyncio + async def test_handles_missing_abstract_file(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock( + side_effect=["overview content", Exception("not found")] + ) + + result = await executor._read_existing_overview_abstract("viking://resource/root/dir") + assert result == (None, None) + + @pytest.mark.asyncio + async def test_calls_read_file_with_context(self, executor, mock_viking_fs, mock_context): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(return_value="content") + + await executor._read_existing_overview_abstract("viking://resource/root/dir") + + for call in mock_viking_fs.read_file.call_args_list: + assert "ctx" in call[1] + assert call[1]["ctx"] == mock_context + + @pytest.mark.asyncio + async def test_handles_nested_directory_path(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(return_value="content") + + await executor._read_existing_overview_abstract("viking://resource/root/a/b/c") + + calls = mock_viking_fs.read_file.call_args_list + assert calls[0][0][0] == "viking://resource/target/a/b/c/.overview.md" + assert calls[1][0][0] == "viking://resource/target/a/b/c/.abstract.md" + + +class TestIncrementalUpdateIntegration: + """Integration tests for incremental update scenarios.""" + + @pytest.mark.asyncio + async def test_full_incremental_flow_no_changes(self, executor, mock_viking_fs, mock_vector_store): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_viking_fs.ls = AsyncMock(return_value=[]) + mock_viking_fs.read_file = AsyncMock(return_value="same content") + mock_vector_store.get_context_by_uri = AsyncMock( + return_value=[{"abstract": "existing summary"}] + ) + + content_changed = await executor._check_file_content_changed("viking://resource/root/file.py") + assert content_changed is False + + summary = await executor._read_existing_summary("viking://resource/root/file.py") + assert summary is not None + assert summary["summary"] == "existing summary" + + @pytest.mark.asyncio + async def test_full_incremental_flow_with_changes(self, executor, mock_viking_fs, mock_vector_store): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_viking_fs.read_file = AsyncMock(side_effect=["new content", "old content"]) + + content_changed = await executor._check_file_content_changed("viking://resource/root/file.py") + assert content_changed is True + + @pytest.mark.asyncio + async def test_directory_change_detection_flow(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + executor._list_dir = AsyncMock( + side_effect=[ + ([], ["viking://resource/target/file1.py", "viking://resource/target/file2.py"]), + ([], []), + ] + ) + mock_viking_fs.read_file = AsyncMock(return_value="same content") + + current_files = ["viking://resource/root/file1.py", "viking://resource/root/file2.py"] + changed = await executor._check_dir_children_changed( + "viking://resource/root", current_files, [] + ) + assert changed is False + + @pytest.mark.asyncio + async def test_overview_abstract_read_flow(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock( + side_effect=["existing overview", "existing abstract"] + ) + + overview, abstract = await executor._read_existing_overview_abstract( + "viking://resource/root/subdir" + ) + assert overview == "existing overview" + assert abstract == "existing abstract" diff --git a/tests/unit/storage/queuefs/test_embedding_msg.py b/tests/unit/storage/queuefs/test_embedding_msg.py new file mode 100644 index 00000000..0ff42657 --- /dev/null +++ b/tests/unit/storage/queuefs/test_embedding_msg.py @@ -0,0 +1,257 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +import json +import pytest + +from openviking.storage.queuefs.embedding_msg import EmbeddingMsg + + +class TestEmbeddingMsg: + """Unit tests for EmbeddingMsg class.""" + + def test_semantic_msg_id_serialization(self): + """Test semantic_msg_id field serialization via to_dict().""" + msg = EmbeddingMsg( + message="test message", + context_data={"key": "value"}, + semantic_msg_id="semantic-123", + ) + result = msg.to_dict() + assert result["semantic_msg_id"] == "semantic-123" + assert result["message"] == "test message" + assert result["context_data"] == {"key": "value"} + assert hasattr(msg, "id") + assert msg.id is not None + + def test_semantic_msg_id_deserialization(self): + """Test semantic_msg_id field deserialization via from_dict().""" + data = { + "id": "test-id-123", + "message": "test message", + "context_data": {"key": "value"}, + "semantic_msg_id": "semantic-456", + } + msg = EmbeddingMsg.from_dict(data) + assert msg.semantic_msg_id == "semantic-456" + assert msg.id == "test-id-123" + assert msg.message == "test message" + assert msg.context_data == {"key": "value"} + + def test_from_dict_missing_semantic_msg_id_defaults_to_none(self): + """Test from_dict() compatibility with old format (missing semantic_msg_id).""" + data = { + "id": "test-id-789", + "message": "legacy message", + "context_data": {"legacy": True}, + } + msg = EmbeddingMsg.from_dict(data) + assert msg.semantic_msg_id is None + assert msg.id == "test-id-789" + assert msg.message == "legacy message" + + def test_from_dict_semantic_msg_id_explicit_none(self): + """Test from_dict() with explicit None for semantic_msg_id.""" + data = { + "id": "test-id-none", + "message": "message with None", + "context_data": {}, + "semantic_msg_id": None, + } + msg = EmbeddingMsg.from_dict(data) + assert msg.semantic_msg_id is None + + def test_to_json_with_semantic_msg_id(self): + """Test to_json() method with semantic_msg_id.""" + msg = EmbeddingMsg( + message="json test", + context_data={"json_key": "json_value"}, + semantic_msg_id="semantic-json", + ) + json_str = msg.to_json() + parsed = json.loads(json_str) + assert parsed["semantic_msg_id"] == "semantic-json" + assert parsed["message"] == "json test" + assert parsed["context_data"] == {"json_key": "json_value"} + + def test_to_json_without_semantic_msg_id(self): + """Test to_json() method without semantic_msg_id (None).""" + msg = EmbeddingMsg( + message="json test no id", + context_data={"key": "value"}, + ) + json_str = msg.to_json() + parsed = json.loads(json_str) + assert parsed["semantic_msg_id"] is None + + def test_from_json_with_semantic_msg_id(self): + """Test from_json() method with semantic_msg_id.""" + json_str = json.dumps({ + "id": "json-id-123", + "message": "from json", + "context_data": {"from": "json"}, + "semantic_msg_id": "semantic-from-json", + }) + msg = EmbeddingMsg.from_json(json_str) + assert msg.semantic_msg_id == "semantic-from-json" + assert msg.id == "json-id-123" + assert msg.message == "from json" + + def test_from_json_missing_semantic_msg_id(self): + """Test from_json() with missing semantic_msg_id (backward compatibility).""" + json_str = json.dumps({ + "id": "json-id-456", + "message": "legacy json", + "context_data": {"legacy": True}, + }) + msg = EmbeddingMsg.from_json(json_str) + assert msg.semantic_msg_id is None + assert msg.message == "legacy json" + + def test_from_json_invalid_json_raises_value_error(self): + """Test from_json() raises ValueError for invalid JSON.""" + with pytest.raises(ValueError, match="Invalid JSON string"): + EmbeddingMsg.from_json("not a valid json") + + def test_message_field_string_type(self): + """Test message field with string type.""" + msg = EmbeddingMsg( + message="simple string message", + context_data={}, + ) + assert isinstance(msg.message, str) + assert msg.message == "simple string message" + + def test_message_field_list_of_dicts_type(self): + """Test message field with List[Dict] type.""" + message_list = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + msg = EmbeddingMsg( + message=message_list, + context_data={"conversation": True}, + ) + assert isinstance(msg.message, list) + assert len(msg.message) == 2 + assert msg.message[0]["role"] == "user" + assert msg.message[1]["content"] == "Hi there" + + def test_message_list_serialization_deserialization(self): + """Test serialization and deserialization with List[Dict] message.""" + message_list = [ + {"role": "user", "content": "Question?"}, + {"role": "assistant", "content": "Answer."}, + ] + msg = EmbeddingMsg( + message=message_list, + context_data={"type": "qa"}, + semantic_msg_id="qa-123", + ) + json_str = msg.to_json() + restored = EmbeddingMsg.from_json(json_str) + assert isinstance(restored.message, list) + assert len(restored.message) == 2 + assert restored.message[0]["role"] == "user" + assert restored.semantic_msg_id == "qa-123" + + def test_id_auto_generated(self): + """Test that id is auto-generated as UUID.""" + msg = EmbeddingMsg( + message="test", + context_data={}, + ) + assert msg.id is not None + assert len(msg.id) == 36 + assert msg.id.count("-") == 4 + + def test_from_dict_preserves_or_generates_id(self): + """Test from_dict() preserves provided id or generates new one.""" + data_with_id = { + "id": "preserved-id", + "message": "test", + "context_data": {}, + } + msg = EmbeddingMsg.from_dict(data_with_id) + assert msg.id == "preserved-id" + + data_without_id = { + "message": "test", + "context_data": {}, + } + msg = EmbeddingMsg.from_dict(data_without_id) + assert msg.id is not None + assert len(msg.id) == 36 + + def test_empty_context_data(self): + """Test with empty context_data.""" + msg = EmbeddingMsg( + message="test", + context_data={}, + semantic_msg_id="empty-ctx", + ) + result = msg.to_dict() + assert result["context_data"] == {} + assert result["semantic_msg_id"] == "empty-ctx" + + def test_complex_context_data(self): + """Test with complex nested context_data.""" + complex_data = { + "nested": { + "level1": { + "level2": ["a", "b", "c"], + }, + }, + "list": [1, 2, 3], + "string": "value", + } + msg = EmbeddingMsg( + message="complex test", + context_data=complex_data, + semantic_msg_id="complex-123", + ) + json_str = msg.to_json() + restored = EmbeddingMsg.from_json(json_str) + assert restored.context_data["nested"]["level1"]["level2"] == ["a", "b", "c"] + assert restored.semantic_msg_id == "complex-123" + + def test_semantic_msg_id_empty_string(self): + """Test semantic_msg_id with empty string.""" + msg = EmbeddingMsg( + message="test", + context_data={}, + semantic_msg_id="", + ) + assert msg.semantic_msg_id == "" + result = msg.to_dict() + assert result["semantic_msg_id"] == "" + + def test_roundtrip_string_message(self): + """Test complete roundtrip with string message.""" + original = EmbeddingMsg( + message="roundtrip test", + context_data={"key": "value"}, + semantic_msg_id="roundtrip-id", + ) + json_str = original.to_json() + restored = EmbeddingMsg.from_json(json_str) + assert restored.message == original.message + assert restored.context_data == original.context_data + assert restored.semantic_msg_id == original.semantic_msg_id + assert restored.id is not None + assert len(restored.id) == 36 + + def test_roundtrip_list_message(self): + """Test complete roundtrip with List[Dict] message.""" + message_list = [ + {"type": "text", "content": "part1"}, + {"type": "code", "content": "print('hello')"}, + ] + original = EmbeddingMsg( + message=message_list, + context_data={"format": "mixed"}, + semantic_msg_id="list-msg-id", + ) + json_str = original.to_json() + restored = EmbeddingMsg.from_json(json_str) + assert restored.message == original.message + assert restored.semantic_msg_id == "list-msg-id" diff --git a/tests/unit/storage/queuefs/test_embedding_tracker.py b/tests/unit/storage/queuefs/test_embedding_tracker.py new file mode 100644 index 00000000..1850fef0 --- /dev/null +++ b/tests/unit/storage/queuefs/test_embedding_tracker.py @@ -0,0 +1,554 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for EmbeddingTaskTracker.""" + +import asyncio + +import pytest + +from openviking.storage.queuefs.embedding_tracker import EmbeddingTaskTracker + + +def reset_singleton(): + """Reset the singleton instance for testing.""" + EmbeddingTaskTracker._instance = None + + +@pytest.fixture(autouse=True) +def clean_singleton(): + """Reset singleton before and after each test.""" + reset_singleton() + yield + reset_singleton() + + +@pytest.fixture +def tracker() -> EmbeddingTaskTracker: + """Create a fresh tracker instance for each test.""" + return EmbeddingTaskTracker() + + +# ── Singleton Pattern Tests ── + + +def test_singleton_returns_same_instance(): + """Test that get_instance() returns the same instance.""" + instance1 = EmbeddingTaskTracker.get_instance() + instance2 = EmbeddingTaskTracker.get_instance() + assert instance1 is instance2 + + +def test_singleton_persists_across_calls(): + """Test that singleton persists across multiple get_instance calls.""" + instance1 = EmbeddingTaskTracker.get_instance() + instance2 = EmbeddingTaskTracker.get_instance() + instance3 = EmbeddingTaskTracker.get_instance() + assert instance1 is instance2 is instance3 + + +# ── Register Tests ── + + +@pytest.mark.asyncio +async def test_register_task(tracker: EmbeddingTaskTracker): + """Test registering a task with valid count.""" + await tracker.register("msg-1", 5) + status = await tracker.get_status("msg-1") + assert status is not None + assert status["remaining"] == 5 + assert status["total"] == 5 + + +@pytest.mark.asyncio +async def test_register_with_metadata(tracker: EmbeddingTaskTracker): + """Test registering a task with metadata.""" + metadata = {"key": "value", "count": 10} + await tracker.register("msg-2", 3, metadata=metadata) + status = await tracker.get_status("msg-2") + assert status is not None + assert status["metadata"] == metadata + + +@pytest.mark.asyncio +async def test_register_with_callback(tracker: EmbeddingTaskTracker): + """Test registering a task with on_complete callback.""" + callback_called = [] + + async def on_complete(): + callback_called.append(True) + + await tracker.register("msg-3", 1, on_complete=on_complete) + await tracker.decrement("msg-3") + assert len(callback_called) == 1 + + +@pytest.mark.asyncio +async def test_register_with_sync_callback(tracker: EmbeddingTaskTracker): + """Test registering a task with synchronous callback.""" + callback_called = [] + + def on_complete(): + callback_called.append(True) + + await tracker.register("msg-4", 1, on_complete=on_complete) + await tracker.decrement("msg-4") + assert len(callback_called) == 1 + + +@pytest.mark.asyncio +async def test_register_with_zero_count_does_nothing(tracker: EmbeddingTaskTracker): + """Test that registering with total_count=0 does nothing.""" + await tracker.register("msg-5", 0) + status = await tracker.get_status("msg-5") + assert status is None + + +@pytest.mark.asyncio +async def test_register_with_negative_count_does_nothing(tracker: EmbeddingTaskTracker): + """Test that registering with negative total_count does nothing.""" + await tracker.register("msg-6", -5) + status = await tracker.get_status("msg-6") + assert status is None + + +@pytest.mark.asyncio +async def test_register_overwrites_existing(tracker: EmbeddingTaskTracker): + """Test that registering with same ID overwrites existing entry.""" + await tracker.register("msg-7", 5) + await tracker.register("msg-7", 10) + status = await tracker.get_status("msg-7") + assert status is not None + assert status["remaining"] == 10 + assert status["total"] == 10 + + +# ── Increment Tests ── + + +@pytest.mark.asyncio +async def test_increment_existing_task(tracker: EmbeddingTaskTracker): + """Test incrementing an existing task.""" + await tracker.register("msg-10", 5) + result = await tracker.increment("msg-10") + assert result == 6 + status = await tracker.get_status("msg-10") + assert status["remaining"] == 6 + assert status["total"] == 6 + + +@pytest.mark.asyncio +async def test_increment_multiple_times(tracker: EmbeddingTaskTracker): + """Test incrementing a task multiple times.""" + await tracker.register("msg-11", 2) + await tracker.increment("msg-11") + await tracker.increment("msg-11") + await tracker.increment("msg-11") + status = await tracker.get_status("msg-11") + assert status["remaining"] == 5 + assert status["total"] == 5 + + +@pytest.mark.asyncio +async def test_increment_nonexistent_task_returns_none(tracker: EmbeddingTaskTracker): + """Test incrementing a non-existent task returns None.""" + result = await tracker.increment("nonexistent") + assert result is None + + +# ── Decrement Tests ── + + +@pytest.mark.asyncio +async def test_decrement_existing_task(tracker: EmbeddingTaskTracker): + """Test decrementing an existing task.""" + await tracker.register("msg-20", 5) + result = await tracker.decrement("msg-20") + assert result == 4 + status = await tracker.get_status("msg-20") + assert status["remaining"] == 4 + + +@pytest.mark.asyncio +async def test_decrement_multiple_times(tracker: EmbeddingTaskTracker): + """Test decrementing a task multiple times.""" + await tracker.register("msg-21", 3) + await tracker.decrement("msg-21") + await tracker.decrement("msg-21") + status = await tracker.get_status("msg-21") + assert status["remaining"] == 1 + + +@pytest.mark.asyncio +async def test_decrement_nonexistent_task_returns_none(tracker: EmbeddingTaskTracker): + """Test decrementing a non-existent task returns None.""" + result = await tracker.decrement("nonexistent") + assert result is None + + +@pytest.mark.asyncio +async def test_decrement_to_zero_removes_task(tracker: EmbeddingTaskTracker): + """Test that decrementing to zero removes the task.""" + await tracker.register("msg-22", 1) + result = await tracker.decrement("msg-22") + assert result == 0 + status = await tracker.get_status("msg-22") + assert status is None + + +@pytest.mark.asyncio +async def test_decrement_triggers_callback_on_completion(tracker: EmbeddingTaskTracker): + """Test that callback is triggered when count reaches zero.""" + callback_called = [] + + async def on_complete(): + callback_called.append("async") + + await tracker.register("msg-23", 2, on_complete=on_complete) + await tracker.decrement("msg-23") + assert len(callback_called) == 0 + await tracker.decrement("msg-23") + assert len(callback_called) == 1 + assert callback_called[0] == "async" + + +@pytest.mark.asyncio +async def test_decrement_sync_callback_on_completion(tracker: EmbeddingTaskTracker): + """Test that sync callback is triggered when count reaches zero.""" + callback_called = [] + + def on_complete(): + callback_called.append("sync") + + await tracker.register("msg-24", 1, on_complete=on_complete) + await tracker.decrement("msg-24") + assert len(callback_called) == 1 + assert callback_called[0] == "sync" + + +@pytest.mark.asyncio +async def test_decrement_callback_error_is_handled(tracker: EmbeddingTaskTracker): + """Test that callback errors are handled gracefully.""" + + async def on_complete(): + raise ValueError("Callback error") + + await tracker.register("msg-25", 1, on_complete=on_complete) + result = await tracker.decrement("msg-25") + assert result == 0 + + +@pytest.mark.asyncio +async def test_decrement_below_zero_removes_task(tracker: EmbeddingTaskTracker): + """Test that decrementing below zero still removes the task.""" + await tracker.register("msg-26", 1) + await tracker.decrement("msg-26") + status = await tracker.get_status("msg-26") + assert status is None + + +# ── Get Status Tests ── + + +@pytest.mark.asyncio +async def test_get_status_existing_task(tracker: EmbeddingTaskTracker): + """Test getting status of an existing task.""" + await tracker.register("msg-30", 5, metadata={"key": "value"}) + status = await tracker.get_status("msg-30") + assert status is not None + assert status["remaining"] == 5 + assert status["total"] == 5 + assert status["metadata"] == {"key": "value"} + + +@pytest.mark.asyncio +async def test_get_status_nonexistent_task(tracker: EmbeddingTaskTracker): + """Test getting status of a non-existent task.""" + status = await tracker.get_status("nonexistent") + assert status is None + + +@pytest.mark.asyncio +async def test_get_status_reflects_changes(tracker: EmbeddingTaskTracker): + """Test that get_status reflects increment/decrement changes.""" + await tracker.register("msg-31", 5) + await tracker.increment("msg-31") + status = await tracker.get_status("msg-31") + assert status["remaining"] == 6 + assert status["total"] == 6 + await tracker.decrement("msg-31") + status = await tracker.get_status("msg-31") + assert status["remaining"] == 5 + + +# ── Remove Tests ── + + +@pytest.mark.asyncio +async def test_remove_existing_task(tracker: EmbeddingTaskTracker): + """Test removing an existing task.""" + await tracker.register("msg-40", 5) + result = await tracker.remove("msg-40") + assert result is True + status = await tracker.get_status("msg-40") + assert status is None + + +@pytest.mark.asyncio +async def test_remove_nonexistent_task(tracker: EmbeddingTaskTracker): + """Test removing a non-existent task.""" + result = await tracker.remove("nonexistent") + assert result is False + + +@pytest.mark.asyncio +async def test_remove_does_not_trigger_callback(tracker: EmbeddingTaskTracker): + """Test that remove does not trigger the on_complete callback.""" + callback_called = [] + + async def on_complete(): + callback_called.append(True) + + await tracker.register("msg-41", 5, on_complete=on_complete) + await tracker.remove("msg-41") + assert len(callback_called) == 0 + + +# ── Get All Tracked Tests ── + + +@pytest.mark.asyncio +async def test_get_all_tracked_empty(tracker: EmbeddingTaskTracker): + """Test get_all_tracked when no tasks are registered.""" + all_tasks = await tracker.get_all_tracked() + assert all_tasks == {} + + +@pytest.mark.asyncio +async def test_get_all_tracked_single_task(tracker: EmbeddingTaskTracker): + """Test get_all_tracked with a single task.""" + await tracker.register("msg-50", 5, metadata={"key": "value"}) + all_tasks = await tracker.get_all_tracked() + assert len(all_tasks) == 1 + assert "msg-50" in all_tasks + assert all_tasks["msg-50"]["remaining"] == 5 + assert all_tasks["msg-50"]["total"] == 5 + assert all_tasks["msg-50"]["metadata"] == {"key": "value"} + + +@pytest.mark.asyncio +async def test_get_all_tracked_multiple_tasks(tracker: EmbeddingTaskTracker): + """Test get_all_tracked with multiple tasks.""" + await tracker.register("msg-51", 3) + await tracker.register("msg-52", 5) + await tracker.register("msg-53", 7) + all_tasks = await tracker.get_all_tracked() + assert len(all_tasks) == 3 + assert "msg-51" in all_tasks + assert "msg-52" in all_tasks + assert "msg-53" in all_tasks + + +@pytest.mark.asyncio +async def test_get_all_tracked_excludes_on_complete(tracker: EmbeddingTaskTracker): + """Test that get_all_tracked does not include on_complete callback.""" + await tracker.register("msg-54", 5, on_complete=lambda: None) + all_tasks = await tracker.get_all_tracked() + assert "on_complete" not in all_tasks["msg-54"] + + +@pytest.mark.asyncio +async def test_get_all_tracked_returns_copy(tracker: EmbeddingTaskTracker): + """Test that get_all_tracked returns a copy, not internal state.""" + await tracker.register("msg-55", 5) + all_tasks = await tracker.get_all_tracked() + all_tasks["msg-55"]["remaining"] = 999 + status = await tracker.get_status("msg-55") + assert status["remaining"] == 5 + + +# ── Concurrency Tests ── + + +@pytest.mark.asyncio +async def test_concurrent_register(tracker: EmbeddingTaskTracker): + """Test concurrent register operations.""" + + async def register_task(msg_id: str): + await tracker.register(msg_id, 5) + + await asyncio.gather( + register_task("msg-60"), + register_task("msg-61"), + register_task("msg-62"), + ) + all_tasks = await tracker.get_all_tracked() + assert len(all_tasks) == 3 + + +@pytest.mark.asyncio +async def test_concurrent_increment(tracker: EmbeddingTaskTracker): + """Test concurrent increment operations.""" + await tracker.register("msg-70", 1) + + async def increment_task(): + await tracker.increment("msg-70") + + await asyncio.gather( + increment_task(), + increment_task(), + increment_task(), + ) + status = await tracker.get_status("msg-70") + assert status["remaining"] == 4 + assert status["total"] == 4 + + +@pytest.mark.asyncio +async def test_concurrent_decrement(tracker: EmbeddingTaskTracker): + """Test concurrent decrement operations.""" + await tracker.register("msg-71", 3) + + async def decrement_task(): + await tracker.decrement("msg-71") + + await asyncio.gather( + decrement_task(), + decrement_task(), + decrement_task(), + ) + status = await tracker.get_status("msg-71") + assert status is None + + +@pytest.mark.asyncio +async def test_concurrent_mixed_operations(tracker: EmbeddingTaskTracker): + """Test concurrent mixed operations (increment and decrement).""" + await tracker.register("msg-72", 5) + + async def increment(): + await tracker.increment("msg-72") + + async def decrement(): + await tracker.decrement("msg-72") + + await asyncio.gather( + increment(), + increment(), + decrement(), + decrement(), + decrement(), + ) + status = await tracker.get_status("msg-72") + assert status is not None + + +@pytest.mark.asyncio +async def test_concurrent_register_and_decrement(tracker: EmbeddingTaskTracker): + """Test concurrent register and decrement operations.""" + callback_called = [] + + async def on_complete(): + callback_called.append(True) + + await tracker.register("msg-73", 1, on_complete=on_complete) + await tracker.decrement("msg-73") + assert len(callback_called) == 1 + + +@pytest.mark.asyncio +async def test_concurrent_callback_execution(tracker: EmbeddingTaskTracker): + """Test that callbacks are executed correctly under concurrency.""" + callback_count = [] + + async def make_callback(msg_id: str): + async def on_complete(): + callback_count.append(msg_id) + + return on_complete + + async def register_and_complete(msg_id: str): + callback = await make_callback(msg_id) + await tracker.register(msg_id, 1, on_complete=callback) + await tracker.decrement(msg_id) + + await asyncio.gather( + register_and_complete("msg-80"), + register_and_complete("msg-81"), + register_and_complete("msg-82"), + ) + assert len(callback_count) == 3 + + +# ── Edge Cases Tests ── + + +@pytest.mark.asyncio +async def test_multiple_decrements_to_zero(tracker: EmbeddingTaskTracker): + """Test multiple decrements that bring count to exactly zero.""" + callback_called = [] + + async def on_complete(): + callback_called.append(True) + + await tracker.register("msg-90", 3, on_complete=on_complete) + await tracker.decrement("msg-90") + await tracker.decrement("msg-90") + assert len(callback_called) == 0 + await tracker.decrement("msg-90") + assert len(callback_called) == 1 + + +@pytest.mark.asyncio +async def test_decrement_after_increment(tracker: EmbeddingTaskTracker): + """Test decrement after increment maintains correct count.""" + await tracker.register("msg-91", 2) + await tracker.increment("msg-91") + await tracker.decrement("msg-91") + status = await tracker.get_status("msg-91") + assert status["remaining"] == 2 + assert status["total"] == 3 + + +@pytest.mark.asyncio +async def test_empty_metadata(tracker: EmbeddingTaskTracker): + """Test that empty metadata is handled correctly.""" + await tracker.register("msg-92", 5, metadata={}) + status = await tracker.get_status("msg-92") + assert status["metadata"] == {} + + +@pytest.mark.asyncio +async def test_none_metadata(tracker: EmbeddingTaskTracker): + """Test that None metadata defaults to empty dict.""" + await tracker.register("msg-93", 5, metadata=None) + status = await tracker.get_status("msg-93") + assert status["metadata"] == {} + + +@pytest.mark.asyncio +async def test_none_callback(tracker: EmbeddingTaskTracker): + """Test that None callback is handled correctly.""" + await tracker.register("msg-94", 1, on_complete=None) + result = await tracker.decrement("msg-94") + assert result == 0 + + +@pytest.mark.asyncio +async def test_large_count(tracker: EmbeddingTaskTracker): + """Test with large task count.""" + large_count = 10000 + await tracker.register("msg-95", large_count) + status = await tracker.get_status("msg-95") + assert status["remaining"] == large_count + assert status["total"] == large_count + + +@pytest.mark.asyncio +async def test_special_characters_in_id(tracker: EmbeddingTaskTracker): + """Test with special characters in semantic_msg_id.""" + special_id = "msg-with-special_chars.123!@#$%" + await tracker.register(special_id, 5) + status = await tracker.get_status(special_id) + assert status is not None + assert status["remaining"] == 5 diff --git a/tests/unit/storage/queuefs/test_processor_incremental.py b/tests/unit/storage/queuefs/test_processor_incremental.py new file mode 100644 index 00000000..c87ebcf9 --- /dev/null +++ b/tests/unit/storage/queuefs/test_processor_incremental.py @@ -0,0 +1,832 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for SemanticProcessor incremental update and diff calculation. + +Tests for: +- _detect_file_type(): Detect file type based on extension +- _collect_tree_info(): Collect directory tree information +- _compute_diff(): Compute directory differences +- _check_file_content_changed(): Check file content changes +- _execute_sync_operations(): Execute sync operations +- _create_sync_diff_callback(): Create sync diff callback +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.parse.parsers.constants import ( + FILE_TYPE_CODE, + FILE_TYPE_DOCUMENTATION, + FILE_TYPE_OTHER, +) +from openviking.server.identity import RequestContext, Role +from openviking.storage.queuefs.semantic_processor import DiffResult, SemanticProcessor +from openviking_cli.session.user_id import UserIdentifier +from openviking_cli.utils import VikingURI + + +class FakeVikingFS: + """Fake VikingFS for testing.""" + + def __init__(self): + self._tree = {} + self._file_contents = {} + self.deleted_files = [] + self.deleted_dirs = [] + self.moved_files = [] + self.created_dirs = [] + + def set_tree(self, tree): + self._tree = tree + + def set_file_contents(self, contents): + self._file_contents = contents + + async def ls(self, uri, ctx=None): + return self._tree.get(uri.rstrip("/"), []) + + async def read_file(self, uri, ctx=None): + return self._file_contents.get(uri, "") + + async def rm(self, uri, recursive=False, ctx=None): + if recursive: + self.deleted_dirs.append(uri) + else: + self.deleted_files.append(uri) + + async def mv(self, src, dst, ctx=None): + self.moved_files.append((src, dst)) + + async def mkdir(self, uri, exist_ok=True, ctx=None): + self.created_dirs.append(uri) + + +@pytest.fixture +def processor(): + """Create a SemanticProcessor instance for testing.""" + return SemanticProcessor(max_concurrent_llm=10) + + +@pytest.fixture +def fake_fs(): + """Create a fake VikingFS instance.""" + return FakeVikingFS() + + +@pytest.fixture +def ctx(): + """Create a RequestContext for testing.""" + return RequestContext( + user=UserIdentifier("test_account", "test_user", "test_agent"), + role=Role.USER, + ) + + +class TestDetectFileType: + """Test cases for _detect_file_type() method.""" + + def test_detect_python_file(self, processor): + result = processor._detect_file_type("main.py") + assert result == FILE_TYPE_CODE + + def test_detect_javascript_file(self, processor): + result = processor._detect_file_type("app.js") + assert result == FILE_TYPE_CODE + + def test_detect_typescript_file(self, processor): + result = processor._detect_file_type("utils.ts") + assert result == FILE_TYPE_CODE + + def test_detect_java_file(self, processor): + result = processor._detect_file_type("Main.java") + assert result == FILE_TYPE_CODE + + def test_detect_go_file(self, processor): + result = processor._detect_file_type("server.go") + assert result == FILE_TYPE_CODE + + def test_detect_rust_file(self, processor): + result = processor._detect_file_type("main.rs") + assert result == FILE_TYPE_CODE + + def test_detect_c_file(self, processor): + result = processor._detect_file_type("program.c") + assert result == FILE_TYPE_CODE + + def test_detect_cpp_file(self, processor): + result = processor._detect_file_type("module.cpp") + assert result == FILE_TYPE_CODE + + def test_detect_markdown_file(self, processor): + result = processor._detect_file_type("README.md") + assert result == FILE_TYPE_DOCUMENTATION + + def test_detect_rst_file(self, processor): + result = processor._detect_file_type("docs.rst") + assert result == FILE_TYPE_DOCUMENTATION + + def test_detect_txt_file(self, processor): + result = processor._detect_file_type("notes.txt") + assert result == FILE_TYPE_DOCUMENTATION + + def test_detect_json_file(self, processor): + result = processor._detect_file_type("config.json") + assert result == FILE_TYPE_CODE + + def test_detect_yaml_file(self, processor): + result = processor._detect_file_type("settings.yaml") + assert result == FILE_TYPE_CODE + + def test_detect_unknown_extension(self, processor): + result = processor._detect_file_type("data.xyz") + assert result == FILE_TYPE_OTHER + + def test_detect_no_extension(self, processor): + result = processor._detect_file_type("Makefile") + assert result == FILE_TYPE_OTHER + + def test_detect_uppercase_extension(self, processor): + result = processor._detect_file_type("SCRIPT.PY") + assert result == FILE_TYPE_CODE + + def test_detect_mixed_case_extension(self, processor): + result = processor._detect_file_type("ReadMe.Md") + assert result == FILE_TYPE_DOCUMENTATION + + def test_detect_path_with_dots(self, processor): + result = processor._detect_file_type("src/utils/helper.py") + assert result == FILE_TYPE_CODE + + +class TestCollectTreeInfo: + """Test cases for _collect_tree_info() method.""" + + @pytest.mark.asyncio + async def test_collect_empty_directory(self, processor, fake_fs, ctx): + fake_fs.set_tree({"viking://temp/empty": []}) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/empty") + + assert result == {"viking://temp/empty": ([], [])} + + @pytest.mark.asyncio + async def test_collect_directory_with_files(self, processor, fake_fs, ctx): + fake_fs.set_tree({ + "viking://temp/dir": [ + {"name": "file1.txt", "isDir": False}, + {"name": "file2.py", "isDir": False}, + ] + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/dir") + + assert "viking://temp/dir" in result + sub_dirs, files = result["viking://temp/dir"] + assert sub_dirs == [] + assert len(files) == 2 + assert "viking://temp/dir/file1.txt" in files + assert "viking://temp/dir/file2.py" in files + + @pytest.mark.asyncio + async def test_collect_directory_with_subdirs(self, processor, fake_fs, ctx): + fake_fs.set_tree({ + "viking://temp/root": [ + {"name": "subdir1", "isDir": True}, + {"name": "subdir2", "isDir": True}, + ], + "viking://temp/root/subdir1": [ + {"name": "file1.txt", "isDir": False}, + ], + "viking://temp/root/subdir2": [ + {"name": "file2.txt", "isDir": False}, + ], + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/root") + + assert "viking://temp/root" in result + assert "viking://temp/root/subdir1" in result + assert "viking://temp/root/subdir2" in result + + @pytest.mark.asyncio + async def test_collect_nested_directories(self, processor, fake_fs, ctx): + fake_fs.set_tree({ + "viking://temp/root": [ + {"name": "level1", "isDir": True}, + ], + "viking://temp/root/level1": [ + {"name": "level2", "isDir": True}, + ], + "viking://temp/root/level1/level2": [ + {"name": "deep_file.txt", "isDir": False}, + ], + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/root") + + assert "viking://temp/root" in result + assert "viking://temp/root/level1" in result + assert "viking://temp/root/level1/level2" in result + + @pytest.mark.asyncio + async def test_collect_skips_hidden_files(self, processor, fake_fs, ctx): + fake_fs.set_tree({ + "viking://temp/dir": [ + {"name": ".hidden", "isDir": False}, + {"name": "visible.txt", "isDir": False}, + ] + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/dir") + + _, files = result["viking://temp/dir"] + assert len(files) == 1 + assert "viking://temp/dir/visible.txt" in files + + @pytest.mark.asyncio + async def test_collect_skips_dot_and_dotdot(self, processor, fake_fs, ctx): + fake_fs.set_tree({ + "viking://temp/dir": [ + {"name": ".", "isDir": True}, + {"name": "..", "isDir": True}, + {"name": "file.txt", "isDir": False}, + ] + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/dir") + + sub_dirs, files = result["viking://temp/dir"] + assert sub_dirs == [] + assert len(files) == 1 + + @pytest.mark.asyncio + async def test_collect_handles_ls_error(self, processor, fake_fs, ctx): + fake_fs.ls = AsyncMock(side_effect=Exception("LS error")) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/dir") + + assert result == {} + + +class TestComputeDiff: + """Test cases for _compute_diff() method.""" + + @pytest.mark.asyncio + async def test_compute_diff_no_changes(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], ["viking://temp/root/file.txt"]), + } + target_tree = { + "viking://target/root": ([], ["viking://target/root/file.txt"]), + } + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert diff.added_files == [] + assert diff.deleted_files == [] + assert diff.updated_files == [] + assert diff.added_dirs == [] + assert diff.deleted_dirs == [] + + @pytest.mark.asyncio + async def test_compute_diff_added_files(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], ["viking://temp/root/new_file.txt"]), + } + target_tree = { + "viking://target/root": ([], []), + } + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert len(diff.added_files) == 1 + assert "viking://temp/root/new_file.txt" in diff.added_files + assert diff.deleted_files == [] + + @pytest.mark.asyncio + async def test_compute_diff_deleted_files(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], []), + } + target_tree = { + "viking://target/root": ([], ["viking://target/root/old_file.txt"]), + } + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert diff.added_files == [] + assert len(diff.deleted_files) == 1 + assert "viking://target/root/old_file.txt" in diff.deleted_files + + @pytest.mark.asyncio + async def test_compute_diff_updated_files(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], ["viking://temp/root/file.txt"]), + } + target_tree = { + "viking://target/root": ([], ["viking://target/root/file.txt"]), + } + fake_fs.set_file_contents({ + "viking://temp/root/file.txt": "new content", + "viking://target/root/file.txt": "old content", + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert len(diff.updated_files) == 1 + assert "viking://temp/root/file.txt" in diff.updated_files + + @pytest.mark.asyncio + async def test_compute_diff_unchanged_files(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], ["viking://temp/root/file.txt"]), + } + target_tree = { + "viking://target/root": ([], ["viking://target/root/file.txt"]), + } + fake_fs.set_file_contents({ + "viking://temp/root/file.txt": "same content", + "viking://target/root/file.txt": "same content", + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert diff.updated_files == [] + + @pytest.mark.asyncio + async def test_compute_diff_added_dirs(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": (["viking://temp/root/new_dir"], []), + "viking://temp/root/new_dir": ([], []), + } + target_tree = { + "viking://target/root": ([], []), + } + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert len(diff.added_dirs) == 1 + assert "viking://temp/root/new_dir" in diff.added_dirs + + @pytest.mark.asyncio + async def test_compute_diff_deleted_dirs(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], []), + } + target_tree = { + "viking://target/root": (["viking://target/root/old_dir"], []), + "viking://target/root/old_dir": ([], []), + } + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert len(diff.deleted_dirs) == 1 + assert "viking://target/root/old_dir" in diff.deleted_dirs + + @pytest.mark.asyncio + async def test_compute_diff_mixed_changes(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ( + ["viking://temp/root/new_dir"], + ["viking://temp/root/new_file.txt", "viking://temp/root/updated.txt"], + ), + "viking://temp/root/new_dir": ([], []), + } + target_tree = { + "viking://target/root": ( + ["viking://target/root/old_dir"], + ["viking://target/root/updated.txt", "viking://target/root/deleted.txt"], + ), + "viking://target/root/old_dir": ([], []), + } + fake_fs.set_file_contents({ + "viking://temp/root/updated.txt": "new content", + "viking://target/root/updated.txt": "old content", + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert len(diff.added_files) == 1 + assert len(diff.deleted_files) == 1 + assert len(diff.updated_files) == 1 + assert len(diff.added_dirs) == 1 + assert len(diff.deleted_dirs) == 1 + + +class TestCheckFileContentChanged: + """Test cases for _check_file_content_changed() method.""" + + @pytest.mark.asyncio + async def test_content_changed(self, processor, fake_fs, ctx): + fake_fs.set_file_contents({ + "viking://temp/file.txt": "new content", + "viking://target/file.txt": "old content", + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._check_file_content_changed( + "viking://temp/file.txt", "viking://target/file.txt" + ) + + assert result is True + + @pytest.mark.asyncio + async def test_content_unchanged(self, processor, fake_fs, ctx): + fake_fs.set_file_contents({ + "viking://temp/file.txt": "same content", + "viking://target/file.txt": "same content", + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._check_file_content_changed( + "viking://temp/file.txt", "viking://target/file.txt" + ) + + assert result is False + + @pytest.mark.asyncio + async def test_content_changed_empty_files(self, processor, fake_fs, ctx): + fake_fs.set_file_contents({ + "viking://temp/file.txt": "", + "viking://target/file.txt": "", + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._check_file_content_changed( + "viking://temp/file.txt", "viking://target/file.txt" + ) + + assert result is False + + @pytest.mark.asyncio + async def test_content_changed_one_empty(self, processor, fake_fs, ctx): + fake_fs.set_file_contents({ + "viking://temp/file.txt": "content", + "viking://target/file.txt": "", + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._check_file_content_changed( + "viking://temp/file.txt", "viking://target/file.txt" + ) + + assert result is True + + @pytest.mark.asyncio + async def test_content_changed_on_exception(self, processor, fake_fs, ctx): + fake_fs.read_file = AsyncMock(side_effect=Exception("Read error")) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._check_file_content_changed( + "viking://temp/file.txt", "viking://target/file.txt" + ) + + assert result is True + + +class TestExecuteSyncOperations: + """Test cases for _execute_sync_operations() method.""" + + @pytest.mark.asyncio + async def test_execute_delete_files(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=[], + deleted_files=["viking://target/deleted.txt"], + updated_files=[], + added_dirs=[], + deleted_dirs=[], + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert "viking://target/deleted.txt" in fake_fs.deleted_files + + @pytest.mark.asyncio + async def test_execute_move_added_files(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=["viking://temp/root/new.txt"], + deleted_files=[], + updated_files=[], + added_dirs=[], + deleted_dirs=[], + ) + processor._current_ctx = ctx + + mock_viking_uri = MagicMock() + mock_viking_uri.parent = None + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + with patch( + "openviking.storage.queuefs.semantic_processor.VikingURI", return_value=mock_viking_uri + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert ("viking://temp/root/new.txt", "viking://target/root/new.txt") in fake_fs.moved_files + + @pytest.mark.asyncio + async def test_execute_move_updated_files(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=[], + deleted_files=[], + updated_files=["viking://temp/root/updated.txt"], + added_dirs=[], + deleted_dirs=[], + ) + processor._current_ctx = ctx + + mock_viking_uri = MagicMock() + mock_viking_uri.parent = None + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + with patch( + "openviking.storage.queuefs.semantic_processor.VikingURI", return_value=mock_viking_uri + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert "viking://target/root/updated.txt" in fake_fs.deleted_files + assert ("viking://temp/root/updated.txt", "viking://target/root/updated.txt") in fake_fs.moved_files + + @pytest.mark.asyncio + async def test_execute_delete_dirs(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=[], + deleted_files=[], + updated_files=[], + added_dirs=[], + deleted_dirs=["viking://target/root/old_dir"], + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert "viking://target/root/old_dir" in fake_fs.deleted_dirs + + @pytest.mark.asyncio + async def test_execute_delete_dirs_deepest_first(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=[], + deleted_files=[], + updated_files=[], + added_dirs=[], + deleted_dirs=[ + "viking://target/root/level1", + "viking://target/root/level1/level2", + "viking://target/root/level1/level2/level3", + ], + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert len(fake_fs.deleted_dirs) == 3 + deepest_first = sorted(fake_fs.deleted_dirs, key=lambda x: x.count("/"), reverse=True) + assert fake_fs.deleted_dirs == deepest_first + + @pytest.mark.asyncio + async def test_execute_creates_parent_dirs(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=["viking://temp/root/subdir/new.txt"], + deleted_files=[], + updated_files=[], + added_dirs=[], + deleted_dirs=[], + ) + processor._current_ctx = ctx + + mock_parent = MagicMock() + mock_parent.uri = "viking://target/root/subdir" + mock_viking_uri = MagicMock() + mock_viking_uri.parent = mock_parent + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + with patch( + "openviking.storage.queuefs.semantic_processor.VikingURI", return_value=mock_viking_uri + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert "viking://target/root/subdir" in fake_fs.created_dirs + + +class TestCreateSyncDiffCallback: + """Test cases for _create_sync_diff_callback() method.""" + + @pytest.mark.asyncio + async def test_callback_returns_callable(self, processor): + callback = processor._create_sync_diff_callback( + "viking://temp/root", "viking://target/root" + ) + assert callable(callback) + + @pytest.mark.asyncio + async def test_callback_is_async(self, processor): + callback = processor._create_sync_diff_callback( + "viking://temp/root", "viking://target/root" + ) + import asyncio + assert asyncio.iscoroutinefunction(callback) + + @pytest.mark.asyncio + async def test_callback_collects_tree_info(self, processor, fake_fs, ctx): + fake_fs.set_tree({ + "viking://temp/root": [ + {"name": "file.txt", "isDir": False}, + ], + "viking://target/root": [ + {"name": "file.txt", "isDir": False}, + ], + }) + fake_fs.set_file_contents({ + "viking://temp/root/file.txt": "content", + "viking://target/root/file.txt": "content", + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + callback = processor._create_sync_diff_callback( + "viking://temp/root", "viking://target/root" + ) + await callback() + + @pytest.mark.asyncio + async def test_callback_handles_exception(self, processor, fake_fs, ctx): + fake_fs.ls = AsyncMock(side_effect=Exception("Test error")) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + callback = processor._create_sync_diff_callback( + "viking://temp/root", "viking://target/root" + ) + await callback() + + @pytest.mark.asyncio + async def test_callback_deletes_root_after_sync(self, processor, fake_fs, ctx): + fake_fs.set_tree({ + "viking://temp/root": [], + "viking://target/root": [], + }) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + callback = processor._create_sync_diff_callback( + "viking://temp/root", "viking://target/root" + ) + await callback() + + assert "viking://temp/root" in fake_fs.deleted_dirs + + +class TestDiffResult: + """Test cases for DiffResult dataclass.""" + + def test_diff_result_default_values(self): + diff = DiffResult() + assert diff.added_files == [] + assert diff.deleted_files == [] + assert diff.updated_files == [] + assert diff.added_dirs == [] + assert diff.deleted_dirs == [] + + def test_diff_result_with_values(self): + diff = DiffResult( + added_files=["a.txt"], + deleted_files=["b.txt"], + updated_files=["c.txt"], + added_dirs=["dir1"], + deleted_dirs=["dir2"], + ) + assert diff.added_files == ["a.txt"] + assert diff.deleted_files == ["b.txt"] + assert diff.updated_files == ["c.txt"] + assert diff.added_dirs == ["dir1"] + assert diff.deleted_dirs == ["dir2"] + + def test_diff_result_modifiable(self): + diff = DiffResult() + diff.added_files.append("new.txt") + assert "new.txt" in diff.added_files diff --git a/tests/unit/storage/queuefs/test_semantic_msg.py b/tests/unit/storage/queuefs/test_semantic_msg.py new file mode 100644 index 00000000..84dee8eb --- /dev/null +++ b/tests/unit/storage/queuefs/test_semantic_msg.py @@ -0,0 +1,406 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for SemanticMsg dataclass, focusing on new fields target_uri and skip_vectorization.""" + +import json + +import pytest + +from openviking.storage.queuefs.semantic_msg import SemanticMsg + + +class TestTargetUriField: + """Tests for target_uri field serialization and deserialization.""" + + def test_target_uri_default_value(self): + msg = SemanticMsg(uri="viking://resource/test", context_type="resource") + assert msg.target_uri == "" + + def test_target_uri_set_in_constructor(self): + msg = SemanticMsg( + uri="viking://resource/temp", + context_type="resource", + target_uri="viking://resource/target", + ) + assert msg.target_uri == "viking://resource/target" + + def test_target_uri_serialization_to_dict(self): + msg = SemanticMsg( + uri="viking://resource/temp", + context_type="resource", + target_uri="viking://resource/target", + ) + data = msg.to_dict() + assert "target_uri" in data + assert data["target_uri"] == "viking://resource/target" + + def test_target_uri_deserialization_from_dict(self): + data = { + "uri": "viking://resource/temp", + "context_type": "resource", + "target_uri": "viking://resource/target", + } + msg = SemanticMsg.from_dict(data) + assert msg.target_uri == "viking://resource/target" + + def test_target_uri_empty_string_serialization(self): + msg = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + target_uri="", + ) + data = msg.to_dict() + assert data["target_uri"] == "" + + def test_target_uri_with_memory_context(self): + msg = SemanticMsg( + uri="viking://memory/temp/session", + context_type="memory", + target_uri="viking://session/abc123", + ) + data = msg.to_dict() + msg_restored = SemanticMsg.from_dict(data) + assert msg_restored.target_uri == "viking://session/abc123" + + +class TestSkipVectorizationField: + """Tests for skip_vectorization field serialization and deserialization.""" + + def test_skip_vectorization_default_value(self): + msg = SemanticMsg(uri="viking://resource/test", context_type="resource") + assert msg.skip_vectorization is False + + def test_skip_vectorization_set_true_in_constructor(self): + msg = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + skip_vectorization=True, + ) + assert msg.skip_vectorization is True + + def test_skip_vectorization_serialization_to_dict(self): + msg = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + skip_vectorization=True, + ) + data = msg.to_dict() + assert "skip_vectorization" in data + assert data["skip_vectorization"] is True + + def test_skip_vectorization_deserialization_from_dict(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "skip_vectorization": True, + } + msg = SemanticMsg.from_dict(data) + assert msg.skip_vectorization is True + + def test_skip_vectorization_false_serialization(self): + msg = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + skip_vectorization=False, + ) + data = msg.to_dict() + assert data["skip_vectorization"] is False + + def test_skip_vectorization_round_trip(self): + original = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + skip_vectorization=True, + ) + restored = SemanticMsg.from_dict(original.to_dict()) + assert restored.skip_vectorization == original.skip_vectorization + + +class TestFromDictBackwardCompatibility: + """Tests for from_dict() backward compatibility with old format missing new fields.""" + + def test_missing_target_uri_uses_default(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + } + msg = SemanticMsg.from_dict(data) + assert msg.target_uri == "" + + def test_missing_skip_vectorization_uses_default(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + } + msg = SemanticMsg.from_dict(data) + assert msg.skip_vectorization is False + + def test_missing_both_new_fields_uses_defaults(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "recursive": True, + "account_id": "test_account", + } + msg = SemanticMsg.from_dict(data) + assert msg.target_uri == "" + assert msg.skip_vectorization is False + + def test_old_format_with_all_legacy_fields(self): + data = { + "id": "legacy-id-123", + "uri": "viking://resource/test", + "context_type": "resource", + "status": "pending", + "recursive": False, + "account_id": "account1", + "user_id": "user1", + "agent_id": "agent1", + "role": "admin", + } + msg = SemanticMsg.from_dict(data) + assert msg.id == "legacy-id-123" + assert msg.uri == "viking://resource/test" + assert msg.context_type == "resource" + assert msg.status == "pending" + assert msg.recursive is False + assert msg.target_uri == "" + assert msg.skip_vectorization is False + + def test_partial_new_fields_only_target_uri(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "target_uri": "viking://resource/target", + } + msg = SemanticMsg.from_dict(data) + assert msg.target_uri == "viking://resource/target" + assert msg.skip_vectorization is False + + def test_partial_new_fields_only_skip_vectorization(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "skip_vectorization": True, + } + msg = SemanticMsg.from_dict(data) + assert msg.target_uri == "" + assert msg.skip_vectorization is True + + +class TestToJsonFromJson: + """Tests for to_json() and from_json() methods.""" + + def test_to_json_returns_valid_json_string(self): + msg = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + ) + json_str = msg.to_json() + assert isinstance(json_str, str) + parsed = json.loads(json_str) + assert isinstance(parsed, dict) + + def test_from_json_creates_valid_object(self): + json_str = '{"uri": "viking://resource/test", "context_type": "resource"}' + msg = SemanticMsg.from_json(json_str) + assert msg.uri == "viking://resource/test" + assert msg.context_type == "resource" + + def test_to_json_and_from_json_round_trip(self): + original = SemanticMsg( + uri="viking://resource/temp", + context_type="memory", + target_uri="viking://session/abc", + skip_vectorization=True, + recursive=False, + account_id="test_account", + user_id="test_user", + agent_id="test_agent", + role="admin", + ) + json_str = original.to_json() + restored = SemanticMsg.from_json(json_str) + + assert restored.uri == original.uri + assert restored.context_type == original.context_type + assert restored.target_uri == original.target_uri + assert restored.skip_vectorization == original.skip_vectorization + assert restored.recursive == original.recursive + assert restored.account_id == original.account_id + assert restored.user_id == original.user_id + assert restored.agent_id == original.agent_id + assert restored.role == original.role + + def test_from_json_with_new_fields(self): + json_str = json.dumps({ + "uri": "viking://resource/test", + "context_type": "resource", + "target_uri": "viking://resource/target", + "skip_vectorization": True, + }) + msg = SemanticMsg.from_json(json_str) + assert msg.target_uri == "viking://resource/target" + assert msg.skip_vectorization is True + + def test_from_json_without_new_fields(self): + json_str = json.dumps({ + "uri": "viking://resource/test", + "context_type": "resource", + }) + msg = SemanticMsg.from_json(json_str) + assert msg.target_uri == "" + assert msg.skip_vectorization is False + + def test_from_json_invalid_json_raises_value_error(self): + with pytest.raises(ValueError, match="Invalid JSON string"): + SemanticMsg.from_json("not a valid json") + + def test_from_json_missing_required_fields_raises_value_error(self): + json_str = '{"uri": "viking://resource/test"}' + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_json(json_str) + + +class TestRequiredFieldsValidation: + """Tests for required field validation (uri and context_type).""" + + def test_missing_uri_raises_value_error(self): + data = {"context_type": "resource"} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_missing_context_type_raises_value_error(self): + data = {"uri": "viking://resource/test"} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_missing_both_required_fields_raises_value_error(self): + data = {"target_uri": "viking://resource/target"} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_empty_uri_raises_value_error(self): + data = {"uri": "", "context_type": "resource"} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_empty_context_type_raises_value_error(self): + data = {"uri": "viking://resource/test", "context_type": ""} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_none_uri_raises_value_error(self): + data = {"uri": None, "context_type": "resource"} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_none_context_type_raises_value_error(self): + data = {"uri": "viking://resource/test", "context_type": None} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_empty_dict_raises_value_error(self): + with pytest.raises(ValueError, match="Data dictionary is empty"): + SemanticMsg.from_dict({}) + + def test_valid_minimal_data_succeeds(self): + data = {"uri": "viking://resource/test", "context_type": "resource"} + msg = SemanticMsg.from_dict(data) + assert msg.uri == "viking://resource/test" + assert msg.context_type == "resource" + + def test_error_message_lists_all_missing_fields(self): + data = {"skip_vectorization": True} + with pytest.raises(ValueError) as exc_info: + SemanticMsg.from_dict(data) + error_msg = str(exc_info.value) + assert "uri" in error_msg + assert "context_type" in error_msg + + +class TestEdgeCases: + """Tests for edge cases and boundary conditions.""" + + def test_target_uri_with_special_characters(self): + special_uri = "viking://resource/test%20space?query=value&other=123" + msg = SemanticMsg( + uri="viking://resource/temp", + context_type="resource", + target_uri=special_uri, + ) + restored = SemanticMsg.from_dict(msg.to_dict()) + assert restored.target_uri == special_uri + + def test_target_uri_with_unicode(self): + unicode_uri = "viking://resource/测试/目录" + msg = SemanticMsg( + uri="viking://resource/temp", + context_type="resource", + target_uri=unicode_uri, + ) + restored = SemanticMsg.from_dict(msg.to_dict()) + assert restored.target_uri == unicode_uri + + def test_target_uri_with_long_path(self): + long_path = "viking://resource/" + "/".join(["dir"] * 100) + msg = SemanticMsg( + uri="viking://resource/temp", + context_type="resource", + target_uri=long_path, + ) + restored = SemanticMsg.from_dict(msg.to_dict()) + assert restored.target_uri == long_path + + def test_preserves_existing_id_in_from_dict(self): + original = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + ) + original_id = original.id + data = original.to_dict() + restored = SemanticMsg.from_dict(data) + assert restored.id == original_id + + def test_from_dict_overwrites_id_if_provided(self): + data = { + "id": "custom-id-123", + "uri": "viking://resource/test", + "context_type": "resource", + } + msg = SemanticMsg.from_dict(data) + assert msg.id == "custom-id-123" + + def test_from_dict_preserves_status_and_timestamp(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "status": "completed", + "timestamp": 1700000000, + } + msg = SemanticMsg.from_dict(data) + assert msg.status == "completed" + assert msg.timestamp == 1700000000 + + def test_all_context_types(self): + context_types = ["resource", "memory", "skill", "session"] + for ctx_type in context_types: + msg = SemanticMsg( + uri=f"viking://{ctx_type}/test", + context_type=ctx_type, + ) + assert msg.context_type == ctx_type + restored = SemanticMsg.from_dict(msg.to_dict()) + assert restored.context_type == ctx_type + + def test_extra_fields_in_dict_are_ignored(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "extra_field": "should_be_ignored", + "another_extra": 12345, + } + msg = SemanticMsg.from_dict(data) + assert msg.uri == "viking://resource/test" + assert not hasattr(msg, "extra_field") + assert not hasattr(msg, "another_extra") diff --git a/tests/unit/storage/test_viking_fs_new.py b/tests/unit/storage/test_viking_fs_new.py new file mode 100644 index 00000000..b754dd64 --- /dev/null +++ b/tests/unit/storage/test_viking_fs_new.py @@ -0,0 +1,195 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for VikingFS new methods. + +Tests for: +- exists(): Check if URI exists +- copy_directory(): Recursively copy directory +- delete_temp(): Delete temporary directory +""" + +import contextvars +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.storage.viking_fs import VikingFS + + +def _create_viking_fs_mock(): + """Create a VikingFS instance with mocked AGFS backend.""" + fs = VikingFS.__new__(VikingFS) + fs.agfs = MagicMock() + fs.query_embedder = None + fs.vector_store = None + fs._uri_prefix = "viking://" + fs._bound_ctx = contextvars.ContextVar("vikingfs_bound_ctx", default=None) + return fs + + +@pytest.mark.asyncio +class TestVikingFSExists: + """Test cases for VikingFS.exists() method.""" + + async def test_exists_returns_true_when_uri_exists(self): + """exists() should return True when URI exists.""" + fs = _create_viking_fs_mock() + fs.stat = AsyncMock(return_value={"name": "test_file.txt", "isDir": False}) + + result = await fs.exists("viking://temp/test_file.txt") + + assert result is True + fs.stat.assert_awaited_once_with("viking://temp/test_file.txt", ctx=None) + + async def test_exists_returns_false_when_uri_not_found(self): + """exists() should return False when URI does not exist.""" + fs = _create_viking_fs_mock() + fs.stat = AsyncMock(side_effect=FileNotFoundError("Not found")) + + result = await fs.exists("viking://temp/nonexistent.txt") + + assert result is False + fs.stat.assert_awaited_once_with("viking://temp/nonexistent.txt", ctx=None) + + async def test_exists_returns_false_on_any_exception(self): + """exists() should return False on any exception, not just FileNotFoundError.""" + fs = _create_viking_fs_mock() + fs.stat = AsyncMock(side_effect=PermissionError("Access denied")) + + result = await fs.exists("viking://temp/protected.txt") + + assert result is False + + +@pytest.mark.asyncio +class TestVikingFSCopyDirectory: + """Test cases for VikingFS.copy_directory() method.""" + + async def test_copy_directory_recursive(self): + """copy_directory() should recursively copy directory contents.""" + fs = _create_viking_fs_mock() + fs._ensure_access = MagicMock() + fs._uri_to_path = MagicMock(side_effect=lambda uri, ctx=None: uri.replace("viking://", "/local/")) + fs._ensure_parent_dirs = AsyncMock() + + mock_agfs_cp = MagicMock() + + with patch("openviking.storage.viking_fs.agfs_cp", mock_agfs_cp): + await fs.copy_directory( + "viking://temp/source_dir/", + "viking://temp/dest_dir/", + ) + + fs._ensure_access.assert_any_call("viking://temp/source_dir/", None) + fs._ensure_access.assert_any_call("viking://temp/dest_dir/", None) + fs._ensure_parent_dirs.assert_awaited_once_with("/local/temp/dest_dir/") + mock_agfs_cp.assert_called_once_with( + fs.agfs, + "/local/temp/source_dir/", + "/local/temp/dest_dir/", + recursive=True, + ) + + async def test_copy_directory_with_context(self): + """copy_directory() should pass context to helper methods.""" + from openviking.server.identity import RequestContext, Role + from openviking_cli.session.user_id import UserIdentifier + + fs = _create_viking_fs_mock() + fs._ensure_access = MagicMock() + fs._uri_to_path = MagicMock(side_effect=lambda uri, ctx=None: uri.replace("viking://", "/local/")) + fs._ensure_parent_dirs = AsyncMock() + + ctx = RequestContext( + user=UserIdentifier("acc1", "user1", "agent1"), + role=Role.USER, + ) + + mock_agfs_cp = MagicMock() + + with patch("openviking.storage.viking_fs.agfs_cp", mock_agfs_cp): + await fs.copy_directory( + "viking://temp/source/", + "viking://temp/dest/", + ctx=ctx, + ) + + fs._ensure_access.assert_any_call("viking://temp/source/", ctx) + fs._ensure_access.assert_any_call("viking://temp/dest/", ctx) + + +@pytest.mark.asyncio +class TestVikingFSDeleteTemp: + """Test cases for VikingFS.delete_temp() method.""" + + async def test_delete_temp_removes_directory_and_contents(self): + """delete_temp() should remove directory and all its contents.""" + fs = _create_viking_fs_mock() + fs._uri_to_path = MagicMock(return_value="/local/temp/test_temp") + + fs._ls_entries = MagicMock(return_value=[ + {"name": "file1.txt", "isDir": False}, + {"name": "subdir", "isDir": True}, + ]) + + fs.agfs.rm = MagicMock() + + call_count = [0] + + async def mock_delete_temp(uri, ctx=None): + call_count[0] += 1 + if call_count[0] == 1: + fs._ls_entries.return_value = [ + {"name": "nested_file.txt", "isDir": False}, + ] + await fs.delete_temp(uri, ctx=ctx) + else: + fs._ls_entries.return_value = [] + + original_delete_temp = fs.delete_temp + fs.delete_temp = mock_delete_temp + + await original_delete_temp("viking://temp/test_temp/") + + assert fs.agfs.rm.call_count >= 1 + + async def test_delete_temp_handles_empty_directory(self): + """delete_temp() should handle empty directory gracefully.""" + fs = _create_viking_fs_mock() + fs._uri_to_path = MagicMock(return_value="/local/temp/empty_temp") + fs._ls_entries = MagicMock(return_value=[]) + fs.agfs.rm = MagicMock() + + await fs.delete_temp("viking://temp/empty_temp/") + + fs.agfs.rm.assert_called_once_with("/local/temp/empty_temp") + + async def test_delete_temp_skips_dot_entries(self): + """delete_temp() should skip . and .. entries.""" + fs = _create_viking_fs_mock() + fs._uri_to_path = MagicMock(return_value="/local/temp/test_temp") + fs._ls_entries = MagicMock(return_value=[ + {"name": ".", "isDir": True}, + {"name": "..", "isDir": True}, + {"name": "actual_file.txt", "isDir": False}, + ]) + fs.agfs.rm = MagicMock() + + await fs.delete_temp("viking://temp/test_temp/") + + rm_calls = [call[0][0] for call in fs.agfs.rm.call_args_list] + assert "/local/temp/test_temp/actual_file.txt" in rm_calls + assert "/local/temp/test_temp/." not in rm_calls + assert "/local/temp/test_temp/.." not in rm_calls + + async def test_delete_temp_logs_warning_on_error(self): + """delete_temp() should log warning but not raise on error.""" + fs = _create_viking_fs_mock() + fs._uri_to_path = MagicMock(return_value="/local/temp/error_temp") + fs._ls_entries = MagicMock(side_effect=Exception("AGFS error")) + + with patch("openviking.storage.viking_fs.logger") as mock_logger: + await fs.delete_temp("viking://temp/error_temp/") + + mock_logger.warning.assert_called_once() + assert "Failed to delete temp" in mock_logger.warning.call_args[0][0] From dc317f9b3e705d2b15ba4b663c564821a025b1f1 Mon Sep 17 00:00:00 2001 From: yepper Date: Thu, 12 Mar 2026 18:19:00 +0800 Subject: [PATCH 4/5] style(tests): clean up unused imports in test files Remove unused imports across multiple test files to improve code cleanliness and reduce potential confusion. This includes removing unused mock objects, context classes, and constants that are not referenced in the tests. --- tests/unit/session/test_deduplicator_uri.py | 1 - tests/unit/session/test_memory_extractor_tools.py | 5 +---- tests/unit/storage/queuefs/test_embedding_msg.py | 1 + tests/unit/storage/queuefs/test_processor_incremental.py | 1 - 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/unit/session/test_deduplicator_uri.py b/tests/unit/session/test_deduplicator_uri.py index 900375cd..ba5c5c8d 100644 --- a/tests/unit/session/test_deduplicator_uri.py +++ b/tests/unit/session/test_deduplicator_uri.py @@ -5,7 +5,6 @@ import pytest -from openviking.core.context import Context from openviking.session.memory_deduplicator import MemoryDeduplicator from openviking.session.memory_extractor import CandidateMemory, MemoryCategory from openviking_cli.session.user_id import UserIdentifier diff --git a/tests/unit/session/test_memory_extractor_tools.py b/tests/unit/session/test_memory_extractor_tools.py index cd8b6189..4d77aaab 100644 --- a/tests/unit/session/test_memory_extractor_tools.py +++ b/tests/unit/session/test_memory_extractor_tools.py @@ -1,16 +1,13 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import patch import pytest from openviking.session.memory_extractor import ( - FIELD_MAX_LENGTH, FIELD_MAX_LENGTHS, MemoryExtractor, - MemoryCategory, - ToolSkillCandidateMemory, ) diff --git a/tests/unit/storage/queuefs/test_embedding_msg.py b/tests/unit/storage/queuefs/test_embedding_msg.py index 0ff42657..af4d073e 100644 --- a/tests/unit/storage/queuefs/test_embedding_msg.py +++ b/tests/unit/storage/queuefs/test_embedding_msg.py @@ -1,6 +1,7 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: Apache-2.0 import json + import pytest from openviking.storage.queuefs.embedding_msg import EmbeddingMsg diff --git a/tests/unit/storage/queuefs/test_processor_incremental.py b/tests/unit/storage/queuefs/test_processor_incremental.py index c87ebcf9..133c4649 100644 --- a/tests/unit/storage/queuefs/test_processor_incremental.py +++ b/tests/unit/storage/queuefs/test_processor_incremental.py @@ -23,7 +23,6 @@ from openviking.server.identity import RequestContext, Role from openviking.storage.queuefs.semantic_processor import DiffResult, SemanticProcessor from openviking_cli.session.user_id import UserIdentifier -from openviking_cli.utils import VikingURI class FakeVikingFS: From 0bd87d85d356fb3982860f976c665e98f1a56333 Mon Sep 17 00:00:00 2001 From: yepper Date: Thu, 12 Mar 2026 18:23:32 +0800 Subject: [PATCH 5/5] style: reformat code for better readability and consistency Refactor long lines and adjust formatting to improve code readability. Changes include: - Breaking long lines to adhere to line length limits - Reformatting dictionary and list literals for consistency - Adjusting indentation in multi-line statements --- .../storage/queuefs/embedding_tracker.py | 4 +- .../storage/queuefs/semantic_processor.py | 8 +- tests/unit/session/test_deduplicator_uri.py | 22 +- .../storage/queuefs/test_dag_incremental.py | 65 ++++-- .../storage/queuefs/test_embedding_msg.py | 26 ++- .../queuefs/test_processor_incremental.py | 217 ++++++++++-------- .../unit/storage/queuefs/test_semantic_msg.py | 24 +- tests/unit/storage/test_viking_fs_new.py | 30 ++- 8 files changed, 243 insertions(+), 153 deletions(-) diff --git a/openviking/storage/queuefs/embedding_tracker.py b/openviking/storage/queuefs/embedding_tracker.py index c9fd6ab6..4149d8d1 100644 --- a/openviking/storage/queuefs/embedding_tracker.py +++ b/openviking/storage/queuefs/embedding_tracker.py @@ -126,7 +126,9 @@ async def decrement(self, semantic_msg_id: str) -> Optional[int]: on_complete = task_info.get("on_complete") del self._tasks[semantic_msg_id] - logger.info(f"All embedding tasks({task_info['total']}) completed for SemanticMsg {semantic_msg_id}") + logger.info( + f"All embedding tasks({task_info['total']}) completed for SemanticMsg {semantic_msg_id}" + ) if on_complete: try: diff --git a/openviking/storage/queuefs/semantic_processor.py b/openviking/storage/queuefs/semantic_processor.py index f60d37eb..56a05b14 100644 --- a/openviking/storage/queuefs/semantic_processor.py +++ b/openviking/storage/queuefs/semantic_processor.py @@ -236,7 +236,9 @@ async def sync_diff_callback() -> None: root_tree = await self._collect_tree_info(root_uri, ctx=ctx) target_tree = await self._collect_tree_info(target_uri, ctx=ctx) - diff = await self._compute_diff(root_tree, target_tree, root_uri, target_uri, ctx=ctx) + diff = await self._compute_diff( + root_tree, target_tree, root_uri, target_uri, ctx=ctx + ) logger.info( f"[SyncDiff] Diff computed: " f"added_files={len(diff.added_files)}, " @@ -460,9 +462,7 @@ def map_to_target(root_item_uri: str) -> str: target_parent = VikingURI(target_file).parent if target_parent: try: - await viking_fs.mkdir( - target_parent.uri, exist_ok=True, ctx=ctx - ) + await viking_fs.mkdir(target_parent.uri, exist_ok=True, ctx=ctx) except Exception as mkdir_error: logger.debug( f"[SyncDiff] Parent dir creation skipped (may already exist): {mkdir_error}" diff --git a/tests/unit/session/test_deduplicator_uri.py b/tests/unit/session/test_deduplicator_uri.py index ba5c5c8d..46057b77 100644 --- a/tests/unit/session/test_deduplicator_uri.py +++ b/tests/unit/session/test_deduplicator_uri.py @@ -72,7 +72,9 @@ class TestFindSimilarMemoriesURIConversion: async def test_user_uri_converted_to_temp_uri(self): vikingdb = MagicMock() vikingdb.get_embedder.return_value = _DummyEmbedder() - vikingdb.search_similar_memories = AsyncMock(return_value=[_make_existing_user_memory("pref1.md")]) + vikingdb.search_similar_memories = AsyncMock( + return_value=[_make_existing_user_memory("pref1.md")] + ) dedup = MemoryDeduplicator(vikingdb=vikingdb) candidate = _make_candidate() @@ -94,7 +96,9 @@ async def test_user_uri_converted_to_temp_uri(self): async def test_agent_uri_converted_to_temp_uri(self): vikingdb = MagicMock() vikingdb.get_embedder.return_value = _DummyEmbedder() - vikingdb.search_similar_memories = AsyncMock(return_value=[_make_existing_agent_memory("case1.md")]) + vikingdb.search_similar_memories = AsyncMock( + return_value=[_make_existing_agent_memory("case1.md")] + ) dedup = MemoryDeduplicator(vikingdb=vikingdb) candidate = _make_candidate(category=MemoryCategory.CASES) @@ -117,7 +121,9 @@ async def test_agent_uri_converted_to_temp_uri(self): async def test_no_conversion_when_no_temp_uri(self): vikingdb = MagicMock() vikingdb.get_embedder.return_value = _DummyEmbedder() - vikingdb.search_similar_memories = AsyncMock(return_value=[_make_existing_user_memory("pref1.md")]) + vikingdb.search_similar_memories = AsyncMock( + return_value=[_make_existing_user_memory("pref1.md")] + ) dedup = MemoryDeduplicator(vikingdb=vikingdb) candidate = _make_candidate() @@ -164,7 +170,9 @@ async def test_mixed_uris_only_convert_matching_type(self): async def test_uri_conversion_preserves_meta_and_score(self): vikingdb = MagicMock() vikingdb.get_embedder.return_value = _DummyEmbedder() - vikingdb.search_similar_memories = AsyncMock(return_value=[_make_existing_user_memory("pref1.md")]) + vikingdb.search_similar_memories = AsyncMock( + return_value=[_make_existing_user_memory("pref1.md")] + ) dedup = MemoryDeduplicator(vikingdb=vikingdb) candidate = _make_candidate() @@ -199,7 +207,9 @@ def test_extract_with_em_dash(self): assert result == "work schedule" def test_extract_with_no_separator_returns_prefix(self): - result = MemoryDeduplicator._extract_facet_key("This is a long abstract without any separator") + result = MemoryDeduplicator._extract_facet_key( + "This is a long abstract without any separator" + ) assert len(result) <= 24 assert result == "this is a long abstract" @@ -278,7 +288,7 @@ def test_partial_similarity(self): vec_a = [1.0, 0.0, 0.0] vec_b = [1.0, 1.0, 0.0] result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) - expected = 1.0 / (2.0 ** 0.5) + expected = 1.0 / (2.0**0.5) assert abs(result - expected) < 1e-9 def test_negative_values(self): diff --git a/tests/unit/storage/queuefs/test_dag_incremental.py b/tests/unit/storage/queuefs/test_dag_incremental.py index f2f0cca9..08082ce7 100644 --- a/tests/unit/storage/queuefs/test_dag_incremental.py +++ b/tests/unit/storage/queuefs/test_dag_incremental.py @@ -15,7 +15,9 @@ def mock_processor(): """Create a mock SemanticProcessor.""" processor = MagicMock() - processor._generate_single_file_summary = AsyncMock(return_value={"name": "test.py", "summary": "test summary"}) + processor._generate_single_file_summary = AsyncMock( + return_value={"name": "test.py", "summary": "test summary"} + ) processor._generate_overview = AsyncMock(return_value="test overview") processor._extract_abstract_from_overview = MagicMock(return_value="test abstract") processor._vectorize_single_file = AsyncMock() @@ -54,7 +56,9 @@ def mock_context(): @pytest.fixture def executor(mock_processor, mock_context, mock_viking_fs): """Create a SemanticDagExecutor instance for testing.""" - with patch("openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs): + with patch( + "openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs + ): executor = SemanticDagExecutor( processor=mock_processor, context_type="resource", @@ -70,8 +74,12 @@ def executor(mock_processor, mock_context, mock_viking_fs): class TestGetTargetFilePath: """Tests for _get_target_file_path() method.""" - def test_returns_none_when_incremental_update_disabled(self, mock_processor, mock_context, mock_viking_fs): - with patch("openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs): + def test_returns_none_when_incremental_update_disabled( + self, mock_processor, mock_context, mock_viking_fs + ): + with patch( + "openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs + ): executor = SemanticDagExecutor( processor=mock_processor, context_type="resource", @@ -84,8 +92,12 @@ def test_returns_none_when_incremental_update_disabled(self, mock_processor, moc result = executor._get_target_file_path("viking://resource/root/file.py") assert result is None - def test_returns_none_when_target_uri_is_none(self, mock_processor, mock_context, mock_viking_fs): - with patch("openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs): + def test_returns_none_when_target_uri_is_none( + self, mock_processor, mock_context, mock_viking_fs + ): + with patch( + "openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs + ): executor = SemanticDagExecutor( processor=mock_processor, context_type="resource", @@ -207,7 +219,9 @@ async def test_returns_none_when_vector_store_is_none(self, executor, mock_vikin assert result is None @pytest.mark.asyncio - async def test_returns_summary_dict_when_record_exists(self, executor, mock_viking_fs, mock_vector_store): + async def test_returns_summary_dict_when_record_exists( + self, executor, mock_viking_fs, mock_vector_store + ): executor._viking_fs = mock_viking_fs executor._root_uri = "viking://resource/root" mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) @@ -234,7 +248,9 @@ async def test_returns_none_when_no_records(self, executor, mock_viking_fs, mock assert result is None @pytest.mark.asyncio - async def test_returns_none_when_abstract_is_empty(self, executor, mock_viking_fs, mock_vector_store): + async def test_returns_none_when_abstract_is_empty( + self, executor, mock_viking_fs, mock_vector_store + ): executor._viking_fs = mock_viking_fs executor._root_uri = "viking://resource/root" mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) @@ -256,7 +272,9 @@ async def test_returns_none_on_exception(self, executor, mock_viking_fs, mock_ve assert result is None @pytest.mark.asyncio - async def test_calls_vector_store_with_correct_uri(self, executor, mock_viking_fs, mock_vector_store, mock_context): + async def test_calls_vector_store_with_correct_uri( + self, executor, mock_viking_fs, mock_vector_store, mock_context + ): executor._viking_fs = mock_viking_fs executor._root_uri = "viking://resource/root" mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) @@ -293,7 +311,10 @@ async def test_returns_false_when_children_identical(self, executor, mock_viking executor._list_dir = AsyncMock( side_effect=[ - (["viking://resource/target/dir1", "viking://resource/target/dir2"], ["viking://resource/target/file1.py", "viking://resource/target/file2.py"]), + ( + ["viking://resource/target/dir1", "viking://resource/target/dir2"], + ["viking://resource/target/file1.py", "viking://resource/target/file2.py"], + ), ([], []), ] ) @@ -382,9 +403,7 @@ async def test_handles_empty_directories(self, executor, mock_viking_fs): mock_viking_fs.ls = AsyncMock(return_value=[]) - result = await executor._check_dir_children_changed( - "viking://resource/root", [], [] - ) + result = await executor._check_dir_children_changed("viking://resource/root", [], []) assert result is False @pytest.mark.asyncio @@ -496,7 +515,9 @@ class TestIncrementalUpdateIntegration: """Integration tests for incremental update scenarios.""" @pytest.mark.asyncio - async def test_full_incremental_flow_no_changes(self, executor, mock_viking_fs, mock_vector_store): + async def test_full_incremental_flow_no_changes( + self, executor, mock_viking_fs, mock_vector_store + ): executor._viking_fs = mock_viking_fs executor._root_uri = "viking://resource/root" mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) @@ -507,7 +528,9 @@ async def test_full_incremental_flow_no_changes(self, executor, mock_viking_fs, return_value=[{"abstract": "existing summary"}] ) - content_changed = await executor._check_file_content_changed("viking://resource/root/file.py") + content_changed = await executor._check_file_content_changed( + "viking://resource/root/file.py" + ) assert content_changed is False summary = await executor._read_existing_summary("viking://resource/root/file.py") @@ -515,14 +538,18 @@ async def test_full_incremental_flow_no_changes(self, executor, mock_viking_fs, assert summary["summary"] == "existing summary" @pytest.mark.asyncio - async def test_full_incremental_flow_with_changes(self, executor, mock_viking_fs, mock_vector_store): + async def test_full_incremental_flow_with_changes( + self, executor, mock_viking_fs, mock_vector_store + ): executor._viking_fs = mock_viking_fs executor._root_uri = "viking://resource/root" mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) mock_viking_fs.read_file = AsyncMock(side_effect=["new content", "old content"]) - content_changed = await executor._check_file_content_changed("viking://resource/root/file.py") + content_changed = await executor._check_file_content_changed( + "viking://resource/root/file.py" + ) assert content_changed is True @pytest.mark.asyncio @@ -549,9 +576,7 @@ async def test_overview_abstract_read_flow(self, executor, mock_viking_fs): executor._viking_fs = mock_viking_fs executor._root_uri = "viking://resource/root" - mock_viking_fs.read_file = AsyncMock( - side_effect=["existing overview", "existing abstract"] - ) + mock_viking_fs.read_file = AsyncMock(side_effect=["existing overview", "existing abstract"]) overview, abstract = await executor._read_existing_overview_abstract( "viking://resource/root/subdir" diff --git a/tests/unit/storage/queuefs/test_embedding_msg.py b/tests/unit/storage/queuefs/test_embedding_msg.py index af4d073e..3311ae96 100644 --- a/tests/unit/storage/queuefs/test_embedding_msg.py +++ b/tests/unit/storage/queuefs/test_embedding_msg.py @@ -86,12 +86,14 @@ def test_to_json_without_semantic_msg_id(self): def test_from_json_with_semantic_msg_id(self): """Test from_json() method with semantic_msg_id.""" - json_str = json.dumps({ - "id": "json-id-123", - "message": "from json", - "context_data": {"from": "json"}, - "semantic_msg_id": "semantic-from-json", - }) + json_str = json.dumps( + { + "id": "json-id-123", + "message": "from json", + "context_data": {"from": "json"}, + "semantic_msg_id": "semantic-from-json", + } + ) msg = EmbeddingMsg.from_json(json_str) assert msg.semantic_msg_id == "semantic-from-json" assert msg.id == "json-id-123" @@ -99,11 +101,13 @@ def test_from_json_with_semantic_msg_id(self): def test_from_json_missing_semantic_msg_id(self): """Test from_json() with missing semantic_msg_id (backward compatibility).""" - json_str = json.dumps({ - "id": "json-id-456", - "message": "legacy json", - "context_data": {"legacy": True}, - }) + json_str = json.dumps( + { + "id": "json-id-456", + "message": "legacy json", + "context_data": {"legacy": True}, + } + ) msg = EmbeddingMsg.from_json(json_str) assert msg.semantic_msg_id is None assert msg.message == "legacy json" diff --git a/tests/unit/storage/queuefs/test_processor_incremental.py b/tests/unit/storage/queuefs/test_processor_incremental.py index 133c4649..01c9a57e 100644 --- a/tests/unit/storage/queuefs/test_processor_incremental.py +++ b/tests/unit/storage/queuefs/test_processor_incremental.py @@ -175,12 +175,14 @@ async def test_collect_empty_directory(self, processor, fake_fs, ctx): @pytest.mark.asyncio async def test_collect_directory_with_files(self, processor, fake_fs, ctx): - fake_fs.set_tree({ - "viking://temp/dir": [ - {"name": "file1.txt", "isDir": False}, - {"name": "file2.py", "isDir": False}, - ] - }) + fake_fs.set_tree( + { + "viking://temp/dir": [ + {"name": "file1.txt", "isDir": False}, + {"name": "file2.py", "isDir": False}, + ] + } + ) processor._current_ctx = ctx with patch( @@ -197,18 +199,20 @@ async def test_collect_directory_with_files(self, processor, fake_fs, ctx): @pytest.mark.asyncio async def test_collect_directory_with_subdirs(self, processor, fake_fs, ctx): - fake_fs.set_tree({ - "viking://temp/root": [ - {"name": "subdir1", "isDir": True}, - {"name": "subdir2", "isDir": True}, - ], - "viking://temp/root/subdir1": [ - {"name": "file1.txt", "isDir": False}, - ], - "viking://temp/root/subdir2": [ - {"name": "file2.txt", "isDir": False}, - ], - }) + fake_fs.set_tree( + { + "viking://temp/root": [ + {"name": "subdir1", "isDir": True}, + {"name": "subdir2", "isDir": True}, + ], + "viking://temp/root/subdir1": [ + {"name": "file1.txt", "isDir": False}, + ], + "viking://temp/root/subdir2": [ + {"name": "file2.txt", "isDir": False}, + ], + } + ) processor._current_ctx = ctx with patch( @@ -222,17 +226,19 @@ async def test_collect_directory_with_subdirs(self, processor, fake_fs, ctx): @pytest.mark.asyncio async def test_collect_nested_directories(self, processor, fake_fs, ctx): - fake_fs.set_tree({ - "viking://temp/root": [ - {"name": "level1", "isDir": True}, - ], - "viking://temp/root/level1": [ - {"name": "level2", "isDir": True}, - ], - "viking://temp/root/level1/level2": [ - {"name": "deep_file.txt", "isDir": False}, - ], - }) + fake_fs.set_tree( + { + "viking://temp/root": [ + {"name": "level1", "isDir": True}, + ], + "viking://temp/root/level1": [ + {"name": "level2", "isDir": True}, + ], + "viking://temp/root/level1/level2": [ + {"name": "deep_file.txt", "isDir": False}, + ], + } + ) processor._current_ctx = ctx with patch( @@ -246,12 +252,14 @@ async def test_collect_nested_directories(self, processor, fake_fs, ctx): @pytest.mark.asyncio async def test_collect_skips_hidden_files(self, processor, fake_fs, ctx): - fake_fs.set_tree({ - "viking://temp/dir": [ - {"name": ".hidden", "isDir": False}, - {"name": "visible.txt", "isDir": False}, - ] - }) + fake_fs.set_tree( + { + "viking://temp/dir": [ + {"name": ".hidden", "isDir": False}, + {"name": "visible.txt", "isDir": False}, + ] + } + ) processor._current_ctx = ctx with patch( @@ -265,13 +273,15 @@ async def test_collect_skips_hidden_files(self, processor, fake_fs, ctx): @pytest.mark.asyncio async def test_collect_skips_dot_and_dotdot(self, processor, fake_fs, ctx): - fake_fs.set_tree({ - "viking://temp/dir": [ - {"name": ".", "isDir": True}, - {"name": "..", "isDir": True}, - {"name": "file.txt", "isDir": False}, - ] - }) + fake_fs.set_tree( + { + "viking://temp/dir": [ + {"name": ".", "isDir": True}, + {"name": "..", "isDir": True}, + {"name": "file.txt", "isDir": False}, + ] + } + ) processor._current_ctx = ctx with patch( @@ -372,10 +382,12 @@ async def test_compute_diff_updated_files(self, processor, fake_fs, ctx): target_tree = { "viking://target/root": ([], ["viking://target/root/file.txt"]), } - fake_fs.set_file_contents({ - "viking://temp/root/file.txt": "new content", - "viking://target/root/file.txt": "old content", - }) + fake_fs.set_file_contents( + { + "viking://temp/root/file.txt": "new content", + "viking://target/root/file.txt": "old content", + } + ) processor._current_ctx = ctx with patch( @@ -396,10 +408,12 @@ async def test_compute_diff_unchanged_files(self, processor, fake_fs, ctx): target_tree = { "viking://target/root": ([], ["viking://target/root/file.txt"]), } - fake_fs.set_file_contents({ - "viking://temp/root/file.txt": "same content", - "viking://target/root/file.txt": "same content", - }) + fake_fs.set_file_contents( + { + "viking://temp/root/file.txt": "same content", + "viking://target/root/file.txt": "same content", + } + ) processor._current_ctx = ctx with patch( @@ -469,10 +483,12 @@ async def test_compute_diff_mixed_changes(self, processor, fake_fs, ctx): ), "viking://target/root/old_dir": ([], []), } - fake_fs.set_file_contents({ - "viking://temp/root/updated.txt": "new content", - "viking://target/root/updated.txt": "old content", - }) + fake_fs.set_file_contents( + { + "viking://temp/root/updated.txt": "new content", + "viking://target/root/updated.txt": "old content", + } + ) processor._current_ctx = ctx with patch( @@ -494,10 +510,12 @@ class TestCheckFileContentChanged: @pytest.mark.asyncio async def test_content_changed(self, processor, fake_fs, ctx): - fake_fs.set_file_contents({ - "viking://temp/file.txt": "new content", - "viking://target/file.txt": "old content", - }) + fake_fs.set_file_contents( + { + "viking://temp/file.txt": "new content", + "viking://target/file.txt": "old content", + } + ) processor._current_ctx = ctx with patch( @@ -511,10 +529,12 @@ async def test_content_changed(self, processor, fake_fs, ctx): @pytest.mark.asyncio async def test_content_unchanged(self, processor, fake_fs, ctx): - fake_fs.set_file_contents({ - "viking://temp/file.txt": "same content", - "viking://target/file.txt": "same content", - }) + fake_fs.set_file_contents( + { + "viking://temp/file.txt": "same content", + "viking://target/file.txt": "same content", + } + ) processor._current_ctx = ctx with patch( @@ -528,10 +548,12 @@ async def test_content_unchanged(self, processor, fake_fs, ctx): @pytest.mark.asyncio async def test_content_changed_empty_files(self, processor, fake_fs, ctx): - fake_fs.set_file_contents({ - "viking://temp/file.txt": "", - "viking://target/file.txt": "", - }) + fake_fs.set_file_contents( + { + "viking://temp/file.txt": "", + "viking://target/file.txt": "", + } + ) processor._current_ctx = ctx with patch( @@ -545,10 +567,12 @@ async def test_content_changed_empty_files(self, processor, fake_fs, ctx): @pytest.mark.asyncio async def test_content_changed_one_empty(self, processor, fake_fs, ctx): - fake_fs.set_file_contents({ - "viking://temp/file.txt": "content", - "viking://target/file.txt": "", - }) + fake_fs.set_file_contents( + { + "viking://temp/file.txt": "content", + "viking://target/file.txt": "", + } + ) processor._current_ctx = ctx with patch( @@ -616,7 +640,8 @@ async def test_execute_move_added_files(self, processor, fake_fs, ctx): "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs ): with patch( - "openviking.storage.queuefs.semantic_processor.VikingURI", return_value=mock_viking_uri + "openviking.storage.queuefs.semantic_processor.VikingURI", + return_value=mock_viking_uri, ): await processor._execute_sync_operations( diff, "viking://temp/root", "viking://target/root" @@ -642,14 +667,18 @@ async def test_execute_move_updated_files(self, processor, fake_fs, ctx): "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs ): with patch( - "openviking.storage.queuefs.semantic_processor.VikingURI", return_value=mock_viking_uri + "openviking.storage.queuefs.semantic_processor.VikingURI", + return_value=mock_viking_uri, ): await processor._execute_sync_operations( diff, "viking://temp/root", "viking://target/root" ) assert "viking://target/root/updated.txt" in fake_fs.deleted_files - assert ("viking://temp/root/updated.txt", "viking://target/root/updated.txt") in fake_fs.moved_files + assert ( + "viking://temp/root/updated.txt", + "viking://target/root/updated.txt", + ) in fake_fs.moved_files @pytest.mark.asyncio async def test_execute_delete_dirs(self, processor, fake_fs, ctx): @@ -717,7 +746,8 @@ async def test_execute_creates_parent_dirs(self, processor, fake_fs, ctx): "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs ): with patch( - "openviking.storage.queuefs.semantic_processor.VikingURI", return_value=mock_viking_uri + "openviking.storage.queuefs.semantic_processor.VikingURI", + return_value=mock_viking_uri, ): await processor._execute_sync_operations( diff, "viking://temp/root", "viking://target/root" @@ -742,22 +772,27 @@ async def test_callback_is_async(self, processor): "viking://temp/root", "viking://target/root" ) import asyncio + assert asyncio.iscoroutinefunction(callback) @pytest.mark.asyncio async def test_callback_collects_tree_info(self, processor, fake_fs, ctx): - fake_fs.set_tree({ - "viking://temp/root": [ - {"name": "file.txt", "isDir": False}, - ], - "viking://target/root": [ - {"name": "file.txt", "isDir": False}, - ], - }) - fake_fs.set_file_contents({ - "viking://temp/root/file.txt": "content", - "viking://target/root/file.txt": "content", - }) + fake_fs.set_tree( + { + "viking://temp/root": [ + {"name": "file.txt", "isDir": False}, + ], + "viking://target/root": [ + {"name": "file.txt", "isDir": False}, + ], + } + ) + fake_fs.set_file_contents( + { + "viking://temp/root/file.txt": "content", + "viking://target/root/file.txt": "content", + } + ) processor._current_ctx = ctx with patch( @@ -783,10 +818,12 @@ async def test_callback_handles_exception(self, processor, fake_fs, ctx): @pytest.mark.asyncio async def test_callback_deletes_root_after_sync(self, processor, fake_fs, ctx): - fake_fs.set_tree({ - "viking://temp/root": [], - "viking://target/root": [], - }) + fake_fs.set_tree( + { + "viking://temp/root": [], + "viking://target/root": [], + } + ) processor._current_ctx = ctx with patch( diff --git a/tests/unit/storage/queuefs/test_semantic_msg.py b/tests/unit/storage/queuefs/test_semantic_msg.py index 84dee8eb..37551109 100644 --- a/tests/unit/storage/queuefs/test_semantic_msg.py +++ b/tests/unit/storage/queuefs/test_semantic_msg.py @@ -233,21 +233,25 @@ def test_to_json_and_from_json_round_trip(self): assert restored.role == original.role def test_from_json_with_new_fields(self): - json_str = json.dumps({ - "uri": "viking://resource/test", - "context_type": "resource", - "target_uri": "viking://resource/target", - "skip_vectorization": True, - }) + json_str = json.dumps( + { + "uri": "viking://resource/test", + "context_type": "resource", + "target_uri": "viking://resource/target", + "skip_vectorization": True, + } + ) msg = SemanticMsg.from_json(json_str) assert msg.target_uri == "viking://resource/target" assert msg.skip_vectorization is True def test_from_json_without_new_fields(self): - json_str = json.dumps({ - "uri": "viking://resource/test", - "context_type": "resource", - }) + json_str = json.dumps( + { + "uri": "viking://resource/test", + "context_type": "resource", + } + ) msg = SemanticMsg.from_json(json_str) assert msg.target_uri == "" assert msg.skip_vectorization is False diff --git a/tests/unit/storage/test_viking_fs_new.py b/tests/unit/storage/test_viking_fs_new.py index b754dd64..6661cc9f 100644 --- a/tests/unit/storage/test_viking_fs_new.py +++ b/tests/unit/storage/test_viking_fs_new.py @@ -69,7 +69,9 @@ async def test_copy_directory_recursive(self): """copy_directory() should recursively copy directory contents.""" fs = _create_viking_fs_mock() fs._ensure_access = MagicMock() - fs._uri_to_path = MagicMock(side_effect=lambda uri, ctx=None: uri.replace("viking://", "/local/")) + fs._uri_to_path = MagicMock( + side_effect=lambda uri, ctx=None: uri.replace("viking://", "/local/") + ) fs._ensure_parent_dirs = AsyncMock() mock_agfs_cp = MagicMock() @@ -97,7 +99,9 @@ async def test_copy_directory_with_context(self): fs = _create_viking_fs_mock() fs._ensure_access = MagicMock() - fs._uri_to_path = MagicMock(side_effect=lambda uri, ctx=None: uri.replace("viking://", "/local/")) + fs._uri_to_path = MagicMock( + side_effect=lambda uri, ctx=None: uri.replace("viking://", "/local/") + ) fs._ensure_parent_dirs = AsyncMock() ctx = RequestContext( @@ -127,10 +131,12 @@ async def test_delete_temp_removes_directory_and_contents(self): fs = _create_viking_fs_mock() fs._uri_to_path = MagicMock(return_value="/local/temp/test_temp") - fs._ls_entries = MagicMock(return_value=[ - {"name": "file1.txt", "isDir": False}, - {"name": "subdir", "isDir": True}, - ]) + fs._ls_entries = MagicMock( + return_value=[ + {"name": "file1.txt", "isDir": False}, + {"name": "subdir", "isDir": True}, + ] + ) fs.agfs.rm = MagicMock() @@ -168,11 +174,13 @@ async def test_delete_temp_skips_dot_entries(self): """delete_temp() should skip . and .. entries.""" fs = _create_viking_fs_mock() fs._uri_to_path = MagicMock(return_value="/local/temp/test_temp") - fs._ls_entries = MagicMock(return_value=[ - {"name": ".", "isDir": True}, - {"name": "..", "isDir": True}, - {"name": "actual_file.txt", "isDir": False}, - ]) + fs._ls_entries = MagicMock( + return_value=[ + {"name": ".", "isDir": True}, + {"name": "..", "isDir": True}, + {"name": "actual_file.txt", "isDir": False}, + ] + ) fs.agfs.rm = MagicMock() await fs.delete_temp("viking://temp/test_temp/")