Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions go/ai/exp/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,26 @@ type SessionFlowResult struct {
type SessionFlowStreamChunk[Stream any] struct {
// Artifact contains a newly produced artifact.
Artifact *Artifact `json:"artifact,omitempty"`
// EndTurn signals that the session flow has finished processing the current input.
// When true, the client should stop iterating and may send the next input.
EndTurn bool `json:"endTurn,omitempty"`
// ModelChunk contains generation tokens from the model.
ModelChunk *ai.ModelResponseChunk `json:"modelChunk,omitempty"`
// SnapshotID contains the ID of a snapshot that was just persisted.
SnapshotID string `json:"snapshotId,omitempty"`
// Status contains user-defined structured status information.
// The Stream type parameter defines the shape of this data.
Status Stream `json:"status,omitempty"`
// TurnEnd signals that the session flow has finished processing the current
// turn. When non-nil, the client should stop iterating and may send the
// next input. It carries the snapshot ID (if any) and the number of inputs
// that were combined into this turn.
TurnEnd *TurnEnd `json:"turnEnd,omitempty"`
}

// TurnEnd signals the completion of a turn and carries per-turn metadata.
type TurnEnd struct {
// SnapshotID contains the ID of the snapshot persisted at the end of
// this turn. Empty if no snapshot was created.
SnapshotID string `json:"snapshotId,omitempty"`
// InputCount is the number of client inputs that were combined into
// this turn. Always >= 1.
InputCount int `json:"inputCount"`
}

// Artifact represents a named collection of parts produced during a session.
Expand Down
15 changes: 13 additions & 2 deletions go/ai/exp/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ type SessionFlowOption[State any] interface {
}

type sessionFlowOptions[State any] struct {
store SessionStore[State]
callback SnapshotCallback[State]
store SessionStore[State]
callback SnapshotCallback[State]
combineInputs bool
}

func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[State]) error {
Expand All @@ -46,6 +47,9 @@ func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[St
}
opts.callback = o.callback
}
if o.combineInputs {
opts.combineInputs = true
}
return nil
}

Expand All @@ -60,6 +64,13 @@ func WithSnapshotCallback[State any](cb SnapshotCallback[State]) SessionFlowOpti
return &sessionFlowOptions[State]{callback: cb}
}

// WithBatchedInputs enables input batching for prompt-backed session flows.
// When set, [DefineSessionFlowFromPrompt] uses [SessionRunner.RunBatched]
// instead of [SessionRunner.Run], combining queued inputs into a single turn.
func WithBatchedInputs[State any]() SessionFlowOption[State] {
return &sessionFlowOptions[State]{combineInputs: true}
}

// WithSnapshotOn configures snapshots to be created only for the specified events.
// For example, WithSnapshotOn[MyState](SnapshotEventTurnEnd) skips the
// invocation-end snapshot.
Expand Down
116 changes: 106 additions & 10 deletions go/ai/exp/session_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,17 @@ type SessionRunner[State any] struct {
lastSnapshot *SessionSnapshot[State]
lastSnapshotVersion uint64
collectTurnOutput func() any
turnInputCount int // set by Run (1) or RunBatched (N), read by onEndTurn
}

// Run loops over the input channel, calling fn for each turn. Each turn is
// wrapped in a trace span for observability. Input messages are automatically
// added to the session before fn is called. After fn returns successfully, an
// EndTurn chunk is sent and a snapshot check is triggered.
// added to the session before fn is called. After fn returns successfully, a
// [TurnEnd] chunk is sent and a snapshot check is triggered.
func (a *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Context, input *SessionFlowInput) error) error {
for input := range a.InputCh {
a.turnInputCount = 1

spanMeta := &tracing.SpanMetadata{
Name: fmt.Sprintf("sessionFlow/turn/%d", a.TurnIndex),
Type: "flowStep",
Expand Down Expand Up @@ -97,6 +100,82 @@ func (a *SessionRunner[State]) Run(ctx context.Context, fn func(ctx context.Cont
return nil
}

// RunBatched is like [Run] but combines queued inputs into a single turn.
// When the server finishes processing a turn and additional inputs are
// waiting in the channel, they are merged into one [SessionFlowInput]
// (messages concatenated) and processed as a single turn. The resulting
// [TurnEnd] chunk reports how many client inputs were combined via
// InputCount.
//
// A draining goroutine keeps the input channel flowing so that clients
// sending multiple messages in rapid succession do not block. The
// intermediary buffer is bounded (128 items) to provide back-pressure
// for fast producers.
func (a *SessionRunner[State]) RunBatched(ctx context.Context, fn func(ctx context.Context, input *SessionFlowInput) error) error {
// Draining goroutine: forwards InputCh into a larger buffer so that
// the non-blocking drain below can capture all queued inputs.
buf := make(chan *SessionFlowInput, 128)
go func() {
defer close(buf)
for input := range a.InputCh {
buf <- input
}
}()

for first := range buf {
// Non-blocking drain: combine any queued inputs with the first.
combined := &SessionFlowInput{
Messages: append([]*ai.Message(nil), first.Messages...),
ToolRestarts: append([]*ai.Part(nil), first.ToolRestarts...),
}
count := 1
drain:
for {
select {
case next, ok := <-buf:
if !ok {
break drain
}
combined.Messages = append(combined.Messages, next.Messages...)
combined.ToolRestarts = append(combined.ToolRestarts, next.ToolRestarts...)
count++
default:
break drain
}
}

a.turnInputCount = count

spanMeta := &tracing.SpanMetadata{
Name: fmt.Sprintf("sessionFlow/turn/%d", a.TurnIndex),
Type: "flowStep",
Subtype: "flowStep",
}

_, err := tracing.RunInNewSpan(ctx, spanMeta, combined,
func(ctx context.Context, input *SessionFlowInput) (any, error) {
a.AddMessages(input.Messages...)

if err := fn(ctx, input); err != nil {
return nil, err
}

a.onEndTurn(ctx)
a.TurnIndex++

if a.collectTurnOutput != nil {
return a.collectTurnOutput(), nil
}
return nil, nil
},
)
if err != nil {
return err
}
}
return nil
}

// Result returns an [SessionFlowResult] populated from the current session state:
// the last message in the conversation history and all artifacts.
// It is a convenience for custom session flows that don't need to construct the
Expand Down Expand Up @@ -281,8 +360,8 @@ func DefineSessionFlow[Stream, State any](
if chunk.Artifact != nil {
session.AddArtifacts(chunk.Artifact)
}
// Accumulate content chunks (exclude control signals from onEndTurn).
if !chunk.EndTurn && chunk.SnapshotID == "" {
// Accumulate content chunks (exclude TurnEnd control signals).
if chunk.TurnEnd == nil {
turnMu.Lock()
turnChunks = append(turnChunks, chunk)
turnMu.Unlock()
Expand All @@ -291,14 +370,20 @@ func DefineSessionFlow[Stream, State any](
}
}()

// Wire up onEndTurn: triggers snapshot + sends EndTurn chunk.
// Wire up onEndTurn: triggers snapshot + sends TurnEnd chunk.
// Writes through respCh to preserve ordering with user chunks.
agentSess.onEndTurn = func(turnCtx context.Context) {
snapshotID := agentSess.maybeSnapshot(turnCtx, SnapshotEventTurnEnd)
if snapshotID != "" {
respCh <- &SessionFlowStreamChunk[Stream]{SnapshotID: snapshotID}
inputCount := agentSess.turnInputCount
if inputCount == 0 {
inputCount = 1
}
respCh <- &SessionFlowStreamChunk[Stream]{
TurnEnd: &TurnEnd{
SnapshotID: snapshotID,
InputCount: inputCount,
},
}
respCh <- &SessionFlowStreamChunk[Stream]{EndTurn: true}
}

result, fnErr := fn(ctx, Responder[Stream](respCh), agentSess)
Expand Down Expand Up @@ -359,8 +444,19 @@ func DefineSessionFlowFromPrompt[State, PromptIn any](
panic(fmt.Sprintf("DefineSessionFlowFromPrompt: prompt %q not found", promptName))
}

// Pre-scan options for prompt-flow settings.
scanOpts := &sessionFlowOptions[State]{}
for _, opt := range opts {
opt.applySessionFlow(scanOpts)
}
combineInputs := scanOpts.combineInputs

fn := func(ctx context.Context, resp Responder[any], sess *SessionRunner[State]) (*SessionFlowResult, error) {
if err := sess.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error {
run := sess.Run
if combineInputs {
run = sess.RunBatched
}
if err := run(ctx, func(ctx context.Context, input *SessionFlowInput) error {
// Resolve prompt input: session state override > default.
promptInput := defaultInput
if stored := sess.InputVariables(); stored != nil {
Expand Down Expand Up @@ -634,7 +730,7 @@ func (c *SessionFlowConnection[Stream, State]) Close() error {
// Receive returns an iterator for receiving stream chunks.
// Unlike the underlying BidiConnection.Receive, breaking out of this iterator
// does not cancel the connection. This enables multi-turn patterns where the
// caller breaks on EndTurn, sends the next input, then calls Receive again.
// caller breaks on TurnEnd, sends the next input, then calls Receive again.
func (c *SessionFlowConnection[Stream, State]) Receive() iter.Seq2[*SessionFlowStreamChunk[Stream], error] {
c.initReceiver()
return func(yield func(*SessionFlowStreamChunk[Stream], error) bool) {
Expand Down
Loading
Loading