diff --git a/bot/vikingbot/config/loader.py b/bot/vikingbot/config/loader.py index e57c3879..130e151b 100644 --- a/bot/vikingbot/config/loader.py +++ b/bot/vikingbot/config/loader.py @@ -4,11 +4,14 @@ import os from pathlib import Path from typing import Any + from loguru import logger + from vikingbot.config.schema import Config CONFIG_PATH = None + def get_config_path() -> Path: """Get the path to ov.conf config file. @@ -217,4 +220,4 @@ def camel_to_snake(name: str) -> str: def snake_to_camel(name: str) -> str: """Convert snake_case to camelCase.""" components = name.split("_") - return components[0] + "".join(x.title() for x in components[1:]) \ No newline at end of file + return components[0] + "".join(x.title() for x in components[1:]) diff --git a/bot/vikingbot/config/schema.py b/bot/vikingbot/config/schema.py index 90a2b10c..3dc1a8d5 100644 --- a/bot/vikingbot/config/schema.py +++ b/bot/vikingbot/config/schema.py @@ -40,8 +40,10 @@ class SandboxMode(str, Enum): SHARED = "shared" PER_CHANNEL = "per-channel" + class AgentMemoryMode(str, Enum): """Agent memory mode enumeration.""" + PER_SESSION = "per-session" SHARED = "shared" PER_CHANNEL = "per-channel" @@ -109,7 +111,10 @@ class FeishuChannelConfig(BaseChannelConfig): encrypt_key: str = "" verification_token: str = "" allow_from: list[str] = Field(default_factory=list) ## 允许更新Agent对话的Feishu用户ID列表 - thread_require_mention: bool = Field(default=True, description="话题群模式下是否需要@才响应:默认True=所有消息必须@才响应;False=新话题首条消息无需@,后续回复必须@") + thread_require_mention: bool = Field( + default=True, + description="话题群模式下是否需要@才响应:默认True=所有消息必须@才响应;False=新话题首条消息无需@,后续回复必须@", + ) def channel_id(self) -> str: # Use app_id directly as the ID @@ -396,7 +401,9 @@ class ProviderConfig(BaseModel): api_key: str = "" api_base: Optional[str] = None - extra_headers: Optional[dict[str, str]] = Field(default_factory=dict) # Custom headers (e.g. APP-Code for AiHubMix) + extra_headers: Optional[dict[str, str]] = Field( + default_factory=dict + ) # Custom headers (e.g. APP-Code for AiHubMix) class ProvidersConfig(BaseModel): @@ -734,4 +741,4 @@ def from_safe_name(safe_name: str): file_name_split = safe_name.split("__") return SessionKey( type=file_name_split[0], channel_id=file_name_split[1], chat_id=file_name_split[2] - ) \ No newline at end of file + ) diff --git a/bot/vikingbot/openviking_mount/ov_server.py b/bot/vikingbot/openviking_mount/ov_server.py index d3f6383c..514beb47 100644 --- a/bot/vikingbot/openviking_mount/ov_server.py +++ b/bot/vikingbot/openviking_mount/ov_server.py @@ -1,9 +1,9 @@ import asyncio import hashlib -from typing import List, Dict, Any, Optional +import time +from typing import Any, Dict, List, Optional from loguru import logger -import time import openviking as ov from vikingbot.config.loader import load_config @@ -99,9 +99,7 @@ async def find(self, query: str, target_uri: Optional[str] = None): return await self.client.find(query, target_uri=target_uri) return await self.client.find(query) - async def add_resource( - self, local_path: str, desc: str - ) -> Optional[Dict[str, Any]]: + async def add_resource(self, local_path: str, desc: str) -> Optional[Dict[str, Any]]: """添加资源到 Viking""" result = await self.client.add_resource(path=local_path, reason=desc) return result @@ -327,7 +325,9 @@ async def search_memory( async def grep(self, uri: str, pattern: str, case_insensitive: bool = False) -> Dict[str, Any]: """通过模式(正则表达式)搜索内容""" - return await self.client.grep(uri, pattern, case_insensitive=case_insensitive, node_limit=10) + return await self.client.grep( + uri, pattern, case_insensitive=case_insensitive, node_limit=10 + ) async def glob(self, pattern: str, uri: Optional[str] = None) -> Dict[str, Any]: """通过 glob 模式匹配文件""" @@ -337,7 +337,8 @@ async def commit(self, session_id: str, messages: list[dict[str, Any]], user_id: """提交会话""" import re import uuid - from openviking.message.part import Part, TextPart, ToolPart + + from openviking.message.part import TextPart, ToolPart user_exists = await self._check_user_exists(user_id) if not user_exists: diff --git a/openviking/models/embedder/__init__.py b/openviking/models/embedder/__init__.py index b418b809..02ccdfc0 100644 --- a/openviking/models/embedder/__init__.py +++ b/openviking/models/embedder/__init__.py @@ -25,7 +25,6 @@ ) from openviking.models.embedder.jina_embedders import JinaDenseEmbedder from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder -from openviking.models.embedder.voyage_embedders import VoyageDenseEmbedder from openviking.models.embedder.vikingdb_embedders import ( VikingDBDenseEmbedder, VikingDBHybridEmbedder, @@ -36,6 +35,7 @@ VolcengineHybridEmbedder, VolcengineSparseEmbedder, ) +from openviking.models.embedder.voyage_embedders import VoyageDenseEmbedder __all__ = [ # Base classes diff --git a/openviking/parse/parsers/excel.py b/openviking/parse/parsers/excel.py index 2da44ce3..a904f786 100644 --- a/openviking/parse/parsers/excel.py +++ b/openviking/parse/parsers/excel.py @@ -142,7 +142,9 @@ def _format_xls_cell(cell, wb, xlrd) -> str: dt = xlrd.xldate_as_tuple(cell.value, wb.datemode) # Include time component if non-zero if dt[3] or dt[4] or dt[5]: - return f"{dt[0]:04d}-{dt[1]:02d}-{dt[2]:02d} {dt[3]:02d}:{dt[4]:02d}:{dt[5]:02d}" + return ( + f"{dt[0]:04d}-{dt[1]:02d}-{dt[2]:02d} {dt[3]:02d}:{dt[4]:02d}:{dt[5]:02d}" + ) return f"{dt[0]:04d}-{dt[1]:02d}-{dt[2]:02d}" except Exception: return str(cell.value) @@ -151,8 +153,13 @@ def _format_xls_cell(cell, wb, xlrd) -> str: if cell.ctype == xlrd.XL_CELL_ERROR: # xlrd error code map error_map = { - 0x00: "#NULL!", 0x07: "#DIV/0!", 0x0F: "#VALUE!", - 0x17: "#REF!", 0x1D: "#NAME?", 0x24: "#NUM!", 0x2A: "#N/A", + 0x00: "#NULL!", + 0x07: "#DIV/0!", + 0x0F: "#VALUE!", + 0x17: "#REF!", + 0x1D: "#NAME?", + 0x24: "#NUM!", + 0x2A: "#N/A", } return error_map.get(cell.value, f"#ERR({cell.value})") if cell.ctype == xlrd.XL_CELL_NUMBER: diff --git a/openviking/parse/parsers/legacy_doc.py b/openviking/parse/parsers/legacy_doc.py index 95b770ae..025216f9 100644 --- a/openviking/parse/parsers/legacy_doc.py +++ b/openviking/parse/parsers/legacy_doc.py @@ -19,7 +19,7 @@ logger = get_logger(__name__) - # Max stream size to read (50MB) — prevents DoS from crafted files +# Max stream size to read (50MB) — prevents DoS from crafted files _MAX_STREAM_SIZE = 50 * 1024 * 1024 # Max character count sanity cap for ccpText _MAX_CCP_TEXT = 10_000_000 @@ -154,9 +154,7 @@ def _extract_from_ole(self, ole) -> str: if fc_clx <= 0 or lcb_clx <= 0 or fc_clx + lcb_clx > len(table_data): return self._simple_text_extract(word_doc, ccp_text) - return self._extract_via_clx( - word_doc, table_data, fc_clx, lcb_clx, ccp_text - ) + return self._extract_via_clx(word_doc, table_data, fc_clx, lcb_clx, ccp_text) def _simple_text_extract(self, word_doc: bytes, ccp_text: int) -> str: """ @@ -177,7 +175,10 @@ def _simple_text_extract(self, word_doc: bytes, ccp_text: int) -> str: raw = word_doc[text_start:end] text = raw.decode("utf-16-le", errors="replace") # Sanity: if mostly printable, it's likely correct - if sum(1 for c in text[:200] if c.isprintable() or c in "\n\r\t") > len(text[:200]) * 0.5: + if ( + sum(1 for c in text[:200] if c.isprintable() or c in "\n\r\t") + > len(text[:200]) * 0.5 + ): return self._clean_word_text(text) # Fall back to CP1252 single-byte @@ -277,7 +278,9 @@ def _extract_via_clx( raw = word_doc[byte_offset:byte_end] text_parts.append(self._decode_cp1252(raw)) else: - logger.warning(f"Piece {i} extends beyond stream ({byte_end} > {len(word_doc)})") + logger.warning( + f"Piece {i} extends beyond stream ({byte_end} > {len(word_doc)})" + ) else: # UTF-16LE byte_offset = fc_real @@ -286,7 +289,9 @@ def _extract_via_clx( raw = word_doc[byte_offset:byte_end] text_parts.append(raw.decode("utf-16-le", errors="replace")) else: - logger.warning(f"Piece {i} extends beyond stream ({byte_end} > {len(word_doc)})") + logger.warning( + f"Piece {i} extends beyond stream ({byte_end} > {len(word_doc)})" + ) chars_extracted += piece_char_count @@ -305,7 +310,7 @@ def _clean_word_text(text: str) -> str: """Normalize Word control characters to readable equivalents.""" text = text.replace("\r\n", "\n").replace("\r", "\n") # \x07 = cell/row end, \x0B = soft line break, \x0C = section break - text = text.replace("\x07", "\t").replace("\x0B", "\n").replace("\x0C", "\n\n") + text = text.replace("\x07", "\t").replace("\x0b", "\n").replace("\x0c", "\n\n") return text def _fallback_extract(self, path: Path) -> str: diff --git a/openviking/parse/registry.py b/openviking/parse/registry.py index 4c6b3e4a..40d351fb 100644 --- a/openviking/parse/registry.py +++ b/openviking/parse/registry.py @@ -19,14 +19,14 @@ # Import will be handled dynamically to avoid dependency issues from openviking.parse.parsers.html import HTMLParser + +# Import markitdown-inspired parsers +from openviking.parse.parsers.legacy_doc import LegacyDocParser from openviking.parse.parsers.markdown import MarkdownParser from openviking.parse.parsers.media import AudioParser, ImageParser, VideoParser from openviking.parse.parsers.pdf import PDFParser from openviking.parse.parsers.powerpoint import PowerPointParser from openviking.parse.parsers.text import TextParser - -# Import markitdown-inspired parsers -from openviking.parse.parsers.legacy_doc import LegacyDocParser from openviking.parse.parsers.word import WordParser from openviking.parse.parsers.zip_parser import ZipParser diff --git a/openviking/retrieve/hierarchical_retriever.py b/openviking/retrieve/hierarchical_retriever.py index cddc5fea..9debd0ce 100644 --- a/openviking/retrieve/hierarchical_retriever.py +++ b/openviking/retrieve/hierarchical_retriever.py @@ -18,6 +18,7 @@ from openviking.retrieve.retrieval_stats import get_stats_collector from openviking.server.identity import RequestContext, Role from openviking.storage import VikingDBManager, VikingDBManagerProxy +from openviking.storage.memory_relation_store import MemoryRelationStore, RelationType from openviking.storage.viking_fs import get_viking_fs from openviking.telemetry import get_current_telemetry from openviking.utils.time_utils import parse_iso_datetime @@ -56,6 +57,7 @@ def __init__( storage: VikingDBManager, embedder: Optional[Any], rerank_config: Optional[RerankConfig] = None, + memory_relation_store: Optional[MemoryRelationStore] = None, ): """Initialize hierarchical retriever with rerank_config. @@ -63,10 +65,13 @@ def __init__( storage: VikingVectorIndexBackend instance embedder: Embedder instance (supports dense/sparse/hybrid) rerank_config: Rerank configuration (optional, will fallback to vector search only) + memory_relation_store: Optional memory relation store for filtering + superseded memories during retrieval. """ self.vector_store = storage self.embedder = embedder self.rerank_config = rerank_config + self.memory_relation_store = memory_relation_store # Use rerank threshold if available, otherwise use a default self.threshold = rerank_config.threshold if rerank_config else 0 @@ -92,6 +97,7 @@ async def retrieve( score_threshold: Optional[float] = None, score_gte: bool = False, scope_dsl: Optional[Dict[str, Any]] = None, + follow_relations: bool = False, ) -> QueryResult: """ Execute hierarchical retrieval. @@ -102,12 +108,15 @@ async def retrieve( score_gte: True uses >=, False uses > grep_patterns: Keyword match pattern list scope_dsl: Additional scope constraints passed from public find/search filter + follow_relations: When True and a memory_relation_store is configured, + filter out superseded memories and surface related memories as + additional context (up to MAX_RELATIONS). """ t0 = time.monotonic() # Use custom threshold or default threshold effective_threshold = score_threshold if score_threshold is not None else self.threshold - # 创建 proxy 包装器,绑定当前 ctx + # Create proxy wrapper bound to current ctx vector_proxy = VikingDBManagerProxy(self.vector_store, ctx) target_dirs = [d for d in (query.target_directories or []) if d] @@ -173,7 +182,7 @@ async def retrieve( mode=mode, ) - # 从 global_results 中提取 level 2 的文件作为初始候选者 + # Extract level 2 files from global_results as initial candidates initial_candidates = [r for r in global_results if r.get("level", 2) == 2] # Step 4: Recursive search @@ -196,6 +205,10 @@ async def retrieve( # Step 6: Convert results matched = await self._convert_to_matched_contexts(candidates, ctx=ctx) + # Step 6b: Filter superseded memories and add related context + if follow_relations and self.memory_relation_store: + matched = await self._apply_memory_relations(matched) + final = matched[:limit] # Record retrieval stats for the observer. @@ -291,13 +304,11 @@ def _merge_starting_points( docs = [str(r.get("abstract", "")) for r in global_results] query_scores = self._rerank_scores(query, docs, default_scores) for i, r in enumerate(global_results): - # 只添加非 level 2 的项目到起始点 if r.get("level", 2) != 2: points.append((r["uri"], query_scores[i])) seen.add(r["uri"]) else: for r in global_results: - # 只添加非 level 2 的项目到起始点 if r.get("level", 2) != 2: points.append((r["uri"], r["_score"])) seen.add(r["uri"]) @@ -352,12 +363,11 @@ def passes_threshold(score: float) -> bool: prev_topk_uris: set = set() convergence_rounds = 0 - # 添加初始候选者(level 2 文件) + # Add initial candidates (level 2 files) if initial_candidates: for r in initial_candidates: uri = r.get("uri", "") if uri: - # 只添加 level 2 的文件 if r.get("level", 2) == 2: score = r.get("_score", 0.0) r["_final_score"] = score @@ -385,7 +395,7 @@ def passes_threshold(score: float) -> bool: results = await vector_proxy.search_children_in_tenant( parent_uri=current_uri, query_vector=query_vector, - sparse_query_vector=sparse_query_vector, # Pass sparse vector + sparse_query_vector=sparse_query_vector, context_type=context_type, target_directories=target_dirs, extra_filter=scope_dsl, @@ -417,7 +427,6 @@ def passes_threshold(score: float) -> bool: continue telemetry.count("vector.passed", 1) - # Deduplicate by URI and keep the highest-scored candidate. previous = collected_by_uri.get(uri) if previous is None or final_score > previous.get("_final_score", 0): r["_final_score"] = final_score @@ -428,7 +437,6 @@ def passes_threshold(score: float) -> bool: final_score, ) - # Only recurse into directories (L0/L1). L2 files are terminal hits. if uri not in visited and r.get("level", 2) != 2: heapq.heappush(dir_queue, (-final_score, uri)) @@ -442,7 +450,6 @@ def passes_threshold(score: float) -> bool: if current_topk_uris == prev_topk_uris and len(current_topk_uris) >= limit: convergence_rounds += 1 - if convergence_rounds >= self.MAX_CONVERGENCE_ROUNDS: break else: @@ -526,6 +533,50 @@ async def _convert_to_matched_contexts( results.sort(key=lambda x: x.score, reverse=True) return results + async def _apply_memory_relations(self, matched: List[MatchedContext]) -> List[MatchedContext]: + """Filter superseded memories and annotate with related-memory context. + + When a memory has been superseded (another memory has a ``supersedes`` + relation pointing to it), remove it from results so the caller only + sees the latest version. Additionally, surface ``related_to`` memories + as extra RelatedContext entries (up to MAX_RELATIONS per result). + """ + if not self.memory_relation_store: + return matched + + filtered: List[MatchedContext] = [] + for mc in matched: + # Strip level suffix to get the canonical URI for relation lookup. + canonical_uri = mc.uri + for suffix in (".abstract.md", ".overview.md"): + if canonical_uri.endswith(f"/{suffix}"): + canonical_uri = canonical_uri[: -(len(suffix) + 1)] + break + + if await self.memory_relation_store.is_superseded(canonical_uri): + logger.debug("[retrieve] Filtering superseded memory: %s", canonical_uri) + continue + + # Add related_to memories as additional context. + related_rels = await self.memory_relation_store.query( + canonical_uri, relation_type=RelationType.RELATED_TO, direction="both" + ) + existing_uris = {r.uri for r in mc.relations} + added = 0 + for rel in related_rels: + if added >= self.MAX_RELATIONS: + break + peer_uri = rel.target_uri if rel.source_uri == canonical_uri else rel.source_uri + if peer_uri not in existing_uris: + mc.relations.append( + RelatedContext(uri=peer_uri, abstract=rel.metadata.get("abstract", "")) + ) + existing_uris.add(peer_uri) + added += 1 + + filtered.append(mc) + return filtered + @classmethod def _append_level_suffix(cls, uri: str, level: int) -> str: """Return user-facing URI with L0/L1 suffix reconstructed by level.""" diff --git a/openviking/server/app.py b/openviking/server/app.py index c22794e1..3b51e6c7 100644 --- a/openviking/server/app.py +++ b/openviking/server/app.py @@ -20,6 +20,7 @@ content_router, debug_router, filesystem_router, + memory_relations_router, observer_router, pack_router, relations_router, @@ -177,6 +178,7 @@ async def general_error_handler(request: Request, exc: Exception): app.include_router(debug_router) app.include_router(observer_router) app.include_router(tasks_router) + app.include_router(memory_relations_router) app.include_router(bot_router, prefix="/bot/v1") return app diff --git a/openviking/server/models.py b/openviking/server/models.py index 445da590..e10f1ed0 100644 --- a/openviking/server/models.py +++ b/openviking/server/models.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Response models and error codes for OpenViking HTTP Server.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel @@ -24,6 +24,24 @@ class Response(BaseModel): telemetry: Optional[Dict[str, Any]] = None +class MemoryRelationResponse(BaseModel): + """Response model for a single memory relation.""" + + id: str + source_uri: str + target_uri: str + relation_type: str + created_at: str + metadata: Optional[dict] = None + + +class MemoryRelationListResponse(BaseModel): + """Response model for a list of memory relations.""" + + relations: List[MemoryRelationResponse] + total: int + + # Error code to HTTP status code mapping ERROR_CODE_TO_HTTP_STATUS = { "OK": 200, diff --git a/openviking/server/routers/__init__.py b/openviking/server/routers/__init__.py index 12aa9f34..9ed5de14 100644 --- a/openviking/server/routers/__init__.py +++ b/openviking/server/routers/__init__.py @@ -14,6 +14,7 @@ from openviking.server.routers.search import router as search_router from openviking.server.routers.sessions import router as sessions_router from openviking.server.routers.system import router as system_router +from openviking.server.routers.memory_relations import router as memory_relations_router from openviking.server.routers.tasks import router as tasks_router __all__ = [ @@ -25,6 +26,7 @@ "content_router", "search_router", "relations_router", + "memory_relations_router", "sessions_router", "pack_router", "debug_router", diff --git a/openviking/server/routers/memory_relations.py b/openviking/server/routers/memory_relations.py new file mode 100644 index 00000000..3f56b3de --- /dev/null +++ b/openviking/server/routers/memory_relations.py @@ -0,0 +1,90 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Memory relation endpoints for OpenViking HTTP Server. + +Provides typed relation queries between memories (supersedes, contradicts, +related_to, derived_from). Unlike resource-level relations in relations.py, +memory relations carry a semantic type for conflict detection and retrieval. +""" + +from typing import Optional + +from fastapi import APIRouter, Depends, Query + +from openviking.server.auth import get_request_context +from openviking.server.identity import RequestContext +from openviking.server.models import ErrorInfo, MemoryRelationListResponse, MemoryRelationResponse, Response +from openviking.storage.memory_relation_store import MemoryRelationStore, RelationType + +router = APIRouter(prefix="/api/v1/memories", tags=["memory-relations"]) + +# Module-level store reference, set during service initialization. +_relation_store: Optional[MemoryRelationStore] = None + + +def set_memory_relation_store(store: MemoryRelationStore) -> None: + """Wire the memory relation store into the router (called at startup).""" + global _relation_store + _relation_store = store + + +def _get_store() -> MemoryRelationStore: + if _relation_store is None: + raise RuntimeError("MemoryRelationStore not initialized") + return _relation_store + + +@router.get("/{uri:path}/relations") +async def get_memory_relations( + uri: str, + type: Optional[str] = Query(None, description="Filter by relation type"), + direction: str = Query("both", description="outgoing, incoming, or both"), + _ctx: RequestContext = Depends(get_request_context), +) -> Response: + """Get typed relations for a memory URI. + + Returns all relations where the given URI is either source or target, + optionally filtered by relation type and direction. + """ + store = _get_store() + + relation_type = None + if type: + try: + relation_type = RelationType(type) + except ValueError: + valid = ", ".join(t.value for t in RelationType) + return Response( + status="error", + error=ErrorInfo( + code="INVALID_ARGUMENT", + message=f"Invalid relation type: {type}. Valid types: {valid}", + ), + ) + + valid_directions = ("outgoing", "incoming", "both") + if direction not in valid_directions: + return Response( + status="error", + error=ErrorInfo( + code="INVALID_ARGUMENT", + message=f"Invalid direction: {direction}. Valid: {', '.join(valid_directions)}", + ), + ) + + relations = await store.query(uri, relation_type=relation_type, direction=direction) + + items = [ + MemoryRelationResponse( + id=r.id, + source_uri=r.source_uri, + target_uri=r.target_uri, + relation_type=r.relation_type.value, + created_at=r.created_at.isoformat(), + metadata=r.metadata, + ) + for r in relations + ] + + result = MemoryRelationListResponse(relations=items, total=len(items)) + return Response(status="ok", result=result.model_dump()) diff --git a/openviking/service/core.py b/openviking/service/core.py index b07bdb84..3a7d1f22 100644 --- a/openviking/service/core.py +++ b/openviking/service/core.py @@ -132,7 +132,9 @@ def _init_storage( logger.warning("AGFS client not initialized, skipping queue manager") # Initialize VikingDBManager with QueueManager - self._vikingdb_manager = VikingDBManager(vectordb_config=config.vectordb, queue_manager=self._queue_manager) + self._vikingdb_manager = VikingDBManager( + vectordb_config=config.vectordb, queue_manager=self._queue_manager + ) # Configure queues if QueueManager is available if self._queue_manager: diff --git a/openviking/session/compressor.py b/openviking/session/compressor.py index 7fd066c7..d7a7487d 100644 --- a/openviking/session/compressor.py +++ b/openviking/session/compressor.py @@ -63,9 +63,14 @@ def __init__( vikingdb: VikingDBManager, ): """Initialize session compressor.""" + from openviking.storage.memory_relation_store import MemoryRelationStore + self.vikingdb = vikingdb self.extractor = MemoryExtractor() - self.deduplicator = MemoryDeduplicator(vikingdb=vikingdb) + self.relation_store = MemoryRelationStore() + self.deduplicator = MemoryDeduplicator( + vikingdb=vikingdb, relation_store=self.relation_store + ) self._pending_semantic_changes: Dict[str, Dict[str, set]] = {} def _record_semantic_change( @@ -379,9 +384,7 @@ async def extract_long_term_memories( merged_text = ( f"{action.memory.abstract} {candidate.content}" ) - merged_embed = self.deduplicator.embedder.embed( - merged_text - ) + merged_embed = self.deduplicator.embedder.embed(merged_text) batch_memories.append( (merged_embed.dense_vector, action.memory) ) diff --git a/openviking/session/memory_deduplicator.py b/openviking/session/memory_deduplicator.py index 5b474fd0..208841e2 100644 --- a/openviking/session/memory_deduplicator.py +++ b/openviking/session/memory_deduplicator.py @@ -18,6 +18,11 @@ from openviking.prompts import render_prompt from openviking.server.identity import RequestContext from openviking.storage import VikingDBManager +from openviking.storage.memory_relation_store import ( + MemoryRelation, + MemoryRelationStore, + RelationType, +) from openviking.telemetry import get_current_telemetry from openviking_cli.utils import get_logger from openviking_cli.utils.config import get_openviking_config @@ -84,9 +89,18 @@ def _category_uri_prefix(category: str, user) -> str: def __init__( self, vikingdb: VikingDBManager, + relation_store: Optional[MemoryRelationStore] = None, ): - """Initialize deduplicator.""" + """Initialize deduplicator. + + Args: + vikingdb: VikingDB manager for vector search. + relation_store: Optional memory relation store. When provided, + MERGE decisions automatically create ``supersedes`` relations + and DELETE decisions create ``contradicts`` relations. + """ self.vikingdb = vikingdb + self.relation_store = relation_store config = get_openviking_config() self.embedder = config.embedding.get_query_embedder() @@ -117,7 +131,7 @@ async def deduplicate( # Step 2: LLM decision decision, reason, actions = await self._llm_decision(candidate, similar_memories) - return DedupResult( + result = DedupResult( decision=decision, candidate=candidate, similar_memories=similar_memories, @@ -126,6 +140,47 @@ async def deduplicate( query_vector=query_vector, ) + # Step 3: Record relations for merge/delete actions + await self._record_dedup_relations(result) + + return result + + async def _record_dedup_relations(self, result: DedupResult) -> None: + """Create memory relations based on dedup actions. + + - MERGE -> supersedes (existing memory absorbs candidate and supersedes it) + - DELETE -> contradicts (audit trail before deletion) + """ + if not self.relation_store or not result.actions: + return + + candidate_uri = getattr(result.candidate, "uri", "") + if not candidate_uri: + # Build a provisional URI from category + abstract for tracking. + candidate_uri = ( + f"pending://{result.candidate.category.value}/{result.candidate.abstract[:40]}" + ) + + for action in result.actions: + if action.decision == MemoryActionDecision.MERGE: + # The existing memory survives with merged content, so it + # supersedes the candidate (which is not created). + relation = MemoryRelation( + source_uri=action.memory.uri, + target_uri=candidate_uri, + relation_type=RelationType.SUPERSEDES, + metadata={"reason": action.reason, "dedup_decision": "merge"}, + ) + await self.relation_store.create(relation) + elif action.decision == MemoryActionDecision.DELETE: + relation = MemoryRelation( + source_uri=candidate_uri, + target_uri=action.memory.uri, + relation_type=RelationType.CONTRADICTS, + metadata={"reason": action.reason, "dedup_decision": "delete"}, + ) + await self.relation_store.create(relation) + async def _find_similar_memories( self, candidate: CandidateMemory, @@ -406,7 +461,7 @@ def _extract_facet_key(text: str) -> str: normalized = " ".join(str(text).strip().split()) # Prefer common separators used by extraction templates. - for sep in (":", ":", "-", "—"): + for sep in ("\uff1a", ":", "-", "\u2014"): if sep in normalized: left = normalized.split(sep, 1)[0].strip().lower() if left: diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index c46f2d76..3827ae0b 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -41,6 +41,42 @@ class CollectionSchemas: Centralized collection schema definitions. """ + @staticmethod + def memory_relation_collection(name: str) -> Dict[str, Any]: + """ + Get the schema for the memory relation collection. + + Stores typed edges between memory URIs for conflict detection, + supersession tracking, and related-context retrieval. + No vector field needed - relations are queried by URI and type. + + Args: + name: Collection name + + Returns: + Schema definition for the memory relation collection + """ + return { + "CollectionName": name, + "Description": "Memory relation graph (supersedes, contradicts, related_to, derived_from)", + "Fields": [ + {"FieldName": "id", "FieldType": "string", "IsPrimaryKey": True}, + {"FieldName": "source_uri", "FieldType": "string"}, + {"FieldName": "target_uri", "FieldType": "string"}, + {"FieldName": "relation_type", "FieldType": "string"}, + {"FieldName": "created_at", "FieldType": "date_time"}, + {"FieldName": "metadata", "FieldType": "string"}, + {"FieldName": "account_id", "FieldType": "string"}, + ], + "ScalarIndex": [ + "source_uri", + "target_uri", + "relation_type", + "created_at", + "account_id", + ], + } + @staticmethod def context_collection(name: str, vector_dim: int) -> Dict[str, Any]: """ diff --git a/openviking/storage/memory_relation_store.py b/openviking/storage/memory_relation_store.py new file mode 100644 index 00000000..06bcfcca --- /dev/null +++ b/openviking/storage/memory_relation_store.py @@ -0,0 +1,189 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +""" +Memory relation store for OpenViking. + +Provides typed relation tracking between memories (supersedes, contradicts, +related_to, derived_from) stored as lightweight documents in VikingDB. +Unlike resource-level relations managed by VikingFS, memory relations carry +a semantic type and are designed for conflict detection and retrieval filtering. +""" + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, List, Optional +from uuid import uuid4 + +from openviking_cli.utils import get_logger + +logger = get_logger(__name__) + + +class RelationType(str, Enum): + """Semantic relation type between two memories.""" + + SUPERSEDES = "supersedes" + CONTRADICTS = "contradicts" + RELATED_TO = "related_to" + DERIVED_FROM = "derived_from" + + +@dataclass +class MemoryRelation: + """A typed edge between two memory URIs.""" + + source_uri: str + target_uri: str + relation_type: RelationType + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + metadata: Dict[str, Any] = field(default_factory=dict) + id: str = field(default_factory=lambda: uuid4().hex) + account_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + d = { + "id": self.id, + "source_uri": self.source_uri, + "target_uri": self.target_uri, + "relation_type": self.relation_type.value, + "created_at": self.created_at.isoformat(), + "metadata": self.metadata, + } + if self.account_id is not None: + d["account_id"] = self.account_id + return d + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MemoryRelation": + created_at = data.get("created_at") + if isinstance(created_at, str): + created_at = datetime.fromisoformat(created_at) + elif not isinstance(created_at, datetime): + created_at = datetime.now(timezone.utc) + + return cls( + id=data.get("id", uuid4().hex), + source_uri=data["source_uri"], + target_uri=data["target_uri"], + relation_type=RelationType(data["relation_type"]), + created_at=created_at, + metadata=data.get("metadata", {}), + account_id=data.get("account_id"), + ) + + +class MemoryRelationStore: + """ + In-memory relation store backed by a simple list. + + Relations are stored as MemoryRelation objects and can be queried by + source URI, target URI, or relation type. This store is designed to + be lightweight - a future iteration can persist to VikingDB's collection + mechanism once the schema is validated. + """ + + def __init__(self) -> None: + self._relations: List[MemoryRelation] = [] + + async def create(self, relation: MemoryRelation) -> str: + """Store a new relation. Returns the relation ID.""" + # Prevent exact duplicates (same source, target, type). + for existing in self._relations: + if ( + existing.source_uri == relation.source_uri + and existing.target_uri == relation.target_uri + and existing.relation_type == relation.relation_type + ): + logger.debug( + "Duplicate relation skipped: %s --%s--> %s", + relation.source_uri, + relation.relation_type.value, + relation.target_uri, + ) + return existing.id + + self._relations.append(relation) + logger.debug( + "Created relation %s: %s --%s--> %s", + relation.id, + relation.source_uri, + relation.relation_type.value, + relation.target_uri, + ) + return relation.id + + async def query( + self, + uri: str, + relation_type: Optional[RelationType] = None, + direction: str = "outgoing", + ) -> List[MemoryRelation]: + """Query relations for a memory URI. + + Args: + uri: The memory URI to query. + relation_type: Filter by relation type (optional). + direction: "outgoing" (uri is source), "incoming" (uri is target), + or "both". + + Returns: + List of matching MemoryRelation objects. + """ + results: List[MemoryRelation] = [] + for rel in self._relations: + match = False + if direction in ("outgoing", "both") and rel.source_uri == uri: + match = True + if direction in ("incoming", "both") and rel.target_uri == uri: + match = True + if match and relation_type is not None and rel.relation_type != relation_type: + match = False + if match: + results.append(rel) + return results + + async def delete(self, relation_id: str) -> bool: + """Delete a relation by ID. Returns True if found and deleted.""" + for i, rel in enumerate(self._relations): + if rel.id == relation_id: + self._relations.pop(i) + logger.debug("Deleted relation %s", relation_id) + return True + return False + + async def delete_by_uri(self, uri: str) -> int: + """Delete all relations involving a URI (as source or target). + + Returns the number of relations deleted. + """ + before = len(self._relations) + self._relations = [ + r for r in self._relations if r.source_uri != uri and r.target_uri != uri + ] + deleted = before - len(self._relations) + if deleted: + logger.debug("Deleted %d relations for URI %s", deleted, uri) + return deleted + + async def get_superseded_uris(self, uri: str) -> List[str]: + """Get URIs that the given URI supersedes (outgoing supersedes edges). + + Useful during retrieval to filter out stale memories. + """ + return [ + r.target_uri + for r in self._relations + if r.source_uri == uri and r.relation_type == RelationType.SUPERSEDES + ] + + async def is_superseded(self, uri: str) -> bool: + """Check if a URI has been superseded by a newer memory.""" + return any( + r.target_uri == uri and r.relation_type == RelationType.SUPERSEDES + for r in self._relations + ) + + def count(self) -> int: + """Return total number of stored relations.""" + return len(self._relations) diff --git a/openviking/storage/vectordb/service/server_fastapi.py b/openviking/storage/vectordb/service/server_fastapi.py index 34574e60..2b6edb04 100644 --- a/openviking/storage/vectordb/service/server_fastapi.py +++ b/openviking/storage/vectordb/service/server_fastapi.py @@ -11,7 +11,7 @@ import random import time from contextlib import asynccontextmanager -from typing import Dict, Any +from typing import Any, Dict import uvicorn from fastapi import FastAPI, Request @@ -28,19 +28,19 @@ @asynccontextmanager async def lifespan(app: FastAPI): """Handle application startup and shutdown events. - + Manages resource initialization and cleanup, ensuring graceful shutdown by waiting for all active requests to complete. - + Args: app: The FastAPI application instance """ # Startup logger.info("============ VikingDB Server Starting =============") random.seed(time.time_ns()) - + yield - + # Shutdown logger.info("Waiting for active requests to complete...") while _active_requests > 0: @@ -61,31 +61,30 @@ async def lifespan(app: FastAPI): @app.exception_handler(VikingDBException) async def vikingdb_exception_handler(request: Request, exc: VikingDBException) -> JSONResponse: """Handle VikingDB-specific exceptions. - + Args: request: The incoming HTTP request exc: The VikingDBException that was raised - + Returns: JSONResponse with error details """ return JSONResponse( - status_code=200, - content=error_response(exc.message, exc.code.value, request=request) + status_code=200, content=error_response(exc.message, exc.code.value, request=request) ) @app.middleware("http") async def request_tracking_middleware(request: Request, call_next): """Middleware to track request processing time and active request count. - + Increments active request counter, measures processing time, and adds processing time header to response. - + Args: request: The incoming HTTP request call_next: The next middleware/handler in the chain - + Returns: Response with added X-Process-Time header """ @@ -118,7 +117,7 @@ async def request_tracking_middleware(request: Request, call_next): @app.get("/") async def root() -> Dict[str, str]: """Root endpoint providing basic server information. - + Returns: Dict containing server name and version """ @@ -128,7 +127,7 @@ async def root() -> Dict[str, str]: @app.get("/health") async def health() -> Dict[str, Any]: """Health check endpoint for monitoring server status. - + Returns: Dict containing health status and current active request count """ @@ -140,4 +139,4 @@ async def health() -> Dict[str, Any]: logger.info("Starting VikingDB server on 0.0.0.0:5000") uvicorn.run(app, host="0.0.0.0", port=5000, log_level="info") except Exception as e: - logger.error(f"Failed to start VikingDB server: {e}") \ No newline at end of file + logger.error(f"Failed to start VikingDB server: {e}") diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index 7c45f544..3c25269b 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -166,11 +166,13 @@ def __init__( rerank_config: Optional["RerankConfig"] = None, vector_store: Optional["VikingVectorIndexBackend"] = None, timeout: int = 10, + memory_relation_store: Optional[Any] = None, ): self.agfs = agfs self.query_embedder = query_embedder self.rerank_config = rerank_config self.vector_store = vector_store + self.memory_relation_store = memory_relation_store self._bound_ctx: contextvars.ContextVar[Optional[RequestContext]] = contextvars.ContextVar( "vikingfs_bound_ctx", default=None ) @@ -648,6 +650,7 @@ async def find( storage=storage, embedder=embedder, rerank_config=self.rerank_config, + memory_relation_store=self.memory_relation_store, ) # Infer context_type (None = search all types) @@ -793,6 +796,7 @@ async def search( storage=storage, embedder=embedder, rerank_config=self.rerank_config, + memory_relation_store=self.memory_relation_store, ) async def _execute(tq: TypedQuery): diff --git a/tests/misc/test_embedding_input_type.py b/tests/misc/test_embedding_input_type.py index 4ee8f660..34e4897f 100644 --- a/tests/misc/test_embedding_input_type.py +++ b/tests/misc/test_embedding_input_type.py @@ -9,8 +9,6 @@ from unittest.mock import MagicMock, patch -import pytest - from openviking_cli.utils.config.embedding_config import EmbeddingConfig, EmbeddingModelConfig @@ -86,7 +84,7 @@ def test_legacy_input_type_lowercase_normalization(self): ) assert config.input_type == "search_query" - def test_query_document_param_lowercase_normalization(self): + def test_query_document_param_lowercase_normalization_jina(self): """Query/document task values should be normalized to lowercase.""" config = EmbeddingModelConfig( model="jina-embeddings-v5-text-small", diff --git a/tests/session/test_memory_extractor_language.py b/tests/session/test_memory_extractor_language.py index d1d98021..22edc865 100644 --- a/tests/session/test_memory_extractor_language.py +++ b/tests/session/test_memory_extractor_language.py @@ -104,7 +104,12 @@ def test_detect_output_language_japanese_with_single_cyrillic(): def test_detect_output_language_russian_with_threshold(): """Russian text with sufficient Cyrillic chars should be detected as Russian.""" - messages = [_msg("user", "\u042d\u0442\u043e \u0440\u0443\u0441\u0441\u043a\u0438\u0439 \u0442\u0435\u043a\u0441\u0442")] + messages = [ + _msg( + "user", + "\u042d\u0442\u043e \u0440\u0443\u0441\u0441\u043a\u0438\u0439 \u0442\u0435\u043a\u0441\u0442", + ) + ] language = MemoryExtractor._detect_output_language(messages, fallback_language="en") assert language == "ru" diff --git a/tests/unit/session/test_dedup_relations.py b/tests/unit/session/test_dedup_relations.py new file mode 100644 index 00000000..971adccc --- /dev/null +++ b/tests/unit/session/test_dedup_relations.py @@ -0,0 +1,167 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Integration test: dedup MERGE -> supersedes relation -> retrieval prefers newer memory.""" + +from unittest.mock import MagicMock + +import pytest + +from openviking.core.context import Context +from openviking.session.memory_deduplicator import ( + DedupDecision, + DedupResult, + ExistingMemoryAction, + MemoryActionDecision, + MemoryDeduplicator, +) +from openviking.storage.memory_relation_store import ( + MemoryRelationStore, + RelationType, +) + + +@pytest.fixture +def relation_store(): + return MemoryRelationStore() + + +@pytest.fixture +def old_memory(): + ctx = Context(uri="viking://user/u1/memories/preferences/theme_dark") + ctx.abstract = "User prefers dark mode" + ctx.meta = {"_dedup_score": 0.95} + return ctx + + +@pytest.fixture +def candidate_memory(): + """Mock CandidateMemory with the fields the deduplicator needs.""" + m = MagicMock() + m.uri = "" + m.category.value = "preferences" + m.abstract = "User prefers light mode" + m.content = "The user switched to light mode" + m.overview = "" + return m + + +class TestDedupRelationRecording: + @pytest.mark.asyncio + async def test_merge_creates_supersedes_relation( + self, relation_store, old_memory, candidate_memory + ): + """When dedup decides MERGE, a supersedes relation should be created.""" + result = DedupResult( + decision=DedupDecision.NONE, + candidate=candidate_memory, + similar_memories=[old_memory], + actions=[ + ExistingMemoryAction( + memory=old_memory, + decision=MemoryActionDecision.MERGE, + reason="preference update", + ) + ], + ) + + dedup = MemoryDeduplicator.__new__(MemoryDeduplicator) + dedup.relation_store = relation_store + + await dedup._record_dedup_relations(result) + + assert relation_store.count() == 1 + rels = await relation_store.query(old_memory.uri, direction="incoming") + assert len(rels) == 1 + assert rels[0].relation_type == RelationType.SUPERSEDES + assert rels[0].target_uri == old_memory.uri + + @pytest.mark.asyncio + async def test_delete_creates_contradicts_relation( + self, relation_store, old_memory, candidate_memory + ): + """When dedup decides DELETE, a contradicts relation should be created.""" + result = DedupResult( + decision=DedupDecision.CREATE, + candidate=candidate_memory, + similar_memories=[old_memory], + actions=[ + ExistingMemoryAction( + memory=old_memory, + decision=MemoryActionDecision.DELETE, + reason="contradiction detected", + ) + ], + ) + + dedup = MemoryDeduplicator.__new__(MemoryDeduplicator) + dedup.relation_store = relation_store + + await dedup._record_dedup_relations(result) + + assert relation_store.count() == 1 + rels = await relation_store.query(old_memory.uri, direction="incoming") + assert len(rels) == 1 + assert rels[0].relation_type == RelationType.CONTRADICTS + + @pytest.mark.asyncio + async def test_no_store_no_crash(self, old_memory, candidate_memory): + """Without a relation store, recording should be a no-op.""" + result = DedupResult( + decision=DedupDecision.NONE, + candidate=candidate_memory, + similar_memories=[old_memory], + actions=[ + ExistingMemoryAction( + memory=old_memory, + decision=MemoryActionDecision.MERGE, + reason="test", + ) + ], + ) + + dedup = MemoryDeduplicator.__new__(MemoryDeduplicator) + dedup.relation_store = None + + # Should not raise + await dedup._record_dedup_relations(result) + + @pytest.mark.asyncio + async def test_skip_creates_no_relations(self, relation_store, old_memory, candidate_memory): + """SKIP decisions should not create any relations.""" + result = DedupResult( + decision=DedupDecision.SKIP, + candidate=candidate_memory, + similar_memories=[old_memory], + actions=None, + ) + + dedup = MemoryDeduplicator.__new__(MemoryDeduplicator) + dedup.relation_store = relation_store + + await dedup._record_dedup_relations(result) + + assert relation_store.count() == 0 + + @pytest.mark.asyncio + async def test_superseded_memory_detectable(self, relation_store, old_memory, candidate_memory): + """After MERGE, the old memory should be detected as superseded.""" + result = DedupResult( + decision=DedupDecision.NONE, + candidate=candidate_memory, + similar_memories=[old_memory], + actions=[ + ExistingMemoryAction( + memory=old_memory, + decision=MemoryActionDecision.MERGE, + reason="preference update", + ) + ], + ) + + dedup = MemoryDeduplicator.__new__(MemoryDeduplicator) + dedup.relation_store = relation_store + + await dedup._record_dedup_relations(result) + + # The old memory should now be flagged as superseded + assert await relation_store.is_superseded(old_memory.uri) is True diff --git a/tests/unit/storage/__init__.py b/tests/unit/storage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/storage/test_relation_store.py b/tests/unit/storage/test_relation_store.py new file mode 100644 index 00000000..6269d4ea --- /dev/null +++ b/tests/unit/storage/test_relation_store.py @@ -0,0 +1,156 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from openviking.storage.memory_relation_store import ( + MemoryRelation, + MemoryRelationStore, + RelationType, +) + + +@pytest.fixture +def store(): + return MemoryRelationStore() + + +@pytest.fixture +def sample_relation(): + return MemoryRelation( + source_uri="viking://user/u1/memories/preferences/theme_light", + target_uri="viking://user/u1/memories/preferences/theme_dark", + relation_type=RelationType.SUPERSEDES, + metadata={"reason": "user changed preference"}, + ) + + +class TestMemoryRelationStore: + @pytest.mark.asyncio + async def test_create_returns_id(self, store, sample_relation): + rel_id = await store.create(sample_relation) + assert rel_id == sample_relation.id + assert store.count() == 1 + + @pytest.mark.asyncio + async def test_create_deduplicates(self, store, sample_relation): + await store.create(sample_relation) + dup = MemoryRelation( + source_uri=sample_relation.source_uri, + target_uri=sample_relation.target_uri, + relation_type=sample_relation.relation_type, + ) + returned_id = await store.create(dup) + assert returned_id == sample_relation.id + assert store.count() == 1 + + @pytest.mark.asyncio + async def test_query_outgoing(self, store, sample_relation): + await store.create(sample_relation) + results = await store.query(sample_relation.source_uri, direction="outgoing") + assert len(results) == 1 + assert results[0].target_uri == sample_relation.target_uri + + @pytest.mark.asyncio + async def test_query_incoming(self, store, sample_relation): + await store.create(sample_relation) + results = await store.query(sample_relation.target_uri, direction="incoming") + assert len(results) == 1 + assert results[0].source_uri == sample_relation.source_uri + + @pytest.mark.asyncio + async def test_query_both(self, store, sample_relation): + await store.create(sample_relation) + # Query by source + results_src = await store.query(sample_relation.source_uri, direction="both") + assert len(results_src) == 1 + # Query by target + results_tgt = await store.query(sample_relation.target_uri, direction="both") + assert len(results_tgt) == 1 + + @pytest.mark.asyncio + async def test_query_by_type(self, store, sample_relation): + await store.create(sample_relation) + # Add a different type + related = MemoryRelation( + source_uri=sample_relation.source_uri, + target_uri="viking://user/u1/memories/preferences/font_size", + relation_type=RelationType.RELATED_TO, + ) + await store.create(related) + assert store.count() == 2 + + supersedes = await store.query( + sample_relation.source_uri, + relation_type=RelationType.SUPERSEDES, + direction="outgoing", + ) + assert len(supersedes) == 1 + assert supersedes[0].relation_type == RelationType.SUPERSEDES + + @pytest.mark.asyncio + async def test_query_no_results(self, store): + results = await store.query("viking://nonexistent") + assert results == [] + + @pytest.mark.asyncio + async def test_delete_by_id(self, store, sample_relation): + await store.create(sample_relation) + assert store.count() == 1 + deleted = await store.delete(sample_relation.id) + assert deleted is True + assert store.count() == 0 + + @pytest.mark.asyncio + async def test_delete_nonexistent(self, store): + deleted = await store.delete("nonexistent-id") + assert deleted is False + + @pytest.mark.asyncio + async def test_delete_by_uri(self, store, sample_relation): + await store.create(sample_relation) + # Add another relation referencing the same URI as target + r2 = MemoryRelation( + source_uri="viking://user/u1/memories/preferences/other", + target_uri=sample_relation.target_uri, + relation_type=RelationType.CONTRADICTS, + ) + await store.create(r2) + assert store.count() == 2 + + deleted = await store.delete_by_uri(sample_relation.target_uri) + assert deleted == 2 + assert store.count() == 0 + + @pytest.mark.asyncio + async def test_is_superseded(self, store, sample_relation): + await store.create(sample_relation) + assert await store.is_superseded(sample_relation.target_uri) is True + assert await store.is_superseded(sample_relation.source_uri) is False + assert await store.is_superseded("viking://nonexistent") is False + + @pytest.mark.asyncio + async def test_get_superseded_uris(self, store, sample_relation): + await store.create(sample_relation) + superseded = await store.get_superseded_uris(sample_relation.source_uri) + assert superseded == [sample_relation.target_uri] + + @pytest.mark.asyncio + async def test_from_dict_round_trip(self, sample_relation): + d = sample_relation.to_dict() + restored = MemoryRelation.from_dict(d) + assert restored.source_uri == sample_relation.source_uri + assert restored.target_uri == sample_relation.target_uri + assert restored.relation_type == sample_relation.relation_type + assert restored.id == sample_relation.id + + @pytest.mark.asyncio + async def test_all_relation_types(self, store): + for rtype in RelationType: + r = MemoryRelation( + source_uri=f"viking://src/{rtype.value}", + target_uri=f"viking://tgt/{rtype.value}", + relation_type=rtype, + ) + await store.create(r) + assert store.count() == 4 diff --git a/tests/unit/test_ollama_embedding_factory.py b/tests/unit/test_ollama_embedding_factory.py index d2b3caf1..dd2db492 100644 --- a/tests/unit/test_ollama_embedding_factory.py +++ b/tests/unit/test_ollama_embedding_factory.py @@ -9,7 +9,6 @@ with the openai factory and the placeholder used inside OpenAIDenseEmbedder. """ -import pytest from unittest.mock import MagicMock, patch from openviking_cli.utils.config.embedding_config import EmbeddingConfig, EmbeddingModelConfig @@ -27,7 +26,7 @@ def _make_mock_openai_class(): def _make_ollama_cfg(**kwargs) -> EmbeddingModelConfig: - defaults = dict(provider="ollama", model="nomic-embed-text", dimension=768) + defaults = {"provider": "ollama", "model": "nomic-embed-text", "dimension": 768} defaults.update(kwargs) return EmbeddingModelConfig(**defaults) diff --git a/tests/unit/test_openai_embedder.py b/tests/unit/test_openai_embedder.py index 8e9b72a8..bd5493ca 100644 --- a/tests/unit/test_openai_embedder.py +++ b/tests/unit/test_openai_embedder.py @@ -4,8 +4,6 @@ from unittest.mock import MagicMock, patch -import pytest - from openviking.models.embedder import OpenAIDenseEmbedder diff --git a/tests/unit/test_openai_embedder_chunking.py b/tests/unit/test_openai_embedder_chunking.py index 514aa4df..05a85d36 100644 --- a/tests/unit/test_openai_embedder_chunking.py +++ b/tests/unit/test_openai_embedder_chunking.py @@ -409,7 +409,9 @@ def test_low_custom_max_tokens_triggers_chunking(self, mock_openai_class): # Need text long enough to produce multiple chunks. # Fallback estimation: len(text)//3. With max_tokens=5, need >5 tokens. # Fixed-length split has min chunk_size=100, so text must be >100 chars to split. - text = "Hello world test. " * 30 # 540 chars -> 180 estimated tokens, well over max_tokens=5 + text = ( + "Hello world test. " * 30 + ) # 540 chars -> 180 estimated tokens, well over max_tokens=5 mock_client.embeddings.create.reset_mock() result = embedder.embed(text)