diff --git a/memori/agents/memory_agent.py b/memori/agents/memory_agent.py index 59bcd7e2..180ab9d5 100644 --- a/memori/agents/memory_agent.py +++ b/memori/agents/memory_agent.py @@ -70,15 +70,41 @@ def _detect_database_type(self, db_manager): logger.debug(f"MemoryAgent: Detected database type: {self._database_type}") return self._database_type - SYSTEM_PROMPT = """You are an advanced Memory Processing Agent responsible for analyzing conversations and extracting structured information with intelligent classification and conscious context detection. + SYSTEM_PROMPT = """You are an advanced Memory Processing Agent responsible for analyzing conversations and extracting structured information with intelligent classification, conscious context detection, and comprehensive entity extraction for graph-based search. Your primary functions: 1. **Intelligent Classification**: Categorize memories with enhanced classification system 2. **Conscious Context Detection**: Identify user context information for immediate promotion -3. **Entity Extraction**: Extract comprehensive entities and keywords +3. **Entity Extraction**: Extract comprehensive entities and keywords for graph building 4. **Deduplication**: Identify and handle duplicate information 5. **Context Filtering**: Determine what should be stored vs filtered out +**ENTITY EXTRACTION FOR GRAPH SEARCH (CRITICAL):** +Extract ALL relevant entities across 6 types: +- **person**: Names, authors, team members, users (e.g., "Alice", "Bob", "the team lead") +- **technology**: Technologies, tools, libraries, languages (e.g., "Python", "Docker", "JWT", "FastAPI") +- **topic**: Concepts, subjects, themes (e.g., "authentication", "API design", "performance optimization") +- **skill**: Skills, abilities, competencies (e.g., "debugging", "code review", "system design") +- **project**: Projects, repositories, systems (e.g., "user-dashboard", "payment-service") +- **keyword**: Important terms, acronyms, specific details (e.g., "rate limiting", "OAuth2", "Redis cache") + +**ENTITY EXTRACTION RULES:** +- Extract 5-10 high-quality entities per memory +- Include both explicit mentions AND implicit references +- Normalize technical terms (e.g., "jwt" → "JWT", "docker" → "Docker") +- Extract person names even if only first name mentioned +- Include technology stack components +- Extract project names and system names +- Include important concepts and patterns discussed + +**WHY THIS MATTERS:** +Entities enable graph-based search to: +- Find related memories via shared entities +- Discover context through entity relationships +- Answer queries like "Show me everything about JWT" +- Connect memories discussing similar technologies +- Build knowledge graphs of user's work and interests + **ENHANCED CLASSIFICATION SYSTEM:** **CONSCIOUS_INFO** (Auto-promote to short-term context): diff --git a/memori/agents/retrieval_agent.py b/memori/agents/retrieval_agent.py index 2006cbb8..34c8bff6 100644 --- a/memori/agents/retrieval_agent.py +++ b/memori/agents/retrieval_agent.py @@ -24,37 +24,84 @@ class MemorySearchEngine: Uses OpenAI Structured Outputs to understand queries and plan searches. """ - SYSTEM_PROMPT = """You are a Memory Search Agent responsible for understanding user queries and planning effective memory retrieval strategies. + SYSTEM_PROMPT = """You are an advanced Memory Search Agent with graph-based retrieval capabilities. You understand memory relationships and can plan multi-hop graph traversals. Your primary functions: 1. **Analyze Query Intent**: Understand what the user is actually looking for 2. **Extract Search Parameters**: Identify key entities, topics, and concepts -3. **Plan Search Strategy**: Recommend the best approach to find relevant memories -4. **Filter Recommendations**: Suggest appropriate filters for category, importance, etc. - -**MEMORY CATEGORIES AVAILABLE:** -- **fact**: Factual information, definitions, technical details, specific data points -- **preference**: User preferences, likes/dislikes, settings, personal choices, opinions -- **skill**: Skills, abilities, competencies, learning progress, expertise levels -- **context**: Project context, work environment, current situations, background info -- **rule**: Rules, policies, procedures, guidelines, constraints - -**SEARCH STRATEGIES:** -- **keyword_search**: Direct keyword/phrase matching in content -- **entity_search**: Search by specific entities (people, technologies, topics) -- **category_filter**: Filter by memory categories -- **importance_filter**: Filter by importance levels -- **temporal_filter**: Search within specific time ranges -- **semantic_search**: Conceptual/meaning-based search - -**QUERY INTERPRETATION GUIDELINES:** -- "What did I learn about X?" → Focus on facts and skills related to X -- "My preferences for Y" → Focus on preference category -- "Rules about Z" → Focus on rule category -- "Recent work on A" → Temporal filter + context/skill categories -- "Important information about B" → Importance filter + keyword search - -Be strategic and comprehensive in your search planning.""" +3. **Plan Graph Strategy**: Choose the best graph traversal approach +4. **Configure Graph Expansion**: Set hop distance, relationship strength, and traversal strategy + +**MEMORY CATEGORIES:** +- **fact**: Factual information, definitions, technical details +- **preference**: User preferences, likes/dislikes, opinions +- **skill**: Skills, abilities, expertise levels +- **context**: Project context, work environment, background +- **rule**: Rules, policies, procedures, guidelines + +**GRAPH SEARCH STRATEGIES (choose one):** + +1. **text_only** (~30ms) + - Use for: Simple keyword searches, no context needed + - Example: "Find memories containing 'API key'" + +2. **entity_first** (~100ms) + - Use for: Entity-tagged searches (people, technologies, topics) + - Example: "Show me everything about JWT", "What did Alice say?" + +3. **graph_expansion_1hop** (~150ms) + - Use for: Find directly related context (1-hop away) + - Example: "Tell me about X and related topics" + +4. **graph_expansion_2hop** (~300ms) + - Use for: Deep context discovery (2-hop away) + - Example: "Find everything connected to this project" + +5. **graph_walk_contextual** (~350ms) + - Use for: "Related to X" queries requiring max depth + - Example: "All memories related to authentication system" + +6. **entity_cluster_discovery** (~200ms) + - Use for: Multi-entity queries, finding shared context + - Example: "Memories about both JWT and OAuth" + +7. **category_focused_graph** (~180ms) + - Use for: Category filter + graph expansion + - Example: "Recent facts about Docker, with related context" + +**RELATIONSHIP TYPES (for filtering):** +- semantic_similarity: Similar topics/concepts +- causality: Cause and effect relationships +- reference: One memory references another +- elaboration: Provides more detail +- supports: Reinforces/validates information +- prerequisite: Required knowledge +- temporal: Time-based relationships +- related_entity: Share common entities + +**GRAPH EXPANSION PARAMETERS:** +- **hop_distance**: 0-3 (0=no expansion, 1-3=multi-hop) +- **min_relationship_strength**: 0.0-1.0 (filter weak relationships) +- **expansion_strategy**: breadth_first, depth_first, strongest_first, entity_guided +- **require_entity_overlap**: true/false (only traverse to memories sharing entities) + +**QUERY → STRATEGY EXAMPLES:** + +"What is JWT?" → text_only (simple definition) +"Show me everything about JWT" → entity_first (entity-tagged search) +"JWT and related auth topics" → graph_expansion_1hop (1-hop from JWT) +"All authentication-related memories" → graph_walk_contextual (deep walk) +"Memories about JWT and OAuth together" → entity_cluster_discovery +"Recent Docker facts with context" → category_focused_graph + +**COMPLEX QUERY PLANNING:** +- Single entity + context needed → entity_first or graph_expansion_1hop +- Multiple related entities → entity_cluster_discovery +- "Everything about X" → graph_walk_contextual with max hops +- Category-specific with context → category_focused_graph +- Simple lookup → text_only + +Be strategic and choose the right balance between speed and depth.""" def __init__( self, @@ -129,8 +176,12 @@ def plan_search(self, query: str, context: str | None = None) -> MemorySearchQue logger.debug(f"Using cached search plan for: {query}") return cached_result - # Prepare the prompt - prompt = f"User query: {query}" + # Prepare the prompt - clean query first to prevent duplication + cleaned_query = query.strip() + while cleaned_query.lower().startswith("user query:"): + cleaned_query = cleaned_query[11:].strip() + + prompt = f"User query: {cleaned_query}" if context: prompt += f"\nAdditional context: {context}" @@ -635,8 +686,12 @@ def _plan_search_with_fallback_parsing(self, query: str) -> MemorySearchQuery: but doesn't support structured outputs (like Ollama, local models, etc.) """ try: - # Prepare the prompt from raw query - prompt = f"User query: {query}" + # Prepare the prompt from raw query - clean query first to prevent duplication + cleaned_query = query.strip() + while cleaned_query.lower().startswith("user query:"): + cleaned_query = cleaned_query[11:].strip() + + prompt = f"User query: {cleaned_query}" # Enhanced system prompt for JSON output json_system_prompt = ( @@ -729,6 +784,25 @@ def _create_search_query_from_dict( except ValueError: logger.debug(f"Invalid category filter '{cat_str}', skipping") + # Handle search_strategy - LLM might return list or string + raw_strategy = data.get("search_strategy", "text_only") + if isinstance(raw_strategy, list): + # Take first strategy if list provided + strategy_str = raw_strategy[0] if raw_strategy else "text_only" + else: + strategy_str = raw_strategy + + # Convert to enum, fallback to TEXT_ONLY if invalid + from ..utils.pydantic_models import SearchStrategy + try: + if isinstance(strategy_str, str): + search_strategy_enum = SearchStrategy(strategy_str.lower()) + else: + search_strategy_enum = SearchStrategy.TEXT_ONLY + except (ValueError, AttributeError): + logger.debug(f"Invalid search strategy '{strategy_str}', using TEXT_ONLY") + search_strategy_enum = SearchStrategy.TEXT_ONLY + # Create search query object with proper validation search_query = MemorySearchQuery( query_text=data.get("query_text", original_query), @@ -739,7 +813,7 @@ def _create_search_query_from_dict( min_importance=max( 0.0, min(1.0, float(data.get("min_importance", 0.0))) ), - search_strategy=data.get("search_strategy", ["keyword_search"]), + search_strategy=search_strategy_enum, expected_result_types=data.get("expected_result_types", ["any"]), ) diff --git a/memori/core/conversation.py b/memori/core/conversation.py index 26393419..6fc9a4fc 100644 --- a/memori/core/conversation.py +++ b/memori/core/conversation.py @@ -332,7 +332,7 @@ def _build_auto_context_prompt(self, context: list[dict[str, Any]]) -> str: for mem in context: if isinstance(mem, dict): content = mem.get("searchable_content", "") or mem.get("summary", "") - category = mem.get("category_primary", "") + category = mem.get("category_primary") or "" # Ensure string, not None # Skip duplicates (case-insensitive) content_key = content.lower().strip() @@ -340,7 +340,7 @@ def _build_auto_context_prompt(self, context: list[dict[str, Any]]) -> str: continue seen_content.add(content_key) - if category.startswith("essential_"): + if category and category.startswith("essential_"): context_prompt += f"[{category.upper()}] {content}\n" else: context_prompt += f"- {content}\n" diff --git a/memori/core/memory.py b/memori/core/memory.py index 843faee2..29dab509 100644 --- a/memori/core/memory.py +++ b/memori/core/memory.py @@ -7,7 +7,7 @@ import time import uuid from datetime import datetime -from typing import Any +from typing import Any, Optional from loguru import logger @@ -26,7 +26,12 @@ from ..database.sqlalchemy_manager import SQLAlchemyDatabaseManager from ..utils.exceptions import DatabaseError, MemoriError from ..utils.logging import LoggingManager -from ..utils.pydantic_models import ConversationContext +from ..utils.pydantic_models import ( + ConversationContext, + GraphExpansionConfig, + ScoringWeights, + SearchStrategy, +) from .conversation import ConversationManager @@ -67,6 +72,7 @@ def __init__( database_prefix: str | None = None, # Database name prefix database_suffix: str | None = None, # Database name suffix conscious_memory_limit: int = 10, # Limit for conscious memory processing + graph_search: bool = False, # Use graph-based search by default ): """ Initialize Memori memory system v1.0. @@ -97,6 +103,8 @@ def __init__( enable_auto_creation: Enable automatic database creation if database doesn't exist database_prefix: Optional prefix for database name (for multi-tenant setups) database_suffix: Optional suffix for database name (e.g., 'dev', 'prod', 'test') + conscious_memory_limit: Maximum number of memories to process for conscious context + graph_search: Use graph-based search by default (enables entity and relationship traversal) """ self.database_connect = database_connect self.template = template @@ -111,6 +119,7 @@ def __init__( self.schema_init = schema_init self.database_prefix = database_prefix self.database_suffix = database_suffix + self.graph_search = graph_search # Validate conscious_memory_limit parameter if not isinstance(conscious_memory_limit, int) or conscious_memory_limit < 1: raise ValueError("conscious_memory_limit must be a positive integer") @@ -246,6 +255,53 @@ def __init__( logger.info( f"Agents initialized successfully with model: {effective_model}" ) + + # Initialize graph-based search components + try: + from ..database.graph_search_service import GraphSearchService + from ..processors import ( + EntityExtractionService, + RelationshipDetectionService, + ) + + self.graph_search_service = GraphSearchService(self.db_manager) + + # Initialize entity extraction (using smaller model for speed) + if self.provider_config: + client = self.provider_config.create_client() + else: + import openai + + client = openai.OpenAI(api_key=self.openai_api_key) + + self.entity_extractor = EntityExtractionService( + client=client, model="gpt-4o-mini" + ) + + self.relationship_detector = RelationshipDetectionService( + db_manager=self.db_manager + ) + + logger.info("Graph components initialized successfully") + + # Connect graph components to database manager + if hasattr(self.db_manager, 'set_graph_components'): + self.db_manager.set_graph_components( + graph_search_service=self.graph_search_service, + entity_extractor=self.entity_extractor, + relationship_detector=self.relationship_detector + ) + logger.debug("Graph components connected to database manager") + + except Exception as graph_error: + logger.warning( + f"Graph components initialization failed: {graph_error}. " + "Graph-based search will not be available." + ) + self.graph_search_service = None + self.entity_extractor = None + self.relationship_detector = None + except ImportError as e: logger.warning( f"Failed to import LLM agents: {e}. Memory ingestion disabled." @@ -807,6 +863,17 @@ def disable_interceptor(self, interceptor_name: str = None) -> bool: def _inject_openai_context(self, kwargs): """Inject context for OpenAI calls based on ingest mode using ConversationManager""" try: + # Check if we're in a recursion guard - if so, skip injection entirely + if hasattr(self, "_in_context_retrieval") and self._in_context_retrieval: + logger.debug("OpenAI context injection skipped - inside retrieval operation") + return kwargs + + # Check for system-only messages (internal operations) - skip injection + messages = kwargs.get("messages", []) + if messages and len(messages) == 1 and messages[0].get("role") == "system": + logger.debug("OpenAI context injection skipped - internal system operation") + return kwargs + # Check for deferred conscious initialization self._check_deferred_initialization() @@ -1194,7 +1261,36 @@ def _get_auto_ingest_context(self, user_input: str) -> list[dict[str, Any]]: f"Auto-ingest: Starting context retrieval for query: '{user_input[:50]}...' in namespace: '{self.namespace}'" ) - # Always try direct database search first as it's more reliable + # Use graph search if enabled, otherwise use direct database search + if self.graph_search: + logger.debug("Auto-ingest: Using graph-based search (graph_search=True)") + try: + # Use the search() method which has graph search logic + results = self.search(query=user_input, limit=5) + logger.debug( + f"Auto-ingest: Graph search returned {len(results) if results else 0} results" + ) + + if results: + for i, result in enumerate( + results[:3] + ): # Log first 3 results for debugging + logger.debug( + f"Auto-ingest: Result {i+1}: {type(result)} with keys: {list(result.keys()) if isinstance(result, dict) else 'N/A'}" + ) + + # Add search metadata + for result in results: + if isinstance(result, dict): + result["retrieval_method"] = "graph_search" + result["retrieval_query"] = user_input + return results + except Exception as graph_search_e: + logger.warning(f"Auto-ingest: Graph search failed: {graph_search_e}") + logger.debug("Auto-ingest: Falling back to direct database search") + results = [] + + # Direct database search (default or fallback) logger.debug("Auto-ingest: Using direct database search (primary method)") logger.debug( f"Auto-ingest: Database manager type: {type(self.db_manager).__name__}" @@ -2455,12 +2551,33 @@ def add(self, text: str, metadata: dict[str, Any] | None = None) -> str: metadata=metadata or {"type": "manual_memory", "source": "add_method"}, ) + def _get_memory_count(self) -> int: + """Get total count of long-term memories in database""" + try: + with self.db_manager.get_session() as session: + from ..database.models import LongTermMemory + + count = ( + session.query(LongTermMemory) + .filter(LongTermMemory.namespace == self.namespace) + .count() + ) + return count + except Exception as e: + logger.warning(f"Failed to get memory count: {e}") + return 0 + def search(self, query: str, limit: int = 5) -> list[dict[str, Any]]: """ Search for memories/conversations based on a query. This is a unified method that works with both SQL and MongoDB backends. + **Automatic Search Strategy:** + - When memories < 10: Uses text-based search (graph not useful yet) + - When memories >= 10: Uses graph-based search with entity and relationship traversal + - If graph_search=True is explicitly set, respects that setting + Args: query: Search query string limit: Maximum number of results to return @@ -2473,8 +2590,154 @@ def search(self, query: str, limit: int = 5) -> list[dict[str, Any]]: return [] try: - # Use the existing retrieve_context method for consistency - return self.retrieve_context(query, limit=limit) + # Clean the query - remove "User query:" prefix if present (prevents recursion issues) + cleaned_query = query.strip() + while cleaned_query.lower().startswith("user query:"): + cleaned_query = cleaned_query[11:].strip() + + # Use cleaned query for search + search_query = cleaned_query if cleaned_query else query + + # Determine if we should use graph search (automatic or explicit) + memory_count = self._get_memory_count() + use_graph = False + + if hasattr(self, 'graph_search_service') and self.graph_search_service: + if hasattr(self, 'graph_search') and self.graph_search: + # Explicit graph_search=True parameter + use_graph = memory_count >= 5 + if memory_count < 5: + logger.debug( + f"graph_search=True but only {memory_count} memories (need 5+), using text search" + ) + elif memory_count >= 10: + # Automatic graph search when enough memories + use_graph = True + logger.debug( + f"Auto-enabling graph search ({memory_count} memories >= 10 threshold)" + ) + + # Use graph search if conditions are met + if use_graph: + try: + logger.debug(f"Using graph-based search for query: '{search_query}'") + + entity_filters: list[str] = [] + category_filters: list[str] = [] + graph_strategy = SearchStrategy.GRAPH_EXPANSION_1HOP + hop_distance = 1 + scoring_weights = None + graph_config: Optional[GraphExpansionConfig] = GraphExpansionConfig( + enabled=True, + hop_distance=hop_distance, + min_relationship_strength=0.2, + ) + + if self.search_engine: + try: + plan = self.search_engine.plan_search(search_query) + entity_filters = plan.entity_filters or [] + category_filters = [c.value for c in plan.category_filters] + + raw_strategies = plan.search_strategy or [] + + def _normalize_strategy(value: Any) -> str: + if hasattr(value, "value"): + return str(value.value) + if isinstance(value, str): + return value + return str(value) + + requested_strategies = { + _normalize_strategy(s).lower() for s in raw_strategies + } + + if "text_only" in requested_strategies: + graph_strategy = SearchStrategy.TEXT_ONLY + graph_config = None + entity_filters = [] + category_filters = [] + elif { + "graph_walk", + "graph_traversal", + "graph_context", + } & requested_strategies: + graph_strategy = SearchStrategy.GRAPH_WALK_CONTEXTUAL + hop_distance = 3 + elif "graph_expansion" in requested_strategies: + graph_strategy = SearchStrategy.GRAPH_EXPANSION_2HOP + hop_distance = 2 + elif "entity_search" in requested_strategies and entity_filters: + graph_strategy = SearchStrategy.ENTITY_FIRST + graph_config = None + elif ( + "category_filter" in requested_strategies + and category_filters + ): + graph_strategy = SearchStrategy.CATEGORY_FOCUSED_GRAPH + hop_distance = 1 + + if graph_config is not None and graph_strategy != SearchStrategy.TEXT_ONLY: + graph_config = graph_config.copy(update={"hop_distance": hop_distance}) + + except Exception as plan_error: + logger.debug( + f"Search planner unavailable or failed: {plan_error}" + ) + + if graph_strategy == SearchStrategy.TEXT_ONLY: + logger.debug("Plan requested TEXT_ONLY search; using direct retrieval") + text_results = self.retrieve_context(search_query, limit=limit) + if text_results: + return text_results + logger.debug("TEXT_ONLY retrieval returned no results, continuing to graph fallback") + + # Use GraphSearchService directly + graph_results = self.graph_search_service.search( + query_text=search_query, + strategy=graph_strategy, + namespace=self.namespace, + entities=entity_filters, + categories=category_filters, + graph_expansion=graph_config, + scoring_weights=scoring_weights, + max_results=limit + ) + + if graph_results: + logger.debug(f"Graph search returned {len(graph_results)} results") + + # Convert GraphSearchResult objects to dict format + results = [] + for r in graph_results: + results.append({ + 'memory_id': r.memory_id, + 'summary': r.summary, + 'content': r.content, + 'importance_score': r.importance_score, + 'created_at': r.timestamp, + 'category_primary': r.category.value if r.category else None, + 'composite_score': r.composite_score, + 'search_strategy': 'graph_search', + # Graph metadata + 'hop_distance': r.hop_distance, + 'shared_entities': r.shared_entities, + 'match_reason': r.match_reason, + 'graph_strength_score': r.graph_strength_score, + 'entity_overlap_score': r.entity_overlap_score, + }) + + return results + + logger.debug("Graph search returned no results, falling back to text search") + + except Exception as graph_error: + logger.warning(f"Graph search failed, falling back to text search: {graph_error}") + logger.debug("Graph search error details", exc_info=True) + + # Fall back to traditional text search + return self.retrieve_context(search_query, limit=limit) + except Exception as e: logger.error(f"Search failed: {e}") return [] diff --git a/memori/database/graph_queries/__init__.py b/memori/database/graph_queries/__init__.py new file mode 100644 index 00000000..3e0803ed --- /dev/null +++ b/memori/database/graph_queries/__init__.py @@ -0,0 +1,50 @@ +""" +Graph Query Builders for Database-Agnostic Graph Traversal +Provides query builders for PostgreSQL, MySQL, SQLite, and MongoDB +""" + +from .base import GraphQueryBuilder +from .postgresql import PostgreSQLGraphQueryBuilder +from .mysql import MySQLGraphQueryBuilder +from .sqlite import SQLiteGraphQueryBuilder +from .mongodb import MongoDBGraphQueryBuilder + + +def get_query_builder(dialect: str) -> GraphQueryBuilder: + """ + Factory function to get appropriate query builder for database dialect + + Args: + dialect: Database dialect name (postgresql, mysql, sqlite, mongodb) + + Returns: + Appropriate GraphQueryBuilder instance + + Raises: + ValueError: If dialect is not supported + """ + dialect = dialect.lower() + + if dialect == "postgresql": + return PostgreSQLGraphQueryBuilder() + elif dialect == "mysql": + return MySQLGraphQueryBuilder() + elif dialect == "sqlite": + return SQLiteGraphQueryBuilder() + elif dialect == "mongodb": + return MongoDBGraphQueryBuilder() + else: + raise ValueError( + f"Unsupported database dialect: {dialect}. " + f"Supported: postgresql, mysql, sqlite, mongodb" + ) + + +__all__ = [ + "GraphQueryBuilder", + "PostgreSQLGraphQueryBuilder", + "MySQLGraphQueryBuilder", + "SQLiteGraphQueryBuilder", + "MongoDBGraphQueryBuilder", + "get_query_builder", +] diff --git a/memori/database/graph_queries/base.py b/memori/database/graph_queries/base.py new file mode 100644 index 00000000..aff8426b --- /dev/null +++ b/memori/database/graph_queries/base.py @@ -0,0 +1,167 @@ +""" +Base Graph Query Builder Interface +Defines the contract for all database-specific query builders +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + + +class GraphQueryBuilder(ABC): + """Abstract base class for database-specific graph query builders""" + + @abstractmethod + def build_entity_search_query( + self, + entities: List[str], + entity_types: Optional[List[str]] = None, + namespace: str = "default", + min_relevance: float = 0.0, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """ + Build query to find memories by entity values + + Args: + entities: List of entity values to search for + entity_types: Optional filter by entity types + namespace: Memory namespace + min_relevance: Minimum relevance score threshold + limit: Maximum results + + Returns: + Tuple of (query_string, params_dict) + """ + pass + + @abstractmethod + def build_graph_expansion_query( + self, + seed_memory_ids: List[str], + hop_distance: int, + min_strength: float, + relationship_types: Optional[List[str]] = None, + namespace: str = "default", + limit_per_hop: int = 10, + ) -> tuple[str, Dict[str, Any]]: + """ + Build query to expand from seed memories via graph relationships + + Args: + seed_memory_ids: Starting memory IDs + hop_distance: Number of hops to traverse (1-3) + min_strength: Minimum relationship strength + relationship_types: Optional filter by relationship types + namespace: Memory namespace + limit_per_hop: Maximum results per hop level + + Returns: + Tuple of (query_string, params_dict) + """ + pass + + @abstractmethod + def build_entity_cluster_query( + self, + entities: List[str], + namespace: str = "default", + min_shared_entities: int = 2, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """ + Build query to find memories that share multiple entities + + Args: + entities: List of entity values + namespace: Memory namespace + min_shared_entities: Minimum number of shared entities + limit: Maximum results + + Returns: + Tuple of (query_string, params_dict) + """ + pass + + @abstractmethod + def build_relationship_discovery_query( + self, + memory_id: str, + relationship_types: Optional[List[str]] = None, + min_strength: float = 0.5, + namespace: str = "default", + limit: int = 20, + ) -> tuple[str, Dict[str, Any]]: + """ + Build query to find all direct relationships for a memory + + Args: + memory_id: Memory ID to find relationships for + relationship_types: Optional filter by types + min_strength: Minimum relationship strength + namespace: Memory namespace + limit: Maximum results + + Returns: + Tuple of (query_string, params_dict) + """ + pass + + @abstractmethod + def build_path_finding_query( + self, + source_memory_id: str, + target_memory_id: str, + max_depth: int = 3, + namespace: str = "default", + ) -> tuple[str, Dict[str, Any]]: + """ + Build query to find paths between two memories + + Args: + source_memory_id: Starting memory + target_memory_id: Destination memory + max_depth: Maximum path length + namespace: Memory namespace + + Returns: + Tuple of (query_string, params_dict) + """ + pass + + @abstractmethod + def build_shared_entities_query( + self, + memory_id: str, + namespace: str = "default", + min_overlap: int = 1, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """ + Build query to find memories sharing entities with given memory + + Args: + memory_id: Memory ID to compare against + namespace: Memory namespace + min_overlap: Minimum number of shared entities + limit: Maximum results + + Returns: + Tuple of (query_string, params_dict) + """ + pass + + def supports_recursive_cte(self) -> bool: + """Whether this database supports recursive CTEs""" + return True + + def get_parameter_placeholder(self, param_name: str) -> str: + """ + Get database-specific parameter placeholder + + Args: + param_name: Parameter name + + Returns: + Placeholder string (e.g., ?, %s, :param_name) + """ + return f":{param_name}" diff --git a/memori/database/graph_queries/mongodb.py b/memori/database/graph_queries/mongodb.py new file mode 100644 index 00000000..eef77afb --- /dev/null +++ b/memori/database/graph_queries/mongodb.py @@ -0,0 +1,335 @@ +""" +MongoDB Graph Query Builder +Uses aggregation pipelines and $graphLookup for graph traversal +""" + +from typing import Any, Dict, List, Optional + +from .base import GraphQueryBuilder + + +class MongoDBGraphQueryBuilder(GraphQueryBuilder): + """MongoDB-specific graph query builder using aggregation framework""" + + def build_entity_search_query( + self, + entities: List[str], + entity_types: Optional[List[str]] = None, + namespace: str = "default", + min_relevance: float = 0.0, + limit: int = 50, + ) -> tuple[Dict[str, Any], Dict[str, Any]]: + """Build MongoDB aggregation pipeline for entity search""" + + # Normalize entities + normalized_entities = [e.lower() for e in entities] + + # Build aggregation pipeline + pipeline = [ + { + "$match": { + "namespace": namespace, + "normalized_value": {"$in": normalized_entities}, + "relevance_score": {"$gte": min_relevance}, + } + } + ] + + if entity_types: + pipeline[0]["$match"]["entity_type"] = {"$in": entity_types} + + pipeline.extend( + [ + { + "$group": { + "_id": {"memory_id": "$memory_id", "memory_type": "$memory_type"}, + "entity_match_count": {"$sum": 1}, + "avg_relevance": {"$avg": "$relevance_score"}, + "matched_entities": {"$addToSet": "$entity_value"}, + } + }, + {"$sort": {"entity_match_count": -1, "avg_relevance": -1}}, + {"$limit": limit}, + { + "$project": { + "memory_id": "$_id.memory_id", + "memory_type": "$_id.memory_type", + "entity_match_count": 1, + "avg_relevance": 1, + "matched_entities": 1, + "_id": 0, + } + }, + ] + ) + + # MongoDB doesn't use traditional params, return pipeline + return {"__mongodb_pipeline": pipeline}, {} + + def build_graph_expansion_query( + self, + seed_memory_ids: List[str], + hop_distance: int, + min_strength: float, + relationship_types: Optional[List[str]] = None, + namespace: str = "default", + limit_per_hop: int = 10, + ) -> tuple[Dict[str, Any], Dict[str, Any]]: + """Build MongoDB $graphLookup pipeline for graph expansion""" + + match_condition = { + "namespace": namespace, + "strength": {"$gte": min_strength}, + } + + if relationship_types: + match_condition["relationship_type"] = {"$in": relationship_types} + + pipeline = [ + {"$match": {"memory_id": {"$in": seed_memory_ids}}}, + { + "$graphLookup": { + "from": "memory_relationships", + "startWith": "$memory_id", + "connectFromField": "target_memory_id", + "connectToField": "source_memory_id", + "as": "graph_path", + "maxDepth": hop_distance - 1, + "depthField": "hop", + "restrictSearchWithMatch": match_condition, + } + }, + {"$unwind": "$graph_path"}, + { + "$group": { + "_id": "$graph_path.target_memory_id", + "hop": {"$min": "$graph_path.hop"}, + "cumulative_strength": {"$max": "$graph_path.strength"}, + "relationship_types": {"$addToSet": "$graph_path.relationship_type"}, + } + }, + {"$sort": {"hop": 1, "cumulative_strength": -1}}, + {"$limit": limit_per_hop * hop_distance}, + { + "$project": { + "memory_id": "$_id", + "hop": {"$add": ["$hop", 1]}, # Adjust hop count + "cumulative_strength": 1, + "relationship_types": 1, + "_id": 0, + } + }, + ] + + return {"__mongodb_pipeline": pipeline}, {} + + def build_entity_cluster_query( + self, + entities: List[str], + namespace: str = "default", + min_shared_entities: int = 2, + limit: int = 50, + ) -> tuple[Dict[str, Any], Dict[str, Any]]: + """Find memories sharing multiple entities""" + + normalized_entities = [e.lower() for e in entities] + + pipeline = [ + { + "$match": { + "namespace": namespace, + "normalized_value": {"$in": normalized_entities}, + } + }, + { + "$group": { + "_id": {"memory_id": "$memory_id", "memory_type": "$memory_type"}, + "shared_entity_count": {"$addToSet": "$normalized_value"}, + "avg_relevance": {"$avg": "$relevance_score"}, + "shared_entities": {"$addToSet": "$entity_value"}, + } + }, + {"$match": {"shared_entity_count": {"$gte": min_shared_entities}}}, + { + "$addFields": { + "shared_entity_count": {"$size": "$shared_entity_count"} + } + }, + {"$sort": {"shared_entity_count": -1, "avg_relevance": -1}}, + {"$limit": limit}, + { + "$project": { + "memory_id": "$_id.memory_id", + "memory_type": "$_id.memory_type", + "shared_entity_count": 1, + "avg_relevance": 1, + "shared_entities": 1, + "_id": 0, + } + }, + ] + + return {"__mongodb_pipeline": pipeline}, {} + + def build_relationship_discovery_query( + self, + memory_id: str, + relationship_types: Optional[List[str]] = None, + min_strength: float = 0.5, + namespace: str = "default", + limit: int = 20, + ) -> tuple[Dict[str, Any], Dict[str, Any]]: + """Find all direct relationships for a memory""" + + match_condition = { + "$or": [ + {"source_memory_id": memory_id}, + {"target_memory_id": memory_id}, + ], + "namespace": namespace, + "strength": {"$gte": min_strength}, + } + + if relationship_types: + match_condition["relationship_type"] = {"$in": relationship_types} + + pipeline = [ + {"$match": match_condition}, + {"$sort": {"strength": -1}}, + {"$limit": limit}, + { + "$addFields": { + "related_memory_id": { + "$cond": { + "if": {"$eq": ["$source_memory_id", memory_id]}, + "then": "$target_memory_id", + "else": "$source_memory_id", + } + } + } + }, + ] + + return {"__mongodb_pipeline": pipeline}, {} + + def build_path_finding_query( + self, + source_memory_id: str, + target_memory_id: str, + max_depth: int = 3, + namespace: str = "default", + ) -> tuple[Dict[str, Any], Dict[str, Any]]: + """Find paths between two memories using $graphLookup""" + + pipeline = [ + {"$match": {"source_memory_id": source_memory_id, "namespace": namespace}}, + { + "$graphLookup": { + "from": "memory_relationships", + "startWith": "$target_memory_id", + "connectFromField": "target_memory_id", + "connectToField": "source_memory_id", + "as": "path", + "maxDepth": max_depth - 1, + "depthField": "depth", + } + }, + { + "$match": { + "path.target_memory_id": target_memory_id, + } + }, + {"$limit": 5}, + { + "$project": { + "path": 1, + "depth": {"$size": "$path"}, + "total_strength": {"$multiply": ["$strength", "$path.strength"]}, + } + }, + {"$sort": {"depth": 1, "total_strength": -1}}, + ] + + return {"__mongodb_pipeline": pipeline}, {} + + def build_shared_entities_query( + self, + memory_id: str, + namespace: str = "default", + min_overlap: int = 1, + limit: int = 50, + ) -> tuple[Dict[str, Any], Dict[str, Any]]: + """Find memories sharing entities with given memory""" + + pipeline = [ + # Step 1: Get source memory entities + { + "$match": { + "memory_id": memory_id, + "namespace": namespace, + } + }, + { + "$group": { + "_id": None, + "source_entities": {"$addToSet": "$normalized_value"}, + } + }, + # Step 2: Find memories with matching entities + { + "$lookup": { + "from": "memory_entities", + "let": {"source_ents": "$source_entities"}, + "pipeline": [ + { + "$match": { + "$expr": { + "$and": [ + {"$ne": ["$memory_id", memory_id]}, + {"$eq": ["$namespace", namespace]}, + {"$in": ["$normalized_value", "$$source_ents"]}, + ] + } + } + }, + { + "$group": { + "_id": { + "memory_id": "$memory_id", + "memory_type": "$memory_type", + }, + "shared_count": {"$sum": 1}, + "avg_relevance": {"$avg": "$relevance_score"}, + "shared_entities": {"$addToSet": "$entity_value"}, + } + }, + {"$match": {"shared_count": {"$gte": min_overlap}}}, + {"$sort": {"shared_count": -1, "avg_relevance": -1}}, + {"$limit": limit}, + ], + "as": "matches", + } + }, + {"$unwind": "$matches"}, + {"$replaceRoot": {"newRoot": "$matches"}}, + { + "$project": { + "memory_id": "$_id.memory_id", + "memory_type": "$_id.memory_type", + "shared_count": 1, + "avg_relevance": 1, + "shared_entities": 1, + "_id": 0, + } + }, + ] + + return {"__mongodb_pipeline": pipeline}, {} + + def get_parameter_placeholder(self, param_name: str) -> str: + """MongoDB doesn't use SQL-style placeholders""" + return f"${param_name}" + + def supports_recursive_cte(self) -> bool: + """MongoDB doesn't use CTEs, but has $graphLookup""" + return False diff --git a/memori/database/graph_queries/mysql.py b/memori/database/graph_queries/mysql.py new file mode 100644 index 00000000..9203fffa --- /dev/null +++ b/memori/database/graph_queries/mysql.py @@ -0,0 +1,272 @@ +""" +MySQL Graph Query Builder +Uses temporary tables and iterative queries for graph traversal +""" + +from typing import Any, Dict, List, Optional + +from .base import GraphQueryBuilder + + +class MySQLGraphQueryBuilder(GraphQueryBuilder): + """MySQL-specific graph query builder""" + + def supports_recursive_cte(self) -> bool: # type: ignore[override] + return False + + def build_entity_search_query( + self, + entities: List[str], + entity_types: Optional[List[str]] = None, + namespace: str = "default", + min_relevance: float = 0.0, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """Find memories by entity values""" + + normalized_entities = [e.lower() for e in entities] + + # Build IN clause for entities + entity_placeholders = ",".join([f"%s" for _ in normalized_entities]) + + query = f""" + SELECT + me.memory_id, + me.memory_type, + COUNT(*) as entity_match_count, + AVG(me.relevance_score) as avg_relevance, + GROUP_CONCAT(DISTINCT me.entity_value) as matched_entities + FROM memory_entities me + WHERE me.namespace = %s + AND me.normalized_value IN ({entity_placeholders}) + AND me.relevance_score >= %s + """ + + params_list = [namespace] + normalized_entities + [min_relevance] + + if entity_types: + type_placeholders = ",".join(["%s" for _ in entity_types]) + query += f" AND me.entity_type IN ({type_placeholders})" + params_list.extend(entity_types) + + query += """ + GROUP BY me.memory_id, me.memory_type + ORDER BY entity_match_count DESC, avg_relevance DESC + LIMIT %s + """ + + params_list.append(limit) + + # Convert to dict for consistency + params = {"__mysql_params": params_list} + + return query, params + + def build_graph_expansion_query( + self, + seed_memory_ids: List[str], + hop_distance: int, + min_strength: float, + relationship_types: Optional[List[str]] = None, + namespace: str = "default", + limit_per_hop: int = 10, + ) -> tuple[str, Dict[str, Any]]: + """ + Note: MySQL recursive CTE support is limited + This returns a query for 1-hop expansion + Multi-hop should be handled iteratively by the service layer + """ + + seed_placeholders = ",".join(["%s" for _ in seed_memory_ids]) + + query = f""" + SELECT DISTINCT + CASE + WHEN mr.source_memory_id IN ({seed_placeholders}) THEN mr.target_memory_id + ELSE mr.source_memory_id + END as memory_id, + CASE + WHEN mr.source_memory_id IN ({seed_placeholders}) THEN mr.source_memory_id + ELSE mr.target_memory_id + END as via_memory_id, + mr.strength as edge_strength, + mr.relationship_type + FROM memory_relationships mr + WHERE (mr.source_memory_id IN ({seed_placeholders}) + OR mr.target_memory_id IN ({seed_placeholders})) + AND mr.namespace = %s + AND mr.strength >= %s + """ + + params_list = ( + seed_memory_ids + + seed_memory_ids + + seed_memory_ids + + seed_memory_ids + + [namespace, min_strength] + ) + + if relationship_types: + type_placeholders = ",".join(["%s" for _ in relationship_types]) + query += f" AND mr.relationship_type IN ({type_placeholders})" + params_list.extend(relationship_types) + + query += """ + ORDER BY edge_strength DESC + LIMIT %s + """ + + params_list.append(limit_per_hop) + + params = {"__mysql_params": params_list} + + return query, params + + def build_entity_cluster_query( + self, + entities: List[str], + namespace: str = "default", + min_shared_entities: int = 2, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """Find memories sharing multiple entities""" + + normalized_entities = [e.lower() for e in entities] + entity_placeholders = ",".join(["%s" for _ in normalized_entities]) + + query = f""" + SELECT + me.memory_id, + me.memory_type, + COUNT(DISTINCT me.normalized_value) as shared_entity_count, + AVG(me.relevance_score) as avg_relevance, + GROUP_CONCAT(DISTINCT me.entity_value) as shared_entities + FROM memory_entities me + WHERE me.namespace = %s + AND me.normalized_value IN ({entity_placeholders}) + GROUP BY me.memory_id, me.memory_type + HAVING COUNT(DISTINCT me.normalized_value) >= %s + ORDER BY shared_entity_count DESC, avg_relevance DESC + LIMIT %s + """ + + params_list = [namespace] + normalized_entities + [min_shared_entities, limit] + params = {"__mysql_params": params_list} + + return query, params + + def build_relationship_discovery_query( + self, + memory_id: str, + relationship_types: Optional[List[str]] = None, + min_strength: float = 0.5, + namespace: str = "default", + limit: int = 20, + ) -> tuple[str, Dict[str, Any]]: + """Find all direct relationships for a memory""" + + query = """ + SELECT + mr.relationship_id, + mr.source_memory_id, + mr.target_memory_id, + mr.relationship_type, + mr.strength, + mr.reasoning, + mr.shared_entity_count, + CASE + WHEN mr.source_memory_id = %s THEN mr.target_memory_id + ELSE mr.source_memory_id + END as related_memory_id + FROM memory_relationships mr + WHERE (mr.source_memory_id = %s OR mr.target_memory_id = %s) + AND mr.namespace = %s + AND mr.strength >= %s + """ + + params_list = [memory_id, memory_id, memory_id, namespace, min_strength] + + if relationship_types: + type_placeholders = ",".join(["%s" for _ in relationship_types]) + query += f" AND mr.relationship_type IN ({type_placeholders})" + params_list.extend(relationship_types) + + query += """ + ORDER BY mr.strength DESC + LIMIT %s + """ + + params_list.append(limit) + params = {"__mysql_params": params_list} + + return query, params + + def build_path_finding_query( + self, + source_memory_id: str, + target_memory_id: str, + max_depth: int = 3, + namespace: str = "default", + ) -> tuple[str, Dict[str, Any]]: + """Note: Path finding in MySQL requires iterative queries""" + + # Return simple 1-hop check + query = """ + SELECT + source_memory_id, + target_memory_id, + relationship_type, + strength, + 1 as depth + FROM memory_relationships + WHERE ((source_memory_id = %s AND target_memory_id = %s) + OR (source_memory_id = %s AND target_memory_id = %s)) + AND namespace = %s + LIMIT 1 + """ + + params_list = [source_memory_id, target_memory_id, target_memory_id, source_memory_id, namespace] + params = {"__mysql_params": params_list} + + return query, params + + def build_shared_entities_query( + self, + memory_id: str, + namespace: str = "default", + min_overlap: int = 1, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """Find memories sharing entities with given memory""" + + query = """ + SELECT + me2.memory_id, + me2.memory_type, + COUNT(DISTINCT me2.normalized_value) as shared_count, + AVG(me2.relevance_score) as avg_relevance, + GROUP_CONCAT(DISTINCT me2.entity_value) as shared_entities + FROM memory_entities me1 + JOIN memory_entities me2 ON me1.normalized_value = me2.normalized_value + WHERE me1.memory_id = %s + AND me2.memory_id != %s + AND me1.namespace = %s + AND me2.namespace = %s + GROUP BY me2.memory_id, me2.memory_type + HAVING COUNT(DISTINCT me2.normalized_value) >= %s + ORDER BY shared_count DESC, avg_relevance DESC + LIMIT %s + """ + + params_list = [memory_id, memory_id, namespace, namespace, min_overlap, limit] + params = {"__mysql_params": params_list} + + return query, params + + def get_parameter_placeholder(self, param_name: str) -> str: + """MySQL uses %s placeholders""" + return "%s" + + def supports_recursive_cte(self) -> bool: + """MySQL 8.0+ supports recursive CTEs but with limitations""" + return False # Return False to use iterative approach diff --git a/memori/database/graph_queries/postgresql.py b/memori/database/graph_queries/postgresql.py new file mode 100644 index 00000000..86f6370e --- /dev/null +++ b/memori/database/graph_queries/postgresql.py @@ -0,0 +1,351 @@ +""" +PostgreSQL Graph Query Builder +Uses recursive CTEs for efficient graph traversal +""" + +from typing import Any, Dict, List, Optional + +from .base import GraphQueryBuilder + + +class PostgreSQLGraphQueryBuilder(GraphQueryBuilder): + """PostgreSQL-specific graph query builder with recursive CTEs""" + + def build_entity_search_query( + self, + entities: List[str], + entity_types: Optional[List[str]] = None, + namespace: str = "default", + min_relevance: float = 0.0, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """Find memories by entity values""" + + # Normalize entities for case-insensitive search + normalized_entities = [e.lower() for e in entities] + + query = """ + SELECT DISTINCT + me.memory_id, + me.memory_type, + COUNT(*) as entity_match_count, + AVG(me.relevance_score) as avg_relevance, + array_agg(DISTINCT me.entity_value) as matched_entities + FROM memory_entities me + WHERE me.namespace = :namespace + AND me.normalized_value = ANY(:entities) + AND me.relevance_score >= :min_relevance + """ + + if entity_types: + query += " AND me.entity_type = ANY(:entity_types)" + + query += """ + GROUP BY me.memory_id, me.memory_type + ORDER BY entity_match_count DESC, avg_relevance DESC + LIMIT :limit + """ + + params = { + "namespace": namespace, + "entities": normalized_entities, + "min_relevance": min_relevance, + "limit": limit, + } + + if entity_types: + params["entity_types"] = entity_types + + return query, params + + def build_graph_expansion_query( + self, + seed_memory_ids: List[str], + hop_distance: int, + min_strength: float, + relationship_types: Optional[List[str]] = None, + namespace: str = "default", + limit_per_hop: int = 10, + ) -> tuple[str, Dict[str, Any]]: + """Expand via graph relationships using recursive CTE""" + + if not seed_memory_ids: + return "SELECT NULL WHERE FALSE", {} + + seed_placeholders = ", ".join( + [f":seed_{i}" for i in range(len(seed_memory_ids))] + ) + + query = f""" + WITH RECURSIVE graph_walk AS ( + -- Base case: seed memories at hop 0 + SELECT + seed_id as memory_id, + 0 as hop, + 1.0::double precision as cumulative_strength, + ARRAY[seed_id]::text[] as path, + NULL::text as relationship_type + FROM unnest(ARRAY[{seed_placeholders}]) as seed_id + + UNION ALL + + -- Recursive case: follow relationships + SELECT + CASE + WHEN mr.source_memory_id = gw.memory_id THEN mr.target_memory_id + ELSE mr.source_memory_id + END as memory_id, + gw.hop + 1 as hop, + gw.cumulative_strength * mr.strength as cumulative_strength, + gw.path || CASE + WHEN mr.source_memory_id = gw.memory_id THEN mr.target_memory_id + ELSE mr.source_memory_id + END as path, + mr.relationship_type + FROM graph_walk gw + JOIN memory_relationships mr ON ( + (mr.source_memory_id = gw.memory_id OR mr.target_memory_id = gw.memory_id) + AND mr.namespace = :namespace + AND mr.strength >= :min_strength + """ + + if relationship_types: + query += " AND mr.relationship_type = ANY(:rel_types)" + + query += """ + ) + WHERE gw.hop < :max_hops + AND NOT ( + CASE + WHEN mr.source_memory_id = gw.memory_id THEN mr.target_memory_id + ELSE mr.source_memory_id + END = ANY(gw.path) + ) + ) + SELECT DISTINCT + memory_id, + hop, + MAX(cumulative_strength) as max_strength, + array_agg(DISTINCT relationship_type) FILTER (WHERE relationship_type IS NOT NULL) as relationship_types, + array_agg(DISTINCT path) as paths + FROM graph_walk + WHERE hop > 0 + GROUP BY memory_id, hop + ORDER BY hop ASC, max_strength DESC + LIMIT :limit + """ + + params = { + "namespace": namespace, + "min_strength": min_strength, + "max_hops": hop_distance, + "limit": limit_per_hop * hop_distance, + } + + for i, seed in enumerate(seed_memory_ids): + params[f"seed_{i}"] = seed + + if relationship_types: + params["rel_types"] = relationship_types + + return query, params + + def build_entity_cluster_query( + self, + entities: List[str], + namespace: str = "default", + min_shared_entities: int = 2, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """Find memories sharing multiple entities""" + + normalized_entities = [e.lower() for e in entities] + + query = """ + SELECT + me.memory_id, + me.memory_type, + COUNT(DISTINCT me.normalized_value) as shared_entity_count, + AVG(me.relevance_score) as avg_relevance, + array_agg(DISTINCT me.entity_value) as shared_entities + FROM memory_entities me + WHERE me.namespace = :namespace + AND me.normalized_value = ANY(:entities) + GROUP BY me.memory_id, me.memory_type + HAVING COUNT(DISTINCT me.normalized_value) >= :min_shared + ORDER BY shared_entity_count DESC, avg_relevance DESC + LIMIT :limit + """ + + params = { + "namespace": namespace, + "entities": normalized_entities, + "min_shared": min_shared_entities, + "limit": limit, + } + + return query, params + + def build_relationship_discovery_query( + self, + memory_id: str, + relationship_types: Optional[List[str]] = None, + min_strength: float = 0.5, + namespace: str = "default", + limit: int = 20, + ) -> tuple[str, Dict[str, Any]]: + """Find all direct relationships for a memory""" + + query = """ + SELECT + mr.relationship_id, + mr.source_memory_id, + mr.target_memory_id, + mr.relationship_type, + mr.strength, + mr.reasoning, + mr.shared_entity_count, + CASE + WHEN mr.source_memory_id = :memory_id THEN mr.target_memory_id + ELSE mr.source_memory_id + END as related_memory_id + FROM memory_relationships mr + WHERE (mr.source_memory_id = :memory_id OR mr.target_memory_id = :memory_id) + AND mr.namespace = :namespace + AND mr.strength >= :min_strength + """ + + if relationship_types: + query += " AND mr.relationship_type = ANY(:rel_types)" + + query += """ + ORDER BY mr.strength DESC + LIMIT :limit + """ + + params = { + "memory_id": memory_id, + "namespace": namespace, + "min_strength": min_strength, + "limit": limit, + } + + if relationship_types: + params["rel_types"] = relationship_types + + return query, params + + def build_path_finding_query( + self, + source_memory_id: str, + target_memory_id: str, + max_depth: int = 3, + namespace: str = "default", + ) -> tuple[str, Dict[str, Any]]: + """Find paths between two memories using recursive CTE""" + + query = """ + WITH RECURSIVE path_search AS ( + -- Base case: start from source + SELECT + :source as current_id, + :target as target_id, + ARRAY[:source]::text[] as path, + ARRAY[]::text[] as relationship_path, + 0 as depth, + 1.0::double precision as total_strength + + UNION ALL + + -- Recursive case: follow relationships + SELECT + CASE + WHEN mr.source_memory_id = ps.current_id THEN mr.target_memory_id + ELSE mr.source_memory_id + END as current_id, + ps.target_id, + ps.path || CASE + WHEN mr.source_memory_id = ps.current_id THEN mr.target_memory_id + ELSE mr.source_memory_id + END as path, + ps.relationship_path || mr.relationship_type as relationship_path, + ps.depth + 1 as depth, + ps.total_strength * mr.strength as total_strength + FROM path_search ps + JOIN memory_relationships mr ON ( + (mr.source_memory_id = ps.current_id OR mr.target_memory_id = ps.current_id) + AND mr.namespace = :namespace + ) + WHERE ps.depth < :max_depth + AND ps.current_id != ps.target_id + AND NOT ( + CASE + WHEN mr.source_memory_id = ps.current_id THEN mr.target_memory_id + ELSE mr.source_memory_id + END = ANY(ps.path) + ) + ) + SELECT + path, + relationship_path, + depth, + total_strength + FROM path_search + WHERE current_id = target_id + ORDER BY depth ASC, total_strength DESC + LIMIT 5 + """ + + params = { + "source": source_memory_id, + "target": target_memory_id, + "namespace": namespace, + "max_depth": max_depth, + } + + return query, params + + def build_shared_entities_query( + self, + memory_id: str, + namespace: str = "default", + min_overlap: int = 1, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """Find memories sharing entities with given memory""" + + query = """ + WITH source_entities AS ( + SELECT normalized_value + FROM memory_entities + WHERE memory_id = :memory_id + AND namespace = :namespace + ) + SELECT + me.memory_id, + me.memory_type, + COUNT(DISTINCT me.normalized_value) as shared_count, + AVG(me.relevance_score) as avg_relevance, + array_agg(DISTINCT me.entity_value) as shared_entities + FROM memory_entities me + JOIN source_entities se ON me.normalized_value = se.normalized_value + WHERE me.memory_id != :memory_id + AND me.namespace = :namespace + GROUP BY me.memory_id, me.memory_type + HAVING COUNT(DISTINCT me.normalized_value) >= :min_overlap + ORDER BY shared_count DESC, avg_relevance DESC + LIMIT :limit + """ + + params = { + "memory_id": memory_id, + "namespace": namespace, + "min_overlap": min_overlap, + "limit": limit, + } + + return query, params + + def get_parameter_placeholder(self, param_name: str) -> str: + """PostgreSQL uses :param_name style placeholders""" + return f":{param_name}" diff --git a/memori/database/graph_queries/sqlite.py b/memori/database/graph_queries/sqlite.py new file mode 100644 index 00000000..1eb0abc5 --- /dev/null +++ b/memori/database/graph_queries/sqlite.py @@ -0,0 +1,262 @@ +""" +SQLite Graph Query Builder +Uses simple CTEs and iterative queries +""" + +from typing import Any, Dict, List, Optional + +from .base import GraphQueryBuilder + + +class SQLiteGraphQueryBuilder(GraphQueryBuilder): + """SQLite-specific graph query builder""" + + def supports_recursive_cte(self) -> bool: # type: ignore[override] + return False + + def build_entity_search_query( + self, + entities: List[str], + entity_types: Optional[List[str]] = None, + namespace: str = "default", + min_relevance: float = 0.0, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """Find memories by entity values""" + + normalized_entities = [e.lower() for e in entities] + entity_placeholders = ",".join(["?" for _ in normalized_entities]) + + query = f""" + SELECT + me.memory_id, + me.memory_type, + COUNT(*) as entity_match_count, + AVG(me.relevance_score) as avg_relevance, + GROUP_CONCAT(DISTINCT me.entity_value) as matched_entities + FROM memory_entities me + WHERE me.namespace = ? + AND me.normalized_value IN ({entity_placeholders}) + AND me.relevance_score >= ? + """ + + params_list = [namespace] + normalized_entities + [min_relevance] + + if entity_types: + type_placeholders = ",".join(["?" for _ in entity_types]) + query += f" AND me.entity_type IN ({type_placeholders})" + params_list.extend(entity_types) + + query += """ + GROUP BY me.memory_id, me.memory_type + ORDER BY entity_match_count DESC, avg_relevance DESC + LIMIT ? + """ + + params_list.append(limit) + params = {"__sqlite_params": params_list} + + return query, params + + def build_graph_expansion_query( + self, + seed_memory_ids: List[str], + hop_distance: int, + min_strength: float, + relationship_types: Optional[List[str]] = None, + namespace: str = "default", + limit_per_hop: int = 10, + ) -> tuple[str, Dict[str, Any]]: + """1-hop expansion for SQLite (multi-hop handled iteratively)""" + + seed_placeholders = ",".join(["?" for _ in seed_memory_ids]) + + query = f""" + SELECT DISTINCT + CASE + WHEN mr.source_memory_id IN ({seed_placeholders}) THEN mr.target_memory_id + ELSE mr.source_memory_id + END as memory_id, + CASE + WHEN mr.source_memory_id IN ({seed_placeholders}) THEN mr.source_memory_id + ELSE mr.target_memory_id + END as via_memory_id, + mr.strength as edge_strength, + mr.relationship_type + FROM memory_relationships mr + WHERE (mr.source_memory_id IN ({seed_placeholders}) + OR mr.target_memory_id IN ({seed_placeholders})) + AND mr.namespace = ? + AND mr.strength >= ? + """ + + params_list = ( + seed_memory_ids + + seed_memory_ids + + seed_memory_ids + + seed_memory_ids + + [namespace, min_strength] + ) + + if relationship_types: + type_placeholders = ",".join(["?" for _ in relationship_types]) + query += f" AND mr.relationship_type IN ({type_placeholders})" + params_list.extend(relationship_types) + + query += """ + ORDER BY edge_strength DESC + LIMIT ? + """ + + params_list.append(limit_per_hop) + params = {"__sqlite_params": params_list} + + return query, params + + def build_entity_cluster_query( + self, + entities: List[str], + namespace: str = "default", + min_shared_entities: int = 2, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """Find memories sharing multiple entities""" + + normalized_entities = [e.lower() for e in entities] + entity_placeholders = ",".join(["?" for _ in normalized_entities]) + + query = f""" + SELECT + me.memory_id, + me.memory_type, + COUNT(DISTINCT me.normalized_value) as shared_entity_count, + AVG(me.relevance_score) as avg_relevance, + GROUP_CONCAT(DISTINCT me.entity_value) as shared_entities + FROM memory_entities me + WHERE me.namespace = ? + AND me.normalized_value IN ({entity_placeholders}) + GROUP BY me.memory_id, me.memory_type + HAVING COUNT(DISTINCT me.normalized_value) >= ? + ORDER BY shared_entity_count DESC, avg_relevance DESC + LIMIT ? + """ + + params_list = [namespace] + normalized_entities + [min_shared_entities, limit] + params = {"__sqlite_params": params_list} + + return query, params + + def build_relationship_discovery_query( + self, + memory_id: str, + relationship_types: Optional[List[str]] = None, + min_strength: float = 0.5, + namespace: str = "default", + limit: int = 20, + ) -> tuple[str, Dict[str, Any]]: + """Find all direct relationships for a memory""" + + query = """ + SELECT + mr.relationship_id, + mr.source_memory_id, + mr.target_memory_id, + mr.relationship_type, + mr.strength, + mr.reasoning, + mr.shared_entity_count, + CASE + WHEN mr.source_memory_id = ? THEN mr.target_memory_id + ELSE mr.source_memory_id + END as related_memory_id + FROM memory_relationships mr + WHERE (mr.source_memory_id = ? OR mr.target_memory_id = ?) + AND mr.namespace = ? + AND mr.strength >= ? + """ + + params_list = [memory_id, memory_id, memory_id, namespace, min_strength] + + if relationship_types: + type_placeholders = ",".join(["?" for _ in relationship_types]) + query += f" AND mr.relationship_type IN ({type_placeholders})" + params_list.extend(relationship_types) + + query += """ + ORDER BY mr.strength DESC + LIMIT ? + """ + + params_list.append(limit) + params = {"__sqlite_params": params_list} + + return query, params + + def build_path_finding_query( + self, + source_memory_id: str, + target_memory_id: str, + max_depth: int = 3, + namespace: str = "default", + ) -> tuple[str, Dict[str, Any]]: + """Simple 1-hop path check for SQLite""" + + query = """ + SELECT + source_memory_id, + target_memory_id, + relationship_type, + strength, + 1 as depth + FROM memory_relationships + WHERE ((source_memory_id = ? AND target_memory_id = ?) + OR (source_memory_id = ? AND target_memory_id = ?)) + AND namespace = ? + LIMIT 1 + """ + + params_list = [source_memory_id, target_memory_id, target_memory_id, source_memory_id, namespace] + params = {"__sqlite_params": params_list} + + return query, params + + def build_shared_entities_query( + self, + memory_id: str, + namespace: str = "default", + min_overlap: int = 1, + limit: int = 50, + ) -> tuple[str, Dict[str, Any]]: + """Find memories sharing entities with given memory""" + + query = """ + SELECT + me2.memory_id, + me2.memory_type, + COUNT(DISTINCT me2.normalized_value) as shared_count, + AVG(me2.relevance_score) as avg_relevance, + GROUP_CONCAT(DISTINCT me2.entity_value, ', ') as shared_entities + FROM memory_entities me1 + JOIN memory_entities me2 ON me1.normalized_value = me2.normalized_value + WHERE me1.memory_id = ? + AND me2.memory_id != ? + AND me1.namespace = ? + AND me2.namespace = ? + GROUP BY me2.memory_id, me2.memory_type + HAVING COUNT(DISTINCT me2.normalized_value) >= ? + ORDER BY shared_count DESC, avg_relevance DESC + LIMIT ? + """ + + params_list = [memory_id, memory_id, namespace, namespace, min_overlap, limit] + params = {"__sqlite_params": params_list} + + return query, params + + def get_parameter_placeholder(self, param_name: str) -> str: + """SQLite uses ? placeholders""" + return "?" + + def supports_recursive_cte(self) -> bool: + """SQLite supports recursive CTEs but we use simple iterative approach""" + return False diff --git a/memori/database/graph_search_service.py b/memori/database/graph_search_service.py new file mode 100644 index 00000000..49df2a20 --- /dev/null +++ b/memori/database/graph_search_service.py @@ -0,0 +1,965 @@ +""" +Graph-Based Memory Search Service +Implements 7 search strategies with graph expansion and composite scoring +""" + +from datetime import datetime +from typing import Any, Dict, List, Optional + +from loguru import logger + +from memori.database.graph_queries import get_query_builder +from memori.utils.pydantic_models import ( + ExpansionStrategy, + GraphExpansionConfig, + GraphSearchResult, + GraphTraversalPath, + RelationshipType, + SearchStrategy, + ScoringWeights, +) + + +class GraphSearchService: + """ + Core service for graph-based memory search + Supports 7 search strategies with configurable scoring + """ + + def __init__(self, database_manager): + """ + Initialize graph search service + + Args: + database_manager: Database manager with engine/session + """ + self.db_manager = database_manager + self.dialect = database_manager.engine.dialect.name + self.query_builder = get_query_builder(self.dialect) + + # Performance statistics + self.stats = { + "total_searches": 0, + "strategy_usage": {}, + "avg_response_time_ms": 0, + } + + logger.info(f"GraphSearchService initialized for {self.dialect}") + + def search( + self, + query_text: str, + strategy: SearchStrategy, + namespace: str = "default", + entities: Optional[List[str]] = None, + categories: Optional[List[str]] = None, + graph_expansion: Optional[GraphExpansionConfig] = None, + scoring_weights: Optional[ScoringWeights] = None, + max_results: int = 10, + ) -> List[GraphSearchResult]: + """ + Main search entry point with strategy routing + + Args: + query_text: Search query string + strategy: Search strategy to use + namespace: Memory namespace + entities: Optional entity filters + categories: Optional category filters + graph_expansion: Graph expansion configuration + scoring_weights: Scoring weights + max_results: Maximum results to return + + Returns: + List of GraphSearchResult objects + """ + start_time = datetime.now() + self.stats["total_searches"] += 1 + self.stats["strategy_usage"][strategy] = ( + self.stats["strategy_usage"].get(strategy, 0) + 1 + ) + + logger.debug( + f"Graph search: strategy={strategy}, namespace={namespace}, " + f"entities={entities}, max_results={max_results}" + ) + + try: + # Route to appropriate strategy + if strategy == SearchStrategy.TEXT_ONLY: + results = self._text_only_search( + query_text, namespace, categories, max_results + ) + + elif strategy == SearchStrategy.ENTITY_FIRST: + results = self.entity_first_search( + entities or [], + namespace, + categories, + max_results, + ) + + elif strategy == SearchStrategy.GRAPH_EXPANSION_1HOP: + results = self.search_with_expansion( + query_text, + entities or [], + categories or [], + namespace, + expand_hops=1, + min_strength=graph_expansion.min_relationship_strength + if graph_expansion + else 0.2, + limit=max_results, + ) + + elif strategy == SearchStrategy.GRAPH_EXPANSION_2HOP: + results = self.search_with_expansion( + query_text, + entities or [], + categories or [], + namespace, + expand_hops=2, + min_strength=graph_expansion.min_relationship_strength + if graph_expansion + else 0.2, + limit=max_results, + ) + + elif strategy == SearchStrategy.GRAPH_WALK_CONTEXTUAL: + results = self.graph_walk( + entities or [], + namespace, + max_depth=3, + min_strength=graph_expansion.min_relationship_strength + if graph_expansion + else 0.2, + limit=max_results, + ) + + elif strategy == SearchStrategy.ENTITY_CLUSTER_DISCOVERY: + results = self.entity_cluster_discovery( + entities or [], + namespace, + min_shared=2, + limit=max_results, + ) + + elif strategy == SearchStrategy.CATEGORY_FOCUSED_GRAPH: + results = self.category_focused_graph_search( + query_text, + categories or [], + namespace, + expand_hops=1, + limit=max_results, + ) + + else: + logger.warning(f"Unknown strategy: {strategy}, falling back to TEXT_ONLY") + results = self._text_only_search( + query_text, namespace, categories, max_results + ) + + # Apply composite scoring + if scoring_weights: + results = self._apply_composite_scoring(results, scoring_weights) + + # Sort by composite score + results = sorted(results, key=lambda x: x.composite_score, reverse=True)[ + :max_results + ] + + # Calculate response time + elapsed_ms = (datetime.now() - start_time).total_seconds() * 1000 + self.stats["avg_response_time_ms"] = ( + self.stats["avg_response_time_ms"] * 0.9 + elapsed_ms * 0.1 + ) # Moving average + + logger.info( + f"Search completed: strategy={strategy}, results={len(results)}, " + f"time={elapsed_ms:.1f}ms" + ) + + return results + + except Exception as e: + logger.error(f"Graph search failed: {e}") + raise + + def search_with_expansion( + self, + query_text: str, + entities: List[str], + categories: List[str], + namespace: str, + expand_hops: int = 1, + min_strength: float = 0.5, + limit: int = 10, + ) -> List[GraphSearchResult]: + """ + Search with graph expansion from seed memories + + Steps: + 1. Find seed memories (text + entity search) + 2. Expand via graph relationships + 3. Combine and deduplicate results + 4. Calculate composite scores + 5. Return top results with graph metadata + """ + # Step 1: Find seed memories + seed_memories = self._find_seed_memories( + query_text, entities, categories, namespace, limit=20 + ) + + if not seed_memories: + logger.debug( + "No seed memories found via entities, falling back to text search for seeds" + ) + # Fallback: Use text-based search to find seed memories + seed_memories = self._find_text_based_seeds( + query_text, namespace, categories, limit=10 + ) + + if not seed_memories: + logger.debug( + "No seed memories found via text search either, graph search cannot proceed" + ) + return [] + + seed_memory_ids = [m["memory_id"] for m in seed_memories] + logger.debug(f"Starting graph expansion from {len(seed_memory_ids)} seed memories") + + # Step 2: Expand via graph + expanded_memories = self._expand_via_graph( + seed_memory_ids, + hop_distance=expand_hops, + min_strength=min_strength, + namespace=namespace, + limit_per_hop=limit, + ) + + # Step 3: Combine seed and expanded memories + all_memories = seed_memories + expanded_memories + + # Step 4: Fetch full memory data and enrich with graph metadata + results = self._enrich_with_memory_data(all_memories, namespace) + + # Step 5: Generate match reasons + results = self._generate_match_reasons(results, query_text, entities) + + return results + + def entity_first_search( + self, + entities: List[str], + namespace: str, + categories: Optional[List[str]] = None, + limit: int = 50, + ) -> List[GraphSearchResult]: + """Search by entity tags first, then expand""" + if not entities: + return [] + + # Build and execute entity search query + query, params = self.query_builder.build_entity_search_query( + entities=entities, + entity_types=None, + namespace=namespace, + min_relevance=0.3, + limit=limit, + ) + + with self.db_manager.get_session() as session: + raw_results = self._execute_query(session, query, params) + + # Convert to GraphSearchResult + results = [] + for row in raw_results: + result = GraphSearchResult( + memory_id=row["memory_id"], + content="", # Will be filled by _enrich_with_memory_data + summary="", + entity_overlap_score=min(1.0, row.get("entity_match_count", 0) / len(entities)), + hop_distance=0, + shared_entities=( + row.get("matched_entities", "").split(", ") + if isinstance(row.get("matched_entities"), str) + else row.get("matched_entities", []) + ), + match_reason=f"Matched {row.get('entity_match_count', 0)} entities", + ) + results.append(result) + + # Enrich with full memory data + results = self._enrich_with_memory_data( + [{"memory_id": r.memory_id, "hop": 0} for r in results], + namespace, + ) + + return results + + def entity_cluster_discovery( + self, + entities: List[str], + namespace: str, + min_shared: int = 2, + limit: int = 50, + ) -> List[GraphSearchResult]: + """Find memories that share multiple entities (cluster discovery)""" + if not entities or len(entities) < min_shared: + return [] + + # Build and execute entity cluster query + query, params = self.query_builder.build_entity_cluster_query( + entities=entities, + namespace=namespace, + min_shared_entities=min_shared, + limit=limit, + ) + + with self.db_manager.get_session() as session: + raw_results = self._execute_query(session, query, params) + + # Convert to results + results = [] + for row in raw_results: + shared_entities = ( + row.get("shared_entities", "").split(", ") + if isinstance(row.get("shared_entities"), str) + else row.get("shared_entities", []) + ) + + result = GraphSearchResult( + memory_id=row["memory_id"], + content="", + summary="", + entity_overlap_score=min( + 1.0, row.get("shared_entity_count", 0) / len(entities) + ), + hop_distance=0, + shared_entities=shared_entities, + match_reason=f"Shared {row.get('shared_entity_count', 0)} entities: {', '.join(shared_entities[:3])}", + ) + results.append(result) + + # Enrich with full data + results = self._enrich_with_memory_data( + [{"memory_id": r.memory_id, "hop": 0} for r in results], + namespace, + ) + + return results + + def graph_walk( + self, + entities: List[str], + namespace: str, + max_depth: int = 3, + min_strength: float = 0.5, + limit: int = 20, + ) -> List[GraphSearchResult]: + """ + Contextual graph walk - follow relationships from entity-tagged memories + + This is the most powerful strategy for "find everything related to X" queries + """ + # First find memories with these entities + seed_results = self.entity_first_search( + entities, namespace, categories=None, limit=10 + ) + + if not seed_results: + return [] + + seed_ids = [r.memory_id for r in seed_results] + + # Perform multi-hop expansion + expanded = self._expand_via_graph( + seed_ids, + hop_distance=max_depth, + min_strength=min_strength, + namespace=namespace, + limit_per_hop=limit, + ) + + # Combine seeds and expanded + all_results = seed_results + self._enrich_with_memory_data(expanded, namespace) + + return all_results[:limit] + + def category_focused_graph_search( + self, + query_text: str, + categories: List[str], + namespace: str, + expand_hops: int = 1, + limit: int = 50, + ) -> List[GraphSearchResult]: + """Search within specific categories, then expand via graph""" + + # Find seed memories in categories + seeds = self._find_seed_memories( + query_text, [], categories, namespace, limit=20 + ) + + if not seeds: + return [] + + # Expand from seeds + return self.search_with_expansion( + query_text=query_text, + entities=[], + categories=categories, + namespace=namespace, + expand_hops=expand_hops, + min_strength=0.5, + limit=limit, + ) + + # ==================== Helper Methods ==================== + + def _text_only_search( + self, + query_text: str, + namespace: str, + categories: Optional[List[str]], + limit: int, + ) -> List[GraphSearchResult]: + """Fallback to traditional text search (no graph)""" + try: + from .search_service import SearchService + + # Get a session from db_manager + with self.db_manager.SessionLocal() as session: + search_service = SearchService( + session=session, + database_type=self.dialect + ) + + results = search_service.search_memories( + query=query_text, + namespace=namespace, + limit=limit, + category_filter=categories, + ) + + # Convert to GraphSearchResult format + graph_results = [] + for result in results: + if isinstance(result, dict): + composite = result.get("composite_score") + if composite is None: + composite = result.get("search_score", 0.5) + text_score = result.get("search_score", composite) + created_at = result.get("created_at", datetime.now()) + + graph_results.append( + GraphSearchResult( + memory_id=result.get("memory_id", ""), + content=result.get("processed_data", ""), + summary=result.get("summary", ""), + category=result.get("category_primary"), + composite_score=composite, + text_relevance_score=text_score, + entity_overlap_score=0.0, + graph_strength_score=0.0, + importance_score=result.get("importance_score", 0.5), + recency_score=0.5, + hop_distance=0, + graph_paths=[], + shared_entities=[], + timestamp=created_at, + access_count=result.get("access_count") or 0, + last_accessed=result.get("last_accessed"), + ) + ) + + logger.debug( + f"TEXT_ONLY search returned {len(graph_results)} results for query: '{query_text[:50]}...'" + ) + return graph_results + + except Exception as e: + logger.error(f"TEXT_ONLY search failed: {e}") + return [] + + def _find_seed_memories( + self, + query_text: str, + entities: List[str], + categories: List[str], + namespace: str, + limit: int, + ) -> List[Dict[str, Any]]: + """Find initial seed memories for graph expansion""" + + seed_map: Dict[str, Dict[str, Any]] = {} + + def upsert_seed(memory_id: str, **metadata: Any) -> None: + existing = seed_map.get(memory_id) + if existing: + # Keep the strongest signals from any strategy + existing["text_score"] = max( + existing.get("text_score", 0.0), metadata.get("text_score", 0.0) + ) + existing["entity_overlap_score"] = max( + existing.get("entity_overlap_score", 0.0), + metadata.get("entity_overlap_score", 0.0), + ) + existing.setdefault("sources", set()).update(metadata.get("sources", set())) + if metadata.get("matched_entities"): + existing["matched_entities"] = metadata["matched_entities"] + if metadata.get("category"): + existing["category"] = metadata["category"] + if metadata.get("created_at"): + existing["created_at"] = metadata["created_at"] + if metadata.get("importance_score") is not None: + existing["importance_score"] = metadata["importance_score"] + else: + seed_map[memory_id] = { + "memory_id": memory_id, + "hop": 0, + "text_score": metadata.get("text_score", 0.0), + "entity_overlap_score": metadata.get("entity_overlap_score", 0.0), + "matched_entities": metadata.get("matched_entities", []), + "sources": metadata.get("sources", set()), + "created_at": metadata.get("created_at"), + "importance_score": metadata.get("importance_score"), + "category": metadata.get("category"), + } + + # Strategy 1: entity-driven seeds + if entities: + query, params = self.query_builder.build_entity_search_query( + entities=entities, + namespace=namespace, + min_relevance=0.3, + limit=max(limit, len(entities) * 5), + ) + + with self.db_manager.get_session() as session: + entity_results = self._execute_query(session, query, params) + + for row in entity_results: + matched_entities = row.get("matched_entities") + if isinstance(matched_entities, str): + matched_entities = [value.strip() for value in matched_entities.split(",") if value.strip()] + entity_score = 0.0 + if entities: + entity_score = min( + 1.0, + row.get("entity_match_count", 0) / max(1, len(entities)), + ) + + upsert_seed( + row["memory_id"], + entity_overlap_score=entity_score, + matched_entities=matched_entities or [], + sources={"entity"}, + ) + + # Strategy 2: text/category seeds via FTS/LIKE (SearchService) + if query_text: + text_seeds = self._find_text_based_seeds( + query_text=query_text, + namespace=namespace, + categories=categories, + limit=limit, + ) + for seed in text_seeds: + upsert_seed( + seed["memory_id"], + text_score=seed.get("text_score", 0.0), + importance_score=seed.get("importance_score"), + created_at=seed.get("created_at"), + category=seed.get("category"), + sources={"text"}, + ) + + # Convert to list sorted by combined strength + seeds: List[Dict[str, Any]] = [] + for memory_id, data in seed_map.items(): + sources = data.pop("sources", set()) + combined_score = max(data.get("text_score", 0.0), data.get("entity_overlap_score", 0.0)) + seeds.append( + { + "memory_id": memory_id, + "hop": 0, + "cumulative_strength": max(0.3, combined_score), + "text_score": data.get("text_score", 0.0), + "entity_overlap_score": data.get("entity_overlap_score", 0.0), + "matched_entities": data.get("matched_entities", []), + "created_at": data.get("created_at"), + "importance_score": data.get("importance_score"), + "category": data.get("category"), + "sources": list(sources), + } + ) + + seeds.sort(key=lambda s: (s.get("text_score", 0.0) + s.get("entity_overlap_score", 0.0)), reverse=True) + return seeds[:limit] + + def _find_text_based_seeds( + self, + query_text: str, + namespace: str, + categories: Optional[List[str]], + limit: int, + ) -> List[Dict[str, Any]]: + """ + Find seed memories using text-based search when entity search fails + + This is a fallback for when: + - No entities were extracted from the query + - Entity search found no matching memories + """ + if not query_text: + return [] + + try: + # Use the search service for text-based search + from .search_service import SearchService + + # Get a session from db_manager + with self.db_manager.SessionLocal() as session: + search_service = SearchService( + session=session, + database_type=self.dialect + ) + + results = search_service.search_memories( + query=query_text, + namespace=namespace, + limit=limit, + category_filter=categories, + ) + + # Convert to seed format + seeds = [] + for result in results: + if isinstance(result, dict) and "memory_id" in result: + seeds.append( + { + "memory_id": result["memory_id"], + "hop": 0, + "source": "text", + "text_score": result.get("composite_score") + or result.get("search_score", 0.0), + "importance_score": result.get("importance_score"), + "created_at": result.get("created_at"), + "category": result.get("category_primary"), + } + ) + + logger.debug( + f"Text-based seed search found {len(seeds)} memories for query: '{query_text[:50]}...'" + ) + return seeds + + except Exception as e: + logger.warning(f"Text-based seed search failed: {e}") + return [] + + def _expand_via_graph( + self, + seed_memory_ids: List[str], + hop_distance: int, + min_strength: float, + namespace: str, + limit_per_hop: int, + ) -> List[Dict[str, Any]]: + """Expand from seed memories via graph relationships""" + + if hop_distance == 0 or not seed_memory_ids: + return [] + + supports_recursive = getattr(self.query_builder, "supports_recursive_cte", lambda: False)() + + if supports_recursive: + query, params = self.query_builder.build_graph_expansion_query( + seed_memory_ids=seed_memory_ids, + hop_distance=hop_distance, + min_strength=min_strength, + relationship_types=None, + namespace=namespace, + limit_per_hop=limit_per_hop, + ) + + with self.db_manager.get_session() as session: + rows = self._execute_query(session, query, params) + + normalized: List[Dict[str, Any]] = [] + for row in rows: + normalized.append( + { + "memory_id": row.get("memory_id"), + "hop": row.get("hop", 1), + "cumulative_strength": row.get("max_strength") + or row.get("cumulative_strength") + or row.get("edge_strength", 0.0), + "relationship_types": row.get("relationship_types", []), + "paths": row.get("paths"), + } + ) + + logger.debug(f"Graph expansion (recursive) found {len(normalized)} memories") + return normalized + + # Iterative expansion for databases without recursive CTE support (SQLite/MySQL) + visited = set(seed_memory_ids) + frontier = list(seed_memory_ids) + expansions: List[Dict[str, Any]] = [] + + for hop in range(1, hop_distance + 1): + if not frontier: + break + + query, params = self.query_builder.build_graph_expansion_query( + seed_memory_ids=frontier, + hop_distance=1, + min_strength=min_strength, + relationship_types=None, + namespace=namespace, + limit_per_hop=limit_per_hop, + ) + + with self.db_manager.get_session() as session: + rows = self._execute_query(session, query, params) + + next_frontier: List[str] = [] + hop_results: Dict[str, Dict[str, Any]] = {} + + for row in rows: + candidate_id = row.get("memory_id") + if not candidate_id or candidate_id in visited: + continue + + edge_strength = row.get("edge_strength") or row.get("cumulative_strength") or 0.0 + if edge_strength < min_strength: + continue + + adjusted_strength = min(1.0, edge_strength * (0.85 ** (hop - 1))) + + if candidate_id not in hop_results or adjusted_strength > hop_results[candidate_id]["cumulative_strength"]: + hop_results[candidate_id] = { + "memory_id": candidate_id, + "hop": hop, + "cumulative_strength": adjusted_strength, + "edge_strength": edge_strength, + "via": row.get("via_memory_id") + or row.get("source_memory_id") + or row.get("related_memory_id"), + "relationship_type": row.get("relationship_type"), + } + + # Order by strength and enforce per-hop limit + ordered = sorted( + hop_results.values(), + key=lambda item: item["cumulative_strength"], + reverse=True, + )[:limit_per_hop] + + for item in ordered: + expansions.append(item) + visited.add(item["memory_id"]) + next_frontier.append(item["memory_id"]) + + frontier = next_frontier + + logger.debug(f"Graph expansion (iterative) accumulated {len(expansions)} memories") + return expansions + + def _enrich_with_memory_data( + self, + memory_refs: List[Dict[str, Any]], + namespace: str, + ) -> List[GraphSearchResult]: + """Fetch full memory data and create GraphSearchResult objects""" + + if not memory_refs: + return [] + + memory_ids = [m["memory_id"] for m in memory_refs] + + with self.db_manager.get_session() as session: + query = """ + SELECT memory_id, searchable_content as content, summary, + category_primary as category, importance_score, + created_at, access_count, last_accessed, + 'short_term' as memory_type + FROM short_term_memory + WHERE memory_id IN ({}) + AND namespace = ? + + UNION ALL + + SELECT memory_id, searchable_content as content, summary, + category_primary as category, importance_score, + created_at, access_count, last_accessed, + 'long_term' as memory_type + FROM long_term_memory + WHERE memory_id IN ({}) + AND namespace = ? + """.format( + ",".join(["?" for _ in memory_ids]), + ",".join(["?" for _ in memory_ids]), + ) + + params = memory_ids + [namespace] + memory_ids + [namespace] + memory_data = self._execute_query(session, query, {"__sqlite_params": params}) + + memory_lookup = {m["memory_id"]: m for m in memory_data} + + def _calculate_recency_score(created_at: Any) -> float: + try: + if not created_at: + return 0.0 + if isinstance(created_at, str): + created_dt = datetime.fromisoformat(created_at.replace("Z", "+00:00")) + else: + created_dt = created_at + days_old = max(0, (datetime.now(tz=getattr(created_dt, "tzinfo", None)) - created_dt).days) + return max(0.0, 1 - (days_old / 45)) + except Exception: + return 0.0 + + results: List[GraphSearchResult] = [] + for ref in memory_refs: + mem = memory_lookup.get(ref["memory_id"]) + if not mem: + continue + + category_value = mem.get("category") + if category_value: + from ..utils.pydantic_models import MemoryCategoryType + + try: + category_value = MemoryCategoryType(category_value) + except (ValueError, AttributeError): + category_value = None + + created_at = mem.get("created_at") + recency_score = _calculate_recency_score(created_at) + + text_score = ref.get("text_score", 0.0) + entity_score = ref.get("entity_overlap_score", 0.0) + graph_strength = ref.get("cumulative_strength", 0.0) + + combined_score = max(text_score, graph_strength, entity_score) + + result = GraphSearchResult( + memory_id=ref["memory_id"], + content=mem.get("content", ""), + summary=mem.get("summary", ""), + category=category_value, + composite_score=max(combined_score, ref.get("composite_score", 0.0)), + text_relevance_score=text_score, + entity_overlap_score=entity_score, + graph_strength_score=graph_strength, + importance_score=mem.get("importance_score", 0.5), + recency_score=recency_score, + hop_distance=ref.get("hop", 0), + graph_paths=[], + shared_entities=ref.get("matched_entities", []), + connected_via=[ref.get("via")] if ref.get("via") else [], + match_reason=ref.get("match_reason", ""), + timestamp=created_at, + access_count=mem.get("access_count") or 0, + last_accessed=mem.get("last_accessed"), + ) + results.append(result) + + return results + + def _apply_composite_scoring( + self, + results: List[GraphSearchResult], + weights: ScoringWeights, + ) -> List[GraphSearchResult]: + """Apply composite scoring with configurable weights""" + + for result in results: + result.composite_score = ( + result.text_relevance_score * weights.text_relevance + + result.entity_overlap_score * weights.entity_overlap + + result.graph_strength_score * weights.graph_strength + + result.importance_score * weights.importance + + result.recency_score * weights.recency + ) + + return results + + def _generate_match_reasons( + self, + results: List[GraphSearchResult], + query_text: str, + entities: List[str], + ) -> List[GraphSearchResult]: + """Generate human-readable match explanations""" + + for result in results: + reasons = [] + + if result.hop_distance == 0: + reasons.append("Direct match") + else: + reasons.append(f"{result.hop_distance}-hop connection") + + if result.shared_entities: + reasons.append( + f"Shares entities: {', '.join(result.shared_entities[:3])}" + ) + + if result.entity_overlap_score > 0.7: + reasons.append("Strong entity overlap") + + if result.text_relevance_score > 0.3: + reasons.append("Textually relevant") + + if result.connected_via: + reasons.append(f"Connected via {', '.join(result.connected_via)}") + + result.match_reason = " | ".join(reasons) if reasons else "Related memory" + + return results + + def _execute_query( + self, session, query: str, params: Dict[str, Any] + ) -> List[Dict[str, Any]]: + """Execute database query and return results as dicts""" + from sqlalchemy import text + + # Handle different parameter formats + if "__mysql_params" in params or "__sqlite_params" in params: + # Convert positional params (?) to named params for SQLAlchemy 2.0+ + param_list = params.get("__mysql_params") or params.get("__sqlite_params") + + # Convert ? placeholders to named parameters + named_params = {} + modified_query = query + param_index = 0 + + # Replace each ? with a named parameter + while "?" in modified_query and param_index < len(param_list): + param_name = f"param_{param_index}" + modified_query = modified_query.replace("?", f":{param_name}", 1) + named_params[param_name] = param_list[param_index] + param_index += 1 + + query_text = text(modified_query) + result = session.execute(query_text, named_params) + elif "__mongodb_pipeline" in params: + # MongoDB aggregation - different handling + # This would use pymongo instead + raise NotImplementedError("MongoDB queries not yet implemented") + else: + # Named parameters (PostgreSQL style) - use text() + query_text = text(query) + result = session.execute(query_text, params) + + # Convert to list of dicts + columns = result.keys() if hasattr(result, "keys") else [] + return [dict(zip(columns, row)) for row in result] + + def get_stats(self) -> Dict[str, Any]: + """Get search statistics""" + return self.stats.copy() diff --git a/memori/database/migrations/001_add_graph_tables.py b/memori/database/migrations/001_add_graph_tables.py new file mode 100644 index 00000000..d2e552b9 --- /dev/null +++ b/memori/database/migrations/001_add_graph_tables.py @@ -0,0 +1,252 @@ +""" +Migration: Add Graph Tables for Entity and Relationship Storage +Version: 001 +Date: 2025-10-03 +Description: Adds memory_entities and memory_relationships tables for graph-based search +""" + +from datetime import datetime +from loguru import logger + + +def upgrade(engine): + """Add graph tables to existing database""" + dialect = engine.dialect.name + logger.info(f"Running migration 001 on {dialect} database") + + with engine.connect() as conn: + try: + # Step 1: Create memory_entities table + if dialect == "sqlite": + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_entities ( + entity_id TEXT PRIMARY KEY, + memory_id TEXT NOT NULL, + memory_type TEXT NOT NULL, + entity_type TEXT NOT NULL, + entity_value TEXT NOT NULL, + normalized_value TEXT NOT NULL, + relevance_score REAL DEFAULT 0.5, + namespace TEXT NOT NULL DEFAULT 'default', + frequency INTEGER DEFAULT 1, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + context TEXT + ) + """ + ) + elif dialect == "mysql": + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_entities ( + entity_id VARCHAR(255) PRIMARY KEY, + memory_id VARCHAR(255) NOT NULL, + memory_type VARCHAR(50) NOT NULL, + entity_type VARCHAR(100) NOT NULL, + entity_value VARCHAR(500) NOT NULL, + normalized_value VARCHAR(500) NOT NULL, + relevance_score FLOAT DEFAULT 0.5, + namespace VARCHAR(255) NOT NULL DEFAULT 'default', + frequency INT DEFAULT 1, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + context TEXT + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + ) + elif dialect == "postgresql": + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_entities ( + entity_id VARCHAR(255) PRIMARY KEY, + memory_id VARCHAR(255) NOT NULL, + memory_type VARCHAR(50) NOT NULL, + entity_type VARCHAR(100) NOT NULL, + entity_value VARCHAR(500) NOT NULL, + normalized_value VARCHAR(500) NOT NULL, + relevance_score FLOAT DEFAULT 0.5, + namespace VARCHAR(255) NOT NULL DEFAULT 'default', + frequency INTEGER DEFAULT 1, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + context TEXT + ) + """ + ) + + # Step 2: Create entity indexes + _create_entity_indexes(conn, dialect) + + # Step 3: Create memory_relationships table + if dialect == "sqlite": + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_relationships ( + relationship_id TEXT PRIMARY KEY, + source_memory_id TEXT NOT NULL, + target_memory_id TEXT NOT NULL, + source_memory_type TEXT NOT NULL, + target_memory_type TEXT NOT NULL, + relationship_type TEXT NOT NULL, + strength REAL NOT NULL DEFAULT 0.5, + bidirectional INTEGER DEFAULT 1, + namespace TEXT NOT NULL DEFAULT 'default', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_strengthened TIMESTAMP, + access_count INTEGER DEFAULT 0, + reasoning TEXT, + shared_entity_count INTEGER DEFAULT 0, + metadata_json TEXT + ) + """ + ) + elif dialect == "mysql": + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_relationships ( + relationship_id VARCHAR(255) PRIMARY KEY, + source_memory_id VARCHAR(255) NOT NULL, + target_memory_id VARCHAR(255) NOT NULL, + source_memory_type VARCHAR(50) NOT NULL, + target_memory_type VARCHAR(50) NOT NULL, + relationship_type VARCHAR(100) NOT NULL, + strength FLOAT NOT NULL DEFAULT 0.5, + bidirectional BOOLEAN DEFAULT TRUE, + namespace VARCHAR(255) NOT NULL DEFAULT 'default', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_strengthened TIMESTAMP NULL, + access_count INT DEFAULT 0, + reasoning TEXT, + shared_entity_count INT DEFAULT 0, + metadata_json JSON + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + ) + elif dialect == "postgresql": + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_relationships ( + relationship_id VARCHAR(255) PRIMARY KEY, + source_memory_id VARCHAR(255) NOT NULL, + target_memory_id VARCHAR(255) NOT NULL, + source_memory_type VARCHAR(50) NOT NULL, + target_memory_type VARCHAR(50) NOT NULL, + relationship_type VARCHAR(100) NOT NULL, + strength FLOAT NOT NULL DEFAULT 0.5, + bidirectional BOOLEAN DEFAULT TRUE, + namespace VARCHAR(255) NOT NULL DEFAULT 'default', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_strengthened TIMESTAMP, + access_count INTEGER DEFAULT 0, + reasoning TEXT, + shared_entity_count INTEGER DEFAULT 0, + metadata_json JSONB + ) + """ + ) + + # Step 4: Create relationship indexes + _create_relationship_indexes(conn, dialect) + + conn.commit() + logger.success("Migration 001 completed successfully") + + except Exception as e: + logger.error(f"Migration 001 failed: {e}") + conn.rollback() + raise + + +def downgrade(engine): + """Remove graph tables""" + dialect = engine.dialect.name + logger.info(f"Reverting migration 001 on {dialect} database") + + with engine.connect() as conn: + try: + conn.execute("DROP TABLE IF EXISTS memory_relationships") + conn.execute("DROP TABLE IF EXISTS memory_entities") + conn.commit() + logger.success("Migration 001 reverted successfully") + + except Exception as e: + logger.error(f"Migration 001 revert failed: {e}") + conn.rollback() + raise + + +def _create_entity_indexes(conn, dialect): + """Create indexes for memory_entities table""" + indexes = [ + "CREATE INDEX IF NOT EXISTS idx_entity_memory ON memory_entities(memory_id, memory_type)", + "CREATE INDEX IF NOT EXISTS idx_entity_type ON memory_entities(entity_type)", + "CREATE INDEX IF NOT EXISTS idx_entity_value ON memory_entities(entity_value)", + "CREATE INDEX IF NOT EXISTS idx_entity_normalized ON memory_entities(normalized_value)", + "CREATE INDEX IF NOT EXISTS idx_entity_namespace ON memory_entities(namespace)", + "CREATE INDEX IF NOT EXISTS idx_entity_relevance ON memory_entities(relevance_score)", + "CREATE INDEX IF NOT EXISTS idx_entity_type_value ON memory_entities(entity_type, normalized_value)", + "CREATE INDEX IF NOT EXISTS idx_entity_namespace_type ON memory_entities(namespace, entity_type)", + ] + + # Compound index for optimal queries + if dialect == "sqlite": + indexes.append( + "CREATE INDEX IF NOT EXISTS idx_entity_compound ON memory_entities(namespace, entity_type, normalized_value, relevance_score)" + ) + else: + indexes.append( + "CREATE INDEX idx_entity_compound ON memory_entities(namespace, entity_type, normalized_value, relevance_score)" + ) + + for index_sql in indexes: + try: + conn.execute(index_sql) + except Exception as e: + logger.warning(f"Could not create index: {e}") + + +def _create_relationship_indexes(conn, dialect): + """Create indexes for memory_relationships table""" + indexes = [ + "CREATE INDEX IF NOT EXISTS idx_rel_source ON memory_relationships(source_memory_id, source_memory_type)", + "CREATE INDEX IF NOT EXISTS idx_rel_target ON memory_relationships(target_memory_id, target_memory_type)", + "CREATE INDEX IF NOT EXISTS idx_rel_type ON memory_relationships(relationship_type)", + "CREATE INDEX IF NOT EXISTS idx_rel_strength ON memory_relationships(strength)", + "CREATE INDEX IF NOT EXISTS idx_rel_namespace ON memory_relationships(namespace)", + "CREATE INDEX IF NOT EXISTS idx_rel_bidirectional ON memory_relationships(bidirectional)", + "CREATE INDEX IF NOT EXISTS idx_rel_source_type ON memory_relationships(source_memory_id, relationship_type)", + "CREATE INDEX IF NOT EXISTS idx_rel_target_type ON memory_relationships(target_memory_id, relationship_type)", + "CREATE INDEX IF NOT EXISTS idx_rel_entity_count ON memory_relationships(shared_entity_count)", + ] + + # Compound indexes for graph traversal + if dialect == "sqlite": + indexes.extend( + [ + "CREATE INDEX IF NOT EXISTS idx_rel_compound_source ON memory_relationships(source_memory_id, relationship_type, strength)", + "CREATE INDEX IF NOT EXISTS idx_rel_compound_target ON memory_relationships(target_memory_id, relationship_type, strength)", + "CREATE INDEX IF NOT EXISTS idx_rel_namespace_type ON memory_relationships(namespace, relationship_type, strength)", + ] + ) + else: + indexes.extend( + [ + "CREATE INDEX idx_rel_compound_source ON memory_relationships(source_memory_id, relationship_type, strength)", + "CREATE INDEX idx_rel_compound_target ON memory_relationships(target_memory_id, relationship_type, strength)", + "CREATE INDEX idx_rel_namespace_type ON memory_relationships(namespace, relationship_type, strength)", + ] + ) + + for index_sql in indexes: + try: + conn.execute(index_sql) + except Exception as e: + logger.warning(f"Could not create index: {e}") + + +def get_version(): + """Return migration version""" + return "001" + + +def get_description(): + """Return migration description""" + return "Add graph tables for entity and relationship storage" diff --git a/memori/database/migrations/__init__.py b/memori/database/migrations/__init__.py new file mode 100644 index 00000000..e7e93826 --- /dev/null +++ b/memori/database/migrations/__init__.py @@ -0,0 +1,250 @@ +""" +Database Migration Manager for Memori +Handles schema migrations across different database backends +""" + +import importlib +import os +from pathlib import Path +from typing import List + +from loguru import logger + + +class MigrationManager: + """Manages database schema migrations""" + + def __init__(self, engine): + """ + Initialize migration manager + + Args: + engine: SQLAlchemy engine instance + """ + self.engine = engine + self.migrations_dir = Path(__file__).parent + self.available_migrations = self._discover_migrations() + + def _discover_migrations(self) -> List[str]: + """Discover all migration files in the migrations directory""" + migrations = [] + for file in sorted(self.migrations_dir.glob("*.py")): + if file.name != "__init__.py" and not file.name.startswith("_"): + migration_name = file.stem + migrations.append(migration_name) + + logger.debug(f"Discovered {len(migrations)} migrations: {migrations}") + return migrations + + def _load_migration(self, migration_name: str): + """Load a migration module""" + try: + module = importlib.import_module( + f"memori.database.migrations.{migration_name}" + ) + return module + except ImportError as e: + logger.error(f"Failed to load migration {migration_name}: {e}") + raise + + def _create_migration_tracking_table(self): + """Create table to track applied migrations""" + dialect = self.engine.dialect.name + + with self.engine.connect() as conn: + if dialect == "sqlite": + conn.execute( + """ + CREATE TABLE IF NOT EXISTS schema_migrations ( + version TEXT PRIMARY KEY, + description TEXT, + applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + elif dialect == "mysql": + conn.execute( + """ + CREATE TABLE IF NOT EXISTS schema_migrations ( + version VARCHAR(255) PRIMARY KEY, + description TEXT, + applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 + """ + ) + elif dialect == "postgresql": + conn.execute( + """ + CREATE TABLE IF NOT EXISTS schema_migrations ( + version VARCHAR(255) PRIMARY KEY, + description TEXT, + applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + conn.commit() + logger.debug("Migration tracking table ready") + + def _get_applied_migrations(self) -> List[str]: + """Get list of already applied migrations""" + self._create_migration_tracking_table() + + with self.engine.connect() as conn: + result = conn.execute("SELECT version FROM schema_migrations ORDER BY version") + return [row[0] for row in result] + + def _mark_migration_applied(self, version: str, description: str): + """Mark a migration as applied""" + with self.engine.connect() as conn: + if self.engine.dialect.name == "sqlite": + conn.execute( + "INSERT OR IGNORE INTO schema_migrations (version, description) VALUES (?, ?)", + (version, description), + ) + elif self.engine.dialect.name == "mysql": + conn.execute( + "INSERT IGNORE INTO schema_migrations (version, description) VALUES (%s, %s)", + (version, description), + ) + elif self.engine.dialect.name == "postgresql": + conn.execute( + "INSERT INTO schema_migrations (version, description) VALUES (%s, %s) ON CONFLICT (version) DO NOTHING", + (version, description), + ) + + conn.commit() + + def _unmark_migration(self, version: str): + """Remove migration from applied list""" + with self.engine.connect() as conn: + conn.execute( + "DELETE FROM schema_migrations WHERE version = ?", (version,) + ) + conn.commit() + + def pending_migrations(self) -> List[str]: + """Get list of pending (not yet applied) migrations""" + applied = self._get_applied_migrations() + return [m for m in self.available_migrations if m not in applied] + + def upgrade(self, target: str = None): + """ + Apply pending migrations + + Args: + target: Optional specific migration to upgrade to (defaults to latest) + """ + pending = self.pending_migrations() + + if not pending: + logger.info("No pending migrations") + return + + if target: + # Only apply migrations up to target + pending = [m for m in pending if m <= target] + + logger.info(f"Applying {len(pending)} migration(s)") + + for migration_name in pending: + logger.info(f"Applying migration: {migration_name}") + module = self._load_migration(migration_name) + + try: + module.upgrade(self.engine) + version = module.get_version() + description = module.get_description() + self._mark_migration_applied(version, description) + logger.success(f"✓ {migration_name} applied") + + except Exception as e: + logger.error(f"✗ {migration_name} failed: {e}") + raise + + logger.success("All migrations applied successfully") + + def downgrade(self, target: str): + """ + Rollback migrations to a specific version + + Args: + target: Migration version to rollback to + """ + applied = self._get_applied_migrations() + to_revert = [m for m in reversed(applied) if m > target] + + if not to_revert: + logger.info("No migrations to revert") + return + + logger.warning(f"Reverting {len(to_revert)} migration(s)") + + for migration_name in to_revert: + logger.info(f"Reverting migration: {migration_name}") + module = self._load_migration(migration_name) + + try: + module.downgrade(self.engine) + version = module.get_version() + self._unmark_migration(version) + logger.success(f"✓ {migration_name} reverted") + + except Exception as e: + logger.error(f"✗ {migration_name} revert failed: {e}") + raise + + logger.success("Migration rollback completed") + + def status(self): + """Print migration status""" + applied = self._get_applied_migrations() + pending = self.pending_migrations() + + logger.info("=== Migration Status ===") + logger.info(f"Database: {self.engine.dialect.name}") + + if applied: + logger.info(f"\n✓ Applied migrations ({len(applied)}):") + for migration in applied: + logger.info(f" - {migration}") + + if pending: + logger.warning(f"\n⚠ Pending migrations ({len(pending)}):") + for migration in pending: + logger.warning(f" - {migration}") + else: + logger.success("\n✓ Database is up to date") + + +def run_migrations(engine): + """ + Convenience function to run all pending migrations + + Args: + engine: SQLAlchemy engine instance + """ + manager = MigrationManager(engine) + manager.upgrade() + + +def get_migration_status(engine): + """ + Get migration status for a database + + Args: + engine: SQLAlchemy engine instance + + Returns: + dict: Migration status information + """ + manager = MigrationManager(engine) + applied = manager._get_applied_migrations() + pending = manager.pending_migrations() + + return { + "applied": applied, + "pending": pending, + "database": engine.dialect.name, + "up_to_date": len(pending) == 0, + } diff --git a/memori/database/models.py b/memori/database/models.py index 0aa89ad1..bee14fcb 100644 --- a/memori/database/models.py +++ b/memori/database/models.py @@ -180,6 +180,98 @@ class LongTermMemory(Base): ) +class MemoryEntity(Base): + """Entity extraction table for graph-based search""" + + __tablename__ = "memory_entities" + + entity_id = Column(String(255), primary_key=True) + memory_id = Column(String(255), nullable=False) + memory_type = Column(String(50), nullable=False) # 'short_term' or 'long_term' + entity_type = Column(String(100), nullable=False) # person, tech, topic, etc. + entity_value = Column(String(500), nullable=False) + normalized_value = Column(String(500), nullable=False) # Lowercase normalized + relevance_score = Column(Float, default=0.5) + namespace = Column(String(255), nullable=False, default="default") + frequency = Column(Integer, default=1) # How many times mentioned + created_at = Column(DateTime, nullable=False, default=datetime.utcnow) + context = Column(Text) # Additional context about the entity + + # Indexes for fast entity-based search + __table_args__ = ( + Index("idx_entity_memory", "memory_id", "memory_type"), + Index("idx_entity_type", "entity_type"), + Index("idx_entity_value", "entity_value"), + Index("idx_entity_normalized", "normalized_value"), + Index("idx_entity_namespace", "namespace"), + Index("idx_entity_relevance", "relevance_score"), + Index("idx_entity_type_value", "entity_type", "normalized_value"), + Index("idx_entity_namespace_type", "namespace", "entity_type"), + Index( + "idx_entity_compound", + "namespace", + "entity_type", + "normalized_value", + "relevance_score", + ), + ) + + +class MemoryRelationshipDB(Base): + """Relationship graph table for memory connections""" + + __tablename__ = "memory_relationships" + + relationship_id = Column(String(255), primary_key=True) + source_memory_id = Column(String(255), nullable=False) + target_memory_id = Column(String(255), nullable=False) + source_memory_type = Column(String(50), nullable=False) # 'short_term' or 'long_term' + target_memory_type = Column(String(50), nullable=False) + relationship_type = Column( + String(100), nullable=False + ) # semantic_similarity, causality, etc. + strength = Column(Float, nullable=False, default=0.5) # 0.0-1.0 + bidirectional = Column(Boolean, default=True) + namespace = Column(String(255), nullable=False, default="default") + created_at = Column(DateTime, nullable=False, default=datetime.utcnow) + last_strengthened = Column(DateTime) # When strength was updated + access_count = Column(Integer, default=0) # How often traversed + reasoning = Column(Text) # Why this relationship exists + shared_entity_count = Column(Integer, default=0) # Number of shared entities + metadata_json = Column(JSON) # Additional relationship metadata + + # Indexes for graph traversal + __table_args__ = ( + Index("idx_rel_source", "source_memory_id", "source_memory_type"), + Index("idx_rel_target", "target_memory_id", "target_memory_type"), + Index("idx_rel_type", "relationship_type"), + Index("idx_rel_strength", "strength"), + Index("idx_rel_namespace", "namespace"), + Index("idx_rel_bidirectional", "bidirectional"), + Index("idx_rel_source_type", "source_memory_id", "relationship_type"), + Index("idx_rel_target_type", "target_memory_id", "relationship_type"), + Index( + "idx_rel_compound_source", + "source_memory_id", + "relationship_type", + "strength", + ), + Index( + "idx_rel_compound_target", + "target_memory_id", + "relationship_type", + "strength", + ), + Index( + "idx_rel_namespace_type", + "namespace", + "relationship_type", + "strength", + ), + Index("idx_rel_entity_count", "shared_entity_count"), + ) + + # Database-specific configurations def configure_mysql_fulltext(engine): """Configure MySQL FULLTEXT indexes""" diff --git a/memori/database/search_service.py b/memori/database/search_service.py index bd559c35..81b697bc 100644 --- a/memori/database/search_service.py +++ b/memori/database/search_service.py @@ -1,24 +1,52 @@ """ SQLAlchemy-based search service for Memori v2.0 -Provides cross-database full-text search capabilities +Provides cross-database full-text search capabilities with graph integration """ +import re from datetime import datetime -from typing import Any +from typing import Any, Optional from loguru import logger from sqlalchemy import and_, desc, or_, text from sqlalchemy.orm import Session from .models import LongTermMemory, ShortTermMemory +from memori.utils.pydantic_models import ( + GraphExpansionConfig, + SearchStrategy, + ScoringWeights, +) class SearchService: - """Cross-database search service using SQLAlchemy""" + """Cross-database search service using SQLAlchemy with graph capabilities""" - def __init__(self, session: Session, database_type: str): + def __init__( + self, + session: Session, + database_type: str, + graph_search_service: Optional[Any] = None, + ): self.session = session self.database_type = database_type + self.graph_search_service = graph_search_service + + # ------------------------------------------------------------------ + # Query preparation helpers + # ------------------------------------------------------------------ + + @staticmethod + def _sanitize_query_tokens(query: str) -> list[str]: + """Tokenize and strip punctuation for cross-database FTS usage""" + + tokens = [] + for raw in re.split(r"\s+", query.strip()): + cleaned = raw.strip().strip(",.?!;:\"'()[]{}<>") + cleaned = re.sub(r"[^0-9A-Za-z_]+", "", cleaned) + if cleaned: + tokens.append(cleaned.lower()) + return tokens def search_memories( self, @@ -62,6 +90,8 @@ def search_memories( ) try: + sanitized_tokens = self._sanitize_query_tokens(query) + match_query = " ".join(sanitized_tokens) if sanitized_tokens else query.strip() # Try database-specific full-text search first if self.database_type == "sqlite": logger.debug("[SEARCH] Strategy: SQLite FTS5") @@ -172,8 +202,27 @@ def _search_sqlite_fts( f"Search scope - short_term: {search_short_term}, long_term: {search_long_term}" ) - # Build FTS query - fts_query = f'"{query.strip()}"' + # Build FTS query using tokenized terms so multi-word queries do not + # degrade into slow LIKE fallbacks. Terms longer than two characters + # get prefix matching for better recall. + tokens = self._sanitize_query_tokens(query) + long_terms = [t for t in tokens if len(t) > 2] + processed_terms = [] + + for term in long_terms: + sanitized = term.replace('"', '""') + processed_terms.append(f"{sanitized}*") + + if not processed_terms and tokens: + # If everything was short (e.g. "AI"), keep at least one token + sanitized = tokens[0].replace('"', '""') + processed_terms.append(f"{sanitized}*") + + if processed_terms: + fts_query = " OR ".join(processed_terms) + else: + fts_query = query.strip() + logger.debug(f"FTS query built: {fts_query}") # Build category filter @@ -227,7 +276,7 @@ def _search_sqlite_fts( LEFT JOIN long_term_memory lt ON fts.memory_id = lt.memory_id AND fts.memory_type = 'long_term' WHERE memory_search_fts MATCH :fts_query AND fts.namespace = :namespace {category_clause} - ORDER BY search_score, importance_score DESC + ORDER BY search_score DESC, importance_score DESC LIMIT {limit} """ @@ -312,7 +361,7 @@ def _search_mysql_fulltext( try: # Build category filter clause category_clause = "" - params = {"query": query} + params = {"query": match_query} if category_filter: category_placeholders = ",".join( [f":cat_{i}" for i in range(len(category_filter))] @@ -384,7 +433,7 @@ def _search_mysql_fulltext( try: # Build category filter clause category_clause = "" - params = {"query": query} + params = {"query": match_query} if category_filter: category_placeholders = ",".join( [f":cat_{i}" for i in range(len(category_filter))] @@ -478,6 +527,13 @@ def _search_postgresql_fts( results = [] try: + sanitized_tokens = self._sanitize_query_tokens(query) + if sanitized_tokens: + ts_terms = [f"{term}:*" for term in sanitized_tokens] + tsquery_text = " & ".join(ts_terms) + else: + tsquery_text = query.strip() or "" + # Apply limit proportionally between memory types short_limit = ( limit // 2 if search_short_term and search_long_term else limit @@ -486,9 +542,9 @@ def _search_postgresql_fts( limit - short_limit if search_short_term and search_long_term else limit ) - # Prepare query for tsquery - handle spaces and special characters - # Convert simple query to tsquery format (join words with &) - tsquery_text = " & ".join(query.split()) + if not tsquery_text: + logger.debug("Empty tsquery after sanitization, skipping PostgreSQL FTS") + return [] # Search short-term memory if requested if search_short_term: @@ -842,3 +898,206 @@ def _calculate_recency_score(self, created_at) -> float: return max(0, 1 - (days_old / 30)) # Full score for recent, 0 after 30 days except: return 0.0 + + # ==================== Graph-Enhanced Search Methods ==================== + + def search_with_graph( + self, + query: str, + namespace: str = "default", + strategy: SearchStrategy = SearchStrategy.GRAPH_EXPANSION_1HOP, + entities: Optional[list[str]] = None, + category_filter: Optional[list[str]] = None, + graph_expansion: Optional[GraphExpansionConfig] = None, + scoring_weights: Optional[ScoringWeights] = None, + limit: int = 10, + ) -> list[dict[str, Any]]: + """ + Search memories using graph-based strategies + + Args: + query: Search query text + namespace: Memory namespace + strategy: Graph search strategy to use + entities: Optional entity filters + category_filter: Optional category filters + graph_expansion: Graph expansion configuration + scoring_weights: Scoring weights for composite scoring + limit: Maximum results + + Returns: + List of memory dictionaries with graph metadata + """ + if not self.graph_search_service: + logger.warning( + "GraphSearchService not initialized, falling back to text-only search" + ) + return self.search_memories( + query=query, + namespace=namespace, + category_filter=category_filter, + limit=limit, + ) + + logger.info( + f"Graph-enhanced search: strategy={strategy}, query='{query[:50]}...', " + f"namespace={namespace}, entities={entities}" + ) + + try: + # Use GraphSearchService for graph-based search + graph_results = self.graph_search_service.search( + query_text=query, + strategy=strategy, + namespace=namespace, + entities=entities or [], + categories=category_filter or [], + graph_expansion=graph_expansion, + scoring_weights=scoring_weights, + max_results=limit, + ) + + # Convert GraphSearchResult objects to dict format + results = [] + for result in graph_results: + result_dict = { + "memory_id": result.memory_id, + "memory_type": "long_term", # Graph results come from both + "processed_data": {"content": result.content}, + "summary": result.summary, + "importance_score": result.importance_score, + "created_at": result.timestamp, + "category_primary": result.category, + "search_score": result.composite_score, + "composite_score": result.composite_score, + "search_strategy": f"graph_{strategy.value}", + # Graph-specific metadata + "hop_distance": result.hop_distance, + "shared_entities": result.shared_entities, + "match_reason": result.match_reason, + "graph_strength_score": result.graph_strength_score, + "entity_overlap_score": result.entity_overlap_score, + "text_relevance_score": result.text_relevance_score, + } + results.append(result_dict) + + logger.info( + f"Graph search completed: {len(results)} results with strategy={strategy}" + ) + return results + + except Exception as e: + logger.error(f"Graph search failed: {e}, falling back to text search") + logger.debug("Graph search error details", exc_info=True) + # Fallback to traditional search + return self.search_memories( + query=query, + namespace=namespace, + category_filter=category_filter, + limit=limit, + ) + + def search_by_entities( + self, + entities: list[str], + namespace: str = "default", + category_filter: Optional[list[str]] = None, + expand_graph: bool = True, + limit: int = 20, + ) -> list[dict[str, Any]]: + """ + Search memories by entity tags + + Args: + entities: List of entity values to search for + namespace: Memory namespace + category_filter: Optional category filters + expand_graph: Whether to expand via graph relationships + limit: Maximum results + + Returns: + List of memory dictionaries + """ + if not self.graph_search_service: + logger.warning("GraphSearchService not available for entity search") + return [] + + strategy = ( + SearchStrategy.GRAPH_EXPANSION_1HOP + if expand_graph + else SearchStrategy.ENTITY_FIRST + ) + + return self.search_with_graph( + query="", + namespace=namespace, + strategy=strategy, + entities=entities, + category_filter=category_filter, + limit=limit, + ) + + def find_related_memories( + self, + memory_id: str, + namespace: str = "default", + max_hops: int = 2, + min_strength: float = 0.5, + limit: int = 10, + ) -> list[dict[str, Any]]: + """ + Find memories related to a specific memory via graph relationships + + Args: + memory_id: Source memory ID + namespace: Memory namespace + max_hops: Maximum hop distance (1-3) + min_strength: Minimum relationship strength + limit: Maximum results + + Returns: + List of related memory dictionaries + """ + if not self.graph_search_service: + logger.warning("GraphSearchService not available for related memory search") + return [] + + try: + # Use graph walk strategy starting from this memory + graph_expansion = GraphExpansionConfig( + enabled=True, + hop_distance=max_hops, + min_relationship_strength=min_strength, + ) + + # Get related memories via graph expansion + results = self.graph_search_service.search_with_expansion( + query_text="", + entities=[], + categories=[], + namespace=namespace, + expand_hops=max_hops, + min_strength=min_strength, + limit=limit, + ) + + # Convert to dict format + related = [] + for result in results: + if result.memory_id != memory_id: # Exclude source memory + related.append( + { + "memory_id": result.memory_id, + "summary": result.summary, + "hop_distance": result.hop_distance, + "relationship_strength": result.graph_strength_score, + "shared_entities": result.shared_entities, + "match_reason": result.match_reason, + } + ) + + return related[:limit] + + except Exception as e: + logger.error(f"Failed to find related memories for {memory_id}: {e}") + return [] diff --git a/memori/database/sqlalchemy_manager.py b/memori/database/sqlalchemy_manager.py index d4b5b9f4..adba4024 100644 --- a/memori/database/sqlalchemy_manager.py +++ b/memori/database/sqlalchemy_manager.py @@ -63,8 +63,45 @@ def __init__( # Initialize query parameter translator for cross-database compatibility self.query_translator = QueryParameterTranslator(self.database_type) + # Graph search components (set later via set_graph_components) + self.graph_search_service = None + self.entity_extractor = None + self.relationship_detector = None + logger.info(f"Initialized SQLAlchemy database manager for {self.database_type}") + def set_graph_components( + self, + graph_search_service: Any = None, + entity_extractor: Any = None, + relationship_detector: Any = None, + ): + """ + Set graph search components after initialization + + Args: + graph_search_service: GraphSearchService instance + entity_extractor: EntityExtractionService instance + relationship_detector: RelationshipDetectionService instance + """ + self.graph_search_service = graph_search_service + self.entity_extractor = entity_extractor + self.relationship_detector = relationship_detector + logger.info("Graph search components configured for database manager") + + def get_session(self): + """ + Get a new database session (context manager) + + Returns: + SQLAlchemy session context manager + + Example: + with db_manager.get_session() as session: + results = session.query(Model).all() + """ + return self.SessionLocal() + def _validate_database_dependencies(self, database_connect: str): """Validate that required database drivers are installed""" if database_connect.startswith("mysql:") or database_connect.startswith( @@ -485,9 +522,14 @@ def _get_search_service(self) -> SearchService: logger.error("Failed to create database session") return None - search_service = SearchService(session, self.database_type) + search_service = SearchService( + session, + self.database_type, + graph_search_service=self.graph_search_service + ) logger.debug( - f"Created new search service instance for database type: {self.database_type}" + f"Created new search service instance for database type: {self.database_type} " + f"(graph_enabled={self.graph_search_service is not None})" ) return search_service @@ -618,13 +660,68 @@ def store_long_term_memory_enhanced( session.commit() logger.debug(f"Stored enhanced long-term memory {memory_id}") - return memory_id except SQLAlchemyError as e: session.rollback() logger.error(f"Failed to store enhanced long-term memory: {e}") raise DatabaseError(f"Failed to store enhanced long-term memory: {e}") + # Automatic graph building (outside transaction to not fail memory storage) + try: + # Extract entities - but skip if already extracted + if self.entity_extractor: + # Check if entities already exist for this memory + with self.SessionLocal() as check_session: + from sqlalchemy import text + existing_count = check_session.execute( + text("SELECT COUNT(*) FROM memory_entities WHERE memory_id = :memory_id"), + {"memory_id": memory_id} + ).scalar() + + # Skip extraction if entities already exist + if existing_count > 0: + logger.debug( + f"Skipping entity extraction for {memory_id[:8]}... - {existing_count} entities already exist" + ) + else: + logger.debug(f"Extracting entities for new memory {memory_id[:8]}...") + entities = self.entity_extractor.extract_entities( + memory_id=memory_id, + memory_type="long_term", + content=memory.content, + namespace=namespace + ) + + if entities: + with self.SessionLocal() as entity_session: + saved_count = self.entity_extractor.save_entities( + entities, entity_session + ) + logger.debug( + f"Extracted and saved {saved_count} entities for memory {memory_id[:8]}..." + ) + + # Detect relationships + if self.relationship_detector: + relationships = self.relationship_detector.detect_relationships_for_memory( + memory_id=memory_id, + memory_type="long_term", + namespace=namespace, + max_candidates=50 + ) + logger.debug( + f"Created {len(relationships)} relationships for memory {memory_id[:8]}..." + ) + + except Exception as graph_error: + # Don't fail memory storage if graph building fails + logger.warning( + f"Graph building failed for memory {memory_id[:8]}...: {graph_error}" + ) + logger.debug("Graph building error details", exc_info=True) + + return memory_id + def search_memories( self, query: str, diff --git a/memori/database/templates/schemas/basic.sql b/memori/database/templates/schemas/basic.sql index 43aac7d9..74b4478d 100644 --- a/memori/database/templates/schemas/basic.sql +++ b/memori/database/templates/schemas/basic.sql @@ -84,6 +84,46 @@ CREATE TABLE IF NOT EXISTS long_term_memory ( conscious_processed BOOLEAN DEFAULT 0 -- Processed for conscious context extraction ); +-- ====================================== +-- GRAPH-BASED SEARCH TABLES +-- ====================================== + +-- Memory Entities Table +-- Stores extracted entities for graph-based search and relationship building +CREATE TABLE IF NOT EXISTS memory_entities ( + entity_id TEXT PRIMARY KEY, + memory_id TEXT NOT NULL, + memory_type TEXT NOT NULL, -- 'short_term' or 'long_term' + entity_type TEXT NOT NULL, -- person, technology, topic, skill, project, keyword + entity_value TEXT NOT NULL, + normalized_value TEXT NOT NULL, -- Lowercase for case-insensitive matching + relevance_score REAL DEFAULT 0.5, + namespace TEXT NOT NULL DEFAULT 'default', + frequency INTEGER DEFAULT 1, -- How many times mentioned + created_at TIMESTAMP NOT NULL, + context TEXT -- Additional context about this entity +); + +-- Memory Relationships Table +-- Stores graph relationships between memories for advanced traversal +CREATE TABLE IF NOT EXISTS memory_relationships ( + relationship_id TEXT PRIMARY KEY, + source_memory_id TEXT NOT NULL, + target_memory_id TEXT NOT NULL, + source_memory_type TEXT NOT NULL, -- 'short_term' or 'long_term' + target_memory_type TEXT NOT NULL, + relationship_type TEXT NOT NULL, -- semantic_similarity, causality, reference, etc. + strength REAL NOT NULL DEFAULT 0.5, -- 0.0-1.0 + bidirectional BOOLEAN DEFAULT 1, + namespace TEXT NOT NULL DEFAULT 'default', + created_at TIMESTAMP NOT NULL, + last_strengthened TIMESTAMP, -- When strength was last updated + access_count INTEGER DEFAULT 0, -- How often traversed + reasoning TEXT, -- Why this relationship exists + shared_entity_count INTEGER DEFAULT 0, -- Number of shared entities + metadata_json TEXT DEFAULT '{}' -- Additional relationship metadata +); + -- Performance Indexes -- Chat History Indexes @@ -119,6 +159,31 @@ CREATE INDEX IF NOT EXISTS idx_long_term_conscious_processed ON long_term_memory CREATE INDEX IF NOT EXISTS idx_long_term_duplicates ON long_term_memory(processed_for_duplicates); CREATE INDEX IF NOT EXISTS idx_long_term_confidence ON long_term_memory(confidence_score); +-- Graph Entity Indexes (9 indexes for fast entity search) +CREATE INDEX IF NOT EXISTS idx_entity_memory ON memory_entities(memory_id, memory_type); +CREATE INDEX IF NOT EXISTS idx_entity_type ON memory_entities(entity_type); +CREATE INDEX IF NOT EXISTS idx_entity_value ON memory_entities(entity_value); +CREATE INDEX IF NOT EXISTS idx_entity_normalized ON memory_entities(normalized_value); +CREATE INDEX IF NOT EXISTS idx_entity_namespace ON memory_entities(namespace); +CREATE INDEX IF NOT EXISTS idx_entity_relevance ON memory_entities(relevance_score); +CREATE INDEX IF NOT EXISTS idx_entity_type_value ON memory_entities(entity_type, normalized_value); +CREATE INDEX IF NOT EXISTS idx_entity_namespace_type ON memory_entities(namespace, entity_type); +CREATE INDEX IF NOT EXISTS idx_entity_compound ON memory_entities(namespace, entity_type, normalized_value, relevance_score); + +-- Graph Relationship Indexes (12 indexes for fast graph traversal) +CREATE INDEX IF NOT EXISTS idx_rel_source ON memory_relationships(source_memory_id, source_memory_type); +CREATE INDEX IF NOT EXISTS idx_rel_target ON memory_relationships(target_memory_id, target_memory_type); +CREATE INDEX IF NOT EXISTS idx_rel_type ON memory_relationships(relationship_type); +CREATE INDEX IF NOT EXISTS idx_rel_strength ON memory_relationships(strength); +CREATE INDEX IF NOT EXISTS idx_rel_namespace ON memory_relationships(namespace); +CREATE INDEX IF NOT EXISTS idx_rel_bidirectional ON memory_relationships(bidirectional); +CREATE INDEX IF NOT EXISTS idx_rel_source_type ON memory_relationships(source_memory_id, relationship_type); +CREATE INDEX IF NOT EXISTS idx_rel_target_type ON memory_relationships(target_memory_id, relationship_type); +CREATE INDEX IF NOT EXISTS idx_rel_compound_source ON memory_relationships(source_memory_id, relationship_type, strength); +CREATE INDEX IF NOT EXISTS idx_rel_compound_target ON memory_relationships(target_memory_id, relationship_type, strength); +CREATE INDEX IF NOT EXISTS idx_rel_namespace_type ON memory_relationships(namespace, relationship_type, strength); +CREATE INDEX IF NOT EXISTS idx_rel_entity_count ON memory_relationships(shared_entity_count); + -- Full-Text Search Support (SQLite FTS5) -- Enables advanced text search capabilities CREATE VIRTUAL TABLE IF NOT EXISTS memory_search_fts USING fts5( diff --git a/memori/processors/__init__.py b/memori/processors/__init__.py new file mode 100644 index 00000000..93b40d3c --- /dev/null +++ b/memori/processors/__init__.py @@ -0,0 +1,9 @@ +""" +Memory Processing Components +Entity extraction and relationship detection for graph building +""" + +from .entity_extraction import EntityExtractionService +from .relationship_detection import RelationshipDetectionService + +__all__ = ["EntityExtractionService", "RelationshipDetectionService"] diff --git a/memori/processors/entity_extraction.py b/memori/processors/entity_extraction.py new file mode 100644 index 00000000..d2c48528 --- /dev/null +++ b/memori/processors/entity_extraction.py @@ -0,0 +1,376 @@ +""" +Entity Extraction Service +Extracts entities from memory content to populate the memory graph +""" + +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional + +import openai +from loguru import logger +from pydantic import BaseModel, Field + +from memori.database.models import MemoryEntity +from memori.utils.pydantic_models import EntityType + + +class ExtractedEntityWithMetadata(BaseModel): + """Structured entity extraction output""" + + entity_type: str = Field(description="Type of entity (person, technology, topic, etc.)") + entity_value: str = Field(description="The actual entity value/name") + relevance_score: float = Field( + ge=0.0, le=1.0, description="Relevance to the memory (0-1)" + ) + context: Optional[str] = Field( + default=None, description="Brief context about this entity in the memory" + ) + + +class EntityExtractionResult(BaseModel): + """Complete entity extraction result""" + + entities: List[ExtractedEntityWithMetadata] = Field( + description="List of extracted entities" + ) + extraction_confidence: float = Field( + ge=0.0, + le=1.0, + default=0.8, + description="Overall confidence in extraction", + ) + + +class EntityExtractionService: + """ + Service for extracting entities from memory content using LLM structured outputs + """ + + EXTRACTION_PROMPT = """You are an entity extraction specialist. Extract all relevant entities from the given memory content. + +**ENTITY TYPES TO EXTRACT:** + +1. **person** - People, names, authors, developers, team members + Examples: "John Smith", "Alice", "the team lead" + +2. **technology** - Technologies, tools, libraries, frameworks, languages + Examples: "Python", "Docker", "JWT", "PostgreSQL", "FastAPI" + +3. **topic** - Topics, concepts, subjects, themes + Examples: "authentication", "machine learning", "database design" + +4. **skill** - Skills, abilities, competencies, expertise areas + Examples: "API development", "code review", "debugging" + +5. **project** - Projects, repositories, applications, systems + Examples: "user-dashboard", "payment-service", "mobile-app" + +6. **keyword** - Important keywords, terms, acronyms + Examples: "API key", "rate limiting", "CI/CD" + +**EXTRACTION GUIDELINES:** +- Extract entities that are central to the memory's meaning +- Assign relevance scores based on importance to the memory +- Provide brief context about how the entity appears in the memory +- Normalize entity values (e.g., "jwt" → "JWT", "docker" → "Docker") +- Extract both explicit mentions and implicit references +- Prioritize quality over quantity (5-10 high-quality entities is better than 20 low-quality) + +**RELEVANCE SCORING:** +- 1.0: Central to the memory's core meaning +- 0.8: Important supporting entity +- 0.6: Relevant but not critical +- 0.4: Mentioned but peripheral +- 0.2: Barely relevant + +Extract entities now.""" + + def __init__( + self, + client: openai.OpenAI, + model: str = "gpt-4o-mini", + batch_size: int = 10, + ): + """ + Initialize entity extraction service + + Args: + client: OpenAI client instance + model: Model to use for extraction (default: gpt-4o-mini for speed) + batch_size: Number of memories to process in batch + """ + self.client = client + self.model = model + self.batch_size = batch_size + + # Statistics + self.stats = { + "total_extractions": 0, + "total_entities": 0, + "avg_entities_per_memory": 0.0, + "extraction_errors": 0, + } + + logger.info(f"EntityExtractionService initialized with model={model}") + + def extract_entities( + self, + memory_id: str, + memory_type: str, + content: str, + namespace: str = "default", + ) -> List[MemoryEntity]: + """ + Extract entities from a single memory + + Args: + memory_id: Memory identifier + memory_type: 'short_term' or 'long_term' + content: Memory content text + namespace: Memory namespace + + Returns: + List of MemoryEntity objects ready to insert + """ + try: + # Call LLM for extraction (marked as system to avoid context injection loops) + completion = self.client.beta.chat.completions.parse( + model=self.model, + messages=[ + { + "role": "system", + "content": f"{self.EXTRACTION_PROMPT}\n\nExtract entities from this memory:\n\n{content}", + }, + ], + response_format=EntityExtractionResult, + temperature=0.3, + ) + + if completion.choices[0].message.refusal: + logger.warning(f"Entity extraction refused for {memory_id}") + return [] + + result: EntityExtractionResult = completion.choices[0].message.parsed + + # Convert to MemoryEntity objects + entities = [] + for extracted in result.entities: + entity = MemoryEntity( + entity_id=str(uuid.uuid4()), + memory_id=memory_id, + memory_type=memory_type, + entity_type=extracted.entity_type, + entity_value=extracted.entity_value, + normalized_value=extracted.entity_value.lower().strip(), + relevance_score=extracted.relevance_score, + namespace=namespace, + frequency=1, + created_at=datetime.utcnow(), + context=extracted.context, + ) + entities.append(entity) + + # Update stats + self.stats["total_extractions"] += 1 + self.stats["total_entities"] += len(entities) + self.stats["avg_entities_per_memory"] = ( + self.stats["total_entities"] / self.stats["total_extractions"] + ) + + logger.debug( + f"Extracted {len(entities)} entities from memory {memory_id[:8]}..." + ) + + return entities + + except Exception as e: + logger.error(f"Entity extraction failed for {memory_id}: {e}") + self.stats["extraction_errors"] += 1 + return [] + + def extract_entities_batch( + self, + memories: List[Dict[str, Any]], + namespace: str = "default", + ) -> Dict[str, List[MemoryEntity]]: + """ + Extract entities from multiple memories in batch + + Args: + memories: List of memory dicts with 'memory_id', 'content', 'memory_type' + namespace: Memory namespace + + Returns: + Dict mapping memory_id to list of MemoryEntity objects + """ + results = {} + + for memory in memories: + memory_id = memory.get("memory_id") + memory_type = memory.get("memory_type", "long_term") + content = memory.get("content", "") + + if not memory_id or not content: + logger.warning(f"Skipping memory with missing data: {memory}") + continue + + entities = self.extract_entities( + memory_id=memory_id, + memory_type=memory_type, + content=content, + namespace=namespace, + ) + + results[memory_id] = entities + + logger.info( + f"Batch extraction: {len(results)} memories, " + f"{sum(len(e) for e in results.values())} total entities" + ) + + return results + + def save_entities( + self, + entities: List[MemoryEntity], + session: Any, + ) -> int: + """ + Save extracted entities to database + + Args: + entities: List of MemoryEntity objects + session: SQLAlchemy session + + Returns: + Number of entities saved + """ + if not entities: + return 0 + + try: + # Check for duplicates and merge + existing_map = {} + for entity in entities: + # Query existing entity for this memory + entity value + existing = ( + session.query(MemoryEntity) + .filter( + MemoryEntity.memory_id == entity.memory_id, + MemoryEntity.normalized_value == entity.normalized_value, + ) + .first() + ) + + if existing: + # Update frequency and relevance + existing.frequency += 1 + existing.relevance_score = max( + existing.relevance_score, entity.relevance_score + ) + existing_map[entity.entity_id] = existing + else: + # Add new entity + session.add(entity) + + session.commit() + logger.debug(f"Saved {len(entities)} entities to database") + return len(entities) + + except Exception as e: + logger.error(f"Failed to save entities: {e}") + session.rollback() + return 0 + + def backfill_entities( + self, + db_manager: Any, + namespace: str = "default", + batch_size: int = 50, + limit: Optional[int] = None, + ) -> Dict[str, int]: + """ + Backfill entities for existing memories that don't have entities + + Args: + db_manager: Database manager with session + namespace: Memory namespace + batch_size: Memories to process per batch + limit: Optional limit on total memories to process + + Returns: + Stats dict with counts + """ + logger.info(f"Starting entity backfill for namespace={namespace}") + + stats = { + "memories_processed": 0, + "entities_created": 0, + "errors": 0, + } + + with db_manager.get_session() as session: + # Find memories without entities + query = """ + SELECT m.memory_id, m.searchable_content as content, 'long_term' as memory_type + FROM long_term_memory m + LEFT JOIN memory_entities e ON m.memory_id = e.memory_id + WHERE m.namespace = :namespace + AND e.entity_id IS NULL + + UNION ALL + + SELECT m.memory_id, m.searchable_content as content, 'short_term' as memory_type + FROM short_term_memory m + LEFT JOIN memory_entities e ON m.memory_id = e.memory_id + WHERE m.namespace = :namespace + AND e.entity_id IS NULL + """ + + if limit: + query += f" LIMIT {limit}" + + from sqlalchemy import text + + result = session.execute(text(query), {"namespace": namespace}) + memories_to_process = [ + {"memory_id": row[0], "content": row[1], "memory_type": row[2]} + for row in result + ] + + logger.info(f"Found {len(memories_to_process)} memories needing entities") + + # Process in batches + for i in range(0, len(memories_to_process), batch_size): + batch = memories_to_process[i : i + batch_size] + + try: + # Extract entities for batch + batch_results = self.extract_entities_batch(batch, namespace) + + # Save all entities + for memory_id, entities in batch_results.items(): + saved = self.save_entities(entities, session) + stats["entities_created"] += saved + stats["memories_processed"] += 1 + + logger.info( + f"Backfill progress: {stats['memories_processed']}/{len(memories_to_process)} memories" + ) + + except Exception as e: + logger.error(f"Batch backfill failed: {e}") + stats["errors"] += 1 + continue + + logger.success( + f"Backfill complete: {stats['memories_processed']} memories, " + f"{stats['entities_created']} entities created" + ) + + return stats + + def get_stats(self) -> Dict[str, Any]: + """Get extraction statistics""" + return self.stats.copy() diff --git a/memori/processors/relationship_detection.py b/memori/processors/relationship_detection.py new file mode 100644 index 00000000..a14efc55 --- /dev/null +++ b/memori/processors/relationship_detection.py @@ -0,0 +1,412 @@ +""" +Relationship Detection Service +Automatically detects and creates relationships between memories +""" + +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional + +from loguru import logger + +from memori.database.models import MemoryEntity, MemoryRelationshipDB +from memori.utils.pydantic_models import RelationshipType + + +class RelationshipDetectionService: + """ + Service for detecting relationships between memories + Uses entity overlap, semantic similarity, and heuristics + """ + + def __init__(self, db_manager: Any, threshold_config: Optional[Dict[str, float]] = None): + """ + Initialize relationship detection service + + Args: + db_manager: Database manager instance + threshold_config: Optional thresholds for relationship creation + """ + self.db_manager = db_manager + + # Configurable thresholds + self.thresholds = threshold_config or { + "min_entity_overlap": 1, # Minimum shared entities (lowered for small databases) + "min_strength": 0.2, # Minimum relationship strength (lowered to be more permissive) + "entity_overlap_weight": 0.6, # Weight for entity overlap + "temporal_proximity_weight": 0.2, # Weight for time proximity + "category_match_weight": 0.2, # Weight for category match + } + + # Statistics + self.stats = { + "relationships_created": 0, + "relationships_updated": 0, + "memories_analyzed": 0, + } + + logger.info("RelationshipDetectionService initialized") + + def detect_relationships_for_memory( + self, + memory_id: str, + memory_type: str, + namespace: str = "default", + max_candidates: int = 50, + ) -> List[MemoryRelationshipDB]: + """ + Detect relationships for a newly added memory + + Args: + memory_id: Memory identifier + memory_type: 'short_term' or 'long_term' + namespace: Memory namespace + max_candidates: Maximum candidate memories to compare + + Returns: + List of detected relationships + """ + logger.debug(f"Detecting relationships for memory {memory_id[:8]}...") + + relationships = [] + + with self.db_manager.get_session() as session: + # Step 1: Get entities for this memory + source_entities = ( + session.query(MemoryEntity) + .filter( + MemoryEntity.memory_id == memory_id, + MemoryEntity.namespace == namespace, + ) + .all() + ) + + if not source_entities: + logger.debug(f"No entities found for memory {memory_id}, skipping") + return [] + + source_entity_values = {e.normalized_value for e in source_entities} + + # Step 2: Find candidate memories with overlapping entities + candidates = self._find_candidate_memories( + session, + source_entity_values, + memory_id, + namespace, + max_candidates, + ) + + logger.debug(f"Found {len(candidates)} candidate memories for comparison") + + # Step 3: Calculate relationship strength for each candidate + for candidate in candidates: + relationship = self._calculate_relationship( + source_memory_id=memory_id, + source_memory_type=memory_type, + target_memory_id=candidate["memory_id"], + target_memory_type=candidate["memory_type"], + shared_entities=candidate["shared_entities"], + namespace=namespace, + ) + + if relationship and relationship.strength >= self.thresholds["min_strength"]: + relationships.append(relationship) + + # Step 4: Save relationships + if relationships: + self._save_relationships(session, relationships) + + self.stats["memories_analyzed"] += 1 + self.stats["relationships_created"] += len(relationships) + + logger.info( + f"Created {len(relationships)} relationships for memory {memory_id[:8]}..." + ) + + return relationships + + def _find_candidate_memories( + self, + session: Any, + source_entities: set, + exclude_memory_id: str, + namespace: str, + limit: int, + ) -> List[Dict[str, Any]]: + """Find memories with overlapping entities""" + + from sqlalchemy import func, text + + # Build entity filter + entity_list = list(source_entities) + entity_placeholders = ",".join([f":entity_{i}" for i in range(len(entity_list))]) + + # Some dialects (PostgreSQL) use string_agg instead of GROUP_CONCAT + dialect = getattr(session.bind.dialect, "name", "sqlite") + if dialect == "postgresql": + aggregate_expr = "string_agg(e.normalized_value, ',' ORDER BY e.normalized_value)" + else: + aggregate_expr = "GROUP_CONCAT(e.normalized_value, ',')" + + count_expr = "COUNT(DISTINCT e.normalized_value)" + + query = f""" + SELECT + e.memory_id, + e.memory_type, + {aggregate_expr} as shared_entities, + {count_expr} as overlap_count + FROM memory_entities e + WHERE e.namespace = :namespace + AND e.memory_id != :exclude_memory_id + AND e.normalized_value IN ({entity_placeholders}) + GROUP BY e.memory_id, e.memory_type + HAVING {count_expr} >= :min_overlap + ORDER BY overlap_count DESC + LIMIT :limit_val + """ + + # Build params dict with named parameters + params = { + "namespace": namespace, + "exclude_memory_id": exclude_memory_id, + "min_overlap": self.thresholds["min_entity_overlap"], + "limit_val": limit, + } + + # Add entity parameters + for i, entity in enumerate(entity_list): + params[f"entity_{i}"] = entity + + result = session.execute(text(query), params) + + candidates = [] + for row in result: + candidates.append( + { + "memory_id": row[0], + "memory_type": row[1], + "shared_entities": row[2].split(",") if row[2] else [], + "overlap_count": row[3], + } + ) + + return candidates + + def _calculate_relationship( + self, + source_memory_id: str, + source_memory_type: str, + target_memory_id: str, + target_memory_type: str, + shared_entities: List[str], + namespace: str, + ) -> Optional[MemoryRelationshipDB]: + """ + Calculate relationship strength and type between two memories + + Uses a composite scoring approach: + - Entity overlap (primary signal) + - Temporal proximity (memories created around same time) + - Category match (memories in same category) + """ + + # Calculate entity overlap score (0-1). Cap at 3 shared entities so that + # even modest overlap produces a meaningful signal. + overlap_count = len(shared_entities) + entity_overlap_score = min(1.0, overlap_count / 3.0) + + # Baseline strength ensures that a single shared entity still produces a + # traversable edge once graph thresholds are applied. Remaining weight is + # driven by actual overlap so denser connections still rank higher. + base_strength = 0.2 + overlap_component = ( + entity_overlap_score * self.thresholds["entity_overlap_weight"] + ) + strength = min(1.0, base_strength + overlap_component) + + # TODO: incorporate temporal and category proximity when available. For + # now the baseline + overlap component keeps edges discoverable without + # inflating low-value links. + + # Determine relationship type based on shared entities + relationship_type = self._infer_relationship_type(shared_entities, overlap_count) + + # Create relationship + relationship = MemoryRelationshipDB( + relationship_id=str(uuid.uuid4()), + source_memory_id=source_memory_id, + target_memory_id=target_memory_id, + source_memory_type=source_memory_type, + target_memory_type=target_memory_type, + relationship_type=relationship_type, + strength=strength, + bidirectional=True, + namespace=namespace, + created_at=datetime.utcnow(), + shared_entity_count=overlap_count, + reasoning=f"Shared {overlap_count} entities: {', '.join(shared_entities[:3])}", + ) + + return relationship + + def _infer_relationship_type( + self, shared_entities: List[str], overlap_count: int + ) -> str: + """ + Infer relationship type based on shared entities and context + + For now, uses simple heuristics: + - High overlap (4+) → semantic_similarity + - Medium overlap (2-3) → related_entity + - Specific patterns → other types (future enhancement) + """ + + if overlap_count >= 4: + return RelationshipType.SEMANTIC_SIMILARITY.value + + # Default to related_entity for shared entities + return RelationshipType.RELATED_ENTITY.value + + def _save_relationships( + self, session: Any, relationships: List[MemoryRelationshipDB] + ) -> int: + """Save relationships to database, handling duplicates""" + + saved_count = 0 + + for rel in relationships: + # Check if relationship already exists (bidirectional check) + existing = ( + session.query(MemoryRelationshipDB) + .filter( + ( + (MemoryRelationshipDB.source_memory_id == rel.source_memory_id) + & (MemoryRelationshipDB.target_memory_id == rel.target_memory_id) + ) + | ( + (MemoryRelationshipDB.source_memory_id == rel.target_memory_id) + & (MemoryRelationshipDB.target_memory_id == rel.source_memory_id) + ) + ) + .first() + ) + + if existing: + # Update existing relationship + existing.strength = max(existing.strength, rel.strength) + existing.shared_entity_count = max( + existing.shared_entity_count, rel.shared_entity_count + ) + existing.last_strengthened = datetime.utcnow() + self.stats["relationships_updated"] += 1 + else: + # Create new relationship + session.add(rel) + saved_count += 1 + + session.commit() + return saved_count + + def backfill_relationships( + self, + namespace: str = "default", + batch_size: int = 100, + limit: Optional[int] = None, + ) -> Dict[str, int]: + """ + Backfill relationships for all existing memories + + Args: + namespace: Memory namespace + batch_size: Memories to process per batch + limit: Optional limit on total memories + + Returns: + Stats dict with counts + """ + logger.info(f"Starting relationship backfill for namespace={namespace}") + + stats = { + "memories_processed": 0, + "relationships_created": 0, + } + + with self.db_manager.get_session() as session: + # Get all memories with entities + from sqlalchemy import text + + query = """ + SELECT DISTINCT e.memory_id, e.memory_type + FROM memory_entities e + WHERE e.namespace = :namespace + """ + + if limit: + query += f" LIMIT {limit}" + + result = session.execute(text(query), {"namespace": namespace}) + memories = [{"memory_id": row[0], "memory_type": row[1]} for row in result] + + logger.info(f"Found {len(memories)} memories to process") + + # Process each memory + for i, memory in enumerate(memories): + try: + relationships = self.detect_relationships_for_memory( + memory_id=memory["memory_id"], + memory_type=memory["memory_type"], + namespace=namespace, + ) + + stats["relationships_created"] += len(relationships) + stats["memories_processed"] += 1 + + if (i + 1) % 10 == 0: + logger.info( + f"Backfill progress: {i+1}/{len(memories)} memories processed" + ) + + except Exception as e: + logger.error(f"Failed to process memory {memory['memory_id']}: {e}") + continue + + logger.success( + f"Backfill complete: {stats['memories_processed']} memories, " + f"{stats['relationships_created']} relationships created" + ) + + return stats + + def strengthen_relationship( + self, relationship_id: str, strength_increase: float = 0.1 + ) -> bool: + """ + Strengthen an existing relationship (e.g., when traversed/accessed) + + Args: + relationship_id: Relationship identifier + strength_increase: Amount to increase strength by + + Returns: + True if successful + """ + with self.db_manager.get_session() as session: + rel = ( + session.query(MemoryRelationshipDB) + .filter(MemoryRelationshipDB.relationship_id == relationship_id) + .first() + ) + + if rel: + rel.strength = min(1.0, rel.strength + strength_increase) + rel.access_count += 1 + rel.last_strengthened = datetime.utcnow() + session.commit() + return True + + return False + + def get_stats(self) -> Dict[str, Any]: + """Get detection statistics""" + return self.stats.copy() diff --git a/memori/utils/graph_hooks.py b/memori/utils/graph_hooks.py new file mode 100644 index 00000000..9b9a2dfc --- /dev/null +++ b/memori/utils/graph_hooks.py @@ -0,0 +1,226 @@ +""" +Graph Building Hooks +Automatic entity extraction and relationship detection when memories are stored +""" + +from typing import Any, Dict, Optional + +from loguru import logger + + +class MemoryStorageHook: + """ + Hook system for automatic graph building when memories are stored + + Usage: + hook = MemoryStorageHook( + entity_extractor=mem.entity_extractor, + relationship_detector=mem.relationship_detector, + enabled=True + ) + + # After storing a memory + hook.process_memory( + memory_id=memory_id, + memory_type="long_term", + content=content, + namespace=namespace + ) + """ + + def __init__( + self, + entity_extractor: Any = None, + relationship_detector: Any = None, + enabled: bool = True, + async_processing: bool = False, + ): + """ + Initialize storage hook + + Args: + entity_extractor: EntityExtractionService instance + relationship_detector: RelationshipDetectionService instance + enabled: Enable automatic graph building + async_processing: Process in background (future enhancement) + """ + self.entity_extractor = entity_extractor + self.relationship_detector = relationship_detector + self.enabled = enabled + self.async_processing = async_processing + + # Statistics + self.stats = { + "memories_processed": 0, + "entities_extracted": 0, + "relationships_created": 0, + "errors": 0, + } + + logger.info( + f"MemoryStorageHook initialized (enabled={enabled}, " + f"async={async_processing})" + ) + + def process_memory( + self, + memory_id: str, + memory_type: str, + content: str, + namespace: str = "default", + db_session: Any = None, + ) -> Dict[str, int]: + """ + Process a newly stored memory to build graph + + Args: + memory_id: Memory identifier + memory_type: 'short_term' or 'long_term' + content: Memory content text + namespace: Memory namespace + db_session: Optional database session for batch operations + + Returns: + Stats dict with counts + """ + if not self.enabled: + return {"entities": 0, "relationships": 0} + + if not self.entity_extractor or not self.relationship_detector: + logger.warning("Graph components not initialized, skipping graph building") + return {"entities": 0, "relationships": 0} + + stats = {"entities": 0, "relationships": 0} + + try: + # Step 1: Extract entities + logger.debug(f"Extracting entities for memory {memory_id[:8]}...") + entities = self.entity_extractor.extract_entities( + memory_id=memory_id, + memory_type=memory_type, + content=content, + namespace=namespace, + ) + + # Step 2: Save entities + if entities and db_session: + saved = self.entity_extractor.save_entities(entities, db_session) + stats["entities"] = saved + self.stats["entities_extracted"] += saved + elif entities: + # Need to create session + logger.warning("No db_session provided, entities not saved") + + # Step 3: Detect relationships + logger.debug(f"Detecting relationships for memory {memory_id[:8]}...") + relationships = self.relationship_detector.detect_relationships_for_memory( + memory_id=memory_id, + memory_type=memory_type, + namespace=namespace, + max_candidates=50, + ) + + stats["relationships"] = len(relationships) + self.stats["relationships_created"] += len(relationships) + self.stats["memories_processed"] += 1 + + logger.info( + f"Graph built for {memory_id[:8]}: " + f"{stats['entities']} entities, {stats['relationships']} relationships" + ) + + return stats + + except Exception as e: + logger.error(f"Graph building failed for {memory_id}: {e}") + self.stats["errors"] += 1 + return {"entities": 0, "relationships": 0} + + def process_memory_batch( + self, + memories: list[Dict[str, Any]], + namespace: str = "default", + ) -> Dict[str, int]: + """ + Process multiple memories in batch + + Args: + memories: List of memory dicts with 'memory_id', 'content', 'memory_type' + namespace: Memory namespace + + Returns: + Aggregate stats + """ + if not self.enabled: + return {"total_entities": 0, "total_relationships": 0} + + total_stats = {"total_entities": 0, "total_relationships": 0} + + for memory in memories: + stats = self.process_memory( + memory_id=memory.get("memory_id"), + memory_type=memory.get("memory_type", "long_term"), + content=memory.get("content", ""), + namespace=namespace, + ) + + total_stats["total_entities"] += stats.get("entities", 0) + total_stats["total_relationships"] += stats.get("relationships", 0) + + logger.info( + f"Batch processing complete: {len(memories)} memories, " + f"{total_stats['total_entities']} entities, " + f"{total_stats['total_relationships']} relationships" + ) + + return total_stats + + def get_stats(self) -> Dict[str, int]: + """Get hook statistics""" + return self.stats.copy() + + def reset_stats(self): + """Reset statistics""" + self.stats = { + "memories_processed": 0, + "entities_extracted": 0, + "relationships_created": 0, + "errors": 0, + } + + def enable(self): + """Enable automatic graph building""" + self.enabled = True + logger.info("MemoryStorageHook enabled") + + def disable(self): + """Disable automatic graph building""" + self.enabled = False + logger.info("MemoryStorageHook disabled") + + +def create_storage_hook(memori_instance: Any, enabled: bool = True) -> MemoryStorageHook: + """ + Create a storage hook from Memori instance + + Args: + memori_instance: Memori class instance + enabled: Enable automatic processing + + Returns: + Configured MemoryStorageHook + + Example: + hook = create_storage_hook(mem) + + # Use in storage flow + memory_id = mem.store_memory(content) + hook.process_memory(memory_id, "long_term", content, namespace) + """ + hook = MemoryStorageHook( + entity_extractor=getattr(memori_instance, "entity_extractor", None), + relationship_detector=getattr(memori_instance, "relationship_detector", None), + enabled=enabled, + ) + + return hook diff --git a/memori/utils/pydantic_models.py b/memori/utils/pydantic_models.py index c83fd0dc..20273a66 100644 --- a/memori/utils/pydantic_models.py +++ b/memori/utils/pydantic_models.py @@ -58,6 +58,42 @@ class EntityType(str, Enum): keyword = "keyword" +class SearchStrategy(str, Enum): + """Graph-based search strategies for memory retrieval""" + + TEXT_ONLY = "text_only" # Traditional text search only (~30ms) + ENTITY_FIRST = "entity_first" # Search by entity tags first (~100ms) + GRAPH_EXPANSION_1HOP = "graph_expansion_1hop" # 1-hop graph traversal (~150ms) + GRAPH_EXPANSION_2HOP = "graph_expansion_2hop" # 2-hop graph traversal (~300ms) + GRAPH_WALK_CONTEXTUAL = "graph_walk_contextual" # Walk relationships (~350ms) + ENTITY_CLUSTER_DISCOVERY = "entity_cluster_discovery" # Find entity clusters (~200ms) + CATEGORY_FOCUSED_GRAPH = "category_focused_graph" # Category + graph (~180ms) + + +class RelationshipType(str, Enum): + """Types of relationships between memories in the graph""" + + SEMANTIC_SIMILARITY = "semantic_similarity" # Similar topics/concepts + CAUSALITY = "causality" # Cause and effect + REFERENCE = "reference" # One references the other + ELABORATION = "elaboration" # Provides more detail + CONTRADICTION = "contradiction" # Conflicting information + SUPPORTS = "supports" # Reinforces/validates + PREREQUISITE = "prerequisite" # Required knowledge + TEMPORAL = "temporal" # Time-based relationship + RELATED_ENTITY = "related_entity" # Share entities + TOPIC_CONTINUATION = "topic_continuation" # Continue same topic + + +class ExpansionStrategy(str, Enum): + """Graph traversal strategies""" + + BREADTH_FIRST = "breadth_first" # BFS traversal + DEPTH_FIRST = "depth_first" # DFS traversal + STRONGEST_FIRST = "strongest_first" # Follow strongest relationships + ENTITY_GUIDED = "entity_guided" # Follow entity overlaps + + # Define constrained types using Annotated ConfidenceScore = Annotated[float, Field(ge=0.0, le=1.0)] ImportanceScore = Annotated[float, Field(ge=0.0, le=1.0)] @@ -142,8 +178,114 @@ class MemoryImportance(BaseModel): ) +class GraphExpansionConfig(BaseModel): + """Configuration for graph-based memory expansion""" + + enabled: bool = Field(default=False, description="Enable graph expansion") + hop_distance: int = Field( + default=1, + ge=0, + le=3, + description="Number of hops to traverse (0-3)" + ) + min_relationship_strength: float = Field( + default=0.2, + ge=0.0, + le=1.0, + description="Minimum relationship strength threshold" + ) + expansion_strategy: ExpansionStrategy = Field( + default=ExpansionStrategy.BREADTH_FIRST, + description="Graph traversal strategy" + ) + relationship_type_filters: list[RelationshipType] | None = Field( + default=None, + description="Filter by specific relationship types" + ) + require_entity_overlap: bool = Field( + default=False, + description="Require shared entities for traversal" + ) + max_results_per_hop: int = Field( + default=10, + ge=1, + le=50, + description="Maximum results to return per hop level" + ) + + +class ScoringWeights(BaseModel): + """Weights for composite scoring in graph search""" + + text_relevance: float = Field( + default=0.35, + ge=0.0, + le=1.0, + description="Weight for text match relevance" + ) + entity_overlap: float = Field( + default=0.25, + ge=0.0, + le=1.0, + description="Weight for shared entity overlap" + ) + graph_strength: float = Field( + default=0.20, + ge=0.0, + le=1.0, + description="Weight for relationship strength" + ) + importance: float = Field( + default=0.15, + ge=0.0, + le=1.0, + description="Weight for memory importance" + ) + recency: float = Field( + default=0.05, + ge=0.0, + le=1.0, + description="Weight for temporal recency" + ) + + def model_post_init(self, __context) -> None: + """Normalize weights to sum to 1.0""" + total = ( + self.text_relevance + + self.entity_overlap + + self.graph_strength + + self.importance + + self.recency + ) + + # If weights don't sum to ~1.0, normalize them + if not (0.99 <= total <= 1.01): + # Avoid division by zero + if total > 0: + # Normalize all weights proportionally + self.text_relevance /= total + self.entity_overlap /= total + self.graph_strength /= total + self.importance /= total + self.recency /= total + + from loguru import logger + logger.debug( + f"Normalized scoring weights from {total:.3f} to 1.0: " + f"text={self.text_relevance:.3f}, entity={self.entity_overlap:.3f}, " + f"graph={self.graph_strength:.3f}, importance={self.importance:.3f}, recency={self.recency:.3f}" + ) + else: + # If all weights are 0, use defaults + self.text_relevance = 0.35 + self.entity_overlap = 0.25 + self.graph_strength = 0.20 + self.importance = 0.15 + self.recency = 0.05 + + class MemorySearchQuery(BaseModel): - """Structured query for memory search""" + """Structured query for memory search with graph-aware capabilities""" # Query components query_text: str = Field(description="Original query text") @@ -163,13 +305,133 @@ class MemorySearchQuery(BaseModel): default=0.0, description="Minimum importance score" ) - # Search strategy - search_strategy: list[str] = Field( - default_factory=list, description="Recommended search strategies" + # Graph-based search parameters + search_strategy: SearchStrategy = Field( + default=SearchStrategy.TEXT_ONLY, + description="Primary search strategy to use" + ) + graph_expansion: GraphExpansionConfig = Field( + default_factory=GraphExpansionConfig, + description="Graph expansion configuration" + ) + scoring_weights: ScoringWeights = Field( + default_factory=ScoringWeights, + description="Composite scoring weights" ) + + # Result preferences expected_result_types: list[str] = Field( default_factory=list, description="Expected types of results" ) + max_results: int = Field( + default=10, + ge=1, + le=100, + description="Maximum number of results to return" + ) + include_graph_metadata: bool = Field( + default=True, + description="Include graph traversal metadata in results" + ) + + +class GraphTraversalPath(BaseModel): + """Represents a path through the memory graph""" + + memory_ids: list[str] = Field(description="Ordered list of memory IDs in the path") + relationship_types: list[RelationshipType] = Field( + description="Types of relationships in the path" + ) + total_strength: float = Field( + ge=0.0, + le=1.0, + description="Combined strength of relationships in path" + ) + hop_count: int = Field(ge=0, description="Number of hops in this path") + + +class GraphSearchResult(BaseModel): + """Enhanced search result with graph metadata""" + + # Core memory data + memory_id: str = Field(description="Unique memory identifier") + content: str = Field(description="Memory content") + summary: str = Field(description="Memory summary") + category: MemoryCategoryType | None = Field( + default=None, + description="Memory category" + ) + + # Scoring components + composite_score: float = Field( + ge=0.0, + le=1.0, + description="Final composite relevance score" + ) + text_relevance_score: float = Field( + default=0.0, + ge=0.0, + le=1.0, + description="Text match relevance" + ) + entity_overlap_score: float = Field( + default=0.0, + ge=0.0, + le=1.0, + description="Entity overlap score" + ) + graph_strength_score: float = Field( + default=0.0, + ge=0.0, + le=1.0, + description="Graph relationship strength" + ) + importance_score: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Memory importance" + ) + recency_score: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Temporal recency" + ) + + # Graph metadata + hop_distance: int = Field( + default=0, + ge=0, + description="Distance from seed memory (0 = direct match)" + ) + shared_entities: list[str] = Field( + default_factory=list, + description="Entities shared with query or seed memories" + ) + graph_paths: list[GraphTraversalPath] = Field( + default_factory=list, + description="Paths through graph to this result" + ) + connected_via: list[str] = Field( + default_factory=list, + description="Memory IDs this result is connected through" + ) + + # Match explanation + match_reason: str = Field( + default="", + description="Human-readable explanation of why this matched" + ) + relationship_summary: str = Field( + default="", + description="Summary of graph relationships" + ) + + # Metadata + timestamp: datetime | None = Field(default=None, description="Memory creation time") + access_count: int = Field(default=0, description="Number of times accessed") + last_accessed: datetime | None = Field(default=None, description="Last access time") class MemoryRelationship(BaseModel):