diff --git a/.github/workflows/pytest_and_coverage.yml b/.github/workflows/pytest_and_coverage.yml index d3e11272..b6f94878 100644 --- a/.github/workflows/pytest_and_coverage.yml +++ b/.github/workflows/pytest_and_coverage.yml @@ -58,6 +58,9 @@ jobs: # JWT settings PROMETHEUS_JWT_SECRET_KEY: your_jwt_secret_key + # Athena memory service settings + PROMETHEUS_ATHENA_BASE_URL: http://localhost:9003/v0.1.0 + steps: - name: Check out code uses: actions/checkout@v4 diff --git a/docker-compose.yml b/docker-compose.yml index 7d6baa99..ceeb9a86 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,6 +2,10 @@ networks: prometheus_network: driver: bridge +volumes: + postgres_data: + neo4j_data: + services: neo4j: image: neo4j @@ -16,14 +20,13 @@ services: - NEO4J_dbms_memory_transaction_total_max=12G - NEO4J_db_transaction_timeout=600s volumes: - - ./data_neo4j:/data + - neo4j_data:/data healthcheck: test: ["CMD", "cypher-shell", "-u", "neo4j", "-p", "password", "--non-interactive", "RETURN 1;"] interval: 30s timeout: 60s retries: 3 - postgres: image: postgres container_name: postgres_container @@ -34,7 +37,7 @@ services: - POSTGRES_PASSWORD=password - POSTGRES_DB=postgres volumes: - - ./data_postgres:/var/lib/postgresql/data + - postgres_data:/var/lib/postgresql/data healthcheck: test: ["CMD-SHELL", "pg_isready -d postgres -U postgres"] interval: 30s diff --git a/example.env b/example.env index 3f7e6d19..d746b772 100644 --- a/example.env +++ b/example.env @@ -42,3 +42,6 @@ PROMETHEUS_DATABASE_URL=postgresql+asyncpg://postgres:password@postgres:5432/pos # JWT settings PROMETHEUS_JWT_SECRET_KEY=your_jwt_secret_key + +# Athena memory service settings +PROMETHEUS_ATHENA_BASE_URL=http://localhost:9003/v0.1.0 diff --git a/prometheus/app/api/routes/issue.py b/prometheus/app/api/routes/issue.py index c966b500..4cc3a97f 100644 --- a/prometheus/app/api/routes/issue.py +++ b/prometheus/app/api/routes/issue.py @@ -103,6 +103,7 @@ async def answer_issue(issue: IssueRequest, request: Request) -> Response[IssueR issue_service.answer_issue, repository=git_repository, knowledge_graph=knowledge_graph, + repository_id=repository.id, issue_title=issue.issue_title, issue_body=issue.issue_body, issue_comments=issue.issue_comments if issue.issue_comments else [], diff --git a/prometheus/app/services/issue_service.py b/prometheus/app/services/issue_service.py index fd6f54b7..e79cf907 100644 --- a/prometheus/app/services/issue_service.py +++ b/prometheus/app/services/issue_service.py @@ -34,6 +34,7 @@ def answer_issue( self, knowledge_graph: KnowledgeGraph, repository: GitRepository, + repository_id: int, issue_title: str, issue_body: str, issue_comments: Sequence[Mapping[str, str]], @@ -54,6 +55,7 @@ def answer_issue( Args: repository (GitRepository): The Git repository instance. + repository_id (int): The repository ID. knowledge_graph (KnowledgeGraph): The knowledge graph instance. issue_title (str): The title of the issue. issue_body (str): The body of the issue. @@ -113,6 +115,7 @@ def answer_issue( kg=knowledge_graph, git_repo=repository, container=container, + repository_id=repository_id, test_commands=test_commands, ) diff --git a/prometheus/configuration/config.py b/prometheus/configuration/config.py index 6175410a..01fe80af 100644 --- a/prometheus/configuration/config.py +++ b/prometheus/configuration/config.py @@ -65,5 +65,8 @@ class Settings(BaseSettings): # tool for Websearch TAVILY_API_KEY: str + # Athena semantic memory service + ATHENA_BASE_URL: str + settings = Settings() diff --git a/prometheus/lang_graph/graphs/issue_graph.py b/prometheus/lang_graph/graphs/issue_graph.py index 066be54e..c1246932 100644 --- a/prometheus/lang_graph/graphs/issue_graph.py +++ b/prometheus/lang_graph/graphs/issue_graph.py @@ -30,6 +30,7 @@ def __init__( kg: KnowledgeGraph, git_repo: GitRepository, container: BaseContainer, + repository_id: int, test_commands: Optional[Sequence[str]] = None, ): self.git_repo = git_repo @@ -42,6 +43,7 @@ def __init__( model=base_model, kg=kg, local_path=git_repo.playground_path, + repository_id=repository_id, ) # Subgraph node for handling bug issues @@ -51,6 +53,7 @@ def __init__( container=container, kg=kg, git_repo=git_repo, + repository_id=repository_id, test_commands=test_commands, ) @@ -60,6 +63,7 @@ def __init__( base_model=base_model, kg=kg, git_repo=git_repo, + repository_id=repository_id, ) # Create the state graph for the issue handling workflow diff --git a/prometheus/lang_graph/nodes/add_context_refined_query_message_node.py b/prometheus/lang_graph/nodes/add_context_refined_query_message_node.py new file mode 100644 index 00000000..10f94e8d --- /dev/null +++ b/prometheus/lang_graph/nodes/add_context_refined_query_message_node.py @@ -0,0 +1,45 @@ +import logging +import threading + +from langchain_core.messages import HumanMessage + +from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState + + +class AddContextRefinedQueryMessageNode: + """Node for converting refined query to string and adding it to context_provider_messages.""" + + def __init__(self): + """Initialize the add context refined query message node.""" + self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") + + def __call__(self, state: ContextRetrievalState): + """ + Convert refined query to string and add to context_provider_messages. + + Args: + state: Current state containing refined_query + + Returns: + State update with context_provider_messages + """ + refined_query = state["refined_query"] + + # Build the query message + query_parts = [f"Essential query: {refined_query.essential_query}"] + + if refined_query.extra_requirements: + query_parts.append(f"\nExtra requirements: {refined_query.extra_requirements}") + + if refined_query.purpose: + query_parts.append(f"\nPurpose: {refined_query.purpose}") + + query_message = "".join(query_parts) + + self._logger.info("Creating context provider message from refined query") + self._logger.debug(f"Query message: {query_message}") + + # Create HumanMessage and add to context_provider_messages + human_message = HumanMessage(content=query_message) + + return {"context_provider_messages": [human_message]} diff --git a/prometheus/lang_graph/nodes/add_result_context_node.py b/prometheus/lang_graph/nodes/add_result_context_node.py new file mode 100644 index 00000000..836ad2ef --- /dev/null +++ b/prometheus/lang_graph/nodes/add_result_context_node.py @@ -0,0 +1,48 @@ +import logging +import threading + +from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState +from prometheus.utils.knowledge_graph_utils import deduplicate_contexts, sort_contexts + + +class AddResultContextNode: + """Node for adding new_contexts to context and deduplicating the result.""" + + def __init__(self): + """Initialize the add result context node.""" + self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") + + def __call__(self, state: ContextRetrievalState): + """ + Add new_contexts to context and deduplicate. + + Args: + state: Current state containing context and new_contexts + + Returns: + State update with deduplicated context + """ + existing_context = state.get("context", []) + new_contexts = state.get("new_contexts", []) + + if not new_contexts: + self._logger.info("No new contexts to add") + return {"context": existing_context} + + self._logger.info( + f"Adding {len(new_contexts)} new contexts to {len(existing_context)} existing contexts" + ) + + # Combine existing and new contexts + combined_contexts = list(existing_context) + list(new_contexts) + + # Deduplicate + deduplicated_contexts = deduplicate_contexts(combined_contexts) + + self._logger.info( + f"After deduplication: {len(deduplicated_contexts)} total contexts " + f"(removed {len(combined_contexts) - len(deduplicated_contexts)} duplicates)" + ) + + # Sort contexts before returning + return {"context": sort_contexts(deduplicated_contexts)} diff --git a/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py b/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py index 7adc7d1c..a825ff9f 100644 --- a/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_get_regression_tests_subgraph_node.py @@ -21,6 +21,7 @@ def __init__( container: BaseContainer, kg: KnowledgeGraph, git_repo: GitRepository, + repository_id: int, ): self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") self.subgraph = BugGetRegressionTestsSubgraph( @@ -29,6 +30,7 @@ def __init__( container=container, kg=kg, git_repo=git_repo, + repository_id=repository_id, ) def __call__(self, state: Dict): diff --git a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py index 860ed623..2955bc5d 100644 --- a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py @@ -20,6 +20,7 @@ def __init__( container: BaseContainer, kg: KnowledgeGraph, git_repo: GitRepository, + repository_id: int, test_commands: Optional[Sequence[str]], ): self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") @@ -30,6 +31,7 @@ def __init__( container=container, kg=kg, git_repo=git_repo, + repository_id=repository_id, test_commands=test_commands, ) diff --git a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py deleted file mode 100644 index 77a7483b..00000000 --- a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py +++ /dev/null @@ -1,76 +0,0 @@ -import logging -import threading -from typing import Optional, Sequence - -from langchain_core.language_models.chat_models import BaseChatModel - -from prometheus.docker.base_container import BaseContainer -from prometheus.graph.knowledge_graph import KnowledgeGraph -from prometheus.lang_graph.subgraphs.build_and_test_subgraph import BuildAndTestSubgraph -from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState - - -class BuildAndTestSubgraphNode: - def __init__( - self, - container: BaseContainer, - model: BaseChatModel, - kg: KnowledgeGraph, - build_commands: Optional[Sequence[str]] = None, - test_commands: Optional[Sequence[str]] = None, - ): - self.build_and_test_subgraph = BuildAndTestSubgraph( - container=container, - model=model, - kg=kg, - build_commands=build_commands, - test_commands=test_commands, - ) - self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") - - def __call__(self, state: IssueBugState): - exist_build = None - build_command_summary = None - build_fail_log = None - exist_test = None - test_command_summary = None - existing_test_fail_log = None - - if "build_command_summary" in state and state["build_command_summary"]: - exist_build = state["exist_build"] - build_command_summary = state["build_command_summary"] - build_fail_log = state["build_fail_log"] - - if "test_command_summary" in state and state["test_command_summary"]: - exist_test = state["exist_test"] - test_command_summary = state["test_command_summary"] - existing_test_fail_log = state["existing_test_fail_log"] - - self._logger.info("Enters BuildAndTestSubgraphNode") - - output_state = self.build_and_test_subgraph.invoke( - run_build=state["run_build"], - run_existing_test=state["run_existing_test"], - exist_build=exist_build, - build_command_summary=build_command_summary, - build_fail_log=build_fail_log, - exist_test=exist_test, - test_command_summary=test_command_summary, - existing_test_fail_log=existing_test_fail_log, - ) - - self._logger.info(f"exist_build: {output_state['exist_build']}") - self._logger.info(f"build_command_summary:\n{output_state['build_command_summary']}") - self._logger.info(f"build_fail_log:\n{output_state['build_fail_log']}") - self._logger.info(f"exist_test: {output_state['exist_test']}") - self._logger.info(f"test_command_summary:\n{output_state['test_command_summary']}") - self._logger.info(f"existing_test_fail_log:\n{output_state['existing_test_fail_log']}") - - return { - "exist_build": output_state["exist_build"], - "build_command_summary": output_state["build_command_summary"], - "build_fail_log": output_state["build_fail_log"], - "exist_test": output_state["exist_test"], - "test_command_summary": output_state["test_command_summary"], - "existing_test_fail_log": output_state["existing_test_fail_log"], - } diff --git a/prometheus/lang_graph/nodes/context_extraction_node.py b/prometheus/lang_graph/nodes/context_extraction_node.py index 84bde2be..5d8f75ad 100644 --- a/prometheus/lang_graph/nodes/context_extraction_node.py +++ b/prometheus/lang_graph/nodes/context_extraction_node.py @@ -11,11 +11,6 @@ from prometheus.models.context import Context from prometheus.utils.file_utils import read_file_with_line_numbers from prometheus.utils.knowledge_graph_utils import deduplicate_contexts -from prometheus.utils.lang_graph_util import ( - extract_human_queries, - extract_last_tool_messages, - transform_tool_messages_to_str, -) SYS_PROMPT = """\ You are a context summary agent that summarizes code contexts which is relevant to a given query. @@ -65,38 +60,18 @@ ``` Your task is to summarize the relevant contexts to a given query and return it in the specified format. -REMEMBER: Every context object must have ALL four fields (reasoning, relative_path, start_line, end_line). """ HUMAN_MESSAGE = """\ -This is the original user query: +This is the query you need to answer: ---- BEGIN ORIGINAL QUERY --- -{original_query} ---- END ORIGINAL QUERY --- +--- BEGIN QUERY --- +{query} +--- END QUERY --- -The context or file content that you have seen so far (Some of the context may be IRRELEVANT to the query!!!): - ---- BEGIN CONTEXT --- -{context} ---- END CONTEXT --- - -REMEMBER: Your task is to summarize the relevant contexts to a given query and return it in the specified format! -EVERY context object MUST include: reasoning, relative_path, start_line, and end_line. -""" +{extra_requirements} -HUMAN_MESSAGE_WITH_REFINEMENT_QUERY = """\ -This is the original user query: - ---- BEGIN ORIGINAL QUERY --- -{original_query} ---- END ORIGINAL QUERY --- - -This is the refinement query. Please consider it together with the original query. It's really IMPORTANT!!! - ---- BEGIN REFINEMENT QUERY --- -{refinement_query} ---- END REFINEMENT QUERY --- +{purpose} The context or file content that you have seen so far (Some of the context may be IRRELEVANT to the query!!!): @@ -104,8 +79,7 @@ {context} --- END CONTEXT --- -REMEMBER: Your task is to summarize the relevant contexts to a given query and the refinement query, and return your response in the specified format! -EVERY context object MUST include: reasoning, relative_path, start_line, and end_line. +REMEMBER: Your task is to summarize the relevant contexts to the given query and return it in the specified format! """ @@ -143,43 +117,41 @@ def __init__(self, model: BaseChatModel, root_path: str): self.root_path = root_path self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") - def __call__(self, state: ContextRetrievalState): - """ - Extract relevant code contexts from the codebase based on the user query and existing context. - The final contexts are with line numbers. - """ - self._logger.info("Starting context extraction process") - # Get Context List with existing context - final_context = state.get("context", []) + def format_human_message(self, state: ContextRetrievalState): + refined_query = state["refined_query"] + explored_context = state["explored_context"] - # Transform the tool messages to a single string - full_context_str = transform_tool_messages_to_str( - extract_last_tool_messages(state["context_provider_messages"]) + query_str = refined_query.essential_query + extra_requirements_str = ( + f"--- BEGIN EXTRA REQUIREMENTS ---\n{refined_query.extra_requirements}\n--- END EXTRA REQUIREMENTS ---" + if refined_query.extra_requirements + else "" + ) + purpose_str = ( + f"--- BEGIN PURPOSE ---\n{refined_query.purpose}\n--- END PURPOSE ---" + if refined_query.purpose + else "" ) - # return existing context if no new context is available - if not full_context_str: - self._logger.debug( - "No context available from tool messages, returning existing context" - ) - return {"context": final_context} + # Format the human message + return HUMAN_MESSAGE.format( + query=query_str, + extra_requirements=extra_requirements_str, + purpose=purpose_str, + context="\n\n".join([str(context) for context in explored_context]), + ) - # Get last user query or refinement query - last_human_query = extract_human_queries(state["context_provider_messages"])[0] + def __call__(self, state: ContextRetrievalState): + """ + Extract relevant code contexts from the codebase based on the refined query and existing context. + The final contexts are with line numbers. + """ + if not state["explored_context"]: + self._logger.info("No explored_context available, skipping context extraction") + return {"new_contexts": []} - # Format the human message - # If there is no refinement query, use the original query only - if last_human_query.strip() == state["query"].strip(): - human_message = HUMAN_MESSAGE.format( - original_query=state["query"], - context=full_context_str, - ) - else: - human_message = HUMAN_MESSAGE_WITH_REFINEMENT_QUERY.format( - original_query=state["query"], - refinement_query=last_human_query, - context=full_context_str, - ) + # Get human message + human_message = self.format_human_message(state) # Log the human message for debugging self._logger.debug(human_message) @@ -187,7 +159,10 @@ def __call__(self, state: ContextRetrievalState): # Summarize the context based on the last messages and system prompt response = self.model.invoke({"human_prompt": human_message}) self._logger.debug(f"Model response: {response}") + + new_contexts = [] context_list = response.context + for context_ in context_list: if context_.start_line < 1 or context_.end_line < 1: self._logger.warning( @@ -219,9 +194,7 @@ def __call__(self, state: ContextRetrievalState): content=content, ) - final_context = final_context + [context] + new_contexts.append(context) - # Deduplicate contexts before returning - final_context = deduplicate_contexts(final_context) - self._logger.info(f"Context extraction complete, returning context {final_context}") - return {"context": final_context} + # return the new contexts after deduplication + return {"new_contexts": deduplicate_contexts(new_contexts)} diff --git a/prometheus/lang_graph/nodes/context_query_message_node.py b/prometheus/lang_graph/nodes/context_query_message_node.py deleted file mode 100644 index f1167a8a..00000000 --- a/prometheus/lang_graph/nodes/context_query_message_node.py +++ /dev/null @@ -1,17 +0,0 @@ -import logging -import threading - -from langchain_core.messages import HumanMessage - -from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState - - -class ContextQueryMessageNode: - def __init__(self): - self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") - - def __call__(self, state: ContextRetrievalState): - human_message = HumanMessage(state["query"]) - self._logger.debug(f"Sending query to ContextProviderNode:\n{human_message}") - # The message will be added to the end of the context provider messages - return {"context_provider_messages": [human_message]} diff --git a/prometheus/lang_graph/nodes/context_refine_node.py b/prometheus/lang_graph/nodes/context_refine_node.py index 44ea36a8..829b16f3 100644 --- a/prometheus/lang_graph/nodes/context_refine_node.py +++ b/prometheus/lang_graph/nodes/context_refine_node.py @@ -2,18 +2,26 @@ import threading from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState +from prometheus.models.query import Query class ContextRefineStructuredOutput(BaseModel): reasoning: str = Field(description="Your step by step reasoning.") - refined_query: str = Field( - "Additional query to ask the ContextRetriever if the context is not enough. Empty otherwise." + query: str = Field( + description="The main query to ask the ContextRetriever (one sentence). Empty if no additional context is needed." + ) + extra_requirements: str = Field( + default="", + description="Optional additional requirements or fallback instructions (one sentence).", + ) + purpose: str = Field( + default="", + description="Optional brief explanation of why this context is needed (one sentence).", ) @@ -36,14 +44,22 @@ class ContextRefineNode: Provide your analysis in a structured format matching the ContextRefineStructuredOutput model. + Output Structure: + - **query**: The main request for additional context (one sentence). Set to empty string "" if no additional context is needed. + - **extra_requirements** (optional): Fallback instructions if the primary request cannot be fully satisfied. + - **purpose** (optional): Brief explanation of why this context is needed and how it will help complete the task. Use when it helps clarify the intent. + Example output: ```json {{ - "reasoning": "1. The current context includes the main function implementation but lacks details on helper functions it calls.\n2. The query requires understanding of how data is processed, which is not fully covered in the provided context.\n3. The documentation for the main function is missing, which could provide insights into its intended behavior.\n4. Therefore, additional context is needed to fully understand and address the user's query.", - "refined_query": "Please provide the implementation details of the helper functions called within the main function, as well as any relevant documentation that explains the overall data processing workflow." + "reasoning": "The current context lacks the test file content and shared test data definitions needed to extract the 8 relevant test cases.", + "query": "Please provide the full content of sklearn/feature_extraction/tests/test_text.py", + "extra_requirements": "If sending the full file is too large, please include at minimum: (a) the import statements at the top of the file, and (b) the definitions of ALL_FOOD_DOCS and JUNK_FOOD_DOCS, along with their line numbers.", + "purpose": "I need to extract the 8 relevant test cases with their exact line numbers and include all necessary imports and shared test data." }} ``` +IMPORTANT: Keep all fields (query, extra_requirements, purpose) CONCISE and SHORT - ideally ONE sentence each. PLEASE DO NOT INCLUDE ``` IN YOUR OUTPUT! """ @@ -94,7 +110,7 @@ def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): def format_refine_message(self, state: ContextRetrievalState): original_query = state["query"] - context = "\n\n".join([str(context) for context in state["context"]]) + context = "\n\n".join([str(context) for context in state.get("context", [])]) return self.REFINE_PROMPT.format( file_tree=self.file_tree, original_query=original_query, @@ -104,21 +120,25 @@ def format_refine_message(self, state: ContextRetrievalState): def __call__(self, state: ContextRetrievalState): if "max_refined_query_loop" in state and state["max_refined_query_loop"] == 0: self._logger.info("Reached max_refined_query_loop, not asking for more context") - return {"refined_query": ""} + return {"refined_query": None} + # Format the human prompt human_prompt = self.format_refine_message(state) self._logger.debug(human_prompt) + + # Invoke the model response = self.model.invoke({"human_prompt": human_prompt}) self._logger.debug(response) - state_update = {"refined_query": response.refined_query} + refined_query = Query( + essential_query=response.query, + extra_requirements=response.extra_requirements, + purpose=response.purpose, + ) + + state_update = {"refined_query": refined_query} if "max_refined_query_loop" in state: state_update["max_refined_query_loop"] = state["max_refined_query_loop"] - 1 - if response.refined_query: - state_update["context_provider_messages"] = [ - HumanMessage(content=response.refined_query) - ] - return state_update diff --git a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py index 0cdcf144..c01fcda3 100644 --- a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py +++ b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py @@ -18,12 +18,14 @@ def __init__( local_path: str, query_key_name: str, context_key_name: str, + repository_id: int, ): self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") self.context_retrieval_subgraph = ContextRetrievalSubgraph( model=model, kg=kg, local_path=local_path, + repository_id=repository_id, ) self.query_key_name = query_key_name self.context_key_name = context_key_name diff --git a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py index 99a1e1bf..938f4605 100644 --- a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py @@ -24,6 +24,7 @@ def __init__( container: BaseContainer, kg: KnowledgeGraph, git_repo: GitRepository, + repository_id: int, test_commands: Optional[Sequence[str]] = None, ): self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") @@ -34,6 +35,7 @@ def __init__( container=container, kg=kg, git_repo=git_repo, + repository_id=repository_id, test_commands=test_commands, ) diff --git a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py index a3b0dc07..a65e166f 100644 --- a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py @@ -16,12 +16,14 @@ def __init__( model: BaseChatModel, kg: KnowledgeGraph, local_path: str, + repository_id: int, ): self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") self.issue_classification_subgraph = IssueClassificationSubgraph( model=model, kg=kg, local_path=local_path, + repository_id=repository_id, ) def __call__(self, state: IssueState): diff --git a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py index a770fb8d..3f4b9e3a 100644 --- a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py @@ -21,6 +21,7 @@ def __init__( kg: KnowledgeGraph, git_repo: GitRepository, container: BaseContainer, + repository_id: int, ): self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") @@ -30,6 +31,7 @@ def __init__( kg=kg, git_repo=git_repo, container=container, + repository_id=repository_id, ) self.git_repo = git_repo diff --git a/prometheus/lang_graph/nodes/issue_question_subgraph_node.py b/prometheus/lang_graph/nodes/issue_question_subgraph_node.py index 7ef12bfb..550d7547 100644 --- a/prometheus/lang_graph/nodes/issue_question_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_question_subgraph_node.py @@ -22,6 +22,7 @@ def __init__( base_model: BaseChatModel, kg: KnowledgeGraph, git_repo: GitRepository, + repository_id: int, ): self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") self.issue_question_subgraph = IssueQuestionSubgraph( @@ -29,6 +30,7 @@ def __init__( base_model=base_model, kg=kg, git_repo=git_repo, + repository_id=repository_id, ) def __call__(self, state: IssueState): diff --git a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py index 23762ac6..4988e1ae 100644 --- a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py @@ -23,6 +23,7 @@ def __init__( container: BaseContainer, kg: KnowledgeGraph, git_repo: GitRepository, + repository_id: int, ): self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") self.git_repo = git_repo @@ -32,6 +33,7 @@ def __init__( container=container, kg=kg, git_repo=git_repo, + repository_id=repository_id, ) def __call__(self, state: IssueBugState): diff --git a/prometheus/lang_graph/nodes/memory_retrieval_node.py b/prometheus/lang_graph/nodes/memory_retrieval_node.py new file mode 100644 index 00000000..35b187e9 --- /dev/null +++ b/prometheus/lang_graph/nodes/memory_retrieval_node.py @@ -0,0 +1,55 @@ +import logging +import threading + +from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState +from prometheus.models.context import Context +from prometheus.utils.knowledge_graph_utils import deduplicate_contexts, sort_contexts +from prometheus.utils.memory_utils import retrieve_memory + + +class MemoryRetrievalNode: + """Node for retrieving contexts from Athena semantic memory.""" + + def __init__(self, repository_id: int): + """ + Initialize the memory retrieval node. + + Args: + repository_id: Repository identifier for memory storage + """ + self.repository_id = repository_id + self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") + + def __call__(self, state: ContextRetrievalState): + """ + Retrieve contexts from memory using the refined query. + + Args: + state: Current state containing the refined query + + Returns: + State update with memory_contexts + """ + refined_query = state["refined_query"] + + try: + self._logger.info( + f"Retrieving contexts from memory for query: {refined_query.essential_query}" + ) + results = retrieve_memory(repository_id=self.repository_id, query=refined_query) + except Exception as e: + self._logger.error(f"Failed to retrieve from memory: {e}") + # On error, return empty list to continue with normal flow + return {"explored_context": []} + + self._logger.debug(f"Retrieved contexts: {results}") + # Extract contexts from the result + memory_contexts = [] + for memory in results: + for context in memory["memory_context_contexts"]: + memory_contexts.append(Context(**context)) + self._logger.info(f"Retrieved {len(results)} memories from memory") + self._logger.info(f"Retrieved {len(memory_contexts)} contexts from memory") + + # Deduplicate contexts before returning + return {"explored_context": sort_contexts(deduplicate_contexts(memory_contexts))} diff --git a/prometheus/lang_graph/nodes/memory_storage_node.py b/prometheus/lang_graph/nodes/memory_storage_node.py new file mode 100644 index 00000000..119ed1b1 --- /dev/null +++ b/prometheus/lang_graph/nodes/memory_storage_node.py @@ -0,0 +1,51 @@ +import logging +import threading + +from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState +from prometheus.utils.memory_utils import store_memory + + +class MemoryStorageNode: + """Node for storing contexts to Athena semantic memory.""" + + def __init__(self, repository_id: int): + """ + Initialize the memory storage node. + + Args: + repository_id: Repository identifier for memory storage + """ + self.repository_id = repository_id + self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") + + def __call__(self, state: ContextRetrievalState): + """ + Store newly extracted contexts to memory. + + Args: + state: Current state containing refined query and new contexts + + Returns: + Empty state update (storage is side-effect only) + """ + refined_query = state["refined_query"] + new_contexts = state["new_contexts"] + + self._logger.info( + f"Storing {len(new_contexts)} contexts to memory for query: {refined_query.essential_query}" + ) + + try: + store_memory( + repository_id=self.repository_id, + essential_query=refined_query.essential_query, + extra_requirements=refined_query.extra_requirements or "", + purpose=refined_query.purpose or "", + contexts=list(new_contexts), + ) + except Exception as e: + self._logger.error(f"Failed to store to memory: {e}") + # Don't fail the entire flow if memory storage fails + + self._logger.info("Successfully stored contexts to memory") + return None diff --git a/prometheus/lang_graph/nodes/transform_tool_messages_to_context_node.py b/prometheus/lang_graph/nodes/transform_tool_messages_to_context_node.py new file mode 100644 index 00000000..dd2ab894 --- /dev/null +++ b/prometheus/lang_graph/nodes/transform_tool_messages_to_context_node.py @@ -0,0 +1,45 @@ +import logging +import threading + +from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState +from prometheus.utils.knowledge_graph_utils import deduplicate_contexts, sort_contexts +from prometheus.utils.lang_graph_util import ( + extract_last_tool_messages, + transform_tool_messages_to_context, +) + + +class TransformToolMessagesToContextNode: + """Node for transforming tool messages into Context objects and adding them to explored_context. + + This node extracts artifacts from tool messages (after the last human message), + converts them to Context objects using the knowledge graph data generator, + and adds them to the explored_context field in the state. + """ + + def __init__(self): + """Initialize the transform tool messages to context node.""" + self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") + + def __call__(self, state: ContextRetrievalState): + """ + Transform tool messages to Context objects and add to explored_context. + + Args: + state: Current state containing context_provider_messages + + Returns: + State update with Context objects added to explored_context + """ + # Extract tool messages from the message history + context_provider_messages = state.get("context_provider_messages", []) + tool_messages = extract_last_tool_messages(context_provider_messages) + + # Transform tool messages to Context objects + explored_context = transform_tool_messages_to_context(tool_messages) + + if not explored_context: + self._logger.info("No contexts extracted from tool messages") + return {"explored_context": []} + + return {"explored_context": sort_contexts(deduplicate_contexts(explored_context))} diff --git a/prometheus/lang_graph/subgraphs/bug_get_regression_tests_subgraph.py b/prometheus/lang_graph/subgraphs/bug_get_regression_tests_subgraph.py index a31029cd..d8b4c166 100644 --- a/prometheus/lang_graph/subgraphs/bug_get_regression_tests_subgraph.py +++ b/prometheus/lang_graph/subgraphs/bug_get_regression_tests_subgraph.py @@ -35,6 +35,7 @@ def __init__( container: BaseContainer, kg: KnowledgeGraph, git_repo: GitRepository, + repository_id: int, ): """ Initialize the run regression tests pipeline with all necessary parts. @@ -45,6 +46,7 @@ def __init__( container: Docker-based sandbox for running code. kg: Codebase knowledge graph used for context retrieval. git_repo: Git repository interface for codebase manipulation. + repository_id: Repository ID for memory storage. """ # Step 1: Generate initial system messages based on issue data @@ -57,6 +59,7 @@ def __init__( git_repo.playground_path, "select_regression_query", "select_regression_context", + repository_id, ) # Step 3: Select relevant regression tests based on the issue and retrieved context bug_get_regression_tests_selection_node = BugGetRegressionTestsSelectionNode( diff --git a/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py b/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py index 43e7d04f..911e403c 100644 --- a/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py +++ b/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py @@ -40,6 +40,7 @@ def __init__( container: BaseContainer, kg: KnowledgeGraph, git_repo: GitRepository, + repository_id: int, test_commands: Optional[Sequence[str]] = None, ): """ @@ -51,6 +52,7 @@ def __init__( container: Docker-based sandbox for running code. kg: Codebase knowledge graph used for context retrieval. git_repo: Git repository interface for codebase manipulation. + repository_id: Repository ID for memory storage. test_commands: Optional list of test commands to verify reproduction success. """ self.git_repo = git_repo @@ -65,6 +67,7 @@ def __init__( git_repo.playground_path, "bug_reproducing_query", "bug_reproducing_context", + repository_id, ) # Step 3: Write a patch to reproduce the bug diff --git a/prometheus/lang_graph/subgraphs/context_retrieval_state.py b/prometheus/lang_graph/subgraphs/context_retrieval_state.py index 4a1ca945..0e1984ce 100644 --- a/prometheus/lang_graph/subgraphs/context_retrieval_state.py +++ b/prometheus/lang_graph/subgraphs/context_retrieval_state.py @@ -4,6 +4,7 @@ from langgraph.graph.message import add_messages from prometheus.models.context import Context +from prometheus.models.query import Query class ContextRetrievalState(TypedDict): @@ -11,5 +12,12 @@ class ContextRetrievalState(TypedDict): max_refined_query_loop: int context_provider_messages: Annotated[Sequence[BaseMessage], add_messages] - refined_query: str - context: Sequence[Context] + refined_query: Query + + context: Sequence[Context] # Final contexts to return + + explored_context: Sequence[ + Context + ] # contexts explored during the process (both from memory and KG) + + new_contexts: Sequence[Context] # Newly extracted contexts (to be added to memory) diff --git a/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py b/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py index f0847185..eb65f80c 100644 --- a/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py +++ b/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py @@ -6,34 +6,64 @@ from langgraph.prebuilt import ToolNode, tools_condition from prometheus.graph.knowledge_graph import KnowledgeGraph +from prometheus.lang_graph.nodes.add_context_refined_query_message_node import ( + AddContextRefinedQueryMessageNode, +) +from prometheus.lang_graph.nodes.add_result_context_node import AddResultContextNode from prometheus.lang_graph.nodes.context_extraction_node import ContextExtractionNode from prometheus.lang_graph.nodes.context_provider_node import ContextProviderNode -from prometheus.lang_graph.nodes.context_query_message_node import ContextQueryMessageNode from prometheus.lang_graph.nodes.context_refine_node import ContextRefineNode +from prometheus.lang_graph.nodes.memory_retrieval_node import MemoryRetrievalNode +from prometheus.lang_graph.nodes.memory_storage_node import MemoryStorageNode from prometheus.lang_graph.nodes.reset_messages_node import ResetMessagesNode +from prometheus.lang_graph.nodes.transform_tool_messages_to_context_node import ( + TransformToolMessagesToContextNode, +) from prometheus.lang_graph.subgraphs.context_retrieval_state import ContextRetrievalState from prometheus.models.context import Context class ContextRetrievalSubgraph: """ - A LangGraph-based subgraph for retrieving relevant contextual information - (e.g., code, documentation, definitions) from a knowledge graph based on a query. - - This subgraph performs an iterative retrieval process: - 1. Constructs a context query message from the user prompt - 2. Uses tool-based retrieval (Neo4j-backed) to gather candidate context snippets - 3. Selects relevant context with LLM assistance - 4. Optionally refines the query and retries if necessary - 5. Outputs the final selected context - - Nodes: - - ContextQueryMessageNode: Converts user query to internal query prompt - - ContextProviderNode: Queries knowledge graph using structured tools - - ToolNode: Dynamically invokes retrieval tools based on tool condition - - ContextSelectionNode: Uses LLM to select useful context snippets - - ResetMessagesNode: Clears previous context messages - - ContextRefineNode: Decides whether to refine the query and retry + This class defines a LangGraph-based subgraph that retrieves relevant code contexts + using a memory-first strategy. It combines semantic memory (Athena) with knowledge + graph (Neo4j) retrieval to optimize for cost and speed. + + Workflow: + 1. Refine query into structured format (essential_query, extra_requirements, purpose) + 2. Try to retrieve from semantic memory (Athena) + 3. Extract relevant contexts from memory results + 4. If found → Store to memory and refine again (loop) + If not found → Fall back to Knowledge Graph retrieval + 5. KG retrieval: Query Neo4j → Extract contexts → Store to memory + 6. Loop back to refinement until max iterations + + Flow Diagram: + ┌──────────────┐ + │ Refine │◄─────────────┐ + │ Query │ │ + └──────┬───────┘ │ + │ │ + ┌──────▼───────┐ │ + │ Memory │ │ + │ Retrieval │ │ + └──────┬───────┘ │ + │ │ + ┌──────▼───────┐ │ + │ Extract │◄─────┐ │ + │ Contexts │ │ │ + └──────┬───────┘ │ │ + │ │ │ + ┌────────────┴────────┐ │ │ + │ │ │ │ + [has contexts?] │ │ │ + │ │ │ │ + ┌─────────▼─────┐ ┌───────▼─────┴───┐ │ + │ Store + │ │ KG Provider │ │ + │ Merge │ │ (with tools) │ │ + └─────────┬─────┘ └─────────────────┘ │ + │ │ + └───────────────────────────────────┘ """ def __init__( @@ -41,75 +71,113 @@ def __init__( model: BaseChatModel, kg: KnowledgeGraph, local_path: str, + repository_id: int, ): """ Initializes the context retrieval subgraph. Args: model (BaseChatModel): The LLM used for context selection and refinement. + kg (KnowledgeGraph): Knowledge graph instance local_path (str): Local path to the codebase for context extraction. + repository_id (int): Repository ID for memory storage """ - # Step 1: Generate an initial query from the user's input - context_query_message_node = ContextQueryMessageNode() - - # Step 2: Provide candidate context snippets using knowledge graph tools - context_provider_node = ContextProviderNode( - model, - kg, - local_path, - ) + # Step 1: Refine query into structured format + context_refine_node = ContextRefineNode(model=model, kg=kg) + + # Step 2: Retrieve contexts from semantic memory (Athena) + memory_retrieval_node = MemoryRetrievalNode(repository_id=repository_id) + + # Step 3: Extract relevant contexts from explored_context + context_extraction_node = ContextExtractionNode(model=model, root_path=local_path) + + # Step 4: Store new contexts to memory + memory_storage_node = MemoryStorageNode(repository_id=repository_id) - # Step 3: Add tool node to handle tool-based retrieval invocation dynamically - # The tool message will be added to the end of the context provider messages + # Step 5: Merge and deduplicate contexts + add_result_context_node = AddResultContextNode() + + # Step 6: Convert refined query to message for KG retrieval + add_context_refined_query_message_node = AddContextRefinedQueryMessageNode() + + # Step 7: Query knowledge graph (Neo4j) using LLM tools + context_provider_node = ContextProviderNode(model=model, kg=kg, local_path=local_path) context_provider_tools = ToolNode( tools=context_provider_node.tools, name="context_provider_tools", messages_key="context_provider_messages", ) + transform_tool_messages_to_context_node = TransformToolMessagesToContextNode() - # Step 4: Extract the Context - context_extraction_node = ContextExtractionNode(model, local_path) - - # Step 5: Reset tool messages to prepare for the next iteration (if needed) + # Step 8: Reset messages for next iteration reset_context_provider_messages_node = ResetMessagesNode("context_provider_messages") - # Step 6: Refine the query if needed and loop back - context_refine_node = ContextRefineNode(model, kg) - - # Construct the LangGraph workflow + # Define the state machine workflow = StateGraph(ContextRetrievalState) # Add all nodes to the graph - workflow.add_node("context_query_message_node", context_query_message_node) + workflow.add_node("context_refine_node", context_refine_node) + workflow.add_node("memory_retrieval_node", memory_retrieval_node) + workflow.add_node("context_extraction_node", context_extraction_node) + workflow.add_node("memory_storage_node", memory_storage_node) + workflow.add_node("add_result_context_node", add_result_context_node) + workflow.add_node( + "add_context_refined_query_message_node", add_context_refined_query_message_node + ) workflow.add_node("context_provider_node", context_provider_node) workflow.add_node("context_provider_tools", context_provider_tools) - workflow.add_node("context_extraction_node", context_extraction_node) + workflow.add_node( + "transform_tool_messages_to_context_node", transform_tool_messages_to_context_node + ) workflow.add_node( "reset_context_provider_messages_node", reset_context_provider_messages_node ) - workflow.add_node("context_refine_node", context_refine_node) - # Set the entry point for the workflow - workflow.set_entry_point("context_query_message_node") - # Define edges between nodes - workflow.add_edge("context_query_message_node", "context_provider_node") + # Define workflow edges + # Entry: Always start with query refinement + workflow.set_entry_point("context_refine_node") - # Conditional: Use tool node if tools_condition is satisfied + # After refine: Check if we have a valid query, if yes try memory first workflow.add_conditional_edges( - "context_provider_node", - functools.partial(tools_condition, messages_key="context_provider_messages"), - {"tools": "context_provider_tools", END: "context_extraction_node"}, + "context_refine_node", + lambda state: bool(state["refined_query"]) + and bool(state["refined_query"].essential_query.strip()), + {True: "memory_retrieval_node", False: END}, ) - workflow.add_edge("context_provider_tools", "context_provider_node") - workflow.add_edge("context_extraction_node", "reset_context_provider_messages_node") - workflow.add_edge("reset_context_provider_messages_node", "context_refine_node") - # If refined_query is non-empty, loop back to provider; else terminate + # After memory retrieval: Always extract contexts + workflow.add_edge("memory_retrieval_node", "context_extraction_node") + + # After extraction: Check if we found new contexts + # Yes → Store to memory and loop back (memory hit) + # No → Fall back to KG retrieval (memory miss) workflow.add_conditional_edges( - "context_refine_node", - lambda state: bool(state["refined_query"]), - {True: "context_provider_node", False: END}, + "context_extraction_node", + lambda state: len(state["new_contexts"]) > 0, + {True: "memory_storage_node", False: "reset_context_provider_messages_node"}, + ) + + # Memory hit path: Store → Merge → Refine again + workflow.add_edge("memory_storage_node", "add_result_context_node") + workflow.add_edge("add_result_context_node", "context_refine_node") + + # Memory miss path: Reset → Convert query → KG provider + workflow.add_edge( + "reset_context_provider_messages_node", "add_context_refined_query_message_node" + ) + workflow.add_edge("add_context_refined_query_message_node", "context_provider_node") + + # KG provider: Call tools if needed, otherwise extract directly + workflow.add_conditional_edges( + "context_provider_node", + functools.partial(tools_condition, messages_key="context_provider_messages"), + {"tools": "context_provider_tools", END: "transform_tool_messages_to_context_node"}, ) + # After KG provider (no tools): Transform tool messages to contexts + workflow.add_edge("transform_tool_messages_to_context_node", "context_extraction_node") + + # After executing tools: Loop back to provider (may call more tools) + workflow.add_edge("context_provider_tools", "context_provider_node") # Compile and store the subgraph self.subgraph = workflow.compile() @@ -127,7 +195,8 @@ def invoke(self, query: str, max_refined_query_loop: int) -> Dict[str, Sequence[ - "context" (Sequence[Context]): A list of selected context snippets relevant to the query. """ # Set the recursion limit based on the maximum number of refined query loops - config = {"recursion_limit": (max_refined_query_loop + 1) * 75} + max_refined_query_loop = max_refined_query_loop + 1 + config = {"recursion_limit": max_refined_query_loop * 40} input_state = { "query": query, diff --git a/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py index 350ba442..ec8c2f98 100644 --- a/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py @@ -28,6 +28,7 @@ def __init__( container: BaseContainer, kg: KnowledgeGraph, git_repo: GitRepository, + repository_id: int, test_commands: Optional[Sequence[str]] = None, ): # Construct bug reproduction node @@ -37,6 +38,7 @@ def __init__( container=container, kg=kg, git_repo=git_repo, + repository_id=repository_id, test_commands=test_commands, ) # Construct bug regression tests subgraph node @@ -46,6 +48,7 @@ def __init__( container=container, kg=kg, git_repo=git_repo, + repository_id=repository_id, ) # Construct issue bug verified subgraph nodes @@ -55,6 +58,7 @@ def __init__( container=container, kg=kg, git_repo=git_repo, + repository_id=repository_id, ) # Construct issue not verified bug subgraph node issue_not_verified_bug_subgraph_node = IssueNotVerifiedBugSubgraphNode( @@ -63,6 +67,7 @@ def __init__( kg=kg, git_repo=git_repo, container=container, + repository_id=repository_id, ) # Construct issue bug responder node issue_bug_responder_node = IssueBugResponderNode(base_model) diff --git a/prometheus/lang_graph/subgraphs/issue_classification_subgraph.py b/prometheus/lang_graph/subgraphs/issue_classification_subgraph.py index 93422ac9..3b19a048 100644 --- a/prometheus/lang_graph/subgraphs/issue_classification_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_classification_subgraph.py @@ -18,6 +18,7 @@ def __init__( model: BaseChatModel, kg: KnowledgeGraph, local_path: str, + repository_id: int, ): issue_classification_context_message_node = IssueClassificationContextMessageNode() context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( @@ -26,6 +27,7 @@ def __init__( local_path=local_path, query_key_name="issue_classification_query", context_key_name="issue_classification_context", + repository_id=repository_id, ) issue_classifier_node = IssueClassifierNode(model) diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py index efdc2441..038217b0 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py @@ -33,6 +33,7 @@ def __init__( kg: KnowledgeGraph, git_repo: GitRepository, container: BaseContainer, + repository_id: int, ): issue_bug_context_message_node = IssueBugContextMessageNode() context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( @@ -41,6 +42,7 @@ def __init__( local_path=git_repo.playground_path, query_key_name="bug_fix_query", context_key_name="bug_fix_context", + repository_id=repository_id, ) issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode() diff --git a/prometheus/lang_graph/subgraphs/issue_question_subgraph.py b/prometheus/lang_graph/subgraphs/issue_question_subgraph.py index 569a4a5a..aada7acd 100644 --- a/prometheus/lang_graph/subgraphs/issue_question_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_question_subgraph.py @@ -32,6 +32,7 @@ def __init__( base_model: BaseChatModel, kg: KnowledgeGraph, git_repo: GitRepository, + repository_id: int, ): # Step 1: Retrieve relevant context based on the issue details issue_question_context_message_node = IssueQuestionContextMessageNode() @@ -41,6 +42,7 @@ def __init__( local_path=git_repo.playground_path, query_key_name="question_query", context_key_name="question_context", + repository_id=repository_id, ) # Step 2: Send issue question analyze message diff --git a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py index 6240763a..33efc5f1 100644 --- a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py @@ -56,6 +56,7 @@ def __init__( container: BaseContainer, kg: KnowledgeGraph, git_repo: GitRepository, + repository_id: int, ): """ Initialize the verified bug fix subgraph. @@ -77,6 +78,7 @@ def __init__( local_path=git_repo.playground_path, query_key_name="bug_fix_query", context_key_name="bug_fix_context", + repository_id=repository_id, ) # Phase 2: Analyze the bug and generate hypotheses diff --git a/prometheus/models/query.py b/prometheus/models/query.py new file mode 100644 index 00000000..e58ea412 --- /dev/null +++ b/prometheus/models/query.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class Query(BaseModel): + essential_query: str + extra_requirements: str + purpose: str diff --git a/prometheus/utils/knowledge_graph_utils.py b/prometheus/utils/knowledge_graph_utils.py index b65ec891..f47e4257 100644 --- a/prometheus/utils/knowledge_graph_utils.py +++ b/prometheus/utils/knowledge_graph_utils.py @@ -164,3 +164,23 @@ def _analyze_context_relationship(context1: Context, context2: Context) -> str: # For all other cases (including partial overlaps), return separate return "separate" + + +def sort_contexts(contexts: List[Context]) -> List[Context]: + """ + Sort a list of Context objects by relative_path, then by start_line_number and end_line_number. + + Args: + contexts: List of Context objects to sort + + Returns: + Sorted list of Context objects + """ + return sorted( + contexts, + key=lambda ctx: ( + ctx.relative_path, + ctx.start_line_number if ctx.start_line_number is not None else float("inf"), + ctx.end_line_number if ctx.end_line_number is not None else float("inf"), + ), + ) diff --git a/prometheus/utils/lang_graph_util.py b/prometheus/utils/lang_graph_util.py index 3ba38b6c..31dace02 100644 --- a/prometheus/utils/lang_graph_util.py +++ b/prometheus/utils/lang_graph_util.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Sequence +from typing import Callable, Dict, List, Sequence from langchain_core.messages import ( AIMessage, @@ -8,6 +8,7 @@ ) from langchain_core.output_parsers import StrOutputParser +from prometheus.models.context import Context from prometheus.utils.knowledge_graph_utils import knowledge_graph_data_for_context_generator @@ -66,7 +67,16 @@ def extract_last_tool_messages(messages: Sequence[BaseMessage]) -> Sequence[Tool return tool_messages -def transform_tool_messages_to_str(messages: Sequence[ToolMessage]) -> str: +def transform_tool_messages_to_context(messages: Sequence[ToolMessage]) -> List[Context]: + """ + Transform tool messages to Context objects and return them in explored_context. + + Args: + messages: Sequence of ToolMessage objects that may contain artifacts + + Returns: + Dictionary with 'explored_context' key containing list of Context objects + """ # Aggregate all artifacts from the tool messages total_artifacts = [] for message in messages: @@ -74,12 +84,8 @@ def transform_tool_messages_to_str(messages: Sequence[ToolMessage]) -> str: if message.artifact: total_artifacts.extend(message.artifact) - # Convert the aggregated artifacts to a string representation - result = "" - for context in knowledge_graph_data_for_context_generator(total_artifacts): - result += str(context) - result += "\n" - return result + # Convert the aggregated artifacts to Context objects using the knowledge graph generator + return list(knowledge_graph_data_for_context_generator(total_artifacts)) def get_last_message_content(messages: Sequence[BaseMessage]) -> str: diff --git a/prometheus/utils/memory_utils.py b/prometheus/utils/memory_utils.py new file mode 100644 index 00000000..b3ac40fa --- /dev/null +++ b/prometheus/utils/memory_utils.py @@ -0,0 +1,215 @@ +import logging +import threading +from typing import Any, Dict, List + +import requests + +from prometheus.configuration.config import settings +from prometheus.models.context import Context +from prometheus.models.query import Query + + +class AthenaMemoryClient: + """Client for interacting with Athena semantic memory service.""" + + def __init__( + self, + base_url: str, + ): + """ + Initialize Athena memory client. + + Args: + base_url: Base URL of the Athena service + """ + self.base_url = base_url.rstrip("/") + self.timeout = 30 + self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") + + def store_memory( + self, + repository_id: int, + essential_query: str, + extra_requirements: str, + purpose: str, + contexts: List[Context], + ) -> dict[str, Any]: + """ + Store content in semantic memory. + + Args: + repository_id: Repository identifier + essential_query: The main query + extra_requirements: Optional extra requirements + purpose: Optional purpose description + contexts: List of Context objects to store + + Returns: + Response from Athena service + + Raises: + requests.RequestException: If the request fails + """ + url = f"{self.base_url}/semantic-memory/store/" + + payload = { + "repository_id": repository_id, + "query": { + "essential_query": essential_query, + "extra_requirements": extra_requirements, + "purpose": purpose, + }, + "contexts": [context.model_dump() for context in contexts], + } + + self._logger.debug(f"Storing memory for repository {repository_id}") + try: + response = requests.post(url, json=payload, timeout=self.timeout) + response.raise_for_status() + except requests.RequestException as e: + self._logger.error(f"Failed to store memory for repository {repository_id}: {e}") + raise + result = response.json() + self._logger.debug(f"Successfully stored memory: {result}") + return result + + def retrieve_memory( + self, + repository_id: int, + query: Query, + ) -> List[Dict[str, Any]]: + """ + Retrieve content from semantic memory using a query. + + Args: + repository_id: Repository identifier + query: Query object with essential_query, extra_requirements, and purpose + + Returns: + Response from Athena service containing retrieved memories + + Raises: + requests.RequestException: If the request fails + """ + url = f"{self.base_url}/semantic-memory/retrieve/{repository_id}/" + + params = { + "essential_query": query.essential_query, + "extra_requirements": query.extra_requirements or "", + "purpose": query.purpose or "", + } + + self._logger.debug( + f"Retrieving memory for repository {repository_id} with query: {query.essential_query}" + ) + + try: + response = requests.get(url, params=params, timeout=self.timeout) + response.raise_for_status() + except requests.RequestException as e: + self._logger.error(f"Failed to retrieve memory for repository {repository_id}: {e}") + raise + + result = response.json() + self._logger.debug(f"Successfully retrieved {len(result.get('data', []))} memories") + return result["data"] + + def delete_repository_memory(self, repository_id: int) -> dict[str, Any]: + """ + Delete all memories for a repository. + + Args: + repository_id: Repository identifier + + Returns: + Response from Athena service + + Raises: + requests.RequestException: If the request fails + """ + url = f"{self.base_url}/semantic-memory/{repository_id}/" + + self._logger.debug(f"Deleting memory for repository {repository_id}") + try: + response = requests.delete(url, timeout=self.timeout) + response.raise_for_status() + except requests.RequestException as e: + self._logger.error(f"Failed to delete repository memory {repository_id}: {e}") + raise + result = response.json() + self._logger.debug(f"Successfully deleted repository memory: {result}") + return result + + +# Global instance with settings from config +athena_client = AthenaMemoryClient( + base_url=settings.ATHENA_BASE_URL, +) + + +def store_memory( + repository_id: int, + essential_query: str, + extra_requirements: str, + purpose: str, + contexts: List[Context], +) -> dict[str, Any]: + """ + Store contexts to semantic memory for a repository. + + Args: + repository_id: Repository identifier + essential_query: The main query that was used to retrieve these contexts + extra_requirements: Optional extra requirements for the query + purpose: Optional purpose description + contexts: List of Context objects to store + + Returns: + Response from Athena service + + Raises: + requests.RequestException: If the request fails + """ + return athena_client.store_memory( + repository_id=repository_id, + essential_query=essential_query, + extra_requirements=extra_requirements, + purpose=purpose, + contexts=contexts, + ) + + +def retrieve_memory( + repository_id: int, + query: Query, +) -> List[Dict[str, Any]]: + """ + Retrieve contexts from semantic memory using a query. + + Args: + repository_id: Repository identifier + query: Query object with essential_query, extra_requirements, and purpose + + Returns: + Response from Athena service containing retrieved contexts + + Raises: + requests.RequestException: If the request fails + """ + return athena_client.retrieve_memory(repository_id=repository_id, query=query) + + +def delete_repository_memory(repository_id: int) -> dict[str, Any]: + """ + Delete all memories for a repository. + + Args: + repository_id: Repository identifier + + Returns: + Response from Athena service + + Raises: + requests.RequestException: If the request fails + """ + return athena_client.delete_repository_memory(repository_id=repository_id) diff --git a/tests/app/services/test_issue_service.py b/tests/app/services/test_issue_service.py index 83efd60e..1e37f06c 100644 --- a/tests/app/services/test_issue_service.py +++ b/tests/app/services/test_issue_service.py @@ -65,6 +65,7 @@ async def test_answer_issue_with_general_container(issue_service, monkeypatch): result = issue_service.answer_issue( repository=repository, knowledge_graph=knowledge_graph, + repository_id=1, issue_title="Test Issue", issue_body="Test Body", issue_comments=[], @@ -89,6 +90,7 @@ async def test_answer_issue_with_general_container(issue_service, monkeypatch): kg=knowledge_graph, git_repo=repository, container=mock_container, + repository_id=1, test_commands=None, ) assert result == ("test_patch", True, True, True, "test_response", IssueType.BUG) @@ -126,6 +128,7 @@ async def test_answer_issue_with_user_defined_container(issue_service, monkeypat result = issue_service.answer_issue( repository=repository, knowledge_graph=knowledge_graph, + repository_id=1, issue_title="Test Issue", issue_body="Test Body", issue_comments=[], diff --git a/tests/lang_graph/graphs/test_issue_graph.py b/tests/lang_graph/graphs/test_issue_graph.py index 47eea660..a7f838ca 100644 --- a/tests/lang_graph/graphs/test_issue_graph.py +++ b/tests/lang_graph/graphs/test_issue_graph.py @@ -53,6 +53,7 @@ def test_issue_graph_basic_initialization( kg=mock_kg, git_repo=mock_git_repo, container=mock_container, + repository_id=1, ) assert graph.graph is not None diff --git a/tests/lang_graph/subgraphs/test_issue_bug_subgraph.py b/tests/lang_graph/subgraphs/test_issue_bug_subgraph.py index 1899e6d1..fd3a91a8 100644 --- a/tests/lang_graph/subgraphs/test_issue_bug_subgraph.py +++ b/tests/lang_graph/subgraphs/test_issue_bug_subgraph.py @@ -43,6 +43,7 @@ def test_issue_bug_subgraph_basic_initialization(mock_container, mock_kg, mock_g container=mock_container, kg=mock_kg, git_repo=mock_git_repo, + repository_id=1, ) # Verify the subgraph was created @@ -61,6 +62,7 @@ def test_issue_bug_subgraph_with_commands(mock_container, mock_kg, mock_git_repo container=mock_container, kg=mock_kg, git_repo=mock_git_repo, + repository_id=1, test_commands=test_commands, ) diff --git a/tests/lang_graph/subgraphs/test_issue_classification_subgraph.py b/tests/lang_graph/subgraphs/test_issue_classification_subgraph.py index e0a21531..d109d26c 100644 --- a/tests/lang_graph/subgraphs/test_issue_classification_subgraph.py +++ b/tests/lang_graph/subgraphs/test_issue_classification_subgraph.py @@ -36,6 +36,7 @@ def test_issue_classification_subgraph_basic_initialization(mock_kg, mock_git_re model=fake_model, kg=mock_kg, local_path=mock_git_repo.playground_path, + repository_id=1, ) # Verify the subgraph was created diff --git a/tests/lang_graph/subgraphs/test_issue_question_subgraph.py b/tests/lang_graph/subgraphs/test_issue_question_subgraph.py index 26350abd..313faec1 100644 --- a/tests/lang_graph/subgraphs/test_issue_question_subgraph.py +++ b/tests/lang_graph/subgraphs/test_issue_question_subgraph.py @@ -42,6 +42,7 @@ def test_issue_question_subgraph_basic_initialization(mock_container, mock_kg, m base_model=fake_base_model, kg=mock_kg, git_repo=mock_git_repo, + repository_id=1, ) # Verify the subgraph was created diff --git a/tests/utils/test_knowledge_graph_utils.py b/tests/utils/test_knowledge_graph_utils.py index 6928b3ed..a69ed99e 100644 --- a/tests/utils/test_knowledge_graph_utils.py +++ b/tests/utils/test_knowledge_graph_utils.py @@ -1,4 +1,8 @@ -from prometheus.utils.knowledge_graph_utils import knowledge_graph_data_for_context_generator +from prometheus.models.context import Context +from prometheus.utils.knowledge_graph_utils import ( + knowledge_graph_data_for_context_generator, + sort_contexts, +) def test_empty_data(): @@ -143,3 +147,114 @@ def test_complex_deduplication_scenario(): assert len(result) == 2 # Large context + separate comment assert "class MyClass:" in result[0].content assert result[1].content == "# Comment at end" + + +def test_sort_contexts_empty_list(): + """Test sorting an empty list""" + result = sort_contexts([]) + assert result == [] + + +def test_sort_contexts_by_relative_path(): + """Test sorting by relative path""" + contexts = [ + Context( + relative_path="src/z.py", content="content z", start_line_number=1, end_line_number=5 + ), + Context( + relative_path="src/a.py", content="content a", start_line_number=1, end_line_number=5 + ), + Context( + relative_path="src/m.py", content="content m", start_line_number=1, end_line_number=5 + ), + ] + result = sort_contexts(contexts) + assert result[0].relative_path == "src/a.py" + assert result[1].relative_path == "src/m.py" + assert result[2].relative_path == "src/z.py" + + +def test_sort_contexts_by_line_numbers(): + """Test sorting by line numbers within same file""" + contexts = [ + Context( + relative_path="test.py", content="content 3", start_line_number=20, end_line_number=25 + ), + Context( + relative_path="test.py", content="content 1", start_line_number=1, end_line_number=5 + ), + Context( + relative_path="test.py", content="content 2", start_line_number=10, end_line_number=15 + ), + ] + result = sort_contexts(contexts) + assert result[0].start_line_number == 1 + assert result[1].start_line_number == 10 + assert result[2].start_line_number == 20 + + +def test_sort_contexts_none_line_numbers(): + """Test sorting when line numbers are None (should appear last)""" + contexts = [ + Context( + relative_path="test.py", + content="content with lines", + start_line_number=10, + end_line_number=15, + ), + Context( + relative_path="test.py", + content="content no lines", + start_line_number=None, + end_line_number=None, + ), + Context( + relative_path="test.py", content="content first", start_line_number=1, end_line_number=5 + ), + ] + result = sort_contexts(contexts) + assert result[0].start_line_number == 1 + assert result[1].start_line_number == 10 + assert result[2].start_line_number is None + + +def test_sort_contexts_mixed_files_and_lines(): + """Test sorting with multiple files and different line numbers""" + contexts = [ + Context( + relative_path="b.py", content="b content 2", start_line_number=20, end_line_number=25 + ), + Context( + relative_path="a.py", content="a content 2", start_line_number=10, end_line_number=15 + ), + Context( + relative_path="b.py", content="b content 1", start_line_number=5, end_line_number=10 + ), + Context( + relative_path="a.py", content="a content 1", start_line_number=1, end_line_number=5 + ), + ] + result = sort_contexts(contexts) + assert result[0].relative_path == "a.py" and result[0].start_line_number == 1 + assert result[1].relative_path == "a.py" and result[1].start_line_number == 10 + assert result[2].relative_path == "b.py" and result[2].start_line_number == 5 + assert result[3].relative_path == "b.py" and result[3].start_line_number == 20 + + +def test_sort_contexts_end_line_number_tiebreaker(): + """Test sorting uses end_line_number as tiebreaker when start_line_number is same""" + contexts = [ + Context( + relative_path="test.py", content="content 3", start_line_number=10, end_line_number=30 + ), + Context( + relative_path="test.py", content="content 1", start_line_number=10, end_line_number=15 + ), + Context( + relative_path="test.py", content="content 2", start_line_number=10, end_line_number=20 + ), + ] + result = sort_contexts(contexts) + assert result[0].end_line_number == 15 + assert result[1].end_line_number == 20 + assert result[2].end_line_number == 30