Skip to content
Merged
2 changes: 1 addition & 1 deletion crates/jp_cli/src/cmd/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ impl Query {
.handle_tool_call(
cfg,
mcp_client,
root,
&root,
is_tty,
turn_state,
request,
Expand Down
118 changes: 105 additions & 13 deletions crates/jp_cli/src/cmd/query/event.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::{env, fmt::Write, fs, path::PathBuf, time};
use std::{env, fmt::Write, fs, path::Path, time};

use crossterm::style::Stylize as _;
use indexmap::IndexMap;
use indexmap::{IndexMap, IndexSet};
use jp_config::{
AppConfig,
conversation::tool::{
ToolConfigWithDefaults,
QuestionTarget, ToolConfigWithDefaults,
style::{InlineResults, LinkStyle, ParametersStyle, TruncateLines},
},
style::{
Expand All @@ -21,7 +21,7 @@ use jp_llm::{ToolError, tool::ToolDefinition};
use jp_printer::PrinterWriter;
use jp_term::osc::hyperlink;
use jp_tool::{AnswerType, Question};
use serde_json::Value;
use serde_json::{Value, json};

use super::{ResponseHandler, turn::TurnState};
use crate::Error;
Expand Down Expand Up @@ -111,17 +111,50 @@ impl StreamEventHandler {
&mut self,
cfg: &AppConfig,
mcp_client: &jp_mcp::Client,
root: PathBuf,
root: &Path,
is_tty: bool,
turn_state: &mut TurnState,
call: ToolCallRequest,
handler: &mut ResponseHandler,
mut writer: PrinterWriter<'_>,
) -> Result<Option<String>, Error> {
let Some(tool_config) = cfg.conversation.tools.get(&call.name) else {
return Err(Error::NotFound("tool", call.name.clone()));
let response = ToolCallResponse {
id: call.id.clone(),
result: Err(format!("Tool '{}' not found.", call.name)),
};

self.tool_call_responses.push(response.clone());
return Ok(None);
};

let mut arguments_without_tool_answers = call.arguments.clone();

// Remove the special `tool_answers` argument, if any.
let mut tool_answers = arguments_without_tool_answers
.remove("tool_answers")
.map_or(Ok(IndexMap::new()), |v| match v {
Value::Object(v) => Ok(v.into_iter().collect()),
_ => Err(ToolError::ToolCallFailed(
"`tool_answers` argument must be an object".to_owned(),
)),
})?;

// Remove any pending questions for this tool call that are now
// answered.
let mut answered_questions = false;
if let Some(pending) = turn_state.pending_tool_call_questions.get_mut(&call.name) {
for question_id in tool_answers.keys() {
answered_questions = pending.shift_remove(question_id);
}

if pending.is_empty() {
turn_state
.pending_tool_call_questions
.shift_remove(&call.name);
}
}

let editor = cfg.editor.path();

self.tool_calls.push(call.clone());
Expand All @@ -130,11 +163,12 @@ impl StreamEventHandler {
tool_config.source(),
tool_config.description().map(str::to_owned),
tool_config.parameters().clone(),
tool_config.questions(),
mcp_client,
)
.await?;

if handler.render_tool_calls {
if handler.render_tool_calls && !answered_questions {
let (_raw, args) = match &tool_config.style().parameters {
ParametersStyle::Off => (false, ".".to_owned()),
ParametersStyle::Json => {
Expand All @@ -159,7 +193,7 @@ impl StreamEventHandler {
let cmd = command.clone().command();
let name = tool_config.source().tool_name();

match tool.format_args(name, &cmd, &call.arguments, &root)? {
match tool.format_args(name, &cmd, &arguments_without_tool_answers, root)? {
Ok(args) if args.is_empty() => (false, ".".to_owned()),
Ok(args) => (true, format!(":\n\n{args}")),
result @ Err(_) => {
Expand All @@ -180,16 +214,19 @@ impl StreamEventHandler {
write!(writer, "\n\n")?;
}

let mut answers = IndexMap::new();
loop {
match tool
.call(
call.id.clone(),
Value::Object(call.arguments.clone()),
&answers,
Value::Object(arguments_without_tool_answers.clone()),
&tool_answers,
turn_state
.pending_tool_call_questions
.get(&call.name)
.unwrap_or(&IndexSet::new()),
mcp_client,
tool_config.clone(),
&root,
root,
editor.as_deref(),
writer,
)
Expand Down Expand Up @@ -232,6 +269,61 @@ impl StreamEventHandler {
answer.clone()
} else if let Some(answer) = tool_config.get_answer(&question.id) {
answer.clone()
} else if matches!(
tool_config.question_target(&question.id),
Some(QuestionTarget::Assistant)
) {
// Keep track of pending questions for this tool call.
turn_state
.pending_tool_call_questions
.entry(call.name.clone())
.or_default()
.insert(question.id.clone());

// Ask the assistant to answer the question
let mut args = call.arguments.clone();
args.entry("tool_answers".to_owned())
.and_modify(|v| match v {
Value::Object(_) => {}
_ => *v = json!({}),
})
.or_insert_with(|| json!({}))
.as_object_mut()
.expect("tool_answers must be an object")
.insert(question.id.clone(), "<ANSWER HERE>".into());

let values = match question.answer_type {
AnswerType::Boolean => "any boolean type (true or false).".to_owned(),
AnswerType::Select { options } => indoc::formatdoc! {"
one of the following string values:

- {}
", options.join("\n- ")},
AnswerType::Text => "any string.".to_owned(),
};

let response = ToolCallResponse {
id: call.id.clone(),
result: Ok(indoc::formatdoc! {"
Tool requires additional input before it can complete the request:

{}

Please re-run the tool with the following arguments:

```json
{}
```

Where `<ANSWER HERE>` must be {values}
",
question.text,
serde_json::to_string_pretty(&args)?}),
};

self.tool_call_responses.push(response.clone());

return Ok(None);
} else if is_tty {
let (answer, persist_level) = prompt_user(&question, writer)?;

Expand All @@ -249,7 +341,7 @@ impl StreamEventHandler {
question.default.unwrap_or_default()
};

answers.insert(question.id.clone(), answer);
tool_answers.insert(question.id.clone(), answer);
}
Err(e) => return Err(e.into()),
}
Expand Down
44 changes: 39 additions & 5 deletions crates/jp_cli/src/cmd/query/turn.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,56 @@
use indexmap::IndexMap;
//! Utilities related to conversation turns.
//!
//! See [`TurnState`] for more details.

use indexmap::{IndexMap, IndexSet};
use serde_json::Value;

/// State that is persisted for the duration of a turn.
///
/// A turn is one or more request-response cycle(s) between the user and the
/// assistant.
///
/// A turn MUST be initiated by the user with a `ChatRequest`, which MUST be
/// followed by a `ChatResponse` and/or `ToolCallRequest` from the assistant.
///
/// After a `ToolCallRequest`, the user MUST return a `ToolCallResponse`, after
/// which the assistant MUST return a `ChatResponse` and/or a `ToolCallRequest`.
///
/// The turn CONTINUES as long as the assistant responds with at least one
/// `ToolCallRequest`.
///
/// The turn ENDS when the assistant responds with a `ChatResponse` but no
/// `ToolCallRequest`.
#[derive(Debug, Default)]
pub struct TurnState {
/// Tool answers that are instructed to be re-used for the duration of the
/// turn.
///
/// For example, if a tool `foo` asks a question `bar`, and the user
/// indicates that the same answer should be used during this turn, then
/// this map will contain a key `foo` with a value that contains a key
/// `bar` with the [`Value`] of the answer.
/// indicates that the same answer should be used during this turn, then
/// this map will contain a key `foo` with a value that contains a key `bar`
/// with the [`Value`] of the answer.
pub persisted_tool_answers: IndexMap<String, IndexMap<String, Value>>,

/// The number of times we've tried a request to the assistant.
///
/// This is used when the assistant returns an error that is retryable.
/// Every retry increments this counter, until a maximum number of retries
/// is reached.
/// is reached, after which the turn ends in an error.
pub request_count: usize,

/// A list of pending tool call questions.
///
/// The key is the [`ToolCallRequest::name`], the value is a list of
/// question IDs that have not yet been answered.
///
/// NOTE: In the future we could swap this to use `id` instead of `name`,
/// but that requires either the LLM to correctly return the ID of the
/// original tool call in the response, which might be fragile, or for us to
/// track more tool call state, which is a bit more complex. Returning the
/// name of the tool call is the simplest solution, and hasn't
/// caused any issues so far.
///
/// [`ToolCallRequest::name`]: jp_conversation::event::ToolCallRequest::name
pub pending_tool_call_questions: IndexMap<String, IndexSet<String>>,
}
Loading
Loading