diff --git a/README.md b/README.md index deaa4a33..dea1117f 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,89 @@ -![code coverage](https://github.com/Pantheon-temple/Prometheus/raw/coverage-badge/coverage.svg) +![Code Coverage](https://github.com/Pantheon-temple/Prometheus/raw/coverage-badge/coverage.svg) -Prometheus is a FastAPI backend service that performs intelligent codebase-level operations including answering questions, resolving issues, and reviewing pull requests. At its core, it implements a multi-agent approach governed by a state machine that ensures code quality through automated reviews, build verification, and test execution. +# Prometheus -Prometheus can be connected to other services provided in the `Pantheon-temple` orgnization, for example: -* Connect it to `Pantheon-temple/Argus-GitHub` to automatically pull the latest changes from your GitHub repository, answer/fix issues, and review pull requests. -* Connect it to `Pantheon-temple/Hermes` to have chat interface with your codebase (GitHub or local). -* Connect it your own service, by sending requests to the FastAPI endpoints. +Prometheus is a FastAPI-based backend service designed to perform intelligent codebase-level operations, including +answering questions, resolving issues, and reviewing pull requests. At its core, it implements a multi-agent approach +governed by a state machine to ensure code quality through automated reviews, build verification, and test execution. -# Quick start +## ๐Ÿš€ Features -In this project we use `docker-compose.yml`(for Linux), or `docker-compose.win_mac.yml` (for Windows or MacOS). You should update the `example_settings.toml` with the API keys, and then rename the file to `settings.toml`. +- **Codebase Analysis**: Answer questions about your codebase and provide insights. +- **Issue Resolution**: Automatically resolve issues in your repository. +- **Pull Request Reviews**: Perform intelligent reviews of pull requests to ensure code quality. +- **Multi-Agent System**: Uses a state machine to coordinate multiple agents for efficient task execution. +- **Integration with External Services**: Seamlessly connects with other services in the `Pantheon-temple` organization. -Now, simply run `docker compose up`, and you can access Promtheus at `http://localhost:9001` and the OpenAPI docs at `http://localhost:9001/docs` +--- +## ๐Ÿ”— Integrations -``` +Prometheus can be connected to other services for extended functionality: + +- **[Argus-GitHub](https://github.com/Pantheon-temple/Argus-GitHub)**: Automatically pull the latest changes from your + GitHub repository, answer/fix issues, and review pull requests. +- **[Hermes](https://github.com/Pantheon-temple/Hermes)**: Enable a chat interface to interact with your codebase ( + GitHub or local). +- **Custom Services**: Connect your own services by sending requests to the FastAPI endpoints. + +--- + +## โš™๏ธ Quick Start + +### โœ… Prerequisites + +- Docker +- Docker Compose +- API keys (e.g. OpenAI, Anthropic, Google Gemini) + +--- + +### ๐Ÿ“ฆ Setup + +1. Clone the repository: + ```bash + git clone https://github.com/Pantheon-temple/Prometheus.git + cd Prometheus + ``` + +2. Copy the `example.env` file to `.env` and update it with your API keys and other required configurations: + ```bash + mv example.env .env + ``` + +3. Start the services using Docker Compose: + + - **Linux (includes PostgreSQL)**: + ```bash + docker-compose up --build + ``` + + - **macOS / Windows**: + + > โš ๏ธ `docker-compose.win_mac.yml` does **not include PostgreSQL**.If you don't have PostgreSQL on your device, + you may have to start the PostgreSQL container manually **before starting services** by following the "Database + Setup" section below. + + ```bash + docker-compose -f docker-compose.win_mac.yml up --build + ``` + +4. Access Prometheus: + - Service: [http://localhost:9001](http://localhost:9001) + - OpenAPI Docs: [http://localhost:9001/docs](http://localhost:9001/docs) + +--- + +## ๐Ÿ—„๏ธ Database Setup + +### PostgreSQL + +> โš ๏ธ If you're using `docker-compose.win_mac.yml`, you may have to manually start PostgreSQL before launching +> Prometheus: + +Run the following command to start a PostgreSQL container: + +```bash docker run -d \ -p 5432:5432 \ -e POSTGRES_USER=postgres \ @@ -23,7 +92,11 @@ docker run -d \ postgres ``` -``` +### Neo4j + +Run the following command to start a Neo4j container: + +```bash docker run -d \ -p 7474:7474 \ -p 7687:7687 \ @@ -33,4 +106,92 @@ docker run -d \ -e NEO4J_dbms_memory_heap_max__size=8G \ -e NEO4J_dbms_memory_pagecache_size=4G \ neo4j -``` \ No newline at end of file +``` + +Verify Neo4J at: [http://localhost:7474](http://localhost:7474) + +--- + +## โš™๏ธ Configuration + +Set the following variables in your `.env` file: + +### ๐Ÿ”น Neo4j + +* `PROMETHEUS_NEO4J_URI` +* `PROMETHEUS_NEO4J_USERNAME` +* `PROMETHEUS_NEO4J_PASSWORD` + +### ๐Ÿ”น LLM Models + +* `PROMETHEUS_ADVANCED_MODEL` +* `PROMETHEUS_BASE_MODEL` +* API Keys: + + * `PROMETHEUS_OPENAI_API_KEY` + * `PROMETHEUS_ANTHROPIC_API_KEY` + * `PROMETHEUS_GEMINI_API_KEY` + * `PROMETHEUS_OPENROUTER_API_KEY` + +### ๐Ÿ”น Other Settings + +* `PROMETHEUS_WORKING_DIRECTORY` +* `PROMETHEUS_GITHUB_ACCESS_TOKEN` +* `PROMETHEUS_KNOWLEDGE_GRAPH_MAX_AST_DEPTH` +* `PROMETHEUS_NEO4J_BATCH_SIZE` +* `PROMETHEUS_POSTGRES_URI` + +--- + +## ๐Ÿงช Development + +### Requirements + +* Python 3.11+ + +### Steps + +1. Install dependencies: + + ```bash + pip install hatchling + pip install . + pip install .[test] + ``` + +2. Run tests: + + ```bash + coverage run --source=prometheus -m pytest -v -s -m "not git" + ``` + +3. Generate coverage report: + + ```bash + coverage report -m + ``` +4. Generate HTML report: + + ```bash + coverage html + open htmlcov/index.html + ``` + +5. Start dev server: + + ```bash + uvicorn prometheus.app.main:app --host 0.0.0.0 --port 9001 + ``` + +--- + +## ๐Ÿ“„ License + +Licensed under the [Apache License 2.0](LICENSE). + +--- + +## ๐Ÿ“ฌ Contact + +For questions or support, please open an issue in +the [GitHub repository](https://github.com/Pantheon-temple/Prometheus/issues). diff --git a/docker-compose.win_mac.yml b/docker-compose.win_mac.yml index ca70e6b1..55b9508d 100644 --- a/docker-compose.win_mac.yml +++ b/docker-compose.win_mac.yml @@ -32,24 +32,24 @@ services: ports: - "9001:9001" environment: - - PROMETHEUS_LOGGING_LEVEL="DEBUG" - - PROMETHEUS_NEO4J_URI="bolt://neo4j:7687" - - PROMETHEUS_NEO4J_USERNAME=neo4j - - PROMETHEUS_NEO4J_PASSWORD=password - - PROMETHEUS_NEO4J_BATCH_SIZE=1000 - - PROMETHEUS_KNOWLEDGE_GRAPH_MAX_AST_DEPTH=3 - - PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_SIZE=10000 - - PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_OVERLAP=1000 - - PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT=10000 + - PROMETHEUS_LOGGING_LEVEL=${PROMETHEUS_LOGGING_LEVEL} + - PROMETHEUS_NEO4J_URI=${PROMETHEUS_NEO4J_URI} + - PROMETHEUS_NEO4J_USERNAME=${PROMETHEUS_NEO4J_USERNAME} + - PROMETHEUS_NEO4J_PASSWORD=${PROMETHEUS_NEO4J_PASSWORD} + - PROMETHEUS_NEO4J_BATCH_SIZE=${PROMETHEUS_NEO4J_BATCH_SIZE} + - PROMETHEUS_KNOWLEDGE_GRAPH_MAX_AST_DEPTH=${PROMETHEUS_KNOWLEDGE_GRAPH_MAX_AST_DEPTH} + - PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_SIZE=${PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_SIZE} + - PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_OVERLAP=${PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_OVERLAP} + - PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT=${PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT} - PROMETHEUS_ADVANCED_MODEL=${PROMETHEUS_ADVANCED_MODEL} - PROMETHEUS_BASE_MODEL=${PROMETHEUS_BASE_MODEL} - PROMETHEUS_ANTHROPIC_API_KEY=${PROMETHEUS_ANTHROPIC_API_KEY} - PROMETHEUS_GEMINI_API_KEY=${PROMETHEUS_GEMINI_API_KEY} - PROMETHEUS_OPENAI_API_KEY=${PROMETHEUS_OPENAI_API_KEY} - PROMETHEUS_OPENROUTER_API_KEY=${PROMETHEUS_OPENROUTER_API_KEY} - - PROMETHEUS_WORKING_DIRECTORY="/app/working_dir" + - PROMETHEUS_WORKING_DIRECTORY=${PROMETHEUS_WORKING_DIRECTORY} - PROMETHEUS_GITHUB_ACCESS_TOKEN=${PROMETHEUS_GITHUB_ACCESS_TOKEN} - - PROMETHEUS_POSTGRES_URI="postgresql://postgres:password@postgres:5432/postgres?sslmode=disable" + - PROMETHEUS_POSTGRES_URI=${PROMETHEUS_POSTGRES_URI} networks: - prometheus_network volumes: diff --git a/docker-compose.yml b/docker-compose.yml index 4be0c3f3..ed86f6f0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -43,24 +43,24 @@ services: ports: - "9001:9001" environment: - - PROMETHEUS_LOGGING_LEVEL="DEBUG" - - PROMETHEUS_NEO4J_URI="bolt://localhost:7687" - - PROMETHEUS_NEO4J_USERNAME=neo4j - - PROMETHEUS_NEO4J_PASSWORD=password - - PROMETHEUS_NEO4J_BATCH_SIZE=1000 - - PROMETHEUS_KNOWLEDGE_GRAPH_MAX_AST_DEPTH=3 - - PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_SIZE=10000 - - PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_OVERLAP=1000 - - PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT=10000 + - PROMETHEUS_LOGGING_LEVEL=${PROMETHEUS_LOGGING_LEVEL} + - PROMETHEUS_NEO4J_URI=${PROMETHEUS_NEO4J_URI} + - PROMETHEUS_NEO4J_USERNAME=${PROMETHEUS_NEO4J_USERNAME} + - PROMETHEUS_NEO4J_PASSWORD=${PROMETHEUS_NEO4J_PASSWORD} + - PROMETHEUS_NEO4J_BATCH_SIZE=${PROMETHEUS_NEO4J_BATCH_SIZE} + - PROMETHEUS_KNOWLEDGE_GRAPH_MAX_AST_DEPTH=${PROMETHEUS_KNOWLEDGE_GRAPH_MAX_AST_DEPTH} + - PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_SIZE=${PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_SIZE} + - PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_OVERLAP=${PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_OVERLAP} + - PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT=${PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT} - PROMETHEUS_ADVANCED_MODEL=${PROMETHEUS_ADVANCED_MODEL} - PROMETHEUS_BASE_MODEL=${PROMETHEUS_BASE_MODEL} - PROMETHEUS_ANTHROPIC_API_KEY=${PROMETHEUS_ANTHROPIC_API_KEY} - PROMETHEUS_GEMINI_API_KEY=${PROMETHEUS_GEMINI_API_KEY} - PROMETHEUS_OPENAI_API_KEY=${PROMETHEUS_OPENAI_API_KEY} - PROMETHEUS_OPENROUTER_API_KEY=${PROMETHEUS_OPENROUTER_API_KEY} - - PROMETHEUS_WORKING_DIRECTORY="/app/working_dir" + - PROMETHEUS_WORKING_DIRECTORY=${PROMETHEUS_WORKING_DIRECTORY} - PROMETHEUS_GITHUB_ACCESS_TOKEN=${PROMETHEUS_GITHUB_ACCESS_TOKEN} - - PROMETHEUS_POSTGRES_URI="postgresql://postgres:password@localhost:5432/postgres?sslmode=disable" + - PROMETHEUS_POSTGRES_URI=${PROMETHEUS_POSTGRES_URI} volumes: - .:/app - /var/run/docker.sock:/var/run/docker.sock diff --git a/example.env b/example.env new file mode 100644 index 00000000..c20ab69e --- /dev/null +++ b/example.env @@ -0,0 +1,29 @@ +# Logging +PROMETHEUS_LOGGING_LEVEL=DEBUG + +# Neo4j settings +PROMETHEUS_NEO4J_URI=bolt://localhost:7687 +PROMETHEUS_NEO4J_USERNAME=neo4j +PROMETHEUS_NEO4J_PASSWORD=password +PROMETHEUS_NEO4J_BATCH_SIZE=1000 + +# Knowledge Graph settings +PROMETHEUS_WORKING_DIRECTORY=/tmp/ +PROMETHEUS_KNOWLEDGE_GRAPH_MAX_AST_DEPTH=3 +PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_SIZE=10000 +PROMETHEUS_KNOWLEDGE_GRAPH_CHUNK_OVERLAP=1000 +PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT=10000 + +# LLM API keys and model settings +PROMETHEUS_ADVANCED_MODEL=gpt-4o +PROMETHEUS_BASE_MODEL=gpt-4o +PROMETHEUS_ANTHROPIC_API_KEY=anthropic_api_key +PROMETHEUS_GEMINI_API_KEY=gemini_api_key +PROMETHEUS_OPENAI_API_KEY=openai_api_key +PROMETHEUS_OPENROUTER_API_KEY=openrouter_api_key + +# GitHub settings +PROMETHEUS_GITHUB_ACCESS_TOKEN=github_access_token + +# PostgreSQL settings +PROMETHEUS_POSTGRES_URI=postgresql://postgres:password@localhost:5432/postgres?sslmode=disable \ No newline at end of file diff --git a/example_settings.toml b/example_settings.toml deleted file mode 100644 index 81fdcfb5..00000000 --- a/example_settings.toml +++ /dev/null @@ -1,29 +0,0 @@ -[default] -logging_level = "DEBUG" - -# neo4j settings -neo4j_uri = "bolt://localhost:7687" -neo4j_username = "neo4j" -neo4j_password = "password" -neo4j_batch_size = 1000 - -# knowledge_graph settings -working_directory = "/tmp/" -knowledge_graph_max_ast_depth = 3 -knowledge_graph_chunk_size = 10000 -knowledge_graph_chunk_overlap = 1000 -max_token_per_neo4j_result = 10000 - -# LLM settings -advanced_model = "claude-3-5-sonnet-20241022" -base_model = "claude-3-5-haiku-20241022" -anthropic_api_key = "anthropic_api_key" -gemini_api_key = "gemini_api_key" -openai_api_key = "openai_api_key" -openrouter_api_key = "openrouter_api_key" - -# github settings -github_access_token = "github_access_token" - -# postgres settings -postgres_uri = "postgresql://postgres:password@localhost:5432/postgres?sslmode=disable" diff --git a/prometheus/app/api/issue.py b/prometheus/app/api/issue.py index 6dcbe79e..aa47733f 100644 --- a/prometheus/app/api/issue.py +++ b/prometheus/app/api/issue.py @@ -10,123 +10,126 @@ class IssueRequest(BaseModel): - issue_number: int = Field(description="The number of the issue", examples=[42]) - issue_title: str = Field( - description="The title of the issue", examples=["There is a memory leak"] - ) - issue_body: str = Field( - description="The description of the issue", examples=["foo/bar.c is causing a memory leak"] - ) - issue_comments: Optional[Sequence[Mapping[str, str]]] = Field( - default=None, - description="Comments on the issue", - examples=[ - [ - {"username": "user1", "comment": "I've experienced this issue as well."}, - {"username": "user2", "comment": "A potential fix is to adjust the memory settings."}, - ] - ], - ) - issue_type: IssueType = Field( - default=IssueType.AUTO, - description="The type of the issue, set to auto if you do not know", - examples=[IssueType.AUTO], - ) - run_build: Optional[bool] = Field( - default=False, - description="When editing the code, whenver we should run the build to verify the fix", - examples=[False], - ) - run_existing_test: Optional[bool] = Field( - default=False, - description="When editing the code, whenver we should run the existing test to verify the fix", - examples=[False], - ) - number_of_candidate_patch: Optional[int] = Field( - default=4, - description="When the patch is not verfied (through build or test), number of candidate patches we generate to select the best one", - examples=[4], - ) - dockerfile_content: Optional[str] = Field( - default=None, - description="Specify the containerized enviroment with dockerfile content", - examples=["FROM python:3.11\nWORKDIR /app\nCOPY . /app"], - ) - image_name: Optional[str] = Field( - default=None, - description="Specify the containerized enviroment with image name that should be pulled from dockerhub", - examples=["python:3.11-slim"], - ) - workdir: Optional[str] = Field( - default=None, - description="If you specified the container environment, you must also specify the workdir", - examples=["/app"], - ) - build_commands: Optional[Sequence[str]] = Field( - default=None, - description="If you specified dockerfile_content and run_build is True, you must also specify the build commands.", - examples=[["pip install -r requirements.txt", "python -m build"]], - ) - test_commands: Optional[Sequence[str]] = Field( - default=None, - description="If you specified dockerfile_content and run_test is True, you must also specify the test commands.", - examples=[["pytest ."]], - ) - push_to_remote: Optional[bool] = Field( - default=False, - description="When editing the code, whenver we should push the changes to a remote branch", - examples=[True], - ) + issue_number: int = Field(description="The number of the issue", examples=[42]) + issue_title: str = Field( + description="The title of the issue", examples=["There is a memory leak"] + ) + issue_body: str = Field( + description="The description of the issue", examples=["foo/bar.c is causing a memory leak"] + ) + issue_comments: Optional[Sequence[Mapping[str, str]]] = Field( + default=None, + description="Comments on the issue", + examples=[ + [ + {"username": "user1", "comment": "I've experienced this issue as well."}, + { + "username": "user2", + "comment": "A potential fix is to adjust the memory settings.", + }, + ] + ], + ) + issue_type: IssueType = Field( + default=IssueType.AUTO, + description="The type of the issue, set to auto if you do not know", + examples=[IssueType.AUTO], + ) + run_build: Optional[bool] = Field( + default=False, + description="When editing the code, whenver we should run the build to verify the fix", + examples=[False], + ) + run_existing_test: Optional[bool] = Field( + default=False, + description="When editing the code, whenver we should run the existing test to verify the fix", + examples=[False], + ) + number_of_candidate_patch: Optional[int] = Field( + default=4, + description="When the patch is not verfied (through build or test), number of candidate patches we generate to select the best one", + examples=[4], + ) + dockerfile_content: Optional[str] = Field( + default=None, + description="Specify the containerized enviroment with dockerfile content", + examples=["FROM python:3.11\nWORKDIR /app\nCOPY . /app"], + ) + image_name: Optional[str] = Field( + default=None, + description="Specify the containerized enviroment with image name that should be pulled from dockerhub", + examples=["python:3.11-slim"], + ) + workdir: Optional[str] = Field( + default=None, + description="If you specified the container environment, you must also specify the workdir", + examples=["/app"], + ) + build_commands: Optional[Sequence[str]] = Field( + default=None, + description="If you specified dockerfile_content and run_build is True, you must also specify the build commands.", + examples=[["pip install -r requirements.txt", "python -m build"]], + ) + test_commands: Optional[Sequence[str]] = Field( + default=None, + description="If you specified dockerfile_content and run_test is True, you must also specify the test commands.", + examples=[["pytest ."]], + ) + push_to_remote: Optional[bool] = Field( + default=False, + description="When editing the code, whenver we should push the changes to a remote branch", + examples=[True], + ) @router.post( - "/answer/", - summary="Process and generate a response for an issue", - description="Analyzes an issue, generates patches if needed, runs optional builds and tests, and can push changes to a remote branch.", - response_description="Returns the patch, test results, and issue response", + "/answer/", + summary="Process and generate a response for an issue", + description="Analyzes an issue, generates patches if needed, runs optional builds and tests, and can push changes to a remote branch.", + response_description="Returns the patch, test results, and issue response", ) def answer_issue(issue: IssueRequest, request: Request): - if not request.app.state.service_coordinator.exists_knowledge_graph(): - raise HTTPException( - status_code=404, - detail="A repository is not uploaded, use /repository/ endpoint to upload one", - ) + if not request.app.state.service_coordinator.exists_knowledge_graph(): + raise HTTPException( + status_code=404, + detail="A repository is not uploaded, use /repository/ endpoint to upload one", + ) - if issue.dockerfile_content or issue.image_name: - if issue.workdir is None: - raise HTTPException( - status_code=400, - detail="workdir must be provided for user defined environment", - ) + if issue.dockerfile_content or issue.image_name: + if issue.workdir is None: + raise HTTPException( + status_code=400, + detail="workdir must be provided for user defined environment", + ) - ( - remote_branch_name, - patch, - passed_reproducing_test, - passed_build, - passed_existing_test, - issue_response, - ) = request.app.state.service_coordinator.answer_issue( - issue_number=issue.issue_number, - issue_title=issue.issue_title, - issue_body=issue.issue_body, - issue_comments=issue.issue_comments if issue.issue_comments else [], - issue_type=issue.issue_type, - run_build=issue.run_build, - run_existing_test=issue.run_existing_test, - number_of_candidate_patch=issue.number_of_candidate_patch, - dockerfile_content=issue.dockerfile_content, - image_name=issue.image_name, - workdir=issue.workdir, - build_commands=issue.build_commands, - test_commands=issue.test_commands, - push_to_remote=issue.push_to_remote, - ) - return { - "patch": patch, - "passed_reproducing_test": passed_reproducing_test, - "passed_build": passed_build, - "passed_existing_test": passed_existing_test, - "issue_response": issue_response, - "remote_branch_name": remote_branch_name, - } + ( + remote_branch_name, + patch, + passed_reproducing_test, + passed_build, + passed_existing_test, + issue_response, + ) = request.app.state.service_coordinator.answer_issue( + issue_number=issue.issue_number, + issue_title=issue.issue_title, + issue_body=issue.issue_body, + issue_comments=issue.issue_comments if issue.issue_comments else [], + issue_type=issue.issue_type, + run_build=issue.run_build, + run_existing_test=issue.run_existing_test, + number_of_candidate_patch=issue.number_of_candidate_patch, + dockerfile_content=issue.dockerfile_content, + image_name=issue.image_name, + workdir=issue.workdir, + build_commands=issue.build_commands, + test_commands=issue.test_commands, + push_to_remote=issue.push_to_remote, + ) + return { + "patch": patch, + "passed_reproducing_test": passed_reproducing_test, + "passed_build": passed_build, + "passed_existing_test": passed_existing_test, + "issue_response": issue_response, + "remote_branch_name": remote_branch_name, + } diff --git a/prometheus/app/api/repository.py b/prometheus/app/api/repository.py index 97ce4013..65655cf4 100644 --- a/prometheus/app/api/repository.py +++ b/prometheus/app/api/repository.py @@ -7,77 +7,77 @@ @router.get( - "/local/", - description=""" + "/local/", + description=""" Upload a local codebase to Prometheus. """, - responses={ - 404: {"description": "Local repository not found"}, - 200: {"description": "Repository uploaded successfully"}, - }, + responses={ + 404: {"description": "Local repository not found"}, + 200: {"description": "Repository uploaded successfully"}, + }, ) def upload_local_repository(local_repository: str, request: Request): - local_path = Path(local_repository) + local_path = Path(local_repository) - if not local_path.exists(): - raise HTTPException( - status_code=404, - detail=f"Local repository not found at path: {local_repository}", - ) + if not local_path.exists(): + raise HTTPException( + status_code=404, + detail=f"Local repository not found at path: {local_repository}", + ) - request.app.state.service_coordinator.upload_local_repository(local_path) + request.app.state.service_coordinator.upload_local_repository(local_path) - return {"message": "Repository uploaded successfully"} + return {"message": "Repository uploaded successfully"} @router.get( - "/github/", - description=""" + "/github/", + description=""" Upload a GitHub repository to Prometheus, default to the latest commit in the main branch. """, ) def upload_github_repository(https_url: str, request: Request): - try: - request.app.state.service_coordinator.upload_github_repository(https_url) - except git.exc.GitCommandError: - raise HTTPException(status_code=400, detail=f"Unable to clone {https_url}") + try: + request.app.state.service_coordinator.upload_github_repository(https_url) + except git.exc.GitCommandError: + raise HTTPException(status_code=400, detail=f"Unable to clone {https_url}") @router.get( - "/github_commit/", - description=""" + "/github_commit/", + description=""" Upload a GitHub repository at a specific commit to Prometheus. """, ) def upload_github_repository_at_commit(https_url: str, commit_id: str, request: Request): - try: - request.app.state.service_coordinator.upload_github_repository(https_url, commit_id) - except git.exc.GitCommandError: - raise HTTPException( - status_code=400, detail=f"Unable to clone {https_url} with commit {commit_id}" - ) + try: + request.app.state.service_coordinator.upload_github_repository(https_url, commit_id) + except git.exc.GitCommandError: + raise HTTPException( + status_code=400, detail=f"Unable to clone {https_url} with commit {commit_id}" + ) @router.get( - "/delete/", - description=""" + "/delete/", + description=""" Delete the repository uploaded to Prometheus, along with other information. """, ) def delete(request: Request): - if not request.app.state.service_coordinator.exists_knowledge_graph(): - return {"message": "No knowledge graph to delete"} + if not request.app.state.service_coordinator.exists_knowledge_graph(): + return {"message": "No knowledge graph to delete"} - request.app.state.service_coordinator.clear() - return {"message": "Successfully deleted knowledge graph"} + request.app.state.service_coordinator.clear() + return {"message": "Successfully deleted knowledge graph"} @router.get( - "/exists/", - description=""" + "/exists/", + description=""" If there is a codebase uploaded to Promtheus. """, - response_model=bool, + response_model=bool, ) def knowledge_graph_exists(request: Request) -> bool: - return request.app.state.service_coordinator.exists_knowledge_graph() + return request.app.state.service_coordinator.exists_knowledge_graph() diff --git a/prometheus/app/dependencies.py b/prometheus/app/dependencies.py index 60d631f9..0ac6bad1 100644 --- a/prometheus/app/dependencies.py +++ b/prometheus/app/dependencies.py @@ -12,59 +12,61 @@ def initialize_services() -> ServiceCoordinator: - """Initializes and configures the complete prometheus service stack. + """Initializes and configures the complete prometheus service stack. - This function creates and configures all required services for prometheus - operation, using settings from the configuration module. It ensures proper - initialization order and service dependencies. + This function creates and configures all required services for prometheus + operation, using settings from the configuration module. It ensures proper + initialization order and service dependencies. - Note: - This function assumes all required settings are properly configured in - the settings module using Dynaconf. The following settings are required: - - NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD - - LITELLM_MODEL - - NEO4J_BATCH_SIZE - - KNOWLEDGE_GRAPH_MAX_AST_DEPTH - - WORKING_DIRECTORY - - GITHUB_ACCESS_TOKEN + Note: + This function assumes all required settings are properly configured in + the settings module using Dynaconf. The following settings are required: + - NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD + - LITELLM_MODEL + - NEO4J_BATCH_SIZE + - KNOWLEDGE_GRAPH_MAX_AST_DEPTH + - WORKING_DIRECTORY + - GITHUB_ACCESS_TOKEN - Returns: - A fully configured ServiceCoordinator instance managing all services. - """ - neo4j_service = Neo4jService(settings.NEO4J_URI, settings.NEO4J_USERNAME, settings.NEO4J_PASSWORD) - llm_service = LLMService( - settings.ADVANCED_MODEL, - settings.BASE_MODEL, - getattr(settings, "OPENAI_API_KEY", None), - getattr(settings, "ANTHROPIC_API_KEY", None), - getattr(settings, "GEMINI_API_KEY", None), - getattr(settings, "OPENROUTER_API_KEY", None), - ) - knowledge_graph_service = KnowledgeGraphService( - neo4j_service, - settings.NEO4J_BATCH_SIZE, - settings.KNOWLEDGE_GRAPH_MAX_AST_DEPTH, - settings.KNOWLEDGE_GRAPH_CHUNK_SIZE, - settings.KNOWLEDGE_GRAPH_CHUNK_OVERLAP, - ) - resposistory_service = RepositoryService(knowledge_graph_service, settings.WORKING_DIRECTORY) - issue_service = IssueService( - knowledge_graph_service, - resposistory_service, - neo4j_service, - llm_service, - settings.MAX_TOKEN_PER_NEO4J_RESULT, - ) + Returns: + A fully configured ServiceCoordinator instance managing all services. + """ + neo4j_service = Neo4jService( + settings.NEO4J_URI, settings.NEO4J_USERNAME, settings.NEO4J_PASSWORD + ) + llm_service = LLMService( + settings.ADVANCED_MODEL, + settings.BASE_MODEL, + getattr(settings, "OPENAI_API_KEY", None), + getattr(settings, "ANTHROPIC_API_KEY", None), + getattr(settings, "GEMINI_API_KEY", None), + getattr(settings, "OPENROUTER_API_KEY", None), + ) + knowledge_graph_service = KnowledgeGraphService( + neo4j_service, + settings.NEO4J_BATCH_SIZE, + settings.KNOWLEDGE_GRAPH_MAX_AST_DEPTH, + settings.KNOWLEDGE_GRAPH_CHUNK_SIZE, + settings.KNOWLEDGE_GRAPH_CHUNK_OVERLAP, + ) + repository_service = RepositoryService(knowledge_graph_service, settings.WORKING_DIRECTORY) + issue_service = IssueService( + knowledge_graph_service, + repository_service, + neo4j_service, + llm_service, + settings.MAX_TOKEN_PER_NEO4J_RESULT, + ) - service_coordinator = ServiceCoordinator( - issue_service, - knowledge_graph_service, - llm_service, - neo4j_service, - resposistory_service, - settings.MAX_TOKEN_PER_NEO4J_RESULT, - settings.GITHUB_ACCESS_TOKEN, - Path(settings.WORKING_DIRECTORY).absolute(), - ) + service_coordinator = ServiceCoordinator( + issue_service, + knowledge_graph_service, + llm_service, + neo4j_service, + repository_service, + settings.MAX_TOKEN_PER_NEO4J_RESULT, + settings.GITHUB_ACCESS_TOKEN, + Path(settings.WORKING_DIRECTORY).absolute(), + ) - return service_coordinator + return service_coordinator diff --git a/prometheus/app/main.py b/prometheus/app/main.py index b35f1443..c20afbb9 100644 --- a/prometheus/app/main.py +++ b/prometheus/app/main.py @@ -8,7 +8,7 @@ from prometheus.app.api import issue, repository from prometheus.configuration.config import settings -# Create a logger for application's namespace +# Create a logger for the application's namespace logger = logging.getLogger("prometheus") formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") console_handler = logging.StreamHandler() @@ -17,7 +17,6 @@ logger.setLevel(getattr(logging, settings.LOGGING_LEVEL)) logger.propagate = False - logger.info(f"LOGGING_LEVEL={settings.LOGGING_LEVEL}") logger.info(f"ADVANCED_MODEL={settings.ADVANCED_MODEL}") logger.info(f"BASE_MODEL={settings.BASE_MODEL}") @@ -31,10 +30,10 @@ @asynccontextmanager async def lifespan(app: FastAPI): - app.state.service_coordinator = dependencies.initialize_services() - yield - app.state.service_coordinator.clear() - app.state.service_coordinator.close() + app.state.service_coordinator = dependencies.initialize_services() + yield + app.state.service_coordinator.clear() + app.state.service_coordinator.close() app = FastAPI(lifespan=lifespan) @@ -45,4 +44,4 @@ async def lifespan(app: FastAPI): @app.get("/health", tags=["health"]) def health_check(): - return {"status": "healthy", "timestamp": datetime.now(timezone.utc).isoformat()} + return {"status": "healthy", "timestamp": datetime.now(timezone.utc).isoformat()} diff --git a/prometheus/app/services/issue_service.py b/prometheus/app/services/issue_service.py index cb0fdfb9..e980706b 100644 --- a/prometheus/app/services/issue_service.py +++ b/prometheus/app/services/issue_service.py @@ -11,84 +11,84 @@ class IssueService: - def __init__( - self, - kg_service: KnowledgeGraphService, - repository_service: RepositoryService, - neo4j_service: Neo4jService, - llm_service: LLMService, - max_token_per_neo4j_result: int, - ): - self.kg_service = kg_service - self.repository_service = repository_service - self.neo4j_service = neo4j_service - self.llm_service = llm_service - self.max_token_per_neo4j_result = max_token_per_neo4j_result + def __init__( + self, + kg_service: KnowledgeGraphService, + repository_service: RepositoryService, + neo4j_service: Neo4jService, + llm_service: LLMService, + max_token_per_neo4j_result: int, + ): + self.kg_service = kg_service + self.repository_service = repository_service + self.neo4j_service = neo4j_service + self.llm_service = llm_service + self.max_token_per_neo4j_result = max_token_per_neo4j_result - def answer_issue( - self, - issue_title: str, - issue_body: str, - issue_comments: Sequence[Mapping[str, str]], - issue_type: IssueType, - run_build: bool, - run_existing_test: bool, - number_of_candidate_patch: int, - dockerfile_content: Optional[str] = None, - image_name: Optional[str] = None, - workdir: Optional[str] = None, - build_commands: Optional[Sequence[str]] = None, - test_commands: Optional[Sequence[str]] = None, - ): - if dockerfile_content or image_name: - container = UserDefinedContainer( - self.kg_service.kg.get_local_path(), - workdir, - build_commands, - test_commands, - dockerfile_content, - image_name, - ) - else: - container = GeneralContainer(self.kg_service.kg.get_local_path()) + def answer_issue( + self, + issue_title: str, + issue_body: str, + issue_comments: Sequence[Mapping[str, str]], + issue_type: IssueType, + run_build: bool, + run_existing_test: bool, + number_of_candidate_patch: int, + dockerfile_content: Optional[str] = None, + image_name: Optional[str] = None, + workdir: Optional[str] = None, + build_commands: Optional[Sequence[str]] = None, + test_commands: Optional[Sequence[str]] = None, + ): + if dockerfile_content or image_name: + container = UserDefinedContainer( + self.kg_service.kg.get_local_path(), + workdir, + build_commands, + test_commands, + dockerfile_content, + image_name, + ) + else: + container = GeneralContainer(self.kg_service.kg.get_local_path()) - issue_graph = IssueGraph( - advanced_model=self.llm_service.advanced_model, - base_model=self.llm_service.base_model, - kg=self.kg_service.kg, - git_repo=self.repository_service.git_repo, - neo4j_driver=self.neo4j_service.neo4j_driver, - max_token_per_neo4j_result=self.max_token_per_neo4j_result, - container=container, - build_commands=build_commands, - test_commands=test_commands, - ) + issue_graph = IssueGraph( + advanced_model=self.llm_service.advanced_model, + base_model=self.llm_service.base_model, + kg=self.kg_service.kg, + git_repo=self.repository_service.git_repo, + neo4j_driver=self.neo4j_service.neo4j_driver, + max_token_per_neo4j_result=self.max_token_per_neo4j_result, + container=container, + build_commands=build_commands, + test_commands=test_commands, + ) - output_state = issue_graph.invoke( - issue_title, - issue_body, - issue_comments, - issue_type, - run_build, - run_existing_test, - number_of_candidate_patch, - ) + output_state = issue_graph.invoke( + issue_title, + issue_body, + issue_comments, + issue_type, + run_build, + run_existing_test, + number_of_candidate_patch, + ) - if output_state["issue_type"] == IssueType.BUG: - return ( - output_state["edit_patch"], - output_state["passed_reproducing_test"], - output_state["passed_build"], - output_state["passed_existing_test"], - output_state["issue_response"], - ) - elif output_state["issue_type"] == IssueType.QUESTION: - return ( - "", - False, - False, - False, - output_state["issue_response"], - ) + if output_state["issue_type"] == IssueType.BUG: + return ( + output_state["edit_patch"], + output_state["passed_reproducing_test"], + output_state["passed_build"], + output_state["passed_existing_test"], + output_state["issue_response"], + ) + elif output_state["issue_type"] == IssueType.QUESTION: + return ( + "", + False, + False, + False, + output_state["issue_response"], + ) - return "", False, False, False, "" + return "", False, False, False, "" diff --git a/prometheus/app/services/knowledge_graph_service.py b/prometheus/app/services/knowledge_graph_service.py index 11bd9971..5ccaee41 100644 --- a/prometheus/app/services/knowledge_graph_service.py +++ b/prometheus/app/services/knowledge_graph_service.py @@ -9,78 +9,78 @@ class KnowledgeGraphService: - """Manages the lifecycle and operations of Knowledge Graphs. - - This service handles the creation, persistence, and management of Knowledge Graphs - that represent the whole codebase structures. It provides capabilities for building graphs - from codebase, storing them in Neo4j, and managing their lifecycle. - """ - - def __init__( - self, - neo4j_service: Neo4jService, - neo4j_batch_size: int, - max_ast_depth: int, - chunk_size: int, - chunk_overlap: int, - ): - """Initializes the Knowledge Graph service. - - Args: - neo4j_service: Service providing Neo4j database access. - neo4j_batch_size: Number of nodes to process in each Neo4j batch operation. - max_ast_depth: Maximum depth to traverse when building AST representations. - chunk_size: Chunk size for processing text files. - chunk_overlap: Overlap size for processing text files. - """ - self.kg_handler = knowledge_graph_handler.KnowledgeGraphHandler( - neo4j_service.neo4j_driver, neo4j_batch_size - ) - self.max_ast_depth = max_ast_depth - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - self.kg = self._load_existing_knowledge_graph() - - def _load_existing_knowledge_graph(self) -> Optional[KnowledgeGraph]: - """Attempts to load an existing Knowledge Graph from Neo4j. - - Returns: - KnowledgeGraph if one exists in the database, None otherwise. - """ - if self.kg_handler.knowledge_graph_exists(): - return self.kg_handler.read_knowledge_graph() - return None - - def get_local_path(self) -> Path: - if self.kg: - return self.kg.get_local_path() - return None - - def build_and_save_knowledge_graph( - self, path: Path, https_url: Optional[str] = None, commit_id: Optional[str] = None - ): - """Builds a new Knowledge Graph from source code and saves it to Neo4j. - - Creates a new Knowledge Graph representation of the codebase at the specified path, - optionally associating it with a repository URL and commit. Any existing - Knowledge Graph will be cleared before building the new one. - - Args: - path: Path to the source code directory to analyze. - https_url: Optional HTTPS URL of the repository. - commit_id: Optional commit identifier for version tracking. - """ - if self.exists(): - self.clear() + """Manages the lifecycle and operations of Knowledge Graphs. - kg = KnowledgeGraph(self.max_ast_depth, self.chunk_size, self.chunk_overlap) - kg.build_graph(path, https_url, commit_id) - self.kg = kg - self.kg_handler.write_knowledge_graph(kg) - - def exists(self) -> bool: - return self.kg_handler.knowledge_graph_exists() + This service handles the creation, persistence, and management of Knowledge Graphs + that represent the whole codebase structures. It provides capabilities for building graphs + from codebase, storing them in Neo4j, and managing their lifecycle. + """ - def clear(self): - self.kg_handler.clear_knowledge_graph() - self.kg = None + def __init__( + self, + neo4j_service: Neo4jService, + neo4j_batch_size: int, + max_ast_depth: int, + chunk_size: int, + chunk_overlap: int, + ): + """Initializes the Knowledge Graph service. + + Args: + neo4j_service: Service providing Neo4j database access. + neo4j_batch_size: Number of nodes to process in each Neo4j batch operation. + max_ast_depth: Maximum depth to traverse when building AST representations. + chunk_size: Chunk size for processing text files. + chunk_overlap: Overlap size for processing text files. + """ + self.kg_handler = knowledge_graph_handler.KnowledgeGraphHandler( + neo4j_service.neo4j_driver, neo4j_batch_size + ) + self.max_ast_depth = max_ast_depth + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.kg = self._load_existing_knowledge_graph() + + def _load_existing_knowledge_graph(self) -> Optional[KnowledgeGraph]: + """Attempts to load an existing Knowledge Graph from Neo4j. + + Returns: + KnowledgeGraph if one exists in the database, None otherwise. + """ + if self.kg_handler.knowledge_graph_exists(): + return self.kg_handler.read_knowledge_graph() + return None + + def get_local_path(self) -> Path: + if self.kg: + return self.kg.get_local_path() + return None + + def build_and_save_knowledge_graph( + self, path: Path, https_url: Optional[str] = None, commit_id: Optional[str] = None + ): + """Builds a new Knowledge Graph from source code and saves it to Neo4j. + + Creates a new Knowledge Graph representation of the codebase at the specified path, + optionally associating it with a repository URL and commit. Any existing + Knowledge Graph will be cleared before building the new one. + + Args: + path: Path to the source code directory to analyze. + https_url: Optional HTTPS URL of the repository. + commit_id: Optional commit identifier for version tracking. + """ + if self.exists(): + self.clear() + + kg = KnowledgeGraph(self.max_ast_depth, self.chunk_size, self.chunk_overlap) + kg.build_graph(path, https_url, commit_id) + self.kg = kg + self.kg_handler.write_knowledge_graph(kg) + + def exists(self) -> bool: + return self.kg_handler.knowledge_graph_exists() + + def clear(self): + self.kg_handler.clear_knowledge_graph() + self.kg = None diff --git a/prometheus/app/services/llm_service.py b/prometheus/app/services/llm_service.py index 4dc448a2..6d99d4e9 100644 --- a/prometheus/app/services/llm_service.py +++ b/prometheus/app/services/llm_service.py @@ -7,56 +7,72 @@ class CustomChatOpenAI(ChatOpenAI): - def bind_tools(self, tools, tool_choice=None, **kwargs): - kwargs["parallel_tool_calls"] = False - return super().bind_tools(tools, tool_choice=tool_choice, **kwargs) + def bind_tools(self, tools, tool_choice=None, **kwargs): + kwargs["parallel_tool_calls"] = False + return super().bind_tools(tools, tool_choice=tool_choice, **kwargs) class LLMService: - def __init__( - self, - advanced_model_name: str, - base_model_name: str, + def __init__( + self, + advanced_model_name: str, + base_model_name: str, + openai_api_key: Optional[str] = None, + anthropic_api_key: Optional[str] = None, + gemini_api_key: Optional[str] = None, + open_router_api_key: Optional[str] = None, + ): + self.advanced_model = get_model( + advanced_model_name, + openai_api_key, + anthropic_api_key, + gemini_api_key, + open_router_api_key, + ) + self.base_model = get_model( + base_model_name, openai_api_key, anthropic_api_key, gemini_api_key, open_router_api_key + ) + + +def get_model( + model_name: str, openai_api_key: Optional[str] = None, anthropic_api_key: Optional[str] = None, gemini_api_key: Optional[str] = None, open_router_api_key: Optional[str] = None, - ): - self.advanced_model = get_model( - advanced_model_name, openai_api_key, anthropic_api_key, gemini_api_key, open_router_api_key - ) - self.base_model = get_model( - base_model_name, openai_api_key, anthropic_api_key, gemini_api_key, open_router_api_key - ) - - -def get_model( - model_name: str, - openai_api_key: Optional[str] = None, - anthropic_api_key: Optional[str] = None, - gemini_api_key: Optional[str] = None, - open_router_api_key: Optional[str] = None, ) -> BaseChatModel: - if "/" in model_name: - return CustomChatOpenAI( - model=model_name, - api_key=open_router_api_key, - base_url="https://openrouter.ai/api/v1", - temperature=0.0, - max_tokens=None, - max_retries=3, - ) - elif "claude" in model_name: - return ChatAnthropic( - model=model_name, api_key=anthropic_api_key, temperature=0.0, max_tokens=8192, max_retries=3 - ) - elif "gpt" in model_name: - return CustomChatOpenAI( - model=model_name, api_key=openai_api_key, temperature=0.0, max_tokens=None, max_retries=3 - ) - elif "gemini" in model_name: - return ChatGoogleGenerativeAI( - model=model_name, api_key=gemini_api_key, temperature=0.0, max_tokens=None, max_retries=3 - ) - else: - raise ValueError(f"Unknown model name: {model_name}") + if "/" in model_name: + return CustomChatOpenAI( + model=model_name, + api_key=open_router_api_key, + base_url="https://openrouter.ai/api/v1", + temperature=0.0, + max_tokens=None, + max_retries=3, + ) + elif "claude" in model_name: + return ChatAnthropic( + model=model_name, + api_key=anthropic_api_key, + temperature=0.0, + max_tokens=8192, + max_retries=3, + ) + elif "gpt" in model_name: + return CustomChatOpenAI( + model=model_name, + api_key=openai_api_key, + temperature=0.0, + max_tokens=None, + max_retries=3, + ) + elif "gemini" in model_name: + return ChatGoogleGenerativeAI( + model=model_name, + api_key=gemini_api_key, + temperature=0.0, + max_tokens=None, + max_retries=3, + ) + else: + raise ValueError(f"Unknown model name: {model_name}") diff --git a/prometheus/app/services/neo4j_service.py b/prometheus/app/services/neo4j_service.py index 77caf9f8..51be9c28 100644 --- a/prometheus/app/services/neo4j_service.py +++ b/prometheus/app/services/neo4j_service.py @@ -4,11 +4,11 @@ class Neo4jService: - def __init__(self, neo4j_uri: str, neo4j_username: str, neo4j_password: str): - self.neo4j_driver = GraphDatabase.driver( - neo4j_uri, - auth=(neo4j_username, neo4j_password), - ) - - def close(self): - self.neo4j_driver.close() + def __init__(self, neo4j_uri: str, neo4j_username: str, neo4j_password: str): + self.neo4j_driver = GraphDatabase.driver( + neo4j_uri, + auth=(neo4j_username, neo4j_password), + ) + + def close(self): + self.neo4j_driver.close() diff --git a/prometheus/app/services/postgres_service.py b/prometheus/app/services/postgres_service.py index 36ceac2b..dc5bc7bf 100644 --- a/prometheus/app/services/postgres_service.py +++ b/prometheus/app/services/postgres_service.py @@ -9,71 +9,71 @@ class PostgresService: - """Manages PostgreSQL database operations for conversation storage. + """Manages PostgreSQL database operations for conversation storage. - This service handles database connections and provides methods to interact - with stored conversation data. The main functionality is to provide checkpointer - for LangGraph. - """ + This service handles database connections and provides methods to interact + with stored conversation data. The main functionality is to provide checkpointer + for LangGraph. + """ - def __init__(self, postgres_uri: str): - self.postgres_conn = Connection.connect( - postgres_uri, - autocommit=True, - prepare_threshold=0, - row_factory=dict_row, - ) - self.checkpointer = PostgresSaver(self.postgres_conn) - self.checkpointer.setup() + def __init__(self, postgres_uri: str): + self.postgres_conn = Connection.connect( + postgres_uri, + autocommit=True, + prepare_threshold=0, + row_factory=dict_row, + ) + self.checkpointer = PostgresSaver(self.postgres_conn) + self.checkpointer.setup() - def close(self): - self.postgres_conn.close() + def close(self): + self.postgres_conn.close() - def get_all_thread_ids(self) -> list[str]: - """Retrieves all unique conversation thread IDs from the database. + def get_all_thread_ids(self) -> list[str]: + """Retrieves all unique conversation thread IDs from the database. - Returns a list of thread IDs sorted by timestamp in descending order - (most recent first), with duplicates removed. + Returns a list of thread IDs sorted by timestamp in descending order + (most recent first), with duplicates removed. - Returns: - List of unique thread identifiers as strings. - """ - all_checkpoints = list(self.checkpointer.list(None)) - all_checkpoints = sorted( - all_checkpoints, key=lambda x: datetime.fromisoformat(x.checkpoint["ts"]), reverse=True - ) - unique_thread_ids = [] - seen = set() - for checkpoint in all_checkpoints: - thread_id = checkpoint.config["configurable"]["thread_id"] - if thread_id not in seen: - unique_thread_ids.append(thread_id) - seen.add(thread_id) + Returns: + List of unique thread identifiers as strings. + """ + all_checkpoints = list(self.checkpointer.list(None)) + all_checkpoints = sorted( + all_checkpoints, key=lambda x: datetime.fromisoformat(x.checkpoint["ts"]), reverse=True + ) + unique_thread_ids = [] + seen = set() + for checkpoint in all_checkpoints: + thread_id = checkpoint.config["configurable"]["thread_id"] + if thread_id not in seen: + unique_thread_ids.append(thread_id) + seen.add(thread_id) - return unique_thread_ids + return unique_thread_ids - def get_messages(self, conversation_id: str) -> list[dict[str, str]]: - """Retrieves all messages for a specific conversation thread. + def get_messages(self, conversation_id: str) -> list[dict[str, str]]: + """Retrieves all messages for a specific conversation thread. - Fetches and formats the complete message history for a given conversation, - including both user and assistant messages, but excludes tool messages or - call to using tools. + Fetches and formats the complete message history for a given conversation, + including both user and assistant messages, but excludes tool messages or + call to using tools. - Args: - conversation_id: Unique identifier for the conversation thread. + Args: + conversation_id: Unique identifier for the conversation thread. - Returns: - List of message dictionaries, where each dictionary contains: - - 'role': Either 'user' or 'assistant' - - 'text': The message content - """ - read_config = {"configurable": {"thread_id": conversation_id}} - checkpoint = self.checkpointer.get(read_config) - messages = [] - messages.append({"role": "user", "text": checkpoint["channel_values"]["query"]}) - for message in checkpoint["channel_values"]["messages"]: - if isinstance(message, HumanMessage): - messages.append({"role": "user", "text": message.content}) - elif isinstance(message, AIMessage) and not message.additional_kwargs: - messages.append({"role": "assistant", "text": message.content}) - return messages + Returns: + List of message dictionaries, where each dictionary contains: + - 'role': Either 'user' or 'assistant' + - 'text': The message content + """ + read_config = {"configurable": {"thread_id": conversation_id}} + checkpoint = self.checkpointer.get(read_config) + messages = [] + messages.append({"role": "user", "text": checkpoint["channel_values"]["query"]}) + for message in checkpoint["channel_values"]["messages"]: + if isinstance(message, HumanMessage): + messages.append({"role": "user", "text": message.content}) + elif isinstance(message, AIMessage) and not message.additional_kwargs: + messages.append({"role": "assistant", "text": message.content}) + return messages diff --git a/prometheus/app/services/repository_service.py b/prometheus/app/services/repository_service.py index f2c00ed5..8a02717d 100644 --- a/prometheus/app/services/repository_service.py +++ b/prometheus/app/services/repository_service.py @@ -10,84 +10,86 @@ class RepositoryService: - """Manages repository operations. - - This service provides functionality for Git repository operations including - cloning repositories, managing commits, pushing changes, and maintaining - a clean working directory. It integrates with a knowledge graph service - to track repository state and avoid redundant operations. - """ - - def __init__( - self, - kg_service: KnowledgeGraphService, - working_dir: str, - ): - """Initializes the repository service. - - Args: - kg_service: Knowledge graph service instance for codebase tracking. - working_dir: Base directory for repository operations. A 'repositories' - subdirectory will be created under this path. - """ - self.kg_service = kg_service - self.target_directory = Path(working_dir) / "repositories" - self.target_directory.mkdir(parents=True, exist_ok=True) - self.git_repo = self._load_existing_git_repo() - - def _load_existing_git_repo(self): - if self.kg_service.get_local_path() and self.kg_service.get_local_path().exists(): - return GitRepository(str(self.kg_service.get_local_path()), None, copy_to_working_dir=False) - return None - - def get_working_dir(self): - if self.git_repo: - return self.git_repo.get_working_directory() - return None - - def clone_github_repo( - self, github_token: str, https_url: str, commit_id: Optional[str] = None - ) -> Path: - """Clones a GitHub repository to the local workspace. - - Clones the specified repository and optionally checks out a specific commit. - If the repository is already present and matches the requested state, - the operation may be skipped. - - Args: - github_token: GitHub access token for authentication. - https_url: HTTPS URL of the GitHub repository. - commit_id: Optional specific commit to check out. - - Returns: - Path to the local repository directory. - """ - self.git_repo = GitRepository( - https_url, self.target_directory, github_access_token=github_token - ) - if commit_id: - self.git_repo.checkout_commit(commit_id) - local_path = self.git_repo.get_working_directory() - return local_path - - def push_change_to_remote(self, commit_message: str, patch: str): - """Pushes local changes to a new remote branch. - - Creates a new branch with a unique name, commits the current changes, - and pushes them to the remote repository. Branch names are prefixed with - 'prometheus_fix_' and include a unique identifier. - - Args: - commit_message: Message to use for the commit. - - Returns: - Name of the created branch. + """Manages repository operations. + + This service provides functionality for Git repository operations including + cloning repositories, managing commits, pushing changes, and maintaining + a clean working directory. It integrates with a knowledge graph service + to track repository state and avoid redundant operations. """ - branch_name = f"prometheus_fix_{uuid.uuid4().hex[:10]}" - self.git_repo.create_and_push_branch(branch_name, commit_message, patch) - return branch_name - - def clean(self): - self.git_repo = None - shutil.rmtree(self.target_directory) - self.target_directory.mkdir(parents=True) + + def __init__( + self, + kg_service: KnowledgeGraphService, + working_dir: str, + ): + """Initializes the repository service. + + Args: + kg_service: Knowledge graph service instance for codebase tracking. + working_dir: Base directory for repository operations. A 'repositories' + subdirectory will be created under this path. + """ + self.kg_service = kg_service + self.target_directory = Path(working_dir) / "repositories" + self.target_directory.mkdir(parents=True, exist_ok=True) + self.git_repo = self._load_existing_git_repo() + + def _load_existing_git_repo(self): + if self.kg_service.get_local_path() and self.kg_service.get_local_path().exists(): + return GitRepository( + str(self.kg_service.get_local_path()), None, copy_to_working_dir=False + ) + return None + + def get_working_dir(self): + if self.git_repo: + return self.git_repo.get_working_directory() + return None + + def clone_github_repo( + self, github_token: str, https_url: str, commit_id: Optional[str] = None + ) -> Path: + """Clones a GitHub repository to the local workspace. + + Clones the specified repository and optionally checks out a specific commit. + If the repository is already present and matches the requested state, + the operation may be skipped. + + Args: + github_token: GitHub access token for authentication. + https_url: HTTPS URL of the GitHub repository. + commit_id: Optional specific commit to check out. + + Returns: + Path to the local repository directory. + """ + self.git_repo = GitRepository( + https_url, self.target_directory, github_access_token=github_token + ) + if commit_id: + self.git_repo.checkout_commit(commit_id) + local_path = self.git_repo.get_working_directory() + return local_path + + def push_change_to_remote(self, commit_message: str, patch: str): + """Pushes local changes to a new remote branch. + + Creates a new branch with a unique name, commits the current changes, + and pushes them to the remote repository. Branch names are prefixed with + 'prometheus_fix_' and include a unique identifier. + + Args: + commit_message: Message to use for the commit. + + Returns: + Name of the created branch. + """ + branch_name = f"prometheus_fix_{uuid.uuid4().hex[:10]}" + self.git_repo.create_and_push_branch(branch_name, commit_message, patch) + return branch_name + + def clean(self): + self.git_repo = None + shutil.rmtree(self.target_directory) + self.target_directory.mkdir(parents=True) diff --git a/prometheus/app/services/service_coordinator.py b/prometheus/app/services/service_coordinator.py index b834a692..3f4467e0 100644 --- a/prometheus/app/services/service_coordinator.py +++ b/prometheus/app/services/service_coordinator.py @@ -21,154 +21,161 @@ class ServiceCoordinator: - """Coordinates operations between various prometheus services. - - This class serves as the central orchestrator for all service interactions, - managing the lifecycle of various operations including codebase analysis, - issue handling, and conversation management. It ensures proper initialization, - coordination, and cleanup of all dependent services. - """ - - def __init__( - self, - issue_service: IssueService, - knowledge_graph_service: KnowledgeGraphService, - llm_service: LLMService, - neo4j_service: Neo4jService, - repository_service: RepositoryService, - max_token_per_neo4j_result: int, - github_token: str, - working_directory: Path, - ): - """Initializes the service coordinator with required services. - - Args: - issue_service: Service for issue handling. - knowledge_graph_service: Service for knowledge graph operations. - llm_service: Service for language model operations. - neo4j_service: Service for Neo4j database operations. - repository_service: Service for repository management. - max_token_per_neo4j_result: Maximum number of tokens per Neo4j result. - github_token: GitHub access token for repository operations. - working_directory: Working directory for all Prometheus related files. - """ - self.issue_service = issue_service - self.knowledge_graph_service = knowledge_graph_service - self.llm_service = llm_service - self.neo4j_service = neo4j_service - self.repository_service = repository_service - self.max_token_per_neo4j_result = max_token_per_neo4j_result - self.github_token = github_token - self.working_directory = working_directory - self.answer_issue_log_dir = self.working_directory / "answer_issue_logs" - self.answer_issue_log_dir.mkdir(parents=True, exist_ok=True) - self._logger = logging.getLogger("prometheus.app.services.service_coordinator") - - if self.knowledge_graph_service.get_local_path() != self.repository_service.get_working_dir(): - self._logger.critical( - f"Knowledge graph and repository working directories do not match: {self.knowledge_graph_service.get_local_path()} vs {self.repository_service.get_working_dir()}. Resetting all services." - ) - self.clear() - - def answer_issue( - self, - issue_number: int, - issue_title: str, - issue_body: str, - issue_comments: Sequence[Mapping[str, str]], - issue_type: IssueType, - run_build: bool, - run_existing_test: bool, - number_of_candidate_patch: int, - dockerfile_content: Optional[str] = None, - image_name: Optional[str] = None, - workdir: Optional[str] = None, - build_commands: Optional[Sequence[str]] = None, - test_commands: Optional[Sequence[str]] = None, - push_to_remote: Optional[bool] = None, - ): - logger = logging.getLogger("prometheus") - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - log_file = self.answer_issue_log_dir / f"{timestamp}.log" - file_handler = logging.FileHandler(log_file) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - - try: - patch, passed_reproducing_test, passed_build, passed_existing_test, issue_response = ( - self.issue_service.answer_issue( - issue_title=issue_title, - issue_body=issue_body, - issue_comments=issue_comments, - issue_type=issue_type, - run_build=run_build, - run_existing_test=run_existing_test, - number_of_candidate_patch=number_of_candidate_patch, - dockerfile_content=dockerfile_content, - image_name=image_name, - workdir=workdir, - build_commands=build_commands, - test_commands=test_commands, - ) - ) + """Coordinates operations between various prometheus services. - remote_branch_name = None - if patch and push_to_remote: - remote_branch_name = self.repository_service.push_change_to_remote( - f"Fixes #{issue_number}", patch - ) - return ( - remote_branch_name, - patch, - passed_reproducing_test, - passed_build, - passed_existing_test, - issue_response, - ) - except Exception as e: - logger.error(f"Error in answer_issue: {str(e)}\n{traceback.format_exc()}") - return None, None, False, False, False, None - finally: - logger.removeHandler(file_handler) - file_handler.close() - - def exists_knowledge_graph(self) -> bool: - return self.knowledge_graph_service.exists() - - def upload_local_repository(self, path: Path): - """Uploads a local repository. - - Args: - path: Path to the local repository directory. + This class serves as the central orchestrator for all service interactions, + managing the lifecycle of various operations including codebase analysis, + issue handling, and conversation management. It ensures proper initialization, + coordination, and cleanup of all dependent services. """ - self.clear() - self.knowledge_graph_service.build_and_save_knowledge_graph(path) - def upload_github_repository(self, https_url: str, commit_id: Optional[str] = None): - """Uploads a GitHub repository. - - Args: - https_url: HTTPS URL of the GitHub repository. - commit_id: Optional specific commit to analyze. - """ - self.clear() - saved_path = self.repository_service.clone_github_repo(self.github_token, https_url, commit_id) - self.knowledge_graph_service.build_and_save_knowledge_graph(saved_path, https_url, commit_id) + def __init__( + self, + issue_service: IssueService, + knowledge_graph_service: KnowledgeGraphService, + llm_service: LLMService, + neo4j_service: Neo4jService, + repository_service: RepositoryService, + max_token_per_neo4j_result: int, + github_token: str, + working_directory: Path, + ): + """Initializes the service coordinator with required services. + + Args: + issue_service: Service for issue handling. + knowledge_graph_service: Service for knowledge graph operations. + llm_service: Service for language model operations. + neo4j_service: Service for Neo4j database operations. + repository_service: Service for repository management. + max_token_per_neo4j_result: Maximum number of tokens per Neo4j result. + github_token: GitHub access token for repository operations. + working_directory: Working directory for all Prometheus related files. + """ + self.issue_service = issue_service + self.knowledge_graph_service = knowledge_graph_service + self.llm_service = llm_service + self.neo4j_service = neo4j_service + self.repository_service = repository_service + self.max_token_per_neo4j_result = max_token_per_neo4j_result + self.github_token = github_token + self.working_directory = working_directory + self.answer_issue_log_dir = self.working_directory / "answer_issue_logs" + self.answer_issue_log_dir.mkdir(parents=True, exist_ok=True) + self._logger = logging.getLogger("prometheus.app.services.service_coordinator") + + if ( + self.knowledge_graph_service.get_local_path() + != self.repository_service.get_working_dir() + ): + self._logger.critical( + f"Knowledge graph and repository working directories do not match: {self.knowledge_graph_service.get_local_path()} vs {self.repository_service.get_working_dir()}. Resetting all services." + ) + self.clear() + + def answer_issue( + self, + issue_number: int, + issue_title: str, + issue_body: str, + issue_comments: Sequence[Mapping[str, str]], + issue_type: IssueType, + run_build: bool, + run_existing_test: bool, + number_of_candidate_patch: int, + dockerfile_content: Optional[str] = None, + image_name: Optional[str] = None, + workdir: Optional[str] = None, + build_commands: Optional[Sequence[str]] = None, + test_commands: Optional[Sequence[str]] = None, + push_to_remote: Optional[bool] = None, + ): + logger = logging.getLogger("prometheus") + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = self.answer_issue_log_dir / f"{timestamp}.log" + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + try: + patch, passed_reproducing_test, passed_build, passed_existing_test, issue_response = ( + self.issue_service.answer_issue( + issue_title=issue_title, + issue_body=issue_body, + issue_comments=issue_comments, + issue_type=issue_type, + run_build=run_build, + run_existing_test=run_existing_test, + number_of_candidate_patch=number_of_candidate_patch, + dockerfile_content=dockerfile_content, + image_name=image_name, + workdir=workdir, + build_commands=build_commands, + test_commands=test_commands, + ) + ) + + remote_branch_name = None + if patch and push_to_remote: + remote_branch_name = self.repository_service.push_change_to_remote( + f"Fixes #{issue_number}", patch + ) + return ( + remote_branch_name, + patch, + passed_reproducing_test, + passed_build, + passed_existing_test, + issue_response, + ) + except Exception as e: + logger.error(f"Error in answer_issue: {str(e)}\n{traceback.format_exc()}") + return None, None, False, False, False, None + finally: + logger.removeHandler(file_handler) + file_handler.close() + + def exists_knowledge_graph(self) -> bool: + return self.knowledge_graph_service.exists() + + def upload_local_repository(self, path: Path): + """Uploads a local repository. + + Args: + path: Path to the local repository directory. + """ + self.clear() + self.knowledge_graph_service.build_and_save_knowledge_graph(path) + + def upload_github_repository(self, https_url: str, commit_id: Optional[str] = None): + """Uploads a GitHub repository. + + Args: + https_url: HTTPS URL of the GitHub repository. + commit_id: Optional specific commit to analyze. + """ + self.clear() + saved_path = self.repository_service.clone_github_repo( + self.github_token, https_url, commit_id + ) + self.knowledge_graph_service.build_and_save_knowledge_graph( + saved_path, https_url, commit_id + ) - def clear(self): - """Clears all service state and working directories. + def clear(self): + """Clears all service state and working directories. - Resets the knowledge graph, cleans repository working directory, - and reinitializes subgraph services. - """ - self.knowledge_graph_service.clear() - self.repository_service.clean() + Resets the knowledge graph, cleans repository working directory, + and reinitializes subgraph services. + """ + self.knowledge_graph_service.clear() + self.repository_service.clean() - def close(self): - """Closes all database connections and releases resources. + def close(self): + """Closes all database connections and releases resources. - This method should be called when the coordinator is no longer needed - to ensure proper cleanup of database connections and resources. - """ - self.neo4j_service.close() + This method should be called when the coordinator is no longer needed + to ensure proper cleanup of database connections and resources. + """ + self.neo4j_service.close() diff --git a/prometheus/configuration/config.py b/prometheus/configuration/config.py index a9021fc2..ea0b2860 100644 --- a/prometheus/configuration/config.py +++ b/prometheus/configuration/config.py @@ -1,7 +1,3 @@ from dynaconf import Dynaconf -settings = Dynaconf( - envvar_prefix="PROMETHEUS", - settings_files=["settings.toml"], - environments=True, -) +settings = Dynaconf(envvar_prefix="PROMETHEUS", load_dotenv=True) diff --git a/prometheus/docker/base_container.py b/prometheus/docker/base_container.py index c96e665b..c09f5187 100644 --- a/prometheus/docker/base_container.py +++ b/prometheus/docker/base_container.py @@ -10,181 +10,181 @@ class BaseContainer(ABC): - """An abstract base class for managing Docker containers with file synchronization capabilities. + """An abstract base class for managing Docker containers with file synchronization capabilities. - This class provides core functionality for creating, managing, and interacting with Docker - containers. It handles container lifecycle operations including building images, starting - containers, updating files, and cleanup. The class is designed to be extended for specific - container implementations that specifies the Dockerfile, how to build and how to run the test. - """ - - client: docker.DockerClient = docker.from_env() - tag_name: str - workdir: str = "/app" - container: docker.models.containers.Container - project_path: Path - timeout: int = 120 - logger: logging.Logger - - def __init__(self, project_path: Path, workdir: Optional[str] = None): - """Initialize the container with a project directory. - - Creates a temporary copy of the project directory to work with. - - Args: - project_path: Path to the project directory to be containerized. - """ - self._logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}") - temp_dir = Path(tempfile.mkdtemp()) - temp_project_path = temp_dir / project_path.name - shutil.copytree(project_path, temp_project_path) - self.project_path = temp_project_path.absolute() - self._logger.info(f"Created temporary project directory: {self.project_path}") - - if workdir: - self.workdir = workdir - self._logger.debug(f"Using workdir: {self.workdir}") - - self.container = None - - @abstractmethod - def get_dockerfile_content(self) -> str: - """Get the content of the Dockerfile for building the container image. - - Returns: - str: Content of the Dockerfile as a string. - """ - pass - - def build_docker_image(self): - """Build a Docker image using the Dockerfile content. - - Creates a Dockerfile in the project directory and builds a Docker image - using the specified tag name. - """ - dockerfile_content = self.get_dockerfile_content() - dockerfile_path = self.project_path / "prometheus.Dockerfile" - dockerfile_path.write_text(dockerfile_content) - self._logger.info(f"Building docker image {self.tag_name}") - self.client.images.build( - path=str(self.project_path), dockerfile=dockerfile_path.name, tag=self.tag_name - ) - - def start_container(self): - """Start a Docker container from the built image. - - Starts a detached container with TTY enabled and mounts the Docker socket. - """ - self._logger.info(f"Starting container from image {self.tag_name}") - self.container = self.client.containers.run( - self.tag_name, - detach=True, - tty=True, - network_mode="host", - environment={"PYTHONPATH": f"{self.workdir}:$PYTHONPATH"}, - volumes={"/var/run/docker.sock": {"bind": "/var/run/docker.sock", "mode": "rw"}}, - ) - - def is_running(self) -> bool: - return bool(self.container) - - def update_files( - self, project_root_path: Path, updated_files: Sequence[Path], removed_files: Sequence[Path] - ): - """Update files in the running container with files from a local directory. - - Creates a tar archive of the new files and copies them into the workdir of the container. - - Args: - new_project_path: Path to the directory containing new files. - """ - if not project_root_path.is_absolute(): - raise ValueError("project_root_path {project_root_path} must be a absolute path") - - self._logger.info("Updating files in the container after edits.") - for file in removed_files: - self._logger.info(f"Removing file {file} in the container") - self.execute_command(f"rm {file}") - - parent_dirs = {str(file.parent) for file in updated_files} - for dir_path in sorted(parent_dirs): - self._logger.info(f"Creating directory {dir_path} in the container") - self.execute_command(f"mkdir -p {dir_path}") - - with tempfile.NamedTemporaryFile() as temp_tar: - with tarfile.open(fileobj=temp_tar, mode="w") as tar: - for file in updated_files: - local_absolute_file = project_root_path / file - self._logger.info(f"Updating {file} in the container") - tar.add(local_absolute_file, arcname=str(file)) - - temp_tar.seek(0) - - self.container.put_archive(self.workdir, temp_tar.read()) - - self._logger.info("Files updated successfully") - - @abstractmethod - def run_build(self): - """Run build commands in the container. - - This method should be implemented by subclasses to define build steps. + This class provides core functionality for creating, managing, and interacting with Docker + containers. It handles container lifecycle operations including building images, starting + containers, updating files, and cleanup. The class is designed to be extended for specific + container implementations that specifies the Dockerfile, how to build and how to run the test. """ - pass - - @abstractmethod - def run_test(self): - """Run test commands in the container. - - This method should be implemented by subclasses to define test steps. - """ - pass - - def execute_command(self, command: str) -> str: - """Execute a command in the running container. - Args: - command: Command to execute in the container. - - Returns: - str: Output of the command as a string. - """ - timeout_msg = f""" + client: docker.DockerClient = docker.from_env() + tag_name: str + workdir: str = "/app" + container: docker.models.containers.Container + project_path: Path + timeout: int = 120 + logger: logging.Logger + + def __init__(self, project_path: Path, workdir: Optional[str] = None): + """Initialize the container with a project directory. + + Creates a temporary copy of the project directory to work with. + + Args: + project_path: Path to the project directory to be containerized. + """ + self._logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}") + temp_dir = Path(tempfile.mkdtemp()) + temp_project_path = temp_dir / project_path.name + shutil.copytree(project_path, temp_project_path) + self.project_path = temp_project_path.absolute() + self._logger.info(f"Created temporary project directory: {self.project_path}") + + if workdir: + self.workdir = workdir + self._logger.debug(f"Using workdir: {self.workdir}") + + self.container = None + + @abstractmethod + def get_dockerfile_content(self) -> str: + """Get the content of the Dockerfile for building the container image. + + Returns: + str: Content of the Dockerfile as a string. + """ + pass + + def build_docker_image(self): + """Build a Docker image using the Dockerfile content. + + Creates a Dockerfile in the project directory and builds a Docker image + using the specified tag name. + """ + dockerfile_content = self.get_dockerfile_content() + dockerfile_path = self.project_path / "prometheus.Dockerfile" + dockerfile_path.write_text(dockerfile_content) + self._logger.info(f"Building docker image {self.tag_name}") + self.client.images.build( + path=str(self.project_path), dockerfile=dockerfile_path.name, tag=self.tag_name + ) + + def start_container(self): + """Start a Docker container from the built image. + + Starts a detached container with TTY enabled and mounts the Docker socket. + """ + self._logger.info(f"Starting container from image {self.tag_name}") + self.container = self.client.containers.run( + self.tag_name, + detach=True, + tty=True, + network_mode="host", + environment={"PYTHONPATH": f"{self.workdir}:$PYTHONPATH"}, + volumes={"/var/run/docker.sock": {"bind": "/var/run/docker.sock", "mode": "rw"}}, + ) + + def is_running(self) -> bool: + return bool(self.container) + + def update_files( + self, project_root_path: Path, updated_files: Sequence[Path], removed_files: Sequence[Path] + ): + """Update files in the running container with files from a local directory. + + Creates a tar archive of the new files and copies them into the workdir of the container. + + Args: + new_project_path: Path to the directory containing new files. + """ + if not project_root_path.is_absolute(): + raise ValueError("project_root_path {project_root_path} must be a absolute path") + + self._logger.info("Updating files in the container after edits.") + for file in removed_files: + self._logger.info(f"Removing file {file} in the container") + self.execute_command(f"rm {file}") + + parent_dirs = {str(file.parent) for file in updated_files} + for dir_path in sorted(parent_dirs): + self._logger.info(f"Creating directory {dir_path} in the container") + self.execute_command(f"mkdir -p {dir_path}") + + with tempfile.NamedTemporaryFile() as temp_tar: + with tarfile.open(fileobj=temp_tar, mode="w") as tar: + for file in updated_files: + local_absolute_file = project_root_path / file + self._logger.info(f"Updating {file} in the container") + tar.add(local_absolute_file, arcname=str(file)) + + temp_tar.seek(0) + + self.container.put_archive(self.workdir, temp_tar.read()) + + self._logger.info("Files updated successfully") + + @abstractmethod + def run_build(self): + """Run build commands in the container. + + This method should be implemented by subclasses to define build steps. + """ + pass + + @abstractmethod + def run_test(self): + """Run test commands in the container. + + This method should be implemented by subclasses to define test steps. + """ + pass + + def execute_command(self, command: str) -> str: + """Execute a command in the running container. + + Args: + command: Command to execute in the container. + + Returns: + str: Output of the command as a string. + """ + timeout_msg = f""" ******************************************************************************* {command} timeout after {self.timeout} seconds ******************************************************************************* """ - timeout_command = f"timeout -k 5 {self.timeout}s {command}" - command = f'/bin/bash -l -c "{timeout_command}"' - self._logger.debug(f"Running command in container: {command}") - exec_result = self.container.exec_run(command, workdir=self.workdir) - exec_result_str = exec_result.output.decode("utf-8") - - if exec_result.exit_code in (124, 137): - exec_result_str += timeout_msg - - self._logger.debug(f"Command output:\n{exec_result_str}") - return exec_result_str - - def restart_container(self): - self._logger.info("Restarting the container") - if self.container: - self.container.stop(timeout=10) - self.container.remove(force=True) - - self.start_container() - - def cleanup(self): - """Clean up container resources and temporary files. - - Stops and removes the container, removes the Docker image, - and deletes temporary project files. - """ - self._logger.info("Cleaning up container and temporary files") - if self.container: - self.container.stop(timeout=10) - self.container.remove(force=True) - self.container = None - self.client.images.remove(self.tag_name, force=True) - - shutil.rmtree(self.project_path) + timeout_command = f"timeout -k 5 {self.timeout}s {command}" + command = f'/bin/bash -l -c "{timeout_command}"' + self._logger.debug(f"Running command in container: {command}") + exec_result = self.container.exec_run(command, workdir=self.workdir) + exec_result_str = exec_result.output.decode("utf-8") + + if exec_result.exit_code in (124, 137): + exec_result_str += timeout_msg + + self._logger.debug(f"Command output:\n{exec_result_str}") + return exec_result_str + + def restart_container(self): + self._logger.info("Restarting the container") + if self.container: + self.container.stop(timeout=10) + self.container.remove(force=True) + + self.start_container() + + def cleanup(self): + """Clean up container resources and temporary files. + + Stops and removes the container, removes the Docker image, + and deletes temporary project files. + """ + self._logger.info("Cleaning up container and temporary files") + if self.container: + self.container.stop(timeout=10) + self.container.remove(force=True) + self.container = None + self.client.images.remove(self.tag_name, force=True) + + shutil.rmtree(self.project_path) diff --git a/prometheus/docker/general_container.py b/prometheus/docker/general_container.py index 257e7513..10b9bd4b 100644 --- a/prometheus/docker/general_container.py +++ b/prometheus/docker/general_container.py @@ -5,49 +5,49 @@ class GeneralContainer(BaseContainer): - """A general-purpose container with a comprehensive development environment. - - This container provides a full Ubuntu-based development environment with common - development tools and languages pre-installed, including Python, Node.js, Java, - and various build tools. It's designed to be a flexible container that can - handle various types of projects through direct command execution rather than - predefined build and test methods. - - The container includes: - - Build tools (gcc, g++, cmake, make) - - Programming languages (Python 3, Node.js, Java) - - Development tools (git, gdb) - - Database clients (PostgreSQL, MySQL, SQLite) - - Text editors (vim, nano) - - Docker CLI for container management - - Various utility tools (curl, wget, zip, etc.) - - Unlike specialized containers, this container does not implement run_build() or - run_test() methods. Instead, the agent will use execute_command() directly for - custom build and test operations. - """ - - def __init__(self, project_path: Path): - """Initialize the general container with a unique tag name. - - Args: - project_path (Path): Path to the project directory to be containerized. + """A general-purpose container with a comprehensive development environment. + + This container provides a full Ubuntu-based development environment with common + development tools and languages pre-installed, including Python, Node.js, Java, + and various build tools. It's designed to be a flexible container that can + handle various types of projects through direct command execution rather than + predefined build and test methods. + + The container includes: + - Build tools (gcc, g++, cmake, make) + - Programming languages (Python 3, Node.js, Java) + - Development tools (git, gdb) + - Database clients (PostgreSQL, MySQL, SQLite) + - Text editors (vim, nano) + - Docker CLI for container management + - Various utility tools (curl, wget, zip, etc.) + + Unlike specialized containers, this container does not implement run_build() or + run_test() methods. Instead, the agent will use execute_command() directly for + custom build and test operations. """ - super().__init__(project_path) - self.tag_name = f"prometheus_general_container_{uuid.uuid4().hex[:10]}" - def get_dockerfile_content(self) -> str: - """Get the Dockerfile content for the general-purpose container. + def __init__(self, project_path: Path): + """Initialize the general container with a unique tag name. - The Dockerfile sets up an Ubuntu-based environment with a comprehensive - set of development tools and languages installed. It includes Python, - Node.js, Java, and various build tools, making it suitable for different - types of projects. + Args: + project_path (Path): Path to the project directory to be containerized. + """ + super().__init__(project_path) + self.tag_name = f"prometheus_general_container_{uuid.uuid4().hex[:10]}" - Returns: - str: Content of the Dockerfile as a string. - """ - DOCKERFILE_CONTENT = """\ + def get_dockerfile_content(self) -> str: + """Get the Dockerfile content for the general-purpose container. + + The Dockerfile sets up an Ubuntu-based environment with a comprehensive + set of development tools and languages installed. It includes Python, + Node.js, Java, and various build tools, making it suitable for different + types of projects. + + Returns: + str: Content of the Dockerfile as a string. + """ + DOCKERFILE_CONTENT = """\ FROM ubuntu:24.04 # Avoid timezone prompts during package installation @@ -101,32 +101,32 @@ def get_dockerfile_content(self) -> str: # Copy project files COPY . /app/ """ - return DOCKERFILE_CONTENT - - def run_build(self): - """Not implemented for GeneralContainer. - - This method is intentionally not implemented as the GeneralContainer is designed - to use execute_command() directly for custom build operations. - - Raises: - NotImplementedError: Always raises this exception to indicate that direct - command execution should be used instead. - """ - raise NotImplementedError( - "GeneralContainer does not support run_build, use execute_command directly" - ) - - def run_test(self): - """Not implemented for GeneralContainer. - - This method is intentionally not implemented as the GeneralContainer is designed - to use execute_command() directly for custom test operations. - - Raises: - NotImplementedError: Always raises this exception to indicate that direct - command execution should be used instead. - """ - raise NotImplementedError( - "GeneralContainer does not support run_test, use execute_command directly" - ) + return DOCKERFILE_CONTENT + + def run_build(self): + """Not implemented for GeneralContainer. + + This method is intentionally not implemented as the GeneralContainer is designed + to use execute_command() directly for custom build operations. + + Raises: + NotImplementedError: Always raises this exception to indicate that direct + command execution should be used instead. + """ + raise NotImplementedError( + "GeneralContainer does not support run_build, use execute_command directly" + ) + + def run_test(self): + """Not implemented for GeneralContainer. + + This method is intentionally not implemented as the GeneralContainer is designed + to use execute_command() directly for custom test operations. + + Raises: + NotImplementedError: Always raises this exception to indicate that direct + command execution should be used instead. + """ + raise NotImplementedError( + "GeneralContainer does not support run_test, use execute_command directly" + ) diff --git a/prometheus/docker/user_defined_container.py b/prometheus/docker/user_defined_container.py index 83157c14..025ab9cb 100644 --- a/prometheus/docker/user_defined_container.py +++ b/prometheus/docker/user_defined_container.py @@ -6,55 +6,55 @@ class UserDefinedContainer(BaseContainer): - def __init__( - self, - project_path: Path, - workdir: Optional[str] = None, - build_commands: Optional[Sequence[str]] = None, - test_commands: Optional[Sequence[str]] = None, - dockerfile_content: Optional[str] = None, - image_name: Optional[str] = None, - ): - super().__init__(project_path, workdir) - - assert bool(dockerfile_content) != bool(image_name), ( - "Exactly one of dockerfile_content or image_name must be provided" - ) - - self.tag_name = f"prometheus_user_defined_container_{uuid.uuid4().hex[:10]}" - self.build_commands = build_commands - self.test_commands = test_commands - self.dockerfile_content = dockerfile_content - self.image_name = image_name - - def get_dockerfile_content(self) -> str: - return self.dockerfile_content - - def build_docker_image(self): - if self.dockerfile_content: - super().build_docker_image() - else: - self._logger.info(f"Pulling docker image: {self.image_name}") - pulled_image = self.client.images.pull(self.image_name) - self._logger.info(f"Tagging pulled image as: {self.tag_name}") - pulled_image.tag(repository=self.tag_name) - - def run_build(self) -> str: - if self.build_commands is None: - raise ValueError("build_commands is None. The user did not provide build commands.") - - command_output = "" - for build_command in self.build_commands: - command_output += f"$ {build_command}\n" - command_output += f"{self.execute_command(build_command)}\n" - return command_output - - def run_test(self) -> str: - if self.test_commands is None: - raise ValueError("test_commands is None. The user did not provide build commands.") - - command_output = "" - for test_command in self.test_commands: - command_output += f"$ {test_command}\n" - command_output += f"{self.execute_command(test_command)}\n" - return command_output + def __init__( + self, + project_path: Path, + workdir: Optional[str] = None, + build_commands: Optional[Sequence[str]] = None, + test_commands: Optional[Sequence[str]] = None, + dockerfile_content: Optional[str] = None, + image_name: Optional[str] = None, + ): + super().__init__(project_path, workdir) + + assert bool(dockerfile_content) != bool(image_name), ( + "Exactly one of dockerfile_content or image_name must be provided" + ) + + self.tag_name = f"prometheus_user_defined_container_{uuid.uuid4().hex[:10]}" + self.build_commands = build_commands + self.test_commands = test_commands + self.dockerfile_content = dockerfile_content + self.image_name = image_name + + def get_dockerfile_content(self) -> str: + return self.dockerfile_content + + def build_docker_image(self): + if self.dockerfile_content: + super().build_docker_image() + else: + self._logger.info(f"Pulling docker image: {self.image_name}") + pulled_image = self.client.images.pull(self.image_name) + self._logger.info(f"Tagging pulled image as: {self.tag_name}") + pulled_image.tag(repository=self.tag_name) + + def run_build(self) -> str: + if self.build_commands is None: + raise ValueError("build_commands is None. The user did not provide build commands.") + + command_output = "" + for build_command in self.build_commands: + command_output += f"$ {build_command}\n" + command_output += f"{self.execute_command(build_command)}\n" + return command_output + + def run_test(self) -> str: + if self.test_commands is None: + raise ValueError("test_commands is None. The user did not provide build commands.") + + command_output = "" + for test_command in self.test_commands: + command_output += f"$ {test_command}\n" + command_output += f"{self.execute_command(test_command)}\n" + return command_output diff --git a/prometheus/git/git_repository.py b/prometheus/git/git_repository.py index 30341c8d..aa5412da 100644 --- a/prometheus/git/git_repository.py +++ b/prometheus/git/git_repository.py @@ -11,145 +11,145 @@ class GitRepository: - """A class for managing Git repositories with support for both local and remote operations. - - This class provides a unified interface for working with Git repositories, - whether they are local or remote (HTTPS). It supports common Git operations - such as cloning, checking out commits, switching branches, and pushing changes. - For remote repositories, it handles authentication using GitHub access tokens.. - """ - - def __init__( - self, - address: str, - working_directory: Path, - copy_to_working_dir: bool = True, - github_access_token: Optional[str] = None, - ): - """Initialize a GitRepository instance. - - Args: - address: Either a local path to a Git repository or an HTTPS URL - for a remote repository. - working_directory: Directory where the repository will be stored - or copied to. - copy_to_working_dir: If True, creates a copy of a local - repository in the working directory. Defaults to True. - github_access_token: GitHub access token for authentication with remote - repositories. Required if address is an HTTPS URL. Defaults to None. - """ - self._logger = logging.getLogger("prometheus.git.git_repository") - - # Configure git command to use our logger - g = Git() - type(g).GIT_PYTHON_TRACE = "full" - git_cmd_logger = logging.getLogger("git.cmd") - - # Ensure git command output goes to our logger - for handler in git_cmd_logger.handlers: - git_cmd_logger.removeHandler(handler) - git_cmd_logger.parent = self._logger - git_cmd_logger.propagate = True - - if address.startswith("https://"): - if github_access_token is None: - raise ValueError("github_access_token is required for https repository") - self.repo = self._clone_repository(address, github_access_token, working_directory) - self._set_default_branch() - else: - local_path = address - if copy_to_working_dir: - local_path = working_directory / os.path.basename(address) - shutil.copytree(address, local_path) - try: - self.repo = Repo(local_path) - self._set_default_branch() - except InvalidGitRepositoryError: - self.repo = Repo.init(local_path) - self._set_default_branch() - - def _set_default_branch(self): - try: - self.default_branch = ( - self.repo.remote().refs["HEAD"].reference.name.replace("refs/heads/", "") - ) - except ValueError: - self.default_branch = self.repo.active_branch.name - - def _clone_repository( - self, https_url: str, github_access_token: str, target_directory: Path - ) -> Repo: - """Clone a remote repository using HTTPS authentication. - - Args: - https_url: HTTPS URL of the remote repository. - github_access_token: GitHub access token for authentication. - target_directory: Directory where the repository will be cloned. - - Returns: - Repo: GitPython Repo object representing the cloned repository. - """ - self._original_https_url = https_url - https_url = https_url.replace("https://", f"https://{github_access_token}@") - repo_name = https_url.split("/")[-1].split(".")[0] - local_path = target_directory / repo_name - if local_path.exists(): - shutil.rmtree(local_path) - - return Repo.clone_from(https_url, local_path) - - def checkout_commit(self, commit_sha: str): - self.repo.git.checkout(commit_sha) - - def switch_branch(self, branch_name: str): - self.repo.git.checkout(branch_name) - - def pull(self): - self.repo.git.pull() - - def get_diff(self, excluded_files: Optional[Sequence[str]] = None) -> str: - self.repo.git.add("-A") - if excluded_files: - self.repo.git.reset(excluded_files) - diff = self.repo.git.diff("--staged") - if diff and not diff.endswith("\n"): - diff += "\n" - self.repo.git.reset() - return diff - - def get_working_directory(self) -> Path: - return Path(self.repo.working_dir).absolute() - - def reset_repository(self): - self.repo.git.reset("--hard") - self.repo.git.clean("-fd") - - def remove_repository(self): - if self.repo is not None: - shutil.rmtree(self.repo.working_dir) - self.repo = None - - def create_and_push_branch(self, branch_name: str, commit_message: str, patch: str): - """Create a new branch, commit changes, and push to remote. - - This method creates a new branch, switches to it, stages all changes, - commits them with the provided message, and pushes the branch to the - remote repository. - - Args: - branch_name: Name of the new branch to create. - commit_message: Message for the commit. + """A class for managing Git repositories with support for both local and remote operations. + + This class provides a unified interface for working with Git repositories, + whether they are local or remote (HTTPS). It supports common Git operations + such as cloning, checking out commits, switching branches, and pushing changes. + For remote repositories, it handles authentication using GitHub access tokens.. """ - with tempfile.NamedTemporaryFile(mode="w", suffix=".patch") as tmp_file: - tmp_file.write(patch) - tmp_file.flush() - - new_branch = self.repo.create_head(branch_name) - new_branch.checkout() - - self.repo.git.apply(tmp_file.name) - self.repo.git.add(A=True) - self.repo.index.commit(commit_message) - self.repo.git.push("--set-upstream", "origin", branch_name) - self.reset_repository() - self.switch_branch(self.default_branch) + + def __init__( + self, + address: str, + working_directory: Path, + copy_to_working_dir: bool = True, + github_access_token: Optional[str] = None, + ): + """Initialize a GitRepository instance. + + Args: + address: Either a local path to a Git repository or an HTTPS URL + for a remote repository. + working_directory: Directory where the repository will be stored + or copied to. + copy_to_working_dir: If True, creates a copy of a local + repository in the working directory. Defaults to True. + github_access_token: GitHub access token for authentication with remote + repositories. Required if address is an HTTPS URL. Defaults to None. + """ + self._logger = logging.getLogger("prometheus.git.git_repository") + + # Configure git command to use our logger + g = Git() + type(g).GIT_PYTHON_TRACE = "full" + git_cmd_logger = logging.getLogger("git.cmd") + + # Ensure git command output goes to our logger + for handler in git_cmd_logger.handlers: + git_cmd_logger.removeHandler(handler) + git_cmd_logger.parent = self._logger + git_cmd_logger.propagate = True + + if address.startswith("https://"): + if github_access_token is None: + raise ValueError("github_access_token is required for https repository") + self.repo = self._clone_repository(address, github_access_token, working_directory) + self._set_default_branch() + else: + local_path = address + if copy_to_working_dir: + local_path = working_directory / os.path.basename(address) + shutil.copytree(address, local_path) + try: + self.repo = Repo(local_path) + self._set_default_branch() + except InvalidGitRepositoryError: + self.repo = Repo.init(local_path) + self._set_default_branch() + + def _set_default_branch(self): + try: + self.default_branch = ( + self.repo.remote().refs["HEAD"].reference.name.replace("refs/heads/", "") + ) + except ValueError: + self.default_branch = self.repo.active_branch.name + + def _clone_repository( + self, https_url: str, github_access_token: str, target_directory: Path + ) -> Repo: + """Clone a remote repository using HTTPS authentication. + + Args: + https_url: HTTPS URL of the remote repository. + github_access_token: GitHub access token for authentication. + target_directory: Directory where the repository will be cloned. + + Returns: + Repo: GitPython Repo object representing the cloned repository. + """ + self._original_https_url = https_url + https_url = https_url.replace("https://", f"https://{github_access_token}@") + repo_name = https_url.split("/")[-1].split(".")[0] + local_path = target_directory / repo_name + if local_path.exists(): + shutil.rmtree(local_path) + + return Repo.clone_from(https_url, local_path) + + def checkout_commit(self, commit_sha: str): + self.repo.git.checkout(commit_sha) + + def switch_branch(self, branch_name: str): + self.repo.git.checkout(branch_name) + + def pull(self): + self.repo.git.pull() + + def get_diff(self, excluded_files: Optional[Sequence[str]] = None) -> str: + self.repo.git.add("-A") + if excluded_files: + self.repo.git.reset(excluded_files) + diff = self.repo.git.diff("--staged") + if diff and not diff.endswith("\n"): + diff += "\n" + self.repo.git.reset() + return diff + + def get_working_directory(self) -> Path: + return Path(self.repo.working_dir).absolute() + + def reset_repository(self): + self.repo.git.reset("--hard") + self.repo.git.clean("-fd") + + def remove_repository(self): + if self.repo is not None: + shutil.rmtree(self.repo.working_dir) + self.repo = None + + def create_and_push_branch(self, branch_name: str, commit_message: str, patch: str): + """Create a new branch, commit changes, and push to remote. + + This method creates a new branch, switches to it, stages all changes, + commits them with the provided message, and pushes the branch to the + remote repository. + + Args: + branch_name: Name of the new branch to create. + commit_message: Message for the commit. + """ + with tempfile.NamedTemporaryFile(mode="w", suffix=".patch") as tmp_file: + tmp_file.write(patch) + tmp_file.flush() + + new_branch = self.repo.create_head(branch_name) + new_branch.checkout() + + self.repo.git.apply(tmp_file.name) + self.repo.git.add(A=True) + self.repo.index.commit(commit_message) + self.repo.git.push("--set-upstream", "origin", branch_name) + self.reset_repository() + self.switch_branch(self.default_branch) diff --git a/prometheus/graph/file_graph_builder.py b/prometheus/graph/file_graph_builder.py index 80569e2e..32413814 100644 --- a/prometheus/graph/file_graph_builder.py +++ b/prometheus/graph/file_graph_builder.py @@ -8,194 +8,196 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter from prometheus.graph.graph_types import ( - ASTNode, - KnowledgeGraphEdge, - KnowledgeGraphEdgeType, - KnowledgeGraphNode, - TextNode, + ASTNode, + KnowledgeGraphEdge, + KnowledgeGraphEdgeType, + KnowledgeGraphNode, + TextNode, ) from prometheus.parser import tree_sitter_parser class FileGraphBuilder: - """A class for building knowledge graphs from individual files. + """A class for building knowledge graphs from individual files. - This class processes files and creates knowledge graph representations using different - strategies based on the file type. For source code files, it uses tree-sitter to - create an Abstract Syntax Tree (AST) representation. For markdown files, it creates - a chain of text nodes based on the document's structure. + This class processes files and creates knowledge graph representations using different + strategies based on the file type. For source code files, it uses tree-sitter to + create an Abstract Syntax Tree (AST) representation. For markdown files, it creates + a chain of text nodes based on the document's structure. - The resulting knowledge graph consists of nodes (KnowledgeGraphNode) connected by - edges (KnowledgeGraphEdge) with different relationship types (KnowledgeGraphEdgeType). - """ - - def __init__(self, max_ast_depth: int, chunk_size: int, chunk_overlap: int): - """Initialize the FileGraphBuilder. - - Args: - max_ast_depth: Maximum depth to traverse in the AST when processing source code files. - Higher values create more detailed but larger graphs. - chunk_size: The chunk size for text files. - chunk_overlap: The overlap size for text files. - """ - self.max_ast_depth = max_ast_depth - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - - def supports_file(self, file: Path) -> bool: - """Checks if we support building knowledge graph for this file.""" - if tree_sitter_parser.supports_file(file): - return True - - if file.suffix in [".md", ".txt", ".rst"]: - return True - - return False - - def build_file_graph( - self, parent_node: KnowledgeGraphNode, file: Path, next_node_id: int - ) -> Tuple[int, Sequence[KnowledgeGraphNode], Sequence[KnowledgeGraphEdge]]: - """Build knowledge graph for a single file. - - Args: - parent_node: The parent knowledge graph node that represent the file. - The node attribute should have type FileNode. - file: The file to build knowledge graph. - next_node_id: The next available node id. - - Returns: - A tuple of (next_node_id, kg_nodes, kg_edges), where next_node_id is the - new next_node_id, kg_nodes is a list of all nodes created for the file, - and kg_edges is a list of all edges created for this file. - """ - # In this case, it is a file that tree sitter can parse (source code) - if tree_sitter_parser.supports_file(file): - return self._tree_sitter_file_graph(parent_node, file, next_node_id) - - if file.suffix in [".md", ".txt", ".rst"]: - return self._text_file_graph(parent_node, file, next_node_id) - - def _tree_sitter_file_graph( - self, parent_node: KnowledgeGraphNode, file: Path, next_node_id: int - ) -> Tuple[int, Sequence[KnowledgeGraphNode], Sequence[KnowledgeGraphEdge]]: - """Parse a file to a tree-sitter graph. - - Tree-sitter is an abstract syntax tree (AST) parser to parse source code. We simply - use the tree representation as the knowledge graph to represent this file. - In the AST, nodes have PARENT_OF relationship, if one node is a parent of another - node. For example, 'class_definition' can be a parent of 'function_definition'. - The root ASTNode is connected to the parent_node using the HAS_AST relationship. - - Args: - parent_node: The parent knowledge graph node that represent the file. - The node attribute should have type FileNode. - file: The file to build knowledge graph. - next_node_id: The next available node id. - - Returns: - A tuple of (next_node_id, kg_nodes, kg_edges), where next_node_id is the - new next_node_id, kg_nodes is a list of all nodes created for the file, - and kg_edges is a list of all edges created for this file. + The resulting knowledge graph consists of nodes (KnowledgeGraphNode) connected by + edges (KnowledgeGraphEdge) with different relationship types (KnowledgeGraphEdgeType). """ - tree_sitter_nodes = [] - tree_sitter_edges = [] - - tree = tree_sitter_parser.parse(file) - if tree.root_node.has_error or tree.root_node.child_count == 0: - return next_node_id, tree_sitter_nodes, tree_sitter_edges - - ast_root_node = ASTNode( - type=tree.root_node.type, - start_line=tree.root_node.start_point[0] + 1, - end_line=tree.root_node.end_point[0] + 1, - text=tree.root_node.text.decode("utf-8"), - ) - kg_ast_root_node = KnowledgeGraphNode(next_node_id, ast_root_node) - next_node_id += 1 - tree_sitter_nodes.append(kg_ast_root_node) - tree_sitter_edges.append( - KnowledgeGraphEdge(parent_node, kg_ast_root_node, KnowledgeGraphEdgeType.has_ast) - ) - - node_stack = deque() - node_stack.append((tree.root_node, kg_ast_root_node, 1)) - while node_stack: - tree_sitter_node, kg_node, depth = node_stack.pop() - - if depth > self.max_ast_depth: - continue - - for tree_sitter_child_node in tree_sitter_node.children: - child_ast_node = ASTNode( - type=tree_sitter_child_node.type, - start_line=tree_sitter_child_node.start_point[0] + 1, - end_line=tree_sitter_child_node.end_point[0] + 1, - text=tree_sitter_child_node.text.decode("utf-8"), + + def __init__(self, max_ast_depth: int, chunk_size: int, chunk_overlap: int): + """Initialize the FileGraphBuilder. + + Args: + max_ast_depth: Maximum depth to traverse in the AST when processing source code files. + Higher values create more detailed but larger graphs. + chunk_size: The chunk size for text files. + chunk_overlap: The overlap size for text files. + """ + self.max_ast_depth = max_ast_depth + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + def supports_file(self, file: Path) -> bool: + """Checks if we support building knowledge graph for this file.""" + if tree_sitter_parser.supports_file(file): + return True + + if file.suffix in [".md", ".txt", ".rst"]: + return True + + return False + + def build_file_graph( + self, parent_node: KnowledgeGraphNode, file: Path, next_node_id: int + ) -> Tuple[int, Sequence[KnowledgeGraphNode], Sequence[KnowledgeGraphEdge]]: + """Build knowledge graph for a single file. + + Args: + parent_node: The parent knowledge graph node that represent the file. + The node attribute should have type FileNode. + file: The file to build knowledge graph. + next_node_id: The next available node id. + + Returns: + A tuple of (next_node_id, kg_nodes, kg_edges), where next_node_id is the + new next_node_id, kg_nodes is a list of all nodes created for the file, + and kg_edges is a list of all edges created for this file. + """ + # In this case, it is a file that tree sitter can parse (source code) + if tree_sitter_parser.supports_file(file): + return self._tree_sitter_file_graph(parent_node, file, next_node_id) + + if file.suffix in [".md", ".txt", ".rst"]: + return self._text_file_graph(parent_node, file, next_node_id) + + def _tree_sitter_file_graph( + self, parent_node: KnowledgeGraphNode, file: Path, next_node_id: int + ) -> Tuple[int, Sequence[KnowledgeGraphNode], Sequence[KnowledgeGraphEdge]]: + """Parse a file to a tree-sitter graph. + + Tree-sitter is an abstract syntax tree (AST) parser to parse source code. We simply + use the tree representation as the knowledge graph to represent this file. + In the AST, nodes have PARENT_OF relationship, if one node is a parent of another + node. For example, 'class_definition' can be a parent of 'function_definition'. + The root ASTNode is connected to the parent_node using the HAS_AST relationship. + + Args: + parent_node: The parent knowledge graph node that represent the file. + The node attribute should have type FileNode. + file: The file to build knowledge graph. + next_node_id: The next available node id. + + Returns: + A tuple of (next_node_id, kg_nodes, kg_edges), where next_node_id is the + new next_node_id, kg_nodes is a list of all nodes created for the file, + and kg_edges is a list of all edges created for this file. + """ + tree_sitter_nodes = [] + tree_sitter_edges = [] + + tree = tree_sitter_parser.parse(file) + if tree.root_node.has_error or tree.root_node.child_count == 0: + return next_node_id, tree_sitter_nodes, tree_sitter_edges + + ast_root_node = ASTNode( + type=tree.root_node.type, + start_line=tree.root_node.start_point[0] + 1, + end_line=tree.root_node.end_point[0] + 1, + text=tree.root_node.text.decode("utf-8"), ) - kg_child_ast_node = KnowledgeGraphNode(next_node_id, child_ast_node) + kg_ast_root_node = KnowledgeGraphNode(next_node_id, ast_root_node) next_node_id += 1 - - tree_sitter_nodes.append(kg_child_ast_node) + tree_sitter_nodes.append(kg_ast_root_node) tree_sitter_edges.append( - KnowledgeGraphEdge(kg_node, kg_child_ast_node, KnowledgeGraphEdgeType.parent_of) + KnowledgeGraphEdge(parent_node, kg_ast_root_node, KnowledgeGraphEdgeType.has_ast) ) - node_stack.append((tree_sitter_child_node, kg_child_ast_node, depth + 1)) - return next_node_id, tree_sitter_nodes, tree_sitter_edges - - def _text_file_graph( - self, parent_node: KnowledgeGraphNode, file: Path, next_node_id: int - ) -> Tuple[int, Sequence[KnowledgeGraphNode], Sequence[KnowledgeGraphEdge]]: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, length_function=len - ) - text = file.open(encoding="utf-8").read() - documents = text_splitter.create_documents([text]) - return self._documents_to_file_graph(documents, parent_node, next_node_id) - - def _documents_to_file_graph( - self, - documents: Sequence[Document], - parent_node: KnowledgeGraphNode, - next_node_id: int, - ) -> Tuple[int, Sequence[KnowledgeGraphNode], Sequence[KnowledgeGraphEdge]]: - """Convert the parsed langchain documents to a knowledge graph. - - The parsed document will form a chain of nodes, where all nodes are connected - to the parent_node using the HAS_TEXT relationship. The nodes are connected using - the NEXT_CHUNK relationship in chronological order. - - Args: - documents: The langchain documents used to create the TextNode. - parent_node: The parent knowledge graph node that represent the file. - The node attribute should have type FileNode. - next_node_id: The next available node id. - - Returns: - A tuple of (next_node_id, kg_nodes, kg_edges), where next_node_id is the - new next_node_id, kg_nodes is a list of all nodes created for the file, - and kg_edges is a list of all edges created for this file. - """ - document_nodes = [] - document_edges = [] - - previous_node = None - for document in documents: - text_node = TextNode( - text=document.page_content, - metadata=str(document.metadata) if document.metadata else "", - ) - kg_text_node = KnowledgeGraphNode(next_node_id, text_node) - next_node_id += 1 - document_nodes.append(kg_text_node) - document_edges.append( - KnowledgeGraphEdge(parent_node, kg_text_node, KnowledgeGraphEdgeType.has_text) - ) - - if previous_node: - document_edges.append( - KnowledgeGraphEdge(previous_node, kg_text_node, KnowledgeGraphEdgeType.next_chunk) + node_stack = deque() + node_stack.append((tree.root_node, kg_ast_root_node, 1)) + while node_stack: + tree_sitter_node, kg_node, depth = node_stack.pop() + + if depth > self.max_ast_depth: + continue + + for tree_sitter_child_node in tree_sitter_node.children: + child_ast_node = ASTNode( + type=tree_sitter_child_node.type, + start_line=tree_sitter_child_node.start_point[0] + 1, + end_line=tree_sitter_child_node.end_point[0] + 1, + text=tree_sitter_child_node.text.decode("utf-8"), + ) + kg_child_ast_node = KnowledgeGraphNode(next_node_id, child_ast_node) + next_node_id += 1 + + tree_sitter_nodes.append(kg_child_ast_node) + tree_sitter_edges.append( + KnowledgeGraphEdge(kg_node, kg_child_ast_node, KnowledgeGraphEdgeType.parent_of) + ) + + node_stack.append((tree_sitter_child_node, kg_child_ast_node, depth + 1)) + return next_node_id, tree_sitter_nodes, tree_sitter_edges + + def _text_file_graph( + self, parent_node: KnowledgeGraphNode, file: Path, next_node_id: int + ) -> Tuple[int, Sequence[KnowledgeGraphNode], Sequence[KnowledgeGraphEdge]]: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, length_function=len ) - - previous_node = kg_text_node - return next_node_id, document_nodes, document_edges + text = file.open(encoding="utf-8").read() + documents = text_splitter.create_documents([text]) + return self._documents_to_file_graph(documents, parent_node, next_node_id) + + def _documents_to_file_graph( + self, + documents: Sequence[Document], + parent_node: KnowledgeGraphNode, + next_node_id: int, + ) -> Tuple[int, Sequence[KnowledgeGraphNode], Sequence[KnowledgeGraphEdge]]: + """Convert the parsed langchain documents to a knowledge graph. + + The parsed document will form a chain of nodes, where all nodes are connected + to the parent_node using the HAS_TEXT relationship. The nodes are connected using + the NEXT_CHUNK relationship in chronological order. + + Args: + documents: The langchain documents used to create the TextNode. + parent_node: The parent knowledge graph node that represent the file. + The node attribute should have type FileNode. + next_node_id: The next available node id. + + Returns: + A tuple of (next_node_id, kg_nodes, kg_edges), where next_node_id is the + new next_node_id, kg_nodes is a list of all nodes created for the file, + and kg_edges is a list of all edges created for this file. + """ + document_nodes = [] + document_edges = [] + + previous_node = None + for document in documents: + text_node = TextNode( + text=document.page_content, + metadata=str(document.metadata) if document.metadata else "", + ) + kg_text_node = KnowledgeGraphNode(next_node_id, text_node) + next_node_id += 1 + document_nodes.append(kg_text_node) + document_edges.append( + KnowledgeGraphEdge(parent_node, kg_text_node, KnowledgeGraphEdgeType.has_text) + ) + + if previous_node: + document_edges.append( + KnowledgeGraphEdge( + previous_node, kg_text_node, KnowledgeGraphEdgeType.next_chunk + ) + ) + + previous_node = kg_text_node + return next_node_id, document_nodes, document_edges diff --git a/prometheus/graph/graph_types.py b/prometheus/graph/graph_types.py index b545df84..512c26b9 100644 --- a/prometheus/graph/graph_types.py +++ b/prometheus/graph/graph_types.py @@ -6,221 +6,221 @@ class CodeBaseSourceEnum(enum.StrEnum): - """Enum of all knowledge graph edge types""" + """Enum of all knowledge graph edge types""" - github = "GitHub" - local = "local" + github = "GitHub" + local = "local" @dataclasses.dataclass(frozen=True) class MetadataNode: - """A node for the codebase metadata. + """A node for the codebase metadata. - Attributes: - codebase_source: Where the codebase originated from. - local_path: The local path where the codebase is stored. - https_url: The HTTPS URL of the remote repository. - commit_id: The commit ID of the remote repository. - """ + Attributes: + codebase_source: Where the codebase originated from. + local_path: The local path where the codebase is stored. + https_url: The HTTPS URL of the remote repository. + commit_id: The commit ID of the remote repository. + """ - codebase_source: CodeBaseSourceEnum + codebase_source: CodeBaseSourceEnum - local_path: str + local_path: str - https_url: str - commit_id: str + https_url: str + commit_id: str - def to_neo4j_node(self) -> "Neo4jMetadataNode": - return Neo4jMetadataNode( - codebase_source=self.codebase_source, - local_path=self.local_path, - https_url=self.https_url, - commit_id=self.commit_id, - ) + def to_neo4j_node(self) -> "Neo4jMetadataNode": + return Neo4jMetadataNode( + codebase_source=self.codebase_source, + local_path=self.local_path, + https_url=self.https_url, + commit_id=self.commit_id, + ) - @classmethod - def from_neo4j_metadata_node(cls, node: "Neo4jMetadataNode") -> "MetadataNode": - return cls( - codebase_source=node["codebase_source"], - local_path=node["local_path"], - https_url=node["https_url"], - commit_id=node["commit_id"], - ) + @classmethod + def from_neo4j_metadata_node(cls, node: "Neo4jMetadataNode") -> "MetadataNode": + return cls( + codebase_source=node["codebase_source"], + local_path=node["local_path"], + https_url=node["https_url"], + commit_id=node["commit_id"], + ) @dataclasses.dataclass(frozen=True) class FileNode: - """A node representing a file/dir. + """A node representing a file/dir. - Attributes: - basename: The basename of a file/dir, like 'bar.py' or 'foo'. - relative_path: The relative path from the root path, like 'foo/bar/baz.java'. - """ + Attributes: + basename: The basename of a file/dir, like 'bar.py' or 'foo'. + relative_path: The relative path from the root path, like 'foo/bar/baz.java'. + """ - basename: str - relative_path: str + basename: str + relative_path: str @dataclasses.dataclass(frozen=True) class ASTNode: - """A node representing a tree-sitter node. + """A node representing a tree-sitter node. - Attributes: - type: The tree-sitter node type. - start_line: The starting line number. 0-indexed and inclusive. - end_line: The ending line number. 0-indexed and inclusive. - text: The source code correcpsonding to the node. - """ + Attributes: + type: The tree-sitter node type. + start_line: The starting line number. 0-indexed and inclusive. + end_line: The ending line number. 0-indexed and inclusive. + text: The source code correcpsonding to the node. + """ - type: str - start_line: int - end_line: int - text: str + type: str + start_line: int + end_line: int + text: str @dataclasses.dataclass(frozen=True) class TextNode: - """A node representing a piece of text. + """A node representing a piece of text. - Attributes: - text: A string. - metadata: The metadata about the string. - """ + Attributes: + text: A string. + metadata: The metadata about the string. + """ - text: str - metadata: str + text: str + metadata: str @dataclasses.dataclass(frozen=True) class KnowledgeGraphNode: - """A node in the knowledge graph. - - Attributes: - node_id: A id that uniquely identifies a node in the graph. - node: The node itself, can be a FileNode, ASTNode or TextNode. - """ - - node_id: int - node: Union[FileNode, ASTNode, TextNode] - - def to_neo4j_node(self) -> Union["Neo4jFileNode", "Neo4jASTNode", "Neo4jTextNode"]: - """Convert the KnowledgeGraphNode into a Neo4j node format.""" - match self.node: - case FileNode(): - return Neo4jFileNode( - node_id=self.node_id, - basename=self.node.basename, - relative_path=self.node.relative_path, + """A node in the knowledge graph. + + Attributes: + node_id: A id that uniquely identifies a node in the graph. + node: The node itself, can be a FileNode, ASTNode or TextNode. + """ + + node_id: int + node: Union[FileNode, ASTNode, TextNode] + + def to_neo4j_node(self) -> Union["Neo4jFileNode", "Neo4jASTNode", "Neo4jTextNode"]: + """Convert the KnowledgeGraphNode into a Neo4j node format.""" + match self.node: + case FileNode(): + return Neo4jFileNode( + node_id=self.node_id, + basename=self.node.basename, + relative_path=self.node.relative_path, + ) + case ASTNode(): + return Neo4jASTNode( + node_id=self.node_id, + type=self.node.type, + start_line=self.node.start_line, + end_line=self.node.end_line, + text=self.node.text, + ) + case TextNode(): + return Neo4jTextNode( + node_id=self.node_id, + text=self.node.text, + metadata=self.node.metadata, + ) + case _: + raise ValueError("Unknown KnowledgeGraphNode.node type") + + @classmethod + def from_neo4j_file_node(cls, node: "Neo4jFileNode") -> "KnowledgeGraphNode": + return cls( + node_id=node["node_id"], + node=FileNode( + basename=node["basename"], + relative_path=node["relative_path"], + ), ) - case ASTNode(): - return Neo4jASTNode( - node_id=self.node_id, - type=self.node.type, - start_line=self.node.start_line, - end_line=self.node.end_line, - text=self.node.text, + + @classmethod + def from_neo4j_ast_node(cls, node: "Neo4jASTNode") -> "KnowledgeGraphNode": + return cls( + node_id=node["node_id"], + node=ASTNode( + type=node["type"], + start_line=node["start_line"], + end_line=node["end_line"], + text=node["text"], + ), ) - case TextNode(): - return Neo4jTextNode( - node_id=self.node_id, - text=self.node.text, - metadata=self.node.metadata, + + @classmethod + def from_neo4j_text_node(cls, node: "Neo4jTextNode") -> "KnowledgeGraphNode": + return cls( + node_id=node["node_id"], + node=TextNode(text=node["text"], metadata=node["metadata"]), ) - case _: - raise ValueError("Unknown KnowledgeGraphNode.node type") - - @classmethod - def from_neo4j_file_node(cls, node: "Neo4jFileNode") -> "KnowledgeGraphNode": - return cls( - node_id=node["node_id"], - node=FileNode( - basename=node["basename"], - relative_path=node["relative_path"], - ), - ) - - @classmethod - def from_neo4j_ast_node(cls, node: "Neo4jASTNode") -> "KnowledgeGraphNode": - return cls( - node_id=node["node_id"], - node=ASTNode( - type=node["type"], - start_line=node["start_line"], - end_line=node["end_line"], - text=node["text"], - ), - ) - - @classmethod - def from_neo4j_text_node(cls, node: "Neo4jTextNode") -> "KnowledgeGraphNode": - return cls( - node_id=node["node_id"], - node=TextNode(text=node["text"], metadata=node["metadata"]), - ) class KnowledgeGraphEdgeType(enum.StrEnum): - """Enum of all knowledge graph edge types""" + """Enum of all knowledge graph edge types""" - parent_of = "PARENT_OF" - has_file = "HAS_FILE" - has_ast = "HAS_AST" - has_text = "HAS_TEXT" - next_chunk = "NEXT_CHUNK" + parent_of = "PARENT_OF" + has_file = "HAS_FILE" + has_ast = "HAS_AST" + has_text = "HAS_TEXT" + next_chunk = "NEXT_CHUNK" @dataclasses.dataclass(frozen=True) class KnowledgeGraphEdge: - """An edge in the knowledge graph. - - Attributes: - source: The source knowledge graph node. - target: The target knowledge graph node. - type: The knowledge graph edge type. - """ - - source: KnowledgeGraphNode - target: KnowledgeGraphNode - type: KnowledgeGraphEdgeType - - def to_neo4j_edge( - self, - ) -> Union[ - "Neo4jHasFileEdge", - "Neo4jHasASTEdge", - "Neo4jParentOfEdge", - "Neo4jHasTextEdge", - "Neo4jNextChunkEdge", - ]: - """Convert the KnowledgeGraphEdge into a Neo4j edge format.""" - match self.type: - case KnowledgeGraphEdgeType.has_file: - return Neo4jHasFileEdge( - source=self.source.to_neo4j_node(), - target=self.target.to_neo4j_node(), - ) - case KnowledgeGraphEdgeType.has_ast: - return Neo4jHasASTEdge( - source=self.source.to_neo4j_node(), - target=self.target.to_neo4j_node(), - ) - case KnowledgeGraphEdgeType.parent_of: - return Neo4jParentOfEdge( - source=self.source.to_neo4j_node(), - target=self.target.to_neo4j_node(), - ) - case KnowledgeGraphEdgeType.has_text: - return Neo4jHasTextEdge( - source=self.source.to_neo4j_node(), - target=self.target.to_neo4j_node(), - ) - case KnowledgeGraphEdgeType.next_chunk: - return Neo4jNextChunkEdge( - source=self.source.to_neo4j_node(), - target=self.target.to_neo4j_node(), - ) - case _: - raise ValueError(f"Unknown edge type: {self.type}") + """An edge in the knowledge graph. + + Attributes: + source: The source knowledge graph node. + target: The target knowledge graph node. + type: The knowledge graph edge type. + """ + + source: KnowledgeGraphNode + target: KnowledgeGraphNode + type: KnowledgeGraphEdgeType + + def to_neo4j_edge( + self, + ) -> Union[ + "Neo4jHasFileEdge", + "Neo4jHasASTEdge", + "Neo4jParentOfEdge", + "Neo4jHasTextEdge", + "Neo4jNextChunkEdge", + ]: + """Convert the KnowledgeGraphEdge into a Neo4j edge format.""" + match self.type: + case KnowledgeGraphEdgeType.has_file: + return Neo4jHasFileEdge( + source=self.source.to_neo4j_node(), + target=self.target.to_neo4j_node(), + ) + case KnowledgeGraphEdgeType.has_ast: + return Neo4jHasASTEdge( + source=self.source.to_neo4j_node(), + target=self.target.to_neo4j_node(), + ) + case KnowledgeGraphEdgeType.parent_of: + return Neo4jParentOfEdge( + source=self.source.to_neo4j_node(), + target=self.target.to_neo4j_node(), + ) + case KnowledgeGraphEdgeType.has_text: + return Neo4jHasTextEdge( + source=self.source.to_neo4j_node(), + target=self.target.to_neo4j_node(), + ) + case KnowledgeGraphEdgeType.next_chunk: + return Neo4jNextChunkEdge( + source=self.source.to_neo4j_node(), + target=self.target.to_neo4j_node(), + ) + case _: + raise ValueError(f"Unknown edge type: {self.type}") ############################################################################### @@ -229,52 +229,52 @@ def to_neo4j_edge( class Neo4jMetadataNode(TypedDict): - codebase_source: str - local_path: str - https_url: str - commit_id: str + codebase_source: str + local_path: str + https_url: str + commit_id: str class Neo4jFileNode(TypedDict): - node_id: int - basename: str - relative_path: str + node_id: int + basename: str + relative_path: str class Neo4jASTNode(TypedDict): - node_id: int - type: str - start_line: int - end_line: int - text: str + node_id: int + type: str + start_line: int + end_line: int + text: str class Neo4jTextNode(TypedDict): - node_id: int - text: str - metadata: str + node_id: int + text: str + metadata: str class Neo4jHasFileEdge(TypedDict): - source: Neo4jFileNode - target: Neo4jFileNode + source: Neo4jFileNode + target: Neo4jFileNode class Neo4jHasASTEdge(TypedDict): - source: Neo4jFileNode - target: Neo4jASTNode + source: Neo4jFileNode + target: Neo4jASTNode class Neo4jParentOfEdge(TypedDict): - source: Neo4jASTNode - target: Neo4jASTNode + source: Neo4jASTNode + target: Neo4jASTNode class Neo4jHasTextEdge(TypedDict): - source: Neo4jFileNode - target: Neo4jTextNode + source: Neo4jFileNode + target: Neo4jTextNode class Neo4jNextChunkEdge(TypedDict): - source: Neo4jTextNode - target: Neo4jTextNode + source: Neo4jTextNode + target: Neo4jTextNode diff --git a/prometheus/graph/knowledge_graph.py b/prometheus/graph/knowledge_graph.py index 77b9c72c..d00c35ea 100644 --- a/prometheus/graph/knowledge_graph.py +++ b/prometheus/graph/knowledge_graph.py @@ -27,428 +27,436 @@ from prometheus.graph.file_graph_builder import FileGraphBuilder from prometheus.graph.graph_types import ( - ASTNode, - CodeBaseSourceEnum, - FileNode, - KnowledgeGraphEdge, - KnowledgeGraphEdgeType, - KnowledgeGraphNode, - MetadataNode, - Neo4jASTNode, - Neo4jFileNode, - Neo4jHasASTEdge, - Neo4jHasFileEdge, - Neo4jHasTextEdge, - Neo4jMetadataNode, - Neo4jNextChunkEdge, - Neo4jParentOfEdge, - Neo4jTextNode, - TextNode, + ASTNode, + CodeBaseSourceEnum, + FileNode, + KnowledgeGraphEdge, + KnowledgeGraphEdgeType, + KnowledgeGraphNode, + MetadataNode, + Neo4jASTNode, + Neo4jFileNode, + Neo4jHasASTEdge, + Neo4jHasFileEdge, + Neo4jHasTextEdge, + Neo4jMetadataNode, + Neo4jNextChunkEdge, + Neo4jParentOfEdge, + Neo4jTextNode, + TextNode, ) class KnowledgeGraph: - def __init__( - self, - max_ast_depth: int, - chunk_size: int, - chunk_overlap: int, - next_node_id: int = 0, - root_node: Optional[KnowledgeGraphNode] = None, - metadata_node: Optional[MetadataNode] = None, - knowledge_graph_nodes: Optional[Sequence[KnowledgeGraphNode]] = None, - knowledge_graph_edges: Optional[Sequence[KnowledgeGraphEdge]] = None, - ): - """Initializes the knowledge graph. - - Args: - max_ast_depth: The maximum depth of tree-sitter nodes to parse. - chunk_size: The chunk size for text files. - chunk_overlap: The overlap size for text files. - next_node_id: The next available node id. - root_node: The root node for the knowledge graph. - metadata_node: The metadata node for the knowledge graph. - knowledge_graph_nodes: The initial list of knowledge graph nodes. - knowledge_graph_edges: The initial list of knowledge graph edges. - """ - self.max_ast_depth = max_ast_depth - self._next_node_id = next_node_id - self._root_node = root_node - self._metadata_node = metadata_node - self._knowledge_graph_nodes = knowledge_graph_nodes if knowledge_graph_nodes is not None else [] - self._knowledge_graph_edges = knowledge_graph_edges if knowledge_graph_edges is not None else [] - self._file_graph_builder = FileGraphBuilder(max_ast_depth, chunk_size, chunk_overlap) - self._logger = logging.getLogger("prometheus.graph.knowledge_graph") - - def build_graph( - self, root_dir: Path, https_url: Optional[str] = None, commit_id: Optional[str] = None - ): - """Builds knowledege graph for a codebase at a location. - - Args: - root_dir: The codebase root directory. - """ - root_dir = root_dir.absolute() - gitignore_parser = igittigitt.IgnoreParser() - gitignore_parser.parse_rule_files(root_dir) - gitignore_parser.add_rule(".git", root_dir) - - if https_url is not None: - metadata_node = MetadataNode( - codebase_source=CodeBaseSourceEnum.github, - local_path=str(root_dir), - https_url=https_url, - commit_id=commit_id, - ) - else: - metadata_node = MetadataNode( - codebase_source=CodeBaseSourceEnum.local, - local_path=str(root_dir), - https_url=None, - commit_id=None, - ) - self._metadata_node = metadata_node - - # The root node for the whole graph - root_dir_node = FileNode(basename=root_dir.name, relative_path=".") - kg_root_dir_node = KnowledgeGraphNode(self._next_node_id, root_dir_node) - self._next_node_id += 1 - self._knowledge_graph_nodes.append(kg_root_dir_node) - self._root_node = kg_root_dir_node - - file_stack = deque() - file_stack.append((root_dir, kg_root_dir_node)) - - # Now we traverse the file system to parse all the files and create all relationships - while file_stack: - file, kg_file_path_node = file_stack.pop() - - # The file is a directory, we create FileNode for all children files. - if file.is_dir(): - self._logger.info(f"Processing directory {file}") - for child_file in sorted(file.iterdir()): - if gitignore_parser.match(child_file): - self._logger.info(f"Skipping {child_file} because it is ignored") - continue - - child_file_node = FileNode( - basename=child_file.name, - relative_path=child_file.relative_to(root_dir).as_posix(), - ) - kg_child_file_node = KnowledgeGraphNode(self._next_node_id, child_file_node) - self._next_node_id += 1 - self._knowledge_graph_nodes.append(kg_child_file_node) - self._knowledge_graph_edges.append( - KnowledgeGraphEdge( - kg_file_path_node, - kg_child_file_node, - KnowledgeGraphEdgeType.has_file, - ) - ) - - file_stack.append((child_file, kg_child_file_node)) - continue - - # The file is a file that file_graph_builder supports, it means that we can - # build a knowledge graph over it. - if self._file_graph_builder.supports_file(file): - self._logger.info(f"Processing file {file}") - try: - next_node_id, kg_nodes, kg_edges = self._file_graph_builder.build_file_graph( - kg_file_path_node, file, self._next_node_id - ) - except UnicodeDecodeError: - self._logger.warning(f"UnicodeDecodeError when processing {file}") - continue + def __init__( + self, + max_ast_depth: int, + chunk_size: int, + chunk_overlap: int, + next_node_id: int = 0, + root_node: Optional[KnowledgeGraphNode] = None, + metadata_node: Optional[MetadataNode] = None, + knowledge_graph_nodes: Optional[Sequence[KnowledgeGraphNode]] = None, + knowledge_graph_edges: Optional[Sequence[KnowledgeGraphEdge]] = None, + ): + """Initializes the knowledge graph. + + Args: + max_ast_depth: The maximum depth of tree-sitter nodes to parse. + chunk_size: The chunk size for text files. + chunk_overlap: The overlap size for text files. + next_node_id: The next available node id. + root_node: The root node for the knowledge graph. + metadata_node: The metadata node for the knowledge graph. + knowledge_graph_nodes: The initial list of knowledge graph nodes. + knowledge_graph_edges: The initial list of knowledge graph edges. + """ + self.max_ast_depth = max_ast_depth self._next_node_id = next_node_id - self._knowledge_graph_nodes.extend(kg_nodes) - self._knowledge_graph_edges.extend(kg_edges) - continue - - # This can only happend for files that are not supported by file_graph_builder. - self._logger.info(f"Skip parsing {file} because it is not supported") - - @classmethod - def from_neo4j( - cls, - metadata_node: MetadataNode, - file_nodes: Sequence[KnowledgeGraphNode], - ast_nodes: Sequence[KnowledgeGraphNode], - text_nodes: Sequence[KnowledgeGraphNode], - parent_of_edges_ids: Sequence[Mapping[str, int]], - has_file_edges_ids: Sequence[Mapping[str, int]], - has_ast_edges_ids: Sequence[Mapping[str, int]], - has_text_edges_ids: Sequence[Mapping[str, int]], - next_chunk_edges_ids: Sequence[Mapping[str, int]], - ): - """Creates a knowledge graph from nodes and edges stored in neo4j.""" - # All nodes - knowledge_graph_nodes = [x for x in itertools.chain(file_nodes, ast_nodes, text_nodes)] - - # All edges - node_id_to_node = {x.node_id: x for x in knowledge_graph_nodes} - parent_of_edges = [ - KnowledgeGraphEdge( - node_id_to_node[parent_of_edge_ids["source_id"]], - node_id_to_node[parent_of_edge_ids["target_id"]], - KnowledgeGraphEdgeType.parent_of, - ) - for parent_of_edge_ids in parent_of_edges_ids - ] - has_file_edges = [ - KnowledgeGraphEdge( - node_id_to_node[has_file_edge_ids["source_id"]], - node_id_to_node[has_file_edge_ids["target_id"]], - KnowledgeGraphEdgeType.has_file, - ) - for has_file_edge_ids in has_file_edges_ids - ] - has_ast_edges = [ - KnowledgeGraphEdge( - node_id_to_node[has_ast_edge_ids["source_id"]], - node_id_to_node[has_ast_edge_ids["target_id"]], - KnowledgeGraphEdgeType.has_ast, - ) - for has_ast_edge_ids in has_ast_edges_ids - ] - has_text_edges = [ - KnowledgeGraphEdge( - node_id_to_node[has_text_edge_ids["source_id"]], - node_id_to_node[has_text_edge_ids["target_id"]], - KnowledgeGraphEdgeType.has_text, - ) - for has_text_edge_ids in has_text_edges_ids - ] - next_chunk_edges = [ - KnowledgeGraphEdge( - node_id_to_node[next_chunk_edge_ids["source_id"]], - node_id_to_node[next_chunk_edge_ids["target_id"]], - KnowledgeGraphEdgeType.next_chunk, - ) - for next_chunk_edge_ids in next_chunk_edges_ids - ] - knowledge_graph_edges = [ - x - for x in itertools.chain( - parent_of_edges, has_file_edges, has_ast_edges, has_text_edges, next_chunk_edges - ) - ] - - # Root node - root_node = None - for node in knowledge_graph_nodes: - if node.node_id == 0: - root_node = node - break - if root_node is None: - raise ValueError("Node with node_id 0 not found.") - - return cls( - max_ast_depth=-1, - chunk_size=-1, - chunk_overlap=-1, - next_node_id=len(knowledge_graph_nodes), - root_node=root_node, - metadata_node=metadata_node, - knowledge_graph_nodes=knowledge_graph_nodes, - knowledge_graph_edges=knowledge_graph_edges, - ) - - def get_file_tree(self, max_depth: int = 5, max_lines: int = 5000) -> str: - """Generate a tree-like string representation of the file structure. - - Creates an ASCII tree visualization of the file hierarchy, similar to the Unix 'tree' - command output. The tree is generated using Unicode box-drawing characters and - indentation to show the hierarchical relationship between files and directories. - - Example: - project/ - โ”œโ”€โ”€ src/ - โ”‚ โ”œโ”€โ”€ main.py - โ”‚ โ””โ”€โ”€ utils/ - โ”‚ โ”œโ”€โ”€ helpers.py - โ”‚ โ””โ”€โ”€ config.py - โ””โ”€โ”€ tests/ - โ”œโ”€โ”€ test_main.py - โ””โ”€โ”€ test_utils.py - - Args: - max_depth: Maximum depth of the tree to display. Nodes beyond this depth will - be omitted. Defaults to 5. - max_lines: Maximum number of lines in the output string. Useful for truncating - very large trees. Defaults to 5000. - - Returns: - str: A string representation of the file tree, where each line represents a file - or directory, with appropriate indentation and connecting lines showing - the hierarchy. - """ - file_node_adjacency_dict = self._get_file_node_adjacency_dict() - - stack = deque() - stack.append((self._root_node, 0, "", None)) - result_lines = [] - - SPACE = " " - BRANCH = "| " - TEE = "โ”œโ”€โ”€ " - LAST = "โ””โ”€โ”€ " - while stack and (len(result_lines)) < max_lines: - file_node, depth, prefix, is_last = stack.pop() - - # Skip if we've exceeded max_depth - if depth > max_depth: - continue - - pointer = LAST if is_last else TEE - line_prefix = "" if depth == 0 else prefix + pointer - result_lines.append(line_prefix + file_node.node.basename) - - sorted_children_file_node = sorted( - file_node_adjacency_dict[file_node], key=lambda x: x.node.basename - ) - for i in range(len(sorted_children_file_node) - 1, -1, -1): - extension = SPACE if is_last else BRANCH - new_prefix = "" if depth == 0 else prefix + extension - stack.append( - ( - sorted_children_file_node[i], - depth + 1, - new_prefix, - i == len(sorted_children_file_node) - 1, - ) + self._root_node = root_node + self._metadata_node = metadata_node + self._knowledge_graph_nodes = ( + knowledge_graph_nodes if knowledge_graph_nodes is not None else [] ) - return "\n".join(result_lines) - - def get_all_ast_node_types(self) -> Sequence[str]: - ast_node_types = set() - for ast_node in self.get_ast_nodes(): - ast_node_types.add(ast_node.node.type) - return list(ast_node_types) - - def is_built_from_local_codebase(self) -> bool: - return self._metadata_node.codebase_source == CodeBaseSourceEnum.local - - def is_built_from_github(self) -> bool: - return self._metadata_node.codebase_source == CodeBaseSourceEnum.github - - def get_local_path(self) -> Path: - return Path(self._metadata_node.local_path).absolute() - - def get_codebase_https_url(self) -> str: - return self._metadata_node.https_url - - def get_codebase_commit_id(self) -> str: - return self._metadata_node.commit_id - - def _get_file_node_adjacency_dict( - self, - ) -> Mapping[KnowledgeGraphNode, Sequence[KnowledgeGraphNode]]: - file_node_adjacency_dict = defaultdict(list) - for has_file_edge in self.get_has_file_edges(): - file_node_adjacency_dict[has_file_edge.source].append(has_file_edge.target) - return file_node_adjacency_dict - - def get_metadata_node(self) -> MetadataNode: - return self._metadata_node - - def get_file_nodes(self) -> Sequence[KnowledgeGraphNode]: - return [ - kg_node for kg_node in self._knowledge_graph_nodes if isinstance(kg_node.node, FileNode) - ] - - def get_ast_nodes(self) -> Sequence[KnowledgeGraphNode]: - return [kg_node for kg_node in self._knowledge_graph_nodes if isinstance(kg_node.node, ASTNode)] - - def get_text_nodes(self) -> Sequence[KnowledgeGraphNode]: - return [ - kg_node for kg_node in self._knowledge_graph_nodes if isinstance(kg_node.node, TextNode) - ] - - def get_has_ast_edges(self) -> Sequence[KnowledgeGraphEdge]: - return [ - kg_edge - for kg_edge in self._knowledge_graph_edges - if kg_edge.type == KnowledgeGraphEdgeType.has_ast - ] - - def get_has_file_edges(self) -> Sequence[KnowledgeGraphEdge]: - return [ - kg_edge - for kg_edge in self._knowledge_graph_edges - if kg_edge.type == KnowledgeGraphEdgeType.has_file - ] - - def get_has_text_edges(self) -> Sequence[KnowledgeGraphEdge]: - return [ - kg_edge - for kg_edge in self._knowledge_graph_edges - if kg_edge.type == KnowledgeGraphEdgeType.has_text - ] - - def get_next_chunk_edges(self) -> Sequence[KnowledgeGraphEdge]: - return [ - kg_edge - for kg_edge in self._knowledge_graph_edges - if kg_edge.type == KnowledgeGraphEdgeType.next_chunk - ] - - def get_parent_of_edges(self) -> Sequence[KnowledgeGraphEdge]: - return [ - kg_edge - for kg_edge in self._knowledge_graph_edges - if kg_edge.type == KnowledgeGraphEdgeType.parent_of - ] - - def get_neo4j_metadata_node(self) -> Neo4jMetadataNode: - return self._metadata_node.to_neo4j_node() - - def get_neo4j_file_nodes(self) -> Sequence[Neo4jFileNode]: - return [kg_node.to_neo4j_node() for kg_node in self.get_file_nodes()] - - def get_neo4j_ast_nodes(self) -> Sequence[Neo4jASTNode]: - return [kg_node.to_neo4j_node() for kg_node in self.get_ast_nodes()] - - def get_neo4j_text_nodes(self) -> Sequence[Neo4jTextNode]: - return [kg_node.to_neo4j_node() for kg_node in self.get_text_nodes()] - - def get_neo4j_has_ast_edges(self) -> Sequence[Neo4jHasASTEdge]: - return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_ast_edges()] - - def get_neo4j_has_file_edges(self) -> Sequence[Neo4jHasFileEdge]: - return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_file_edges()] - - def get_neo4j_has_text_edges(self) -> Sequence[Neo4jHasTextEdge]: - return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_text_edges()] - - def get_neo4j_next_chunk_edges(self) -> Sequence[Neo4jNextChunkEdge]: - return [kg_edge.to_neo4j_edge() for kg_edge in self.get_next_chunk_edges()] - - def get_neo4j_parent_of_edges(self) -> Sequence[Neo4jParentOfEdge]: - return [kg_edge.to_neo4j_edge() for kg_edge in self.get_parent_of_edges()] - - def __eq__(self, other: "KnowledgeGraph") -> bool: - if not isinstance(other, KnowledgeGraph): - return False - - if self._metadata_node != other._metadata_node: - return False - - self._knowledge_graph_nodes.sort(key=lambda x: x.node_id) - other._knowledge_graph_nodes.sort(key=lambda x: x.node_id) - - for self_kg_node, other_kg_node in itertools.zip_longest( - self._knowledge_graph_nodes, other._knowledge_graph_nodes, fillvalue=None - ): - if self_kg_node != other_kg_node: - return False + self._knowledge_graph_edges = ( + knowledge_graph_edges if knowledge_graph_edges is not None else [] + ) + self._file_graph_builder = FileGraphBuilder(max_ast_depth, chunk_size, chunk_overlap) + self._logger = logging.getLogger("prometheus.graph.knowledge_graph") - self._knowledge_graph_edges.sort(key=lambda x: (x.source.node_id, x.target.node_id, x.type)) - other._knowledge_graph_edges.sort(key=lambda x: (x.source.node_id, x.target.node_id, x.type)) - for self_kg_edge, other_kg_edge in itertools.zip_longest( - self._knowledge_graph_edges, other._knowledge_graph_edges, fillvalue=None + def build_graph( + self, root_dir: Path, https_url: Optional[str] = None, commit_id: Optional[str] = None + ): + """Builds knowledege graph for a codebase at a location. + + Args: + root_dir: The codebase root directory. + """ + root_dir = root_dir.absolute() + gitignore_parser = igittigitt.IgnoreParser() + gitignore_parser.parse_rule_files(root_dir) + gitignore_parser.add_rule(".git", root_dir) + + if https_url is not None: + metadata_node = MetadataNode( + codebase_source=CodeBaseSourceEnum.github, + local_path=str(root_dir), + https_url=https_url, + commit_id=commit_id, + ) + else: + metadata_node = MetadataNode( + codebase_source=CodeBaseSourceEnum.local, + local_path=str(root_dir), + https_url=None, + commit_id=None, + ) + self._metadata_node = metadata_node + + # The root node for the whole graph + root_dir_node = FileNode(basename=root_dir.name, relative_path=".") + kg_root_dir_node = KnowledgeGraphNode(self._next_node_id, root_dir_node) + self._next_node_id += 1 + self._knowledge_graph_nodes.append(kg_root_dir_node) + self._root_node = kg_root_dir_node + + file_stack = deque() + file_stack.append((root_dir, kg_root_dir_node)) + + # Now we traverse the file system to parse all the files and create all relationships + while file_stack: + file, kg_file_path_node = file_stack.pop() + + # The file is a directory, we create FileNode for all children files. + if file.is_dir(): + self._logger.info(f"Processing directory {file}") + for child_file in sorted(file.iterdir()): + if gitignore_parser.match(child_file): + self._logger.info(f"Skipping {child_file} because it is ignored") + continue + + child_file_node = FileNode( + basename=child_file.name, + relative_path=child_file.relative_to(root_dir).as_posix(), + ) + kg_child_file_node = KnowledgeGraphNode(self._next_node_id, child_file_node) + self._next_node_id += 1 + self._knowledge_graph_nodes.append(kg_child_file_node) + self._knowledge_graph_edges.append( + KnowledgeGraphEdge( + kg_file_path_node, + kg_child_file_node, + KnowledgeGraphEdgeType.has_file, + ) + ) + + file_stack.append((child_file, kg_child_file_node)) + continue + + # The file is a file that file_graph_builder supports, it means that we can + # build a knowledge graph over it. + if self._file_graph_builder.supports_file(file): + self._logger.info(f"Processing file {file}") + try: + next_node_id, kg_nodes, kg_edges = self._file_graph_builder.build_file_graph( + kg_file_path_node, file, self._next_node_id + ) + except UnicodeDecodeError: + self._logger.warning(f"UnicodeDecodeError when processing {file}") + continue + self._next_node_id = next_node_id + self._knowledge_graph_nodes.extend(kg_nodes) + self._knowledge_graph_edges.extend(kg_edges) + continue + + # This can only happend for files that are not supported by file_graph_builder. + self._logger.info(f"Skip parsing {file} because it is not supported") + + @classmethod + def from_neo4j( + cls, + metadata_node: MetadataNode, + file_nodes: Sequence[KnowledgeGraphNode], + ast_nodes: Sequence[KnowledgeGraphNode], + text_nodes: Sequence[KnowledgeGraphNode], + parent_of_edges_ids: Sequence[Mapping[str, int]], + has_file_edges_ids: Sequence[Mapping[str, int]], + has_ast_edges_ids: Sequence[Mapping[str, int]], + has_text_edges_ids: Sequence[Mapping[str, int]], + next_chunk_edges_ids: Sequence[Mapping[str, int]], ): - if self_kg_edge != other_kg_edge: - return False + """Creates a knowledge graph from nodes and edges stored in neo4j.""" + # All nodes + knowledge_graph_nodes = [x for x in itertools.chain(file_nodes, ast_nodes, text_nodes)] + + # All edges + node_id_to_node = {x.node_id: x for x in knowledge_graph_nodes} + parent_of_edges = [ + KnowledgeGraphEdge( + node_id_to_node[parent_of_edge_ids["source_id"]], + node_id_to_node[parent_of_edge_ids["target_id"]], + KnowledgeGraphEdgeType.parent_of, + ) + for parent_of_edge_ids in parent_of_edges_ids + ] + has_file_edges = [ + KnowledgeGraphEdge( + node_id_to_node[has_file_edge_ids["source_id"]], + node_id_to_node[has_file_edge_ids["target_id"]], + KnowledgeGraphEdgeType.has_file, + ) + for has_file_edge_ids in has_file_edges_ids + ] + has_ast_edges = [ + KnowledgeGraphEdge( + node_id_to_node[has_ast_edge_ids["source_id"]], + node_id_to_node[has_ast_edge_ids["target_id"]], + KnowledgeGraphEdgeType.has_ast, + ) + for has_ast_edge_ids in has_ast_edges_ids + ] + has_text_edges = [ + KnowledgeGraphEdge( + node_id_to_node[has_text_edge_ids["source_id"]], + node_id_to_node[has_text_edge_ids["target_id"]], + KnowledgeGraphEdgeType.has_text, + ) + for has_text_edge_ids in has_text_edges_ids + ] + next_chunk_edges = [ + KnowledgeGraphEdge( + node_id_to_node[next_chunk_edge_ids["source_id"]], + node_id_to_node[next_chunk_edge_ids["target_id"]], + KnowledgeGraphEdgeType.next_chunk, + ) + for next_chunk_edge_ids in next_chunk_edges_ids + ] + knowledge_graph_edges = [ + x + for x in itertools.chain( + parent_of_edges, has_file_edges, has_ast_edges, has_text_edges, next_chunk_edges + ) + ] + + # Root node + root_node = None + for node in knowledge_graph_nodes: + if node.node_id == 0: + root_node = node + break + if root_node is None: + raise ValueError("Node with node_id 0 not found.") + + return cls( + max_ast_depth=-1, + chunk_size=-1, + chunk_overlap=-1, + next_node_id=len(knowledge_graph_nodes), + root_node=root_node, + metadata_node=metadata_node, + knowledge_graph_nodes=knowledge_graph_nodes, + knowledge_graph_edges=knowledge_graph_edges, + ) + + def get_file_tree(self, max_depth: int = 5, max_lines: int = 5000) -> str: + """Generate a tree-like string representation of the file structure. + + Creates an ASCII tree visualization of the file hierarchy, similar to the Unix 'tree' + command output. The tree is generated using Unicode box-drawing characters and + indentation to show the hierarchical relationship between files and directories. + + Example: + project/ + โ”œโ”€โ”€ src/ + โ”‚ โ”œโ”€โ”€ main.py + โ”‚ โ””โ”€โ”€ utils/ + โ”‚ โ”œโ”€โ”€ helpers.py + โ”‚ โ””โ”€โ”€ config.py + โ””โ”€โ”€ tests/ + โ”œโ”€โ”€ test_main.py + โ””โ”€โ”€ test_utils.py + + Args: + max_depth: Maximum depth of the tree to display. Nodes beyond this depth will + be omitted. Defaults to 5. + max_lines: Maximum number of lines in the output string. Useful for truncating + very large trees. Defaults to 5000. + + Returns: + str: A string representation of the file tree, where each line represents a file + or directory, with appropriate indentation and connecting lines showing + the hierarchy. + """ + file_node_adjacency_dict = self._get_file_node_adjacency_dict() + + stack = deque() + stack.append((self._root_node, 0, "", None)) + result_lines = [] + + SPACE = " " + BRANCH = "| " + TEE = "โ”œโ”€โ”€ " + LAST = "โ””โ”€โ”€ " + while stack and (len(result_lines)) < max_lines: + file_node, depth, prefix, is_last = stack.pop() + + # Skip if we've exceeded max_depth + if depth > max_depth: + continue + + pointer = LAST if is_last else TEE + line_prefix = "" if depth == 0 else prefix + pointer + result_lines.append(line_prefix + file_node.node.basename) + + sorted_children_file_node = sorted( + file_node_adjacency_dict[file_node], key=lambda x: x.node.basename + ) + for i in range(len(sorted_children_file_node) - 1, -1, -1): + extension = SPACE if is_last else BRANCH + new_prefix = "" if depth == 0 else prefix + extension + stack.append( + ( + sorted_children_file_node[i], + depth + 1, + new_prefix, + i == len(sorted_children_file_node) - 1, + ) + ) + return "\n".join(result_lines) + + def get_all_ast_node_types(self) -> Sequence[str]: + ast_node_types = set() + for ast_node in self.get_ast_nodes(): + ast_node_types.add(ast_node.node.type) + return list(ast_node_types) + + def is_built_from_local_codebase(self) -> bool: + return self._metadata_node.codebase_source == CodeBaseSourceEnum.local + + def is_built_from_github(self) -> bool: + return self._metadata_node.codebase_source == CodeBaseSourceEnum.github + + def get_local_path(self) -> Path: + return Path(self._metadata_node.local_path).absolute() + + def get_codebase_https_url(self) -> str: + return self._metadata_node.https_url + + def get_codebase_commit_id(self) -> str: + return self._metadata_node.commit_id + + def _get_file_node_adjacency_dict( + self, + ) -> Mapping[KnowledgeGraphNode, Sequence[KnowledgeGraphNode]]: + file_node_adjacency_dict = defaultdict(list) + for has_file_edge in self.get_has_file_edges(): + file_node_adjacency_dict[has_file_edge.source].append(has_file_edge.target) + return file_node_adjacency_dict + + def get_metadata_node(self) -> MetadataNode: + return self._metadata_node + + def get_file_nodes(self) -> Sequence[KnowledgeGraphNode]: + return [ + kg_node for kg_node in self._knowledge_graph_nodes if isinstance(kg_node.node, FileNode) + ] + + def get_ast_nodes(self) -> Sequence[KnowledgeGraphNode]: + return [ + kg_node for kg_node in self._knowledge_graph_nodes if isinstance(kg_node.node, ASTNode) + ] + + def get_text_nodes(self) -> Sequence[KnowledgeGraphNode]: + return [ + kg_node for kg_node in self._knowledge_graph_nodes if isinstance(kg_node.node, TextNode) + ] + + def get_has_ast_edges(self) -> Sequence[KnowledgeGraphEdge]: + return [ + kg_edge + for kg_edge in self._knowledge_graph_edges + if kg_edge.type == KnowledgeGraphEdgeType.has_ast + ] + + def get_has_file_edges(self) -> Sequence[KnowledgeGraphEdge]: + return [ + kg_edge + for kg_edge in self._knowledge_graph_edges + if kg_edge.type == KnowledgeGraphEdgeType.has_file + ] + + def get_has_text_edges(self) -> Sequence[KnowledgeGraphEdge]: + return [ + kg_edge + for kg_edge in self._knowledge_graph_edges + if kg_edge.type == KnowledgeGraphEdgeType.has_text + ] + + def get_next_chunk_edges(self) -> Sequence[KnowledgeGraphEdge]: + return [ + kg_edge + for kg_edge in self._knowledge_graph_edges + if kg_edge.type == KnowledgeGraphEdgeType.next_chunk + ] + + def get_parent_of_edges(self) -> Sequence[KnowledgeGraphEdge]: + return [ + kg_edge + for kg_edge in self._knowledge_graph_edges + if kg_edge.type == KnowledgeGraphEdgeType.parent_of + ] + + def get_neo4j_metadata_node(self) -> Neo4jMetadataNode: + return self._metadata_node.to_neo4j_node() + + def get_neo4j_file_nodes(self) -> Sequence[Neo4jFileNode]: + return [kg_node.to_neo4j_node() for kg_node in self.get_file_nodes()] + + def get_neo4j_ast_nodes(self) -> Sequence[Neo4jASTNode]: + return [kg_node.to_neo4j_node() for kg_node in self.get_ast_nodes()] + + def get_neo4j_text_nodes(self) -> Sequence[Neo4jTextNode]: + return [kg_node.to_neo4j_node() for kg_node in self.get_text_nodes()] + + def get_neo4j_has_ast_edges(self) -> Sequence[Neo4jHasASTEdge]: + return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_ast_edges()] + + def get_neo4j_has_file_edges(self) -> Sequence[Neo4jHasFileEdge]: + return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_file_edges()] + + def get_neo4j_has_text_edges(self) -> Sequence[Neo4jHasTextEdge]: + return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_text_edges()] + + def get_neo4j_next_chunk_edges(self) -> Sequence[Neo4jNextChunkEdge]: + return [kg_edge.to_neo4j_edge() for kg_edge in self.get_next_chunk_edges()] + + def get_neo4j_parent_of_edges(self) -> Sequence[Neo4jParentOfEdge]: + return [kg_edge.to_neo4j_edge() for kg_edge in self.get_parent_of_edges()] + + def __eq__(self, other: "KnowledgeGraph") -> bool: + if not isinstance(other, KnowledgeGraph): + return False + + if self._metadata_node != other._metadata_node: + return False + + self._knowledge_graph_nodes.sort(key=lambda x: x.node_id) + other._knowledge_graph_nodes.sort(key=lambda x: x.node_id) + + for self_kg_node, other_kg_node in itertools.zip_longest( + self._knowledge_graph_nodes, other._knowledge_graph_nodes, fillvalue=None + ): + if self_kg_node != other_kg_node: + return False + + self._knowledge_graph_edges.sort(key=lambda x: (x.source.node_id, x.target.node_id, x.type)) + other._knowledge_graph_edges.sort( + key=lambda x: (x.source.node_id, x.target.node_id, x.type) + ) + for self_kg_edge, other_kg_edge in itertools.zip_longest( + self._knowledge_graph_edges, other._knowledge_graph_edges, fillvalue=None + ): + if self_kg_edge != other_kg_edge: + return False - return True + return True diff --git a/prometheus/lang_graph/graphs/issue_graph.py b/prometheus/lang_graph/graphs/issue_graph.py index baac95e6..834be6be 100644 --- a/prometheus/lang_graph/graphs/issue_graph.py +++ b/prometheus/lang_graph/graphs/issue_graph.py @@ -10,101 +10,101 @@ from prometheus.lang_graph.graphs.issue_state import IssueState, IssueType from prometheus.lang_graph.nodes.issue_bug_subgraph_node import IssueBugSubgraphNode from prometheus.lang_graph.nodes.issue_classification_subgraph_node import ( - IssueClassificationSubgraphNode, + IssueClassificationSubgraphNode, ) from prometheus.lang_graph.nodes.noop_node import NoopNode class IssueGraph: - def __init__( - self, - advanced_model: BaseChatModel, - base_model: BaseChatModel, - kg: KnowledgeGraph, - git_repo: GitRepository, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - container: BaseContainer, - build_commands: Optional[Sequence[str]] = None, - test_commands: Optional[Sequence[str]] = None, - ): - self.git_repo = git_repo + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + container: BaseContainer, + build_commands: Optional[Sequence[str]] = None, + test_commands: Optional[Sequence[str]] = None, + ): + self.git_repo = git_repo - issue_type_branch_node = NoopNode() - issue_classification_subgraph_node = IssueClassificationSubgraphNode( - model=base_model, - kg=kg, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - ) - issue_bug_subgraph_node = IssueBugSubgraphNode( - advanced_model=advanced_model, - base_model=base_model, - container=container, - kg=kg, - git_repo=git_repo, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - build_commands=build_commands, - test_commands=test_commands, - ) + issue_type_branch_node = NoopNode() + issue_classification_subgraph_node = IssueClassificationSubgraphNode( + model=base_model, + kg=kg, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + ) + issue_bug_subgraph_node = IssueBugSubgraphNode( + advanced_model=advanced_model, + base_model=base_model, + container=container, + kg=kg, + git_repo=git_repo, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + build_commands=build_commands, + test_commands=test_commands, + ) - workflow = StateGraph(IssueState) + workflow = StateGraph(IssueState) - workflow.add_node("issue_type_branch_node", issue_type_branch_node) - workflow.add_node("issue_classification_subgraph_node", issue_classification_subgraph_node) - workflow.add_node("issue_bug_subgraph_node", issue_bug_subgraph_node) + workflow.add_node("issue_type_branch_node", issue_type_branch_node) + workflow.add_node("issue_classification_subgraph_node", issue_classification_subgraph_node) + workflow.add_node("issue_bug_subgraph_node", issue_bug_subgraph_node) - workflow.set_entry_point("issue_type_branch_node") - workflow.add_conditional_edges( - "issue_type_branch_node", - lambda state: state["issue_type"], - { - IssueType.AUTO: "issue_classification_subgraph_node", - IssueType.BUG: "issue_bug_subgraph_node", - IssueType.FEATURE: END, - IssueType.DOCUMENTATION: END, - IssueType.QUESTION: END, - }, - ) - workflow.add_conditional_edges( - "issue_classification_subgraph_node", - lambda state: state["issue_type"], - { - IssueType.BUG: "issue_bug_subgraph_node", - IssueType.FEATURE: END, - IssueType.DOCUMENTATION: END, - IssueType.QUESTION: END, - }, - ) - workflow.add_edge("issue_bug_subgraph_node", END) + workflow.set_entry_point("issue_type_branch_node") + workflow.add_conditional_edges( + "issue_type_branch_node", + lambda state: state["issue_type"], + { + IssueType.AUTO: "issue_classification_subgraph_node", + IssueType.BUG: "issue_bug_subgraph_node", + IssueType.FEATURE: END, + IssueType.DOCUMENTATION: END, + IssueType.QUESTION: END, + }, + ) + workflow.add_conditional_edges( + "issue_classification_subgraph_node", + lambda state: state["issue_type"], + { + IssueType.BUG: "issue_bug_subgraph_node", + IssueType.FEATURE: END, + IssueType.DOCUMENTATION: END, + IssueType.QUESTION: END, + }, + ) + workflow.add_edge("issue_bug_subgraph_node", END) - self.graph = workflow.compile() + self.graph = workflow.compile() - def invoke( - self, - issue_title: str, - issue_body: str, - issue_comments: Sequence[Mapping[str, str]], - issue_type: IssueType, - run_build: bool, - run_existing_test: bool, - number_of_candidate_patch: int, - ): - config = None + def invoke( + self, + issue_title: str, + issue_body: str, + issue_comments: Sequence[Mapping[str, str]], + issue_type: IssueType, + run_build: bool, + run_existing_test: bool, + number_of_candidate_patch: int, + ): + config = None - input_state = { - "issue_title": issue_title, - "issue_body": issue_body, - "issue_comments": issue_comments, - "issue_type": issue_type, - "run_build": run_build, - "run_existing_test": run_existing_test, - "number_of_candidate_patch": number_of_candidate_patch, - } + input_state = { + "issue_title": issue_title, + "issue_body": issue_body, + "issue_comments": issue_comments, + "issue_type": issue_type, + "run_build": run_build, + "run_existing_test": run_existing_test, + "number_of_candidate_patch": number_of_candidate_patch, + } - output_state = self.graph.invoke(input_state, config) + output_state = self.graph.invoke(input_state, config) - self.git_repo.reset_repository() + self.git_repo.reset_repository() - return output_state + return output_state diff --git a/prometheus/lang_graph/graphs/issue_state.py b/prometheus/lang_graph/graphs/issue_state.py index e579f204..ee460f03 100644 --- a/prometheus/lang_graph/graphs/issue_state.py +++ b/prometheus/lang_graph/graphs/issue_state.py @@ -3,27 +3,27 @@ class IssueType(StrEnum): - AUTO = "auto" - BUG = "bug" - FEATURE = "feature" - DOCUMENTATION = "documentation" - QUESTION = "question" + AUTO = "auto" + BUG = "bug" + FEATURE = "feature" + DOCUMENTATION = "documentation" + QUESTION = "question" class IssueState(TypedDict): - # Attributes provided by the user - issue_title: str - issue_body: str - issue_comments: Sequence[Mapping[str, str]] - issue_type: IssueType - run_build: bool - run_existing_test: bool - number_of_candidate_patch: int + # Attributes provided by the user + issue_title: str + issue_body: str + issue_comments: Sequence[Mapping[str, str]] + issue_type: IssueType + run_build: bool + run_existing_test: bool + number_of_candidate_patch: int - edit_patch: str + edit_patch: str - passed_reproducing_test: bool - passed_build: bool - passed_existing_test: bool + passed_reproducing_test: bool + passed_build: bool + passed_existing_test: bool - issue_response: str + issue_response: str diff --git a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py index eba5c83c..25620677 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verification_subgraph_node.py @@ -8,34 +8,36 @@ class BugFixVerificationSubgraphNode: - def __init__( - self, - model: BaseChatModel, - container: BaseContainer, - ): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_fix_verification_subgraph_node" - ) - self.subgraph = BugFixVerificationSubgraph( - model=model, - container=container, - ) - - def __call__(self, state: IssueBugState): - self._logger.info("Enter bug_fix_verification_subgraph_node") - self._logger.debug(f"reproduced_bug_file: {state['reproduced_bug_file']}") - self._logger.debug(f"reproduced_bug_commands: {state['reproduced_bug_commands']}") - - output_state = self.subgraph.invoke( - reproduced_bug_file=state["reproduced_bug_file"], - reproduced_bug_commands=state["reproduced_bug_commands"], - ) - - self._logger.info( - f"Passing bug reproducing test: {not bool(output_state['reproducing_test_fail_log'])}" - ) - self._logger.debug(f"reproducing_test_fail_log: {output_state['reproducing_test_fail_log']}") - - return { - "reproducing_test_fail_log": output_state["reproducing_test_fail_log"], - } + def __init__( + self, + model: BaseChatModel, + container: BaseContainer, + ): + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.bug_fix_verification_subgraph_node" + ) + self.subgraph = BugFixVerificationSubgraph( + model=model, + container=container, + ) + + def __call__(self, state: IssueBugState): + self._logger.info("Enter bug_fix_verification_subgraph_node") + self._logger.debug(f"reproduced_bug_file: {state['reproduced_bug_file']}") + self._logger.debug(f"reproduced_bug_commands: {state['reproduced_bug_commands']}") + + output_state = self.subgraph.invoke( + reproduced_bug_file=state["reproduced_bug_file"], + reproduced_bug_commands=state["reproduced_bug_commands"], + ) + + self._logger.info( + f"Passing bug reproducing test: {not bool(output_state['reproducing_test_fail_log'])}" + ) + self._logger.debug( + f"reproducing_test_fail_log: {output_state['reproducing_test_fail_log']}" + ) + + return { + "reproducing_test_fail_log": output_state["reproducing_test_fail_log"], + } diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_node.py index ad459db4..ae397d22 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_node.py @@ -11,7 +11,7 @@ class BugFixVerifyNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are a bug fix verification agent. Your role is to verify whether a bug has been fixed by running the reproduction steps and reporting the results accurately. Your tasks are to: @@ -37,7 +37,7 @@ class BugFixVerifyNode: Remember: Your only job is to execute the commands and report results faithfully. Do not offer suggestions, analyze results, or try to fix issues. """ - HUMAN_PROMPT = """\ + HUMAN_PROMPT = """\ Reproducing bug file: {reproduced_bug_file} @@ -45,39 +45,39 @@ class BugFixVerifyNode: {reproduced_bug_commands} """ - def __init__(self, model: BaseChatModel, container: BaseContainer): - self.tools = self._init_tools(container) - self.model_with_tools = model.bind_tools(self.tools) - self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_verify_node") - - def _init_tools(self, container: BaseContainer): - tools = [] - - run_command_fn = functools.partial(container_command.run_command, container=container) - run_command_tool = StructuredTool.from_function( - func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, - ) - tools.append(run_command_tool) - - return tools - - def format_human_message(self, state: BugFixVerficationState) -> HumanMessage: - return HumanMessage( - self.HUMAN_PROMPT.format( - reproduced_bug_file=state["reproduced_bug_file"], - reproduced_bug_commands=state["reproduced_bug_commands"], - ) - ) - - def __call__(self, state: BugFixVerficationState): - human_message = self.format_human_message(state) - message_history = [self.system_prompt, human_message] + state["bug_fix_verify_messages"] - - response = self.model_with_tools.invoke(message_history) - - self._logger.debug(response) - return {"bug_fix_verify_messages": [response]} + def __init__(self, model: BaseChatModel, container: BaseContainer): + self.tools = self._init_tools(container) + self.model_with_tools = model.bind_tools(self.tools) + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_verify_node") + + def _init_tools(self, container: BaseContainer): + tools = [] + + run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_tool = StructuredTool.from_function( + func=run_command_fn, + name=container_command.run_command.__name__, + description=container_command.RUN_COMMAND_DESCRIPTION, + args_schema=container_command.RunCommandInput, + ) + tools.append(run_command_tool) + + return tools + + def format_human_message(self, state: BugFixVerficationState) -> HumanMessage: + return HumanMessage( + self.HUMAN_PROMPT.format( + reproduced_bug_file=state["reproduced_bug_file"], + reproduced_bug_commands=state["reproduced_bug_commands"], + ) + ) + + def __call__(self, state: BugFixVerficationState): + human_message = self.format_human_message(state) + message_history = [self.system_prompt, human_message] + state["bug_fix_verify_messages"] + + response = self.model_with_tools.invoke(message_history) + + self._logger.debug(response) + return {"bug_fix_verify_messages": [response]} diff --git a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py index 5d3c66dc..ac4fb533 100644 --- a/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_fix_verify_structured_node.py @@ -9,13 +9,13 @@ class BugFixVerifyStructureOutput(BaseModel): - reproducing_test_fail_log: str = Field( - description="If the test failed, contains the complete test failure log. Otherwise empty string" - ) + reproducing_test_fail_log: str = Field( + description="If the test failed, contains the complete test failure log. Otherwise empty string" + ) class BugFixVerifyStructuredNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are a test result parser. Your only task is to check if the bug reproducing test now passes after code changes. Your task is to: @@ -84,20 +84,22 @@ def test_empty_array_parsing(): - Any error or failure means the bug isn't fixed yet """.replace("{", "{{").replace("}", "}}") - def __init__(self, model: BaseChatModel): - prompt = ChatPromptTemplate.from_messages( - [("system", self.SYS_PROMPT), ("human", "{bug_reproducing_logs}")] - ) - structured_llm = model.with_structured_output(BugFixVerifyStructureOutput) - self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_fix_verify_structured_node") - - def __call__(self, state: BugFixVerficationState): - bug_fix_verify_message = get_last_message_content(state["bug_fix_verify_messages"]) - response = self.model.invoke({"bug_reproducing_logs": bug_fix_verify_message}) - - self._logger.debug(response) - - return { - "reproducing_test_fail_log": response.reproducing_test_fail_log, - } + def __init__(self, model: BaseChatModel): + prompt = ChatPromptTemplate.from_messages( + [("system", self.SYS_PROMPT), ("human", "{bug_reproducing_logs}")] + ) + structured_llm = model.with_structured_output(BugFixVerifyStructureOutput) + self.model = prompt | structured_llm + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.bug_fix_verify_structured_node" + ) + + def __call__(self, state: BugFixVerficationState): + bug_fix_verify_message = get_last_message_content(state["bug_fix_verify_messages"]) + response = self.model.invoke({"bug_reproducing_logs": bug_fix_verify_message}) + + self._logger.debug(response) + + return { + "reproducing_test_fail_log": response.reproducing_test_fail_log, + } diff --git a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py index 1cab808b..4434984a 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_execute_node.py @@ -14,7 +14,7 @@ class BugReproducingExecuteNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are a testing expert focused solely on executing THE SINGLE bug reproduction test file. Your only goal is to run the test file created by the previous agent and return its output as it is. @@ -28,7 +28,7 @@ class BugReproducingExecuteNode: * STOP TRYING IF THE TEST EXECUTES. """ - HUMAN_PROMPT = """\ + HUMAN_PROMPT = """\ ISSUE INFORMATION: Title: {title} Description: {body} @@ -41,77 +41,79 @@ class BugReproducingExecuteNode: {test_commands} """ - def __init__( - self, - model: BaseChatModel, - container: BaseContainer, - test_commands: Optional[Sequence[str]] = None, - ): - self.test_commands = test_commands - self.tools = self._init_tools(container) - self.model_with_tools = model.bind_tools(self.tools) - self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_execute_node") - - def _init_tools(self, container: BaseContainer): - tools = [] - - run_command_fn = functools.partial(container_command.run_command, container=container) - run_command_tool = StructuredTool.from_function( - func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, - ) - tools.append(run_command_tool) - - return tools - - def added_test_filename(self, state: BugReproductionState) -> str: - added_files, modified_file, removed_files = get_updated_files(state["bug_reproducing_patch"]) - if removed_files: - raise ValueError("The bug reproducing patch delete files") - if modified_file: - raise ValueError("The bug reproducing patch modified existing files") - if len(added_files) != 1: - raise ValueError("The bug reproducing patch added not one files") - return added_files[0] - - def format_human_message( - self, state: BugReproductionState, reproduced_bug_file: str - ) -> HumanMessage: - test_commands_str = "" - if self.test_commands: - test_commands_str = format_test_commands(self.test_commands) - return HumanMessage( - self.HUMAN_PROMPT.format( - title=state["issue_title"], - body=state["issue_body"], - comments=state["issue_comments"], - reproduced_bug_file=reproduced_bug_file, - test_commands=test_commands_str, - ) - ) - - def __call__(self, state: BugReproductionState): - try: - reproduced_bug_file = self.added_test_filename(state) - except ValueError as e: - self._logger.error(f"Error in bug reproducing execute node: {e}") - return { - "bug_reproducing_execute_messages": [ - AIMessage(f"THE TEST WAS NOT EXECUTED BECAUSE OF AN ERROR: {str(e)}") - ], - } - - message_history = [ - self.system_prompt, - self.format_human_message(state, reproduced_bug_file), - ] + state["bug_reproducing_execute_messages"] - - response = self.model_with_tools.invoke(message_history) - self._logger.debug(response) - return { - "bug_reproducing_execute_messages": [response], - "reproduced_bug_file": reproduced_bug_file, - } + def __init__( + self, + model: BaseChatModel, + container: BaseContainer, + test_commands: Optional[Sequence[str]] = None, + ): + self.test_commands = test_commands + self.tools = self._init_tools(container) + self.model_with_tools = model.bind_tools(self.tools) + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_execute_node") + + def _init_tools(self, container: BaseContainer): + tools = [] + + run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_tool = StructuredTool.from_function( + func=run_command_fn, + name=container_command.run_command.__name__, + description=container_command.RUN_COMMAND_DESCRIPTION, + args_schema=container_command.RunCommandInput, + ) + tools.append(run_command_tool) + + return tools + + def added_test_filename(self, state: BugReproductionState) -> str: + added_files, modified_file, removed_files = get_updated_files( + state["bug_reproducing_patch"] + ) + if removed_files: + raise ValueError("The bug reproducing patch delete files") + if modified_file: + raise ValueError("The bug reproducing patch modified existing files") + if len(added_files) != 1: + raise ValueError("The bug reproducing patch added not one files") + return added_files[0] + + def format_human_message( + self, state: BugReproductionState, reproduced_bug_file: str + ) -> HumanMessage: + test_commands_str = "" + if self.test_commands: + test_commands_str = format_test_commands(self.test_commands) + return HumanMessage( + self.HUMAN_PROMPT.format( + title=state["issue_title"], + body=state["issue_body"], + comments=state["issue_comments"], + reproduced_bug_file=reproduced_bug_file, + test_commands=test_commands_str, + ) + ) + + def __call__(self, state: BugReproductionState): + try: + reproduced_bug_file = self.added_test_filename(state) + except ValueError as e: + self._logger.error(f"Error in bug reproducing execute node: {e}") + return { + "bug_reproducing_execute_messages": [ + AIMessage(f"THE TEST WAS NOT EXECUTED BECAUSE OF AN ERROR: {str(e)}") + ], + } + + message_history = [ + self.system_prompt, + self.format_human_message(state, reproduced_bug_file), + ] + state["bug_reproducing_execute_messages"] + + response = self.model_with_tools.invoke(message_history) + self._logger.debug(response) + return { + "bug_reproducing_execute_messages": [response], + "reproduced_bug_file": reproduced_bug_file, + } diff --git a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py index f7c6c578..25275b84 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_file_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_file_node.py @@ -12,7 +12,7 @@ class BugReproducingFileNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are a test file manager. Your task is to save the provided bug reproducing code in the project. You should: 1. Examine the project structure to identify existing test file naming patterns and test folder organization @@ -27,7 +27,7 @@ class BugReproducingFileNode: Respond with the created file's relative path. """ - HUMAN_PROMPT = """\ + HUMAN_PROMPT = """\ Save this bug reproducing code in the project: {bug_reproducing_code} @@ -35,61 +35,63 @@ class BugReproducingFileNode: {project_structure} """ - def __init__( - self, - model: BaseChatModel, - kg: KnowledgeGraph, - ): - self.kg = kg - self.tools = self._init_tools(str(kg.get_local_path())) - self.model_with_tools = model.bind_tools(self.tools) - self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_file_node") - - def _init_tools(self, root_path: str): - """Initializes file operation tools with the given root path. - - Args: - root_path: Base directory path for all file operations. - - Returns: - List of StructuredTool instances configured for file operations. - """ - tools = [] - - read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) - read_file_tool = StructuredTool.from_function( - func=read_file_fn, - name=file_operation.read_file.__name__, - description=file_operation.READ_FILE_DESCRIPTION, - args_schema=file_operation.ReadFileInput, - ) - tools.append(read_file_tool) - - create_file_fn = functools.partial(file_operation.create_file, root_path=root_path) - create_file_tool = StructuredTool.from_function( - func=create_file_fn, - name=file_operation.create_file.__name__, - description=file_operation.CREATE_FILE_DESCRIPTION, - args_schema=file_operation.CreateFileInput, - ) - tools.append(create_file_tool) - - return tools - - def format_human_message(self, state: BugReproductionState) -> HumanMessage: - return HumanMessage( - self.HUMAN_PROMPT.format( - bug_reproducing_code=get_last_message_content(state["bug_reproducing_write_messages"]), - project_structure=self.kg.get_file_tree(), - ) - ) - - def __call__(self, state: BugReproductionState): - message_history = [self.system_prompt, self.format_human_message(state)] + state[ - "bug_reproducing_file_messages" - ] - - response = self.model_with_tools.invoke(message_history) - self._logger.debug(response) - return {"bug_reproducing_file_messages": [response]} + def __init__( + self, + model: BaseChatModel, + kg: KnowledgeGraph, + ): + self.kg = kg + self.tools = self._init_tools(str(kg.get_local_path())) + self.model_with_tools = model.bind_tools(self.tools) + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_file_node") + + def _init_tools(self, root_path: str): + """Initializes file operation tools with the given root path. + + Args: + root_path: Base directory path for all file operations. + + Returns: + List of StructuredTool instances configured for file operations. + """ + tools = [] + + read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) + read_file_tool = StructuredTool.from_function( + func=read_file_fn, + name=file_operation.read_file.__name__, + description=file_operation.READ_FILE_DESCRIPTION, + args_schema=file_operation.ReadFileInput, + ) + tools.append(read_file_tool) + + create_file_fn = functools.partial(file_operation.create_file, root_path=root_path) + create_file_tool = StructuredTool.from_function( + func=create_file_fn, + name=file_operation.create_file.__name__, + description=file_operation.CREATE_FILE_DESCRIPTION, + args_schema=file_operation.CreateFileInput, + ) + tools.append(create_file_tool) + + return tools + + def format_human_message(self, state: BugReproductionState) -> HumanMessage: + return HumanMessage( + self.HUMAN_PROMPT.format( + bug_reproducing_code=get_last_message_content( + state["bug_reproducing_write_messages"] + ), + project_structure=self.kg.get_file_tree(), + ) + ) + + def __call__(self, state: BugReproductionState): + message_history = [self.system_prompt, self.format_human_message(state)] + state[ + "bug_reproducing_file_messages" + ] + + response = self.model_with_tools.invoke(message_history) + self._logger.debug(response) + return {"bug_reproducing_file_messages": [response]} diff --git a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py index 51343a1b..9fa204aa 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_structured_node.py @@ -8,25 +8,25 @@ from prometheus.lang_graph.subgraphs.bug_reproduction_state import BugReproductionState from prometheus.utils.issue_util import format_issue_comments from prometheus.utils.lang_graph_util import ( - format_agent_tool_message_history, - get_last_message_content, + format_agent_tool_message_history, + get_last_message_content, ) class BugReproducingStructuredOutput(BaseModel): - reproduced_bug: bool = Field( - description="True ONLY if test fails as described in the issue and uses provided examples if any exist" - ) - reproduced_bug_failure_log: str = Field( - description="Complete test execution log. If test passes, include explanation that test should fail to demonstrate the bug" - ) - reproduced_bug_commands: Sequence[str] = Field( - description="A list of commands run to the single file to reproduce the bug" - ) + reproduced_bug: bool = Field( + description="True ONLY if test fails as described in the issue and uses provided examples if any exist" + ) + reproduced_bug_failure_log: str = Field( + description="Complete test execution log. If test passes, include explanation that test should fail to demonstrate the bug" + ) + reproduced_bug_commands: Sequence[str] = Field( + description="A list of commands run to the single file to reproduce the bug" + ) class BugReproducingStructuredNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are an agent that verifies if a test properly reproduces a reported bug. You analyze: 1. Issue description and any provided examples 2. The test code written to reproduce the bug @@ -115,7 +115,7 @@ def test_array_pop(): } """.replace("{", "{{").replace("}", "}}") - HUMAN_PROMPT = """\ + HUMAN_PROMPT = """\ ISSUE INFORMATION: Title: {title} Description: {body} @@ -128,31 +128,33 @@ def test_array_pop(): {bug_reproducing_log} """ - def __init__(self, model: BaseChatModel): - prompt = ChatPromptTemplate.from_messages( - [("system", self.SYS_PROMPT), ("human", "{bug_reproducing_info}")] - ) - structured_llm = model.with_structured_output(BugReproducingStructuredOutput) - self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_structured_node") - - def __call__(self, state: BugReproductionState): - bug_reproducing_log = format_agent_tool_message_history( - state["bug_reproducing_execute_messages"] - ) - bug_reproducing_info = self.HUMAN_PROMPT.format( - title=state["issue_title"], - body=state["issue_body"], - comments=format_issue_comments(state["issue_comments"]), - bug_reproducing_code=get_last_message_content(state["bug_reproducing_write_messages"]), - bug_reproducing_log=bug_reproducing_log, - ) - - response = self.model.invoke({"bug_reproducing_info": bug_reproducing_info}) - self._logger.debug(response) - - return { - "reproduced_bug": response.reproduced_bug, - "reproduced_bug_failure_log": response.reproduced_bug_failure_log, - "reproduced_bug_commands": response.reproduced_bug_commands, - } + def __init__(self, model: BaseChatModel): + prompt = ChatPromptTemplate.from_messages( + [("system", self.SYS_PROMPT), ("human", "{bug_reproducing_info}")] + ) + structured_llm = model.with_structured_output(BugReproducingStructuredOutput) + self.model = prompt | structured_llm + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.bug_reproducing_structured_node" + ) + + def __call__(self, state: BugReproductionState): + bug_reproducing_log = format_agent_tool_message_history( + state["bug_reproducing_execute_messages"] + ) + bug_reproducing_info = self.HUMAN_PROMPT.format( + title=state["issue_title"], + body=state["issue_body"], + comments=format_issue_comments(state["issue_comments"]), + bug_reproducing_code=get_last_message_content(state["bug_reproducing_write_messages"]), + bug_reproducing_log=bug_reproducing_log, + ) + + response = self.model.invoke({"bug_reproducing_info": bug_reproducing_info}) + self._logger.debug(response) + + return { + "reproduced_bug": response.reproduced_bug, + "reproduced_bug_failure_log": response.reproduced_bug_failure_log, + "reproduced_bug_commands": response.reproduced_bug_commands, + } diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py index 5f76ffd4..852142e2 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_message_node.py @@ -7,7 +7,7 @@ class BugReproducingWriteMessageNode: - FIRST_HUMAN_PROMPT = """\ + FIRST_HUMAN_PROMPT = """\ {issue_info} Bug reproducing context: @@ -16,36 +16,36 @@ class BugReproducingWriteMessageNode: Now generate the complete self-contained test case that reproduces the bug with the same error/exception. """ - FOLLOWUP_HUMAN_PROMPT = """\ + FOLLOWUP_HUMAN_PROMPT = """\ Your previous test case failed to reproduce the bug. Here is the failure log: {reproduced_bug_failure_log} Now think about what went wrong and generate the complete self-contained test case that reproduces the bug with the same error/exception again. """ - def __init__(self): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.bug_reproducing_write_message_node" - ) + def __init__(self): + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.bug_reproducing_write_message_node" + ) - def format_human_message(self, state: BugReproductionState): - if "reproduced_bug_failure_log" in state and state["reproduced_bug_failure_log"]: - return HumanMessage( - self.FOLLOWUP_HUMAN_PROMPT.format( - reproduced_bug_failure_log=state["reproduced_bug_failure_log"], + def format_human_message(self, state: BugReproductionState): + if "reproduced_bug_failure_log" in state and state["reproduced_bug_failure_log"]: + return HumanMessage( + self.FOLLOWUP_HUMAN_PROMPT.format( + reproduced_bug_failure_log=state["reproduced_bug_failure_log"], + ) + ) + + return HumanMessage( + self.FIRST_HUMAN_PROMPT.format( + issue_info=format_issue_info( + state["issue_title"], state["issue_body"], state["issue_comments"] + ), + bug_reproducing_context="\n\n".join(state["bug_reproducing_context"]), + ) ) - ) - - return HumanMessage( - self.FIRST_HUMAN_PROMPT.format( - issue_info=format_issue_info( - state["issue_title"], state["issue_body"], state["issue_comments"] - ), - bug_reproducing_context="\n\n".join(state["bug_reproducing_context"]), - ) - ) - - def __call__(self, state: BugReproductionState): - human_message = self.format_human_message(state) - self._logger.debug(f"Sending message to BugReproducingWriteNode:\n{human_message}") - return {"bug_reproducing_write_messages": [human_message]} + + def __call__(self, state: BugReproductionState): + human_message = self.format_human_message(state) + self._logger.debug(f"Sending message to BugReproducingWriteNode:\n{human_message}") + return {"bug_reproducing_write_messages": [human_message]} diff --git a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py index 153a0ab8..bf55bad3 100644 --- a/prometheus/lang_graph/nodes/bug_reproducing_write_node.py +++ b/prometheus/lang_graph/nodes/bug_reproducing_write_node.py @@ -12,7 +12,7 @@ class BugReproducingWriteNode: - SYS_PROMPT = '''\ + SYS_PROMPT = '''\ You are a QA automation expert who writes focused a single test case to reproduce software bugs. Given an issue description, create a single minimal test with MINIMAL number of assertions that demonstrates the problem. @@ -112,38 +112,38 @@ def test_empty_array_parsing(parser): ''' - def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): - self.tools = self._init_tools(str(kg.get_local_path())) - self.system_prompt = SystemMessage(self.SYS_PROMPT) - self.model_with_tools = model.bind_tools(self.tools) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_write_node") - - def _init_tools(self, root_path: str): - """Initializes file operation tools with the given root path. - - Args: - root_path: Base directory path for all file operations. - - Returns: - List of StructuredTool instances configured for file operations. - """ - tools = [] - - read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) - read_file_tool = StructuredTool.from_function( - func=read_file_fn, - name=file_operation.read_file.__name__, - description=file_operation.READ_FILE_DESCRIPTION, - args_schema=file_operation.ReadFileInput, - ) - tools.append(read_file_tool) - - return tools - - def __call__(self, state: BugReproductionState): - message_history = [self.system_prompt] + state["bug_reproducing_write_messages"] - truncated_message_history = truncate_messages(message_history) - response = self.model_with_tools.invoke(truncated_message_history) - - self._logger.debug(response) - return {"bug_reproducing_write_messages": [response]} + def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): + self.tools = self._init_tools(str(kg.get_local_path())) + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self.model_with_tools = model.bind_tools(self.tools) + self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproducing_write_node") + + def _init_tools(self, root_path: str): + """Initializes file operation tools with the given root path. + + Args: + root_path: Base directory path for all file operations. + + Returns: + List of StructuredTool instances configured for file operations. + """ + tools = [] + + read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) + read_file_tool = StructuredTool.from_function( + func=read_file_fn, + name=file_operation.read_file.__name__, + description=file_operation.READ_FILE_DESCRIPTION, + args_schema=file_operation.ReadFileInput, + ) + tools.append(read_file_tool) + + return tools + + def __call__(self, state: BugReproductionState): + message_history = [self.system_prompt] + state["bug_reproducing_write_messages"] + truncated_message_history = truncate_messages(message_history) + response = self.model_with_tools.invoke(truncated_message_history) + + self._logger.debug(response) + return {"bug_reproducing_write_messages": [response]} diff --git a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py index 872f907a..39b0d0a5 100644 --- a/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py +++ b/prometheus/lang_graph/nodes/bug_reproduction_subgraph_node.py @@ -13,49 +13,51 @@ class BugReproductionSubgraphNode: - def __init__( - self, - advanced_model: BaseChatModel, - base_model: BaseChatModel, - container: BaseContainer, - kg: KnowledgeGraph, - git_repo: GitRepository, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - test_commands: Optional[Sequence[str]], - ): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.bug_reproduction_subgraph_node") - self.git_repo = git_repo - self.bug_reproduction_subgraph = BugReproductionSubgraph( - advanced_model=advanced_model, - base_model=base_model, - container=container, - kg=kg, - git_repo=git_repo, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - test_commands=test_commands, - ) + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + container: BaseContainer, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + test_commands: Optional[Sequence[str]], + ): + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.bug_reproduction_subgraph_node" + ) + self.git_repo = git_repo + self.bug_reproduction_subgraph = BugReproductionSubgraph( + advanced_model=advanced_model, + base_model=base_model, + container=container, + kg=kg, + git_repo=git_repo, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + test_commands=test_commands, + ) - def __call__(self, state: IssueBugState): - self._logger.info("Enter bug_reproduction_subgraph") + def __call__(self, state: IssueBugState): + self._logger.info("Enter bug_reproduction_subgraph") - try: - output_state = self.bug_reproduction_subgraph.invoke( - state["issue_title"], - state["issue_body"], - state["issue_comments"], - ) - except GraphRecursionError: - self._logger.info("Recursion limit reached, returning reproduced_bug=False") - self.git_repo.reset_repository() - return {"reproduced_bug": False} + try: + output_state = self.bug_reproduction_subgraph.invoke( + state["issue_title"], + state["issue_body"], + state["issue_comments"], + ) + except GraphRecursionError: + self._logger.info("Recursion limit reached, returning reproduced_bug=False") + self.git_repo.reset_repository() + return {"reproduced_bug": False} - self._logger.info(f"reproduced_bug: {output_state['reproduced_bug']}") - self._logger.info(f"reproduced_bug_file: {output_state['reproduced_bug_file']}") - self._logger.info(f"reproduced_bug_commands: {output_state['reproduced_bug_commands']}") - return { - "reproduced_bug": output_state["reproduced_bug"], - "reproduced_bug_file": output_state["reproduced_bug_file"], - "reproduced_bug_commands": output_state["reproduced_bug_commands"], - } + self._logger.info(f"reproduced_bug: {output_state['reproduced_bug']}") + self._logger.info(f"reproduced_bug_file: {output_state['reproduced_bug_file']}") + self._logger.info(f"reproduced_bug_commands: {output_state['reproduced_bug_commands']}") + return { + "reproduced_bug": output_state["reproduced_bug"], + "reproduced_bug_file": output_state["reproduced_bug_file"], + "reproduced_bug_commands": output_state["reproduced_bug_commands"], + } diff --git a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py index fce816dc..7f84c90b 100644 --- a/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py +++ b/prometheus/lang_graph/nodes/build_and_test_subgraph_node.py @@ -10,66 +10,66 @@ class BuildAndTestSubgraphNode: - def __init__( - self, - container: BaseContainer, - model: BaseChatModel, - kg: KnowledgeGraph, - build_commands: Optional[Sequence[str]] = None, - test_commands: Optional[Sequence[str]] = None, - ): - self.build_and_test_subgraph = BuildAndTestSubgraph( - container=container, - model=model, - kg=kg, - build_commands=build_commands, - test_commands=test_commands, - ) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.build_and_test_subgraph_node") + def __init__( + self, + container: BaseContainer, + model: BaseChatModel, + kg: KnowledgeGraph, + build_commands: Optional[Sequence[str]] = None, + test_commands: Optional[Sequence[str]] = None, + ): + self.build_and_test_subgraph = BuildAndTestSubgraph( + container=container, + model=model, + kg=kg, + build_commands=build_commands, + test_commands=test_commands, + ) + self._logger = logging.getLogger("prometheus.lang_graph.nodes.build_and_test_subgraph_node") - def __call__(self, state: IssueBugState): - exist_build = None - build_command_summary = None - build_fail_log = None - exist_test = None - test_command_summary = None - existing_test_fail_log = None + def __call__(self, state: IssueBugState): + exist_build = None + build_command_summary = None + build_fail_log = None + exist_test = None + test_command_summary = None + existing_test_fail_log = None - if "build_command_summary" in state and state["build_command_summary"]: - exist_build = state["exist_build"] - build_command_summary = state["build_command_summary"] - build_fail_log = state["build_fail_log"] + if "build_command_summary" in state and state["build_command_summary"]: + exist_build = state["exist_build"] + build_command_summary = state["build_command_summary"] + build_fail_log = state["build_fail_log"] - if "test_command_summary" in state and state["test_command_summary"]: - exist_test = state["exist_test"] - test_command_summary = state["test_command_summary"] - existing_test_fail_log = state["existing_test_fail_log"] + if "test_command_summary" in state and state["test_command_summary"]: + exist_test = state["exist_test"] + test_command_summary = state["test_command_summary"] + existing_test_fail_log = state["existing_test_fail_log"] - self._logger.info("Enters BuildAndTestSubgraphNode") + self._logger.info("Enters BuildAndTestSubgraphNode") - output_state = self.build_and_test_subgraph.invoke( - run_build=state["run_build"], - run_existing_test=state["run_existing_test"], - exist_build=exist_build, - build_command_summary=build_command_summary, - build_fail_log=build_fail_log, - exist_test=exist_test, - test_command_summary=test_command_summary, - existing_test_fail_log=existing_test_fail_log, - ) + output_state = self.build_and_test_subgraph.invoke( + run_build=state["run_build"], + run_existing_test=state["run_existing_test"], + exist_build=exist_build, + build_command_summary=build_command_summary, + build_fail_log=build_fail_log, + exist_test=exist_test, + test_command_summary=test_command_summary, + existing_test_fail_log=existing_test_fail_log, + ) - self._logger.info(f"exist_build: {output_state['exist_build']}") - self._logger.info(f"build_command_summary:\n{output_state['build_command_summary']}") - self._logger.info(f"build_fail_log:\n{output_state['build_fail_log']}") - self._logger.info(f"exist_test: {output_state['exist_test']}") - self._logger.info(f"test_command_summary:\n{output_state['test_command_summary']}") - self._logger.info(f"existing_test_fail_log:\n{output_state['existing_test_fail_log']}") + self._logger.info(f"exist_build: {output_state['exist_build']}") + self._logger.info(f"build_command_summary:\n{output_state['build_command_summary']}") + self._logger.info(f"build_fail_log:\n{output_state['build_fail_log']}") + self._logger.info(f"exist_test: {output_state['exist_test']}") + self._logger.info(f"test_command_summary:\n{output_state['test_command_summary']}") + self._logger.info(f"existing_test_fail_log:\n{output_state['existing_test_fail_log']}") - return { - "exist_build": output_state["exist_build"], - "build_command_summary": output_state["build_command_summary"], - "build_fail_log": output_state["build_fail_log"], - "exist_test": output_state["exist_test"], - "test_command_summary": output_state["test_command_summary"], - "existing_test_fail_log": output_state["existing_test_fail_log"], - } + return { + "exist_build": output_state["exist_build"], + "build_command_summary": output_state["build_command_summary"], + "build_fail_log": output_state["build_fail_log"], + "exist_test": output_state["exist_test"], + "test_command_summary": output_state["test_command_summary"], + "existing_test_fail_log": output_state["existing_test_fail_log"], + } diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index 0d65d33e..d9637757 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -20,20 +20,20 @@ class ContextProviderNode: - """Provides contextual information from a codebase using knowledge graph search. + """Provides contextual information from a codebase using knowledge graph search. - This class implements a systematic approach to finding relevant code context - by searching through a Neo4j knowledge graph representation of a codebase. - It uses a combination of file structure navigation, AST analysis, and text - search to gather comprehensive context for queries. + This class implements a systematic approach to finding relevant code context + by searching through a Neo4j knowledge graph representation of a codebase. + It uses a combination of file structure navigation, AST analysis, and text + search to gather comprehensive context for queries. - The knowledge graph contains three main types of nodes: - - FileNode: Represents files and directories - - ASTNode: Represents syntactic elements from the code - - TextNode: Represents documentation and text content - """ + The knowledge graph contains three main types of nodes: + - FileNode: Represents files and directories + - ASTNode: Represents syntactic elements from the code + - TextNode: Represents documentation and text content + """ - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are a context gatherer that searches a Neo4j knowledge graph representation of a codebase. Your role is to efficiently find relevant code and documentation context based on user queries. @@ -83,246 +83,246 @@ class ContextProviderNode: Available AST node types for code structure search: {ast_node_types} """ - def __init__( - self, - model: BaseChatModel, - kg: KnowledgeGraph, - neo4j_driver: neo4j.Driver, - max_token_per_result: int, - ): - """Initializes the ContextProviderNode with model, knowledge graph, and database connection. - - Sets up the context provider with necessary prompts, graph traversal tools, - and logging configuration. Initializes the system prompt with the current - file tree structure from the knowledge graph. - - Args: - model: Language model instance that will be used for query analysis and - context finding. Must be a BaseChatModel implementation that supports - tool binding. - kg: Knowledge graph instance containing the processed codebase structure. - Used to obtain the file tree for system prompts. - neo4j_driver: Neo4j driver instance for executing graph queries. This - driver should be properly configured with authentication and - connection details. - max_token_per_result: Maximum number of tokens per retrieved Neo4j result. - """ - self.neo4j_driver = neo4j_driver - self.max_token_per_result = max_token_per_result - - ast_node_types_str = ", ".join(kg.get_all_ast_node_types()) - self.system_prompt = SystemMessage( - self.SYS_PROMPT.format(file_tree=kg.get_file_tree(), ast_node_types=ast_node_types_str) - ) - self.tools = self._init_tools() - self.model_with_tools = model.bind_tools(self.tools) - - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_provider_node") - - def _init_tools(self): - """Initializes KnowledgeGraph traversal tools. - - - Returns: - List of StructuredTool instances configured for KnowledgeGraph traversal. - """ - tools = [] - - find_file_node_with_basename_fn = functools.partial( - graph_traversal.find_file_node_with_basename, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - find_file_node_with_basename_tool = StructuredTool.from_function( - func=find_file_node_with_basename_fn, - name=graph_traversal.find_file_node_with_basename.__name__, - description=graph_traversal.FIND_FILE_NODE_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.FindFileNodeWithBasenameInput, - response_format="content_and_artifact", - ) - tools.append(find_file_node_with_basename_tool) - - find_file_node_with_relative_path_fn = functools.partial( - graph_traversal.find_file_node_with_relative_path, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - find_file_node_with_relative_path_tool = StructuredTool.from_function( - func=find_file_node_with_relative_path_fn, - name=graph_traversal.find_file_node_with_relative_path.__name__, - description=graph_traversal.FIND_FILE_NODE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.FindFileNodeWithRelativePathInput, - response_format="content_and_artifact", - ) - tools.append(find_file_node_with_relative_path_tool) - - find_ast_node_with_text_in_file_with_basename_fn = functools.partial( - graph_traversal.find_ast_node_with_text_in_file_with_basename, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - find_ast_node_with_text_in_file_with_basename_tool = StructuredTool.from_function( - func=find_ast_node_with_text_in_file_with_basename_fn, - name=graph_traversal.find_ast_node_with_text_in_file_with_basename.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTextInFileWithBasenameInput, - response_format="content_and_artifact", - ) - tools.append(find_ast_node_with_text_in_file_with_basename_tool) - - find_ast_node_with_text_in_file_with_relative_path_fn = functools.partial( - graph_traversal.find_ast_node_with_text_in_file_with_relative_path, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - find_ast_node_with_text_in_file_with_relative_path_tool = StructuredTool.from_function( - func=find_ast_node_with_text_in_file_with_relative_path_fn, - name=graph_traversal.find_ast_node_with_text_in_file_with_relative_path.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTextInFileWithRelativePathInput, - response_format="content_and_artifact", - ) - tools.append(find_ast_node_with_text_in_file_with_relative_path_tool) - - find_ast_node_with_type_in_file_with_basename_fn = functools.partial( - graph_traversal.find_ast_node_with_type_in_file_with_basename, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - find_ast_node_with_type_in_file_with_basename_tool = StructuredTool.from_function( - func=find_ast_node_with_type_in_file_with_basename_fn, - name=graph_traversal.find_ast_node_with_type_in_file_with_basename.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTypeInFileWithBasenameInput, - response_format="content_and_artifact", - ) - tools.append(find_ast_node_with_type_in_file_with_basename_tool) - - find_ast_node_with_type_in_file_with_relative_path_fn = functools.partial( - graph_traversal.find_ast_node_with_type_in_file_with_relative_path, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - find_ast_node_with_type_in_file_with_relative_path_tool = StructuredTool.from_function( - func=find_ast_node_with_type_in_file_with_relative_path_fn, - name=graph_traversal.find_ast_node_with_type_in_file_with_relative_path.__name__, - description=graph_traversal.FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.FindASTNodeWithTypeInFileWithRelativePathInput, - response_format="content_and_artifact", - ) - tools.append(find_ast_node_with_type_in_file_with_relative_path_tool) - - find_text_node_with_text_fn = functools.partial( - graph_traversal.find_text_node_with_text, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - find_text_node_with_text_tool = StructuredTool.from_function( - func=find_text_node_with_text_fn, - name=graph_traversal.find_text_node_with_text.__name__, - description=graph_traversal.FIND_TEXT_NODE_WITH_TEXT_DESCRIPTION, - args_schema=graph_traversal.FindTextNodeWithTextInput, - response_format="content_and_artifact", - ) - tools.append(find_text_node_with_text_tool) - - find_text_node_with_text_in_file_fn = functools.partial( - graph_traversal.find_text_node_with_text_in_file, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - find_text_node_with_text_in_file_tool = StructuredTool.from_function( - func=find_text_node_with_text_in_file_fn, - name=graph_traversal.find_text_node_with_text_in_file.__name__, - description=graph_traversal.FIND_TEXT_NODE_WITH_TEXT_IN_FILE_DESCRIPTION, - args_schema=graph_traversal.FindTextNodeWithTextInFileInput, - response_format="content_and_artifact", - ) - tools.append(find_text_node_with_text_in_file_tool) - - get_next_text_node_with_node_id_fn = functools.partial( - graph_traversal.get_next_text_node_with_node_id, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - get_next_text_node_with_node_id_tool = StructuredTool.from_function( - func=get_next_text_node_with_node_id_fn, - name=graph_traversal.get_next_text_node_with_node_id.__name__, - description=graph_traversal.GET_NEXT_TEXT_NODE_WITH_NODE_ID_DESCRIPTION, - args_schema=graph_traversal.GetNextTextNodeWithNodeIdInput, - response_format="content_and_artifact", - ) - tools.append(get_next_text_node_with_node_id_tool) - - preview_file_content_with_basename_fn = functools.partial( - graph_traversal.preview_file_content_with_basename, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - preview_file_content_with_basename_tool = StructuredTool.from_function( - func=preview_file_content_with_basename_fn, - name=graph_traversal.preview_file_content_with_basename.__name__, - description=graph_traversal.PREVIEW_FILE_CONTENT_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.PreviewFileContentWithBasenameInput, - response_format="content_and_artifact", - ) - tools.append(preview_file_content_with_basename_tool) - - preview_file_content_with_relative_path_fn = functools.partial( - graph_traversal.preview_file_content_with_relative_path, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - preview_file_content_with_relative_path_tool = StructuredTool.from_function( - func=preview_file_content_with_relative_path_fn, - name=graph_traversal.preview_file_content_with_relative_path.__name__, - description=graph_traversal.PREVIEW_FILE_CONTENT_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.PreviewFileContentWithRelativePathInput, - response_format="content_and_artifact", - ) - tools.append(preview_file_content_with_relative_path_tool) - - read_code_with_basename_fn = functools.partial( - graph_traversal.read_code_with_basename, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - read_code_with_basename_tool = StructuredTool.from_function( - func=read_code_with_basename_fn, - name=graph_traversal.read_code_with_basename.__name__, - description=graph_traversal.READ_CODE_WITH_BASENAME_DESCRIPTION, - args_schema=graph_traversal.ReadCodeWithBasenameInput, - response_format="content_and_artifact", - ) - tools.append(read_code_with_basename_tool) - - read_code_with_relative_path_fn = functools.partial( - graph_traversal.read_code_with_relative_path, - driver=self.neo4j_driver, - max_token_per_result=self.max_token_per_result, - ) - read_code_with_relative_path_tool = StructuredTool.from_function( - func=read_code_with_relative_path_fn, - name=graph_traversal.read_code_with_relative_path.__name__, - description=graph_traversal.READ_CODE_WITH_RELATIVE_PATH_DESCRIPTION, - args_schema=graph_traversal.ReadCodeWithRelativePathInput, - response_format="content_and_artifact", - ) - tools.append(read_code_with_relative_path_tool) - - return tools - - def __call__(self, state: Dict): - """Processes the current state and traverse the knowledge graph to retrieve context. - - Args: - state: Current state containing the human query and preivous context_messages. - - Returns: - Dictionary that will update the state with the model's response messages. - """ - message_history = [self.system_prompt] + state["context_provider_messages"] - truncated_message_history = truncate_messages(message_history) - response = self.model_with_tools.invoke(truncated_message_history) - self._logger.debug(response) - return {"context_provider_messages": [response]} + def __init__( + self, + model: BaseChatModel, + kg: KnowledgeGraph, + neo4j_driver: neo4j.Driver, + max_token_per_result: int, + ): + """Initializes the ContextProviderNode with model, knowledge graph, and database connection. + + Sets up the context provider with necessary prompts, graph traversal tools, + and logging configuration. Initializes the system prompt with the current + file tree structure from the knowledge graph. + + Args: + model: Language model instance that will be used for query analysis and + context finding. Must be a BaseChatModel implementation that supports + tool binding. + kg: Knowledge graph instance containing the processed codebase structure. + Used to obtain the file tree for system prompts. + neo4j_driver: Neo4j driver instance for executing graph queries. This + driver should be properly configured with authentication and + connection details. + max_token_per_result: Maximum number of tokens per retrieved Neo4j result. + """ + self.neo4j_driver = neo4j_driver + self.max_token_per_result = max_token_per_result + + ast_node_types_str = ", ".join(kg.get_all_ast_node_types()) + self.system_prompt = SystemMessage( + self.SYS_PROMPT.format(file_tree=kg.get_file_tree(), ast_node_types=ast_node_types_str) + ) + self.tools = self._init_tools() + self.model_with_tools = model.bind_tools(self.tools) + + self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_provider_node") + + def _init_tools(self): + """Initializes KnowledgeGraph traversal tools. + + + Returns: + List of StructuredTool instances configured for KnowledgeGraph traversal. + """ + tools = [] + + find_file_node_with_basename_fn = functools.partial( + graph_traversal.find_file_node_with_basename, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + find_file_node_with_basename_tool = StructuredTool.from_function( + func=find_file_node_with_basename_fn, + name=graph_traversal.find_file_node_with_basename.__name__, + description=graph_traversal.FIND_FILE_NODE_WITH_BASENAME_DESCRIPTION, + args_schema=graph_traversal.FindFileNodeWithBasenameInput, + response_format="content_and_artifact", + ) + tools.append(find_file_node_with_basename_tool) + + find_file_node_with_relative_path_fn = functools.partial( + graph_traversal.find_file_node_with_relative_path, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + find_file_node_with_relative_path_tool = StructuredTool.from_function( + func=find_file_node_with_relative_path_fn, + name=graph_traversal.find_file_node_with_relative_path.__name__, + description=graph_traversal.FIND_FILE_NODE_WITH_RELATIVE_PATH_DESCRIPTION, + args_schema=graph_traversal.FindFileNodeWithRelativePathInput, + response_format="content_and_artifact", + ) + tools.append(find_file_node_with_relative_path_tool) + + find_ast_node_with_text_in_file_with_basename_fn = functools.partial( + graph_traversal.find_ast_node_with_text_in_file_with_basename, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + find_ast_node_with_text_in_file_with_basename_tool = StructuredTool.from_function( + func=find_ast_node_with_text_in_file_with_basename_fn, + name=graph_traversal.find_ast_node_with_text_in_file_with_basename.__name__, + description=graph_traversal.FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_BASENAME_DESCRIPTION, + args_schema=graph_traversal.FindASTNodeWithTextInFileWithBasenameInput, + response_format="content_and_artifact", + ) + tools.append(find_ast_node_with_text_in_file_with_basename_tool) + + find_ast_node_with_text_in_file_with_relative_path_fn = functools.partial( + graph_traversal.find_ast_node_with_text_in_file_with_relative_path, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + find_ast_node_with_text_in_file_with_relative_path_tool = StructuredTool.from_function( + func=find_ast_node_with_text_in_file_with_relative_path_fn, + name=graph_traversal.find_ast_node_with_text_in_file_with_relative_path.__name__, + description=graph_traversal.FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION, + args_schema=graph_traversal.FindASTNodeWithTextInFileWithRelativePathInput, + response_format="content_and_artifact", + ) + tools.append(find_ast_node_with_text_in_file_with_relative_path_tool) + + find_ast_node_with_type_in_file_with_basename_fn = functools.partial( + graph_traversal.find_ast_node_with_type_in_file_with_basename, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + find_ast_node_with_type_in_file_with_basename_tool = StructuredTool.from_function( + func=find_ast_node_with_type_in_file_with_basename_fn, + name=graph_traversal.find_ast_node_with_type_in_file_with_basename.__name__, + description=graph_traversal.FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_BASENAME_DESCRIPTION, + args_schema=graph_traversal.FindASTNodeWithTypeInFileWithBasenameInput, + response_format="content_and_artifact", + ) + tools.append(find_ast_node_with_type_in_file_with_basename_tool) + + find_ast_node_with_type_in_file_with_relative_path_fn = functools.partial( + graph_traversal.find_ast_node_with_type_in_file_with_relative_path, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + find_ast_node_with_type_in_file_with_relative_path_tool = StructuredTool.from_function( + func=find_ast_node_with_type_in_file_with_relative_path_fn, + name=graph_traversal.find_ast_node_with_type_in_file_with_relative_path.__name__, + description=graph_traversal.FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION, + args_schema=graph_traversal.FindASTNodeWithTypeInFileWithRelativePathInput, + response_format="content_and_artifact", + ) + tools.append(find_ast_node_with_type_in_file_with_relative_path_tool) + + find_text_node_with_text_fn = functools.partial( + graph_traversal.find_text_node_with_text, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + find_text_node_with_text_tool = StructuredTool.from_function( + func=find_text_node_with_text_fn, + name=graph_traversal.find_text_node_with_text.__name__, + description=graph_traversal.FIND_TEXT_NODE_WITH_TEXT_DESCRIPTION, + args_schema=graph_traversal.FindTextNodeWithTextInput, + response_format="content_and_artifact", + ) + tools.append(find_text_node_with_text_tool) + + find_text_node_with_text_in_file_fn = functools.partial( + graph_traversal.find_text_node_with_text_in_file, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + find_text_node_with_text_in_file_tool = StructuredTool.from_function( + func=find_text_node_with_text_in_file_fn, + name=graph_traversal.find_text_node_with_text_in_file.__name__, + description=graph_traversal.FIND_TEXT_NODE_WITH_TEXT_IN_FILE_DESCRIPTION, + args_schema=graph_traversal.FindTextNodeWithTextInFileInput, + response_format="content_and_artifact", + ) + tools.append(find_text_node_with_text_in_file_tool) + + get_next_text_node_with_node_id_fn = functools.partial( + graph_traversal.get_next_text_node_with_node_id, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + get_next_text_node_with_node_id_tool = StructuredTool.from_function( + func=get_next_text_node_with_node_id_fn, + name=graph_traversal.get_next_text_node_with_node_id.__name__, + description=graph_traversal.GET_NEXT_TEXT_NODE_WITH_NODE_ID_DESCRIPTION, + args_schema=graph_traversal.GetNextTextNodeWithNodeIdInput, + response_format="content_and_artifact", + ) + tools.append(get_next_text_node_with_node_id_tool) + + preview_file_content_with_basename_fn = functools.partial( + graph_traversal.preview_file_content_with_basename, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + preview_file_content_with_basename_tool = StructuredTool.from_function( + func=preview_file_content_with_basename_fn, + name=graph_traversal.preview_file_content_with_basename.__name__, + description=graph_traversal.PREVIEW_FILE_CONTENT_WITH_BASENAME_DESCRIPTION, + args_schema=graph_traversal.PreviewFileContentWithBasenameInput, + response_format="content_and_artifact", + ) + tools.append(preview_file_content_with_basename_tool) + + preview_file_content_with_relative_path_fn = functools.partial( + graph_traversal.preview_file_content_with_relative_path, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + preview_file_content_with_relative_path_tool = StructuredTool.from_function( + func=preview_file_content_with_relative_path_fn, + name=graph_traversal.preview_file_content_with_relative_path.__name__, + description=graph_traversal.PREVIEW_FILE_CONTENT_WITH_RELATIVE_PATH_DESCRIPTION, + args_schema=graph_traversal.PreviewFileContentWithRelativePathInput, + response_format="content_and_artifact", + ) + tools.append(preview_file_content_with_relative_path_tool) + + read_code_with_basename_fn = functools.partial( + graph_traversal.read_code_with_basename, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + read_code_with_basename_tool = StructuredTool.from_function( + func=read_code_with_basename_fn, + name=graph_traversal.read_code_with_basename.__name__, + description=graph_traversal.READ_CODE_WITH_BASENAME_DESCRIPTION, + args_schema=graph_traversal.ReadCodeWithBasenameInput, + response_format="content_and_artifact", + ) + tools.append(read_code_with_basename_tool) + + read_code_with_relative_path_fn = functools.partial( + graph_traversal.read_code_with_relative_path, + driver=self.neo4j_driver, + max_token_per_result=self.max_token_per_result, + ) + read_code_with_relative_path_tool = StructuredTool.from_function( + func=read_code_with_relative_path_fn, + name=graph_traversal.read_code_with_relative_path.__name__, + description=graph_traversal.READ_CODE_WITH_RELATIVE_PATH_DESCRIPTION, + args_schema=graph_traversal.ReadCodeWithRelativePathInput, + response_format="content_and_artifact", + ) + tools.append(read_code_with_relative_path_tool) + + return tools + + def __call__(self, state: Dict): + """Processes the current state and traverse the knowledge graph to retrieve context. + + Args: + state: Current state containing the human query and preivous context_messages. + + Returns: + Dictionary that will update the state with the model's response messages. + """ + message_history = [self.system_prompt] + state["context_provider_messages"] + truncated_message_history = truncate_messages(message_history) + response = self.model_with_tools.invoke(truncated_message_history) + self._logger.debug(response) + return {"context_provider_messages": [response]} diff --git a/prometheus/lang_graph/nodes/context_query_message_node.py b/prometheus/lang_graph/nodes/context_query_message_node.py index dc764223..d77eec34 100644 --- a/prometheus/lang_graph/nodes/context_query_message_node.py +++ b/prometheus/lang_graph/nodes/context_query_message_node.py @@ -6,10 +6,10 @@ class ContextQueryMessageNode: - def __init__(self): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_query_message_node") + def __init__(self): + self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_query_message_node") - def __call__(self, state: ContextRetrievalState): - human_message = HumanMessage(state["query"]) - self._logger.debug(f"Sending query to ContextProviderNode:\n{human_message}") - return {"context_provider_messages": [human_message]} + def __call__(self, state: ContextRetrievalState): + human_message = HumanMessage(state["query"]) + self._logger.debug(f"Sending query to ContextProviderNode:\n{human_message}") + return {"context_provider_messages": [human_message]} diff --git a/prometheus/lang_graph/nodes/context_refine_node.py b/prometheus/lang_graph/nodes/context_refine_node.py index 96ad8885..c21375eb 100644 --- a/prometheus/lang_graph/nodes/context_refine_node.py +++ b/prometheus/lang_graph/nodes/context_refine_node.py @@ -10,14 +10,14 @@ class ContextRefineStructuredOutput(BaseModel): - reasoning: str = Field(description="Your step by step reasoning.") - refined_query: str = Field( - "Additional query to ask the ContextRetriever if the context is not enough. Empty otherwise." - ) + reasoning: str = Field(description="Your step by step reasoning.") + refined_query: str = Field( + "Additional query to ask the ContextRetriever if the context is not enough. Empty otherwise." + ) class ContextRefineNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are an intelligent assistant specialized in analyzing code context to determine if additional source code or documentation from the codebase is necessary to fulfill the user's query. @@ -44,7 +44,7 @@ class ContextRefineStructuredOutput(BaseModel): {file_tree} """ - REFINE_PROMPT = """\ + REFINE_PROMPT = """\ This is the original user query: {original_query} @@ -68,41 +68,43 @@ class ContextRefineStructuredOutput(BaseModel): - Consider both code and documentation that might be relevant """ - def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): - prompt = ChatPromptTemplate.from_messages( - [ - ("system", self.SYS_PROMPT.format(file_tree=kg.get_file_tree())), - ("human", "{human_prompt}"), - ] - ) - structured_llm = model.with_structured_output(ContextRefineStructuredOutput) - self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_refine_node") - - def format_refine_message(self, state: ContextRetrievalState): - original_query = state["query"] - context = "\n\n".join(state["context"]) - return self.REFINE_PROMPT.format( - original_query=original_query, - context=context, - ) - - def __call__(self, state: ContextRetrievalState): - if "max_refined_query_loop" in state and state["max_refined_query_loop"] == 0: - self._logger.info("Reached max_refined_query_loop, not asking for more context") - return {"refined_query": ""} - - human_prompt = self.format_refine_message(state) - self._logger.debug(human_prompt) - response = self.model.invoke({"human_prompt": human_prompt}) - self._logger.debug(response) - - state_update = {"refined_query": response.refined_query} - - if "max_refined_query_loop" in state: - state_update["max_refined_query_loop"] = state["max_refined_query_loop"] - 1 - - if response.refined_query: - state_update["context_provider_messages"] = [HumanMessage(content=response.refined_query)] - - return state_update + def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): + prompt = ChatPromptTemplate.from_messages( + [ + ("system", self.SYS_PROMPT.format(file_tree=kg.get_file_tree())), + ("human", "{human_prompt}"), + ] + ) + structured_llm = model.with_structured_output(ContextRefineStructuredOutput) + self.model = prompt | structured_llm + self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_refine_node") + + def format_refine_message(self, state: ContextRetrievalState): + original_query = state["query"] + context = "\n\n".join(state["context"]) + return self.REFINE_PROMPT.format( + original_query=original_query, + context=context, + ) + + def __call__(self, state: ContextRetrievalState): + if "max_refined_query_loop" in state and state["max_refined_query_loop"] == 0: + self._logger.info("Reached max_refined_query_loop, not asking for more context") + return {"refined_query": ""} + + human_prompt = self.format_refine_message(state) + self._logger.debug(human_prompt) + response = self.model.invoke({"human_prompt": human_prompt}) + self._logger.debug(response) + + state_update = {"refined_query": response.refined_query} + + if "max_refined_query_loop" in state: + state_update["max_refined_query_loop"] = state["max_refined_query_loop"] - 1 + + if response.refined_query: + state_update["context_provider_messages"] = [ + HumanMessage(content=response.refined_query) + ] + + return state_update diff --git a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py index 5a36c148..847ddd4d 100644 --- a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py +++ b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py @@ -9,29 +9,31 @@ class ContextRetrievalSubgraphNode: - def __init__( - self, - model: BaseChatModel, - kg: KnowledgeGraph, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - query_key_name: str, - context_key_name: str, - ): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_retrieval_subgraph_node") - self.context_retrevial_subgraph = ContextRetrievalSubgraph( - model=model, - kg=kg, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - ) - self.query_key_name = query_key_name - self.context_key_name = context_key_name + def __init__( + self, + model: BaseChatModel, + kg: KnowledgeGraph, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + query_key_name: str, + context_key_name: str, + ): + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.context_retrieval_subgraph_node" + ) + self.context_retrevial_subgraph = ContextRetrievalSubgraph( + model=model, + kg=kg, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + ) + self.query_key_name = query_key_name + self.context_key_name = context_key_name - def __call__(self, state: Dict): - self._logger.info("Enter context retrieval subgraph") - output_state = self.context_retrevial_subgraph.invoke( - state[self.query_key_name], state["max_refined_query_loop"] - ) + def __call__(self, state: Dict): + self._logger.info("Enter context retrieval subgraph") + output_state = self.context_retrevial_subgraph.invoke( + state[self.query_key_name], state["max_refined_query_loop"] + ) - return {self.context_key_name: output_state["context"]} + return {self.context_key_name: output_state["context"]} diff --git a/prometheus/lang_graph/nodes/context_selection_node.py b/prometheus/lang_graph/nodes/context_selection_node.py index e2fc6141..815c8d9e 100644 --- a/prometheus/lang_graph/nodes/context_selection_node.py +++ b/prometheus/lang_graph/nodes/context_selection_node.py @@ -10,14 +10,14 @@ class ContextSelectionStructuredOutput(BaseModel): - reasoning: str = Field( - description="Your step-by-step reasoning why the context is relevant to the query" - ) - relevant: bool = Field(description="If the context is relevant to the query") + reasoning: str = Field( + description="Your step-by-step reasoning why the context is relevant to the query" + ) + relevant: bool = Field(description="If the context is relevant to the query") class ContextSelectionNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are a context selection agent that evaluates if a piece of code context is relevant to a given query. Your goal is to determine if the context directly answers the query requirements. Your evaluation must consider two key aspects: @@ -73,7 +73,7 @@ def validate_password(password: str, hash: str): Your task is to analyze the context and provide a similar structured output with detailed reasoning and a relevance decision. """.replace("{", "{{").replace("}", "}}") - HUMAN_PROMPT = """\ + HUMAN_PROMPT = """\ Query: {query} @@ -83,26 +83,28 @@ def validate_password(password: str, hash: str): Please classify if the found context is relevant to the query. """ - def __init__(self, model: BaseChatModel): - prompt = ChatPromptTemplate.from_messages( - [("system", self.SYS_PROMPT), ("human", "{human_prompt}")] - ) - structured_llm = model.with_structured_output(ContextSelectionStructuredOutput) - self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_selection_node") - - def format_human_prompt(self, state: ContextRetrievalState, search_result: str) -> str: - context_info = self.HUMAN_PROMPT.format(query=state["query"], context=search_result) - return context_info - - def __call__(self, state: ContextRetrievalState): - context_list = state.get("context", []) - for tool_message in extract_last_tool_messages(state["context_provider_messages"]): - for search_result in neo4j_data_for_context_generator(tool_message.artifact): - human_prompt = self.format_human_prompt(state, search_result) - response = self.model.invoke({"human_prompt": human_prompt}) - self._logger.debug(f"Is this search result {search_result} relevant?: {response.relevant}") - if response.relevant: - context_list.append(search_result) - self._logger.info(f"Context selection complete, returning context {context_list}") - return {"context": context_list} + def __init__(self, model: BaseChatModel): + prompt = ChatPromptTemplate.from_messages( + [("system", self.SYS_PROMPT), ("human", "{human_prompt}")] + ) + structured_llm = model.with_structured_output(ContextSelectionStructuredOutput) + self.model = prompt | structured_llm + self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_selection_node") + + def format_human_prompt(self, state: ContextRetrievalState, search_result: str) -> str: + context_info = self.HUMAN_PROMPT.format(query=state["query"], context=search_result) + return context_info + + def __call__(self, state: ContextRetrievalState): + context_list = state.get("context", []) + for tool_message in extract_last_tool_messages(state["context_provider_messages"]): + for search_result in neo4j_data_for_context_generator(tool_message.artifact): + human_prompt = self.format_human_prompt(state, search_result) + response = self.model.invoke({"human_prompt": human_prompt}) + self._logger.debug( + f"Is this search result {search_result} relevant?: {response.relevant}" + ) + if response.relevant: + context_list.append(search_result) + self._logger.info(f"Context selection complete, returning context {context_list}") + return {"context": context_list} diff --git a/prometheus/lang_graph/nodes/edit_message_node.py b/prometheus/lang_graph/nodes/edit_message_node.py index 2a1d42af..ec1f2479 100644 --- a/prometheus/lang_graph/nodes/edit_message_node.py +++ b/prometheus/lang_graph/nodes/edit_message_node.py @@ -8,7 +8,7 @@ class EditMessageNode: - FIRST_HUMAN_PROMPT = """\ + FIRST_HUMAN_PROMPT = """\ {issue_info} Bug Context: @@ -20,7 +20,7 @@ class EditMessageNode: Please implement these changes precisely, following the exact specifications from the analyzer. """ - FOLLOWUP_HUMAN_PROMPT = """\ + FOLLOWUP_HUMAN_PROMPT = """\ The edit that you generated following error: {edit_error} @@ -31,39 +31,39 @@ class EditMessageNode: specific issues that caused the previous error. """ - def __init__(self): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.edit_message_node") + def __init__(self): + self._logger = logging.getLogger("prometheus.lang_graph.nodes.edit_message_node") - def format_human_message(self, state: Dict): - edit_error = "" - if "reproducing_test_fail_log" in state and state["reproducing_test_fail_log"]: - edit_error = ( - f"Your failed to pass the bug exposing test cases:\n{state['reproducing_test_fail_log']}" - ) - elif "build_fail_log" in state and state["build_fail_log"]: - edit_error = f"Your failed to pass the build:\n{state['build_fail_log']}" - elif "existing_test_fail_log" in state and state["existing_test_fail_log"]: - edit_error = f"Your failed to existing test cases:\n{state['existing_test_fail_log']}" + def format_human_message(self, state: Dict): + edit_error = "" + if "reproducing_test_fail_log" in state and state["reproducing_test_fail_log"]: + edit_error = f"Your failed to pass the bug exposing test cases:\n{state['reproducing_test_fail_log']}" + elif "build_fail_log" in state and state["build_fail_log"]: + edit_error = f"Your failed to pass the build:\n{state['build_fail_log']}" + elif "existing_test_fail_log" in state and state["existing_test_fail_log"]: + edit_error = f"Your failed to existing test cases:\n{state['existing_test_fail_log']}" - if edit_error: - return HumanMessage( - self.FOLLOWUP_HUMAN_PROMPT.format( - edit_error=edit_error, - bug_analyzer_message=get_last_message_content(state["issue_bug_analyzer_messages"]), - ) - ) + if edit_error: + return HumanMessage( + self.FOLLOWUP_HUMAN_PROMPT.format( + edit_error=edit_error, + bug_analyzer_message=get_last_message_content( + state["issue_bug_analyzer_messages"] + ), + ) + ) - return HumanMessage( - self.FIRST_HUMAN_PROMPT.format( - issue_info=format_issue_info( - state["issue_title"], state["issue_body"], state["issue_comments"] - ), - bug_fix_context="\n\n".join(state["bug_fix_context"]), - bug_analyzer_message=get_last_message_content(state["issue_bug_analyzer_messages"]), - ) - ) + return HumanMessage( + self.FIRST_HUMAN_PROMPT.format( + issue_info=format_issue_info( + state["issue_title"], state["issue_body"], state["issue_comments"] + ), + bug_fix_context="\n\n".join(state["bug_fix_context"]), + bug_analyzer_message=get_last_message_content(state["issue_bug_analyzer_messages"]), + ) + ) - def __call__(self, state: Dict): - human_message = self.format_human_message(state) - self._logger.debug(f"Sending message to EditNode:\n{human_message}") - return {"edit_messages": [human_message]} + def __call__(self, state: Dict): + human_message = self.format_human_message(state) + self._logger.debug(f"Sending message to EditNode:\n{human_message}") + return {"edit_messages": [human_message]} diff --git a/prometheus/lang_graph/nodes/edit_node.py b/prometheus/lang_graph/nodes/edit_node.py index d181afdf..6ba03215 100644 --- a/prometheus/lang_graph/nodes/edit_node.py +++ b/prometheus/lang_graph/nodes/edit_node.py @@ -20,7 +20,7 @@ class EditNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are a specialized editing agent responsible for implementing precise changes to files. Your primary focus is on accurately implementing the code changes that have been analyzed and proposed by the user. @@ -117,76 +117,76 @@ def other_method(): 6. Verify uniqueness of matches before changes """ - def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): - self.system_prompt = SystemMessage(self.SYS_PROMPT) - self.tools = self._init_tools(kg.get_local_path()) - self.model_with_tools = model.bind_tools(self.tools) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.edit_node") - - def _init_tools(self, root_path: str): - """Initializes file operation tools with the given root path. - - Args: - root_path: Base directory path for all file operations. - - Returns: - List of StructuredTool instances configured for file operations. - """ - tools = [] - - read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) - read_file_tool = StructuredTool.from_function( - func=read_file_fn, - name=file_operation.read_file.__name__, - description=file_operation.READ_FILE_DESCRIPTION, - args_schema=file_operation.ReadFileInput, - ) - tools.append(read_file_tool) - - read_file_with_line_numbers_fn = functools.partial( - file_operation.read_file_with_line_numbers, root_path=root_path - ) - read_file_with_line_numbers_tool = StructuredTool.from_function( - func=read_file_with_line_numbers_fn, - name=file_operation.read_file_with_line_numbers.__name__, - description=file_operation.READ_FILE_WITH_LINE_NUMBERS_DESCRIPTION, - args_schema=file_operation.ReadFileWithLineNumbersInput, - ) - tools.append(read_file_with_line_numbers_tool) - - create_file_fn = functools.partial(file_operation.create_file, root_path=root_path) - create_file_tool = StructuredTool.from_function( - func=create_file_fn, - name=file_operation.create_file.__name__, - description=file_operation.CREATE_FILE_DESCRIPTION, - args_schema=file_operation.CreateFileInput, - ) - tools.append(create_file_tool) - - delete_fn = functools.partial(file_operation.delete, root_path=root_path) - delete_tool = StructuredTool.from_function( - func=delete_fn, - name=file_operation.delete.__name__, - description=file_operation.DELETE_DESCRIPTION, - args_schema=file_operation.DeleteInput, - ) - tools.append(delete_tool) - - edit_file_fn = functools.partial(file_operation.edit_file, root_path=root_path) - edit_file_tool = StructuredTool.from_function( - func=edit_file_fn, - name=file_operation.edit_file.__name__, - description=file_operation.EDIT_FILE_DESCRIPTION, - args_schema=file_operation.EditFileInput, - ) - tools.append(edit_file_tool) - - return tools - - def __call__(self, state: Dict): - message_history = [self.system_prompt] + state["edit_messages"] - truncated_message_history = truncate_messages(message_history) - response = self.model_with_tools.invoke(truncated_message_history) - - self._logger.debug(response) - return {"edit_messages": [response]} + def __init__(self, model: BaseChatModel, kg: KnowledgeGraph): + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self.tools = self._init_tools(kg.get_local_path()) + self.model_with_tools = model.bind_tools(self.tools) + self._logger = logging.getLogger("prometheus.lang_graph.nodes.edit_node") + + def _init_tools(self, root_path: str): + """Initializes file operation tools with the given root path. + + Args: + root_path: Base directory path for all file operations. + + Returns: + List of StructuredTool instances configured for file operations. + """ + tools = [] + + read_file_fn = functools.partial(file_operation.read_file, root_path=root_path) + read_file_tool = StructuredTool.from_function( + func=read_file_fn, + name=file_operation.read_file.__name__, + description=file_operation.READ_FILE_DESCRIPTION, + args_schema=file_operation.ReadFileInput, + ) + tools.append(read_file_tool) + + read_file_with_line_numbers_fn = functools.partial( + file_operation.read_file_with_line_numbers, root_path=root_path + ) + read_file_with_line_numbers_tool = StructuredTool.from_function( + func=read_file_with_line_numbers_fn, + name=file_operation.read_file_with_line_numbers.__name__, + description=file_operation.READ_FILE_WITH_LINE_NUMBERS_DESCRIPTION, + args_schema=file_operation.ReadFileWithLineNumbersInput, + ) + tools.append(read_file_with_line_numbers_tool) + + create_file_fn = functools.partial(file_operation.create_file, root_path=root_path) + create_file_tool = StructuredTool.from_function( + func=create_file_fn, + name=file_operation.create_file.__name__, + description=file_operation.CREATE_FILE_DESCRIPTION, + args_schema=file_operation.CreateFileInput, + ) + tools.append(create_file_tool) + + delete_fn = functools.partial(file_operation.delete, root_path=root_path) + delete_tool = StructuredTool.from_function( + func=delete_fn, + name=file_operation.delete.__name__, + description=file_operation.DELETE_DESCRIPTION, + args_schema=file_operation.DeleteInput, + ) + tools.append(delete_tool) + + edit_file_fn = functools.partial(file_operation.edit_file, root_path=root_path) + edit_file_tool = StructuredTool.from_function( + func=edit_file_fn, + name=file_operation.edit_file.__name__, + description=file_operation.EDIT_FILE_DESCRIPTION, + args_schema=file_operation.EditFileInput, + ) + tools.append(edit_file_tool) + + return tools + + def __call__(self, state: Dict): + message_history = [self.system_prompt] + state["edit_messages"] + truncated_message_history = truncate_messages(message_history) + response = self.model_with_tools.invoke(truncated_message_history) + + self._logger.debug(response) + return {"edit_messages": [response]} diff --git a/prometheus/lang_graph/nodes/final_patch_selection_node.py b/prometheus/lang_graph/nodes/final_patch_selection_node.py index e9ee8d48..260ddb1d 100644 --- a/prometheus/lang_graph/nodes/final_patch_selection_node.py +++ b/prometheus/lang_graph/nodes/final_patch_selection_node.py @@ -9,14 +9,14 @@ class FinalPatchSelectionStructuredOutput(BaseModel): - reasoning: str = Field( - description="Your step-by-step reasoning why the selected patch is the best" - ) - patch_index: int = Field(description="The patch index that you select") + reasoning: str = Field( + description="Your step-by-step reasoning why the selected patch is the best" + ) + patch_index: int = Field(description="The patch index that you select") class FinalPatchSelectionNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are an expert programming assistant specialized in evaluating and selecting the best patch among multiple options. Your goal is to analyze each patch and select the most appropriate one based on the following prioritized criteria: 1. EFFECTIVENESS: The patch must correctly fix the reported issue @@ -109,7 +109,7 @@ class FinalPatchSelectionNode: - Pay attention to the git diff format, including file paths, chunk headers, and line numbers """.replace("{", "{{").replace("}", "}}") - HUMAN_PROMPT = """\ + HUMAN_PROMPT = """\ {issue_info} Bug Context: @@ -119,40 +119,40 @@ class FinalPatchSelectionNode: {patches} """ - def __init__(self, model: BaseChatModel, max_retries: int = 2): - self.max_retries = max_retries - prompt = ChatPromptTemplate.from_messages( - [("system", self.SYS_PROMPT), ("human", "{human_prompt}")] - ) - structured_llm = model.with_structured_output(FinalPatchSelectionStructuredOutput) - self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.final_patch_selection_node") - - def format_human_message(self, state: Dict): - patches = "" - for index, patch in enumerate(state["edit_patches"]): - patches += f"Patch at index {index}:\n" - patches += f"{patch}\n\n" - patches += f"You must select a patch with index from 0 to {len(state['edit_patches']) - 1}, and provide your reasoning." - - return self.HUMAN_PROMPT.format( - issue_info=format_issue_info( - state["issue_title"], state["issue_body"], state["issue_comments"] - ), - bug_fix_context="\n\n".join(state["bug_fix_context"]), - patches=patches, - ) - - def __call__(self, state: Dict): - human_prompt = self.format_human_message(state) - for try_index in range(self.max_retries): - response = self.model.invoke({"human_prompt": human_prompt}) - self._logger.info(f"FinalPatchSelectionNode response at {try_index} try:\n{response}") - - if response.patch_index >= 0 and response.patch_index < len(state["edit_patches"]): - return {"final_patch": state["edit_patches"][response.patch_index]} - - self._logger.info( - "FinalPatchSelectionNode failed to select a patch with correct index, defaulting to 0" - ) - return {"final_patch": state["edit_patches"][0]} + def __init__(self, model: BaseChatModel, max_retries: int = 2): + self.max_retries = max_retries + prompt = ChatPromptTemplate.from_messages( + [("system", self.SYS_PROMPT), ("human", "{human_prompt}")] + ) + structured_llm = model.with_structured_output(FinalPatchSelectionStructuredOutput) + self.model = prompt | structured_llm + self._logger = logging.getLogger("prometheus.lang_graph.nodes.final_patch_selection_node") + + def format_human_message(self, state: Dict): + patches = "" + for index, patch in enumerate(state["edit_patches"]): + patches += f"Patch at index {index}:\n" + patches += f"{patch}\n\n" + patches += f"You must select a patch with index from 0 to {len(state['edit_patches']) - 1}, and provide your reasoning." + + return self.HUMAN_PROMPT.format( + issue_info=format_issue_info( + state["issue_title"], state["issue_body"], state["issue_comments"] + ), + bug_fix_context="\n\n".join(state["bug_fix_context"]), + patches=patches, + ) + + def __call__(self, state: Dict): + human_prompt = self.format_human_message(state) + for try_index in range(self.max_retries): + response = self.model.invoke({"human_prompt": human_prompt}) + self._logger.info(f"FinalPatchSelectionNode response at {try_index} try:\n{response}") + + if response.patch_index >= 0 and response.patch_index < len(state["edit_patches"]): + return {"final_patch": state["edit_patches"][response.patch_index]} + + self._logger.info( + "FinalPatchSelectionNode failed to select a patch with correct index, defaulting to 0" + ) + return {"final_patch": state["edit_patches"][0]} diff --git a/prometheus/lang_graph/nodes/general_build_node.py b/prometheus/lang_graph/nodes/general_build_node.py index b6d31a39..1d3454f5 100644 --- a/prometheus/lang_graph/nodes/general_build_node.py +++ b/prometheus/lang_graph/nodes/general_build_node.py @@ -12,7 +12,7 @@ class GeneralBuildNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are a build system expert responsible for building software projects in an docker container. You are positioned at the root of the codebase where you can run commands. Your goal is to determine the correct build approach and execute the build. @@ -41,43 +41,45 @@ class GeneralBuildNode: - Simply execute the build and report that it was run """ - def __init__(self, model: BaseChatModel, container: BaseContainer, kg: KnowledgeGraph): - self.kg = kg - self.tools = self._init_tools(container) - self.model_with_tools = model.bind_tools(self.tools) - self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_build_node") + def __init__(self, model: BaseChatModel, container: BaseContainer, kg: KnowledgeGraph): + self.kg = kg + self.tools = self._init_tools(container) + self.model_with_tools = model.bind_tools(self.tools) + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_build_node") - def _init_tools(self, container: BaseContainer): - tools = [] + def _init_tools(self, container: BaseContainer): + tools = [] - run_command_fn = functools.partial(container_command.run_command, container=container) - run_command_tool = StructuredTool.from_function( - func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, - ) - tools.append(run_command_tool) + run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_tool = StructuredTool.from_function( + func=run_command_fn, + name=container_command.run_command.__name__, + description=container_command.RUN_COMMAND_DESCRIPTION, + args_schema=container_command.RunCommandInput, + ) + tools.append(run_command_tool) - return tools + return tools - def format_human_message(self, state: BuildAndTestState) -> HumanMessage: - message = f"The (incomplete) project structure is:\n{self.kg.get_file_tree()}" - if "build_command_summary" in state and state["build_command_summary"]: - message += f"\n\nThe previous build summary is:\n{state['build_command_summary']}" - return HumanMessage(message) + def format_human_message(self, state: BuildAndTestState) -> HumanMessage: + message = f"The (incomplete) project structure is:\n{self.kg.get_file_tree()}" + if "build_command_summary" in state and state["build_command_summary"]: + message += f"\n\nThe previous build summary is:\n{state['build_command_summary']}" + return HumanMessage(message) - def __call__(self, state: BuildAndTestState): - if "exist_build" in state and not state["exist_build"]: - self._logger.info("exist_build is false, skipping build.") - return { - "build_messages": [AIMessage(content="Previous agent determined there is no build system.")] - } + def __call__(self, state: BuildAndTestState): + if "exist_build" in state and not state["exist_build"]: + self._logger.info("exist_build is false, skipping build.") + return { + "build_messages": [ + AIMessage(content="Previous agent determined there is no build system.") + ] + } - message_history = [self.system_prompt, self.format_human_message(state)] + state[ - "build_messages" - ] - response = self.model_with_tools.invoke(message_history) - self._logger.debug(response) - return {"build_messages": [response]} + message_history = [self.system_prompt, self.format_human_message(state)] + state[ + "build_messages" + ] + response = self.model_with_tools.invoke(message_history) + self._logger.debug(response) + return {"build_messages": [response]} diff --git a/prometheus/lang_graph/nodes/general_build_structured_node.py b/prometheus/lang_graph/nodes/general_build_structured_node.py index 5b49c7bd..af9058bc 100644 --- a/prometheus/lang_graph/nodes/general_build_structured_node.py +++ b/prometheus/lang_graph/nodes/general_build_structured_node.py @@ -17,39 +17,39 @@ class BuildStructuredOutput(BaseModel): - """Structured output model for build analysis results. + """Structured output model for build analysis results. - Attributes: - exist_build: Boolean indicating presence of a build system. - command_summary: Detailed description of build system and required commands. - fail_log: Error logs from failed builds, empty string if successful. - """ + Attributes: + exist_build: Boolean indicating presence of a build system. + command_summary: Detailed description of build system and required commands. + fail_log: Error logs from failed builds, empty string if successful. + """ - exist_build: bool = Field( - description="Indicates if there is any build system present in the project" - ) - command_summary: str = Field( - description="Summary of the build system and list of commands required to build the project" - ) - fail_log: str = Field( - description="Contains the error logs if build failed, empty string if successful" - ) + exist_build: bool = Field( + description="Indicates if there is any build system present in the project" + ) + command_summary: str = Field( + description="Summary of the build system and list of commands required to build the project" + ) + fail_log: str = Field( + description="Contains the error logs if build failed, empty string if successful" + ) class GeneralBuildStructuredNode: - """Analyzes and summarizes build execution attempts for software projects. + """Analyzes and summarizes build execution attempts for software projects. - This class processes build execution histories to provide structured analysis - of build systems, required build steps, and any failures encountered. It can - identify and analyze various build systems including CMake, npm, Maven, and others. + This class processes build execution histories to provide structured analysis + of build systems, required build steps, and any failures encountered. It can + identify and analyze various build systems including CMake, npm, Maven, and others. - The analysis covers three main areas: - 1. Build system detection and classification - 2. Required build steps and commands - 3. Build failure analysis and error logging - """ + The analysis covers three main areas: + 1. Build system detection and classification + 2. Required build steps and commands + 3. Build failure analysis and error logging + """ - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are a build system expert analyzing build attempts for software projects. You'll review a history of commands executed by an agent that attempted to build the project. Examine this build history to: @@ -221,45 +221,47 @@ class GeneralBuildStructuredNode: } """.replace("{", "{{").replace("}", "}}") - def __init__(self, model: BaseChatModel): - """Initializes the GeneralBuildSummarizationNode with an analysis model. - - Sets up the build summarizer with a prompt template and structured output - model for analyzing build execution histories. - - Args: - model: Language model instance that will be used for analyzing build - histories and generating structured summaries. Must be a - BaseChatModel implementation. - """ - prompt = ChatPromptTemplate.from_messages( - [("system", self.SYS_PROMPT), ("human", "{build_history}")] - ) - structured_llm = model.with_structured_output(BuildStructuredOutput) - self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_build_structured_node") - - def __call__(self, state: BuildAndTestState): - """Processes build state to generate structured build analysis. - - Analyzes the build execution history to determine build system presence, - required commands, and any failures encountered. - - Args: - state: Current state containing build execution messages and history. - - Returns: - Dictionary that updates the statecontaining: - - exist_build: Boolean indicating if a build system exists - - build_command_summary: String describing build system and required commands - - build_fail_log: String containing error logs (empty if successful) - """ - build_history = format_agent_tool_message_history(state["build_messages"]) - response = self.model.invoke({"build_history": build_history}) - self._logger.debug(response) - - return { - "exist_build": response.exist_build, - "build_command_summary": response.command_summary, - "build_fail_log": response.fail_log, - } + def __init__(self, model: BaseChatModel): + """Initializes the GeneralBuildSummarizationNode with an analysis model. + + Sets up the build summarizer with a prompt template and structured output + model for analyzing build execution histories. + + Args: + model: Language model instance that will be used for analyzing build + histories and generating structured summaries. Must be a + BaseChatModel implementation. + """ + prompt = ChatPromptTemplate.from_messages( + [("system", self.SYS_PROMPT), ("human", "{build_history}")] + ) + structured_llm = model.with_structured_output(BuildStructuredOutput) + self.model = prompt | structured_llm + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.general_build_structured_node" + ) + + def __call__(self, state: BuildAndTestState): + """Processes build state to generate structured build analysis. + + Analyzes the build execution history to determine build system presence, + required commands, and any failures encountered. + + Args: + state: Current state containing build execution messages and history. + + Returns: + Dictionary that updates the statecontaining: + - exist_build: Boolean indicating if a build system exists + - build_command_summary: String describing build system and required commands + - build_fail_log: String containing error logs (empty if successful) + """ + build_history = format_agent_tool_message_history(state["build_messages"]) + response = self.model.invoke({"build_history": build_history}) + self._logger.debug(response) + + return { + "exist_build": response.exist_build, + "build_command_summary": response.command_summary, + "build_fail_log": response.fail_log, + } diff --git a/prometheus/lang_graph/nodes/general_test_node.py b/prometheus/lang_graph/nodes/general_test_node.py index 9fcabfc5..d35c1ee8 100644 --- a/prometheus/lang_graph/nodes/general_test_node.py +++ b/prometheus/lang_graph/nodes/general_test_node.py @@ -12,15 +12,15 @@ class GeneralTestNode: - """Executes tests for software projects in containerized environments. + """Executes tests for software projects in containerized environments. - This class provides automated test execution capabilities by analyzing project - structures, detecting testing frameworks, and running appropriate test commands - in an Ubuntu container. It focuses solely on test execution without attempting - to analyze or fix any test failures. - """ + This class provides automated test execution capabilities by analyzing project + structures, detecting testing frameworks, and running appropriate test commands + in an Ubuntu container. It focuses solely on test execution without attempting + to analyze or fix any test failures. + """ - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are a testing expert responsible for running tests for software projects in an Ubuntu container. You are positioned at the root of the codebase where you can run commands. Your goal is to determine the correct testing framework and execute the tests. @@ -58,45 +58,45 @@ class GeneralTestNode: - Simply execute the tests and report that they were run """ - def __init__(self, model: BaseChatModel, container: BaseContainer, kg: KnowledgeGraph): - self.kg = kg - self.tools = self._init_tools(container) - self.model_with_tools = model.bind_tools(self.tools) - self.system_prompt = SystemMessage(self.SYS_PROMPT) - self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_test_node") - - def _init_tools(self, container: BaseContainer): - tools = [] - - run_command_fn = functools.partial(container_command.run_command, container=container) - run_command_tool = StructuredTool.from_function( - func=run_command_fn, - name=container_command.run_command.__name__, - description=container_command.RUN_COMMAND_DESCRIPTION, - args_schema=container_command.RunCommandInput, - ) - tools.append(run_command_tool) - - return tools - - def format_human_message(self, state: BuildAndTestState) -> HumanMessage: - message = f"The (incomplete) project structure is:\n{self.kg.get_file_tree()}" - if "test_command_summary" in state and state["test_command_summary"]: - message += f"\n\nThe previous test summary is:\n{state['test_command_summary']}" - return HumanMessage(message) - - def __call__(self, state: BuildAndTestState): - if "exist_test" in state and not state["exist_test"]: - self._logger.info("exist_test is false, skipping test.") - return { - "build_messages": [ - AIMessage(content="Previous agent determined there is no test framework.") + def __init__(self, model: BaseChatModel, container: BaseContainer, kg: KnowledgeGraph): + self.kg = kg + self.tools = self._init_tools(container) + self.model_with_tools = model.bind_tools(self.tools) + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_test_node") + + def _init_tools(self, container: BaseContainer): + tools = [] + + run_command_fn = functools.partial(container_command.run_command, container=container) + run_command_tool = StructuredTool.from_function( + func=run_command_fn, + name=container_command.run_command.__name__, + description=container_command.RUN_COMMAND_DESCRIPTION, + args_schema=container_command.RunCommandInput, + ) + tools.append(run_command_tool) + + return tools + + def format_human_message(self, state: BuildAndTestState) -> HumanMessage: + message = f"The (incomplete) project structure is:\n{self.kg.get_file_tree()}" + if "test_command_summary" in state and state["test_command_summary"]: + message += f"\n\nThe previous test summary is:\n{state['test_command_summary']}" + return HumanMessage(message) + + def __call__(self, state: BuildAndTestState): + if "exist_test" in state and not state["exist_test"]: + self._logger.info("exist_test is false, skipping test.") + return { + "build_messages": [ + AIMessage(content="Previous agent determined there is no test framework.") + ] + } + + message_history = [self.system_prompt, self.format_human_message(state)] + state[ + "test_messages" ] - } - - message_history = [self.system_prompt, self.format_human_message(state)] + state[ - "test_messages" - ] - response = self.model_with_tools.invoke(message_history) - self._logger.debug(response) - return {"test_messages": [response]} + response = self.model_with_tools.invoke(message_history) + self._logger.debug(response) + return {"test_messages": [response]} diff --git a/prometheus/lang_graph/nodes/general_test_structured_node.py b/prometheus/lang_graph/nodes/general_test_structured_node.py index 42416419..78758070 100644 --- a/prometheus/lang_graph/nodes/general_test_structured_node.py +++ b/prometheus/lang_graph/nodes/general_test_structured_node.py @@ -17,34 +17,34 @@ class TestStructuredOutput(BaseModel): - """Structured output model for test analysis results. + """Structured output model for test analysis results. - Attributes: - exist_test: Boolean indicating presence of a test framework. - command_summary: Detailed description of test framework and required commands. - fail_log: Detailed test failure logs and messages, empty string if all tests passed. - """ + Attributes: + exist_test: Boolean indicating presence of a test framework. + command_summary: Detailed description of test framework and required commands. + fail_log: Detailed test failure logs and messages, empty string if all tests passed. + """ - exist_test: bool = Field( - description="Indicates if there is any test framework present in the project" - ) - command_summary: str = Field( - description="Summary of the test framework and list of commands required to run tests" - ) - fail_log: str = Field( - description="Contains the test failure logs if any tests failed, empty string if all passed" - ) + exist_test: bool = Field( + description="Indicates if there is any test framework present in the project" + ) + command_summary: str = Field( + description="Summary of the test framework and list of commands required to run tests" + ) + fail_log: str = Field( + description="Contains the test failure logs if any tests failed, empty string if all passed" + ) class GeneralTestStructuredNode: - """Analyzes and summarizes test execution attempts for software projects. + """Analyzes and summarizes test execution attempts for software projects. - This class processes test execution histories to provide structured analysis - of test frameworks, required test steps, and any failures encountered. It can - identify and analyze various testing frameworks including pytest, Jest, and others. - """ + This class processes test execution histories to provide structured analysis + of test frameworks, required test steps, and any failures encountered. It can + identify and analyze various testing frameworks including pytest, Jest, and others. + """ - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are a testing expert analyzing test execution history for software projects. You'll review a history of commands executed by an agent that attempted to run the tests. Examine this test history to: @@ -270,44 +270,44 @@ def test_user_permissions(): } """.replace("{", "{{").replace("}", "}}") - def __init__(self, model: BaseChatModel): - """Initializes the GeneralTestSummarizationNode with an analysis model. - - Sets up the test summarizer with a prompt template and structured output - model for analyzing test execution histories. - - Args: - model: Language model instance that will be used for analyzing test - histories and generating structured summaries. Must be a - BaseChatModel implementation. - """ - prompt = ChatPromptTemplate.from_messages( - [("system", self.SYS_PROMPT), ("human", "{test_history}")] - ) - structured_llm = model.with_structured_output(TestStructuredOutput) - self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_test_structured_node") - - def __call__(self, state: BuildAndTestState): - """Processes test state to generate structured test analysis. - - Analyzes the test execution history to determine test framework presence, - required commands, and any failures encountered. - - Args: - state: Current state containing test execution messages and history. - - Returns: - Dictionary that updates the state containing: - - exist_test: Boolean indicating if a test framework exists - - test_command_summary: String describing test framework and required commands - - existing_test_fail_log: String containing test failure details (empty if all passed) - """ - test_history = format_agent_tool_message_history(state["test_messages"]) - response = self.model.invoke({"test_history": test_history}) - self._logger.debug(response) - return { - "exist_test": response.exist_test, - "test_command_summary": response.command_summary, - "existing_test_fail_log": response.fail_log, - } + def __init__(self, model: BaseChatModel): + """Initializes the GeneralTestSummarizationNode with an analysis model. + + Sets up the test summarizer with a prompt template and structured output + model for analyzing test execution histories. + + Args: + model: Language model instance that will be used for analyzing test + histories and generating structured summaries. Must be a + BaseChatModel implementation. + """ + prompt = ChatPromptTemplate.from_messages( + [("system", self.SYS_PROMPT), ("human", "{test_history}")] + ) + structured_llm = model.with_structured_output(TestStructuredOutput) + self.model = prompt | structured_llm + self._logger = logging.getLogger("prometheus.lang_graph.nodes.general_test_structured_node") + + def __call__(self, state: BuildAndTestState): + """Processes test state to generate structured test analysis. + + Analyzes the test execution history to determine test framework presence, + required commands, and any failures encountered. + + Args: + state: Current state containing test execution messages and history. + + Returns: + Dictionary that updates the state containing: + - exist_test: Boolean indicating if a test framework exists + - test_command_summary: String describing test framework and required commands + - existing_test_fail_log: String containing test failure details (empty if all passed) + """ + test_history = format_agent_tool_message_history(state["test_messages"]) + response = self.model.invoke({"test_history": test_history}) + self._logger.debug(response) + return { + "exist_test": response.exist_test, + "test_command_summary": response.command_summary, + "existing_test_fail_log": response.fail_log, + } diff --git a/prometheus/lang_graph/nodes/git_diff_node.py b/prometheus/lang_graph/nodes/git_diff_node.py index 8f913b38..4a8be50c 100644 --- a/prometheus/lang_graph/nodes/git_diff_node.py +++ b/prometheus/lang_graph/nodes/git_diff_node.py @@ -13,56 +13,56 @@ class GitDiffNode: - """Generates Git diffs for project modifications. + """Generates Git diffs for project modifications. - This class handles the generation of Git diffs to track changes made to a project. - It works with a GitRepository instance to access the project's Git operations - and create patch output. The node is typically used as part of an automated - workflow to capture code modifications made during issue resolution. - """ + This class handles the generation of Git diffs to track changes made to a project. + It works with a GitRepository instance to access the project's Git operations + and create patch output. The node is typically used as part of an automated + workflow to capture code modifications made during issue resolution. + """ - def __init__( - self, - git_repo: GitRepository, - state_patch_name: str, - state_excluded_files_key: Optional[str] = None, - return_list: bool = False, - ): - self.git_repo = git_repo - self.state_patch_name = state_patch_name - self.state_excluded_files_key = state_excluded_files_key - self.return_list = return_list - self._logger = logging.getLogger("prometheus.lang_graph.nodes.git_diff_node") + def __init__( + self, + git_repo: GitRepository, + state_patch_name: str, + state_excluded_files_key: Optional[str] = None, + return_list: bool = False, + ): + self.git_repo = git_repo + self.state_patch_name = state_patch_name + self.state_excluded_files_key = state_excluded_files_key + self.return_list = return_list + self._logger = logging.getLogger("prometheus.lang_graph.nodes.git_diff_node") - def __call__(self, state: Dict): - """Generates a Git diff for the current project state. + def __call__(self, state: Dict): + """Generates a Git diff for the current project state. - Creates a Git repository instance for the project path specified in the - state and generates a diff of any uncommitted changes. + Creates a Git repository instance for the project path specified in the + state and generates a diff of any uncommitted changes. - Args: - state: Current state containing project information, including the - project_path key specifying the Git repository location. + Args: + state: Current state containing project information, including the + project_path key specifying the Git repository location. - Returns: - Dictionary that update the state containing: - - patch: String containing the Git diff output showing all changes made to the project. - """ - excluded_files = None - if ( - self.state_excluded_files_key - and self.state_excluded_files_key in state - and state[self.state_excluded_files_key] - ): - excluded_files = state[self.state_excluded_files_key] - if isinstance(excluded_files, str): - excluded_files = [excluded_files] - self._logger.debug( - f"Excluding the following files when generating the patch: {excluded_files}" - ) - patch = self.git_repo.get_diff(excluded_files) - self._logger.info(f"Generated patch:\n{patch}") + Returns: + Dictionary that update the state containing: + - patch: String containing the Git diff output showing all changes made to the project. + """ + excluded_files = None + if ( + self.state_excluded_files_key + and self.state_excluded_files_key in state + and state[self.state_excluded_files_key] + ): + excluded_files = state[self.state_excluded_files_key] + if isinstance(excluded_files, str): + excluded_files = [excluded_files] + self._logger.debug( + f"Excluding the following files when generating the patch: {excluded_files}" + ) + patch = self.git_repo.get_diff(excluded_files) + self._logger.info(f"Generated patch:\n{patch}") - if self.return_list: - patch = [patch] - return {self.state_patch_name: patch} + if self.return_list: + patch = [patch] + return {self.state_patch_name: patch} diff --git a/prometheus/lang_graph/nodes/git_reset_node.py b/prometheus/lang_graph/nodes/git_reset_node.py index c4c31251..b9d63b30 100644 --- a/prometheus/lang_graph/nodes/git_reset_node.py +++ b/prometheus/lang_graph/nodes/git_reset_node.py @@ -4,13 +4,13 @@ class GitResetNode: - def __init__( - self, - git_repo: GitRepository, - ): - self.git_repo = git_repo - self._logger = logging.getLogger("prometheus.lang_graph.nodes.git_reset_node") + def __init__( + self, + git_repo: GitRepository, + ): + self.git_repo = git_repo + self._logger = logging.getLogger("prometheus.lang_graph.nodes.git_reset_node") - def __call__(self, _): - self._logger.debug("Resetting the git repository") - self.git_repo.reset_repository() + def __call__(self, _): + self._logger.debug("Resetting the git repository") + self.git_repo.reset_repository() diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py index a1ae5f19..c5f63d28 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_message_node.py @@ -7,7 +7,7 @@ class IssueBugAnalyzerMessageNode: - FIRST_HUMAN_PROMPT = """\ + FIRST_HUMAN_PROMPT = """\ I am going to share details about an issue reported to a codebase and its related bug context. Please analyze this bug and provide a high-level description of what needs to be changed: @@ -40,7 +40,7 @@ class IssueBugAnalyzerMessageNode: {bug_fix_context} """ - FOLLOWUP_HUMAN_PROMPT = """\ + FOLLOWUP_HUMAN_PROMPT = """\ Given your suggestion, the edit agent generated the following patch: {edit_patch} @@ -63,36 +63,40 @@ class IssueBugAnalyzerMessageNode: Do NOT provide actual code snippets or diffs. Focus on describing what needs to be changed. """ - def __init__(self): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_analyzer_message_node") - - def format_human_message(self, state: Dict): - edit_error = "" - if "reproducing_test_fail_log" in state and state["reproducing_test_fail_log"]: - edit_error = f"The patch failed to pass the bug exposing test cases:\n{state['reproducing_test_fail_log']}" - elif "build_fail_log" in state and state["build_fail_log"]: - edit_error = f"The patch failed to pass the build:\n{state['build_fail_log']}" - elif "existing_test_fail_log" in state and state["existing_test_fail_log"]: - edit_error = f"The patch failed to existing test cases:\n{state['existing_test_fail_log']}" - - if not edit_error: - return HumanMessage( - self.FIRST_HUMAN_PROMPT.format( - issue_info=format_issue_info( - state["issue_title"], state["issue_body"], state["issue_comments"] - ), - bug_fix_context="\n\n".join(state["bug_fix_context"]), + def __init__(self): + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.issue_bug_analyzer_message_node" ) - ) - - return HumanMessage( - self.FOLLOWUP_HUMAN_PROMPT.format( - edit_patch=state["edit_patch"], - edit_error=edit_error, - ) - ) - - def __call__(self, state: Dict): - human_message = self.format_human_message(state) - self._logger.debug(f"Sending message to IssueBugAnalyzerNode:\n{human_message}") - return {"issue_bug_analyzer_messages": [human_message]} + + def format_human_message(self, state: Dict): + edit_error = "" + if "reproducing_test_fail_log" in state and state["reproducing_test_fail_log"]: + edit_error = f"The patch failed to pass the bug exposing test cases:\n{state['reproducing_test_fail_log']}" + elif "build_fail_log" in state and state["build_fail_log"]: + edit_error = f"The patch failed to pass the build:\n{state['build_fail_log']}" + elif "existing_test_fail_log" in state and state["existing_test_fail_log"]: + edit_error = ( + f"The patch failed to existing test cases:\n{state['existing_test_fail_log']}" + ) + + if not edit_error: + return HumanMessage( + self.FIRST_HUMAN_PROMPT.format( + issue_info=format_issue_info( + state["issue_title"], state["issue_body"], state["issue_comments"] + ), + bug_fix_context="\n\n".join(state["bug_fix_context"]), + ) + ) + + return HumanMessage( + self.FOLLOWUP_HUMAN_PROMPT.format( + edit_patch=state["edit_patch"], + edit_error=edit_error, + ) + ) + + def __call__(self, state: Dict): + human_message = self.format_human_message(state) + self._logger.debug(f"Sending message to IssueBugAnalyzerNode:\n{human_message}") + return {"issue_bug_analyzer_messages": [human_message]} diff --git a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py index 0e38104d..4fd09e4f 100644 --- a/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_analyzer_node.py @@ -8,7 +8,7 @@ class IssueBugAnalyzerNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are an expert software engineer specializing in bug analysis and fixes. Your role is to: 1. Carefully analyze reported software issues and bugs by: @@ -42,15 +42,15 @@ class IssueBugAnalyzerNode: rather than implementation details. """ - def __init__(self, model: BaseChatModel): - self.system_prompt = SystemMessage(self.SYS_PROMPT) - self.model = model - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_analyzer_node") + def __init__(self, model: BaseChatModel): + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self.model = model + self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_analyzer_node") - def __call__(self, state: Dict): - message_history = [self.system_prompt] + state["issue_bug_analyzer_messages"] - truncated_message_history = truncate_messages(message_history) - response = self.model.invoke(truncated_message_history) + def __call__(self, state: Dict): + message_history = [self.system_prompt] + state["issue_bug_analyzer_messages"] + truncated_message_history = truncate_messages(message_history) + response = self.model.invoke(truncated_message_history) - self._logger.debug(response) - return {"issue_bug_analyzer_messages": [response]} + self._logger.debug(response) + return {"issue_bug_analyzer_messages": [response]} diff --git a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py index bdd2d57b..238d84fb 100644 --- a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py @@ -5,8 +5,8 @@ class IssueBugContextMessageNode: - BUG_FIX_QUERY = ( - """\ + BUG_FIX_QUERY = ( + """\ {issue_info} Find all relevant source code context and documentation needed to understand and fix this issue. @@ -18,18 +18,20 @@ class IssueBugContextMessageNode: Skip any test files """.replace("{", "{{") - .replace("}", "}}") - .replace("{{issue_info}}", "{issue_info}") - ) + .replace("}", "}}") + .replace("{{issue_info}}", "{issue_info}") + ) - def __init__(self): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_context_message_node") + def __init__(self): + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.issue_bug_context_message_node" + ) - def __call__(self, state: Dict): - bug_fix_query = self.BUG_FIX_QUERY.format( - issue_info=format_issue_info( - state["issue_title"], state["issue_body"], state["issue_comments"] - ), - ) - self._logger.debug(f"Sending query to context provider:\n{bug_fix_query}") - return {"bug_fix_query": bug_fix_query} + def __call__(self, state: Dict): + bug_fix_query = self.BUG_FIX_QUERY.format( + issue_info=format_issue_info( + state["issue_title"], state["issue_body"], state["issue_comments"] + ), + ) + self._logger.debug(f"Sending query to context provider:\n{bug_fix_query}") + return {"bug_fix_query": bug_fix_query} diff --git a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py index df6e5f5c..e951e2d4 100644 --- a/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_reproduction_context_message_node.py @@ -5,7 +5,7 @@ class IssueBugReproductionContextMessageNode: - BUG_REPRODUCING_QUERY = """\ + BUG_REPRODUCING_QUERY = """\ {issue_info} OBJECTIVE: Find three relevant existing test cases that demonstrates similar functionality to the reported bug, @@ -107,16 +107,16 @@ def test_file_permission_denied(self, mock_open, mock_access): Find the THREE most relevant test cases with complete context, ensuring ALL necessary imports are included at the start of each test file. """ - def __init__(self): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_bug_reproduction_context_message_node" - ) - - def __call__(self, state: BugReproductionState): - bug_reproducing_query = self.BUG_REPRODUCING_QUERY.format( - issue_info=format_issue_info( - state["issue_title"], state["issue_body"], state["issue_comments"] - ), - ) - self._logger.debug(f"Sending query to context provider subgraph:\n{bug_reproducing_query}") - return {"bug_reproducing_query": bug_reproducing_query} + def __init__(self): + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.issue_bug_reproduction_context_message_node" + ) + + def __call__(self, state: BugReproductionState): + bug_reproducing_query = self.BUG_REPRODUCING_QUERY.format( + issue_info=format_issue_info( + state["issue_title"], state["issue_body"], state["issue_comments"] + ), + ) + self._logger.debug(f"Sending query to context provider subgraph:\n{bug_reproducing_query}") + return {"bug_reproducing_query": bug_reproducing_query} diff --git a/prometheus/lang_graph/nodes/issue_bug_responder_node.py b/prometheus/lang_graph/nodes/issue_bug_responder_node.py index a05f8b6a..1be5f22e 100644 --- a/prometheus/lang_graph/nodes/issue_bug_responder_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_responder_node.py @@ -8,7 +8,7 @@ class IssueBugResponderNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are the final agent in a multi-agent bug fixing system. Users report issues on GitHub/GitLab, and our system works to fix them. Your role is to compose the response that will be posted back to the issue thread. @@ -34,7 +34,7 @@ class IssueBugResponderNode: Format your response as a properly structured comment. """ - HUMAN_PROMPT = """\ + HUMAN_PROMPT = """\ {issue_info} Generated patch: @@ -44,42 +44,42 @@ class IssueBugResponderNode: {verification} """ - def __init__(self, model: BaseChatModel): - self.system_prompt = SystemMessage(self.SYS_PROMPT) - self.model = model + def __init__(self, model: BaseChatModel): + self.system_prompt = SystemMessage(self.SYS_PROMPT) + self.model = model - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_responder_node") + self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_responder_node") - def format_human_message(self, state: Dict) -> HumanMessage: - verification_messages = [] + def format_human_message(self, state: Dict) -> HumanMessage: + verification_messages = [] - # We only report successful verifications that were performed - if state["passed_reproducing_test"]: - verification_messages.append("โœ“ The bug reproducing test passed") + # We only report successful verifications that were performed + if state["passed_reproducing_test"]: + verification_messages.append("โœ“ The bug reproducing test passed") - if state["passed_build"]: - verification_messages.append("โœ“ Build passes successfully") + if state["passed_build"]: + verification_messages.append("โœ“ Build passes successfully") - if state["passed_existing_test"]: - verification_messages.append("โœ“ All existing tests pass successfully") + if state["passed_existing_test"]: + verification_messages.append("โœ“ All existing tests pass successfully") - verification_summary = "\n".join(verification_messages) + verification_summary = "\n".join(verification_messages) - formatted_message = self.HUMAN_PROMPT.format( - issue_info=format_issue_info( - state["issue_title"], state["issue_body"], state["issue_comments"] - ), - edit_patch=state["edit_patch"], - verification=verification_summary, - ) + formatted_message = self.HUMAN_PROMPT.format( + issue_info=format_issue_info( + state["issue_title"], state["issue_body"], state["issue_comments"] + ), + edit_patch=state["edit_patch"], + verification=verification_summary, + ) - return HumanMessage(content=formatted_message) + return HumanMessage(content=formatted_message) - def __call__(self, state: Dict): - messages = [ - self.system_prompt, - self.format_human_message(state), - ] - response = self.model.invoke(messages) - self._logger.debug(response) - return {"issue_response": response.content} + def __call__(self, state: Dict): + messages = [ + self.system_prompt, + self.format_human_message(state), + ] + response = self.model.invoke(messages) + self._logger.debug(response) + return {"issue_response": response.content} diff --git a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py index d5236be5..49ee14cf 100644 --- a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py @@ -13,68 +13,68 @@ class IssueBugSubgraphNode: - def __init__( - self, - advanced_model: BaseChatModel, - base_model: BaseChatModel, - container: BaseContainer, - kg: KnowledgeGraph, - git_repo: GitRepository, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - build_commands: Optional[Sequence[str]] = None, - test_commands: Optional[Sequence[str]] = None, - ): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_subgraph_node") - self.container = container - self.issue_bug_subgraph = IssueBugSubgraph( - advanced_model=advanced_model, - base_model=base_model, - container=container, - kg=kg, - git_repo=git_repo, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - build_commands=build_commands, - test_commands=test_commands, - ) + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + container: BaseContainer, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + build_commands: Optional[Sequence[str]] = None, + test_commands: Optional[Sequence[str]] = None, + ): + self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_bug_subgraph_node") + self.container = container + self.issue_bug_subgraph = IssueBugSubgraph( + advanced_model=advanced_model, + base_model=base_model, + container=container, + kg=kg, + git_repo=git_repo, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + build_commands=build_commands, + test_commands=test_commands, + ) - def __call__(self, state: IssueState): - self.container.build_docker_image() - self.container.start_container() + def __call__(self, state: IssueState): + self.container.build_docker_image() + self.container.start_container() - self._logger.info("Enter IssueBugSubgraphNode") + self._logger.info("Enter IssueBugSubgraphNode") - try: - output_state = self.issue_bug_subgraph.invoke( - issue_title=state["issue_title"], - issue_body=state["issue_body"], - issue_comments=state["issue_comments"], - run_build=state["run_build"], - run_existing_test=state["run_existing_test"], - number_of_candidate_patch=state["number_of_candidate_patch"], - ) + try: + output_state = self.issue_bug_subgraph.invoke( + issue_title=state["issue_title"], + issue_body=state["issue_body"], + issue_comments=state["issue_comments"], + run_build=state["run_build"], + run_existing_test=state["run_existing_test"], + number_of_candidate_patch=state["number_of_candidate_patch"], + ) - self._logger.info(f"Generated patch:\n{output_state['edit_patch']}") - self._logger.info(f"passed_reproducing_test: {output_state['passed_reproducing_test']}") - self._logger.info(f"passed_build: {output_state['passed_build']}") - self._logger.info(f"passed_existing_test: {output_state['passed_existing_test']}") - self._logger.info(f"issue_response:\n{output_state['issue_response']}") - return { - "edit_patch": output_state["edit_patch"], - "passed_reproducing_test": output_state["passed_reproducing_test"], - "passed_build": output_state["passed_build"], - "passed_existing_test": output_state["passed_existing_test"], - "issue_response": output_state["issue_response"], - } - except GraphRecursionError: - self._logger.critical("Please increase the recursion limit of IssueBugSubgraph") - return { - "edit_patch": "", - "passed_reproducing_test": False, - "passed_build": False, - "passed_existing_test": False, - "issue_response": "", - } - finally: - self.container.cleanup() + self._logger.info(f"Generated patch:\n{output_state['edit_patch']}") + self._logger.info(f"passed_reproducing_test: {output_state['passed_reproducing_test']}") + self._logger.info(f"passed_build: {output_state['passed_build']}") + self._logger.info(f"passed_existing_test: {output_state['passed_existing_test']}") + self._logger.info(f"issue_response:\n{output_state['issue_response']}") + return { + "edit_patch": output_state["edit_patch"], + "passed_reproducing_test": output_state["passed_reproducing_test"], + "passed_build": output_state["passed_build"], + "passed_existing_test": output_state["passed_existing_test"], + "issue_response": output_state["issue_response"], + } + except GraphRecursionError: + self._logger.critical("Please increase the recursion limit of IssueBugSubgraph") + return { + "edit_patch": "", + "passed_reproducing_test": False, + "passed_build": False, + "passed_existing_test": False, + "issue_response": "", + } + finally: + self.container.cleanup() diff --git a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py index b7e56399..a90e484a 100644 --- a/prometheus/lang_graph/nodes/issue_classification_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_context_message_node.py @@ -5,7 +5,7 @@ class IssueClassificationContextMessageNode: - ISSUE_CLASSIFICATION_QUERY = """\ + ISSUE_CLASSIFICATION_QUERY = """\ {issue_info} OBJECTIVE: Find ALL self-contained context needed to accurately classify this issue as a bug, feature request, documentation update, or question. @@ -71,16 +71,16 @@ class IssueClassificationContextMessageNode: 6. Integration patterns """ - def __init__(self): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_classification_context_message_node" - ) + def __init__(self): + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.issue_classification_context_message_node" + ) - def __call__(self, state: IssueClassificationState): - issue_classification_query = self.ISSUE_CLASSIFICATION_QUERY.format( - issue_info=format_issue_info( - state["issue_title"], state["issue_body"], state["issue_comments"] - ), - ) - self._logger.debug(f"Sending query to context provider:\n{issue_classification_query}") - return {"issue_classification_query": issue_classification_query} + def __call__(self, state: IssueClassificationState): + issue_classification_query = self.ISSUE_CLASSIFICATION_QUERY.format( + issue_info=format_issue_info( + state["issue_title"], state["issue_body"], state["issue_comments"] + ), + ) + self._logger.debug(f"Sending query to context provider:\n{issue_classification_query}") + return {"issue_classification_query": issue_classification_query} diff --git a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py index 53b323de..07229cd3 100644 --- a/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_classification_subgraph_node.py @@ -6,32 +6,32 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.graphs.issue_state import IssueState from prometheus.lang_graph.subgraphs.issue_classification_subgraph import ( - IssueClassificationSubgraph, + IssueClassificationSubgraph, ) class IssueClassificationSubgraphNode: - def __init__( - self, - model: BaseChatModel, - kg: KnowledgeGraph, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - ): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_classification_subgraph_node" - ) - self.issue_classification_subgraph = IssueClassificationSubgraph( - model=model, - kg=kg, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - ) + def __init__( + self, + model: BaseChatModel, + kg: KnowledgeGraph, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + ): + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.issue_classification_subgraph_node" + ) + self.issue_classification_subgraph = IssueClassificationSubgraph( + model=model, + kg=kg, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + ) - def __call__(self, state: IssueState): - self._logger.info("Enter IssueClassificationSubgraphNode") - issue_type = self.issue_classification_subgraph.invoke( - state["issue_title"], state["issue_body"], state["issue_comments"] - ) - self._logger.info(f"issue_type: {issue_type}") - return {"issue_type": issue_type} + def __call__(self, state: IssueState): + self._logger.info("Enter IssueClassificationSubgraphNode") + issue_type = self.issue_classification_subgraph.invoke( + state["issue_title"], state["issue_body"], state["issue_comments"] + ) + self._logger.info(f"issue_type: {issue_type}") + return {"issue_type": issue_type} diff --git a/prometheus/lang_graph/nodes/issue_classifier_node.py b/prometheus/lang_graph/nodes/issue_classifier_node.py index 53cb2ad1..9ffd2a5b 100644 --- a/prometheus/lang_graph/nodes/issue_classifier_node.py +++ b/prometheus/lang_graph/nodes/issue_classifier_node.py @@ -10,13 +10,13 @@ class IssueClassifierOutput(BaseModel): - issue_type: IssueType = Field( - description='Classified issue type, can only be one of "bug", "feature", "documentation", or "question"', - ) + issue_type: IssueType = Field( + description='Classified issue type, can only be one of "bug", "feature", "documentation", or "question"', + ) class IssueClassifierNode: - SYS_PROMPT = """\ + SYS_PROMPT = """\ You are an expert GitHub issue classifier. Your task is to analyze GitHub issues and classify them into exactly one of these categories: - "bug": Issues reporting software defects, unexpected behavior, or crashes - "feature": Feature requests, enhancements, or new functionality proposals @@ -111,32 +111,32 @@ class IssueClassifierNode: Analyze the provided issue and respond with a JSON object containing only the issue_type field with one of the four allowed values: "bug", "feature", "documentation", or "question". """.replace("{", "{{").replace("}", "}}") - ISSUE_CLASSIFICATION_CONTEXT = """\ + ISSUE_CLASSIFICATION_CONTEXT = """\ {issue_info} Issue classification context: {issue_classification_context} """ - def __init__(self, model: BaseChatModel): - prompt = ChatPromptTemplate.from_messages( - [("system", self.SYS_PROMPT), ("human", "{context_info}")] - ) - structured_llm = model.with_structured_output(IssueClassifierOutput) - self.model = prompt | structured_llm - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_classifier_node") - - def format_context_info(self, state: IssueClassificationState) -> str: - context_info = self.ISSUE_CLASSIFICATION_CONTEXT.format( - issue_info=format_issue_info( - state["issue_title"], state["issue_body"], state["issue_comments"] - ), - issue_classification_context="\n\n".join(state["issue_classification_context"]), - ) - return context_info - - def __call__(self, state: IssueClassificationState): - context_info = self.format_context_info(state) - response = self.model.invoke({"context_info": context_info}) - self._logger.info(f"IssueClassifierNode classified the issue as: {response.issue_type}") - return {"issue_type": response.issue_type} + def __init__(self, model: BaseChatModel): + prompt = ChatPromptTemplate.from_messages( + [("system", self.SYS_PROMPT), ("human", "{context_info}")] + ) + structured_llm = model.with_structured_output(IssueClassifierOutput) + self.model = prompt | structured_llm + self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_classifier_node") + + def format_context_info(self, state: IssueClassificationState) -> str: + context_info = self.ISSUE_CLASSIFICATION_CONTEXT.format( + issue_info=format_issue_info( + state["issue_title"], state["issue_body"], state["issue_comments"] + ), + issue_classification_context="\n\n".join(state["issue_classification_context"]), + ) + return context_info + + def __call__(self, state: IssueClassificationState): + context_info = self.format_context_info(state) + response = self.model.invoke({"context_info": context_info}) + self._logger.info(f"IssueClassifierNode classified the issue as: {response.issue_type}") + return {"issue_type": response.issue_type} diff --git a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py index 562a171a..c0ce236c 100644 --- a/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_not_verified_bug_subgraph_node.py @@ -7,47 +7,47 @@ from prometheus.git.git_repository import GitRepository from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.issue_not_verified_bug_subgraph import ( - IssueNotVerifiedBugSubgraph, + IssueNotVerifiedBugSubgraph, ) class IssueNotVerifiedBugSubgraphNode: - def __init__( - self, - advanced_model: BaseChatModel, - base_model: BaseChatModel, - kg: KnowledgeGraph, - git_repo: GitRepository, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - ): - self._logger = logging.getLogger( - "prometheus.lang_graph.nodes.issue_not_verified_bug_subgraph_node" - ) - self.issue_not_verified_bug_subgraph = IssueNotVerifiedBugSubgraph( - advanced_model=advanced_model, - base_model=base_model, - kg=kg, - git_repo=git_repo, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - ) - - def __call__(self, state: Dict): - self._logger.info("Enter IssueNotVerifiedBugSubgraphNode") - - output_state = self.issue_not_verified_bug_subgraph.invoke( - issue_title=state["issue_title"], - issue_body=state["issue_body"], - issue_comments=state["issue_comments"], - number_of_candidate_patch=state["number_of_candidate_patch"], - ) - - self._logger.info(f"final_patch:\n{output_state['final_patch']}") - - return { - "edit_patch": output_state["final_patch"], - "passed_reproducing_test": False, - "passed_build": False, - "passed_existing_test": False, - } + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + ): + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.issue_not_verified_bug_subgraph_node" + ) + self.issue_not_verified_bug_subgraph = IssueNotVerifiedBugSubgraph( + advanced_model=advanced_model, + base_model=base_model, + kg=kg, + git_repo=git_repo, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + ) + + def __call__(self, state: Dict): + self._logger.info("Enter IssueNotVerifiedBugSubgraphNode") + + output_state = self.issue_not_verified_bug_subgraph.invoke( + issue_title=state["issue_title"], + issue_body=state["issue_body"], + issue_comments=state["issue_comments"], + number_of_candidate_patch=state["number_of_candidate_patch"], + ) + + self._logger.info(f"final_patch:\n{output_state['final_patch']}") + + return { + "edit_patch": output_state["final_patch"], + "passed_reproducing_test": False, + "passed_build": False, + "passed_existing_test": False, + } diff --git a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py index a3b0b734..ead733dd 100644 --- a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py @@ -12,64 +12,68 @@ class IssueVerifiedBugSubgraphNode: - def __init__( - self, - advanced_model: BaseChatModel, - base_model: BaseChatModel, - container: BaseContainer, - kg: KnowledgeGraph, - git_repo: GitRepository, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - build_commands: Optional[Sequence[str]] = None, - test_commands: Optional[Sequence[str]] = None, - ): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.issue_verified_bug_subgraph_node") - self.git_repo = git_repo - self.issue_reproduced_bug_subgraph = IssueVerifiedBugSubgraph( - advanced_model=advanced_model, - base_model=base_model, - container=container, - kg=kg, - git_repo=git_repo, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - build_commands=build_commands, - test_commands=test_commands, - ) + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + container: BaseContainer, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + build_commands: Optional[Sequence[str]] = None, + test_commands: Optional[Sequence[str]] = None, + ): + self._logger = logging.getLogger( + "prometheus.lang_graph.nodes.issue_verified_bug_subgraph_node" + ) + self.git_repo = git_repo + self.issue_reproduced_bug_subgraph = IssueVerifiedBugSubgraph( + advanced_model=advanced_model, + base_model=base_model, + container=container, + kg=kg, + git_repo=git_repo, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + build_commands=build_commands, + test_commands=test_commands, + ) - def __call__(self, state: Dict): - self._logger.info("Enter IssueVerifiedBugSubgraphNode") - try: - output_state = self.issue_reproduced_bug_subgraph.invoke( - issue_title=state["issue_title"], - issue_body=state["issue_body"], - issue_comments=state["issue_comments"], - run_build=state["run_build"], - run_existing_test=state["run_existing_test"], - reproduced_bug_file=state["reproduced_bug_file"], - reproduced_bug_commands=state["reproduced_bug_commands"], - ) - except GraphRecursionError: - self._logger.info("Recursion limit reached") - self.git_repo.reset_repository() - return { - "edit_patch": "", - "passed_reproducing_test": False, - "passed_build": False, - "passed_existing_test": False, - } + def __call__(self, state: Dict): + self._logger.info("Enter IssueVerifiedBugSubgraphNode") + try: + output_state = self.issue_reproduced_bug_subgraph.invoke( + issue_title=state["issue_title"], + issue_body=state["issue_body"], + issue_comments=state["issue_comments"], + run_build=state["run_build"], + run_existing_test=state["run_existing_test"], + reproduced_bug_file=state["reproduced_bug_file"], + reproduced_bug_commands=state["reproduced_bug_commands"], + ) + except GraphRecursionError: + self._logger.info("Recursion limit reached") + self.git_repo.reset_repository() + return { + "edit_patch": "", + "passed_reproducing_test": False, + "passed_build": False, + "passed_existing_test": False, + } - passed_reproducing_test = not bool(output_state["reproducing_test_fail_log"]) - passed_build = state["run_build"] and not output_state["build_fail_log"] - passed_existing_test = state["run_existing_test"] and not output_state["existing_test_fail_log"] - self._logger.info(f"edit_patch: {output_state['edit_patch']}") - self._logger.info(f"passed_reproducing_test: {passed_reproducing_test}") - self._logger.info(f"passed_build: {passed_build}") - self._logger.info(f"passed_existing_test: {passed_existing_test}") - return { - "edit_patch": output_state["edit_patch"], - "passed_reproducing_test": passed_reproducing_test, - "passed_build": passed_build, - "passed_existing_test": passed_existing_test, - } + passed_reproducing_test = not bool(output_state["reproducing_test_fail_log"]) + passed_build = state["run_build"] and not output_state["build_fail_log"] + passed_existing_test = ( + state["run_existing_test"] and not output_state["existing_test_fail_log"] + ) + self._logger.info(f"edit_patch: {output_state['edit_patch']}") + self._logger.info(f"passed_reproducing_test: {passed_reproducing_test}") + self._logger.info(f"passed_build: {passed_build}") + self._logger.info(f"passed_existing_test: {passed_existing_test}") + return { + "edit_patch": output_state["edit_patch"], + "passed_reproducing_test": passed_reproducing_test, + "passed_build": passed_build, + "passed_existing_test": passed_existing_test, + } diff --git a/prometheus/lang_graph/nodes/noop_node.py b/prometheus/lang_graph/nodes/noop_node.py index 99141223..e21b3f70 100644 --- a/prometheus/lang_graph/nodes/noop_node.py +++ b/prometheus/lang_graph/nodes/noop_node.py @@ -10,30 +10,30 @@ class NoopNode: - """No-operation node that routes workflow without processing. + """No-operation node that routes workflow without processing. - This class implements a pass-through node that accepts any input and returns - None without performing any operations. It serves as a routing element in - node graphs where a connection point is needed but no actual processing - should occur. - """ + This class implements a pass-through node that accepts any input and returns + None without performing any operations. It serves as a routing element in + node graphs where a connection point is needed but no actual processing + should occur. + """ - def __init__(self): - self._logger = logging.getLogger("prometheus.lang_graph.nodes.noop_node") + def __init__(self): + self._logger = logging.getLogger("prometheus.lang_graph.nodes.noop_node") - def __call__(self, state: Dict) -> None: - """Routes the workflow without performing any operations. + def __call__(self, state: Dict) -> None: + """Routes the workflow without performing any operations. - Accepts any input and returns None, serving as a pass-through point - in the workflow. + Accepts any input and returns None, serving as a pass-through point + in the workflow. - Args: - _: Any state - not used by this node. + Args: + _: Any state - not used by this node. - Returns: - None always, as this node performs no operations. - """ - for key, value in state.items(): - if isinstance(value, bool) or isinstance(value, int): - self._logger.debug(f"State {key}: {value}") - return + Returns: + None always, as this node performs no operations. + """ + for key, value in state.items(): + if isinstance(value, bool) or isinstance(value, int): + self._logger.debug(f"State {key}: {value}") + return diff --git a/prometheus/lang_graph/nodes/reset_messages_node.py b/prometheus/lang_graph/nodes/reset_messages_node.py index cf5a0bb6..bf87c134 100644 --- a/prometheus/lang_graph/nodes/reset_messages_node.py +++ b/prometheus/lang_graph/nodes/reset_messages_node.py @@ -15,44 +15,44 @@ class ResetMessagesNode: - """Resets message states for workflow loop iterations. + """Resets message states for workflow loop iterations. - This class provides functionality to clear or reset message states between - workflow loop iterations. It handles both list-type message collections - (by clearing the list) and string-type messages (by returning an empty - string state). + This class provides functionality to clear or reset message states between + workflow loop iterations. It handles both list-type message collections + (by clearing the list) and string-type messages (by returning an empty + string state). - The node is typically used in loops where: - - Message history needs to be cleared for the next iteration - - The same state key is reused throughout the loop - """ - - def __init__(self, message_state_key: str): - """Initializes the ResetMessagesNode with a target state key. - - Args: - message_state_key: String identifying which state attribute should - be reset during node execution. + The node is typically used in loops where: + - Message history needs to be cleared for the next iteration + - The same state key is reused throughout the loop """ - self.message_state_key = message_state_key - self._logger = logging.getLogger("prometheus.lang_graph.nodes.reset_messages_node") - - def __call__(self, state: Dict): - """Resets the specified message state for the next iteration. - Handles two types of message states: - 1. List-type: Clears the list in place - 2. String-type: Returns a new state with empty string - - Args: - state: Current workflow state containing the message attribute to be reset. - - Returns: - For string-type messages: Dictionary with empty string state - For list-type messages: None (list is cleared in place) - """ - self._logger.debug(f"Resetting {self.message_state_key} in state.") - if isinstance(state[self.message_state_key], list): - state[self.message_state_key].clear() - elif isinstance(state[self.message_state_key], str): - return {self.message_state_key: ""} + def __init__(self, message_state_key: str): + """Initializes the ResetMessagesNode with a target state key. + + Args: + message_state_key: String identifying which state attribute should + be reset during node execution. + """ + self.message_state_key = message_state_key + self._logger = logging.getLogger("prometheus.lang_graph.nodes.reset_messages_node") + + def __call__(self, state: Dict): + """Resets the specified message state for the next iteration. + + Handles two types of message states: + 1. List-type: Clears the list in place + 2. String-type: Returns a new state with empty string + + Args: + state: Current workflow state containing the message attribute to be reset. + + Returns: + For string-type messages: Dictionary with empty string state + For list-type messages: None (list is cleared in place) + """ + self._logger.debug(f"Resetting {self.message_state_key} in state.") + if isinstance(state[self.message_state_key], list): + state[self.message_state_key].clear() + elif isinstance(state[self.message_state_key], str): + return {self.message_state_key: ""} diff --git a/prometheus/lang_graph/nodes/update_container_node.py b/prometheus/lang_graph/nodes/update_container_node.py index eb68ae59..5c473dc6 100644 --- a/prometheus/lang_graph/nodes/update_container_node.py +++ b/prometheus/lang_graph/nodes/update_container_node.py @@ -15,38 +15,38 @@ class UpdateContainerNode: - """Synchronizes agent-made changes with container filesystem. + """Synchronizes agent-made changes with container filesystem. - This class handles the synchronization of file changes between an agent's - workspace and a container environment. It ensures that any modifications - made by AI agents (such as code fixes or edits) are properly reflected - in the container where builds, tests, or other operations may occur. - """ - - def __init__(self, container: BaseContainer, git_repo: GitRepository): - """Initializes the UpdateContainerNode with a target container. - - Args: - container: Container instance that will receive file updates. Must - be a subclass of BaseContainer implementing the update_files method. - git_repo: The local git repository used to retrieve the project's files. + This class handles the synchronization of file changes between an agent's + workspace and a container environment. It ensures that any modifications + made by AI agents (such as code fixes or edits) are properly reflected + in the container where builds, tests, or other operations may occur. """ - self.container = container - self.git_repo = git_repo - self._logger = logging.getLogger("prometheus.lang_graph.nodes.update_container_node") - - def __call__(self, _: Dict): - """Synchronizes the current project state with the container.""" - if self.container.is_running(): - self._logger.info("Copy over all updated files to the container") - all_files_patch = self.git_repo.get_diff() - self.container.restart_container() - added_files, modified_file, removed_files = get_updated_files(all_files_patch) - self.container.update_files( - self.git_repo.get_working_directory(), added_files + modified_file, removed_files - ) - else: - self._logger.info( - "Not updating files in docker container because it is not running, " - "most likely due to run_build and run_test are both false." - ) + + def __init__(self, container: BaseContainer, git_repo: GitRepository): + """Initializes the UpdateContainerNode with a target container. + + Args: + container: Container instance that will receive file updates. Must + be a subclass of BaseContainer implementing the update_files method. + git_repo: The local git repository used to retrieve the project's files. + """ + self.container = container + self.git_repo = git_repo + self._logger = logging.getLogger("prometheus.lang_graph.nodes.update_container_node") + + def __call__(self, _: Dict): + """Synchronizes the current project state with the container.""" + if self.container.is_running(): + self._logger.info("Copy over all updated files to the container") + all_files_patch = self.git_repo.get_diff() + self.container.restart_container() + added_files, modified_file, removed_files = get_updated_files(all_files_patch) + self.container.update_files( + self.git_repo.get_working_directory(), added_files + modified_file, removed_files + ) + else: + self._logger.info( + "Not updating files in docker container because it is not running, " + "most likely due to run_build and run_test are both false." + ) diff --git a/prometheus/lang_graph/nodes/user_defined_build_node.py b/prometheus/lang_graph/nodes/user_defined_build_node.py index 9d261215..aa8dc40c 100644 --- a/prometheus/lang_graph/nodes/user_defined_build_node.py +++ b/prometheus/lang_graph/nodes/user_defined_build_node.py @@ -8,15 +8,16 @@ class UserDefinedBuildNode: - def __init__(self, container: BaseContainer): - self.container = container - self._logger = logging.getLogger("prometheus.lang_graph.nodes.user_defined_build_node") + def __init__(self, container: BaseContainer): + self.container = container + self._logger = logging.getLogger("prometheus.lang_graph.nodes.user_defined_build_node") - def __call__(self, _: Any): - build_output = self.container.run_build() - self._logger.debug(f"UserDefinedBuildNode response:\n{build_output}") + def __call__(self, _: Any): + build_output = self.container.run_build() + self._logger.debug(f"UserDefinedBuildNode response:\n{build_output}") - tool_message = ToolMessage( - content=build_output, tool_call_id=f"user_defined_build_commands_{uuid.uuid4().hex[:10]}" - ) - return {"build_messages": [tool_message]} + tool_message = ToolMessage( + content=build_output, + tool_call_id=f"user_defined_build_commands_{uuid.uuid4().hex[:10]}", + ) + return {"build_messages": [tool_message]} diff --git a/prometheus/lang_graph/nodes/user_defined_test_node.py b/prometheus/lang_graph/nodes/user_defined_test_node.py index e61f13ff..7f3e6e61 100644 --- a/prometheus/lang_graph/nodes/user_defined_test_node.py +++ b/prometheus/lang_graph/nodes/user_defined_test_node.py @@ -8,15 +8,15 @@ class UserDefinedTestNode: - def __init__(self, container: BaseContainer): - self.container = container - self._logger = logging.getLogger("prometheus.lang_graph.nodes.user_defined_test_node") + def __init__(self, container: BaseContainer): + self.container = container + self._logger = logging.getLogger("prometheus.lang_graph.nodes.user_defined_test_node") - def __call__(self, _: Any): - test_output = self.container.run_test() - self._logger.debug(f"UserDefinedTestNode response:\n{test_output}") + def __call__(self, _: Any): + test_output = self.container.run_test() + self._logger.debug(f"UserDefinedTestNode response:\n{test_output}") - tool_message = ToolMessage( - test_output, tool_call_id=f"user_defined_test_commands_{uuid.uuid4().hex[:10]}" - ) - return {"test_messages": [tool_message]} + tool_message = ToolMessage( + test_output, tool_call_id=f"user_defined_test_commands_{uuid.uuid4().hex[:10]}" + ) + return {"test_messages": [tool_message]} diff --git a/prometheus/lang_graph/subgraphs/bug_fix_verification_state.py b/prometheus/lang_graph/subgraphs/bug_fix_verification_state.py index 192bb393..ac493fb7 100644 --- a/prometheus/lang_graph/subgraphs/bug_fix_verification_state.py +++ b/prometheus/lang_graph/subgraphs/bug_fix_verification_state.py @@ -5,9 +5,9 @@ class BugFixVerficationState(TypedDict): - reproduced_bug_file: str - reproduced_bug_commands: Sequence[str] + reproduced_bug_file: str + reproduced_bug_commands: Sequence[str] - bug_fix_verify_messages: Annotated[Sequence[BaseMessage], add_messages] + bug_fix_verify_messages: Annotated[Sequence[BaseMessage], add_messages] - reproducing_test_fail_log: str + reproducing_test_fail_log: str diff --git a/prometheus/lang_graph/subgraphs/bug_fix_verification_subgraph.py b/prometheus/lang_graph/subgraphs/bug_fix_verification_subgraph.py index 53764bee..e0434a51 100644 --- a/prometheus/lang_graph/subgraphs/bug_fix_verification_subgraph.py +++ b/prometheus/lang_graph/subgraphs/bug_fix_verification_subgraph.py @@ -12,50 +12,50 @@ class BugFixVerificationSubgraph: - def __init__( - self, - model: BaseChatModel, - container: BaseContainer, - ): - bug_fix_verify_node = BugFixVerifyNode(model, container) - bug_fix_verify_tools = ToolNode( - tools=bug_fix_verify_node.tools, - name="bug_fix_verify_tools", - messages_key="bug_fix_verify_messages", - ) - bug_fix_verify_structured_node = BugFixVerifyStructuredNode(model) - - workflow = StateGraph(BugFixVerficationState) - - workflow.add_node("bug_fix_verify_node", bug_fix_verify_node) - workflow.add_node("bug_fix_verify_tools", bug_fix_verify_tools) - workflow.add_node("bug_fix_verify_structured_node", bug_fix_verify_structured_node) - - workflow.set_entry_point("bug_fix_verify_node") - - workflow.add_conditional_edges( - "bug_fix_verify_node", - functools.partial(tools_condition, messages_key="bug_fix_verify_messages"), - { - "tools": "bug_fix_verify_tools", - END: "bug_fix_verify_structured_node", - }, - ) - workflow.add_edge("bug_fix_verify_tools", "bug_fix_verify_node") - workflow.add_edge("bug_fix_verify_structured_node", END) - - self.subgraph = workflow.compile() - - def invoke(self, reproduced_bug_file: str, reproduced_bug_commands: Sequence[str]): - config = None - - input_state = { - "reproduced_bug_file": reproduced_bug_file, - "reproduced_bug_commands": reproduced_bug_commands, - } - - output_state = self.subgraph.invoke(input_state, config) - - return { - "reproducing_test_fail_log": output_state["reproducing_test_fail_log"], - } + def __init__( + self, + model: BaseChatModel, + container: BaseContainer, + ): + bug_fix_verify_node = BugFixVerifyNode(model, container) + bug_fix_verify_tools = ToolNode( + tools=bug_fix_verify_node.tools, + name="bug_fix_verify_tools", + messages_key="bug_fix_verify_messages", + ) + bug_fix_verify_structured_node = BugFixVerifyStructuredNode(model) + + workflow = StateGraph(BugFixVerficationState) + + workflow.add_node("bug_fix_verify_node", bug_fix_verify_node) + workflow.add_node("bug_fix_verify_tools", bug_fix_verify_tools) + workflow.add_node("bug_fix_verify_structured_node", bug_fix_verify_structured_node) + + workflow.set_entry_point("bug_fix_verify_node") + + workflow.add_conditional_edges( + "bug_fix_verify_node", + functools.partial(tools_condition, messages_key="bug_fix_verify_messages"), + { + "tools": "bug_fix_verify_tools", + END: "bug_fix_verify_structured_node", + }, + ) + workflow.add_edge("bug_fix_verify_tools", "bug_fix_verify_node") + workflow.add_edge("bug_fix_verify_structured_node", END) + + self.subgraph = workflow.compile() + + def invoke(self, reproduced_bug_file: str, reproduced_bug_commands: Sequence[str]): + config = None + + input_state = { + "reproduced_bug_file": reproduced_bug_file, + "reproduced_bug_commands": reproduced_bug_commands, + } + + output_state = self.subgraph.invoke(input_state, config) + + return { + "reproducing_test_fail_log": output_state["reproducing_test_fail_log"], + } diff --git a/prometheus/lang_graph/subgraphs/bug_reproduction_state.py b/prometheus/lang_graph/subgraphs/bug_reproduction_state.py index acaba9e8..5109324e 100644 --- a/prometheus/lang_graph/subgraphs/bug_reproduction_state.py +++ b/prometheus/lang_graph/subgraphs/bug_reproduction_state.py @@ -5,22 +5,22 @@ class BugReproductionState(TypedDict): - issue_title: str - issue_body: str - issue_comments: Sequence[Mapping[str, str]] + issue_title: str + issue_body: str + issue_comments: Sequence[Mapping[str, str]] - max_refined_query_loop: int + max_refined_query_loop: int - bug_reproducing_query: str - bug_reproducing_context: Sequence[str] + bug_reproducing_query: str + bug_reproducing_context: Sequence[str] - bug_reproducing_write_messages: Annotated[Sequence[BaseMessage], add_messages] - bug_reproducing_file_messages: Annotated[Sequence[BaseMessage], add_messages] - bug_reproducing_execute_messages: Annotated[Sequence[BaseMessage], add_messages] + bug_reproducing_write_messages: Annotated[Sequence[BaseMessage], add_messages] + bug_reproducing_file_messages: Annotated[Sequence[BaseMessage], add_messages] + bug_reproducing_execute_messages: Annotated[Sequence[BaseMessage], add_messages] - bug_reproducing_patch: str + bug_reproducing_patch: str - reproduced_bug: bool - reproduced_bug_failure_log: str - reproduced_bug_file: str - reproduced_bug_commands: Sequence[str] + reproduced_bug: bool + reproduced_bug_failure_log: str + reproduced_bug_file: str + reproduced_bug_commands: Sequence[str] diff --git a/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py b/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py index 992cbee1..8a4893ff 100644 --- a/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py +++ b/prometheus/lang_graph/subgraphs/bug_reproduction_subgraph.py @@ -14,14 +14,14 @@ from prometheus.lang_graph.nodes.bug_reproducing_file_node import BugReproducingFileNode from prometheus.lang_graph.nodes.bug_reproducing_structured_node import BugReproducingStructuredNode from prometheus.lang_graph.nodes.bug_reproducing_write_message_node import ( - BugReproducingWriteMessageNode, + BugReproducingWriteMessageNode, ) from prometheus.lang_graph.nodes.bug_reproducing_write_node import BugReproducingWriteNode from prometheus.lang_graph.nodes.context_retrieval_subgraph_node import ContextRetrievalSubgraphNode from prometheus.lang_graph.nodes.git_diff_node import GitDiffNode from prometheus.lang_graph.nodes.git_reset_node import GitResetNode from prometheus.lang_graph.nodes.issue_bug_reproduction_context_message_node import ( - IssueBugReproductionContextMessageNode, + IssueBugReproductionContextMessageNode, ) from prometheus.lang_graph.nodes.reset_messages_node import ResetMessagesNode from prometheus.lang_graph.nodes.update_container_node import UpdateContainerNode @@ -29,166 +29,173 @@ class BugReproductionSubgraph: - def __init__( - self, - advanced_model: BaseChatModel, - base_model: BaseChatModel, - container: BaseContainer, - kg: KnowledgeGraph, - git_repo: GitRepository, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - test_commands: Optional[Sequence[str]] = None, - ): - self.git_repo = git_repo - - issue_bug_reproduction_context_message_node = IssueBugReproductionContextMessageNode() - context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( - base_model, - kg, - neo4j_driver, - max_token_per_neo4j_result, - "bug_reproducing_query", - "bug_reproducing_context", - ) - - bug_reproducing_write_message_node = BugReproducingWriteMessageNode() - bug_reproducing_write_node = BugReproducingWriteNode(advanced_model, kg) - bug_reproducing_write_tools = ToolNode( - tools=bug_reproducing_write_node.tools, - name="bug_reproducing_write_tools", - messages_key="bug_reproducing_write_messages", - ) - bug_reproducing_file_node = BugReproducingFileNode(base_model, kg) - bug_reproducing_file_tools = ToolNode( - tools=bug_reproducing_file_node.tools, - name="bug_reproducing_file_tools", - messages_key="bug_reproducing_file_messages", - ) - git_diff_node = GitDiffNode(git_repo, "bug_reproducing_patch") - update_container_node = UpdateContainerNode(container, git_repo) - bug_reproducing_execute_node = BugReproducingExecuteNode(base_model, container, test_commands) - bug_reproducing_execute_tools = ToolNode( - tools=bug_reproducing_execute_node.tools, - name="bug_reproducing_execute_tools", - messages_key="bug_reproducing_execute_messages", - ) - bug_reproducing_structured_node = BugReproducingStructuredNode(advanced_model) - reset_bug_reproducing_file_messages_node = ResetMessagesNode("bug_reproducing_file_messages") - reset_bug_reproducing_execute_messages_node = ResetMessagesNode( - "bug_reproducing_execute_messages" - ) - git_reset_node = GitResetNode(git_repo) - - workflow = StateGraph(BugReproductionState) - - workflow.add_node( - "issue_bug_reproduction_context_message_node", issue_bug_reproduction_context_message_node - ) - workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) - - workflow.add_node("bug_reproducing_write_message_node", bug_reproducing_write_message_node) - workflow.add_node("bug_reproducing_write_node", bug_reproducing_write_node) - workflow.add_node("bug_reproducing_write_tools", bug_reproducing_write_tools) - workflow.add_node("bug_reproducing_file_node", bug_reproducing_file_node) - workflow.add_node("bug_reproducing_file_tools", bug_reproducing_file_tools) - workflow.add_node("git_diff_node", git_diff_node) - workflow.add_node("update_container_node", update_container_node) - workflow.add_node("bug_reproducing_execute_node", bug_reproducing_execute_node) - workflow.add_node("bug_reproducing_execute_tools", bug_reproducing_execute_tools) - workflow.add_node("bug_reproducing_structured_node", bug_reproducing_structured_node) - - workflow.add_node( - "reset_bug_reproducing_file_messages_node", reset_bug_reproducing_file_messages_node - ) - workflow.add_node( - "reset_bug_reproducing_execute_messages_node", reset_bug_reproducing_execute_messages_node - ) - workflow.add_node("git_reset_node", git_reset_node) - - workflow.set_entry_point("issue_bug_reproduction_context_message_node") - workflow.add_edge( - "issue_bug_reproduction_context_message_node", "context_retrieval_subgraph_node" - ) - workflow.add_edge("context_retrieval_subgraph_node", "bug_reproducing_write_message_node") - - workflow.add_edge("bug_reproducing_write_message_node", "bug_reproducing_write_node") - workflow.add_conditional_edges( - "bug_reproducing_write_node", - functools.partial(tools_condition, messages_key="bug_reproducing_write_messages"), - { - "tools": "bug_reproducing_write_tools", - END: "bug_reproducing_file_node", - }, - ) - workflow.add_edge("bug_reproducing_write_tools", "bug_reproducing_write_node") - workflow.add_conditional_edges( - "bug_reproducing_file_node", - functools.partial(tools_condition, messages_key="bug_reproducing_file_messages"), - { - "tools": "bug_reproducing_file_tools", - END: "git_diff_node", - }, - ) - workflow.add_edge("bug_reproducing_file_tools", "bug_reproducing_file_node") - workflow.add_edge("git_diff_node", "update_container_node") - workflow.add_edge("update_container_node", "bug_reproducing_execute_node") - workflow.add_conditional_edges( - "bug_reproducing_execute_node", - functools.partial(tools_condition, messages_key="bug_reproducing_execute_messages"), - { - "tools": "bug_reproducing_execute_tools", - END: "bug_reproducing_structured_node", - }, - ) - workflow.add_edge("bug_reproducing_execute_tools", "bug_reproducing_execute_node") - workflow.add_conditional_edges( - "bug_reproducing_structured_node", - lambda state: state["reproduced_bug"], - {True: END, False: "reset_bug_reproducing_file_messages_node"}, - ) - - workflow.add_edge( - "reset_bug_reproducing_file_messages_node", "reset_bug_reproducing_execute_messages_node" - ) - workflow.add_edge( - "reset_bug_reproducing_execute_messages_node", - "git_reset_node", - ) - workflow.add_edge( - "git_reset_node", - "bug_reproducing_write_message_node", - ) - - self.subgraph = workflow.compile() - - def invoke( - self, - issue_title: str, - issue_body: str, - issue_comments: Sequence[Mapping[str, str]], - recursion_limit: int = 50, - ): - config = {"recursion_limit": recursion_limit} - - input_state = { - "issue_title": issue_title, - "issue_body": issue_body, - "issue_comments": issue_comments, - "max_refined_query_loop": 1, - } - - try: - output_state = self.subgraph.invoke(input_state, config) - return { - "reproduced_bug": output_state["reproduced_bug"], - "reproduced_bug_file": output_state["reproduced_bug_file"], - "reproduced_bug_commands": output_state["reproduced_bug_commands"], - } - except GraphRecursionError: - self.git_repo.reset_repository() - return { - "reproduced_bug": False, - "reproduced_bug_file": "", - "reproduced_bug_commands": "", - } + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + container: BaseContainer, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + test_commands: Optional[Sequence[str]] = None, + ): + self.git_repo = git_repo + + issue_bug_reproduction_context_message_node = IssueBugReproductionContextMessageNode() + context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( + base_model, + kg, + neo4j_driver, + max_token_per_neo4j_result, + "bug_reproducing_query", + "bug_reproducing_context", + ) + + bug_reproducing_write_message_node = BugReproducingWriteMessageNode() + bug_reproducing_write_node = BugReproducingWriteNode(advanced_model, kg) + bug_reproducing_write_tools = ToolNode( + tools=bug_reproducing_write_node.tools, + name="bug_reproducing_write_tools", + messages_key="bug_reproducing_write_messages", + ) + bug_reproducing_file_node = BugReproducingFileNode(base_model, kg) + bug_reproducing_file_tools = ToolNode( + tools=bug_reproducing_file_node.tools, + name="bug_reproducing_file_tools", + messages_key="bug_reproducing_file_messages", + ) + git_diff_node = GitDiffNode(git_repo, "bug_reproducing_patch") + update_container_node = UpdateContainerNode(container, git_repo) + bug_reproducing_execute_node = BugReproducingExecuteNode( + base_model, container, test_commands + ) + bug_reproducing_execute_tools = ToolNode( + tools=bug_reproducing_execute_node.tools, + name="bug_reproducing_execute_tools", + messages_key="bug_reproducing_execute_messages", + ) + bug_reproducing_structured_node = BugReproducingStructuredNode(advanced_model) + reset_bug_reproducing_file_messages_node = ResetMessagesNode( + "bug_reproducing_file_messages" + ) + reset_bug_reproducing_execute_messages_node = ResetMessagesNode( + "bug_reproducing_execute_messages" + ) + git_reset_node = GitResetNode(git_repo) + + workflow = StateGraph(BugReproductionState) + + workflow.add_node( + "issue_bug_reproduction_context_message_node", + issue_bug_reproduction_context_message_node, + ) + workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) + + workflow.add_node("bug_reproducing_write_message_node", bug_reproducing_write_message_node) + workflow.add_node("bug_reproducing_write_node", bug_reproducing_write_node) + workflow.add_node("bug_reproducing_write_tools", bug_reproducing_write_tools) + workflow.add_node("bug_reproducing_file_node", bug_reproducing_file_node) + workflow.add_node("bug_reproducing_file_tools", bug_reproducing_file_tools) + workflow.add_node("git_diff_node", git_diff_node) + workflow.add_node("update_container_node", update_container_node) + workflow.add_node("bug_reproducing_execute_node", bug_reproducing_execute_node) + workflow.add_node("bug_reproducing_execute_tools", bug_reproducing_execute_tools) + workflow.add_node("bug_reproducing_structured_node", bug_reproducing_structured_node) + + workflow.add_node( + "reset_bug_reproducing_file_messages_node", reset_bug_reproducing_file_messages_node + ) + workflow.add_node( + "reset_bug_reproducing_execute_messages_node", + reset_bug_reproducing_execute_messages_node, + ) + workflow.add_node("git_reset_node", git_reset_node) + + workflow.set_entry_point("issue_bug_reproduction_context_message_node") + workflow.add_edge( + "issue_bug_reproduction_context_message_node", "context_retrieval_subgraph_node" + ) + workflow.add_edge("context_retrieval_subgraph_node", "bug_reproducing_write_message_node") + + workflow.add_edge("bug_reproducing_write_message_node", "bug_reproducing_write_node") + workflow.add_conditional_edges( + "bug_reproducing_write_node", + functools.partial(tools_condition, messages_key="bug_reproducing_write_messages"), + { + "tools": "bug_reproducing_write_tools", + END: "bug_reproducing_file_node", + }, + ) + workflow.add_edge("bug_reproducing_write_tools", "bug_reproducing_write_node") + workflow.add_conditional_edges( + "bug_reproducing_file_node", + functools.partial(tools_condition, messages_key="bug_reproducing_file_messages"), + { + "tools": "bug_reproducing_file_tools", + END: "git_diff_node", + }, + ) + workflow.add_edge("bug_reproducing_file_tools", "bug_reproducing_file_node") + workflow.add_edge("git_diff_node", "update_container_node") + workflow.add_edge("update_container_node", "bug_reproducing_execute_node") + workflow.add_conditional_edges( + "bug_reproducing_execute_node", + functools.partial(tools_condition, messages_key="bug_reproducing_execute_messages"), + { + "tools": "bug_reproducing_execute_tools", + END: "bug_reproducing_structured_node", + }, + ) + workflow.add_edge("bug_reproducing_execute_tools", "bug_reproducing_execute_node") + workflow.add_conditional_edges( + "bug_reproducing_structured_node", + lambda state: state["reproduced_bug"], + {True: END, False: "reset_bug_reproducing_file_messages_node"}, + ) + + workflow.add_edge( + "reset_bug_reproducing_file_messages_node", + "reset_bug_reproducing_execute_messages_node", + ) + workflow.add_edge( + "reset_bug_reproducing_execute_messages_node", + "git_reset_node", + ) + workflow.add_edge( + "git_reset_node", + "bug_reproducing_write_message_node", + ) + + self.subgraph = workflow.compile() + + def invoke( + self, + issue_title: str, + issue_body: str, + issue_comments: Sequence[Mapping[str, str]], + recursion_limit: int = 50, + ): + config = {"recursion_limit": recursion_limit} + + input_state = { + "issue_title": issue_title, + "issue_body": issue_body, + "issue_comments": issue_comments, + "max_refined_query_loop": 1, + } + + try: + output_state = self.subgraph.invoke(input_state, config) + return { + "reproduced_bug": output_state["reproduced_bug"], + "reproduced_bug_file": output_state["reproduced_bug_file"], + "reproduced_bug_commands": output_state["reproduced_bug_commands"], + } + except GraphRecursionError: + self.git_repo.reset_repository() + return { + "reproduced_bug": False, + "reproduced_bug_file": "", + "reproduced_bug_commands": "", + } diff --git a/prometheus/lang_graph/subgraphs/build_and_test_state.py b/prometheus/lang_graph/subgraphs/build_and_test_state.py index b700ed55..0dcafe62 100644 --- a/prometheus/lang_graph/subgraphs/build_and_test_state.py +++ b/prometheus/lang_graph/subgraphs/build_and_test_state.py @@ -5,15 +5,15 @@ class BuildAndTestState(TypedDict): - run_build: bool - run_existing_test: bool + run_build: bool + run_existing_test: bool - build_messages: Annotated[Sequence[BaseMessage], add_messages] - exist_build: bool - build_command_summary: str - build_fail_log: str + build_messages: Annotated[Sequence[BaseMessage], add_messages] + exist_build: bool + build_command_summary: str + build_fail_log: str - test_messages: Annotated[Sequence[BaseMessage], add_messages] - exist_test: bool - test_command_summary: str - existing_test_fail_log: str + test_messages: Annotated[Sequence[BaseMessage], add_messages] + exist_test: bool + test_command_summary: str + existing_test_fail_log: str diff --git a/prometheus/lang_graph/subgraphs/build_and_test_subgraph.py b/prometheus/lang_graph/subgraphs/build_and_test_subgraph.py index aff4c701..67cce0f3 100644 --- a/prometheus/lang_graph/subgraphs/build_and_test_subgraph.py +++ b/prometheus/lang_graph/subgraphs/build_and_test_subgraph.py @@ -9,7 +9,7 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.nodes.general_build_node import GeneralBuildNode from prometheus.lang_graph.nodes.general_build_structured_node import ( - GeneralBuildStructuredNode, + GeneralBuildStructuredNode, ) from prometheus.lang_graph.nodes.general_test_node import GeneralTestNode from prometheus.lang_graph.nodes.general_test_structured_node import GeneralTestStructuredNode @@ -20,129 +20,129 @@ class BuildAndTestSubgraph: - def __init__( - self, - container: BaseContainer, - model: BaseChatModel, - kg: KnowledgeGraph, - build_commands: Optional[Sequence[str]] = None, - test_commands: Optional[Sequence[str]] = None, - ): - build_branch_node = NoopNode() - if build_commands: - build_node = UserDefinedBuildNode(container) - else: - build_node = GeneralBuildNode(model, container, kg) - build_tools = ToolNode( - tools=build_node.tools, - name="build_tools", - messages_key="build_messages", - ) - general_build_structured_node = GeneralBuildStructuredNode(model) + def __init__( + self, + container: BaseContainer, + model: BaseChatModel, + kg: KnowledgeGraph, + build_commands: Optional[Sequence[str]] = None, + test_commands: Optional[Sequence[str]] = None, + ): + build_branch_node = NoopNode() + if build_commands: + build_node = UserDefinedBuildNode(container) + else: + build_node = GeneralBuildNode(model, container, kg) + build_tools = ToolNode( + tools=build_node.tools, + name="build_tools", + messages_key="build_messages", + ) + general_build_structured_node = GeneralBuildStructuredNode(model) - test_branch_node = NoopNode() - if test_commands: - test_node = UserDefinedTestNode(container) - else: - test_node = GeneralTestNode(model, container, kg) - test_tools = ToolNode( - tools=test_node.tools, - name="test_tools", - messages_key="test_messages", - ) - general_test_structured_node = GeneralTestStructuredNode(model) + test_branch_node = NoopNode() + if test_commands: + test_node = UserDefinedTestNode(container) + else: + test_node = GeneralTestNode(model, container, kg) + test_tools = ToolNode( + tools=test_node.tools, + name="test_tools", + messages_key="test_messages", + ) + general_test_structured_node = GeneralTestStructuredNode(model) - workflow = StateGraph(BuildAndTestState) - workflow.add_node("build_branch_node", build_branch_node) - workflow.add_node("build_node", build_node) - if not build_commands: - workflow.add_node("build_tools", build_tools) - workflow.add_node("general_build_structured_node", general_build_structured_node) + workflow = StateGraph(BuildAndTestState) + workflow.add_node("build_branch_node", build_branch_node) + workflow.add_node("build_node", build_node) + if not build_commands: + workflow.add_node("build_tools", build_tools) + workflow.add_node("general_build_structured_node", general_build_structured_node) - workflow.add_node("test_branch_node", test_branch_node) - workflow.add_node("test_node", test_node) - if not test_commands: - workflow.add_node("test_tools", test_tools) - workflow.add_node("general_test_structured_node", general_test_structured_node) + workflow.add_node("test_branch_node", test_branch_node) + workflow.add_node("test_node", test_node) + if not test_commands: + workflow.add_node("test_tools", test_tools) + workflow.add_node("general_test_structured_node", general_test_structured_node) - workflow.set_entry_point("build_branch_node") - workflow.add_conditional_edges( - "build_branch_node", - lambda state: state["run_build"], - {True: "build_node", False: "test_branch_node"}, - ) - if build_commands: - workflow.add_edge("build_node", "general_build_structured_node") - else: - workflow.add_conditional_edges( - "build_node", - functools.partial(tools_condition, messages_key="build_messages"), - { - "tools": "build_tools", - END: "general_build_structured_node", - }, - ) - workflow.add_edge("build_tools", "build_node") - workflow.add_edge("general_build_structured_node", "test_node") + workflow.set_entry_point("build_branch_node") + workflow.add_conditional_edges( + "build_branch_node", + lambda state: state["run_build"], + {True: "build_node", False: "test_branch_node"}, + ) + if build_commands: + workflow.add_edge("build_node", "general_build_structured_node") + else: + workflow.add_conditional_edges( + "build_node", + functools.partial(tools_condition, messages_key="build_messages"), + { + "tools": "build_tools", + END: "general_build_structured_node", + }, + ) + workflow.add_edge("build_tools", "build_node") + workflow.add_edge("general_build_structured_node", "test_node") - workflow.add_conditional_edges( - "test_branch_node", - lambda state: state["run_existing_test"], - {True: "test_node", False: END}, - ) - if test_commands: - workflow.add_edge("test_node", "general_test_structured_node") - else: - workflow.add_conditional_edges( - "test_node", - functools.partial(tools_condition, messages_key="test_messages"), - { - "tools": "test_tools", - END: "general_test_structured_node", - }, - ) - workflow.add_edge("test_tools", "test_node") - workflow.add_edge("general_test_structured_node", END) + workflow.add_conditional_edges( + "test_branch_node", + lambda state: state["run_existing_test"], + {True: "test_node", False: END}, + ) + if test_commands: + workflow.add_edge("test_node", "general_test_structured_node") + else: + workflow.add_conditional_edges( + "test_node", + functools.partial(tools_condition, messages_key="test_messages"), + { + "tools": "test_tools", + END: "general_test_structured_node", + }, + ) + workflow.add_edge("test_tools", "test_node") + workflow.add_edge("general_test_structured_node", END) - self.subgraph = workflow.compile() + self.subgraph = workflow.compile() - def invoke( - self, - run_build: bool, - run_existing_test: bool, - exist_build: Optional[bool] = None, - build_command_summary: Optional[str] = None, - build_fail_log: Optional[str] = None, - exist_test: Optional[bool] = None, - test_command_summary: Optional[str] = None, - existing_test_fail_log: Optional[str] = None, - ) -> Mapping[str, Union[bool, str]]: - config = None + def invoke( + self, + run_build: bool, + run_existing_test: bool, + exist_build: Optional[bool] = None, + build_command_summary: Optional[str] = None, + build_fail_log: Optional[str] = None, + exist_test: Optional[bool] = None, + test_command_summary: Optional[str] = None, + existing_test_fail_log: Optional[str] = None, + ) -> Mapping[str, Union[bool, str]]: + config = None - input_state = { - "run_build": run_build, - "run_existing_test": run_existing_test, - } - if exist_build: - input_state["exist_build"] = exist_build - if build_command_summary: - input_state["build_command_summary"] = build_command_summary - if build_fail_log: - input_state["build_fail_log"] = build_fail_log - if exist_test: - input_state["exist_test"] = exist_test - if test_command_summary: - input_state["test_command_summary"] = test_command_summary - if existing_test_fail_log: - input_state["existing_test_fail_log"] = existing_test_fail_log + input_state = { + "run_build": run_build, + "run_existing_test": run_existing_test, + } + if exist_build: + input_state["exist_build"] = exist_build + if build_command_summary: + input_state["build_command_summary"] = build_command_summary + if build_fail_log: + input_state["build_fail_log"] = build_fail_log + if exist_test: + input_state["exist_test"] = exist_test + if test_command_summary: + input_state["test_command_summary"] = test_command_summary + if existing_test_fail_log: + input_state["existing_test_fail_log"] = existing_test_fail_log - output_state = self.subgraph.invoke(input_state, config) + output_state = self.subgraph.invoke(input_state, config) - return { - "exist_build": output_state.get("exist_build", False), - "build_command_summary": output_state.get("build_command_summary", ""), - "build_fail_log": output_state.get("build_fail_log", ""), - "exist_test": output_state.get("exist_test", False), - "test_command_summary": output_state.get("test_command_summary", ""), - "existing_test_fail_log": output_state.get("existing_test_fail_log", ""), - } + return { + "exist_build": output_state.get("exist_build", False), + "build_command_summary": output_state.get("build_command_summary", ""), + "build_fail_log": output_state.get("build_fail_log", ""), + "exist_test": output_state.get("exist_test", False), + "test_command_summary": output_state.get("test_command_summary", ""), + "existing_test_fail_log": output_state.get("existing_test_fail_log", ""), + } diff --git a/prometheus/lang_graph/subgraphs/context_retrieval_state.py b/prometheus/lang_graph/subgraphs/context_retrieval_state.py index a662b4a2..dd11adb2 100644 --- a/prometheus/lang_graph/subgraphs/context_retrieval_state.py +++ b/prometheus/lang_graph/subgraphs/context_retrieval_state.py @@ -5,9 +5,9 @@ class ContextRetrievalState(TypedDict): - query: str - max_refined_query_loop: int + query: str + max_refined_query_loop: int - context_provider_messages: Annotated[Sequence[BaseMessage], add_messages] - refined_query: str - context: Sequence[str] + context_provider_messages: Annotated[Sequence[BaseMessage], add_messages] + refined_query: str + context: Sequence[str] diff --git a/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py b/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py index 7b6b9931..a44c07b3 100644 --- a/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py +++ b/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py @@ -16,57 +16,61 @@ class ContextRetrievalSubgraph: - def __init__( - self, - model: BaseChatModel, - kg: KnowledgeGraph, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - ): - context_query_message_node = ContextQueryMessageNode() - context_provider_node = ContextProviderNode(model, kg, neo4j_driver, max_token_per_neo4j_result) - context_provider_tools = ToolNode( - tools=context_provider_node.tools, - name="context_provider_tools", - messages_key="context_provider_messages", - ) - context_selection_node = ContextSelectionNode(model) - reset_context_provider_messages_node = ResetMessagesNode("context_provider_messages") - context_refine_node = ContextRefineNode(model, kg) + def __init__( + self, + model: BaseChatModel, + kg: KnowledgeGraph, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + ): + context_query_message_node = ContextQueryMessageNode() + context_provider_node = ContextProviderNode( + model, kg, neo4j_driver, max_token_per_neo4j_result + ) + context_provider_tools = ToolNode( + tools=context_provider_node.tools, + name="context_provider_tools", + messages_key="context_provider_messages", + ) + context_selection_node = ContextSelectionNode(model) + reset_context_provider_messages_node = ResetMessagesNode("context_provider_messages") + context_refine_node = ContextRefineNode(model, kg) - workflow = StateGraph(ContextRetrievalState) + workflow = StateGraph(ContextRetrievalState) - workflow.add_node("context_query_message_node", context_query_message_node) - workflow.add_node("context_provider_node", context_provider_node) - workflow.add_node("context_provider_tools", context_provider_tools) - workflow.add_node("context_selection_node", context_selection_node) - workflow.add_node("reset_context_provider_messages_node", reset_context_provider_messages_node) - workflow.add_node("context_refine_node", context_refine_node) + workflow.add_node("context_query_message_node", context_query_message_node) + workflow.add_node("context_provider_node", context_provider_node) + workflow.add_node("context_provider_tools", context_provider_tools) + workflow.add_node("context_selection_node", context_selection_node) + workflow.add_node( + "reset_context_provider_messages_node", reset_context_provider_messages_node + ) + workflow.add_node("context_refine_node", context_refine_node) - workflow.set_entry_point("context_query_message_node") - workflow.add_edge("context_query_message_node", "context_provider_node") - workflow.add_conditional_edges( - "context_provider_node", - functools.partial(tools_condition, messages_key="context_provider_messages"), - {"tools": "context_provider_tools", END: "context_selection_node"}, - ) - workflow.add_edge("context_provider_tools", "context_provider_node") - workflow.add_edge("context_selection_node", "reset_context_provider_messages_node") - workflow.add_edge("reset_context_provider_messages_node", "context_refine_node") - workflow.add_conditional_edges( - "context_refine_node", - lambda state: bool(state["refined_query"]), - {True: "context_provider_node", False: END}, - ) + workflow.set_entry_point("context_query_message_node") + workflow.add_edge("context_query_message_node", "context_provider_node") + workflow.add_conditional_edges( + "context_provider_node", + functools.partial(tools_condition, messages_key="context_provider_messages"), + {"tools": "context_provider_tools", END: "context_selection_node"}, + ) + workflow.add_edge("context_provider_tools", "context_provider_node") + workflow.add_edge("context_selection_node", "reset_context_provider_messages_node") + workflow.add_edge("reset_context_provider_messages_node", "context_refine_node") + workflow.add_conditional_edges( + "context_refine_node", + lambda state: bool(state["refined_query"]), + {True: "context_provider_node", False: END}, + ) - self.subgraph = workflow.compile() + self.subgraph = workflow.compile() - def invoke( - self, query: str, max_refined_query_loop: int, recursion_limit: int = 999 - ) -> Sequence[str]: - config = {"recursion_limit": recursion_limit} + def invoke( + self, query: str, max_refined_query_loop: int, recursion_limit: int = 999 + ) -> Sequence[str]: + config = {"recursion_limit": recursion_limit} - input_state = {"query": query, "max_refined_query_loop": max_refined_query_loop} + input_state = {"query": query, "max_refined_query_loop": max_refined_query_loop} - output_state = self.subgraph.invoke(input_state, config) - return {"context": output_state["context"]} + output_state = self.subgraph.invoke(input_state, config) + return {"context": output_state["context"]} diff --git a/prometheus/lang_graph/subgraphs/issue_bug_state.py b/prometheus/lang_graph/subgraphs/issue_bug_state.py index 4195eed1..38cef40c 100644 --- a/prometheus/lang_graph/subgraphs/issue_bug_state.py +++ b/prometheus/lang_graph/subgraphs/issue_bug_state.py @@ -2,21 +2,21 @@ class IssueBugState(TypedDict): - issue_title: str - issue_body: str - issue_comments: Sequence[Mapping[str, str]] - run_build: bool - run_existing_test: bool - number_of_candidate_patch: int + issue_title: str + issue_body: str + issue_comments: Sequence[Mapping[str, str]] + run_build: bool + run_existing_test: bool + number_of_candidate_patch: int - reproduced_bug: bool - reproduced_bug_file: str - reproduced_bug_commands: Sequence[str] + reproduced_bug: bool + reproduced_bug_file: str + reproduced_bug_commands: Sequence[str] - edit_patch: str + edit_patch: str - passed_reproducing_test: bool - passed_build: bool - passed_existing_test: bool + passed_reproducing_test: bool + passed_build: bool + passed_existing_test: bool - issue_response: str + issue_response: str diff --git a/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py index e88d8d41..086cd8af 100644 --- a/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py @@ -10,112 +10,119 @@ from prometheus.lang_graph.nodes.bug_reproduction_subgraph_node import BugReproductionSubgraphNode from prometheus.lang_graph.nodes.issue_bug_responder_node import IssueBugResponderNode from prometheus.lang_graph.nodes.issue_not_verified_bug_subgraph_node import ( - IssueNotVerifiedBugSubgraphNode, + IssueNotVerifiedBugSubgraphNode, ) from prometheus.lang_graph.nodes.issue_verified_bug_subgraph_node import ( - IssueVerifiedBugSubgraphNode, + IssueVerifiedBugSubgraphNode, ) from prometheus.lang_graph.subgraphs.issue_bug_state import IssueBugState class IssueBugSubgraph: - def __init__( - self, - advanced_model: BaseChatModel, - base_model: BaseChatModel, - container: BaseContainer, - kg: KnowledgeGraph, - git_repo: GitRepository, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - build_commands: Optional[Sequence[str]] = None, - test_commands: Optional[Sequence[str]] = None, - ): - bug_reproduction_subgraph_node = BugReproductionSubgraphNode( - advanced_model=advanced_model, - base_model=base_model, - container=container, - kg=kg, - git_repo=git_repo, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - test_commands=test_commands, - ) - - issue_verified_bug_subgraph_node = IssueVerifiedBugSubgraphNode( - advanced_model=advanced_model, - base_model=base_model, - container=container, - kg=kg, - git_repo=git_repo, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - build_commands=build_commands, - test_commands=test_commands, - ) - issue_not_verified_bug_subgraph_node = IssueNotVerifiedBugSubgraphNode( - advanced_model=advanced_model, - base_model=base_model, - kg=kg, - git_repo=git_repo, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - ) - - issue_bug_responder_node = IssueBugResponderNode(base_model) - - workflow = StateGraph(IssueBugState) - - workflow.add_node("bug_reproduction_subgraph_node", bug_reproduction_subgraph_node) - - workflow.add_node("issue_verified_bug_subgraph_node", issue_verified_bug_subgraph_node) - workflow.add_node("issue_not_verified_bug_subgraph_node", issue_not_verified_bug_subgraph_node) - - workflow.add_node("issue_bug_responder_node", issue_bug_responder_node) - - workflow.set_entry_point("bug_reproduction_subgraph_node") - - workflow.add_conditional_edges( - "bug_reproduction_subgraph_node", - lambda state: state["reproduced_bug"] or state["run_build"] or state["run_existing_test"], - {True: "issue_verified_bug_subgraph_node", False: "issue_not_verified_bug_subgraph_node"}, - ) - workflow.add_conditional_edges( - "issue_verified_bug_subgraph_node", - lambda state: bool(state["edit_patch"]), - {True: "issue_bug_responder_node", False: "issue_not_verified_bug_subgraph_node"}, - ) - workflow.add_edge("issue_not_verified_bug_subgraph_node", "issue_bug_responder_node") - workflow.add_edge("issue_bug_responder_node", END) - - self.subgraph = workflow.compile() - - def invoke( - self, - issue_title: str, - issue_body: str, - issue_comments: Sequence[Mapping[str, str]], - run_build: bool, - run_existing_test: bool, - number_of_candidate_patch: int, - recursion_limit: int = 30, - ): - config = {"recursion_limit": recursion_limit} - - input_state = { - "issue_title": issue_title, - "issue_body": issue_body, - "issue_comments": issue_comments, - "run_build": run_build, - "run_existing_test": run_existing_test, - "number_of_candidate_patch": number_of_candidate_patch, - } - - output_state = self.subgraph.invoke(input_state, config) - return { - "edit_patch": output_state["edit_patch"], - "passed_reproducing_test": output_state["passed_reproducing_test"], - "passed_build": output_state["passed_build"], - "passed_existing_test": output_state["passed_existing_test"], - "issue_response": output_state["issue_response"], - } + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + container: BaseContainer, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + build_commands: Optional[Sequence[str]] = None, + test_commands: Optional[Sequence[str]] = None, + ): + bug_reproduction_subgraph_node = BugReproductionSubgraphNode( + advanced_model=advanced_model, + base_model=base_model, + container=container, + kg=kg, + git_repo=git_repo, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + test_commands=test_commands, + ) + + issue_verified_bug_subgraph_node = IssueVerifiedBugSubgraphNode( + advanced_model=advanced_model, + base_model=base_model, + container=container, + kg=kg, + git_repo=git_repo, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + build_commands=build_commands, + test_commands=test_commands, + ) + issue_not_verified_bug_subgraph_node = IssueNotVerifiedBugSubgraphNode( + advanced_model=advanced_model, + base_model=base_model, + kg=kg, + git_repo=git_repo, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + ) + + issue_bug_responder_node = IssueBugResponderNode(base_model) + + workflow = StateGraph(IssueBugState) + + workflow.add_node("bug_reproduction_subgraph_node", bug_reproduction_subgraph_node) + + workflow.add_node("issue_verified_bug_subgraph_node", issue_verified_bug_subgraph_node) + workflow.add_node( + "issue_not_verified_bug_subgraph_node", issue_not_verified_bug_subgraph_node + ) + + workflow.add_node("issue_bug_responder_node", issue_bug_responder_node) + + workflow.set_entry_point("bug_reproduction_subgraph_node") + + workflow.add_conditional_edges( + "bug_reproduction_subgraph_node", + lambda state: state["reproduced_bug"] + or state["run_build"] + or state["run_existing_test"], + { + True: "issue_verified_bug_subgraph_node", + False: "issue_not_verified_bug_subgraph_node", + }, + ) + workflow.add_conditional_edges( + "issue_verified_bug_subgraph_node", + lambda state: bool(state["edit_patch"]), + {True: "issue_bug_responder_node", False: "issue_not_verified_bug_subgraph_node"}, + ) + workflow.add_edge("issue_not_verified_bug_subgraph_node", "issue_bug_responder_node") + workflow.add_edge("issue_bug_responder_node", END) + + self.subgraph = workflow.compile() + + def invoke( + self, + issue_title: str, + issue_body: str, + issue_comments: Sequence[Mapping[str, str]], + run_build: bool, + run_existing_test: bool, + number_of_candidate_patch: int, + recursion_limit: int = 30, + ): + config = {"recursion_limit": recursion_limit} + + input_state = { + "issue_title": issue_title, + "issue_body": issue_body, + "issue_comments": issue_comments, + "run_build": run_build, + "run_existing_test": run_existing_test, + "number_of_candidate_patch": number_of_candidate_patch, + } + + output_state = self.subgraph.invoke(input_state, config) + return { + "edit_patch": output_state["edit_patch"], + "passed_reproducing_test": output_state["passed_reproducing_test"], + "passed_build": output_state["passed_build"], + "passed_existing_test": output_state["passed_existing_test"], + "issue_response": output_state["issue_response"], + } diff --git a/prometheus/lang_graph/subgraphs/issue_classification_state.py b/prometheus/lang_graph/subgraphs/issue_classification_state.py index ad99ad8d..bd4e2741 100644 --- a/prometheus/lang_graph/subgraphs/issue_classification_state.py +++ b/prometheus/lang_graph/subgraphs/issue_classification_state.py @@ -2,15 +2,15 @@ class IssueClassificationState(TypedDict): - # Attributes provided by the user - issue_title: str - issue_body: str - issue_comments: Sequence[Mapping[str, str]] + # Attributes provided by the user + issue_title: str + issue_body: str + issue_comments: Sequence[Mapping[str, str]] - max_refined_query_loop: int + max_refined_query_loop: int - # Attributes generated and by the subgraph - issue_classification_query: str - issue_classification_context: Sequence[str] + # Attributes generated and by the subgraph + issue_classification_query: str + issue_classification_context: Sequence[str] - issue_type: str + issue_type: str diff --git a/prometheus/lang_graph/subgraphs/issue_classification_subgraph.py b/prometheus/lang_graph/subgraphs/issue_classification_subgraph.py index 9b187841..c6b2b73f 100644 --- a/prometheus/lang_graph/subgraphs/issue_classification_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_classification_subgraph.py @@ -7,64 +7,64 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.nodes.context_retrieval_subgraph_node import ContextRetrievalSubgraphNode from prometheus.lang_graph.nodes.issue_classification_context_message_node import ( - IssueClassificationContextMessageNode, + IssueClassificationContextMessageNode, ) from prometheus.lang_graph.nodes.issue_classifier_node import IssueClassifierNode from prometheus.lang_graph.subgraphs.issue_classification_state import IssueClassificationState class IssueClassificationSubgraph: - def __init__( - self, - model: BaseChatModel, - kg: KnowledgeGraph, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - ): - issue_classification_context_message_node = IssueClassificationContextMessageNode() - context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( - model=model, - kg=kg, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - query_key_name="issue_classification_query", - context_key_name="issue_classification_context", - ) - issue_classifier_node = IssueClassifierNode(model) + def __init__( + self, + model: BaseChatModel, + kg: KnowledgeGraph, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + ): + issue_classification_context_message_node = IssueClassificationContextMessageNode() + context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( + model=model, + kg=kg, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + query_key_name="issue_classification_query", + context_key_name="issue_classification_context", + ) + issue_classifier_node = IssueClassifierNode(model) - workflow = StateGraph(IssueClassificationState) - workflow.add_node( - "issue_classification_context_message_node", issue_classification_context_message_node - ) - workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) - workflow.add_node("issue_classifier_node", issue_classifier_node) + workflow = StateGraph(IssueClassificationState) + workflow.add_node( + "issue_classification_context_message_node", issue_classification_context_message_node + ) + workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) + workflow.add_node("issue_classifier_node", issue_classifier_node) - workflow.set_entry_point("issue_classification_context_message_node") - workflow.add_edge( - "issue_classification_context_message_node", "context_retrieval_subgraph_node" - ) - workflow.add_edge("context_retrieval_subgraph_node", "issue_classifier_node") - workflow.add_edge("issue_classifier_node", END) + workflow.set_entry_point("issue_classification_context_message_node") + workflow.add_edge( + "issue_classification_context_message_node", "context_retrieval_subgraph_node" + ) + workflow.add_edge("context_retrieval_subgraph_node", "issue_classifier_node") + workflow.add_edge("issue_classifier_node", END) - self.subgraph = workflow.compile() + self.subgraph = workflow.compile() - def invoke( - self, - issue_title: str, - issue_body: str, - issue_comments: Sequence[Mapping[str, str]], - ) -> str: - config = None + def invoke( + self, + issue_title: str, + issue_body: str, + issue_comments: Sequence[Mapping[str, str]], + ) -> str: + config = None - input_state = { - "issue_title": issue_title, - "issue_body": issue_body, - "issue_comments": issue_comments, - "max_refined_query_loop": 1, - } + input_state = { + "issue_title": issue_title, + "issue_body": issue_body, + "issue_comments": issue_comments, + "max_refined_query_loop": 1, + } - output_state = self.subgraph.invoke( - input_state, - config, - ) - return output_state["issue_type"] + output_state = self.subgraph.invoke( + input_state, + config, + ) + return output_state["issue_type"] diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py index 668be190..d7d8cd1d 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_state.py @@ -6,19 +6,19 @@ class IssueNotVerifiedBugState(TypedDict): - issue_title: str - issue_body: str - issue_comments: Sequence[Mapping[str, str]] - number_of_candidate_patch: int + issue_title: str + issue_body: str + issue_comments: Sequence[Mapping[str, str]] + number_of_candidate_patch: int - max_refined_query_loop: int + max_refined_query_loop: int - bug_fix_query: str - bug_fix_context: Sequence[str] + bug_fix_query: str + bug_fix_context: Sequence[str] - issue_bug_analyzer_messages: Annotated[Sequence[BaseMessage], add_messages] - edit_messages: Annotated[Sequence[BaseMessage], add_messages] + issue_bug_analyzer_messages: Annotated[Sequence[BaseMessage], add_messages] + edit_messages: Annotated[Sequence[BaseMessage], add_messages] - edit_patches: Annotated[Sequence[str], add] + edit_patches: Annotated[Sequence[str], add] - final_patch: str + final_patch: str diff --git a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py index e81a35c0..6346ea29 100644 --- a/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_not_verified_bug_subgraph.py @@ -22,111 +22,111 @@ class IssueNotVerifiedBugSubgraph: - def __init__( - self, - advanced_model: BaseChatModel, - base_model: BaseChatModel, - kg: KnowledgeGraph, - git_repo: GitRepository, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - ): - issue_bug_context_message_node = IssueBugContextMessageNode() - context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( - model=base_model, - kg=kg, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - query_key_name="bug_fix_query", - context_key_name="bug_fix_context", - ) - - issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode() - issue_bug_analyzer_node = IssueBugAnalyzerNode(advanced_model) - - edit_message_node = EditMessageNode() - edit_node = EditNode(advanced_model, kg) - edit_tools = ToolNode( - tools=edit_node.tools, - name="edit_tools", - messages_key="edit_messages", - ) - git_diff_node = GitDiffNode(git_repo, "edit_patches", return_list=True) - - git_reset_node = GitResetNode(git_repo) - reset_issue_bug_analyzer_messages_node = ResetMessagesNode("issue_bug_analyzer_messages") - reset_edit_messages_node = ResetMessagesNode("edit_messages") - - final_patch_selection_node = FinalPatchSelectionNode(advanced_model) - - workflow = StateGraph(IssueNotVerifiedBugState) - - workflow.add_node("issue_bug_context_message_node", issue_bug_context_message_node) - workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) - - workflow.add_node("issue_bug_analyzer_message_node", issue_bug_analyzer_message_node) - workflow.add_node("issue_bug_analyzer_node", issue_bug_analyzer_node) - - workflow.add_node("edit_message_node", edit_message_node) - workflow.add_node("edit_node", edit_node) - workflow.add_node("edit_tools", edit_tools) - workflow.add_node("git_diff_node", git_diff_node) - - workflow.add_node("git_reset_node", git_reset_node) - workflow.add_node( - "reset_issue_bug_analyzer_messages_node", reset_issue_bug_analyzer_messages_node - ) - workflow.add_node("reset_edit_messages_node", reset_edit_messages_node) - - workflow.add_node("final_patch_selection_node", final_patch_selection_node) - - workflow.set_entry_point("issue_bug_context_message_node") - workflow.add_edge("issue_bug_context_message_node", "context_retrieval_subgraph_node") - workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_analyzer_message_node") - workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") - workflow.add_edge("issue_bug_analyzer_node", "edit_message_node") - - workflow.add_edge("edit_message_node", "edit_node") - workflow.add_conditional_edges( - "edit_node", - functools.partial(tools_condition, messages_key="edit_messages"), - {"tools": "edit_tools", END: "git_diff_node"}, - ) - workflow.add_edge("edit_tools", "edit_node") - - workflow.add_conditional_edges( - "git_diff_node", - lambda state: len(state["edit_patches"]) < state["number_of_candidate_patch"], - {True: "git_reset_node", False: "final_patch_selection_node"}, - ) - - workflow.add_edge("git_reset_node", "reset_issue_bug_analyzer_messages_node") - workflow.add_edge("reset_issue_bug_analyzer_messages_node", "reset_edit_messages_node") - workflow.add_edge("reset_edit_messages_node", "issue_bug_analyzer_message_node") - - workflow.add_edge("final_patch_selection_node", END) - - self.subgraph = workflow.compile() - - def invoke( - self, - issue_title: str, - issue_body: str, - issue_comments: Sequence[Mapping[str, str]], - number_of_candidate_patch: int, - recursion_limit: int = 999, - ): - config = {"recursion_limit": recursion_limit} - - input_state = { - "issue_title": issue_title, - "issue_body": issue_body, - "issue_comments": issue_comments, - "number_of_candidate_patch": number_of_candidate_patch, - "max_refined_query_loop": 3, - } - - output_state = self.subgraph.invoke(input_state, config) - return { - "final_patch": output_state["final_patch"], - } + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + ): + issue_bug_context_message_node = IssueBugContextMessageNode() + context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( + model=base_model, + kg=kg, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + query_key_name="bug_fix_query", + context_key_name="bug_fix_context", + ) + + issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode() + issue_bug_analyzer_node = IssueBugAnalyzerNode(advanced_model) + + edit_message_node = EditMessageNode() + edit_node = EditNode(advanced_model, kg) + edit_tools = ToolNode( + tools=edit_node.tools, + name="edit_tools", + messages_key="edit_messages", + ) + git_diff_node = GitDiffNode(git_repo, "edit_patches", return_list=True) + + git_reset_node = GitResetNode(git_repo) + reset_issue_bug_analyzer_messages_node = ResetMessagesNode("issue_bug_analyzer_messages") + reset_edit_messages_node = ResetMessagesNode("edit_messages") + + final_patch_selection_node = FinalPatchSelectionNode(advanced_model) + + workflow = StateGraph(IssueNotVerifiedBugState) + + workflow.add_node("issue_bug_context_message_node", issue_bug_context_message_node) + workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) + + workflow.add_node("issue_bug_analyzer_message_node", issue_bug_analyzer_message_node) + workflow.add_node("issue_bug_analyzer_node", issue_bug_analyzer_node) + + workflow.add_node("edit_message_node", edit_message_node) + workflow.add_node("edit_node", edit_node) + workflow.add_node("edit_tools", edit_tools) + workflow.add_node("git_diff_node", git_diff_node) + + workflow.add_node("git_reset_node", git_reset_node) + workflow.add_node( + "reset_issue_bug_analyzer_messages_node", reset_issue_bug_analyzer_messages_node + ) + workflow.add_node("reset_edit_messages_node", reset_edit_messages_node) + + workflow.add_node("final_patch_selection_node", final_patch_selection_node) + + workflow.set_entry_point("issue_bug_context_message_node") + workflow.add_edge("issue_bug_context_message_node", "context_retrieval_subgraph_node") + workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_analyzer_message_node") + workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") + workflow.add_edge("issue_bug_analyzer_node", "edit_message_node") + + workflow.add_edge("edit_message_node", "edit_node") + workflow.add_conditional_edges( + "edit_node", + functools.partial(tools_condition, messages_key="edit_messages"), + {"tools": "edit_tools", END: "git_diff_node"}, + ) + workflow.add_edge("edit_tools", "edit_node") + + workflow.add_conditional_edges( + "git_diff_node", + lambda state: len(state["edit_patches"]) < state["number_of_candidate_patch"], + {True: "git_reset_node", False: "final_patch_selection_node"}, + ) + + workflow.add_edge("git_reset_node", "reset_issue_bug_analyzer_messages_node") + workflow.add_edge("reset_issue_bug_analyzer_messages_node", "reset_edit_messages_node") + workflow.add_edge("reset_edit_messages_node", "issue_bug_analyzer_message_node") + + workflow.add_edge("final_patch_selection_node", END) + + self.subgraph = workflow.compile() + + def invoke( + self, + issue_title: str, + issue_body: str, + issue_comments: Sequence[Mapping[str, str]], + number_of_candidate_patch: int, + recursion_limit: int = 999, + ): + config = {"recursion_limit": recursion_limit} + + input_state = { + "issue_title": issue_title, + "issue_body": issue_body, + "issue_comments": issue_comments, + "number_of_candidate_patch": number_of_candidate_patch, + "max_refined_query_loop": 3, + } + + output_state = self.subgraph.invoke(input_state, config) + return { + "final_patch": output_state["final_patch"], + } diff --git a/prometheus/lang_graph/subgraphs/issue_verified_bug_state.py b/prometheus/lang_graph/subgraphs/issue_verified_bug_state.py index a3fea66b..227f8291 100644 --- a/prometheus/lang_graph/subgraphs/issue_verified_bug_state.py +++ b/prometheus/lang_graph/subgraphs/issue_verified_bug_state.py @@ -5,33 +5,33 @@ class IssueVerifiedBugState(TypedDict): - issue_title: str - issue_body: str - issue_comments: Sequence[Mapping[str, str]] - run_build: bool - run_existing_test: bool + issue_title: str + issue_body: str + issue_comments: Sequence[Mapping[str, str]] + run_build: bool + run_existing_test: bool - reproduced_bug_file: str - reproduced_bug_commands: Sequence[str] - max_refined_query_loop: int + reproduced_bug_file: str + reproduced_bug_commands: Sequence[str] + max_refined_query_loop: int - bug_fix_query: str - bug_fix_context: Sequence[str] + bug_fix_query: str + bug_fix_context: Sequence[str] - issue_bug_analyzer_messages: Annotated[Sequence[BaseMessage], add_messages] - edit_messages: Annotated[Sequence[BaseMessage], add_messages] + issue_bug_analyzer_messages: Annotated[Sequence[BaseMessage], add_messages] + edit_messages: Annotated[Sequence[BaseMessage], add_messages] - edit_patch: str + edit_patch: str - reproducing_test_fail_log: str + reproducing_test_fail_log: str - exist_build: bool - build_command_summary: str - build_fail_log: str + exist_build: bool + build_command_summary: str + build_fail_log: str - exist_test: bool - test_command_summary: str - existing_test_fail_log: str + exist_test: bool + test_command_summary: str + existing_test_fail_log: str - max_refined_query_loop: int - refined_query: str + max_refined_query_loop: int + refined_query: str diff --git a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py index 43b87db3..8f777848 100644 --- a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py @@ -10,7 +10,7 @@ from prometheus.git.git_repository import GitRepository from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.nodes.bug_fix_verification_subgraph_node import ( - BugFixVerificationSubgraphNode, + BugFixVerificationSubgraphNode, ) from prometheus.lang_graph.nodes.build_and_test_subgraph_node import BuildAndTestSubgraphNode from prometheus.lang_graph.nodes.context_retrieval_subgraph_node import ContextRetrievalSubgraphNode @@ -26,137 +26,137 @@ class IssueVerifiedBugSubgraph: - def __init__( - self, - advanced_model: BaseChatModel, - base_model: BaseChatModel, - container: BaseContainer, - kg: KnowledgeGraph, - git_repo: GitRepository, - neo4j_driver: neo4j.Driver, - max_token_per_neo4j_result: int, - build_commands: Optional[Sequence[str]] = None, - test_commands: Optional[Sequence[str]] = None, - ): - issue_bug_context_message_node = IssueBugContextMessageNode() - context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( - model=base_model, - kg=kg, - neo4j_driver=neo4j_driver, - max_token_per_neo4j_result=max_token_per_neo4j_result, - query_key_name="bug_fix_query", - context_key_name="bug_fix_context", - ) - - issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode() - issue_bug_analyzer_node = IssueBugAnalyzerNode(advanced_model) - - edit_message_node = EditMessageNode() - edit_node = EditNode(advanced_model, kg) - edit_tools = ToolNode( - tools=edit_node.tools, - name="edit_tools", - messages_key="edit_messages", - ) - git_diff_node = GitDiffNode(git_repo, "edit_patch", "reproduced_bug_file") - update_container_node = UpdateContainerNode(container, git_repo) - - bug_fix_verification_subgraph_node = BugFixVerificationSubgraphNode( - base_model, - container, - ) - build_or_test_branch_node = NoopNode() - build_and_test_subgraph_node = BuildAndTestSubgraphNode( - container, - advanced_model, - kg, - build_commands, - test_commands, - ) - - workflow = StateGraph(IssueVerifiedBugState) - - workflow.add_node("issue_bug_context_message_node", issue_bug_context_message_node) - workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) - - workflow.add_node("issue_bug_analyzer_message_node", issue_bug_analyzer_message_node) - workflow.add_node("issue_bug_analyzer_node", issue_bug_analyzer_node) - - workflow.add_node("edit_message_node", edit_message_node) - workflow.add_node("edit_node", edit_node) - workflow.add_node("edit_tools", edit_tools) - workflow.add_node("git_diff_node", git_diff_node) - workflow.add_node("update_container_node", update_container_node) - - workflow.add_node("bug_fix_verification_subgraph_node", bug_fix_verification_subgraph_node) - workflow.add_node("build_or_test_branch_node", build_or_test_branch_node) - workflow.add_node("build_and_test_subgraph_node", build_and_test_subgraph_node) - - workflow.set_entry_point("issue_bug_context_message_node") - workflow.add_edge("issue_bug_context_message_node", "context_retrieval_subgraph_node") - workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_analyzer_message_node") - - workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") - workflow.add_edge("issue_bug_analyzer_node", "edit_message_node") - - workflow.add_edge("edit_message_node", "edit_node") - workflow.add_conditional_edges( - "edit_node", - functools.partial(tools_condition, messages_key="edit_messages"), - {"tools": "edit_tools", END: "git_diff_node"}, - ) - workflow.add_edge("edit_tools", "edit_node") - workflow.add_edge("git_diff_node", "update_container_node") - workflow.add_edge("update_container_node", "bug_fix_verification_subgraph_node") - - workflow.add_conditional_edges( - "bug_fix_verification_subgraph_node", - lambda state: bool(state["reproducing_test_fail_log"]), - {True: "issue_bug_analyzer_message_node", False: "build_or_test_branch_node"}, - ) - workflow.add_conditional_edges( - "build_or_test_branch_node", - lambda state: state["run_build"] or state["run_existing_test"], - {True: "build_and_test_subgraph_node", False: END}, - ) - workflow.add_conditional_edges( - "build_and_test_subgraph_node", - lambda state: bool(state["build_fail_log"]) or bool(state["existing_test_fail_log"]), - {True: "issue_bug_analyzer_message_node", False: END}, - ) - - self.subgraph = workflow.compile() - - def invoke( - self, - issue_title: str, - issue_body: str, - issue_comments: Sequence[Mapping[str, str]], - run_build: bool, - run_existing_test: bool, - reproduced_bug_file: str, - reproduced_bug_commands: Sequence[str], - recursion_limit: int = 80, - ): - config = {"recursion_limit": recursion_limit} - - input_state = { - "issue_title": issue_title, - "issue_body": issue_body, - "issue_comments": issue_comments, - "run_build": run_build, - "run_existing_test": run_existing_test, - "reproduced_bug_file": reproduced_bug_file, - "reproduced_bug_commands": reproduced_bug_commands, - "max_refined_query_loop": 3, - } - - output_state = self.subgraph.invoke(input_state, config) - return { - "edit_patch": output_state["edit_patch"], - "reproducing_test_fail_log": output_state["reproducing_test_fail_log"], - "exist_build": output_state.get("exist_build", False), - "build_fail_log": output_state.get("build_fail_log", ""), - "exist_test": output_state.get("exist_test", False), - "existing_test_fail_log": output_state.get("existing_test_fail_log", ""), - } + def __init__( + self, + advanced_model: BaseChatModel, + base_model: BaseChatModel, + container: BaseContainer, + kg: KnowledgeGraph, + git_repo: GitRepository, + neo4j_driver: neo4j.Driver, + max_token_per_neo4j_result: int, + build_commands: Optional[Sequence[str]] = None, + test_commands: Optional[Sequence[str]] = None, + ): + issue_bug_context_message_node = IssueBugContextMessageNode() + context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( + model=base_model, + kg=kg, + neo4j_driver=neo4j_driver, + max_token_per_neo4j_result=max_token_per_neo4j_result, + query_key_name="bug_fix_query", + context_key_name="bug_fix_context", + ) + + issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode() + issue_bug_analyzer_node = IssueBugAnalyzerNode(advanced_model) + + edit_message_node = EditMessageNode() + edit_node = EditNode(advanced_model, kg) + edit_tools = ToolNode( + tools=edit_node.tools, + name="edit_tools", + messages_key="edit_messages", + ) + git_diff_node = GitDiffNode(git_repo, "edit_patch", "reproduced_bug_file") + update_container_node = UpdateContainerNode(container, git_repo) + + bug_fix_verification_subgraph_node = BugFixVerificationSubgraphNode( + base_model, + container, + ) + build_or_test_branch_node = NoopNode() + build_and_test_subgraph_node = BuildAndTestSubgraphNode( + container, + advanced_model, + kg, + build_commands, + test_commands, + ) + + workflow = StateGraph(IssueVerifiedBugState) + + workflow.add_node("issue_bug_context_message_node", issue_bug_context_message_node) + workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) + + workflow.add_node("issue_bug_analyzer_message_node", issue_bug_analyzer_message_node) + workflow.add_node("issue_bug_analyzer_node", issue_bug_analyzer_node) + + workflow.add_node("edit_message_node", edit_message_node) + workflow.add_node("edit_node", edit_node) + workflow.add_node("edit_tools", edit_tools) + workflow.add_node("git_diff_node", git_diff_node) + workflow.add_node("update_container_node", update_container_node) + + workflow.add_node("bug_fix_verification_subgraph_node", bug_fix_verification_subgraph_node) + workflow.add_node("build_or_test_branch_node", build_or_test_branch_node) + workflow.add_node("build_and_test_subgraph_node", build_and_test_subgraph_node) + + workflow.set_entry_point("issue_bug_context_message_node") + workflow.add_edge("issue_bug_context_message_node", "context_retrieval_subgraph_node") + workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_analyzer_message_node") + + workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") + workflow.add_edge("issue_bug_analyzer_node", "edit_message_node") + + workflow.add_edge("edit_message_node", "edit_node") + workflow.add_conditional_edges( + "edit_node", + functools.partial(tools_condition, messages_key="edit_messages"), + {"tools": "edit_tools", END: "git_diff_node"}, + ) + workflow.add_edge("edit_tools", "edit_node") + workflow.add_edge("git_diff_node", "update_container_node") + workflow.add_edge("update_container_node", "bug_fix_verification_subgraph_node") + + workflow.add_conditional_edges( + "bug_fix_verification_subgraph_node", + lambda state: bool(state["reproducing_test_fail_log"]), + {True: "issue_bug_analyzer_message_node", False: "build_or_test_branch_node"}, + ) + workflow.add_conditional_edges( + "build_or_test_branch_node", + lambda state: state["run_build"] or state["run_existing_test"], + {True: "build_and_test_subgraph_node", False: END}, + ) + workflow.add_conditional_edges( + "build_and_test_subgraph_node", + lambda state: bool(state["build_fail_log"]) or bool(state["existing_test_fail_log"]), + {True: "issue_bug_analyzer_message_node", False: END}, + ) + + self.subgraph = workflow.compile() + + def invoke( + self, + issue_title: str, + issue_body: str, + issue_comments: Sequence[Mapping[str, str]], + run_build: bool, + run_existing_test: bool, + reproduced_bug_file: str, + reproduced_bug_commands: Sequence[str], + recursion_limit: int = 80, + ): + config = {"recursion_limit": recursion_limit} + + input_state = { + "issue_title": issue_title, + "issue_body": issue_body, + "issue_comments": issue_comments, + "run_build": run_build, + "run_existing_test": run_existing_test, + "reproduced_bug_file": reproduced_bug_file, + "reproduced_bug_commands": reproduced_bug_commands, + "max_refined_query_loop": 3, + } + + output_state = self.subgraph.invoke(input_state, config) + return { + "edit_patch": output_state["edit_patch"], + "reproducing_test_fail_log": output_state["reproducing_test_fail_log"], + "exist_build": output_state.get("exist_build", False), + "build_fail_log": output_state.get("build_fail_log", ""), + "exist_test": output_state.get("exist_test", False), + "existing_test_fail_log": output_state.get("existing_test_fail_log", ""), + } diff --git a/prometheus/neo4j/knowledge_graph_handler.py b/prometheus/neo4j/knowledge_graph_handler.py index 5b774a3a..d6ce9640 100644 --- a/prometheus/neo4j/knowledge_graph_handler.py +++ b/prometheus/neo4j/knowledge_graph_handler.py @@ -6,314 +6,318 @@ from neo4j import GraphDatabase, ManagedTransaction from prometheus.graph.graph_types import ( - KnowledgeGraphNode, - MetadataNode, - Neo4jASTNode, - Neo4jFileNode, - Neo4jHasASTEdge, - Neo4jHasFileEdge, - Neo4jHasTextEdge, - Neo4jMetadataNode, - Neo4jNextChunkEdge, - Neo4jParentOfEdge, - Neo4jTextNode, + KnowledgeGraphNode, + MetadataNode, + Neo4jASTNode, + Neo4jFileNode, + Neo4jHasASTEdge, + Neo4jHasFileEdge, + Neo4jHasTextEdge, + Neo4jMetadataNode, + Neo4jNextChunkEdge, + Neo4jParentOfEdge, + Neo4jTextNode, ) from prometheus.graph.knowledge_graph import KnowledgeGraph class KnowledgeGraphHandler: - """The handler to writing the Knowledge graph to neo4j.""" - - def __init__(self, driver: GraphDatabase.driver, batch_size: int): - """ - Args: - driver: The neo4j driver. - batch_size: The maximum number of nodes/edges written to neo4j each time. - """ - self.driver = driver - self.batch_size = batch_size - - self._logger = logging.getLogger("prometheus.neo4j.knowledge_graph_handler") - - def _init_database(self, tx: ManagedTransaction): - """Initialization of the neo4j database.""" - - # Create constraints for node_id attributes. - # It also means that node_id will be indexed. - queries = [ - "CREATE CONSTRAINT unique_file_node_id IF NOT EXISTS " - "FOR (n:FileNode) REQUIRE n.node_id IS UNIQUE", - "CREATE CONSTRAINT unique_ast_node_id IF NOT EXISTS " - "FOR (n:ASTNode) REQUIRE n.node_id IS UNIQUE", - "CREATE CONSTRAINT unique_text_node_id IF NOT EXISTS " - "FOR (n:TextNode) REQUIRE n.node_id IS UNIQUE", - ] - for query in queries: - tx.run(query) - - def _write_meta_node(self, tx: ManagedTransaction, metadata_node: Neo4jMetadataNode): - """Write Neo4jMetadataNode to neo4j.""" - self._logger.debug("Writing MetadataNode to neo4j") - query = """ + """The handler to writing the Knowledge graph to neo4j.""" + + def __init__(self, driver: GraphDatabase.driver, batch_size: int): + """ + Args: + driver: The neo4j driver. + batch_size: The maximum number of nodes/edges written to neo4j each time. + """ + self.driver = driver + self.batch_size = batch_size + + self._logger = logging.getLogger("prometheus.neo4j.knowledge_graph_handler") + + def _init_database(self, tx: ManagedTransaction): + """Initialization of the neo4j database.""" + + # Create constraints for node_id attributes. + # It also means that node_id will be indexed. + queries = [ + "CREATE CONSTRAINT unique_file_node_id IF NOT EXISTS " + "FOR (n:FileNode) REQUIRE n.node_id IS UNIQUE", + "CREATE CONSTRAINT unique_ast_node_id IF NOT EXISTS " + "FOR (n:ASTNode) REQUIRE n.node_id IS UNIQUE", + "CREATE CONSTRAINT unique_text_node_id IF NOT EXISTS " + "FOR (n:TextNode) REQUIRE n.node_id IS UNIQUE", + ] + for query in queries: + tx.run(query) + + def _write_meta_node(self, tx: ManagedTransaction, metadata_node: Neo4jMetadataNode): + """Write Neo4jMetadataNode to neo4j.""" + self._logger.debug("Writing MetadataNode to neo4j") + query = """ CREATE (a:MetadataNode {codebase_source: $codebase_source, local_path: $local_path, https_url: $https_url, commit_id: $commit_id}) """ - tx.run( - query, - codebase_source=metadata_node["codebase_source"], - local_path=metadata_node["local_path"], - https_url=metadata_node["https_url"], - commit_id=metadata_node["commit_id"], - ) - - def _write_file_nodes(self, tx: ManagedTransaction, file_nodes: Sequence[Neo4jFileNode]): - """Write Neo4jFileNode to neo4j.""" - self._logger.debug(f"Writing {len(file_nodes)} FileNode to neo4j") - query = """ + tx.run( + query, + codebase_source=metadata_node["codebase_source"], + local_path=metadata_node["local_path"], + https_url=metadata_node["https_url"], + commit_id=metadata_node["commit_id"], + ) + + def _write_file_nodes(self, tx: ManagedTransaction, file_nodes: Sequence[Neo4jFileNode]): + """Write Neo4jFileNode to neo4j.""" + self._logger.debug(f"Writing {len(file_nodes)} FileNode to neo4j") + query = """ UNWIND $file_nodes AS file_node CREATE (a:FileNode {node_id: file_node.node_id, basename: file_node.basename, relative_path: file_node.relative_path}) """ - for i in range(0, len(file_nodes), self.batch_size): - file_nodes_batch = file_nodes[i : i + self.batch_size] - tx.run(query, file_nodes=file_nodes_batch) - - def _write_ast_nodes(self, tx: ManagedTransaction, ast_nodes: Sequence[Neo4jASTNode]): - """Write Neo4jASTNode to neo4j.""" - self._logger.debug(f"Writing {len(ast_nodes)} ASTNode to neo4j") - query = """ + for i in range(0, len(file_nodes), self.batch_size): + file_nodes_batch = file_nodes[i : i + self.batch_size] + tx.run(query, file_nodes=file_nodes_batch) + + def _write_ast_nodes(self, tx: ManagedTransaction, ast_nodes: Sequence[Neo4jASTNode]): + """Write Neo4jASTNode to neo4j.""" + self._logger.debug(f"Writing {len(ast_nodes)} ASTNode to neo4j") + query = """ UNWIND $ast_nodes AS ast_node CREATE (a:ASTNode {node_id: ast_node.node_id, start_line: ast_node.start_line, end_line: ast_node.end_line, type: ast_node.type, text: ast_node.text}) """ - for i in range(0, len(ast_nodes), self.batch_size): - ast_nodes_batch = ast_nodes[i : i + self.batch_size] - tx.run(query, ast_nodes=ast_nodes_batch) - - def _write_text_nodes(self, tx: ManagedTransaction, text_nodes: Sequence[Neo4jTextNode]): - """Write Neo4jTextNode to neo4j.""" - self._logger.debug(f"Writing {len(text_nodes)} TextNode to neo4j") - query = """ + for i in range(0, len(ast_nodes), self.batch_size): + ast_nodes_batch = ast_nodes[i : i + self.batch_size] + tx.run(query, ast_nodes=ast_nodes_batch) + + def _write_text_nodes(self, tx: ManagedTransaction, text_nodes: Sequence[Neo4jTextNode]): + """Write Neo4jTextNode to neo4j.""" + self._logger.debug(f"Writing {len(text_nodes)} TextNode to neo4j") + query = """ UNWIND $text_nodes AS text_node CREATE (a:TextNode {node_id: text_node.node_id, text: text_node.text, metadata: text_node.metadata}) """ - for i in range(0, len(text_nodes), self.batch_size): - text_nodes_batch = text_nodes[i : i + self.batch_size] - tx.run(query, text_nodes=text_nodes_batch) - - def _write_has_file_edges( - self, tx: ManagedTransaction, has_file_edges: Sequence[Neo4jHasFileEdge] - ): - """Write Neo4jHasFileEdge to neo4j.""" - self._logger.debug(f"Writing {len(has_file_edges)} HasFileEdge to neo4j") - query = """ + for i in range(0, len(text_nodes), self.batch_size): + text_nodes_batch = text_nodes[i : i + self.batch_size] + tx.run(query, text_nodes=text_nodes_batch) + + def _write_has_file_edges( + self, tx: ManagedTransaction, has_file_edges: Sequence[Neo4jHasFileEdge] + ): + """Write Neo4jHasFileEdge to neo4j.""" + self._logger.debug(f"Writing {len(has_file_edges)} HasFileEdge to neo4j") + query = """ UNWIND $edges AS edge MATCH (source:FileNode), (target:FileNode) WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id CREATE (source) -[:HAS_FILE]-> (target) """ - for i in range(0, len(has_file_edges), self.batch_size): - has_file_edges_batch = has_file_edges[i : i + self.batch_size] - tx.run(query, edges=has_file_edges_batch) - - def _write_has_ast_edges(self, tx: ManagedTransaction, has_ast_edges: Sequence[Neo4jHasASTEdge]): - """Write Neo4jHasASTEdge to neo4j.""" - self._logger.debug(f"Writing {len(has_ast_edges)} HasASTEdge to neo4j") - query = """ + for i in range(0, len(has_file_edges), self.batch_size): + has_file_edges_batch = has_file_edges[i : i + self.batch_size] + tx.run(query, edges=has_file_edges_batch) + + def _write_has_ast_edges( + self, tx: ManagedTransaction, has_ast_edges: Sequence[Neo4jHasASTEdge] + ): + """Write Neo4jHasASTEdge to neo4j.""" + self._logger.debug(f"Writing {len(has_ast_edges)} HasASTEdge to neo4j") + query = """ UNWIND $edges AS edge MATCH (source:FileNode), (target:ASTNode) WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id CREATE (source) -[:HAS_AST]-> (target) """ - for i in range(0, len(has_ast_edges), self.batch_size): - has_ast_edges_batch = has_ast_edges[i : i + self.batch_size] - tx.run(query, edges=has_ast_edges_batch) - - def _write_has_text_edges( - self, tx: ManagedTransaction, has_text_edges: Sequence[Neo4jHasTextEdge] - ): - """Write Neo4jHasTextEdge to neo4j.""" - self._logger.debug(f"Writing {len(has_text_edges)} HasTextEdges to neo4j") - query = """ + for i in range(0, len(has_ast_edges), self.batch_size): + has_ast_edges_batch = has_ast_edges[i : i + self.batch_size] + tx.run(query, edges=has_ast_edges_batch) + + def _write_has_text_edges( + self, tx: ManagedTransaction, has_text_edges: Sequence[Neo4jHasTextEdge] + ): + """Write Neo4jHasTextEdge to neo4j.""" + self._logger.debug(f"Writing {len(has_text_edges)} HasTextEdges to neo4j") + query = """ UNWIND $edges AS edge MATCH (source:FileNode), (target:TextNode) WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id CREATE (source) -[:HAS_TEXT]-> (target) """ - for i in range(0, len(has_text_edges), self.batch_size): - has_text_edges_batch = has_text_edges[i : i + self.batch_size] - tx.run(query, edges=has_text_edges_batch) - - def _write_parent_of_edges( - self, tx: ManagedTransaction, parent_of_edges: Sequence[Neo4jParentOfEdge] - ): - """Write Neo4jParentOfEdge to neo4j.""" - self._logger.debug(f"Writing {len(parent_of_edges)} ParentOfEdge to neo4j") - query = """ + for i in range(0, len(has_text_edges), self.batch_size): + has_text_edges_batch = has_text_edges[i : i + self.batch_size] + tx.run(query, edges=has_text_edges_batch) + + def _write_parent_of_edges( + self, tx: ManagedTransaction, parent_of_edges: Sequence[Neo4jParentOfEdge] + ): + """Write Neo4jParentOfEdge to neo4j.""" + self._logger.debug(f"Writing {len(parent_of_edges)} ParentOfEdge to neo4j") + query = """ UNWIND $edges AS edge MATCH (source:ASTNode), (target:ASTNode) WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id CREATE (source) -[:PARENT_OF]-> (target) """ - for i in range(0, len(parent_of_edges), self.batch_size): - parent_of_edges_batch = parent_of_edges[i : i + self.batch_size] - tx.run(query, edges=parent_of_edges_batch) - - def _write_next_chunk_edges( - self, tx: ManagedTransaction, next_chunk_edges: Sequence[Neo4jNextChunkEdge] - ): - """Write Neo4jNextChunkEdge to neo4j.""" - self._logger.debug(f"Writing {len(next_chunk_edges)} NextChunkEdge to neo4j") - query = """ + for i in range(0, len(parent_of_edges), self.batch_size): + parent_of_edges_batch = parent_of_edges[i : i + self.batch_size] + tx.run(query, edges=parent_of_edges_batch) + + def _write_next_chunk_edges( + self, tx: ManagedTransaction, next_chunk_edges: Sequence[Neo4jNextChunkEdge] + ): + """Write Neo4jNextChunkEdge to neo4j.""" + self._logger.debug(f"Writing {len(next_chunk_edges)} NextChunkEdge to neo4j") + query = """ UNWIND $edges AS edge MATCH (source:TextNode), (target:TextNode) WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id CREATE (source) -[:NEXT_CHUNK]-> (target) """ - for i in range(0, len(next_chunk_edges), self.batch_size): - next_chunk_edges_batch = next_chunk_edges[i : i + self.batch_size] - tx.run(query, edges=next_chunk_edges_batch) - - def write_knowledge_graph(self, kg: KnowledgeGraph): - """Write the knowledge graph to neo4j. - - Args: - kg: The knowledge graph to write to neo4j. - """ - self._logger.info("Writing knowledge graph to neo4j") - with self.driver.session() as session: - session.execute_write(self._init_database) - - session.execute_write(self._write_meta_node, kg.get_neo4j_metadata_node()) - session.execute_write(self._write_file_nodes, kg.get_neo4j_file_nodes()) - session.execute_write(self._write_ast_nodes, kg.get_neo4j_ast_nodes()) - session.execute_write(self._write_text_nodes, kg.get_neo4j_text_nodes()) - - session.execute_write(self._write_has_ast_edges, kg.get_neo4j_has_ast_edges()) - session.execute_write(self._write_has_file_edges, kg.get_neo4j_has_file_edges()) - session.execute_write(self._write_has_text_edges, kg.get_neo4j_has_text_edges()) - session.execute_write(self._write_next_chunk_edges, kg.get_neo4j_next_chunk_edges()) - session.execute_write(self._write_parent_of_edges, kg.get_neo4j_parent_of_edges()) - - def _read_metadata_node(self, tx: ManagedTransaction) -> MetadataNode: - """Read MetadataNode from neo4j.""" - query = "MATCH (n:MetadataNode) RETURN n.codebase_source AS codebase_source, n.local_path AS local_path, n.https_url AS https_url, n.commit_id AS commit_id" - result = tx.run(query) - return [MetadataNode.from_neo4j_metadata_node(record.data()) for record in result][0] - - def _read_file_nodes(self, tx: ManagedTransaction) -> Sequence[KnowledgeGraphNode]: - """Read FileNode from neo4j.""" - query = "MATCH (n:FileNode) RETURN n.node_id AS node_id, n.basename AS basename, n.relative_path AS relative_path" - result = tx.run(query) - return [KnowledgeGraphNode.from_neo4j_file_node(record.data()) for record in result] - - def _read_ast_nodes(self, tx: ManagedTransaction) -> Sequence[KnowledgeGraphNode]: - """Read ASTNode from neo4j.""" - query = "MATCH (n:ASTNode) RETURN n.node_id AS node_id, n.start_line AS start_line, n.end_line AS end_line, n.type AS type, n.text AS text" - result = tx.run(query) - return [KnowledgeGraphNode.from_neo4j_ast_node(record.data()) for record in result] - - def _read_text_nodes(self, tx: ManagedTransaction) -> Sequence[KnowledgeGraphNode]: - """Read TextNode from neo4j.""" - query = "MATCH (n:TextNode) RETURN n.node_id AS node_id, n.text AS text, n.metadata AS metadata" - result = tx.run(query) - return [KnowledgeGraphNode.from_neo4j_text_node(record.data()) for record in result] - - def _read_parent_of_edges(self, tx: ManagedTransaction) -> Sequence[Mapping[str, int]]: - """Read PARENT_OF edges from neo4j.""" - query = """ + for i in range(0, len(next_chunk_edges), self.batch_size): + next_chunk_edges_batch = next_chunk_edges[i : i + self.batch_size] + tx.run(query, edges=next_chunk_edges_batch) + + def write_knowledge_graph(self, kg: KnowledgeGraph): + """Write the knowledge graph to neo4j. + + Args: + kg: The knowledge graph to write to neo4j. + """ + self._logger.info("Writing knowledge graph to neo4j") + with self.driver.session() as session: + session.execute_write(self._init_database) + + session.execute_write(self._write_meta_node, kg.get_neo4j_metadata_node()) + session.execute_write(self._write_file_nodes, kg.get_neo4j_file_nodes()) + session.execute_write(self._write_ast_nodes, kg.get_neo4j_ast_nodes()) + session.execute_write(self._write_text_nodes, kg.get_neo4j_text_nodes()) + + session.execute_write(self._write_has_ast_edges, kg.get_neo4j_has_ast_edges()) + session.execute_write(self._write_has_file_edges, kg.get_neo4j_has_file_edges()) + session.execute_write(self._write_has_text_edges, kg.get_neo4j_has_text_edges()) + session.execute_write(self._write_next_chunk_edges, kg.get_neo4j_next_chunk_edges()) + session.execute_write(self._write_parent_of_edges, kg.get_neo4j_parent_of_edges()) + + def _read_metadata_node(self, tx: ManagedTransaction) -> MetadataNode: + """Read MetadataNode from neo4j.""" + query = "MATCH (n:MetadataNode) RETURN n.codebase_source AS codebase_source, n.local_path AS local_path, n.https_url AS https_url, n.commit_id AS commit_id" + result = tx.run(query) + return [MetadataNode.from_neo4j_metadata_node(record.data()) for record in result][0] + + def _read_file_nodes(self, tx: ManagedTransaction) -> Sequence[KnowledgeGraphNode]: + """Read FileNode from neo4j.""" + query = "MATCH (n:FileNode) RETURN n.node_id AS node_id, n.basename AS basename, n.relative_path AS relative_path" + result = tx.run(query) + return [KnowledgeGraphNode.from_neo4j_file_node(record.data()) for record in result] + + def _read_ast_nodes(self, tx: ManagedTransaction) -> Sequence[KnowledgeGraphNode]: + """Read ASTNode from neo4j.""" + query = "MATCH (n:ASTNode) RETURN n.node_id AS node_id, n.start_line AS start_line, n.end_line AS end_line, n.type AS type, n.text AS text" + result = tx.run(query) + return [KnowledgeGraphNode.from_neo4j_ast_node(record.data()) for record in result] + + def _read_text_nodes(self, tx: ManagedTransaction) -> Sequence[KnowledgeGraphNode]: + """Read TextNode from neo4j.""" + query = ( + "MATCH (n:TextNode) RETURN n.node_id AS node_id, n.text AS text, n.metadata AS metadata" + ) + result = tx.run(query) + return [KnowledgeGraphNode.from_neo4j_text_node(record.data()) for record in result] + + def _read_parent_of_edges(self, tx: ManagedTransaction) -> Sequence[Mapping[str, int]]: + """Read PARENT_OF edges from neo4j.""" + query = """ MATCH (source:ASTNode) -[:PARENT_OF]-> (target:ASTNode) RETURN source.node_id AS source_id, target.node_id AS target_id """ - result = tx.run(query) - return [record.data() for record in result] + result = tx.run(query) + return [record.data() for record in result] - def _read_has_file_edges(self, tx: ManagedTransaction) -> Sequence[Mapping[str, int]]: - """Read HAS_FILE edges from neo4j.""" - query = """ + def _read_has_file_edges(self, tx: ManagedTransaction) -> Sequence[Mapping[str, int]]: + """Read HAS_FILE edges from neo4j.""" + query = """ MATCH (source:FileNode) -[:HAS_FILE]-> (target:FileNode) RETURN source.node_id AS source_id, target.node_id AS target_id """ - result = tx.run(query) - return [record.data() for record in result] + result = tx.run(query) + return [record.data() for record in result] - def _read_has_ast_edges(self, tx: ManagedTransaction) -> Sequence[Mapping[str, int]]: - """Read HAS_AST edges from neo4j.""" - query = """ + def _read_has_ast_edges(self, tx: ManagedTransaction) -> Sequence[Mapping[str, int]]: + """Read HAS_AST edges from neo4j.""" + query = """ MATCH (source:FileNode) -[:HAS_AST]-> (target:ASTNode) RETURN source.node_id AS source_id, target.node_id AS target_id """ - result = tx.run(query) - return [record.data() for record in result] + result = tx.run(query) + return [record.data() for record in result] - def _read_has_text_edges(self, tx: ManagedTransaction) -> Sequence[Mapping[str, int]]: - """Read HAS_TEXT edges from neo4j.""" - query = """ + def _read_has_text_edges(self, tx: ManagedTransaction) -> Sequence[Mapping[str, int]]: + """Read HAS_TEXT edges from neo4j.""" + query = """ MATCH (source:FileNode) -[:HAS_TEXT]-> (target:TextNode) RETURN source.node_id AS source_id, target.node_id AS target_id """ - result = tx.run(query) - return [record.data() for record in result] + result = tx.run(query) + return [record.data() for record in result] - def _read_next_chunk_edges(self, tx: ManagedTransaction) -> Sequence[Mapping[str, int]]: - """Read NEXT_CHUNK edges from neo4j.""" - query = """ + def _read_next_chunk_edges(self, tx: ManagedTransaction) -> Sequence[Mapping[str, int]]: + """Read NEXT_CHUNK edges from neo4j.""" + query = """ MATCH (source:TextNode) -[:NEXT_CHUNK]-> (target:TextNode) RETURN source.node_id AS source_id, target.node_id AS target_id """ - result = tx.run(query) - return [record.data() for record in result] - - def read_knowledge_graph(self) -> KnowledgeGraph: - """Read KnowledgeGraph from neo4j.""" - self._logger.info("Reading knowledge graph from neo4j") - with self.driver.session() as session: - return KnowledgeGraph.from_neo4j( - session.execute_read(self._read_metadata_node), - session.execute_read(self._read_file_nodes), - session.execute_read(self._read_ast_nodes), - session.execute_read(self._read_text_nodes), - session.execute_read(self._read_parent_of_edges), - session.execute_read(self._read_has_file_edges), - session.execute_read(self._read_has_ast_edges), - session.execute_read(self._read_has_text_edges), - session.execute_read(self._read_next_chunk_edges), - ) - - def knowledge_graph_exists(self) -> bool: - """Check if the knowledge graph exists in the Neo4j database. - - Returns: - True if there are any nodes in the database, False otherwise. - """ - query = """ + result = tx.run(query) + return [record.data() for record in result] + + def read_knowledge_graph(self) -> KnowledgeGraph: + """Read KnowledgeGraph from neo4j.""" + self._logger.info("Reading knowledge graph from neo4j") + with self.driver.session() as session: + return KnowledgeGraph.from_neo4j( + session.execute_read(self._read_metadata_node), + session.execute_read(self._read_file_nodes), + session.execute_read(self._read_ast_nodes), + session.execute_read(self._read_text_nodes), + session.execute_read(self._read_parent_of_edges), + session.execute_read(self._read_has_file_edges), + session.execute_read(self._read_has_ast_edges), + session.execute_read(self._read_has_text_edges), + session.execute_read(self._read_next_chunk_edges), + ) + + def knowledge_graph_exists(self) -> bool: + """Check if the knowledge graph exists in the Neo4j database. + + Returns: + True if there are any nodes in the database, False otherwise. + """ + query = """ MATCH (n) RETURN COUNT(n) > 0 AS graph_exists """ - with self.driver.session() as session: - result = session.run(query) - record = result.single() - return record["graph_exists"] if record else False + with self.driver.session() as session: + result = session.run(query) + record = result.single() + return record["graph_exists"] if record else False - def verify_empty(self, tx: ManagedTransaction): - query = """ + def verify_empty(self, tx: ManagedTransaction): + query = """ MATCH (n) RETURN count(n) as count """ - result = tx.run(query) - return result.single()["count"] == 0 + result = tx.run(query) + return result.single()["count"] == 0 - def clear_knowledge_graph(self): - """Clear the knowledge graph from neo4j.""" - query = """ + def clear_knowledge_graph(self): + """Clear the knowledge graph from neo4j.""" + query = """ MATCH (n) DETACH DELETE n """ - self._logger.info("Deleting knowledge graph from neo4j") - with self.driver.session() as session: - session.run(query) + self._logger.info("Deleting knowledge graph from neo4j") + with self.driver.session() as session: + session.run(query) - max_retries = 3 - for attempt in range(max_retries): - if session.execute_read(self.verify_empty): - break + max_retries = 3 + for attempt in range(max_retries): + if session.execute_read(self.verify_empty): + break - self._logger.warning(f"Database not empty after attempt {attempt + 1}, retrying...") - session.run(query) + self._logger.warning(f"Database not empty after attempt {attempt + 1}, retrying...") + session.run(query) diff --git a/prometheus/parser/file_types.py b/prometheus/parser/file_types.py index c42a9665..20ceb5c1 100644 --- a/prometheus/parser/file_types.py +++ b/prometheus/parser/file_types.py @@ -3,48 +3,48 @@ class FileType(enum.StrEnum): - """Enum of all tree-sitter supported file types""" + """Enum of all tree-sitter supported file types""" - BASH = "bash" - C = "c" - CSHARP = "csharp" - CPP = "cpp" - GO = "go" - JAVA = "java" - JAVASCRIPT = "javascript" - KOTLIN = "kotlin" - PHP = "php" - PYTHON = "python" - SQL = "sql" - YAML = "yaml" - UNKNOWN = "UNKNOWN" + BASH = "bash" + C = "c" + CSHARP = "csharp" + CPP = "cpp" + GO = "go" + JAVA = "java" + JAVASCRIPT = "javascript" + KOTLIN = "kotlin" + PHP = "php" + PYTHON = "python" + SQL = "sql" + YAML = "yaml" + UNKNOWN = "UNKNOWN" - @classmethod - def from_path(cls, path: Path): - match path.suffix: - case ".sh": - return cls.BASH - case ".c": - return cls.C - case ".cs": - return cls.CSHARP - case ".cpp" | ".cc" | ".cxx": - return cls.CPP - case ".go": - return cls.GO - case ".java": - return cls.JAVA - case ".js": - return cls.JAVASCRIPT - case ".kt": - return cls.KOTLIN - case ".php": - return cls.PHP - case ".py": - return cls.PYTHON - case ".sql": - return cls.SQL - case ".yaml" | ".yml": - return cls.YAML - case _: - return cls.UNKNOWN + @classmethod + def from_path(cls, path: Path): + match path.suffix: + case ".sh": + return cls.BASH + case ".c": + return cls.C + case ".cs": + return cls.CSHARP + case ".cpp" | ".cc" | ".cxx": + return cls.CPP + case ".go": + return cls.GO + case ".java": + return cls.JAVA + case ".js": + return cls.JAVASCRIPT + case ".kt": + return cls.KOTLIN + case ".php": + return cls.PHP + case ".py": + return cls.PYTHON + case ".sql": + return cls.SQL + case ".yaml" | ".yml": + return cls.YAML + case _: + return cls.UNKNOWN diff --git a/prometheus/parser/tree_sitter_parser.py b/prometheus/parser/tree_sitter_parser.py index d8f17284..c83553c8 100644 --- a/prometheus/parser/tree_sitter_parser.py +++ b/prometheus/parser/tree_sitter_parser.py @@ -18,61 +18,61 @@ class FileNotSupportedError(Exception): - """Exception raised when attempting to parse an unsupported file type. + """Exception raised when attempting to parse an unsupported file type. - This exception is raised when the parser encounters a file type that - is not supported by the tree-sitter parser implementation. - """ + This exception is raised when the parser encounters a file type that + is not supported by the tree-sitter parser implementation. + """ - pass + pass FILE_TYPE_TO_LANG = { - FileType.BASH: "bash", - FileType.C: "c", - FileType.CSHARP: "c_sharp", - FileType.CPP: "cpp", - FileType.GO: "go", - FileType.JAVA: "java", - FileType.JAVASCRIPT: "javascript", - FileType.KOTLIN: "kotlin", - FileType.PHP: "php", - FileType.PYTHON: "python", - FileType.SQL: "sql", - FileType.YAML: "yaml", + FileType.BASH: "bash", + FileType.C: "c", + FileType.CSHARP: "c_sharp", + FileType.CPP: "cpp", + FileType.GO: "go", + FileType.JAVA: "java", + FileType.JAVASCRIPT: "javascript", + FileType.KOTLIN: "kotlin", + FileType.PHP: "php", + FileType.PYTHON: "python", + FileType.SQL: "sql", + FileType.YAML: "yaml", } def supports_file(file: Path) -> bool: - """Checks if a given file type is supported by the parser. + """Checks if a given file type is supported by the parser. - Args: - file: A Path object representing the file to check. + Args: + file: A Path object representing the file to check. - Returns: - bool: True if the file type is supported, False otherwise. - """ - file_type = FileType.from_path(file) - return file_type in FILE_TYPE_TO_LANG + Returns: + bool: True if the file type is supported, False otherwise. + """ + file_type = FileType.from_path(file) + return file_type in FILE_TYPE_TO_LANG def parse(file: Path) -> Tree: - """Parses a source code file using the appropriate tree-sitter parser. + """Parses a source code file using the appropriate tree-sitter parser. - Args: - file: A Path object representing the file to parse. + Args: + file: A Path object representing the file to parse. - Returns: - Tree: A tree-sitter Tree object representing the parsed syntax tree. + Returns: + Tree: A tree-sitter Tree object representing the parsed syntax tree. - Raises: - FileNotSupportedError: If the file type is not supported by the parser. - """ - file_type = FileType.from_path(file) - lang = FILE_TYPE_TO_LANG.get(file_type, None) - if lang is None: - raise FileNotSupportedError(f"{file_type.value} is not supported by tree_sitter_parser") + Raises: + FileNotSupportedError: If the file type is not supported by the parser. + """ + file_type = FileType.from_path(file) + lang = FILE_TYPE_TO_LANG.get(file_type, None) + if lang is None: + raise FileNotSupportedError(f"{file_type.value} is not supported by tree_sitter_parser") - lang_parser = get_parser(lang) - with file.open("rb") as f: - return lang_parser.parse(f.read()) + lang_parser = get_parser(lang) + with file.open("rb") as f: + return lang_parser.parse(f.read()) diff --git a/prometheus/tools/container_command.py b/prometheus/tools/container_command.py index b7987374..42459ad8 100644 --- a/prometheus/tools/container_command.py +++ b/prometheus/tools/container_command.py @@ -4,7 +4,7 @@ class RunCommandInput(BaseModel): - command: str = Field("The shell command to be run in the container") + command: str = Field("The shell command to be run in the container") RUN_COMMAND_DESCRIPTION = """\ @@ -14,4 +14,4 @@ class RunCommandInput(BaseModel): def run_command(command: str, container: GeneralContainer) -> str: - return container.execute_command(command) + return container.execute_command(command) diff --git a/prometheus/tools/file_operation.py b/prometheus/tools/file_operation.py index 0133374c..4122d0c2 100644 --- a/prometheus/tools/file_operation.py +++ b/prometheus/tools/file_operation.py @@ -11,7 +11,7 @@ class ReadFileInput(BaseModel): - relative_path: str = Field("The relative path of the file to read") + relative_path: str = Field("The relative path of the file to read") READ_FILE_DESCRIPTION = """\ @@ -22,25 +22,25 @@ class ReadFileInput(BaseModel): def read_file(relative_path: str, root_path: str, n_lines: int = 1000) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a abolsute path, not relative path." + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a abolsute path, not relative path." - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." + file_path = Path(os.path.join(root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." - with file_path.open() as f: - lines = f.readlines() + with file_path.open() as f: + lines = f.readlines() - return pre_append_line_numbers("".join(lines[:n_lines]), 1) + return pre_append_line_numbers("".join(lines[:n_lines]), 1) class ReadFileWithLineNumbersInput(BaseModel): - relative_path: str = Field( - description="The relative path of the file to read, eg. foo/bar/test.py, not absolute path" - ) - start_line: int = Field(description="The start line number to read, 1-indexed and inclusive") - end_line: int = Field(description="The ending line number to read, 1-indexed and exclusive") + relative_path: str = Field( + description="The relative path of the file to read, eg. foo/bar/test.py, not absolute path" + ) + start_line: int = Field(description="The start line number to read, 1-indexed and inclusive") + end_line: int = Field(description="The ending line number to read, 1-indexed and exclusive") READ_FILE_WITH_LINE_NUMBERS_DESCRIPTION = """\ @@ -51,36 +51,34 @@ class ReadFileWithLineNumbersInput(BaseModel): def read_file_with_line_numbers( - relative_path: str, root_path: str, start_line: int, end_line: int + relative_path: str, root_path: str, start_line: int, end_line: int ) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a abolsute path, not relative path." + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a abolsute path, not relative path." - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." + file_path = Path(os.path.join(root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." - if end_line < start_line: - return ( - f"The end line number {end_line} must be greater than the start line number {start_line}." - ) + if end_line < start_line: + return f"The end line number {end_line} must be greater than the start line number {start_line}." - zero_based_start_line = start_line - 1 - zero_based_end_line = end_line - 1 + zero_based_start_line = start_line - 1 + zero_based_end_line = end_line - 1 - with file_path.open() as f: - lines = f.readlines() + with file_path.open() as f: + lines = f.readlines() - return pre_append_line_numbers( - "".join(lines[zero_based_start_line:zero_based_end_line]), start_line - ) + return pre_append_line_numbers( + "".join(lines[zero_based_start_line:zero_based_end_line]), start_line + ) class CreateFileInput(BaseModel): - relative_path: str = Field( - description="The relative path of the file to create, eg. foo/bar/test.py, not absolute path" - ) - content: str = Field(description="The content of the file to create") + relative_path: str = Field( + description="The relative path of the file to create, eg. foo/bar/test.py, not absolute path" + ) + content: str = Field(description="The content of the file to create") CREATE_FILE_DESCRIPTION = """\ @@ -91,22 +89,22 @@ class CreateFileInput(BaseModel): def create_file(relative_path: str, root_path: str, content: str) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a abolsute path, not relative path." + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a abolsute path, not relative path." - file_path = Path(os.path.join(root_path, relative_path)) - if file_path.exists(): - return f"The file {relative_path} already exists." + file_path = Path(os.path.join(root_path, relative_path)) + if file_path.exists(): + return f"The file {relative_path} already exists." - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text(content) - return f"The file {relative_path} has been created." + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content) + return f"The file {relative_path} has been created." class DeleteInput(BaseModel): - relative_path: str = Field( - description="The relative path of the file/dir to delete, eg. foo/bar/test.py, not absolute path" - ) + relative_path: str = Field( + description="The relative path of the file/dir to delete, eg. foo/bar/test.py, not absolute path" + ) DELETE_DESCRIPTION = """\ @@ -117,31 +115,31 @@ class DeleteInput(BaseModel): def delete(relative_path: str, root_path: str) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a abolsute path, not relative path." + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a abolsute path, not relative path." - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." + file_path = Path(os.path.join(root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." - if file_path.is_dir(): - shutil.rmtree(file_path) - return f"The directory {relative_path} has been deleted." + if file_path.is_dir(): + shutil.rmtree(file_path) + return f"The directory {relative_path} has been deleted." - file_path.unlink() - return f"The file {relative_path} has been deleted." + file_path.unlink() + return f"The file {relative_path} has been deleted." class EditFileInput(BaseModel): - relative_path: str = Field( - description="The relative path of the file to edit, eg. foo/bar/test.py, not absolute path" - ) - old_content: str = Field( - description="The exact string content to be replaced in the file. Must match exactly one occurrence in the file" - ) - new_content: str = Field( - description="The new content that will replace the old_content in the file" - ) + relative_path: str = Field( + description="The relative path of the file to edit, eg. foo/bar/test.py, not absolute path" + ) + old_content: str = Field( + description="The exact string content to be replaced in the file. Must match exactly one occurrence in the file" + ) + new_content: str = Field( + description="The new content that will replace the old_content in the file" + ) EDIT_FILE_DESCRIPTION = """\ @@ -163,27 +161,27 @@ class EditFileInput(BaseModel): def edit_file(relative_path: str, root_path: str, old_content: str, new_content: str) -> str: - if os.path.isabs(relative_path): - return f"relative_path: {relative_path} is a abolsute path, not relative path." + if os.path.isabs(relative_path): + return f"relative_path: {relative_path} is a abolsute path, not relative path." - file_path = Path(os.path.join(root_path, relative_path)) - if not file_path.exists(): - return f"The file {relative_path} does not exist." + file_path = Path(os.path.join(root_path, relative_path)) + if not file_path.exists(): + return f"The file {relative_path} does not exist." - content = file_path.read_text() + content = file_path.read_text() - occurrences = content.count(old_content) + occurrences = content.count(old_content) - if occurrences == 0: - return f"No match found for the specified content in {relative_path}. Please verify the content to replace." + if occurrences == 0: + return f"No match found for the specified content in {relative_path}. Please verify the content to replace." - if occurrences > 1: - return ( - f"Found {occurrences} occurrences of the specified content in {relative_path}. " - "Please provide more context to ensure a unique match." - ) + if occurrences > 1: + return ( + f"Found {occurrences} occurrences of the specified content in {relative_path}. " + "Please provide more context to ensure a unique match." + ) - new_content_full = content.replace(old_content, new_content) - file_path.write_text(new_content_full) + new_content_full = content.replace(old_content, new_content) + file_path.write_text(new_content_full) - return f"Successfully edited {relative_path}." + return f"Successfully edited {relative_path}." diff --git a/prometheus/tools/graph_traversal.py b/prometheus/tools/graph_traversal.py index 2adb0ee7..d1b59aa2 100644 --- a/prometheus/tools/graph_traversal.py +++ b/prometheus/tools/graph_traversal.py @@ -14,7 +14,7 @@ class FindFileNodeWithBasenameInput(BaseModel): - basename: str = Field("The basename of FileNode to search for") + basename: str = Field("The basename of FileNode to search for") FIND_FILE_NODE_WITH_BASENAME_DESCRIPTION = """\ @@ -27,19 +27,19 @@ class FindFileNodeWithBasenameInput(BaseModel): def find_file_node_with_basename( - basename: str, driver: GraphDatabase.driver, max_token_per_result: int + basename: str, driver: GraphDatabase.driver, max_token_per_result: int ) -> str: - query = f""" + query = f""" MATCH (f:FileNode {{ basename: '{basename}' }}) RETURN f AS FileNode ORDER BY f.node_id LIMIT {MAX_RESULT} """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) class FindFileNodeWithRelativePathInput(BaseModel): - relative_path: str = Field("The relative_path of FileNode to search for") + relative_path: str = Field("The relative_path of FileNode to search for") FIND_FILE_NODE_WITH_RELATIVE_PATH_DESCRIPTION = """\ @@ -52,15 +52,15 @@ class FindFileNodeWithRelativePathInput(BaseModel): def find_file_node_with_relative_path( - relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int + relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int ) -> str: - query = f"""\ + query = f"""\ MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) RETURN f AS FileNode ORDER BY f.node_id LIMIT {MAX_RESULT} """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) ############################################################################### @@ -69,8 +69,8 @@ def find_file_node_with_relative_path( class FindASTNodeWithTextInFileWithBasenameInput(BaseModel): - text: str = Field("Search ASTNode that exactly contains this text.") - basename: str = Field("The basename of file/directory to search under for ASTNodes.") + text: str = Field("Search ASTNode that exactly contains this text.") + basename: str = Field("The basename of file/directory to search under for ASTNodes.") FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_BASENAME_DESCRIPTION = """\ @@ -83,21 +83,21 @@ class FindASTNodeWithTextInFileWithBasenameInput(BaseModel): def find_ast_node_with_text_in_file_with_basename( - text: str, basename: str, driver: GraphDatabase.driver, max_token_per_result: int + text: str, basename: str, driver: GraphDatabase.driver, max_token_per_result: int ) -> str: - query = f"""\ + query = f"""\ MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) WHERE f.basename = '{basename}' AND a.text CONTAINS '{text}' RETURN c as FileNode, a AS ASTNode ORDER BY SIZE(a.text) LIMIT {MAX_RESULT} """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) class FindASTNodeWithTextInFileWithRelativePathInput(BaseModel): - text: str = Field("Search ASTNode that exactly contains this text.") - relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") + text: str = Field("Search ASTNode that exactly contains this text.") + relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") FIND_AST_NODE_WITH_TEXT_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION = """\ @@ -110,21 +110,21 @@ class FindASTNodeWithTextInFileWithRelativePathInput(BaseModel): def find_ast_node_with_text_in_file_with_relative_path( - text: str, relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int + text: str, relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int ) -> str: - query = f"""\ + query = f"""\ MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) WHERE f.relative_path = '{relative_path}' AND a.text CONTAINS '{text}' RETURN c as FileNode, a AS ASTNode ORDER BY SIZE(a.text) LIMIT {MAX_RESULT} """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) class FindASTNodeWithTypeInFileWithBasenameInput(BaseModel): - type: str = Field("Search ASTNode with this tree-sitter node type.") - basename: str = Field("The basename of file/directory to search under for ASTNodes.") + type: str = Field("Search ASTNode with this tree-sitter node type.") + basename: str = Field("The basename of file/directory to search under for ASTNodes.") FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_BASENAME_DESCRIPTION = """\ @@ -135,21 +135,21 @@ class FindASTNodeWithTypeInFileWithBasenameInput(BaseModel): def find_ast_node_with_type_in_file_with_basename( - type: str, basename: str, driver: GraphDatabase.driver, max_token_per_result: int + type: str, basename: str, driver: GraphDatabase.driver, max_token_per_result: int ) -> str: - query = f"""\ + query = f"""\ MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) WHERE f.basename = '{basename}' AND a.type = '{type}' RETURN c as FileNode, a AS ASTNode ORDER BY SIZE(a.text) LIMIT {MAX_RESULT} """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) class FindASTNodeWithTypeInFileWithRelativePathInput(BaseModel): - type: str = Field("Search ASTNode with this tree-sitter node type.") - relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") + type: str = Field("Search ASTNode with this tree-sitter node type.") + relative_path: str = Field("The relative path of file/directory to search under for ASTNodes.") FIND_AST_NODE_WITH_TYPE_IN_FILE_WITH_RELATIVE_PATH_DESCRIPTION = """\ @@ -160,16 +160,16 @@ class FindASTNodeWithTypeInFileWithRelativePathInput(BaseModel): def find_ast_node_with_type_in_file_with_relative_path( - type: str, relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int + type: str, relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int ) -> str: - query = f"""\ + query = f"""\ MATCH (f:FileNode) -[:HAS_FILE*0..]-> (c:FileNode) -[:HAS_AST]-> (:ASTNode) -[:PARENT_OF*0..]-> (a:ASTNode) WHERE f.relative_path = '{relative_path}' AND a.type = '{type}' RETURN c as FileNode, a AS ASTNode ORDER BY SIZE(a.text) LIMIT {MAX_RESULT} """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) ############################################################################### @@ -178,7 +178,7 @@ def find_ast_node_with_type_in_file_with_relative_path( class FindTextNodeWithTextInput(BaseModel): - text: str = Field("Search TextNode that exactly contains this text.") + text: str = Field("Search TextNode that exactly contains this text.") FIND_TEXT_NODE_WITH_TEXT_DESCRIPTION = """\ @@ -190,21 +190,21 @@ class FindTextNodeWithTextInput(BaseModel): def find_text_node_with_text( - text: str, driver: GraphDatabase.driver, max_token_per_result: int + text: str, driver: GraphDatabase.driver, max_token_per_result: int ) -> str: - query = f"""\ + query = f"""\ MATCH (f:FileNode) -[:HAS_TEXT]-> (t:TextNode) WHERE t.text CONTAINS '{text}' RETURN f as FileNode, t AS TextNode ORDER BY t.node_id LIMIT {MAX_RESULT} """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) class FindTextNodeWithTextInFileInput(BaseModel): - text: str = Field("Search TextNode that exactly contains this text.") - basename: str = Field("The basename of FileNode to search TextNode.") + text: str = Field("Search TextNode that exactly contains this text.") + basename: str = Field("The basename of FileNode to search TextNode.") FIND_TEXT_NODE_WITH_TEXT_IN_FILE_DESCRIPTION = """\ @@ -218,20 +218,20 @@ class FindTextNodeWithTextInFileInput(BaseModel): def find_text_node_with_text_in_file( - text: str, basename: str, driver: GraphDatabase.driver, max_token_per_result: int + text: str, basename: str, driver: GraphDatabase.driver, max_token_per_result: int ) -> str: - query = f"""\ + query = f"""\ MATCH (f:FileNode) -[:HAS_TEXT]-> (t:TextNode) WHERE f.basename = '{basename}' AND t.text CONTAINS '{text}' RETURN f as FileNode, t AS TextNode ORDER BY t.node_id LIMIT {MAX_RESULT} """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) class GetNextTextNodeWithNodeIdInput(BaseModel): - node_id: int = Field("Get the next TextNode of this given node_id.") + node_id: int = Field("Get the next TextNode of this given node_id.") GET_NEXT_TEXT_NODE_WITH_NODE_ID_DESCRIPTION = """\ @@ -241,13 +241,13 @@ class GetNextTextNodeWithNodeIdInput(BaseModel): def get_next_text_node_with_node_id( - node_id: int, driver: GraphDatabase.driver, max_token_per_result: int + node_id: int, driver: GraphDatabase.driver, max_token_per_result: int ) -> str: - query = f"""\ + query = f"""\ MATCH (f:FileNode) -[:HAS_TEXT]-> (a:TextNode {{ node_id: {node_id} }}) -[:NEXT_CHUNK]-> (b:TextNode) RETURN f as FileNode, b AS TextNode """ - return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(query, driver, max_token_per_result) ############################################################################### @@ -256,7 +256,7 @@ def get_next_text_node_with_node_id( class PreviewFileContentWithBasenameInput(BaseModel): - basename: str = Field("The basename of FileNode to preview.") + basename: str = Field("The basename of FileNode to preview.") PREVIEW_FILE_CONTENT_WITH_BASENAME_DESCRIPTION = """\ @@ -270,9 +270,9 @@ class PreviewFileContentWithBasenameInput(BaseModel): def preview_file_content_with_basename( - basename: str, driver: GraphDatabase.driver, max_token_per_result: int + basename: str, driver: GraphDatabase.driver, max_token_per_result: int ) -> str: - source_code_query = f"""\ + source_code_query = f"""\ MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_AST]-> (a:ASTNode) WITH f, apoc.text.split(a.text, '\\R') AS lines RETURN @@ -285,20 +285,20 @@ def preview_file_content_with_basename( ORDER BY f.node_id """ - text_query = f"""\ + text_query = f"""\ MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_TEXT]-> (t:TextNode) WHERE NOT EXISTS((:TextNode) -[:NEXT_CHUNK]-> (t)) RETURN f as FileNode, t.text AS preview ORDER BY f.node_id """ - if tree_sitter_parser.supports_file(Path(basename)): - return neo4j_util.run_neo4j_query(source_code_query, driver, max_token_per_result) - return neo4j_util.run_neo4j_query(text_query, driver, max_token_per_result) + if tree_sitter_parser.supports_file(Path(basename)): + return neo4j_util.run_neo4j_query(source_code_query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(text_query, driver, max_token_per_result) class PreviewFileContentWithRelativePathInput(BaseModel): - relative_path: str = Field("The relative path of FileNode to preview.") + relative_path: str = Field("The relative path of FileNode to preview.") PREVIEW_FILE_CONTENT_WITH_RELATIVE_PATH_DESCRIPTION = """\ @@ -312,9 +312,9 @@ class PreviewFileContentWithRelativePathInput(BaseModel): def preview_file_content_with_relative_path( - relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int + relative_path: str, driver: GraphDatabase.driver, max_token_per_result: int ) -> str: - source_code_query = f"""\ + source_code_query = f"""\ MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_AST]-> (a:ASTNode) WITH f, apoc.text.split(a.text, '\\R') AS lines RETURN @@ -327,22 +327,22 @@ def preview_file_content_with_relative_path( ORDER BY f.node_id """ - text_query = f"""\ + text_query = f"""\ MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_TEXT]-> (t:TextNode) WHERE NOT EXISTS((:TextNode) -[:NEXT_CHUNK]-> (t)) RETURN f as FileNode, t.text AS preview ORDER BY f.node_id """ - if tree_sitter_parser.supports_file(Path(relative_path)): - return neo4j_util.run_neo4j_query(source_code_query, driver, max_token_per_result) - return neo4j_util.run_neo4j_query(text_query, driver, max_token_per_result) + if tree_sitter_parser.supports_file(Path(relative_path)): + return neo4j_util.run_neo4j_query(source_code_query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(text_query, driver, max_token_per_result) class ReadCodeWithBasenameInput(BaseModel): - basename: str = Field("The basename of FileNode to read.") - start_line: int = Field("The starting line number, 1-indexed and inclusive.") - end_line: int = Field("The ending line number, 1-indexed and exclusive.") + basename: str = Field("The basename of FileNode to read.") + start_line: int = Field("The starting line number, 1-indexed and inclusive.") + end_line: int = Field("The ending line number, 1-indexed and exclusive.") READ_CODE_WITH_BASENAME_DESCRIPTION = """\ @@ -362,16 +362,16 @@ class ReadCodeWithBasenameInput(BaseModel): def read_code_with_basename( - basename: str, - start_line: int, - end_line: int, - driver: GraphDatabase.driver, - max_token_per_result: int, + basename: str, + start_line: int, + end_line: int, + driver: GraphDatabase.driver, + max_token_per_result: int, ) -> str: - if end_line < start_line: - return f"end_line {end_line} must be greater than start_line {start_line}" + if end_line < start_line: + return f"end_line {end_line} must be greater than start_line {start_line}" - source_code_query = f"""\ + source_code_query = f"""\ MATCH (f:FileNode {{ basename: '{basename}' }}) -[:HAS_AST]-> (a:ASTNode) WITH f, apoc.text.split(a.text, '\\R') AS lines RETURN @@ -384,13 +384,13 @@ def read_code_with_basename( ORDER BY f.node_id """ - return neo4j_util.run_neo4j_query(source_code_query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(source_code_query, driver, max_token_per_result) class ReadCodeWithRelativePathInput(BaseModel): - relative_path: str = Field("The relative path of FileNode to read from root of codebase.") - start_line: int = Field("The starting line number, 1-indexed and inclusive.") - end_line: int = Field("The ending line number, 1-indexed and exclusive.") + relative_path: str = Field("The relative path of FileNode to read from root of codebase.") + start_line: int = Field("The starting line number, 1-indexed and inclusive.") + end_line: int = Field("The ending line number, 1-indexed and exclusive.") READ_CODE_WITH_RELATIVE_PATH_DESCRIPTION = """\ @@ -411,16 +411,16 @@ class ReadCodeWithRelativePathInput(BaseModel): def read_code_with_relative_path( - relative_path: str, - start_line: int, - end_line: int, - driver: GraphDatabase.driver, - max_token_per_result: int, + relative_path: str, + start_line: int, + end_line: int, + driver: GraphDatabase.driver, + max_token_per_result: int, ) -> str: - if end_line < start_line: - return f"end_line {end_line} must be greater than start_line {start_line}" + if end_line < start_line: + return f"end_line {end_line} must be greater than start_line {start_line}" - source_code_query = f"""\ + source_code_query = f"""\ MATCH (f:FileNode {{ relative_path: '{relative_path}' }}) -[:HAS_AST]-> (a:ASTNode) WITH f, apoc.text.split(a.text, '\\R') AS lines RETURN @@ -433,4 +433,4 @@ def read_code_with_relative_path( ORDER BY f.node_id """ - return neo4j_util.run_neo4j_query(source_code_query, driver, max_token_per_result) + return neo4j_util.run_neo4j_query(source_code_query, driver, max_token_per_result) diff --git a/prometheus/utils/issue_util.py b/prometheus/utils/issue_util.py index e3decc96..e5df2a64 100644 --- a/prometheus/utils/issue_util.py +++ b/prometheus/utils/issue_util.py @@ -2,28 +2,28 @@ def format_issue_comments(issue_comments: Sequence[Mapping[str, str]]): - """Formats a sequence of issue comments into a readable string. + """Formats a sequence of issue comments into a readable string. - Combines multiple issue comments with their associated usernames into a - formatted string suitable for inclusion in the response context. + Combines multiple issue comments with their associated usernames into a + formatted string suitable for inclusion in the response context. - Args: - issue_comments: Sequence of mappings containing 'username' and 'comment' - keys for each issue comment. + Args: + issue_comments: Sequence of mappings containing 'username' and 'comment' + keys for each issue comment. - Returns: - Formatted string containing all comments with usernames, separated by newlines. - """ - formatted_issue_comments = [] - for issue_comment in issue_comments: - formatted_issue_comments.append(f"{issue_comment['username']}: {issue_comment['comment']}") - return "\n\n".join(formatted_issue_comments) + Returns: + Formatted string containing all comments with usernames, separated by newlines. + """ + formatted_issue_comments = [] + for issue_comment in issue_comments: + formatted_issue_comments.append(f"{issue_comment['username']}: {issue_comment['comment']}") + return "\n\n".join(formatted_issue_comments) def format_issue_info( - issue_title: str, issue_body: str, issue_comments: Sequence[Mapping[str, str]] + issue_title: str, issue_body: str, issue_comments: Sequence[Mapping[str, str]] ) -> str: - return f"""\ + return f"""\ Issue title: {issue_title} @@ -35,5 +35,5 @@ def format_issue_info( def format_test_commands(test_commands: Sequence[str]) -> str: - test_commands_with_prefix = [f"$ {test_command}" for test_command in test_commands] - return "\n".join(test_commands_with_prefix) + test_commands_with_prefix = [f"$ {test_command}" for test_command in test_commands] + return "\n".join(test_commands_with_prefix) diff --git a/prometheus/utils/lang_graph_util.py b/prometheus/utils/lang_graph_util.py index e3891268..b5b4192c 100644 --- a/prometheus/utils/lang_graph_util.py +++ b/prometheus/utils/lang_graph_util.py @@ -2,130 +2,130 @@ import tiktoken from langchain_core.messages import ( - AIMessage, - BaseMessage, - HumanMessage, - SystemMessage, - ToolMessage, - trim_messages, + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, + trim_messages, ) from langchain_core.output_parsers import StrOutputParser def check_remaining_steps( - state: Dict, - router: Callable[..., str], - min_remaining_steps: int, - remaining_steps_key: str = "remaining_steps", + state: Dict, + router: Callable[..., str], + min_remaining_steps: int, + remaining_steps_key: str = "remaining_steps", ) -> str: - original_route = router(state) - if state[remaining_steps_key] > min_remaining_steps: - return original_route - else: - return "low_remaining_steps" + original_route = router(state) + if state[remaining_steps_key] > min_remaining_steps: + return original_route + else: + return "low_remaining_steps" def str_token_counter(text: str) -> int: - enc = tiktoken.get_encoding("o200k_base") - return len(enc.encode(text)) + enc = tiktoken.get_encoding("o200k_base") + return len(enc.encode(text)) def tiktoken_counter(messages: Sequence[BaseMessage]) -> int: - """Approximately reproduce https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - - For simplicity only supports str Message.contents. - """ - output_parser = StrOutputParser() - num_tokens = 3 # every reply is primed with <|start|>assistant<|message|> - tokens_per_message = 3 - tokens_per_name = 1 - for msg in messages: - if isinstance(msg, HumanMessage): - role = "user" - elif isinstance(msg, AIMessage): - role = "assistant" - elif isinstance(msg, ToolMessage): - role = "tool" - elif isinstance(msg, SystemMessage): - role = "system" - else: - raise ValueError(f"Unsupported messages type {msg.__class__}") - msg_content = output_parser.invoke(msg) - num_tokens += tokens_per_message + str_token_counter(role) + str_token_counter(msg_content) - if msg.name: - num_tokens += tokens_per_name + str_token_counter(msg.name) - return num_tokens + """Approximately reproduce https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + + For simplicity only supports str Message.contents. + """ + output_parser = StrOutputParser() + num_tokens = 3 # every reply is primed with <|start|>assistant<|message|> + tokens_per_message = 3 + tokens_per_name = 1 + for msg in messages: + if isinstance(msg, HumanMessage): + role = "user" + elif isinstance(msg, AIMessage): + role = "assistant" + elif isinstance(msg, ToolMessage): + role = "tool" + elif isinstance(msg, SystemMessage): + role = "system" + else: + raise ValueError(f"Unsupported messages type {msg.__class__}") + msg_content = output_parser.invoke(msg) + num_tokens += tokens_per_message + str_token_counter(role) + str_token_counter(msg_content) + if msg.name: + num_tokens += tokens_per_name + str_token_counter(msg.name) + return num_tokens def truncate_messages( - messages: Sequence[BaseMessage], max_tokens: int = 100000 + messages: Sequence[BaseMessage], max_tokens: int = 100000 ) -> Sequence[BaseMessage]: - return trim_messages( - messages, - token_counter=tiktoken_counter, - strategy="last", - max_tokens=max_tokens, - start_on="human", - end_on=("human", "tool"), - include_system=True, - ) + return trim_messages( + messages, + token_counter=tiktoken_counter, + strategy="last", + max_tokens=max_tokens, + start_on="human", + end_on=("human", "tool"), + include_system=True, + ) def extract_ai_responses(messages: Sequence[BaseMessage]) -> Sequence[str]: - ai_responses = [] - output_parser = StrOutputParser() - for index, message in enumerate(messages): - if isinstance(message, AIMessage) and ( - index == len(messages) - 1 or isinstance(messages[index + 1], HumanMessage) - ): - ai_responses.append(output_parser.invoke(message)) - return ai_responses + ai_responses = [] + output_parser = StrOutputParser() + for index, message in enumerate(messages): + if isinstance(message, AIMessage) and ( + index == len(messages) - 1 or isinstance(messages[index + 1], HumanMessage) + ): + ai_responses.append(output_parser.invoke(message)) + return ai_responses def extract_human_queries(messages: Sequence[BaseMessage]) -> Sequence[str]: - human_queries = [] - output_parser = StrOutputParser() - for message in messages: - if isinstance(message, HumanMessage): - human_queries.append(output_parser.invoke(message)) - return human_queries + human_queries = [] + output_parser = StrOutputParser() + for message in messages: + if isinstance(message, HumanMessage): + human_queries.append(output_parser.invoke(message)) + return human_queries def extract_last_tool_messages(messages: Sequence[BaseMessage]) -> Sequence[ToolMessage]: - tool_messages = [] - last_human_index = -1 - for i in range(len(messages) - 1, -1, -1): - if isinstance(messages[i], HumanMessage): - last_human_index = i - break + tool_messages = [] + last_human_index = -1 + for i in range(len(messages) - 1, -1, -1): + if isinstance(messages[i], HumanMessage): + last_human_index = i + break - if last_human_index == -1: - return [] + if last_human_index == -1: + return [] - for message in messages[last_human_index + 1 :]: - if isinstance(message, ToolMessage): - tool_messages.append(message) - return tool_messages + for message in messages[last_human_index + 1 :]: + if isinstance(message, ToolMessage): + tool_messages.append(message) + return tool_messages def get_last_message_content(messages: Sequence[BaseMessage]) -> str: - output_parser = StrOutputParser() - return output_parser.invoke(messages[-1]) + output_parser = StrOutputParser() + return output_parser.invoke(messages[-1]) def format_agent_tool_message_history(messages: Sequence[BaseMessage]) -> str: - formatted_messages = [] - for message in messages: - if isinstance(message, AIMessage): - if message.content: - formatted_messages.append(f"Assistant internal thought: {message.content}") - if ( - message.additional_kwargs - and "tool_calls" in message.additional_kwargs - and message.additional_kwargs["tool_calls"] - ): - for tool_call in message.additional_kwargs["tool_calls"]: - formatted_messages.append(f"Assistant executed tool: {tool_call['function']}") - elif isinstance(message, ToolMessage): - formatted_messages.append(f"Tool output: {message.content}") - return "\n\n".join(formatted_messages) + formatted_messages = [] + for message in messages: + if isinstance(message, AIMessage): + if message.content: + formatted_messages.append(f"Assistant internal thought: {message.content}") + if ( + message.additional_kwargs + and "tool_calls" in message.additional_kwargs + and message.additional_kwargs["tool_calls"] + ): + for tool_call in message.additional_kwargs["tool_calls"]: + formatted_messages.append(f"Assistant executed tool: {tool_call['function']}") + elif isinstance(message, ToolMessage): + formatted_messages.append(f"Tool output: {message.content}") + return "\n\n".join(formatted_messages) diff --git a/prometheus/utils/neo4j_util.py b/prometheus/utils/neo4j_util.py index 974fd467..a928e541 100644 --- a/prometheus/utils/neo4j_util.py +++ b/prometheus/utils/neo4j_util.py @@ -8,70 +8,70 @@ def format_neo4j_data(data: Sequence[Mapping[str, Any]], max_token_per_result: int) -> str: - """Format a Neo4j result into a string. + """Format a Neo4j result into a string. - Args: - result: The result from a Neo4j query. - max_token_per_result: Maximum number of tokens per result. + Args: + result: The result from a Neo4j query. + max_token_per_result: Maximum number of tokens per result. - Returns: - A string representation of the result. - """ - if not data: - return EMPTY_DATA_MESSAGE + Returns: + A string representation of the result. + """ + if not data: + return EMPTY_DATA_MESSAGE - output = "" - for index, row_result in enumerate(data): - output += f"Result {index + 1}:\n" - for key in sorted(row_result.keys()): - output += f"{key}: {str(row_result[key])}\n" - output += "\n\n" - return truncate_text(output.strip(), max_token_per_result) + output = "" + for index, row_result in enumerate(data): + output += f"Result {index + 1}:\n" + for key in sorted(row_result.keys()): + output += f"{key}: {str(row_result[key])}\n" + output += "\n\n" + return truncate_text(output.strip(), max_token_per_result) def neo4j_data_for_context_generator(data: Optional[Sequence[Mapping[str, Any]]]) -> Iterator[str]: - if data is None: - return + if data is None: + return - for search_result in data: - search_result_keys = search_result.keys() - if len(search_result_keys) == 1: - continue + for search_result in data: + search_result_keys = search_result.keys() + if len(search_result_keys) == 1: + continue - search_result_components = [f"File: {search_result['FileNode']['relative_path']}"] - for key in search_result: - if key == "FileNode": - continue + search_result_components = [f"File: {search_result['FileNode']['relative_path']}"] + for key in search_result: + if key == "FileNode": + continue - if "start_line" in search_result[key] and "end_line" in search_result[key]: - search_result_components.append( - f"Line number range: {search_result[key]['start_line']} - {search_result[key]['end_line']}" - ) - search_result[key].pop("start_line") - search_result[key].pop("end_line") + if "start_line" in search_result[key] and "end_line" in search_result[key]: + search_result_components.append( + f"Line number range: {search_result[key]['start_line']} - {search_result[key]['end_line']}" + ) + search_result[key].pop("start_line") + search_result[key].pop("end_line") - search_result_components.append(f"{key}: {search_result[key]}") - yield "\n".join(search_result_components) + search_result_components.append(f"{key}: {search_result[key]}") + yield "\n".join(search_result_components) def run_neo4j_query( - query: str, driver: neo4j.GraphDatabase.driver, max_token_per_result: int + query: str, driver: neo4j.GraphDatabase.driver, max_token_per_result: int ) -> Tuple[str, Sequence[Mapping[str, Any]]]: - """Run a read-only Neo4j query and format the result into a string. + """Run a read-only Neo4j query and format the result into a string. - Args: - query: The query to run. - driver: The Neo4j driver to use. - max_token_per_result: Maximum number of tokens per result. + Args: + query: The query to run. + driver: The Neo4j driver to use. + max_token_per_result: Maximum number of tokens per result. - Returns: - A string representation of the result. - """ + Returns: + A string representation of the result. + """ - def query_transaction(tx): - result = tx.run(query) - data = result.data() - return format_neo4j_data(data, max_token_per_result), data + def query_transaction(tx): + result = tx.run(query) + data = result.data() + return format_neo4j_data(data, max_token_per_result), data - with driver.session() as session: - return session.execute_read(query_transaction) + with driver.session() as session: + return session.execute_read(query_transaction) diff --git a/prometheus/utils/patch_util.py b/prometheus/utils/patch_util.py index 39324c2c..e97c14d0 100644 --- a/prometheus/utils/patch_util.py +++ b/prometheus/utils/patch_util.py @@ -5,18 +5,18 @@ def get_updated_files(diff: str) -> Tuple[Sequence[Path], Sequence[Path], Sequence[Path]]: - patch = PatchSet(diff) - added_files = [] - modified_files = [] - removed_files = [] + patch = PatchSet(diff) + added_files = [] + modified_files = [] + removed_files = [] - for added_file in patch.added_files: - added_files.append(Path(added_file.path)) + for added_file in patch.added_files: + added_files.append(Path(added_file.path)) - for modified_file in patch.modified_files: - modified_files.append(Path(modified_file.path)) + for modified_file in patch.modified_files: + modified_files.append(Path(modified_file.path)) - for removed_file in patch.removed_files: - removed_files.append(Path(removed_file.path)) + for removed_file in patch.removed_files: + removed_files.append(Path(removed_file.path)) - return added_files, modified_files, removed_files + return added_files, modified_files, removed_files diff --git a/prometheus/utils/str_util.py b/prometheus/utils/str_util.py index 409f1387..1f584efd 100644 --- a/prometheus/utils/str_util.py +++ b/prometheus/utils/str_util.py @@ -5,11 +5,11 @@ @lru_cache(maxsize=1) def get_tokenizer(encoding: str = "o200k_base") -> tiktoken.Encoding: - return tiktoken.get_encoding(encoding) + return tiktoken.get_encoding(encoding) def pre_append_line_numbers(text: str, start_line: int) -> str: - return "\n".join([f"{start_line + i}. {line}" for i, line in enumerate(text.splitlines())]) + return "\n".join([f"{start_line + i}. {line}" for i, line in enumerate(text.splitlines())]) TRUNCATED_TEXT = "... Output has been truncated becuase it is too long, please narrow down your query if you wish to see more" @@ -17,11 +17,11 @@ def pre_append_line_numbers(text: str, start_line: int) -> str: def truncate_text(text: str, max_token: int) -> str: - encoder = get_tokenizer() - tokens = encoder.encode(text) - if len(tokens) <= max_token: - return text - - truncated_tokens = tokens[: max_token - TRUNCATED_TEXT_LEN] - truncated_text = encoder.decode(truncated_tokens) - return truncated_text + TRUNCATED_TEXT + encoder = get_tokenizer() + tokens = encoder.encode(text) + if len(tokens) <= max_token: + return text + + truncated_tokens = tokens[: max_token - TRUNCATED_TEXT_LEN] + truncated_text = encoder.decode(truncated_tokens) + return truncated_text + TRUNCATED_TEXT diff --git a/pyproject.toml b/pyproject.toml index d3cbb1a4..3a983f79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ test = [ ] [tool.ruff] -indent-width = 2 +indent-width = 4 line-length = 100 [tool.ruff.lint] diff --git a/tests/app/api/test_issue.py b/tests/app/api/test_issue.py index 26ef6ac5..ae5e7c1a 100644 --- a/tests/app/api/test_issue.py +++ b/tests/app/api/test_issue.py @@ -14,107 +14,107 @@ @pytest.fixture def mock_service_coordinator(): - service_coordinator = mock.MagicMock() - app.state.service_coordinator = service_coordinator - yield service_coordinator + service_coordinator = mock.MagicMock() + app.state.service_coordinator = service_coordinator + yield service_coordinator def test_answer_issue(mock_service_coordinator): - mock_service_coordinator.exists_knowledge_graph.return_value = True - mock_service_coordinator.answer_issue.return_value = ( - "feature/fix-42", # remote_branch_name - "test patch", # patch - True, # passed_reproducing_test - True, # passed_build - True, # passed_existing_test - "Issue fixed", # issue_response - ) - - response = client.post( - "/issue/answer", - json={ - "issue_number": 42, - "issue_title": "Test Issue", - "issue_body": "Test description", - }, - ) - - assert response.status_code == 200 - assert response.json() == { - "remote_branch_name": "feature/fix-42", - "patch": "test patch", - "passed_reproducing_test": True, - "passed_build": True, - "passed_existing_test": True, - "issue_response": "Issue fixed", - } + mock_service_coordinator.exists_knowledge_graph.return_value = True + mock_service_coordinator.answer_issue.return_value = ( + "feature/fix-42", # remote_branch_name + "test patch", # patch + True, # passed_reproducing_test + True, # passed_build + True, # passed_existing_test + "Issue fixed", # issue_response + ) + + response = client.post( + "/issue/answer", + json={ + "issue_number": 42, + "issue_title": "Test Issue", + "issue_body": "Test description", + }, + ) + + assert response.status_code == 200 + assert response.json() == { + "remote_branch_name": "feature/fix-42", + "patch": "test patch", + "passed_reproducing_test": True, + "passed_build": True, + "passed_existing_test": True, + "issue_response": "Issue fixed", + } def test_answer_issue_no_repository(mock_service_coordinator): - mock_service_coordinator.exists_knowledge_graph.return_value = False + mock_service_coordinator.exists_knowledge_graph.return_value = False - response = client.post( - "/issue/answer", - json={"issue_number": 42, "issue_title": "Test Issue", "issue_body": "Test description"}, - ) + response = client.post( + "/issue/answer", + json={"issue_number": 42, "issue_title": "Test Issue", "issue_body": "Test description"}, + ) - assert response.status_code == 404 + assert response.status_code == 404 def test_answer_issue_invalid_container_config(mock_service_coordinator): - mock_service_coordinator.exists_knowledge_graph.return_value = True + mock_service_coordinator.exists_knowledge_graph.return_value = True - response = client.post( - "/issue/answer", - json={ - "issue_number": 42, - "issue_title": "Test Issue", - "issue_body": "Test description", - "dockerfile_content": "FROM python:3.11", - "workdir": None, - }, - ) + response = client.post( + "/issue/answer", + json={ + "issue_number": 42, + "issue_title": "Test Issue", + "issue_body": "Test description", + "dockerfile_content": "FROM python:3.11", + "workdir": None, + }, + ) - assert response.status_code == 400 + assert response.status_code == 400 def test_answer_issue_with_container(mock_service_coordinator): - mock_service_coordinator.exists_knowledge_graph.return_value = True - mock_service_coordinator.answer_issue.return_value = ( - "feature/fix-42", - "test patch", - True, - True, - True, - "Issue fixed", - ) - - test_payload = { - "issue_number": 42, - "issue_title": "Test Issue", - "issue_body": "Test description", - "dockerfile_content": "FROM python:3.11", - "workdir": "/app", - "build_commands": ["pip install -r requirements.txt"], - "test_commands": ["pytest ."], - } - - response = client.post("/issue/answer", json=test_payload) - - assert response.status_code == 200 - mock_service_coordinator.answer_issue.assert_called_once_with( - issue_number=42, - issue_title="Test Issue", - issue_body="Test description", - issue_comments=[], - issue_type=IssueType.AUTO, - run_build=False, - run_existing_test=False, - number_of_candidate_patch=4, - dockerfile_content="FROM python:3.11", - image_name=None, - workdir="/app", - build_commands=["pip install -r requirements.txt"], - test_commands=["pytest ."], - push_to_remote=False, - ) + mock_service_coordinator.exists_knowledge_graph.return_value = True + mock_service_coordinator.answer_issue.return_value = ( + "feature/fix-42", + "test patch", + True, + True, + True, + "Issue fixed", + ) + + test_payload = { + "issue_number": 42, + "issue_title": "Test Issue", + "issue_body": "Test description", + "dockerfile_content": "FROM python:3.11", + "workdir": "/app", + "build_commands": ["pip install -r requirements.txt"], + "test_commands": ["pytest ."], + } + + response = client.post("/issue/answer", json=test_payload) + + assert response.status_code == 200 + mock_service_coordinator.answer_issue.assert_called_once_with( + issue_number=42, + issue_title="Test Issue", + issue_body="Test description", + issue_comments=[], + issue_type=IssueType.AUTO, + run_build=False, + run_existing_test=False, + number_of_candidate_patch=4, + dockerfile_content="FROM python:3.11", + image_name=None, + workdir="/app", + build_commands=["pip install -r requirements.txt"], + test_commands=["pytest ."], + push_to_remote=False, + ) diff --git a/tests/app/api/test_repository.py b/tests/app/api/test_repository.py index b53cce52..3f983899 100644 --- a/tests/app/api/test_repository.py +++ b/tests/app/api/test_repository.py @@ -14,30 +14,30 @@ @pytest.fixture def mock_service_coordinator(): - service_coordinator = mock.MagicMock() - app.state.service_coordinator = service_coordinator - yield service_coordinator + service_coordinator = mock.MagicMock() + app.state.service_coordinator = service_coordinator + yield service_coordinator def test_upload_local_repository(mock_service_coordinator): - mock_service_coordinator.upload_local_repository.return_value = None - response = client.get( - "/repository/local", - params={"local_repository": test_project_paths.TEST_PROJECT_PATH.absolute().as_posix()}, - ) - assert response.status_code == 200 + mock_service_coordinator.upload_local_repository.return_value = None + response = client.get( + "/repository/local", + params={"local_repository": test_project_paths.TEST_PROJECT_PATH.absolute().as_posix()}, + ) + assert response.status_code == 200 def test_upload_fake_local_repository(): - response = client.get( - "/repository/local/", - params={"local_repository": "/foo/bar/"}, - ) - assert response.status_code == 404 + response = client.get( + "/repository/local/", + params={"local_repository": "/foo/bar/"}, + ) + assert response.status_code == 404 def test_delete(mock_service_coordinator): - mock_service_coordinator.exists_knowledge_graph.return_value = True - mock_service_coordinator.clear.return_value = None - response = client.get("repository/delete") - assert response.status_code == 200 + mock_service_coordinator.exists_knowledge_graph.return_value = True + mock_service_coordinator.clear.return_value = None + response = client.get("repository/delete") + assert response.status_code == 200 diff --git a/tests/app/services/test_issue_service.py b/tests/app/services/test_issue_service.py index 8daf8dc4..c3d0291a 100644 --- a/tests/app/services/test_issue_service.py +++ b/tests/app/services/test_issue_service.py @@ -13,165 +13,167 @@ @pytest.fixture def mock_kg_service(): - service = create_autospec(KnowledgeGraphService, instance=True) - service.kg = Mock(name="mock_knowledge_graph") - service.kg.get_local_path.return_value = Path("/mock/path") - return service + service = create_autospec(KnowledgeGraphService, instance=True) + service.kg = Mock(name="mock_knowledge_graph") + service.kg.get_local_path.return_value = Path("/mock/path") + return service @pytest.fixture def mock_repository_service(): - service = create_autospec(RepositoryService, instance=True) - service.git_repo = Mock(name="mock_git_repo") - return service + service = create_autospec(RepositoryService, instance=True) + service.git_repo = Mock(name="mock_git_repo") + return service @pytest.fixture def mock_neo4j_service(): - service = create_autospec(Neo4jService, instance=True) - service.neo4j_driver = Mock(name="mock_neo4j_driver") - return service + service = create_autospec(Neo4jService, instance=True) + service.neo4j_driver = Mock(name="mock_neo4j_driver") + return service @pytest.fixture def mock_llm_service(): - service = create_autospec(LLMService, instance=True) - service.advanced_model = "gpt-4" - service.base_model = "gpt-3.5-turbo" - return service + service = create_autospec(LLMService, instance=True) + service.advanced_model = "gpt-4" + service.base_model = "gpt-3.5-turbo" + return service @pytest.fixture def issue_service(mock_kg_service, mock_repository_service, mock_neo4j_service, mock_llm_service): - return IssueService( - mock_kg_service, - mock_repository_service, - mock_neo4j_service, - mock_llm_service, - max_token_per_neo4j_result=1000, - ) + return IssueService( + mock_kg_service, + mock_repository_service, + mock_neo4j_service, + mock_llm_service, + max_token_per_neo4j_result=1000, + ) def test_answer_issue_with_general_container(issue_service, monkeypatch): - # Setup - mock_issue_graph = Mock() - mock_issue_graph_class = Mock(return_value=mock_issue_graph) - monkeypatch.setattr("prometheus.app.services.issue_service.IssueGraph", mock_issue_graph_class) - - mock_container = Mock() - mock_general_container_class = Mock(return_value=mock_container) - monkeypatch.setattr( - "prometheus.app.services.issue_service.GeneralContainer", mock_general_container_class - ) - - # Mock output state for bug type - mock_output_state = { - "issue_type": IssueType.BUG, - "edit_patch": "test_patch", - "passed_reproducing_test": True, - "passed_build": True, - "passed_existing_test": True, - "issue_response": "test_response", - } - mock_issue_graph.invoke.return_value = mock_output_state - - # Exercise - result = issue_service.answer_issue( - issue_title="Test Issue", - issue_body="Test Body", - issue_comments=[], - issue_type=IssueType.BUG, - run_build=True, - run_existing_test=True, - number_of_candidate_patch=1, - ) - - # Verify - mock_general_container_class.assert_called_once_with(issue_service.kg_service.kg.get_local_path()) - mock_issue_graph_class.assert_called_once_with( - advanced_model=issue_service.llm_service.advanced_model, - base_model=issue_service.llm_service.base_model, - kg=issue_service.kg_service.kg, - git_repo=issue_service.repository_service.git_repo, - neo4j_driver=issue_service.neo4j_service.neo4j_driver, - max_token_per_neo4j_result=issue_service.max_token_per_neo4j_result, - container=mock_container, - build_commands=None, - test_commands=None, - ) - assert result == ("test_patch", True, True, True, "test_response") + # Setup + mock_issue_graph = Mock() + mock_issue_graph_class = Mock(return_value=mock_issue_graph) + monkeypatch.setattr("prometheus.app.services.issue_service.IssueGraph", mock_issue_graph_class) + + mock_container = Mock() + mock_general_container_class = Mock(return_value=mock_container) + monkeypatch.setattr( + "prometheus.app.services.issue_service.GeneralContainer", mock_general_container_class + ) + + # Mock output state for bug type + mock_output_state = { + "issue_type": IssueType.BUG, + "edit_patch": "test_patch", + "passed_reproducing_test": True, + "passed_build": True, + "passed_existing_test": True, + "issue_response": "test_response", + } + mock_issue_graph.invoke.return_value = mock_output_state + + # Exercise + result = issue_service.answer_issue( + issue_title="Test Issue", + issue_body="Test Body", + issue_comments=[], + issue_type=IssueType.BUG, + run_build=True, + run_existing_test=True, + number_of_candidate_patch=1, + ) + + # Verify + mock_general_container_class.assert_called_once_with( + issue_service.kg_service.kg.get_local_path() + ) + mock_issue_graph_class.assert_called_once_with( + advanced_model=issue_service.llm_service.advanced_model, + base_model=issue_service.llm_service.base_model, + kg=issue_service.kg_service.kg, + git_repo=issue_service.repository_service.git_repo, + neo4j_driver=issue_service.neo4j_service.neo4j_driver, + max_token_per_neo4j_result=issue_service.max_token_per_neo4j_result, + container=mock_container, + build_commands=None, + test_commands=None, + ) + assert result == ("test_patch", True, True, True, "test_response") def test_answer_issue_with_user_defined_container(issue_service, monkeypatch): - # Setup - mock_issue_graph = Mock() - mock_issue_graph_class = Mock(return_value=mock_issue_graph) - monkeypatch.setattr("prometheus.app.services.issue_service.IssueGraph", mock_issue_graph_class) - - mock_container = Mock() - mock_user_container_class = Mock(return_value=mock_container) - monkeypatch.setattr( - "prometheus.app.services.issue_service.UserDefinedContainer", mock_user_container_class - ) - - # Mock output state for question type - mock_output_state = {"issue_type": IssueType.QUESTION, "issue_response": "test_response"} - mock_issue_graph.invoke.return_value = mock_output_state - - # Exercise - result = issue_service.answer_issue( - issue_title="Test Issue", - issue_body="Test Body", - issue_comments=[], - issue_type=IssueType.QUESTION, - run_build=True, - run_existing_test=True, - number_of_candidate_patch=1, - dockerfile_content="FROM python:3.8", - image_name="test-image", - workdir="/app", - build_commands=["pip install -r requirements.txt"], - test_commands=["pytest"], - ) - - # Verify - mock_user_container_class.assert_called_once_with( - issue_service.kg_service.kg.get_local_path(), - "/app", - ["pip install -r requirements.txt"], - ["pytest"], - "FROM python:3.8", - "test-image", - ) - assert result == ("", False, False, False, "test_response") + # Setup + mock_issue_graph = Mock() + mock_issue_graph_class = Mock(return_value=mock_issue_graph) + monkeypatch.setattr("prometheus.app.services.issue_service.IssueGraph", mock_issue_graph_class) + + mock_container = Mock() + mock_user_container_class = Mock(return_value=mock_container) + monkeypatch.setattr( + "prometheus.app.services.issue_service.UserDefinedContainer", mock_user_container_class + ) + + # Mock output state for question type + mock_output_state = {"issue_type": IssueType.QUESTION, "issue_response": "test_response"} + mock_issue_graph.invoke.return_value = mock_output_state + + # Exercise + result = issue_service.answer_issue( + issue_title="Test Issue", + issue_body="Test Body", + issue_comments=[], + issue_type=IssueType.QUESTION, + run_build=True, + run_existing_test=True, + number_of_candidate_patch=1, + dockerfile_content="FROM python:3.8", + image_name="test-image", + workdir="/app", + build_commands=["pip install -r requirements.txt"], + test_commands=["pytest"], + ) + + # Verify + mock_user_container_class.assert_called_once_with( + issue_service.kg_service.kg.get_local_path(), + "/app", + ["pip install -r requirements.txt"], + ["pytest"], + "FROM python:3.8", + "test-image", + ) + assert result == ("", False, False, False, "test_response") def test_answer_issue_unknown_type(issue_service, monkeypatch): - # Setup - mock_issue_graph = Mock() - mock_issue_graph_class = Mock(return_value=mock_issue_graph) - monkeypatch.setattr("prometheus.app.services.issue_service.IssueGraph", mock_issue_graph_class) - - mock_container = Mock() - mock_user_container_class = Mock(return_value=mock_container) - monkeypatch.setattr( - "prometheus.app.services.issue_service.GeneralContainer", mock_user_container_class - ) - - # Mock output state for unknown type - mock_output_state = {"issue_type": "UNKNOWN"} - mock_issue_graph.invoke.return_value = mock_output_state - - # Exercise - result = issue_service.answer_issue( - issue_title="Test Issue", - issue_body="Test Body", - issue_comments=[], - issue_type="UNKNOWN", - run_build=True, - run_existing_test=True, - number_of_candidate_patch=1, - ) - - # Verify - assert result == ("", False, False, False, "") + # Setup + mock_issue_graph = Mock() + mock_issue_graph_class = Mock(return_value=mock_issue_graph) + monkeypatch.setattr("prometheus.app.services.issue_service.IssueGraph", mock_issue_graph_class) + + mock_container = Mock() + mock_user_container_class = Mock(return_value=mock_container) + monkeypatch.setattr( + "prometheus.app.services.issue_service.GeneralContainer", mock_user_container_class + ) + + # Mock output state for unknown type + mock_output_state = {"issue_type": "UNKNOWN"} + mock_issue_graph.invoke.return_value = mock_output_state + + # Exercise + result = issue_service.answer_issue( + issue_title="Test Issue", + issue_body="Test Body", + issue_comments=[], + issue_type="UNKNOWN", + run_build=True, + run_existing_test=True, + number_of_candidate_patch=1, + ) + + # Verify + assert result == ("", False, False, False, "") diff --git a/tests/app/services/test_knowledge_graph_service.py b/tests/app/services/test_knowledge_graph_service.py index b3ef9bea..0a6ec1f7 100644 --- a/tests/app/services/test_knowledge_graph_service.py +++ b/tests/app/services/test_knowledge_graph_service.py @@ -10,131 +10,152 @@ @pytest.fixture def mock_neo4j_service(): - service = create_autospec(Neo4jService, instance=True) - service.neo4j_driver = Mock(name="mock_neo4j_driver") - return service + service = create_autospec(Neo4jService, instance=True) + service.neo4j_driver = Mock(name="mock_neo4j_driver") + return service @pytest.fixture def mock_knowledge_graph(): - return create_autospec(KnowledgeGraph, instance=True) + return create_autospec(KnowledgeGraph, instance=True) @pytest.fixture def mock_kg_handler(monkeypatch): - mock = create_autospec(KnowledgeGraphHandler, instance=True) - monkeypatch.setattr( - "prometheus.neo4j.knowledge_graph_handler.KnowledgeGraphHandler", lambda *args, **kwargs: mock - ) - return mock + mock = create_autospec(KnowledgeGraphHandler, instance=True) + monkeypatch.setattr( + "prometheus.neo4j.knowledge_graph_handler.KnowledgeGraphHandler", + lambda *args, **kwargs: mock, + ) + return mock def test_init_with_existing_graph(mock_neo4j_service, mock_kg_handler, mock_knowledge_graph): - # Setup - mock_kg_handler.knowledge_graph_exists.return_value = True - mock_kg_handler.read_knowledge_graph.return_value = mock_knowledge_graph + # Setup + mock_kg_handler.knowledge_graph_exists.return_value = True + mock_kg_handler.read_knowledge_graph.return_value = mock_knowledge_graph - # Exercise - service = KnowledgeGraphService( - mock_neo4j_service, neo4j_batch_size=100, max_ast_depth=5, chunk_size=1000, chunk_overlap=100 - ) + # Exercise + service = KnowledgeGraphService( + mock_neo4j_service, + neo4j_batch_size=100, + max_ast_depth=5, + chunk_size=1000, + chunk_overlap=100, + ) - # Verify - mock_kg_handler.knowledge_graph_exists.assert_called_once() - mock_kg_handler.read_knowledge_graph.assert_called_once() - assert service.kg == mock_knowledge_graph + # Verify + mock_kg_handler.knowledge_graph_exists.assert_called_once() + mock_kg_handler.read_knowledge_graph.assert_called_once() + assert service.kg == mock_knowledge_graph def test_init_without_existing_graph(mock_neo4j_service, mock_kg_handler): - # Setup - mock_kg_handler.knowledge_graph_exists.return_value = False + # Setup + mock_kg_handler.knowledge_graph_exists.return_value = False - # Exercise - service = KnowledgeGraphService( - mock_neo4j_service, neo4j_batch_size=100, max_ast_depth=5, chunk_size=1000, chunk_overlap=100 - ) + # Exercise + service = KnowledgeGraphService( + mock_neo4j_service, + neo4j_batch_size=100, + max_ast_depth=5, + chunk_size=1000, + chunk_overlap=100, + ) - # Verify - mock_kg_handler.knowledge_graph_exists.assert_called_once() - mock_kg_handler.read_knowledge_graph.assert_not_called() - assert service.kg is None + # Verify + mock_kg_handler.knowledge_graph_exists.assert_called_once() + mock_kg_handler.read_knowledge_graph.assert_not_called() + assert service.kg is None def test_build_and_save_new_knowledge_graph(mock_neo4j_service, mock_kg_handler, monkeypatch): - # Setup - mock_kg_handler.knowledge_graph_exists.return_value = False - - # Mock KnowledgeGraph constructor and instance - mock_kg = Mock() - mock_kg_class = Mock(return_value=mock_kg) - monkeypatch.setattr( - "prometheus.app.services.knowledge_graph_service.KnowledgeGraph", mock_kg_class - ) - mock_kg_class.build_graph.return_value = None - mock_path = "/foo/bar" - mock_kg.get_local_path.return_value = mock_path - - # Exercise - service = KnowledgeGraphService( - mock_neo4j_service, neo4j_batch_size=100, max_ast_depth=5, chunk_size=1000, chunk_overlap=100 - ) - service.build_and_save_knowledge_graph(mock_path) - - # Verify - mock_kg_class.assert_called_once_with( - service.max_ast_depth, service.chunk_size, service.chunk_overlap - ) - mock_kg.build_graph.assert_called_once_with(mock_path, None, None) - mock_kg_handler.write_knowledge_graph.assert_called_once_with(mock_kg) - assert service.kg == mock_kg + # Setup + mock_kg_handler.knowledge_graph_exists.return_value = False + + # Mock KnowledgeGraph constructor and instance + mock_kg = Mock() + mock_kg_class = Mock(return_value=mock_kg) + monkeypatch.setattr( + "prometheus.app.services.knowledge_graph_service.KnowledgeGraph", mock_kg_class + ) + mock_kg_class.build_graph.return_value = None + mock_path = "/foo/bar" + mock_kg.get_local_path.return_value = mock_path + + # Exercise + service = KnowledgeGraphService( + mock_neo4j_service, + neo4j_batch_size=100, + max_ast_depth=5, + chunk_size=1000, + chunk_overlap=100, + ) + service.build_and_save_knowledge_graph(mock_path) + + # Verify + mock_kg_class.assert_called_once_with( + service.max_ast_depth, service.chunk_size, service.chunk_overlap + ) + mock_kg.build_graph.assert_called_once_with(mock_path, None, None) + mock_kg_handler.write_knowledge_graph.assert_called_once_with(mock_kg) + assert service.kg == mock_kg def test_build_and_save_clear_existing_knowledge_graph( - mock_neo4j_service, mock_kg_handler, monkeypatch + mock_neo4j_service, mock_kg_handler, monkeypatch ): - # Setup - mock_kg_handler.knowledge_graph_exists.return_value = True - - # Mock KnowledgeGraph constructor and instance - mock_kg = Mock() - mock_kg_class = Mock(return_value=mock_kg) - monkeypatch.setattr( - "prometheus.app.services.knowledge_graph_service.KnowledgeGraph", mock_kg_class - ) - mock_kg_class.build_graph.return_value = None - mock_path = "/foo/bar" - mock_kg.get_local_path.return_value = mock_path - - # Exercise - service = KnowledgeGraphService( - mock_neo4j_service, neo4j_batch_size=100, max_ast_depth=5, chunk_size=1000, chunk_overlap=100 - ) - service.build_and_save_knowledge_graph(mock_path) - - # Verify - mock_kg_handler.clear_knowledge_graph.assert_called_once() - mock_kg_class.assert_called_once_with( - service.max_ast_depth, service.chunk_size, service.chunk_overlap - ) - mock_kg.build_graph.assert_called_once_with(mock_path, None, None) - mock_kg_handler.write_knowledge_graph.assert_called_once_with(mock_kg) - assert service.kg == mock_kg + # Setup + mock_kg_handler.knowledge_graph_exists.return_value = True + + # Mock KnowledgeGraph constructor and instance + mock_kg = Mock() + mock_kg_class = Mock(return_value=mock_kg) + monkeypatch.setattr( + "prometheus.app.services.knowledge_graph_service.KnowledgeGraph", mock_kg_class + ) + mock_kg_class.build_graph.return_value = None + mock_path = "/foo/bar" + mock_kg.get_local_path.return_value = mock_path + + # Exercise + service = KnowledgeGraphService( + mock_neo4j_service, + neo4j_batch_size=100, + max_ast_depth=5, + chunk_size=1000, + chunk_overlap=100, + ) + service.build_and_save_knowledge_graph(mock_path) + + # Verify + mock_kg_handler.clear_knowledge_graph.assert_called_once() + mock_kg_class.assert_called_once_with( + service.max_ast_depth, service.chunk_size, service.chunk_overlap + ) + mock_kg.build_graph.assert_called_once_with(mock_path, None, None) + mock_kg_handler.write_knowledge_graph.assert_called_once_with(mock_kg) + assert service.kg == mock_kg def test_clear_calls_handler_and_resets_kg( - mock_neo4j_service, mock_kg_handler, mock_knowledge_graph + mock_neo4j_service, mock_kg_handler, mock_knowledge_graph ): - # Setup - mock_kg_handler.knowledge_graph_exists.return_value = True - mock_kg_handler.read_knowledge_graph.return_value = mock_knowledge_graph - mock_knowledge_graph.get_local_path.return_value = "/mock/path" - - # Exercise - service = KnowledgeGraphService( - mock_neo4j_service, neo4j_batch_size=100, max_ast_depth=5, chunk_size=1000, chunk_overlap=100 - ) - service.clear() - - # Verify - mock_kg_handler.clear_knowledge_graph.assert_called_once() - assert service.kg is None + # Setup + mock_kg_handler.knowledge_graph_exists.return_value = True + mock_kg_handler.read_knowledge_graph.return_value = mock_knowledge_graph + mock_knowledge_graph.get_local_path.return_value = "/mock/path" + + # Exercise + service = KnowledgeGraphService( + mock_neo4j_service, + neo4j_batch_size=100, + max_ast_depth=5, + chunk_size=1000, + chunk_overlap=100, + ) + service.clear() + + # Verify + mock_kg_handler.clear_knowledge_graph.assert_called_once() + assert service.kg is None diff --git a/tests/app/services/test_llm_service.py b/tests/app/services/test_llm_service.py index 180dfce8..054d1814 100644 --- a/tests/app/services/test_llm_service.py +++ b/tests/app/services/test_llm_service.py @@ -7,113 +7,113 @@ @pytest.fixture def mock_chat_openai(): - with patch("prometheus.app.services.llm_service.ChatOpenAI") as mock: - yield mock + with patch("prometheus.app.services.llm_service.ChatOpenAI") as mock: + yield mock @pytest.fixture def mock_custom_chat_openai(): - with patch("prometheus.app.services.llm_service.CustomChatOpenAI") as mock: - yield mock + with patch("prometheus.app.services.llm_service.CustomChatOpenAI") as mock: + yield mock @pytest.fixture def mock_chat_anthropic(): - with patch("prometheus.app.services.llm_service.ChatAnthropic") as mock: - yield mock + with patch("prometheus.app.services.llm_service.ChatAnthropic") as mock: + yield mock @pytest.fixture def mock_chat_google(): - with patch("prometheus.app.services.llm_service.ChatGoogleGenerativeAI") as mock: - yield mock + with patch("prometheus.app.services.llm_service.ChatGoogleGenerativeAI") as mock: + yield mock def test_llm_service_init(mock_custom_chat_openai, mock_chat_anthropic): - # Setup - mock_gpt_instance = Mock() - mock_claude_instance = Mock() - mock_custom_chat_openai.return_value = mock_gpt_instance - mock_chat_anthropic.return_value = mock_claude_instance - - # Exercise - service = LLMService( - advanced_model_name="gpt-4", - base_model_name="claude-2.1", - openai_api_key="openai-key", - anthropic_api_key="anthropic-key", - ) - - # Verify - assert service.advanced_model == mock_gpt_instance - assert service.base_model == mock_claude_instance - mock_custom_chat_openai.assert_called_once_with( - model="gpt-4", api_key="openai-key", temperature=0.0, max_tokens=None, max_retries=3 - ) - mock_chat_anthropic.assert_called_once_with( - model="claude-2.1", api_key="anthropic-key", temperature=0.0, max_tokens=8192, max_retries=3 - ) + # Setup + mock_gpt_instance = Mock() + mock_claude_instance = Mock() + mock_custom_chat_openai.return_value = mock_gpt_instance + mock_chat_anthropic.return_value = mock_claude_instance + + # Exercise + service = LLMService( + advanced_model_name="gpt-4", + base_model_name="claude-2.1", + openai_api_key="openai-key", + anthropic_api_key="anthropic-key", + ) + + # Verify + assert service.advanced_model == mock_gpt_instance + assert service.base_model == mock_claude_instance + mock_custom_chat_openai.assert_called_once_with( + model="gpt-4", api_key="openai-key", temperature=0.0, max_tokens=None, max_retries=3 + ) + mock_chat_anthropic.assert_called_once_with( + model="claude-2.1", api_key="anthropic-key", temperature=0.0, max_tokens=8192, max_retries=3 + ) def test_get_model_openrouter(mock_chat_openai): - # Exercise - get_model(model_name="openrouter/model", open_router_api_key="openrouter-key") + # Exercise + get_model(model_name="openrouter/model", open_router_api_key="openrouter-key") - # Verify - mock_chat_openai.assert_called_once_with( - model="openrouter/model", - api_key="openrouter-key", - base_url="https://openrouter.ai/api/v1", - temperature=0.0, - max_tokens=None, - max_retries=3, - ) + # Verify + mock_chat_openai.assert_called_once_with( + model="openrouter/model", + api_key="openrouter-key", + base_url="https://openrouter.ai/api/v1", + temperature=0.0, + max_tokens=None, + max_retries=3, + ) def test_get_model_claude(mock_chat_anthropic): - # Exercise - get_model(model_name="claude-2.1", anthropic_api_key="anthropic-key") + # Exercise + get_model(model_name="claude-2.1", anthropic_api_key="anthropic-key") - # Verify - mock_chat_anthropic.assert_called_once_with( - model="claude-2.1", api_key="anthropic-key", temperature=0.0, max_tokens=8192, max_retries=3 - ) + # Verify + mock_chat_anthropic.assert_called_once_with( + model="claude-2.1", api_key="anthropic-key", temperature=0.0, max_tokens=8192, max_retries=3 + ) def test_get_model_gpt(mock_custom_chat_openai): - # Exercise - get_model(model_name="gpt-4", openai_api_key="openai-key") + # Exercise + get_model(model_name="gpt-4", openai_api_key="openai-key") - # Verify - mock_custom_chat_openai.assert_called_once_with( - model="gpt-4", api_key="openai-key", temperature=0.0, max_tokens=None, max_retries=3 - ) + # Verify + mock_custom_chat_openai.assert_called_once_with( + model="gpt-4", api_key="openai-key", temperature=0.0, max_tokens=None, max_retries=3 + ) def test_get_model_gemini(mock_chat_google): - # Exercise - get_model(model_name="gemini-pro", gemini_api_key="gemini-key") + # Exercise + get_model(model_name="gemini-pro", gemini_api_key="gemini-key") - # Verify - mock_chat_google.assert_called_once_with( - model="gemini-pro", api_key="gemini-key", temperature=0.0, max_tokens=None, max_retries=3 - ) + # Verify + mock_chat_google.assert_called_once_with( + model="gemini-pro", api_key="gemini-key", temperature=0.0, max_tokens=None, max_retries=3 + ) def test_get_model_unknown(): - # Exercise & Verify - with pytest.raises(ValueError, match="Unknown model name: unknown-model"): - get_model("unknown-model") + # Exercise & Verify + with pytest.raises(ValueError, match="Unknown model name: unknown-model"): + get_model("unknown-model") def test_custom_chat_openai_bind_tools(): - # Setup - model = CustomChatOpenAI(api_key="test-key") - mock_tools = [Mock()] + # Setup + model = CustomChatOpenAI(api_key="test-key") + mock_tools = [Mock()] - # Exercise - with patch("prometheus.app.services.llm_service.ChatOpenAI.bind_tools") as mock_bind: - model.bind_tools(mock_tools) + # Exercise + with patch("prometheus.app.services.llm_service.ChatOpenAI.bind_tools") as mock_bind: + model.bind_tools(mock_tools) - # Verify - mock_bind.assert_called_once_with(mock_tools, tool_choice=None, parallel_tool_calls=False) + # Verify + mock_bind.assert_called_once_with(mock_tools, tool_choice=None, parallel_tool_calls=False) diff --git a/tests/app/services/test_neo4j_service.py b/tests/app/services/test_neo4j_service.py index 2841fd43..2a795d83 100644 --- a/tests/app/services/test_neo4j_service.py +++ b/tests/app/services/test_neo4j_service.py @@ -6,12 +6,12 @@ @pytest.mark.slow def test_neo4j_service(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - neo4j_service = Neo4jService( - neo4j_container.get_connection_url(), neo4j_container.username, neo4j_container.password - ) - assert neo4j_service.neo4j_driver is not None - try: - neo4j_service.neo4j_driver.verify_connectivity() - except Exception as e: - pytest.fail(f"Connection verification failed: {e}") + neo4j_container, _ = neo4j_container_with_kg_fixture + neo4j_service = Neo4jService( + neo4j_container.get_connection_url(), neo4j_container.username, neo4j_container.password + ) + assert neo4j_service.neo4j_driver is not None + try: + neo4j_service.neo4j_driver.verify_connectivity() + except Exception as e: + pytest.fail(f"Connection verification failed: {e}") diff --git a/tests/app/services/test_repository_service.py b/tests/app/services/test_repository_service.py index 0923d83d..99d1af27 100644 --- a/tests/app/services/test_repository_service.py +++ b/tests/app/services/test_repository_service.py @@ -11,70 +11,70 @@ @pytest.fixture def mock_kg_service(): - service = create_autospec(KnowledgeGraphService, instance=True) - service.get_local_path.return_value = None - return service + service = create_autospec(KnowledgeGraphService, instance=True) + service.get_local_path.return_value = None + return service @pytest.fixture def mock_git_repository(): - repo = create_autospec(GitRepository, instance=True) - repo.get_working_directory.return_value = Path("/test/working/dir/repositories/repo") - return repo + repo = create_autospec(GitRepository, instance=True) + repo.get_working_directory.return_value = Path("/test/working/dir/repositories/repo") + return repo @pytest.fixture def mock_knowledge_graph(): - kg = create_autospec(KnowledgeGraph, instance=True) - kg.is_built_from_github.return_value = False - return kg + kg = create_autospec(KnowledgeGraph, instance=True) + kg.is_built_from_github.return_value = False + return kg @pytest.fixture def service(mock_kg_service): - working_dir = Path("/test/working/dir") - # Don't mock GitRepository in the fixture anymore - with patch("pathlib.Path.mkdir"): - return RepositoryService( - kg_service=mock_kg_service, - working_dir=working_dir, - ) + working_dir = Path("/test/working/dir") + # Don't mock GitRepository in the fixture anymore + with patch("pathlib.Path.mkdir"): + return RepositoryService( + kg_service=mock_kg_service, + working_dir=working_dir, + ) def test_clone_new_github_repo(service, mock_kg_service, mock_git_repository): - # Setup - test_url = "https://github.com/test/repo" - test_commit = "abc123" - test_github_token = "test_token" - expected_path = Path("/test/working/dir/repositories/repo") - mock_kg_service.local_path = None + # Setup + test_url = "https://github.com/test/repo" + test_commit = "abc123" + test_github_token = "test_token" + expected_path = Path("/test/working/dir/repositories/repo") + mock_kg_service.local_path = None - # Mock the GitRepository class creation - with patch( - "prometheus.app.services.repository_service.GitRepository", return_value=mock_git_repository - ) as mock_git_class: - # Exercise - result_path = service.clone_github_repo(test_github_token, test_url, test_commit) + # Mock the GitRepository class creation + with patch( + "prometheus.app.services.repository_service.GitRepository", return_value=mock_git_repository + ) as mock_git_class: + # Exercise + result_path = service.clone_github_repo(test_github_token, test_url, test_commit) - # Verify - # Check that GitRepository was instantiated with correct parameters - mock_git_class.assert_called_once_with( - test_url, service.target_directory, github_access_token=test_github_token - ) + # Verify + # Check that GitRepository was instantiated with correct parameters + mock_git_class.assert_called_once_with( + test_url, service.target_directory, github_access_token=test_github_token + ) - # Verify checkout_commit was called - mock_git_repository.checkout_commit.assert_called_once_with(test_commit) + # Verify checkout_commit was called + mock_git_repository.checkout_commit.assert_called_once_with(test_commit) - # Verify the returned path matches expected - assert result_path == expected_path + # Verify the returned path matches expected + assert result_path == expected_path def test_clean_working_directory(service): - # Setup - with patch("shutil.rmtree") as mock_rmtree, patch("pathlib.Path.mkdir") as mock_mkdir: - # Exercise - service.clean() - - # Verify - mock_rmtree.assert_called_once_with(service.target_directory) - mock_mkdir.assert_called_once_with(parents=True) + # Setup + with patch("shutil.rmtree") as mock_rmtree, patch("pathlib.Path.mkdir") as mock_mkdir: + # Exercise + service.clean() + + # Verify + mock_rmtree.assert_called_once_with(service.target_directory) + mock_mkdir.assert_called_once_with(parents=True) diff --git a/tests/app/services/test_service_coordinator.py b/tests/app/services/test_service_coordinator.py index 03d6ce63..71652628 100644 --- a/tests/app/services/test_service_coordinator.py +++ b/tests/app/services/test_service_coordinator.py @@ -14,102 +14,102 @@ @pytest.fixture def mock_services(): - llm_service = Mock(spec=LLMService) - llm_service.model = Mock() - - issue_service = Mock(sepc=IssueService) - knowledge_graph_service = Mock(spec=KnowledgeGraphService) - knowledge_graph_service.kg = None - knowledge_graph_service.get_local_path.return_value = None - neo4j_service = Mock(spec=Neo4jService) - repository_service = Mock(spec=RepositoryService) - repository_service.get_working_dir.return_value = None - - yield { - "issue_service": issue_service, - "knowledge_graph_service": knowledge_graph_service, - "llm_service": llm_service, - "neo4j_service": neo4j_service, - "max_token_per_neo4j_result": 1000, - "repository_service": repository_service, - "github_token": "test_token", - "working_directory": Path(tempfile.TemporaryDirectory().name), - } + llm_service = Mock(spec=LLMService) + llm_service.model = Mock() + + issue_service = Mock(sepc=IssueService) + knowledge_graph_service = Mock(spec=KnowledgeGraphService) + knowledge_graph_service.kg = None + knowledge_graph_service.get_local_path.return_value = None + neo4j_service = Mock(spec=Neo4jService) + repository_service = Mock(spec=RepositoryService) + repository_service.get_working_dir.return_value = None + + yield { + "issue_service": issue_service, + "knowledge_graph_service": knowledge_graph_service, + "llm_service": llm_service, + "neo4j_service": neo4j_service, + "max_token_per_neo4j_result": 1000, + "repository_service": repository_service, + "github_token": "test_token", + "working_directory": Path(tempfile.TemporaryDirectory().name), + } @pytest.fixture def service_coordinator(mock_services): - coordinator = ServiceCoordinator( - mock_services["issue_service"], - mock_services["knowledge_graph_service"], - mock_services["llm_service"], - mock_services["neo4j_service"], - mock_services["repository_service"], - mock_services["max_token_per_neo4j_result"], - mock_services["github_token"], - mock_services["working_directory"], - ) - return coordinator + coordinator = ServiceCoordinator( + mock_services["issue_service"], + mock_services["knowledge_graph_service"], + mock_services["llm_service"], + mock_services["neo4j_service"], + mock_services["repository_service"], + mock_services["max_token_per_neo4j_result"], + mock_services["github_token"], + mock_services["working_directory"], + ) + return coordinator def test_initialization(service_coordinator, mock_services): - """Test that services are properly initialized""" - assert service_coordinator.knowledge_graph_service == mock_services["knowledge_graph_service"] - assert service_coordinator.llm_service == mock_services["llm_service"] - assert service_coordinator.neo4j_service == mock_services["neo4j_service"] - assert service_coordinator.repository_service == mock_services["repository_service"] + """Test that services are properly initialized""" + assert service_coordinator.knowledge_graph_service == mock_services["knowledge_graph_service"] + assert service_coordinator.llm_service == mock_services["llm_service"] + assert service_coordinator.neo4j_service == mock_services["neo4j_service"] + assert service_coordinator.repository_service == mock_services["repository_service"] def test_upload_local_repository(service_coordinator, mock_services): - """Test local repository upload process""" - test_path = Path("/test/path") + """Test local repository upload process""" + test_path = Path("/test/path") - service_coordinator.upload_local_repository(test_path) + service_coordinator.upload_local_repository(test_path) - # Verify the sequence of operations - mock_services["knowledge_graph_service"].build_and_save_knowledge_graph.assert_called_once_with( - test_path - ) + # Verify the sequence of operations + mock_services["knowledge_graph_service"].build_and_save_knowledge_graph.assert_called_once_with( + test_path + ) def test_upload_github_repository(service_coordinator, mock_services): - """Test GitHub repository upload process""" - test_url = "https://github.com/test/repo" - test_commit = "abc123" - saved_path = Path("/saved/path") - mock_services["repository_service"].clone_github_repo.return_value = saved_path + """Test GitHub repository upload process""" + test_url = "https://github.com/test/repo" + test_commit = "abc123" + saved_path = Path("/saved/path") + mock_services["repository_service"].clone_github_repo.return_value = saved_path - service_coordinator.upload_github_repository(test_url, test_commit) + service_coordinator.upload_github_repository(test_url, test_commit) - # Verify the sequence of operations - mock_services["repository_service"].clone_github_repo.assert_called_once_with( - mock_services["github_token"], test_url, test_commit - ) - mock_services["knowledge_graph_service"].build_and_save_knowledge_graph.assert_called_once_with( - saved_path, test_url, test_commit - ) + # Verify the sequence of operations + mock_services["repository_service"].clone_github_repo.assert_called_once_with( + mock_services["github_token"], test_url, test_commit + ) + mock_services["knowledge_graph_service"].build_and_save_knowledge_graph.assert_called_once_with( + saved_path, test_url, test_commit + ) def test_clear(service_coordinator, mock_services): - """Test clearing of services""" - service_coordinator.clear() + """Test clearing of services""" + service_coordinator.clear() - mock_services["knowledge_graph_service"].clear.assert_called_once() - mock_services["repository_service"].clean.assert_called_once() + mock_services["knowledge_graph_service"].clear.assert_called_once() + mock_services["repository_service"].clean.assert_called_once() def test_close(service_coordinator, mock_services): - """Test proper closing of services""" - service_coordinator.close() + """Test proper closing of services""" + service_coordinator.close() - mock_services["neo4j_service"].close.assert_called_once() + mock_services["neo4j_service"].close.assert_called_once() def test_exists_knowledge_graph(service_coordinator, mock_services): - """Test knowledge graph existence check""" - mock_services["knowledge_graph_service"].exists.return_value = True + """Test knowledge graph existence check""" + mock_services["knowledge_graph_service"].exists.return_value = True - result = service_coordinator.exists_knowledge_graph() + result = service_coordinator.exists_knowledge_graph() - mock_services["knowledge_graph_service"].exists.assert_called_once() - assert result is True + mock_services["knowledge_graph_service"].exists.assert_called_once() + assert result is True diff --git a/tests/app/test_main.py b/tests/app/test_main.py index 2ed55b02..a9a2aed5 100644 --- a/tests/app/test_main.py +++ b/tests/app/test_main.py @@ -7,59 +7,59 @@ @pytest.fixture(autouse=True) def mock_settings(): - """ - Create a mock settings object and patch it before any tests run. - The autouse=True ensures this fixture runs for all tests in the module. - """ - mock_settings = { - "LOGGING_LEVEL": "INFO", - "NEO4J_URI": "bolt://localhost:7687", - "NEO4J_USERNAME": "neo4j", - "NEO4J_PASSWORD": "password", - "NEO4J_BATCH_SIZE": 1000, - "KNOWLEDGE_GRAPH_MAX_AST_DEPTH": 5, - "KNOWLEDGE_GRAPH_CHUNK_SIZE": 10000, - "KNOWLEDGE_GRAPH_CHUNK_OVERLAP": 1000, - "MAX_TOKEN_PER_NEO4J_RESULT": 5000, - "ADVANCED_MODEL": "claude-3-5-sonnet-latest", - "BASE_MODEL": "claude-3-5-haiku-latest", - "ANTHROPIC_API_KEY": "ANTHROPIC_API_KEY", - "WORKING_DIRECTORY": "/tmp", - "GITHUB_ACCESS_TOKEN": "GITHUB_ACCESS_TOKEN", - "POSTGRES_URI": "postgresql://postgres:password@localhost:5432/postgres?sslmode=disable", - } + """ + Create a mock settings object and patch it before any tests run. + The autouse=True ensures this fixture runs for all tests in the module. + """ + mock_settings = { + "LOGGING_LEVEL": "INFO", + "NEO4J_URI": "bolt://localhost:7687", + "NEO4J_USERNAME": "neo4j", + "NEO4J_PASSWORD": "password", + "NEO4J_BATCH_SIZE": 1000, + "KNOWLEDGE_GRAPH_MAX_AST_DEPTH": 5, + "KNOWLEDGE_GRAPH_CHUNK_SIZE": 10000, + "KNOWLEDGE_GRAPH_CHUNK_OVERLAP": 1000, + "MAX_TOKEN_PER_NEO4J_RESULT": 5000, + "ADVANCED_MODEL": "claude-3-5-sonnet-latest", + "BASE_MODEL": "claude-3-5-haiku-latest", + "ANTHROPIC_API_KEY": "ANTHROPIC_API_KEY", + "WORKING_DIRECTORY": "/tmp", + "GITHUB_ACCESS_TOKEN": "GITHUB_ACCESS_TOKEN", + "POSTGRES_URI": "postgresql://postgres:password@localhost:5432/postgres?sslmode=disable", + } - # Create a DynaconfDict that properly implements attribute access - settings_obj = DynaconfDict(mock_settings) + # Create a DynaconfDict that properly implements attribute access + settings_obj = DynaconfDict(mock_settings) - # Ensure all settings are accessible as attributes - for key, value in mock_settings.items(): - setattr(settings_obj, key, value) + # Ensure all settings are accessible as attributes + for key, value in mock_settings.items(): + setattr(settings_obj, key, value) - # Patch the settings before importing the app - with patch("prometheus.configuration.config.settings", settings_obj): - yield settings_obj + # Patch the settings before importing the app + with patch("prometheus.configuration.config.settings", settings_obj): + yield settings_obj @pytest.fixture def mock_dependencies(): - """Mock the service coordinator dependencies""" - mock_coordinator = MagicMock() - with patch("prometheus.app.dependencies.initialize_services", return_value=mock_coordinator): - yield mock_coordinator + """Mock the service coordinator dependencies""" + mock_coordinator = MagicMock() + with patch("prometheus.app.dependencies.initialize_services", return_value=mock_coordinator): + yield mock_coordinator @pytest.fixture def test_client(mock_settings, mock_dependencies): - """Create a TestClient instance with mocked settings and dependencies""" - # Import app here to ensure settings are properly mocked - from prometheus.app.main import app + """Create a TestClient instance with mocked settings and dependencies""" + # Import app here to ensure settings are properly mocked + from prometheus.app.main import app - with TestClient(app) as client: - yield client + with TestClient(app) as client: + yield client def test_app_initialization(test_client, mock_dependencies): - """Test that the app initializes correctly with mocked dependencies""" - assert test_client.app.state.service_coordinator is not None - assert test_client.app.state.service_coordinator == mock_dependencies + """Test that the app initializes correctly with mocked dependencies""" + assert test_client.app.state.service_coordinator is not None + assert test_client.app.state.service_coordinator == mock_dependencies diff --git a/tests/docker/test_base_container.py b/tests/docker/test_base_container.py index b660c526..400293af 100644 --- a/tests/docker/test_base_container.py +++ b/tests/docker/test_base_container.py @@ -9,172 +9,172 @@ class TestContainer(BaseContainer): - """Concrete implementation of BaseContainer for testing.""" + """Concrete implementation of BaseContainer for testing.""" - def get_dockerfile_content(self) -> str: - return "FROM python:3.9\nWORKDIR /app\nCOPY . /app/" + def get_dockerfile_content(self) -> str: + return "FROM python:3.9\nWORKDIR /app\nCOPY . /app/" - def run_build(self): - """Simple implementation for testing""" - return self.execute_command("echo 'Building...'") + def run_build(self): + """Simple implementation for testing""" + return self.execute_command("echo 'Building...'") - def run_test(self): - """Simple implementation for testing""" - return self.execute_command("echo 'Testing...'") + def run_test(self): + """Simple implementation for testing""" + return self.execute_command("echo 'Testing...'") @pytest.fixture def temp_project_dir(): - # Create a temporary directory with some test files - temp_dir = Path(tempfile.mkdtemp()) - test_file = temp_dir / "test.txt" - test_file.write_text("test content") + # Create a temporary directory with some test files + temp_dir = Path(tempfile.mkdtemp()) + test_file = temp_dir / "test.txt" + test_file.write_text("test content") - yield temp_dir + yield temp_dir - # Cleanup - shutil.rmtree(temp_dir) + # Cleanup + shutil.rmtree(temp_dir) @pytest.fixture def mock_docker_client(): - with patch.object(BaseContainer, "client", new_callable=Mock) as mock_client: - yield mock_client + with patch.object(BaseContainer, "client", new_callable=Mock) as mock_client: + yield mock_client @pytest.fixture def container(temp_project_dir, mock_docker_client): - container = TestContainer(temp_project_dir) - container.tag_name = "test_container_tag" - return container + container = TestContainer(temp_project_dir) + container.tag_name = "test_container_tag" + return container def test_get_dockerfile_content(container): - """Test that get_dockerfile_content returns expected content""" - dockerfile_content = container.get_dockerfile_content() + """Test that get_dockerfile_content returns expected content""" + dockerfile_content = container.get_dockerfile_content() - assert "FROM python:3.9" in dockerfile_content - assert "WORKDIR /app" in dockerfile_content - assert "COPY . /app/" in dockerfile_content + assert "FROM python:3.9" in dockerfile_content + assert "WORKDIR /app" in dockerfile_content + assert "COPY . /app/" in dockerfile_content def test_build_docker_image(container, mock_docker_client): - """Test building Docker image""" - # Execute - container.build_docker_image() + """Test building Docker image""" + # Execute + container.build_docker_image() - # Verify - assert (container.project_path / "prometheus.Dockerfile").exists() - mock_docker_client.images.build.assert_called_once_with( - path=str(container.project_path), dockerfile="prometheus.Dockerfile", tag=container.tag_name - ) + # Verify + assert (container.project_path / "prometheus.Dockerfile").exists() + mock_docker_client.images.build.assert_called_once_with( + path=str(container.project_path), dockerfile="prometheus.Dockerfile", tag=container.tag_name + ) def test_start_container(container, mock_docker_client): - """Test starting Docker container""" - # Setup mock - mock_containers = Mock() - mock_docker_client.containers = mock_containers + """Test starting Docker container""" + # Setup mock + mock_containers = Mock() + mock_docker_client.containers = mock_containers - # Execute - container.start_container() + # Execute + container.start_container() - # Verify - mock_containers.run.assert_called_once_with( - container.tag_name, - detach=True, - tty=True, - network_mode="host", - environment={"PYTHONPATH": f"{container.workdir}:$PYTHONPATH"}, - volumes={"/var/run/docker.sock": {"bind": "/var/run/docker.sock", "mode": "rw"}}, - ) + # Verify + mock_containers.run.assert_called_once_with( + container.tag_name, + detach=True, + tty=True, + network_mode="host", + environment={"PYTHONPATH": f"{container.workdir}:$PYTHONPATH"}, + volumes={"/var/run/docker.sock": {"bind": "/var/run/docker.sock", "mode": "rw"}}, + ) def test_is_running(container): - """Test is_running status check""" - # Test when container is None - assert not container.is_running() + """Test is_running status check""" + # Test when container is None + assert not container.is_running() - # Test when container exists - container.container = Mock() - assert container.is_running() + # Test when container exists + container.container = Mock() + assert container.is_running() def test_update_files(container, temp_project_dir): - """Test updating files in container""" - # Setup - container.container = Mock() - mock_execute = Mock() - container.execute_command = mock_execute + """Test updating files in container""" + # Setup + container.container = Mock() + mock_execute = Mock() + container.execute_command = mock_execute - # Create test files - test_file1 = temp_project_dir / "dir1" / "test1.txt" - test_file2 = temp_project_dir / "dir2" / "test2.txt" - test_file1.parent.mkdir(parents=True) - test_file2.parent.mkdir(parents=True) - test_file1.write_text("test1") - test_file2.write_text("test2") + # Create test files + test_file1 = temp_project_dir / "dir1" / "test1.txt" + test_file2 = temp_project_dir / "dir2" / "test2.txt" + test_file1.parent.mkdir(parents=True) + test_file2.parent.mkdir(parents=True) + test_file1.write_text("test1") + test_file2.write_text("test2") - updated_files = [Path("dir1/test1.txt"), Path("dir2/test2.txt")] - removed_files = [Path("dir3/old.txt")] + updated_files = [Path("dir1/test1.txt"), Path("dir2/test2.txt")] + removed_files = [Path("dir3/old.txt")] - # Execute - container.update_files(temp_project_dir, updated_files, removed_files) + # Execute + container.update_files(temp_project_dir, updated_files, removed_files) - # Verify - mock_execute.assert_has_calls( - [call("rm dir3/old.txt"), call("mkdir -p dir1"), call("mkdir -p dir2")] - ) - assert container.container.put_archive.called + # Verify + mock_execute.assert_has_calls( + [call("rm dir3/old.txt"), call("mkdir -p dir1"), call("mkdir -p dir2")] + ) + assert container.container.put_archive.called def test_execute_command(container): - """Test executing command in container""" - # Setup - mock_container = Mock() - mock_exec_result = Mock() - mock_exec_result.exit_code = 0 - mock_exec_result.output = b"command output" - mock_container.exec_run.return_value = mock_exec_result - container.container = mock_container + """Test executing command in container""" + # Setup + mock_container = Mock() + mock_exec_result = Mock() + mock_exec_result.exit_code = 0 + mock_exec_result.output = b"command output" + mock_container.exec_run.return_value = mock_exec_result + container.container = mock_container - # Execute - result = container.execute_command("test command") + # Execute + result = container.execute_command("test command") - # Verify - mock_container.exec_run.assert_called_once_with( - '/bin/bash -l -c "timeout -k 5 120s test command"', workdir=container.workdir - ) - assert result == "command output" + # Verify + mock_container.exec_run.assert_called_once_with( + '/bin/bash -l -c "timeout -k 5 120s test command"', workdir=container.workdir + ) + assert result == "command output" def test_restart_container(container): - """Test container restart""" - # Setup - mock_container = Mock() - container.container = mock_container - container.start_container = Mock() + """Test container restart""" + # Setup + mock_container = Mock() + container.container = mock_container + container.start_container = Mock() - # Execute - container.restart_container() + # Execute + container.restart_container() - # Verify - mock_container.stop.assert_called_once_with(timeout=10) - mock_container.remove.assert_called_once_with(force=True) - container.start_container.assert_called_once() + # Verify + mock_container.stop.assert_called_once_with(timeout=10) + mock_container.remove.assert_called_once_with(force=True) + container.start_container.assert_called_once() def test_cleanup(container, mock_docker_client): - """Test cleanup of container resources""" - # Setup - mock_container = Mock() - container.container = mock_container - - # Execute - container.cleanup() - - # Verify - mock_container.stop.assert_called_once_with(timeout=10) - mock_container.remove.assert_called_once_with(force=True) - mock_docker_client.images.remove.assert_called_once_with(container.tag_name, force=True) - assert not container.project_path.exists() + """Test cleanup of container resources""" + # Setup + mock_container = Mock() + container.container = mock_container + + # Execute + container.cleanup() + + # Verify + mock_container.stop.assert_called_once_with(timeout=10) + mock_container.remove.assert_called_once_with(force=True) + mock_docker_client.images.remove.assert_called_once_with(container.tag_name, force=True) + assert not container.project_path.exists() diff --git a/tests/docker/test_general_container.py b/tests/docker/test_general_container.py index f85d6e1c..2b2066cd 100644 --- a/tests/docker/test_general_container.py +++ b/tests/docker/test_general_container.py @@ -9,48 +9,48 @@ @pytest.fixture def temp_project_dir(): - # Create a temporary directory with some test files - temp_dir = Path(tempfile.mkdtemp()) - test_file = temp_dir / "test.txt" - test_file.write_text("test content") + # Create a temporary directory with some test files + temp_dir = Path(tempfile.mkdtemp()) + test_file = temp_dir / "test.txt" + test_file.write_text("test content") - yield temp_dir + yield temp_dir - # Cleanup - shutil.rmtree(temp_dir) + # Cleanup + shutil.rmtree(temp_dir) @pytest.fixture def container(temp_project_dir): - return GeneralContainer(temp_project_dir) + return GeneralContainer(temp_project_dir) def test_initialization(container, temp_project_dir): - """Test that the container is initialized correctly""" - assert isinstance(container.tag_name, str) - assert container.tag_name.startswith("prometheus_general_container_") - assert container.project_path != temp_project_dir - assert (container.project_path / "test.txt").exists() + """Test that the container is initialized correctly""" + assert isinstance(container.tag_name, str) + assert container.tag_name.startswith("prometheus_general_container_") + assert container.project_path != temp_project_dir + assert (container.project_path / "test.txt").exists() def test_get_dockerfile_content(container): - dockerfile_content = container.get_dockerfile_content() + dockerfile_content = container.get_dockerfile_content() - assert dockerfile_content + assert dockerfile_content - assert "FROM ubuntu:24.04" in dockerfile_content - assert "WORKDIR /app" in dockerfile_content - assert "RUN apt-get update" in dockerfile_content - assert "COPY . /app/" in dockerfile_content + assert "FROM ubuntu:24.04" in dockerfile_content + assert "WORKDIR /app" in dockerfile_content + assert "RUN apt-get update" in dockerfile_content + assert "COPY . /app/" in dockerfile_content def test_run_build_raises_not_implemented(container): - """Test that run_build raises NotImplementedError""" - with pytest.raises(NotImplementedError): - container.run_build() + """Test that run_build raises NotImplementedError""" + with pytest.raises(NotImplementedError): + container.run_build() def test_run_test_raises_not_implemented(container): - """Test that run_test raises NotImplementedError""" - with pytest.raises(NotImplementedError): - container.run_test() + """Test that run_test raises NotImplementedError""" + with pytest.raises(NotImplementedError): + container.run_test() diff --git a/tests/docker/test_user_defined_container.py b/tests/docker/test_user_defined_container.py index aa9d8449..dbc2447b 100644 --- a/tests/docker/test_user_defined_container.py +++ b/tests/docker/test_user_defined_container.py @@ -10,77 +10,77 @@ @pytest.fixture def temp_project_dir(): - # Create a temporary directory with some test files - temp_dir = Path(tempfile.mkdtemp()) - test_file = temp_dir / "test.txt" - test_file.write_text("test content") + # Create a temporary directory with some test files + temp_dir = Path(tempfile.mkdtemp()) + test_file = temp_dir / "test.txt" + test_file.write_text("test content") - yield temp_dir + yield temp_dir - # Cleanup - shutil.rmtree(temp_dir) + # Cleanup + shutil.rmtree(temp_dir) @pytest.fixture def container(temp_project_dir): - return UserDefinedContainer( - temp_project_dir, - "/app", - ["pip install -r requirements.txt", "python setup.py build"], - ["pytest tests/"], - "FROM python:3.9\nWORKDIR /app\nCOPY . /app/", - None, - ) + return UserDefinedContainer( + temp_project_dir, + "/app", + ["pip install -r requirements.txt", "python setup.py build"], + ["pytest tests/"], + "FROM python:3.9\nWORKDIR /app\nCOPY . /app/", + None, + ) def test_initialization(container, temp_project_dir): - """Test that the container is initialized correctly""" - assert isinstance(container.tag_name, str) - assert container.tag_name.startswith("prometheus_user_defined_container_") - assert container.project_path != temp_project_dir - assert (container.project_path / "test.txt").exists() + """Test that the container is initialized correctly""" + assert isinstance(container.tag_name, str) + assert container.tag_name.startswith("prometheus_user_defined_container_") + assert container.project_path != temp_project_dir + assert (container.project_path / "test.txt").exists() def test_get_dockerfile_content(container): - dockerfile_content = container.get_dockerfile_content() + dockerfile_content = container.get_dockerfile_content() - assert dockerfile_content + assert dockerfile_content - # Check for key elements in the Dockerfile - assert "FROM python:3.9" in dockerfile_content - assert "WORKDIR /app" in dockerfile_content - assert "COPY . /app/" in dockerfile_content + # Check for key elements in the Dockerfile + assert "FROM python:3.9" in dockerfile_content + assert "WORKDIR /app" in dockerfile_content + assert "COPY . /app/" in dockerfile_content def test_run_build(container): - """Test that build commands are executed correctly""" - container.execute_command = Mock() - container.execute_command.side_effect = ["Output 1", "Output 2"] + """Test that build commands are executed correctly""" + container.execute_command = Mock() + container.execute_command.side_effect = ["Output 1", "Output 2"] - build_output = container.run_build() + build_output = container.run_build() - # Verify execute_command was called for each build command - assert container.execute_command.call_count == 2 - container.execute_command.assert_any_call("pip install -r requirements.txt") - container.execute_command.assert_any_call("python setup.py build") + # Verify execute_command was called for each build command + assert container.execute_command.call_count == 2 + container.execute_command.assert_any_call("pip install -r requirements.txt") + container.execute_command.assert_any_call("python setup.py build") - # Verify output format - expected_output = ( - "$ pip install -r requirements.txt\nOutput 1\n$ python setup.py build\nOutput 2\n" - ) - assert build_output == expected_output + # Verify output format + expected_output = ( + "$ pip install -r requirements.txt\nOutput 1\n$ python setup.py build\nOutput 2\n" + ) + assert build_output == expected_output def test_run_test(container): - """Test that test commands are executed correctly""" - container.execute_command = Mock() - container.execute_command.return_value = "Test passed" + """Test that test commands are executed correctly""" + container.execute_command = Mock() + container.execute_command.return_value = "Test passed" - test_output = container.run_test() + test_output = container.run_test() - # Verify execute_command was called for the test command - container.execute_command.assert_called_once_with("pytest tests/") + # Verify execute_command was called for the test command + container.execute_command.assert_called_once_with("pytest tests/") - # Verify output format - expected_output = "$ pytest tests/\nTest passed\n" - assert test_output == expected_output + # Verify output format + expected_output = "$ pytest tests/\nTest passed\n" + assert test_output == expected_output diff --git a/tests/git/test_git_repository.py b/tests/git/test_git_repository.py index 5d3ab854..a48e04c1 100644 --- a/tests/git/test_git_repository.py +++ b/tests/git/test_git_repository.py @@ -10,87 +10,87 @@ @pytest.mark.skipif( - sys.platform.startswith("win"), - reason="Test fails on Windows because of cptree in git_repo_fixture", + sys.platform.startswith("win"), + reason="Test fails on Windows because of cptree in git_repo_fixture", ) @pytest.mark.git def test_init_with_https_url(git_repo_fixture): # noqa: F811 - with mock.patch("git.Repo.clone_from") as mock_clone_from, mock.patch("shutil.rmtree"): - repo = git_repo_fixture - mock_clone_from.return_value = repo + with mock.patch("git.Repo.clone_from") as mock_clone_from, mock.patch("shutil.rmtree"): + repo = git_repo_fixture + mock_clone_from.return_value = repo - access_token = "access_token" - https_url = "https://github.com/foo/bar.git" - target_directory = test_project_paths.TEST_PROJECT_PATH + access_token = "access_token" + https_url = "https://github.com/foo/bar.git" + target_directory = test_project_paths.TEST_PROJECT_PATH - GitRepository( - address=https_url, working_directory=target_directory, github_access_token=access_token - ) + GitRepository( + address=https_url, working_directory=target_directory, github_access_token=access_token + ) - mock_clone_from.assert_called_once_with( - f"https://{access_token}@github.com/foo/bar.git", target_directory / "bar" - ) + mock_clone_from.assert_called_once_with( + f"https://{access_token}@github.com/foo/bar.git", target_directory / "bar" + ) @pytest.mark.skipif( - sys.platform.startswith("win"), - reason="Test fails on Windows because of cptree in git_repo_fixture", + sys.platform.startswith("win"), + reason="Test fails on Windows because of cptree in git_repo_fixture", ) @pytest.mark.git def test_checkout_commit(git_repo_fixture): # noqa: F811 - git_repo = GitRepository( - address=git_repo_fixture.working_dir, - working_directory=Path("/foo/bar"), - copy_to_working_dir=False, - ) + git_repo = GitRepository( + address=git_repo_fixture.working_dir, + working_directory=Path("/foo/bar"), + copy_to_working_dir=False, + ) - commit_sha = "293551b7bd9572b63018c9ed2bccea0f37726805" - assert git_repo.repo.head.commit.hexsha != commit_sha - git_repo.checkout_commit(commit_sha) - assert git_repo.repo.head.commit.hexsha == commit_sha + commit_sha = "293551b7bd9572b63018c9ed2bccea0f37726805" + assert git_repo.repo.head.commit.hexsha != commit_sha + git_repo.checkout_commit(commit_sha) + assert git_repo.repo.head.commit.hexsha == commit_sha @pytest.mark.skipif( - sys.platform.startswith("win"), - reason="Test fails on Windows because of cptree in git_repo_fixture", + sys.platform.startswith("win"), + reason="Test fails on Windows because of cptree in git_repo_fixture", ) @pytest.mark.git def test_switch_branch(git_repo_fixture): # noqa: F811 - git_repo = GitRepository( - address=git_repo_fixture.working_dir, - working_directory=Path("/foo/bar"), - copy_to_working_dir=False, - ) + git_repo = GitRepository( + address=git_repo_fixture.working_dir, + working_directory=Path("/foo/bar"), + copy_to_working_dir=False, + ) - branch_name = "dev" - assert git_repo.repo.active_branch.name != branch_name - git_repo.switch_branch(branch_name) - assert git_repo.repo.active_branch.name == branch_name + branch_name = "dev" + assert git_repo.repo.active_branch.name != branch_name + git_repo.switch_branch(branch_name) + assert git_repo.repo.active_branch.name == branch_name @pytest.mark.skipif( - sys.platform.startswith("win"), - reason="Test fails on Windows because of cptree in git_repo_fixture", + sys.platform.startswith("win"), + reason="Test fails on Windows because of cptree in git_repo_fixture", ) @pytest.mark.git def test_get_diff(git_repo_fixture): # noqa: F811 - local_path = Path(git_repo_fixture.working_dir).absolute() - test_file = local_path / "test.c" - - # Initialize repository - git_repo = GitRepository( - address=str(local_path), working_directory=Path("/foo/bar"), copy_to_working_dir=False - ) - - # Create a change by modifying test.c - original_content = test_file.read_text() - new_content = "int main() { return 0; }\n" - test_file.write_text(new_content) - - # Get diff without exclusions - diff = git_repo.get_diff() - assert diff is not None - expected_diff = """\ + local_path = Path(git_repo_fixture.working_dir).absolute() + test_file = local_path / "test.c" + + # Initialize repository + git_repo = GitRepository( + address=str(local_path), working_directory=Path("/foo/bar"), copy_to_working_dir=False + ) + + # Create a change by modifying test.c + original_content = test_file.read_text() + new_content = "int main() { return 0; }\n" + test_file.write_text(new_content) + + # Get diff without exclusions + diff = git_repo.get_diff() + assert diff is not None + expected_diff = """\ diff --git a/test.c b/test.c index 79a1160..76e8197 100644 --- a/test.c @@ -105,31 +105,31 @@ def test_get_diff(git_repo_fixture): # noqa: F811 \ No newline at end of file +int main() { return 0; } """ - assert diff == expected_diff + assert diff == expected_diff - # Test with excluded files - diff_with_exclusion = git_repo.get_diff(excluded_files=["test.c"]) - assert diff_with_exclusion == "" + # Test with excluded files + diff_with_exclusion = git_repo.get_diff(excluded_files=["test.c"]) + assert diff_with_exclusion == "" - # Cleanup - restore original content - test_file.write_text(original_content) + # Cleanup - restore original content + test_file.write_text(original_content) @pytest.mark.skipif( - sys.platform.startswith("win"), - reason="Test fails on Windows because of cptree in git_repo_fixture", + sys.platform.startswith("win"), + reason="Test fails on Windows because of cptree in git_repo_fixture", ) @pytest.mark.git def test_remove_repository(git_repo_fixture): # noqa: F811 - with mock.patch("shutil.rmtree") as mock_rmtree: - local_path = git_repo_fixture.working_dir + with mock.patch("shutil.rmtree") as mock_rmtree: + local_path = git_repo_fixture.working_dir - git_repo = GitRepository( - address=local_path, working_directory=Path("/foo/bar"), copy_to_working_dir=False - ) - assert git_repo.repo is not None + git_repo = GitRepository( + address=local_path, working_directory=Path("/foo/bar"), copy_to_working_dir=False + ) + assert git_repo.repo is not None - git_repo.remove_repository() + git_repo.remove_repository() - mock_rmtree.assert_called_once_with(local_path) - assert git_repo.repo is None + mock_rmtree.assert_called_once_with(local_path) + assert git_repo.repo is None diff --git a/tests/graph/test_file_graph_builder.py b/tests/graph/test_file_graph_builder.py index 87e07411..5f90c6f1 100644 --- a/tests/graph/test_file_graph_builder.py +++ b/tests/graph/test_file_graph_builder.py @@ -1,104 +1,104 @@ from prometheus.graph.file_graph_builder import FileGraphBuilder from prometheus.graph.graph_types import ( - ASTNode, - KnowledgeGraphEdgeType, - KnowledgeGraphNode, - TextNode, + ASTNode, + KnowledgeGraphEdgeType, + KnowledgeGraphNode, + TextNode, ) from tests.test_utils import test_project_paths def test_supports_file(): - file_graph_builder = FileGraphBuilder(0, 0, 0) + file_graph_builder = FileGraphBuilder(0, 0, 0) - assert file_graph_builder.supports_file(test_project_paths.C_FILE) - assert file_graph_builder.supports_file(test_project_paths.JAVA_FILE) - assert file_graph_builder.supports_file(test_project_paths.MD_FILE) - assert file_graph_builder.supports_file(test_project_paths.PYTHON_FILE) + assert file_graph_builder.supports_file(test_project_paths.C_FILE) + assert file_graph_builder.supports_file(test_project_paths.JAVA_FILE) + assert file_graph_builder.supports_file(test_project_paths.MD_FILE) + assert file_graph_builder.supports_file(test_project_paths.PYTHON_FILE) - assert file_graph_builder.supports_file(test_project_paths.DUMMY_FILE) is False + assert file_graph_builder.supports_file(test_project_paths.DUMMY_FILE) is False def test_build_python_file_graph(): - file_graph_builder = FileGraphBuilder(1000, 1000, 100) - - parent_kg_node = KnowledgeGraphNode(0, None) - next_node_id, kg_nodes, kg_edges = file_graph_builder.build_file_graph( - parent_kg_node, test_project_paths.PYTHON_FILE, 0 - ) - - assert next_node_id == 11 - assert len(kg_nodes) == 11 - assert len(kg_edges) == 11 - - # Test if some of the nodes exists - argument_list_ast_node = ASTNode( - type="argument_list", start_line=1, end_line=1, text='("Hello world!")' - ) - string_ast_node = ASTNode(type="string", start_line=1, end_line=1, text='"Hello world!"') - - found_argument_list_ast_node = False - for kg_node in kg_nodes: - if kg_node.node == argument_list_ast_node: - found_argument_list_ast_node = True - assert found_argument_list_ast_node - - found_string_ast_node = False - for kg_node in kg_nodes: - if kg_node.node == string_ast_node: - found_string_ast_node = True - assert found_string_ast_node - - # Test if some of the edges exists - found_edge = False - for kg_edge in kg_edges: - if ( - kg_edge.source.node == argument_list_ast_node - and kg_edge.target.node == string_ast_node - and kg_edge.type == KnowledgeGraphEdgeType.parent_of - ): - found_edge = True - assert found_edge + file_graph_builder = FileGraphBuilder(1000, 1000, 100) + + parent_kg_node = KnowledgeGraphNode(0, None) + next_node_id, kg_nodes, kg_edges = file_graph_builder.build_file_graph( + parent_kg_node, test_project_paths.PYTHON_FILE, 0 + ) + + assert next_node_id == 11 + assert len(kg_nodes) == 11 + assert len(kg_edges) == 11 + + # Test if some of the nodes exists + argument_list_ast_node = ASTNode( + type="argument_list", start_line=1, end_line=1, text='("Hello world!")' + ) + string_ast_node = ASTNode(type="string", start_line=1, end_line=1, text='"Hello world!"') + + found_argument_list_ast_node = False + for kg_node in kg_nodes: + if kg_node.node == argument_list_ast_node: + found_argument_list_ast_node = True + assert found_argument_list_ast_node + + found_string_ast_node = False + for kg_node in kg_nodes: + if kg_node.node == string_ast_node: + found_string_ast_node = True + assert found_string_ast_node + + # Test if some of the edges exists + found_edge = False + for kg_edge in kg_edges: + if ( + kg_edge.source.node == argument_list_ast_node + and kg_edge.target.node == string_ast_node + and kg_edge.type == KnowledgeGraphEdgeType.parent_of + ): + found_edge = True + assert found_edge def test_build_text_file_graph(): - file_graph_builder = FileGraphBuilder(1000, 100, 10) - - parent_kg_node = KnowledgeGraphNode(0, None) - next_node_id, kg_nodes, kg_edges = file_graph_builder.build_file_graph( - parent_kg_node, test_project_paths.MD_FILE, 0 - ) - - assert next_node_id == 2 - assert len(kg_nodes) == 2 - assert len(kg_edges) == 3 - - # Test if some of the nodes exists - text_node_1 = TextNode( - text="# A\n\nText under header A.\n\n## B\n\nText under header B.\n\n## C\n\nText under header C.\n\n### D", - metadata="", - ) - text_node_2 = TextNode(text="### D\n\nText under header D.", metadata="") - - found_text_node_1 = False - for kg_node in kg_nodes: - if kg_node.node == text_node_1: - found_text_node_1 = True - assert found_text_node_1 - - found_text_node_2 = False - for kg_node in kg_nodes: - if kg_node.node == text_node_2: - found_text_node_2 = True - assert found_text_node_2 - - # Test if some of the edges exists - found_edge = False - for kg_edge in kg_edges: - if ( - kg_edge.source.node == text_node_1 - and kg_edge.target.node == text_node_2 - and kg_edge.type == KnowledgeGraphEdgeType.next_chunk - ): - found_edge = True - assert found_edge + file_graph_builder = FileGraphBuilder(1000, 100, 10) + + parent_kg_node = KnowledgeGraphNode(0, None) + next_node_id, kg_nodes, kg_edges = file_graph_builder.build_file_graph( + parent_kg_node, test_project_paths.MD_FILE, 0 + ) + + assert next_node_id == 2 + assert len(kg_nodes) == 2 + assert len(kg_edges) == 3 + + # Test if some of the nodes exists + text_node_1 = TextNode( + text="# A\n\nText under header A.\n\n## B\n\nText under header B.\n\n## C\n\nText under header C.\n\n### D", + metadata="", + ) + text_node_2 = TextNode(text="### D\n\nText under header D.", metadata="") + + found_text_node_1 = False + for kg_node in kg_nodes: + if kg_node.node == text_node_1: + found_text_node_1 = True + assert found_text_node_1 + + found_text_node_2 = False + for kg_node in kg_nodes: + if kg_node.node == text_node_2: + found_text_node_2 = True + assert found_text_node_2 + + # Test if some of the edges exists + found_edge = False + for kg_edge in kg_edges: + if ( + kg_edge.source.node == text_node_1 + and kg_edge.target.node == text_node_2 + and kg_edge.type == KnowledgeGraphEdgeType.next_chunk + ): + found_edge = True + assert found_edge diff --git a/tests/graph/test_graph_types.py b/tests/graph/test_graph_types.py index 1b3afd9c..c6b7a239 100644 --- a/tests/graph/test_graph_types.py +++ b/tests/graph/test_graph_types.py @@ -1,263 +1,263 @@ from prometheus.graph.graph_types import ( - ASTNode, - FileNode, - KnowledgeGraphEdge, - KnowledgeGraphEdgeType, - KnowledgeGraphNode, - Neo4jASTNode, - Neo4jFileNode, - Neo4jTextNode, - TextNode, + ASTNode, + FileNode, + KnowledgeGraphEdge, + KnowledgeGraphEdgeType, + KnowledgeGraphNode, + Neo4jASTNode, + Neo4jFileNode, + Neo4jTextNode, + TextNode, ) def test_to_neo4j_file_node(): - basename = "foo" - relative_path = "foo/bar/baz.py" - node_id = 1 + basename = "foo" + relative_path = "foo/bar/baz.py" + node_id = 1 - file_node = FileNode(basename, relative_path) - knowldege_graph_node = KnowledgeGraphNode(node_id, file_node) - neo4j_file_node = knowldege_graph_node.to_neo4j_node() + file_node = FileNode(basename, relative_path) + knowldege_graph_node = KnowledgeGraphNode(node_id, file_node) + neo4j_file_node = knowldege_graph_node.to_neo4j_node() - assert isinstance(neo4j_file_node, dict) + assert isinstance(neo4j_file_node, dict) - assert "node_id" in neo4j_file_node - assert "basename" in neo4j_file_node - assert "relative_path" in neo4j_file_node + assert "node_id" in neo4j_file_node + assert "basename" in neo4j_file_node + assert "relative_path" in neo4j_file_node - assert neo4j_file_node["node_id"] == node_id - assert neo4j_file_node["basename"] == basename - assert neo4j_file_node["relative_path"] == relative_path + assert neo4j_file_node["node_id"] == node_id + assert neo4j_file_node["basename"] == basename + assert neo4j_file_node["relative_path"] == relative_path def test_to_neo4j_ast_node(): - type = "method_declaration" - start_line = 1 - end_line = 5 - text = "print('Hello world')" - node_id = 1 + type = "method_declaration" + start_line = 1 + end_line = 5 + text = "print('Hello world')" + node_id = 1 - ast_node = ASTNode(type, start_line, end_line, text) - knowldege_graph_node = KnowledgeGraphNode(node_id, ast_node) - neo4j_ast_node = knowldege_graph_node.to_neo4j_node() + ast_node = ASTNode(type, start_line, end_line, text) + knowldege_graph_node = KnowledgeGraphNode(node_id, ast_node) + neo4j_ast_node = knowldege_graph_node.to_neo4j_node() - assert isinstance(neo4j_ast_node, dict) + assert isinstance(neo4j_ast_node, dict) - assert "node_id" in neo4j_ast_node - assert "type" in neo4j_ast_node - assert "start_line" in neo4j_ast_node - assert "end_line" in neo4j_ast_node - assert "text" in neo4j_ast_node + assert "node_id" in neo4j_ast_node + assert "type" in neo4j_ast_node + assert "start_line" in neo4j_ast_node + assert "end_line" in neo4j_ast_node + assert "text" in neo4j_ast_node - assert neo4j_ast_node["node_id"] == node_id - assert neo4j_ast_node["type"] == type - assert neo4j_ast_node["start_line"] == start_line - assert neo4j_ast_node["end_line"] == end_line - assert neo4j_ast_node["text"] == text + assert neo4j_ast_node["node_id"] == node_id + assert neo4j_ast_node["type"] == type + assert neo4j_ast_node["start_line"] == start_line + assert neo4j_ast_node["end_line"] == end_line + assert neo4j_ast_node["text"] == text def test_to_neo4j_text_node(): - text = "Hello world" - metadata = "metadata" - node_id = 1 + text = "Hello world" + metadata = "metadata" + node_id = 1 - text_node = TextNode(text, metadata) - knowldege_graph_node = KnowledgeGraphNode(node_id, text_node) - neo4j_text_node = knowldege_graph_node.to_neo4j_node() + text_node = TextNode(text, metadata) + knowldege_graph_node = KnowledgeGraphNode(node_id, text_node) + neo4j_text_node = knowldege_graph_node.to_neo4j_node() - assert isinstance(neo4j_text_node, dict) + assert isinstance(neo4j_text_node, dict) - assert "node_id" in neo4j_text_node - assert "text" in neo4j_text_node - assert "metadata" in neo4j_text_node + assert "node_id" in neo4j_text_node + assert "text" in neo4j_text_node + assert "metadata" in neo4j_text_node - assert neo4j_text_node["node_id"] == node_id - assert neo4j_text_node["text"] == text - assert neo4j_text_node["metadata"] == metadata + assert neo4j_text_node["node_id"] == node_id + assert neo4j_text_node["text"] == text + assert neo4j_text_node["metadata"] == metadata def test_to_neo4j_has_file_edge(): - source_basename = "source" - source_relative_path = "foo/bar/source.py" - source_node_id = 1 - target_basename = "target" - target_relative_path = "foo/bar/target.py" - target_node_id = 10 + source_basename = "source" + source_relative_path = "foo/bar/source.py" + source_node_id = 1 + target_basename = "target" + target_relative_path = "foo/bar/target.py" + target_node_id = 10 - source_file_node = FileNode(source_basename, source_relative_path) - source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node) - target_file_node = FileNode(target_basename, target_relative_path) - target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_file_node) - knowledge_graph_edge = KnowledgeGraphEdge( - source_knowledge_graph_node, - target_knowledge_graph_node, - KnowledgeGraphEdgeType.has_file, - ) - neo4j_has_file_edge = knowledge_graph_edge.to_neo4j_edge() + source_file_node = FileNode(source_basename, source_relative_path) + source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node) + target_file_node = FileNode(target_basename, target_relative_path) + target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_file_node) + knowledge_graph_edge = KnowledgeGraphEdge( + source_knowledge_graph_node, + target_knowledge_graph_node, + KnowledgeGraphEdgeType.has_file, + ) + neo4j_has_file_edge = knowledge_graph_edge.to_neo4j_edge() - assert isinstance(neo4j_has_file_edge, dict) + assert isinstance(neo4j_has_file_edge, dict) - assert "source" in neo4j_has_file_edge - assert "target" in neo4j_has_file_edge + assert "source" in neo4j_has_file_edge + assert "target" in neo4j_has_file_edge - assert neo4j_has_file_edge["source"] == source_knowledge_graph_node.to_neo4j_node() - assert neo4j_has_file_edge["target"] == target_knowledge_graph_node.to_neo4j_node() + assert neo4j_has_file_edge["source"] == source_knowledge_graph_node.to_neo4j_node() + assert neo4j_has_file_edge["target"] == target_knowledge_graph_node.to_neo4j_node() def test_to_neo4j_has_ast_edge(): - source_basename = "source" - source_relative_path = "foo/bar/source.py" - source_node_id = 1 - target_type = "return_statement" - target_start_line = 7 - target_end_line = 9 - target_text = "return True" - target_node_id = 10 - - source_file_node = FileNode(source_basename, source_relative_path) - source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node) - target_ast_node = ASTNode(target_type, target_start_line, target_end_line, target_text) - target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_ast_node) - knowledge_graph_edge = KnowledgeGraphEdge( - source_knowledge_graph_node, - target_knowledge_graph_node, - KnowledgeGraphEdgeType.has_ast, - ) - neo4j_has_ast_edge = knowledge_graph_edge.to_neo4j_edge() - - assert isinstance(neo4j_has_ast_edge, dict) - - assert "source" in neo4j_has_ast_edge - assert "target" in neo4j_has_ast_edge - - assert neo4j_has_ast_edge["source"] == source_knowledge_graph_node.to_neo4j_node() - assert neo4j_has_ast_edge["target"] == target_knowledge_graph_node.to_neo4j_node() + source_basename = "source" + source_relative_path = "foo/bar/source.py" + source_node_id = 1 + target_type = "return_statement" + target_start_line = 7 + target_end_line = 9 + target_text = "return True" + target_node_id = 10 + + source_file_node = FileNode(source_basename, source_relative_path) + source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node) + target_ast_node = ASTNode(target_type, target_start_line, target_end_line, target_text) + target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_ast_node) + knowledge_graph_edge = KnowledgeGraphEdge( + source_knowledge_graph_node, + target_knowledge_graph_node, + KnowledgeGraphEdgeType.has_ast, + ) + neo4j_has_ast_edge = knowledge_graph_edge.to_neo4j_edge() + + assert isinstance(neo4j_has_ast_edge, dict) + + assert "source" in neo4j_has_ast_edge + assert "target" in neo4j_has_ast_edge + + assert neo4j_has_ast_edge["source"] == source_knowledge_graph_node.to_neo4j_node() + assert neo4j_has_ast_edge["target"] == target_knowledge_graph_node.to_neo4j_node() def test_to_neo4j_parent_of_edge(): - source_type = "method_declaration" - source_start_line = 1 - source_end_line = 5 - source_text = "print('Hello world')" - source_node_id = 1 - target_type = "return_statement" - target_start_line = 7 - target_end_line = 9 - target_text = "return True" - target_node_id = 10 - - source_ast_node = ASTNode(source_type, source_start_line, source_end_line, source_text) - source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_ast_node) - target_ast_node = ASTNode(target_type, target_start_line, target_end_line, target_text) - target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_ast_node) - knowledge_graph_edge = KnowledgeGraphEdge( - source_knowledge_graph_node, - target_knowledge_graph_node, - KnowledgeGraphEdgeType.parent_of, - ) - neo4j_parent_of_edge = knowledge_graph_edge.to_neo4j_edge() - - assert isinstance(neo4j_parent_of_edge, dict) - - assert "source" in neo4j_parent_of_edge - assert "target" in neo4j_parent_of_edge - - assert neo4j_parent_of_edge["source"] == source_knowledge_graph_node.to_neo4j_node() - assert neo4j_parent_of_edge["target"] == target_knowledge_graph_node.to_neo4j_node() + source_type = "method_declaration" + source_start_line = 1 + source_end_line = 5 + source_text = "print('Hello world')" + source_node_id = 1 + target_type = "return_statement" + target_start_line = 7 + target_end_line = 9 + target_text = "return True" + target_node_id = 10 + + source_ast_node = ASTNode(source_type, source_start_line, source_end_line, source_text) + source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_ast_node) + target_ast_node = ASTNode(target_type, target_start_line, target_end_line, target_text) + target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_ast_node) + knowledge_graph_edge = KnowledgeGraphEdge( + source_knowledge_graph_node, + target_knowledge_graph_node, + KnowledgeGraphEdgeType.parent_of, + ) + neo4j_parent_of_edge = knowledge_graph_edge.to_neo4j_edge() + + assert isinstance(neo4j_parent_of_edge, dict) + + assert "source" in neo4j_parent_of_edge + assert "target" in neo4j_parent_of_edge + + assert neo4j_parent_of_edge["source"] == source_knowledge_graph_node.to_neo4j_node() + assert neo4j_parent_of_edge["target"] == target_knowledge_graph_node.to_neo4j_node() def test_to_neo4j_has_text_edge(): - source_basename = "source" - source_relative_path = "foo/bar/source.py" - source_node_id = 1 - target_text = "Hello world" - target_metadata = "metadata" - target_node_id = 10 + source_basename = "source" + source_relative_path = "foo/bar/source.py" + source_node_id = 1 + target_text = "Hello world" + target_metadata = "metadata" + target_node_id = 10 - source_file_node = FileNode(source_basename, source_relative_path) - source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node) - target_text_node = TextNode(target_text, target_metadata) - target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_text_node) - knowledge_graph_edge = KnowledgeGraphEdge( - source_knowledge_graph_node, - target_knowledge_graph_node, - KnowledgeGraphEdgeType.has_text, - ) - neo4j_has_text_edge = knowledge_graph_edge.to_neo4j_edge() + source_file_node = FileNode(source_basename, source_relative_path) + source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_file_node) + target_text_node = TextNode(target_text, target_metadata) + target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_text_node) + knowledge_graph_edge = KnowledgeGraphEdge( + source_knowledge_graph_node, + target_knowledge_graph_node, + KnowledgeGraphEdgeType.has_text, + ) + neo4j_has_text_edge = knowledge_graph_edge.to_neo4j_edge() - assert isinstance(neo4j_has_text_edge, dict) + assert isinstance(neo4j_has_text_edge, dict) - assert "source" in neo4j_has_text_edge - assert "target" in neo4j_has_text_edge + assert "source" in neo4j_has_text_edge + assert "target" in neo4j_has_text_edge - assert neo4j_has_text_edge["source"] == source_knowledge_graph_node.to_neo4j_node() - assert neo4j_has_text_edge["target"] == target_knowledge_graph_node.to_neo4j_node() + assert neo4j_has_text_edge["source"] == source_knowledge_graph_node.to_neo4j_node() + assert neo4j_has_text_edge["target"] == target_knowledge_graph_node.to_neo4j_node() def test_to_neo4j_next_chunk_edge(): - source_text = "Hello" - source_metadata = "meta" - source_node_id = 1 - target_text = "world" - target_metadata = "data" - target_node_id = 10 + source_text = "Hello" + source_metadata = "meta" + source_node_id = 1 + target_text = "world" + target_metadata = "data" + target_node_id = 10 - source_text_node = TextNode(source_text, source_metadata) - source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_text_node) - target_text_node = TextNode(target_text, target_metadata) - target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_text_node) - knowledge_graph_edge = KnowledgeGraphEdge( - source_knowledge_graph_node, - target_knowledge_graph_node, - KnowledgeGraphEdgeType.has_text, - ) - neo4j_next_chunk_edge = knowledge_graph_edge.to_neo4j_edge() + source_text_node = TextNode(source_text, source_metadata) + source_knowledge_graph_node = KnowledgeGraphNode(source_node_id, source_text_node) + target_text_node = TextNode(target_text, target_metadata) + target_knowledge_graph_node = KnowledgeGraphNode(target_node_id, target_text_node) + knowledge_graph_edge = KnowledgeGraphEdge( + source_knowledge_graph_node, + target_knowledge_graph_node, + KnowledgeGraphEdgeType.has_text, + ) + neo4j_next_chunk_edge = knowledge_graph_edge.to_neo4j_edge() - assert isinstance(neo4j_next_chunk_edge, dict) + assert isinstance(neo4j_next_chunk_edge, dict) - assert "source" in neo4j_next_chunk_edge - assert "target" in neo4j_next_chunk_edge + assert "source" in neo4j_next_chunk_edge + assert "target" in neo4j_next_chunk_edge - assert neo4j_next_chunk_edge["source"] == source_knowledge_graph_node.to_neo4j_node() - assert neo4j_next_chunk_edge["target"] == target_knowledge_graph_node.to_neo4j_node() + assert neo4j_next_chunk_edge["source"] == source_knowledge_graph_node.to_neo4j_node() + assert neo4j_next_chunk_edge["target"] == target_knowledge_graph_node.to_neo4j_node() def test_from_neo4j_file_node(): - node_id = 10 - basename = "foo.py" - relative_path = "bar/baz/foo.py" - neo4j_file_node = Neo4jFileNode(node_id=node_id, basename=basename, relative_path=relative_path) - knowledge_graph_node = KnowledgeGraphNode.from_neo4j_file_node(neo4j_file_node) + node_id = 10 + basename = "foo.py" + relative_path = "bar/baz/foo.py" + neo4j_file_node = Neo4jFileNode(node_id=node_id, basename=basename, relative_path=relative_path) + knowledge_graph_node = KnowledgeGraphNode.from_neo4j_file_node(neo4j_file_node) - expected_file_node = FileNode(basename, relative_path) - expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_file_node) - assert knowledge_graph_node == expected_knowledge_graph_node + expected_file_node = FileNode(basename, relative_path) + expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_file_node) + assert knowledge_graph_node == expected_knowledge_graph_node def test_from_neo4j_ast_node(): - node_id = 15 - type = "string_literal" - start_line = 5 - end_line = 6 - text = '"hello world"' - neo4j_ast_node = Neo4jASTNode( - node_id=node_id, type=type, start_line=start_line, end_line=end_line, text=text - ) - knowledge_graph_node = KnowledgeGraphNode.from_neo4j_ast_node(neo4j_ast_node) + node_id = 15 + type = "string_literal" + start_line = 5 + end_line = 6 + text = '"hello world"' + neo4j_ast_node = Neo4jASTNode( + node_id=node_id, type=type, start_line=start_line, end_line=end_line, text=text + ) + knowledge_graph_node = KnowledgeGraphNode.from_neo4j_ast_node(neo4j_ast_node) - expected_ast_node = ASTNode(type, start_line, end_line, text) - expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_ast_node) - assert knowledge_graph_node == expected_knowledge_graph_node + expected_ast_node = ASTNode(type, start_line, end_line, text) + expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_ast_node) + assert knowledge_graph_node == expected_knowledge_graph_node def test_from_neo4j_text_node(): - node_id = 20 - text = "hello world" - metadata = '{"foo": "bar"}' - neo4j_text_node = Neo4jTextNode(node_id=node_id, text=text, metadata=metadata) - knowledge_graph_node = KnowledgeGraphNode.from_neo4j_text_node(neo4j_text_node) - - expected_text_node = TextNode(text, metadata) - expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_text_node) - assert knowledge_graph_node == expected_knowledge_graph_node + node_id = 20 + text = "hello world" + metadata = '{"foo": "bar"}' + neo4j_text_node = Neo4jTextNode(node_id=node_id, text=text, metadata=metadata) + knowledge_graph_node = KnowledgeGraphNode.from_neo4j_text_node(neo4j_text_node) + + expected_text_node = TextNode(text, metadata) + expected_knowledge_graph_node = KnowledgeGraphNode(node_id, expected_text_node) + assert knowledge_graph_node == expected_knowledge_graph_node diff --git a/tests/graph/test_knowledge_graph.py b/tests/graph/test_knowledge_graph.py index 2b1d22c4..d61a12e7 100644 --- a/tests/graph/test_knowledge_graph.py +++ b/tests/graph/test_knowledge_graph.py @@ -7,34 +7,34 @@ def test_build_graph(): - knowledge_graph = KnowledgeGraph(1000, 100, 10) - knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH) + knowledge_graph = KnowledgeGraph(1000, 100, 10) + knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH) - assert knowledge_graph._next_node_id == 95 - assert knowledge_graph._metadata_node is not None - assert knowledge_graph.is_built_from_local_codebase() - assert knowledge_graph._metadata_node.local_path == str(test_project_paths.TEST_PROJECT_PATH) - # 9 FileNode - # 84 ASTnode - # 2 TextNode - assert len(knowledge_graph._knowledge_graph_nodes) == 95 - assert len(knowledge_graph._knowledge_graph_edges) == 95 + assert knowledge_graph._next_node_id == 95 + assert knowledge_graph._metadata_node is not None + assert knowledge_graph.is_built_from_local_codebase() + assert knowledge_graph._metadata_node.local_path == str(test_project_paths.TEST_PROJECT_PATH) + # 9 FileNode + # 84 ASTnode + # 2 TextNode + assert len(knowledge_graph._knowledge_graph_nodes) == 95 + assert len(knowledge_graph._knowledge_graph_edges) == 95 - assert len(knowledge_graph.get_file_nodes()) == 9 - assert len(knowledge_graph.get_ast_nodes()) == 84 - assert len(knowledge_graph.get_text_nodes()) == 2 - assert len(knowledge_graph.get_parent_of_edges()) == 81 - assert len(knowledge_graph.get_has_file_edges()) == 8 - assert len(knowledge_graph.get_has_ast_edges()) == 3 - assert len(knowledge_graph.get_has_text_edges()) == 2 - assert len(knowledge_graph.get_next_chunk_edges()) == 1 + assert len(knowledge_graph.get_file_nodes()) == 9 + assert len(knowledge_graph.get_ast_nodes()) == 84 + assert len(knowledge_graph.get_text_nodes()) == 2 + assert len(knowledge_graph.get_parent_of_edges()) == 81 + assert len(knowledge_graph.get_has_file_edges()) == 8 + assert len(knowledge_graph.get_has_ast_edges()) == 3 + assert len(knowledge_graph.get_has_text_edges()) == 2 + assert len(knowledge_graph.get_next_chunk_edges()) == 1 def test_get_file_tree(): - knowledge_graph = KnowledgeGraph(1000, 1000, 100) - knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH) - file_tree = knowledge_graph.get_file_tree() - expected_file_tree = """\ + knowledge_graph = KnowledgeGraph(1000, 1000, 100) + knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH) + file_tree = knowledge_graph.get_file_tree() + expected_file_tree = """\ test_project โ”œโ”€โ”€ .gitignore โ”œโ”€โ”€ bar @@ -44,27 +44,27 @@ def test_get_file_tree(): | โ”œโ”€โ”€ test.dummy | โ””โ”€โ”€ test.md โ””โ”€โ”€ test.c""" - assert file_tree == expected_file_tree + assert file_tree == expected_file_tree def test_get_file_tree_depth_one(): - knowledge_graph = KnowledgeGraph(1000, 1000, 100) - knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH) - file_tree = knowledge_graph.get_file_tree(max_depth=1) - expected_file_tree = """\ + knowledge_graph = KnowledgeGraph(1000, 1000, 100) + knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH) + file_tree = knowledge_graph.get_file_tree(max_depth=1) + expected_file_tree = """\ test_project โ”œโ”€โ”€ .gitignore โ”œโ”€โ”€ bar โ”œโ”€โ”€ foo โ””โ”€โ”€ test.c""" - assert file_tree == expected_file_tree + assert file_tree == expected_file_tree def test_get_file_tree_depth_two_max_seven_lines(): - knowledge_graph = KnowledgeGraph(1000, 1000, 100) - knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH) - file_tree = knowledge_graph.get_file_tree(max_depth=2, max_lines=7) - expected_file_tree = """\ + knowledge_graph = KnowledgeGraph(1000, 1000, 100) + knowledge_graph.build_graph(test_project_paths.TEST_PROJECT_PATH) + file_tree = knowledge_graph.get_file_tree(max_depth=2, max_lines=7) + expected_file_tree = """\ test_project โ”œโ”€โ”€ .gitignore โ”œโ”€โ”€ bar @@ -72,14 +72,14 @@ def test_get_file_tree_depth_two_max_seven_lines(): | โ””โ”€โ”€ test.py โ”œโ”€โ”€ foo | โ”œโ”€โ”€ test.dummy""" - assert file_tree == expected_file_tree + assert file_tree == expected_file_tree @pytest.mark.slow def test_from_neo4j(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, kg = neo4j_container_with_kg_fixture - driver = neo4j_container.get_driver() - handler = KnowledgeGraphHandler(driver, 100) - read_kg = handler.read_knowledge_graph() + neo4j_container, kg = neo4j_container_with_kg_fixture + driver = neo4j_container.get_driver() + handler = KnowledgeGraphHandler(driver, 100) + read_kg = handler.read_knowledge_graph() - assert read_kg == kg + assert read_kg == kg diff --git a/tests/lang_graph/graphs/test_issue_graph.py b/tests/lang_graph/graphs/test_issue_graph.py index dfb1efc8..837e80b3 100644 --- a/tests/lang_graph/graphs/test_issue_graph.py +++ b/tests/lang_graph/graphs/test_issue_graph.py @@ -12,54 +12,54 @@ @pytest.fixture def mock_advanced_model(): - return Mock(spec=BaseChatModel) + return Mock(spec=BaseChatModel) @pytest.fixture def mock_base_model(): - return Mock(spec=BaseChatModel) + return Mock(spec=BaseChatModel) @pytest.fixture def mock_kg(): - kg = Mock(spec=KnowledgeGraph) - kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"] - return kg + kg = Mock(spec=KnowledgeGraph) + kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"] + return kg @pytest.fixture def mock_git_repo(): - return Mock(spec=GitRepository) + return Mock(spec=GitRepository) @pytest.fixture def mock_neo4j_driver(): - return Mock(spec=neo4j.Driver) + return Mock(spec=neo4j.Driver) @pytest.fixture def mock_container(): - return Mock(spec=BaseContainer) + return Mock(spec=BaseContainer) def test_issue_graph_basic_initialization( - mock_advanced_model, - mock_base_model, - mock_kg, - mock_git_repo, - mock_neo4j_driver, - mock_container, + mock_advanced_model, + mock_base_model, + mock_kg, + mock_git_repo, + mock_neo4j_driver, + mock_container, ): - """Test that IssueGraph initializes correctly with basic components.""" - graph = IssueGraph( - advanced_model=mock_advanced_model, - base_model=mock_base_model, - kg=mock_kg, - git_repo=mock_git_repo, - neo4j_driver=mock_neo4j_driver, - max_token_per_neo4j_result=1000, - container=mock_container, - ) - - assert graph.graph is not None - assert graph.git_repo == mock_git_repo + """Test that IssueGraph initializes correctly with basic components.""" + graph = IssueGraph( + advanced_model=mock_advanced_model, + base_model=mock_base_model, + kg=mock_kg, + git_repo=mock_git_repo, + neo4j_driver=mock_neo4j_driver, + max_token_per_neo4j_result=1000, + container=mock_container, + ) + + assert graph.graph is not None + assert graph.git_repo == mock_git_repo diff --git a/tests/lang_graph/nodes/test_bug_fix_verify_node.py b/tests/lang_graph/nodes/test_bug_fix_verify_node.py index 21b6f2eb..e9d1075d 100644 --- a/tests/lang_graph/nodes/test_bug_fix_verify_node.py +++ b/tests/lang_graph/nodes/test_bug_fix_verify_node.py @@ -11,42 +11,42 @@ @pytest.fixture def mock_container(): - return Mock(spec=BaseContainer) + return Mock(spec=BaseContainer) @pytest.fixture def test_state(): - return BugFixVerficationState( - { - "reproduced_bug_file": "test_bug.py", - "reproduced_bug_commands": ["python test_bug.py", "./run_test.sh"], - "bug_fix_verify_messages": [AIMessage(content="Previous verification result")], - } - ) + return BugFixVerficationState( + { + "reproduced_bug_file": "test_bug.py", + "reproduced_bug_commands": ["python test_bug.py", "./run_test.sh"], + "bug_fix_verify_messages": [AIMessage(content="Previous verification result")], + } + ) @pytest.fixture def fake_llm(): - return FakeListChatWithToolsModel(responses=["Test execution completed"]) + return FakeListChatWithToolsModel(responses=["Test execution completed"]) def test_format_human_message(mock_container, fake_llm, test_state): - """Test human message formatting.""" - node = BugFixVerifyNode(fake_llm, mock_container) - message = node.format_human_message(test_state) + """Test human message formatting.""" + node = BugFixVerifyNode(fake_llm, mock_container) + message = node.format_human_message(test_state) - assert isinstance(message, HumanMessage) - assert "test_bug.py" in message.content - assert "python test_bug.py" in message.content - assert "./run_test.sh" in message.content + assert isinstance(message, HumanMessage) + assert "test_bug.py" in message.content + assert "python test_bug.py" in message.content + assert "./run_test.sh" in message.content def test_call_method(mock_container, fake_llm, test_state): - """Test the __call__ method execution.""" - node = BugFixVerifyNode(fake_llm, mock_container) + """Test the __call__ method execution.""" + node = BugFixVerifyNode(fake_llm, mock_container) - result = node(test_state) + result = node(test_state) - assert "bug_fix_verify_messages" in result - assert len(result["bug_fix_verify_messages"]) == 1 - assert result["bug_fix_verify_messages"][0].content == "Test execution completed" + assert "bug_fix_verify_messages" in result + assert len(result["bug_fix_verify_messages"]) == 1 + assert result["bug_fix_verify_messages"][0].content == "Test execution completed" diff --git a/tests/lang_graph/nodes/test_bug_reproducing_execute_node.py b/tests/lang_graph/nodes/test_bug_reproducing_execute_node.py index 8e360f3f..f6afcf69 100644 --- a/tests/lang_graph/nodes/test_bug_reproducing_execute_node.py +++ b/tests/lang_graph/nodes/test_bug_reproducing_execute_node.py @@ -11,64 +11,64 @@ @pytest.fixture def mock_container(): - return Mock(spec=BaseContainer) + return Mock(spec=BaseContainer) @pytest.fixture def test_state(): - return BugReproductionState( - { - "issue_title": "Test Bug", - "issue_body": "Bug description", - "issue_comments": ["Comment 1", "Comment 2"], - "bug_context": "Context of the bug", - "bug_reproducing_write_messages": [AIMessage(content="patch")], - "bug_reproducing_file_messages": [AIMessage(content="path")], - "bug_reproducing_execute_messages": [], - "bug_reproducing_patch": "--- /dev/null\n+++ b/newfile\n@@ -0,0 +1 @@\n+content", - } - ) + return BugReproductionState( + { + "issue_title": "Test Bug", + "issue_body": "Bug description", + "issue_comments": ["Comment 1", "Comment 2"], + "bug_context": "Context of the bug", + "bug_reproducing_write_messages": [AIMessage(content="patch")], + "bug_reproducing_file_messages": [AIMessage(content="path")], + "bug_reproducing_execute_messages": [], + "bug_reproducing_patch": "--- /dev/null\n+++ b/newfile\n@@ -0,0 +1 @@\n+content", + } + ) def test_format_human_message_with_test_commands(mock_container, test_state): - """Test message formatting with provided test commands.""" - fake_llm = FakeListChatWithToolsModel(responses=["test"]) - test_commands = ["pytest", "python -m unittest"] - node = BugReproducingExecuteNode(fake_llm, mock_container, test_commands) + """Test message formatting with provided test commands.""" + fake_llm = FakeListChatWithToolsModel(responses=["test"]) + test_commands = ["pytest", "python -m unittest"] + node = BugReproducingExecuteNode(fake_llm, mock_container, test_commands) - message = node.format_human_message(test_state, "/foo/bar/test.py") + message = node.format_human_message(test_state, "/foo/bar/test.py") - assert isinstance(message, HumanMessage) - assert "Test Bug" in message.content - assert "Bug description" in message.content - assert "Comment 1" in message.content - assert "Comment 2" in message.content - assert "pytest" in message.content - assert "/foo/bar/test.py" in message.content + assert isinstance(message, HumanMessage) + assert "Test Bug" in message.content + assert "Bug description" in message.content + assert "Comment 1" in message.content + assert "Comment 2" in message.content + assert "pytest" in message.content + assert "/foo/bar/test.py" in message.content def test_format_human_message_without_test_commands(mock_container, test_state): - """Test message formatting without test commands.""" - fake_llm = FakeListChatWithToolsModel(responses=["test"]) - node = BugReproducingExecuteNode(fake_llm, mock_container) + """Test message formatting without test commands.""" + fake_llm = FakeListChatWithToolsModel(responses=["test"]) + node = BugReproducingExecuteNode(fake_llm, mock_container) - message = node.format_human_message(test_state, "/foo/bar/test.py") + message = node.format_human_message(test_state, "/foo/bar/test.py") - assert isinstance(message, HumanMessage) - assert "Test Bug" in message.content - assert "Bug description" in message.content - assert "User provided test commands:\n" in message.content - assert "/foo/bar/test.py" in message.content + assert isinstance(message, HumanMessage) + assert "Test Bug" in message.content + assert "Bug description" in message.content + assert "User provided test commands:\n" in message.content + assert "/foo/bar/test.py" in message.content def test_call_method(mock_container, test_state): - """Test the __call__ method execution.""" - fake_response = "Test execution completed" - fake_llm = FakeListChatWithToolsModel(responses=[fake_response]) - node = BugReproducingExecuteNode(fake_llm, mock_container) + """Test the __call__ method execution.""" + fake_response = "Test execution completed" + fake_llm = FakeListChatWithToolsModel(responses=[fake_response]) + node = BugReproducingExecuteNode(fake_llm, mock_container) - result = node(test_state) + result = node(test_state) - assert "bug_reproducing_execute_messages" in result - assert len(result["bug_reproducing_execute_messages"]) == 1 - assert result["bug_reproducing_execute_messages"][0].content == fake_response + assert "bug_reproducing_execute_messages" in result + assert len(result["bug_reproducing_execute_messages"]) == 1 + assert result["bug_reproducing_execute_messages"][0].content == fake_response diff --git a/tests/lang_graph/nodes/test_bug_reproducing_file_node.py b/tests/lang_graph/nodes/test_bug_reproducing_file_node.py index 9f04f670..da08e0c4 100644 --- a/tests/lang_graph/nodes/test_bug_reproducing_file_node.py +++ b/tests/lang_graph/nodes/test_bug_reproducing_file_node.py @@ -11,50 +11,52 @@ @pytest.fixture def mock_kg(): - kg = Mock(spec=KnowledgeGraph) - kg.get_local_path.return_value = "/test/path" - kg.get_file_tree.return_value = "test_dir/\n test_file.py" - return kg + kg = Mock(spec=KnowledgeGraph) + kg.get_local_path.return_value = "/test/path" + kg.get_file_tree.return_value = "test_dir/\n test_file.py" + return kg @pytest.fixture def fake_llm(): - return FakeListChatWithToolsModel(responses=["test_output.py"]) + return FakeListChatWithToolsModel(responses=["test_output.py"]) @pytest.fixture def basic_state(): - return BugReproductionState( - { - "bug_reproducing_write_messages": [AIMessage(content="def test_bug():\n assert 1 == 2")], - "bug_reproducing_file_messages": [], - } - ) + return BugReproductionState( + { + "bug_reproducing_write_messages": [ + AIMessage(content="def test_bug():\n assert 1 == 2") + ], + "bug_reproducing_file_messages": [], + } + ) def test_initialization(mock_kg, fake_llm): - """Test basic initialization of BugReproducingFileNode.""" - node = BugReproducingFileNode(fake_llm, mock_kg) + """Test basic initialization of BugReproducingFileNode.""" + node = BugReproducingFileNode(fake_llm, mock_kg) - assert isinstance(node.system_prompt, SystemMessage) - assert len(node.tools) == 2 # read_file, create_file - assert mock_kg.get_local_path.called + assert isinstance(node.system_prompt, SystemMessage) + assert len(node.tools) == 2 # read_file, create_file + assert mock_kg.get_local_path.called def test_format_human_message(mock_kg, fake_llm, basic_state): - """Test human message formatting with bug file.""" - node = BugReproducingFileNode(fake_llm, mock_kg) - message = node.format_human_message(basic_state) + """Test human message formatting with bug file.""" + node = BugReproducingFileNode(fake_llm, mock_kg) + message = node.format_human_message(basic_state) - assert isinstance(message, HumanMessage) - assert "def test_bug():" in message.content + assert isinstance(message, HumanMessage) + assert "def test_bug():" in message.content def test_call_method(mock_kg, fake_llm, basic_state): - """Test the __call__ method execution.""" - node = BugReproducingFileNode(fake_llm, mock_kg) - result = node(basic_state) + """Test the __call__ method execution.""" + node = BugReproducingFileNode(fake_llm, mock_kg) + result = node(basic_state) - assert "bug_reproducing_file_messages" in result - assert len(result["bug_reproducing_file_messages"]) == 1 - assert result["bug_reproducing_file_messages"][0].content == "test_output.py" + assert "bug_reproducing_file_messages" in result + assert len(result["bug_reproducing_file_messages"]) == 1 + assert result["bug_reproducing_file_messages"][0].content == "test_output.py" diff --git a/tests/lang_graph/nodes/test_bug_reproducing_write_node.py b/tests/lang_graph/nodes/test_bug_reproducing_write_node.py index 427848ff..fa56fe23 100644 --- a/tests/lang_graph/nodes/test_bug_reproducing_write_node.py +++ b/tests/lang_graph/nodes/test_bug_reproducing_write_node.py @@ -11,34 +11,34 @@ @pytest.fixture def mock_kg(): - kg = Mock(spec=KnowledgeGraph) - kg.get_local_path.return_value = "/foo/bar" - return kg + kg = Mock(spec=KnowledgeGraph) + kg.get_local_path.return_value = "/foo/bar" + return kg @pytest.fixture def test_state(): - return BugReproductionState( - { - "issue_title": "Test Bug", - "issue_body": "Bug description", - "issue_comments": ["Comment 1", "Comment 2"], - "bug_context": "Context of the bug", - "reproduced_bug_file": "test/file.py", - "reproduced_bug_failure_log": "Test failure log", - "bug_reproducing_write_messages": [HumanMessage("assert x == 10")], - } - ) + return BugReproductionState( + { + "issue_title": "Test Bug", + "issue_body": "Bug description", + "issue_comments": ["Comment 1", "Comment 2"], + "bug_context": "Context of the bug", + "reproduced_bug_file": "test/file.py", + "reproduced_bug_failure_log": "Test failure log", + "bug_reproducing_write_messages": [HumanMessage("assert x == 10")], + } + ) def test_call_method(mock_kg, test_state): - """Test the __call__ method execution.""" - fake_response = "Created test file" - fake_llm = FakeListChatWithToolsModel(responses=[fake_response]) - node = BugReproducingWriteNode(fake_llm, mock_kg) + """Test the __call__ method execution.""" + fake_response = "Created test file" + fake_llm = FakeListChatWithToolsModel(responses=[fake_response]) + node = BugReproducingWriteNode(fake_llm, mock_kg) - result = node(test_state) + result = node(test_state) - assert "bug_reproducing_write_messages" in result - assert len(result["bug_reproducing_write_messages"]) == 1 - assert result["bug_reproducing_write_messages"][0].content == fake_response + assert "bug_reproducing_write_messages" in result + assert len(result["bug_reproducing_write_messages"]) == 1 + assert result["bug_reproducing_write_messages"][0].content == fake_response diff --git a/tests/lang_graph/nodes/test_context_provider_node.py b/tests/lang_graph/nodes/test_context_provider_node.py index db697592..905442cf 100644 --- a/tests/lang_graph/nodes/test_context_provider_node.py +++ b/tests/lang_graph/nodes/test_context_provider_node.py @@ -8,25 +8,25 @@ @pytest.mark.slow def test_context_provider_node_basic_query(neo4j_container_with_kg_fixture): # noqa: F811 - """Test basic query handling with the ContextProviderNode.""" - neo4j_container, kg = neo4j_container_with_kg_fixture - fake_response = "Fake response" - fake_llm = FakeListChatWithToolsModel(responses=[fake_response]) - node = ContextProviderNode( - model=fake_llm, kg=kg, neo4j_driver=neo4j_container.get_driver(), max_token_per_result=1000 - ) + """Test basic query handling with the ContextProviderNode.""" + neo4j_container, kg = neo4j_container_with_kg_fixture + fake_response = "Fake response" + fake_llm = FakeListChatWithToolsModel(responses=[fake_response]) + node = ContextProviderNode( + model=fake_llm, kg=kg, neo4j_driver=neo4j_container.get_driver(), max_token_per_result=1000 + ) - test_messages = [ - AIMessage(content="This code handles file processing"), - ToolMessage(content="Found implementation in utils.py", tool_call_id="test_tool_call_1"), - ] - test_state = { - "original_query": "How does the error handling work?", - "context_provider_messages": test_messages, - } + test_messages = [ + AIMessage(content="This code handles file processing"), + ToolMessage(content="Found implementation in utils.py", tool_call_id="test_tool_call_1"), + ] + test_state = { + "original_query": "How does the error handling work?", + "context_provider_messages": test_messages, + } - result = node(test_state) + result = node(test_state) - assert "context_provider_messages" in result - assert len(result["context_provider_messages"]) == 1 - assert result["context_provider_messages"][0].content == fake_response + assert "context_provider_messages" in result + assert len(result["context_provider_messages"]) == 1 + assert result["context_provider_messages"][0].content == fake_response diff --git a/tests/lang_graph/nodes/test_edit_message_node.py b/tests/lang_graph/nodes/test_edit_message_node.py index f7df2e42..2f803685 100644 --- a/tests/lang_graph/nodes/test_edit_message_node.py +++ b/tests/lang_graph/nodes/test_edit_message_node.py @@ -8,110 +8,110 @@ @pytest.fixture def edit_node(): - return EditMessageNode() + return EditMessageNode() @pytest.fixture def base_state(): - return { - "issue_title": "Test Bug", - "issue_body": "This is a test bug description", - "issue_comments": [ - {"username": "user1", "comment": "Comment 1"}, - {"username": "user2", "comment": "Comment 2"}, - ], - "bug_fix_context": ["Context 1", "Context 2"], - "issue_bug_analyzer_messages": ["Analysis message"], - } + return { + "issue_title": "Test Bug", + "issue_body": "This is a test bug description", + "issue_comments": [ + {"username": "user1", "comment": "Comment 1"}, + {"username": "user2", "comment": "Comment 2"}, + ], + "bug_fix_context": ["Context 1", "Context 2"], + "issue_bug_analyzer_messages": ["Analysis message"], + } def test_first_message_formatting(edit_node, base_state): - # Using context managers for patching - with patch( - "prometheus.lang_graph.nodes.edit_message_node.format_issue_info" - ) as mock_format_issue: + # Using context managers for patching with patch( - "prometheus.lang_graph.nodes.edit_message_node.get_last_message_content" - ) as mock_last_message: - mock_format_issue.return_value = "Formatted Issue Info" - mock_last_message.return_value = "Last Analysis Message" + "prometheus.lang_graph.nodes.edit_message_node.format_issue_info" + ) as mock_format_issue: + with patch( + "prometheus.lang_graph.nodes.edit_message_node.get_last_message_content" + ) as mock_last_message: + mock_format_issue.return_value = "Formatted Issue Info" + mock_last_message.return_value = "Last Analysis Message" - result = edit_node(base_state) + result = edit_node(base_state) - assert isinstance(result, dict) - assert "edit_messages" in result - assert len(result["edit_messages"]) == 1 - assert isinstance(result["edit_messages"][0], HumanMessage) + assert isinstance(result, dict) + assert "edit_messages" in result + assert len(result["edit_messages"]) == 1 + assert isinstance(result["edit_messages"][0], HumanMessage) - message_content = result["edit_messages"][0].content - assert "Formatted Issue Info" in message_content - assert "Context 1\n\nContext 2" in message_content - assert "Last Analysis Message" in message_content + message_content = result["edit_messages"][0].content + assert "Formatted Issue Info" in message_content + assert "Context 1\n\nContext 2" in message_content + assert "Last Analysis Message" in message_content def test_followup_message_with_build_fail(edit_node, base_state): - # Add build failure to state - base_state["build_fail_log"] = "Build failed: error in compilation" + # Add build failure to state + base_state["build_fail_log"] = "Build failed: error in compilation" - with patch( - "prometheus.lang_graph.nodes.edit_message_node.get_last_message_content" - ) as mock_last_message: - mock_last_message.return_value = "Last Analysis Message" + with patch( + "prometheus.lang_graph.nodes.edit_message_node.get_last_message_content" + ) as mock_last_message: + mock_last_message.return_value = "Last Analysis Message" - result = edit_node(base_state) - message_content = result["edit_messages"][0].content + result = edit_node(base_state) + message_content = result["edit_messages"][0].content - assert "Build failed: error in compilation" in message_content - assert "Please implement these revised changes carefully" in message_content + assert "Build failed: error in compilation" in message_content + assert "Please implement these revised changes carefully" in message_content def test_followup_message_with_test_fail(edit_node, base_state): - # Add test failure to state - base_state["reproducing_test_fail_log"] = "Test failed: assertion error" + # Add test failure to state + base_state["reproducing_test_fail_log"] = "Test failed: assertion error" - with patch( - "prometheus.lang_graph.nodes.edit_message_node.get_last_message_content" - ) as mock_last_message: - mock_last_message.return_value = "Last Analysis Message" + with patch( + "prometheus.lang_graph.nodes.edit_message_node.get_last_message_content" + ) as mock_last_message: + mock_last_message.return_value = "Last Analysis Message" - result = edit_node(base_state) - message_content = result["edit_messages"][0].content + result = edit_node(base_state) + message_content = result["edit_messages"][0].content - assert "Test failed: assertion error" in message_content - assert "Please implement these revised changes carefully" in message_content + assert "Test failed: assertion error" in message_content + assert "Please implement these revised changes carefully" in message_content def test_followup_message_with_existing_test_fail(edit_node, base_state): - # Add existing test failure to state - base_state["existing_test_fail_log"] = "Existing test failed" + # Add existing test failure to state + base_state["existing_test_fail_log"] = "Existing test failed" - with patch( - "prometheus.lang_graph.nodes.edit_message_node.get_last_message_content" - ) as mock_last_message: - mock_last_message.return_value = "Last Analysis Message" + with patch( + "prometheus.lang_graph.nodes.edit_message_node.get_last_message_content" + ) as mock_last_message: + mock_last_message.return_value = "Last Analysis Message" - result = edit_node(base_state) - message_content = result["edit_messages"][0].content + result = edit_node(base_state) + message_content = result["edit_messages"][0].content - assert "Existing test failed" in message_content - assert "Please implement these revised changes carefully" in message_content + assert "Existing test failed" in message_content + assert "Please implement these revised changes carefully" in message_content def test_error_priority(edit_node, base_state): - # Add multiple error types to test priority handling - base_state["reproducing_test_fail_log"] = "Test failed" - base_state["build_fail_log"] = "Build failed" - base_state["existing_test_fail_log"] = "Existing test failed" - - with patch( - "prometheus.lang_graph.nodes.edit_message_node.get_last_message_content" - ) as mock_last_message: - mock_last_message.return_value = "Last Analysis Message" - - result = edit_node(base_state) - message_content = result["edit_messages"][0].content - - # Should prioritize reproducing test failure - assert "Test failed" in message_content - assert "Build failed" not in message_content - assert "Existing test failed" not in message_content + # Add multiple error types to test priority handling + base_state["reproducing_test_fail_log"] = "Test failed" + base_state["build_fail_log"] = "Build failed" + base_state["existing_test_fail_log"] = "Existing test failed" + + with patch( + "prometheus.lang_graph.nodes.edit_message_node.get_last_message_content" + ) as mock_last_message: + mock_last_message.return_value = "Last Analysis Message" + + result = edit_node(base_state) + message_content = result["edit_messages"][0].content + + # Should prioritize reproducing test failure + assert "Test failed" in message_content + assert "Build failed" not in message_content + assert "Existing test failed" not in message_content diff --git a/tests/lang_graph/nodes/test_edit_node.py b/tests/lang_graph/nodes/test_edit_node.py index 6909b5b0..cc3f5bf5 100644 --- a/tests/lang_graph/nodes/test_edit_node.py +++ b/tests/lang_graph/nodes/test_edit_node.py @@ -10,32 +10,32 @@ @pytest.fixture def mock_kg(): - kg = Mock(spec=KnowledgeGraph) - kg.get_local_path.return_value = "/mock/path" - return kg + kg = Mock(spec=KnowledgeGraph) + kg.get_local_path.return_value = "/mock/path" + return kg @pytest.fixture def fake_llm(): - return FakeListChatWithToolsModel(responses=["File edit completed successfully"]) + return FakeListChatWithToolsModel(responses=["File edit completed successfully"]) def test_init_edit_node(mock_kg, fake_llm): - """Test EditNode initialization.""" - node = EditNode(fake_llm, mock_kg) + """Test EditNode initialization.""" + node = EditNode(fake_llm, mock_kg) - assert isinstance(node.system_prompt, SystemMessage) - assert len(node.tools) == 5 # Should have 5 file operation tools - assert node.model_with_tools is not None + assert isinstance(node.system_prompt, SystemMessage) + assert len(node.tools) == 5 # Should have 5 file operation tools + assert node.model_with_tools is not None def test_call_method_basic(mock_kg, fake_llm): - """Test basic call functionality without tool execution.""" - node = EditNode(fake_llm, mock_kg) - state = {"edit_messages": [HumanMessage(content="Make the following changes: ...")]} + """Test basic call functionality without tool execution.""" + node = EditNode(fake_llm, mock_kg) + state = {"edit_messages": [HumanMessage(content="Make the following changes: ...")]} - result = node(state) + result = node(state) - assert "edit_messages" in result - assert len(result["edit_messages"]) == 1 - assert result["edit_messages"][0].content == "File edit completed successfully" + assert "edit_messages" in result + assert len(result["edit_messages"]) == 1 + assert result["edit_messages"][0].content == "File edit completed successfully" diff --git a/tests/lang_graph/nodes/test_general_build_node.py b/tests/lang_graph/nodes/test_general_build_node.py index 48b8bd76..743bcff3 100644 --- a/tests/lang_graph/nodes/test_general_build_node.py +++ b/tests/lang_graph/nodes/test_general_build_node.py @@ -12,51 +12,53 @@ @pytest.fixture def mock_container(): - return Mock(spec=BaseContainer) + return Mock(spec=BaseContainer) @pytest.fixture def mock_kg(): - kg = Mock(spec=KnowledgeGraph) - kg.get_file_tree.return_value = ".\nโ”œโ”€โ”€ src\nโ”‚ โ””โ”€โ”€ main.py\nโ””โ”€โ”€ build.gradle" - return kg + kg = Mock(spec=KnowledgeGraph) + kg.get_file_tree.return_value = ".\nโ”œโ”€โ”€ src\nโ”‚ โ””โ”€โ”€ main.py\nโ””โ”€โ”€ build.gradle" + return kg @pytest.fixture def fake_llm(): - return FakeListChatWithToolsModel(responses=["Build command executed successfully"]) + return FakeListChatWithToolsModel(responses=["Build command executed successfully"]) def test_format_human_message_basic(mock_container, mock_kg, fake_llm): - """Test basic human message formatting.""" - node = GeneralBuildNode(fake_llm, mock_container, mock_kg) - state = BuildAndTestState({}) + """Test basic human message formatting.""" + node = GeneralBuildNode(fake_llm, mock_container, mock_kg) + state = BuildAndTestState({}) - message = node.format_human_message(state) + message = node.format_human_message(state) - assert isinstance(message, HumanMessage) - assert "project structure is:" in message.content - assert mock_kg.get_file_tree() in message.content + assert isinstance(message, HumanMessage) + assert "project structure is:" in message.content + assert mock_kg.get_file_tree() in message.content def test_format_human_message_with_build_summary(mock_container, mock_kg, fake_llm): - """Test message formatting with build command summary.""" - node = GeneralBuildNode(fake_llm, mock_container, mock_kg) - state = BuildAndTestState({"build_command_summary": "Previous build used gradle"}) + """Test message formatting with build command summary.""" + node = GeneralBuildNode(fake_llm, mock_container, mock_kg) + state = BuildAndTestState({"build_command_summary": "Previous build used gradle"}) - message = node.format_human_message(state) + message = node.format_human_message(state) - assert "Previous build used gradle" in message.content - assert "The previous build summary is:" in message.content + assert "Previous build used gradle" in message.content + assert "The previous build summary is:" in message.content def test_call_method_with_no_build(mock_container, mock_kg, fake_llm): - """Test __call__ method when exist_build is False.""" - node = GeneralBuildNode(fake_llm, mock_container, mock_kg) - state = BuildAndTestState({"exist_build": False}) + """Test __call__ method when exist_build is False.""" + node = GeneralBuildNode(fake_llm, mock_container, mock_kg) + state = BuildAndTestState({"exist_build": False}) - result = node(state) + result = node(state) - assert "build_messages" in result - assert len(result["build_messages"]) == 1 - assert "Previous agent determined there is no build system" in result["build_messages"][0].content + assert "build_messages" in result + assert len(result["build_messages"]) == 1 + assert ( + "Previous agent determined there is no build system" in result["build_messages"][0].content + ) diff --git a/tests/lang_graph/nodes/test_general_test_node.py b/tests/lang_graph/nodes/test_general_test_node.py index 8f18c76d..b218bc4e 100644 --- a/tests/lang_graph/nodes/test_general_test_node.py +++ b/tests/lang_graph/nodes/test_general_test_node.py @@ -12,62 +12,64 @@ @pytest.fixture def mock_container(): - return Mock(spec=BaseContainer) + return Mock(spec=BaseContainer) @pytest.fixture def mock_kg(): - kg = Mock(spec=KnowledgeGraph) - kg.get_file_tree.return_value = "./\nโ”œโ”€โ”€ tests/\nโ”‚ โ””โ”€โ”€ test_main.py" - return kg + kg = Mock(spec=KnowledgeGraph) + kg.get_file_tree.return_value = "./\nโ”œโ”€โ”€ tests/\nโ”‚ โ””โ”€โ”€ test_main.py" + return kg @pytest.fixture def fake_llm(): - return FakeListChatWithToolsModel(responses=["Tests executed successfully"]) + return FakeListChatWithToolsModel(responses=["Tests executed successfully"]) @pytest.fixture def basic_state(): - return BuildAndTestState({"exist_test": True, "test_messages": [], "test_command_summary": None}) + return BuildAndTestState( + {"exist_test": True, "test_messages": [], "test_command_summary": None} + ) def test_format_human_message(mock_container, mock_kg, fake_llm, basic_state): - """Test basic message formatting.""" - node = GeneralTestNode(fake_llm, mock_container, mock_kg) - message = node.format_human_message(basic_state) + """Test basic message formatting.""" + node = GeneralTestNode(fake_llm, mock_container, mock_kg) + message = node.format_human_message(basic_state) - assert isinstance(message, HumanMessage) - assert mock_kg.get_file_tree() in message.content + assert isinstance(message, HumanMessage) + assert mock_kg.get_file_tree() in message.content def test_call_with_no_tests(mock_container, mock_kg, fake_llm): - """Test behavior when no tests exist.""" - node = GeneralTestNode(fake_llm, mock_container, mock_kg) - state = BuildAndTestState({"exist_test": False}) + """Test behavior when no tests exist.""" + node = GeneralTestNode(fake_llm, mock_container, mock_kg) + state = BuildAndTestState({"exist_test": False}) - result = node(state) + result = node(state) - assert "build_messages" in result - assert "no test framework" in result["build_messages"][0].content + assert "build_messages" in result + assert "no test framework" in result["build_messages"][0].content def test_call_normal_execution(mock_container, mock_kg, fake_llm, basic_state): - """Test normal execution flow.""" - node = GeneralTestNode(fake_llm, mock_container, mock_kg) + """Test normal execution flow.""" + node = GeneralTestNode(fake_llm, mock_container, mock_kg) - result = node(basic_state) + result = node(basic_state) - assert "test_messages" in result - assert len(result["test_messages"]) == 1 - assert result["test_messages"][0].content == "Tests executed successfully" + assert "test_messages" in result + assert len(result["test_messages"]) == 1 + assert result["test_messages"][0].content == "Tests executed successfully" def test_format_human_message_with_summary(mock_container, mock_kg, fake_llm): - """Test message formatting with test summary.""" - node = GeneralTestNode(fake_llm, mock_container, mock_kg) - state = BuildAndTestState({"test_command_summary": "Previous test used pytest"}) + """Test message formatting with test summary.""" + node = GeneralTestNode(fake_llm, mock_container, mock_kg) + state = BuildAndTestState({"test_command_summary": "Previous test used pytest"}) - message = node.format_human_message(state) + message = node.format_human_message(state) - assert "Previous test used pytest" in message.content + assert "Previous test used pytest" in message.content diff --git a/tests/lang_graph/nodes/test_git_diff_node.py b/tests/lang_graph/nodes/test_git_diff_node.py index 598c65b7..457bfbc6 100644 --- a/tests/lang_graph/nodes/test_git_diff_node.py +++ b/tests/lang_graph/nodes/test_git_diff_node.py @@ -8,28 +8,28 @@ @pytest.fixture def mock_git_repo(): - git_repo = Mock(spec=GitRepository) - git_repo.get_diff.return_value = "sample diff content" - return git_repo + git_repo = Mock(spec=GitRepository) + git_repo.get_diff.return_value = "sample diff content" + return git_repo def test_git_diff_node(mock_git_repo): - node = GitDiffNode(mock_git_repo, "patch") + node = GitDiffNode(mock_git_repo, "patch") - # Execute - result = node({}) + # Execute + result = node({}) - # Assert - assert result == {"patch": "sample diff content"} - mock_git_repo.get_diff.assert_called_with(None) + # Assert + assert result == {"patch": "sample diff content"} + mock_git_repo.get_diff.assert_called_with(None) def test_git_diff_node_with_excluded_files(mock_git_repo): - node = GitDiffNode(mock_git_repo, "patch", "excluded_file") + node = GitDiffNode(mock_git_repo, "patch", "excluded_file") - # Execute - result = node({"excluded_file": "/foo/bar.py"}) + # Execute + result = node({"excluded_file": "/foo/bar.py"}) - # Assert - assert result == {"patch": "sample diff content"} - mock_git_repo.get_diff.assert_called_with(["/foo/bar.py"]) + # Assert + assert result == {"patch": "sample diff content"} + mock_git_repo.get_diff.assert_called_with(["/foo/bar.py"]) diff --git a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py index cd50b494..38842e7e 100644 --- a/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py +++ b/tests/lang_graph/nodes/test_issue_bug_analyzer_node.py @@ -7,16 +7,16 @@ @pytest.fixture def fake_llm(): - return FakeListChatWithToolsModel(responses=["Bug analysis completed successfully"]) + return FakeListChatWithToolsModel(responses=["Bug analysis completed successfully"]) def test_call_method_basic(fake_llm): - """Test basic call functionality.""" - node = IssueBugAnalyzerNode(fake_llm) - state = {"issue_bug_analyzer_messages": [HumanMessage(content="Please analyze this bug: ...")]} + """Test basic call functionality.""" + node = IssueBugAnalyzerNode(fake_llm) + state = {"issue_bug_analyzer_messages": [HumanMessage(content="Please analyze this bug: ...")]} - result = node(state) + result = node(state) - assert "issue_bug_analyzer_messages" in result - assert len(result["issue_bug_analyzer_messages"]) == 1 - assert result["issue_bug_analyzer_messages"][0].content == "Bug analysis completed successfully" + assert "issue_bug_analyzer_messages" in result + assert len(result["issue_bug_analyzer_messages"]) == 1 + assert result["issue_bug_analyzer_messages"][0].content == "Bug analysis completed successfully" diff --git a/tests/lang_graph/nodes/test_issue_bug_responder_node.py b/tests/lang_graph/nodes/test_issue_bug_responder_node.py index 44cc4923..061622c2 100644 --- a/tests/lang_graph/nodes/test_issue_bug_responder_node.py +++ b/tests/lang_graph/nodes/test_issue_bug_responder_node.py @@ -8,105 +8,105 @@ @pytest.fixture def fake_llm(): - return FakeListChatWithToolsModel( - responses=["Thank you for reporting this issue. The fix has been implemented and verified."] - ) + return FakeListChatWithToolsModel( + responses=["Thank you for reporting this issue. The fix has been implemented and verified."] + ) @pytest.fixture def basic_state(): - return IssueBugState( - { - "issue_title": "Test Bug", - "issue_body": "Found a bug in the code", - "issue_comments": [ - {"username": "user1", "comment": "This affects my workflow"}, - {"username": "user2", "comment": "Same issue here"}, - ], - "edit_messages": [AIMessage("I have fixed the bug")], - "edit_patch": "Fixed array index calculation", - "passed_reproducing_test": True, - "passed_build": True, - "passed_existing_test": True, - } - ) + return IssueBugState( + { + "issue_title": "Test Bug", + "issue_body": "Found a bug in the code", + "issue_comments": [ + {"username": "user1", "comment": "This affects my workflow"}, + {"username": "user2", "comment": "Same issue here"}, + ], + "edit_messages": [AIMessage("I have fixed the bug")], + "edit_patch": "Fixed array index calculation", + "passed_reproducing_test": True, + "passed_build": True, + "passed_existing_test": True, + } + ) def test_format_human_message_basic(fake_llm, basic_state): - """Test basic human message formatting.""" - node = IssueBugResponderNode(fake_llm) - message = node.format_human_message(basic_state) + """Test basic human message formatting.""" + node = IssueBugResponderNode(fake_llm) + message = node.format_human_message(basic_state) - assert "Test Bug" in message.content - assert "Found a bug in the code" in message.content - assert "user1" in message.content - assert "user2" in message.content - assert "Fixed array index" in message.content + assert "Test Bug" in message.content + assert "Found a bug in the code" in message.content + assert "user1" in message.content + assert "user2" in message.content + assert "Fixed array index" in message.content def test_format_human_message_verification(fake_llm, basic_state): - """Test verification message formatting.""" - node = IssueBugResponderNode(fake_llm) - message = node.format_human_message(basic_state) + """Test verification message formatting.""" + node = IssueBugResponderNode(fake_llm) + message = node.format_human_message(basic_state) - assert "โœ“ The bug reproducing test passed" in message.content - assert "โœ“ Build passes successfully" in message.content - assert "โœ“ All existing tests pass successfully" in message.content + assert "โœ“ The bug reproducing test passed" in message.content + assert "โœ“ Build passes successfully" in message.content + assert "โœ“ All existing tests pass successfully" in message.content def test_format_human_message_no_verification(fake_llm): - """Test message formatting without verifications.""" - state = IssueBugState( - { - "issue_title": "Test Bug", - "issue_body": "Bug description", - "issue_comments": [], - "edit_messages": [AIMessage("I have fixed the bug")], - "edit_patch": "Fixed array index calculation", - "passed_reproducing_test": False, - "passed_build": False, - "passed_existing_test": False, - } - ) - - node = IssueBugResponderNode(fake_llm) - message = node.format_human_message(state) - - assert "โœ“ The bug reproducing test passed" not in message.content - assert "โœ“ Build passes successfully" not in message.content - assert "โœ“ All existing tests pass successfully" not in message.content + """Test message formatting without verifications.""" + state = IssueBugState( + { + "issue_title": "Test Bug", + "issue_body": "Bug description", + "issue_comments": [], + "edit_messages": [AIMessage("I have fixed the bug")], + "edit_patch": "Fixed array index calculation", + "passed_reproducing_test": False, + "passed_build": False, + "passed_existing_test": False, + } + ) + + node = IssueBugResponderNode(fake_llm) + message = node.format_human_message(state) + + assert "โœ“ The bug reproducing test passed" not in message.content + assert "โœ“ Build passes successfully" not in message.content + assert "โœ“ All existing tests pass successfully" not in message.content def test_format_human_message_partial_verification(fake_llm): - """Test message formatting with partial verifications.""" - state = IssueBugState( - { - "issue_title": "Test Bug", - "issue_body": "Bug description", - "issue_comments": [], - "edit_messages": [AIMessage("I have fixed the bug")], - "edit_patch": "Fixed array index calculation", - "passed_reproducing_test": True, - "passed_build": False, - "passed_existing_test": True, - } - ) - - node = IssueBugResponderNode(fake_llm) - message = node.format_human_message(state) - - assert "โœ“ The bug reproducing test passed" in message.content - assert "โœ“ Build passes successfully" not in message.content - assert "โœ“ All existing tests pass successfully" in message.content + """Test message formatting with partial verifications.""" + state = IssueBugState( + { + "issue_title": "Test Bug", + "issue_body": "Bug description", + "issue_comments": [], + "edit_messages": [AIMessage("I have fixed the bug")], + "edit_patch": "Fixed array index calculation", + "passed_reproducing_test": True, + "passed_build": False, + "passed_existing_test": True, + } + ) + + node = IssueBugResponderNode(fake_llm) + message = node.format_human_message(state) + + assert "โœ“ The bug reproducing test passed" in message.content + assert "โœ“ Build passes successfully" not in message.content + assert "โœ“ All existing tests pass successfully" in message.content def test_call_method(fake_llm, basic_state): - """Test the call method execution.""" - node = IssueBugResponderNode(fake_llm) - result = node(basic_state) - - assert "issue_response" in result - assert ( - result["issue_response"] - == "Thank you for reporting this issue. The fix has been implemented and verified." - ) + """Test the call method execution.""" + node = IssueBugResponderNode(fake_llm) + result = node(basic_state) + + assert "issue_response" in result + assert ( + result["issue_response"] + == "Thank you for reporting this issue. The fix has been implemented and verified." + ) diff --git a/tests/lang_graph/nodes/test_reset_messages_node.py b/tests/lang_graph/nodes/test_reset_messages_node.py index e74ce8f1..5d77be7e 100644 --- a/tests/lang_graph/nodes/test_reset_messages_node.py +++ b/tests/lang_graph/nodes/test_reset_messages_node.py @@ -4,17 +4,17 @@ def test_reset_messages_node(): - reset_build_messages_node = ResetMessagesNode("build_messages") + reset_build_messages_node = ResetMessagesNode("build_messages") - state = { - "build_messages": [HumanMessage(content="message 1"), HumanMessage(content="message 2")], - "test_messages": [HumanMessage(content="message 3"), HumanMessage(content="message 4")], - } + state = { + "build_messages": [HumanMessage(content="message 1"), HumanMessage(content="message 2")], + "test_messages": [HumanMessage(content="message 3"), HumanMessage(content="message 4")], + } - reset_build_messages_node(state) + reset_build_messages_node(state) - assert "build_messages" in state - assert len(state["build_messages"]) == 0 + assert "build_messages" in state + assert len(state["build_messages"]) == 0 - assert "test_messages" in state - assert len(state["test_messages"]) == 2 + assert "test_messages" in state + assert len(state["test_messages"]) == 2 diff --git a/tests/lang_graph/nodes/test_update_container_node.py b/tests/lang_graph/nodes/test_update_container_node.py index de85e4b0..2a7d946c 100644 --- a/tests/lang_graph/nodes/test_update_container_node.py +++ b/tests/lang_graph/nodes/test_update_container_node.py @@ -7,18 +7,18 @@ def test_update_container_node(): - mocked_container = Mock(spec=GeneralContainer) - mocked_container.is_running.return_value = True - mocked_git_repo = Mock(spec=GitRepository) - mocked_git_repo.get_diff.return_value = "--- /dev/null\n+++ b/newfile\n@@ -0,0 +1 @@\n+content" - mocked_git_repo.get_working_directory.return_value = Path("/test/working/dir/repositories/repo") - update_container_node = UpdateContainerNode(mocked_container, mocked_git_repo) + mocked_container = Mock(spec=GeneralContainer) + mocked_container.is_running.return_value = True + mocked_git_repo = Mock(spec=GitRepository) + mocked_git_repo.get_diff.return_value = "--- /dev/null\n+++ b/newfile\n@@ -0,0 +1 @@\n+content" + mocked_git_repo.get_working_directory.return_value = Path("/test/working/dir/repositories/repo") + update_container_node = UpdateContainerNode(mocked_container, mocked_git_repo) - update_container_node(None) + update_container_node(None) - assert mocked_git_repo.get_diff.call_count == 1 - assert mocked_container.is_running.call_count == 1 - assert mocked_container.update_files.call_count == 1 - mocked_container.update_files.assert_called_with( - Path("/test/working/dir/repositories/repo"), [Path("newfile")], [] - ) + assert mocked_git_repo.get_diff.call_count == 1 + assert mocked_container.is_running.call_count == 1 + assert mocked_container.update_files.call_count == 1 + mocked_container.update_files.assert_called_with( + Path("/test/working/dir/repositories/repo"), [Path("newfile")], [] + ) diff --git a/tests/lang_graph/nodes/test_user_defined_build_node.py b/tests/lang_graph/nodes/test_user_defined_build_node.py index 1ed8f83d..9374bb82 100644 --- a/tests/lang_graph/nodes/test_user_defined_build_node.py +++ b/tests/lang_graph/nodes/test_user_defined_build_node.py @@ -9,24 +9,24 @@ @pytest.fixture def mock_container(): - container = Mock(spec=BaseContainer) - container.run_build.return_value = "Build successful" - return container + container = Mock(spec=BaseContainer) + container.run_build.return_value = "Build successful" + return container @pytest.fixture def build_node(mock_container): - return UserDefinedBuildNode(container=mock_container) + return UserDefinedBuildNode(container=mock_container) def test_successful_build(build_node, mock_container): - expected_output = "Build successful" + expected_output = "Build successful" - result = build_node(None) + result = build_node(None) - assert isinstance(result, dict) - assert "build_messages" in result - assert len(result["build_messages"]) == 1 - assert isinstance(result["build_messages"][0], ToolMessage) - assert result["build_messages"][0].content == expected_output - mock_container.run_build.assert_called_once() + assert isinstance(result, dict) + assert "build_messages" in result + assert len(result["build_messages"]) == 1 + assert isinstance(result["build_messages"][0], ToolMessage) + assert result["build_messages"][0].content == expected_output + mock_container.run_build.assert_called_once() diff --git a/tests/lang_graph/nodes/test_user_defined_test_node.py b/tests/lang_graph/nodes/test_user_defined_test_node.py index 4b2f0c69..06b3ec60 100644 --- a/tests/lang_graph/nodes/test_user_defined_test_node.py +++ b/tests/lang_graph/nodes/test_user_defined_test_node.py @@ -9,24 +9,24 @@ @pytest.fixture def mock_container(): - container = Mock(spec=BaseContainer) - container.run_test.return_value = "Test successful" - return container + container = Mock(spec=BaseContainer) + container.run_test.return_value = "Test successful" + return container @pytest.fixture def test_node(mock_container): - return UserDefinedTestNode(container=mock_container) + return UserDefinedTestNode(container=mock_container) def test_successful_test(test_node, mock_container): - expected_output = "Test successful" + expected_output = "Test successful" - result = test_node(None) + result = test_node(None) - assert isinstance(result, dict) - assert "test_messages" in result - assert len(result["test_messages"]) == 1 - assert isinstance(result["test_messages"][0], ToolMessage) - assert result["test_messages"][0].content == expected_output - mock_container.run_test.assert_called_once() + assert isinstance(result, dict) + assert "test_messages" in result + assert len(result["test_messages"]) == 1 + assert isinstance(result["test_messages"][0], ToolMessage) + assert result["test_messages"][0].content == expected_output + mock_container.run_test.assert_called_once() diff --git a/tests/lang_graph/subgraphs/test_bug_fix_verification_subgraph.py b/tests/lang_graph/subgraphs/test_bug_fix_verification_subgraph.py index 14ca4346..f7e5f25a 100644 --- a/tests/lang_graph/subgraphs/test_bug_fix_verification_subgraph.py +++ b/tests/lang_graph/subgraphs/test_bug_fix_verification_subgraph.py @@ -10,18 +10,18 @@ @pytest.fixture def mock_container(): - return Mock(spec=BaseContainer) + return Mock(spec=BaseContainer) @pytest.fixture def mock_checkpointer(): - return Mock(spec=BaseCheckpointSaver) + return Mock(spec=BaseCheckpointSaver) def test_bug_fix_verification_subgraph_basic_initialization(mock_container): - """Test that BugFixVerificationSubgraph initializes correctly with basic components.""" - fake_model = FakeListChatWithToolsModel(responses=[]) + """Test that BugFixVerificationSubgraph initializes correctly with basic components.""" + fake_model = FakeListChatWithToolsModel(responses=[]) - subgraph = BugFixVerificationSubgraph(fake_model, mock_container) + subgraph = BugFixVerificationSubgraph(fake_model, mock_container) - assert subgraph.subgraph is not None + assert subgraph.subgraph is not None diff --git a/tests/lang_graph/subgraphs/test_bug_reproduction_subgraph.py b/tests/lang_graph/subgraphs/test_bug_reproduction_subgraph.py index 9c678190..eee385f9 100644 --- a/tests/lang_graph/subgraphs/test_bug_reproduction_subgraph.py +++ b/tests/lang_graph/subgraphs/test_bug_reproduction_subgraph.py @@ -12,44 +12,44 @@ @pytest.fixture def mock_container(): - return Mock(spec=BaseContainer) + return Mock(spec=BaseContainer) @pytest.fixture def mock_kg(): - kg = Mock(spec=KnowledgeGraph) - kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"] - return kg + kg = Mock(spec=KnowledgeGraph) + kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"] + return kg @pytest.fixture def mock_git_repo(): - return Mock(spec=GitRepository) + return Mock(spec=GitRepository) @pytest.fixture def mock_neo4j_driver(): - return Mock(spec=neo4j.Driver) + return Mock(spec=neo4j.Driver) def test_bug_reproduction_subgraph_basic_initialization( - mock_container, mock_kg, mock_git_repo, mock_neo4j_driver + mock_container, mock_kg, mock_git_repo, mock_neo4j_driver ): - """Test that BugReproductionSubgraph initializes correctly with basic components.""" - # Initialize fake model with empty responses - fake_advanced_model = FakeListChatWithToolsModel(responses=[]) - fake_base_model = FakeListChatWithToolsModel(responses=[]) - - # Initialize the subgraph - subgraph = BugReproductionSubgraph( - fake_advanced_model, - fake_base_model, - mock_container, - mock_kg, - mock_git_repo, - mock_neo4j_driver, - 0, - ) - - # Verify the subgraph was created - assert subgraph.subgraph is not None + """Test that BugReproductionSubgraph initializes correctly with basic components.""" + # Initialize fake model with empty responses + fake_advanced_model = FakeListChatWithToolsModel(responses=[]) + fake_base_model = FakeListChatWithToolsModel(responses=[]) + + # Initialize the subgraph + subgraph = BugReproductionSubgraph( + fake_advanced_model, + fake_base_model, + mock_container, + mock_kg, + mock_git_repo, + mock_neo4j_driver, + 0, + ) + + # Verify the subgraph was created + assert subgraph.subgraph is not None diff --git a/tests/lang_graph/subgraphs/test_build_and_test_subgraph.py b/tests/lang_graph/subgraphs/test_build_and_test_subgraph.py index 6586fc66..f955e73a 100644 --- a/tests/lang_graph/subgraphs/test_build_and_test_subgraph.py +++ b/tests/lang_graph/subgraphs/test_build_and_test_subgraph.py @@ -10,38 +10,38 @@ @pytest.fixture def mock_container(): - return Mock(spec=BaseContainer) + return Mock(spec=BaseContainer) @pytest.fixture def mock_kg(): - return Mock(spec=KnowledgeGraph) + return Mock(spec=KnowledgeGraph) def test_build_and_test_subgraph_basic_initialization(mock_container, mock_kg): - """Test that BuildAndTestSubgraph initializes correctly with basic components.""" - # Initialize fake model with empty responses - fake_model = FakeListChatWithToolsModel(responses=[]) + """Test that BuildAndTestSubgraph initializes correctly with basic components.""" + # Initialize fake model with empty responses + fake_model = FakeListChatWithToolsModel(responses=[]) - # Initialize the subgraph - subgraph = BuildAndTestSubgraph(container=mock_container, model=fake_model, kg=mock_kg) + # Initialize the subgraph + subgraph = BuildAndTestSubgraph(container=mock_container, model=fake_model, kg=mock_kg) - # Verify the subgraph was created - assert subgraph.subgraph is not None + # Verify the subgraph was created + assert subgraph.subgraph is not None def test_build_and_test_subgraph_with_commands(mock_container, mock_kg): - """Test that BuildAndTestSubgraph initializes correctly with build and test commands.""" - fake_model = FakeListChatWithToolsModel(responses=[]) - build_commands = ["make build"] - test_commands = ["make test"] - - subgraph = BuildAndTestSubgraph( - container=mock_container, - model=fake_model, - kg=mock_kg, - build_commands=build_commands, - test_commands=test_commands, - ) - - assert subgraph.subgraph is not None + """Test that BuildAndTestSubgraph initializes correctly with build and test commands.""" + fake_model = FakeListChatWithToolsModel(responses=[]) + build_commands = ["make build"] + test_commands = ["make test"] + + subgraph = BuildAndTestSubgraph( + container=mock_container, + model=fake_model, + kg=mock_kg, + build_commands=build_commands, + test_commands=test_commands, + ) + + assert subgraph.subgraph is not None diff --git a/tests/lang_graph/subgraphs/test_issue_bug_subgraph.py b/tests/lang_graph/subgraphs/test_issue_bug_subgraph.py index a53b7a76..5ab87faf 100644 --- a/tests/lang_graph/subgraphs/test_issue_bug_subgraph.py +++ b/tests/lang_graph/subgraphs/test_issue_bug_subgraph.py @@ -12,69 +12,69 @@ @pytest.fixture def mock_container(): - return Mock(spec=BaseContainer) + return Mock(spec=BaseContainer) @pytest.fixture def mock_kg(): - kg = Mock(spec=KnowledgeGraph) - # Configure the mock to return a list of AST node types - kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"] - return kg + kg = Mock(spec=KnowledgeGraph) + # Configure the mock to return a list of AST node types + kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"] + return kg @pytest.fixture def mock_git_repo(): - return Mock(spec=GitRepository) + return Mock(spec=GitRepository) @pytest.fixture def mock_neo4j_driver(): - return Mock(spec=neo4j.Driver) + return Mock(spec=neo4j.Driver) def test_issue_bug_subgraph_basic_initialization( - mock_container, mock_kg, mock_git_repo, mock_neo4j_driver + mock_container, mock_kg, mock_git_repo, mock_neo4j_driver ): - """Test that IssueBugSubgraph initializes correctly with basic components.""" - # Initialize fake model with empty responses - fake_advanced_model = FakeListChatWithToolsModel(responses=[]) - fake_base_model = FakeListChatWithToolsModel(responses=[]) - - # Initialize the subgraph with required parameters - subgraph = IssueBugSubgraph( - advanced_model=fake_advanced_model, - base_model=fake_base_model, - container=mock_container, - kg=mock_kg, - git_repo=mock_git_repo, - neo4j_driver=mock_neo4j_driver, - max_token_per_neo4j_result=1000, - ) - - # Verify the subgraph was created - assert subgraph.subgraph is not None + """Test that IssueBugSubgraph initializes correctly with basic components.""" + # Initialize fake model with empty responses + fake_advanced_model = FakeListChatWithToolsModel(responses=[]) + fake_base_model = FakeListChatWithToolsModel(responses=[]) + + # Initialize the subgraph with required parameters + subgraph = IssueBugSubgraph( + advanced_model=fake_advanced_model, + base_model=fake_base_model, + container=mock_container, + kg=mock_kg, + git_repo=mock_git_repo, + neo4j_driver=mock_neo4j_driver, + max_token_per_neo4j_result=1000, + ) + + # Verify the subgraph was created + assert subgraph.subgraph is not None def test_issue_bug_subgraph_with_commands( - mock_container, mock_kg, mock_git_repo, mock_neo4j_driver + mock_container, mock_kg, mock_git_repo, mock_neo4j_driver ): - """Test that IssueBugSubgraph initializes correctly with build and test commands.""" - fake_advanced_model = FakeListChatWithToolsModel(responses=[]) - fake_base_model = FakeListChatWithToolsModel(responses=[]) - build_commands = ["make build"] - test_commands = ["make test"] - - subgraph = IssueBugSubgraph( - advanced_model=fake_advanced_model, - base_model=fake_base_model, - container=mock_container, - kg=mock_kg, - git_repo=mock_git_repo, - neo4j_driver=mock_neo4j_driver, - max_token_per_neo4j_result=1000, - build_commands=build_commands, - test_commands=test_commands, - ) - - assert subgraph.subgraph is not None + """Test that IssueBugSubgraph initializes correctly with build and test commands.""" + fake_advanced_model = FakeListChatWithToolsModel(responses=[]) + fake_base_model = FakeListChatWithToolsModel(responses=[]) + build_commands = ["make build"] + test_commands = ["make test"] + + subgraph = IssueBugSubgraph( + advanced_model=fake_advanced_model, + base_model=fake_base_model, + container=mock_container, + kg=mock_kg, + git_repo=mock_git_repo, + neo4j_driver=mock_neo4j_driver, + max_token_per_neo4j_result=1000, + build_commands=build_commands, + test_commands=test_commands, + ) + + assert subgraph.subgraph is not None diff --git a/tests/lang_graph/subgraphs/test_issue_classification_subgraph.py b/tests/lang_graph/subgraphs/test_issue_classification_subgraph.py index f9dec95e..59047678 100644 --- a/tests/lang_graph/subgraphs/test_issue_classification_subgraph.py +++ b/tests/lang_graph/subgraphs/test_issue_classification_subgraph.py @@ -5,33 +5,36 @@ from prometheus.graph.knowledge_graph import KnowledgeGraph from prometheus.lang_graph.subgraphs.issue_classification_subgraph import ( - IssueClassificationSubgraph, + IssueClassificationSubgraph, ) from tests.test_utils.util import FakeListChatWithToolsModel @pytest.fixture def mock_kg(): - kg = Mock(spec=KnowledgeGraph) - # Configure the mock to return a list of AST node types - kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"] - return kg + kg = Mock(spec=KnowledgeGraph) + # Configure the mock to return a list of AST node types + kg.get_all_ast_node_types.return_value = ["FunctionDef", "ClassDef", "Module", "Import", "Call"] + return kg @pytest.fixture def mock_neo4j_driver(): - return Mock(spec=neo4j.Driver) + return Mock(spec=neo4j.Driver) def test_issue_classification_subgraph_basic_initialization(mock_kg, mock_neo4j_driver): - """Test that IssueClassificationSubgraph initializes correctly with basic components.""" - # Initialize fake model with empty responses - fake_model = FakeListChatWithToolsModel(responses=[]) - - # Initialize the subgraph with required parameters - subgraph = IssueClassificationSubgraph( - model=fake_model, kg=mock_kg, neo4j_driver=mock_neo4j_driver, max_token_per_neo4j_result=1000 - ) - - # Verify the subgraph was created - assert subgraph.subgraph is not None + """Test that IssueClassificationSubgraph initializes correctly with basic components.""" + # Initialize fake model with empty responses + fake_model = FakeListChatWithToolsModel(responses=[]) + + # Initialize the subgraph with required parameters + subgraph = IssueClassificationSubgraph( + model=fake_model, + kg=mock_kg, + neo4j_driver=mock_neo4j_driver, + max_token_per_neo4j_result=1000, + ) + + # Verify the subgraph was created + assert subgraph.subgraph is not None diff --git a/tests/neo4j/test_knowledge_graph_handler.py b/tests/neo4j/test_knowledge_graph_handler.py index 01f27491..c51ae850 100644 --- a/tests/neo4j/test_knowledge_graph_handler.py +++ b/tests/neo4j/test_knowledge_graph_handler.py @@ -5,129 +5,129 @@ from prometheus.neo4j.knowledge_graph_handler import KnowledgeGraphHandler from tests.test_utils import test_project_paths from tests.test_utils.fixtures import ( # noqa: F401 - empty_neo4j_container_fixture, - neo4j_container_with_kg_fixture, + empty_neo4j_container_fixture, + neo4j_container_with_kg_fixture, ) @pytest.mark.slow def test_num_metadata_node(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) + neo4j_container, _ = neo4j_container_with_kg_fixture + handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) - with neo4j_container.get_driver() as driver: - with driver.session() as session: - read_metadata_node = session.execute_read(handler._read_metadata_node) - assert isinstance(read_metadata_node, MetadataNode) + with neo4j_container.get_driver() as driver: + with driver.session() as session: + read_metadata_node = session.execute_read(handler._read_metadata_node) + assert isinstance(read_metadata_node, MetadataNode) @pytest.mark.slow def test_num_ast_nodes(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) + neo4j_container, _ = neo4j_container_with_kg_fixture + handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) - with neo4j_container.get_driver() as driver: - with driver.session() as session: - read_ast_nodes = session.execute_read(handler._read_ast_nodes) - assert len(read_ast_nodes) == 84 + with neo4j_container.get_driver() as driver: + with driver.session() as session: + read_ast_nodes = session.execute_read(handler._read_ast_nodes) + assert len(read_ast_nodes) == 84 @pytest.mark.slow def test_num_file_nodes(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) + neo4j_container, _ = neo4j_container_with_kg_fixture + handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) - with neo4j_container.get_driver() as driver: - with driver.session() as session: - read_file_nodes = session.execute_read(handler._read_file_nodes) - assert len(read_file_nodes) == 9 + with neo4j_container.get_driver() as driver: + with driver.session() as session: + read_file_nodes = session.execute_read(handler._read_file_nodes) + assert len(read_file_nodes) == 9 @pytest.mark.slow def test_num_text_nodes(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) + neo4j_container, _ = neo4j_container_with_kg_fixture + handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) - with neo4j_container.get_driver() as driver: - with driver.session() as session: - read_text_nodes = session.execute_read(handler._read_text_nodes) - assert len(read_text_nodes) == 2 + with neo4j_container.get_driver() as driver: + with driver.session() as session: + read_text_nodes = session.execute_read(handler._read_text_nodes) + assert len(read_text_nodes) == 2 @pytest.mark.slow def test_num_parent_of_edges(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) + neo4j_container, _ = neo4j_container_with_kg_fixture + handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) - with neo4j_container.get_driver() as driver: - with driver.session() as session: - read_parent_of_edges = session.execute_read(handler._read_parent_of_edges) - assert len(read_parent_of_edges) == 81 + with neo4j_container.get_driver() as driver: + with driver.session() as session: + read_parent_of_edges = session.execute_read(handler._read_parent_of_edges) + assert len(read_parent_of_edges) == 81 @pytest.mark.slow def test_num_has_file_edges(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) + neo4j_container, _ = neo4j_container_with_kg_fixture + handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) - with neo4j_container.get_driver() as driver: - with driver.session() as session: - read_has_file_edges = session.execute_read(handler._read_has_file_edges) - assert len(read_has_file_edges) == 8 + with neo4j_container.get_driver() as driver: + with driver.session() as session: + read_has_file_edges = session.execute_read(handler._read_has_file_edges) + assert len(read_has_file_edges) == 8 @pytest.mark.slow def test_num_has_ast_edges(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) + neo4j_container, _ = neo4j_container_with_kg_fixture + handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) - with neo4j_container.get_driver() as driver: - with driver.session() as session: - read_has_ast_edges = session.execute_read(handler._read_has_ast_edges) - assert len(read_has_ast_edges) == 3 + with neo4j_container.get_driver() as driver: + with driver.session() as session: + read_has_ast_edges = session.execute_read(handler._read_has_ast_edges) + assert len(read_has_ast_edges) == 3 @pytest.mark.slow def test_num_has_text_edges(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) + neo4j_container, _ = neo4j_container_with_kg_fixture + handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) - with neo4j_container.get_driver() as driver: - with driver.session() as session: - read_has_text_edges = session.execute_read(handler._read_has_text_edges) - assert len(read_has_text_edges) == 2 + with neo4j_container.get_driver() as driver: + with driver.session() as session: + read_has_text_edges = session.execute_read(handler._read_has_text_edges) + assert len(read_has_text_edges) == 2 @pytest.mark.slow def test_num_next_chunk_edges(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) + neo4j_container, _ = neo4j_container_with_kg_fixture + handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) - with neo4j_container.get_driver() as driver: - with driver.session() as session: - read_next_chunk_edges = session.execute_read(handler._read_next_chunk_edges) - assert len(read_next_chunk_edges) == 1 + with neo4j_container.get_driver() as driver: + with driver.session() as session: + read_next_chunk_edges = session.execute_read(handler._read_next_chunk_edges) + assert len(read_next_chunk_edges) == 1 @pytest.mark.slow def test_knowledge_graph_exists(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) + neo4j_container, _ = neo4j_container_with_kg_fixture + handler = KnowledgeGraphHandler(neo4j_container.get_driver(), 100) - assert handler.knowledge_graph_exists() + assert handler.knowledge_graph_exists() @pytest.mark.slow def test_clear_knowledge_graph(empty_neo4j_container_fixture): # noqa: F811 - kg = KnowledgeGraph(1000, 1000, 100) - kg.build_graph(test_project_paths.TEST_PROJECT_PATH) + kg = KnowledgeGraph(1000, 1000, 100) + kg.build_graph(test_project_paths.TEST_PROJECT_PATH) - driver = empty_neo4j_container_fixture.get_driver() - handler = KnowledgeGraphHandler(driver, 100) - handler.write_knowledge_graph(kg) + driver = empty_neo4j_container_fixture.get_driver() + handler = KnowledgeGraphHandler(driver, 100) + handler.write_knowledge_graph(kg) - assert handler.knowledge_graph_exists() + assert handler.knowledge_graph_exists() - handler.clear_knowledge_graph() + handler.clear_knowledge_graph() - assert not handler.knowledge_graph_exists() + assert not handler.knowledge_graph_exists() diff --git a/tests/parser/test_file_types.py b/tests/parser/test_file_types.py index 8463fffe..25503e59 100644 --- a/tests/parser/test_file_types.py +++ b/tests/parser/test_file_types.py @@ -6,29 +6,29 @@ @pytest.mark.parametrize( - "file_path,expected_type", - [ - ("script.sh", FileType.BASH), - ("program.c", FileType.C), - ("app.cs", FileType.CSHARP), - ("class.cpp", FileType.CPP), - ("header.cc", FileType.CPP), - ("file.cxx", FileType.CPP), - ("main.go", FileType.GO), - ("Class.java", FileType.JAVA), - ("app.js", FileType.JAVASCRIPT), - ("Service.kt", FileType.KOTLIN), - ("index.php", FileType.PHP), - ("script.py", FileType.PYTHON), - ("query.sql", FileType.SQL), - ("config.yaml", FileType.YAML), - ("docker-compose.yml", FileType.YAML), - ("readme.md", FileType.UNKNOWN), - ("Makefile", FileType.UNKNOWN), - ("", FileType.UNKNOWN), - ], + "file_path,expected_type", + [ + ("script.sh", FileType.BASH), + ("program.c", FileType.C), + ("app.cs", FileType.CSHARP), + ("class.cpp", FileType.CPP), + ("header.cc", FileType.CPP), + ("file.cxx", FileType.CPP), + ("main.go", FileType.GO), + ("Class.java", FileType.JAVA), + ("app.js", FileType.JAVASCRIPT), + ("Service.kt", FileType.KOTLIN), + ("index.php", FileType.PHP), + ("script.py", FileType.PYTHON), + ("query.sql", FileType.SQL), + ("config.yaml", FileType.YAML), + ("docker-compose.yml", FileType.YAML), + ("readme.md", FileType.UNKNOWN), + ("Makefile", FileType.UNKNOWN), + ("", FileType.UNKNOWN), + ], ) def test_file_type_from_path(file_path: str, expected_type: FileType): - """Test that file extensions are correctly mapped to FileTypes""" - path = Path(file_path) - assert FileType.from_path(path) == expected_type + """Test that file extensions are correctly mapped to FileTypes""" + path = Path(file_path) + assert FileType.from_path(path) == expected_type diff --git a/tests/parser/test_tree_sitter_parser.py b/tests/parser/test_tree_sitter_parser.py index bdc1ff98..43132a21 100644 --- a/tests/parser/test_tree_sitter_parser.py +++ b/tests/parser/test_tree_sitter_parser.py @@ -6,81 +6,81 @@ from prometheus.parser.file_types import FileType from prometheus.parser.tree_sitter_parser import ( - FILE_TYPE_TO_LANG, - FileNotSupportedError, - parse, - supports_file, + FILE_TYPE_TO_LANG, + FileNotSupportedError, + parse, + supports_file, ) # Test fixtures @pytest.fixture def mock_python_file(): - return Path("test.py") + return Path("test.py") @pytest.fixture def mock_unsupported_file(): - return Path("test.unsupported") + return Path("test.unsupported") @pytest.fixture def mock_tree(): - tree = create_autospec(Tree, instance=True) - return tree + tree = create_autospec(Tree, instance=True) + return tree # Test supports_file function def test_supports_file_with_supported_type(mock_python_file): - with patch("prometheus.parser.file_types.FileType.from_path") as mock_from_path: - mock_from_path.return_value = FileType.PYTHON - assert supports_file(mock_python_file) is True - mock_from_path.assert_called_once_with(mock_python_file) + with patch("prometheus.parser.file_types.FileType.from_path") as mock_from_path: + mock_from_path.return_value = FileType.PYTHON + assert supports_file(mock_python_file) is True + mock_from_path.assert_called_once_with(mock_python_file) def test_supports_file_with_unsupported_type(mock_unsupported_file): - with patch("prometheus.parser.file_types.FileType.from_path") as mock_from_path: - mock_from_path.return_value = None - assert supports_file(mock_unsupported_file) is False - mock_from_path.assert_called_once_with(mock_unsupported_file) + with patch("prometheus.parser.file_types.FileType.from_path") as mock_from_path: + mock_from_path.return_value = None + assert supports_file(mock_unsupported_file) is False + mock_from_path.assert_called_once_with(mock_unsupported_file) def test_parse_python_file_successfully(mock_python_file, mock_tree): - mock_content = b'print("hello")' - m = mock_open(read_data=mock_content) + mock_content = b'print("hello")' + m = mock_open(read_data=mock_content) - with ( - patch("prometheus.parser.file_types.FileType.from_path") as mock_from_path, - patch("prometheus.parser.tree_sitter_parser.get_parser") as mock_get_parser, - patch.object(Path, "open", m), - ): - # Setup mocks - mock_from_path.return_value = FileType.PYTHON - mock_parser = mock_get_parser.return_value - mock_parser.parse.return_value = mock_tree + with ( + patch("prometheus.parser.file_types.FileType.from_path") as mock_from_path, + patch("prometheus.parser.tree_sitter_parser.get_parser") as mock_get_parser, + patch.object(Path, "open", m), + ): + # Setup mocks + mock_from_path.return_value = FileType.PYTHON + mock_parser = mock_get_parser.return_value + mock_parser.parse.return_value = mock_tree - # Test parse function - result = parse(mock_python_file) + # Test parse function + result = parse(mock_python_file) - # Verify results and interactions - assert result == mock_tree - mock_from_path.assert_called_once_with(mock_python_file) - mock_get_parser.assert_called_once_with("python") - mock_parser.parse.assert_called_once_with(mock_content) + # Verify results and interactions + assert result == mock_tree + mock_from_path.assert_called_once_with(mock_python_file) + mock_get_parser.assert_called_once_with("python") + mock_parser.parse.assert_called_once_with(mock_content) def test_parse_unsupported_file_raises_error(mock_unsupported_file): - with pytest.raises(FileNotSupportedError): - parse(mock_unsupported_file) + with pytest.raises(FileNotSupportedError): + parse(mock_unsupported_file) def test_parse_all_supported_languages(): - """Test that we can get parsers for all supported languages.""" - with patch("tree_sitter_languages.get_parser") as mock_get_parser: - mock_parser = MagicMock() - mock_get_parser.return_value = mock_parser - - for lang in FILE_TYPE_TO_LANG.values(): - mock_get_parser(lang) - mock_get_parser.assert_called_with(lang) - mock_get_parser.reset_mock() + """Test that we can get parsers for all supported languages.""" + with patch("tree_sitter_languages.get_parser") as mock_get_parser: + mock_parser = MagicMock() + mock_get_parser.return_value = mock_parser + + for lang in FILE_TYPE_TO_LANG.values(): + mock_get_parser(lang) + mock_get_parser.assert_called_with(lang) + mock_get_parser.reset_mock() diff --git a/tests/test_utils/fixtures.py b/tests/test_utils/fixtures.py index c2dddadd..737da765 100644 --- a/tests/test_utils/fixtures.py +++ b/tests/test_utils/fixtures.py @@ -18,42 +18,42 @@ @pytest.fixture(scope="session") def neo4j_container_with_kg_fixture(): - kg = KnowledgeGraph(1000, 100, 10) - kg.build_graph(test_project_paths.TEST_PROJECT_PATH) - container = ( - Neo4jContainer(image=NEO4J_IMAGE, username=NEO4J_USERNAME, password=NEO4J_PASSWORD) - .with_env("NEO4J_PLUGINS", '["apoc"]') - .with_name(f"neo4j_container_with_kg_{uuid.uuid4().hex[:12]}") - ) - with container as neo4j_container: - driver = neo4j_container.get_driver() - handler = KnowledgeGraphHandler(driver, 100) - handler.write_knowledge_graph(kg) - yield neo4j_container, kg + kg = KnowledgeGraph(1000, 100, 10) + kg.build_graph(test_project_paths.TEST_PROJECT_PATH) + container = ( + Neo4jContainer(image=NEO4J_IMAGE, username=NEO4J_USERNAME, password=NEO4J_PASSWORD) + .with_env("NEO4J_PLUGINS", '["apoc"]') + .with_name(f"neo4j_container_with_kg_{uuid.uuid4().hex[:12]}") + ) + with container as neo4j_container: + driver = neo4j_container.get_driver() + handler = KnowledgeGraphHandler(driver, 100) + handler.write_knowledge_graph(kg) + yield neo4j_container, kg @pytest.fixture(scope="session") def empty_neo4j_container_fixture(): - container = ( - Neo4jContainer(image=NEO4J_IMAGE, username=NEO4J_USERNAME, password=NEO4J_PASSWORD) - .with_env("NEO4J_PLUGINS", '["apoc"]') - .with_name(f"empty_neo4j_container_{uuid.uuid4().hex[:12]}") - ) - with container as neo4j_container: - yield neo4j_container + container = ( + Neo4jContainer(image=NEO4J_IMAGE, username=NEO4J_USERNAME, password=NEO4J_PASSWORD) + .with_env("NEO4J_PLUGINS", '["apoc"]') + .with_name(f"empty_neo4j_container_{uuid.uuid4().hex[:12]}") + ) + with container as neo4j_container: + yield neo4j_container @pytest.fixture(scope="function") def git_repo_fixture(): - temp_dir = Path(tempfile.mkdtemp()) - temp_project_dir = temp_dir / "test_project" - original_project_path = test_project_paths.TEST_PROJECT_PATH - - try: - shutil.copytree(original_project_path, temp_project_dir) - shutil.move(temp_project_dir / test_project_paths.GIT_DIR.name, temp_project_dir / ".git") - - repo = Repo(temp_project_dir) - yield repo - finally: - shutil.rmtree(temp_project_dir) + temp_dir = Path(tempfile.mkdtemp()) + temp_project_dir = temp_dir / "test_project" + original_project_path = test_project_paths.TEST_PROJECT_PATH + + try: + shutil.copytree(original_project_path, temp_project_dir) + shutil.move(temp_project_dir / test_project_paths.GIT_DIR.name, temp_project_dir / ".git") + + repo = Repo(temp_project_dir) + yield repo + finally: + shutil.rmtree(temp_project_dir) diff --git a/tests/test_utils/util.py b/tests/test_utils/util.py index 607cea35..c0d4c38f 100644 --- a/tests/test_utils/util.py +++ b/tests/test_utils/util.py @@ -3,11 +3,11 @@ class FakeListChatWithToolsModel(FakeListChatModel): - def bind_tools(self, tools=None, tool_choice=None): - return self + def bind_tools(self, tools=None, tool_choice=None): + return self def clean_neo4j_container(neo4j_container: Neo4jContainer): - with neo4j_container.get_driver() as driver: - with driver.session() as session: - session.run("MATCH (n) DETACH DELETE n") + with neo4j_container.get_driver() as driver: + with driver.session() as session: + session.run("MATCH (n) DETACH DELETE n") diff --git a/tests/tools/test_file_operation.py b/tests/tools/test_file_operation.py index b3cb206a..d172fd78 100644 --- a/tests/tools/test_file_operation.py +++ b/tests/tools/test_file_operation.py @@ -1,121 +1,121 @@ import pytest from prometheus.tools.file_operation import ( - create_file, - delete, - edit_file, - read_file, - read_file_with_line_numbers, + create_file, + delete, + edit_file, + read_file, + read_file_with_line_numbers, ) @pytest.fixture def temp_test_dir(tmp_path): - """Create a temporary test directory.""" - test_dir = tmp_path / "test_files" - test_dir.mkdir() - yield test_dir - # Cleanup happens automatically after tests due to tmp_path fixture + """Create a temporary test directory.""" + test_dir = tmp_path / "test_files" + test_dir.mkdir() + yield test_dir + # Cleanup happens automatically after tests due to tmp_path fixture def test_create_and_read_file(temp_test_dir): - """Test creating a file and reading its contents.""" - test_file = temp_test_dir / "test.txt" - content = "line 1\nline 2\nline 3" + """Test creating a file and reading its contents.""" + test_file = temp_test_dir / "test.txt" + content = "line 1\nline 2\nline 3" - # Test create_file - result = create_file("test.txt", str(temp_test_dir), content) - assert test_file.exists() - assert test_file.read_text() == content - assert result == "The file test.txt has been created." + # Test create_file + result = create_file("test.txt", str(temp_test_dir), content) + assert test_file.exists() + assert test_file.read_text() == content + assert result == "The file test.txt has been created." - # Test read_file - result = read_file("test.txt", str(temp_test_dir)) - expected = "1. line 1\n2. line 2\n3. line 3" - assert result == expected + # Test read_file + result = read_file("test.txt", str(temp_test_dir)) + expected = "1. line 1\n2. line 2\n3. line 3" + assert result == expected def test_read_file_nonexistent(temp_test_dir): - """Test reading a nonexistent file.""" - result = read_file("nonexistent_file.txt", str(temp_test_dir)) - assert result == "The file nonexistent_file.txt does not exist." + """Test reading a nonexistent file.""" + result = read_file("nonexistent_file.txt", str(temp_test_dir)) + assert result == "The file nonexistent_file.txt does not exist." def test_read_file_with_line_numbers(temp_test_dir): - """Test reading specific line ranges from a file.""" - content = "line 1\nline 2\nline 3\nline 4\nline 5" - create_file("test_lines.txt", str(temp_test_dir), content) + """Test reading specific line ranges from a file.""" + content = "line 1\nline 2\nline 3\nline 4\nline 5" + create_file("test_lines.txt", str(temp_test_dir), content) - # Test reading specific lines - result = read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 2, 4) - expected = "2. line 2\n3. line 3" - assert result == expected + # Test reading specific lines + result = read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 2, 4) + expected = "2. line 2\n3. line 3" + assert result == expected - # Test invalid range - result = read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 4, 2) - assert result == "The end line number 2 must be greater than the start line number 4." + # Test invalid range + result = read_file_with_line_numbers("test_lines.txt", str(temp_test_dir), 4, 2) + assert result == "The end line number 2 must be greater than the start line number 4." def test_delete(temp_test_dir): - """Test file and directory deletion.""" - # Test file deletion - test_file = temp_test_dir / "to_delete.txt" - create_file("to_delete.txt", str(temp_test_dir), "content") - assert test_file.exists() - result = delete("to_delete.txt", str(temp_test_dir)) - assert result == "The file to_delete.txt has been deleted." - assert not test_file.exists() - - # Test directory deletion - test_subdir = temp_test_dir / "subdir" - test_subdir.mkdir() - create_file("subdir/file.txt", str(temp_test_dir), "content") - result = delete("subdir", str(temp_test_dir)) - assert result == "The directory subdir has been deleted." - assert not test_subdir.exists() + """Test file and directory deletion.""" + # Test file deletion + test_file = temp_test_dir / "to_delete.txt" + create_file("to_delete.txt", str(temp_test_dir), "content") + assert test_file.exists() + result = delete("to_delete.txt", str(temp_test_dir)) + assert result == "The file to_delete.txt has been deleted." + assert not test_file.exists() + + # Test directory deletion + test_subdir = temp_test_dir / "subdir" + test_subdir.mkdir() + create_file("subdir/file.txt", str(temp_test_dir), "content") + result = delete("subdir", str(temp_test_dir)) + assert result == "The directory subdir has been deleted." + assert not test_subdir.exists() def test_delete_nonexistent(temp_test_dir): - """Test deleting a nonexistent path.""" - result = delete("nonexistent_path", str(temp_test_dir)) - assert result == "The file nonexistent_path does not exist." + """Test deleting a nonexistent path.""" + result = delete("nonexistent_path", str(temp_test_dir)) + assert result == "The file nonexistent_path does not exist." def test_edit_file(temp_test_dir): - """Test editing specific lines in a file.""" - # Test case 1: Successfully edit a single occurrence - initial_content = "line 1\nline 2\nline 3\nline 4\nline 5" - create_file("edit_test.txt", str(temp_test_dir), initial_content) - result = edit_file("edit_test.txt", str(temp_test_dir), "line 2", "new line 2") - assert result == "Successfully edited edit_test.txt." - - # Test case 2: Absolute path error - result = edit_file("/edit_test.txt", str(temp_test_dir), "line 2", "new line 2") - assert result == "relative_path: /edit_test.txt is a abolsute path, not relative path." - - # Test case 3: File doesn't exist - result = edit_file("nonexistent.txt", str(temp_test_dir), "line 2", "new line 2") - assert result == "The file nonexistent.txt does not exist." - - # Test case 4: No matches found - result = edit_file("edit_test.txt", str(temp_test_dir), "nonexistent line", "new content") - assert ( - result - == "No match found for the specified content in edit_test.txt. Please verify the content to replace." - ) - - # Test case 5: Multiple occurrences - duplicate_content = "line 1\nline 2\nline 2\nline 3" - create_file("duplicate_test.txt", str(temp_test_dir), duplicate_content) - result = edit_file("duplicate_test.txt", str(temp_test_dir), "line 2", "new line 2") - assert ( - result - == "Found 2 occurrences of the specified content in duplicate_test.txt. Please provide more context to ensure a unique match." - ) + """Test editing specific lines in a file.""" + # Test case 1: Successfully edit a single occurrence + initial_content = "line 1\nline 2\nline 3\nline 4\nline 5" + create_file("edit_test.txt", str(temp_test_dir), initial_content) + result = edit_file("edit_test.txt", str(temp_test_dir), "line 2", "new line 2") + assert result == "Successfully edited edit_test.txt." + + # Test case 2: Absolute path error + result = edit_file("/edit_test.txt", str(temp_test_dir), "line 2", "new line 2") + assert result == "relative_path: /edit_test.txt is a abolsute path, not relative path." + + # Test case 3: File doesn't exist + result = edit_file("nonexistent.txt", str(temp_test_dir), "line 2", "new line 2") + assert result == "The file nonexistent.txt does not exist." + + # Test case 4: No matches found + result = edit_file("edit_test.txt", str(temp_test_dir), "nonexistent line", "new content") + assert ( + result + == "No match found for the specified content in edit_test.txt. Please verify the content to replace." + ) + + # Test case 5: Multiple occurrences + duplicate_content = "line 1\nline 2\nline 2\nline 3" + create_file("duplicate_test.txt", str(temp_test_dir), duplicate_content) + result = edit_file("duplicate_test.txt", str(temp_test_dir), "line 2", "new line 2") + assert ( + result + == "Found 2 occurrences of the specified content in duplicate_test.txt. Please provide more context to ensure a unique match." + ) def test_create_file_already_exists(temp_test_dir): - """Test creating a file that already exists.""" - create_file("existing.txt", str(temp_test_dir), "content") - result = create_file("existing.txt", str(temp_test_dir), "new content") - assert result == "The file existing.txt already exists." + """Test creating a file that already exists.""" + create_file("existing.txt", str(temp_test_dir), "content") + result = create_file("existing.txt", str(temp_test_dir), "new content") + assert result == "The file existing.txt already exists." diff --git a/tests/tools/test_graph_traversal.py b/tests/tools/test_graph_traversal.py index 803a4784..24d5ae71 100644 --- a/tests/tools/test_graph_traversal.py +++ b/tests/tools/test_graph_traversal.py @@ -7,271 +7,277 @@ @pytest.mark.slow def test_find_file_node_with_basename(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture + neo4j_container, _ = neo4j_container_with_kg_fixture - with neo4j_container.get_driver() as driver: - result = graph_traversal.find_file_node_with_basename( - test_project_paths.PYTHON_FILE.name, driver, 1000 - ) + with neo4j_container.get_driver() as driver: + result = graph_traversal.find_file_node_with_basename( + test_project_paths.PYTHON_FILE.name, driver, 1000 + ) - basename = test_project_paths.PYTHON_FILE.name - relative_path = str( - test_project_paths.PYTHON_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() - ) + basename = test_project_paths.PYTHON_FILE.name + relative_path = str( + test_project_paths.PYTHON_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() + ) - result_data = result[1] - assert len(result_data) == 1 - assert "FileNode" in result_data[0] - assert result_data[0]["FileNode"].get("basename", "") == basename - assert result_data[0]["FileNode"].get("relative_path", "") == relative_path + result_data = result[1] + assert len(result_data) == 1 + assert "FileNode" in result_data[0] + assert result_data[0]["FileNode"].get("basename", "") == basename + assert result_data[0]["FileNode"].get("relative_path", "") == relative_path @pytest.mark.slow def test_find_file_node_with_relative_path(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture + neo4j_container, _ = neo4j_container_with_kg_fixture - relative_path = str( - test_project_paths.MD_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() - ) - with neo4j_container.get_driver() as driver: - result = graph_traversal.find_file_node_with_relative_path(relative_path, driver, 1000) + relative_path = str( + test_project_paths.MD_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() + ) + with neo4j_container.get_driver() as driver: + result = graph_traversal.find_file_node_with_relative_path(relative_path, driver, 1000) - basename = test_project_paths.MD_FILE.name + basename = test_project_paths.MD_FILE.name - result_data = result[1] - assert len(result_data) == 1 - assert "FileNode" in result_data[0] - assert result_data[0]["FileNode"].get("basename", "") == basename - assert result_data[0]["FileNode"].get("relative_path", "") == relative_path + result_data = result[1] + assert len(result_data) == 1 + assert "FileNode" in result_data[0] + assert result_data[0]["FileNode"].get("basename", "") == basename + assert result_data[0]["FileNode"].get("relative_path", "") == relative_path @pytest.mark.slow def test_find_ast_node_with_text_in_file_with_basename(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture + neo4j_container, _ = neo4j_container_with_kg_fixture - basename = test_project_paths.PYTHON_FILE.name - with neo4j_container.get_driver() as driver: - result = graph_traversal.find_ast_node_with_text_in_file_with_basename( - "Hello world!", basename, driver, 1000 - ) + basename = test_project_paths.PYTHON_FILE.name + with neo4j_container.get_driver() as driver: + result = graph_traversal.find_ast_node_with_text_in_file_with_basename( + "Hello world!", basename, driver, 1000 + ) - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "ASTNode" in result_row - assert "Hello world!" in result_row["ASTNode"].get("text", "") - assert "FileNode" in result_row - assert result_row["FileNode"].get("basename", "") == basename + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "ASTNode" in result_row + assert "Hello world!" in result_row["ASTNode"].get("text", "") + assert "FileNode" in result_row + assert result_row["FileNode"].get("basename", "") == basename @pytest.mark.slow def test_find_ast_node_with_text_in_file_with_relative_path(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - - relative_path = str( - test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() - ) - with neo4j_container.get_driver() as driver: - result = graph_traversal.find_ast_node_with_text_in_file_with_relative_path( - "Hello world!", relative_path, driver, 1000 + neo4j_container, _ = neo4j_container_with_kg_fixture + + relative_path = str( + test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) + with neo4j_container.get_driver() as driver: + result = graph_traversal.find_ast_node_with_text_in_file_with_relative_path( + "Hello world!", relative_path, driver, 1000 + ) - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "ASTNode" in result_row - assert "Hello world!" in result_row["ASTNode"].get("text", "") - assert "FileNode" in result_row - assert result_row["FileNode"].get("relative_path", "") == relative_path + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "ASTNode" in result_row + assert "Hello world!" in result_row["ASTNode"].get("text", "") + assert "FileNode" in result_row + assert result_row["FileNode"].get("relative_path", "") == relative_path @pytest.mark.slow def test_find_ast_node_with_type_in_file_with_basename(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture + neo4j_container, _ = neo4j_container_with_kg_fixture - basename = test_project_paths.C_FILE.name - node_type = "function_definition" - with neo4j_container.get_driver() as driver: - result = graph_traversal.find_ast_node_with_type_in_file_with_basename( - node_type, basename, driver, 1000 - ) + basename = test_project_paths.C_FILE.name + node_type = "function_definition" + with neo4j_container.get_driver() as driver: + result = graph_traversal.find_ast_node_with_type_in_file_with_basename( + node_type, basename, driver, 1000 + ) - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "ASTNode" in result_row - assert result_row["ASTNode"].get("type", "") == node_type - assert "FileNode" in result_row - assert result_row["FileNode"].get("basename", "") == basename + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "ASTNode" in result_row + assert result_row["ASTNode"].get("type", "") == node_type + assert "FileNode" in result_row + assert result_row["FileNode"].get("basename", "") == basename @pytest.mark.slow def test_find_ast_node_with_type_in_file_with_relative_path(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - - relative_path = str( - test_project_paths.JAVA_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() - ) - node_type = "string_literal" - with neo4j_container.get_driver() as driver: - result = graph_traversal.find_ast_node_with_type_in_file_with_relative_path( - node_type, relative_path, driver, 1000 + neo4j_container, _ = neo4j_container_with_kg_fixture + + relative_path = str( + test_project_paths.JAVA_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() ) + node_type = "string_literal" + with neo4j_container.get_driver() as driver: + result = graph_traversal.find_ast_node_with_type_in_file_with_relative_path( + node_type, relative_path, driver, 1000 + ) - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "ASTNode" in result_row - assert result_row["ASTNode"].get("type", "") == node_type - assert "FileNode" in result_row - assert result_row["FileNode"].get("relative_path", "") == relative_path + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "ASTNode" in result_row + assert result_row["ASTNode"].get("type", "") == node_type + assert "FileNode" in result_row + assert result_row["FileNode"].get("relative_path", "") == relative_path @pytest.mark.slow def test_find_text_node_with_text(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture + neo4j_container, _ = neo4j_container_with_kg_fixture - text = "Text under header C" - with neo4j_container.get_driver() as driver: - result = graph_traversal.find_text_node_with_text(text, driver, 1000) + text = "Text under header C" + with neo4j_container.get_driver() as driver: + result = graph_traversal.find_text_node_with_text(text, driver, 1000) - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "TextNode" in result_row - assert text in result_row["TextNode"].get("text", "") - assert "FileNode" in result_row - assert result_row["FileNode"].get("relative_path", "") == "foo/test.md" + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "TextNode" in result_row + assert text in result_row["TextNode"].get("text", "") + assert "FileNode" in result_row + assert result_row["FileNode"].get("relative_path", "") == "foo/test.md" @pytest.mark.slow def test_find_text_node_with_text_in_file(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture + neo4j_container, _ = neo4j_container_with_kg_fixture - basename = test_project_paths.MD_FILE.name - text = "Text under header B" - with neo4j_container.get_driver() as driver: - result = graph_traversal.find_text_node_with_text_in_file(text, basename, driver, 1000) + basename = test_project_paths.MD_FILE.name + text = "Text under header B" + with neo4j_container.get_driver() as driver: + result = graph_traversal.find_text_node_with_text_in_file(text, basename, driver, 1000) - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "TextNode" in result_row - assert text in result_row["TextNode"].get("text", "") - assert "FileNode" in result_row - assert result_row["FileNode"].get("basename", "") == basename + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "TextNode" in result_row + assert text in result_row["TextNode"].get("text", "") + assert "FileNode" in result_row + assert result_row["FileNode"].get("basename", "") == basename @pytest.mark.slow def test_get_next_text_node_with_node_id(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture + neo4j_container, _ = neo4j_container_with_kg_fixture - node_id = 36 - with neo4j_container.get_driver() as driver: - result = graph_traversal.get_next_text_node_with_node_id(node_id, driver, 1000) + node_id = 36 + with neo4j_container.get_driver() as driver: + result = graph_traversal.get_next_text_node_with_node_id(node_id, driver, 1000) - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "TextNode" in result_row - assert "Text under header D" in result_row["TextNode"].get("text", "") - assert "FileNode" in result_row - assert result_row["FileNode"].get("relative_path", "") == "foo/test.md" + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "TextNode" in result_row + assert "Text under header D" in result_row["TextNode"].get("text", "") + assert "FileNode" in result_row + assert result_row["FileNode"].get("relative_path", "") == "foo/test.md" @pytest.mark.slow def test_preview_file_content_with_basename(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture + neo4j_container, _ = neo4j_container_with_kg_fixture - basename = test_project_paths.PYTHON_FILE.name - with neo4j_container.get_driver() as driver: - result = graph_traversal.preview_file_content_with_basename(basename, driver, 1000) + basename = test_project_paths.PYTHON_FILE.name + with neo4j_container.get_driver() as driver: + result = graph_traversal.preview_file_content_with_basename(basename, driver, 1000) - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "preview" in result_row - assert 'print("Hello world!")' in result_row["preview"].get("text", "") - assert "FileNode" in result_row - assert result_row["FileNode"].get("basename", "") == basename + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "preview" in result_row + assert 'print("Hello world!")' in result_row["preview"].get("text", "") + assert "FileNode" in result_row + assert result_row["FileNode"].get("basename", "") == basename - basename = test_project_paths.MD_FILE.name - with neo4j_container.get_driver() as driver: - result = graph_traversal.preview_file_content_with_basename(basename, driver, 1000) + basename = test_project_paths.MD_FILE.name + with neo4j_container.get_driver() as driver: + result = graph_traversal.preview_file_content_with_basename(basename, driver, 1000) - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "preview" in result_row - assert "Text under header A" in result_row["preview"] - assert "FileNode" in result_row - assert result_row["FileNode"].get("basename", "") == basename + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "preview" in result_row + assert "Text under header A" in result_row["preview"] + assert "FileNode" in result_row + assert result_row["FileNode"].get("basename", "") == basename @pytest.mark.slow def test_preview_file_content_with_relative_path(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - - relative_path = str( - test_project_paths.PYTHON_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() - ) - with neo4j_container.get_driver() as driver: - result = graph_traversal.preview_file_content_with_relative_path(relative_path, driver, 1000) - - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "preview" in result_row - assert 'print("Hello world!")' in result_row["preview"].get("text", "") - assert "FileNode" in result_row - assert result_row["FileNode"].get("relative_path", "") == relative_path - - relative_path = str( - test_project_paths.MD_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() - ) - with neo4j_container.get_driver() as driver: - result = graph_traversal.preview_file_content_with_relative_path(relative_path, driver, 1000) - - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "preview" in result_row - assert "Text under header A" in result_row["preview"] - assert "FileNode" in result_row - assert result_row["FileNode"].get("relative_path", "") == relative_path + neo4j_container, _ = neo4j_container_with_kg_fixture + + relative_path = str( + test_project_paths.PYTHON_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() + ) + with neo4j_container.get_driver() as driver: + result = graph_traversal.preview_file_content_with_relative_path( + relative_path, driver, 1000 + ) + + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "preview" in result_row + assert 'print("Hello world!")' in result_row["preview"].get("text", "") + assert "FileNode" in result_row + assert result_row["FileNode"].get("relative_path", "") == relative_path + + relative_path = str( + test_project_paths.MD_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() + ) + with neo4j_container.get_driver() as driver: + result = graph_traversal.preview_file_content_with_relative_path( + relative_path, driver, 1000 + ) + + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "preview" in result_row + assert "Text under header A" in result_row["preview"] + assert "FileNode" in result_row + assert result_row["FileNode"].get("relative_path", "") == relative_path @pytest.mark.slow def test_read_code_with_basename(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture + neo4j_container, _ = neo4j_container_with_kg_fixture - basename = test_project_paths.JAVA_FILE.name - with neo4j_container.get_driver() as driver: - result = graph_traversal.read_code_with_basename(basename, 2, 3, driver, 1000) + basename = test_project_paths.JAVA_FILE.name + with neo4j_container.get_driver() as driver: + result = graph_traversal.read_code_with_basename(basename, 2, 3, driver, 1000) - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "SelectedLines" in result_row - assert "public static void main(String[] args) {" in result_row["SelectedLines"].get("text", "") - assert "FileNode" in result_row - assert result_row["FileNode"].get("basename", "") == basename + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "SelectedLines" in result_row + assert "public static void main(String[] args) {" in result_row["SelectedLines"].get( + "text", "" + ) + assert "FileNode" in result_row + assert result_row["FileNode"].get("basename", "") == basename @pytest.mark.slow def test_read_code_with_relative_path(neo4j_container_with_kg_fixture): # noqa: F811 - neo4j_container, _ = neo4j_container_with_kg_fixture - - relative_path = str( - test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() - ) - with neo4j_container.get_driver() as driver: - result = graph_traversal.read_code_with_relative_path(relative_path, 5, 6, driver, 1000) - - result_data = result[1] - assert len(result_data) > 0 - for result_row in result_data: - assert "SelectedLines" in result_row - assert "return 0;" in result_row["SelectedLines"].get("text", "") - assert "FileNode" in result_row - assert result_row["FileNode"].get("relative_path", "") == relative_path + neo4j_container, _ = neo4j_container_with_kg_fixture + + relative_path = str( + test_project_paths.C_FILE.relative_to(test_project_paths.TEST_PROJECT_PATH).as_posix() + ) + with neo4j_container.get_driver() as driver: + result = graph_traversal.read_code_with_relative_path(relative_path, 5, 6, driver, 1000) + + result_data = result[1] + assert len(result_data) > 0 + for result_row in result_data: + assert "SelectedLines" in result_row + assert "return 0;" in result_row["SelectedLines"].get("text", "") + assert "FileNode" in result_row + assert result_row["FileNode"].get("relative_path", "") == relative_path diff --git a/tests/utils/test_issue_util.py b/tests/utils/test_issue_util.py index 6d1c5d4f..3e975deb 100644 --- a/tests/utils/test_issue_util.py +++ b/tests/utils/test_issue_util.py @@ -1,32 +1,32 @@ from prometheus.utils.issue_util import ( - format_issue_comments, - format_issue_info, - format_test_commands, + format_issue_comments, + format_issue_info, + format_test_commands, ) def test_format_issue_comments(): - comments = [ - {"username": "alice", "comment": "This looks good!"}, - {"username": "bob", "comment": "Can we add tests?"}, - ] + comments = [ + {"username": "alice", "comment": "This looks good!"}, + {"username": "bob", "comment": "Can we add tests?"}, + ] - result = format_issue_comments(comments) - expected = "alice: This looks good!\n\nbob: Can we add tests?" + result = format_issue_comments(comments) + expected = "alice: This looks good!\n\nbob: Can we add tests?" - assert result == expected + assert result == expected def test_format_issue_info(): - title = "Bug in login flow" - body = "Users can't login on mobile devices" - comments = [ - {"username": "alice", "comment": "I can reproduce this"}, - {"username": "bob", "comment": "Working on a fix"}, - ] - - result = format_issue_info(title, body, comments) - expected = """\ + title = "Bug in login flow" + body = "Users can't login on mobile devices" + comments = [ + {"username": "alice", "comment": "I can reproduce this"}, + {"username": "bob", "comment": "Working on a fix"}, + ] + + result = format_issue_info(title, body, comments) + expected = """\ Issue title: Bug in login flow @@ -38,13 +38,13 @@ def test_format_issue_info(): bob: Working on a fix""" - assert result == expected + assert result == expected def test_format_test_commands(): - commands = ["pytest test_login.py", "pytest test_auth.py -v"] + commands = ["pytest test_login.py", "pytest test_auth.py -v"] - result = format_test_commands(commands) - expected = "$ pytest test_login.py\n$ pytest test_auth.py -v" + result = format_test_commands(commands) + expected = "$ pytest test_login.py\n$ pytest test_auth.py -v" - assert result == expected + assert result == expected diff --git a/tests/utils/test_lang_graph_util.py b/tests/utils/test_lang_graph_util.py index ccac1e1e..fe7d0a3e 100644 --- a/tests/utils/test_lang_graph_util.py +++ b/tests/utils/test_lang_graph_util.py @@ -1,148 +1,148 @@ from langchain_core.messages import ( - AIMessage, - HumanMessage, - SystemMessage, - ToolMessage, + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, ) from prometheus.utils.lang_graph_util import ( - check_remaining_steps, - extract_ai_responses, - extract_human_queries, - extract_last_tool_messages, - format_agent_tool_message_history, - get_last_message_content, - str_token_counter, - tiktoken_counter, - truncate_messages, + check_remaining_steps, + extract_ai_responses, + extract_human_queries, + extract_last_tool_messages, + format_agent_tool_message_history, + get_last_message_content, + str_token_counter, + tiktoken_counter, + truncate_messages, ) # Test check_remaining_steps def test_check_remaining_steps(): - def mock_router(state): - return "next_step" + def mock_router(state): + return "next_step" - state_enough_steps = {"remaining_steps": 5} - state_low_steps = {"remaining_steps": 2} + state_enough_steps = {"remaining_steps": 5} + state_low_steps = {"remaining_steps": 2} - assert check_remaining_steps(state_enough_steps, mock_router, 3) == "next_step" - assert check_remaining_steps(state_low_steps, mock_router, 3) == "low_remaining_steps" + assert check_remaining_steps(state_enough_steps, mock_router, 3) == "next_step" + assert check_remaining_steps(state_low_steps, mock_router, 3) == "low_remaining_steps" # Test str_token_counter def test_str_token_counter(): - text = "Hello, world!" - token_count = str_token_counter(text) - assert isinstance(token_count, int) - assert token_count > 0 + text = "Hello, world!" + token_count = str_token_counter(text) + assert isinstance(token_count, int) + assert token_count > 0 # Test tiktoken_counter def test_tiktoken_counter(): - messages = [ - SystemMessage(content="System message"), - HumanMessage(content="Human message"), - AIMessage(content="AI response"), - ToolMessage(content="Tool message", tool_call_id="call_1"), - ] + messages = [ + SystemMessage(content="System message"), + HumanMessage(content="Human message"), + AIMessage(content="AI response"), + ToolMessage(content="Tool message", tool_call_id="call_1"), + ] - token_count = tiktoken_counter(messages) - assert isinstance(token_count, int) - assert token_count > 0 + token_count = tiktoken_counter(messages) + assert isinstance(token_count, int) + assert token_count > 0 # Test truncate_messages def test_truncate_messages(): - messages = [ - SystemMessage(content="System message"), - HumanMessage(content="Human message 1"), - AIMessage(content="AI response 1"), - HumanMessage(content="Human message 2"), - AIMessage(content="AI response 2"), - ] + messages = [ + SystemMessage(content="System message"), + HumanMessage(content="Human message 1"), + AIMessage(content="AI response 1"), + HumanMessage(content="Human message 2"), + AIMessage(content="AI response 2"), + ] - truncated = truncate_messages(messages, max_tokens=1000) - assert isinstance(truncated, (list, tuple)) - assert len(truncated) <= len(messages) + truncated = truncate_messages(messages, max_tokens=1000) + assert isinstance(truncated, (list, tuple)) + assert len(truncated) <= len(messages) # Test extract_ai_responses def test_extract_ai_responses(): - messages = [ - HumanMessage(content="Human 1"), - AIMessage(content="AI 1"), - HumanMessage(content="Human 2"), - AIMessage(content="AI 2"), - ] + messages = [ + HumanMessage(content="Human 1"), + AIMessage(content="AI 1"), + HumanMessage(content="Human 2"), + AIMessage(content="AI 2"), + ] - responses = extract_ai_responses(messages) - assert len(responses) == 2 - assert "AI 1" in responses - assert "AI 2" in responses + responses = extract_ai_responses(messages) + assert len(responses) == 2 + assert "AI 1" in responses + assert "AI 2" in responses # Test extract_human_queries def test_extract_human_queries(): - messages = [ - SystemMessage(content="System"), - HumanMessage(content="Human 1"), - AIMessage(content="AI 1"), - HumanMessage(content="Human 2"), - ] + messages = [ + SystemMessage(content="System"), + HumanMessage(content="Human 1"), + AIMessage(content="AI 1"), + HumanMessage(content="Human 2"), + ] - queries = extract_human_queries(messages) - assert len(queries) == 2 - assert "Human 1" in queries - assert "Human 2" in queries + queries = extract_human_queries(messages) + assert len(queries) == 2 + assert "Human 1" in queries + assert "Human 2" in queries # Test extract_last_tool_messages def test_extract_last_tool_messages(): - messages = [ - HumanMessage(content="Human 1"), - ToolMessage(content="Tool 1", tool_call_id="call_1"), - AIMessage(content="AI 1"), - HumanMessage(content="Human 2"), - ToolMessage(content="Tool 2", tool_call_id="call_2"), - ToolMessage(content="Tool 3", tool_call_id="call_3"), - ] + messages = [ + HumanMessage(content="Human 1"), + ToolMessage(content="Tool 1", tool_call_id="call_1"), + AIMessage(content="AI 1"), + HumanMessage(content="Human 2"), + ToolMessage(content="Tool 2", tool_call_id="call_2"), + ToolMessage(content="Tool 3", tool_call_id="call_3"), + ] - tool_messages = extract_last_tool_messages(messages) - assert len(tool_messages) == 2 - assert all(isinstance(msg, ToolMessage) for msg in tool_messages) - assert tool_messages[-1].content == "Tool 3" + tool_messages = extract_last_tool_messages(messages) + assert len(tool_messages) == 2 + assert all(isinstance(msg, ToolMessage) for msg in tool_messages) + assert tool_messages[-1].content == "Tool 3" # Test get_last_message_content def test_get_last_message_content(): - messages = [ - HumanMessage(content="Human"), - AIMessage(content="AI"), - ToolMessage(content="Last message", tool_call_id="call_1"), - ] + messages = [ + HumanMessage(content="Human"), + AIMessage(content="AI"), + ToolMessage(content="Last message", tool_call_id="call_1"), + ] - content = get_last_message_content(messages) - assert content == "Last message" + content = get_last_message_content(messages) + assert content == "Last message" def test_format_agent_tool_message_history(): - messages = [ - AIMessage(content="Let me analyze this"), - AIMessage( - content="I'll use a tool for this", - additional_kwargs={"tool_calls": [{"function": "analyze_data"}]}, - ), - ToolMessage(content="Analysis results: Success", tool_call_id="call_1"), - ] - - result = format_agent_tool_message_history(messages) - - expected = ( - "Assistant internal thought: Let me analyze this\n\n" - "Assistant internal thought: I'll use a tool for this\n\n" - "Assistant executed tool: analyze_data\n\n" - "Tool output: Analysis results: Success" - ) - - assert result == expected + messages = [ + AIMessage(content="Let me analyze this"), + AIMessage( + content="I'll use a tool for this", + additional_kwargs={"tool_calls": [{"function": "analyze_data"}]}, + ), + ToolMessage(content="Analysis results: Success", tool_call_id="call_1"), + ] + + result = format_agent_tool_message_history(messages) + + expected = ( + "Assistant internal thought: Let me analyze this\n\n" + "Assistant internal thought: I'll use a tool for this\n\n" + "Assistant executed tool: analyze_data\n\n" + "Tool output: Analysis results: Success" + ) + + assert result == expected diff --git a/tests/utils/test_neo4j_util.py b/tests/utils/test_neo4j_util.py index e0eb86b3..5aef7ee9 100644 --- a/tests/utils/test_neo4j_util.py +++ b/tests/utils/test_neo4j_util.py @@ -6,144 +6,142 @@ class MockResult: - def __init__(self, data_list): - self._data = data_list + def __init__(self, data_list): + self._data = data_list - def data(self): - return self._data + def data(self): + return self._data class MockSession: - def __init__(self): - self.execute_read = MagicMock() + def __init__(self): + self.execute_read = MagicMock() - def __enter__(self): - return self + def __enter__(self): + return self - def __exit__(self, exc_type, exc_val, exc_tb): - pass + def __exit__(self, exc_type, exc_val, exc_tb): + pass class MockDriver: - def __init__(self, session): - self._session = session + def __init__(self, session): + self._session = session - def session(self): - return self._session + def session(self): + return self._session @pytest.fixture def mock_neo4j_driver(): - session = MockSession() - driver = MockDriver(session) - return driver, session + session = MockSession() + driver = MockDriver(session) + return driver, session def test_format_neo4j_data_single_row(): - data = [{"name": "John", "age": 30}] + data = [{"name": "John", "age": 30}] - formatted = format_neo4j_data(data, 1000) - expected = "Result 1:\nage: 30\nname: John" + formatted = format_neo4j_data(data, 1000) + expected = "Result 1:\nage: 30\nname: John" - assert formatted == expected + assert formatted == expected def test_format_neo4j_result_multiple_rows(): - data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] + data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] - formatted = format_neo4j_data(data, 1000) - expected = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\nage: 25\nname: Jane" + formatted = format_neo4j_data(data, 1000) + expected = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\nage: 25\nname: Jane" - assert formatted == expected + assert formatted == expected def test_format_neo4j_result_empty(): - data = [] - formatted = format_neo4j_data(data, 1000) - assert formatted == EMPTY_DATA_MESSAGE + data = [] + formatted = format_neo4j_data(data, 1000) + assert formatted == EMPTY_DATA_MESSAGE def test_format_neo4j_result_different_keys(): - data = [{"name": "John", "age": 30}, {"city": "New York", "country": "USA"}] + data = [{"name": "John", "age": 30}, {"city": "New York", "country": "USA"}] - formatted = format_neo4j_data(data, 1000) - expected = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\ncity: New York\ncountry: USA" + formatted = format_neo4j_data(data, 1000) + expected = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\ncity: New York\ncountry: USA" - assert formatted == expected + assert formatted == expected def test_format_neo4j_result_complex_values(): - data = [ - {"numbers": [1, 2, 3], "metadata": {"type": "user", "active": True}, "date": "2024-01-01"} - ] + data = [ + {"numbers": [1, 2, 3], "metadata": {"type": "user", "active": True}, "date": "2024-01-01"} + ] - formatted = format_neo4j_data(data, 1000) - expected = ( - "Result 1:\ndate: 2024-01-01\nmetadata: {'type': 'user', 'active': True}\nnumbers: [1, 2, 3]" - ) + formatted = format_neo4j_data(data, 1000) + expected = "Result 1:\ndate: 2024-01-01\nmetadata: {'type': 'user', 'active': True}\nnumbers: [1, 2, 3]" - assert formatted == expected + assert formatted == expected def test_run_neo4j_query_success(mock_neo4j_driver): - driver, session = mock_neo4j_driver + driver, session = mock_neo4j_driver - # Create the expected formatted result - test_data = [{"name": "John", "age": 30}] + # Create the expected formatted result + test_data = [{"name": "John", "age": 30}] - # Mock the transaction execution to return the formatted result - def execute_read_side_effect(query_func): - # Create a mock transaction - tx = Mock() - # Set up the mock result - tx.run.return_value = MockResult(test_data) - # Execute the query function with our mock transaction - return query_func(tx) + # Mock the transaction execution to return the formatted result + def execute_read_side_effect(query_func): + # Create a mock transaction + tx = Mock() + # Set up the mock result + tx.run.return_value = MockResult(test_data) + # Execute the query function with our mock transaction + return query_func(tx) - # Set up the session mock to use our side effect - session.execute_read.side_effect = execute_read_side_effect + # Set up the session mock to use our side effect + session.execute_read.side_effect = execute_read_side_effect - # Run the query - query = "MATCH (n:Person) RETURN n.name as name, n.age as age" - result = run_neo4j_query(query, driver, 1000) + # Run the query + query = "MATCH (n:Person) RETURN n.name as name, n.age as age" + result = run_neo4j_query(query, driver, 1000) - expected_str = "Result 1:\nage: 30\nname: John" - # Verify results - assert result[0] == expected_str - assert result[1] == test_data + expected_str = "Result 1:\nage: 30\nname: John" + # Verify results + assert result[0] == expected_str + assert result[1] == test_data - # Verify the session was used correctly - session.execute_read.assert_called_once() + # Verify the session was used correctly + session.execute_read.assert_called_once() def test_run_neo4j_query_empty_result(mock_neo4j_driver): - driver, session = mock_neo4j_driver + driver, session = mock_neo4j_driver - def execute_read_side_effect(func): - mock_tx = Mock() - mock_tx.run.return_value = MockResult([]) - return func(mock_tx) + def execute_read_side_effect(func): + mock_tx = Mock() + mock_tx.run.return_value = MockResult([]) + return func(mock_tx) - session.execute_read.side_effect = execute_read_side_effect + session.execute_read.side_effect = execute_read_side_effect - result = run_neo4j_query("MATCH (n:NonExistent) RETURN n", driver, 1000) - assert result[0] == EMPTY_DATA_MESSAGE - assert result[1] == [] + result = run_neo4j_query("MATCH (n:NonExistent) RETURN n", driver, 1000) + assert result[0] == EMPTY_DATA_MESSAGE + assert result[1] == [] def test_run_neo4j_query_multiple_results(mock_neo4j_driver): - driver, session = mock_neo4j_driver + driver, session = mock_neo4j_driver - test_data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] + test_data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] - def execute_read_side_effect(func): - mock_tx = Mock() - mock_tx.run.return_value = MockResult(test_data) - return func(mock_tx) + def execute_read_side_effect(func): + mock_tx = Mock() + mock_tx.run.return_value = MockResult(test_data) + return func(mock_tx) - session.execute_read.side_effect = execute_read_side_effect + session.execute_read.side_effect = execute_read_side_effect - result = run_neo4j_query("MATCH (n:Person) RETURN n.name as name, n.age as age", driver, 1000) - expected_str = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\nage: 25\nname: Jane" - assert result[0] == expected_str - assert result[1] == test_data + result = run_neo4j_query("MATCH (n:Person) RETURN n.name as name, n.age as age", driver, 1000) + expected_str = "Result 1:\nage: 30\nname: John\n\n\nResult 2:\nage: 25\nname: Jane" + assert result[0] == expected_str + assert result[1] == test_data diff --git a/tests/utils/test_patch_util.py b/tests/utils/test_patch_util.py index e54ae5f1..08b313aa 100644 --- a/tests/utils/test_patch_util.py +++ b/tests/utils/test_patch_util.py @@ -4,15 +4,15 @@ def test_get_updated_files_empty_diff(): - diff = "" - added, modified, removed = get_updated_files(diff) - assert len(added) == 0 - assert len(modified) == 0 - assert len(removed) == 0 + diff = "" + added, modified, removed = get_updated_files(diff) + assert len(added) == 0 + assert len(modified) == 0 + assert len(removed) == 0 def test_get_updated_files_added_only(): - diff = """ + diff = """ diff --git a/new_file.txt b/new_file.txt new file mode 100644 index 0000000..1234567 @@ -21,15 +21,15 @@ def test_get_updated_files_added_only(): @@ -0,0 +1 @@ +New content """ - added, modified, removed = get_updated_files(diff) - assert len(added) == 1 - assert len(modified) == 0 - assert len(removed) == 0 - assert added[0] == Path("new_file.txt") + added, modified, removed = get_updated_files(diff) + assert len(added) == 1 + assert len(modified) == 0 + assert len(removed) == 0 + assert added[0] == Path("new_file.txt") def test_get_updated_files_modified_only(): - diff = """ + diff = """ diff --git a/modified_file.txt b/modified_file.txt index 1234567..89abcdef --- a/modified_file.txt @@ -38,15 +38,15 @@ def test_get_updated_files_modified_only(): -Old content +Modified content """ - added, modified, removed = get_updated_files(diff) - assert len(added) == 0 - assert len(modified) == 1 - assert len(removed) == 0 - assert modified[0] == Path("modified_file.txt") + added, modified, removed = get_updated_files(diff) + assert len(added) == 0 + assert len(modified) == 1 + assert len(removed) == 0 + assert modified[0] == Path("modified_file.txt") def test_get_updated_files_removed_only(): - diff = """ + diff = """ diff --git a/removed_file.txt b/removed_file.txt deleted file mode 100644 index 1234567..0000000 @@ -55,15 +55,15 @@ def test_get_updated_files_removed_only(): @@ -1 +0,0 @@ -Content to be removed """ - added, modified, removed = get_updated_files(diff) - assert len(added) == 0 - assert len(modified) == 0 - assert len(removed) == 1 - assert removed[0] == Path("removed_file.txt") + added, modified, removed = get_updated_files(diff) + assert len(added) == 0 + assert len(modified) == 0 + assert len(removed) == 1 + assert removed[0] == Path("removed_file.txt") def test_get_updated_files_multiple_changes(): - diff = """ + diff = """ diff --git a/new_file.txt b/new_file.txt new file mode 100644 index 0000000..1234567 @@ -86,17 +86,17 @@ def test_get_updated_files_multiple_changes(): @@ -1 +0,0 @@ -Content to be removed """ - added, modified, removed = get_updated_files(diff) - assert len(added) == 1 - assert len(modified) == 1 - assert len(removed) == 1 - assert added[0] == Path("new_file.txt") - assert modified[0] == Path("modified_file.txt") - assert removed[0] == Path("removed_file.txt") + added, modified, removed = get_updated_files(diff) + assert len(added) == 1 + assert len(modified) == 1 + assert len(removed) == 1 + assert added[0] == Path("new_file.txt") + assert modified[0] == Path("modified_file.txt") + assert removed[0] == Path("removed_file.txt") def test_get_updated_files_with_subfolders(): - diff = """ + diff = """ diff --git a/folder1/new_file.txt b/folder1/new_file.txt new file mode 100644 index 0000000..1234567 @@ -112,9 +112,9 @@ def test_get_updated_files_with_subfolders(): -Old content +Modified content """ - added, modified, removed = get_updated_files(diff) - assert len(added) == 1 - assert len(modified) == 1 - assert len(removed) == 0 - assert added[0] == Path("folder1/new_file.txt") - assert modified[0] == Path("folder2/subfolder/modified_file.txt") + added, modified, removed = get_updated_files(diff) + assert len(added) == 1 + assert len(modified) == 1 + assert len(removed) == 0 + assert added[0] == Path("folder1/new_file.txt") + assert modified[0] == Path("folder2/subfolder/modified_file.txt") diff --git a/tests/utils/test_str_util.py b/tests/utils/test_str_util.py index fb5e25f7..2ebfb264 100644 --- a/tests/utils/test_str_util.py +++ b/tests/utils/test_str_util.py @@ -1,55 +1,55 @@ from prometheus.utils.str_util import ( - TRUNCATED_TEXT, - get_tokenizer, - pre_append_line_numbers, - truncate_text, + TRUNCATED_TEXT, + get_tokenizer, + pre_append_line_numbers, + truncate_text, ) def test_single_line(): - text = "Hello world" - result = pre_append_line_numbers(text, start_line=1) - assert result == "1. Hello world" + text = "Hello world" + result = pre_append_line_numbers(text, start_line=1) + assert result == "1. Hello world" def test_multiple_lines(): - text = "First line\nSecond line\nThird line" - result = pre_append_line_numbers(text, start_line=1) - assert result == "1. First line\n2. Second line\n3. Third line" + text = "First line\nSecond line\nThird line" + result = pre_append_line_numbers(text, start_line=1) + assert result == "1. First line\n2. Second line\n3. Third line" def test_empty_string(): - text = "" - result = pre_append_line_numbers(text, start_line=1) - assert result == "" + text = "" + result = pre_append_line_numbers(text, start_line=1) + assert result == "" def test_no_truncation_needed(): - text = "Short text" - result = truncate_text(text, max_token=100) - assert result == text + text = "Short text" + result = truncate_text(text, max_token=100) + assert result == text def test_truncation(): - # Generate text that will definitely need truncation - long_text = "Hello world! " * 50 - max_tokens = 50 - result = truncate_text(long_text, max_tokens) + # Generate text that will definitely need truncation + long_text = "Hello world! " * 50 + max_tokens = 50 + result = truncate_text(long_text, max_tokens) - # Verify length - assert len(get_tokenizer().encode(result)) <= max_tokens - # Verify warning message - assert result.endswith(TRUNCATED_TEXT) - # Verify some original content remains - assert result.startswith("Hello world!") + # Verify length + assert len(get_tokenizer().encode(result)) <= max_tokens + # Verify warning message + assert result.endswith(TRUNCATED_TEXT) + # Verify some original content remains + assert result.startswith("Hello world!") def test_truncation_with_empty_string(): - result = truncate_text("", max_token=10) - assert result == "" + result = truncate_text("", max_token=10) + assert result == "" def test_truncation_with_unicode_handling(): - text = "Hello ๐Ÿ‘‹ World ๐ŸŒ" - result = truncate_text(text, max_token=100) - assert result == text + text = "Hello ๐Ÿ‘‹ World ๐ŸŒ" + result = truncate_text(text, max_token=100) + assert result == text