Skip to content

Commit 8706b68

Browse files
authored
fix: upgrades to dataclasses for tool observations (#561)
1 parent 7619895 commit 8706b68

File tree

23 files changed

+1367
-396
lines changed

23 files changed

+1367
-396
lines changed

src/codegen/extensions/langchain/tools.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, codebase: Codebase) -> None:
6161

6262
def _run(self, filepath: str) -> str:
6363
result = view_file(self.codebase, filepath)
64-
return json.dumps(result, indent=2)
64+
return result.render()
6565

6666

6767
class ListDirectoryInput(BaseModel):
@@ -84,7 +84,7 @@ def __init__(self, codebase: Codebase) -> None:
8484

8585
def _run(self, dirpath: str = "./", depth: int = 1) -> str:
8686
result = list_directory(self.codebase, dirpath, depth)
87-
return json.dumps(result, indent=2)
87+
return result.render()
8888

8989

9090
class SearchInput(BaseModel):
@@ -107,7 +107,7 @@ def __init__(self, codebase: Codebase) -> None:
107107

108108
def _run(self, query: str, target_directories: Optional[list[str]] = None) -> str:
109109
result = search(self.codebase, query, target_directories)
110-
return json.dumps(result, indent=2)
110+
return result.render()
111111

112112

113113
class EditFileInput(BaseModel):
@@ -130,7 +130,7 @@ def __init__(self, codebase: Codebase) -> None:
130130

131131
def _run(self, filepath: str, content: str) -> str:
132132
result = edit_file(self.codebase, filepath, content)
133-
return json.dumps(result, indent=2)
133+
return result.render()
134134

135135

136136
class CreateFileInput(BaseModel):
@@ -153,7 +153,7 @@ def __init__(self, codebase: Codebase) -> None:
153153

154154
def _run(self, filepath: str, content: str = "") -> str:
155155
result = create_file(self.codebase, filepath, content)
156-
return json.dumps(result, indent=2)
156+
return result.render()
157157

158158

159159
class DeleteFileInput(BaseModel):
@@ -175,7 +175,7 @@ def __init__(self, codebase: Codebase) -> None:
175175

176176
def _run(self, filepath: str) -> str:
177177
result = delete_file(self.codebase, filepath)
178-
return json.dumps(result, indent=2)
178+
return result.render()
179179

180180

181181
class CommitTool(BaseTool):
@@ -190,7 +190,7 @@ def __init__(self, codebase: Codebase) -> None:
190190

191191
def _run(self) -> str:
192192
result = commit(self.codebase)
193-
return json.dumps(result, indent=2)
193+
return result.render()
194194

195195

196196
class RevealSymbolInput(BaseModel):
@@ -233,7 +233,7 @@ def _run(
233233
collect_dependencies=collect_dependencies,
234234
collect_usages=collect_usages,
235235
)
236-
return json.dumps(result, indent=2)
236+
return result.render()
237237

238238

239239
_SEMANTIC_EDIT_BRIEF = """Tool for file editing via an LLM delegate. Describe the changes you want to make and an expert will apply them to the file.
@@ -278,7 +278,7 @@ def __init__(self, codebase: Codebase) -> None:
278278
def _run(self, filepath: str, edit_content: str, start: int = 1, end: int = -1) -> str:
279279
# Create the the draft editor mini llm
280280
result = semantic_edit(self.codebase, filepath, edit_content, start=start, end=end)
281-
return json.dumps(result, indent=2)
281+
return result.render()
282282

283283

284284
class RenameFileInput(BaseModel):
@@ -301,7 +301,7 @@ def __init__(self, codebase: Codebase) -> None:
301301

302302
def _run(self, filepath: str, new_filepath: str) -> str:
303303
result = rename_file(self.codebase, filepath, new_filepath)
304-
return json.dumps(result, indent=2)
304+
return result.render()
305305

306306

307307
class MoveSymbolInput(BaseModel):
@@ -344,7 +344,7 @@ def _run(
344344
strategy=strategy,
345345
include_dependencies=include_dependencies,
346346
)
347-
return json.dumps(result, indent=2)
347+
return result.render()
348348

349349

350350
class SemanticSearchInput(BaseModel):
@@ -368,7 +368,7 @@ def __init__(self, codebase: Codebase) -> None:
368368

369369
def _run(self, query: str, k: int = 5, preview_length: int = 200) -> str:
370370
result = semantic_search(self.codebase, query, k=k, preview_length=preview_length)
371-
return json.dumps(result, indent=2)
371+
return result.render()
372372

373373

374374
########################################################################################################################
@@ -392,7 +392,7 @@ class RunBashCommandTool(BaseTool):
392392

393393
def _run(self, command: str, is_background: bool = False) -> str:
394394
result = run_bash_command(command, is_background)
395-
return json.dumps(result, indent=2)
395+
return result.render()
396396

397397

398398
########################################################################################################################
@@ -420,7 +420,7 @@ def __init__(self, codebase: Codebase) -> None:
420420

421421
def _run(self, title: str, body: str) -> str:
422422
result = create_pr(self.codebase, title, body)
423-
return json.dumps(result, indent=2)
423+
return result.render()
424424

425425

426426
class GithubViewPRInput(BaseModel):
@@ -442,6 +442,7 @@ def __init__(self, codebase: Codebase) -> None:
442442

443443
def _run(self, pr_id: int) -> str:
444444
result = view_pr(self.codebase, pr_id)
445+
return result.render()
445446
return json.dumps(result, indent=2)
446447

447448

@@ -465,7 +466,7 @@ def __init__(self, codebase: Codebase) -> None:
465466

466467
def _run(self, pr_number: int, body: str) -> str:
467468
result = create_pr_comment(self.codebase, pr_number, body)
468-
return json.dumps(result, indent=2)
469+
return result.render()
469470

470471

471472
class GithubCreatePRReviewCommentInput(BaseModel):
@@ -511,7 +512,7 @@ def _run(
511512
side=side,
512513
start_line=start_line,
513514
)
514-
return json.dumps(result, indent=2)
515+
return result.render()
515516

516517

517518
########################################################################################################################
@@ -538,7 +539,7 @@ def __init__(self, client: LinearClient) -> None:
538539

539540
def _run(self, issue_id: str) -> str:
540541
result = linear_get_issue_tool(self.client, issue_id)
541-
return json.dumps(result, indent=2)
542+
return result.render()
542543

543544

544545
class LinearGetIssueCommentsInput(BaseModel):
@@ -560,7 +561,7 @@ def __init__(self, client: LinearClient) -> None:
560561

561562
def _run(self, issue_id: str) -> str:
562563
result = linear_get_issue_comments_tool(self.client, issue_id)
563-
return json.dumps(result, indent=2)
564+
return result.render()
564565

565566

566567
class LinearCommentOnIssueInput(BaseModel):
@@ -583,7 +584,7 @@ def __init__(self, client: LinearClient) -> None:
583584

584585
def _run(self, issue_id: str, body: str) -> str:
585586
result = linear_comment_on_issue_tool(self.client, issue_id, body)
586-
return json.dumps(result, indent=2)
587+
return result.render()
587588

588589

589590
class LinearSearchIssuesInput(BaseModel):
@@ -606,7 +607,7 @@ def __init__(self, client: LinearClient) -> None:
606607

607608
def _run(self, query: str, limit: int = 10) -> str:
608609
result = linear_search_issues_tool(self.client, query, limit)
609-
return json.dumps(result, indent=2)
610+
return result.render()
610611

611612

612613
class LinearCreateIssueInput(BaseModel):
@@ -630,7 +631,7 @@ def __init__(self, client: LinearClient) -> None:
630631

631632
def _run(self, title: str, description: str | None = None, team_id: str | None = None) -> str:
632633
result = linear_create_issue_tool(self.client, title, description, team_id)
633-
return json.dumps(result, indent=2)
634+
return result.render()
634635

635636

636637
class LinearGetTeamsTool(BaseTool):
@@ -645,7 +646,7 @@ def __init__(self, client: LinearClient) -> None:
645646

646647
def _run(self) -> str:
647648
result = linear_get_teams_tool(self.client)
648-
return json.dumps(result, indent=2)
649+
return result.render()
649650

650651

651652
########################################################################################################################
@@ -678,6 +679,7 @@ def __init__(self, codebase: Codebase, say: Callable[[str], None]) -> None:
678679
self.codebase = codebase
679680

680681
def _run(self, content: str) -> str:
682+
# TODO - pull this out into a separate function
681683
print("> Adding links to message")
682684
content_formatted = add_links_to_message(content, self.codebase)
683685
print("> Sending message to Slack")

src/codegen/extensions/tools/bash.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import re
44
import shlex
55
import subprocess
6-
from typing import Any
6+
from typing import ClassVar, Optional
7+
8+
from pydantic import Field
9+
10+
from .observation import Observation
711

812
# Whitelist of allowed commands and their flags
913
ALLOWED_COMMANDS = {
@@ -22,6 +26,28 @@
2226
}
2327

2428

29+
class RunBashCommandObservation(Observation):
30+
"""Response from running a bash command."""
31+
32+
stdout: Optional[str] = Field(
33+
default=None,
34+
description="Standard output from the command",
35+
)
36+
stderr: Optional[str] = Field(
37+
default=None,
38+
description="Standard error from the command",
39+
)
40+
command: str = Field(
41+
description="The command that was executed",
42+
)
43+
pid: Optional[int] = Field(
44+
default=None,
45+
description="Process ID for background commands",
46+
)
47+
48+
str_template: ClassVar[str] = "Command '{command}' completed"
49+
50+
2551
def validate_command(command: str) -> tuple[bool, str]:
2652
"""Validate if a command is safe to execute.
2753
@@ -90,23 +116,24 @@ def validate_command(command: str) -> tuple[bool, str]:
90116
return False, f"Failed to validate command: {e!s}"
91117

92118

93-
def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any]:
119+
def run_bash_command(command: str, is_background: bool = False) -> RunBashCommandObservation:
94120
"""Run a bash command and return its output.
95121
96122
Args:
97123
command: The command to run
98124
is_background: Whether to run the command in the background
99125
100126
Returns:
101-
Dictionary containing the command output or error
127+
RunBashCommandObservation containing the command output or error
102128
"""
103129
# First validate the command
104130
is_valid, error_message = validate_command(command)
105131
if not is_valid:
106-
return {
107-
"status": "error",
108-
"error": f"Invalid command: {error_message}",
109-
}
132+
return RunBashCommandObservation(
133+
status="error",
134+
error=f"Invalid command: {error_message}",
135+
command=command,
136+
)
110137

111138
try:
112139
if is_background:
@@ -118,10 +145,11 @@ def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any
118145
stderr=subprocess.PIPE,
119146
text=True,
120147
)
121-
return {
122-
"status": "success",
123-
"message": f"Command '{command}' started in background with PID {process.pid}",
124-
}
148+
return RunBashCommandObservation(
149+
status="success",
150+
command=command,
151+
pid=process.pid,
152+
)
125153

126154
# For foreground processes, we wait for completion
127155
result = subprocess.run(
@@ -132,20 +160,24 @@ def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any
132160
check=True, # This will raise CalledProcessError if command fails
133161
)
134162

135-
return {
136-
"status": "success",
137-
"stdout": result.stdout,
138-
"stderr": result.stderr,
139-
}
163+
return RunBashCommandObservation(
164+
status="success",
165+
command=command,
166+
stdout=result.stdout,
167+
stderr=result.stderr,
168+
)
169+
140170
except subprocess.CalledProcessError as e:
141-
return {
142-
"status": "error",
143-
"error": f"Command failed with exit code {e.returncode}",
144-
"stdout": e.stdout,
145-
"stderr": e.stderr,
146-
}
171+
return RunBashCommandObservation(
172+
status="error",
173+
error=f"Command failed with exit code {e.returncode}",
174+
command=command,
175+
stdout=e.stdout,
176+
stderr=e.stderr,
177+
)
147178
except Exception as e:
148-
return {
149-
"status": "error",
150-
"error": f"Failed to run command: {e!s}",
151-
}
179+
return RunBashCommandObservation(
180+
status="error",
181+
error=f"Failed to run command: {e!s}",
182+
command=command,
183+
)
Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,42 @@
11
"""Tool for committing changes to disk."""
22

3-
from typing import Any
3+
from typing import ClassVar
4+
5+
from pydantic import Field
46

57
from codegen import Codebase
68

9+
from .observation import Observation
10+
11+
12+
class CommitObservation(Observation):
13+
"""Response from committing changes to disk."""
14+
15+
message: str = Field(
16+
description="Message describing the commit result",
17+
)
18+
19+
str_template: ClassVar[str] = "{message}"
20+
721

8-
def commit(codebase: Codebase) -> dict[str, Any]:
22+
def commit(codebase: Codebase) -> CommitObservation:
923
"""Commit any pending changes to disk.
1024
1125
Args:
1226
codebase: The codebase to operate on
1327
1428
Returns:
15-
Dict containing commit status
29+
CommitObservation containing commit status
1630
"""
17-
codebase.commit()
18-
return {"status": "success", "message": "Changes committed to disk"}
31+
try:
32+
codebase.commit()
33+
return CommitObservation(
34+
status="success",
35+
message="Changes committed to disk",
36+
)
37+
except Exception as e:
38+
return CommitObservation(
39+
status="error",
40+
error=f"Failed to commit changes: {e!s}",
41+
message="Failed to commit changes",
42+
)

0 commit comments

Comments
 (0)