From 955f746d8e74f27cff68cb30c9537f15ef6f79fd Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Tue, 9 Sep 2025 12:14:30 +0800 Subject: [PATCH 1/3] fix: Refactor patch handling and state management in bug verification workflow --- .../bug_fix_verification_subgraph_node.py | 3 + .../nodes/final_patch_selection_node.py | 18 ++--- ...ass_regression_test_patch_subgraph_node.py | 50 ++++++------ .../nodes/issue_verified_bug_subgraph_node.py | 1 + .../nodes/patch_normalization_node.py | 13 ++-- .../nodes/run_regression_tests_node.py | 8 +- .../run_regression_tests_structure_node.py | 75 ++++++++++++------ .../subgraphs/issue_not_verified_bug_state.py | 2 +- .../issue_not_verified_bug_subgraph.py | 10 ++- .../subgraphs/issue_verified_bug_state.py | 4 + .../subgraphs/issue_verified_bug_subgraph.py | 76 ++++++++++++++----- prometheus/tools/web_search.py | 2 + 12 files changed, 174 insertions(+), 88 deletions(-) diff --git a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py index 8bdee47..e741785 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py @@ -53,4 +53,7 @@ def __call__(self, state: IssueVerifiedBugState): return { "reproducing_test_fail_log": output_state["reproducing_test_fail_log"], + "final_candidate_patches": [state["edit_patch"]] + if not bool(output_state["reproducing_test_fail_log"]) + else [], } diff --git a/prometheus/lang_graph/nodes/final_patch_selection_node.py b/prometheus/lang_graph/nodes/final_patch_selection_node.py index 1355449..a3150f1 100644 --- a/prometheus/lang_graph/nodes/final_patch_selection_node.py +++ b/prometheus/lang_graph/nodes/final_patch_selection_node.py @@ -1,10 +1,9 @@ -from typing import Sequence +from typing import Dict, Sequence from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, Field -from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState from prometheus.utils.issue_util import format_issue_info from prometheus.utils.logger_manager import get_thread_logger @@ -120,8 +119,8 @@ class FinalPatchSelectionNode: {patches} """ - def __init__(self, model: BaseChatModel, max_retries: int = 2): - self.max_retries = max_retries + def __init__(self, model: BaseChatModel, candidate_patch_key: str): + self.candidate_patch_key = candidate_patch_key prompt = ChatPromptTemplate.from_messages( [("system", self.SYS_PROMPT), ("human", "{human_prompt}")] ) @@ -130,7 +129,7 @@ def __init__(self, model: BaseChatModel, max_retries: int = 2): self._logger, file_handler = get_thread_logger(__name__) self.majority_voting_times = 10 - def format_human_message(self, patches: Sequence[str], state: IssueNotVerifiedBugState): + def format_human_message(self, patches: Sequence[str], state: Dict): patches_str = "" for index, patch in enumerate(patches): patches_str += f"Patch at index {index}:\n" @@ -148,12 +147,11 @@ def format_human_message(self, patches: Sequence[str], state: IssueNotVerifiedBu patches=patches_str, ) - def __call__(self, state: IssueNotVerifiedBugState): + def __call__(self, state: Dict): # Determine candidate patches - if "tested_patch_result" in state and state["tested_patch_result"]: - patches = [result.patch for result in state["tested_patch_result"] if result.passed] - else: - patches = state["deduplicated_patches"] + patches = state[self.candidate_patch_key] + self._logger.debug(f"Total candidate patches: {len(patches)}") + self._logger.debug(f"Candidate patches: {patches}") # Handle the case with no candidate patches if not patches: diff --git a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py index 8c6d038..9aa346f 100644 --- a/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py +++ b/prometheus/lang_graph/nodes/get_pass_regression_test_patch_subgraph_node.py @@ -20,6 +20,8 @@ def __init__( git_repo: GitRepository, testing_patch_key: str, is_testing_patch_list: bool, + return_str_patch: bool, + return_key: str, ): self._logger, file_handler = get_thread_logger(__name__) self.subgraph = GetPassRegressionTestPatchSubgraph( @@ -30,6 +32,8 @@ def __init__( self.git_repo = git_repo self.testing_patch_key = testing_patch_key self.is_testing_patch_list = is_testing_patch_list + self.return_str_patch = return_str_patch + self.return_key = return_key def __call__(self, state: Dict): self._logger.info("Enter get_pass_regression_test_patch_subgraph_node") @@ -43,37 +47,37 @@ def __call__(self, state: Dict): if not state["selected_regression_tests"]: self._logger.info("No regression tests selected, skipping patch testing.") - return { - "tested_patch_result": [ - TestedPatchResult(patch=patch, passed=True, regression_test_failure_log="") - for patch in testing_patch - ] - } - try: - output_state = self.subgraph.invoke( - selected_regression_tests=state["selected_regression_tests"], - patches=testing_patch, - ) - except GraphRecursionError: - # If the recursion limit is reached, return a failure result for each patch - self._logger.info("Recursion limit reached") - return { - "tested_patch_result": [ + test_patch_results = [ + TestedPatchResult(patch=patch, passed=True, regression_test_failure_log="") + for patch in testing_patch + ] + else: + try: + output_state = self.subgraph.invoke( + selected_regression_tests=state["selected_regression_tests"], + patches=testing_patch, + ) + self._logger.debug(f"tested_patch_result: {output_state['tested_patch_result']}") + test_patch_results = output_state["tested_patch_result"] + except GraphRecursionError: + # If the recursion limit is reached, return a failure result for each patch + self._logger.info("Recursion limit reached") + test_patch_results = [ TestedPatchResult( patch=patch, passed=False, regression_test_failure_log="Fail to get regression test result. Please try again!", ) - for patch in state[self.testing_patch_key] + for patch in testing_patch ] - } - finally: - # Reset the git repository to its original state - self.git_repo.reset_repository() - self._logger.debug(f"tested_patch_result: {output_state['tested_patch_result']}") + finally: + # Reset the git repository to its original state + self.git_repo.reset_repository() return { - "tested_patch_result": output_state["tested_patch_result"], + self.return_key: [result.patch for result in test_patch_results if result.passed] + if self.return_str_patch + else test_patch_results, } diff --git a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py index 99e400c..ff10071 100644 --- a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py @@ -39,6 +39,7 @@ def __call__(self, state: IssueBugState): issue_title=state["issue_title"], issue_body=state["issue_body"], issue_comments=state["issue_comments"], + number_of_candidate_patch=state["number_of_candidate_patch"], run_regression_test=state["run_regression_test"], run_existing_test=state["run_existing_test"], reproduced_bug_file=state["reproduced_bug_file"], diff --git a/prometheus/lang_graph/nodes/patch_normalization_node.py b/prometheus/lang_graph/nodes/patch_normalization_node.py index 12d8471..ee2a2c4 100644 --- a/prometheus/lang_graph/nodes/patch_normalization_node.py +++ b/prometheus/lang_graph/nodes/patch_normalization_node.py @@ -9,7 +9,6 @@ from dataclasses import dataclass from typing import Dict, List, Sequence -from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState from prometheus.utils.logger_manager import get_thread_logger @@ -37,8 +36,10 @@ class PatchNormalizationNode: Simplified approach without complex voting mechanisms. """ - def __init__(self): + def __init__(self, input_patch_key: str, return_key: str): self._logger, file_handler = get_thread_logger(__name__) + self.return_key = return_key + self.input_patch_key = input_patch_key def normalize_patch(self, raw_patch: str) -> str: """Normalize patch content for deduplication @@ -135,17 +136,17 @@ def deduplicate_patches(self, patches: Sequence[str]) -> List[NormalizedPatch]: return deduplicated - def __call__(self, state: IssueNotVerifiedBugState) -> Dict: + def __call__(self, state: Dict) -> Dict: """Node call interface Process edit_patches in state, return normalized, deduplicated patches """ - patches = state.get("edit_patches", []) + patches = state.get(self.input_patch_key, []) if not patches: self._logger.warning("No patches found to process") return { - "deduplicated_patches": [], + self.return_key: [], } self._logger.info(f"Starting to process {len(patches)} patches") @@ -161,5 +162,5 @@ def __call__(self, state: IssueNotVerifiedBugState) -> Dict: ) return { - "deduplicated_patches": deduplicated_patches, + self.return_key: deduplicated_patches, } diff --git a/prometheus/lang_graph/nodes/run_regression_tests_node.py b/prometheus/lang_graph/nodes/run_regression_tests_node.py index 5a8e7dd..390c800 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_node.py @@ -33,8 +33,8 @@ class RunRegressionTestsNode: - Do NOT modify the core logic or parameters of the commands! - Do NOT attempt to fix bugs or modify test logic! - You MUST RUN ALL THE TESTS EXACTLY AS PROVIDED! -- Do NOT stop util all tests are run! -- DO NOT ASSUME ALL DEPENDENCIES ARE INSTALLED.! +- Do NOT stop until all tests are run! +- DO NOT ASSUME ALL DEPENDENCIES ARE INSTALLED! REMINDER: - Install dependencies if needed! @@ -76,7 +76,9 @@ def _init_tools(self): def format_human_message(self, state: RunRegressionTestsState) -> HumanMessage: return HumanMessage( - self.HUMAN_PROMPT.format(selected_regression_tests=state["selected_regression_tests"]) + self.HUMAN_PROMPT.format( + selected_regression_tests="\n".join(state["selected_regression_tests"]) + ) ) def __call__(self, state: RunRegressionTestsState): diff --git a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py index 1248c0e..b8f3105 100644 --- a/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py +++ b/prometheus/lang_graph/nodes/run_regression_tests_structure_node.py @@ -14,7 +14,7 @@ class RunRegressionTestsStructureOutput(BaseModel): description="List of test identifier of regression tests that passed (e.g., class name and method name)" ) regression_test_fail_log: str = Field( - description="If any test failed, contains the exact and complete test FAILURE log. Otherwise empty string" + description="Complete failure log if any test failed. Empty string if all tests passed or no tests were run." ) total_tests_run: int = Field( description="Total number of tests run, including both passed and failed tests, or 0 if no tests were run", @@ -31,53 +31,78 @@ class RunRegressionTestsStructuredNode: - Test summary showing "passed" or "PASSED" - Warning is ok - No "FAILURES" section -2. If a test fails, capture the exact and complete failure output. Otherwise empty string for failure log +2. If ANY test failed, capture the complete failure output; otherwise leave failure log empty 3. Return the exact test identifiers that passed -4. Count the total number of tests run. Only count tests that were actually executed! If tests were unable to run due to an error, do not count them! +4. Count the total number of tests run. Only count tests that were actually executed (Both Passed and Failed)! Regardless of pass or fail, count them if they were run. Return: - passed_regression_tests: List of test identifier of regression tests that passed (e.g., class name and method name) - regression_test_fail_log: empty string if all tests pass, exact complete test output if a test fails -- total_tests_run: Total number of tests run, including both passed and failed tests. If you can't find any test run, return 0 +- total_tests_run: Total number of tests run, including BOTH PASSED and FAILED tests. If you can't find any test run, return 0 Example 1: + + ``` Run Regression Tests Logs: ============================= test session starts ============================== -collecting ... collected 7 items - -test_file_operation.py::test_create_and_read_file PASSED [ 14%] -test_file_operation.py::test_read_file_nonexistent PASSED [ 28%] -test_file_operation.py::test_read_file_with_line_numbers PASSED [ 42%] -test_file_operation.py::test_delete PASSED [ 57%] -test_file_operation.py::test_delete_nonexistent PASSED [ 71%] -test_file_operation.py::test_edit_file PASSED [ 85%] -test_file_operation.py::test_create_file_already_exists PASSED [100%] - -============================== 7 passed in 1.53s =============================== +collected 6 items + +test_patch_util.py::test_get_updated_files_empty_diff PASSED [ 16%] +test_patch_util.py::test_get_updated_files_added_only FAILED [ 33%] +tests/utils/test_patch_util.py:13 (test_get_updated_files_added_only) +0 != 1 + +Expected:1 +Actual:0 +<点击以查看差异> + +def test_get_updated_files_added_only(): + diff = \""" + diff --git a/new_file.txt b/new_file.txt + new file mode 100644 + index 0000000..1234567 + --- /dev/null + +++ b/new_file.txt + @@ -0,0 +1 @@ + +New content + \""" + added, modified, removed = get_updated_files(diff) + assert len(added) == 1 +> assert len(modified) == 1 +E assert 0 == 1 +E + where 0 = len([]) + +test_patch_util.py:26: AssertionError + +test_patch_util.py::test_get_updated_files_modified_only PASSED [ 50%] +test_patch_util.py::test_get_updated_files_removed_only PASSED [ 66%] +test_patch_util.py::test_get_updated_files_multiple_changes PASSED [ 83%] +test_patch_util.py::test_get_updated_files_with_subfolders PASSED [100%] + +========================= 1 failed, 5 passed in 0.03s ========================== ``` Example 1 Output: {{ "passed_regression_tests": [ - "test_file_operation.py::test_create_and_read_file", - "test_file_operation.py::test_read_file_nonexistent", - "test_file_operation.py::test_read_file_with_line_numbers", - "test_file_operation.py::test_delete", - "test_file_operation.py::test_delete_nonexistent", - "test_file_operation.py::test_edit_file", - "test_file_operation.py::test_create_file_already_exists" + "test_patch_util.py::test_get_updated_files_empty_diff", + "test_patch_util.py::test_get_updated_files_modified_only", + "test_patch_util.py::test_get_updated_files_removed_only", + "test_patch_util.py::test_get_updated_files_multiple_changes", + "test_patch_util.py::test_get_updated_files_with_subfolders" ], - "reproducing_test_fail_log": "", - "total_tests_run": 7 + "regression_test_fail_log": "test_patch_util.py::test_get_updated_files_added_only FAILED [ 33%] tests/utils/test_patch_util.py:13 (test_get_updated_files_added_only) 0 != 1 Expected:1 Actual:0 <点击以查看差异> def test_get_updated_files_added_only(): diff = \\\"\"\" diff --git a/new_file.txt b/new_file.txt new file mode 100644 index 0000000..1234567 --- /dev/null +++ b/new_file.txt @@ -0,0 +1 @@ +New content \\\"\"\" added, modified, removed = get_updated_files(diff) assert len(added) == 1 > assert len(modified) == 1 E assert 0 == 1 E + where 0 = len([]) test_patch_util.py:26: AssertionError", + "total_tests_run": 6 }} + Important: - Only look at test pass/fail status - A single failing test means the test is not passing - Include complete test output in failure log - Do Not output any log when where is no test executed. ONLY output the log exact and complete test FAILURE log when test failure! -- Do not forget to return the total number of tests run! If tests were unable to run due to an error, do not count them! +- Only include tests that actually executed (Both Passed and Failed). If tests couldn't run due to setup errors, don't count them. - If you can't find any test run, return 0 for total number of tests run! """ HUMAN_PROMPT = """ diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py index 6cb5ace..5e61f3b 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py @@ -24,7 +24,7 @@ class IssueNotVerifiedBugState(TypedDict): edit_patches: Annotated[Sequence[str], add] - deduplicated_patches: Sequence[str] + final_candidate_patches: Sequence[str] final_patch: str diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py index 43a1753..15ea23e 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py @@ -65,19 +65,23 @@ def __init__( reset_edit_messages_node = ResetMessagesNode("edit_messages") # Patch Normalization Node - patch_normalization_node = PatchNormalizationNode() + patch_normalization_node = PatchNormalizationNode("edit_patches", "final_candidate_patches") # Get pass regression test patch subgraph node get_pass_regression_test_patch_subgraph_node = GetPassRegressionTestPatchSubgraphNode( model=base_model, container=container, git_repo=git_repo, - testing_patch_key="deduplicated_patches", + testing_patch_key="final_candidate_patches", is_testing_patch_list=True, + return_str_patch=True, + return_key="final_candidate_patches", ) # Final patch selection node - final_patch_selection_node = FinalPatchSelectionNode(advanced_model) + final_patch_selection_node = FinalPatchSelectionNode( + advanced_model, "final_candidate_patches" + ) workflow = StateGraph(IssueNotVerifiedBugState) diff --git a/prometheus/lang_graph/subgraphs/issue_verified_bug_state.py b/prometheus/lang_graph/subgraphs/issue_verified_bug_state.py index a31f69e..39fca82 100644 --- a/prometheus/lang_graph/subgraphs/issue_verified_bug_state.py +++ b/prometheus/lang_graph/subgraphs/issue_verified_bug_state.py @@ -1,3 +1,4 @@ +from operator import add from typing import Annotated, Mapping, Sequence, TypedDict from langchain_core.messages import BaseMessage @@ -15,6 +16,9 @@ class IssueVerifiedBugState(TypedDict): max_refined_query_loop: int refined_query: str + number_of_candidate_patch: int + final_candidate_patches: Annotated[Sequence[TestedPatchResult], add] + run_existing_test: bool run_regression_test: bool diff --git a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py index de1a1c0..d920a19 100644 --- a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py @@ -14,6 +14,7 @@ from prometheus.lang_graph.nodes.context_retrieval_subgraph_node import ContextRetrievalSubgraphNode from prometheus.lang_graph.nodes.edit_message_node import EditMessageNode from prometheus.lang_graph.nodes.edit_node import EditNode +from prometheus.lang_graph.nodes.final_patch_selection_node import FinalPatchSelectionNode from prometheus.lang_graph.nodes.get_pass_regression_test_patch_subgraph_node import ( GetPassRegressionTestPatchSubgraphNode, ) @@ -23,6 +24,7 @@ from prometheus.lang_graph.nodes.issue_bug_analyzer_node import IssueBugAnalyzerNode from prometheus.lang_graph.nodes.issue_bug_context_message_node import IssueBugContextMessageNode from prometheus.lang_graph.nodes.noop_node import NoopNode +from prometheus.lang_graph.nodes.patch_normalization_node import PatchNormalizationNode from prometheus.lang_graph.nodes.run_existing_tests_subgraph_node import ( RunExistingTestsSubgraphNode, ) @@ -98,15 +100,16 @@ def __init__( git_diff_node = GitDiffNode(git_repo, "edit_patch") git_reset_node = GitResetNode(git_repo) - noop_node = NoopNode() - # Phase 5: Run Regression Tests if available + get_pass_regression_test_patch_branch_node = NoopNode() get_pass_regression_test_patch_subgraph_node = GetPassRegressionTestPatchSubgraphNode( model=base_model, container=container, git_repo=git_repo, testing_patch_key="edit_patch", is_testing_patch_list=False, + return_str_patch=False, + return_key="tested_patch_result", ) # Phase 6: Re-run test case that reproduces the bug @@ -114,6 +117,18 @@ def __init__( base_model, container, git_repo ) + # Select the best patch if the number of passed reproduction test patches >= candidate patches + final_patch_selection_branch_node = NoopNode() + + # Patch Normalization Node + patch_normalization_node = PatchNormalizationNode( + "final_candidate_patches", "final_candidate_patches" + ) + + final_patch_selection_node = FinalPatchSelectionNode( + advanced_model, "final_candidate_patches" + ) + # Phase 7: Optionally run existing tests run_existing_tests_branch_node = NoopNode() run_existing_tests_subgraph_node = RunExistingTestsSubgraphNode( @@ -140,14 +155,21 @@ def __init__( workflow.add_node("edit_tools", edit_tools) workflow.add_node("git_diff_node", git_diff_node) workflow.add_node("git_reset_node", git_reset_node) - workflow.add_node("noop_node", noop_node) + workflow.add_node("bug_fix_verification_subgraph_node", bug_fix_verification_subgraph_node) + + workflow.add_node("final_patch_selection_branch_node", final_patch_selection_branch_node) + workflow.add_node("patch_normalization_node", patch_normalization_node) + workflow.add_node("final_patch_selection_node", final_patch_selection_node) + + workflow.add_node( + "get_pass_regression_test_patch_branch_node", get_pass_regression_test_patch_branch_node + ) workflow.add_node( "get_pass_regression_test_patch_subgraph_node", get_pass_regression_test_patch_subgraph_node, ) - workflow.add_node("bug_fix_verification_subgraph_node", bug_fix_verification_subgraph_node) workflow.add_node("run_existing_tests_branch_node", run_existing_tests_branch_node) workflow.add_node("run_existing_tests_subgraph_node", run_existing_tests_subgraph_node) @@ -180,33 +202,52 @@ def __init__( workflow.add_conditional_edges( "git_reset_node", lambda state: bool(state["edit_patch"]), - {True: "noop_node", False: "issue_bug_analyzer_message_node"}, + {True: "bug_fix_verification_subgraph_node", False: "issue_bug_analyzer_message_node"}, ) + # If reproduction test fails, loop back to reanalyze the bug workflow.add_conditional_edges( - "noop_node", + "bug_fix_verification_subgraph_node", + lambda state: bool(state["reproducing_test_fail_log"]), + { + True: "issue_bug_analyzer_message_node", + False: "final_patch_selection_branch_node", + }, + ) + + # If the number of passed reproduction test patches >= candidate patches, select the best patch + workflow.add_conditional_edges( + "final_patch_selection_branch_node", + lambda state: len(state["final_candidate_patches"]) + >= state["number_of_candidate_patch"], + { + True: "patch_normalization_node", + False: "get_pass_regression_test_patch_branch_node", + }, + ) + + workflow.add_edge("patch_normalization_node", "final_patch_selection_node") + workflow.add_edge("final_patch_selection_node", END) + + # If selected regression tests are required to run, run them + workflow.add_conditional_edges( + "get_pass_regression_test_patch_branch_node", lambda state: state["run_regression_test"], { True: "get_pass_regression_test_patch_subgraph_node", - False: "bug_fix_verification_subgraph_node", + False: "run_existing_tests_branch_node", }, ) + workflow.add_conditional_edges( "get_pass_regression_test_patch_subgraph_node", lambda state: state["tested_patch_result"][0].passed, { - True: "bug_fix_verification_subgraph_node", + True: "run_existing_tests_branch_node", False: "issue_bug_analyzer_message_node", }, ) - # If test still fails, loop back to reanalyze the bug - workflow.add_conditional_edges( - "bug_fix_verification_subgraph_node", - lambda state: bool(state["reproducing_test_fail_log"]), - {True: "issue_bug_analyzer_message_node", False: "run_existing_tests_branch_node"}, - ) - # Optionally run existing tests suite workflow.add_conditional_edges( "run_existing_tests_branch_node", @@ -229,20 +270,21 @@ def invoke( issue_title: str, issue_body: str, issue_comments: Sequence[Mapping[str, str]], + number_of_candidate_patch: int, run_regression_test: bool, run_existing_test: bool, reproduced_bug_file: str, reproduced_bug_commands: Sequence[str], reproduced_bug_patch: str, selected_regression_tests: Sequence[str], - recursion_limit: int = 200, ): - config = {"recursion_limit": recursion_limit} + config = {"recursion_limit": (number_of_candidate_patch + 3) * 75 + 75} input_state = { "issue_title": issue_title, "issue_body": issue_body, "issue_comments": issue_comments, + "number_of_candidate_patch": number_of_candidate_patch, "run_regression_test": run_regression_test, "run_existing_test": run_existing_test, "reproduced_bug_file": reproduced_bug_file, diff --git a/prometheus/tools/web_search.py b/prometheus/tools/web_search.py index cfd38b2..ecefe3f 100644 --- a/prometheus/tools/web_search.py +++ b/prometheus/tools/web_search.py @@ -116,6 +116,8 @@ def web_search( "readthedocs.org", ] + self._logger.debug(f"Query: {query}") + # Call the Tavily API try: response = self.tavily_client.search( From ab9041d2a1ef570bee36933b91dc2de7e31e09e8 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Thu, 11 Sep 2025 00:02:50 +0800 Subject: [PATCH 2/3] feat: Add persistent shell support for command execution in Docker containers --- prometheus/docker/base_container.py | 129 ++++++++++++++++++--- prometheus/exceptions/docker_exception.py | 4 + tests/docker/test_base_container.py | 135 +++++++++++++++++----- 3 files changed, 227 insertions(+), 41 deletions(-) create mode 100644 prometheus/exceptions/docker_exception.py diff --git a/prometheus/docker/base_container.py b/prometheus/docker/base_container.py index 8cc5c78..c4a99ad 100644 --- a/prometheus/docker/base_container.py +++ b/prometheus/docker/base_container.py @@ -7,7 +7,9 @@ from typing import Optional, Sequence import docker +import pexpect +from prometheus.exceptions.docker_exception import DockerException from prometheus.utils.logger_manager import get_thread_logger @@ -18,6 +20,8 @@ class BaseContainer(ABC): containers. It handles container lifecycle operations including building images, starting containers, updating files, and cleanup. The class is designed to be extended for specific container implementations that specifies the Dockerfile, how to build and how to run the test. + + Now supports persistent shell for maintaining command execution context. """ client: docker.DockerClient = docker.from_env() @@ -27,6 +31,7 @@ class BaseContainer(ABC): project_path: Path timeout: int = 300 # Timeout for commands in seconds logger: logging.Logger + shell: Optional[pexpect.spawn] = None # Persistent shell def __init__( self, @@ -56,6 +61,7 @@ def __init__( self._logger.debug(f"Using workdir: {self.workdir}") self.container = None + self.shell = None @abstractmethod def get_dockerfile_content(self) -> str: @@ -103,6 +109,7 @@ def start_container(self): """Start a Docker container from the built image. Starts a detached container with TTY enabled and mounts the Docker socket. + Also initializes the persistent shell. """ self._logger.info(f"Starting container from image {self.tag_name}") self.container = self.client.containers.run( @@ -114,6 +121,50 @@ def start_container(self): volumes={"/var/run/docker.sock": {"bind": "/var/run/docker.sock", "mode": "rw"}}, ) + # Initialize persistent shell + self._start_persistent_shell() + + def _start_persistent_shell(self): + """Start a persistent bash shell inside the container using pexpect.""" + if not self.container: + self._logger.error("Container must be started before initializing shell") + return + + self._logger.info("Starting persistent shell for interactive mode...") + try: + command = f"docker exec -it {self.container.id} /bin/bash" + self.shell = pexpect.spawn(command, encoding="utf-8", timeout=self.timeout) + + # Wait for the initial shell prompt + self.shell.expect([r"\$", r"#"], timeout=60) + + self._logger.info("Persistent shell is ready") + except pexpect.exceptions.TIMEOUT: + self._logger.error( + "Timeout waiting for shell prompt. The container might be slow to start or misconfigured." + ) + if self.shell: + self.shell.close(force=True) + self.shell = None + raise DockerException("Timeout waiting for shell prompt.") + except Exception as e: + self._logger.error(f"Failed to start persistent shell: {e}") + if self.shell: + self.shell.close(force=True) + self.shell = None + raise DockerException(f"Failed to start persistent shell: {e}") + + def _restart_shell_if_needed(self): + """Restart the shell if it's not alive.""" + if not self.shell or not self.shell.isalive(): + self._logger.warning("Shell not found or died. Attempting to restart...") + if self.shell: + self.shell.close(force=True) + self._start_persistent_shell() + + if self.shell is None: + raise DockerException("Failed to start or restart the persistent shell.") + def is_running(self) -> bool: return bool(self.container) @@ -156,6 +207,7 @@ def update_files( self._logger.info("Files updated successfully") def run_build(self) -> str: + """Run build commands and return combined output.""" if not self.build_commands: self._logger.error("No build commands defined") return "" @@ -167,6 +219,7 @@ def run_build(self) -> str: return command_output def run_test(self) -> str: + """Run test commands and return combined output.""" if not self.test_commands: self._logger.error("No test commands defined") return "" @@ -178,30 +231,70 @@ def run_test(self) -> str: return command_output def execute_command(self, command: str) -> str: - """Execute a command in the running container. + """Execute a command in the running container using persistent shell. Args: command: Command to execute in the container. Returns: - str: Output of the command as a string. + str: Output of the command. """ - timeout_msg = f""" + self._logger.debug(f"Executing command: {command}") + + # Ensure shell is available + self._restart_shell_if_needed() + + # Unique marker to identify command completion and exit code + marker = "---CMD_DONE---" + full_command = command.strip() + marker_command = f"echo {marker}$?" + + try: + self.shell.sendline(full_command) + self.shell.sendline(marker_command) + + # Wait for the marker with exit code + self.shell.expect(marker + r"(\d+)", timeout=self.timeout) + exit_code = int(self.shell.match.group(1)) + + # Get the output before the marker + output_before_marker = self.shell.before + + # Clean up the output by removing command echoes + all_lines = output_before_marker.splitlines() + clean_lines = [] + for line in all_lines: + stripped_line = line.strip() + # Ignore the line if it's an echo of our commands + if ( + stripped_line != full_command + and marker_command not in stripped_line + and line not in ["\x1b[?2004l", "\x1b[?2004h"] + ): + clean_lines.append(line) + + cleaned_output = "\n".join(clean_lines).strip() + + # Wait for the next shell prompt to ensure the shell is ready + self.shell.expect([r"\$", r"#"], timeout=10) + + self._logger.debug(f"Command exit code: {exit_code}") + self._logger.debug(f"Command output:\n{cleaned_output}") + + return cleaned_output + + except pexpect.exceptions.TIMEOUT: + timeout_msg = f""" ******************************************************************************* {command} timeout after {self.timeout} seconds ******************************************************************************* """ - bash_cmd = ["/bin/bash", "-lc", command] - full_cmd = ["timeout", "-k", "5", f"{self.timeout}s", *bash_cmd] - self._logger.debug(f"Running command in container: {command}") - exec_result = self.container.exec_run(full_cmd, workdir=self.workdir) - exec_result_str = exec_result.output.decode("utf-8") - - if exec_result.exit_code in (124, 137): - exec_result_str += timeout_msg + self._logger.error(f"Command '{command}' timed out after {self.timeout} seconds") + partial_output = getattr(self.shell, "before", "") + return f"Command '{command}' timed out after {self.timeout} seconds. Partial output:\n{partial_output}{timeout_msg}" - self._logger.debug(f"Command output:\n{exec_result_str}") - return exec_result_str + except Exception as e: + raise DockerException(f"Error executing command '{command}': {e}") def reset_repository(self): """Reset the git repository in the container to a clean state.""" @@ -212,10 +305,18 @@ def reset_repository(self): def cleanup(self): """Clean up container resources and temporary files. - Stops and removes the container, removes the Docker image, + Stops the persistent shell, stops and removes the container, removes the Docker image, and deletes temporary project files. """ self._logger.info("Cleaning up container and temporary files") + + # Close persistent shell first + if self.shell and self.shell.isalive(): + self._logger.info("Closing persistent shell...") + self.shell.close(force=True) + self.shell = None + + self._logger.info("Cleaning up container and temporary files") if self.container: self.container.stop(timeout=10) self.container.remove(force=True) diff --git a/prometheus/exceptions/docker_exception.py b/prometheus/exceptions/docker_exception.py new file mode 100644 index 0000000..847ea55 --- /dev/null +++ b/prometheus/exceptions/docker_exception.py @@ -0,0 +1,4 @@ +class DockerException(Exception): + """Base class for Docker-related exceptions.""" + + pass diff --git a/tests/docker/test_base_container.py b/tests/docker/test_base_container.py index 410cf4f..b7ac403 100644 --- a/tests/docker/test_base_container.py +++ b/tests/docker/test_base_container.py @@ -67,16 +67,25 @@ def test_build_docker_image(container, mock_docker_client): ) -def test_start_container(container, mock_docker_client): +@patch("prometheus.docker.base_container.pexpect.spawn") +def test_start_container(mock_spawn, container, mock_docker_client): """Test starting Docker container""" - # Setup mock + # Setup mock for pexpect shell + mock_shell = Mock() + mock_spawn.return_value = mock_shell + mock_shell.expect.return_value = 0 # Simulate successful prompt match + + # Setup mock for docker client mock_containers = Mock() mock_docker_client.containers = mock_containers + mock_container = Mock() + mock_container.id = "test_container_id" + mock_containers.run.return_value = mock_container # Execute container.start_container() - # Verify + # Verify docker container run was called mock_containers.run.assert_called_once_with( container.tag_name, detach=True, @@ -86,6 +95,14 @@ def test_start_container(container, mock_docker_client): volumes={"/var/run/docker.sock": {"bind": "/var/run/docker.sock", "mode": "rw"}}, ) + # Verify pexpect shell was started + mock_spawn.assert_called_once_with( + f"docker exec -it {mock_container.id} /bin/bash", + encoding="utf-8", + timeout=container.timeout, + ) + mock_shell.expect.assert_called() + def test_is_running(container): """Test is_running status check""" @@ -101,8 +118,7 @@ def test_update_files(container, temp_project_dir): """Test updating files in container""" # Setup container.container = Mock() - mock_execute = Mock() - container.execute_command = mock_execute + container.execute_command = Mock() # Create test files test_file1 = temp_project_dir / "dir1" / "test1.txt" @@ -119,39 +135,59 @@ def test_update_files(container, temp_project_dir): container.update_files(temp_project_dir, updated_files, removed_files) # Verify - mock_execute.assert_has_calls( + container.execute_command.assert_has_calls( [call("rm dir3/old.txt"), call("mkdir -p dir1"), call("mkdir -p dir2")] ) assert container.container.put_archive.called -def test_execute_command(container): - """Test executing command in container""" - # Setup - mock_container = Mock() - mock_exec_result = Mock() - mock_exec_result.exit_code = 0 - mock_exec_result.output = b"command output" - mock_container.exec_run.return_value = mock_exec_result - container.container = mock_container +@patch("prometheus.docker.base_container.pexpect.spawn") +def test_execute_command(mock_spawn, container): + """Test executing command in container using persistent shell""" + # Setup mock shell + mock_shell = Mock() + mock_spawn.return_value = mock_shell + + # Setup container and shell + container.container = Mock() + container.container.id = "test_container_id" + container.shell = mock_shell + mock_shell.isalive.return_value = True + + # Mock the shell interactions + mock_shell.match = Mock() + mock_shell.match.group.return_value = "0" # Exit code 0 + mock_shell.before = "test command\ncommand output" + + # Execute + result = container.execute_command("test command") + + # Verify shell interactions + assert mock_shell.sendline.call_count == 2 # Command + marker command + mock_shell.expect.assert_called() + + # The result should contain the cleaned output + assert "command output" in result + + +def test_execute_command_with_mock(container): + """Test executing command with direct mocking""" + # Setup - directly mock the execute_command method + container.execute_command = Mock(return_value="mocked output") + container.container = Mock() # Execute result = container.execute_command("test command") # Verify - mock_container.exec_run.assert_called_once_with( - ["timeout", "-k", "5", "300s", "/bin/bash", "-lc", "test command"], - workdir=container.workdir, - ) - assert result == "command output" + container.execute_command.assert_called_once_with("test command") + assert result == "mocked output" def test_reset_repository(container): """Test container reset repository""" - # Setup - Mock the execute_command method of the container itself + # Setup - Mock the execute_command method container.execute_command = Mock(return_value="Command output") - - # Also ensure the container has a valid container attribute (even if it's not used in this method) container.container = Mock() # Execute @@ -159,22 +195,29 @@ def test_reset_repository(container): # Verify - Check that execute_command was called twice with the correct commands assert container.execute_command.call_count == 2 - - # Check the specific calls expected_calls = [call("git reset --hard"), call("git clean -fd")] container.execute_command.assert_has_calls(expected_calls, any_order=False) -def test_cleanup(container, mock_docker_client): +@patch("prometheus.docker.base_container.pexpect.spawn") +def test_cleanup(mock_spawn, container, mock_docker_client): """Test cleanup of container resources""" # Setup mock_container = Mock() container.container = mock_container + # Setup mock shell + mock_shell = Mock() + mock_shell.isalive.return_value = True + container.shell = mock_shell + # Execute container.cleanup() - # Verify + # Verify shell cleanup + mock_shell.close.assert_called_once_with(force=True) + + # Verify container cleanup mock_container.stop.assert_called_once_with(timeout=10) mock_container.remove.assert_called_once_with(force=True) mock_docker_client.images.remove.assert_called_once_with(container.tag_name, force=True) @@ -213,3 +256,41 @@ def test_run_test(container): # Verify output format expected_output = "$ pytest tests/\nTest passed\n" assert test_output == expected_output + + +def test_run_build_no_commands(container): + """Test run_build when no build commands are defined""" + container.build_commands = None + result = container.run_build() + assert result == "" + + +def test_run_test_no_commands(container): + """Test run_test when no test commands are defined""" + container.test_commands = None + result = container.run_test() + assert result == "" + + +@patch("prometheus.docker.base_container.pexpect.spawn") +def test_restart_shell_if_needed(mock_spawn, container): + """Test shell restart functionality""" + # Setup + mock_shell_dead = Mock() + mock_shell_dead.isalive.return_value = False + + mock_shell_new = Mock() + mock_shell_new.expect.return_value = 0 + mock_spawn.return_value = mock_shell_new + + container.container = Mock() + container.container.id = "test_container_id" + container.shell = mock_shell_dead + + # Execute + container._restart_shell_if_needed() + + # Verify old shell was closed and new one started + mock_shell_dead.close.assert_called_once_with(force=True) + mock_spawn.assert_called_once() + assert container.shell == mock_shell_new From 5811c33e5d4363b4f1cc14e5a674c3afd9879bc9 Mon Sep 17 00:00:00 2001 From: Yue Pan <79363355+dcloud347@users.noreply.github.com> Date: Thu, 11 Sep 2025 10:42:01 +0800 Subject: [PATCH 3/3] fix: Add pexpect dependency for improved process handling --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 522821b..b45cebd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "tavily-python>=0.5.1", "langchain-mcp-adapters>=0.1.9", "httpx==0.28.1", + "pexpect==4.9.0" ] requires-python = ">= 3.11"