Skip to content

Commit 18f1655

Browse files
authored
delete_node_by_prams for filter && simple support (#558)
1 parent 7d16794 commit 18f1655

File tree

2 files changed

+283
-22
lines changed

2 files changed

+283
-22
lines changed

src/memos/graph_dbs/neo4j.py

Lines changed: 115 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,7 @@ def get_by_metadata(
812812
user_name: str | None = None,
813813
filter: dict | None = None,
814814
knowledgebase_ids: list[str] | None = None,
815+
user_name_flag: bool = True,
815816
) -> list[str]:
816817
"""
817818
TODO:
@@ -876,11 +877,19 @@ def get_by_metadata(
876877
raise ValueError(f"Unsupported operator: {op}")
877878

878879
# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
879-
user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher(
880-
user_name=user_name,
881-
knowledgebase_ids=knowledgebase_ids,
882-
default_user_name=self.config.user_name,
883-
node_alias="n",
880+
user_name_conditions = []
881+
user_name_params = {}
882+
if user_name_flag:
883+
user_name_conditions, user_name_params = (
884+
self._build_user_name_and_kb_ids_conditions_cypher(
885+
user_name=user_name,
886+
knowledgebase_ids=knowledgebase_ids,
887+
default_user_name=self.config.user_name,
888+
node_alias="n",
889+
)
890+
)
891+
print(
892+
f"[get_by_metadata] user_name_conditions: {user_name_conditions},user_name_params: {user_name_params}"
884893
)
885894

886895
# Add user_name WHERE clause
@@ -1425,7 +1434,7 @@ def build_filter_condition(condition_dict: dict, param_counter: list) -> tuple[s
14251434
# Use datetime() function for date comparisons
14261435
if key in ("created_at", "updated_at") or key.endswith("_at"):
14271436
condition_parts.append(
1428-
f"{node_alias}.{key} {cypher_op} datetime(${param_name})"
1437+
f"datetime({node_alias}.{key}) {cypher_op} datetime(${param_name})"
14291438
)
14301439
else:
14311440
condition_parts.append(
@@ -1482,6 +1491,12 @@ def build_filter_condition(condition_dict: dict, param_counter: list) -> tuple[s
14821491
if condition_str:
14831492
filter_conditions.append(f"({condition_str})")
14841493
filter_params.update(params)
1494+
else:
1495+
# Handle simple dict without "and" or "or" (e.g., {"id": "xxx"})
1496+
condition_str, params = build_filter_condition(filter, param_counter)
1497+
if condition_str:
1498+
filter_conditions.append(condition_str)
1499+
filter_params.update(params)
14851500

14861501
return filter_conditions, filter_params
14871502

@@ -1505,3 +1520,97 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
15051520
break
15061521
node["sources"][idx] = json.loads(node["sources"][idx])
15071522
return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}
1523+
1524+
def delete_node_by_prams(
1525+
self,
1526+
memory_ids: list[str] | None = None,
1527+
file_ids: list[str] | None = None,
1528+
filter: dict | None = None,
1529+
) -> int:
1530+
"""
1531+
Delete nodes by memory_ids, file_ids, or filter.
1532+
1533+
Args:
1534+
memory_ids (list[str], optional): List of memory node IDs to delete.
1535+
file_ids (list[str], optional): List of file node IDs to delete.
1536+
filter (dict, optional): Filter dictionary to query matching nodes for deletion.
1537+
1538+
Returns:
1539+
int: Number of nodes deleted.
1540+
"""
1541+
# Collect all node IDs to delete
1542+
ids_to_delete = set()
1543+
1544+
# Add memory_ids if provided
1545+
if memory_ids and len(memory_ids) > 0:
1546+
ids_to_delete.update(memory_ids)
1547+
1548+
# Add file_ids if provided (treating them as node IDs)
1549+
if file_ids and len(file_ids) > 0:
1550+
ids_to_delete.update(file_ids)
1551+
1552+
# Query nodes by filter if provided
1553+
if filter:
1554+
# Use get_by_metadata with empty filters list and filter
1555+
filter_ids = self.get_by_metadata(
1556+
filters=[],
1557+
user_name=None,
1558+
filter=filter,
1559+
knowledgebase_ids=None,
1560+
user_name_flag=False,
1561+
)
1562+
ids_to_delete.update(filter_ids)
1563+
1564+
# If no IDs to delete, return 0
1565+
if not ids_to_delete:
1566+
logger.warning("[delete_node_by_prams] No nodes to delete")
1567+
return 0
1568+
1569+
# Convert to list for easier handling
1570+
ids_list = list(ids_to_delete)
1571+
logger.info(f"[delete_node_by_prams] Deleting {len(ids_list)} nodes: {ids_list}")
1572+
1573+
# Build WHERE condition for collected IDs (query n.id)
1574+
ids_where = "n.id IN $ids_to_delete"
1575+
params = {"ids_to_delete": ids_list}
1576+
1577+
# Calculate total count for logging
1578+
total_count = len(ids_list)
1579+
logger.info(
1580+
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
1581+
)
1582+
print(
1583+
f"[delete_node_by_prams] Deleting {total_count} nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
1584+
)
1585+
1586+
# First count matching nodes to get accurate count
1587+
count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count"
1588+
logger.info(f"[delete_node_by_prams] count_query: {count_query}")
1589+
print(f"[delete_node_by_prams] count_query: {count_query}")
1590+
1591+
# Then delete nodes
1592+
delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n"
1593+
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
1594+
print(f"[delete_node_by_prams] delete_query: {delete_query}")
1595+
1596+
deleted_count = 0
1597+
try:
1598+
with self.driver.session(database=self.db_name) as session:
1599+
# Count nodes before deletion
1600+
count_result = session.run(count_query, **params)
1601+
count_record = count_result.single()
1602+
expected_count = total_count
1603+
if count_record:
1604+
expected_count = count_record["node_count"] or total_count
1605+
1606+
# Delete nodes
1607+
session.run(delete_query, **params)
1608+
# Use the count from before deletion as the actual deleted count
1609+
deleted_count = expected_count
1610+
1611+
except Exception as e:
1612+
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
1613+
raise
1614+
1615+
logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
1616+
return deleted_count

0 commit comments

Comments
 (0)