forked from Y-Research-SBU/QuantAgent
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph_setup.py
More file actions
105 lines (79 loc) · 3.39 KB
/
graph_setup.py
File metadata and controls
105 lines (79 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import Dict, Any
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode
from graph_util import TechnicalTools
import default_config
import os
from langchain.chat_models import init_chat_model
from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt.chat_agent_executor import AgentState
from langgraph.prebuilt import create_react_agent
from langgraph.graph import END, StateGraph, START
from typing import TypedDict, List
import random
from IPython.display import Image, display
import pandas as pd
from graph_util import *
# from langchain_community.chat_models import ChatOpenAI
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from agent_state import IndicatorAgentState
from indicator_agent import *
from decision_agent import *
from pattern_agent import *
from trend_agent import *
class SetGraph:
def __init__(
self,
agent_llm: ChatOpenAI,
graph_llm: ChatOpenAI,
toolkit: TechnicalTools,
tool_nodes: Dict[str, ToolNode],
):
self.agent_llm = agent_llm
self.graph_llm = graph_llm
self.toolkit = toolkit
self.tool_nodes = tool_nodes
def set_graph(self):
# Create analyst nodes
agent_nodes = {}
tool_nodes = {}
all_agents = ['indicator', 'pattern', 'trend']
# create nodes for indicator agent
agent_nodes['indicator'] = create_indicator_agent(self.graph_llm, self.toolkit)
tool_nodes['indicator'] = self.tool_nodes['indicator']
# create nodes for pattern agent
agent_nodes['pattern'] = create_pattern_agent(self.agent_llm, self.graph_llm, self.toolkit)
tool_nodes['pattern'] = self.tool_nodes['pattern']
# create nodes for trend agent
agent_nodes['trend'] = create_trend_agent(self.agent_llm, self.graph_llm, self.toolkit)
tool_nodes['trend'] = self.tool_nodes['trend']
# create nodes for decision agent
# decision_agent_node = create_final_trade_decider(self.agent_llm)
decision_agent_node = create_final_trade_decider(self.graph_llm)
# create graph
graph = StateGraph(IndicatorAgentState)
# add agent nodes and associated tool nodes to graph
for agent_type, cur_node in agent_nodes.items():
graph.add_node(f"{agent_type.capitalize()} Agent", cur_node)
graph.add_node(f"{agent_type}_tools", tool_nodes[agent_type])
# add rest of the nodes
graph.add_node("Decision Maker", decision_agent_node)
# set start of graph
graph.add_edge(START, "Indicator Agent")
# add edges to graph
for i, agent_type in enumerate(all_agents):
current_agent = f"{agent_type.capitalize()} Agent"
current_tools = f"{agent_type}_tools"
if i == len(all_agents) - 1:
graph.add_edge(current_agent, "Decision Maker")
else:
next_agent = f"{all_agents[i + 1].capitalize()} Agent"
graph.add_edge(current_agent, next_agent)
# Decision Maker Process
graph.add_edge("Decision Maker", END)
return graph.compile()