@@ -336,6 +336,29 @@ def _run(self, query: str, k: int = 5, preview_length: int = 200) -> str:
336336 return json .dumps (result , indent = 2 )
337337
338338
339+ class CreatePRInput (BaseModel ):
340+ """Input for creating a PR"""
341+
342+ title : str = Field (..., description = "The title of the PR" )
343+ body : str = Field (..., description = "The body of the PR" )
344+
345+
346+ class CreatePRTool (BaseTool ):
347+ """Tool for creating a PR."""
348+
349+ name : ClassVar [str ] = "create_pr"
350+ description : ClassVar [str ] = "Create a PR for the current branch"
351+ args_schema : ClassVar [type [BaseModel ]] = CreatePRInput
352+ codebase : Codebase = Field (exclude = True )
353+
354+ def __init__ (self , codebase : Codebase ) -> None :
355+ super ().__init__ (codebase = codebase )
356+
357+ def _run (self , title : str , body : str ) -> str :
358+ pr = self .codebase .create_pr (title = title , body = body )
359+ return pr .html_url
360+
361+
339362def get_workspace_tools (codebase : Codebase ) -> list ["BaseTool" ]:
340363 """Get all workspace tools initialized with a codebase.
341364
@@ -345,26 +368,15 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]:
345368 Returns:
346369 List of initialized Langchain tools
347370 """
348- from .tools import (
349- CommitTool ,
350- CreateFileTool ,
351- DeleteFileTool ,
352- EditFileTool ,
353- ListDirectoryTool ,
354- RevealSymbolTool ,
355- SearchTool ,
356- SemanticEditTool ,
357- ViewFileTool ,
358- )
359-
360371 return [
361- ViewFileTool (codebase ),
362- ListDirectoryTool (codebase ),
363- SearchTool (codebase ),
364- EditFileTool (codebase ),
372+ CommitTool (codebase ),
365373 CreateFileTool (codebase ),
374+ CreatePRTool (codebase ),
366375 DeleteFileTool (codebase ),
367- CommitTool (codebase ),
376+ EditFileTool (codebase ),
377+ ListDirectoryTool (codebase ),
368378 RevealSymbolTool (codebase ),
379+ SearchTool (codebase ),
369380 SemanticEditTool (codebase ),
381+ ViewFileTool (codebase ),
370382 ]
0 commit comments