Skip to content

Commit 1e38b73

Browse files
committed
refactor: params to plan prompt
1 parent a53d47f commit 1e38b73

File tree

1 file changed

+53
-15
lines changed

1 file changed

+53
-15
lines changed

src/rai_core/rai/agents/langchain/core/megamind.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -209,27 +209,64 @@ def get_initial_megamind_state(task: str):
209209
)
210210

211211

212+
@dataclass
213+
class PlanPrompts:
214+
"""Configurable prompts for the planning step."""
215+
216+
objective_template: str = "You are given objective to complete: {original_task}"
217+
steps_done_header: str = "Steps that were already done successfully:\n"
218+
next_step_prompt: str = "\nBased on that outcome and past steps come up with the next step and delegate it to selected agent."
219+
first_step_prompt: str = (
220+
"\nCome up with the first step and delegate it to selected agent."
221+
)
222+
completion_prompt: str = (
223+
"\n\nWhen you decide that the objective is completed return response to user."
224+
)
225+
226+
@classmethod
227+
def default(cls):
228+
"""Return default prompts."""
229+
return cls()
230+
231+
@classmethod
232+
def custom(cls, **kwargs):
233+
"""Create custom prompts with overrides."""
234+
default = cls.default()
235+
for key, value in kwargs.items():
236+
if hasattr(default, key):
237+
setattr(default, key, value)
238+
return default
239+
240+
212241
def plan_step(
213242
megamind_agent: BaseChatModel,
214243
state: MegamindState,
244+
prompts: Optional[PlanPrompts] = None,
215245
context_providers: Optional[List[ContextProvider]] = None,
216246
) -> MegamindState:
217247
"""Initial planning step."""
248+
if prompts is None:
249+
prompts = PlanPrompts.default()
250+
218251
if "original_task" not in state:
219252
state["original_task"] = state["messages"][0].content[0]["text"]
220253
if "steps_done" not in state:
221254
state["steps_done"] = []
222255
if "step" not in state:
223256
state["step"] = None
224257

225-
megamind_prompt = f"You are given objective to complete: {state['original_task']}"
258+
megamind_prompt = prompts.objective_template.format(
259+
original_task=state["original_task"]
260+
)
226261
if context_providers:
227262
for provider in context_providers:
228263
megamind_prompt += provider.get_context()
229264
megamind_prompt += "\n"
265+
266+
# Add completed steps if any
230267
if state["steps_done"]:
231268
megamind_prompt += "\n\n"
232-
megamind_prompt += "Steps that were already done successfully:\n"
269+
megamind_prompt += prompts.steps_done_header
233270
steps_done = "\n".join(
234271
[f"{i + 1}. {step}" for i, step in enumerate(state["steps_done"])]
235272
)
@@ -239,22 +276,17 @@ def plan_step(
239276
if state["step"]:
240277
if not state["step_success"]:
241278
raise ValueError("Step success should be specified at this point")
242-
243-
megamind_prompt += "\nBased on that outcome and past steps come up with the next step and delegate it to selected agent."
279+
megamind_prompt += prompts.next_step_prompt
244280

245281
else:
246-
megamind_prompt += "\n"
247-
megamind_prompt += (
248-
"Come up with the fist step and delegate it to selected agent."
249-
)
282+
megamind_prompt += prompts.first_step_prompt
283+
284+
megamind_prompt += prompts.completion_prompt
250285

251-
megamind_prompt += "\n\n"
252-
megamind_prompt += (
253-
"When you decide that the objective is completed return response to user."
254-
)
255286
messages = [
256287
HumanMultimodalMessage(content=megamind_prompt),
257288
]
289+
258290
# NOTE (jmatejcz) the response of megamind isnt appended to messages
259291
# as Command from handoff instantly transitions to next node
260292
megamind_agent.invoke({"messages": messages})
@@ -265,7 +297,8 @@ def create_megamind(
265297
megamind_llm: BaseChatModel,
266298
executors: List[Executor],
267299
megamind_system_prompt: Optional[str] = None,
268-
task_planning_prompt: Optional[str] = None,
300+
anylyzer_prompt: Optional[str] = None,
301+
plan_prompts: Optional[PlanPrompts] = None,
269302
context_providers: List[ContextProvider] = [],
270303
) -> CompiledStateGraph:
271304
"""Create a megamind langchain agent
@@ -292,7 +325,7 @@ def create_megamind(
292325
llm=executor.llm,
293326
tools=executor.tools,
294327
system_prompt=executor.system_prompt,
295-
planning_prompt=task_planning_prompt,
328+
planning_prompt=anylyzer_prompt,
296329
)
297330

298331
handoff_tools.append(
@@ -325,7 +358,12 @@ def create_megamind(
325358

326359
graph = StateGraph(MegamindState).add_node(
327360
"megamind",
328-
partial(plan_step, megamind_agent, context_providers=context_providers),
361+
partial(
362+
plan_step,
363+
megamind_agent,
364+
context_providers=context_providers,
365+
prompts=plan_prompts,
366+
),
329367
)
330368
for agent_name, agent in executor_agents.items():
331369
graph.add_node(agent_name, agent)

0 commit comments

Comments
 (0)