Skip to content
Merged

Dev #129

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 115 additions & 14 deletions prometheus/docker/base_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from typing import Optional, Sequence

import docker
import pexpect

from prometheus.exceptions.docker_exception import DockerException
from prometheus.utils.logger_manager import get_thread_logger


Expand All @@ -18,6 +20,8 @@ class BaseContainer(ABC):
containers. It handles container lifecycle operations including building images, starting
containers, updating files, and cleanup. The class is designed to be extended for specific
container implementations that specifies the Dockerfile, how to build and how to run the test.

Now supports persistent shell for maintaining command execution context.
"""

client: docker.DockerClient = docker.from_env()
Expand All @@ -27,6 +31,7 @@ class BaseContainer(ABC):
project_path: Path
timeout: int = 300 # Timeout for commands in seconds
logger: logging.Logger
shell: Optional[pexpect.spawn] = None # Persistent shell

def __init__(
self,
Expand Down Expand Up @@ -56,6 +61,7 @@ def __init__(
self._logger.debug(f"Using workdir: {self.workdir}")

self.container = None
self.shell = None

@abstractmethod
def get_dockerfile_content(self) -> str:
Expand Down Expand Up @@ -103,6 +109,7 @@ def start_container(self):
"""Start a Docker container from the built image.

Starts a detached container with TTY enabled and mounts the Docker socket.
Also initializes the persistent shell.
"""
self._logger.info(f"Starting container from image {self.tag_name}")
self.container = self.client.containers.run(
Expand All @@ -114,6 +121,50 @@ def start_container(self):
volumes={"/var/run/docker.sock": {"bind": "/var/run/docker.sock", "mode": "rw"}},
)

# Initialize persistent shell
self._start_persistent_shell()

def _start_persistent_shell(self):
"""Start a persistent bash shell inside the container using pexpect."""
if not self.container:
self._logger.error("Container must be started before initializing shell")
return

self._logger.info("Starting persistent shell for interactive mode...")
try:
command = f"docker exec -it {self.container.id} /bin/bash"
self.shell = pexpect.spawn(command, encoding="utf-8", timeout=self.timeout)

# Wait for the initial shell prompt
self.shell.expect([r"\$", r"#"], timeout=60)

self._logger.info("Persistent shell is ready")
except pexpect.exceptions.TIMEOUT:
self._logger.error(
"Timeout waiting for shell prompt. The container might be slow to start or misconfigured."
)
if self.shell:
self.shell.close(force=True)
self.shell = None
raise DockerException("Timeout waiting for shell prompt.")
except Exception as e:
self._logger.error(f"Failed to start persistent shell: {e}")
if self.shell:
self.shell.close(force=True)
self.shell = None
raise DockerException(f"Failed to start persistent shell: {e}")

def _restart_shell_if_needed(self):
"""Restart the shell if it's not alive."""
if not self.shell or not self.shell.isalive():
self._logger.warning("Shell not found or died. Attempting to restart...")
if self.shell:
self.shell.close(force=True)
self._start_persistent_shell()

if self.shell is None:
raise DockerException("Failed to start or restart the persistent shell.")

def is_running(self) -> bool:
return bool(self.container)

Expand Down Expand Up @@ -156,6 +207,7 @@ def update_files(
self._logger.info("Files updated successfully")

def run_build(self) -> str:
"""Run build commands and return combined output."""
if not self.build_commands:
self._logger.error("No build commands defined")
return ""
Expand All @@ -167,6 +219,7 @@ def run_build(self) -> str:
return command_output

def run_test(self) -> str:
"""Run test commands and return combined output."""
if not self.test_commands:
self._logger.error("No test commands defined")
return ""
Expand All @@ -178,30 +231,70 @@ def run_test(self) -> str:
return command_output

def execute_command(self, command: str) -> str:
"""Execute a command in the running container.
"""Execute a command in the running container using persistent shell.

Args:
command: Command to execute in the container.

Returns:
str: Output of the command as a string.
str: Output of the command.
"""
timeout_msg = f"""
self._logger.debug(f"Executing command: {command}")

# Ensure shell is available
self._restart_shell_if_needed()

# Unique marker to identify command completion and exit code
marker = "---CMD_DONE---"
full_command = command.strip()
marker_command = f"echo {marker}$?"

try:
self.shell.sendline(full_command)
self.shell.sendline(marker_command)

# Wait for the marker with exit code
self.shell.expect(marker + r"(\d+)", timeout=self.timeout)
exit_code = int(self.shell.match.group(1))

# Get the output before the marker
output_before_marker = self.shell.before

# Clean up the output by removing command echoes
all_lines = output_before_marker.splitlines()
clean_lines = []
for line in all_lines:
stripped_line = line.strip()
# Ignore the line if it's an echo of our commands
if (
stripped_line != full_command
and marker_command not in stripped_line
and line not in ["\x1b[?2004l", "\x1b[?2004h"]
):
clean_lines.append(line)

cleaned_output = "\n".join(clean_lines).strip()

# Wait for the next shell prompt to ensure the shell is ready
self.shell.expect([r"\$", r"#"], timeout=10)

self._logger.debug(f"Command exit code: {exit_code}")
self._logger.debug(f"Command output:\n{cleaned_output}")

return cleaned_output

except pexpect.exceptions.TIMEOUT:
timeout_msg = f"""
*******************************************************************************
{command} timeout after {self.timeout} seconds
*******************************************************************************
"""
bash_cmd = ["/bin/bash", "-lc", command]
full_cmd = ["timeout", "-k", "5", f"{self.timeout}s", *bash_cmd]
self._logger.debug(f"Running command in container: {command}")
exec_result = self.container.exec_run(full_cmd, workdir=self.workdir)
exec_result_str = exec_result.output.decode("utf-8")

if exec_result.exit_code in (124, 137):
exec_result_str += timeout_msg
self._logger.error(f"Command '{command}' timed out after {self.timeout} seconds")
partial_output = getattr(self.shell, "before", "")
return f"Command '{command}' timed out after {self.timeout} seconds. Partial output:\n{partial_output}{timeout_msg}"

self._logger.debug(f"Command output:\n{exec_result_str}")
return exec_result_str
except Exception as e:
raise DockerException(f"Error executing command '{command}': {e}")

def reset_repository(self):
"""Reset the git repository in the container to a clean state."""
Expand All @@ -212,10 +305,18 @@ def reset_repository(self):
def cleanup(self):
"""Clean up container resources and temporary files.

Stops and removes the container, removes the Docker image,
Stops the persistent shell, stops and removes the container, removes the Docker image,
and deletes temporary project files.
"""
self._logger.info("Cleaning up container and temporary files")

# Close persistent shell first
if self.shell and self.shell.isalive():
self._logger.info("Closing persistent shell...")
self.shell.close(force=True)
self.shell = None

self._logger.info("Cleaning up container and temporary files")
if self.container:
self.container.stop(timeout=10)
self.container.remove(force=True)
Expand Down
4 changes: 4 additions & 0 deletions prometheus/exceptions/docker_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class DockerException(Exception):
"""Base class for Docker-related exceptions."""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@ def __call__(self, state: IssueVerifiedBugState):

return {
"reproducing_test_fail_log": output_state["reproducing_test_fail_log"],
"final_candidate_patches": [state["edit_patch"]]
if not bool(output_state["reproducing_test_fail_log"])
else [],
}
18 changes: 8 additions & 10 deletions prometheus/lang_graph/nodes/final_patch_selection_node.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Sequence
from typing import Dict, Sequence

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field

from prometheus.lang_graph.subgraphs.issue_not_verified_bug_state import IssueNotVerifiedBugState
from prometheus.utils.issue_util import format_issue_info
from prometheus.utils.logger_manager import get_thread_logger

Expand Down Expand Up @@ -120,8 +119,8 @@ class FinalPatchSelectionNode:
{patches}
"""

def __init__(self, model: BaseChatModel, max_retries: int = 2):
self.max_retries = max_retries
def __init__(self, model: BaseChatModel, candidate_patch_key: str):
self.candidate_patch_key = candidate_patch_key
prompt = ChatPromptTemplate.from_messages(
[("system", self.SYS_PROMPT), ("human", "{human_prompt}")]
)
Expand All @@ -130,7 +129,7 @@ def __init__(self, model: BaseChatModel, max_retries: int = 2):
self._logger, file_handler = get_thread_logger(__name__)
self.majority_voting_times = 10

def format_human_message(self, patches: Sequence[str], state: IssueNotVerifiedBugState):
def format_human_message(self, patches: Sequence[str], state: Dict):
patches_str = ""
for index, patch in enumerate(patches):
patches_str += f"Patch at index {index}:\n"
Expand All @@ -148,12 +147,11 @@ def format_human_message(self, patches: Sequence[str], state: IssueNotVerifiedBu
patches=patches_str,
)

def __call__(self, state: IssueNotVerifiedBugState):
def __call__(self, state: Dict):
# Determine candidate patches
if "tested_patch_result" in state and state["tested_patch_result"]:
patches = [result.patch for result in state["tested_patch_result"] if result.passed]
else:
patches = state["deduplicated_patches"]
patches = state[self.candidate_patch_key]
self._logger.debug(f"Total candidate patches: {len(patches)}")
self._logger.debug(f"Candidate patches: {patches}")

# Handle the case with no candidate patches
if not patches:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def __init__(
git_repo: GitRepository,
testing_patch_key: str,
is_testing_patch_list: bool,
return_str_patch: bool,
return_key: str,
):
self._logger, file_handler = get_thread_logger(__name__)
self.subgraph = GetPassRegressionTestPatchSubgraph(
Expand All @@ -30,6 +32,8 @@ def __init__(
self.git_repo = git_repo
self.testing_patch_key = testing_patch_key
self.is_testing_patch_list = is_testing_patch_list
self.return_str_patch = return_str_patch
self.return_key = return_key

def __call__(self, state: Dict):
self._logger.info("Enter get_pass_regression_test_patch_subgraph_node")
Expand All @@ -43,37 +47,37 @@ def __call__(self, state: Dict):

if not state["selected_regression_tests"]:
self._logger.info("No regression tests selected, skipping patch testing.")
return {
"tested_patch_result": [
TestedPatchResult(patch=patch, passed=True, regression_test_failure_log="")
for patch in testing_patch
]
}

try:
output_state = self.subgraph.invoke(
selected_regression_tests=state["selected_regression_tests"],
patches=testing_patch,
)
except GraphRecursionError:
# If the recursion limit is reached, return a failure result for each patch
self._logger.info("Recursion limit reached")
return {
"tested_patch_result": [
test_patch_results = [
TestedPatchResult(patch=patch, passed=True, regression_test_failure_log="")
for patch in testing_patch
]
else:
try:
output_state = self.subgraph.invoke(
selected_regression_tests=state["selected_regression_tests"],
patches=testing_patch,
)
self._logger.debug(f"tested_patch_result: {output_state['tested_patch_result']}")
test_patch_results = output_state["tested_patch_result"]
except GraphRecursionError:
# If the recursion limit is reached, return a failure result for each patch
self._logger.info("Recursion limit reached")
test_patch_results = [
TestedPatchResult(
patch=patch,
passed=False,
regression_test_failure_log="Fail to get regression test result. Please try again!",
)
for patch in state[self.testing_patch_key]
for patch in testing_patch
]
}
finally:
# Reset the git repository to its original state
self.git_repo.reset_repository()

self._logger.debug(f"tested_patch_result: {output_state['tested_patch_result']}")
finally:
# Reset the git repository to its original state
self.git_repo.reset_repository()

return {
"tested_patch_result": output_state["tested_patch_result"],
self.return_key: [result.patch for result in test_patch_results if result.passed]
if self.return_str_patch
else test_patch_results,
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __call__(self, state: IssueBugState):
issue_title=state["issue_title"],
issue_body=state["issue_body"],
issue_comments=state["issue_comments"],
number_of_candidate_patch=state["number_of_candidate_patch"],
run_regression_test=state["run_regression_test"],
run_existing_test=state["run_existing_test"],
reproduced_bug_file=state["reproduced_bug_file"],
Expand Down
Loading