Skip to content

Commit 2bfc95d

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: HITL/Wire up tool confirmation support
This is a port of the python implementation and part of the "human in the loop" workflow. PiperOrigin-RevId: 827953462
1 parent 5e68159 commit 2bfc95d

File tree

8 files changed

+662
-25
lines changed

8 files changed

+662
-25
lines changed

core/src/main/java/com/google/adk/flows/llmflows/Contents.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ private ImmutableList<Content> getContents(
109109
if (!isEventBelongsToBranch(currentBranch, event)) {
110110
continue;
111111
}
112+
if (isRequestConfirmationEvent(event)) {
113+
continue;
114+
}
112115

113116
// TODO: Skip auth events.
114117

@@ -511,4 +514,21 @@ private static boolean hasContentWithNonEmptyParts(Event event) {
511514
.map(list -> !list.isEmpty()) // Optional<Boolean>
512515
.orElse(false);
513516
}
517+
518+
/** Checks if the event is a request confirmation event. */
519+
private static boolean isRequestConfirmationEvent(Event event) {
520+
return event.content().flatMap(Content::parts).stream()
521+
.flatMap(List::stream)
522+
// return event.content().flatMap(Content::parts).orElse(ImmutableList.of()).stream()
523+
.anyMatch(
524+
part ->
525+
part.functionCall()
526+
.flatMap(FunctionCall::name)
527+
.map(Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME::equals)
528+
.orElse(false)
529+
|| part.functionResponse()
530+
.flatMap(FunctionResponse::name)
531+
.map(Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME::equals)
532+
.orElse(false));
533+
}
514534
}

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package com.google.adk.flows.llmflows;
1919

20+
import static com.google.common.collect.ImmutableMap.toImmutableMap;
21+
2022
import com.google.adk.Telemetry;
2123
import com.google.adk.agents.ActiveStreamingTool;
2224
import com.google.adk.agents.Callbacks.AfterToolCallback;
@@ -27,6 +29,7 @@
2729
import com.google.adk.events.EventActions;
2830
import com.google.adk.tools.BaseTool;
2931
import com.google.adk.tools.FunctionTool;
32+
import com.google.adk.tools.ToolConfirmation;
3033
import com.google.adk.tools.ToolContext;
3134
import com.google.common.base.VerifyException;
3235
import com.google.common.collect.ImmutableList;
@@ -52,14 +55,14 @@
5255
import java.util.Optional;
5356
import java.util.Set;
5457
import java.util.UUID;
55-
import org.jspecify.annotations.Nullable;
5658
import org.slf4j.Logger;
5759
import org.slf4j.LoggerFactory;
5860

5961
/** Utility class for handling function calls. */
6062
public final class Functions {
6163

6264
private static final String AF_FUNCTION_CALL_ID_PREFIX = "adk-";
65+
static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation";
6366
private static final Logger logger = LoggerFactory.getLogger(Functions.class);
6467

6568
/** Generates a unique ID for a function call. */
@@ -122,6 +125,15 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) {
122125
/** Handles standard, non-streaming function calls. */
123126
public static Maybe<Event> handleFunctionCalls(
124127
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
128+
return handleFunctionCalls(invocationContext, functionCallEvent, tools, ImmutableMap.of());
129+
}
130+
131+
/** Handles standard, non-streaming function calls with tool confirmations. */
132+
public static Maybe<Event> handleFunctionCalls(
133+
InvocationContext invocationContext,
134+
Event functionCallEvent,
135+
Map<String, BaseTool> tools,
136+
Map<String, ToolConfirmation> toolConfirmations) {
125137
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
126138

127139
List<Maybe<Event>> functionResponseEvents = new ArrayList<>();
@@ -134,9 +146,10 @@ public static Maybe<Event> handleFunctionCalls(
134146
ToolContext toolContext =
135147
ToolContext.builder(invocationContext)
136148
.functionCallId(functionCall.id().orElse(""))
149+
.toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null)))
137150
.build();
138151

139-
Map<String, Object> functionArgs = functionCall.args().orElse(new HashMap<>());
152+
Map<String, Object> functionArgs = functionCall.args().orElse(ImmutableMap.of());
140153

141154
Maybe<Map<String, Object>> maybeFunctionResult =
142155
maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext)
@@ -192,10 +205,12 @@ public static Maybe<Event> handleFunctionCalls(
192205
if (events.isEmpty()) {
193206
return Maybe.empty();
194207
}
195-
Event mergedEvent = Functions.mergeParallelFunctionResponseEvents(events);
196-
if (mergedEvent == null) {
208+
Optional<Event> maybeMergedEvent =
209+
Functions.mergeParallelFunctionResponseEvents(events);
210+
if (maybeMergedEvent.isEmpty()) {
197211
return Maybe.empty();
198212
}
213+
var mergedEvent = maybeMergedEvent.get();
199214

200215
if (events.size() > 1) {
201216
Tracer tracer = Telemetry.getTracer();
@@ -288,7 +303,7 @@ public static Maybe<Event> handleFunctionCallsLive(
288303
if (events.isEmpty()) {
289304
return Maybe.empty();
290305
}
291-
return Maybe.just(Functions.mergeParallelFunctionResponseEvents(events));
306+
return Maybe.just(Functions.mergeParallelFunctionResponseEvents(events).orElse(null));
292307
});
293308
}
294309

@@ -387,13 +402,13 @@ public static Set<String> getLongRunningFunctionCalls(
387402
return longRunningFunctionCalls;
388403
}
389404

390-
private static @Nullable Event mergeParallelFunctionResponseEvents(
405+
private static Optional<Event> mergeParallelFunctionResponseEvents(
391406
List<Event> functionResponseEvents) {
392407
if (functionResponseEvents.isEmpty()) {
393-
return null;
408+
return Optional.empty();
394409
}
395410
if (functionResponseEvents.size() == 1) {
396-
return functionResponseEvents.get(0);
411+
return Optional.of(functionResponseEvents.get(0));
397412
}
398413
// Use the first event as the base for common attributes
399414
Event baseEvent = functionResponseEvents.get(0);
@@ -410,15 +425,16 @@ public static Set<String> getLongRunningFunctionCalls(
410425
mergedActionsBuilder.merge(event.actions());
411426
}
412427

413-
return Event.builder()
414-
.id(Event.generateEventId())
415-
.invocationId(baseEvent.invocationId())
416-
.author(baseEvent.author())
417-
.branch(baseEvent.branch())
418-
.content(Optional.of(Content.builder().role("user").parts(mergedParts).build()))
419-
.actions(mergedActionsBuilder.build())
420-
.timestamp(baseEvent.timestamp())
421-
.build();
428+
return Optional.of(
429+
Event.builder()
430+
.id(Event.generateEventId())
431+
.invocationId(baseEvent.invocationId())
432+
.author(baseEvent.author())
433+
.branch(baseEvent.branch())
434+
.content(Optional.of(Content.builder().role("user").parts(mergedParts).build()))
435+
.actions(mergedActionsBuilder.build())
436+
.timestamp(baseEvent.timestamp())
437+
.build());
422438
}
423439

424440
private static Maybe<Map<String, Object>> maybeInvokeBeforeToolCall(
@@ -563,5 +579,65 @@ private static Event buildResponseEvent(
563579
}
564580
}
565581

582+
/**
583+
* Generates a request confirmation event from a function response event.
584+
*
585+
* @param invocationContext The invocation context.
586+
* @param functionCallEvent The event containing the original function call.
587+
* @param functionResponseEvent The event containing the function response.
588+
* @return An optional event containing the request confirmation function call.
589+
*/
590+
public static Optional<Event> generateRequestConfirmationEvent(
591+
InvocationContext invocationContext, Event functionCallEvent, Event functionResponseEvent) {
592+
if (functionResponseEvent.actions().requestedToolConfirmations().isEmpty()) {
593+
return Optional.empty();
594+
}
595+
596+
List<Part> parts = new ArrayList<>();
597+
Set<String> longRunningToolIds = new HashSet<>();
598+
ImmutableMap<String, FunctionCall> functionCallsById =
599+
functionCallEvent.functionCalls().stream()
600+
.filter(fc -> fc.id().isPresent())
601+
.collect(toImmutableMap(fc -> fc.id().get(), fc -> fc));
602+
603+
for (Map.Entry<String, ToolConfirmation> entry :
604+
functionResponseEvent.actions().requestedToolConfirmations().entrySet().stream()
605+
.filter(fc -> functionCallsById.containsKey(fc.getKey()))
606+
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))
607+
.entrySet()) {
608+
609+
FunctionCall requestConfirmationFunctionCall =
610+
FunctionCall.builder()
611+
.name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)
612+
.args(
613+
ImmutableMap.of(
614+
"originalFunctionCall",
615+
functionCallsById.get(entry.getKey()),
616+
"toolConfirmation",
617+
entry.getValue()))
618+
.id(generateClientFunctionCallId())
619+
.build();
620+
621+
longRunningToolIds.add(requestConfirmationFunctionCall.id().get());
622+
parts.add(Part.builder().functionCall(requestConfirmationFunctionCall).build());
623+
}
624+
625+
if (parts.isEmpty()) {
626+
return Optional.empty();
627+
}
628+
629+
var contentBuilder = Content.builder().parts(parts);
630+
functionResponseEvent.content().flatMap(Content::role).ifPresent(contentBuilder::role);
631+
632+
return Optional.of(
633+
Event.builder()
634+
.invocationId(invocationContext.invocationId())
635+
.author(invocationContext.agent().name())
636+
.branch(invocationContext.branch())
637+
.content(contentBuilder.build())
638+
.longRunningToolIds(longRunningToolIds)
639+
.build());
640+
}
641+
566642
private Functions() {}
567643
}

0 commit comments

Comments
 (0)