From 008b18c13f079520fa7854cbc1204e9ca8f92c39 Mon Sep 17 00:00:00 2001 From: "zhencheng.cato" Date: Mon, 10 Nov 2025 16:42:49 +0800 Subject: [PATCH] feat: add conditional_edge options --- graph/conditional_edge_test.go | 176 +++++++++++++++++++++++++++++++++ graph/graph.go | 69 ++++++++++++- 2 files changed, 242 insertions(+), 3 deletions(-) create mode 100644 graph/conditional_edge_test.go diff --git a/graph/conditional_edge_test.go b/graph/conditional_edge_test.go new file mode 100644 index 0000000..8f9705d --- /dev/null +++ b/graph/conditional_edge_test.go @@ -0,0 +1,176 @@ +package graph + +import ( + "context" + "testing" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/schema" +) + +func TestConditionalEdge(t *testing.T) { + tests := []struct { + name string + setupGraph func() *MessageGraph + input []llms.MessageContent + expectedOutput string + expectError bool + }{ + { + name: "simple conditional edge", + setupGraph: func() *MessageGraph { + g := NewMessageGraph() + + // Add a node that checks the input + g.AddNode("checker", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + // Simple logic: if message contains "hello", go to hello_node, otherwise go to goodbye_node + if len(state) > 0 { + content := state[0].Parts[0].(llms.TextContent).Text + if content == "hello" { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "check: hello")), nil + } + } + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "check: other")), nil + }) + + g.AddNode("hello_node", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Hello there!")), nil + }) + + g.AddNode("goodbye_node", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Goodbye!")), nil + }) + + // Add conditional edge from checker + g.AddConditionalEdge("checker", func(ctx context.Context, state []llms.MessageContent) (string, error) { + if len(state) > 1 { + content := state[1].Parts[0].(llms.TextContent).Text + if content == "check: hello" { + return "hello_node", nil + } + } + return "goodbye_node", nil + }, nil) + + // Add edges from hello_node and goodbye_node to END + g.AddEdge("hello_node", END) + g.AddEdge("goodbye_node", END) + + g.SetEntryPoint("checker") + return g + }, + input: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "hello")}, + expectedOutput: "Hello there!", + expectError: false, + }, + { + name: "conditional edge with path map", + setupGraph: func() *MessageGraph { + g := NewMessageGraph() + + g.AddNode("decision", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + if len(state) > 0 { + content := state[0].Parts[0].(llms.TextContent).Text + if content == "A" { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "option_A")), nil + } + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "option_B")), nil + } + return state, nil + }) + + g.AddNode("path_a", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "You chose path A")), nil + }) + + g.AddNode("path_b", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "You chose path B")), nil + }) + + // Add conditional edge with path mapping + g.AddConditionalEdge("decision", func(ctx context.Context, state []llms.MessageContent) (string, error) { + if len(state) > 1 { + content := state[1].Parts[0].(llms.TextContent).Text + return content, nil // Returns "option_A" or "option_B" + } + return "option_B", nil + }, map[string]string{ + "option_A": "path_a", + "option_B": "path_b", + }) + + g.AddEdge("path_a", END) + g.AddEdge("path_b", END) + + g.SetEntryPoint("decision") + return g + }, + input: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "A")}, + expectedOutput: "You chose path A", + expectError: false, + }, + { + name: "conditional edge returning END", + setupGraph: func() *MessageGraph { + g := NewMessageGraph() + + g.AddNode("start", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "start_processed")), nil + }) + + // Add conditional edge that can return END + g.AddConditionalEdge("start", func(ctx context.Context, state []llms.MessageContent) (string, error) { + return END, nil + }, nil) + + g.SetEntryPoint("start") + return g + }, + input: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "test")}, + expectedOutput: "start_processed", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := tt.setupGraph() + runnable, err := g.Compile() + if err != nil { + t.Fatalf("Failed to compile graph: %v", err) + } + + result, err := runnable.Invoke(context.Background(), tt.input) + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Check if the expected output is in the result + found := false + for _, msg := range result { + for _, part := range msg.Parts { + if textPart, ok := part.(llms.TextContent); ok { + if textPart.Text == tt.expectedOutput { + found = true + break + } + } + } + if found { + break + } + } + + if !found { + t.Errorf("Expected output '%s' not found in result", tt.expectedOutput) + } + }) + } +} diff --git a/graph/graph.go b/graph/graph.go index 421b69e..0904cd3 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -41,6 +41,18 @@ type Edge struct { To string } +// ConditionalEdge represents a conditional edge in the message graph. +type ConditionalEdge struct { + // From is the name of the node from which the edge originates. + From string + + // Path is the function that determines the next node based on the current state. + Path func(ctx context.Context, state []llms.MessageContent) (string, error) + + // PathMap is an optional mapping of path results to node names. + PathMap map[string]string +} + // MessageGraph represents a message graph. type MessageGraph struct { // nodes is a map of node names to their corresponding Node objects. @@ -49,6 +61,9 @@ type MessageGraph struct { // edges is a slice of Edge objects representing the connections between nodes. edges []Edge + // conditionalEdges is a slice of ConditionalEdge objects representing conditional connections. + conditionalEdges []ConditionalEdge + // entryPoint is the name of the entry point node in the graph. entryPoint string } @@ -56,7 +71,9 @@ type MessageGraph struct { // NewMessageGraph creates a new instance of MessageGraph. func NewMessageGraph() *MessageGraph { return &MessageGraph{ - nodes: make(map[string]Node), + nodes: make(map[string]Node), + edges: []Edge{}, + conditionalEdges: []ConditionalEdge{}, } } @@ -76,6 +93,17 @@ func (g *MessageGraph) AddEdge(from, to string) { }) } +// AddConditionalEdge adds a conditional edge from the source node. +// The path function determines the next node based on the current state. +// If pathMap is provided, the result of path function will be mapped to actual node names. +func (g *MessageGraph) AddConditionalEdge(source string, path func(ctx context.Context, state []llms.MessageContent) (string, error), pathMap map[string]string) { + g.conditionalEdges = append(g.conditionalEdges, ConditionalEdge{ + From: source, + Path: path, + PathMap: pathMap, + }) +} + // SetEntryPoint sets the entry point node name for the message graph. func (g *MessageGraph) SetEntryPoint(name string) { g.entryPoint = name @@ -99,8 +127,6 @@ func (g *MessageGraph) Compile() (*Runnable, error) { }, nil } -// Invoke executes the compiled message graph with the given input messages. -// It returns the resulting messages and an error if any occurs during the execution. // Invoke executes the compiled message graph with the given input messages. // It returns the resulting messages and an error if any occurs during the execution. func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) ([]llms.MessageContent, error) { @@ -123,6 +149,43 @@ func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) ( return nil, fmt.Errorf("error in node %s: %w", currentNode, err) } + // First check for conditional edges + foundConditional := false + for _, condEdge := range r.graph.conditionalEdges { + if condEdge.From == currentNode { + // Execute the conditional path function + pathResult, err := condEdge.Path(ctx, state) + if err != nil { + return nil, fmt.Errorf("error in conditional edge from %s: %w", currentNode, err) + } + + // Apply path mapping if provided + nextNode := pathResult + if condEdge.PathMap != nil { + if mappedNode, ok := condEdge.PathMap[pathResult]; ok { + nextNode = mappedNode + } + } + + if nextNode == END { + currentNode = END + } else { + // Verify the target node exists + if _, ok := r.graph.nodes[nextNode]; !ok { + return nil, fmt.Errorf("%w: %s (target of conditional edge from %s)", ErrNodeNotFound, nextNode, currentNode) + } + currentNode = nextNode + } + foundConditional = true + break + } + } + + if foundConditional { + continue + } + + // If no conditional edge found, check for regular edges foundNext := false for _, edge := range r.graph.edges { if edge.From == currentNode {