diff --git a/openviking/core/context.py b/openviking/core/context.py index 76308570..94d47b1f 100644 --- a/openviking/core/context.py +++ b/openviking/core/context.py @@ -56,7 +56,6 @@ def __init__( self, uri: str, parent_uri: Optional[str] = None, - temp_uri: Optional[str] = None, is_leaf: bool = False, abstract: str = "", context_type: Optional[str] = None, @@ -79,7 +78,6 @@ def __init__( self.id = id or str(uuid4()) self.uri = uri self.parent_uri = parent_uri - self.temp_uri = temp_uri self.is_leaf = is_leaf self.abstract = abstract self.context_type = context_type or self._derive_context_type() @@ -161,7 +159,6 @@ def to_dict(self) -> Dict[str, Any]: "id": self.id, "uri": self.uri, "parent_uri": self.parent_uri, - "temp_uri": self.temp_uri, "is_leaf": self.is_leaf, "abstract": self.abstract, "context_type": self.context_type, @@ -197,7 +194,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "Context": obj = cls( uri=data["uri"], parent_uri=data.get("parent_uri"), - temp_uri=data.get("temp_uri"), is_leaf=data.get("is_leaf", False), abstract=data.get("abstract", ""), context_type=data.get("context_type"), diff --git a/openviking/parse/parsers/constants.py b/openviking/parse/parsers/constants.py index a817a8f2..311e545a 100644 --- a/openviking/parse/parsers/constants.py +++ b/openviking/parse/parsers/constants.py @@ -174,7 +174,6 @@ ".graphql", ".gql", ".prisma", - ".conf", } # Documentation file extensions for file type detection @@ -225,7 +224,6 @@ ".yarnrc", ".env", ".env.example", - ".jsonl", } # Common text encodings to try for encoding detection (in order of likelihood) diff --git a/openviking/parse/parsers/directory.py b/openviking/parse/parsers/directory.py index 371cf0f9..da27b659 100644 --- a/openviking/parse/parsers/directory.py +++ b/openviking/parse/parsers/directory.py @@ -123,11 +123,6 @@ async def parse( viking_fs = self._get_viking_fs() temp_uri = self._create_temp_uri() target_uri = f"{temp_uri}/{dir_name}" - logger.info( - f"Scanning directory: {source_path}, " - f"processable files: {len(processable_files)}, " - f"warnings: {warnings}" - ) await viking_fs.mkdir(temp_uri, exist_ok=True) await viking_fs.mkdir(target_uri, exist_ok=True) diff --git a/openviking/parse/parsers/upload_utils.py b/openviking/parse/parsers/upload_utils.py index d1870173..50c80ed0 100644 --- a/openviking/parse/parsers/upload_utils.py +++ b/openviking/parse/parsers/upload_utils.py @@ -40,7 +40,6 @@ "NEWS", "NOTICE", "TODO", - "BUILD", } diff --git a/openviking/parse/tree_builder.py b/openviking/parse/tree_builder.py index 18bcbf07..c8409f41 100644 --- a/openviking/parse/tree_builder.py +++ b/openviking/parse/tree_builder.py @@ -138,6 +138,16 @@ async def finalize_from_temp( base_uri = parent_uri or auto_base_uri # 3. Determine candidate_uri if to_uri: + # Exact target URI: must not exist yet + try: + await viking_fs.stat(to_uri, ctx=ctx) + # If we get here, it already exists + raise FileExistsError(f"Target URI already exists: {to_uri}") + except FileExistsError: + raise + except Exception: + # It doesn't exist, good to use + pass candidate_uri = to_uri else: if parent_uri: @@ -150,7 +160,34 @@ async def finalize_from_temp( raise ValueError(f"Parent URI is not a directory: {parent_uri}") candidate_uri = VikingURI(base_uri).join(final_doc_name).uri - final_uri = candidate_uri + final_uri = await self._resolve_unique_uri(candidate_uri, ctx=ctx) + + if final_uri != candidate_uri: + logger.info(f"[TreeBuilder] Resolved name conflict: {candidate_uri} -> {final_uri}") + else: + logger.info(f"[TreeBuilder] Finalizing from temp: {final_uri}") + + # 4. Move directory tree from temp to final location in AGFS + await self._move_temp_to_dest(viking_fs, temp_doc_uri, final_uri, ctx=ctx) + logger.info(f"[TreeBuilder] Moved temp tree: {temp_doc_uri} -> {final_uri}") + + # 5. Cleanup temporary root directory + try: + await viking_fs.delete_temp(temp_uri, ctx=ctx) + logger.info(f"[TreeBuilder] Cleaned up temp root: {temp_uri}") + except Exception as e: + logger.warning(f"[TreeBuilder] Failed to cleanup temp root: {e}") + + # 6. Enqueue to SemanticQueue for async semantic generation + if trigger_semantic: + try: + await self._enqueue_semantic_generation(final_uri, "resource", ctx=ctx) + logger.info(f"[TreeBuilder] Enqueued semantic generation for: {final_uri}") + except Exception as e: + logger.error( + f"[TreeBuilder] Failed to enqueue semantic generation: {e}", exc_info=True + ) + # 7. Return simple BuildingTree (no scanning needed) tree = BuildingTree( source_path=source_path, @@ -159,11 +196,39 @@ async def finalize_from_temp( tree._root_uri = final_uri # Create a minimal Context object for the root so that tree.root is not None - root_context = Context(uri=final_uri, temp_uri=temp_doc_uri) + root_context = Context(uri=final_uri) tree.add_context(root_context) return tree + async def _resolve_unique_uri( + self, uri: str, max_attempts: int = 100, ctx: Optional[RequestContext] = None + ) -> str: + """Return a URI that does not collide with an existing resource. + + If *uri* is free, return it unchanged. Otherwise append ``_1``, + ``_2``, … until a free name is found (like macOS Finder / Windows + Explorer). + """ + viking_fs = get_viking_fs() + + async def _exists(u: str) -> bool: + try: + await viking_fs.stat(u, ctx=ctx) + return True + except Exception: + return False + + if not await _exists(uri): + return uri + + for i in range(1, max_attempts + 1): + candidate = f"{uri}_{i}" + if not await _exists(candidate): + return candidate + + raise FileExistsError(f"Cannot resolve unique name for {uri} after {max_attempts} attempts") + async def _move_temp_to_dest( self, viking_fs, src_uri: str, dst_uri: str, ctx: RequestContext ) -> None: @@ -196,7 +261,7 @@ async def _ensure_parent_dirs(self, uri: str, ctx: RequestContext) -> None: logger.debug(f"Parent dir {parent_uri} may already exist: {e}") async def _enqueue_semantic_generation( - self, uri: str, final_uri: str, context_type: str, ctx: RequestContext + self, uri: str, context_type: str, ctx: RequestContext ) -> None: """ Enqueue a directory for semantic generation. @@ -219,6 +284,32 @@ async def _enqueue_semantic_generation( user_id=ctx.user.user_id, agent_id=ctx.user.agent_id, role=ctx.role.value, - target_uri=final_uri, ) await semantic_queue.enqueue(msg) + + async def _load_content(self, uri: str, content_type: str) -> str: + """Helper to load content with proper type handling""" + import json + + if content_type == "abstract": + result = await get_viking_fs().abstract(uri) + elif content_type == "overview": + result = await get_viking_fs().overview(uri) + elif content_type == "detail": + result = await get_viking_fs().read_file(uri) + else: + return "" + + # Handle different return types + if isinstance(result, str): + return result + elif isinstance(result, bytes): + return result.decode("utf-8") + elif hasattr(result, "to_dict") and not isinstance(result, list): + # Handle FindResult by converting to dict (skip lists) + return str(result.to_dict()) + elif isinstance(result, list): + # Handle list results + return json.dumps(result) + else: + return str(result) diff --git a/openviking/server/models.py b/openviking/server/models.py index 26b6e5ec..4cb7f967 100644 --- a/openviking/server/models.py +++ b/openviking/server/models.py @@ -39,7 +39,6 @@ class Response(BaseModel): "INVALID_URI": 400, "NOT_FOUND": 404, "ALREADY_EXISTS": 409, - "CONFLICT": 409, "PERMISSION_DENIED": 403, "UNAUTHENTICATED": 401, "RESOURCE_EXHAUSTED": 429, diff --git a/openviking/service/core.py b/openviking/service/core.py index a19f3ba3..5764fed1 100644 --- a/openviking/service/core.py +++ b/openviking/service/core.py @@ -255,9 +255,7 @@ async def initialize(self) -> None: ) # Initialize processors - self._resource_processor = ResourceProcessor( - vikingdb=self._vikingdb_manager, - ) + self._resource_processor = ResourceProcessor(vikingdb=self._vikingdb_manager) self._skill_processor = SkillProcessor(vikingdb=self._vikingdb_manager) self._session_compressor = SessionCompressor(vikingdb=self._vikingdb_manager) diff --git a/openviking/session/compressor.py b/openviking/session/compressor.py index 0dd58d88..f5009897 100644 --- a/openviking/session/compressor.py +++ b/openviking/session/compressor.py @@ -78,57 +78,16 @@ async def _index_memory(self, memory: Context, ctx: RequestContext) -> bool: await self.extractor._enqueue_semantic_for_parent(memory.uri, ctx) return True - def _convert_to_temp_uri( - self, target_uri: str, user_temp_uri: Optional[str], agent_temp_uri: Optional[str] - ) -> str: - """Convert target URI to temp URI for COW pattern. - - Args: - target_uri: Target URI (e.g., viking://user/... or viking://agent/...) - user_temp_uri: Temp user URI (if available) - agent_temp_uri: Temp agent URI (if available) - - Returns: - Converted temp URI, or original URI if no temp available - """ - if not user_temp_uri and not agent_temp_uri: - return target_uri - - # Convert user URI - if target_uri.startswith("viking://user/") and user_temp_uri: - # viking://user/{user_space}/memories/... -> {user_temp_uri}/memories/... - parts = target_uri.split("/") - if len(parts) >= 5: - # parts[0]="viking:", parts[1]="", parts[2]="user", parts[3]="{user_space}", parts[4:]="memories/..." - rest = "/".join(parts[4:]) - return f"{user_temp_uri}/{rest}" - - # Convert agent URI - if target_uri.startswith("viking://agent/") and agent_temp_uri: - # viking://agent/{agent_space}/memories/... -> {agent_temp_uri}/memories/... - parts = target_uri.split("/") - if len(parts) >= 5: - # parts[0]="viking:", parts[1]="", parts[2]="agent", parts[3]="{agent_space}", parts[4:]="memories/..." - rest = "/".join(parts[4:]) - return f"{agent_temp_uri}/{rest}" - - return target_uri - async def _merge_into_existing( self, candidate: CandidateMemory, target_memory: Context, viking_fs, ctx: RequestContext, - user_temp_uri: Optional[str] = None, - agent_temp_uri: Optional[str] = None, ) -> bool: """Merge candidate content into an existing memory file.""" try: - # Convert target URI to temp URI for COW pattern - temp_uri = self._convert_to_temp_uri(target_memory.uri, user_temp_uri, agent_temp_uri) - - existing_content = await viking_fs.read_file(temp_uri, ctx=ctx) + existing_content = await viking_fs.read_file(target_memory.uri, ctx=ctx) payload = await self.extractor._merge_memory_bundle( existing_abstract=target_memory.abstract, existing_overview=(target_memory.meta or {}).get("overview") or "", @@ -142,41 +101,34 @@ async def _merge_into_existing( if not payload: return False - await viking_fs.write_file(temp_uri, payload.content, ctx=ctx) + await viking_fs.write_file(target_memory.uri, payload.content, ctx=ctx) target_memory.abstract = payload.abstract target_memory.meta = {**(target_memory.meta or {}), "overview": payload.overview} - logger.info("Merged memory %s with abstract %s", temp_uri, target_memory.abstract) + logger.info( + "Merged memory %s with abstract %s", target_memory.uri, target_memory.abstract + ) target_memory.set_vectorize(Vectorize(text=payload.content)) - # Note: vectorization will be handled by SemanticQueue after directory switch - # await self._index_memory(target_memory, ctx) + await self._index_memory(target_memory, ctx) return True except Exception as e: logger.error(f"Failed to merge memory {target_memory.uri}: {e}") return False async def _delete_existing_memory( - self, - memory: Context, - viking_fs, - ctx: RequestContext, - user_temp_uri: Optional[str] = None, - agent_temp_uri: Optional[str] = None, + self, memory: Context, viking_fs, ctx: RequestContext ) -> bool: """Hard delete an existing memory file and clean up its vector record.""" try: - # Convert target URI to temp URI for COW pattern - temp_uri = self._convert_to_temp_uri(memory.uri, user_temp_uri, agent_temp_uri) - - await viking_fs.rm(temp_uri, recursive=False, ctx=ctx) + await viking_fs.rm(memory.uri, recursive=False, ctx=ctx) except Exception as e: - logger.error(f"Failed to delete memory file {temp_uri}: {e}") + logger.error(f"Failed to delete memory file {memory.uri}: {e}") return False try: # rm() already syncs vector deletion in most cases; keep this as a safe fallback. - await self.vikingdb.delete_uris(ctx, [temp_uri]) + await self.vikingdb.delete_uris(ctx, [memory.uri]) except Exception as e: - logger.warning(f"Failed to remove vector record for {temp_uri}: {e}") + logger.warning(f"Failed to remove vector record for {memory.uri}: {e}") return True async def extract_long_term_memories( @@ -185,25 +137,9 @@ async def extract_long_term_memories( user: Optional["UserIdentifier"] = None, session_id: Optional[str] = None, ctx: Optional[RequestContext] = None, - user_temp_uri: Optional[str] = None, - agent_temp_uri: Optional[str] = None, strict_extract_errors: bool = False, ) -> List[Context]: - """Extract long-term memories from messages. - - Args: - messages: Messages to extract from - user: User identifier - session_id: Session ID - ctx: Request context - user_temp_uri: Temp user URI (for COW pattern). If provided, user memories - will be written to this temp location. - agent_temp_uri: Temp agent URI (for COW pattern). If provided, agent memories - will be written to this temp location. - - Returns: - List of extracted memories - """ + """Extract long-term memories from messages.""" if not messages: return [] @@ -234,19 +170,11 @@ async def extract_long_term_memories( for candidate in candidates: # Profile: skip dedup, always merge if candidate.category in ALWAYS_MERGE_CATEGORIES: - memory = await self.extractor.create_memory( - candidate, - user, - session_id, - ctx=ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=agent_temp_uri, - ) + memory = await self.extractor.create_memory(candidate, user, session_id, ctx=ctx) if memory: memories.append(memory) stats.created += 1 - # Note: vectorization will be handled by SemanticQueue after directory switch - # await self._index_memory(memory, ctx) + await self._index_memory(memory, ctx) else: stats.skipped += 1 continue @@ -285,11 +213,11 @@ async def extract_long_term_memories( ) if skill_name: memory = await self.extractor._merge_skill_memory( - skill_name, candidate, ctx=ctx, agent_temp_uri=agent_temp_uri + skill_name, candidate, ctx=ctx ) elif tool_name: memory = await self.extractor._merge_tool_memory( - tool_name, candidate, ctx=ctx, agent_temp_uri=agent_temp_uri + tool_name, candidate, ctx=ctx ) else: logger.warning("No tool_name or skill_name found, skipping") @@ -298,8 +226,7 @@ async def extract_long_term_memories( if memory: memories.append(memory) stats.merged += 1 - # Note: vectorization will be handled by SemanticQueue after directory switch - # await self._index_memory(memory, ctx) + await self._index_memory(memory, ctx) continue # Dedup check for other categories @@ -329,11 +256,7 @@ async def extract_long_term_memories( for action in actions: if action.decision == MemoryActionDecision.DELETE: if viking_fs and await self._delete_existing_memory( - action.memory, - viking_fs, - ctx=ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=agent_temp_uri, + action.memory, viking_fs, ctx=ctx ): stats.deleted += 1 else: @@ -341,17 +264,13 @@ async def extract_long_term_memories( elif action.decision == MemoryActionDecision.MERGE: if candidate.category in MERGE_SUPPORTED_CATEGORIES and viking_fs: if await self._merge_into_existing( - candidate, - action.memory, - viking_fs, - ctx=ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=agent_temp_uri, + candidate, action.memory, viking_fs, ctx=ctx ): stats.merged += 1 else: stats.skipped += 1 else: + # events/cases don't support MERGE, treat as SKIP stats.skipped += 1 continue @@ -360,60 +279,24 @@ async def extract_long_term_memories( for action in actions: if action.decision == MemoryActionDecision.DELETE: if viking_fs and await self._delete_existing_memory( - action.memory, - viking_fs, - ctx=ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=agent_temp_uri, + action.memory, viking_fs, ctx=ctx ): stats.deleted += 1 else: stats.skipped += 1 - memory = await self.extractor.create_memory( - candidate, - user, - session_id, - ctx=ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=agent_temp_uri, - ) + memory = await self.extractor.create_memory(candidate, user, session_id, ctx=ctx) if memory: memories.append(memory) stats.created += 1 - # Note: vectorization will be handled by SemanticQueue after directory switch - # await self._index_memory(memory, ctx) + await self._index_memory(memory, ctx) else: stats.skipped += 1 # Extract URIs used in messages, create relations used_uris = self._extract_used_uris(messages) if used_uris and memories: - # Convert memory URIs from temp to target for relation creation - target_memories = [] - for memory in memories: - # Create a copy with target URI - target_uri = memory.uri - # If memory.uri is a temp URI, convert it to target URI - if user_temp_uri and memory.uri.startswith(user_temp_uri): - target_uri = memory.uri.replace( - user_temp_uri, f"viking://user/{ctx.user.user_space_name()}" - ) - elif agent_temp_uri and memory.uri.startswith(agent_temp_uri): - target_uri = memory.uri.replace( - agent_temp_uri, f"viking://agent/{ctx.user.agent_space_name()}" - ) - - # Create a new Context with target URI for relation creation - target_memory = Context( - uri=target_uri, - context_type=memory.context_type, - abstract=memory.abstract, - meta=memory.meta, - ) - target_memories.append(target_memory) - - await self._create_relations(target_memories, used_uris, ctx=ctx) + await self._create_relations(memories, used_uris, ctx=ctx) logger.info( f"Memory extraction: created={stats.created}, " diff --git a/openviking/session/memory_deduplicator.py b/openviking/session/memory_deduplicator.py index ee28ff33..c119ecb8 100644 --- a/openviking/session/memory_deduplicator.py +++ b/openviking/session/memory_deduplicator.py @@ -117,19 +117,8 @@ async def deduplicate( async def _find_similar_memories( self, candidate: CandidateMemory, - user_temp_uri: Optional[str] = None, - agent_temp_uri: Optional[str] = None, ) -> List[Context]: - """Find similar existing memories using vector search. - - Args: - candidate: Candidate memory - user_temp_uri: Temp user URI (for COW pattern) - agent_temp_uri: Temp agent URI (for COW pattern) - - Returns: - List of similar memories with temp URIs (if temp URIs provided) - """ + """Find similar existing memories using vector search.""" if not self.embedder: return [] @@ -138,7 +127,6 @@ async def _find_similar_memories( embed_result: EmbedResult = self.embedder.embed(query_text) query_vector = embed_result.dense_vector - # Search target URI (not temp URI) because vectors are stored for target URIs category_uri_prefix = self._category_uri_prefix(candidate.category.value, candidate.user) owner = candidate.user @@ -189,25 +177,6 @@ async def _find_similar_memories( if context: # Keep retrieval score for later destructive-action guardrails. context.meta = {**(context.meta or {}), "_dedup_score": score} - - # Convert target URI to temp URI (for COW pattern) - if user_temp_uri or agent_temp_uri: - original_uri = context.uri - # Convert user URI - if user_temp_uri and original_uri.startswith("viking://user/"): - parts = original_uri.split("/") - if len(parts) >= 5: - rest = "/".join(parts[4:]) - context.uri = f"{user_temp_uri}/{rest}" - logger.debug(f"Converted URI: {original_uri} -> {context.uri}") - # Convert agent URI - elif agent_temp_uri and original_uri.startswith("viking://agent/"): - parts = original_uri.split("/") - if len(parts) >= 5: - rest = "/".join(parts[4:]) - context.uri = f"{agent_temp_uri}/{rest}" - logger.debug(f"Converted URI: {original_uri} -> {context.uri}") - similar.append(context) logger.debug("Dedup similar memories after threshold=%d", len(similar)) return similar diff --git a/openviking/session/memory_extractor.py b/openviking/session/memory_extractor.py index cf6b3606..4130880e 100644 --- a/openviking/session/memory_extractor.py +++ b/openviking/session/memory_extractor.py @@ -219,6 +219,7 @@ async def extract( context: dict, user: UserIdentifier, session_id: str, + *, strict: bool = False, ) -> List[CandidateMemory]: """Extract memory candidates from messages. @@ -407,21 +408,8 @@ async def create_memory( user: str, session_id: str, ctx: RequestContext, - user_temp_uri: Optional[str] = None, - agent_temp_uri: Optional[str] = None, ) -> Optional[Context]: - """Create Context object from candidate and persist to AGFS as .md file. - - Args: - candidate: Candidate memory to create - user: User identifier - session_id: Session ID - ctx: Request context - user_temp_uri: Temp user URI (for COW pattern). If provided, user memories - will be written to this temp location. - agent_temp_uri: Temp agent URI (for COW pattern). If provided, agent memories - will be written to this temp location. - """ + """Create Context object from candidate and persist to AGFS as .md file.""" viking_fs = get_viking_fs() if not viking_fs: logger.warning("VikingFS not available, skipping memory creation") @@ -431,22 +419,14 @@ async def create_memory( # Special handling for profile: append to profile.md if candidate.category == MemoryCategory.PROFILE: - payload = await self._append_to_profile( - candidate, viking_fs, ctx=ctx, user_temp_uri=user_temp_uri - ) + payload = await self._append_to_profile(candidate, viking_fs, ctx=ctx) if not payload: return None user_space = ctx.user.user_space_name() - # Use temp user URI if provided (for COW pattern) - if user_temp_uri: - memory_uri = f"{user_temp_uri}/memories/profile.md" - parent_uri = f"{user_temp_uri}/memories" - else: - memory_uri = f"viking://user/{user_space}/memories/profile.md" - parent_uri = f"viking://user/{user_space}/memories" + memory_uri = f"viking://user/{user_space}/memories/profile.md" memory = Context( uri=memory_uri, - parent_uri=parent_uri, + parent_uri=f"viking://user/{user_space}/memories", is_leaf=True, abstract=payload.abstract, context_type=ContextType.MEMORY.value, @@ -467,17 +447,9 @@ async def create_memory( MemoryCategory.ENTITIES, MemoryCategory.EVENTS, ]: - # Use temp user URI if provided (for COW pattern) - if user_temp_uri: - parent_uri = f"{user_temp_uri}/{cat_dir}" - else: - parent_uri = f"viking://user/{ctx.user.user_space_name()}/{cat_dir}" + parent_uri = f"viking://user/{ctx.user.user_space_name()}/{cat_dir}" else: # CASES, PATTERNS - # Use temp agent URI if provided (for COW pattern) - if agent_temp_uri: - parent_uri = f"{agent_temp_uri}/{cat_dir}" - else: - parent_uri = f"viking://agent/{ctx.user.agent_space_name()}/{cat_dir}" + parent_uri = f"viking://agent/{ctx.user.agent_space_name()}/{cat_dir}" # Generate file URI (store directly as .md file, no directory creation) memory_id = f"mem_{str(uuid4())}" @@ -513,14 +485,9 @@ async def _append_to_profile( candidate: CandidateMemory, viking_fs, ctx: RequestContext, - user_temp_uri: Optional[str] = None, ) -> Optional[MergedMemoryPayload]: """Update user profile - always merge with existing content.""" - # Use temp user URI if provided (for COW pattern) - if user_temp_uri: - uri = f"{user_temp_uri}/memories/profile.md" - else: - uri = f"viking://user/{ctx.user.user_space_name()}/memories/profile.md" + uri = f"viking://user/{ctx.user.user_space_name()}/memories/profile.md" existing = "" try: existing = await viking_fs.read_file(uri, ctx=ctx) or "" @@ -620,11 +587,7 @@ async def _merge_memory_bundle( return None async def _merge_tool_memory( - self, - tool_name: str, - candidate: CandidateMemory, - ctx: "RequestContext", - agent_temp_uri: Optional[str] = None, + self, tool_name: str, candidate: CandidateMemory, ctx: "RequestContext" ) -> Optional[Context]: """合并 Tool Memory,统计数据用 Python 累加""" if not tool_name or not tool_name.strip(): @@ -632,11 +595,7 @@ async def _merge_tool_memory( return None agent_space = ctx.user.agent_space_name() - # Use temp agent URI if provided (for COW pattern) - if agent_temp_uri: - uri = f"{agent_temp_uri}/memories/tools/{tool_name}.md" - else: - uri = f"viking://agent/{agent_space}/memories/tools/{tool_name}.md" + uri = f"viking://agent/{agent_space}/memories/tools/{tool_name}.md" viking_fs = get_viking_fs() if not viking_fs: @@ -692,7 +651,7 @@ async def _merge_tool_memory( tool_name, merged_stats, new_guidelines, fields=new_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_tool_context(uri, candidate, ctx, agent_temp_uri=agent_temp_uri) + return self._create_tool_context(uri, candidate, ctx) existing_stats = self._parse_tool_statistics(existing) merged_stats = self._merge_tool_statistics(existing_stats, new_stats) @@ -750,9 +709,7 @@ async def _merge_tool_memory( tool_name, merged_stats, merged_guidelines, fields=merged_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_tool_context( - uri, candidate, ctx, abstract_override=abstract_override, agent_temp_uri=agent_temp_uri - ) + return self._create_tool_context(uri, candidate, ctx, abstract_override=abstract_override) async def _enqueue_semantic_for_parent(self, file_uri: str, ctx: "RequestContext") -> None: """Enqueue semantic generation for parent directory.""" @@ -1168,18 +1125,12 @@ def _create_tool_context( candidate: CandidateMemory, ctx: "RequestContext", abstract_override: Optional[str] = None, - agent_temp_uri: Optional[str] = None, ) -> Context: """创建 Tool Memory 的 Context 对象""" agent_space = ctx.user.agent_space_name() - # Use temp agent URI if provided (for COW pattern) - if agent_temp_uri: - parent_uri = f"{agent_temp_uri}/memories/tools" - else: - parent_uri = f"viking://agent/{agent_space}/memories/tools" return Context( uri=uri, - parent_uri=parent_uri, + parent_uri=f"viking://agent/{agent_space}/memories/tools", is_leaf=True, abstract=abstract_override or candidate.abstract, context_type=ContextType.MEMORY.value, @@ -1211,11 +1162,7 @@ def _extract_tool_guidelines(self, content: str) -> str: return content.strip() async def _merge_skill_memory( - self, - skill_name: str, - candidate: CandidateMemory, - ctx: "RequestContext", - agent_temp_uri: Optional[str] = None, + self, skill_name: str, candidate: CandidateMemory, ctx: "RequestContext" ) -> Optional[Context]: """合并 Skill Memory,统计数据用 Python 累加""" if not skill_name or not skill_name.strip(): @@ -1223,11 +1170,7 @@ async def _merge_skill_memory( return None agent_space = ctx.user.agent_space_name() - # Use temp agent URI if provided (for COW pattern) - if agent_temp_uri: - uri = f"{agent_temp_uri}/memories/skills/{skill_name}.md" - else: - uri = f"viking://agent/{agent_space}/memories/skills/{skill_name}.md" + uri = f"viking://agent/{agent_space}/memories/skills/{skill_name}.md" viking_fs = get_viking_fs() if not viking_fs: @@ -1296,7 +1239,7 @@ async def _merge_skill_memory( skill_name, merged_stats, new_guidelines, fields=new_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_skill_context(uri, candidate, ctx, agent_temp_uri=agent_temp_uri) + return self._create_skill_context(uri, candidate, ctx) existing_stats = self._parse_skill_statistics(existing) merged_stats = self._merge_skill_statistics(existing_stats, new_stats) @@ -1358,9 +1301,7 @@ async def _merge_skill_memory( skill_name, merged_stats, merged_guidelines, fields=merged_fields ) await viking_fs.write_file(uri=uri, content=merged_content, ctx=ctx) - return self._create_skill_context( - uri, candidate, ctx, abstract_override=abstract_override, agent_temp_uri=agent_temp_uri - ) + return self._create_skill_context(uri, candidate, ctx, abstract_override=abstract_override) def _compute_skill_statistics_derived(self, stats: dict) -> dict: """计算 Skill 派生统计数据(成功率)""" @@ -1512,18 +1453,12 @@ def _create_skill_context( candidate: CandidateMemory, ctx: "RequestContext", abstract_override: Optional[str] = None, - agent_temp_uri: Optional[str] = None, ) -> Context: """创建 Skill Memory 的 Context 对象""" agent_space = ctx.user.agent_space_name() - # Use temp agent URI if provided (for COW pattern) - if agent_temp_uri: - parent_uri = f"{agent_temp_uri}/memories/skills" - else: - parent_uri = f"viking://agent/{agent_space}/memories/skills" return Context( uri=uri, - parent_uri=parent_uri, + parent_uri=f"viking://agent/{agent_space}/memories/skills", is_leaf=True, abstract=abstract_override or candidate.abstract, context_type=ContextType.MEMORY.value, diff --git a/openviking/session/session.py b/openviking/session/session.py index 330a943b..84b8a76e 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -7,10 +7,9 @@ import json import re -import time from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import uuid4 from openviking.message import Message, Part @@ -92,13 +91,6 @@ def __init__( self._stats: SessionStats = SessionStats() self._loaded = False - # Temp URI management for COW pattern - self._temp_base_uri: Optional[str] = None - self._session_temp_uri: Optional[str] = None - self._user_temp_uri: Optional[str] = None - self._agent_temp_uri: Optional[str] = None - self._temp_created_at: Optional[float] = None - logger.info(f"Session created: {self.session_id} for user {self.user}") async def load(self): @@ -302,210 +294,65 @@ def commit(self) -> Dict[str, Any]: logger.info(f"Session {self.session_id} committed") return result - def _create_temp_uris(self) -> Tuple[str, str, str, str]: - """Create temp URIs for session, user and agent directories. - - Temp URI structure matches target URI structure for Semantic DAG recursive processing: - - Session: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/session/{user_space}/{session_id}/ - - User: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/user/{user_space}/ - - Agent: viking://temp/session/{user_space}/{session_id}/commit_{uuid}/agent/{agent_space}/ - - Returns: - (temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri) - """ - temp_base_uri = ( - f"viking://temp/session/" - f"{self.user.user_space_name()}/" - f"{self.session_id}/" - f"commit_{uuid4().hex[:8]}" - ) - - # Match target URI structure for Semantic DAG recursive processing - session_temp_uri = ( - f"{temp_base_uri}/session/{self.user.user_space_name()}/{self.session_id}" - ) - user_temp_uri = f"{temp_base_uri}/user/{self.user.user_space_name()}" - agent_temp_uri = f"{temp_base_uri}/agent/{self.user.agent_space_name()}" - - self._temp_base_uri = temp_base_uri - self._session_temp_uri = session_temp_uri - self._user_temp_uri = user_temp_uri - self._agent_temp_uri = agent_temp_uri - self._temp_created_at = time.time() - - return temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri - - async def _cleanup_temp_uris(self) -> None: - """Clean up all temp directories after commit.""" - if self._temp_base_uri: - try: - await self._viking_fs.delete_temp(self._temp_base_uri, ctx=self.ctx) - logger.info(f"Cleaned up temp base: {self._temp_base_uri}") - except Exception as e: - logger.warning(f"Failed to cleanup temp {self._temp_base_uri}: {e}") - finally: - self._temp_base_uri = None - self._session_temp_uri = None - self._user_temp_uri = None - self._agent_temp_uri = None - self._temp_created_at = None - async def commit_async(self) -> Dict[str, Any]: - """Async commit session with Copy-on-Write pattern. - - Process: - 1. Copy: Copy existing session, user and agent directories to temp - 2. Write: Make all changes in temp - 3. Semantic: Trigger semantic processing - 4. Switch: Atomically switch from temp to target (handled by SemanticProcessor) - """ + """Async commit session: create archive, extract memories, persist.""" result = { "session_id": self.session_id, "status": "committed", "memories_extracted": 0, "active_count_updated": 0, "archived": False, - "temp_base_uri": None, - "session_temp_uri": None, - "user_temp_uri": None, - "agent_temp_uri": None, - "semantic_msg_id": None, "stats": None, } - if not self._messages: return result - # ========== Phase 1: Copy ========== - temp_base_uri, session_temp_uri, user_temp_uri, agent_temp_uri = self._create_temp_uris() - result["temp_base_uri"] = temp_base_uri - result["session_temp_uri"] = session_temp_uri - result["user_temp_uri"] = user_temp_uri - result["agent_temp_uri"] = agent_temp_uri - - try: - # 1.1 Copy existing session to temp - logger.info(f"Copying session {self.session_id} to temp: {session_temp_uri}") - try: - await self._viking_fs.copy_directory( - from_uri=self._session_uri, - to_uri=session_temp_uri, - ctx=self.ctx, - ) - logger.info(f"Session copied to temp: {session_temp_uri}") - except Exception as e: - if "not found" in str(e).lower(): - logger.info(f"Session {self.session_id} not found, creating new temp") - await self._viking_fs.mkdir(session_temp_uri, exist_ok=True, ctx=self.ctx) - else: - raise - - # 1.2 Copy existing user directory to temp - user_uri = f"viking://user/{self.user.user_space_name()}" - logger.info(f"Copying user directory to temp: {user_temp_uri}") - try: - await self._viking_fs.copy_directory( - from_uri=user_uri, - to_uri=user_temp_uri, - ctx=self.ctx, - ) - logger.info(f"User directory copied to temp: {user_temp_uri}") - except Exception as e: - if "not found" in str(e).lower(): - logger.info("User directory not found, creating new temp") - await self._viking_fs.mkdir(user_temp_uri, exist_ok=True, ctx=self.ctx) - else: - raise - - # 1.3 Copy existing agent directory to temp - agent_uri = f"viking://agent/{self.user.agent_space_name()}" - logger.info(f"Copying agent directory to temp: {agent_temp_uri}") - try: - await self._viking_fs.copy_directory( - from_uri=agent_uri, - to_uri=agent_temp_uri, - ctx=self.ctx, - ) - logger.info(f"Agent directory copied to temp: {agent_temp_uri}") - except Exception as e: - if "not found" in str(e).lower(): - logger.info("Agent directory not found, creating new temp") - await self._viking_fs.mkdir(agent_temp_uri, exist_ok=True, ctx=self.ctx) - else: - raise + # 1. Archive current messages + self._compression.compression_index += 1 + messages_to_archive = self._messages.copy() - except Exception as e: - logger.error(f"Failed to copy directories to temp: {e}") - await self._cleanup_temp_uris() - raise + summary = await self._generate_archive_summary_async(messages_to_archive) + archive_abstract = self._extract_abstract_from_summary(summary) + archive_overview = summary - # ========== Phase 2: Write (all changes in temp) ========== - try: - # 2.1 Archive current messages to temp - self._compression.compression_index += 1 - messages_to_archive = self._messages.copy() + await self._write_archive_async( + index=self._compression.compression_index, + messages=messages_to_archive, + abstract=archive_abstract, + overview=archive_overview, + ) - await self._write_archive_to_temp( - temp_uri=session_temp_uri, - index=self._compression.compression_index, - messages=messages_to_archive, - ) + self._compression.original_count += len(messages_to_archive) + result["archived"] = True - self._compression.original_count += len(messages_to_archive) - result["archived"] = True + self._messages.clear() + logger.info( + f"Archived: {len(messages_to_archive)} messages → history/archive_{self._compression.compression_index:03d}/" + ) - self._messages.clear() + # 2. Extract long-term memories + if self._session_compressor: logger.info( - f"Archived: {len(messages_to_archive)} messages → " - f"{session_temp_uri}/history/archive_{self._compression.compression_index:03d}/" + f"Starting memory extraction from {len(messages_to_archive)} archived messages" ) - - # 2.2 Extract long-term memories (to temp user and agent directories) - if self._session_compressor: - logger.info( - f"Starting memory extraction from {len(messages_to_archive)} archived messages" - ) - memories = await self._session_compressor.extract_long_term_memories( - messages=messages_to_archive, - user=self.user, - session_id=self.session_id, - ctx=self.ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=agent_temp_uri, - ) - logger.info(f"Extracted {len(memories)} memories to temp directories") - result["memories_extracted"] = len(memories) - self._stats.memories_extracted += len(memories) - - # 2.3 Write current messages to temp - await self._write_messages_to_temp(session_temp_uri, self._messages) - - logger.info(f"Session changes written to temp: {session_temp_uri}") - # 2.5 Update active_count - active_count_updated = await self._update_active_counts_async() - result["active_count_updated"] = active_count_updated - except Exception as e: - logger.error(f"Failed to write changes to temp: {e}") - await self._cleanup_temp_uris() - raise - - # ========== Phase 3: Semantic = Switch =========== - try: - semantic_msg_ids = await self._enqueue_to_semantic_queue( - session_temp_uri=session_temp_uri, - user_temp_uri=user_temp_uri, - agent_temp_uri=agent_temp_uri, + memories = await self._session_compressor.extract_long_term_memories( + messages=messages_to_archive, ) + logger.info(f"Extracted {len(memories)} memories") + result["memories_extracted"] = len(memories) + self._stats.memories_extracted += len(memories) - logger.info(f"Session, user, agent enqueued to SemanticQueue: {semantic_msg_ids}") - result["semantic_msg_ids"] = semantic_msg_ids + # 3. Write current messages to AGFS + await self._write_to_agfs_async(self._messages) - except Exception as e: - logger.error(f"Failed to enqueue to SemanticQueue: {e}") - await self._cleanup_temp_uris() - raise + # 4. Create relations + await self._write_relations_async() - # ========== Update statistics ========== + # 5. Update active_count + active_count_updated = await self._update_active_counts_async() + result["active_count_updated"] = active_count_updated + + # 6. Update statistics self._stats.compression_count = self._compression.compression_index result["stats"] = { "total_turns": self._stats.total_turns, @@ -515,7 +362,7 @@ async def commit_async(self) -> Dict[str, Any]: } self._stats.total_tokens = 0 - logger.info(f"Session {self.session_id} committed (async with COW pattern)") + logger.info(f"Session {self.session_id} committed (async)") return result def _update_active_counts(self) -> int: @@ -699,105 +546,6 @@ def _write_archive( logger.debug(f"Written archive: {archive_uri}") - async def _write_archive_to_temp( - self, - temp_uri: str, - index: int, - messages: List[Message], - ) -> None: - """Write archive to temp directory. - - Note: .abstract.md and .overview.md will be generated by Semantic DAG. - """ - archive_uri = f"{temp_uri}/history/archive_{index:03d}" - - lines = [m.to_jsonl() for m in messages] - await self._viking_fs.write_file( - uri=f"{archive_uri}/messages.jsonl", - content="\n".join(lines) + "\n", - ctx=self.ctx, - ) - - # Note: .abstract.md and .overview.md will be generated by Semantic DAG - # No need to manually create them here - - logger.debug(f"Written archive to temp: {archive_uri}") - - async def _write_messages_to_temp(self, temp_uri: str, messages: List[Message]) -> None: - """Write current messages to temp directory.""" - lines = [m.to_jsonl() for m in messages] - content = "\n".join(lines) + "\n" if lines else "" - - await self._viking_fs.write_file( - uri=f"{temp_uri}/messages.jsonl", - content=content, - ctx=self.ctx, - ) - - async def _enqueue_to_semantic_queue( - self, - session_temp_uri: str, - user_temp_uri: str, - agent_temp_uri: str, - ) -> List[str]: - """Enqueue session, user, and agent to SemanticQueue for L0/L1 generation. - - The SemanticProcessor will handle: - 1. Generate L0/L1 for session, user and agent directories - 2. Atomically switch temp URIs to target URIs - 3. Create usage relations - 4. Clean up temp URIs - - Returns: - List of message IDs [session_msg_id, user_msg_id, agent_msg_id] - """ - from openviking.storage.queuefs import SemanticMsg, get_queue_manager - - queue_manager = get_queue_manager() - semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC, allow_create=True) - - user_target_uri = f"viking://user/{self.user.user_space_name()}" - agent_target_uri = f"viking://agent/{self.user.agent_space_name()}" - - session_msg = SemanticMsg( - uri=session_temp_uri, - context_type="memory", - target_uri=self._session_uri, - account_id=self.ctx.account_id, - user_id=self.ctx.user.user_id, - agent_id=self.ctx.user.agent_id, - role=self.ctx.role.value, - recursive=True, - ) - - user_msg = SemanticMsg( - uri=user_temp_uri, - context_type="memory", - target_uri=user_target_uri, - account_id=self.ctx.account_id, - user_id=self.ctx.user.user_id, - agent_id=self.ctx.user.agent_id, - role=self.ctx.role.value, - recursive=True, - ) - - agent_msg = SemanticMsg( - uri=agent_temp_uri, - context_type="memory", - target_uri=agent_target_uri, - account_id=self.ctx.account_id, - user_id=self.ctx.user.user_id, - agent_id=self.ctx.user.agent_id, - role=self.ctx.role.value, - recursive=True, - ) - - await semantic_queue.enqueue(session_msg) - await semantic_queue.enqueue(user_msg) - await semantic_queue.enqueue(agent_msg) - - return [session_msg.id, user_msg.id, agent_msg.id] - async def _write_archive_async( self, index: int, diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index 817eb02e..90c61d07 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -263,12 +263,3 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, traceback.print_exc() self.report_error(str(e), data) return None - finally: - if embedding_msg and embedding_msg.semantic_msg_id: - from openviking.storage.queuefs.embedding_tracker import EmbeddingTaskTracker - - tracker = EmbeddingTaskTracker.get_instance() - try: - await tracker.decrement(embedding_msg.semantic_msg_id) - except Exception as tracker_err: - logger.warning(f"Failed to decrement embedding tracker: {tracker_err}") diff --git a/openviking/storage/queuefs/__init__.py b/openviking/storage/queuefs/__init__.py index 87514a83..b73a01d7 100644 --- a/openviking/storage/queuefs/__init__.py +++ b/openviking/storage/queuefs/__init__.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from .embedding_msg import EmbeddingMsg from .embedding_queue import EmbeddingQueue -from .embedding_tracker import EmbeddingTaskTracker from .named_queue import NamedQueue, QueueError, QueueStatus from .queue_manager import QueueManager, get_queue_manager, init_queue_manager from .semantic_dag import SemanticDagExecutor @@ -19,7 +18,6 @@ "QueueError", "EmbeddingQueue", "EmbeddingMsg", - "EmbeddingTaskTracker", "SemanticQueue", "SemanticDagExecutor", "SemanticMsg", diff --git a/openviking/storage/queuefs/embedding_msg.py b/openviking/storage/queuefs/embedding_msg.py index 94e93a2c..19b8381e 100644 --- a/openviking/storage/queuefs/embedding_msg.py +++ b/openviking/storage/queuefs/embedding_msg.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union from uuid import uuid4 @@ -10,18 +10,11 @@ class EmbeddingMsg: message: Union[str, List[Dict[str, Any]]] context_data: Dict[str, Any] - semantic_msg_id: Optional[str] = None - def __init__( - self, - message: Union[str, List[Dict[str, Any]]], - context_data: Dict[str, Any], - semantic_msg_id: Optional[str] = None, - ): + def __init__(self, message: Union[str, List[Dict[str, Any]]], context_data: Dict[str, Any]): self.id = str(uuid4()) self.message = message self.context_data = context_data - self.semantic_msg_id = semantic_msg_id def to_dict(self) -> Dict[str, Any]: """Convert embedding message to dictionary format.""" @@ -37,7 +30,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "EmbeddingMsg": obj = EmbeddingMsg( message=data["message"], context_data=data["context_data"], - semantic_msg_id=data.get("semantic_msg_id"), ) obj.id = data.get("id", obj.id) return obj diff --git a/openviking/storage/queuefs/embedding_tracker.py b/openviking/storage/queuefs/embedding_tracker.py deleted file mode 100644 index 4149d8d1..00000000 --- a/openviking/storage/queuefs/embedding_tracker.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Embedding Task Tracker for tracking embedding task completion status.""" - -import asyncio -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional - -from openviking_cli.utils.logger import get_logger - -logger = get_logger(__name__) - - -@dataclass -class EmbeddingTaskTracker: - """Track embedding task completion status for each SemanticMsg. - - This tracker maintains a global registry of embedding tasks associated - with each SemanticMsg. When all embedding tasks for a SemanticMsg are - completed, it triggers the registered callback and removes the entry. - """ - - _instance: Optional["EmbeddingTaskTracker"] = None - _lock: asyncio.Lock = field(default_factory=asyncio.Lock) - _tasks: Dict[str, Dict[str, Any]] = field(default_factory=dict) - - @classmethod - def get_instance(cls) -> "EmbeddingTaskTracker": - """Get the singleton instance of EmbeddingTaskTracker.""" - if cls._instance is None: - cls._instance = cls() - return cls._instance - - async def register( - self, - semantic_msg_id: str, - total_count: int, - on_complete: Optional[Callable[[], Any]] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> None: - """Register a SemanticMsg with its total embedding task count. - - Args: - semantic_msg_id: The ID of the SemanticMsg - total_count: Total number of embedding tasks for this SemanticMsg - on_complete: Optional callback when all tasks complete - metadata: Optional metadata to store with the task - """ - async with self._lock: - self._tasks[semantic_msg_id] = { - "remaining": total_count, - "total": total_count, - "on_complete": on_complete, - "metadata": metadata or {}, - } - logger.info( - f"Registered embedding tracker for SemanticMsg {semantic_msg_id}: " - f"{total_count} tasks" - ) - - if total_count <= 0 and on_complete: - del self._tasks[semantic_msg_id] - logger.info( - f"No embedding tasks for SemanticMsg {semantic_msg_id}, " - f"triggering on_complete immediately" - ) - - if total_count <= 0 and on_complete: - try: - result = on_complete() - if asyncio.iscoroutine(result): - await result - except Exception as e: - logger.error( - f"Error in completion callback for {semantic_msg_id}: {e}", - exc_info=True, - ) - - async def increment(self, semantic_msg_id: str) -> Optional[int]: - """Increment the remaining task count for a SemanticMsg. - - This method should be called when a new embedding task is added - for an already registered SemanticMsg. - - Args: - semantic_msg_id: The ID of the SemanticMsg - - Returns: - The remaining count after increment, or None if not found - """ - async with self._lock: - if semantic_msg_id not in self._tasks: - return None - - task_info = self._tasks[semantic_msg_id] - task_info["remaining"] += 1 - task_info["total"] += 1 - remaining = task_info["remaining"] - - return remaining - - async def decrement(self, semantic_msg_id: str) -> Optional[int]: - """Decrement the remaining task count for a SemanticMsg. - - This method should be called when an embedding task is completed. - When the count reaches zero, the registered callback is executed - and the entry is removed from the tracker. - - Args: - semantic_msg_id: The ID of the SemanticMsg - - Returns: - The remaining count after decrement, or None if not found - """ - on_complete = None - - async with self._lock: - if semantic_msg_id not in self._tasks: - return None - - task_info = self._tasks[semantic_msg_id] - task_info["remaining"] -= 1 - remaining = task_info["remaining"] - - if remaining <= 0: - on_complete = task_info.get("on_complete") - - del self._tasks[semantic_msg_id] - logger.info( - f"All embedding tasks({task_info['total']}) completed for SemanticMsg {semantic_msg_id}" - ) - - if on_complete: - try: - result = on_complete() - if asyncio.iscoroutine(result): - await result - except Exception as e: - logger.error( - f"Error in completion callback for {semantic_msg_id}: {e}", - exc_info=True, - ) - return remaining - - async def get_status(self, semantic_msg_id: str) -> Optional[Dict[str, Any]]: - """Get the current status of a SemanticMsg's embedding tasks. - - Args: - semantic_msg_id: The ID of the SemanticMsg - - Returns: - Dict with 'remaining', 'total', 'metadata' or None if not found - """ - async with self._lock: - if semantic_msg_id not in self._tasks: - return None - task_info = self._tasks[semantic_msg_id] - return { - "remaining": task_info["remaining"], - "total": task_info["total"], - "metadata": task_info.get("metadata", {}), - } - - async def remove(self, semantic_msg_id: str) -> bool: - """Remove a SemanticMsg from the tracker. - - Args: - semantic_msg_id: The ID of the SemanticMsg - - Returns: - True if removed, False if not found - """ - async with self._lock: - if semantic_msg_id in self._tasks: - del self._tasks[semantic_msg_id] - return True - return False - - async def get_all_tracked(self) -> Dict[str, Dict[str, Any]]: - """Get all currently tracked SemanticMsgs. - - Returns: - Dict of semantic_msg_id -> task info - """ - async with self._lock: - return { - msg_id: { - "remaining": info["remaining"], - "total": info["total"], - "metadata": info.get("metadata", {}), - } - for msg_id, info in self._tasks.items() - } diff --git a/openviking/storage/queuefs/semantic_dag.py b/openviking/storage/queuefs/semantic_dag.py index 397250c5..0307521f 100644 --- a/openviking/storage/queuefs/semantic_dag.py +++ b/openviking/storage/queuefs/semantic_dag.py @@ -48,19 +48,11 @@ def __init__( context_type: str, max_concurrent_llm: int, ctx: RequestContext, - incremental_update: bool = False, - target_uri: Optional[str] = None, - semantic_msg_id: Optional[str] = None, - recursive: bool = True, ): self._processor = processor self._context_type = context_type self._max_concurrent_llm = max_concurrent_llm self._ctx = ctx - self._incremental_update = incremental_update - self._target_uri = target_uri - self._semantic_msg_id = semantic_msg_id - self._recursive = recursive self._llm_sem = asyncio.Semaphore(max_concurrent_llm) self._viking_fs = get_viking_fs() self._nodes: Dict[str, DirNode] = {} @@ -87,10 +79,8 @@ async def _dispatch_dir(self, dir_uri: str, parent_uri: Optional[str]) -> None: children_dirs, file_paths = await self._list_dir(dir_uri) file_index = {path: idx for idx, path in enumerate(file_paths)} child_index = {path: idx for idx, path in enumerate(children_dirs)} - if self._recursive: - pending = len(children_dirs) + len(file_paths) - else: - pending = len(file_paths) + pending = len(children_dirs) + len(file_paths) + node = DirNode( uri=dir_uri, children_dirs=children_dirs, @@ -118,10 +108,8 @@ async def _dispatch_dir(self, dir_uri: str, parent_uri: Optional[str]) -> None: self._stats.in_progress_nodes += 1 asyncio.create_task(self._file_summary_task(dir_uri, file_path)) - if children_dirs: - if self._recursive: - for child_uri in children_dirs: - asyncio.create_task(self._dispatch_dir(child_uri, dir_uri)) + for child_uri in children_dirs: + asyncio.create_task(self._dispatch_dir(child_uri, dir_uri)) except Exception as e: logger.error(f"Failed to dispatch directory {dir_uri}: {e}", exc_info=True) if parent_uri: @@ -153,107 +141,13 @@ async def _list_dir(self, uri: str) -> tuple[list[str], list[str]]: return children_dirs, file_paths - def _get_target_file_path(self, current_uri: str) -> Optional[str]: - if not self._incremental_update or not self._target_uri or not self._root_uri: - logger.warning( - f"Invalid target_uri or root_uri for incremental update: target_uri={self._target_uri}, root_uri={self._root_uri}" - ) - return None - try: - relative_path = current_uri[len(self._root_uri) :] - if relative_path.startswith("/"): - relative_path = relative_path[1:] - return f"{self._target_uri}/{relative_path}" if relative_path else self._target_uri - except Exception: - return None - - async def _check_file_content_changed(self, file_path: str) -> bool: - target_path = self._get_target_file_path(file_path) - if not target_path: - return True - try: - current_content = await self._viking_fs.read_file(file_path, ctx=self._ctx) - target_content = await self._viking_fs.read_file(target_path, ctx=self._ctx) - return current_content != target_content - except Exception: - return True - - async def _read_existing_summary(self, file_path: str) -> Optional[Dict[str, str]]: - target_path = self._get_target_file_path(file_path) - if not target_path: - return None - try: - vector_store = self._viking_fs._get_vector_store() - if not vector_store: - return None - records = await vector_store.get_context_by_uri( - account_id=self._ctx.account_id, - uri=target_path, - limit=1, - ) - if records and len(records) > 0: - record = records[0] - summary = record.get("abstract", "") - if summary: - file_name = file_path.split("/")[-1] - return {"name": file_name, "summary": summary} - except Exception: - pass - return None - - async def _check_dir_children_changed( - self, dir_uri: str, current_files: List[str], current_dirs: List[str] - ) -> bool: - target_path = self._get_target_file_path(dir_uri) - if not target_path: - return True - try: - target_dirs, target_files = await self._list_dir(target_path) - current_file_names = {f.split("/")[-1] for f in current_files} - target_file_names = {f.split("/")[-1] for f in target_files} - if current_file_names != target_file_names: - return True - current_dir_names = {d.split("/")[-1] for d in current_dirs} - target_dir_names = {d.split("/")[-1] for d in target_dirs} - if current_dir_names != target_dir_names: - return True - for current_file in current_files: - if await self._check_file_content_changed(current_file): - return True - return False - except Exception: - return True - - async def _read_existing_overview_abstract( - self, dir_uri: str - ) -> tuple[Optional[str], Optional[str]]: - target_path = self._get_target_file_path(dir_uri) - if not target_path: - return None, None - try: - overview = await self._viking_fs.read_file(f"{target_path}/.overview.md", ctx=self._ctx) - abstract = await self._viking_fs.read_file(f"{target_path}/.abstract.md", ctx=self._ctx) - return overview, abstract - except Exception: - return None, None - async def _file_summary_task(self, parent_uri: str, file_path: str) -> None: """Generate file summary and notify parent completion.""" - file_name = file_path.split("/")[-1] - need_vectorize = True try: - summary_dict = None - if self._incremental_update: - content_changed = await self._check_file_content_changed(file_path) - - if not content_changed: - summary_dict = await self._read_existing_summary(file_path) - need_vectorize = False - if summary_dict is None: - summary_dict = await self._processor._generate_single_file_summary( - file_path, llm_sem=self._llm_sem, ctx=self._ctx - ) + summary_dict = await self._processor._generate_single_file_summary( + file_path, llm_sem=self._llm_sem, ctx=self._ctx + ) except Exception as e: logger.warning(f"Failed to generate summary for {file_path}: {e}") summary_dict = {"name": file_name, "summary": ""} @@ -261,21 +155,21 @@ async def _file_summary_task(self, parent_uri: str, file_path: str) -> None: self._stats.done_nodes += 1 self._stats.in_progress_nodes = max(0, self._stats.in_progress_nodes - 1) + await self._on_file_done(parent_uri, file_path, summary_dict) + + # Vectorize file as soon as summary is ready to avoid waiting for overview. try: - if need_vectorize: - asyncio.create_task( - self._processor._vectorize_single_file( - parent_uri=parent_uri, - context_type=self._context_type, - file_path=file_path, - summary_dict=summary_dict, - ctx=self._ctx, - semantic_msg_id=self._semantic_msg_id, - ) + asyncio.create_task( + self._processor._vectorize_single_file( + parent_uri=parent_uri, + context_type=self._context_type, + file_path=file_path, + summary_dict=summary_dict, + ctx=self._ctx, ) + ) except Exception as e: logger.error(f"Failed to schedule vectorization for {file_path}: {e}", exc_info=True) - await self._on_file_done(parent_uri, file_path, summary_dict) async def _on_file_done( self, parent_uri: str, file_path: str, summary_dict: Dict[str, str] @@ -347,27 +241,17 @@ async def _overview_task(self, dir_uri: str) -> None: node = self._nodes.get(dir_uri) if not node: return - need_vectorize = True + + async with node.lock: + file_summaries = self._finalize_file_summaries(node) + children_abstracts = self._finalize_children_abstracts(node) + try: - overview = None - abstract = None - if self._incremental_update: - children_changed = await self._check_dir_children_changed( - dir_uri, node.file_paths, node.children_dirs + async with self._llm_sem: + overview = await self._processor._generate_overview( + dir_uri, file_summaries, children_abstracts ) - - if not children_changed: - need_vectorize = False - overview, abstract = await self._read_existing_overview_abstract(dir_uri) - if overview is None or abstract is None: - async with node.lock: - file_summaries = self._finalize_file_summaries(node) - children_abstracts = self._finalize_children_abstracts(node) - async with self._llm_sem: - overview = await self._processor._generate_overview( - dir_uri, file_summaries, children_abstracts - ) - abstract = self._processor._extract_abstract_from_overview(overview) + abstract = self._processor._extract_abstract_from_overview(overview) try: await self._viking_fs.write_file(f"{dir_uri}/.overview.md", overview, ctx=self._ctx) @@ -376,19 +260,11 @@ async def _overview_task(self, dir_uri: str) -> None: logger.warning(f"Failed to write overview/abstract for {dir_uri}: {e}") try: - if need_vectorize: - asyncio.create_task( - self._processor._vectorize_directory( - dir_uri, - self._context_type, - abstract, - overview, - ctx=self._ctx, - semantic_msg_id=self._semantic_msg_id, - ) - ) + await self._processor._vectorize_directory_simple( + dir_uri, self._context_type, abstract, overview, ctx=self._ctx + ) except Exception as e: - logger.error(f"Failed to schedule vectorization for {dir_uri}: {e}", exc_info=True) + logger.error(f"Failed to vectorize directory {dir_uri}: {e}", exc_info=True) except Exception as e: logger.error(f"Failed to generate overview for {dir_uri}: {e}", exc_info=True) diff --git a/openviking/storage/queuefs/semantic_msg.py b/openviking/storage/queuefs/semantic_msg.py index 7517dd3a..5f7bd730 100644 --- a/openviking/storage/queuefs/semantic_msg.py +++ b/openviking/storage/queuefs/semantic_msg.py @@ -16,7 +16,7 @@ class SemanticMsg: Attributes: id: Unique identifier (UUID) uri: Directory URI to process - context_type: Type of context (resource, memory, skill, session) + context_type: Type of context (resource, memory, skill) status: Processing status (pending/processing/completed) timestamp: Creation timestamp recursive: Whether to recursively process subdirectories. @@ -27,7 +27,7 @@ class SemanticMsg: id: str # UUID uri: str # Directory URI - context_type: str # resource, memory, skill, session + context_type: str # resource, memory, skill status: str = "pending" # pending/processing/completed timestamp: int = int(datetime.now().timestamp()) recursive: bool = True # Whether to recursively process subdirectories @@ -37,7 +37,6 @@ class SemanticMsg: role: str = "root" # Additional flags skip_vectorization: bool = False - target_uri: str = "" def __init__( self, @@ -49,7 +48,6 @@ def __init__( agent_id: str = "default", role: str = "root", skip_vectorization: bool = False, - target_uri: str = "", ): self.id = str(uuid4()) self.uri = uri @@ -60,7 +58,6 @@ def __init__( self.agent_id = agent_id self.role = role self.skip_vectorization = skip_vectorization - self.target_uri = target_uri def to_dict(self) -> Dict[str, Any]: """Convert object to dictionary.""" @@ -96,7 +93,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "SemanticMsg": agent_id=data.get("agent_id", "default"), role=data.get("role", "root"), skip_vectorization=data.get("skip_vectorization", False), - target_uri=data.get("target_uri", ""), ) if "id" in data and data["id"]: obj.id = data["id"] diff --git a/openviking/storage/queuefs/semantic_processor.py b/openviking/storage/queuefs/semantic_processor.py index 56a05b14..59700783 100644 --- a/openviking/storage/queuefs/semantic_processor.py +++ b/openviking/storage/queuefs/semantic_processor.py @@ -3,8 +3,7 @@ """SemanticProcessor: Processes messages from SemanticQueue, generates .abstract.md and .overview.md.""" import asyncio -from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Tuple from openviking.parse.parsers.constants import ( CODE_EXTENSIONS, @@ -30,22 +29,9 @@ from openviking_cli.utils.config import get_openviking_config from openviking_cli.utils.logger import get_logger -from .embedding_tracker import EmbeddingTaskTracker - logger = get_logger(__name__) -@dataclass -class DiffResult: - """Directory diff result for sync operations.""" - - added_files: List[str] = field(default_factory=list) - deleted_files: List[str] = field(default_factory=list) - updated_files: List[str] = field(default_factory=list) - added_dirs: List[str] = field(default_factory=list) - deleted_dirs: List[str] = field(default_factory=list) - - class SemanticProcessor(DequeueHandlerBase): """ Semantic processor, generates .abstract.md and .overview.md bottom-up. @@ -117,21 +103,55 @@ def _detect_file_type(self, file_name: str) -> str: # Default to other return FILE_TYPE_OTHER - async def _check_file_content_changed( - self, file_path: str, target_file: str, ctx: Optional[RequestContext] = None - ) -> bool: - """Check if file content has changed compared to target file.""" + async def _enqueue_semantic_msg(self, msg: SemanticMsg) -> None: + """Enqueue a SemanticMsg to the semantic queue for processing.""" + from openviking.storage.queuefs import get_queue_manager + + queue_manager = get_queue_manager() + semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC) + # The queue manager returns SemanticQueue but method signature says NamedQueue + # We need to ignore the type error for the enqueue call + await semantic_queue.enqueue(msg) # type: ignore + logger.debug(f"Enqueued semantic message for processing: {msg.uri}") + + async def _collect_directory_info( + self, + uri: str, + result: List[Tuple[str, List[str], List[str]]], + ) -> None: + """Recursively collect directory info, post-order traversal ensures bottom-up order.""" viking_fs = get_viking_fs() + try: - current_content = await viking_fs.read_file(file_path, ctx=ctx) - target_content = await viking_fs.read_file(target_file, ctx=ctx) - return current_content != target_content - except Exception: - return True + entries = await viking_fs.ls(uri, ctx=self._current_ctx) + except Exception as e: + logger.warning(f"Failed to list directory {uri}: {e}") + return + + children_uris = [] + file_paths = [] + + for entry in entries: + name = entry.get("name", "") + if not name or name.startswith(".") or name in [".", ".."]: + continue + + item_uri = VikingURI(uri).join(name).uri + + if entry.get("isDir", False): + # Child directory + children_uris.append(item_uri) + # Recursively collect children + await self._collect_directory_info(item_uri, result) + else: + # File (not starting with .) + file_paths.append(item_uri) + + # Add current directory info + result.append((uri, children_uris, file_paths)) async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: """Process dequeued SemanticMsg, recursively process all subdirectories.""" - msg = None try: import json @@ -146,361 +166,147 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, msg = SemanticMsg.from_dict(data) self._current_msg = msg self._current_ctx = self._ctx_from_semantic_msg(msg) - logger.info(f"Processing semantic generation for: {msg})") - - # Check if target_uri exists, auto-detect incremental update - is_incremental = False - viking_fs = get_viking_fs() - if msg.target_uri: - target_exists = await viking_fs.exists(msg.target_uri, ctx=self._current_ctx) - if target_exists: - is_incremental = True - logger.info(f"Target URI exists, using incremental update: {msg.target_uri}") - - tracker = EmbeddingTaskTracker.get_instance() - on_complete = self._create_sync_diff_callback( - root_uri=msg.uri, - target_uri=msg.target_uri, - ctx=self._current_ctx, - ) - # Register task with tracker, total_count=1 for root URI - await tracker.register( - semantic_msg_id=msg.id, - total_count=1, - on_complete=on_complete, - metadata={ - "uri": msg.uri, - }, - ) - executor = SemanticDagExecutor( - processor=self, - context_type=msg.context_type, - max_concurrent_llm=self.max_concurrent_llm, - ctx=self._current_ctx, - incremental_update=is_incremental, - target_uri=msg.target_uri, - semantic_msg_id=msg.id, - recursive=msg.recursive, + logger.info( + f"Processing semantic generation for: {msg.uri} (recursive={msg.recursive})" ) - self._dag_executor = executor - await executor.run(msg.uri) - logger.info(f"Completed semantic generation for: {msg.uri}") - self.report_success() - return None + + if msg.recursive: + executor = SemanticDagExecutor( + processor=self, + context_type=msg.context_type, + max_concurrent_llm=self.max_concurrent_llm, + ctx=self._current_ctx, + ) + self._dag_executor = executor + await executor.run(msg.uri) + logger.info(f"Completed semantic generation for: {msg.uri}") + self.report_success() + return None + else: + # Non-recursive processing: directly process this directory + children_uris = [] + file_paths = [] + + # Collect immediate children info only (no recursion) + viking_fs = get_viking_fs() + try: + entries = await viking_fs.ls(msg.uri, ctx=self._current_ctx) + for entry in entries: + name = entry.get("name", "") + if not name or name.startswith(".") or name in [".", ".."]: + continue + + item_uri = VikingURI(msg.uri).join(name).uri + + if entry.get("isDir", False): + children_uris.append(item_uri) + else: + file_paths.append(item_uri) + except Exception as e: + logger.warning(f"Failed to list directory {msg.uri}: {e}") + + # Process this directory + await self._process_single_directory( + uri=msg.uri, + context_type=msg.context_type, + children_uris=children_uris, + file_paths=file_paths, + ) + + logger.info(f"Completed semantic generation for: {msg.uri}") + self.report_success() + return None except Exception as e: logger.error(f"Failed to process semantic message: {e}", exc_info=True) self.report_error(str(e), data) return None finally: - # Decrement task counter for root URI - if msg is not None: - tracker = EmbeddingTaskTracker.get_instance() - await tracker.decrement( - semantic_msg_id=msg.id, - ) self._current_msg = None - self._current_ctx = None def get_dag_stats(self) -> Optional["DagStats"]: if not self._dag_executor: return None return self._dag_executor.get_stats() - def _create_sync_diff_callback( - self, - root_uri: str, - target_uri: str, - ctx: RequestContext, - ) -> Callable[[], Awaitable[None]]: - """ - Create a callback function to sync directory differences. - - This callback compares root_uri (new content) with target_uri (old content), - handles added/updated/deleted files, then cleans up root_uri. - - Args: - root_uri: Source directory URI (new content) - target_uri: Target directory URI (old content) - ctx: Request context (captured at callback creation time) - - Returns: - Async callback function - """ - - async def sync_diff_callback() -> None: - - try: - viking_fs = get_viking_fs() - - root_tree = await self._collect_tree_info(root_uri, ctx=ctx) - - target_tree = await self._collect_tree_info(target_uri, ctx=ctx) - diff = await self._compute_diff( - root_tree, target_tree, root_uri, target_uri, ctx=ctx - ) - logger.info( - f"[SyncDiff] Diff computed: " - f"added_files={len(diff.added_files)}, " - f"deleted_files={len(diff.deleted_files)}, " - f"updated_files={len(diff.updated_files)}, " - f"added_dirs={len(diff.added_dirs)}, " - f"deleted_dirs={len(diff.deleted_dirs)}" - ) - await self._execute_sync_operations(diff, root_uri, target_uri, ctx=ctx) - try: - await viking_fs.rm(root_uri, recursive=True, ctx=ctx) - except Exception as e: - logger.warning(f"[SyncDiff] Failed to delete root directory {root_uri}: {e}") - - except Exception as e: - logger.error( - f"[SyncDiff] Error in sync_diff_callback: " - f"root_uri={root_uri}, target_uri={target_uri} " - f"error={e}", - exc_info=True, - ) - - return sync_diff_callback - - async def _collect_tree_info( + async def _process_single_directory( self, uri: str, - ctx: Optional[RequestContext] = None, - ) -> Dict[str, Tuple[List[str], List[str]]]: - """ - Recursively collect directory tree information. - - Args: - uri: Directory URI - ctx: Request context - - Returns: - Dictionary: {dir_uri: ([subdir_uris], [file_uris])} - """ + context_type: str, + children_uris: List[str], + file_paths: List[str], + ) -> None: + """Process single directory, generate .abstract.md and .overview.md.""" viking_fs = get_viking_fs() - result: Dict[str, Tuple[List[str], List[str]]] = {} - total_dirs = 0 - total_files = 0 - - async def collect_recursive(current_uri: str, depth: int = 0) -> None: - nonlocal total_dirs, total_files - indent = " " * depth - try: - entries = await viking_fs.ls(current_uri, show_all_hidden=True, ctx=ctx) - except Exception as e: - logger.warning(f"[SyncDiff]{indent} Failed to list {current_uri}: {e}") - return - - sub_dirs: List[str] = [] - files: List[str] = [] - - for entry in entries: - name = entry.get("name", "") - if not name or name in [".", ".."]: - continue - if name.startswith(".") and name not in [".abstract.md", ".overview.md"]: - continue - - item_uri = VikingURI(current_uri).join(name).uri - - if entry.get("isDir", False): - sub_dirs.append(item_uri) - total_dirs += 1 - await collect_recursive(item_uri, depth + 1) - else: - files.append(item_uri) - total_files += 1 - - result[current_uri] = (sub_dirs, files) - - await collect_recursive(uri) - return result - - async def _compute_diff( - self, - root_tree: Dict[str, Tuple[List[str], List[str]]], - target_tree: Dict[str, Tuple[List[str], List[str]]], - root_uri: str, - target_uri: str, - ctx: Optional[RequestContext] = None, - ) -> DiffResult: - """ - Compute differences between two directory trees. - - Args: - root_tree: Directory tree from root_uri - target_tree: Directory tree from target_uri - root_uri: Source directory URI - target_uri: Target directory URI - ctx: Request context - - Returns: - DiffResult with added/deleted/updated files and directories - """ - def get_relative_path(uri: str, base_uri: str) -> str: - if uri.startswith(base_uri): - rel = uri[len(base_uri) :] - return rel.lstrip("/") - return uri - - root_files: Set[str] = set() - root_dirs: Set[str] = set() - target_files: Set[str] = set() - target_dirs: Set[str] = set() - - for dir_uri, (sub_dirs, files) in root_tree.items(): - rel_dir = get_relative_path(dir_uri, root_uri) - if rel_dir: - root_dirs.add(rel_dir) - for f in files: - root_files.add(get_relative_path(f, root_uri)) - for d in sub_dirs: - root_dirs.add(get_relative_path(d, root_uri)) - - for dir_uri, (sub_dirs, files) in target_tree.items(): - rel_dir = get_relative_path(dir_uri, target_uri) - if rel_dir: - target_dirs.add(rel_dir) - for f in files: - target_files.add(get_relative_path(f, target_uri)) - for d in sub_dirs: - target_dirs.add(get_relative_path(d, target_uri)) - - added_files_rel = root_files - target_files - deleted_files_rel = target_files - root_files - common_files = root_files & target_files - - added_dirs_rel = root_dirs - target_dirs - deleted_dirs_rel = target_dirs - root_dirs - - updated_files: List[str] = [] - for rel_file in common_files: - root_file = f"{root_uri}/{rel_file}" - target_file = f"{target_uri}/{rel_file}" - try: - if await self._check_file_content_changed(root_file, target_file, ctx=ctx): - updated_files.append(root_file) - except Exception as e: - logger.warning( - f"[SyncDiff] Failed to compare file content for {rel_file}: {e}, " - f"treating as unchanged" - ) + # 1. Collect .abstract.md from subdirectories (already processed earlier) + children_abstracts = await self._collect_children_abstracts(children_uris) - added_files = [f"{root_uri}/{f}" for f in added_files_rel] - deleted_files = [f"{target_uri}/{f}" for f in deleted_files_rel] - added_dirs = [f"{root_uri}/{d}" for d in added_dirs_rel] - deleted_dirs = [f"{target_uri}/{d}" for d in deleted_dirs_rel] - - result = DiffResult( - added_files=added_files, - deleted_files=deleted_files, - updated_files=updated_files, - added_dirs=added_dirs, - deleted_dirs=deleted_dirs, + # 2. Concurrently generate summaries for files in directory + file_summaries = await self._generate_file_summaries( + file_paths, context_type=context_type, parent_uri=uri, enqueue_files=True ) - return result - - async def _execute_sync_operations( - self, - diff: DiffResult, - root_uri: str, - target_uri: str, - ctx: Optional[RequestContext] = None, - ) -> None: - """ - Execute sync operations based on diff result. - - Processing order: - 1. Delete files in target that don't exist in root - 2. Move added/updated files from root to target - 3. Delete directories in target that don't exist in root - - Args: - diff: DiffResult containing operations to perform - root_uri: Source directory URI - target_uri: Target directory URI - ctx: Request context - """ - viking_fs = get_viking_fs() - - def map_to_target(root_item_uri: str) -> str: - if root_item_uri.startswith(root_uri): - rel = root_item_uri[len(root_uri) :] - return f"{target_uri}{rel}" if rel else target_uri - return root_item_uri + # 3. Generate .overview.md (contains brief description) + overview = await self._generate_overview(uri, file_summaries, children_abstracts) - total_deleted = 0 - total_moved = 0 - total_failed = 0 + # 4. Extract abstract from overview + abstract = self._extract_abstract_from_overview(overview) - for i, deleted_file in enumerate(diff.deleted_files, 1): - try: - await viking_fs.rm(deleted_file, ctx=ctx) - total_deleted += 1 - except Exception as e: - total_failed += 1 - logger.warning( - f"[SyncDiff] Failed to delete file [{i}/{len(diff.deleted_files)}]: {deleted_file}, error={e}" - ) - - for i, updated_file in enumerate(diff.updated_files, 1): - target_file = map_to_target(updated_file) - try: - await viking_fs.rm(target_file, ctx=ctx) - except Exception as e: - logger.warning( - f"[SyncDiff] Failed to remove old file [{i}/{len(diff.updated_files)}]: {target_file}, error={e}" - ) + # 5. Write files + await viking_fs.write_file(f"{uri}/.overview.md", overview, ctx=self._current_ctx) + await viking_fs.write_file(f"{uri}/.abstract.md", abstract, ctx=self._current_ctx) - files_to_move = diff.added_files + diff.updated_files - for i, root_file in enumerate(files_to_move, 1): - target_file = map_to_target(root_file) - try: - target_parent = VikingURI(target_file).parent - if target_parent: - try: - await viking_fs.mkdir(target_parent.uri, exist_ok=True, ctx=ctx) - except Exception as mkdir_error: - logger.debug( - f"[SyncDiff] Parent dir creation skipped (may already exist): {mkdir_error}" - ) - await viking_fs.mv(root_file, target_file, ctx=ctx) - total_moved += 1 - except Exception as e: - total_failed += 1 - logger.warning( - f"[SyncDiff] Failed to move file [{i}/{len(files_to_move)}]: " - f"{root_file} -> {target_file}, error={e}" - ) + logger.debug(f"Generated overview and abstract for {uri}") - for i, deleted_dir in enumerate( - sorted(diff.deleted_dirs, key=lambda x: x.count("/"), reverse=True), 1 - ): - try: - await viking_fs.rm(deleted_dir, recursive=True, ctx=ctx) - except Exception as e: - total_failed += 1 - logger.warning( - f"[SyncDiff] Failed to delete directory [{i}/{len(diff.deleted_dirs)}]: " - f"{deleted_dir}, error={e}" - ) + # 6. Vectorize directory + try: + await self._vectorize_directory_simple(uri, context_type, abstract, overview) + except Exception as e: + logger.error(f"Failed to vectorize directory {uri}: {e}", exc_info=True) - async def _collect_children_abstracts( - self, children_uris: List[str], ctx: Optional[RequestContext] = None - ) -> List[Dict[str, str]]: + async def _collect_children_abstracts(self, children_uris: List[str]) -> List[Dict[str, str]]: """Collect .abstract.md from subdirectories.""" viking_fs = get_viking_fs() results = [] for child_uri in children_uris: - abstract = await viking_fs.abstract(child_uri, ctx=ctx) + abstract = await viking_fs.abstract(child_uri, ctx=self._current_ctx) dir_name = child_uri.split("/")[-1] results.append({"name": dir_name, "abstract": abstract}) return results + async def _generate_file_summaries( + self, + file_paths: List[str], + context_type: Optional[str] = None, + parent_uri: Optional[str] = None, + enqueue_files: bool = False, + ) -> List[Dict[str, str]]: + """Concurrently generate file summaries.""" + if not file_paths: + return [] + + async def generate_one_summary(file_path: str) -> Dict[str, str]: + summary = await self._generate_single_file_summary(file_path, ctx=self._current_ctx) + if enqueue_files and context_type and parent_uri: + try: + await self._vectorize_single_file( + parent_uri=parent_uri, + context_type=context_type, + file_path=file_path, + summary_dict=summary, + ) + except Exception as e: + logger.error( + f"Failed to vectorize file {file_path}: {e}", + exc_info=True, + ) + return summary + + tasks = [generate_one_summary(fp) for fp in file_paths] + return await asyncio.gather(*tasks) + async def _generate_text_summary( self, file_path: str, @@ -701,14 +507,13 @@ def replace_index(match): logger.error(f"Failed to generate overview for {dir_uri}: {e}", exc_info=True) return f"# {dir_uri.split('/')[-1]}\n\nDirectory overview" - async def _vectorize_directory( + async def _vectorize_directory_simple( self, uri: str, context_type: str, abstract: str, overview: str, ctx: Optional[RequestContext] = None, - semantic_msg_id: Optional[str] = None, ) -> None: """Create directory Context and enqueue to EmbeddingQueue.""" @@ -718,12 +523,6 @@ async def _vectorize_directory( from openviking.utils.embedding_utils import vectorize_directory_meta - tracker = EmbeddingTaskTracker.get_instance() - # Increment task for .abstract.md - await tracker.increment(semantic_msg_id=semantic_msg_id) - # Increment task for .overview.md - await tracker.increment(semantic_msg_id=semantic_msg_id) - active_ctx = ctx or self._current_ctx await vectorize_directory_meta( uri=uri, @@ -731,25 +530,44 @@ async def _vectorize_directory( overview=overview, context_type=context_type, ctx=active_ctx, - semantic_msg_id=semantic_msg_id, ) + async def _vectorize_files( + self, + uri: str, + context_type: str, + file_paths: List[str], + file_summaries: List[Dict[str, str]], + ctx: Optional[RequestContext] = None, + ) -> None: + """Vectorize files in directory.""" + from openviking.storage.queuefs import get_queue_manager + + queue_manager = get_queue_manager() + embedding_queue = queue_manager.get_queue(queue_manager.EMBEDDING) + + for file_path, file_summary_dict in zip(file_paths, file_summaries): + await self._vectorize_single_file( + parent_uri=uri, + context_type=context_type, + file_path=file_path, + summary_dict=file_summary_dict, + embedding_queue=embedding_queue, + ctx=ctx, + ) + async def _vectorize_single_file( self, parent_uri: str, context_type: str, file_path: str, summary_dict: Dict[str, str], + embedding_queue: Optional[Any] = None, ctx: Optional[RequestContext] = None, - semantic_msg_id: Optional[str] = None, ) -> None: """Vectorize a single file using its content or summary.""" from openviking.utils.embedding_utils import vectorize_file - tracker = EmbeddingTaskTracker.get_instance() - await tracker.increment( - semantic_msg_id=semantic_msg_id, - ) active_ctx = ctx or self._current_ctx await vectorize_file( file_path=file_path, @@ -757,5 +575,4 @@ async def _vectorize_single_file( parent_uri=parent_uri, context_type=context_type, ctx=active_ctx, - semantic_msg_id=semantic_msg_id, ) diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index 9490fedc..ed65d83d 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from openviking.pyagfs.exceptions import AGFSHTTPError -from openviking.pyagfs.helpers import cp as agfs_cp from openviking.server.identity import RequestContext, Role from openviking.utils.time_utils import format_simplified, get_current_timestamp, parse_iso_datetime from openviking_cli.exceptions import NotFoundError @@ -361,22 +360,6 @@ async def stat(self, uri: str, ctx: Optional[RequestContext] = None) -> Dict[str path = self._uri_to_path(uri, ctx=ctx) return self.agfs.stat(path) - async def exists(self, uri: str, ctx: Optional[RequestContext] = None) -> bool: - """Check if a URI exists. - - Args: - uri: Viking URI - ctx: Request context - - Returns: - bool: True if the URI exists, False otherwise - """ - try: - await self.stat(uri, ctx=ctx) - return True - except Exception: - return False - async def glob( self, pattern: str, @@ -1484,29 +1467,6 @@ async def move_file( self.agfs.write(to_path, content) self.agfs.rm(from_path) - async def copy_directory( - self, - from_uri: str, - to_uri: str, - ctx: Optional[RequestContext] = None, - ) -> None: - """Copy directory recursively. - - Args: - from_uri: Source directory URI - to_uri: Destination directory URI - ctx: Request context - """ - self._ensure_access(from_uri, ctx) - self._ensure_access(to_uri, ctx) - - from_path = self._uri_to_path(from_uri, ctx=ctx) - to_path = self._uri_to_path(to_uri, ctx=ctx) - - await self._ensure_parent_dirs(to_path) - - await asyncio.to_thread(agfs_cp, self.agfs, from_path, to_path, recursive=True) - # ========== Temp File Operations (backward compatible) ========== def create_temp_uri(self) -> str: diff --git a/openviking/utils/embedding_utils.py b/openviking/utils/embedding_utils.py index 907c5dcd..1dffc1c8 100644 --- a/openviking/utils/embedding_utils.py +++ b/openviking/utils/embedding_utils.py @@ -116,7 +116,6 @@ async def vectorize_directory_meta( overview: str, context_type: str = "resource", ctx: Optional[RequestContext] = None, - semantic_msg_id: Optional[str] = None, ) -> None: """ Vectorize directory metadata (.abstract.md and .overview.md). @@ -148,7 +147,6 @@ async def vectorize_directory_meta( context_abstract.set_vectorize(Vectorize(text=abstract)) msg_abstract = EmbeddingMsgConverter.from_context(context_abstract) if msg_abstract: - msg_abstract.semantic_msg_id = semantic_msg_id await embedding_queue.enqueue(msg_abstract) logger.debug(f"Enqueued directory L0 (abstract) for vectorization: {uri}") @@ -167,7 +165,6 @@ async def vectorize_directory_meta( context_overview.set_vectorize(Vectorize(text=overview)) msg_overview = EmbeddingMsgConverter.from_context(context_overview) if msg_overview: - msg_overview.semantic_msg_id = semantic_msg_id await embedding_queue.enqueue(msg_overview) logger.debug(f"Enqueued directory L1 (overview) for vectorization: {uri}") @@ -178,7 +175,6 @@ async def vectorize_file( parent_uri: str, context_type: str = "resource", ctx: Optional[RequestContext] = None, - semantic_msg_id: Optional[str] = None, ) -> None: """ Vectorize a single file. @@ -250,7 +246,6 @@ async def vectorize_file( if not embedding_msg: return - embedding_msg.semantic_msg_id = semantic_msg_id await embedding_queue.enqueue(embedding_msg) logger.debug(f"Enqueued file for vectorization: {file_path}") diff --git a/openviking/utils/resource_processor.py b/openviking/utils/resource_processor.py index f70c4e54..c43ea541 100644 --- a/openviking/utils/resource_processor.py +++ b/openviking/utils/resource_processor.py @@ -120,7 +120,7 @@ async def process_resource( "source_path": None, } - # ============ Phase 1: Parse source and writes to temp viking fs ============ + # ============ Phase 1: Parse source (Parser generates L0/L1 and writes to temp) ============ try: media_processor = self._get_media_processor() viking_fs = get_viking_fs() @@ -178,7 +178,6 @@ async def process_resource( ) if context_tree and context_tree.root: result["root_uri"] = context_tree.root.uri - result["temp_uri"] = context_tree.root.temp_uri except Exception as e: result["status"] = "error" result["errors"].append(f"Finalize from temp error: {e}") @@ -194,7 +193,6 @@ async def process_resource( # ============ Phase 4: Optional Steps ============ build_index = kwargs.get("build_index", True) - temp_uri_for_summarize = result.get("temp_uri") or parse_result.temp_dir_path if summarize: # Explicit summarization request. # If build_index is ALSO True, we want vectorization. @@ -205,7 +203,6 @@ async def process_resource( resource_uris=[result["root_uri"]], ctx=ctx, skip_vectorization=skip_vec, - temp_uris=[temp_uri_for_summarize], **kwargs, ) except Exception as e: @@ -217,11 +214,7 @@ async def process_resource( # We assume this means "Ingest and Index", which requires summarization. try: await self._get_summarizer().summarize( - resource_uris=[result["root_uri"]], - ctx=ctx, - skip_vectorization=False, - temp_uris=[temp_uri_for_summarize], - **kwargs, + resource_uris=[result["root_uri"]], ctx=ctx, skip_vectorization=False, **kwargs ) except Exception as e: logger.error(f"Auto-index failed: {e}") diff --git a/openviking/utils/summarizer.py b/openviking/utils/summarizer.py index 3d059f6f..a7477ba3 100644 --- a/openviking/utils/summarizer.py +++ b/openviking/utils/summarizer.py @@ -6,14 +6,13 @@ Handles summarization and key information extraction. """ -from typing import TYPE_CHECKING, Any, Dict, List - -from openviking.storage.queuefs import SemanticMsg, get_queue_manager +from typing import TYPE_CHECKING, Any, Dict, List, Optional from openviking_cli.utils import get_logger +from openviking.storage.queuefs import SemanticMsg, get_queue_manager if TYPE_CHECKING: - from openviking.parse.vlm import VLMProcessor from openviking.server.identity import RequestContext + from openviking.parse.vlm import VLMProcessor logger = get_logger(__name__) @@ -40,19 +39,8 @@ async def summarize( queue_manager = get_queue_manager() semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC, allow_create=True) - temp_uris = kwargs.get("temp_uris", []) - if temp_uris == []: - temp_uris = resource_uris - if len(temp_uris) != len(resource_uris): - logger.error( - f"temp_uris length ({len(temp_uris)}) must match resource_uris length ({len(resource_uris)})" - ) - return { - "status": "error", - "message": "temp_uris length must match resource_uris length", - } enqueued_count = 0 - for uri, temp_uri in zip(resource_uris, temp_uris): + for uri in resource_uris: # Determine context_type based on URI context_type = "resource" if uri.startswith("viking://memory/"): @@ -61,14 +49,13 @@ async def summarize( context_type = "skill" msg = SemanticMsg( - uri=temp_uri, + uri=uri, context_type=context_type, account_id=ctx.account_id, user_id=ctx.user.user_id, agent_id=ctx.user.agent_id, role=ctx.role.value, skip_vectorization=skip_vectorization, - target_uri=uri, ) await semantic_queue.enqueue(msg) enqueued_count += 1 diff --git a/openviking_cli/exceptions.py b/openviking_cli/exceptions.py index cd432552..807d317e 100644 --- a/openviking_cli/exceptions.py +++ b/openviking_cli/exceptions.py @@ -70,14 +70,6 @@ def __init__(self, resource: str, resource_type: str = "resource"): ) -class ConflictError(OpenVikingError): - """Resource conflict (e.g., locked by another operation).""" - - def __init__(self, message: str, resource: Optional[str] = None): - details = {"resource": resource} if resource else {} - super().__init__(message, code="CONFLICT", details=details) - - # ============= Authentication Errors ============= diff --git a/tests/unit/session/test_compressor_cow.py b/tests/unit/session/test_compressor_cow.py deleted file mode 100644 index cf6efe12..00000000 --- a/tests/unit/session/test_compressor_cow.py +++ /dev/null @@ -1,458 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from openviking.core.context import Context -from openviking.session.compressor import SessionCompressor -from openviking.session.memory_extractor import CandidateMemory, MemoryCategory -from openviking_cli.session.user_id import UserIdentifier - - -def _make_user() -> UserIdentifier: - return UserIdentifier("acc1", "test_user", "test_agent") - - -def _make_candidate(category: MemoryCategory = MemoryCategory.PREFERENCES) -> CandidateMemory: - return CandidateMemory( - category=category, - abstract="User prefers concise summaries", - overview="User asks for concise answers frequently.", - content="The user prefers concise summaries over long explanations.", - source_session="session_test", - user=_make_user(), - language="en", - ) - - -def _make_context(uri: str, abstract: str = "Existing memory") -> Context: - return Context( - uri=uri, - context_type="memory", - abstract=abstract, - meta={"overview": "Existing overview"}, - ) - - -class TestConvertToTempUri: - def test_user_uri_converted_to_temp_uri(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - user_temp_uri = "viking://user/temp_user_123" - - result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, None) - - expected = f"{user_temp_uri}/memories/preferences/pref1.md" - assert result == expected - - def test_agent_uri_converted_to_temp_uri(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - agent_space = _make_user().agent_space_name() - target_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" - agent_temp_uri = "viking://agent/temp_agent_456" - - result = compressor._convert_to_temp_uri(target_uri, None, agent_temp_uri) - - expected = f"{agent_temp_uri}/memories/cases/case1.md" - assert result == expected - - def test_no_temp_uri_returns_original_uri(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - - result = compressor._convert_to_temp_uri(target_uri, None, None) - - assert result == target_uri - - def test_mixed_uris_only_convert_matching_type(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - agent_space = _make_user().agent_space_name() - - user_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - agent_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" - - user_temp_uri = "viking://user/temp_user_123" - - result_user = compressor._convert_to_temp_uri(user_uri, user_temp_uri, None) - result_agent = compressor._convert_to_temp_uri(agent_uri, user_temp_uri, None) - - assert result_user == f"{user_temp_uri}/memories/preferences/pref1.md" - assert result_agent == agent_uri - - def test_agent_uri_with_both_temp_uris(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - agent_space = _make_user().agent_space_name() - target_uri = f"viking://agent/{agent_space}/memories/patterns/pattern1.md" - user_temp_uri = "viking://user/temp_user_123" - agent_temp_uri = "viking://agent/temp_agent_456" - - result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, agent_temp_uri) - - expected = f"{agent_temp_uri}/memories/patterns/pattern1.md" - assert result == expected - - def test_user_uri_with_both_temp_uris(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - target_uri = f"viking://user/{user_space}/memories/entities/entity1.md" - user_temp_uri = "viking://user/temp_user_123" - agent_temp_uri = "viking://agent/temp_agent_456" - - result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, agent_temp_uri) - - expected = f"{user_temp_uri}/memories/entities/entity1.md" - assert result == expected - - def test_non_viking_uri_returns_unchanged(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - target_uri = "file:///some/local/path/memory.md" - user_temp_uri = "viking://user/temp_user_123" - - result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, None) - - assert result == target_uri - - def test_short_uri_returns_unchanged(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - target_uri = "viking://user" - user_temp_uri = "viking://user/temp_user_123" - - result = compressor._convert_to_temp_uri(target_uri, user_temp_uri, None) - - assert result == target_uri - - -@pytest.mark.asyncio -class TestMergeIntoExisting: - async def test_merge_into_existing_success(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - user_temp_uri = "viking://user/temp_user_123" - temp_uri = f"{user_temp_uri}/memories/preferences/pref1.md" - - candidate = _make_candidate() - target_memory = _make_context(target_uri) - - viking_fs = AsyncMock() - viking_fs.read_file = AsyncMock(return_value="Existing content") - viking_fs.write_file = AsyncMock() - - mock_payload = MagicMock() - mock_payload.abstract = "Merged abstract" - mock_payload.overview = "Merged overview" - mock_payload.content = "Merged content" - - with patch.object( - compressor.extractor, - "_merge_memory_bundle", - AsyncMock(return_value=mock_payload), - ): - ctx = MagicMock() - result = await compressor._merge_into_existing( - candidate, - target_memory, - viking_fs, - ctx=ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=None, - ) - - assert result is True - viking_fs.read_file.assert_called_once() - viking_fs.write_file.assert_called_once() - assert target_memory.abstract == "Merged abstract" - assert target_memory.meta.get("overview") == "Merged overview" - - async def test_merge_into_existing_without_temp_uri(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - - candidate = _make_candidate() - target_memory = _make_context(target_uri) - - viking_fs = AsyncMock() - viking_fs.read_file = AsyncMock(return_value="Existing content") - viking_fs.write_file = AsyncMock() - - mock_payload = MagicMock() - mock_payload.abstract = "Merged abstract" - mock_payload.overview = "Merged overview" - mock_payload.content = "Merged content" - - with patch.object( - compressor.extractor, - "_merge_memory_bundle", - AsyncMock(return_value=mock_payload), - ): - ctx = MagicMock() - result = await compressor._merge_into_existing( - candidate, - target_memory, - viking_fs, - ctx=ctx, - user_temp_uri=None, - agent_temp_uri=None, - ) - - assert result is True - viking_fs.read_file.assert_called_once_with(target_uri, ctx=ctx) - - async def test_merge_into_existing_merge_bundle_returns_none(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - user_temp_uri = "viking://user/temp_user_123" - - candidate = _make_candidate() - target_memory = _make_context(target_uri) - - viking_fs = AsyncMock() - viking_fs.read_file = AsyncMock(return_value="Existing content") - - with patch.object( - compressor.extractor, - "_merge_memory_bundle", - AsyncMock(return_value=None), - ): - ctx = MagicMock() - result = await compressor._merge_into_existing( - candidate, - target_memory, - viking_fs, - ctx=ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=None, - ) - - assert result is False - viking_fs.write_file.assert_not_called() - - async def test_merge_into_existing_read_file_exception(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - user_temp_uri = "viking://user/temp_user_123" - - candidate = _make_candidate() - target_memory = _make_context(target_uri) - - viking_fs = AsyncMock() - viking_fs.read_file = AsyncMock(side_effect=Exception("Read error")) - - ctx = MagicMock() - result = await compressor._merge_into_existing( - candidate, - target_memory, - viking_fs, - ctx=ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=None, - ) - - assert result is False - - async def test_merge_into_existing_agent_uri(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - agent_space = _make_user().agent_space_name() - target_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" - agent_temp_uri = "viking://agent/temp_agent_456" - - candidate = _make_candidate(category=MemoryCategory.CASES) - target_memory = _make_context(target_uri) - - viking_fs = AsyncMock() - viking_fs.read_file = AsyncMock(return_value="Existing content") - viking_fs.write_file = AsyncMock() - - mock_payload = MagicMock() - mock_payload.abstract = "Merged case abstract" - mock_payload.overview = "Merged case overview" - mock_payload.content = "Merged case content" - - with patch.object( - compressor.extractor, - "_merge_memory_bundle", - AsyncMock(return_value=mock_payload), - ): - ctx = MagicMock() - result = await compressor._merge_into_existing( - candidate, - target_memory, - viking_fs, - ctx=ctx, - user_temp_uri=None, - agent_temp_uri=agent_temp_uri, - ) - - assert result is True - expected_temp_uri = f"{agent_temp_uri}/memories/cases/case1.md" - viking_fs.read_file.assert_called_once_with(expected_temp_uri, ctx=ctx) - - -@pytest.mark.asyncio -class TestDeleteExistingMemory: - async def test_delete_existing_memory_success(self): - vikingdb = MagicMock() - vikingdb.delete_uris = AsyncMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - user_temp_uri = "viking://user/temp_user_123" - temp_uri = f"{user_temp_uri}/memories/preferences/pref1.md" - - memory = _make_context(target_uri) - - viking_fs = AsyncMock() - viking_fs.rm = AsyncMock() - - ctx = MagicMock() - result = await compressor._delete_existing_memory( - memory, - viking_fs, - ctx=ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=None, - ) - - assert result is True - viking_fs.rm.assert_called_once_with(temp_uri, recursive=False, ctx=ctx) - vikingdb.delete_uris.assert_called_once_with(ctx, [temp_uri]) - - async def test_delete_existing_memory_without_temp_uri(self): - vikingdb = MagicMock() - vikingdb.delete_uris = AsyncMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - - memory = _make_context(target_uri) - - viking_fs = AsyncMock() - viking_fs.rm = AsyncMock() - - ctx = MagicMock() - result = await compressor._delete_existing_memory( - memory, - viking_fs, - ctx=ctx, - user_temp_uri=None, - agent_temp_uri=None, - ) - - assert result is True - viking_fs.rm.assert_called_once_with(target_uri, recursive=False, ctx=ctx) - vikingdb.delete_uris.assert_called_once_with(ctx, [target_uri]) - - async def test_delete_existing_memory_rm_exception(self): - vikingdb = MagicMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - user_temp_uri = "viking://user/temp_user_123" - - memory = _make_context(target_uri) - - viking_fs = AsyncMock() - viking_fs.rm = AsyncMock(side_effect=Exception("Delete error")) - - ctx = MagicMock() - result = await compressor._delete_existing_memory( - memory, - viking_fs, - ctx=ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=None, - ) - - assert result is False - - async def test_delete_existing_memory_vector_delete_exception(self): - vikingdb = MagicMock() - vikingdb.delete_uris = AsyncMock(side_effect=Exception("Vector delete error")) - compressor = SessionCompressor(vikingdb=vikingdb) - - user_space = _make_user().user_space_name() - target_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - user_temp_uri = "viking://user/temp_user_123" - - memory = _make_context(target_uri) - - viking_fs = AsyncMock() - viking_fs.rm = AsyncMock() - - ctx = MagicMock() - result = await compressor._delete_existing_memory( - memory, - viking_fs, - ctx=ctx, - user_temp_uri=user_temp_uri, - agent_temp_uri=None, - ) - - assert result is True - viking_fs.rm.assert_called_once() - vikingdb.delete_uris.assert_called_once() - - async def test_delete_existing_memory_agent_uri(self): - vikingdb = MagicMock() - vikingdb.delete_uris = AsyncMock() - compressor = SessionCompressor(vikingdb=vikingdb) - - agent_space = _make_user().agent_space_name() - target_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" - agent_temp_uri = "viking://agent/temp_agent_456" - temp_uri = f"{agent_temp_uri}/memories/cases/case1.md" - - memory = _make_context(target_uri) - - viking_fs = AsyncMock() - viking_fs.rm = AsyncMock() - - ctx = MagicMock() - result = await compressor._delete_existing_memory( - memory, - viking_fs, - ctx=ctx, - user_temp_uri=None, - agent_temp_uri=agent_temp_uri, - ) - - assert result is True - viking_fs.rm.assert_called_once_with(temp_uri, recursive=False, ctx=ctx) - vikingdb.delete_uris.assert_called_once_with(ctx, [temp_uri]) diff --git a/tests/unit/session/test_deduplicator_uri.py b/tests/unit/session/test_deduplicator_uri.py deleted file mode 100644 index 46057b77..00000000 --- a/tests/unit/session/test_deduplicator_uri.py +++ /dev/null @@ -1,310 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 - -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from openviking.session.memory_deduplicator import MemoryDeduplicator -from openviking.session.memory_extractor import CandidateMemory, MemoryCategory -from openviking_cli.session.user_id import UserIdentifier - - -class _DummyEmbedResult: - def __init__(self, dense_vector): - self.dense_vector = dense_vector - - -class _DummyEmbedder: - def embed(self, _text): - return _DummyEmbedResult([0.1, 0.2, 0.3]) - - -def _make_user() -> UserIdentifier: - return UserIdentifier("acc1", "test_user", "test_agent") - - -def _make_candidate(category: MemoryCategory = MemoryCategory.PREFERENCES) -> CandidateMemory: - return CandidateMemory( - category=category, - abstract="User prefers concise summaries", - overview="User asks for concise answers frequently.", - content="The user prefers concise summaries over long explanations.", - source_session="session_test", - user=_make_user(), - language="en", - ) - - -def _make_existing_user_memory(uri_suffix: str = "existing.md") -> dict: - user_space = _make_user().user_space_name() - return { - "id": f"uri_{uri_suffix}", - "uri": f"viking://user/{user_space}/memories/preferences/{uri_suffix}", - "context_type": "memory", - "level": 2, - "account_id": "acc1", - "owner_space": user_space, - "abstract": "Existing preference memory", - "category": "preferences", - "_score": 0.85, - } - - -def _make_existing_agent_memory(uri_suffix: str = "case1.md") -> dict: - user = _make_user() - agent_space = user.agent_space_name() - return { - "id": f"uri_{uri_suffix}", - "uri": f"viking://agent/{agent_space}/memories/cases/{uri_suffix}", - "context_type": "memory", - "level": 2, - "account_id": "acc1", - "owner_space": agent_space, - "abstract": "Existing case memory", - "category": "cases", - "_score": 0.90, - } - - -@pytest.mark.asyncio -class TestFindSimilarMemoriesURIConversion: - async def test_user_uri_converted_to_temp_uri(self): - vikingdb = MagicMock() - vikingdb.get_embedder.return_value = _DummyEmbedder() - vikingdb.search_similar_memories = AsyncMock( - return_value=[_make_existing_user_memory("pref1.md")] - ) - - dedup = MemoryDeduplicator(vikingdb=vikingdb) - candidate = _make_candidate() - - user_temp_uri = "viking://user/temp_user_123" - similar = await dedup._find_similar_memories( - candidate, - user_temp_uri=user_temp_uri, - agent_temp_uri=None, - ) - - assert len(similar) == 1 - user_space = _make_user().user_space_name() - original_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - expected_uri = f"{user_temp_uri}/memories/preferences/pref1.md" - assert similar[0].uri == expected_uri - assert similar[0].uri != original_uri - - async def test_agent_uri_converted_to_temp_uri(self): - vikingdb = MagicMock() - vikingdb.get_embedder.return_value = _DummyEmbedder() - vikingdb.search_similar_memories = AsyncMock( - return_value=[_make_existing_agent_memory("case1.md")] - ) - - dedup = MemoryDeduplicator(vikingdb=vikingdb) - candidate = _make_candidate(category=MemoryCategory.CASES) - - agent_temp_uri = "viking://agent/temp_agent_456" - similar = await dedup._find_similar_memories( - candidate, - user_temp_uri=None, - agent_temp_uri=agent_temp_uri, - ) - - assert len(similar) == 1 - user = _make_user() - agent_space = user.agent_space_name() - original_uri = f"viking://agent/{agent_space}/memories/cases/case1.md" - expected_uri = f"{agent_temp_uri}/memories/cases/case1.md" - assert similar[0].uri == expected_uri - assert similar[0].uri != original_uri - - async def test_no_conversion_when_no_temp_uri(self): - vikingdb = MagicMock() - vikingdb.get_embedder.return_value = _DummyEmbedder() - vikingdb.search_similar_memories = AsyncMock( - return_value=[_make_existing_user_memory("pref1.md")] - ) - - dedup = MemoryDeduplicator(vikingdb=vikingdb) - candidate = _make_candidate() - - similar = await dedup._find_similar_memories( - candidate, - user_temp_uri=None, - agent_temp_uri=None, - ) - - assert len(similar) == 1 - user_space = _make_user().user_space_name() - expected_uri = f"viking://user/{user_space}/memories/preferences/pref1.md" - assert similar[0].uri == expected_uri - - async def test_mixed_uris_only_convert_matching_type(self): - user_space = _make_user().user_space_name() - agent_space = _make_user().agent_space_name() - - vikingdb = MagicMock() - vikingdb.get_embedder.return_value = _DummyEmbedder() - vikingdb.search_similar_memories = AsyncMock( - return_value=[ - _make_existing_user_memory("pref1.md"), - _make_existing_agent_memory("case1.md"), - ] - ) - - dedup = MemoryDeduplicator(vikingdb=vikingdb) - candidate = _make_candidate() - - user_temp_uri = "viking://user/temp_user_123" - similar = await dedup._find_similar_memories( - candidate, - user_temp_uri=user_temp_uri, - agent_temp_uri=None, - ) - - assert len(similar) == 2 - uris = {m.uri for m in similar} - assert f"{user_temp_uri}/memories/preferences/pref1.md" in uris - assert f"viking://agent/{agent_space}/memories/cases/case1.md" in uris - - async def test_uri_conversion_preserves_meta_and_score(self): - vikingdb = MagicMock() - vikingdb.get_embedder.return_value = _DummyEmbedder() - vikingdb.search_similar_memories = AsyncMock( - return_value=[_make_existing_user_memory("pref1.md")] - ) - - dedup = MemoryDeduplicator(vikingdb=vikingdb) - candidate = _make_candidate() - - user_temp_uri = "viking://user/temp_user_123" - similar = await dedup._find_similar_memories( - candidate, - user_temp_uri=user_temp_uri, - agent_temp_uri=None, - ) - - assert len(similar) == 1 - assert similar[0].meta is not None - assert similar[0].meta.get("_dedup_score") == 0.85 - - -class TestExtractFacetKey: - def test_extract_with_chinese_colon(self): - result = MemoryDeduplicator._extract_facet_key("饮食偏好:喜欢吃苹果和草莓") - assert result == "饮食偏好" - - def test_extract_with_english_colon(self): - result = MemoryDeduplicator._extract_facet_key("User preference: dark mode enabled") - assert result == "user preference" - - def test_extract_with_hyphen(self): - result = MemoryDeduplicator._extract_facet_key("Coding style - prefer type hints") - assert result == "coding style" - - def test_extract_with_em_dash(self): - result = MemoryDeduplicator._extract_facet_key("Work schedule — remote on Fridays") - assert result == "work schedule" - - def test_extract_with_no_separator_returns_prefix(self): - result = MemoryDeduplicator._extract_facet_key( - "This is a long abstract without any separator" - ) - assert len(result) <= 24 - assert result == "this is a long abstract" - - def test_extract_with_empty_string(self): - result = MemoryDeduplicator._extract_facet_key("") - assert result == "" - - def test_extract_with_none(self): - result = MemoryDeduplicator._extract_facet_key(None) - assert result == "" - - def test_extract_normalizes_whitespace(self): - result = MemoryDeduplicator._extract_facet_key(" Multiple spaces : value ") - assert result == "multiple spaces" - - def test_extract_with_short_text_no_separator(self): - result = MemoryDeduplicator._extract_facet_key("Short") - assert result == "short" - - def test_extract_returns_lowercase(self): - result = MemoryDeduplicator._extract_facet_key("FOOD PREFERENCE: pizza") - assert result == "food preference" - - def test_extract_with_separator_at_start(self): - result = MemoryDeduplicator._extract_facet_key(": starts with separator") - assert result == ": starts with" - - def test_extract_with_multiple_separators_uses_first(self): - result = MemoryDeduplicator._extract_facet_key("Topic: Subtopic - Detail") - assert result == "topic" - - -class TestCosineSimilarity: - def test_identical_vectors(self): - vec = [1.0, 2.0, 3.0] - result = MemoryDeduplicator._cosine_similarity(vec, vec) - assert abs(result - 1.0) < 1e-9 - - def test_orthogonal_vectors(self): - vec_a = [1.0, 0.0] - vec_b = [0.0, 1.0] - result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) - assert abs(result) < 1e-9 - - def test_opposite_vectors(self): - vec_a = [1.0, 2.0, 3.0] - vec_b = [-1.0, -2.0, -3.0] - result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) - assert abs(result + 1.0) < 1e-9 - - def test_different_length_vectors(self): - vec_a = [1.0, 2.0, 3.0] - vec_b = [1.0, 2.0] - result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) - assert result == 0.0 - - def test_zero_vector_a(self): - vec_a = [0.0, 0.0, 0.0] - vec_b = [1.0, 2.0, 3.0] - result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) - assert result == 0.0 - - def test_zero_vector_b(self): - vec_a = [1.0, 2.0, 3.0] - vec_b = [0.0, 0.0, 0.0] - result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) - assert result == 0.0 - - def test_both_zero_vectors(self): - vec_a = [0.0, 0.0, 0.0] - vec_b = [0.0, 0.0, 0.0] - result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) - assert result == 0.0 - - def test_partial_similarity(self): - vec_a = [1.0, 0.0, 0.0] - vec_b = [1.0, 1.0, 0.0] - result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) - expected = 1.0 / (2.0**0.5) - assert abs(result - expected) < 1e-9 - - def test_negative_values(self): - vec_a = [1.0, -2.0, 3.0] - vec_b = [-1.0, 2.0, 3.0] - result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) - assert 0 < result < 1 - - def test_single_element_vectors(self): - vec_a = [5.0] - vec_b = [3.0] - result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) - assert abs(result - 1.0) < 1e-9 - - def test_large_vectors(self): - vec_a = [float(i) for i in range(100)] - vec_b = [float(i * 2) for i in range(100)] - result = MemoryDeduplicator._cosine_similarity(vec_a, vec_b) - assert abs(result - 1.0) < 1e-6 diff --git a/tests/unit/session/test_memory_extractor_tools.py b/tests/unit/session/test_memory_extractor_tools.py deleted file mode 100644 index 4d77aaab..00000000 --- a/tests/unit/session/test_memory_extractor_tools.py +++ /dev/null @@ -1,786 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 - -from unittest.mock import patch - -import pytest - -from openviking.session.memory_extractor import ( - FIELD_MAX_LENGTHS, - MemoryExtractor, -) - - -@pytest.fixture -def extractor(): - return MemoryExtractor() - - -class TestParseToolStatistics: - def test_parse_chinese_format_full(self, extractor): - content = """ -Tool: test_tool - -总调用次数: 100 -成功率: 85.0%(85 成功,15 失败) -平均耗时: 150.5ms -平均Token: 500 -""" - stats = extractor._parse_tool_statistics(content) - assert stats["total_calls"] == 100 - assert stats["success_count"] == 85 - assert stats["fail_count"] == 15 - assert stats["total_time_ms"] == 15050.0 - assert stats["total_tokens"] == 50000 - - def test_parse_chinese_format_with_colon(self, extractor): - content = """ -总调用次数:200 -成功率:90.5%(181 成功,19 失败) -平均耗时:200.0ms -平均Token:800 -""" - stats = extractor._parse_tool_statistics(content) - assert stats["total_calls"] == 200 - assert stats["success_count"] == 181 - assert stats["fail_count"] == 19 - - def test_parse_english_format_full(self, extractor): - content = """ -Tool: test_tool - -Based on 50 historical calls: -- Success rate: 80.0% (40 successful, 10 failed) -- Avg time: 1.5s, Avg tokens: 600 -""" - stats = extractor._parse_tool_statistics(content) - assert stats["total_calls"] == 50 - assert stats["success_count"] == 40 - assert stats["fail_count"] == 10 - assert stats["total_time_ms"] == 75000.0 - assert stats["total_tokens"] == 30000 - - def test_parse_english_format_ms(self, extractor): - content = """ -Based on 30 historical calls: -- Success rate: 90.0% (27 successful, 3 failed) -- Avg time: 250.5ms, Avg tokens: 400 -""" - stats = extractor._parse_tool_statistics(content) - assert stats["total_calls"] == 30 - assert stats["success_count"] == 27 - assert stats["fail_count"] == 3 - assert stats["total_time_ms"] == 7515.0 - assert stats["total_tokens"] == 12000 - - def test_parse_chinese_success_rate_only(self, extractor): - content = """ -总调用次数: 100 -成功率: 75.0% -""" - stats = extractor._parse_tool_statistics(content) - assert stats["total_calls"] == 100 - assert stats["success_count"] == 75 - assert stats["fail_count"] == 25 - - def test_parse_english_success_rate_only(self, extractor): - content = """ -Based on 80 historical calls: -- Success rate: 87.5% -""" - stats = extractor._parse_tool_statistics(content) - assert stats["total_calls"] == 80 - assert stats["success_count"] == 70 - assert stats["fail_count"] == 10 - - def test_parse_empty_content(self, extractor): - stats = extractor._parse_tool_statistics("") - assert stats["total_calls"] == 0 - assert stats["success_count"] == 0 - assert stats["fail_count"] == 0 - assert stats["total_time_ms"] == 0 - assert stats["total_tokens"] == 0 - - def test_parse_chinese_avg_time_seconds(self, extractor): - content = """ -总调用次数: 10 -平均耗时: 2.5s -""" - stats = extractor._parse_tool_statistics(content) - assert stats["total_calls"] == 10 - assert stats["total_time_ms"] == 25000.0 - - def test_parse_no_total_calls_infers_from_success_fail(self, extractor): - content = """ -成功率: 80.0%(40 成功,10 失败) -""" - stats = extractor._parse_tool_statistics(content) - assert stats["total_calls"] == 50 - assert stats["success_count"] == 40 - assert stats["fail_count"] == 10 - - -class TestMergeToolStatistics: - def test_merge_basic(self, extractor): - existing = { - "total_calls": 100, - "success_count": 80, - "fail_count": 20, - "total_time_ms": 10000.0, - "total_tokens": 50000, - } - new = { - "total_calls": 50, - "success_count": 45, - "fail_count": 5, - "total_time_ms": 5000.0, - "total_tokens": 25000, - } - merged = extractor._merge_tool_statistics(existing, new) - assert merged["total_calls"] == 150 - assert merged["success_count"] == 125 - assert merged["fail_count"] == 25 - assert merged["total_time_ms"] == 15000.0 - assert merged["total_tokens"] == 75000 - assert abs(merged["avg_time_ms"] - 100.0) < 0.01 - assert abs(merged["avg_tokens"] - 500.0) < 0.01 - assert abs(merged["success_rate"] - 0.8333) < 0.01 - - def test_merge_with_zero_existing(self, extractor): - existing = { - "total_calls": 0, - "success_count": 0, - "fail_count": 0, - "total_time_ms": 0, - "total_tokens": 0, - } - new = { - "total_calls": 10, - "success_count": 8, - "fail_count": 2, - "total_time_ms": 1000.0, - "total_tokens": 5000, - } - merged = extractor._merge_tool_statistics(existing, new) - assert merged["total_calls"] == 10 - assert merged["success_count"] == 8 - assert merged["fail_count"] == 2 - - def test_merge_with_zero_new(self, extractor): - existing = { - "total_calls": 20, - "success_count": 15, - "fail_count": 5, - "total_time_ms": 2000.0, - "total_tokens": 10000, - } - new = { - "total_calls": 0, - "success_count": 0, - "fail_count": 0, - "total_time_ms": 0, - "total_tokens": 0, - } - merged = extractor._merge_tool_statistics(existing, new) - assert merged["total_calls"] == 20 - assert merged["success_count"] == 15 - assert merged["fail_count"] == 5 - - def test_merge_both_zero(self, extractor): - existing = { - "total_calls": 0, - "success_count": 0, - "fail_count": 0, - "total_time_ms": 0, - "total_tokens": 0, - } - new = { - "total_calls": 0, - "success_count": 0, - "fail_count": 0, - "total_time_ms": 0, - "total_tokens": 0, - } - merged = extractor._merge_tool_statistics(existing, new) - assert merged["total_calls"] == 0 - assert merged["avg_time_ms"] == 0 - assert merged["avg_tokens"] == 0 - assert merged["success_rate"] == 0 - - -class TestGenerateToolMemoryContent: - def test_generate_basic(self, extractor): - with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): - stats = { - "total_calls": 100, - "success_count": 85, - "fail_count": 15, - "avg_time_ms": 150.5, - "avg_tokens": 500, - "success_rate": 0.85, - } - guidelines = "Use this tool for testing purposes." - content = extractor._generate_tool_memory_content("test_tool", stats, guidelines) - assert "Tool: test_tool" in content - assert "Based on 100 historical calls:" in content - assert "Success rate: 85.0%" in content - assert "85 successful, 15 failed" in content - assert "Use this tool for testing purposes." in content - - def test_generate_with_fields(self, extractor): - with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): - stats = { - "total_calls": 50, - "success_count": 40, - "fail_count": 10, - "avg_time_ms": 200.0, - "avg_tokens": 600, - "success_rate": 0.8, - } - fields = { - "best_for": "Data processing tasks", - "optimal_params": "batch_size=100", - "common_failures": "Timeout on large inputs", - "recommendation": "Use with small batches", - } - content = extractor._generate_tool_memory_content("test_tool", stats, "", fields=fields) - assert "Best for: Data processing tasks" in content - assert "Optimal params: batch_size=100" in content - assert "Common failures: Timeout on large inputs" in content - assert "Recommendation: Use with small batches" in content - - def test_generate_with_empty_fields(self, extractor): - with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): - stats = { - "total_calls": 10, - "success_count": 10, - "fail_count": 0, - "avg_time_ms": 100.0, - "avg_tokens": 300, - "success_rate": 1.0, - } - content = extractor._generate_tool_memory_content("test_tool", stats, "", fields={}) - assert "Best for: " in content - assert "Optimal params: " in content - - def test_generate_extracts_fields_from_guidelines(self, extractor): - with patch.object(extractor, "_get_tool_static_description", return_value="A test tool"): - stats = { - "total_calls": 20, - "success_count": 18, - "fail_count": 2, - "avg_time_ms": 50.0, - "avg_tokens": 200, - "success_rate": 0.9, - } - guidelines = """ -Best for: Quick data validation -Optimal params: strict_mode=true -Common failures: Invalid input format -Recommendation: Always validate input first -""" - content = extractor._generate_tool_memory_content("test_tool", stats, guidelines) - assert "Best for: Quick data validation" in content - assert "Optimal params: strict_mode=true" in content - - -class TestParseSkillStatistics: - def test_parse_chinese_format_full(self, extractor): - content = """ -Skill: test_skill - -总执行次数: 100 -成功率: 90.0%(90 成功,10 失败) -""" - stats = extractor._parse_skill_statistics(content) - assert stats["total_executions"] == 100 - assert stats["success_count"] == 90 - assert stats["fail_count"] == 10 - - def test_parse_chinese_format_with_colon(self, extractor): - content = """ -总执行次数:50 -成功率:80.0%(40 成功,10 失败) -""" - stats = extractor._parse_skill_statistics(content) - assert stats["total_executions"] == 50 - assert stats["success_count"] == 40 - assert stats["fail_count"] == 10 - - def test_parse_english_format_full(self, extractor): - content = """ -Skill: test_skill - -Based on 75 historical executions: -- Success rate: 85.0% (64 successful, 11 failed) -""" - stats = extractor._parse_skill_statistics(content) - assert stats["total_executions"] == 75 - assert stats["success_count"] == 64 - assert stats["fail_count"] == 11 - - def test_parse_english_success_rate_only(self, extractor): - content = """ -Based on 60 historical executions: -- Success rate: 75.0% -""" - stats = extractor._parse_skill_statistics(content) - assert stats["total_executions"] == 60 - assert stats["success_count"] == 45 - assert stats["fail_count"] == 15 - - def test_parse_empty_content(self, extractor): - stats = extractor._parse_skill_statistics("") - assert stats["total_executions"] == 0 - assert stats["success_count"] == 0 - assert stats["fail_count"] == 0 - - def test_parse_no_total_executions_infers_from_success_fail(self, extractor): - content = """ -成功率: 70.0%(35 成功,15 失败) -""" - stats = extractor._parse_skill_statistics(content) - assert stats["total_executions"] == 50 - assert stats["success_count"] == 35 - assert stats["fail_count"] == 15 - - -class TestMergeSkillStatistics: - def test_merge_basic(self, extractor): - existing = { - "total_executions": 100, - "success_count": 90, - "fail_count": 10, - } - new = { - "total_executions": 50, - "success_count": 45, - "fail_count": 5, - } - merged = extractor._merge_skill_statistics(existing, new) - assert merged["total_executions"] == 150 - assert merged["success_count"] == 135 - assert merged["fail_count"] == 15 - assert abs(merged["success_rate"] - 0.9) < 0.01 - - def test_merge_with_zero_existing(self, extractor): - existing = { - "total_executions": 0, - "success_count": 0, - "fail_count": 0, - } - new = { - "total_executions": 20, - "success_count": 18, - "fail_count": 2, - } - merged = extractor._merge_skill_statistics(existing, new) - assert merged["total_executions"] == 20 - assert merged["success_count"] == 18 - assert merged["fail_count"] == 2 - - def test_merge_with_zero_new(self, extractor): - existing = { - "total_executions": 30, - "success_count": 25, - "fail_count": 5, - } - new = { - "total_executions": 0, - "success_count": 0, - "fail_count": 0, - } - merged = extractor._merge_skill_statistics(existing, new) - assert merged["total_executions"] == 30 - assert merged["success_count"] == 25 - assert merged["fail_count"] == 5 - - def test_merge_both_zero(self, extractor): - existing = { - "total_executions": 0, - "success_count": 0, - "fail_count": 0, - } - new = { - "total_executions": 0, - "success_count": 0, - "fail_count": 0, - } - merged = extractor._merge_skill_statistics(existing, new) - assert merged["total_executions"] == 0 - assert merged["success_rate"] == 0 - - -class TestGenerateSkillMemoryContent: - def test_generate_basic(self, extractor): - stats = { - "total_executions": 100, - "success_count": 90, - "fail_count": 10, - "success_rate": 0.9, - } - guidelines = "Use this skill for data processing." - content = extractor._generate_skill_memory_content("test_skill", stats, guidelines) - assert "Skill: test_skill" in content - assert "Based on 100 historical executions:" in content - assert "Success rate: 90.0%" in content - assert "90 successful, 10 failed" in content - assert "Use this skill for data processing." in content - - def test_generate_with_fields(self, extractor): - stats = { - "total_executions": 50, - "success_count": 45, - "fail_count": 5, - "success_rate": 0.9, - } - fields = { - "best_for": "Automated workflows", - "recommended_flow": "Step 1 -> Step 2 -> Step 3", - "key_dependencies": "Database connection", - "common_failures": "Network timeout", - "recommendation": "Use with retry logic", - } - content = extractor._generate_skill_memory_content("test_skill", stats, "", fields=fields) - assert "Best for: Automated workflows" in content - assert "Recommended flow: Step 1 -> Step 2 -> Step 3" in content - assert "Key dependencies: Database connection" in content - assert "Common failures: Network timeout" in content - assert "Recommendation: Use with retry logic" in content - - def test_generate_with_empty_fields(self, extractor): - stats = { - "total_executions": 10, - "success_count": 10, - "fail_count": 0, - "success_rate": 1.0, - } - content = extractor._generate_skill_memory_content("test_skill", stats, "", fields={}) - assert "Best for: " in content - assert "Recommended flow: " in content - - def test_generate_extracts_fields_from_guidelines(self, extractor): - stats = { - "total_executions": 20, - "success_count": 18, - "fail_count": 2, - "success_rate": 0.9, - } - guidelines = """ -Best for: Complex data transformations -Recommended flow: Validate -> Transform -> Store -Key dependencies: S3 bucket access -Common failures: Permission denied -Recommendation: Check permissions first -""" - content = extractor._generate_skill_memory_content("test_skill", stats, guidelines) - assert "Best for: Complex data transformations" in content - assert "Recommended flow: Validate -> Transform -> Store" in content - - -class TestMergeKvField: - @pytest.mark.asyncio - async def test_merge_both_empty(self, extractor): - result = await extractor._merge_kv_field("", "", "best_for") - assert result == "" - - @pytest.mark.asyncio - async def test_merge_existing_empty(self, extractor): - result = await extractor._merge_kv_field("", "new value", "best_for") - assert result == "new value" - - @pytest.mark.asyncio - async def test_merge_new_empty(self, extractor): - result = await extractor._merge_kv_field("existing value", "", "best_for") - assert result == "existing value" - - @pytest.mark.asyncio - async def test_merge_identical_values(self, extractor): - result = await extractor._merge_kv_field("same value", "same value", "best_for") - assert result == "same value" - - @pytest.mark.asyncio - async def test_merge_different_values(self, extractor): - result = await extractor._merge_kv_field("value A", "value B", "best_for") - assert "value A" in result - assert "value B" in result - assert ";" in result - - @pytest.mark.asyncio - async def test_merge_with_semicolon_separator(self, extractor): - result = await extractor._merge_kv_field("item1; item2", "item3", "best_for") - assert "item1" in result - assert "item2" in result - assert "item3" in result - - @pytest.mark.asyncio - async def test_merge_deduplicates(self, extractor): - result = await extractor._merge_kv_field("item1; item2", "item2; item3", "best_for") - assert result.count("item2") == 1 - assert "item1" in result - assert "item3" in result - - @pytest.mark.asyncio - async def test_merge_respects_max_length(self, extractor): - long_value = "x" * 600 - result = await extractor._merge_kv_field(long_value, "new", "best_for") - assert len(result) <= FIELD_MAX_LENGTHS["best_for"] - - @pytest.mark.asyncio - async def test_merge_with_newline_separator(self, extractor): - result = await extractor._merge_kv_field("item1\nitem2", "item3", "best_for") - assert "item1" in result - assert "item2" in result - assert "item3" in result - - -class TestSmartTruncate: - def test_no_truncation_needed(self, extractor): - text = "short text" - result = extractor._smart_truncate(text, 100) - assert result == text - - def test_truncate_at_semicolon(self, extractor): - text = "item1; item2; item3; item4; item5" - result = extractor._smart_truncate(text, 25) - assert len(result) <= 25 - assert result.endswith(";") or result.count(";") >= 1 - - def test_truncate_at_space(self, extractor): - text = "word1 word2 word3 word4 word5" - result = extractor._smart_truncate(text, 20) - assert len(result) <= 20 - - def test_truncate_fallback(self, extractor): - text = "abcdefghijklmnopqrstuvwxyz" - result = extractor._smart_truncate(text, 10) - assert len(result) == 10 - assert result == "abcdefghij" - - def test_truncate_empty_string(self, extractor): - result = extractor._smart_truncate("", 10) - assert result == "" - - def test_truncate_exact_length(self, extractor): - text = "exactly10!" - result = extractor._smart_truncate(text, 10) - assert result == text - - -class TestComputeStatisticsDerived: - def test_compute_with_calls(self, extractor): - stats = { - "total_calls": 100, - "success_count": 80, - "fail_count": 20, - "total_time_ms": 10000.0, - "total_tokens": 50000, - } - result = extractor._compute_statistics_derived(stats) - assert abs(result["avg_time_ms"] - 100.0) < 0.01 - assert abs(result["avg_tokens"] - 500.0) < 0.01 - assert abs(result["success_rate"] - 0.8) < 0.01 - - def test_compute_with_zero_calls(self, extractor): - stats = { - "total_calls": 0, - "success_count": 0, - "fail_count": 0, - "total_time_ms": 0, - "total_tokens": 0, - } - result = extractor._compute_statistics_derived(stats) - assert result["avg_time_ms"] == 0 - assert result["avg_tokens"] == 0 - assert result["success_rate"] == 0 - - def test_compute_preserves_original_values(self, extractor): - stats = { - "total_calls": 50, - "success_count": 40, - "fail_count": 10, - "total_time_ms": 5000.0, - "total_tokens": 25000, - } - result = extractor._compute_statistics_derived(stats) - assert result["total_calls"] == 50 - assert result["success_count"] == 40 - assert result["fail_count"] == 10 - assert result["total_time_ms"] == 5000.0 - assert result["total_tokens"] == 25000 - - -class TestFormatDuration: - def test_format_zero(self, extractor): - result = extractor._format_duration(0) - assert result == "0s" - - def test_format_milliseconds(self, extractor): - result = extractor._format_duration(500) - assert result == "500ms" - - def test_format_seconds(self, extractor): - result = extractor._format_duration(1500) - assert result == "1.5s" - - def test_format_large_seconds(self, extractor): - result = extractor._format_duration(10000) - assert result == "10.0s" - - def test_format_none(self, extractor): - result = extractor._format_duration(None) - assert result == "N/A" - - def test_format_negative(self, extractor): - result = extractor._format_duration(-100) - assert result == "0s" - - def test_format_invalid_type(self, extractor): - result = extractor._format_duration("invalid") - assert result == "N/A" - - def test_format_exactly_one_second(self, extractor): - result = extractor._format_duration(1000) - assert result == "1.0s" - - def test_format_just_under_one_second(self, extractor): - result = extractor._format_duration(999) - assert result == "999ms" - - -class TestExtractContentField: - def test_extract_with_chinese_colon(self, extractor): - content = "Best for:数据处理任务" - result = extractor._extract_content_field(content, ["Best for"]) - assert result == "数据处理任务" - - def test_extract_with_english_colon(self, extractor): - content = "Best for: data processing tasks" - result = extractor._extract_content_field(content, ["Best for"]) - assert result == "data processing tasks" - - def test_extract_with_multiple_keys(self, extractor): - content = "最佳场景: 快速验证" - result = extractor._extract_content_field(content, ["Best for", "最佳场景"]) - assert result == "快速验证" - - def test_extract_not_found(self, extractor): - content = "Some other content" - result = extractor._extract_content_field(content, ["Best for"]) - assert result == "" - - def test_extract_empty_content(self, extractor): - result = extractor._extract_content_field("", ["Best for"]) - assert result == "" - - -class TestCompactBlock: - def test_compact_basic(self, extractor): - text = "Line 1\nLine 2\nLine 3" - result = extractor._compact_block(text) - assert result == "Line 1; Line 2; Line 3" - - def test_compact_with_prefixes(self, extractor): - text = "> Point 1\n- Point 2\n* Point 3" - result = extractor._compact_block(text) - assert "Point 1" in result - assert "Point 2" in result - assert "Point 3" in result - - def test_compact_empty(self, extractor): - result = extractor._compact_block("") - assert result == "" - - def test_compact_whitespace_only(self, extractor): - result = extractor._compact_block(" \n \n ") - assert result == "" - - -class TestExtractToolMemoryContextFieldsFromText: - def test_extract_all_fields(self, extractor): - text = """ -Best for: Data processing -Optimal params: batch_size=100 -Common failures: Timeout -Recommendation: Use small batches -""" - result = extractor._extract_tool_memory_context_fields_from_text(text) - assert result["best_for"] == "Data processing" - assert result["optimal_params"] == "batch_size=100" - assert result["common_failures"] == "Timeout" - assert result["recommendation"] == "Use small batches" - - def test_extract_partial_fields(self, extractor): - text = """ -Best for: Testing -Recommendation: Run in dev mode -""" - result = extractor._extract_tool_memory_context_fields_from_text(text) - assert result["best_for"] == "Testing" - assert result["optimal_params"] == "" - assert result["common_failures"] == "" - assert result["recommendation"] == "Run in dev mode" - - def test_extract_chinese_fields(self, extractor): - text = """ -最佳场景: 数据处理 -最优参数: 批量大小=100 -常见失败: 超时 -推荐: 使用小批量 -""" - result = extractor._extract_tool_memory_context_fields_from_text(text) - assert result["best_for"] == "数据处理" - assert result["optimal_params"] == "批量大小=100" - assert result["common_failures"] == "超时" - assert result["recommendation"] == "使用小批量" - - -class TestExtractSkillMemoryContextFieldsFromText: - def test_extract_all_fields(self, extractor): - text = """ -Best for: Automated workflows -Recommended flow: Step1 -> Step2 -> Step3 -Key dependencies: Database -Common failures: Connection error -Recommendation: Use connection pool -""" - result = extractor._extract_skill_memory_context_fields_from_text(text) - assert result["best_for"] == "Automated workflows" - assert result["recommended_flow"] == "Step1 -> Step2 -> Step3" - assert result["key_dependencies"] == "Database" - assert result["common_failures"] == "Connection error" - assert result["recommendation"] == "Use connection pool" - - def test_extract_chinese_fields(self, extractor): - text = """ -最佳场景: 自动化工作流 -推荐流程: 步骤1 -> 步骤2 -> 步骤3 -关键依赖: 数据库 -常见失败: 连接错误 -推荐: 使用连接池 -""" - result = extractor._extract_skill_memory_context_fields_from_text(text) - assert result["best_for"] == "自动化工作流" - assert result["recommended_flow"] == "步骤1 -> 步骤2 -> 步骤3" - assert result["key_dependencies"] == "数据库" - assert result["common_failures"] == "连接错误" - assert result["recommendation"] == "使用连接池" - - -class TestFormatMs: - def test_format_zero(self, extractor): - result = extractor._format_ms(0) - assert result == "0.000ms" - - def test_format_normal_value(self, extractor): - result = extractor._format_ms(123.456) - assert result == "123.456ms" - - def test_format_very_small_value(self, extractor): - result = extractor._format_ms(0.000123) - assert "ms" in result - assert float(result.replace("ms", "")) > 0 - - def test_format_large_value(self, extractor): - result = extractor._format_ms(9999.999) - assert result == "9999.999ms" diff --git a/tests/unit/session/test_session_cow.py b/tests/unit/session/test_session_cow.py deleted file mode 100644 index 26ed353d..00000000 --- a/tests/unit/session/test_session_cow.py +++ /dev/null @@ -1,495 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Unit tests for Session COW (Copy-on-Write) mode and async commit functionality.""" - -import time -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from openviking.server.identity import RequestContext, Role -from openviking.session.session import Session -from openviking_cli.session.user_id import UserIdentifier - - -def _make_user() -> UserIdentifier: - return UserIdentifier("test_account", "test_user", "test_agent") - - -def _make_session(viking_fs: MagicMock = None, session_id: str = "test_session_123") -> Session: - user = _make_user() - ctx = RequestContext(user=user, role=Role.ROOT) - fs = viking_fs or MagicMock() - return Session( - viking_fs=fs, - user=user, - ctx=ctx, - session_id=session_id, - ) - - -class TestCreateTempUris: - """Tests for _create_temp_uris() method.""" - - def test_returns_tuple_of_four_uris(self): - session = _make_session() - result = session._create_temp_uris() - - assert isinstance(result, tuple) - assert len(result) == 4 - - def test_temp_base_uri_format(self): - session = _make_session(session_id="sess_abc") - user = session.user - temp_base, _, _, _ = session._create_temp_uris() - - assert temp_base.startswith("viking://temp/session/") - assert f"/{user.user_space_name()}/" in temp_base - assert "/sess_abc/" in temp_base - assert "/commit_" in temp_base - - def test_session_temp_uri_structure(self): - session = _make_session(session_id="sess_abc") - user = session.user - temp_base, session_temp, _, _ = session._create_temp_uris() - - assert session_temp.startswith(temp_base) - assert "/session/" in session_temp - assert f"/{user.user_space_name()}/sess_abc" in session_temp - - def test_user_temp_uri_structure(self): - session = _make_session() - user = session.user - temp_base, _, user_temp, _ = session._create_temp_uris() - - assert user_temp.startswith(temp_base) - assert "/user/" in user_temp - assert user_temp.endswith(f"/user/{user.user_space_name()}") - - def test_agent_temp_uri_structure(self): - session = _make_session() - user = session.user - temp_base, _, _, agent_temp = session._create_temp_uris() - - assert agent_temp.startswith(temp_base) - assert "/agent/" in agent_temp - assert agent_temp.endswith(f"/agent/{user.agent_space_name()}") - - def test_sets_internal_state(self): - session = _make_session() - temp_base, session_temp, user_temp, agent_temp = session._create_temp_uris() - - assert session._temp_base_uri == temp_base - assert session._session_temp_uri == session_temp - assert session._user_temp_uri == user_temp - assert session._agent_temp_uri == agent_temp - assert session._temp_created_at is not None - - def test_temp_created_at_is_recent(self): - session = _make_session() - before = time.time() - session._create_temp_uris() - after = time.time() - - assert before <= session._temp_created_at <= after - - def test_commit_uuid_is_8_chars(self): - session = _make_session() - temp_base, _, _, _ = session._create_temp_uris() - - commit_part = temp_base.split("/commit_")[-1] - assert len(commit_part) == 8 - assert all(c in "0123456789abcdef" for c in commit_part) - - def test_multiple_calls_generate_different_uuids(self): - session = _make_session() - temp_base1, _, _, _ = session._create_temp_uris() - temp_base2, _, _, _ = session._create_temp_uris() - - assert temp_base1 != temp_base2 - - -class TestCleanupTempUris: - """Tests for _cleanup_temp_uris() method.""" - - @pytest.mark.asyncio - async def test_calls_delete_temp_on_viking_fs(self): - viking_fs = MagicMock() - viking_fs.delete_temp = AsyncMock() - session = _make_session(viking_fs=viking_fs) - - session._create_temp_uris() - saved_temp_base = session._temp_base_uri - await session._cleanup_temp_uris() - - viking_fs.delete_temp.assert_called_once() - call_args = viking_fs.delete_temp.call_args - assert call_args[0][0] == saved_temp_base - - @pytest.mark.asyncio - async def test_resets_internal_state(self): - viking_fs = MagicMock() - viking_fs.delete_temp = AsyncMock() - session = _make_session(viking_fs=viking_fs) - - session._create_temp_uris() - await session._cleanup_temp_uris() - - assert session._temp_base_uri is None - assert session._session_temp_uri is None - assert session._user_temp_uri is None - assert session._agent_temp_uri is None - assert session._temp_created_at is None - - @pytest.mark.asyncio - async def test_no_cleanup_when_no_temp_uri(self): - viking_fs = MagicMock() - viking_fs.delete_temp = AsyncMock() - session = _make_session(viking_fs=viking_fs) - - await session._cleanup_temp_uris() - - viking_fs.delete_temp.assert_not_called() - - @pytest.mark.asyncio - async def test_handles_delete_exception(self): - viking_fs = MagicMock() - viking_fs.delete_temp = AsyncMock(side_effect=Exception("Delete failed")) - session = _make_session(viking_fs=viking_fs) - - session._create_temp_uris() - await session._cleanup_temp_uris() - - assert session._temp_base_uri is None - - @pytest.mark.asyncio - async def test_passes_ctx_to_delete_temp(self): - viking_fs = MagicMock() - viking_fs.delete_temp = AsyncMock() - session = _make_session(viking_fs=viking_fs) - - session._create_temp_uris() - await session._cleanup_temp_uris() - - call_kwargs = viking_fs.delete_temp.call_args[1] - assert "ctx" in call_kwargs - assert call_kwargs["ctx"] == session.ctx - - -class TestEnqueueToSemanticQueue: - """Tests for _enqueue_to_semantic_queue() method.""" - - @pytest.mark.asyncio - async def test_returns_list_of_three_msg_ids(self): - session = _make_session() - - mock_queue = MagicMock() - mock_queue.enqueue = AsyncMock() - - mock_queue_manager = MagicMock() - mock_queue_manager.SEMANTIC = "semantic" - mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) - - with patch( - "openviking.storage.queuefs.get_queue_manager", - return_value=mock_queue_manager, - ): - result = await session._enqueue_to_semantic_queue( - session_temp_uri="viking://temp/session/test", - user_temp_uri="viking://temp/user/test", - agent_temp_uri="viking://temp/agent/test", - ) - - assert isinstance(result, list) - assert len(result) == 3 - - @pytest.mark.asyncio - async def test_session_msg_has_correct_target_uri(self): - session = _make_session(session_id="sess_xyz") - user = session.user - - enqueued_msgs = [] - - async def capture_enqueue(msg): - enqueued_msgs.append(msg) - - mock_queue = MagicMock() - mock_queue.enqueue = capture_enqueue - - mock_queue_manager = MagicMock() - mock_queue_manager.SEMANTIC = "semantic" - mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) - - session_temp = f"viking://temp/session/{user.user_space_name()}/sess_xyz/commit_abc123/session/{user.user_space_name()}/sess_xyz" - - with patch( - "openviking.storage.queuefs.get_queue_manager", - return_value=mock_queue_manager, - ): - await session._enqueue_to_semantic_queue( - session_temp_uri=session_temp, - user_temp_uri="viking://temp/user/test", - agent_temp_uri="viking://temp/agent/test", - ) - - session_msg = enqueued_msgs[0] - expected_target = f"viking://session/{user.user_space_name()}/sess_xyz" - assert session_msg.target_uri == expected_target - assert session_msg.uri == session_temp - - @pytest.mark.asyncio - async def test_user_msg_has_correct_target_uri(self): - session = _make_session() - user = session.user - - enqueued_msgs = [] - - async def capture_enqueue(msg): - enqueued_msgs.append(msg) - - mock_queue = MagicMock() - mock_queue.enqueue = capture_enqueue - - mock_queue_manager = MagicMock() - mock_queue_manager.SEMANTIC = "semantic" - mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) - - user_temp = f"viking://temp/session/{user.user_space_name()}/sess/commit_abc/user/{user.user_space_name()}" - - with patch( - "openviking.storage.queuefs.get_queue_manager", - return_value=mock_queue_manager, - ): - await session._enqueue_to_semantic_queue( - session_temp_uri="viking://temp/session/test", - user_temp_uri=user_temp, - agent_temp_uri="viking://temp/agent/test", - ) - - user_msg = enqueued_msgs[1] - expected_target = f"viking://user/{user.user_space_name()}" - assert user_msg.target_uri == expected_target - assert user_msg.uri == user_temp - - @pytest.mark.asyncio - async def test_agent_msg_has_correct_target_uri(self): - session = _make_session() - user = session.user - - enqueued_msgs = [] - - async def capture_enqueue(msg): - enqueued_msgs.append(msg) - - mock_queue = MagicMock() - mock_queue.enqueue = capture_enqueue - - mock_queue_manager = MagicMock() - mock_queue_manager.SEMANTIC = "semantic" - mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) - - agent_temp = f"viking://temp/session/{user.user_space_name()}/sess/commit_abc/agent/{user.agent_space_name()}" - - with patch( - "openviking.storage.queuefs.get_queue_manager", - return_value=mock_queue_manager, - ): - await session._enqueue_to_semantic_queue( - session_temp_uri="viking://temp/session/test", - user_temp_uri="viking://temp/user/test", - agent_temp_uri=agent_temp, - ) - - agent_msg = enqueued_msgs[2] - expected_target = f"viking://agent/{user.agent_space_name()}" - assert agent_msg.target_uri == expected_target - assert agent_msg.uri == agent_temp - - @pytest.mark.asyncio - async def test_all_msgs_have_context_type_memory(self): - session = _make_session() - - enqueued_msgs = [] - - async def capture_enqueue(msg): - enqueued_msgs.append(msg) - - mock_queue = MagicMock() - mock_queue.enqueue = capture_enqueue - - mock_queue_manager = MagicMock() - mock_queue_manager.SEMANTIC = "semantic" - mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) - - with patch( - "openviking.storage.queuefs.get_queue_manager", - return_value=mock_queue_manager, - ): - await session._enqueue_to_semantic_queue( - session_temp_uri="viking://temp/session/test", - user_temp_uri="viking://temp/user/test", - agent_temp_uri="viking://temp/agent/test", - ) - - for msg in enqueued_msgs: - assert msg.context_type == "memory" - - @pytest.mark.asyncio - async def test_all_msgs_have_recursive_true(self): - session = _make_session() - - enqueued_msgs = [] - - async def capture_enqueue(msg): - enqueued_msgs.append(msg) - - mock_queue = MagicMock() - mock_queue.enqueue = capture_enqueue - - mock_queue_manager = MagicMock() - mock_queue_manager.SEMANTIC = "semantic" - mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) - - with patch( - "openviking.storage.queuefs.get_queue_manager", - return_value=mock_queue_manager, - ): - await session._enqueue_to_semantic_queue( - session_temp_uri="viking://temp/session/test", - user_temp_uri="viking://temp/user/test", - agent_temp_uri="viking://temp/agent/test", - ) - - for msg in enqueued_msgs: - assert msg.recursive is True - - @pytest.mark.asyncio - async def test_msgs_have_correct_user_context(self): - user = _make_user() - session = _make_session() - - enqueued_msgs = [] - - async def capture_enqueue(msg): - enqueued_msgs.append(msg) - - mock_queue = MagicMock() - mock_queue.enqueue = capture_enqueue - - mock_queue_manager = MagicMock() - mock_queue_manager.SEMANTIC = "semantic" - mock_queue_manager.get_queue = MagicMock(return_value=mock_queue) - - with patch( - "openviking.storage.queuefs.get_queue_manager", - return_value=mock_queue_manager, - ): - await session._enqueue_to_semantic_queue( - session_temp_uri="viking://temp/session/test", - user_temp_uri="viking://temp/user/test", - agent_temp_uri="viking://temp/agent/test", - ) - - for msg in enqueued_msgs: - assert msg.account_id == user.account_id - assert msg.user_id == user.user_id - assert msg.agent_id == user.agent_id - - -class TestTempUriStructureMatchesTarget: - """Tests for temp URI structure matching target URI structure.""" - - def test_session_temp_uri_contains_target_path(self): - session = _make_session(session_id="sess_123") - user = session.user - - temp_base, session_temp, _, _ = session._create_temp_uris() - - target_path = f"/session/{user.user_space_name()}/sess_123" - assert session_temp.endswith(target_path) - - def test_user_temp_uri_contains_target_path(self): - session = _make_session() - user = session.user - - temp_base, _, user_temp, _ = session._create_temp_uris() - - target_path = f"/user/{user.user_space_name()}" - assert user_temp.endswith(target_path) - - def test_agent_temp_uri_contains_target_path(self): - session = _make_session() - user = session.user - - temp_base, _, _, agent_temp = session._create_temp_uris() - - target_path = f"/agent/{user.agent_space_name()}" - assert agent_temp.endswith(target_path) - - def test_all_temp_uris_share_same_base(self): - session = _make_session() - - temp_base, session_temp, user_temp, agent_temp = session._create_temp_uris() - - assert session_temp.startswith(temp_base) - assert user_temp.startswith(temp_base) - assert agent_temp.startswith(temp_base) - - def test_temp_uri_structure_allows_semantic_dag_recursive_processing(self): - session = _make_session(session_id="sess_xyz") - user = session.user - - temp_base, session_temp, user_temp, agent_temp = session._create_temp_uris() - - assert "/session/" in session_temp - assert f"/{user.user_space_name()}/sess_xyz" in session_temp - - assert "/user/" in user_temp - assert f"/{user.user_space_name()}" in user_temp - - assert "/agent/" in agent_temp - assert f"/{user.agent_space_name()}" in agent_temp - - -class TestTempUriWithDifferentUsers: - """Tests for temp URI generation with different user configurations.""" - - def test_different_user_space_names(self): - user1 = UserIdentifier("acc1", "alice", "agent1") - user2 = UserIdentifier("acc2", "bob", "agent2") - - session1 = Session( - viking_fs=MagicMock(), - user=user1, - ctx=RequestContext(user=user1, role=Role.ROOT), - session_id="sess1", - ) - session2 = Session( - viking_fs=MagicMock(), - user=user2, - ctx=RequestContext(user=user2, role=Role.ROOT), - session_id="sess2", - ) - - _, session_temp1, user_temp1, agent_temp1 = session1._create_temp_uris() - _, session_temp2, user_temp2, agent_temp2 = session2._create_temp_uris() - - assert "alice" in session_temp1 - assert "bob" in session_temp2 - assert user_temp1 != user_temp2 - assert agent_temp1 != agent_temp2 - - def test_agent_space_name_is_hashed(self): - user = UserIdentifier("acc", "myuser", "myagent") - session = Session( - viking_fs=MagicMock(), - user=user, - ctx=RequestContext(user=user, role=Role.ROOT), - session_id="sess", - ) - - _, _, _, agent_temp = session._create_temp_uris() - - assert "myagent" not in agent_temp - assert user.agent_space_name() in agent_temp - assert len(user.agent_space_name()) == 12 diff --git a/tests/unit/storage/queuefs/__init__.py b/tests/unit/storage/queuefs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/storage/queuefs/test_dag_incremental.py b/tests/unit/storage/queuefs/test_dag_incremental.py deleted file mode 100644 index 08082ce7..00000000 --- a/tests/unit/storage/queuefs/test_dag_incremental.py +++ /dev/null @@ -1,585 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Unit tests for SemanticDagExecutor incremental update and content change detection.""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from openviking.server.identity import RequestContext, Role -from openviking.storage.queuefs.semantic_dag import SemanticDagExecutor -from openviking_cli.session.user_id import UserIdentifier - - -@pytest.fixture -def mock_processor(): - """Create a mock SemanticProcessor.""" - processor = MagicMock() - processor._generate_single_file_summary = AsyncMock( - return_value={"name": "test.py", "summary": "test summary"} - ) - processor._generate_overview = AsyncMock(return_value="test overview") - processor._extract_abstract_from_overview = MagicMock(return_value="test abstract") - processor._vectorize_single_file = AsyncMock() - processor._vectorize_directory = AsyncMock() - return processor - - -@pytest.fixture -def mock_viking_fs(): - """Create a mock VikingFS.""" - fs = MagicMock() - fs.ls = AsyncMock(return_value=[]) - fs.read_file = AsyncMock(return_value="") - fs.write_file = AsyncMock() - fs._get_vector_store = MagicMock(return_value=None) - return fs - - -@pytest.fixture -def mock_vector_store(): - """Create a mock VectorStore.""" - store = MagicMock() - store.get_context_by_uri = AsyncMock(return_value=[]) - return store - - -@pytest.fixture -def mock_context(): - """Create a mock RequestContext.""" - user = MagicMock(spec=UserIdentifier) - user.account_id = "test_account" - user.user_id = "test_user" - return RequestContext(user=user, role=Role.USER) - - -@pytest.fixture -def executor(mock_processor, mock_context, mock_viking_fs): - """Create a SemanticDagExecutor instance for testing.""" - with patch( - "openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs - ): - executor = SemanticDagExecutor( - processor=mock_processor, - context_type="resource", - max_concurrent_llm=5, - ctx=mock_context, - incremental_update=True, - target_uri="viking://resource/target", - recursive=True, - ) - return executor - - -class TestGetTargetFilePath: - """Tests for _get_target_file_path() method.""" - - def test_returns_none_when_incremental_update_disabled( - self, mock_processor, mock_context, mock_viking_fs - ): - with patch( - "openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs - ): - executor = SemanticDagExecutor( - processor=mock_processor, - context_type="resource", - max_concurrent_llm=5, - ctx=mock_context, - incremental_update=False, - target_uri="viking://resource/target", - ) - executor._root_uri = "viking://resource/root" - result = executor._get_target_file_path("viking://resource/root/file.py") - assert result is None - - def test_returns_none_when_target_uri_is_none( - self, mock_processor, mock_context, mock_viking_fs - ): - with patch( - "openviking.storage.queuefs.semantic_dag.get_viking_fs", return_value=mock_viking_fs - ): - executor = SemanticDagExecutor( - processor=mock_processor, - context_type="resource", - max_concurrent_llm=5, - ctx=mock_context, - incremental_update=True, - target_uri=None, - ) - executor._root_uri = "viking://resource/root" - result = executor._get_target_file_path("viking://resource/root/file.py") - assert result is None - - def test_returns_none_when_root_uri_is_none(self, executor): - executor._root_uri = None - result = executor._get_target_file_path("viking://resource/root/file.py") - assert result is None - - def test_returns_target_path_for_file_in_root(self, executor): - executor._root_uri = "viking://resource/root" - result = executor._get_target_file_path("viking://resource/root/file.py") - assert result == "viking://resource/target/file.py" - - def test_returns_target_path_for_file_in_subdirectory(self, executor): - executor._root_uri = "viking://resource/root" - result = executor._get_target_file_path("viking://resource/root/subdir/file.py") - assert result == "viking://resource/target/subdir/file.py" - - def test_returns_target_uri_when_current_uri_equals_root(self, executor): - executor._root_uri = "viking://resource/root" - result = executor._get_target_file_path("viking://resource/root") - assert result == "viking://resource/target" - - def test_handles_nested_paths(self, executor): - executor._root_uri = "viking://resource/root" - result = executor._get_target_file_path("viking://resource/root/a/b/c/file.py") - assert result == "viking://resource/target/a/b/c/file.py" - - def test_handles_path_prefix_matching(self, executor): - executor._root_uri = "viking://resource/root" - result = executor._get_target_file_path("viking://resource/rootdir/file.py") - assert result == "viking://resource/target/dir/file.py" - - def test_returns_none_on_exception(self, executor): - executor._root_uri = "viking://resource/root" - result = executor._get_target_file_path(None) - assert result is None - - -class TestCheckFileContentChanged: - """Tests for _check_file_content_changed() method.""" - - @pytest.mark.asyncio - async def test_returns_true_when_target_path_is_none(self, executor): - executor._root_uri = None - result = await executor._check_file_content_changed("viking://resource/root/file.py") - assert result is True - - @pytest.mark.asyncio - async def test_returns_true_when_content_differs(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.read_file = AsyncMock(side_effect=["current content", "target content"]) - - result = await executor._check_file_content_changed("viking://resource/root/file.py") - assert result is True - - @pytest.mark.asyncio - async def test_returns_false_when_content_identical(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.read_file = AsyncMock(return_value="same content") - - result = await executor._check_file_content_changed("viking://resource/root/file.py") - assert result is False - - @pytest.mark.asyncio - async def test_returns_true_on_read_exception(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.read_file = AsyncMock(side_effect=Exception("read error")) - - result = await executor._check_file_content_changed("viking://resource/root/file.py") - assert result is True - - @pytest.mark.asyncio - async def test_calls_read_file_with_correct_paths(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.read_file = AsyncMock(return_value="content") - - await executor._check_file_content_changed("viking://resource/root/subdir/file.py") - - assert mock_viking_fs.read_file.call_count == 2 - calls = mock_viking_fs.read_file.call_args_list - assert calls[0][0][0] == "viking://resource/root/subdir/file.py" - assert calls[1][0][0] == "viking://resource/target/subdir/file.py" - - -class TestReadExistingSummary: - """Tests for _read_existing_summary() method.""" - - @pytest.mark.asyncio - async def test_returns_none_when_target_path_is_none(self, executor): - executor._root_uri = None - result = await executor._read_existing_summary("viking://resource/root/file.py") - assert result is None - - @pytest.mark.asyncio - async def test_returns_none_when_vector_store_is_none(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - mock_viking_fs._get_vector_store = MagicMock(return_value=None) - - result = await executor._read_existing_summary("viking://resource/root/file.py") - assert result is None - - @pytest.mark.asyncio - async def test_returns_summary_dict_when_record_exists( - self, executor, mock_viking_fs, mock_vector_store - ): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) - - mock_vector_store.get_context_by_uri = AsyncMock( - return_value=[{"abstract": "existing summary content"}] - ) - - result = await executor._read_existing_summary("viking://resource/root/subdir/file.py") - - assert result is not None - assert result["name"] == "file.py" - assert result["summary"] == "existing summary content" - - @pytest.mark.asyncio - async def test_returns_none_when_no_records(self, executor, mock_viking_fs, mock_vector_store): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) - - mock_vector_store.get_context_by_uri = AsyncMock(return_value=[]) - - result = await executor._read_existing_summary("viking://resource/root/file.py") - assert result is None - - @pytest.mark.asyncio - async def test_returns_none_when_abstract_is_empty( - self, executor, mock_viking_fs, mock_vector_store - ): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) - - mock_vector_store.get_context_by_uri = AsyncMock(return_value=[{"abstract": ""}]) - - result = await executor._read_existing_summary("viking://resource/root/file.py") - assert result is None - - @pytest.mark.asyncio - async def test_returns_none_on_exception(self, executor, mock_viking_fs, mock_vector_store): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) - - mock_vector_store.get_context_by_uri = AsyncMock(side_effect=Exception("db error")) - - result = await executor._read_existing_summary("viking://resource/root/file.py") - assert result is None - - @pytest.mark.asyncio - async def test_calls_vector_store_with_correct_uri( - self, executor, mock_viking_fs, mock_vector_store, mock_context - ): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) - - mock_vector_store.get_context_by_uri = AsyncMock(return_value=[{"abstract": "summary"}]) - - await executor._read_existing_summary("viking://resource/root/file.py") - - mock_vector_store.get_context_by_uri.assert_called_once_with( - account_id=mock_context.account_id, - uri="viking://resource/target/file.py", - limit=1, - ) - - -class TestCheckDirChildrenChanged: - """Tests for _check_dir_children_changed() method.""" - - @pytest.mark.asyncio - async def test_returns_true_when_target_path_is_none(self, executor): - executor._root_uri = None - result = await executor._check_dir_children_changed( - "viking://resource/root", ["file1.py"], ["dir1"] - ) - assert result is True - - @pytest.mark.asyncio - async def test_returns_false_when_children_identical(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - current_files = ["viking://resource/root/file1.py", "viking://resource/root/file2.py"] - current_dirs = ["viking://resource/root/dir1", "viking://resource/root/dir2"] - - executor._list_dir = AsyncMock( - side_effect=[ - ( - ["viking://resource/target/dir1", "viking://resource/target/dir2"], - ["viking://resource/target/file1.py", "viking://resource/target/file2.py"], - ), - ([], []), - ] - ) - mock_viking_fs.read_file = AsyncMock(return_value="same content") - - result = await executor._check_dir_children_changed( - "viking://resource/root", current_files, current_dirs - ) - assert result is False - - @pytest.mark.asyncio - async def test_returns_true_when_file_names_differ(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - current_files = ["viking://resource/root/file1.py"] - current_dirs = [] - - mock_viking_fs.ls = AsyncMock( - return_value=[ - {"name": "file2.py", "isDir": False}, - ] - ) - - result = await executor._check_dir_children_changed( - "viking://resource/root", current_files, current_dirs - ) - assert result is True - - @pytest.mark.asyncio - async def test_returns_true_when_dir_names_differ(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - current_files = [] - current_dirs = ["viking://resource/root/dir1"] - - mock_viking_fs.ls = AsyncMock( - return_value=[ - {"name": "dir2", "isDir": True}, - ] - ) - - result = await executor._check_dir_children_changed( - "viking://resource/root", current_files, current_dirs - ) - assert result is True - - @pytest.mark.asyncio - async def test_returns_true_when_file_content_changed(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - current_files = ["viking://resource/root/file1.py"] - current_dirs = [] - - mock_viking_fs.ls = AsyncMock( - side_effect=[ - [{"name": "file1.py", "isDir": False}], - [], - ] - ) - mock_viking_fs.read_file = AsyncMock(side_effect=["old content", "new content"]) - - result = await executor._check_dir_children_changed( - "viking://resource/root", current_files, current_dirs - ) - assert result is True - - @pytest.mark.asyncio - async def test_returns_true_on_exception(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.ls = AsyncMock(side_effect=Exception("ls error")) - - result = await executor._check_dir_children_changed( - "viking://resource/root", ["file1.py"], [] - ) - assert result is True - - @pytest.mark.asyncio - async def test_handles_empty_directories(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.ls = AsyncMock(return_value=[]) - - result = await executor._check_dir_children_changed("viking://resource/root", [], []) - assert result is False - - @pytest.mark.asyncio - async def test_ignores_dot_files_in_comparison(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - current_files = ["viking://resource/root/file1.py"] - current_dirs = [] - - executor._list_dir = AsyncMock( - return_value=( - [], - ["viking://resource/target/file1.py"], - ) - ) - - result = await executor._check_dir_children_changed( - "viking://resource/root", current_files, current_dirs - ) - assert result is False - - -class TestReadExistingOverviewAbstract: - """Tests for _read_existing_overview_abstract() method.""" - - @pytest.mark.asyncio - async def test_returns_none_tuple_when_target_path_is_none(self, executor): - executor._root_uri = None - result = await executor._read_existing_overview_abstract("viking://resource/root") - assert result == (None, None) - - @pytest.mark.asyncio - async def test_returns_overview_and_abstract_when_files_exist(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.read_file = AsyncMock(side_effect=["overview content", "abstract content"]) - - result = await executor._read_existing_overview_abstract("viking://resource/root/dir") - - assert result == ("overview content", "abstract content") - calls = mock_viking_fs.read_file.call_args_list - assert calls[0][0][0] == "viking://resource/target/dir/.overview.md" - assert calls[1][0][0] == "viking://resource/target/dir/.abstract.md" - - @pytest.mark.asyncio - async def test_returns_none_tuple_on_exception(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.read_file = AsyncMock(side_effect=Exception("read error")) - - result = await executor._read_existing_overview_abstract("viking://resource/root/dir") - assert result == (None, None) - - @pytest.mark.asyncio - async def test_handles_missing_overview_file(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.read_file = AsyncMock( - side_effect=[Exception("not found"), "abstract content"] - ) - - result = await executor._read_existing_overview_abstract("viking://resource/root/dir") - assert result == (None, None) - - @pytest.mark.asyncio - async def test_handles_missing_abstract_file(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.read_file = AsyncMock( - side_effect=["overview content", Exception("not found")] - ) - - result = await executor._read_existing_overview_abstract("viking://resource/root/dir") - assert result == (None, None) - - @pytest.mark.asyncio - async def test_calls_read_file_with_context(self, executor, mock_viking_fs, mock_context): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.read_file = AsyncMock(return_value="content") - - await executor._read_existing_overview_abstract("viking://resource/root/dir") - - for call in mock_viking_fs.read_file.call_args_list: - assert "ctx" in call[1] - assert call[1]["ctx"] == mock_context - - @pytest.mark.asyncio - async def test_handles_nested_directory_path(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.read_file = AsyncMock(return_value="content") - - await executor._read_existing_overview_abstract("viking://resource/root/a/b/c") - - calls = mock_viking_fs.read_file.call_args_list - assert calls[0][0][0] == "viking://resource/target/a/b/c/.overview.md" - assert calls[1][0][0] == "viking://resource/target/a/b/c/.abstract.md" - - -class TestIncrementalUpdateIntegration: - """Integration tests for incremental update scenarios.""" - - @pytest.mark.asyncio - async def test_full_incremental_flow_no_changes( - self, executor, mock_viking_fs, mock_vector_store - ): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) - - mock_viking_fs.ls = AsyncMock(return_value=[]) - mock_viking_fs.read_file = AsyncMock(return_value="same content") - mock_vector_store.get_context_by_uri = AsyncMock( - return_value=[{"abstract": "existing summary"}] - ) - - content_changed = await executor._check_file_content_changed( - "viking://resource/root/file.py" - ) - assert content_changed is False - - summary = await executor._read_existing_summary("viking://resource/root/file.py") - assert summary is not None - assert summary["summary"] == "existing summary" - - @pytest.mark.asyncio - async def test_full_incremental_flow_with_changes( - self, executor, mock_viking_fs, mock_vector_store - ): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - mock_viking_fs._get_vector_store = MagicMock(return_value=mock_vector_store) - - mock_viking_fs.read_file = AsyncMock(side_effect=["new content", "old content"]) - - content_changed = await executor._check_file_content_changed( - "viking://resource/root/file.py" - ) - assert content_changed is True - - @pytest.mark.asyncio - async def test_directory_change_detection_flow(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - executor._list_dir = AsyncMock( - side_effect=[ - ([], ["viking://resource/target/file1.py", "viking://resource/target/file2.py"]), - ([], []), - ] - ) - mock_viking_fs.read_file = AsyncMock(return_value="same content") - - current_files = ["viking://resource/root/file1.py", "viking://resource/root/file2.py"] - changed = await executor._check_dir_children_changed( - "viking://resource/root", current_files, [] - ) - assert changed is False - - @pytest.mark.asyncio - async def test_overview_abstract_read_flow(self, executor, mock_viking_fs): - executor._viking_fs = mock_viking_fs - executor._root_uri = "viking://resource/root" - - mock_viking_fs.read_file = AsyncMock(side_effect=["existing overview", "existing abstract"]) - - overview, abstract = await executor._read_existing_overview_abstract( - "viking://resource/root/subdir" - ) - assert overview == "existing overview" - assert abstract == "existing abstract" diff --git a/tests/unit/storage/queuefs/test_embedding_msg.py b/tests/unit/storage/queuefs/test_embedding_msg.py deleted file mode 100644 index 3311ae96..00000000 --- a/tests/unit/storage/queuefs/test_embedding_msg.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -import json - -import pytest - -from openviking.storage.queuefs.embedding_msg import EmbeddingMsg - - -class TestEmbeddingMsg: - """Unit tests for EmbeddingMsg class.""" - - def test_semantic_msg_id_serialization(self): - """Test semantic_msg_id field serialization via to_dict().""" - msg = EmbeddingMsg( - message="test message", - context_data={"key": "value"}, - semantic_msg_id="semantic-123", - ) - result = msg.to_dict() - assert result["semantic_msg_id"] == "semantic-123" - assert result["message"] == "test message" - assert result["context_data"] == {"key": "value"} - assert hasattr(msg, "id") - assert msg.id is not None - - def test_semantic_msg_id_deserialization(self): - """Test semantic_msg_id field deserialization via from_dict().""" - data = { - "id": "test-id-123", - "message": "test message", - "context_data": {"key": "value"}, - "semantic_msg_id": "semantic-456", - } - msg = EmbeddingMsg.from_dict(data) - assert msg.semantic_msg_id == "semantic-456" - assert msg.id == "test-id-123" - assert msg.message == "test message" - assert msg.context_data == {"key": "value"} - - def test_from_dict_missing_semantic_msg_id_defaults_to_none(self): - """Test from_dict() compatibility with old format (missing semantic_msg_id).""" - data = { - "id": "test-id-789", - "message": "legacy message", - "context_data": {"legacy": True}, - } - msg = EmbeddingMsg.from_dict(data) - assert msg.semantic_msg_id is None - assert msg.id == "test-id-789" - assert msg.message == "legacy message" - - def test_from_dict_semantic_msg_id_explicit_none(self): - """Test from_dict() with explicit None for semantic_msg_id.""" - data = { - "id": "test-id-none", - "message": "message with None", - "context_data": {}, - "semantic_msg_id": None, - } - msg = EmbeddingMsg.from_dict(data) - assert msg.semantic_msg_id is None - - def test_to_json_with_semantic_msg_id(self): - """Test to_json() method with semantic_msg_id.""" - msg = EmbeddingMsg( - message="json test", - context_data={"json_key": "json_value"}, - semantic_msg_id="semantic-json", - ) - json_str = msg.to_json() - parsed = json.loads(json_str) - assert parsed["semantic_msg_id"] == "semantic-json" - assert parsed["message"] == "json test" - assert parsed["context_data"] == {"json_key": "json_value"} - - def test_to_json_without_semantic_msg_id(self): - """Test to_json() method without semantic_msg_id (None).""" - msg = EmbeddingMsg( - message="json test no id", - context_data={"key": "value"}, - ) - json_str = msg.to_json() - parsed = json.loads(json_str) - assert parsed["semantic_msg_id"] is None - - def test_from_json_with_semantic_msg_id(self): - """Test from_json() method with semantic_msg_id.""" - json_str = json.dumps( - { - "id": "json-id-123", - "message": "from json", - "context_data": {"from": "json"}, - "semantic_msg_id": "semantic-from-json", - } - ) - msg = EmbeddingMsg.from_json(json_str) - assert msg.semantic_msg_id == "semantic-from-json" - assert msg.id == "json-id-123" - assert msg.message == "from json" - - def test_from_json_missing_semantic_msg_id(self): - """Test from_json() with missing semantic_msg_id (backward compatibility).""" - json_str = json.dumps( - { - "id": "json-id-456", - "message": "legacy json", - "context_data": {"legacy": True}, - } - ) - msg = EmbeddingMsg.from_json(json_str) - assert msg.semantic_msg_id is None - assert msg.message == "legacy json" - - def test_from_json_invalid_json_raises_value_error(self): - """Test from_json() raises ValueError for invalid JSON.""" - with pytest.raises(ValueError, match="Invalid JSON string"): - EmbeddingMsg.from_json("not a valid json") - - def test_message_field_string_type(self): - """Test message field with string type.""" - msg = EmbeddingMsg( - message="simple string message", - context_data={}, - ) - assert isinstance(msg.message, str) - assert msg.message == "simple string message" - - def test_message_field_list_of_dicts_type(self): - """Test message field with List[Dict] type.""" - message_list = [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there"}, - ] - msg = EmbeddingMsg( - message=message_list, - context_data={"conversation": True}, - ) - assert isinstance(msg.message, list) - assert len(msg.message) == 2 - assert msg.message[0]["role"] == "user" - assert msg.message[1]["content"] == "Hi there" - - def test_message_list_serialization_deserialization(self): - """Test serialization and deserialization with List[Dict] message.""" - message_list = [ - {"role": "user", "content": "Question?"}, - {"role": "assistant", "content": "Answer."}, - ] - msg = EmbeddingMsg( - message=message_list, - context_data={"type": "qa"}, - semantic_msg_id="qa-123", - ) - json_str = msg.to_json() - restored = EmbeddingMsg.from_json(json_str) - assert isinstance(restored.message, list) - assert len(restored.message) == 2 - assert restored.message[0]["role"] == "user" - assert restored.semantic_msg_id == "qa-123" - - def test_id_auto_generated(self): - """Test that id is auto-generated as UUID.""" - msg = EmbeddingMsg( - message="test", - context_data={}, - ) - assert msg.id is not None - assert len(msg.id) == 36 - assert msg.id.count("-") == 4 - - def test_from_dict_preserves_or_generates_id(self): - """Test from_dict() preserves provided id or generates new one.""" - data_with_id = { - "id": "preserved-id", - "message": "test", - "context_data": {}, - } - msg = EmbeddingMsg.from_dict(data_with_id) - assert msg.id == "preserved-id" - - data_without_id = { - "message": "test", - "context_data": {}, - } - msg = EmbeddingMsg.from_dict(data_without_id) - assert msg.id is not None - assert len(msg.id) == 36 - - def test_empty_context_data(self): - """Test with empty context_data.""" - msg = EmbeddingMsg( - message="test", - context_data={}, - semantic_msg_id="empty-ctx", - ) - result = msg.to_dict() - assert result["context_data"] == {} - assert result["semantic_msg_id"] == "empty-ctx" - - def test_complex_context_data(self): - """Test with complex nested context_data.""" - complex_data = { - "nested": { - "level1": { - "level2": ["a", "b", "c"], - }, - }, - "list": [1, 2, 3], - "string": "value", - } - msg = EmbeddingMsg( - message="complex test", - context_data=complex_data, - semantic_msg_id="complex-123", - ) - json_str = msg.to_json() - restored = EmbeddingMsg.from_json(json_str) - assert restored.context_data["nested"]["level1"]["level2"] == ["a", "b", "c"] - assert restored.semantic_msg_id == "complex-123" - - def test_semantic_msg_id_empty_string(self): - """Test semantic_msg_id with empty string.""" - msg = EmbeddingMsg( - message="test", - context_data={}, - semantic_msg_id="", - ) - assert msg.semantic_msg_id == "" - result = msg.to_dict() - assert result["semantic_msg_id"] == "" - - def test_roundtrip_string_message(self): - """Test complete roundtrip with string message.""" - original = EmbeddingMsg( - message="roundtrip test", - context_data={"key": "value"}, - semantic_msg_id="roundtrip-id", - ) - json_str = original.to_json() - restored = EmbeddingMsg.from_json(json_str) - assert restored.message == original.message - assert restored.context_data == original.context_data - assert restored.semantic_msg_id == original.semantic_msg_id - assert restored.id is not None - assert len(restored.id) == 36 - - def test_roundtrip_list_message(self): - """Test complete roundtrip with List[Dict] message.""" - message_list = [ - {"type": "text", "content": "part1"}, - {"type": "code", "content": "print('hello')"}, - ] - original = EmbeddingMsg( - message=message_list, - context_data={"format": "mixed"}, - semantic_msg_id="list-msg-id", - ) - json_str = original.to_json() - restored = EmbeddingMsg.from_json(json_str) - assert restored.message == original.message - assert restored.semantic_msg_id == "list-msg-id" diff --git a/tests/unit/storage/queuefs/test_embedding_tracker.py b/tests/unit/storage/queuefs/test_embedding_tracker.py deleted file mode 100644 index 1850fef0..00000000 --- a/tests/unit/storage/queuefs/test_embedding_tracker.py +++ /dev/null @@ -1,554 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 - -"""Unit tests for EmbeddingTaskTracker.""" - -import asyncio - -import pytest - -from openviking.storage.queuefs.embedding_tracker import EmbeddingTaskTracker - - -def reset_singleton(): - """Reset the singleton instance for testing.""" - EmbeddingTaskTracker._instance = None - - -@pytest.fixture(autouse=True) -def clean_singleton(): - """Reset singleton before and after each test.""" - reset_singleton() - yield - reset_singleton() - - -@pytest.fixture -def tracker() -> EmbeddingTaskTracker: - """Create a fresh tracker instance for each test.""" - return EmbeddingTaskTracker() - - -# ── Singleton Pattern Tests ── - - -def test_singleton_returns_same_instance(): - """Test that get_instance() returns the same instance.""" - instance1 = EmbeddingTaskTracker.get_instance() - instance2 = EmbeddingTaskTracker.get_instance() - assert instance1 is instance2 - - -def test_singleton_persists_across_calls(): - """Test that singleton persists across multiple get_instance calls.""" - instance1 = EmbeddingTaskTracker.get_instance() - instance2 = EmbeddingTaskTracker.get_instance() - instance3 = EmbeddingTaskTracker.get_instance() - assert instance1 is instance2 is instance3 - - -# ── Register Tests ── - - -@pytest.mark.asyncio -async def test_register_task(tracker: EmbeddingTaskTracker): - """Test registering a task with valid count.""" - await tracker.register("msg-1", 5) - status = await tracker.get_status("msg-1") - assert status is not None - assert status["remaining"] == 5 - assert status["total"] == 5 - - -@pytest.mark.asyncio -async def test_register_with_metadata(tracker: EmbeddingTaskTracker): - """Test registering a task with metadata.""" - metadata = {"key": "value", "count": 10} - await tracker.register("msg-2", 3, metadata=metadata) - status = await tracker.get_status("msg-2") - assert status is not None - assert status["metadata"] == metadata - - -@pytest.mark.asyncio -async def test_register_with_callback(tracker: EmbeddingTaskTracker): - """Test registering a task with on_complete callback.""" - callback_called = [] - - async def on_complete(): - callback_called.append(True) - - await tracker.register("msg-3", 1, on_complete=on_complete) - await tracker.decrement("msg-3") - assert len(callback_called) == 1 - - -@pytest.mark.asyncio -async def test_register_with_sync_callback(tracker: EmbeddingTaskTracker): - """Test registering a task with synchronous callback.""" - callback_called = [] - - def on_complete(): - callback_called.append(True) - - await tracker.register("msg-4", 1, on_complete=on_complete) - await tracker.decrement("msg-4") - assert len(callback_called) == 1 - - -@pytest.mark.asyncio -async def test_register_with_zero_count_does_nothing(tracker: EmbeddingTaskTracker): - """Test that registering with total_count=0 does nothing.""" - await tracker.register("msg-5", 0) - status = await tracker.get_status("msg-5") - assert status is None - - -@pytest.mark.asyncio -async def test_register_with_negative_count_does_nothing(tracker: EmbeddingTaskTracker): - """Test that registering with negative total_count does nothing.""" - await tracker.register("msg-6", -5) - status = await tracker.get_status("msg-6") - assert status is None - - -@pytest.mark.asyncio -async def test_register_overwrites_existing(tracker: EmbeddingTaskTracker): - """Test that registering with same ID overwrites existing entry.""" - await tracker.register("msg-7", 5) - await tracker.register("msg-7", 10) - status = await tracker.get_status("msg-7") - assert status is not None - assert status["remaining"] == 10 - assert status["total"] == 10 - - -# ── Increment Tests ── - - -@pytest.mark.asyncio -async def test_increment_existing_task(tracker: EmbeddingTaskTracker): - """Test incrementing an existing task.""" - await tracker.register("msg-10", 5) - result = await tracker.increment("msg-10") - assert result == 6 - status = await tracker.get_status("msg-10") - assert status["remaining"] == 6 - assert status["total"] == 6 - - -@pytest.mark.asyncio -async def test_increment_multiple_times(tracker: EmbeddingTaskTracker): - """Test incrementing a task multiple times.""" - await tracker.register("msg-11", 2) - await tracker.increment("msg-11") - await tracker.increment("msg-11") - await tracker.increment("msg-11") - status = await tracker.get_status("msg-11") - assert status["remaining"] == 5 - assert status["total"] == 5 - - -@pytest.mark.asyncio -async def test_increment_nonexistent_task_returns_none(tracker: EmbeddingTaskTracker): - """Test incrementing a non-existent task returns None.""" - result = await tracker.increment("nonexistent") - assert result is None - - -# ── Decrement Tests ── - - -@pytest.mark.asyncio -async def test_decrement_existing_task(tracker: EmbeddingTaskTracker): - """Test decrementing an existing task.""" - await tracker.register("msg-20", 5) - result = await tracker.decrement("msg-20") - assert result == 4 - status = await tracker.get_status("msg-20") - assert status["remaining"] == 4 - - -@pytest.mark.asyncio -async def test_decrement_multiple_times(tracker: EmbeddingTaskTracker): - """Test decrementing a task multiple times.""" - await tracker.register("msg-21", 3) - await tracker.decrement("msg-21") - await tracker.decrement("msg-21") - status = await tracker.get_status("msg-21") - assert status["remaining"] == 1 - - -@pytest.mark.asyncio -async def test_decrement_nonexistent_task_returns_none(tracker: EmbeddingTaskTracker): - """Test decrementing a non-existent task returns None.""" - result = await tracker.decrement("nonexistent") - assert result is None - - -@pytest.mark.asyncio -async def test_decrement_to_zero_removes_task(tracker: EmbeddingTaskTracker): - """Test that decrementing to zero removes the task.""" - await tracker.register("msg-22", 1) - result = await tracker.decrement("msg-22") - assert result == 0 - status = await tracker.get_status("msg-22") - assert status is None - - -@pytest.mark.asyncio -async def test_decrement_triggers_callback_on_completion(tracker: EmbeddingTaskTracker): - """Test that callback is triggered when count reaches zero.""" - callback_called = [] - - async def on_complete(): - callback_called.append("async") - - await tracker.register("msg-23", 2, on_complete=on_complete) - await tracker.decrement("msg-23") - assert len(callback_called) == 0 - await tracker.decrement("msg-23") - assert len(callback_called) == 1 - assert callback_called[0] == "async" - - -@pytest.mark.asyncio -async def test_decrement_sync_callback_on_completion(tracker: EmbeddingTaskTracker): - """Test that sync callback is triggered when count reaches zero.""" - callback_called = [] - - def on_complete(): - callback_called.append("sync") - - await tracker.register("msg-24", 1, on_complete=on_complete) - await tracker.decrement("msg-24") - assert len(callback_called) == 1 - assert callback_called[0] == "sync" - - -@pytest.mark.asyncio -async def test_decrement_callback_error_is_handled(tracker: EmbeddingTaskTracker): - """Test that callback errors are handled gracefully.""" - - async def on_complete(): - raise ValueError("Callback error") - - await tracker.register("msg-25", 1, on_complete=on_complete) - result = await tracker.decrement("msg-25") - assert result == 0 - - -@pytest.mark.asyncio -async def test_decrement_below_zero_removes_task(tracker: EmbeddingTaskTracker): - """Test that decrementing below zero still removes the task.""" - await tracker.register("msg-26", 1) - await tracker.decrement("msg-26") - status = await tracker.get_status("msg-26") - assert status is None - - -# ── Get Status Tests ── - - -@pytest.mark.asyncio -async def test_get_status_existing_task(tracker: EmbeddingTaskTracker): - """Test getting status of an existing task.""" - await tracker.register("msg-30", 5, metadata={"key": "value"}) - status = await tracker.get_status("msg-30") - assert status is not None - assert status["remaining"] == 5 - assert status["total"] == 5 - assert status["metadata"] == {"key": "value"} - - -@pytest.mark.asyncio -async def test_get_status_nonexistent_task(tracker: EmbeddingTaskTracker): - """Test getting status of a non-existent task.""" - status = await tracker.get_status("nonexistent") - assert status is None - - -@pytest.mark.asyncio -async def test_get_status_reflects_changes(tracker: EmbeddingTaskTracker): - """Test that get_status reflects increment/decrement changes.""" - await tracker.register("msg-31", 5) - await tracker.increment("msg-31") - status = await tracker.get_status("msg-31") - assert status["remaining"] == 6 - assert status["total"] == 6 - await tracker.decrement("msg-31") - status = await tracker.get_status("msg-31") - assert status["remaining"] == 5 - - -# ── Remove Tests ── - - -@pytest.mark.asyncio -async def test_remove_existing_task(tracker: EmbeddingTaskTracker): - """Test removing an existing task.""" - await tracker.register("msg-40", 5) - result = await tracker.remove("msg-40") - assert result is True - status = await tracker.get_status("msg-40") - assert status is None - - -@pytest.mark.asyncio -async def test_remove_nonexistent_task(tracker: EmbeddingTaskTracker): - """Test removing a non-existent task.""" - result = await tracker.remove("nonexistent") - assert result is False - - -@pytest.mark.asyncio -async def test_remove_does_not_trigger_callback(tracker: EmbeddingTaskTracker): - """Test that remove does not trigger the on_complete callback.""" - callback_called = [] - - async def on_complete(): - callback_called.append(True) - - await tracker.register("msg-41", 5, on_complete=on_complete) - await tracker.remove("msg-41") - assert len(callback_called) == 0 - - -# ── Get All Tracked Tests ── - - -@pytest.mark.asyncio -async def test_get_all_tracked_empty(tracker: EmbeddingTaskTracker): - """Test get_all_tracked when no tasks are registered.""" - all_tasks = await tracker.get_all_tracked() - assert all_tasks == {} - - -@pytest.mark.asyncio -async def test_get_all_tracked_single_task(tracker: EmbeddingTaskTracker): - """Test get_all_tracked with a single task.""" - await tracker.register("msg-50", 5, metadata={"key": "value"}) - all_tasks = await tracker.get_all_tracked() - assert len(all_tasks) == 1 - assert "msg-50" in all_tasks - assert all_tasks["msg-50"]["remaining"] == 5 - assert all_tasks["msg-50"]["total"] == 5 - assert all_tasks["msg-50"]["metadata"] == {"key": "value"} - - -@pytest.mark.asyncio -async def test_get_all_tracked_multiple_tasks(tracker: EmbeddingTaskTracker): - """Test get_all_tracked with multiple tasks.""" - await tracker.register("msg-51", 3) - await tracker.register("msg-52", 5) - await tracker.register("msg-53", 7) - all_tasks = await tracker.get_all_tracked() - assert len(all_tasks) == 3 - assert "msg-51" in all_tasks - assert "msg-52" in all_tasks - assert "msg-53" in all_tasks - - -@pytest.mark.asyncio -async def test_get_all_tracked_excludes_on_complete(tracker: EmbeddingTaskTracker): - """Test that get_all_tracked does not include on_complete callback.""" - await tracker.register("msg-54", 5, on_complete=lambda: None) - all_tasks = await tracker.get_all_tracked() - assert "on_complete" not in all_tasks["msg-54"] - - -@pytest.mark.asyncio -async def test_get_all_tracked_returns_copy(tracker: EmbeddingTaskTracker): - """Test that get_all_tracked returns a copy, not internal state.""" - await tracker.register("msg-55", 5) - all_tasks = await tracker.get_all_tracked() - all_tasks["msg-55"]["remaining"] = 999 - status = await tracker.get_status("msg-55") - assert status["remaining"] == 5 - - -# ── Concurrency Tests ── - - -@pytest.mark.asyncio -async def test_concurrent_register(tracker: EmbeddingTaskTracker): - """Test concurrent register operations.""" - - async def register_task(msg_id: str): - await tracker.register(msg_id, 5) - - await asyncio.gather( - register_task("msg-60"), - register_task("msg-61"), - register_task("msg-62"), - ) - all_tasks = await tracker.get_all_tracked() - assert len(all_tasks) == 3 - - -@pytest.mark.asyncio -async def test_concurrent_increment(tracker: EmbeddingTaskTracker): - """Test concurrent increment operations.""" - await tracker.register("msg-70", 1) - - async def increment_task(): - await tracker.increment("msg-70") - - await asyncio.gather( - increment_task(), - increment_task(), - increment_task(), - ) - status = await tracker.get_status("msg-70") - assert status["remaining"] == 4 - assert status["total"] == 4 - - -@pytest.mark.asyncio -async def test_concurrent_decrement(tracker: EmbeddingTaskTracker): - """Test concurrent decrement operations.""" - await tracker.register("msg-71", 3) - - async def decrement_task(): - await tracker.decrement("msg-71") - - await asyncio.gather( - decrement_task(), - decrement_task(), - decrement_task(), - ) - status = await tracker.get_status("msg-71") - assert status is None - - -@pytest.mark.asyncio -async def test_concurrent_mixed_operations(tracker: EmbeddingTaskTracker): - """Test concurrent mixed operations (increment and decrement).""" - await tracker.register("msg-72", 5) - - async def increment(): - await tracker.increment("msg-72") - - async def decrement(): - await tracker.decrement("msg-72") - - await asyncio.gather( - increment(), - increment(), - decrement(), - decrement(), - decrement(), - ) - status = await tracker.get_status("msg-72") - assert status is not None - - -@pytest.mark.asyncio -async def test_concurrent_register_and_decrement(tracker: EmbeddingTaskTracker): - """Test concurrent register and decrement operations.""" - callback_called = [] - - async def on_complete(): - callback_called.append(True) - - await tracker.register("msg-73", 1, on_complete=on_complete) - await tracker.decrement("msg-73") - assert len(callback_called) == 1 - - -@pytest.mark.asyncio -async def test_concurrent_callback_execution(tracker: EmbeddingTaskTracker): - """Test that callbacks are executed correctly under concurrency.""" - callback_count = [] - - async def make_callback(msg_id: str): - async def on_complete(): - callback_count.append(msg_id) - - return on_complete - - async def register_and_complete(msg_id: str): - callback = await make_callback(msg_id) - await tracker.register(msg_id, 1, on_complete=callback) - await tracker.decrement(msg_id) - - await asyncio.gather( - register_and_complete("msg-80"), - register_and_complete("msg-81"), - register_and_complete("msg-82"), - ) - assert len(callback_count) == 3 - - -# ── Edge Cases Tests ── - - -@pytest.mark.asyncio -async def test_multiple_decrements_to_zero(tracker: EmbeddingTaskTracker): - """Test multiple decrements that bring count to exactly zero.""" - callback_called = [] - - async def on_complete(): - callback_called.append(True) - - await tracker.register("msg-90", 3, on_complete=on_complete) - await tracker.decrement("msg-90") - await tracker.decrement("msg-90") - assert len(callback_called) == 0 - await tracker.decrement("msg-90") - assert len(callback_called) == 1 - - -@pytest.mark.asyncio -async def test_decrement_after_increment(tracker: EmbeddingTaskTracker): - """Test decrement after increment maintains correct count.""" - await tracker.register("msg-91", 2) - await tracker.increment("msg-91") - await tracker.decrement("msg-91") - status = await tracker.get_status("msg-91") - assert status["remaining"] == 2 - assert status["total"] == 3 - - -@pytest.mark.asyncio -async def test_empty_metadata(tracker: EmbeddingTaskTracker): - """Test that empty metadata is handled correctly.""" - await tracker.register("msg-92", 5, metadata={}) - status = await tracker.get_status("msg-92") - assert status["metadata"] == {} - - -@pytest.mark.asyncio -async def test_none_metadata(tracker: EmbeddingTaskTracker): - """Test that None metadata defaults to empty dict.""" - await tracker.register("msg-93", 5, metadata=None) - status = await tracker.get_status("msg-93") - assert status["metadata"] == {} - - -@pytest.mark.asyncio -async def test_none_callback(tracker: EmbeddingTaskTracker): - """Test that None callback is handled correctly.""" - await tracker.register("msg-94", 1, on_complete=None) - result = await tracker.decrement("msg-94") - assert result == 0 - - -@pytest.mark.asyncio -async def test_large_count(tracker: EmbeddingTaskTracker): - """Test with large task count.""" - large_count = 10000 - await tracker.register("msg-95", large_count) - status = await tracker.get_status("msg-95") - assert status["remaining"] == large_count - assert status["total"] == large_count - - -@pytest.mark.asyncio -async def test_special_characters_in_id(tracker: EmbeddingTaskTracker): - """Test with special characters in semantic_msg_id.""" - special_id = "msg-with-special_chars.123!@#$%" - await tracker.register(special_id, 5) - status = await tracker.get_status(special_id) - assert status is not None - assert status["remaining"] == 5 diff --git a/tests/unit/storage/queuefs/test_processor_incremental.py b/tests/unit/storage/queuefs/test_processor_incremental.py deleted file mode 100644 index 01c9a57e..00000000 --- a/tests/unit/storage/queuefs/test_processor_incremental.py +++ /dev/null @@ -1,868 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Unit tests for SemanticProcessor incremental update and diff calculation. - -Tests for: -- _detect_file_type(): Detect file type based on extension -- _collect_tree_info(): Collect directory tree information -- _compute_diff(): Compute directory differences -- _check_file_content_changed(): Check file content changes -- _execute_sync_operations(): Execute sync operations -- _create_sync_diff_callback(): Create sync diff callback -""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from openviking.parse.parsers.constants import ( - FILE_TYPE_CODE, - FILE_TYPE_DOCUMENTATION, - FILE_TYPE_OTHER, -) -from openviking.server.identity import RequestContext, Role -from openviking.storage.queuefs.semantic_processor import DiffResult, SemanticProcessor -from openviking_cli.session.user_id import UserIdentifier - - -class FakeVikingFS: - """Fake VikingFS for testing.""" - - def __init__(self): - self._tree = {} - self._file_contents = {} - self.deleted_files = [] - self.deleted_dirs = [] - self.moved_files = [] - self.created_dirs = [] - - def set_tree(self, tree): - self._tree = tree - - def set_file_contents(self, contents): - self._file_contents = contents - - async def ls(self, uri, ctx=None): - return self._tree.get(uri.rstrip("/"), []) - - async def read_file(self, uri, ctx=None): - return self._file_contents.get(uri, "") - - async def rm(self, uri, recursive=False, ctx=None): - if recursive: - self.deleted_dirs.append(uri) - else: - self.deleted_files.append(uri) - - async def mv(self, src, dst, ctx=None): - self.moved_files.append((src, dst)) - - async def mkdir(self, uri, exist_ok=True, ctx=None): - self.created_dirs.append(uri) - - -@pytest.fixture -def processor(): - """Create a SemanticProcessor instance for testing.""" - return SemanticProcessor(max_concurrent_llm=10) - - -@pytest.fixture -def fake_fs(): - """Create a fake VikingFS instance.""" - return FakeVikingFS() - - -@pytest.fixture -def ctx(): - """Create a RequestContext for testing.""" - return RequestContext( - user=UserIdentifier("test_account", "test_user", "test_agent"), - role=Role.USER, - ) - - -class TestDetectFileType: - """Test cases for _detect_file_type() method.""" - - def test_detect_python_file(self, processor): - result = processor._detect_file_type("main.py") - assert result == FILE_TYPE_CODE - - def test_detect_javascript_file(self, processor): - result = processor._detect_file_type("app.js") - assert result == FILE_TYPE_CODE - - def test_detect_typescript_file(self, processor): - result = processor._detect_file_type("utils.ts") - assert result == FILE_TYPE_CODE - - def test_detect_java_file(self, processor): - result = processor._detect_file_type("Main.java") - assert result == FILE_TYPE_CODE - - def test_detect_go_file(self, processor): - result = processor._detect_file_type("server.go") - assert result == FILE_TYPE_CODE - - def test_detect_rust_file(self, processor): - result = processor._detect_file_type("main.rs") - assert result == FILE_TYPE_CODE - - def test_detect_c_file(self, processor): - result = processor._detect_file_type("program.c") - assert result == FILE_TYPE_CODE - - def test_detect_cpp_file(self, processor): - result = processor._detect_file_type("module.cpp") - assert result == FILE_TYPE_CODE - - def test_detect_markdown_file(self, processor): - result = processor._detect_file_type("README.md") - assert result == FILE_TYPE_DOCUMENTATION - - def test_detect_rst_file(self, processor): - result = processor._detect_file_type("docs.rst") - assert result == FILE_TYPE_DOCUMENTATION - - def test_detect_txt_file(self, processor): - result = processor._detect_file_type("notes.txt") - assert result == FILE_TYPE_DOCUMENTATION - - def test_detect_json_file(self, processor): - result = processor._detect_file_type("config.json") - assert result == FILE_TYPE_CODE - - def test_detect_yaml_file(self, processor): - result = processor._detect_file_type("settings.yaml") - assert result == FILE_TYPE_CODE - - def test_detect_unknown_extension(self, processor): - result = processor._detect_file_type("data.xyz") - assert result == FILE_TYPE_OTHER - - def test_detect_no_extension(self, processor): - result = processor._detect_file_type("Makefile") - assert result == FILE_TYPE_OTHER - - def test_detect_uppercase_extension(self, processor): - result = processor._detect_file_type("SCRIPT.PY") - assert result == FILE_TYPE_CODE - - def test_detect_mixed_case_extension(self, processor): - result = processor._detect_file_type("ReadMe.Md") - assert result == FILE_TYPE_DOCUMENTATION - - def test_detect_path_with_dots(self, processor): - result = processor._detect_file_type("src/utils/helper.py") - assert result == FILE_TYPE_CODE - - -class TestCollectTreeInfo: - """Test cases for _collect_tree_info() method.""" - - @pytest.mark.asyncio - async def test_collect_empty_directory(self, processor, fake_fs, ctx): - fake_fs.set_tree({"viking://temp/empty": []}) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._collect_tree_info("viking://temp/empty") - - assert result == {"viking://temp/empty": ([], [])} - - @pytest.mark.asyncio - async def test_collect_directory_with_files(self, processor, fake_fs, ctx): - fake_fs.set_tree( - { - "viking://temp/dir": [ - {"name": "file1.txt", "isDir": False}, - {"name": "file2.py", "isDir": False}, - ] - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._collect_tree_info("viking://temp/dir") - - assert "viking://temp/dir" in result - sub_dirs, files = result["viking://temp/dir"] - assert sub_dirs == [] - assert len(files) == 2 - assert "viking://temp/dir/file1.txt" in files - assert "viking://temp/dir/file2.py" in files - - @pytest.mark.asyncio - async def test_collect_directory_with_subdirs(self, processor, fake_fs, ctx): - fake_fs.set_tree( - { - "viking://temp/root": [ - {"name": "subdir1", "isDir": True}, - {"name": "subdir2", "isDir": True}, - ], - "viking://temp/root/subdir1": [ - {"name": "file1.txt", "isDir": False}, - ], - "viking://temp/root/subdir2": [ - {"name": "file2.txt", "isDir": False}, - ], - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._collect_tree_info("viking://temp/root") - - assert "viking://temp/root" in result - assert "viking://temp/root/subdir1" in result - assert "viking://temp/root/subdir2" in result - - @pytest.mark.asyncio - async def test_collect_nested_directories(self, processor, fake_fs, ctx): - fake_fs.set_tree( - { - "viking://temp/root": [ - {"name": "level1", "isDir": True}, - ], - "viking://temp/root/level1": [ - {"name": "level2", "isDir": True}, - ], - "viking://temp/root/level1/level2": [ - {"name": "deep_file.txt", "isDir": False}, - ], - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._collect_tree_info("viking://temp/root") - - assert "viking://temp/root" in result - assert "viking://temp/root/level1" in result - assert "viking://temp/root/level1/level2" in result - - @pytest.mark.asyncio - async def test_collect_skips_hidden_files(self, processor, fake_fs, ctx): - fake_fs.set_tree( - { - "viking://temp/dir": [ - {"name": ".hidden", "isDir": False}, - {"name": "visible.txt", "isDir": False}, - ] - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._collect_tree_info("viking://temp/dir") - - _, files = result["viking://temp/dir"] - assert len(files) == 1 - assert "viking://temp/dir/visible.txt" in files - - @pytest.mark.asyncio - async def test_collect_skips_dot_and_dotdot(self, processor, fake_fs, ctx): - fake_fs.set_tree( - { - "viking://temp/dir": [ - {"name": ".", "isDir": True}, - {"name": "..", "isDir": True}, - {"name": "file.txt", "isDir": False}, - ] - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._collect_tree_info("viking://temp/dir") - - sub_dirs, files = result["viking://temp/dir"] - assert sub_dirs == [] - assert len(files) == 1 - - @pytest.mark.asyncio - async def test_collect_handles_ls_error(self, processor, fake_fs, ctx): - fake_fs.ls = AsyncMock(side_effect=Exception("LS error")) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._collect_tree_info("viking://temp/dir") - - assert result == {} - - -class TestComputeDiff: - """Test cases for _compute_diff() method.""" - - @pytest.mark.asyncio - async def test_compute_diff_no_changes(self, processor, fake_fs, ctx): - root_tree = { - "viking://temp/root": ([], ["viking://temp/root/file.txt"]), - } - target_tree = { - "viking://target/root": ([], ["viking://target/root/file.txt"]), - } - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - diff = await processor._compute_diff( - root_tree, target_tree, "viking://temp/root", "viking://target/root" - ) - - assert diff.added_files == [] - assert diff.deleted_files == [] - assert diff.updated_files == [] - assert diff.added_dirs == [] - assert diff.deleted_dirs == [] - - @pytest.mark.asyncio - async def test_compute_diff_added_files(self, processor, fake_fs, ctx): - root_tree = { - "viking://temp/root": ([], ["viking://temp/root/new_file.txt"]), - } - target_tree = { - "viking://target/root": ([], []), - } - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - diff = await processor._compute_diff( - root_tree, target_tree, "viking://temp/root", "viking://target/root" - ) - - assert len(diff.added_files) == 1 - assert "viking://temp/root/new_file.txt" in diff.added_files - assert diff.deleted_files == [] - - @pytest.mark.asyncio - async def test_compute_diff_deleted_files(self, processor, fake_fs, ctx): - root_tree = { - "viking://temp/root": ([], []), - } - target_tree = { - "viking://target/root": ([], ["viking://target/root/old_file.txt"]), - } - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - diff = await processor._compute_diff( - root_tree, target_tree, "viking://temp/root", "viking://target/root" - ) - - assert diff.added_files == [] - assert len(diff.deleted_files) == 1 - assert "viking://target/root/old_file.txt" in diff.deleted_files - - @pytest.mark.asyncio - async def test_compute_diff_updated_files(self, processor, fake_fs, ctx): - root_tree = { - "viking://temp/root": ([], ["viking://temp/root/file.txt"]), - } - target_tree = { - "viking://target/root": ([], ["viking://target/root/file.txt"]), - } - fake_fs.set_file_contents( - { - "viking://temp/root/file.txt": "new content", - "viking://target/root/file.txt": "old content", - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - diff = await processor._compute_diff( - root_tree, target_tree, "viking://temp/root", "viking://target/root" - ) - - assert len(diff.updated_files) == 1 - assert "viking://temp/root/file.txt" in diff.updated_files - - @pytest.mark.asyncio - async def test_compute_diff_unchanged_files(self, processor, fake_fs, ctx): - root_tree = { - "viking://temp/root": ([], ["viking://temp/root/file.txt"]), - } - target_tree = { - "viking://target/root": ([], ["viking://target/root/file.txt"]), - } - fake_fs.set_file_contents( - { - "viking://temp/root/file.txt": "same content", - "viking://target/root/file.txt": "same content", - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - diff = await processor._compute_diff( - root_tree, target_tree, "viking://temp/root", "viking://target/root" - ) - - assert diff.updated_files == [] - - @pytest.mark.asyncio - async def test_compute_diff_added_dirs(self, processor, fake_fs, ctx): - root_tree = { - "viking://temp/root": (["viking://temp/root/new_dir"], []), - "viking://temp/root/new_dir": ([], []), - } - target_tree = { - "viking://target/root": ([], []), - } - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - diff = await processor._compute_diff( - root_tree, target_tree, "viking://temp/root", "viking://target/root" - ) - - assert len(diff.added_dirs) == 1 - assert "viking://temp/root/new_dir" in diff.added_dirs - - @pytest.mark.asyncio - async def test_compute_diff_deleted_dirs(self, processor, fake_fs, ctx): - root_tree = { - "viking://temp/root": ([], []), - } - target_tree = { - "viking://target/root": (["viking://target/root/old_dir"], []), - "viking://target/root/old_dir": ([], []), - } - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - diff = await processor._compute_diff( - root_tree, target_tree, "viking://temp/root", "viking://target/root" - ) - - assert len(diff.deleted_dirs) == 1 - assert "viking://target/root/old_dir" in diff.deleted_dirs - - @pytest.mark.asyncio - async def test_compute_diff_mixed_changes(self, processor, fake_fs, ctx): - root_tree = { - "viking://temp/root": ( - ["viking://temp/root/new_dir"], - ["viking://temp/root/new_file.txt", "viking://temp/root/updated.txt"], - ), - "viking://temp/root/new_dir": ([], []), - } - target_tree = { - "viking://target/root": ( - ["viking://target/root/old_dir"], - ["viking://target/root/updated.txt", "viking://target/root/deleted.txt"], - ), - "viking://target/root/old_dir": ([], []), - } - fake_fs.set_file_contents( - { - "viking://temp/root/updated.txt": "new content", - "viking://target/root/updated.txt": "old content", - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - diff = await processor._compute_diff( - root_tree, target_tree, "viking://temp/root", "viking://target/root" - ) - - assert len(diff.added_files) == 1 - assert len(diff.deleted_files) == 1 - assert len(diff.updated_files) == 1 - assert len(diff.added_dirs) == 1 - assert len(diff.deleted_dirs) == 1 - - -class TestCheckFileContentChanged: - """Test cases for _check_file_content_changed() method.""" - - @pytest.mark.asyncio - async def test_content_changed(self, processor, fake_fs, ctx): - fake_fs.set_file_contents( - { - "viking://temp/file.txt": "new content", - "viking://target/file.txt": "old content", - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._check_file_content_changed( - "viking://temp/file.txt", "viking://target/file.txt" - ) - - assert result is True - - @pytest.mark.asyncio - async def test_content_unchanged(self, processor, fake_fs, ctx): - fake_fs.set_file_contents( - { - "viking://temp/file.txt": "same content", - "viking://target/file.txt": "same content", - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._check_file_content_changed( - "viking://temp/file.txt", "viking://target/file.txt" - ) - - assert result is False - - @pytest.mark.asyncio - async def test_content_changed_empty_files(self, processor, fake_fs, ctx): - fake_fs.set_file_contents( - { - "viking://temp/file.txt": "", - "viking://target/file.txt": "", - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._check_file_content_changed( - "viking://temp/file.txt", "viking://target/file.txt" - ) - - assert result is False - - @pytest.mark.asyncio - async def test_content_changed_one_empty(self, processor, fake_fs, ctx): - fake_fs.set_file_contents( - { - "viking://temp/file.txt": "content", - "viking://target/file.txt": "", - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._check_file_content_changed( - "viking://temp/file.txt", "viking://target/file.txt" - ) - - assert result is True - - @pytest.mark.asyncio - async def test_content_changed_on_exception(self, processor, fake_fs, ctx): - fake_fs.read_file = AsyncMock(side_effect=Exception("Read error")) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - result = await processor._check_file_content_changed( - "viking://temp/file.txt", "viking://target/file.txt" - ) - - assert result is True - - -class TestExecuteSyncOperations: - """Test cases for _execute_sync_operations() method.""" - - @pytest.mark.asyncio - async def test_execute_delete_files(self, processor, fake_fs, ctx): - diff = DiffResult( - added_files=[], - deleted_files=["viking://target/deleted.txt"], - updated_files=[], - added_dirs=[], - deleted_dirs=[], - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - await processor._execute_sync_operations( - diff, "viking://temp/root", "viking://target/root" - ) - - assert "viking://target/deleted.txt" in fake_fs.deleted_files - - @pytest.mark.asyncio - async def test_execute_move_added_files(self, processor, fake_fs, ctx): - diff = DiffResult( - added_files=["viking://temp/root/new.txt"], - deleted_files=[], - updated_files=[], - added_dirs=[], - deleted_dirs=[], - ) - processor._current_ctx = ctx - - mock_viking_uri = MagicMock() - mock_viking_uri.parent = None - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - with patch( - "openviking.storage.queuefs.semantic_processor.VikingURI", - return_value=mock_viking_uri, - ): - await processor._execute_sync_operations( - diff, "viking://temp/root", "viking://target/root" - ) - - assert ("viking://temp/root/new.txt", "viking://target/root/new.txt") in fake_fs.moved_files - - @pytest.mark.asyncio - async def test_execute_move_updated_files(self, processor, fake_fs, ctx): - diff = DiffResult( - added_files=[], - deleted_files=[], - updated_files=["viking://temp/root/updated.txt"], - added_dirs=[], - deleted_dirs=[], - ) - processor._current_ctx = ctx - - mock_viking_uri = MagicMock() - mock_viking_uri.parent = None - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - with patch( - "openviking.storage.queuefs.semantic_processor.VikingURI", - return_value=mock_viking_uri, - ): - await processor._execute_sync_operations( - diff, "viking://temp/root", "viking://target/root" - ) - - assert "viking://target/root/updated.txt" in fake_fs.deleted_files - assert ( - "viking://temp/root/updated.txt", - "viking://target/root/updated.txt", - ) in fake_fs.moved_files - - @pytest.mark.asyncio - async def test_execute_delete_dirs(self, processor, fake_fs, ctx): - diff = DiffResult( - added_files=[], - deleted_files=[], - updated_files=[], - added_dirs=[], - deleted_dirs=["viking://target/root/old_dir"], - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - await processor._execute_sync_operations( - diff, "viking://temp/root", "viking://target/root" - ) - - assert "viking://target/root/old_dir" in fake_fs.deleted_dirs - - @pytest.mark.asyncio - async def test_execute_delete_dirs_deepest_first(self, processor, fake_fs, ctx): - diff = DiffResult( - added_files=[], - deleted_files=[], - updated_files=[], - added_dirs=[], - deleted_dirs=[ - "viking://target/root/level1", - "viking://target/root/level1/level2", - "viking://target/root/level1/level2/level3", - ], - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - await processor._execute_sync_operations( - diff, "viking://temp/root", "viking://target/root" - ) - - assert len(fake_fs.deleted_dirs) == 3 - deepest_first = sorted(fake_fs.deleted_dirs, key=lambda x: x.count("/"), reverse=True) - assert fake_fs.deleted_dirs == deepest_first - - @pytest.mark.asyncio - async def test_execute_creates_parent_dirs(self, processor, fake_fs, ctx): - diff = DiffResult( - added_files=["viking://temp/root/subdir/new.txt"], - deleted_files=[], - updated_files=[], - added_dirs=[], - deleted_dirs=[], - ) - processor._current_ctx = ctx - - mock_parent = MagicMock() - mock_parent.uri = "viking://target/root/subdir" - mock_viking_uri = MagicMock() - mock_viking_uri.parent = mock_parent - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - with patch( - "openviking.storage.queuefs.semantic_processor.VikingURI", - return_value=mock_viking_uri, - ): - await processor._execute_sync_operations( - diff, "viking://temp/root", "viking://target/root" - ) - - assert "viking://target/root/subdir" in fake_fs.created_dirs - - -class TestCreateSyncDiffCallback: - """Test cases for _create_sync_diff_callback() method.""" - - @pytest.mark.asyncio - async def test_callback_returns_callable(self, processor): - callback = processor._create_sync_diff_callback( - "viking://temp/root", "viking://target/root" - ) - assert callable(callback) - - @pytest.mark.asyncio - async def test_callback_is_async(self, processor): - callback = processor._create_sync_diff_callback( - "viking://temp/root", "viking://target/root" - ) - import asyncio - - assert asyncio.iscoroutinefunction(callback) - - @pytest.mark.asyncio - async def test_callback_collects_tree_info(self, processor, fake_fs, ctx): - fake_fs.set_tree( - { - "viking://temp/root": [ - {"name": "file.txt", "isDir": False}, - ], - "viking://target/root": [ - {"name": "file.txt", "isDir": False}, - ], - } - ) - fake_fs.set_file_contents( - { - "viking://temp/root/file.txt": "content", - "viking://target/root/file.txt": "content", - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - callback = processor._create_sync_diff_callback( - "viking://temp/root", "viking://target/root" - ) - await callback() - - @pytest.mark.asyncio - async def test_callback_handles_exception(self, processor, fake_fs, ctx): - fake_fs.ls = AsyncMock(side_effect=Exception("Test error")) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - callback = processor._create_sync_diff_callback( - "viking://temp/root", "viking://target/root" - ) - await callback() - - @pytest.mark.asyncio - async def test_callback_deletes_root_after_sync(self, processor, fake_fs, ctx): - fake_fs.set_tree( - { - "viking://temp/root": [], - "viking://target/root": [], - } - ) - processor._current_ctx = ctx - - with patch( - "openviking.storage.queuefs.semantic_processor.get_viking_fs", return_value=fake_fs - ): - callback = processor._create_sync_diff_callback( - "viking://temp/root", "viking://target/root" - ) - await callback() - - assert "viking://temp/root" in fake_fs.deleted_dirs - - -class TestDiffResult: - """Test cases for DiffResult dataclass.""" - - def test_diff_result_default_values(self): - diff = DiffResult() - assert diff.added_files == [] - assert diff.deleted_files == [] - assert diff.updated_files == [] - assert diff.added_dirs == [] - assert diff.deleted_dirs == [] - - def test_diff_result_with_values(self): - diff = DiffResult( - added_files=["a.txt"], - deleted_files=["b.txt"], - updated_files=["c.txt"], - added_dirs=["dir1"], - deleted_dirs=["dir2"], - ) - assert diff.added_files == ["a.txt"] - assert diff.deleted_files == ["b.txt"] - assert diff.updated_files == ["c.txt"] - assert diff.added_dirs == ["dir1"] - assert diff.deleted_dirs == ["dir2"] - - def test_diff_result_modifiable(self): - diff = DiffResult() - diff.added_files.append("new.txt") - assert "new.txt" in diff.added_files diff --git a/tests/unit/storage/queuefs/test_semantic_msg.py b/tests/unit/storage/queuefs/test_semantic_msg.py deleted file mode 100644 index 37551109..00000000 --- a/tests/unit/storage/queuefs/test_semantic_msg.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Unit tests for SemanticMsg dataclass, focusing on new fields target_uri and skip_vectorization.""" - -import json - -import pytest - -from openviking.storage.queuefs.semantic_msg import SemanticMsg - - -class TestTargetUriField: - """Tests for target_uri field serialization and deserialization.""" - - def test_target_uri_default_value(self): - msg = SemanticMsg(uri="viking://resource/test", context_type="resource") - assert msg.target_uri == "" - - def test_target_uri_set_in_constructor(self): - msg = SemanticMsg( - uri="viking://resource/temp", - context_type="resource", - target_uri="viking://resource/target", - ) - assert msg.target_uri == "viking://resource/target" - - def test_target_uri_serialization_to_dict(self): - msg = SemanticMsg( - uri="viking://resource/temp", - context_type="resource", - target_uri="viking://resource/target", - ) - data = msg.to_dict() - assert "target_uri" in data - assert data["target_uri"] == "viking://resource/target" - - def test_target_uri_deserialization_from_dict(self): - data = { - "uri": "viking://resource/temp", - "context_type": "resource", - "target_uri": "viking://resource/target", - } - msg = SemanticMsg.from_dict(data) - assert msg.target_uri == "viking://resource/target" - - def test_target_uri_empty_string_serialization(self): - msg = SemanticMsg( - uri="viking://resource/test", - context_type="resource", - target_uri="", - ) - data = msg.to_dict() - assert data["target_uri"] == "" - - def test_target_uri_with_memory_context(self): - msg = SemanticMsg( - uri="viking://memory/temp/session", - context_type="memory", - target_uri="viking://session/abc123", - ) - data = msg.to_dict() - msg_restored = SemanticMsg.from_dict(data) - assert msg_restored.target_uri == "viking://session/abc123" - - -class TestSkipVectorizationField: - """Tests for skip_vectorization field serialization and deserialization.""" - - def test_skip_vectorization_default_value(self): - msg = SemanticMsg(uri="viking://resource/test", context_type="resource") - assert msg.skip_vectorization is False - - def test_skip_vectorization_set_true_in_constructor(self): - msg = SemanticMsg( - uri="viking://resource/test", - context_type="resource", - skip_vectorization=True, - ) - assert msg.skip_vectorization is True - - def test_skip_vectorization_serialization_to_dict(self): - msg = SemanticMsg( - uri="viking://resource/test", - context_type="resource", - skip_vectorization=True, - ) - data = msg.to_dict() - assert "skip_vectorization" in data - assert data["skip_vectorization"] is True - - def test_skip_vectorization_deserialization_from_dict(self): - data = { - "uri": "viking://resource/test", - "context_type": "resource", - "skip_vectorization": True, - } - msg = SemanticMsg.from_dict(data) - assert msg.skip_vectorization is True - - def test_skip_vectorization_false_serialization(self): - msg = SemanticMsg( - uri="viking://resource/test", - context_type="resource", - skip_vectorization=False, - ) - data = msg.to_dict() - assert data["skip_vectorization"] is False - - def test_skip_vectorization_round_trip(self): - original = SemanticMsg( - uri="viking://resource/test", - context_type="resource", - skip_vectorization=True, - ) - restored = SemanticMsg.from_dict(original.to_dict()) - assert restored.skip_vectorization == original.skip_vectorization - - -class TestFromDictBackwardCompatibility: - """Tests for from_dict() backward compatibility with old format missing new fields.""" - - def test_missing_target_uri_uses_default(self): - data = { - "uri": "viking://resource/test", - "context_type": "resource", - } - msg = SemanticMsg.from_dict(data) - assert msg.target_uri == "" - - def test_missing_skip_vectorization_uses_default(self): - data = { - "uri": "viking://resource/test", - "context_type": "resource", - } - msg = SemanticMsg.from_dict(data) - assert msg.skip_vectorization is False - - def test_missing_both_new_fields_uses_defaults(self): - data = { - "uri": "viking://resource/test", - "context_type": "resource", - "recursive": True, - "account_id": "test_account", - } - msg = SemanticMsg.from_dict(data) - assert msg.target_uri == "" - assert msg.skip_vectorization is False - - def test_old_format_with_all_legacy_fields(self): - data = { - "id": "legacy-id-123", - "uri": "viking://resource/test", - "context_type": "resource", - "status": "pending", - "recursive": False, - "account_id": "account1", - "user_id": "user1", - "agent_id": "agent1", - "role": "admin", - } - msg = SemanticMsg.from_dict(data) - assert msg.id == "legacy-id-123" - assert msg.uri == "viking://resource/test" - assert msg.context_type == "resource" - assert msg.status == "pending" - assert msg.recursive is False - assert msg.target_uri == "" - assert msg.skip_vectorization is False - - def test_partial_new_fields_only_target_uri(self): - data = { - "uri": "viking://resource/test", - "context_type": "resource", - "target_uri": "viking://resource/target", - } - msg = SemanticMsg.from_dict(data) - assert msg.target_uri == "viking://resource/target" - assert msg.skip_vectorization is False - - def test_partial_new_fields_only_skip_vectorization(self): - data = { - "uri": "viking://resource/test", - "context_type": "resource", - "skip_vectorization": True, - } - msg = SemanticMsg.from_dict(data) - assert msg.target_uri == "" - assert msg.skip_vectorization is True - - -class TestToJsonFromJson: - """Tests for to_json() and from_json() methods.""" - - def test_to_json_returns_valid_json_string(self): - msg = SemanticMsg( - uri="viking://resource/test", - context_type="resource", - ) - json_str = msg.to_json() - assert isinstance(json_str, str) - parsed = json.loads(json_str) - assert isinstance(parsed, dict) - - def test_from_json_creates_valid_object(self): - json_str = '{"uri": "viking://resource/test", "context_type": "resource"}' - msg = SemanticMsg.from_json(json_str) - assert msg.uri == "viking://resource/test" - assert msg.context_type == "resource" - - def test_to_json_and_from_json_round_trip(self): - original = SemanticMsg( - uri="viking://resource/temp", - context_type="memory", - target_uri="viking://session/abc", - skip_vectorization=True, - recursive=False, - account_id="test_account", - user_id="test_user", - agent_id="test_agent", - role="admin", - ) - json_str = original.to_json() - restored = SemanticMsg.from_json(json_str) - - assert restored.uri == original.uri - assert restored.context_type == original.context_type - assert restored.target_uri == original.target_uri - assert restored.skip_vectorization == original.skip_vectorization - assert restored.recursive == original.recursive - assert restored.account_id == original.account_id - assert restored.user_id == original.user_id - assert restored.agent_id == original.agent_id - assert restored.role == original.role - - def test_from_json_with_new_fields(self): - json_str = json.dumps( - { - "uri": "viking://resource/test", - "context_type": "resource", - "target_uri": "viking://resource/target", - "skip_vectorization": True, - } - ) - msg = SemanticMsg.from_json(json_str) - assert msg.target_uri == "viking://resource/target" - assert msg.skip_vectorization is True - - def test_from_json_without_new_fields(self): - json_str = json.dumps( - { - "uri": "viking://resource/test", - "context_type": "resource", - } - ) - msg = SemanticMsg.from_json(json_str) - assert msg.target_uri == "" - assert msg.skip_vectorization is False - - def test_from_json_invalid_json_raises_value_error(self): - with pytest.raises(ValueError, match="Invalid JSON string"): - SemanticMsg.from_json("not a valid json") - - def test_from_json_missing_required_fields_raises_value_error(self): - json_str = '{"uri": "viking://resource/test"}' - with pytest.raises(ValueError, match="Missing required fields"): - SemanticMsg.from_json(json_str) - - -class TestRequiredFieldsValidation: - """Tests for required field validation (uri and context_type).""" - - def test_missing_uri_raises_value_error(self): - data = {"context_type": "resource"} - with pytest.raises(ValueError, match="Missing required fields"): - SemanticMsg.from_dict(data) - - def test_missing_context_type_raises_value_error(self): - data = {"uri": "viking://resource/test"} - with pytest.raises(ValueError, match="Missing required fields"): - SemanticMsg.from_dict(data) - - def test_missing_both_required_fields_raises_value_error(self): - data = {"target_uri": "viking://resource/target"} - with pytest.raises(ValueError, match="Missing required fields"): - SemanticMsg.from_dict(data) - - def test_empty_uri_raises_value_error(self): - data = {"uri": "", "context_type": "resource"} - with pytest.raises(ValueError, match="Missing required fields"): - SemanticMsg.from_dict(data) - - def test_empty_context_type_raises_value_error(self): - data = {"uri": "viking://resource/test", "context_type": ""} - with pytest.raises(ValueError, match="Missing required fields"): - SemanticMsg.from_dict(data) - - def test_none_uri_raises_value_error(self): - data = {"uri": None, "context_type": "resource"} - with pytest.raises(ValueError, match="Missing required fields"): - SemanticMsg.from_dict(data) - - def test_none_context_type_raises_value_error(self): - data = {"uri": "viking://resource/test", "context_type": None} - with pytest.raises(ValueError, match="Missing required fields"): - SemanticMsg.from_dict(data) - - def test_empty_dict_raises_value_error(self): - with pytest.raises(ValueError, match="Data dictionary is empty"): - SemanticMsg.from_dict({}) - - def test_valid_minimal_data_succeeds(self): - data = {"uri": "viking://resource/test", "context_type": "resource"} - msg = SemanticMsg.from_dict(data) - assert msg.uri == "viking://resource/test" - assert msg.context_type == "resource" - - def test_error_message_lists_all_missing_fields(self): - data = {"skip_vectorization": True} - with pytest.raises(ValueError) as exc_info: - SemanticMsg.from_dict(data) - error_msg = str(exc_info.value) - assert "uri" in error_msg - assert "context_type" in error_msg - - -class TestEdgeCases: - """Tests for edge cases and boundary conditions.""" - - def test_target_uri_with_special_characters(self): - special_uri = "viking://resource/test%20space?query=value&other=123" - msg = SemanticMsg( - uri="viking://resource/temp", - context_type="resource", - target_uri=special_uri, - ) - restored = SemanticMsg.from_dict(msg.to_dict()) - assert restored.target_uri == special_uri - - def test_target_uri_with_unicode(self): - unicode_uri = "viking://resource/测试/目录" - msg = SemanticMsg( - uri="viking://resource/temp", - context_type="resource", - target_uri=unicode_uri, - ) - restored = SemanticMsg.from_dict(msg.to_dict()) - assert restored.target_uri == unicode_uri - - def test_target_uri_with_long_path(self): - long_path = "viking://resource/" + "/".join(["dir"] * 100) - msg = SemanticMsg( - uri="viking://resource/temp", - context_type="resource", - target_uri=long_path, - ) - restored = SemanticMsg.from_dict(msg.to_dict()) - assert restored.target_uri == long_path - - def test_preserves_existing_id_in_from_dict(self): - original = SemanticMsg( - uri="viking://resource/test", - context_type="resource", - ) - original_id = original.id - data = original.to_dict() - restored = SemanticMsg.from_dict(data) - assert restored.id == original_id - - def test_from_dict_overwrites_id_if_provided(self): - data = { - "id": "custom-id-123", - "uri": "viking://resource/test", - "context_type": "resource", - } - msg = SemanticMsg.from_dict(data) - assert msg.id == "custom-id-123" - - def test_from_dict_preserves_status_and_timestamp(self): - data = { - "uri": "viking://resource/test", - "context_type": "resource", - "status": "completed", - "timestamp": 1700000000, - } - msg = SemanticMsg.from_dict(data) - assert msg.status == "completed" - assert msg.timestamp == 1700000000 - - def test_all_context_types(self): - context_types = ["resource", "memory", "skill", "session"] - for ctx_type in context_types: - msg = SemanticMsg( - uri=f"viking://{ctx_type}/test", - context_type=ctx_type, - ) - assert msg.context_type == ctx_type - restored = SemanticMsg.from_dict(msg.to_dict()) - assert restored.context_type == ctx_type - - def test_extra_fields_in_dict_are_ignored(self): - data = { - "uri": "viking://resource/test", - "context_type": "resource", - "extra_field": "should_be_ignored", - "another_extra": 12345, - } - msg = SemanticMsg.from_dict(data) - assert msg.uri == "viking://resource/test" - assert not hasattr(msg, "extra_field") - assert not hasattr(msg, "another_extra") diff --git a/tests/unit/storage/test_viking_fs_new.py b/tests/unit/storage/test_viking_fs_new.py deleted file mode 100644 index 6661cc9f..00000000 --- a/tests/unit/storage/test_viking_fs_new.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Unit tests for VikingFS new methods. - -Tests for: -- exists(): Check if URI exists -- copy_directory(): Recursively copy directory -- delete_temp(): Delete temporary directory -""" - -import contextvars -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from openviking.storage.viking_fs import VikingFS - - -def _create_viking_fs_mock(): - """Create a VikingFS instance with mocked AGFS backend.""" - fs = VikingFS.__new__(VikingFS) - fs.agfs = MagicMock() - fs.query_embedder = None - fs.vector_store = None - fs._uri_prefix = "viking://" - fs._bound_ctx = contextvars.ContextVar("vikingfs_bound_ctx", default=None) - return fs - - -@pytest.mark.asyncio -class TestVikingFSExists: - """Test cases for VikingFS.exists() method.""" - - async def test_exists_returns_true_when_uri_exists(self): - """exists() should return True when URI exists.""" - fs = _create_viking_fs_mock() - fs.stat = AsyncMock(return_value={"name": "test_file.txt", "isDir": False}) - - result = await fs.exists("viking://temp/test_file.txt") - - assert result is True - fs.stat.assert_awaited_once_with("viking://temp/test_file.txt", ctx=None) - - async def test_exists_returns_false_when_uri_not_found(self): - """exists() should return False when URI does not exist.""" - fs = _create_viking_fs_mock() - fs.stat = AsyncMock(side_effect=FileNotFoundError("Not found")) - - result = await fs.exists("viking://temp/nonexistent.txt") - - assert result is False - fs.stat.assert_awaited_once_with("viking://temp/nonexistent.txt", ctx=None) - - async def test_exists_returns_false_on_any_exception(self): - """exists() should return False on any exception, not just FileNotFoundError.""" - fs = _create_viking_fs_mock() - fs.stat = AsyncMock(side_effect=PermissionError("Access denied")) - - result = await fs.exists("viking://temp/protected.txt") - - assert result is False - - -@pytest.mark.asyncio -class TestVikingFSCopyDirectory: - """Test cases for VikingFS.copy_directory() method.""" - - async def test_copy_directory_recursive(self): - """copy_directory() should recursively copy directory contents.""" - fs = _create_viking_fs_mock() - fs._ensure_access = MagicMock() - fs._uri_to_path = MagicMock( - side_effect=lambda uri, ctx=None: uri.replace("viking://", "/local/") - ) - fs._ensure_parent_dirs = AsyncMock() - - mock_agfs_cp = MagicMock() - - with patch("openviking.storage.viking_fs.agfs_cp", mock_agfs_cp): - await fs.copy_directory( - "viking://temp/source_dir/", - "viking://temp/dest_dir/", - ) - - fs._ensure_access.assert_any_call("viking://temp/source_dir/", None) - fs._ensure_access.assert_any_call("viking://temp/dest_dir/", None) - fs._ensure_parent_dirs.assert_awaited_once_with("/local/temp/dest_dir/") - mock_agfs_cp.assert_called_once_with( - fs.agfs, - "/local/temp/source_dir/", - "/local/temp/dest_dir/", - recursive=True, - ) - - async def test_copy_directory_with_context(self): - """copy_directory() should pass context to helper methods.""" - from openviking.server.identity import RequestContext, Role - from openviking_cli.session.user_id import UserIdentifier - - fs = _create_viking_fs_mock() - fs._ensure_access = MagicMock() - fs._uri_to_path = MagicMock( - side_effect=lambda uri, ctx=None: uri.replace("viking://", "/local/") - ) - fs._ensure_parent_dirs = AsyncMock() - - ctx = RequestContext( - user=UserIdentifier("acc1", "user1", "agent1"), - role=Role.USER, - ) - - mock_agfs_cp = MagicMock() - - with patch("openviking.storage.viking_fs.agfs_cp", mock_agfs_cp): - await fs.copy_directory( - "viking://temp/source/", - "viking://temp/dest/", - ctx=ctx, - ) - - fs._ensure_access.assert_any_call("viking://temp/source/", ctx) - fs._ensure_access.assert_any_call("viking://temp/dest/", ctx) - - -@pytest.mark.asyncio -class TestVikingFSDeleteTemp: - """Test cases for VikingFS.delete_temp() method.""" - - async def test_delete_temp_removes_directory_and_contents(self): - """delete_temp() should remove directory and all its contents.""" - fs = _create_viking_fs_mock() - fs._uri_to_path = MagicMock(return_value="/local/temp/test_temp") - - fs._ls_entries = MagicMock( - return_value=[ - {"name": "file1.txt", "isDir": False}, - {"name": "subdir", "isDir": True}, - ] - ) - - fs.agfs.rm = MagicMock() - - call_count = [0] - - async def mock_delete_temp(uri, ctx=None): - call_count[0] += 1 - if call_count[0] == 1: - fs._ls_entries.return_value = [ - {"name": "nested_file.txt", "isDir": False}, - ] - await fs.delete_temp(uri, ctx=ctx) - else: - fs._ls_entries.return_value = [] - - original_delete_temp = fs.delete_temp - fs.delete_temp = mock_delete_temp - - await original_delete_temp("viking://temp/test_temp/") - - assert fs.agfs.rm.call_count >= 1 - - async def test_delete_temp_handles_empty_directory(self): - """delete_temp() should handle empty directory gracefully.""" - fs = _create_viking_fs_mock() - fs._uri_to_path = MagicMock(return_value="/local/temp/empty_temp") - fs._ls_entries = MagicMock(return_value=[]) - fs.agfs.rm = MagicMock() - - await fs.delete_temp("viking://temp/empty_temp/") - - fs.agfs.rm.assert_called_once_with("/local/temp/empty_temp") - - async def test_delete_temp_skips_dot_entries(self): - """delete_temp() should skip . and .. entries.""" - fs = _create_viking_fs_mock() - fs._uri_to_path = MagicMock(return_value="/local/temp/test_temp") - fs._ls_entries = MagicMock( - return_value=[ - {"name": ".", "isDir": True}, - {"name": "..", "isDir": True}, - {"name": "actual_file.txt", "isDir": False}, - ] - ) - fs.agfs.rm = MagicMock() - - await fs.delete_temp("viking://temp/test_temp/") - - rm_calls = [call[0][0] for call in fs.agfs.rm.call_args_list] - assert "/local/temp/test_temp/actual_file.txt" in rm_calls - assert "/local/temp/test_temp/." not in rm_calls - assert "/local/temp/test_temp/.." not in rm_calls - - async def test_delete_temp_logs_warning_on_error(self): - """delete_temp() should log warning but not raise on error.""" - fs = _create_viking_fs_mock() - fs._uri_to_path = MagicMock(return_value="/local/temp/error_temp") - fs._ls_entries = MagicMock(side_effect=Exception("AGFS error")) - - with patch("openviking.storage.viking_fs.logger") as mock_logger: - await fs.delete_temp("viking://temp/error_temp/") - - mock_logger.warning.assert_called_once() - assert "Failed to delete temp" in mock_logger.warning.call_args[0][0]