Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions graph/conditional_edge_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package graph

import (
"context"
"testing"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
)

func TestConditionalEdge(t *testing.T) {
tests := []struct {
name string
setupGraph func() *MessageGraph
input []llms.MessageContent
expectedOutput string
expectError bool
}{
{
name: "simple conditional edge",
setupGraph: func() *MessageGraph {
g := NewMessageGraph()

// Add a node that checks the input
g.AddNode("checker", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
// Simple logic: if message contains "hello", go to hello_node, otherwise go to goodbye_node
if len(state) > 0 {
content := state[0].Parts[0].(llms.TextContent).Text
if content == "hello" {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "check: hello")), nil
}
}
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "check: other")), nil
})

g.AddNode("hello_node", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Hello there!")), nil
})

g.AddNode("goodbye_node", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Goodbye!")), nil
})

// Add conditional edge from checker
g.AddConditionalEdge("checker", func(ctx context.Context, state []llms.MessageContent) (string, error) {
if len(state) > 1 {
content := state[1].Parts[0].(llms.TextContent).Text
if content == "check: hello" {
return "hello_node", nil
}
}
return "goodbye_node", nil
}, nil)

// Add edges from hello_node and goodbye_node to END
g.AddEdge("hello_node", END)
g.AddEdge("goodbye_node", END)

g.SetEntryPoint("checker")
return g
},
input: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "hello")},
expectedOutput: "Hello there!",
expectError: false,
},
{
name: "conditional edge with path map",
setupGraph: func() *MessageGraph {
g := NewMessageGraph()

g.AddNode("decision", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
if len(state) > 0 {
content := state[0].Parts[0].(llms.TextContent).Text
if content == "A" {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "option_A")), nil
}
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "option_B")), nil
}
return state, nil
})

g.AddNode("path_a", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "You chose path A")), nil
})

g.AddNode("path_b", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "You chose path B")), nil
})

// Add conditional edge with path mapping
g.AddConditionalEdge("decision", func(ctx context.Context, state []llms.MessageContent) (string, error) {
if len(state) > 1 {
content := state[1].Parts[0].(llms.TextContent).Text
return content, nil // Returns "option_A" or "option_B"
}
return "option_B", nil
}, map[string]string{
"option_A": "path_a",
"option_B": "path_b",
})

g.AddEdge("path_a", END)
g.AddEdge("path_b", END)

g.SetEntryPoint("decision")
return g
},
input: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "A")},
expectedOutput: "You chose path A",
expectError: false,
},
{
name: "conditional edge returning END",
setupGraph: func() *MessageGraph {
g := NewMessageGraph()

g.AddNode("start", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "start_processed")), nil
})

// Add conditional edge that can return END
g.AddConditionalEdge("start", func(ctx context.Context, state []llms.MessageContent) (string, error) {
return END, nil
}, nil)

g.SetEntryPoint("start")
return g
},
input: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "test")},
expectedOutput: "start_processed",
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := tt.setupGraph()
runnable, err := g.Compile()
if err != nil {
t.Fatalf("Failed to compile graph: %v", err)
}

result, err := runnable.Invoke(context.Background(), tt.input)
if tt.expectError {
if err == nil {
t.Errorf("Expected error but got none")
}
return
}

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// Check if the expected output is in the result
found := false
for _, msg := range result {
for _, part := range msg.Parts {
if textPart, ok := part.(llms.TextContent); ok {
if textPart.Text == tt.expectedOutput {
found = true
break
}
}
}
if found {
break
}
}

if !found {
t.Errorf("Expected output '%s' not found in result", tt.expectedOutput)
}
})
}
}
69 changes: 66 additions & 3 deletions graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ type Edge struct {
To string
}

// ConditionalEdge represents a conditional edge in the message graph.
type ConditionalEdge struct {
// From is the name of the node from which the edge originates.
From string

// Path is the function that determines the next node based on the current state.
Path func(ctx context.Context, state []llms.MessageContent) (string, error)

// PathMap is an optional mapping of path results to node names.
PathMap map[string]string
}

// MessageGraph represents a message graph.
type MessageGraph struct {
// nodes is a map of node names to their corresponding Node objects.
Expand All @@ -49,14 +61,19 @@ type MessageGraph struct {
// edges is a slice of Edge objects representing the connections between nodes.
edges []Edge

// conditionalEdges is a slice of ConditionalEdge objects representing conditional connections.
conditionalEdges []ConditionalEdge

// entryPoint is the name of the entry point node in the graph.
entryPoint string
}

// NewMessageGraph creates a new instance of MessageGraph.
func NewMessageGraph() *MessageGraph {
return &MessageGraph{
nodes: make(map[string]Node),
nodes: make(map[string]Node),
edges: []Edge{},
conditionalEdges: []ConditionalEdge{},
}
}

Expand All @@ -76,6 +93,17 @@ func (g *MessageGraph) AddEdge(from, to string) {
})
}

// AddConditionalEdge adds a conditional edge from the source node.
// The path function determines the next node based on the current state.
// If pathMap is provided, the result of path function will be mapped to actual node names.
func (g *MessageGraph) AddConditionalEdge(source string, path func(ctx context.Context, state []llms.MessageContent) (string, error), pathMap map[string]string) {
g.conditionalEdges = append(g.conditionalEdges, ConditionalEdge{
From: source,
Path: path,
PathMap: pathMap,
})
}

// SetEntryPoint sets the entry point node name for the message graph.
func (g *MessageGraph) SetEntryPoint(name string) {
g.entryPoint = name
Expand All @@ -99,8 +127,6 @@ func (g *MessageGraph) Compile() (*Runnable, error) {
}, nil
}

// Invoke executes the compiled message graph with the given input messages.
// It returns the resulting messages and an error if any occurs during the execution.
// Invoke executes the compiled message graph with the given input messages.
// It returns the resulting messages and an error if any occurs during the execution.
func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) ([]llms.MessageContent, error) {
Expand All @@ -123,6 +149,43 @@ func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) (
return nil, fmt.Errorf("error in node %s: %w", currentNode, err)
}

// First check for conditional edges
foundConditional := false
for _, condEdge := range r.graph.conditionalEdges {
if condEdge.From == currentNode {
// Execute the conditional path function
pathResult, err := condEdge.Path(ctx, state)
if err != nil {
return nil, fmt.Errorf("error in conditional edge from %s: %w", currentNode, err)
}

// Apply path mapping if provided
nextNode := pathResult
if condEdge.PathMap != nil {
if mappedNode, ok := condEdge.PathMap[pathResult]; ok {
nextNode = mappedNode
}
}

if nextNode == END {
currentNode = END
} else {
// Verify the target node exists
if _, ok := r.graph.nodes[nextNode]; !ok {
return nil, fmt.Errorf("%w: %s (target of conditional edge from %s)", ErrNodeNotFound, nextNode, currentNode)
}
currentNode = nextNode
}
foundConditional = true
break
}
}

if foundConditional {
continue
}

// If no conditional edge found, check for regular edges
foundNext := false
for _, edge := range r.graph.edges {
if edge.From == currentNode {
Expand Down