diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go index b27ace232f..84e32b2e00 100644 --- a/go/ai/exp/gen.go +++ b/go/ai/exp/gen.go @@ -73,16 +73,26 @@ type SessionFlowResult struct { type SessionFlowStreamChunk[Stream any] struct { // Artifact contains a newly produced artifact. Artifact *Artifact `json:"artifact,omitempty"` - // EndTurn signals that the session flow has finished processing the current input. - // When true, the client should stop iterating and may send the next input. - EndTurn bool `json:"endTurn,omitempty"` // ModelChunk contains generation tokens from the model. ModelChunk *ai.ModelResponseChunk `json:"modelChunk,omitempty"` - // SnapshotID contains the ID of a snapshot that was just persisted. - SnapshotID string `json:"snapshotId,omitempty"` // Status contains user-defined structured status information. // The Stream type parameter defines the shape of this data. Status Stream `json:"status,omitempty"` + // TurnEnd signals that the session flow has finished processing the current + // turn. When non-nil, the client should stop iterating and may send the + // next input. It carries the snapshot ID (if any) and the number of inputs + // that were combined into this turn. + TurnEnd *TurnEnd `json:"turnEnd,omitempty"` +} + +// TurnEnd signals the completion of a turn and carries per-turn metadata. +type TurnEnd struct { + // SnapshotID contains the ID of the snapshot persisted at the end of + // this turn. Empty if no snapshot was created. + SnapshotID string `json:"snapshotId,omitempty"` + // InputCount is the number of client inputs that were combined into + // this turn. Always >= 1. + InputCount int `json:"inputCount"` } // Artifact represents a named collection of parts produced during a session. diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go index c22e63956f..9220016c29 100644 --- a/go/ai/exp/option.go +++ b/go/ai/exp/option.go @@ -29,8 +29,9 @@ type SessionFlowOption[State any] interface { } type sessionFlowOptions[State any] struct { - store SessionStore[State] - callback SnapshotCallback[State] + store SessionStore[State] + callback SnapshotCallback[State] + combineInputs bool } func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[State]) error { @@ -46,6 +47,9 @@ func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[St } opts.callback = o.callback } + if o.combineInputs { + opts.combineInputs = true + } return nil } @@ -60,6 +64,13 @@ func WithSnapshotCallback[State any](cb SnapshotCallback[State]) SessionFlowOpti return &sessionFlowOptions[State]{callback: cb} } +// WithBatchedInputs enables input batching for prompt-backed session flows. +// When set, [DefineSessionFlowFromPrompt] uses [SessionRunner.RunBatched] +// instead of [SessionRunner.Run], combining queued inputs into a single turn. +func WithBatchedInputs[State any]() SessionFlowOption[State] { + return &sessionFlowOptions[State]{combineInputs: true} +} + // WithSnapshotOn configures snapshots to be created only for the specified events. // For example, WithSnapshotOn[MyState](SnapshotEventTurnEnd) skips the // invocation-end snapshot. diff --git a/go/ai/exp/session_flow.go b/go/ai/exp/session_flow.go index ae61403931..c7ef920260 100644 --- a/go/ai/exp/session_flow.go +++ b/go/ai/exp/session_flow.go @@ -59,14 +59,17 @@ type SessionRunner[State any] struct { lastSnapshot *SessionSnapshot[State] lastSnapshotVersion uint64 collectTurnOutput func() any + turnInputCount int // set by Run (1) or RunBatched (N), read by onEndTurn } // Run loops over the input channel, calling fn for each turn. Each turn is // wrapped in a trace span for observability. Input messages are automatically -// added to the session before fn is called. After fn returns successfully, an -// EndTurn chunk is sent and a snapshot check is triggered. +// added to the session before fn is called. After fn returns successfully, a +// [TurnEnd] chunk is sent and a snapshot check is triggered. func (a *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *SessionFlowInput) error) error { for input := range a.InputCh { + a.turnInputCount = 1 + spanMeta := &tracing.SpanMetadata{ Name: fmt.Sprintf("sessionFlow/turn/%d", a.TurnIndex), Type: "flowStep", @@ -97,6 +100,82 @@ func (a *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont return nil } +// RunBatched is like [Run] but combines queued inputs into a single turn. +// When the server finishes processing a turn and additional inputs are +// waiting in the channel, they are merged into one [SessionFlowInput] +// (messages concatenated) and processed as a single turn. The resulting +// [TurnEnd] chunk reports how many client inputs were combined via +// InputCount. +// +// A draining goroutine keeps the input channel flowing so that clients +// sending multiple messages in rapid succession do not block. The +// intermediary buffer is bounded (128 items) to provide back-pressure +// for fast producers. +func (a *SessionRunner[State]) RunBatched(ctx context.Context, fn func(ctx context.Context, input *SessionFlowInput) error) error { + // Draining goroutine: forwards InputCh into a larger buffer so that + // the non-blocking drain below can capture all queued inputs. + buf := make(chan *SessionFlowInput, 128) + go func() { + defer close(buf) + for input := range a.InputCh { + buf <- input + } + }() + + for first := range buf { + // Non-blocking drain: combine any queued inputs with the first. + combined := &SessionFlowInput{ + Messages: append([]*ai.Message(nil), first.Messages...), + ToolRestarts: append([]*ai.Part(nil), first.ToolRestarts...), + } + count := 1 + drain: + for { + select { + case next, ok := <-buf: + if !ok { + break drain + } + combined.Messages = append(combined.Messages, next.Messages...) + combined.ToolRestarts = append(combined.ToolRestarts, next.ToolRestarts...) + count++ + default: + break drain + } + } + + a.turnInputCount = count + + spanMeta := &tracing.SpanMetadata{ + Name: fmt.Sprintf("sessionFlow/turn/%d", a.TurnIndex), + Type: "flowStep", + Subtype: "flowStep", + } + + _, err := tracing.RunInNewSpan(ctx, spanMeta, combined, + func(ctx context.Context, input *SessionFlowInput) (any, error) { + a.AddMessages(input.Messages...) + + if err := fn(ctx, input); err != nil { + return nil, err + } + + a.onEndTurn(ctx) + a.TurnIndex++ + + if a.collectTurnOutput != nil { + return a.collectTurnOutput(), nil + } + return nil, nil + }, + ) + if err != nil { + return err + } + } + return nil +} + // Result returns an [SessionFlowResult] populated from the current session state: // the last message in the conversation history and all artifacts. // It is a convenience for custom session flows that don't need to construct the @@ -281,8 +360,8 @@ func DefineSessionFlow[Stream, State any]( if chunk.Artifact != nil { session.AddArtifacts(chunk.Artifact) } - // Accumulate content chunks (exclude control signals from onEndTurn). - if !chunk.EndTurn && chunk.SnapshotID == "" { + // Accumulate content chunks (exclude TurnEnd control signals). + if chunk.TurnEnd == nil { turnMu.Lock() turnChunks = append(turnChunks, chunk) turnMu.Unlock() @@ -291,14 +370,20 @@ func DefineSessionFlow[Stream, State any]( } }() - // Wire up onEndTurn: triggers snapshot + sends EndTurn chunk. + // Wire up onEndTurn: triggers snapshot + sends TurnEnd chunk. // Writes through respCh to preserve ordering with user chunks. agentSess.onEndTurn = func(turnCtx context.Context) { snapshotID := agentSess.maybeSnapshot(turnCtx, SnapshotEventTurnEnd) - if snapshotID != "" { - respCh <- &SessionFlowStreamChunk[Stream]{SnapshotID: snapshotID} + inputCount := agentSess.turnInputCount + if inputCount == 0 { + inputCount = 1 + } + respCh <- &SessionFlowStreamChunk[Stream]{ + TurnEnd: &TurnEnd{ + SnapshotID: snapshotID, + InputCount: inputCount, + }, } - respCh <- &SessionFlowStreamChunk[Stream]{EndTurn: true} } result, fnErr := fn(ctx, Responder[Stream](respCh), agentSess) @@ -359,8 +444,19 @@ func DefineSessionFlowFromPrompt[State, PromptIn any]( panic(fmt.Sprintf("DefineSessionFlowFromPrompt: prompt %q not found", promptName)) } + // Pre-scan options for prompt-flow settings. + scanOpts := &sessionFlowOptions[State]{} + for _, opt := range opts { + opt.applySessionFlow(scanOpts) + } + combineInputs := scanOpts.combineInputs + fn := func(ctx context.Context, resp Responder[any], sess *SessionRunner[State]) (*SessionFlowResult, error) { - if err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + run := sess.Run + if combineInputs { + run = sess.RunBatched + } + if err := run(ctx, func(ctx context.Context, input *SessionFlowInput) error { // Resolve prompt input: session state override > default. promptInput := defaultInput if stored := sess.InputVariables(); stored != nil { @@ -634,7 +730,7 @@ func (c *SessionFlowConnection[Stream, State]) Close() error { // Receive returns an iterator for receiving stream chunks. // Unlike the underlying BidiConnection.Receive, breaking out of this iterator // does not cancel the connection. This enables multi-turn patterns where the -// caller breaks on EndTurn, sends the next input, then calls Receive again. +// caller breaks on TurnEnd, sends the next input, then calls Receive again. func (c *SessionFlowConnection[Stream, State]) Receive() iter.Seq2[*SessionFlowStreamChunk[Stream], error] { c.initReceiver() return func(yield func(*SessionFlowStreamChunk[Stream], error) bool) { diff --git a/go/ai/exp/session_flow_test.go b/go/ai/exp/session_flow_test.go index 1d8d0ddcb5..4a76486240 100644 --- a/go/ai/exp/session_flow_test.go +++ b/go/ai/exp/session_flow_test.go @@ -78,11 +78,11 @@ func TestSessionFlow_BasicMultiTurn(t *testing.T) { t.Fatalf("Receive error: %v", err) } turn1Chunks++ - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } - if turn1Chunks < 2 { // at least status + endTurn + if turn1Chunks < 2 { // at least status + turnEnd t.Errorf("expected at least 2 chunks in turn 1, got %d", turn1Chunks) } @@ -94,7 +94,7 @@ func TestSessionFlow_BasicMultiTurn(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -148,10 +148,10 @@ func TestSessionFlow_WithSessionStore(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.SnapshotID != "" { - snapshotIDs = append(snapshotIDs, chunk.SnapshotID) - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.TurnEnd.SnapshotID) + } break } } @@ -216,7 +216,7 @@ func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -239,7 +239,7 @@ func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -313,7 +313,7 @@ func TestSessionFlow_ClientManagedState(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -387,7 +387,7 @@ func TestSessionFlow_Artifacts(t *testing.T) { if chunk.Artifact != nil { receivedArtifacts = append(receivedArtifacts, chunk.Artifact) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -445,10 +445,10 @@ func TestSessionFlow_SnapshotCallback(t *testing.T) { if err != nil { t.Fatalf("Receive error on turn %d: %v", i, err) } - if chunk.SnapshotID != "" { - snapshotIDs = append(snapshotIDs, chunk.SnapshotID) - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.TurnEnd.SnapshotID) + } break } } @@ -496,7 +496,7 @@ func TestSessionFlow_SendMessages(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -547,7 +547,7 @@ func TestSessionFlow_SessionContext(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -609,7 +609,7 @@ func TestSessionFlow_SetMessages(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -656,7 +656,7 @@ func TestSessionFlow_SnapshotIDInMessageMetadata(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -767,7 +767,7 @@ func TestSessionFlow_TurnSpanOutput(t *testing.T) { if err != nil { t.Fatalf("Receive error on turn %d: %v", turn, err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -793,11 +793,8 @@ func TestSessionFlow_TurnSpanOutput(t *testing.T) { t.Errorf("turn %d: expected 3 chunks, got %d", i, len(chunks)) } for j, chunk := range chunks { - if chunk.EndTurn { - t.Errorf("turn %d, chunk %d: EndTurn should not be in turn output", i, j) - } - if chunk.SnapshotID != "" { - t.Errorf("turn %d, chunk %d: SnapshotID should not be in turn output", i, j) + if chunk.TurnEnd != nil { + t.Errorf("turn %d, chunk %d: TurnEnd should not be in turn output", i, j) } } } @@ -839,10 +836,10 @@ func TestSessionFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.SnapshotID != "" { - sawSnapshot = true - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + sawSnapshot = true + } break } } @@ -850,7 +847,7 @@ func TestSessionFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { conn.Output() if !sawSnapshot { - t.Fatal("expected a snapshot chunk on the stream") + t.Fatal("expected a snapshot ID in TurnEnd") } // Turn output should contain only the status chunk, not the snapshot/endTurn. @@ -938,7 +935,7 @@ func TestPromptAgent_Basic(t *testing.T) { if chunk.ModelChunk != nil { gotChunk = true } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -954,7 +951,7 @@ func TestPromptAgent_Basic(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1007,7 +1004,7 @@ func TestPromptAgent_PromptInputOverride(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1085,7 +1082,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { if chunk.ModelChunk != nil { turn1Response += chunk.ModelChunk.Text() } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1106,7 +1103,7 @@ func TestPromptAgent_MultiTurnHistory(t *testing.T) { if chunk.ModelChunk != nil { turn2Response += chunk.ModelChunk.Text() } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1160,7 +1157,7 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1195,7 +1192,7 @@ func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1286,7 +1283,7 @@ func TestPromptAgent_ToolLoopMessages(t *testing.T) { if err != nil { t.Fatalf("Receive error: %v", err) } - if chunk.EndTurn { + if chunk.TurnEnd != nil { break } } @@ -1586,10 +1583,10 @@ func TestSessionFlow_MultiTurnSnapshotDedup(t *testing.T) { if err != nil { t.Fatalf("Receive error on turn %d: %v", i, err) } - if chunk.SnapshotID != "" { - snapshotIDs = append(snapshotIDs, chunk.SnapshotID) - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.TurnEnd.SnapshotID) + } break } } @@ -1666,3 +1663,125 @@ func TestSessionFlow_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) t.Error("expected parent ID (turn-end snapshot)") } } + +func TestSessionFlow_RunBatched(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + processingTurn1 := make(chan struct{}) + turn1Done := make(chan struct{}) + + var turnInputCounts []int + turn := 0 + af := DefineSessionFlow(reg, "batchedFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.RunBatched(ctx, func(ctx context.Context, input *SessionFlowInput) error { + if turn == 0 { + close(processingTurn1) // signal: first turn is processing + <-turn1Done // wait for test to send more messages + } + turn++ + for _, msg := range input.Messages { + sess.AddMessages(ai.NewModelTextMessage("echo: " + msg.Text())) + } + return nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Send first message. + if err := conn.SendText("turn1"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + + // Wait for first turn to start processing. + <-processingTurn1 + + // Send multiple messages while turn 1 is busy. + if err := conn.SendText("msg2"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + if err := conn.SendText("msg3"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + + // Unblock turn 1. + close(turn1Done) + conn.Close() + + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.TurnEnd != nil { + turnInputCounts = append(turnInputCounts, chunk.TurnEnd.InputCount) + } + } + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Turn 1 should have processed 1 input. + if len(turnInputCounts) < 1 { + t.Fatalf("expected at least 1 turn, got %d", len(turnInputCounts)) + } + if turnInputCounts[0] != 1 { + t.Errorf("turn 1: expected inputCount=1, got %d", turnInputCounts[0]) + } + + // All inputs should be accounted for across turns. + totalInputs := 0 + for _, c := range turnInputCounts { + totalInputs += c + } + if totalInputs != 3 { + t.Errorf("expected total inputCount=3 across all turns, got %d", totalInputs) + } + + // All messages should be in the session: 3 user + 3 echoes = 6. + if got := len(response.State.Messages); got != 6 { + t.Errorf("expected 6 messages, got %d", got) + } +} + +func TestSessionFlow_RunBatched_InputCount(t *testing.T) { + // Verify that TurnEnd.InputCount is always 1 for regular Run. + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineSessionFlow(reg, "inputCountFlow", + func(ctx context.Context, resp Responder[testStatus], sess *SessionRunner[testState]) (*SessionFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("hello") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.TurnEnd != nil { + if chunk.TurnEnd.InputCount != 1 { + t.Errorf("expected InputCount=1 for regular Run, got %d", chunk.TurnEnd.InputCount) + } + break + } + } + conn.Close() + conn.Output() +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 01fbcdc3fe..f588a55f42 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -472,7 +472,7 @@ func DefineBidiFlow[In, Out, StreamOut, StreamIn any](g *Genkit, name string, fn // // Send a message and stream the response: // conn.SendText("Hello!") // for chunk, err := range conn.Receive() { -// if chunk.EndTurn { +// if chunk.TurnEnd != nil { // break // } // fmt.Print(chunk.ModelChunk.Text()) @@ -540,7 +540,7 @@ func DefineSessionFlow[Stream, State any]( // // Send a message and stream the response: // conn.SendText("Hello!") // for chunk, err := range conn.Receive() { -// if chunk.EndTurn { +// if chunk.TurnEnd != nil { // break // } // fmt.Print(chunk.ModelChunk.Text()) diff --git a/go/samples/custom-agent/main.go b/go/samples/custom-agent/main.go index b8e2274e0c..a05c908557 100644 --- a/go/samples/custom-agent/main.go +++ b/go/samples/custom-agent/main.go @@ -104,10 +104,10 @@ func main() { if chunk.ModelChunk != nil { fmt.Print(chunk.ModelChunk.Text()) } - if chunk.SnapshotID != "" { - fmt.Printf("\n[snapshot: %s]", chunk.SnapshotID) - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + fmt.Printf("\n[snapshot: %s]", chunk.TurnEnd.SnapshotID) + } fmt.Println() fmt.Println() break diff --git a/go/samples/prompt-agent/main.go b/go/samples/prompt-agent/main.go index e46ff2067e..b501872520 100644 --- a/go/samples/prompt-agent/main.go +++ b/go/samples/prompt-agent/main.go @@ -84,10 +84,10 @@ func main() { if chunk.ModelChunk != nil { fmt.Print(chunk.ModelChunk.Text()) } - if chunk.SnapshotID != "" { - fmt.Printf("\n[snapshot: %s]", chunk.SnapshotID) - } - if chunk.EndTurn { + if chunk.TurnEnd != nil { + if chunk.TurnEnd.SnapshotID != "" { + fmt.Printf("\n[snapshot: %s]", chunk.TurnEnd.SnapshotID) + } fmt.Println() fmt.Println() break