From bdae708bcd20cbfb0b2173d168d9718589258e6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Sun, 30 Nov 2025 00:28:18 +0800 Subject: [PATCH] delete_node_by_prams for filter && simple support --- src/memos/graph_dbs/neo4j.py | 121 ++++++++++++++++++++-- src/memos/graph_dbs/polardb.py | 184 ++++++++++++++++++++++++++++++--- 2 files changed, 283 insertions(+), 22 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index e934d3a1..5ba1f116 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -812,6 +812,7 @@ def get_by_metadata( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + user_name_flag: bool = True, ) -> list[str]: """ TODO: @@ -876,11 +877,19 @@ def get_by_metadata( raise ValueError(f"Unsupported operator: {op}") # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher( - user_name=user_name, - knowledgebase_ids=knowledgebase_ids, - default_user_name=self.config.user_name, - node_alias="n", + user_name_conditions = [] + user_name_params = {} + if user_name_flag: + user_name_conditions, user_name_params = ( + self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + node_alias="n", + ) + ) + print( + f"[get_by_metadata] user_name_conditions: {user_name_conditions},user_name_params: {user_name_params}" ) # Add user_name WHERE clause @@ -1425,7 +1434,7 @@ def build_filter_condition(condition_dict: dict, param_counter: list) -> tuple[s # Use datetime() function for date comparisons if key in ("created_at", "updated_at") or key.endswith("_at"): condition_parts.append( - f"{node_alias}.{key} {cypher_op} datetime(${param_name})" + f"datetime({node_alias}.{key}) {cypher_op} datetime(${param_name})" ) else: condition_parts.append( @@ -1482,6 +1491,12 @@ def build_filter_condition(condition_dict: dict, param_counter: list) -> tuple[s if condition_str: filter_conditions.append(f"({condition_str})") filter_params.update(params) + else: + # Handle simple dict without "and" or "or" (e.g., {"id": "xxx"}) + condition_str, params = build_filter_condition(filter, param_counter) + if condition_str: + filter_conditions.append(condition_str) + filter_params.update(params) return filter_conditions, filter_params @@ -1505,3 +1520,97 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: break node["sources"][idx] = json.loads(node["sources"][idx]) return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node} + + def delete_node_by_prams( + self, + memory_ids: list[str] | None = None, + file_ids: list[str] | None = None, + filter: dict | None = None, + ) -> int: + """ + Delete nodes by memory_ids, file_ids, or filter. + + Args: + memory_ids (list[str], optional): List of memory node IDs to delete. + file_ids (list[str], optional): List of file node IDs to delete. + filter (dict, optional): Filter dictionary to query matching nodes for deletion. + + Returns: + int: Number of nodes deleted. + """ + # Collect all node IDs to delete + ids_to_delete = set() + + # Add memory_ids if provided + if memory_ids and len(memory_ids) > 0: + ids_to_delete.update(memory_ids) + + # Add file_ids if provided (treating them as node IDs) + if file_ids and len(file_ids) > 0: + ids_to_delete.update(file_ids) + + # Query nodes by filter if provided + if filter: + # Use get_by_metadata with empty filters list and filter + filter_ids = self.get_by_metadata( + filters=[], + user_name=None, + filter=filter, + knowledgebase_ids=None, + user_name_flag=False, + ) + ids_to_delete.update(filter_ids) + + # If no IDs to delete, return 0 + if not ids_to_delete: + logger.warning("[delete_node_by_prams] No nodes to delete") + return 0 + + # Convert to list for easier handling + ids_list = list(ids_to_delete) + logger.info(f"[delete_node_by_prams] Deleting {len(ids_list)} nodes: {ids_list}") + + # Build WHERE condition for collected IDs (query n.id) + ids_where = "n.id IN $ids_to_delete" + params = {"ids_to_delete": ids_list} + + # Calculate total count for logging + total_count = len(ids_list) + logger.info( + f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" + ) + print( + f"[delete_node_by_prams] Deleting {total_count} nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" + ) + + # First count matching nodes to get accurate count + count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count" + logger.info(f"[delete_node_by_prams] count_query: {count_query}") + print(f"[delete_node_by_prams] count_query: {count_query}") + + # Then delete nodes + delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n" + logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + print(f"[delete_node_by_prams] delete_query: {delete_query}") + + deleted_count = 0 + try: + with self.driver.session(database=self.db_name) as session: + # Count nodes before deletion + count_result = session.run(count_query, **params) + count_record = count_result.single() + expected_count = total_count + if count_record: + expected_count = count_record["node_count"] or total_count + + # Delete nodes + session.run(delete_query, **params) + # Use the count from before deletion as the actual deleted count + deleted_count = expected_count + + except Exception as e: + logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) + raise + + logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes") + return deleted_count diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index a7e60704..bfde8c80 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1619,6 +1619,7 @@ def get_by_metadata( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list | None = None, + user_name_flag: bool = True, ) -> list[str]: """ Retrieve node IDs that match given metadata filters. @@ -1693,11 +1694,14 @@ def get_by_metadata( raise ValueError(f"Unsupported operator: {op}") # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( - user_name=user_name, - knowledgebase_ids=knowledgebase_ids, - default_user_name=self._get_config_value("user_name"), - ) + user_name_conditions = [] + if user_name_flag: + user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self._get_config_value("user_name"), + ) + print(f"[get_by_metadata] user_name_conditions: {user_name_conditions}") # Add user_name WHERE clause if user_name_conditions: @@ -1709,16 +1713,26 @@ def get_by_metadata( # Build filter conditions using common method filter_where_clause = self._build_filter_conditions_cypher(filter) - where_str = " AND ".join(where_conditions) + filter_where_clause + # Build WHERE clause: if where_conditions is empty, filter_where_clause should not have " AND " prefix + if where_conditions: + where_str = " AND ".join(where_conditions) + filter_where_clause + else: + # If no other conditions, remove " AND " prefix from filter_where_clause if present + if filter_where_clause.startswith(" AND "): + where_str = filter_where_clause[5:] # Remove " AND " prefix + else: + where_str = filter_where_clause # Use cypher query + # Only include WHERE clause if where_str is not empty + where_clause = f"WHERE {where_str}" if where_str else "" cypher_query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - WHERE {where_str} - RETURN n.id AS id - $$) AS (id agtype) - """ + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + {where_clause} + RETURN n.id AS id + $$) AS (id agtype) + """ ids = [] conn = self._get_connection() @@ -3253,6 +3267,7 @@ def _build_user_name_and_kb_ids_conditions_cypher( """ user_name_conditions = [] effective_user_name = user_name if user_name else default_user_name + print(f"[delete_node_by_prams] effective_user_name: {effective_user_name}") if effective_user_name: escaped_user_name = effective_user_name.replace("'", "''") @@ -3505,6 +3520,11 @@ def build_cypher_filter_condition(condition_dict: dict) -> str: and_conditions.append(f"({condition_str})") if and_conditions: filter_where_clause = " AND " + " AND ".join(and_conditions) + else: + # Handle simple dict without "and" or "or" (e.g., {"id": "xxx"}) + condition_str = build_cypher_filter_condition(filter) + if condition_str: + filter_where_clause = " AND " + condition_str return filter_where_clause @@ -3654,11 +3674,11 @@ def build_filter_condition(condition_dict: dict) -> str: if isinstance(op_value, str): escaped_value = escape_sql_string(op_value) condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype) @> '\"{escaped_value}\"'::agtype" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '\"{escaped_value}\"'::agtype" ) else: condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype) @> {op_value}::agtype" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> {op_value}::agtype" ) else: # Direct property access @@ -3684,11 +3704,11 @@ def build_filter_condition(condition_dict: dict) -> str: .replace("_", "\\_") ) condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype)::text LIKE '%{escaped_value}%'" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype])::text LIKE '%{escaped_value}%'" ) else: condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype)::text LIKE '%{op_value}%'" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype])::text LIKE '%{op_value}%'" ) else: # Direct property access @@ -3752,6 +3772,11 @@ def build_filter_condition(condition_dict: dict) -> str: condition_str = build_filter_condition(condition) if condition_str: filter_conditions.append(f"({condition_str})") + else: + # Handle simple dict without "and" or "or" (e.g., {"id": "xxx"}) + condition_str = build_filter_condition(filter) + if condition_str: + filter_conditions.append(condition_str) return filter_conditions @@ -3823,3 +3848,130 @@ def process_condition(condition): return new_condition return process_condition(filter_dict) + + @timed + def delete_node_by_prams( + self, + memory_ids: list[str] | None = None, + file_ids: list[str] | None = None, + filter: dict | None = None, + ) -> int: + """ + Delete nodes by memory_ids, file_ids, or filter. + + Args: + memory_ids (list[str], optional): List of memory node IDs to delete. + file_ids (list[str], optional): List of file node IDs to delete. + filter (dict, optional): Filter dictionary to query matching nodes for deletion. + + Returns: + int: Number of nodes deleted. + """ + # Collect all node IDs to delete + ids_to_delete = set() + + # Add memory_ids if provided + if memory_ids and len(memory_ids) > 0: + ids_to_delete.update(memory_ids) + + # Add file_ids if provided (treating them as node IDs) + if file_ids and len(file_ids) > 0: + ids_to_delete.update(file_ids) + + # Query nodes by filter if provided + if filter: + # Parse filter to validate and transform field names (e.g., add "info." prefix if needed) + parsed_filter = self.parse_filter(filter) + if parsed_filter: + # Use get_by_metadata with empty filters list and parsed filter + filter_ids = self.get_by_metadata( + filters=[], + user_name=None, + filter=parsed_filter, + knowledgebase_ids=None, + user_name_flag=False, + ) + ids_to_delete.update(filter_ids) + else: + logger.warning( + "[delete_node_by_prams] Filter parsed to None, skipping filter query" + ) + + # If no IDs to delete, return 0 + if not ids_to_delete: + logger.warning("[delete_node_by_prams] No nodes to delete") + return 0 + + # Convert to list for easier handling + ids_list = list(ids_to_delete) + logger.info(f"[delete_node_by_prams] Deleting {len(ids_list)} nodes: {ids_list}") + + # Build WHERE condition for collected IDs (query n.id) + id_conditions = [] + for node_id in ids_list: + # Escape single quotes in node IDs + escaped_id = str(node_id).replace("'", "\\'") + id_conditions.append(f"'{escaped_id}'") + + # Build WHERE clause for IDs + ids_where = f"n.id IN [{', '.join(id_conditions)}]" + + # Use Cypher DELETE query + # First count matching nodes to get accurate count + count_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE {ids_where} + RETURN count(n) AS node_count + $$) AS (node_count agtype) + """ + logger.info(f"[delete_node_by_prams] count_query: {count_query}") + print(f"[delete_node_by_prams] count_query: {count_query}") + + # Then delete nodes + delete_query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE {ids_where} + DETACH DELETE n + $$) AS (result agtype) + """ + + # Calculate total count for logging + total_count = len(ids_list) + logger.info( + f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" + ) + print( + f"[delete_node_by_prams] Deleting {total_count} nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" + ) + logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + print(f"[delete_node_by_prams] delete_query: {delete_query}") + + conn = self._get_connection() + deleted_count = 0 + try: + with conn.cursor() as cursor: + # Count nodes before deletion + cursor.execute(count_query) + count_results = cursor.fetchall() + expected_count = total_count + if count_results and len(count_results) > 0: + count_str = str(count_results[0][0]) + count_str = count_str.strip('"').strip("'") + expected_count = int(count_str) if count_str.isdigit() else total_count + + # Delete nodes + cursor.execute(delete_query) + # Use the count from before deletion as the actual deleted count + deleted_count = expected_count + conn.commit() + except Exception as e: + logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) + conn.rollback() + raise + finally: + self._return_connection(conn) + + logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes") + return deleted_count