diff --git a/openviking/core/context.py b/openviking/core/context.py index 94d47b1f..76308570 100644 --- a/openviking/core/context.py +++ b/openviking/core/context.py @@ -56,6 +56,7 @@ def __init__( self, uri: str, parent_uri: Optional[str] = None, + temp_uri: Optional[str] = None, is_leaf: bool = False, abstract: str = "", context_type: Optional[str] = None, @@ -78,6 +79,7 @@ def __init__( self.id = id or str(uuid4()) self.uri = uri self.parent_uri = parent_uri + self.temp_uri = temp_uri self.is_leaf = is_leaf self.abstract = abstract self.context_type = context_type or self._derive_context_type() @@ -159,6 +161,7 @@ def to_dict(self) -> Dict[str, Any]: "id": self.id, "uri": self.uri, "parent_uri": self.parent_uri, + "temp_uri": self.temp_uri, "is_leaf": self.is_leaf, "abstract": self.abstract, "context_type": self.context_type, @@ -194,6 +197,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "Context": obj = cls( uri=data["uri"], parent_uri=data.get("parent_uri"), + temp_uri=data.get("temp_uri"), is_leaf=data.get("is_leaf", False), abstract=data.get("abstract", ""), context_type=data.get("context_type"), diff --git a/openviking/parse/parsers/constants.py b/openviking/parse/parsers/constants.py index 311e545a..a817a8f2 100644 --- a/openviking/parse/parsers/constants.py +++ b/openviking/parse/parsers/constants.py @@ -174,6 +174,7 @@ ".graphql", ".gql", ".prisma", + ".conf", } # Documentation file extensions for file type detection @@ -224,6 +225,7 @@ ".yarnrc", ".env", ".env.example", + ".jsonl", } # Common text encodings to try for encoding detection (in order of likelihood) diff --git a/openviking/parse/parsers/directory.py b/openviking/parse/parsers/directory.py index da27b659..371cf0f9 100644 --- a/openviking/parse/parsers/directory.py +++ b/openviking/parse/parsers/directory.py @@ -123,6 +123,11 @@ async def parse( viking_fs = self._get_viking_fs() temp_uri = self._create_temp_uri() target_uri = f"{temp_uri}/{dir_name}" + logger.info( + f"Scanning directory: {source_path}, " + f"processable files: {len(processable_files)}, " + f"warnings: {warnings}" + ) await viking_fs.mkdir(temp_uri, exist_ok=True) await viking_fs.mkdir(target_uri, exist_ok=True) diff --git a/openviking/parse/parsers/upload_utils.py b/openviking/parse/parsers/upload_utils.py index 50c80ed0..d1870173 100644 --- a/openviking/parse/parsers/upload_utils.py +++ b/openviking/parse/parsers/upload_utils.py @@ -40,6 +40,7 @@ "NEWS", "NOTICE", "TODO", + "BUILD", } diff --git a/openviking/parse/tree_builder.py b/openviking/parse/tree_builder.py index c8409f41..18bcbf07 100644 --- a/openviking/parse/tree_builder.py +++ b/openviking/parse/tree_builder.py @@ -138,16 +138,6 @@ async def finalize_from_temp( base_uri = parent_uri or auto_base_uri # 3. Determine candidate_uri if to_uri: - # Exact target URI: must not exist yet - try: - await viking_fs.stat(to_uri, ctx=ctx) - # If we get here, it already exists - raise FileExistsError(f"Target URI already exists: {to_uri}") - except FileExistsError: - raise - except Exception: - # It doesn't exist, good to use - pass candidate_uri = to_uri else: if parent_uri: @@ -160,34 +150,7 @@ async def finalize_from_temp( raise ValueError(f"Parent URI is not a directory: {parent_uri}") candidate_uri = VikingURI(base_uri).join(final_doc_name).uri - final_uri = await self._resolve_unique_uri(candidate_uri, ctx=ctx) - - if final_uri != candidate_uri: - logger.info(f"[TreeBuilder] Resolved name conflict: {candidate_uri} -> {final_uri}") - else: - logger.info(f"[TreeBuilder] Finalizing from temp: {final_uri}") - - # 4. Move directory tree from temp to final location in AGFS - await self._move_temp_to_dest(viking_fs, temp_doc_uri, final_uri, ctx=ctx) - logger.info(f"[TreeBuilder] Moved temp tree: {temp_doc_uri} -> {final_uri}") - - # 5. Cleanup temporary root directory - try: - await viking_fs.delete_temp(temp_uri, ctx=ctx) - logger.info(f"[TreeBuilder] Cleaned up temp root: {temp_uri}") - except Exception as e: - logger.warning(f"[TreeBuilder] Failed to cleanup temp root: {e}") - - # 6. Enqueue to SemanticQueue for async semantic generation - if trigger_semantic: - try: - await self._enqueue_semantic_generation(final_uri, "resource", ctx=ctx) - logger.info(f"[TreeBuilder] Enqueued semantic generation for: {final_uri}") - except Exception as e: - logger.error( - f"[TreeBuilder] Failed to enqueue semantic generation: {e}", exc_info=True - ) - + final_uri = candidate_uri # 7. Return simple BuildingTree (no scanning needed) tree = BuildingTree( source_path=source_path, @@ -196,39 +159,11 @@ async def finalize_from_temp( tree._root_uri = final_uri # Create a minimal Context object for the root so that tree.root is not None - root_context = Context(uri=final_uri) + root_context = Context(uri=final_uri, temp_uri=temp_doc_uri) tree.add_context(root_context) return tree - async def _resolve_unique_uri( - self, uri: str, max_attempts: int = 100, ctx: Optional[RequestContext] = None - ) -> str: - """Return a URI that does not collide with an existing resource. - - If *uri* is free, return it unchanged. Otherwise append ``_1``, - ``_2``, … until a free name is found (like macOS Finder / Windows - Explorer). - """ - viking_fs = get_viking_fs() - - async def _exists(u: str) -> bool: - try: - await viking_fs.stat(u, ctx=ctx) - return True - except Exception: - return False - - if not await _exists(uri): - return uri - - for i in range(1, max_attempts + 1): - candidate = f"{uri}_{i}" - if not await _exists(candidate): - return candidate - - raise FileExistsError(f"Cannot resolve unique name for {uri} after {max_attempts} attempts") - async def _move_temp_to_dest( self, viking_fs, src_uri: str, dst_uri: str, ctx: RequestContext ) -> None: @@ -261,7 +196,7 @@ async def _ensure_parent_dirs(self, uri: str, ctx: RequestContext) -> None: logger.debug(f"Parent dir {parent_uri} may already exist: {e}") async def _enqueue_semantic_generation( - self, uri: str, context_type: str, ctx: RequestContext + self, uri: str, final_uri: str, context_type: str, ctx: RequestContext ) -> None: """ Enqueue a directory for semantic generation. @@ -284,32 +219,6 @@ async def _enqueue_semantic_generation( user_id=ctx.user.user_id, agent_id=ctx.user.agent_id, role=ctx.role.value, + target_uri=final_uri, ) await semantic_queue.enqueue(msg) - - async def _load_content(self, uri: str, content_type: str) -> str: - """Helper to load content with proper type handling""" - import json - - if content_type == "abstract": - result = await get_viking_fs().abstract(uri) - elif content_type == "overview": - result = await get_viking_fs().overview(uri) - elif content_type == "detail": - result = await get_viking_fs().read_file(uri) - else: - return "" - - # Handle different return types - if isinstance(result, str): - return result - elif isinstance(result, bytes): - return result.decode("utf-8") - elif hasattr(result, "to_dict") and not isinstance(result, list): - # Handle FindResult by converting to dict (skip lists) - return str(result.to_dict()) - elif isinstance(result, list): - # Handle list results - return json.dumps(result) - else: - return str(result) diff --git a/openviking/server/models.py b/openviking/server/models.py index 4cb7f967..26b6e5ec 100644 --- a/openviking/server/models.py +++ b/openviking/server/models.py @@ -39,6 +39,7 @@ class Response(BaseModel): "INVALID_URI": 400, "NOT_FOUND": 404, "ALREADY_EXISTS": 409, + "CONFLICT": 409, "PERMISSION_DENIED": 403, "UNAUTHENTICATED": 401, "RESOURCE_EXHAUSTED": 429, diff --git a/openviking/service/core.py b/openviking/service/core.py index 5764fed1..a19f3ba3 100644 --- a/openviking/service/core.py +++ b/openviking/service/core.py @@ -255,7 +255,9 @@ async def initialize(self) -> None: ) # Initialize processors - self._resource_processor = ResourceProcessor(vikingdb=self._vikingdb_manager) + self._resource_processor = ResourceProcessor( + vikingdb=self._vikingdb_manager, + ) self._skill_processor = SkillProcessor(vikingdb=self._vikingdb_manager) self._session_compressor = SessionCompressor(vikingdb=self._vikingdb_manager) diff --git a/openviking/session/compressor.py b/openviking/session/compressor.py index f5009897..0dd58d88 100644 --- a/openviking/session/compressor.py +++ b/openviking/session/compressor.py @@ -78,16 +78,57 @@ async def _index_memory(self, memory: Context, ctx: RequestContext) -> bool: await self.extractor._enqueue_semantic_for_parent(memory.uri, ctx) return True + def _convert_to_temp_uri( + self, target_uri: str, user_temp_uri: Optional[str], agent_temp_uri: Optional[str] + ) -> str: + """Convert target URI to temp URI for COW pattern. + + Args: + target_uri: Target URI (e.g., viking://user/... or viking://agent/...) + user_temp_uri: Temp user URI (if available) + agent_temp_uri: Temp agent URI (if available) + + Returns: + Converted temp URI, or original URI if no temp available + """ + if not user_temp_uri and not agent_temp_uri: + return target_uri + + # Convert user URI + if target_uri.startswith("viking://user/") and user_temp_uri: + # viking://user/{user_space}/memories/... -> {user_temp_uri}/memories/... + parts = target_uri.split("/") + if len(parts) >= 5: + # parts[0]="viking:", parts[1]="", parts[2]="user", parts[3]="{user_space}", parts[4:]="memories/..." + rest = "/".join(parts[4:]) + return f"{user_temp_uri}/{rest}" + + # Convert agent URI + if target_uri.startswith("viking://agent/") and agent_temp_uri: + # viking://agent/{agent_space}/memories/... -> {agent_temp_uri}/memories/... + parts = target_uri.split("/") + if len(parts) >= 5: + # parts[0]="viking:", parts[1]="", parts[2]="agent", parts[3]="{agent_space}", parts[4:]="memories/..." + rest = "/".join(parts[4:]) + return f"{agent_temp_uri}/{rest}" + + return target_uri + async def _merge_into_existing( self, candidate: CandidateMemory, target_memory: Context, viking_fs, ctx: RequestContext, + user_temp_uri: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> bool: """Merge candidate content into an existing memory file.""" try: - existing_content = await viking_fs.read_file(target_memory.uri, ctx=ctx) + # Convert target URI to temp URI for COW pattern + temp_uri = self._convert_to_temp_uri(target_memory.uri, user_temp_uri, agent_temp_uri) + + existing_content = await viking_fs.read_file(temp_uri, ctx=ctx) payload = await self.extractor._merge_memory_bundle( existing_abstract=target_memory.abstract, existing_overview=(target_memory.meta or {}).get("overview") or "", @@ -101,34 +142,41 @@ async def _merge_into_existing( if not payload: return False - await viking_fs.write_file(target_memory.uri, payload.content, ctx=ctx) + await viking_fs.write_file(temp_uri, payload.content, ctx=ctx) target_memory.abstract = payload.abstract target_memory.meta = {**(target_memory.meta or {}), "overview": payload.overview} - logger.info( - "Merged memory %s with abstract %s", target_memory.uri, target_memory.abstract - ) + logger.info("Merged memory %s with abstract %s", temp_uri, target_memory.abstract) target_memory.set_vectorize(Vectorize(text=payload.content)) - await self._index_memory(target_memory, ctx) + # Note: vectorization will be handled by SemanticQueue after directory switch + # await self._index_memory(target_memory, ctx) return True except Exception as e: logger.error(f"Failed to merge memory {target_memory.uri}: {e}") return False async def _delete_existing_memory( - self, memory: Context, viking_fs, ctx: RequestContext + self, + memory: Context, + viking_fs, + ctx: RequestContext, + user_temp_uri: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> bool: """Hard delete an existing memory file and clean up its vector record.""" try: - await viking_fs.rm(memory.uri, recursive=False, ctx=ctx) + # Convert target URI to temp URI for COW pattern + temp_uri = self._convert_to_temp_uri(memory.uri, user_temp_uri, agent_temp_uri) + + await viking_fs.rm(temp_uri, recursive=False, ctx=ctx) except Exception as e: - logger.error(f"Failed to delete memory file {memory.uri}: {e}") + logger.error(f"Failed to delete memory file {temp_uri}: {e}") return False try: # rm() already syncs vector deletion in most cases; keep this as a safe fallback. - await self.vikingdb.delete_uris(ctx, [memory.uri]) + await self.vikingdb.delete_uris(ctx, [temp_uri]) except Exception as e: - logger.warning(f"Failed to remove vector record for {memory.uri}: {e}") + logger.warning(f"Failed to remove vector record for {temp_uri}: {e}") return True async def extract_long_term_memories( @@ -137,9 +185,25 @@ async def extract_long_term_memories( user: Optional["UserIdentifier"] = None, session_id: Optional[str] = None, ctx: Optional[RequestContext] = None, + user_temp_uri: Optional[str] = None, + agent_temp_uri: Optional[str] = None, strict_extract_errors: bool = False, ) -> List[Context]: - """Extract long-term memories from messages.""" + """Extract long-term memories from messages. + + Args: + messages: Messages to extract from + user: User identifier + session_id: Session ID + ctx: Request context + user_temp_uri: Temp user URI (for COW pattern). If provided, user memories + will be written to this temp location. + agent_temp_uri: Temp agent URI (for COW pattern). If provided, agent memories + will be written to this temp location. + + Returns: + List of extracted memories + """ if not messages: return [] @@ -170,11 +234,19 @@ async def extract_long_term_memories( for candidate in candidates: # Profile: skip dedup, always merge if candidate.category in ALWAYS_MERGE_CATEGORIES: - memory = await self.extractor.create_memory(candidate, user, session_id, ctx=ctx) + memory = await self.extractor.create_memory( + candidate, + user, + session_id, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri, + ) if memory: memories.append(memory) stats.created += 1 - await self._index_memory(memory, ctx) + # Note: vectorization will be handled by SemanticQueue after directory switch + # await self._index_memory(memory, ctx) else: stats.skipped += 1 continue @@ -213,11 +285,11 @@ async def extract_long_term_memories( ) if skill_name: memory = await self.extractor._merge_skill_memory( - skill_name, candidate, ctx=ctx + skill_name, candidate, ctx=ctx, agent_temp_uri=agent_temp_uri ) elif tool_name: memory = await self.extractor._merge_tool_memory( - tool_name, candidate, ctx=ctx + tool_name, candidate, ctx=ctx, agent_temp_uri=agent_temp_uri ) else: logger.warning("No tool_name or skill_name found, skipping") @@ -226,7 +298,8 @@ async def extract_long_term_memories( if memory: memories.append(memory) stats.merged += 1 - await self._index_memory(memory, ctx) + # Note: vectorization will be handled by SemanticQueue after directory switch + # await self._index_memory(memory, ctx) continue # Dedup check for other categories @@ -256,7 +329,11 @@ async def extract_long_term_memories( for action in actions: if action.decision == MemoryActionDecision.DELETE: if viking_fs and await self._delete_existing_memory( - action.memory, viking_fs, ctx=ctx + action.memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri, ): stats.deleted += 1 else: @@ -264,13 +341,17 @@ async def extract_long_term_memories( elif action.decision == MemoryActionDecision.MERGE: if candidate.category in MERGE_SUPPORTED_CATEGORIES and viking_fs: if await self._merge_into_existing( - candidate, action.memory, viking_fs, ctx=ctx + candidate, + action.memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri, ): stats.merged += 1 else: stats.skipped += 1 else: - # events/cases don't support MERGE, treat as SKIP stats.skipped += 1 continue @@ -279,24 +360,60 @@ async def extract_long_term_memories( for action in actions: if action.decision == MemoryActionDecision.DELETE: if viking_fs and await self._delete_existing_memory( - action.memory, viking_fs, ctx=ctx + action.memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri, ): stats.deleted += 1 else: stats.skipped += 1 - memory = await self.extractor.create_memory(candidate, user, session_id, ctx=ctx) + memory = await self.extractor.create_memory( + candidate, + user, + session_id, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri, + ) if memory: memories.append(memory) stats.created += 1 - await self._index_memory(memory, ctx) + # Note: vectorization will be handled by SemanticQueue after directory switch + # await self._index_memory(memory, ctx) else: stats.skipped += 1 # Extract URIs used in messages, create relations used_uris = self._extract_used_uris(messages) if used_uris and memories: - await self._create_relations(memories, used_uris, ctx=ctx) + # Convert memory URIs from temp to target for relation creation + target_memories = [] + for memory in memories: + # Create a copy with target URI + target_uri = memory.uri + # If memory.uri is a temp URI, convert it to target URI + if user_temp_uri and memory.uri.startswith(user_temp_uri): + target_uri = memory.uri.replace( + user_temp_uri, f"viking://user/{ctx.user.user_space_name()}" + ) + elif agent_temp_uri and memory.uri.startswith(agent_temp_uri): + target_uri = memory.uri.replace( + agent_temp_uri, f"viking://agent/{ctx.user.agent_space_name()}" + ) + + # Create a new Context with target URI for relation creation + target_memory = Context( + uri=target_uri, + context_type=memory.context_type, + abstract=memory.abstract, + meta=memory.meta, + ) + target_memories.append(target_memory) + + await self._create_relations(target_memories, used_uris, ctx=ctx) logger.info( f"Memory extraction: created={stats.created}, " diff --git a/openviking/session/memory_deduplicator.py b/openviking/session/memory_deduplicator.py index c119ecb8..ee28ff33 100644 --- a/openviking/session/memory_deduplicator.py +++ b/openviking/session/memory_deduplicator.py @@ -117,8 +117,19 @@ async def deduplicate( async def _find_similar_memories( self, candidate: CandidateMemory, + user_temp_uri: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> List[Context]: - """Find similar existing memories using vector search.""" + """Find similar existing memories using vector search. + + Args: + candidate: Candidate memory + user_temp_uri: Temp user URI (for COW pattern) + agent_temp_uri: Temp agent URI (for COW pattern) + + Returns: + List of similar memories with temp URIs (if temp URIs provided) + """ if not self.embedder: return [] @@ -127,6 +138,7 @@ async def _find_similar_memories( embed_result: EmbedResult = self.embedder.embed(query_text) query_vector = embed_result.dense_vector + # Search target URI (not temp URI) because vectors are stored for target URIs category_uri_prefix = self._category_uri_prefix(candidate.category.value, candidate.user) owner = candidate.user @@ -177,6 +189,25 @@ async def _find_similar_memories( if context: # Keep retrieval score for later destructive-action guardrails. context.meta = {**(context.meta or {}), "_dedup_score": score} + + # Convert target URI to temp URI (for COW pattern) + if user_temp_uri or agent_temp_uri: + original_uri = context.uri + # Convert user URI + if user_temp_uri and original_uri.startswith("viking://user/"): + parts = original_uri.split("/") + if len(parts) >= 5: + rest = "/".join(parts[4:]) + context.uri = f"{user_temp_uri}/{rest}" + logger.debug(f"Converted URI: {original_uri} -> {context.uri}") + # Convert agent URI + elif agent_temp_uri and original_uri.startswith("viking://agent/"): + parts = original_uri.split("/") + if len(parts) >= 5: + rest = "/".join(parts[4:]) + context.uri = f"{agent_temp_uri}/{rest}" + logger.debug(f"Converted URI: {original_uri} -> {context.uri}") + similar.append(context) logger.debug("Dedup similar memories after threshold=%d", len(similar)) return similar diff --git a/openviking/session/memory_extractor.py b/openviking/session/memory_extractor.py index 4130880e..cf6b3606 100644 --- a/openviking/session/memory_extractor.py +++ b/openviking/session/memory_extractor.py @@ -219,7 +219,6 @@ async def extract( context: dict, user: UserIdentifier, session_id: str, - *, strict: bool = False, ) -> List[CandidateMemory]: """Extract memory candidates from messages. @@ -408,8 +407,21 @@ async def create_memory( user: str, session_id: str, ctx: RequestContext, + user_temp_uri: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> Optional[Context]: - """Create Context object from candidate and persist to AGFS as .md file.""" + """Create Context object from candidate and persist to AGFS as .md file. + + Args: + candidate: Candidate memory to create + user: User identifier + session_id: Session ID + ctx: Request context + user_temp_uri: Temp user URI (for COW pattern). If provided, user memories + will be written to this temp location. + agent_temp_uri: Temp agent URI (for COW pattern). If provided, agent memories + will be written to this temp location. + """ viking_fs = get_viking_fs() if not viking_fs: logger.warning("VikingFS not available, skipping memory creation") @@ -419,14 +431,22 @@ async def create_memory( # Special handling for profile: append to profile.md if candidate.category == MemoryCategory.PROFILE: - payload = await self._append_to_profile(candidate, viking_fs, ctx=ctx) + payload = await self._append_to_profile( + candidate, viking_fs, ctx=ctx, user_temp_uri=user_temp_uri + ) if not payload: return None user_space = ctx.user.user_space_name() - memory_uri = f"viking://user/{user_space}/memories/profile.md" + # Use temp user URI if provided (for COW pattern) + if user_temp_uri: + memory_uri = f"{user_temp_uri}/memories/profile.md" + parent_uri = f"{user_temp_uri}/memories" + else: + memory_uri = f"viking://user/{user_space}/memories/profile.md" + parent_uri = f"viking://user/{user_space}/memories" memory = Context( uri=memory_uri, - parent_uri=f"viking://user/{user_space}/memories", + parent_uri=parent_uri, is_leaf=True, abstract=payload.abstract, context_type=ContextType.MEMORY.value, @@ -447,9 +467,17 @@ async def create_memory( MemoryCategory.ENTITIES, MemoryCategory.EVENTS, ]: - parent_uri = f"viking://user/{ctx.user.user_space_name()}/{cat_dir}" + # Use temp user URI if provided (for COW pattern) + if user_temp_uri: + parent_uri = f"{user_temp_uri}/{cat_dir}" + else: + parent_uri = f"viking://user/{ctx.user.user_space_name()}/{cat_dir}" else: # CASES, PATTERNS - parent_uri = f"viking://agent/{ctx.user.agent_space_name()}/{cat_dir}" + # Use temp agent URI if provided (for COW pattern) + if agent_temp_uri: + parent_uri = f"{agent_temp_uri}/{cat_dir}" + else: + parent_uri = f"viking://agent/{ctx.user.agent_space_name()}/{cat_dir}" # Generate file URI (store directly as .md file, no directory creation) memory_id = f"mem_{str(uuid4())}" @@ -485,9 +513,14 @@ async def _append_to_profile( candidate: CandidateMemory, viking_fs, ctx: RequestContext, + user_temp_uri: Optional[str] = None, ) -> Optional[MergedMemoryPayload]: """Update user profile - always merge with existing content.""" - uri = f"viking://user/{ctx.user.user_space_name()}/memories/profile.md" + # Use temp user URI if provided (for COW pattern) + if user_temp_uri: + uri = f"{user_temp_uri}/memories/profile.md" + else: + uri = f"viking://user/{ctx.user.user_space_name()}/memories/profile.md" existing = "" try: existing = await viking_fs.read_file(uri, ctx=ctx) or "" @@ -587,7 +620,11 @@ async def _merge_memory_bundle( return None async def _merge_tool_memory( - self, tool_name: str, candidate: CandidateMemory, ctx: "RequestContext" + self, + tool_name: str, + candidate: CandidateMemory, + ctx: "RequestContext", + agent_temp_uri: Optional[str] = None, ) -> Optional[Context]: """合并 Tool Memory,统计数据用 Python 累加""" if not tool_name or not tool_name.strip(): @@ -595,7 +632,11 @@ async def _merge_tool_memory( return None agent_space = ctx.user.agent_space_name() - uri = f"viking://agent/{agent_space}/memories/tools/{tool_name}.md" + # Use temp agent URI if provided (for COW pattern) + if agent_temp_uri: + uri = f"{agent_temp_uri}/memories/tools/{tool_name}.md" + else: + uri = f"viking://agent/{agent_space}/memories/tools/{tool_name}.md" viking_fs = get_viking_fs() if not viking_fs: @@ -651,7 +692,7 @@ async def _merge_tool_memory( tool_name, merged_stats, new_guidelines, fields=new_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_tool_context(uri, candidate, ctx) + return self._create_tool_context(uri, candidate, ctx, agent_temp_uri=agent_temp_uri) existing_stats = self._parse_tool_statistics(existing) merged_stats = self._merge_tool_statistics(existing_stats, new_stats) @@ -709,7 +750,9 @@ async def _merge_tool_memory( tool_name, merged_stats, merged_guidelines, fields=merged_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_tool_context(uri, candidate, ctx, abstract_override=abstract_override) + return self._create_tool_context( + uri, candidate, ctx, abstract_override=abstract_override, agent_temp_uri=agent_temp_uri + ) async def _enqueue_semantic_for_parent(self, file_uri: str, ctx: "RequestContext") -> None: """Enqueue semantic generation for parent directory.""" @@ -1125,12 +1168,18 @@ def _create_tool_context( candidate: CandidateMemory, ctx: "RequestContext", abstract_override: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> Context: """创建 Tool Memory 的 Context 对象""" agent_space = ctx.user.agent_space_name() + # Use temp agent URI if provided (for COW pattern) + if agent_temp_uri: + parent_uri = f"{agent_temp_uri}/memories/tools" + else: + parent_uri = f"viking://agent/{agent_space}/memories/tools" return Context( uri=uri, - parent_uri=f"viking://agent/{agent_space}/memories/tools", + parent_uri=parent_uri, is_leaf=True, abstract=abstract_override or candidate.abstract, context_type=ContextType.MEMORY.value, @@ -1162,7 +1211,11 @@ def _extract_tool_guidelines(self, content: str) -> str: return content.strip() async def _merge_skill_memory( - self, skill_name: str, candidate: CandidateMemory, ctx: "RequestContext" + self, + skill_name: str, + candidate: CandidateMemory, + ctx: "RequestContext", + agent_temp_uri: Optional[str] = None, ) -> Optional[Context]: """合并 Skill Memory,统计数据用 Python 累加""" if not skill_name or not skill_name.strip(): @@ -1170,7 +1223,11 @@ async def _merge_skill_memory( return None agent_space = ctx.user.agent_space_name() - uri = f"viking://agent/{agent_space}/memories/skills/{skill_name}.md" + # Use temp agent URI if provided (for COW pattern) + if agent_temp_uri: + uri = f"{agent_temp_uri}/memories/skills/{skill_name}.md" + else: + uri = f"viking://agent/{agent_space}/memories/skills/{skill_name}.md" viking_fs = get_viking_fs() if not viking_fs: @@ -1239,7 +1296,7 @@ async def _merge_skill_memory( skill_name, merged_stats, new_guidelines, fields=new_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_skill_context(uri, candidate, ctx) + return self._create_skill_context(uri, candidate, ctx, agent_temp_uri=agent_temp_uri) existing_stats = self._parse_skill_statistics(existing) merged_stats = self._merge_skill_statistics(existing_stats, new_stats) @@ -1301,7 +1358,9 @@ async def _merge_skill_memory( skill_name, merged_stats, merged_guidelines, fields=merged_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_skill_context(uri, candidate, ctx, abstract_override=abstract_override) + return self._create_skill_context( + uri, candidate, ctx, abstract_override=abstract_override, agent_temp_uri=agent_temp_uri + ) def _compute_skill_statistics_derived(self, stats: dict) -> dict: """计算 Skill 派生统计数据(成功率)""" @@ -1453,12 +1512,18 @@ def _create_skill_context( candidate: CandidateMemory, ctx: "RequestContext", abstract_override: Optional[str] = None, + agent_temp_uri: Optional[str] = None, ) -> Context: """创建 Skill Memory 的 Context 对象""" agent_space = ctx.user.agent_space_name() + # Use temp agent URI if provided (for COW pattern) + if agent_temp_uri: + parent_uri = f"{agent_temp_uri}/memories/skills" + else: + parent_uri = f"viking://agent/{agent_space}/memories/skills" return Context( uri=uri, - parent_uri=f"viking://agent/{agent_space}/memories/skills", + parent_uri=parent_uri, is_leaf=True, abstract=abstract_override or candidate.abstract, context_type=ContextType.MEMORY.value, diff --git a/openviking/session/session.py b/openviking/session/session.py index 9f2d7766..0a51aa60 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -7,9 +7,10 @@ import json import re +import time from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from uuid import uuid4 from openviking.message import Message, Part @@ -91,6 +92,13 @@ def __init__( self._stats: SessionStats = SessionStats() self._loaded = False + # Temp URI management for COW pattern + self._temp_base_uri: Optional[str] = None + self._session_temp_uri: Optional[str] = None + self._user_temp_uri: Optional[str] = None + self._agent_temp_uri: Optional[str] = None + self._temp_created_at: Optional[float] = None + logger.info(f"Session created: {self.session_id} for user {self.user}") async def load(self): @@ -294,69 +302,214 @@ def commit(self) -> Dict[str, Any]: logger.info(f"Session {self.session_id} committed") return result + def _create_temp_uris(self) -> Tuple[str, str, str, str]: + """Create temp URIs for session, user and agent directories. + + Temp URI structure matches target URI structure for Semantic DAG recursive processing: + - Session: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/session/{user_space}/{session_id}/ + - User: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/user/{user_space}/ + - Agent: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/agent/{agent_space}/ + + Returns: + (temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri) + """ + temp_base_uri = ( + f"viking://temp/session/" + f"{self.user.user_space_name()}/" + f"{self.session_id}/" + f"commit_{uuid4().hex[:8]}" + ) + + # Match target URI structure for Semantic DAG recursive processing + session_temp_uri = ( + f"{temp_base_uri}/session/{self.user.user_space_name()}/{self.session_id}" + ) + user_temp_uri = f"{temp_base_uri}/user/{self.user.user_space_name()}" + agent_temp_uri = f"{temp_base_uri}/agent/{self.user.agent_space_name()}" + + self._temp_base_uri = temp_base_uri + self._session_temp_uri = session_temp_uri + self._user_temp_uri = user_temp_uri + self._agent_temp_uri = agent_temp_uri + self._temp_created_at = time.time() + + return temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri + + async def _cleanup_temp_uris(self) -> None: + """Clean up all temp directories after commit.""" + if self._temp_base_uri: + try: + await self._viking_fs.delete_temp(self._temp_base_uri, ctx=self.ctx) + logger.info(f"Cleaned up temp base: {self._temp_base_uri}") + except Exception as e: + logger.warning(f"Failed to cleanup temp {self._temp_base_uri}: {e}") + finally: + self._temp_base_uri = None + self._session_temp_uri = None + self._user_temp_uri = None + self._agent_temp_uri = None + self._temp_created_at = None + async def commit_async(self) -> Dict[str, Any]: - """Async commit session: create archive, extract memories, persist.""" + """Async commit session with Copy-on-Write pattern. + + Process: + 1. Copy: Copy existing session, user and agent directories to temp + 2. Write: Make all changes in temp + 3. Semantic: Trigger semantic processing + 4. Switch: Atomically switch from temp to target (handled by SemanticProcessor) + """ result = { "session_id": self.session_id, "status": "committed", "memories_extracted": 0, "active_count_updated": 0, "archived": False, + "temp_base_uri": None, + "session_temp_uri": None, + "user_temp_uri": None, + "agent_temp_uri": None, + "semantic_msg_id": None, "stats": None, } + if not self._messages: return result - # 1. Archive current messages - self._compression.compression_index += 1 - messages_to_archive = self._messages.copy() - - summary = await self._generate_archive_summary_async(messages_to_archive) - archive_abstract = self._extract_abstract_from_summary(summary) - archive_overview = summary + # ========== Phase 1: Copy ========== + temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri = self._create_temp_uris() + result["temp_base_uri"] = temp_base_uri + result["session_temp_uri"] = session_temp_uri + result["user_temp_uri"] = user_temp_uri + result["agent_temp_uri"] = agent_temp_uri - await self._write_archive_async( - index=self._compression.compression_index, - messages=messages_to_archive, - abstract=archive_abstract, - overview=archive_overview, - ) + try: + # 1.1 Copy existing session to temp + logger.info(f"Copying session {self.session_id} to temp: {session_temp_uri}") + try: + await self._viking_fs.copy_directory( + from_uri=self._session_uri, + to_uri=session_temp_uri, + ctx=self.ctx, + ) + logger.info(f"Session copied to temp: {session_temp_uri}") + except Exception as e: + if "not found" in str(e).lower(): + logger.info(f"Session {self.session_id} not found, creating new temp") + await self._viking_fs.mkdir(session_temp_uri, exist_ok=True, ctx=self.ctx) + else: + raise + + # 1.2 Copy existing user directory to temp + user_uri = f"viking://user/{self.user.user_space_name()}" + logger.info(f"Copying user directory to temp: {user_temp_uri}") + try: + await self._viking_fs.copy_directory( + from_uri=user_uri, + to_uri=user_temp_uri, + ctx=self.ctx, + ) + logger.info(f"User directory copied to temp: {user_temp_uri}") + except Exception as e: + if "not found" in str(e).lower(): + logger.info("User directory not found, creating new temp") + await self._viking_fs.mkdir(user_temp_uri, exist_ok=True, ctx=self.ctx) + else: + raise + + # 1.3 Copy existing agent directory to temp + agent_uri = f"viking://agent/{self.user.agent_space_name()}" + logger.info(f"Copying agent directory to temp: {agent_temp_uri}") + try: + await self._viking_fs.copy_directory( + from_uri=agent_uri, + to_uri=agent_temp_uri, + ctx=self.ctx, + ) + logger.info(f"Agent directory copied to temp: {agent_temp_uri}") + except Exception as e: + if "not found" in str(e).lower(): + logger.info("Agent directory not found, creating new temp") + await self._viking_fs.mkdir(agent_temp_uri, exist_ok=True, ctx=self.ctx) + else: + raise - self._compression.original_count += len(messages_to_archive) - result["archived"] = True + except Exception as e: + logger.error(f"Failed to copy directories to temp: {e}") + await self._cleanup_temp_uris() + raise - self._messages.clear() - logger.info( - f"Archived: {len(messages_to_archive)} messages → history/archive_{self._compression.compression_index:03d}/" - ) + # ========== Phase 2: Write (all changes in temp) ========== + try: + # 2.1 Archive current messages to temp + self._compression.compression_index += 1 + messages_to_archive = self._messages.copy() - # 2. Extract long-term memories - if self._session_compressor: - logger.info( - f"Starting memory extraction from {len(messages_to_archive)} archived messages" - ) - memories = await self._session_compressor.extract_long_term_memories( + await self._write_archive_to_temp( + temp_uri=session_temp_uri, + index=self._compression.compression_index, messages=messages_to_archive, user=self.user, session_id=self.session_id, ctx=self.ctx, strict_extract_errors=True, ) - logger.info(f"Extracted {len(memories)} memories") - result["memories_extracted"] = len(memories) - self._stats.memories_extracted += len(memories) - # 3. Write current messages to AGFS - await self._write_to_agfs_async(self._messages) + self._compression.original_count += len(messages_to_archive) + result["archived"] = True - # 4. Create relations - await self._write_relations_async() + self._messages.clear() + logger.info( + f"Archived: {len(messages_to_archive)} messages → " + f"{session_temp_uri}/history/archive_{self._compression.compression_index:03d}/" + ) - # 5. Update active_count - active_count_updated = await self._update_active_counts_async() - result["active_count_updated"] = active_count_updated + # 2.2 Extract long-term memories (to temp user and agent directories) + if self._session_compressor: + logger.info( + f"Starting memory extraction from {len(messages_to_archive)} archived messages" + ) + memories = await self._session_compressor.extract_long_term_memories( + messages=messages_to_archive, + user=self.user, + session_id=self.session_id, + ctx=self.ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri, + ) + logger.info(f"Extracted {len(memories)} memories to temp directories") + result["memories_extracted"] = len(memories) + self._stats.memories_extracted += len(memories) - # 6. Update statistics + # 2.3 Write current messages to temp + await self._write_messages_to_temp(session_temp_uri, self._messages) + + logger.info(f"Session changes written to temp: {session_temp_uri}") + # 2.5 Update active_count + active_count_updated = await self._update_active_counts_async() + result["active_count_updated"] = active_count_updated + except Exception as e: + logger.error(f"Failed to write changes to temp: {e}") + await self._cleanup_temp_uris() + raise + + # ========== Phase 3: Semantic = Switch =========== + try: + semantic_msg_ids = await self._enqueue_to_semantic_queue( + session_temp_uri=session_temp_uri, + user_temp_uri=user_temp_uri, + agent_temp_uri=agent_temp_uri, + ) + + logger.info(f"Session, user, agent enqueued to SemanticQueue: {semantic_msg_ids}") + result["semantic_msg_ids"] = semantic_msg_ids + + except Exception as e: + logger.error(f"Failed to enqueue to SemanticQueue: {e}") + await self._cleanup_temp_uris() + raise + + # ========== Update statistics ========== self._stats.compression_count = self._compression.compression_index result["stats"] = { "total_turns": self._stats.total_turns, @@ -366,7 +519,7 @@ async def commit_async(self) -> Dict[str, Any]: } self._stats.total_tokens = 0 - logger.info(f"Session {self.session_id} committed (async)") + logger.info(f"Session {self.session_id} committed (async with COW pattern)") return result def _update_active_counts(self) -> int: @@ -550,6 +703,105 @@ def _write_archive( logger.debug(f"Written archive: {archive_uri}") + async def _write_archive_to_temp( + self, + temp_uri: str, + index: int, + messages: List[Message], + ) -> None: + """Write archive to temp directory. + + Note: .abstract.md and .overview.md will be generated by Semantic DAG. + """ + archive_uri = f"{temp_uri}/history/archive_{index:03d}" + + lines = [m.to_jsonl() for m in messages] + await self._viking_fs.write_file( + uri=f"{archive_uri}/messages.jsonl", + content="\n".join(lines) + "\n", + ctx=self.ctx, + ) + + # Note: .abstract.md and .overview.md will be generated by Semantic DAG + # No need to manually create them here + + logger.debug(f"Written archive to temp: {archive_uri}") + + async def _write_messages_to_temp(self, temp_uri: str, messages: List[Message]) -> None: + """Write current messages to temp directory.""" + lines = [m.to_jsonl() for m in messages] + content = "\n".join(lines) + "\n" if lines else "" + + await self._viking_fs.write_file( + uri=f"{temp_uri}/messages.jsonl", + content=content, + ctx=self.ctx, + ) + + async def _enqueue_to_semantic_queue( + self, + session_temp_uri: str, + user_temp_uri: str, + agent_temp_uri: str, + ) -> List[str]: + """Enqueue session, user, and agent to SemanticQueue for L0/L1 generation. + + The SemanticProcessor will handle: + 1. Generate L0/L1 for session, user and agent directories + 2. Atomically switch temp URIs to target URIs + 3. Create usage relations + 4. Clean up temp URIs + + Returns: + List of message IDs [session_msg_id, user_msg_id, agent_msg_id] + """ + from openviking.storage.queuefs import SemanticMsg, get_queue_manager + + queue_manager = get_queue_manager() + semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC, allow_create=True) + + user_target_uri = f"viking://user/{self.user.user_space_name()}" + agent_target_uri = f"viking://agent/{self.user.agent_space_name()}" + + session_msg = SemanticMsg( + uri=session_temp_uri, + context_type="memory", + target_uri=self._session_uri, + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, + agent_id=self.ctx.user.agent_id, + role=self.ctx.role.value, + recursive=True, + ) + + user_msg = SemanticMsg( + uri=user_temp_uri, + context_type="memory", + target_uri=user_target_uri, + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, + agent_id=self.ctx.user.agent_id, + role=self.ctx.role.value, + recursive=True, + ) + + agent_msg = SemanticMsg( + uri=agent_temp_uri, + context_type="memory", + target_uri=agent_target_uri, + account_id=self.ctx.account_id, + user_id=self.ctx.user.user_id, + agent_id=self.ctx.user.agent_id, + role=self.ctx.role.value, + recursive=True, + ) + + await semantic_queue.enqueue(session_msg) + await semantic_queue.enqueue(user_msg) + await semantic_queue.enqueue(agent_msg) + + return [session_msg.id, user_msg.id, agent_msg.id] + async def _write_archive_async( self, index: int, diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index 90c61d07..817eb02e 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -263,3 +263,12 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, traceback.print_exc() self.report_error(str(e), data) return None + finally: + if embedding_msg and embedding_msg.semantic_msg_id: + from openviking.storage.queuefs.embedding_tracker import EmbeddingTaskTracker + + tracker = EmbeddingTaskTracker.get_instance() + try: + await tracker.decrement(embedding_msg.semantic_msg_id) + except Exception as tracker_err: + logger.warning(f"Failed to decrement embedding tracker: {tracker_err}") diff --git a/openviking/storage/queuefs/__init__.py b/openviking/storage/queuefs/__init__.py index b73a01d7..87514a83 100644 --- a/openviking/storage/queuefs/__init__.py +++ b/openviking/storage/queuefs/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from .embedding_msg import EmbeddingMsg from .embedding_queue import EmbeddingQueue +from .embedding_tracker import EmbeddingTaskTracker from .named_queue import NamedQueue, QueueError, QueueStatus from .queue_manager import QueueManager, get_queue_manager, init_queue_manager from .semantic_dag import SemanticDagExecutor @@ -18,6 +19,7 @@ "QueueError", "EmbeddingQueue", "EmbeddingMsg", + "EmbeddingTaskTracker", "SemanticQueue", "SemanticDagExecutor", "SemanticMsg", diff --git a/openviking/storage/queuefs/embedding_msg.py b/openviking/storage/queuefs/embedding_msg.py index 19b8381e..94e93a2c 100644 --- a/openviking/storage/queuefs/embedding_msg.py +++ b/openviking/storage/queuefs/embedding_msg.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from uuid import uuid4 @@ -10,11 +10,18 @@ class EmbeddingMsg: message: Union[str, List[Dict[str, Any]]] context_data: Dict[str, Any] + semantic_msg_id: Optional[str] = None - def __init__(self, message: Union[str, List[Dict[str, Any]]], context_data: Dict[str, Any]): + def __init__( + self, + message: Union[str, List[Dict[str, Any]]], + context_data: Dict[str, Any], + semantic_msg_id: Optional[str] = None, + ): self.id = str(uuid4()) self.message = message self.context_data = context_data + self.semantic_msg_id = semantic_msg_id def to_dict(self) -> Dict[str, Any]: """Convert embedding message to dictionary format.""" @@ -30,6 +37,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "EmbeddingMsg": obj = EmbeddingMsg( message=data["message"], context_data=data["context_data"], + semantic_msg_id=data.get("semantic_msg_id"), ) obj.id = data.get("id", obj.id) return obj diff --git a/openviking/storage/queuefs/embedding_tracker.py b/openviking/storage/queuefs/embedding_tracker.py new file mode 100644 index 00000000..4149d8d1 --- /dev/null +++ b/openviking/storage/queuefs/embedding_tracker.py @@ -0,0 +1,193 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Embedding Task Tracker for tracking embedding task completion status.""" + +import asyncio +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional + +from openviking_cli.utils.logger import get_logger + +logger = get_logger(__name__) + + +@dataclass +class EmbeddingTaskTracker: + """Track embedding task completion status for each SemanticMsg. + + This tracker maintains a global registry of embedding tasks associated + with each SemanticMsg. When all embedding tasks for a SemanticMsg are + completed, it triggers the registered callback and removes the entry. + """ + + _instance: Optional["EmbeddingTaskTracker"] = None + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) + _tasks: Dict[str, Dict[str, Any]] = field(default_factory=dict) + + @classmethod + def get_instance(cls) -> "EmbeddingTaskTracker": + """Get the singleton instance of EmbeddingTaskTracker.""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + async def register( + self, + semantic_msg_id: str, + total_count: int, + on_complete: Optional[Callable[[], Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Register a SemanticMsg with its total embedding task count. + + Args: + semantic_msg_id: The ID of the SemanticMsg + total_count: Total number of embedding tasks for this SemanticMsg + on_complete: Optional callback when all tasks complete + metadata: Optional metadata to store with the task + """ + async with self._lock: + self._tasks[semantic_msg_id] = { + "remaining": total_count, + "total": total_count, + "on_complete": on_complete, + "metadata": metadata or {}, + } + logger.info( + f"Registered embedding tracker for SemanticMsg {semantic_msg_id}: " + f"{total_count} tasks" + ) + + if total_count <= 0 and on_complete: + del self._tasks[semantic_msg_id] + logger.info( + f"No embedding tasks for SemanticMsg {semantic_msg_id}, " + f"triggering on_complete immediately" + ) + + if total_count <= 0 and on_complete: + try: + result = on_complete() + if asyncio.iscoroutine(result): + await result + except Exception as e: + logger.error( + f"Error in completion callback for {semantic_msg_id}: {e}", + exc_info=True, + ) + + async def increment(self, semantic_msg_id: str) -> Optional[int]: + """Increment the remaining task count for a SemanticMsg. + + This method should be called when a new embedding task is added + for an already registered SemanticMsg. + + Args: + semantic_msg_id: The ID of the SemanticMsg + + Returns: + The remaining count after increment, or None if not found + """ + async with self._lock: + if semantic_msg_id not in self._tasks: + return None + + task_info = self._tasks[semantic_msg_id] + task_info["remaining"] += 1 + task_info["total"] += 1 + remaining = task_info["remaining"] + + return remaining + + async def decrement(self, semantic_msg_id: str) -> Optional[int]: + """Decrement the remaining task count for a SemanticMsg. + + This method should be called when an embedding task is completed. + When the count reaches zero, the registered callback is executed + and the entry is removed from the tracker. + + Args: + semantic_msg_id: The ID of the SemanticMsg + + Returns: + The remaining count after decrement, or None if not found + """ + on_complete = None + + async with self._lock: + if semantic_msg_id not in self._tasks: + return None + + task_info = self._tasks[semantic_msg_id] + task_info["remaining"] -= 1 + remaining = task_info["remaining"] + + if remaining <= 0: + on_complete = task_info.get("on_complete") + + del self._tasks[semantic_msg_id] + logger.info( + f"All embedding tasks({task_info['total']}) completed for SemanticMsg {semantic_msg_id}" + ) + + if on_complete: + try: + result = on_complete() + if asyncio.iscoroutine(result): + await result + except Exception as e: + logger.error( + f"Error in completion callback for {semantic_msg_id}: {e}", + exc_info=True, + ) + return remaining + + async def get_status(self, semantic_msg_id: str) -> Optional[Dict[str, Any]]: + """Get the current status of a SemanticMsg's embedding tasks. + + Args: + semantic_msg_id: The ID of the SemanticMsg + + Returns: + Dict with 'remaining', 'total', 'metadata' or None if not found + """ + async with self._lock: + if semantic_msg_id not in self._tasks: + return None + task_info = self._tasks[semantic_msg_id] + return { + "remaining": task_info["remaining"], + "total": task_info["total"], + "metadata": task_info.get("metadata", {}), + } + + async def remove(self, semantic_msg_id: str) -> bool: + """Remove a SemanticMsg from the tracker. + + Args: + semantic_msg_id: The ID of the SemanticMsg + + Returns: + True if removed, False if not found + """ + async with self._lock: + if semantic_msg_id in self._tasks: + del self._tasks[semantic_msg_id] + return True + return False + + async def get_all_tracked(self) -> Dict[str, Dict[str, Any]]: + """Get all currently tracked SemanticMsgs. + + Returns: + Dict of semantic_msg_id -> task info + """ + async with self._lock: + return { + msg_id: { + "remaining": info["remaining"], + "total": info["total"], + "metadata": info.get("metadata", {}), + } + for msg_id, info in self._tasks.items() + } diff --git a/openviking/storage/queuefs/semantic_dag.py b/openviking/storage/queuefs/semantic_dag.py index 0307521f..397250c5 100644 --- a/openviking/storage/queuefs/semantic_dag.py +++ b/openviking/storage/queuefs/semantic_dag.py @@ -48,11 +48,19 @@ def __init__( context_type: str, max_concurrent_llm: int, ctx: RequestContext, + incremental_update: bool = False, + target_uri: Optional[str] = None, + semantic_msg_id: Optional[str] = None, + recursive: bool = True, ): self._processor = processor self._context_type = context_type self._max_concurrent_llm = max_concurrent_llm self._ctx = ctx + self._incremental_update = incremental_update + self._target_uri = target_uri + self._semantic_msg_id = semantic_msg_id + self._recursive = recursive self._llm_sem = asyncio.Semaphore(max_concurrent_llm) self._viking_fs = get_viking_fs() self._nodes: Dict[str, DirNode] = {} @@ -79,8 +87,10 @@ async def _dispatch_dir(self, dir_uri: str, parent_uri: Optional[str]) -> None: children_dirs, file_paths = await self._list_dir(dir_uri) file_index = {path: idx for idx, path in enumerate(file_paths)} child_index = {path: idx for idx, path in enumerate(children_dirs)} - pending = len(children_dirs) + len(file_paths) - + if self._recursive: + pending = len(children_dirs) + len(file_paths) + else: + pending = len(file_paths) node = DirNode( uri=dir_uri, children_dirs=children_dirs, @@ -108,8 +118,10 @@ async def _dispatch_dir(self, dir_uri: str, parent_uri: Optional[str]) -> None: self._stats.in_progress_nodes += 1 asyncio.create_task(self._file_summary_task(dir_uri, file_path)) - for child_uri in children_dirs: - asyncio.create_task(self._dispatch_dir(child_uri, dir_uri)) + if children_dirs: + if self._recursive: + for child_uri in children_dirs: + asyncio.create_task(self._dispatch_dir(child_uri, dir_uri)) except Exception as e: logger.error(f"Failed to dispatch directory {dir_uri}: {e}", exc_info=True) if parent_uri: @@ -141,13 +153,107 @@ async def _list_dir(self, uri: str) -> tuple[list[str], list[str]]: return children_dirs, file_paths + def _get_target_file_path(self, current_uri: str) -> Optional[str]: + if not self._incremental_update or not self._target_uri or not self._root_uri: + logger.warning( + f"Invalid target_uri or root_uri for incremental update: target_uri={self._target_uri}, root_uri={self._root_uri}" + ) + return None + try: + relative_path = current_uri[len(self._root_uri) :] + if relative_path.startswith("/"): + relative_path = relative_path[1:] + return f"{self._target_uri}/{relative_path}" if relative_path else self._target_uri + except Exception: + return None + + async def _check_file_content_changed(self, file_path: str) -> bool: + target_path = self._get_target_file_path(file_path) + if not target_path: + return True + try: + current_content = await self._viking_fs.read_file(file_path, ctx=self._ctx) + target_content = await self._viking_fs.read_file(target_path, ctx=self._ctx) + return current_content != target_content + except Exception: + return True + + async def _read_existing_summary(self, file_path: str) -> Optional[Dict[str, str]]: + target_path = self._get_target_file_path(file_path) + if not target_path: + return None + try: + vector_store = self._viking_fs._get_vector_store() + if not vector_store: + return None + records = await vector_store.get_context_by_uri( + account_id=self._ctx.account_id, + uri=target_path, + limit=1, + ) + if records and len(records) > 0: + record = records[0] + summary = record.get("abstract", "") + if summary: + file_name = file_path.split("/")[-1] + return {"name": file_name, "summary": summary} + except Exception: + pass + return None + + async def _check_dir_children_changed( + self, dir_uri: str, current_files: List[str], current_dirs: List[str] + ) -> bool: + target_path = self._get_target_file_path(dir_uri) + if not target_path: + return True + try: + target_dirs, target_files = await self._list_dir(target_path) + current_file_names = {f.split("/")[-1] for f in current_files} + target_file_names = {f.split("/")[-1] for f in target_files} + if current_file_names != target_file_names: + return True + current_dir_names = {d.split("/")[-1] for d in current_dirs} + target_dir_names = {d.split("/")[-1] for d in target_dirs} + if current_dir_names != target_dir_names: + return True + for current_file in current_files: + if await self._check_file_content_changed(current_file): + return True + return False + except Exception: + return True + + async def _read_existing_overview_abstract( + self, dir_uri: str + ) -> tuple[Optional[str], Optional[str]]: + target_path = self._get_target_file_path(dir_uri) + if not target_path: + return None, None + try: + overview = await self._viking_fs.read_file(f"{target_path}/.overview.md", ctx=self._ctx) + abstract = await self._viking_fs.read_file(f"{target_path}/.abstract.md", ctx=self._ctx) + return overview, abstract + except Exception: + return None, None + async def _file_summary_task(self, parent_uri: str, file_path: str) -> None: """Generate file summary and notify parent completion.""" + file_name = file_path.split("/")[-1] + need_vectorize = True try: - summary_dict = await self._processor._generate_single_file_summary( - file_path, llm_sem=self._llm_sem, ctx=self._ctx - ) + summary_dict = None + if self._incremental_update: + content_changed = await self._check_file_content_changed(file_path) + + if not content_changed: + summary_dict = await self._read_existing_summary(file_path) + need_vectorize = False + if summary_dict is None: + summary_dict = await self._processor._generate_single_file_summary( + file_path, llm_sem=self._llm_sem, ctx=self._ctx + ) except Exception as e: logger.warning(f"Failed to generate summary for {file_path}: {e}") summary_dict = {"name": file_name, "summary": ""} @@ -155,21 +261,21 @@ async def _file_summary_task(self, parent_uri: str, file_path: str) -> None: self._stats.done_nodes += 1 self._stats.in_progress_nodes = max(0, self._stats.in_progress_nodes - 1) - await self._on_file_done(parent_uri, file_path, summary_dict) - - # Vectorize file as soon as summary is ready to avoid waiting for overview. try: - asyncio.create_task( - self._processor._vectorize_single_file( - parent_uri=parent_uri, - context_type=self._context_type, - file_path=file_path, - summary_dict=summary_dict, - ctx=self._ctx, + if need_vectorize: + asyncio.create_task( + self._processor._vectorize_single_file( + parent_uri=parent_uri, + context_type=self._context_type, + file_path=file_path, + summary_dict=summary_dict, + ctx=self._ctx, + semantic_msg_id=self._semantic_msg_id, + ) ) - ) except Exception as e: logger.error(f"Failed to schedule vectorization for {file_path}: {e}", exc_info=True) + await self._on_file_done(parent_uri, file_path, summary_dict) async def _on_file_done( self, parent_uri: str, file_path: str, summary_dict: Dict[str, str] @@ -241,17 +347,27 @@ async def _overview_task(self, dir_uri: str) -> None: node = self._nodes.get(dir_uri) if not node: return - - async with node.lock: - file_summaries = self._finalize_file_summaries(node) - children_abstracts = self._finalize_children_abstracts(node) - + need_vectorize = True try: - async with self._llm_sem: - overview = await self._processor._generate_overview( - dir_uri, file_summaries, children_abstracts + overview = None + abstract = None + if self._incremental_update: + children_changed = await self._check_dir_children_changed( + dir_uri, node.file_paths, node.children_dirs ) - abstract = self._processor._extract_abstract_from_overview(overview) + + if not children_changed: + need_vectorize = False + overview, abstract = await self._read_existing_overview_abstract(dir_uri) + if overview is None or abstract is None: + async with node.lock: + file_summaries = self._finalize_file_summaries(node) + children_abstracts = self._finalize_children_abstracts(node) + async with self._llm_sem: + overview = await self._processor._generate_overview( + dir_uri, file_summaries, children_abstracts + ) + abstract = self._processor._extract_abstract_from_overview(overview) try: await self._viking_fs.write_file(f"{dir_uri}/.overview.md", overview, ctx=self._ctx) @@ -260,11 +376,19 @@ async def _overview_task(self, dir_uri: str) -> None: logger.warning(f"Failed to write overview/abstract for {dir_uri}: {e}") try: - await self._processor._vectorize_directory_simple( - dir_uri, self._context_type, abstract, overview, ctx=self._ctx - ) + if need_vectorize: + asyncio.create_task( + self._processor._vectorize_directory( + dir_uri, + self._context_type, + abstract, + overview, + ctx=self._ctx, + semantic_msg_id=self._semantic_msg_id, + ) + ) except Exception as e: - logger.error(f"Failed to vectorize directory {dir_uri}: {e}", exc_info=True) + logger.error(f"Failed to schedule vectorization for {dir_uri}: {e}", exc_info=True) except Exception as e: logger.error(f"Failed to generate overview for {dir_uri}: {e}", exc_info=True) diff --git a/openviking/storage/queuefs/semantic_msg.py b/openviking/storage/queuefs/semantic_msg.py index 5f7bd730..7517dd3a 100644 --- a/openviking/storage/queuefs/semantic_msg.py +++ b/openviking/storage/queuefs/semantic_msg.py @@ -16,7 +16,7 @@ class SemanticMsg: Attributes: id: Unique identifier (UUID) uri: Directory URI to process - context_type: Type of context (resource, memory, skill) + context_type: Type of context (resource, memory, skill, session) status: Processing status (pending/processing/completed) timestamp: Creation timestamp recursive: Whether to recursively process subdirectories. @@ -27,7 +27,7 @@ class SemanticMsg: id: str # UUID uri: str # Directory URI - context_type: str # resource, memory, skill + context_type: str # resource, memory, skill, session status: str = "pending" # pending/processing/completed timestamp: int = int(datetime.now().timestamp()) recursive: bool = True # Whether to recursively process subdirectories @@ -37,6 +37,7 @@ class SemanticMsg: role: str = "root" # Additional flags skip_vectorization: bool = False + target_uri: str = "" def __init__( self, @@ -48,6 +49,7 @@ def __init__( agent_id: str = "default", role: str = "root", skip_vectorization: bool = False, + target_uri: str = "", ): self.id = str(uuid4()) self.uri = uri @@ -58,6 +60,7 @@ def __init__( self.agent_id = agent_id self.role = role self.skip_vectorization = skip_vectorization + self.target_uri = target_uri def to_dict(self) -> Dict[str, Any]: """Convert object to dictionary.""" @@ -93,6 +96,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "SemanticMsg": agent_id=data.get("agent_id", "default"), role=data.get("role", "root"), skip_vectorization=data.get("skip_vectorization", False), + target_uri=data.get("target_uri", ""), ) if "id" in data and data["id"]: obj.id = data["id"] diff --git a/openviking/storage/queuefs/semantic_processor.py b/openviking/storage/queuefs/semantic_processor.py index 59700783..56a05b14 100644 --- a/openviking/storage/queuefs/semantic_processor.py +++ b/openviking/storage/queuefs/semantic_processor.py @@ -3,7 +3,8 @@ """SemanticProcessor: Processes messages from SemanticQueue, generates .abstract.md and .overview.md.""" import asyncio -from typing import Any, Dict, List, Optional, Tuple +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple from openviking.parse.parsers.constants import ( CODE_EXTENSIONS, @@ -29,9 +30,22 @@ from openviking_cli.utils.config import get_openviking_config from openviking_cli.utils.logger import get_logger +from .embedding_tracker import EmbeddingTaskTracker + logger = get_logger(__name__) +@dataclass +class DiffResult: + """Directory diff result for sync operations.""" + + added_files: List[str] = field(default_factory=list) + deleted_files: List[str] = field(default_factory=list) + updated_files: List[str] = field(default_factory=list) + added_dirs: List[str] = field(default_factory=list) + deleted_dirs: List[str] = field(default_factory=list) + + class SemanticProcessor(DequeueHandlerBase): """ Semantic processor, generates .abstract.md and .overview.md bottom-up. @@ -103,55 +117,21 @@ def _detect_file_type(self, file_name: str) -> str: # Default to other return FILE_TYPE_OTHER - async def _enqueue_semantic_msg(self, msg: SemanticMsg) -> None: - """Enqueue a SemanticMsg to the semantic queue for processing.""" - from openviking.storage.queuefs import get_queue_manager - - queue_manager = get_queue_manager() - semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC) - # The queue manager returns SemanticQueue but method signature says NamedQueue - # We need to ignore the type error for the enqueue call - await semantic_queue.enqueue(msg) # type: ignore - logger.debug(f"Enqueued semantic message for processing: {msg.uri}") - - async def _collect_directory_info( - self, - uri: str, - result: List[Tuple[str, List[str], List[str]]], - ) -> None: - """Recursively collect directory info, post-order traversal ensures bottom-up order.""" + async def _check_file_content_changed( + self, file_path: str, target_file: str, ctx: Optional[RequestContext] = None + ) -> bool: + """Check if file content has changed compared to target file.""" viking_fs = get_viking_fs() - try: - entries = await viking_fs.ls(uri, ctx=self._current_ctx) - except Exception as e: - logger.warning(f"Failed to list directory {uri}: {e}") - return - - children_uris = [] - file_paths = [] - - for entry in entries: - name = entry.get("name", "") - if not name or name.startswith(".") or name in [".", ".."]: - continue - - item_uri = VikingURI(uri).join(name).uri - - if entry.get("isDir", False): - # Child directory - children_uris.append(item_uri) - # Recursively collect children - await self._collect_directory_info(item_uri, result) - else: - # File (not starting with .) - file_paths.append(item_uri) - - # Add current directory info - result.append((uri, children_uris, file_paths)) + current_content = await viking_fs.read_file(file_path, ctx=ctx) + target_content = await viking_fs.read_file(target_file, ctx=ctx) + return current_content != target_content + except Exception: + return True async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: """Process dequeued SemanticMsg, recursively process all subdirectories.""" + msg = None try: import json @@ -166,147 +146,361 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, msg = SemanticMsg.from_dict(data) self._current_msg = msg self._current_ctx = self._ctx_from_semantic_msg(msg) - logger.info( - f"Processing semantic generation for: {msg.uri} (recursive={msg.recursive})" + logger.info(f"Processing semantic generation for: {msg})") + + # Check if target_uri exists, auto-detect incremental update + is_incremental = False + viking_fs = get_viking_fs() + if msg.target_uri: + target_exists = await viking_fs.exists(msg.target_uri, ctx=self._current_ctx) + if target_exists: + is_incremental = True + logger.info(f"Target URI exists, using incremental update: {msg.target_uri}") + + tracker = EmbeddingTaskTracker.get_instance() + on_complete = self._create_sync_diff_callback( + root_uri=msg.uri, + target_uri=msg.target_uri, + ctx=self._current_ctx, ) - - if msg.recursive: - executor = SemanticDagExecutor( - processor=self, - context_type=msg.context_type, - max_concurrent_llm=self.max_concurrent_llm, - ctx=self._current_ctx, - ) - self._dag_executor = executor - await executor.run(msg.uri) - logger.info(f"Completed semantic generation for: {msg.uri}") - self.report_success() - return None - else: - # Non-recursive processing: directly process this directory - children_uris = [] - file_paths = [] - - # Collect immediate children info only (no recursion) - viking_fs = get_viking_fs() - try: - entries = await viking_fs.ls(msg.uri, ctx=self._current_ctx) - for entry in entries: - name = entry.get("name", "") - if not name or name.startswith(".") or name in [".", ".."]: - continue - - item_uri = VikingURI(msg.uri).join(name).uri - - if entry.get("isDir", False): - children_uris.append(item_uri) - else: - file_paths.append(item_uri) - except Exception as e: - logger.warning(f"Failed to list directory {msg.uri}: {e}") - - # Process this directory - await self._process_single_directory( - uri=msg.uri, - context_type=msg.context_type, - children_uris=children_uris, - file_paths=file_paths, - ) - - logger.info(f"Completed semantic generation for: {msg.uri}") - self.report_success() - return None + # Register task with tracker, total_count=1 for root URI + await tracker.register( + semantic_msg_id=msg.id, + total_count=1, + on_complete=on_complete, + metadata={ + "uri": msg.uri, + }, + ) + executor = SemanticDagExecutor( + processor=self, + context_type=msg.context_type, + max_concurrent_llm=self.max_concurrent_llm, + ctx=self._current_ctx, + incremental_update=is_incremental, + target_uri=msg.target_uri, + semantic_msg_id=msg.id, + recursive=msg.recursive, + ) + self._dag_executor = executor + await executor.run(msg.uri) + logger.info(f"Completed semantic generation for: {msg.uri}") + self.report_success() + return None except Exception as e: logger.error(f"Failed to process semantic message: {e}", exc_info=True) self.report_error(str(e), data) return None finally: + # Decrement task counter for root URI + if msg is not None: + tracker = EmbeddingTaskTracker.get_instance() + await tracker.decrement( + semantic_msg_id=msg.id, + ) self._current_msg = None + self._current_ctx = None def get_dag_stats(self) -> Optional["DagStats"]: if not self._dag_executor: return None return self._dag_executor.get_stats() - async def _process_single_directory( + def _create_sync_diff_callback( + self, + root_uri: str, + target_uri: str, + ctx: RequestContext, + ) -> Callable[[], Awaitable[None]]: + """ + Create a callback function to sync directory differences. + + This callback compares root_uri (new content) with target_uri (old content), + handles added/updated/deleted files, then cleans up root_uri. + + Args: + root_uri: Source directory URI (new content) + target_uri: Target directory URI (old content) + ctx: Request context (captured at callback creation time) + + Returns: + Async callback function + """ + + async def sync_diff_callback() -> None: + + try: + viking_fs = get_viking_fs() + + root_tree = await self._collect_tree_info(root_uri, ctx=ctx) + + target_tree = await self._collect_tree_info(target_uri, ctx=ctx) + diff = await self._compute_diff( + root_tree, target_tree, root_uri, target_uri, ctx=ctx + ) + logger.info( + f"[SyncDiff] Diff computed: " + f"added_files={len(diff.added_files)}, " + f"deleted_files={len(diff.deleted_files)}, " + f"updated_files={len(diff.updated_files)}, " + f"added_dirs={len(diff.added_dirs)}, " + f"deleted_dirs={len(diff.deleted_dirs)}" + ) + await self._execute_sync_operations(diff, root_uri, target_uri, ctx=ctx) + try: + await viking_fs.rm(root_uri, recursive=True, ctx=ctx) + except Exception as e: + logger.warning(f"[SyncDiff] Failed to delete root directory {root_uri}: {e}") + + except Exception as e: + logger.error( + f"[SyncDiff] Error in sync_diff_callback: " + f"root_uri={root_uri}, target_uri={target_uri} " + f"error={e}", + exc_info=True, + ) + + return sync_diff_callback + + async def _collect_tree_info( self, uri: str, - context_type: str, - children_uris: List[str], - file_paths: List[str], - ) -> None: - """Process single directory, generate .abstract.md and .overview.md.""" + ctx: Optional[RequestContext] = None, + ) -> Dict[str, Tuple[List[str], List[str]]]: + """ + Recursively collect directory tree information. + + Args: + uri: Directory URI + ctx: Request context + + Returns: + Dictionary: {dir_uri: ([subdir_uris], [file_uris])} + """ viking_fs = get_viking_fs() + result: Dict[str, Tuple[List[str], List[str]]] = {} + total_dirs = 0 + total_files = 0 + + async def collect_recursive(current_uri: str, depth: int = 0) -> None: + nonlocal total_dirs, total_files + indent = " " * depth + try: + entries = await viking_fs.ls(current_uri, show_all_hidden=True, ctx=ctx) + except Exception as e: + logger.warning(f"[SyncDiff]{indent} Failed to list {current_uri}: {e}") + return + + sub_dirs: List[str] = [] + files: List[str] = [] + + for entry in entries: + name = entry.get("name", "") + if not name or name in [".", ".."]: + continue + if name.startswith(".") and name not in [".abstract.md", ".overview.md"]: + continue + + item_uri = VikingURI(current_uri).join(name).uri + + if entry.get("isDir", False): + sub_dirs.append(item_uri) + total_dirs += 1 + await collect_recursive(item_uri, depth + 1) + else: + files.append(item_uri) + total_files += 1 + + result[current_uri] = (sub_dirs, files) + + await collect_recursive(uri) + return result + + async def _compute_diff( + self, + root_tree: Dict[str, Tuple[List[str], List[str]]], + target_tree: Dict[str, Tuple[List[str], List[str]]], + root_uri: str, + target_uri: str, + ctx: Optional[RequestContext] = None, + ) -> DiffResult: + """ + Compute differences between two directory trees. + + Args: + root_tree: Directory tree from root_uri + target_tree: Directory tree from target_uri + root_uri: Source directory URI + target_uri: Target directory URI + ctx: Request context + + Returns: + DiffResult with added/deleted/updated files and directories + """ - # 1. Collect .abstract.md from subdirectories (already processed earlier) - children_abstracts = await self._collect_children_abstracts(children_uris) + def get_relative_path(uri: str, base_uri: str) -> str: + if uri.startswith(base_uri): + rel = uri[len(base_uri) :] + return rel.lstrip("/") + return uri + + root_files: Set[str] = set() + root_dirs: Set[str] = set() + target_files: Set[str] = set() + target_dirs: Set[str] = set() + + for dir_uri, (sub_dirs, files) in root_tree.items(): + rel_dir = get_relative_path(dir_uri, root_uri) + if rel_dir: + root_dirs.add(rel_dir) + for f in files: + root_files.add(get_relative_path(f, root_uri)) + for d in sub_dirs: + root_dirs.add(get_relative_path(d, root_uri)) + + for dir_uri, (sub_dirs, files) in target_tree.items(): + rel_dir = get_relative_path(dir_uri, target_uri) + if rel_dir: + target_dirs.add(rel_dir) + for f in files: + target_files.add(get_relative_path(f, target_uri)) + for d in sub_dirs: + target_dirs.add(get_relative_path(d, target_uri)) + + added_files_rel = root_files - target_files + deleted_files_rel = target_files - root_files + common_files = root_files & target_files + + added_dirs_rel = root_dirs - target_dirs + deleted_dirs_rel = target_dirs - root_dirs + + updated_files: List[str] = [] + for rel_file in common_files: + root_file = f"{root_uri}/{rel_file}" + target_file = f"{target_uri}/{rel_file}" + try: + if await self._check_file_content_changed(root_file, target_file, ctx=ctx): + updated_files.append(root_file) + except Exception as e: + logger.warning( + f"[SyncDiff] Failed to compare file content for {rel_file}: {e}, " + f"treating as unchanged" + ) - # 2. Concurrently generate summaries for files in directory - file_summaries = await self._generate_file_summaries( - file_paths, context_type=context_type, parent_uri=uri, enqueue_files=True + added_files = [f"{root_uri}/{f}" for f in added_files_rel] + deleted_files = [f"{target_uri}/{f}" for f in deleted_files_rel] + added_dirs = [f"{root_uri}/{d}" for d in added_dirs_rel] + deleted_dirs = [f"{target_uri}/{d}" for d in deleted_dirs_rel] + + result = DiffResult( + added_files=added_files, + deleted_files=deleted_files, + updated_files=updated_files, + added_dirs=added_dirs, + deleted_dirs=deleted_dirs, ) - # 3. Generate .overview.md (contains brief description) - overview = await self._generate_overview(uri, file_summaries, children_abstracts) + return result - # 4. Extract abstract from overview - abstract = self._extract_abstract_from_overview(overview) + async def _execute_sync_operations( + self, + diff: DiffResult, + root_uri: str, + target_uri: str, + ctx: Optional[RequestContext] = None, + ) -> None: + """ + Execute sync operations based on diff result. - # 5. Write files - await viking_fs.write_file(f"{uri}/.overview.md", overview, ctx=self._current_ctx) - await viking_fs.write_file(f"{uri}/.abstract.md", abstract, ctx=self._current_ctx) + Processing order: + 1. Delete files in target that don't exist in root + 2. Move added/updated files from root to target + 3. Delete directories in target that don't exist in root - logger.debug(f"Generated overview and abstract for {uri}") + Args: + diff: DiffResult containing operations to perform + root_uri: Source directory URI + target_uri: Target directory URI + ctx: Request context + """ + viking_fs = get_viking_fs() - # 6. Vectorize directory - try: - await self._vectorize_directory_simple(uri, context_type, abstract, overview) - except Exception as e: - logger.error(f"Failed to vectorize directory {uri}: {e}", exc_info=True) + def map_to_target(root_item_uri: str) -> str: + if root_item_uri.startswith(root_uri): + rel = root_item_uri[len(root_uri) :] + return f"{target_uri}{rel}" if rel else target_uri + return root_item_uri + + total_deleted = 0 + total_moved = 0 + total_failed = 0 + + for i, deleted_file in enumerate(diff.deleted_files, 1): + try: + await viking_fs.rm(deleted_file, ctx=ctx) + total_deleted += 1 + except Exception as e: + total_failed += 1 + logger.warning( + f"[SyncDiff] Failed to delete file [{i}/{len(diff.deleted_files)}]: {deleted_file}, error={e}" + ) + + for i, updated_file in enumerate(diff.updated_files, 1): + target_file = map_to_target(updated_file) + try: + await viking_fs.rm(target_file, ctx=ctx) + except Exception as e: + logger.warning( + f"[SyncDiff] Failed to remove old file [{i}/{len(diff.updated_files)}]: {target_file}, error={e}" + ) - async def _collect_children_abstracts(self, children_uris: List[str]) -> List[Dict[str, str]]: + files_to_move = diff.added_files + diff.updated_files + for i, root_file in enumerate(files_to_move, 1): + target_file = map_to_target(root_file) + try: + target_parent = VikingURI(target_file).parent + if target_parent: + try: + await viking_fs.mkdir(target_parent.uri, exist_ok=True, ctx=ctx) + except Exception as mkdir_error: + logger.debug( + f"[SyncDiff] Parent dir creation skipped (may already exist): {mkdir_error}" + ) + await viking_fs.mv(root_file, target_file, ctx=ctx) + total_moved += 1 + except Exception as e: + total_failed += 1 + logger.warning( + f"[SyncDiff] Failed to move file [{i}/{len(files_to_move)}]: " + f"{root_file} -> {target_file}, error={e}" + ) + + for i, deleted_dir in enumerate( + sorted(diff.deleted_dirs, key=lambda x: x.count("/"), reverse=True), 1 + ): + try: + await viking_fs.rm(deleted_dir, recursive=True, ctx=ctx) + except Exception as e: + total_failed += 1 + logger.warning( + f"[SyncDiff] Failed to delete directory [{i}/{len(diff.deleted_dirs)}]: " + f"{deleted_dir}, error={e}" + ) + + async def _collect_children_abstracts( + self, children_uris: List[str], ctx: Optional[RequestContext] = None + ) -> List[Dict[str, str]]: """Collect .abstract.md from subdirectories.""" viking_fs = get_viking_fs() results = [] for child_uri in children_uris: - abstract = await viking_fs.abstract(child_uri, ctx=self._current_ctx) + abstract = await viking_fs.abstract(child_uri, ctx=ctx) dir_name = child_uri.split("/")[-1] results.append({"name": dir_name, "abstract": abstract}) return results - async def _generate_file_summaries( - self, - file_paths: List[str], - context_type: Optional[str] = None, - parent_uri: Optional[str] = None, - enqueue_files: bool = False, - ) -> List[Dict[str, str]]: - """Concurrently generate file summaries.""" - if not file_paths: - return [] - - async def generate_one_summary(file_path: str) -> Dict[str, str]: - summary = await self._generate_single_file_summary(file_path, ctx=self._current_ctx) - if enqueue_files and context_type and parent_uri: - try: - await self._vectorize_single_file( - parent_uri=parent_uri, - context_type=context_type, - file_path=file_path, - summary_dict=summary, - ) - except Exception as e: - logger.error( - f"Failed to vectorize file {file_path}: {e}", - exc_info=True, - ) - return summary - - tasks = [generate_one_summary(fp) for fp in file_paths] - return await asyncio.gather(*tasks) - async def _generate_text_summary( self, file_path: str, @@ -507,13 +701,14 @@ def replace_index(match): logger.error(f"Failed to generate overview for {dir_uri}: {e}", exc_info=True) return f"# {dir_uri.split('/')[-1]}\n\nDirectory overview" - async def _vectorize_directory_simple( + async def _vectorize_directory( self, uri: str, context_type: str, abstract: str, overview: str, ctx: Optional[RequestContext] = None, + semantic_msg_id: Optional[str] = None, ) -> None: """Create directory Context and enqueue to EmbeddingQueue.""" @@ -523,6 +718,12 @@ async def _vectorize_directory_simple( from openviking.utils.embedding_utils import vectorize_directory_meta + tracker = EmbeddingTaskTracker.get_instance() + # Increment task for .abstract.md + await tracker.increment(semantic_msg_id=semantic_msg_id) + # Increment task for .overview.md + await tracker.increment(semantic_msg_id=semantic_msg_id) + active_ctx = ctx or self._current_ctx await vectorize_directory_meta( uri=uri, @@ -530,44 +731,25 @@ async def _vectorize_directory_simple( overview=overview, context_type=context_type, ctx=active_ctx, + semantic_msg_id=semantic_msg_id, ) - async def _vectorize_files( - self, - uri: str, - context_type: str, - file_paths: List[str], - file_summaries: List[Dict[str, str]], - ctx: Optional[RequestContext] = None, - ) -> None: - """Vectorize files in directory.""" - from openviking.storage.queuefs import get_queue_manager - - queue_manager = get_queue_manager() - embedding_queue = queue_manager.get_queue(queue_manager.EMBEDDING) - - for file_path, file_summary_dict in zip(file_paths, file_summaries): - await self._vectorize_single_file( - parent_uri=uri, - context_type=context_type, - file_path=file_path, - summary_dict=file_summary_dict, - embedding_queue=embedding_queue, - ctx=ctx, - ) - async def _vectorize_single_file( self, parent_uri: str, context_type: str, file_path: str, summary_dict: Dict[str, str], - embedding_queue: Optional[Any] = None, ctx: Optional[RequestContext] = None, + semantic_msg_id: Optional[str] = None, ) -> None: """Vectorize a single file using its content or summary.""" from openviking.utils.embedding_utils import vectorize_file + tracker = EmbeddingTaskTracker.get_instance() + await tracker.increment( + semantic_msg_id=semantic_msg_id, + ) active_ctx = ctx or self._current_ctx await vectorize_file( file_path=file_path, @@ -575,4 +757,5 @@ async def _vectorize_single_file( parent_uri=parent_uri, context_type=context_type, ctx=active_ctx, + semantic_msg_id=semantic_msg_id, ) diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index bda478ae..dda70dbd 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from openviking.pyagfs.exceptions import AGFSHTTPError +from openviking.pyagfs.helpers import cp as agfs_cp from openviking.server.identity import RequestContext, Role from openviking.utils.time_utils import format_simplified, get_current_timestamp, parse_iso_datetime from openviking_cli.exceptions import NotFoundError @@ -332,6 +333,22 @@ async def stat(self, uri: str, ctx: Optional[RequestContext] = None) -> Dict[str path = self._uri_to_path(uri, ctx=ctx) return self.agfs.stat(path) + async def exists(self, uri: str, ctx: Optional[RequestContext] = None) -> bool: + """Check if a URI exists. + + Args: + uri: Viking URI + ctx: Request context + + Returns: + bool: True if the URI exists, False otherwise + """ + try: + await self.stat(uri, ctx=ctx) + return True + except Exception: + return False + async def glob( self, pattern: str, @@ -1445,6 +1462,29 @@ async def move_file( self.agfs.write(to_path, content) self.agfs.rm(from_path) + async def copy_directory( + self, + from_uri: str, + to_uri: str, + ctx: Optional[RequestContext] = None, + ) -> None: + """Copy directory recursively. + + Args: + from_uri: Source directory URI + to_uri: Destination directory URI + ctx: Request context + """ + self._ensure_access(from_uri, ctx) + self._ensure_access(to_uri, ctx) + + from_path = self._uri_to_path(from_uri, ctx=ctx) + to_path = self._uri_to_path(to_uri, ctx=ctx) + + await self._ensure_parent_dirs(to_path) + + await asyncio.to_thread(agfs_cp, self.agfs, from_path, to_path, recursive=True) + # ========== Temp File Operations (backward compatible) ========== def create_temp_uri(self) -> str: diff --git a/openviking/utils/embedding_utils.py b/openviking/utils/embedding_utils.py index 1dffc1c8..907c5dcd 100644 --- a/openviking/utils/embedding_utils.py +++ b/openviking/utils/embedding_utils.py @@ -116,6 +116,7 @@ async def vectorize_directory_meta( overview: str, context_type: str = "resource", ctx: Optional[RequestContext] = None, + semantic_msg_id: Optional[str] = None, ) -> None: """ Vectorize directory metadata (.abstract.md and .overview.md). @@ -147,6 +148,7 @@ async def vectorize_directory_meta( context_abstract.set_vectorize(Vectorize(text=abstract)) msg_abstract = EmbeddingMsgConverter.from_context(context_abstract) if msg_abstract: + msg_abstract.semantic_msg_id = semantic_msg_id await embedding_queue.enqueue(msg_abstract) logger.debug(f"Enqueued directory L0 (abstract) for vectorization: {uri}") @@ -165,6 +167,7 @@ async def vectorize_directory_meta( context_overview.set_vectorize(Vectorize(text=overview)) msg_overview = EmbeddingMsgConverter.from_context(context_overview) if msg_overview: + msg_overview.semantic_msg_id = semantic_msg_id await embedding_queue.enqueue(msg_overview) logger.debug(f"Enqueued directory L1 (overview) for vectorization: {uri}") @@ -175,6 +178,7 @@ async def vectorize_file( parent_uri: str, context_type: str = "resource", ctx: Optional[RequestContext] = None, + semantic_msg_id: Optional[str] = None, ) -> None: """ Vectorize a single file. @@ -246,6 +250,7 @@ async def vectorize_file( if not embedding_msg: return + embedding_msg.semantic_msg_id = semantic_msg_id await embedding_queue.enqueue(embedding_msg) logger.debug(f"Enqueued file for vectorization: {file_path}") diff --git a/openviking/utils/resource_processor.py b/openviking/utils/resource_processor.py index c43ea541..f70c4e54 100644 --- a/openviking/utils/resource_processor.py +++ b/openviking/utils/resource_processor.py @@ -120,7 +120,7 @@ async def process_resource( "source_path": None, } - # ============ Phase 1: Parse source (Parser generates L0/L1 and writes to temp) ============ + # ============ Phase 1: Parse source and writes to temp viking fs ============ try: media_processor = self._get_media_processor() viking_fs = get_viking_fs() @@ -178,6 +178,7 @@ async def process_resource( ) if context_tree and context_tree.root: result["root_uri"] = context_tree.root.uri + result["temp_uri"] = context_tree.root.temp_uri except Exception as e: result["status"] = "error" result["errors"].append(f"Finalize from temp error: {e}") @@ -193,6 +194,7 @@ async def process_resource( # ============ Phase 4: Optional Steps ============ build_index = kwargs.get("build_index", True) + temp_uri_for_summarize = result.get("temp_uri") or parse_result.temp_dir_path if summarize: # Explicit summarization request. # If build_index is ALSO True, we want vectorization. @@ -203,6 +205,7 @@ async def process_resource( resource_uris=[result["root_uri"]], ctx=ctx, skip_vectorization=skip_vec, + temp_uris=[temp_uri_for_summarize], **kwargs, ) except Exception as e: @@ -214,7 +217,11 @@ async def process_resource( # We assume this means "Ingest and Index", which requires summarization. try: await self._get_summarizer().summarize( - resource_uris=[result["root_uri"]], ctx=ctx, skip_vectorization=False, **kwargs + resource_uris=[result["root_uri"]], + ctx=ctx, + skip_vectorization=False, + temp_uris=[temp_uri_for_summarize], + **kwargs, ) except Exception as e: logger.error(f"Auto-index failed: {e}") diff --git a/openviking/utils/summarizer.py b/openviking/utils/summarizer.py index a7477ba3..3d059f6f 100644 --- a/openviking/utils/summarizer.py +++ b/openviking/utils/summarizer.py @@ -6,13 +6,14 @@ Handles summarization and key information extraction. """ -from typing import TYPE_CHECKING, Any, Dict, List, Optional -from openviking_cli.utils import get_logger +from typing import TYPE_CHECKING, Any, Dict, List + from openviking.storage.queuefs import SemanticMsg, get_queue_manager +from openviking_cli.utils import get_logger if TYPE_CHECKING: - from openviking.server.identity import RequestContext from openviking.parse.vlm import VLMProcessor + from openviking.server.identity import RequestContext logger = get_logger(__name__) @@ -39,8 +40,19 @@ async def summarize( queue_manager = get_queue_manager() semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC, allow_create=True) + temp_uris = kwargs.get("temp_uris", []) + if temp_uris == []: + temp_uris = resource_uris + if len(temp_uris) != len(resource_uris): + logger.error( + f"temp_uris length ({len(temp_uris)}) must match resource_uris length ({len(resource_uris)})" + ) + return { + "status": "error", + "message": "temp_uris length must match resource_uris length", + } enqueued_count = 0 - for uri in resource_uris: + for uri, temp_uri in zip(resource_uris, temp_uris): # Determine context_type based on URI context_type = "resource" if uri.startswith("viking://memory/"): @@ -49,13 +61,14 @@ async def summarize( context_type = "skill" msg = SemanticMsg( - uri=uri, + uri=temp_uri, context_type=context_type, account_id=ctx.account_id, user_id=ctx.user.user_id, agent_id=ctx.user.agent_id, role=ctx.role.value, skip_vectorization=skip_vectorization, + target_uri=uri, ) await semantic_queue.enqueue(msg) enqueued_count += 1 diff --git a/openviking_cli/exceptions.py b/openviking_cli/exceptions.py index 807d317e..cd432552 100644 --- a/openviking_cli/exceptions.py +++ b/openviking_cli/exceptions.py @@ -70,6 +70,14 @@ def __init__(self, resource: str, resource_type: str = "resource"): ) +class ConflictError(OpenVikingError): + """Resource conflict (e.g., locked by another operation).""" + + def __init__(self, message: str, resource: Optional[str] = None): + details = {"resource": resource} if resource else {} + super().__init__(message, code="CONFLICT", details=details) + + # ============= Authentication Errors ============= diff --git a/tests/unit/session/test_compressor_cow.py b/tests/unit/session/test_compressor_cow.py new file mode 100644 index 00000000..cf6efe12 --- /dev/null +++ b/tests/unit/session/test_compressor_cow.py @@ -0,0 +1,458 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.core.context import Context +from openviking.session.compressor import SessionCompressor +from openviking.session.memory_extractor import CandidateMemory, MemoryCategory +from openviking_cli.session.user_id import UserIdentifier + + +def _make_user() -> UserIdentifier: + return UserIdentifier("acc1", "test_user", "test_agent") + + +def _make_candidate(category: MemoryCategory = MemoryCategory.PREFERENCES) -> CandidateMemory: + return CandidateMemory( + category=category, + abstract="User prefers concise summaries", + overview="User asks for concise answers frequently.", + content="The user prefers concise summaries over long explanations.", + source_session="session_test", + user=_make_user(), + language="en", + ) + + +def _make_context(uri: str, abstract: str = "Existing memory") -> Context: + return Context( + uri=uri, + context_type="memory", + abstract=abstract, + meta={"overview": "Existing overview"}, + ) + + +class TestConvertToTempUri: + def test_user_uri_converted_to_temp_uri(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + + result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, None) + + expected = f"{user_temp_uri}/memories/preferences/pref1.md" + assert result == expected + + def test_agent_uri_converted_to_temp_uri(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + agent_space = _make_user().agent_space_name() + target_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" + agent_temp_uri = "viking://agent/temp_agent_456" + + result = compressor._convert_to_temp_uri(target_uri, None, agent_temp_uri) + + expected = f"{agent_temp_uri}/memories/cases/case1.md" + assert result == expected + + def test_no_temp_uri_returns_original_uri(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + + result = compressor._convert_to_temp_uri(target_uri, None, None) + + assert result == target_uri + + def test_mixed_uris_only_convert_matching_type(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + agent_space = _make_user().agent_space_name() + + user_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + agent_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" + + user_temp_uri = "viking://user/temp_user_123" + + result_user = compressor._convert_to_temp_uri(user_uri, user_temp_uri, None) + result_agent = compressor._convert_to_temp_uri(agent_uri, user_temp_uri, None) + + assert result_user == f"{user_temp_uri}/memories/preferences/pref1.md" + assert result_agent == agent_uri + + def test_agent_uri_with_both_temp_uris(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + agent_space = _make_user().agent_space_name() + target_uri = f"viking://agent/{agent_space}/memories/patterns/pattern1.md" + user_temp_uri = "viking://user/temp_user_123" + agent_temp_uri = "viking://agent/temp_agent_456" + + result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, agent_temp_uri) + + expected = f"{agent_temp_uri}/memories/patterns/pattern1.md" + assert result == expected + + def test_user_uri_with_both_temp_uris(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/entities/entity1.md" + user_temp_uri = "viking://user/temp_user_123" + agent_temp_uri = "viking://agent/temp_agent_456" + + result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, agent_temp_uri) + + expected = f"{user_temp_uri}/memories/entities/entity1.md" + assert result == expected + + def test_non_viking_uri_returns_unchanged(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + target_uri = "file:///some/local/path/memory.md" + user_temp_uri = "viking://user/temp_user_123" + + result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, None) + + assert result == target_uri + + def test_short_uri_returns_unchanged(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + target_uri = "viking://user" + user_temp_uri = "viking://user/temp_user_123" + + result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, None) + + assert result == target_uri + + +@pytest.mark.asyncio +class TestMergeIntoExisting: + async def test_merge_into_existing_success(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + temp_uri = f"{user_temp_uri}/memories/preferences/pref1.md" + + candidate = _make_candidate() + target_memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.read_file = AsyncMock(return_value="Existing content") + viking_fs.write_file = AsyncMock() + + mock_payload = MagicMock() + mock_payload.abstract = "Merged abstract" + mock_payload.overview = "Merged overview" + mock_payload.content = "Merged content" + + with patch.object( + compressor.extractor, + "_merge_memory_bundle", + AsyncMock(return_value=mock_payload), + ): + ctx = MagicMock() + result = await compressor._merge_into_existing( + candidate, + target_memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is True + viking_fs.read_file.assert_called_once() + viking_fs.write_file.assert_called_once() + assert target_memory.abstract == "Merged abstract" + assert target_memory.meta.get("overview") == "Merged overview" + + async def test_merge_into_existing_without_temp_uri(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + + candidate = _make_candidate() + target_memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.read_file = AsyncMock(return_value="Existing content") + viking_fs.write_file = AsyncMock() + + mock_payload = MagicMock() + mock_payload.abstract = "Merged abstract" + mock_payload.overview = "Merged overview" + mock_payload.content = "Merged content" + + with patch.object( + compressor.extractor, + "_merge_memory_bundle", + AsyncMock(return_value=mock_payload), + ): + ctx = MagicMock() + result = await compressor._merge_into_existing( + candidate, + target_memory, + viking_fs, + ctx=ctx, + user_temp_uri=None, + agent_temp_uri=None, + ) + + assert result is True + viking_fs.read_file.assert_called_once_with(target_uri, ctx=ctx) + + async def test_merge_into_existing_merge_bundle_returns_none(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + + candidate = _make_candidate() + target_memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.read_file = AsyncMock(return_value="Existing content") + + with patch.object( + compressor.extractor, + "_merge_memory_bundle", + AsyncMock(return_value=None), + ): + ctx = MagicMock() + result = await compressor._merge_into_existing( + candidate, + target_memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is False + viking_fs.write_file.assert_not_called() + + async def test_merge_into_existing_read_file_exception(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + + candidate = _make_candidate() + target_memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.read_file = AsyncMock(side_effect=Exception("Read error")) + + ctx = MagicMock() + result = await compressor._merge_into_existing( + candidate, + target_memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is False + + async def test_merge_into_existing_agent_uri(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + agent_space = _make_user().agent_space_name() + target_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" + agent_temp_uri = "viking://agent/temp_agent_456" + + candidate = _make_candidate(category=MemoryCategory.CASES) + target_memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.read_file = AsyncMock(return_value="Existing content") + viking_fs.write_file = AsyncMock() + + mock_payload = MagicMock() + mock_payload.abstract = "Merged case abstract" + mock_payload.overview = "Merged case overview" + mock_payload.content = "Merged case content" + + with patch.object( + compressor.extractor, + "_merge_memory_bundle", + AsyncMock(return_value=mock_payload), + ): + ctx = MagicMock() + result = await compressor._merge_into_existing( + candidate, + target_memory, + viking_fs, + ctx=ctx, + user_temp_uri=None, + agent_temp_uri=agent_temp_uri, + ) + + assert result is True + expected_temp_uri = f"{agent_temp_uri}/memories/cases/case1.md" + viking_fs.read_file.assert_called_once_with(expected_temp_uri, ctx=ctx) + + +@pytest.mark.asyncio +class TestDeleteExistingMemory: + async def test_delete_existing_memory_success(self): + vikingdb = MagicMock() + vikingdb.delete_uris = AsyncMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + temp_uri = f"{user_temp_uri}/memories/preferences/pref1.md" + + memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.rm = AsyncMock() + + ctx = MagicMock() + result = await compressor._delete_existing_memory( + memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is True + viking_fs.rm.assert_called_once_with(temp_uri, recursive=False, ctx=ctx) + vikingdb.delete_uris.assert_called_once_with(ctx, [temp_uri]) + + async def test_delete_existing_memory_without_temp_uri(self): + vikingdb = MagicMock() + vikingdb.delete_uris = AsyncMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + + memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.rm = AsyncMock() + + ctx = MagicMock() + result = await compressor._delete_existing_memory( + memory, + viking_fs, + ctx=ctx, + user_temp_uri=None, + agent_temp_uri=None, + ) + + assert result is True + viking_fs.rm.assert_called_once_with(target_uri, recursive=False, ctx=ctx) + vikingdb.delete_uris.assert_called_once_with(ctx, [target_uri]) + + async def test_delete_existing_memory_rm_exception(self): + vikingdb = MagicMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + + memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.rm = AsyncMock(side_effect=Exception("Delete error")) + + ctx = MagicMock() + result = await compressor._delete_existing_memory( + memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is False + + async def test_delete_existing_memory_vector_delete_exception(self): + vikingdb = MagicMock() + vikingdb.delete_uris = AsyncMock(side_effect=Exception("Vector delete error")) + compressor = SessionCompressor(vikingdb=vikingdb) + + user_space = _make_user().user_space_name() + target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + user_temp_uri = "viking://user/temp_user_123" + + memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.rm = AsyncMock() + + ctx = MagicMock() + result = await compressor._delete_existing_memory( + memory, + viking_fs, + ctx=ctx, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert result is True + viking_fs.rm.assert_called_once() + vikingdb.delete_uris.assert_called_once() + + async def test_delete_existing_memory_agent_uri(self): + vikingdb = MagicMock() + vikingdb.delete_uris = AsyncMock() + compressor = SessionCompressor(vikingdb=vikingdb) + + agent_space = _make_user().agent_space_name() + target_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" + agent_temp_uri = "viking://agent/temp_agent_456" + temp_uri = f"{agent_temp_uri}/memories/cases/case1.md" + + memory = _make_context(target_uri) + + viking_fs = AsyncMock() + viking_fs.rm = AsyncMock() + + ctx = MagicMock() + result = await compressor._delete_existing_memory( + memory, + viking_fs, + ctx=ctx, + user_temp_uri=None, + agent_temp_uri=agent_temp_uri, + ) + + assert result is True + viking_fs.rm.assert_called_once_with(temp_uri, recursive=False, ctx=ctx) + vikingdb.delete_uris.assert_called_once_with(ctx, [temp_uri]) diff --git a/tests/unit/session/test_deduplicator_uri.py b/tests/unit/session/test_deduplicator_uri.py new file mode 100644 index 00000000..46057b77 --- /dev/null +++ b/tests/unit/session/test_deduplicator_uri.py @@ -0,0 +1,310 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from openviking.session.memory_deduplicator import MemoryDeduplicator +from openviking.session.memory_extractor import CandidateMemory, MemoryCategory +from openviking_cli.session.user_id import UserIdentifier + + +class _DummyEmbedResult: + def __init__(self, dense_vector): + self.dense_vector = dense_vector + + +class _DummyEmbedder: + def embed(self, _text): + return _DummyEmbedResult([0.1, 0.2, 0.3]) + + +def _make_user() -> UserIdentifier: + return UserIdentifier("acc1", "test_user", "test_agent") + + +def _make_candidate(category: MemoryCategory = MemoryCategory.PREFERENCES) -> CandidateMemory: + return CandidateMemory( + category=category, + abstract="User prefers concise summaries", + overview="User asks for concise answers frequently.", + content="The user prefers concise summaries over long explanations.", + source_session="session_test", + user=_make_user(), + language="en", + ) + + +def _make_existing_user_memory(uri_suffix: str = "existing.md") -> dict: + user_space = _make_user().user_space_name() + return { + "id": f"uri_{uri_suffix}", + "uri": f"viking://user/{user_space}/memories/preferences/{uri_suffix}", + "context_type": "memory", + "level": 2, + "account_id": "acc1", + "owner_space": user_space, + "abstract": "Existing preference memory", + "category": "preferences", + "_score": 0.85, + } + + +def _make_existing_agent_memory(uri_suffix: str = "case1.md") -> dict: + user = _make_user() + agent_space = user.agent_space_name() + return { + "id": f"uri_{uri_suffix}", + "uri": f"viking://agent/{agent_space}/memories/cases/{uri_suffix}", + "context_type": "memory", + "level": 2, + "account_id": "acc1", + "owner_space": agent_space, + "abstract": "Existing case memory", + "category": "cases", + "_score": 0.90, + } + + +@pytest.mark.asyncio +class TestFindSimilarMemoriesURIConversion: + async def test_user_uri_converted_to_temp_uri(self): + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock( + return_value=[_make_existing_user_memory("pref1.md")] + ) + + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate() + + user_temp_uri = "viking://user/temp_user_123" + similar = await dedup._find_similar_memories( + candidate, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert len(similar) == 1 + user_space = _make_user().user_space_name() + original_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + expected_uri = f"{user_temp_uri}/memories/preferences/pref1.md" + assert similar[0].uri == expected_uri + assert similar[0].uri != original_uri + + async def test_agent_uri_converted_to_temp_uri(self): + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock( + return_value=[_make_existing_agent_memory("case1.md")] + ) + + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate(category=MemoryCategory.CASES) + + agent_temp_uri = "viking://agent/temp_agent_456" + similar = await dedup._find_similar_memories( + candidate, + user_temp_uri=None, + agent_temp_uri=agent_temp_uri, + ) + + assert len(similar) == 1 + user = _make_user() + agent_space = user.agent_space_name() + original_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" + expected_uri = f"{agent_temp_uri}/memories/cases/case1.md" + assert similar[0].uri == expected_uri + assert similar[0].uri != original_uri + + async def test_no_conversion_when_no_temp_uri(self): + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock( + return_value=[_make_existing_user_memory("pref1.md")] + ) + + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate() + + similar = await dedup._find_similar_memories( + candidate, + user_temp_uri=None, + agent_temp_uri=None, + ) + + assert len(similar) == 1 + user_space = _make_user().user_space_name() + expected_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" + assert similar[0].uri == expected_uri + + async def test_mixed_uris_only_convert_matching_type(self): + user_space = _make_user().user_space_name() + agent_space = _make_user().agent_space_name() + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock( + return_value=[ + _make_existing_user_memory("pref1.md"), + _make_existing_agent_memory("case1.md"), + ] + ) + + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate() + + user_temp_uri = "viking://user/temp_user_123" + similar = await dedup._find_similar_memories( + candidate, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert len(similar) == 2 + uris = {m.uri for m in similar} + assert f"{user_temp_uri}/memories/preferences/pref1.md" in uris + assert f"viking://agent/{agent_space}/memories/cases/case1.md" in uris + + async def test_uri_conversion_preserves_meta_and_score(self): + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock( + return_value=[_make_existing_user_memory("pref1.md")] + ) + + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate() + + user_temp_uri = "viking://user/temp_user_123" + similar = await dedup._find_similar_memories( + candidate, + user_temp_uri=user_temp_uri, + agent_temp_uri=None, + ) + + assert len(similar) == 1 + assert similar[0].meta is not None + assert similar[0].meta.get("_dedup_score") == 0.85 + + +class TestExtractFacetKey: + def test_extract_with_chinese_colon(self): + result = MemoryDeduplicator._extract_facet_key("饮食偏好:喜欢吃苹果和草莓") + assert result == "饮食偏好" + + def test_extract_with_english_colon(self): + result = MemoryDeduplicator._extract_facet_key("User preference: dark mode enabled") + assert result == "user preference" + + def test_extract_with_hyphen(self): + result = MemoryDeduplicator._extract_facet_key("Coding style - prefer type hints") + assert result == "coding style" + + def test_extract_with_em_dash(self): + result = MemoryDeduplicator._extract_facet_key("Work schedule — remote on Fridays") + assert result == "work schedule" + + def test_extract_with_no_separator_returns_prefix(self): + result = MemoryDeduplicator._extract_facet_key( + "This is a long abstract without any separator" + ) + assert len(result) <= 24 + assert result == "this is a long abstract" + + def test_extract_with_empty_string(self): + result = MemoryDeduplicator._extract_facet_key("") + assert result == "" + + def test_extract_with_none(self): + result = MemoryDeduplicator._extract_facet_key(None) + assert result == "" + + def test_extract_normalizes_whitespace(self): + result = MemoryDeduplicator._extract_facet_key(" Multiple spaces : value ") + assert result == "multiple spaces" + + def test_extract_with_short_text_no_separator(self): + result = MemoryDeduplicator._extract_facet_key("Short") + assert result == "short" + + def test_extract_returns_lowercase(self): + result = MemoryDeduplicator._extract_facet_key("FOOD PREFERENCE: pizza") + assert result == "food preference" + + def test_extract_with_separator_at_start(self): + result = MemoryDeduplicator._extract_facet_key(": starts with separator") + assert result == ": starts with" + + def test_extract_with_multiple_separators_uses_first(self): + result = MemoryDeduplicator._extract_facet_key("Topic: Subtopic - Detail") + assert result == "topic" + + +class TestCosineSimilarity: + def test_identical_vectors(self): + vec = [1.0, 2.0, 3.0] + result = MemoryDeduplicator._cosine_similarity(vec, vec) + assert abs(result - 1.0) < 1e-9 + + def test_orthogonal_vectors(self): + vec_a = [1.0, 0.0] + vec_b = [0.0, 1.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert abs(result) < 1e-9 + + def test_opposite_vectors(self): + vec_a = [1.0, 2.0, 3.0] + vec_b = [-1.0, -2.0, -3.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert abs(result + 1.0) < 1e-9 + + def test_different_length_vectors(self): + vec_a = [1.0, 2.0, 3.0] + vec_b = [1.0, 2.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert result == 0.0 + + def test_zero_vector_a(self): + vec_a = [0.0, 0.0, 0.0] + vec_b = [1.0, 2.0, 3.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert result == 0.0 + + def test_zero_vector_b(self): + vec_a = [1.0, 2.0, 3.0] + vec_b = [0.0, 0.0, 0.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert result == 0.0 + + def test_both_zero_vectors(self): + vec_a = [0.0, 0.0, 0.0] + vec_b = [0.0, 0.0, 0.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert result == 0.0 + + def test_partial_similarity(self): + vec_a = [1.0, 0.0, 0.0] + vec_b = [1.0, 1.0, 0.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + expected = 1.0 / (2.0**0.5) + assert abs(result - expected) < 1e-9 + + def test_negative_values(self): + vec_a = [1.0, -2.0, 3.0] + vec_b = [-1.0, 2.0, 3.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert 0 < result < 1 + + def test_single_element_vectors(self): + vec_a = [5.0] + vec_b = [3.0] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert abs(result - 1.0) < 1e-9 + + def test_large_vectors(self): + vec_a = [float(i) for i in range(100)] + vec_b = [float(i * 2) for i in range(100)] + result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) + assert abs(result - 1.0) < 1e-6 diff --git a/tests/unit/session/test_memory_extractor_tools.py b/tests/unit/session/test_memory_extractor_tools.py new file mode 100644 index 00000000..4d77aaab --- /dev/null +++ b/tests/unit/session/test_memory_extractor_tools.py @@ -0,0 +1,786 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import patch + +import pytest + +from openviking.session.memory_extractor import ( + FIELD_MAX_LENGTHS, + MemoryExtractor, +) + + +@pytest.fixture +def extractor(): + return MemoryExtractor() + + +class TestParseToolStatistics: + def test_parse_chinese_format_full(self, extractor): + content = """ +Tool: test_tool + +总调用次数: 100 +成功率: 85.0%(85 成功,15 失败) +平均耗时: 150.5ms +平均Token: 500 +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 100 + assert stats["success_count"] == 85 + assert stats["fail_count"] == 15 + assert stats["total_time_ms"] == 15050.0 + assert stats["total_tokens"] == 50000 + + def test_parse_chinese_format_with_colon(self, extractor): + content = """ +总调用次数:200 +成功率:90.5%(181 成功,19 失败) +平均耗时:200.0ms +平均Token:800 +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 200 + assert stats["success_count"] == 181 + assert stats["fail_count"] == 19 + + def test_parse_english_format_full(self, extractor): + content = """ +Tool: test_tool + +Based on 50 historical calls: +- Success rate: 80.0% (40 successful, 10 failed) +- Avg time: 1.5s, Avg tokens: 600 +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 50 + assert stats["success_count"] == 40 + assert stats["fail_count"] == 10 + assert stats["total_time_ms"] == 75000.0 + assert stats["total_tokens"] == 30000 + + def test_parse_english_format_ms(self, extractor): + content = """ +Based on 30 historical calls: +- Success rate: 90.0% (27 successful, 3 failed) +- Avg time: 250.5ms, Avg tokens: 400 +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 30 + assert stats["success_count"] == 27 + assert stats["fail_count"] == 3 + assert stats["total_time_ms"] == 7515.0 + assert stats["total_tokens"] == 12000 + + def test_parse_chinese_success_rate_only(self, extractor): + content = """ +总调用次数: 100 +成功率: 75.0% +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 100 + assert stats["success_count"] == 75 + assert stats["fail_count"] == 25 + + def test_parse_english_success_rate_only(self, extractor): + content = """ +Based on 80 historical calls: +- Success rate: 87.5% +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 80 + assert stats["success_count"] == 70 + assert stats["fail_count"] == 10 + + def test_parse_empty_content(self, extractor): + stats = extractor._parse_tool_statistics("") + assert stats["total_calls"] == 0 + assert stats["success_count"] == 0 + assert stats["fail_count"] == 0 + assert stats["total_time_ms"] == 0 + assert stats["total_tokens"] == 0 + + def test_parse_chinese_avg_time_seconds(self, extractor): + content = """ +总调用次数: 10 +平均耗时: 2.5s +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 10 + assert stats["total_time_ms"] == 25000.0 + + def test_parse_no_total_calls_infers_from_success_fail(self, extractor): + content = """ +成功率: 80.0%(40 成功,10 失败) +""" + stats = extractor._parse_tool_statistics(content) + assert stats["total_calls"] == 50 + assert stats["success_count"] == 40 + assert stats["fail_count"] == 10 + + +class TestMergeToolStatistics: + def test_merge_basic(self, extractor): + existing = { + "total_calls": 100, + "success_count": 80, + "fail_count": 20, + "total_time_ms": 10000.0, + "total_tokens": 50000, + } + new = { + "total_calls": 50, + "success_count": 45, + "fail_count": 5, + "total_time_ms": 5000.0, + "total_tokens": 25000, + } + merged = extractor._merge_tool_statistics(existing, new) + assert merged["total_calls"] == 150 + assert merged["success_count"] == 125 + assert merged["fail_count"] == 25 + assert merged["total_time_ms"] == 15000.0 + assert merged["total_tokens"] == 75000 + assert abs(merged["avg_time_ms"] - 100.0) < 0.01 + assert abs(merged["avg_tokens"] - 500.0) < 0.01 + assert abs(merged["success_rate"] - 0.8333) < 0.01 + + def test_merge_with_zero_existing(self, extractor): + existing = { + "total_calls": 0, + "success_count": 0, + "fail_count": 0, + "total_time_ms": 0, + "total_tokens": 0, + } + new = { + "total_calls": 10, + "success_count": 8, + "fail_count": 2, + "total_time_ms": 1000.0, + "total_tokens": 5000, + } + merged = extractor._merge_tool_statistics(existing, new) + assert merged["total_calls"] == 10 + assert merged["success_count"] == 8 + assert merged["fail_count"] == 2 + + def test_merge_with_zero_new(self, extractor): + existing = { + "total_calls": 20, + "success_count": 15, + "fail_count": 5, + "total_time_ms": 2000.0, + "total_tokens": 10000, + } + new = { + "total_calls": 0, + "success_count": 0, + "fail_count": 0, + "total_time_ms": 0, + "total_tokens": 0, + } + merged = extractor._merge_tool_statistics(existing, new) + assert merged["total_calls"] == 20 + assert merged["success_count"] == 15 + assert merged["fail_count"] == 5 + + def test_merge_both_zero(self, extractor): + existing = { + "total_calls": 0, + "success_count": 0, + "fail_count": 0, + "total_time_ms": 0, + "total_tokens": 0, + } + new = { + "total_calls": 0, + "success_count": 0, + "fail_count": 0, + "total_time_ms": 0, + "total_tokens": 0, + } + merged = extractor._merge_tool_statistics(existing, new) + assert merged["total_calls"] == 0 + assert merged["avg_time_ms"] == 0 + assert merged["avg_tokens"] == 0 + assert merged["success_rate"] == 0 + + +class TestGenerateToolMemoryContent: + def test_generate_basic(self, extractor): + with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): + stats = { + "total_calls": 100, + "success_count": 85, + "fail_count": 15, + "avg_time_ms": 150.5, + "avg_tokens": 500, + "success_rate": 0.85, + } + guidelines = "Use this tool for testing purposes." + content = extractor._generate_tool_memory_content("test_tool", stats, guidelines) + assert "Tool: test_tool" in content + assert "Based on 100 historical calls:" in content + assert "Success rate: 85.0%" in content + assert "85 successful, 15 failed" in content + assert "Use this tool for testing purposes." in content + + def test_generate_with_fields(self, extractor): + with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): + stats = { + "total_calls": 50, + "success_count": 40, + "fail_count": 10, + "avg_time_ms": 200.0, + "avg_tokens": 600, + "success_rate": 0.8, + } + fields = { + "best_for": "Data processing tasks", + "optimal_params": "batch_size=100", + "common_failures": "Timeout on large inputs", + "recommendation": "Use with small batches", + } + content = extractor._generate_tool_memory_content("test_tool", stats, "", fields=fields) + assert "Best for: Data processing tasks" in content + assert "Optimal params: batch_size=100" in content + assert "Common failures: Timeout on large inputs" in content + assert "Recommendation: Use with small batches" in content + + def test_generate_with_empty_fields(self, extractor): + with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): + stats = { + "total_calls": 10, + "success_count": 10, + "fail_count": 0, + "avg_time_ms": 100.0, + "avg_tokens": 300, + "success_rate": 1.0, + } + content = extractor._generate_tool_memory_content("test_tool", stats, "", fields={}) + assert "Best for: " in content + assert "Optimal params: " in content + + def test_generate_extracts_fields_from_guidelines(self, extractor): + with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): + stats = { + "total_calls": 20, + "success_count": 18, + "fail_count": 2, + "avg_time_ms": 50.0, + "avg_tokens": 200, + "success_rate": 0.9, + } + guidelines = """ +Best for: Quick data validation +Optimal params: strict_mode=true +Common failures: Invalid input format +Recommendation: Always validate input first +""" + content = extractor._generate_tool_memory_content("test_tool", stats, guidelines) + assert "Best for: Quick data validation" in content + assert "Optimal params: strict_mode=true" in content + + +class TestParseSkillStatistics: + def test_parse_chinese_format_full(self, extractor): + content = """ +Skill: test_skill + +总执行次数: 100 +成功率: 90.0%(90 成功,10 失败) +""" + stats = extractor._parse_skill_statistics(content) + assert stats["total_executions"] == 100 + assert stats["success_count"] == 90 + assert stats["fail_count"] == 10 + + def test_parse_chinese_format_with_colon(self, extractor): + content = """ +总执行次数:50 +成功率:80.0%(40 成功,10 失败) +""" + stats = extractor._parse_skill_statistics(content) + assert stats["total_executions"] == 50 + assert stats["success_count"] == 40 + assert stats["fail_count"] == 10 + + def test_parse_english_format_full(self, extractor): + content = """ +Skill: test_skill + +Based on 75 historical executions: +- Success rate: 85.0% (64 successful, 11 failed) +""" + stats = extractor._parse_skill_statistics(content) + assert stats["total_executions"] == 75 + assert stats["success_count"] == 64 + assert stats["fail_count"] == 11 + + def test_parse_english_success_rate_only(self, extractor): + content = """ +Based on 60 historical executions: +- Success rate: 75.0% +""" + stats = extractor._parse_skill_statistics(content) + assert stats["total_executions"] == 60 + assert stats["success_count"] == 45 + assert stats["fail_count"] == 15 + + def test_parse_empty_content(self, extractor): + stats = extractor._parse_skill_statistics("") + assert stats["total_executions"] == 0 + assert stats["success_count"] == 0 + assert stats["fail_count"] == 0 + + def test_parse_no_total_executions_infers_from_success_fail(self, extractor): + content = """ +成功率: 70.0%(35 成功,15 失败) +""" + stats = extractor._parse_skill_statistics(content) + assert stats["total_executions"] == 50 + assert stats["success_count"] == 35 + assert stats["fail_count"] == 15 + + +class TestMergeSkillStatistics: + def test_merge_basic(self, extractor): + existing = { + "total_executions": 100, + "success_count": 90, + "fail_count": 10, + } + new = { + "total_executions": 50, + "success_count": 45, + "fail_count": 5, + } + merged = extractor._merge_skill_statistics(existing, new) + assert merged["total_executions"] == 150 + assert merged["success_count"] == 135 + assert merged["fail_count"] == 15 + assert abs(merged["success_rate"] - 0.9) < 0.01 + + def test_merge_with_zero_existing(self, extractor): + existing = { + "total_executions": 0, + "success_count": 0, + "fail_count": 0, + } + new = { + "total_executions": 20, + "success_count": 18, + "fail_count": 2, + } + merged = extractor._merge_skill_statistics(existing, new) + assert merged["total_executions"] == 20 + assert merged["success_count"] == 18 + assert merged["fail_count"] == 2 + + def test_merge_with_zero_new(self, extractor): + existing = { + "total_executions": 30, + "success_count": 25, + "fail_count": 5, + } + new = { + "total_executions": 0, + "success_count": 0, + "fail_count": 0, + } + merged = extractor._merge_skill_statistics(existing, new) + assert merged["total_executions"] == 30 + assert merged["success_count"] == 25 + assert merged["fail_count"] == 5 + + def test_merge_both_zero(self, extractor): + existing = { + "total_executions": 0, + "success_count": 0, + "fail_count": 0, + } + new = { + "total_executions": 0, + "success_count": 0, + "fail_count": 0, + } + merged = extractor._merge_skill_statistics(existing, new) + assert merged["total_executions"] == 0 + assert merged["success_rate"] == 0 + + +class TestGenerateSkillMemoryContent: + def test_generate_basic(self, extractor): + stats = { + "total_executions": 100, + "success_count": 90, + "fail_count": 10, + "success_rate": 0.9, + } + guidelines = "Use this skill for data processing." + content = extractor._generate_skill_memory_content("test_skill", stats, guidelines) + assert "Skill: test_skill" in content + assert "Based on 100 historical executions:" in content + assert "Success rate: 90.0%" in content + assert "90 successful, 10 failed" in content + assert "Use this skill for data processing." in content + + def test_generate_with_fields(self, extractor): + stats = { + "total_executions": 50, + "success_count": 45, + "fail_count": 5, + "success_rate": 0.9, + } + fields = { + "best_for": "Automated workflows", + "recommended_flow": "Step 1 -> Step 2 -> Step 3", + "key_dependencies": "Database connection", + "common_failures": "Network timeout", + "recommendation": "Use with retry logic", + } + content = extractor._generate_skill_memory_content("test_skill", stats, "", fields=fields) + assert "Best for: Automated workflows" in content + assert "Recommended flow: Step 1 -> Step 2 -> Step 3" in content + assert "Key dependencies: Database connection" in content + assert "Common failures: Network timeout" in content + assert "Recommendation: Use with retry logic" in content + + def test_generate_with_empty_fields(self, extractor): + stats = { + "total_executions": 10, + "success_count": 10, + "fail_count": 0, + "success_rate": 1.0, + } + content = extractor._generate_skill_memory_content("test_skill", stats, "", fields={}) + assert "Best for: " in content + assert "Recommended flow: " in content + + def test_generate_extracts_fields_from_guidelines(self, extractor): + stats = { + "total_executions": 20, + "success_count": 18, + "fail_count": 2, + "success_rate": 0.9, + } + guidelines = """ +Best for: Complex data transformations +Recommended flow: Validate -> Transform -> Store +Key dependencies: S3 bucket access +Common failures: Permission denied +Recommendation: Check permissions first +""" + content = extractor._generate_skill_memory_content("test_skill", stats, guidelines) + assert "Best for: Complex data transformations" in content + assert "Recommended flow: Validate -> Transform -> Store" in content + + +class TestMergeKvField: + @pytest.mark.asyncio + async def test_merge_both_empty(self, extractor): + result = await extractor._merge_kv_field("", "", "best_for") + assert result == "" + + @pytest.mark.asyncio + async def test_merge_existing_empty(self, extractor): + result = await extractor._merge_kv_field("", "new value", "best_for") + assert result == "new value" + + @pytest.mark.asyncio + async def test_merge_new_empty(self, extractor): + result = await extractor._merge_kv_field("existing value", "", "best_for") + assert result == "existing value" + + @pytest.mark.asyncio + async def test_merge_identical_values(self, extractor): + result = await extractor._merge_kv_field("same value", "same value", "best_for") + assert result == "same value" + + @pytest.mark.asyncio + async def test_merge_different_values(self, extractor): + result = await extractor._merge_kv_field("value A", "value B", "best_for") + assert "value A" in result + assert "value B" in result + assert ";" in result + + @pytest.mark.asyncio + async def test_merge_with_semicolon_separator(self, extractor): + result = await extractor._merge_kv_field("item1; item2", "item3", "best_for") + assert "item1" in result + assert "item2" in result + assert "item3" in result + + @pytest.mark.asyncio + async def test_merge_deduplicates(self, extractor): + result = await extractor._merge_kv_field("item1; item2", "item2; item3", "best_for") + assert result.count("item2") == 1 + assert "item1" in result + assert "item3" in result + + @pytest.mark.asyncio + async def test_merge_respects_max_length(self, extractor): + long_value = "x" * 600 + result = await extractor._merge_kv_field(long_value, "new", "best_for") + assert len(result) <= FIELD_MAX_LENGTHS["best_for"] + + @pytest.mark.asyncio + async def test_merge_with_newline_separator(self, extractor): + result = await extractor._merge_kv_field("item1\nitem2", "item3", "best_for") + assert "item1" in result + assert "item2" in result + assert "item3" in result + + +class TestSmartTruncate: + def test_no_truncation_needed(self, extractor): + text = "short text" + result = extractor._smart_truncate(text, 100) + assert result == text + + def test_truncate_at_semicolon(self, extractor): + text = "item1; item2; item3; item4; item5" + result = extractor._smart_truncate(text, 25) + assert len(result) <= 25 + assert result.endswith(";") or result.count(";") >= 1 + + def test_truncate_at_space(self, extractor): + text = "word1 word2 word3 word4 word5" + result = extractor._smart_truncate(text, 20) + assert len(result) <= 20 + + def test_truncate_fallback(self, extractor): + text = "abcdefghijklmnopqrstuvwxyz" + result = extractor._smart_truncate(text, 10) + assert len(result) == 10 + assert result == "abcdefghij" + + def test_truncate_empty_string(self, extractor): + result = extractor._smart_truncate("", 10) + assert result == "" + + def test_truncate_exact_length(self, extractor): + text = "exactly10!" + result = extractor._smart_truncate(text, 10) + assert result == text + + +class TestComputeStatisticsDerived: + def test_compute_with_calls(self, extractor): + stats = { + "total_calls": 100, + "success_count": 80, + "fail_count": 20, + "total_time_ms": 10000.0, + "total_tokens": 50000, + } + result = extractor._compute_statistics_derived(stats) + assert abs(result["avg_time_ms"] - 100.0) < 0.01 + assert abs(result["avg_tokens"] - 500.0) < 0.01 + assert abs(result["success_rate"] - 0.8) < 0.01 + + def test_compute_with_zero_calls(self, extractor): + stats = { + "total_calls": 0, + "success_count": 0, + "fail_count": 0, + "total_time_ms": 0, + "total_tokens": 0, + } + result = extractor._compute_statistics_derived(stats) + assert result["avg_time_ms"] == 0 + assert result["avg_tokens"] == 0 + assert result["success_rate"] == 0 + + def test_compute_preserves_original_values(self, extractor): + stats = { + "total_calls": 50, + "success_count": 40, + "fail_count": 10, + "total_time_ms": 5000.0, + "total_tokens": 25000, + } + result = extractor._compute_statistics_derived(stats) + assert result["total_calls"] == 50 + assert result["success_count"] == 40 + assert result["fail_count"] == 10 + assert result["total_time_ms"] == 5000.0 + assert result["total_tokens"] == 25000 + + +class TestFormatDuration: + def test_format_zero(self, extractor): + result = extractor._format_duration(0) + assert result == "0s" + + def test_format_milliseconds(self, extractor): + result = extractor._format_duration(500) + assert result == "500ms" + + def test_format_seconds(self, extractor): + result = extractor._format_duration(1500) + assert result == "1.5s" + + def test_format_large_seconds(self, extractor): + result = extractor._format_duration(10000) + assert result == "10.0s" + + def test_format_none(self, extractor): + result = extractor._format_duration(None) + assert result == "N/A" + + def test_format_negative(self, extractor): + result = extractor._format_duration(-100) + assert result == "0s" + + def test_format_invalid_type(self, extractor): + result = extractor._format_duration("invalid") + assert result == "N/A" + + def test_format_exactly_one_second(self, extractor): + result = extractor._format_duration(1000) + assert result == "1.0s" + + def test_format_just_under_one_second(self, extractor): + result = extractor._format_duration(999) + assert result == "999ms" + + +class TestExtractContentField: + def test_extract_with_chinese_colon(self, extractor): + content = "Best for:数据处理任务" + result = extractor._extract_content_field(content, ["Best for"]) + assert result == "数据处理任务" + + def test_extract_with_english_colon(self, extractor): + content = "Best for: data processing tasks" + result = extractor._extract_content_field(content, ["Best for"]) + assert result == "data processing tasks" + + def test_extract_with_multiple_keys(self, extractor): + content = "最佳场景: 快速验证" + result = extractor._extract_content_field(content, ["Best for", "最佳场景"]) + assert result == "快速验证" + + def test_extract_not_found(self, extractor): + content = "Some other content" + result = extractor._extract_content_field(content, ["Best for"]) + assert result == "" + + def test_extract_empty_content(self, extractor): + result = extractor._extract_content_field("", ["Best for"]) + assert result == "" + + +class TestCompactBlock: + def test_compact_basic(self, extractor): + text = "Line 1\nLine 2\nLine 3" + result = extractor._compact_block(text) + assert result == "Line 1; Line 2; Line 3" + + def test_compact_with_prefixes(self, extractor): + text = "> Point 1\n- Point 2\n* Point 3" + result = extractor._compact_block(text) + assert "Point 1" in result + assert "Point 2" in result + assert "Point 3" in result + + def test_compact_empty(self, extractor): + result = extractor._compact_block("") + assert result == "" + + def test_compact_whitespace_only(self, extractor): + result = extractor._compact_block(" \n \n ") + assert result == "" + + +class TestExtractToolMemoryContextFieldsFromText: + def test_extract_all_fields(self, extractor): + text = """ +Best for: Data processing +Optimal params: batch_size=100 +Common failures: Timeout +Recommendation: Use small batches +""" + result = extractor._extract_tool_memory_context_fields_from_text(text) + assert result["best_for"] == "Data processing" + assert result["optimal_params"] == "batch_size=100" + assert result["common_failures"] == "Timeout" + assert result["recommendation"] == "Use small batches" + + def test_extract_partial_fields(self, extractor): + text = """ +Best for: Testing +Recommendation: Run in dev mode +""" + result = extractor._extract_tool_memory_context_fields_from_text(text) + assert result["best_for"] == "Testing" + assert result["optimal_params"] == "" + assert result["common_failures"] == "" + assert result["recommendation"] == "Run in dev mode" + + def test_extract_chinese_fields(self, extractor): + text = """ +最佳场景: 数据处理 +最优参数: 批量大小=100 +常见失败: 超时 +推荐: 使用小批量 +""" + result = extractor._extract_tool_memory_context_fields_from_text(text) + assert result["best_for"] == "数据处理" + assert result["optimal_params"] == "批量大小=100" + assert result["common_failures"] == "超时" + assert result["recommendation"] == "使用小批量" + + +class TestExtractSkillMemoryContextFieldsFromText: + def test_extract_all_fields(self, extractor): + text = """ +Best for: Automated workflows +Recommended flow: Step1 -> Step2 -> Step3 +Key dependencies: Database +Common failures: Connection error +Recommendation: Use connection pool +""" + result = extractor._extract_skill_memory_context_fields_from_text(text) + assert result["best_for"] == "Automated workflows" + assert result["recommended_flow"] == "Step1 -> Step2 -> Step3" + assert result["key_dependencies"] == "Database" + assert result["common_failures"] == "Connection error" + assert result["recommendation"] == "Use connection pool" + + def test_extract_chinese_fields(self, extractor): + text = """ +最佳场景: 自动化工作流 +推荐流程: 步骤1 -> 步骤2 -> 步骤3 +关键依赖: 数据库 +常见失败: 连接错误 +推荐: 使用连接池 +""" + result = extractor._extract_skill_memory_context_fields_from_text(text) + assert result["best_for"] == "自动化工作流" + assert result["recommended_flow"] == "步骤1 -> 步骤2 -> 步骤3" + assert result["key_dependencies"] == "数据库" + assert result["common_failures"] == "连接错误" + assert result["recommendation"] == "使用连接池" + + +class TestFormatMs: + def test_format_zero(self, extractor): + result = extractor._format_ms(0) + assert result == "0.000ms" + + def test_format_normal_value(self, extractor): + result = extractor._format_ms(123.456) + assert result == "123.456ms" + + def test_format_very_small_value(self, extractor): + result = extractor._format_ms(0.000123) + assert "ms" in result + assert float(result.replace("ms", "")) > 0 + + def test_format_large_value(self, extractor): + result = extractor._format_ms(9999.999) + assert result == "9999.999ms" diff --git a/tests/unit/session/test_session_cow.py b/tests/unit/session/test_session_cow.py new file mode 100644 index 00000000..26ed353d --- /dev/null +++ b/tests/unit/session/test_session_cow.py @@ -0,0 +1,495 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for Session COW (Copy-on-Write) mode and async commit functionality.""" + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.server.identity import RequestContext, Role +from openviking.session.session import Session +from openviking_cli.session.user_id import UserIdentifier + + +def _make_user() -> UserIdentifier: + return UserIdentifier("test_account", "test_user", "test_agent") + + +def _make_session(viking_fs: MagicMock = None, session_id: str = "test_session_123") -> Session: + user = _make_user() + ctx = RequestContext(user=user, role=Role.ROOT) + fs = viking_fs or MagicMock() + return Session( + viking_fs=fs, + user=user, + ctx=ctx, + session_id=session_id, + ) + + +class TestCreateTempUris: + """Tests for _create_temp_uris() method.""" + + def test_returns_tuple_of_four_uris(self): + session = _make_session() + result = session._create_temp_uris() + + assert isinstance(result, tuple) + assert len(result) == 4 + + def test_temp_base_uri_format(self): + session = _make_session(session_id="sess_abc") + user = session.user + temp_base, _, _, _ = session._create_temp_uris() + + assert temp_base.startswith("viking://temp/session/") + assert f"/{user.user_space_name()}/" in temp_base + assert "/sess_abc/" in temp_base + assert "/commit_" in temp_base + + def test_session_temp_uri_structure(self): + session = _make_session(session_id="sess_abc") + user = session.user + temp_base, session_temp, _, _ = session._create_temp_uris() + + assert session_temp.startswith(temp_base) + assert "/session/" in session_temp + assert f"/{user.user_space_name()}/sess_abc" in session_temp + + def test_user_temp_uri_structure(self): + session = _make_session() + user = session.user + temp_base, _, user_temp, _ = session._create_temp_uris() + + assert user_temp.startswith(temp_base) + assert "/user/" in user_temp + assert user_temp.endswith(f"/user/{user.user_space_name()}") + + def test_agent_temp_uri_structure(self): + session = _make_session() + user = session.user + temp_base, _, _, agent_temp = session._create_temp_uris() + + assert agent_temp.startswith(temp_base) + assert "/agent/" in agent_temp + assert agent_temp.endswith(f"/agent/{user.agent_space_name()}") + + def test_sets_internal_state(self): + session = _make_session() + temp_base, session_temp, user_temp, agent_temp = session._create_temp_uris() + + assert session._temp_base_uri == temp_base + assert session._session_temp_uri == session_temp + assert session._user_temp_uri == user_temp + assert session._agent_temp_uri == agent_temp + assert session._temp_created_at is not None + + def test_temp_created_at_is_recent(self): + session = _make_session() + before = time.time() + session._create_temp_uris() + after = time.time() + + assert before <= session._temp_created_at <= after + + def test_commit_uuid_is_8_chars(self): + session = _make_session() + temp_base, _, _, _ = session._create_temp_uris() + + commit_part = temp_base.split("/commit_")[-1] + assert len(commit_part) == 8 + assert all(c in "0123456789abcdef" for c in commit_part) + + def test_multiple_calls_generate_different_uuids(self): + session = _make_session() + temp_base1, _, _, _ = session._create_temp_uris() + temp_base2, _, _, _ = session._create_temp_uris() + + assert temp_base1 != temp_base2 + + +class TestCleanupTempUris: + """Tests for _cleanup_temp_uris() method.""" + + @pytest.mark.asyncio + async def test_calls_delete_temp_on_viking_fs(self): + viking_fs = MagicMock() + viking_fs.delete_temp = AsyncMock() + session = _make_session(viking_fs=viking_fs) + + session._create_temp_uris() + saved_temp_base = session._temp_base_uri + await session._cleanup_temp_uris() + + viking_fs.delete_temp.assert_called_once() + call_args = viking_fs.delete_temp.call_args + assert call_args[0][0] == saved_temp_base + + @pytest.mark.asyncio + async def test_resets_internal_state(self): + viking_fs = MagicMock() + viking_fs.delete_temp = AsyncMock() + session = _make_session(viking_fs=viking_fs) + + session._create_temp_uris() + await session._cleanup_temp_uris() + + assert session._temp_base_uri is None + assert session._session_temp_uri is None + assert session._user_temp_uri is None + assert session._agent_temp_uri is None + assert session._temp_created_at is None + + @pytest.mark.asyncio + async def test_no_cleanup_when_no_temp_uri(self): + viking_fs = MagicMock() + viking_fs.delete_temp = AsyncMock() + session = _make_session(viking_fs=viking_fs) + + await session._cleanup_temp_uris() + + viking_fs.delete_temp.assert_not_called() + + @pytest.mark.asyncio + async def test_handles_delete_exception(self): + viking_fs = MagicMock() + viking_fs.delete_temp = AsyncMock(side_effect=Exception("Delete failed")) + session = _make_session(viking_fs=viking_fs) + + session._create_temp_uris() + await session._cleanup_temp_uris() + + assert session._temp_base_uri is None + + @pytest.mark.asyncio + async def test_passes_ctx_to_delete_temp(self): + viking_fs = MagicMock() + viking_fs.delete_temp = AsyncMock() + session = _make_session(viking_fs=viking_fs) + + session._create_temp_uris() + await session._cleanup_temp_uris() + + call_kwargs = viking_fs.delete_temp.call_args[1] + assert "ctx" in call_kwargs + assert call_kwargs["ctx"] == session.ctx + + +class TestEnqueueToSemanticQueue: + """Tests for _enqueue_to_semantic_queue() method.""" + + @pytest.mark.asyncio + async def test_returns_list_of_three_msg_ids(self): + session = _make_session() + + mock_queue = MagicMock() + mock_queue.enqueue = AsyncMock() + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + result = await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri="viking://temp/user/test", + agent_temp_uri="viking://temp/agent/test", + ) + + assert isinstance(result, list) + assert len(result) == 3 + + @pytest.mark.asyncio + async def test_session_msg_has_correct_target_uri(self): + session = _make_session(session_id="sess_xyz") + user = session.user + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + session_temp = f"viking://temp/session/{user.user_space_name()}/sess_xyz/commit_abc123/session/{user.user_space_name()}/sess_xyz" + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri=session_temp, + user_temp_uri="viking://temp/user/test", + agent_temp_uri="viking://temp/agent/test", + ) + + session_msg = enqueued_msgs[0] + expected_target = f"viking://session/{user.user_space_name()}/sess_xyz" + assert session_msg.target_uri == expected_target + assert session_msg.uri == session_temp + + @pytest.mark.asyncio + async def test_user_msg_has_correct_target_uri(self): + session = _make_session() + user = session.user + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + user_temp = f"viking://temp/session/{user.user_space_name()}/sess/commit_abc/user/{user.user_space_name()}" + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri=user_temp, + agent_temp_uri="viking://temp/agent/test", + ) + + user_msg = enqueued_msgs[1] + expected_target = f"viking://user/{user.user_space_name()}" + assert user_msg.target_uri == expected_target + assert user_msg.uri == user_temp + + @pytest.mark.asyncio + async def test_agent_msg_has_correct_target_uri(self): + session = _make_session() + user = session.user + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + agent_temp = f"viking://temp/session/{user.user_space_name()}/sess/commit_abc/agent/{user.agent_space_name()}" + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri="viking://temp/user/test", + agent_temp_uri=agent_temp, + ) + + agent_msg = enqueued_msgs[2] + expected_target = f"viking://agent/{user.agent_space_name()}" + assert agent_msg.target_uri == expected_target + assert agent_msg.uri == agent_temp + + @pytest.mark.asyncio + async def test_all_msgs_have_context_type_memory(self): + session = _make_session() + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri="viking://temp/user/test", + agent_temp_uri="viking://temp/agent/test", + ) + + for msg in enqueued_msgs: + assert msg.context_type == "memory" + + @pytest.mark.asyncio + async def test_all_msgs_have_recursive_true(self): + session = _make_session() + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri="viking://temp/user/test", + agent_temp_uri="viking://temp/agent/test", + ) + + for msg in enqueued_msgs: + assert msg.recursive is True + + @pytest.mark.asyncio + async def test_msgs_have_correct_user_context(self): + user = _make_user() + session = _make_session() + + enqueued_msgs = [] + + async def capture_enqueue(msg): + enqueued_msgs.append(msg) + + mock_queue = MagicMock() + mock_queue.enqueue = capture_enqueue + + mock_queue_manager = MagicMock() + mock_queue_manager.SEMANTIC = "semantic" + mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await session._enqueue_to_semantic_queue( + session_temp_uri="viking://temp/session/test", + user_temp_uri="viking://temp/user/test", + agent_temp_uri="viking://temp/agent/test", + ) + + for msg in enqueued_msgs: + assert msg.account_id == user.account_id + assert msg.user_id == user.user_id + assert msg.agent_id == user.agent_id + + +class TestTempUriStructureMatchesTarget: + """Tests for temp URI structure matching target URI structure.""" + + def test_session_temp_uri_contains_target_path(self): + session = _make_session(session_id="sess_123") + user = session.user + + temp_base, session_temp, _, _ = session._create_temp_uris() + + target_path = f"/session/{user.user_space_name()}/sess_123" + assert session_temp.endswith(target_path) + + def test_user_temp_uri_contains_target_path(self): + session = _make_session() + user = session.user + + temp_base, _, user_temp, _ = session._create_temp_uris() + + target_path = f"/user/{user.user_space_name()}" + assert user_temp.endswith(target_path) + + def test_agent_temp_uri_contains_target_path(self): + session = _make_session() + user = session.user + + temp_base, _, _, agent_temp = session._create_temp_uris() + + target_path = f"/agent/{user.agent_space_name()}" + assert agent_temp.endswith(target_path) + + def test_all_temp_uris_share_same_base(self): + session = _make_session() + + temp_base, session_temp, user_temp, agent_temp = session._create_temp_uris() + + assert session_temp.startswith(temp_base) + assert user_temp.startswith(temp_base) + assert agent_temp.startswith(temp_base) + + def test_temp_uri_structure_allows_semantic_dag_recursive_processing(self): + session = _make_session(session_id="sess_xyz") + user = session.user + + temp_base, session_temp, user_temp, agent_temp = session._create_temp_uris() + + assert "/session/" in session_temp + assert f"/{user.user_space_name()}/sess_xyz" in session_temp + + assert "/user/" in user_temp + assert f"/{user.user_space_name()}" in user_temp + + assert "/agent/" in agent_temp + assert f"/{user.agent_space_name()}" in agent_temp + + +class TestTempUriWithDifferentUsers: + """Tests for temp URI generation with different user configurations.""" + + def test_different_user_space_names(self): + user1 = UserIdentifier("acc1", "alice", "agent1") + user2 = UserIdentifier("acc2", "bob", "agent2") + + session1 = Session( + viking_fs=MagicMock(), + user=user1, + ctx=RequestContext(user=user1, role=Role.ROOT), + session_id="sess1", + ) + session2 = Session( + viking_fs=MagicMock(), + user=user2, + ctx=RequestContext(user=user2, role=Role.ROOT), + session_id="sess2", + ) + + _, session_temp1, user_temp1, agent_temp1 = session1._create_temp_uris() + _, session_temp2, user_temp2, agent_temp2 = session2._create_temp_uris() + + assert "alice" in session_temp1 + assert "bob" in session_temp2 + assert user_temp1 != user_temp2 + assert agent_temp1 != agent_temp2 + + def test_agent_space_name_is_hashed(self): + user = UserIdentifier("acc", "myuser", "myagent") + session = Session( + viking_fs=MagicMock(), + user=user, + ctx=RequestContext(user=user, role=Role.ROOT), + session_id="sess", + ) + + _, _, _, agent_temp = session._create_temp_uris() + + assert "myagent" not in agent_temp + assert user.agent_space_name() in agent_temp + assert len(user.agent_space_name()) == 12 diff --git a/tests/unit/storage/queuefs/__init__.py b/tests/unit/storage/queuefs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/storage/queuefs/test_dag_incremental.py b/tests/unit/storage/queuefs/test_dag_incremental.py new file mode 100644 index 00000000..08082ce7 --- /dev/null +++ b/tests/unit/storage/queuefs/test_dag_incremental.py @@ -0,0 +1,585 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for SemanticDagExecutor incremental update and content change detection.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.server.identity import RequestContext, Role +from openviking.storage.queuefs.semantic_dag import SemanticDagExecutor +from openviking_cli.session.user_id import UserIdentifier + + +@pytest.fixture +def mock_processor(): + """Create a mock SemanticProcessor.""" + processor = MagicMock() + processor._generate_single_file_summary = AsyncMock( + return_value={"name": "test.py", "summary": "test summary"} + ) + processor._generate_overview = AsyncMock(return_value="test overview") + processor._extract_abstract_from_overview = MagicMock(return_value="test abstract") + processor._vectorize_single_file = AsyncMock() + processor._vectorize_directory = AsyncMock() + return processor + + +@pytest.fixture +def mock_viking_fs(): + """Create a mock VikingFS.""" + fs = MagicMock() + fs.ls = AsyncMock(return_value=[]) + fs.read_file = AsyncMock(return_value="") + fs.write_file = AsyncMock() + fs._get_vector_store = MagicMock(return_value=None) + return fs + + +@pytest.fixture +def mock_vector_store(): + """Create a mock VectorStore.""" + store = MagicMock() + store.get_context_by_uri = AsyncMock(return_value=[]) + return store + + +@pytest.fixture +def mock_context(): + """Create a mock RequestContext.""" + user = MagicMock(spec=UserIdentifier) + user.account_id = "test_account" + user.user_id = "test_user" + return RequestContext(user=user, role=Role.USER) + + +@pytest.fixture +def executor(mock_processor, mock_context, mock_viking_fs): + """Create a SemanticDagExecutor instance for testing.""" + with patch( + "openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs + ): + executor = SemanticDagExecutor( + processor=mock_processor, + context_type="resource", + max_concurrent_llm=5, + ctx=mock_context, + incremental_update=True, + target_uri="viking://resource/target", + recursive=True, + ) + return executor + + +class TestGetTargetFilePath: + """Tests for _get_target_file_path() method.""" + + def test_returns_none_when_incremental_update_disabled( + self, mock_processor, mock_context, mock_viking_fs + ): + with patch( + "openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs + ): + executor = SemanticDagExecutor( + processor=mock_processor, + context_type="resource", + max_concurrent_llm=5, + ctx=mock_context, + incremental_update=False, + target_uri="viking://resource/target", + ) + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root/file.py") + assert result is None + + def test_returns_none_when_target_uri_is_none( + self, mock_processor, mock_context, mock_viking_fs + ): + with patch( + "openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs + ): + executor = SemanticDagExecutor( + processor=mock_processor, + context_type="resource", + max_concurrent_llm=5, + ctx=mock_context, + incremental_update=True, + target_uri=None, + ) + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root/file.py") + assert result is None + + def test_returns_none_when_root_uri_is_none(self, executor): + executor._root_uri = None + result = executor._get_target_file_path("viking://resource/root/file.py") + assert result is None + + def test_returns_target_path_for_file_in_root(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root/file.py") + assert result == "viking://resource/target/file.py" + + def test_returns_target_path_for_file_in_subdirectory(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root/subdir/file.py") + assert result == "viking://resource/target/subdir/file.py" + + def test_returns_target_uri_when_current_uri_equals_root(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root") + assert result == "viking://resource/target" + + def test_handles_nested_paths(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/root/a/b/c/file.py") + assert result == "viking://resource/target/a/b/c/file.py" + + def test_handles_path_prefix_matching(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path("viking://resource/rootdir/file.py") + assert result == "viking://resource/target/dir/file.py" + + def test_returns_none_on_exception(self, executor): + executor._root_uri = "viking://resource/root" + result = executor._get_target_file_path(None) + assert result is None + + +class TestCheckFileContentChanged: + """Tests for _check_file_content_changed() method.""" + + @pytest.mark.asyncio + async def test_returns_true_when_target_path_is_none(self, executor): + executor._root_uri = None + result = await executor._check_file_content_changed("viking://resource/root/file.py") + assert result is True + + @pytest.mark.asyncio + async def test_returns_true_when_content_differs(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(side_effect=["current content", "target content"]) + + result = await executor._check_file_content_changed("viking://resource/root/file.py") + assert result is True + + @pytest.mark.asyncio + async def test_returns_false_when_content_identical(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(return_value="same content") + + result = await executor._check_file_content_changed("viking://resource/root/file.py") + assert result is False + + @pytest.mark.asyncio + async def test_returns_true_on_read_exception(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(side_effect=Exception("read error")) + + result = await executor._check_file_content_changed("viking://resource/root/file.py") + assert result is True + + @pytest.mark.asyncio + async def test_calls_read_file_with_correct_paths(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(return_value="content") + + await executor._check_file_content_changed("viking://resource/root/subdir/file.py") + + assert mock_viking_fs.read_file.call_count == 2 + calls = mock_viking_fs.read_file.call_args_list + assert calls[0][0][0] == "viking://resource/root/subdir/file.py" + assert calls[1][0][0] == "viking://resource/target/subdir/file.py" + + +class TestReadExistingSummary: + """Tests for _read_existing_summary() method.""" + + @pytest.mark.asyncio + async def test_returns_none_when_target_path_is_none(self, executor): + executor._root_uri = None + result = await executor._read_existing_summary("viking://resource/root/file.py") + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_when_vector_store_is_none(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=None) + + result = await executor._read_existing_summary("viking://resource/root/file.py") + assert result is None + + @pytest.mark.asyncio + async def test_returns_summary_dict_when_record_exists( + self, executor, mock_viking_fs, mock_vector_store + ): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_vector_store.get_context_by_uri = AsyncMock( + return_value=[{"abstract": "existing summary content"}] + ) + + result = await executor._read_existing_summary("viking://resource/root/subdir/file.py") + + assert result is not None + assert result["name"] == "file.py" + assert result["summary"] == "existing summary content" + + @pytest.mark.asyncio + async def test_returns_none_when_no_records(self, executor, mock_viking_fs, mock_vector_store): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_vector_store.get_context_by_uri = AsyncMock(return_value=[]) + + result = await executor._read_existing_summary("viking://resource/root/file.py") + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_when_abstract_is_empty( + self, executor, mock_viking_fs, mock_vector_store + ): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_vector_store.get_context_by_uri = AsyncMock(return_value=[{"abstract": ""}]) + + result = await executor._read_existing_summary("viking://resource/root/file.py") + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_on_exception(self, executor, mock_viking_fs, mock_vector_store): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_vector_store.get_context_by_uri = AsyncMock(side_effect=Exception("db error")) + + result = await executor._read_existing_summary("viking://resource/root/file.py") + assert result is None + + @pytest.mark.asyncio + async def test_calls_vector_store_with_correct_uri( + self, executor, mock_viking_fs, mock_vector_store, mock_context + ): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_vector_store.get_context_by_uri = AsyncMock(return_value=[{"abstract": "summary"}]) + + await executor._read_existing_summary("viking://resource/root/file.py") + + mock_vector_store.get_context_by_uri.assert_called_once_with( + account_id=mock_context.account_id, + uri="viking://resource/target/file.py", + limit=1, + ) + + +class TestCheckDirChildrenChanged: + """Tests for _check_dir_children_changed() method.""" + + @pytest.mark.asyncio + async def test_returns_true_when_target_path_is_none(self, executor): + executor._root_uri = None + result = await executor._check_dir_children_changed( + "viking://resource/root", ["file1.py"], ["dir1"] + ) + assert result is True + + @pytest.mark.asyncio + async def test_returns_false_when_children_identical(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + current_files = ["viking://resource/root/file1.py", "viking://resource/root/file2.py"] + current_dirs = ["viking://resource/root/dir1", "viking://resource/root/dir2"] + + executor._list_dir = AsyncMock( + side_effect=[ + ( + ["viking://resource/target/dir1", "viking://resource/target/dir2"], + ["viking://resource/target/file1.py", "viking://resource/target/file2.py"], + ), + ([], []), + ] + ) + mock_viking_fs.read_file = AsyncMock(return_value="same content") + + result = await executor._check_dir_children_changed( + "viking://resource/root", current_files, current_dirs + ) + assert result is False + + @pytest.mark.asyncio + async def test_returns_true_when_file_names_differ(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + current_files = ["viking://resource/root/file1.py"] + current_dirs = [] + + mock_viking_fs.ls = AsyncMock( + return_value=[ + {"name": "file2.py", "isDir": False}, + ] + ) + + result = await executor._check_dir_children_changed( + "viking://resource/root", current_files, current_dirs + ) + assert result is True + + @pytest.mark.asyncio + async def test_returns_true_when_dir_names_differ(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + current_files = [] + current_dirs = ["viking://resource/root/dir1"] + + mock_viking_fs.ls = AsyncMock( + return_value=[ + {"name": "dir2", "isDir": True}, + ] + ) + + result = await executor._check_dir_children_changed( + "viking://resource/root", current_files, current_dirs + ) + assert result is True + + @pytest.mark.asyncio + async def test_returns_true_when_file_content_changed(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + current_files = ["viking://resource/root/file1.py"] + current_dirs = [] + + mock_viking_fs.ls = AsyncMock( + side_effect=[ + [{"name": "file1.py", "isDir": False}], + [], + ] + ) + mock_viking_fs.read_file = AsyncMock(side_effect=["old content", "new content"]) + + result = await executor._check_dir_children_changed( + "viking://resource/root", current_files, current_dirs + ) + assert result is True + + @pytest.mark.asyncio + async def test_returns_true_on_exception(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.ls = AsyncMock(side_effect=Exception("ls error")) + + result = await executor._check_dir_children_changed( + "viking://resource/root", ["file1.py"], [] + ) + assert result is True + + @pytest.mark.asyncio + async def test_handles_empty_directories(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.ls = AsyncMock(return_value=[]) + + result = await executor._check_dir_children_changed("viking://resource/root", [], []) + assert result is False + + @pytest.mark.asyncio + async def test_ignores_dot_files_in_comparison(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + current_files = ["viking://resource/root/file1.py"] + current_dirs = [] + + executor._list_dir = AsyncMock( + return_value=( + [], + ["viking://resource/target/file1.py"], + ) + ) + + result = await executor._check_dir_children_changed( + "viking://resource/root", current_files, current_dirs + ) + assert result is False + + +class TestReadExistingOverviewAbstract: + """Tests for _read_existing_overview_abstract() method.""" + + @pytest.mark.asyncio + async def test_returns_none_tuple_when_target_path_is_none(self, executor): + executor._root_uri = None + result = await executor._read_existing_overview_abstract("viking://resource/root") + assert result == (None, None) + + @pytest.mark.asyncio + async def test_returns_overview_and_abstract_when_files_exist(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(side_effect=["overview content", "abstract content"]) + + result = await executor._read_existing_overview_abstract("viking://resource/root/dir") + + assert result == ("overview content", "abstract content") + calls = mock_viking_fs.read_file.call_args_list + assert calls[0][0][0] == "viking://resource/target/dir/.overview.md" + assert calls[1][0][0] == "viking://resource/target/dir/.abstract.md" + + @pytest.mark.asyncio + async def test_returns_none_tuple_on_exception(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(side_effect=Exception("read error")) + + result = await executor._read_existing_overview_abstract("viking://resource/root/dir") + assert result == (None, None) + + @pytest.mark.asyncio + async def test_handles_missing_overview_file(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock( + side_effect=[Exception("not found"), "abstract content"] + ) + + result = await executor._read_existing_overview_abstract("viking://resource/root/dir") + assert result == (None, None) + + @pytest.mark.asyncio + async def test_handles_missing_abstract_file(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock( + side_effect=["overview content", Exception("not found")] + ) + + result = await executor._read_existing_overview_abstract("viking://resource/root/dir") + assert result == (None, None) + + @pytest.mark.asyncio + async def test_calls_read_file_with_context(self, executor, mock_viking_fs, mock_context): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(return_value="content") + + await executor._read_existing_overview_abstract("viking://resource/root/dir") + + for call in mock_viking_fs.read_file.call_args_list: + assert "ctx" in call[1] + assert call[1]["ctx"] == mock_context + + @pytest.mark.asyncio + async def test_handles_nested_directory_path(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(return_value="content") + + await executor._read_existing_overview_abstract("viking://resource/root/a/b/c") + + calls = mock_viking_fs.read_file.call_args_list + assert calls[0][0][0] == "viking://resource/target/a/b/c/.overview.md" + assert calls[1][0][0] == "viking://resource/target/a/b/c/.abstract.md" + + +class TestIncrementalUpdateIntegration: + """Integration tests for incremental update scenarios.""" + + @pytest.mark.asyncio + async def test_full_incremental_flow_no_changes( + self, executor, mock_viking_fs, mock_vector_store + ): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_viking_fs.ls = AsyncMock(return_value=[]) + mock_viking_fs.read_file = AsyncMock(return_value="same content") + mock_vector_store.get_context_by_uri = AsyncMock( + return_value=[{"abstract": "existing summary"}] + ) + + content_changed = await executor._check_file_content_changed( + "viking://resource/root/file.py" + ) + assert content_changed is False + + summary = await executor._read_existing_summary("viking://resource/root/file.py") + assert summary is not None + assert summary["summary"] == "existing summary" + + @pytest.mark.asyncio + async def test_full_incremental_flow_with_changes( + self, executor, mock_viking_fs, mock_vector_store + ): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) + + mock_viking_fs.read_file = AsyncMock(side_effect=["new content", "old content"]) + + content_changed = await executor._check_file_content_changed( + "viking://resource/root/file.py" + ) + assert content_changed is True + + @pytest.mark.asyncio + async def test_directory_change_detection_flow(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + executor._list_dir = AsyncMock( + side_effect=[ + ([], ["viking://resource/target/file1.py", "viking://resource/target/file2.py"]), + ([], []), + ] + ) + mock_viking_fs.read_file = AsyncMock(return_value="same content") + + current_files = ["viking://resource/root/file1.py", "viking://resource/root/file2.py"] + changed = await executor._check_dir_children_changed( + "viking://resource/root", current_files, [] + ) + assert changed is False + + @pytest.mark.asyncio + async def test_overview_abstract_read_flow(self, executor, mock_viking_fs): + executor._viking_fs = mock_viking_fs + executor._root_uri = "viking://resource/root" + + mock_viking_fs.read_file = AsyncMock(side_effect=["existing overview", "existing abstract"]) + + overview, abstract = await executor._read_existing_overview_abstract( + "viking://resource/root/subdir" + ) + assert overview == "existing overview" + assert abstract == "existing abstract" diff --git a/tests/unit/storage/queuefs/test_embedding_msg.py b/tests/unit/storage/queuefs/test_embedding_msg.py new file mode 100644 index 00000000..3311ae96 --- /dev/null +++ b/tests/unit/storage/queuefs/test_embedding_msg.py @@ -0,0 +1,262 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +import json + +import pytest + +from openviking.storage.queuefs.embedding_msg import EmbeddingMsg + + +class TestEmbeddingMsg: + """Unit tests for EmbeddingMsg class.""" + + def test_semantic_msg_id_serialization(self): + """Test semantic_msg_id field serialization via to_dict().""" + msg = EmbeddingMsg( + message="test message", + context_data={"key": "value"}, + semantic_msg_id="semantic-123", + ) + result = msg.to_dict() + assert result["semantic_msg_id"] == "semantic-123" + assert result["message"] == "test message" + assert result["context_data"] == {"key": "value"} + assert hasattr(msg, "id") + assert msg.id is not None + + def test_semantic_msg_id_deserialization(self): + """Test semantic_msg_id field deserialization via from_dict().""" + data = { + "id": "test-id-123", + "message": "test message", + "context_data": {"key": "value"}, + "semantic_msg_id": "semantic-456", + } + msg = EmbeddingMsg.from_dict(data) + assert msg.semantic_msg_id == "semantic-456" + assert msg.id == "test-id-123" + assert msg.message == "test message" + assert msg.context_data == {"key": "value"} + + def test_from_dict_missing_semantic_msg_id_defaults_to_none(self): + """Test from_dict() compatibility with old format (missing semantic_msg_id).""" + data = { + "id": "test-id-789", + "message": "legacy message", + "context_data": {"legacy": True}, + } + msg = EmbeddingMsg.from_dict(data) + assert msg.semantic_msg_id is None + assert msg.id == "test-id-789" + assert msg.message == "legacy message" + + def test_from_dict_semantic_msg_id_explicit_none(self): + """Test from_dict() with explicit None for semantic_msg_id.""" + data = { + "id": "test-id-none", + "message": "message with None", + "context_data": {}, + "semantic_msg_id": None, + } + msg = EmbeddingMsg.from_dict(data) + assert msg.semantic_msg_id is None + + def test_to_json_with_semantic_msg_id(self): + """Test to_json() method with semantic_msg_id.""" + msg = EmbeddingMsg( + message="json test", + context_data={"json_key": "json_value"}, + semantic_msg_id="semantic-json", + ) + json_str = msg.to_json() + parsed = json.loads(json_str) + assert parsed["semantic_msg_id"] == "semantic-json" + assert parsed["message"] == "json test" + assert parsed["context_data"] == {"json_key": "json_value"} + + def test_to_json_without_semantic_msg_id(self): + """Test to_json() method without semantic_msg_id (None).""" + msg = EmbeddingMsg( + message="json test no id", + context_data={"key": "value"}, + ) + json_str = msg.to_json() + parsed = json.loads(json_str) + assert parsed["semantic_msg_id"] is None + + def test_from_json_with_semantic_msg_id(self): + """Test from_json() method with semantic_msg_id.""" + json_str = json.dumps( + { + "id": "json-id-123", + "message": "from json", + "context_data": {"from": "json"}, + "semantic_msg_id": "semantic-from-json", + } + ) + msg = EmbeddingMsg.from_json(json_str) + assert msg.semantic_msg_id == "semantic-from-json" + assert msg.id == "json-id-123" + assert msg.message == "from json" + + def test_from_json_missing_semantic_msg_id(self): + """Test from_json() with missing semantic_msg_id (backward compatibility).""" + json_str = json.dumps( + { + "id": "json-id-456", + "message": "legacy json", + "context_data": {"legacy": True}, + } + ) + msg = EmbeddingMsg.from_json(json_str) + assert msg.semantic_msg_id is None + assert msg.message == "legacy json" + + def test_from_json_invalid_json_raises_value_error(self): + """Test from_json() raises ValueError for invalid JSON.""" + with pytest.raises(ValueError, match="Invalid JSON string"): + EmbeddingMsg.from_json("not a valid json") + + def test_message_field_string_type(self): + """Test message field with string type.""" + msg = EmbeddingMsg( + message="simple string message", + context_data={}, + ) + assert isinstance(msg.message, str) + assert msg.message == "simple string message" + + def test_message_field_list_of_dicts_type(self): + """Test message field with List[Dict] type.""" + message_list = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + msg = EmbeddingMsg( + message=message_list, + context_data={"conversation": True}, + ) + assert isinstance(msg.message, list) + assert len(msg.message) == 2 + assert msg.message[0]["role"] == "user" + assert msg.message[1]["content"] == "Hi there" + + def test_message_list_serialization_deserialization(self): + """Test serialization and deserialization with List[Dict] message.""" + message_list = [ + {"role": "user", "content": "Question?"}, + {"role": "assistant", "content": "Answer."}, + ] + msg = EmbeddingMsg( + message=message_list, + context_data={"type": "qa"}, + semantic_msg_id="qa-123", + ) + json_str = msg.to_json() + restored = EmbeddingMsg.from_json(json_str) + assert isinstance(restored.message, list) + assert len(restored.message) == 2 + assert restored.message[0]["role"] == "user" + assert restored.semantic_msg_id == "qa-123" + + def test_id_auto_generated(self): + """Test that id is auto-generated as UUID.""" + msg = EmbeddingMsg( + message="test", + context_data={}, + ) + assert msg.id is not None + assert len(msg.id) == 36 + assert msg.id.count("-") == 4 + + def test_from_dict_preserves_or_generates_id(self): + """Test from_dict() preserves provided id or generates new one.""" + data_with_id = { + "id": "preserved-id", + "message": "test", + "context_data": {}, + } + msg = EmbeddingMsg.from_dict(data_with_id) + assert msg.id == "preserved-id" + + data_without_id = { + "message": "test", + "context_data": {}, + } + msg = EmbeddingMsg.from_dict(data_without_id) + assert msg.id is not None + assert len(msg.id) == 36 + + def test_empty_context_data(self): + """Test with empty context_data.""" + msg = EmbeddingMsg( + message="test", + context_data={}, + semantic_msg_id="empty-ctx", + ) + result = msg.to_dict() + assert result["context_data"] == {} + assert result["semantic_msg_id"] == "empty-ctx" + + def test_complex_context_data(self): + """Test with complex nested context_data.""" + complex_data = { + "nested": { + "level1": { + "level2": ["a", "b", "c"], + }, + }, + "list": [1, 2, 3], + "string": "value", + } + msg = EmbeddingMsg( + message="complex test", + context_data=complex_data, + semantic_msg_id="complex-123", + ) + json_str = msg.to_json() + restored = EmbeddingMsg.from_json(json_str) + assert restored.context_data["nested"]["level1"]["level2"] == ["a", "b", "c"] + assert restored.semantic_msg_id == "complex-123" + + def test_semantic_msg_id_empty_string(self): + """Test semantic_msg_id with empty string.""" + msg = EmbeddingMsg( + message="test", + context_data={}, + semantic_msg_id="", + ) + assert msg.semantic_msg_id == "" + result = msg.to_dict() + assert result["semantic_msg_id"] == "" + + def test_roundtrip_string_message(self): + """Test complete roundtrip with string message.""" + original = EmbeddingMsg( + message="roundtrip test", + context_data={"key": "value"}, + semantic_msg_id="roundtrip-id", + ) + json_str = original.to_json() + restored = EmbeddingMsg.from_json(json_str) + assert restored.message == original.message + assert restored.context_data == original.context_data + assert restored.semantic_msg_id == original.semantic_msg_id + assert restored.id is not None + assert len(restored.id) == 36 + + def test_roundtrip_list_message(self): + """Test complete roundtrip with List[Dict] message.""" + message_list = [ + {"type": "text", "content": "part1"}, + {"type": "code", "content": "print('hello')"}, + ] + original = EmbeddingMsg( + message=message_list, + context_data={"format": "mixed"}, + semantic_msg_id="list-msg-id", + ) + json_str = original.to_json() + restored = EmbeddingMsg.from_json(json_str) + assert restored.message == original.message + assert restored.semantic_msg_id == "list-msg-id" diff --git a/tests/unit/storage/queuefs/test_embedding_tracker.py b/tests/unit/storage/queuefs/test_embedding_tracker.py new file mode 100644 index 00000000..1850fef0 --- /dev/null +++ b/tests/unit/storage/queuefs/test_embedding_tracker.py @@ -0,0 +1,554 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for EmbeddingTaskTracker.""" + +import asyncio + +import pytest + +from openviking.storage.queuefs.embedding_tracker import EmbeddingTaskTracker + + +def reset_singleton(): + """Reset the singleton instance for testing.""" + EmbeddingTaskTracker._instance = None + + +@pytest.fixture(autouse=True) +def clean_singleton(): + """Reset singleton before and after each test.""" + reset_singleton() + yield + reset_singleton() + + +@pytest.fixture +def tracker() -> EmbeddingTaskTracker: + """Create a fresh tracker instance for each test.""" + return EmbeddingTaskTracker() + + +# ── Singleton Pattern Tests ── + + +def test_singleton_returns_same_instance(): + """Test that get_instance() returns the same instance.""" + instance1 = EmbeddingTaskTracker.get_instance() + instance2 = EmbeddingTaskTracker.get_instance() + assert instance1 is instance2 + + +def test_singleton_persists_across_calls(): + """Test that singleton persists across multiple get_instance calls.""" + instance1 = EmbeddingTaskTracker.get_instance() + instance2 = EmbeddingTaskTracker.get_instance() + instance3 = EmbeddingTaskTracker.get_instance() + assert instance1 is instance2 is instance3 + + +# ── Register Tests ── + + +@pytest.mark.asyncio +async def test_register_task(tracker: EmbeddingTaskTracker): + """Test registering a task with valid count.""" + await tracker.register("msg-1", 5) + status = await tracker.get_status("msg-1") + assert status is not None + assert status["remaining"] == 5 + assert status["total"] == 5 + + +@pytest.mark.asyncio +async def test_register_with_metadata(tracker: EmbeddingTaskTracker): + """Test registering a task with metadata.""" + metadata = {"key": "value", "count": 10} + await tracker.register("msg-2", 3, metadata=metadata) + status = await tracker.get_status("msg-2") + assert status is not None + assert status["metadata"] == metadata + + +@pytest.mark.asyncio +async def test_register_with_callback(tracker: EmbeddingTaskTracker): + """Test registering a task with on_complete callback.""" + callback_called = [] + + async def on_complete(): + callback_called.append(True) + + await tracker.register("msg-3", 1, on_complete=on_complete) + await tracker.decrement("msg-3") + assert len(callback_called) == 1 + + +@pytest.mark.asyncio +async def test_register_with_sync_callback(tracker: EmbeddingTaskTracker): + """Test registering a task with synchronous callback.""" + callback_called = [] + + def on_complete(): + callback_called.append(True) + + await tracker.register("msg-4", 1, on_complete=on_complete) + await tracker.decrement("msg-4") + assert len(callback_called) == 1 + + +@pytest.mark.asyncio +async def test_register_with_zero_count_does_nothing(tracker: EmbeddingTaskTracker): + """Test that registering with total_count=0 does nothing.""" + await tracker.register("msg-5", 0) + status = await tracker.get_status("msg-5") + assert status is None + + +@pytest.mark.asyncio +async def test_register_with_negative_count_does_nothing(tracker: EmbeddingTaskTracker): + """Test that registering with negative total_count does nothing.""" + await tracker.register("msg-6", -5) + status = await tracker.get_status("msg-6") + assert status is None + + +@pytest.mark.asyncio +async def test_register_overwrites_existing(tracker: EmbeddingTaskTracker): + """Test that registering with same ID overwrites existing entry.""" + await tracker.register("msg-7", 5) + await tracker.register("msg-7", 10) + status = await tracker.get_status("msg-7") + assert status is not None + assert status["remaining"] == 10 + assert status["total"] == 10 + + +# ── Increment Tests ── + + +@pytest.mark.asyncio +async def test_increment_existing_task(tracker: EmbeddingTaskTracker): + """Test incrementing an existing task.""" + await tracker.register("msg-10", 5) + result = await tracker.increment("msg-10") + assert result == 6 + status = await tracker.get_status("msg-10") + assert status["remaining"] == 6 + assert status["total"] == 6 + + +@pytest.mark.asyncio +async def test_increment_multiple_times(tracker: EmbeddingTaskTracker): + """Test incrementing a task multiple times.""" + await tracker.register("msg-11", 2) + await tracker.increment("msg-11") + await tracker.increment("msg-11") + await tracker.increment("msg-11") + status = await tracker.get_status("msg-11") + assert status["remaining"] == 5 + assert status["total"] == 5 + + +@pytest.mark.asyncio +async def test_increment_nonexistent_task_returns_none(tracker: EmbeddingTaskTracker): + """Test incrementing a non-existent task returns None.""" + result = await tracker.increment("nonexistent") + assert result is None + + +# ── Decrement Tests ── + + +@pytest.mark.asyncio +async def test_decrement_existing_task(tracker: EmbeddingTaskTracker): + """Test decrementing an existing task.""" + await tracker.register("msg-20", 5) + result = await tracker.decrement("msg-20") + assert result == 4 + status = await tracker.get_status("msg-20") + assert status["remaining"] == 4 + + +@pytest.mark.asyncio +async def test_decrement_multiple_times(tracker: EmbeddingTaskTracker): + """Test decrementing a task multiple times.""" + await tracker.register("msg-21", 3) + await tracker.decrement("msg-21") + await tracker.decrement("msg-21") + status = await tracker.get_status("msg-21") + assert status["remaining"] == 1 + + +@pytest.mark.asyncio +async def test_decrement_nonexistent_task_returns_none(tracker: EmbeddingTaskTracker): + """Test decrementing a non-existent task returns None.""" + result = await tracker.decrement("nonexistent") + assert result is None + + +@pytest.mark.asyncio +async def test_decrement_to_zero_removes_task(tracker: EmbeddingTaskTracker): + """Test that decrementing to zero removes the task.""" + await tracker.register("msg-22", 1) + result = await tracker.decrement("msg-22") + assert result == 0 + status = await tracker.get_status("msg-22") + assert status is None + + +@pytest.mark.asyncio +async def test_decrement_triggers_callback_on_completion(tracker: EmbeddingTaskTracker): + """Test that callback is triggered when count reaches zero.""" + callback_called = [] + + async def on_complete(): + callback_called.append("async") + + await tracker.register("msg-23", 2, on_complete=on_complete) + await tracker.decrement("msg-23") + assert len(callback_called) == 0 + await tracker.decrement("msg-23") + assert len(callback_called) == 1 + assert callback_called[0] == "async" + + +@pytest.mark.asyncio +async def test_decrement_sync_callback_on_completion(tracker: EmbeddingTaskTracker): + """Test that sync callback is triggered when count reaches zero.""" + callback_called = [] + + def on_complete(): + callback_called.append("sync") + + await tracker.register("msg-24", 1, on_complete=on_complete) + await tracker.decrement("msg-24") + assert len(callback_called) == 1 + assert callback_called[0] == "sync" + + +@pytest.mark.asyncio +async def test_decrement_callback_error_is_handled(tracker: EmbeddingTaskTracker): + """Test that callback errors are handled gracefully.""" + + async def on_complete(): + raise ValueError("Callback error") + + await tracker.register("msg-25", 1, on_complete=on_complete) + result = await tracker.decrement("msg-25") + assert result == 0 + + +@pytest.mark.asyncio +async def test_decrement_below_zero_removes_task(tracker: EmbeddingTaskTracker): + """Test that decrementing below zero still removes the task.""" + await tracker.register("msg-26", 1) + await tracker.decrement("msg-26") + status = await tracker.get_status("msg-26") + assert status is None + + +# ── Get Status Tests ── + + +@pytest.mark.asyncio +async def test_get_status_existing_task(tracker: EmbeddingTaskTracker): + """Test getting status of an existing task.""" + await tracker.register("msg-30", 5, metadata={"key": "value"}) + status = await tracker.get_status("msg-30") + assert status is not None + assert status["remaining"] == 5 + assert status["total"] == 5 + assert status["metadata"] == {"key": "value"} + + +@pytest.mark.asyncio +async def test_get_status_nonexistent_task(tracker: EmbeddingTaskTracker): + """Test getting status of a non-existent task.""" + status = await tracker.get_status("nonexistent") + assert status is None + + +@pytest.mark.asyncio +async def test_get_status_reflects_changes(tracker: EmbeddingTaskTracker): + """Test that get_status reflects increment/decrement changes.""" + await tracker.register("msg-31", 5) + await tracker.increment("msg-31") + status = await tracker.get_status("msg-31") + assert status["remaining"] == 6 + assert status["total"] == 6 + await tracker.decrement("msg-31") + status = await tracker.get_status("msg-31") + assert status["remaining"] == 5 + + +# ── Remove Tests ── + + +@pytest.mark.asyncio +async def test_remove_existing_task(tracker: EmbeddingTaskTracker): + """Test removing an existing task.""" + await tracker.register("msg-40", 5) + result = await tracker.remove("msg-40") + assert result is True + status = await tracker.get_status("msg-40") + assert status is None + + +@pytest.mark.asyncio +async def test_remove_nonexistent_task(tracker: EmbeddingTaskTracker): + """Test removing a non-existent task.""" + result = await tracker.remove("nonexistent") + assert result is False + + +@pytest.mark.asyncio +async def test_remove_does_not_trigger_callback(tracker: EmbeddingTaskTracker): + """Test that remove does not trigger the on_complete callback.""" + callback_called = [] + + async def on_complete(): + callback_called.append(True) + + await tracker.register("msg-41", 5, on_complete=on_complete) + await tracker.remove("msg-41") + assert len(callback_called) == 0 + + +# ── Get All Tracked Tests ── + + +@pytest.mark.asyncio +async def test_get_all_tracked_empty(tracker: EmbeddingTaskTracker): + """Test get_all_tracked when no tasks are registered.""" + all_tasks = await tracker.get_all_tracked() + assert all_tasks == {} + + +@pytest.mark.asyncio +async def test_get_all_tracked_single_task(tracker: EmbeddingTaskTracker): + """Test get_all_tracked with a single task.""" + await tracker.register("msg-50", 5, metadata={"key": "value"}) + all_tasks = await tracker.get_all_tracked() + assert len(all_tasks) == 1 + assert "msg-50" in all_tasks + assert all_tasks["msg-50"]["remaining"] == 5 + assert all_tasks["msg-50"]["total"] == 5 + assert all_tasks["msg-50"]["metadata"] == {"key": "value"} + + +@pytest.mark.asyncio +async def test_get_all_tracked_multiple_tasks(tracker: EmbeddingTaskTracker): + """Test get_all_tracked with multiple tasks.""" + await tracker.register("msg-51", 3) + await tracker.register("msg-52", 5) + await tracker.register("msg-53", 7) + all_tasks = await tracker.get_all_tracked() + assert len(all_tasks) == 3 + assert "msg-51" in all_tasks + assert "msg-52" in all_tasks + assert "msg-53" in all_tasks + + +@pytest.mark.asyncio +async def test_get_all_tracked_excludes_on_complete(tracker: EmbeddingTaskTracker): + """Test that get_all_tracked does not include on_complete callback.""" + await tracker.register("msg-54", 5, on_complete=lambda: None) + all_tasks = await tracker.get_all_tracked() + assert "on_complete" not in all_tasks["msg-54"] + + +@pytest.mark.asyncio +async def test_get_all_tracked_returns_copy(tracker: EmbeddingTaskTracker): + """Test that get_all_tracked returns a copy, not internal state.""" + await tracker.register("msg-55", 5) + all_tasks = await tracker.get_all_tracked() + all_tasks["msg-55"]["remaining"] = 999 + status = await tracker.get_status("msg-55") + assert status["remaining"] == 5 + + +# ── Concurrency Tests ── + + +@pytest.mark.asyncio +async def test_concurrent_register(tracker: EmbeddingTaskTracker): + """Test concurrent register operations.""" + + async def register_task(msg_id: str): + await tracker.register(msg_id, 5) + + await asyncio.gather( + register_task("msg-60"), + register_task("msg-61"), + register_task("msg-62"), + ) + all_tasks = await tracker.get_all_tracked() + assert len(all_tasks) == 3 + + +@pytest.mark.asyncio +async def test_concurrent_increment(tracker: EmbeddingTaskTracker): + """Test concurrent increment operations.""" + await tracker.register("msg-70", 1) + + async def increment_task(): + await tracker.increment("msg-70") + + await asyncio.gather( + increment_task(), + increment_task(), + increment_task(), + ) + status = await tracker.get_status("msg-70") + assert status["remaining"] == 4 + assert status["total"] == 4 + + +@pytest.mark.asyncio +async def test_concurrent_decrement(tracker: EmbeddingTaskTracker): + """Test concurrent decrement operations.""" + await tracker.register("msg-71", 3) + + async def decrement_task(): + await tracker.decrement("msg-71") + + await asyncio.gather( + decrement_task(), + decrement_task(), + decrement_task(), + ) + status = await tracker.get_status("msg-71") + assert status is None + + +@pytest.mark.asyncio +async def test_concurrent_mixed_operations(tracker: EmbeddingTaskTracker): + """Test concurrent mixed operations (increment and decrement).""" + await tracker.register("msg-72", 5) + + async def increment(): + await tracker.increment("msg-72") + + async def decrement(): + await tracker.decrement("msg-72") + + await asyncio.gather( + increment(), + increment(), + decrement(), + decrement(), + decrement(), + ) + status = await tracker.get_status("msg-72") + assert status is not None + + +@pytest.mark.asyncio +async def test_concurrent_register_and_decrement(tracker: EmbeddingTaskTracker): + """Test concurrent register and decrement operations.""" + callback_called = [] + + async def on_complete(): + callback_called.append(True) + + await tracker.register("msg-73", 1, on_complete=on_complete) + await tracker.decrement("msg-73") + assert len(callback_called) == 1 + + +@pytest.mark.asyncio +async def test_concurrent_callback_execution(tracker: EmbeddingTaskTracker): + """Test that callbacks are executed correctly under concurrency.""" + callback_count = [] + + async def make_callback(msg_id: str): + async def on_complete(): + callback_count.append(msg_id) + + return on_complete + + async def register_and_complete(msg_id: str): + callback = await make_callback(msg_id) + await tracker.register(msg_id, 1, on_complete=callback) + await tracker.decrement(msg_id) + + await asyncio.gather( + register_and_complete("msg-80"), + register_and_complete("msg-81"), + register_and_complete("msg-82"), + ) + assert len(callback_count) == 3 + + +# ── Edge Cases Tests ── + + +@pytest.mark.asyncio +async def test_multiple_decrements_to_zero(tracker: EmbeddingTaskTracker): + """Test multiple decrements that bring count to exactly zero.""" + callback_called = [] + + async def on_complete(): + callback_called.append(True) + + await tracker.register("msg-90", 3, on_complete=on_complete) + await tracker.decrement("msg-90") + await tracker.decrement("msg-90") + assert len(callback_called) == 0 + await tracker.decrement("msg-90") + assert len(callback_called) == 1 + + +@pytest.mark.asyncio +async def test_decrement_after_increment(tracker: EmbeddingTaskTracker): + """Test decrement after increment maintains correct count.""" + await tracker.register("msg-91", 2) + await tracker.increment("msg-91") + await tracker.decrement("msg-91") + status = await tracker.get_status("msg-91") + assert status["remaining"] == 2 + assert status["total"] == 3 + + +@pytest.mark.asyncio +async def test_empty_metadata(tracker: EmbeddingTaskTracker): + """Test that empty metadata is handled correctly.""" + await tracker.register("msg-92", 5, metadata={}) + status = await tracker.get_status("msg-92") + assert status["metadata"] == {} + + +@pytest.mark.asyncio +async def test_none_metadata(tracker: EmbeddingTaskTracker): + """Test that None metadata defaults to empty dict.""" + await tracker.register("msg-93", 5, metadata=None) + status = await tracker.get_status("msg-93") + assert status["metadata"] == {} + + +@pytest.mark.asyncio +async def test_none_callback(tracker: EmbeddingTaskTracker): + """Test that None callback is handled correctly.""" + await tracker.register("msg-94", 1, on_complete=None) + result = await tracker.decrement("msg-94") + assert result == 0 + + +@pytest.mark.asyncio +async def test_large_count(tracker: EmbeddingTaskTracker): + """Test with large task count.""" + large_count = 10000 + await tracker.register("msg-95", large_count) + status = await tracker.get_status("msg-95") + assert status["remaining"] == large_count + assert status["total"] == large_count + + +@pytest.mark.asyncio +async def test_special_characters_in_id(tracker: EmbeddingTaskTracker): + """Test with special characters in semantic_msg_id.""" + special_id = "msg-with-special_chars.123!@#$%" + await tracker.register(special_id, 5) + status = await tracker.get_status(special_id) + assert status is not None + assert status["remaining"] == 5 diff --git a/tests/unit/storage/queuefs/test_processor_incremental.py b/tests/unit/storage/queuefs/test_processor_incremental.py new file mode 100644 index 00000000..01c9a57e --- /dev/null +++ b/tests/unit/storage/queuefs/test_processor_incremental.py @@ -0,0 +1,868 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for SemanticProcessor incremental update and diff calculation. + +Tests for: +- _detect_file_type(): Detect file type based on extension +- _collect_tree_info(): Collect directory tree information +- _compute_diff(): Compute directory differences +- _check_file_content_changed(): Check file content changes +- _execute_sync_operations(): Execute sync operations +- _create_sync_diff_callback(): Create sync diff callback +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.parse.parsers.constants import ( + FILE_TYPE_CODE, + FILE_TYPE_DOCUMENTATION, + FILE_TYPE_OTHER, +) +from openviking.server.identity import RequestContext, Role +from openviking.storage.queuefs.semantic_processor import DiffResult, SemanticProcessor +from openviking_cli.session.user_id import UserIdentifier + + +class FakeVikingFS: + """Fake VikingFS for testing.""" + + def __init__(self): + self._tree = {} + self._file_contents = {} + self.deleted_files = [] + self.deleted_dirs = [] + self.moved_files = [] + self.created_dirs = [] + + def set_tree(self, tree): + self._tree = tree + + def set_file_contents(self, contents): + self._file_contents = contents + + async def ls(self, uri, ctx=None): + return self._tree.get(uri.rstrip("/"), []) + + async def read_file(self, uri, ctx=None): + return self._file_contents.get(uri, "") + + async def rm(self, uri, recursive=False, ctx=None): + if recursive: + self.deleted_dirs.append(uri) + else: + self.deleted_files.append(uri) + + async def mv(self, src, dst, ctx=None): + self.moved_files.append((src, dst)) + + async def mkdir(self, uri, exist_ok=True, ctx=None): + self.created_dirs.append(uri) + + +@pytest.fixture +def processor(): + """Create a SemanticProcessor instance for testing.""" + return SemanticProcessor(max_concurrent_llm=10) + + +@pytest.fixture +def fake_fs(): + """Create a fake VikingFS instance.""" + return FakeVikingFS() + + +@pytest.fixture +def ctx(): + """Create a RequestContext for testing.""" + return RequestContext( + user=UserIdentifier("test_account", "test_user", "test_agent"), + role=Role.USER, + ) + + +class TestDetectFileType: + """Test cases for _detect_file_type() method.""" + + def test_detect_python_file(self, processor): + result = processor._detect_file_type("main.py") + assert result == FILE_TYPE_CODE + + def test_detect_javascript_file(self, processor): + result = processor._detect_file_type("app.js") + assert result == FILE_TYPE_CODE + + def test_detect_typescript_file(self, processor): + result = processor._detect_file_type("utils.ts") + assert result == FILE_TYPE_CODE + + def test_detect_java_file(self, processor): + result = processor._detect_file_type("Main.java") + assert result == FILE_TYPE_CODE + + def test_detect_go_file(self, processor): + result = processor._detect_file_type("server.go") + assert result == FILE_TYPE_CODE + + def test_detect_rust_file(self, processor): + result = processor._detect_file_type("main.rs") + assert result == FILE_TYPE_CODE + + def test_detect_c_file(self, processor): + result = processor._detect_file_type("program.c") + assert result == FILE_TYPE_CODE + + def test_detect_cpp_file(self, processor): + result = processor._detect_file_type("module.cpp") + assert result == FILE_TYPE_CODE + + def test_detect_markdown_file(self, processor): + result = processor._detect_file_type("README.md") + assert result == FILE_TYPE_DOCUMENTATION + + def test_detect_rst_file(self, processor): + result = processor._detect_file_type("docs.rst") + assert result == FILE_TYPE_DOCUMENTATION + + def test_detect_txt_file(self, processor): + result = processor._detect_file_type("notes.txt") + assert result == FILE_TYPE_DOCUMENTATION + + def test_detect_json_file(self, processor): + result = processor._detect_file_type("config.json") + assert result == FILE_TYPE_CODE + + def test_detect_yaml_file(self, processor): + result = processor._detect_file_type("settings.yaml") + assert result == FILE_TYPE_CODE + + def test_detect_unknown_extension(self, processor): + result = processor._detect_file_type("data.xyz") + assert result == FILE_TYPE_OTHER + + def test_detect_no_extension(self, processor): + result = processor._detect_file_type("Makefile") + assert result == FILE_TYPE_OTHER + + def test_detect_uppercase_extension(self, processor): + result = processor._detect_file_type("SCRIPT.PY") + assert result == FILE_TYPE_CODE + + def test_detect_mixed_case_extension(self, processor): + result = processor._detect_file_type("ReadMe.Md") + assert result == FILE_TYPE_DOCUMENTATION + + def test_detect_path_with_dots(self, processor): + result = processor._detect_file_type("src/utils/helper.py") + assert result == FILE_TYPE_CODE + + +class TestCollectTreeInfo: + """Test cases for _collect_tree_info() method.""" + + @pytest.mark.asyncio + async def test_collect_empty_directory(self, processor, fake_fs, ctx): + fake_fs.set_tree({"viking://temp/empty": []}) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/empty") + + assert result == {"viking://temp/empty": ([], [])} + + @pytest.mark.asyncio + async def test_collect_directory_with_files(self, processor, fake_fs, ctx): + fake_fs.set_tree( + { + "viking://temp/dir": [ + {"name": "file1.txt", "isDir": False}, + {"name": "file2.py", "isDir": False}, + ] + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/dir") + + assert "viking://temp/dir" in result + sub_dirs, files = result["viking://temp/dir"] + assert sub_dirs == [] + assert len(files) == 2 + assert "viking://temp/dir/file1.txt" in files + assert "viking://temp/dir/file2.py" in files + + @pytest.mark.asyncio + async def test_collect_directory_with_subdirs(self, processor, fake_fs, ctx): + fake_fs.set_tree( + { + "viking://temp/root": [ + {"name": "subdir1", "isDir": True}, + {"name": "subdir2", "isDir": True}, + ], + "viking://temp/root/subdir1": [ + {"name": "file1.txt", "isDir": False}, + ], + "viking://temp/root/subdir2": [ + {"name": "file2.txt", "isDir": False}, + ], + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/root") + + assert "viking://temp/root" in result + assert "viking://temp/root/subdir1" in result + assert "viking://temp/root/subdir2" in result + + @pytest.mark.asyncio + async def test_collect_nested_directories(self, processor, fake_fs, ctx): + fake_fs.set_tree( + { + "viking://temp/root": [ + {"name": "level1", "isDir": True}, + ], + "viking://temp/root/level1": [ + {"name": "level2", "isDir": True}, + ], + "viking://temp/root/level1/level2": [ + {"name": "deep_file.txt", "isDir": False}, + ], + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/root") + + assert "viking://temp/root" in result + assert "viking://temp/root/level1" in result + assert "viking://temp/root/level1/level2" in result + + @pytest.mark.asyncio + async def test_collect_skips_hidden_files(self, processor, fake_fs, ctx): + fake_fs.set_tree( + { + "viking://temp/dir": [ + {"name": ".hidden", "isDir": False}, + {"name": "visible.txt", "isDir": False}, + ] + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/dir") + + _, files = result["viking://temp/dir"] + assert len(files) == 1 + assert "viking://temp/dir/visible.txt" in files + + @pytest.mark.asyncio + async def test_collect_skips_dot_and_dotdot(self, processor, fake_fs, ctx): + fake_fs.set_tree( + { + "viking://temp/dir": [ + {"name": ".", "isDir": True}, + {"name": "..", "isDir": True}, + {"name": "file.txt", "isDir": False}, + ] + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/dir") + + sub_dirs, files = result["viking://temp/dir"] + assert sub_dirs == [] + assert len(files) == 1 + + @pytest.mark.asyncio + async def test_collect_handles_ls_error(self, processor, fake_fs, ctx): + fake_fs.ls = AsyncMock(side_effect=Exception("LS error")) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._collect_tree_info("viking://temp/dir") + + assert result == {} + + +class TestComputeDiff: + """Test cases for _compute_diff() method.""" + + @pytest.mark.asyncio + async def test_compute_diff_no_changes(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], ["viking://temp/root/file.txt"]), + } + target_tree = { + "viking://target/root": ([], ["viking://target/root/file.txt"]), + } + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert diff.added_files == [] + assert diff.deleted_files == [] + assert diff.updated_files == [] + assert diff.added_dirs == [] + assert diff.deleted_dirs == [] + + @pytest.mark.asyncio + async def test_compute_diff_added_files(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], ["viking://temp/root/new_file.txt"]), + } + target_tree = { + "viking://target/root": ([], []), + } + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert len(diff.added_files) == 1 + assert "viking://temp/root/new_file.txt" in diff.added_files + assert diff.deleted_files == [] + + @pytest.mark.asyncio + async def test_compute_diff_deleted_files(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], []), + } + target_tree = { + "viking://target/root": ([], ["viking://target/root/old_file.txt"]), + } + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert diff.added_files == [] + assert len(diff.deleted_files) == 1 + assert "viking://target/root/old_file.txt" in diff.deleted_files + + @pytest.mark.asyncio + async def test_compute_diff_updated_files(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], ["viking://temp/root/file.txt"]), + } + target_tree = { + "viking://target/root": ([], ["viking://target/root/file.txt"]), + } + fake_fs.set_file_contents( + { + "viking://temp/root/file.txt": "new content", + "viking://target/root/file.txt": "old content", + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert len(diff.updated_files) == 1 + assert "viking://temp/root/file.txt" in diff.updated_files + + @pytest.mark.asyncio + async def test_compute_diff_unchanged_files(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], ["viking://temp/root/file.txt"]), + } + target_tree = { + "viking://target/root": ([], ["viking://target/root/file.txt"]), + } + fake_fs.set_file_contents( + { + "viking://temp/root/file.txt": "same content", + "viking://target/root/file.txt": "same content", + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert diff.updated_files == [] + + @pytest.mark.asyncio + async def test_compute_diff_added_dirs(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": (["viking://temp/root/new_dir"], []), + "viking://temp/root/new_dir": ([], []), + } + target_tree = { + "viking://target/root": ([], []), + } + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert len(diff.added_dirs) == 1 + assert "viking://temp/root/new_dir" in diff.added_dirs + + @pytest.mark.asyncio + async def test_compute_diff_deleted_dirs(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ([], []), + } + target_tree = { + "viking://target/root": (["viking://target/root/old_dir"], []), + "viking://target/root/old_dir": ([], []), + } + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert len(diff.deleted_dirs) == 1 + assert "viking://target/root/old_dir" in diff.deleted_dirs + + @pytest.mark.asyncio + async def test_compute_diff_mixed_changes(self, processor, fake_fs, ctx): + root_tree = { + "viking://temp/root": ( + ["viking://temp/root/new_dir"], + ["viking://temp/root/new_file.txt", "viking://temp/root/updated.txt"], + ), + "viking://temp/root/new_dir": ([], []), + } + target_tree = { + "viking://target/root": ( + ["viking://target/root/old_dir"], + ["viking://target/root/updated.txt", "viking://target/root/deleted.txt"], + ), + "viking://target/root/old_dir": ([], []), + } + fake_fs.set_file_contents( + { + "viking://temp/root/updated.txt": "new content", + "viking://target/root/updated.txt": "old content", + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + diff = await processor._compute_diff( + root_tree, target_tree, "viking://temp/root", "viking://target/root" + ) + + assert len(diff.added_files) == 1 + assert len(diff.deleted_files) == 1 + assert len(diff.updated_files) == 1 + assert len(diff.added_dirs) == 1 + assert len(diff.deleted_dirs) == 1 + + +class TestCheckFileContentChanged: + """Test cases for _check_file_content_changed() method.""" + + @pytest.mark.asyncio + async def test_content_changed(self, processor, fake_fs, ctx): + fake_fs.set_file_contents( + { + "viking://temp/file.txt": "new content", + "viking://target/file.txt": "old content", + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._check_file_content_changed( + "viking://temp/file.txt", "viking://target/file.txt" + ) + + assert result is True + + @pytest.mark.asyncio + async def test_content_unchanged(self, processor, fake_fs, ctx): + fake_fs.set_file_contents( + { + "viking://temp/file.txt": "same content", + "viking://target/file.txt": "same content", + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._check_file_content_changed( + "viking://temp/file.txt", "viking://target/file.txt" + ) + + assert result is False + + @pytest.mark.asyncio + async def test_content_changed_empty_files(self, processor, fake_fs, ctx): + fake_fs.set_file_contents( + { + "viking://temp/file.txt": "", + "viking://target/file.txt": "", + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._check_file_content_changed( + "viking://temp/file.txt", "viking://target/file.txt" + ) + + assert result is False + + @pytest.mark.asyncio + async def test_content_changed_one_empty(self, processor, fake_fs, ctx): + fake_fs.set_file_contents( + { + "viking://temp/file.txt": "content", + "viking://target/file.txt": "", + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._check_file_content_changed( + "viking://temp/file.txt", "viking://target/file.txt" + ) + + assert result is True + + @pytest.mark.asyncio + async def test_content_changed_on_exception(self, processor, fake_fs, ctx): + fake_fs.read_file = AsyncMock(side_effect=Exception("Read error")) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + result = await processor._check_file_content_changed( + "viking://temp/file.txt", "viking://target/file.txt" + ) + + assert result is True + + +class TestExecuteSyncOperations: + """Test cases for _execute_sync_operations() method.""" + + @pytest.mark.asyncio + async def test_execute_delete_files(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=[], + deleted_files=["viking://target/deleted.txt"], + updated_files=[], + added_dirs=[], + deleted_dirs=[], + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert "viking://target/deleted.txt" in fake_fs.deleted_files + + @pytest.mark.asyncio + async def test_execute_move_added_files(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=["viking://temp/root/new.txt"], + deleted_files=[], + updated_files=[], + added_dirs=[], + deleted_dirs=[], + ) + processor._current_ctx = ctx + + mock_viking_uri = MagicMock() + mock_viking_uri.parent = None + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + with patch( + "openviking.storage.queuefs.semantic_processor.VikingURI", + return_value=mock_viking_uri, + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert ("viking://temp/root/new.txt", "viking://target/root/new.txt") in fake_fs.moved_files + + @pytest.mark.asyncio + async def test_execute_move_updated_files(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=[], + deleted_files=[], + updated_files=["viking://temp/root/updated.txt"], + added_dirs=[], + deleted_dirs=[], + ) + processor._current_ctx = ctx + + mock_viking_uri = MagicMock() + mock_viking_uri.parent = None + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + with patch( + "openviking.storage.queuefs.semantic_processor.VikingURI", + return_value=mock_viking_uri, + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert "viking://target/root/updated.txt" in fake_fs.deleted_files + assert ( + "viking://temp/root/updated.txt", + "viking://target/root/updated.txt", + ) in fake_fs.moved_files + + @pytest.mark.asyncio + async def test_execute_delete_dirs(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=[], + deleted_files=[], + updated_files=[], + added_dirs=[], + deleted_dirs=["viking://target/root/old_dir"], + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert "viking://target/root/old_dir" in fake_fs.deleted_dirs + + @pytest.mark.asyncio + async def test_execute_delete_dirs_deepest_first(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=[], + deleted_files=[], + updated_files=[], + added_dirs=[], + deleted_dirs=[ + "viking://target/root/level1", + "viking://target/root/level1/level2", + "viking://target/root/level1/level2/level3", + ], + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert len(fake_fs.deleted_dirs) == 3 + deepest_first = sorted(fake_fs.deleted_dirs, key=lambda x: x.count("/"), reverse=True) + assert fake_fs.deleted_dirs == deepest_first + + @pytest.mark.asyncio + async def test_execute_creates_parent_dirs(self, processor, fake_fs, ctx): + diff = DiffResult( + added_files=["viking://temp/root/subdir/new.txt"], + deleted_files=[], + updated_files=[], + added_dirs=[], + deleted_dirs=[], + ) + processor._current_ctx = ctx + + mock_parent = MagicMock() + mock_parent.uri = "viking://target/root/subdir" + mock_viking_uri = MagicMock() + mock_viking_uri.parent = mock_parent + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + with patch( + "openviking.storage.queuefs.semantic_processor.VikingURI", + return_value=mock_viking_uri, + ): + await processor._execute_sync_operations( + diff, "viking://temp/root", "viking://target/root" + ) + + assert "viking://target/root/subdir" in fake_fs.created_dirs + + +class TestCreateSyncDiffCallback: + """Test cases for _create_sync_diff_callback() method.""" + + @pytest.mark.asyncio + async def test_callback_returns_callable(self, processor): + callback = processor._create_sync_diff_callback( + "viking://temp/root", "viking://target/root" + ) + assert callable(callback) + + @pytest.mark.asyncio + async def test_callback_is_async(self, processor): + callback = processor._create_sync_diff_callback( + "viking://temp/root", "viking://target/root" + ) + import asyncio + + assert asyncio.iscoroutinefunction(callback) + + @pytest.mark.asyncio + async def test_callback_collects_tree_info(self, processor, fake_fs, ctx): + fake_fs.set_tree( + { + "viking://temp/root": [ + {"name": "file.txt", "isDir": False}, + ], + "viking://target/root": [ + {"name": "file.txt", "isDir": False}, + ], + } + ) + fake_fs.set_file_contents( + { + "viking://temp/root/file.txt": "content", + "viking://target/root/file.txt": "content", + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + callback = processor._create_sync_diff_callback( + "viking://temp/root", "viking://target/root" + ) + await callback() + + @pytest.mark.asyncio + async def test_callback_handles_exception(self, processor, fake_fs, ctx): + fake_fs.ls = AsyncMock(side_effect=Exception("Test error")) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + callback = processor._create_sync_diff_callback( + "viking://temp/root", "viking://target/root" + ) + await callback() + + @pytest.mark.asyncio + async def test_callback_deletes_root_after_sync(self, processor, fake_fs, ctx): + fake_fs.set_tree( + { + "viking://temp/root": [], + "viking://target/root": [], + } + ) + processor._current_ctx = ctx + + with patch( + "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs + ): + callback = processor._create_sync_diff_callback( + "viking://temp/root", "viking://target/root" + ) + await callback() + + assert "viking://temp/root" in fake_fs.deleted_dirs + + +class TestDiffResult: + """Test cases for DiffResult dataclass.""" + + def test_diff_result_default_values(self): + diff = DiffResult() + assert diff.added_files == [] + assert diff.deleted_files == [] + assert diff.updated_files == [] + assert diff.added_dirs == [] + assert diff.deleted_dirs == [] + + def test_diff_result_with_values(self): + diff = DiffResult( + added_files=["a.txt"], + deleted_files=["b.txt"], + updated_files=["c.txt"], + added_dirs=["dir1"], + deleted_dirs=["dir2"], + ) + assert diff.added_files == ["a.txt"] + assert diff.deleted_files == ["b.txt"] + assert diff.updated_files == ["c.txt"] + assert diff.added_dirs == ["dir1"] + assert diff.deleted_dirs == ["dir2"] + + def test_diff_result_modifiable(self): + diff = DiffResult() + diff.added_files.append("new.txt") + assert "new.txt" in diff.added_files diff --git a/tests/unit/storage/queuefs/test_semantic_msg.py b/tests/unit/storage/queuefs/test_semantic_msg.py new file mode 100644 index 00000000..37551109 --- /dev/null +++ b/tests/unit/storage/queuefs/test_semantic_msg.py @@ -0,0 +1,410 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for SemanticMsg dataclass, focusing on new fields target_uri and skip_vectorization.""" + +import json + +import pytest + +from openviking.storage.queuefs.semantic_msg import SemanticMsg + + +class TestTargetUriField: + """Tests for target_uri field serialization and deserialization.""" + + def test_target_uri_default_value(self): + msg = SemanticMsg(uri="viking://resource/test", context_type="resource") + assert msg.target_uri == "" + + def test_target_uri_set_in_constructor(self): + msg = SemanticMsg( + uri="viking://resource/temp", + context_type="resource", + target_uri="viking://resource/target", + ) + assert msg.target_uri == "viking://resource/target" + + def test_target_uri_serialization_to_dict(self): + msg = SemanticMsg( + uri="viking://resource/temp", + context_type="resource", + target_uri="viking://resource/target", + ) + data = msg.to_dict() + assert "target_uri" in data + assert data["target_uri"] == "viking://resource/target" + + def test_target_uri_deserialization_from_dict(self): + data = { + "uri": "viking://resource/temp", + "context_type": "resource", + "target_uri": "viking://resource/target", + } + msg = SemanticMsg.from_dict(data) + assert msg.target_uri == "viking://resource/target" + + def test_target_uri_empty_string_serialization(self): + msg = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + target_uri="", + ) + data = msg.to_dict() + assert data["target_uri"] == "" + + def test_target_uri_with_memory_context(self): + msg = SemanticMsg( + uri="viking://memory/temp/session", + context_type="memory", + target_uri="viking://session/abc123", + ) + data = msg.to_dict() + msg_restored = SemanticMsg.from_dict(data) + assert msg_restored.target_uri == "viking://session/abc123" + + +class TestSkipVectorizationField: + """Tests for skip_vectorization field serialization and deserialization.""" + + def test_skip_vectorization_default_value(self): + msg = SemanticMsg(uri="viking://resource/test", context_type="resource") + assert msg.skip_vectorization is False + + def test_skip_vectorization_set_true_in_constructor(self): + msg = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + skip_vectorization=True, + ) + assert msg.skip_vectorization is True + + def test_skip_vectorization_serialization_to_dict(self): + msg = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + skip_vectorization=True, + ) + data = msg.to_dict() + assert "skip_vectorization" in data + assert data["skip_vectorization"] is True + + def test_skip_vectorization_deserialization_from_dict(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "skip_vectorization": True, + } + msg = SemanticMsg.from_dict(data) + assert msg.skip_vectorization is True + + def test_skip_vectorization_false_serialization(self): + msg = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + skip_vectorization=False, + ) + data = msg.to_dict() + assert data["skip_vectorization"] is False + + def test_skip_vectorization_round_trip(self): + original = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + skip_vectorization=True, + ) + restored = SemanticMsg.from_dict(original.to_dict()) + assert restored.skip_vectorization == original.skip_vectorization + + +class TestFromDictBackwardCompatibility: + """Tests for from_dict() backward compatibility with old format missing new fields.""" + + def test_missing_target_uri_uses_default(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + } + msg = SemanticMsg.from_dict(data) + assert msg.target_uri == "" + + def test_missing_skip_vectorization_uses_default(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + } + msg = SemanticMsg.from_dict(data) + assert msg.skip_vectorization is False + + def test_missing_both_new_fields_uses_defaults(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "recursive": True, + "account_id": "test_account", + } + msg = SemanticMsg.from_dict(data) + assert msg.target_uri == "" + assert msg.skip_vectorization is False + + def test_old_format_with_all_legacy_fields(self): + data = { + "id": "legacy-id-123", + "uri": "viking://resource/test", + "context_type": "resource", + "status": "pending", + "recursive": False, + "account_id": "account1", + "user_id": "user1", + "agent_id": "agent1", + "role": "admin", + } + msg = SemanticMsg.from_dict(data) + assert msg.id == "legacy-id-123" + assert msg.uri == "viking://resource/test" + assert msg.context_type == "resource" + assert msg.status == "pending" + assert msg.recursive is False + assert msg.target_uri == "" + assert msg.skip_vectorization is False + + def test_partial_new_fields_only_target_uri(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "target_uri": "viking://resource/target", + } + msg = SemanticMsg.from_dict(data) + assert msg.target_uri == "viking://resource/target" + assert msg.skip_vectorization is False + + def test_partial_new_fields_only_skip_vectorization(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "skip_vectorization": True, + } + msg = SemanticMsg.from_dict(data) + assert msg.target_uri == "" + assert msg.skip_vectorization is True + + +class TestToJsonFromJson: + """Tests for to_json() and from_json() methods.""" + + def test_to_json_returns_valid_json_string(self): + msg = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + ) + json_str = msg.to_json() + assert isinstance(json_str, str) + parsed = json.loads(json_str) + assert isinstance(parsed, dict) + + def test_from_json_creates_valid_object(self): + json_str = '{"uri": "viking://resource/test", "context_type": "resource"}' + msg = SemanticMsg.from_json(json_str) + assert msg.uri == "viking://resource/test" + assert msg.context_type == "resource" + + def test_to_json_and_from_json_round_trip(self): + original = SemanticMsg( + uri="viking://resource/temp", + context_type="memory", + target_uri="viking://session/abc", + skip_vectorization=True, + recursive=False, + account_id="test_account", + user_id="test_user", + agent_id="test_agent", + role="admin", + ) + json_str = original.to_json() + restored = SemanticMsg.from_json(json_str) + + assert restored.uri == original.uri + assert restored.context_type == original.context_type + assert restored.target_uri == original.target_uri + assert restored.skip_vectorization == original.skip_vectorization + assert restored.recursive == original.recursive + assert restored.account_id == original.account_id + assert restored.user_id == original.user_id + assert restored.agent_id == original.agent_id + assert restored.role == original.role + + def test_from_json_with_new_fields(self): + json_str = json.dumps( + { + "uri": "viking://resource/test", + "context_type": "resource", + "target_uri": "viking://resource/target", + "skip_vectorization": True, + } + ) + msg = SemanticMsg.from_json(json_str) + assert msg.target_uri == "viking://resource/target" + assert msg.skip_vectorization is True + + def test_from_json_without_new_fields(self): + json_str = json.dumps( + { + "uri": "viking://resource/test", + "context_type": "resource", + } + ) + msg = SemanticMsg.from_json(json_str) + assert msg.target_uri == "" + assert msg.skip_vectorization is False + + def test_from_json_invalid_json_raises_value_error(self): + with pytest.raises(ValueError, match="Invalid JSON string"): + SemanticMsg.from_json("not a valid json") + + def test_from_json_missing_required_fields_raises_value_error(self): + json_str = '{"uri": "viking://resource/test"}' + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_json(json_str) + + +class TestRequiredFieldsValidation: + """Tests for required field validation (uri and context_type).""" + + def test_missing_uri_raises_value_error(self): + data = {"context_type": "resource"} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_missing_context_type_raises_value_error(self): + data = {"uri": "viking://resource/test"} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_missing_both_required_fields_raises_value_error(self): + data = {"target_uri": "viking://resource/target"} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_empty_uri_raises_value_error(self): + data = {"uri": "", "context_type": "resource"} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_empty_context_type_raises_value_error(self): + data = {"uri": "viking://resource/test", "context_type": ""} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_none_uri_raises_value_error(self): + data = {"uri": None, "context_type": "resource"} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_none_context_type_raises_value_error(self): + data = {"uri": "viking://resource/test", "context_type": None} + with pytest.raises(ValueError, match="Missing required fields"): + SemanticMsg.from_dict(data) + + def test_empty_dict_raises_value_error(self): + with pytest.raises(ValueError, match="Data dictionary is empty"): + SemanticMsg.from_dict({}) + + def test_valid_minimal_data_succeeds(self): + data = {"uri": "viking://resource/test", "context_type": "resource"} + msg = SemanticMsg.from_dict(data) + assert msg.uri == "viking://resource/test" + assert msg.context_type == "resource" + + def test_error_message_lists_all_missing_fields(self): + data = {"skip_vectorization": True} + with pytest.raises(ValueError) as exc_info: + SemanticMsg.from_dict(data) + error_msg = str(exc_info.value) + assert "uri" in error_msg + assert "context_type" in error_msg + + +class TestEdgeCases: + """Tests for edge cases and boundary conditions.""" + + def test_target_uri_with_special_characters(self): + special_uri = "viking://resource/test%20space?query=value&other=123" + msg = SemanticMsg( + uri="viking://resource/temp", + context_type="resource", + target_uri=special_uri, + ) + restored = SemanticMsg.from_dict(msg.to_dict()) + assert restored.target_uri == special_uri + + def test_target_uri_with_unicode(self): + unicode_uri = "viking://resource/测试/目录" + msg = SemanticMsg( + uri="viking://resource/temp", + context_type="resource", + target_uri=unicode_uri, + ) + restored = SemanticMsg.from_dict(msg.to_dict()) + assert restored.target_uri == unicode_uri + + def test_target_uri_with_long_path(self): + long_path = "viking://resource/" + "/".join(["dir"] * 100) + msg = SemanticMsg( + uri="viking://resource/temp", + context_type="resource", + target_uri=long_path, + ) + restored = SemanticMsg.from_dict(msg.to_dict()) + assert restored.target_uri == long_path + + def test_preserves_existing_id_in_from_dict(self): + original = SemanticMsg( + uri="viking://resource/test", + context_type="resource", + ) + original_id = original.id + data = original.to_dict() + restored = SemanticMsg.from_dict(data) + assert restored.id == original_id + + def test_from_dict_overwrites_id_if_provided(self): + data = { + "id": "custom-id-123", + "uri": "viking://resource/test", + "context_type": "resource", + } + msg = SemanticMsg.from_dict(data) + assert msg.id == "custom-id-123" + + def test_from_dict_preserves_status_and_timestamp(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "status": "completed", + "timestamp": 1700000000, + } + msg = SemanticMsg.from_dict(data) + assert msg.status == "completed" + assert msg.timestamp == 1700000000 + + def test_all_context_types(self): + context_types = ["resource", "memory", "skill", "session"] + for ctx_type in context_types: + msg = SemanticMsg( + uri=f"viking://{ctx_type}/test", + context_type=ctx_type, + ) + assert msg.context_type == ctx_type + restored = SemanticMsg.from_dict(msg.to_dict()) + assert restored.context_type == ctx_type + + def test_extra_fields_in_dict_are_ignored(self): + data = { + "uri": "viking://resource/test", + "context_type": "resource", + "extra_field": "should_be_ignored", + "another_extra": 12345, + } + msg = SemanticMsg.from_dict(data) + assert msg.uri == "viking://resource/test" + assert not hasattr(msg, "extra_field") + assert not hasattr(msg, "another_extra") diff --git a/tests/unit/storage/test_viking_fs_new.py b/tests/unit/storage/test_viking_fs_new.py new file mode 100644 index 00000000..6661cc9f --- /dev/null +++ b/tests/unit/storage/test_viking_fs_new.py @@ -0,0 +1,203 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for VikingFS new methods. + +Tests for: +- exists(): Check if URI exists +- copy_directory(): Recursively copy directory +- delete_temp(): Delete temporary directory +""" + +import contextvars +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.storage.viking_fs import VikingFS + + +def _create_viking_fs_mock(): + """Create a VikingFS instance with mocked AGFS backend.""" + fs = VikingFS.__new__(VikingFS) + fs.agfs = MagicMock() + fs.query_embedder = None + fs.vector_store = None + fs._uri_prefix = "viking://" + fs._bound_ctx = contextvars.ContextVar("vikingfs_bound_ctx", default=None) + return fs + + +@pytest.mark.asyncio +class TestVikingFSExists: + """Test cases for VikingFS.exists() method.""" + + async def test_exists_returns_true_when_uri_exists(self): + """exists() should return True when URI exists.""" + fs = _create_viking_fs_mock() + fs.stat = AsyncMock(return_value={"name": "test_file.txt", "isDir": False}) + + result = await fs.exists("viking://temp/test_file.txt") + + assert result is True + fs.stat.assert_awaited_once_with("viking://temp/test_file.txt", ctx=None) + + async def test_exists_returns_false_when_uri_not_found(self): + """exists() should return False when URI does not exist.""" + fs = _create_viking_fs_mock() + fs.stat = AsyncMock(side_effect=FileNotFoundError("Not found")) + + result = await fs.exists("viking://temp/nonexistent.txt") + + assert result is False + fs.stat.assert_awaited_once_with("viking://temp/nonexistent.txt", ctx=None) + + async def test_exists_returns_false_on_any_exception(self): + """exists() should return False on any exception, not just FileNotFoundError.""" + fs = _create_viking_fs_mock() + fs.stat = AsyncMock(side_effect=PermissionError("Access denied")) + + result = await fs.exists("viking://temp/protected.txt") + + assert result is False + + +@pytest.mark.asyncio +class TestVikingFSCopyDirectory: + """Test cases for VikingFS.copy_directory() method.""" + + async def test_copy_directory_recursive(self): + """copy_directory() should recursively copy directory contents.""" + fs = _create_viking_fs_mock() + fs._ensure_access = MagicMock() + fs._uri_to_path = MagicMock( + side_effect=lambda uri, ctx=None: uri.replace("viking://", "/local/") + ) + fs._ensure_parent_dirs = AsyncMock() + + mock_agfs_cp = MagicMock() + + with patch("openviking.storage.viking_fs.agfs_cp", mock_agfs_cp): + await fs.copy_directory( + "viking://temp/source_dir/", + "viking://temp/dest_dir/", + ) + + fs._ensure_access.assert_any_call("viking://temp/source_dir/", None) + fs._ensure_access.assert_any_call("viking://temp/dest_dir/", None) + fs._ensure_parent_dirs.assert_awaited_once_with("/local/temp/dest_dir/") + mock_agfs_cp.assert_called_once_with( + fs.agfs, + "/local/temp/source_dir/", + "/local/temp/dest_dir/", + recursive=True, + ) + + async def test_copy_directory_with_context(self): + """copy_directory() should pass context to helper methods.""" + from openviking.server.identity import RequestContext, Role + from openviking_cli.session.user_id import UserIdentifier + + fs = _create_viking_fs_mock() + fs._ensure_access = MagicMock() + fs._uri_to_path = MagicMock( + side_effect=lambda uri, ctx=None: uri.replace("viking://", "/local/") + ) + fs._ensure_parent_dirs = AsyncMock() + + ctx = RequestContext( + user=UserIdentifier("acc1", "user1", "agent1"), + role=Role.USER, + ) + + mock_agfs_cp = MagicMock() + + with patch("openviking.storage.viking_fs.agfs_cp", mock_agfs_cp): + await fs.copy_directory( + "viking://temp/source/", + "viking://temp/dest/", + ctx=ctx, + ) + + fs._ensure_access.assert_any_call("viking://temp/source/", ctx) + fs._ensure_access.assert_any_call("viking://temp/dest/", ctx) + + +@pytest.mark.asyncio +class TestVikingFSDeleteTemp: + """Test cases for VikingFS.delete_temp() method.""" + + async def test_delete_temp_removes_directory_and_contents(self): + """delete_temp() should remove directory and all its contents.""" + fs = _create_viking_fs_mock() + fs._uri_to_path = MagicMock(return_value="/local/temp/test_temp") + + fs._ls_entries = MagicMock( + return_value=[ + {"name": "file1.txt", "isDir": False}, + {"name": "subdir", "isDir": True}, + ] + ) + + fs.agfs.rm = MagicMock() + + call_count = [0] + + async def mock_delete_temp(uri, ctx=None): + call_count[0] += 1 + if call_count[0] == 1: + fs._ls_entries.return_value = [ + {"name": "nested_file.txt", "isDir": False}, + ] + await fs.delete_temp(uri, ctx=ctx) + else: + fs._ls_entries.return_value = [] + + original_delete_temp = fs.delete_temp + fs.delete_temp = mock_delete_temp + + await original_delete_temp("viking://temp/test_temp/") + + assert fs.agfs.rm.call_count >= 1 + + async def test_delete_temp_handles_empty_directory(self): + """delete_temp() should handle empty directory gracefully.""" + fs = _create_viking_fs_mock() + fs._uri_to_path = MagicMock(return_value="/local/temp/empty_temp") + fs._ls_entries = MagicMock(return_value=[]) + fs.agfs.rm = MagicMock() + + await fs.delete_temp("viking://temp/empty_temp/") + + fs.agfs.rm.assert_called_once_with("/local/temp/empty_temp") + + async def test_delete_temp_skips_dot_entries(self): + """delete_temp() should skip . and .. entries.""" + fs = _create_viking_fs_mock() + fs._uri_to_path = MagicMock(return_value="/local/temp/test_temp") + fs._ls_entries = MagicMock( + return_value=[ + {"name": ".", "isDir": True}, + {"name": "..", "isDir": True}, + {"name": "actual_file.txt", "isDir": False}, + ] + ) + fs.agfs.rm = MagicMock() + + await fs.delete_temp("viking://temp/test_temp/") + + rm_calls = [call[0][0] for call in fs.agfs.rm.call_args_list] + assert "/local/temp/test_temp/actual_file.txt" in rm_calls + assert "/local/temp/test_temp/." not in rm_calls + assert "/local/temp/test_temp/.." not in rm_calls + + async def test_delete_temp_logs_warning_on_error(self): + """delete_temp() should log warning but not raise on error.""" + fs = _create_viking_fs_mock() + fs._uri_to_path = MagicMock(return_value="/local/temp/error_temp") + fs._ls_entries = MagicMock(side_effect=Exception("AGFS error")) + + with patch("openviking.storage.viking_fs.logger") as mock_logger: + await fs.delete_temp("viking://temp/error_temp/") + + mock_logger.warning.assert_called_once() + assert "Failed to delete temp" in mock_logger.warning.call_args[0][0]