-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathreact.py
More file actions
129 lines (110 loc) · 4.18 KB
/
react.py
File metadata and controls
129 lines (110 loc) · 4.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Pattern: ReAct with the default Engine LLM path."""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from qitos import Action, AgentModule, Decision, StateSchema, ToolRegistry
from qitos.kit import (
CodingToolSet,
REACT_SYSTEM_PROMPT,
ReActTextParser,
format_action,
render_prompt,
)
from qitos.models import OpenAICompatibleModel
TASK = "Open buggy_module.py, fix add(a, b) so it returns a + b, then run verification."
WORKSPACE = Path("./playground/react_pattern")
MODEL_NAME = os.getenv("QITOS_MODEL", "Qwen/Qwen3-8B")
MODEL_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.siliconflow.cn/v1/")
MAX_STEPS = 8
@dataclass
class ReactState(StateSchema):
scratchpad: list[str] = field(default_factory=list)
target_file: str = "buggy_module.py"
test_command: str = (
'python -c "import buggy_module; assert buggy_module.add(20, 22) == 42"'
)
class ReactAgent(AgentModule[ReactState, dict[str, Any], Action]):
def __init__(self, llm: Any, workspace_root: str):
registry = ToolRegistry()
registry.include(
CodingToolSet(
workspace_root=workspace_root,
include_notebook=False,
enable_lsp=False,
enable_tasks=False,
enable_web=False,
expose_modern_names=False,
)
)
super().__init__(
tool_registry=registry, llm=llm, model_parser=ReActTextParser()
)
def init_state(self, task: str, **kwargs: Any) -> ReactState:
return ReactState(task=task, max_steps=int(kwargs.get("max_steps", MAX_STEPS)))
def build_system_prompt(self, state: ReactState) -> str | None:
return render_prompt(
REACT_SYSTEM_PROMPT,
{"tool_schema": self.tool_registry.get_tool_descriptions()},
)
def prepare(self, state: ReactState) -> str:
lines = [
f"Task: {state.task}",
f"Target file: {state.target_file}",
f"Verification command: {state.test_command}",
f"Step: {state.current_step}/{state.max_steps}",
]
if state.scratchpad:
lines.append("Recent trajectory:")
lines.extend(state.scratchpad[-8:])
return "\n".join(lines)
def reduce(
self, state: ReactState, observation: dict[str, Any], decision: Decision[Action]
) -> ReactState:
action_results = (
observation.get("action_results", [])
if isinstance(observation, dict)
else []
)
if decision.rationale:
state.scratchpad.append(f"Thought: {decision.rationale}")
if decision.actions:
state.scratchpad.append(f"Action: {format_action(decision.actions[0])}")
if action_results:
first = action_results[0]
state.scratchpad.append(f"Observation: {first}")
if isinstance(first, dict) and int(first.get("returncode", 1)) == 0:
state.final_result = "Patch applied and verification passed."
state.scratchpad = state.scratchpad[-30:]
return state
def build_model() -> OpenAICompatibleModel:
api_key = (os.getenv("OPENAI_API_KEY") or os.getenv("QITOS_API_KEY") or "").strip()
if not api_key:
raise ValueError(
"Set OPENAI_API_KEY or QITOS_API_KEY before running this example."
)
return OpenAICompatibleModel(
model=MODEL_NAME,
api_key=api_key,
base_url=MODEL_BASE_URL,
temperature=0.2,
max_tokens=2048,
)
def main() -> None:
WORKSPACE.mkdir(parents=True, exist_ok=True)
target = WORKSPACE / "buggy_module.py"
if not target.exists():
target.write_text("def add(a, b):\n return a - b\n", encoding="utf-8")
agent = ReactAgent(llm=build_model(), workspace_root=str(WORKSPACE))
result = agent.run(
task=TASK,
workspace=str(WORKSPACE),
max_steps=MAX_STEPS,
return_state=True,
)
print("workspace:", WORKSPACE)
print("final_result:", result.state.final_result)
print("stop_reason:", result.state.stop_reason)
if __name__ == "__main__":
main()