diff --git a/crates/jp_cli/src/cmd/query.rs b/crates/jp_cli/src/cmd/query.rs index 97a9ae75..309245f2 100644 --- a/crates/jp_cli/src/cmd/query.rs +++ b/crates/jp_cli/src/cmd/query.rs @@ -943,7 +943,7 @@ impl Query { .handle_tool_call( cfg, mcp_client, - root, + &root, is_tty, turn_state, request, diff --git a/crates/jp_cli/src/cmd/query/event.rs b/crates/jp_cli/src/cmd/query/event.rs index d2f411f6..c8d7f400 100644 --- a/crates/jp_cli/src/cmd/query/event.rs +++ b/crates/jp_cli/src/cmd/query/event.rs @@ -1,11 +1,11 @@ -use std::{env, fmt::Write, fs, path::PathBuf, time}; +use std::{env, fmt::Write, fs, path::Path, time}; use crossterm::style::Stylize as _; -use indexmap::IndexMap; +use indexmap::{IndexMap, IndexSet}; use jp_config::{ AppConfig, conversation::tool::{ - ToolConfigWithDefaults, + QuestionTarget, ToolConfigWithDefaults, style::{InlineResults, LinkStyle, ParametersStyle, TruncateLines}, }, style::{ @@ -21,7 +21,7 @@ use jp_llm::{ToolError, tool::ToolDefinition}; use jp_printer::PrinterWriter; use jp_term::osc::hyperlink; use jp_tool::{AnswerType, Question}; -use serde_json::Value; +use serde_json::{Value, json}; use super::{ResponseHandler, turn::TurnState}; use crate::Error; @@ -111,7 +111,7 @@ impl StreamEventHandler { &mut self, cfg: &AppConfig, mcp_client: &jp_mcp::Client, - root: PathBuf, + root: &Path, is_tty: bool, turn_state: &mut TurnState, call: ToolCallRequest, @@ -119,9 +119,42 @@ impl StreamEventHandler { mut writer: PrinterWriter<'_>, ) -> Result, Error> { let Some(tool_config) = cfg.conversation.tools.get(&call.name) else { - return Err(Error::NotFound("tool", call.name.clone())); + let response = ToolCallResponse { + id: call.id.clone(), + result: Err(format!("Tool '{}' not found.", call.name)), + }; + + self.tool_call_responses.push(response.clone()); + return Ok(None); }; + let mut arguments_without_tool_answers = call.arguments.clone(); + + // Remove the special `tool_answers` argument, if any. + let mut tool_answers = arguments_without_tool_answers + .remove("tool_answers") + .map_or(Ok(IndexMap::new()), |v| match v { + Value::Object(v) => Ok(v.into_iter().collect()), + _ => Err(ToolError::ToolCallFailed( + "`tool_answers` argument must be an object".to_owned(), + )), + })?; + + // Remove any pending questions for this tool call that are now + // answered. + let mut answered_questions = false; + if let Some(pending) = turn_state.pending_tool_call_questions.get_mut(&call.name) { + for question_id in tool_answers.keys() { + answered_questions = pending.shift_remove(question_id); + } + + if pending.is_empty() { + turn_state + .pending_tool_call_questions + .shift_remove(&call.name); + } + } + let editor = cfg.editor.path(); self.tool_calls.push(call.clone()); @@ -130,11 +163,12 @@ impl StreamEventHandler { tool_config.source(), tool_config.description().map(str::to_owned), tool_config.parameters().clone(), + tool_config.questions(), mcp_client, ) .await?; - if handler.render_tool_calls { + if handler.render_tool_calls && !answered_questions { let (_raw, args) = match &tool_config.style().parameters { ParametersStyle::Off => (false, ".".to_owned()), ParametersStyle::Json => { @@ -159,7 +193,7 @@ impl StreamEventHandler { let cmd = command.clone().command(); let name = tool_config.source().tool_name(); - match tool.format_args(name, &cmd, &call.arguments, &root)? { + match tool.format_args(name, &cmd, &arguments_without_tool_answers, root)? { Ok(args) if args.is_empty() => (false, ".".to_owned()), Ok(args) => (true, format!(":\n\n{args}")), result @ Err(_) => { @@ -180,16 +214,19 @@ impl StreamEventHandler { write!(writer, "\n\n")?; } - let mut answers = IndexMap::new(); loop { match tool .call( call.id.clone(), - Value::Object(call.arguments.clone()), - &answers, + Value::Object(arguments_without_tool_answers.clone()), + &tool_answers, + turn_state + .pending_tool_call_questions + .get(&call.name) + .unwrap_or(&IndexSet::new()), mcp_client, tool_config.clone(), - &root, + root, editor.as_deref(), writer, ) @@ -232,6 +269,61 @@ impl StreamEventHandler { answer.clone() } else if let Some(answer) = tool_config.get_answer(&question.id) { answer.clone() + } else if matches!( + tool_config.question_target(&question.id), + Some(QuestionTarget::Assistant) + ) { + // Keep track of pending questions for this tool call. + turn_state + .pending_tool_call_questions + .entry(call.name.clone()) + .or_default() + .insert(question.id.clone()); + + // Ask the assistant to answer the question + let mut args = call.arguments.clone(); + args.entry("tool_answers".to_owned()) + .and_modify(|v| match v { + Value::Object(_) => {} + _ => *v = json!({}), + }) + .or_insert_with(|| json!({})) + .as_object_mut() + .expect("tool_answers must be an object") + .insert(question.id.clone(), "".into()); + + let values = match question.answer_type { + AnswerType::Boolean => "any boolean type (true or false).".to_owned(), + AnswerType::Select { options } => indoc::formatdoc! {" + one of the following string values: + + - {} + ", options.join("\n- ")}, + AnswerType::Text => "any string.".to_owned(), + }; + + let response = ToolCallResponse { + id: call.id.clone(), + result: Ok(indoc::formatdoc! {" + Tool requires additional input before it can complete the request: + + {} + + Please re-run the tool with the following arguments: + + ```json + {} + ``` + + Where `` must be {values} + ", + question.text, + serde_json::to_string_pretty(&args)?}), + }; + + self.tool_call_responses.push(response.clone()); + + return Ok(None); } else if is_tty { let (answer, persist_level) = prompt_user(&question, writer)?; @@ -249,7 +341,7 @@ impl StreamEventHandler { question.default.unwrap_or_default() }; - answers.insert(question.id.clone(), answer); + tool_answers.insert(question.id.clone(), answer); } Err(e) => return Err(e.into()), } diff --git a/crates/jp_cli/src/cmd/query/turn.rs b/crates/jp_cli/src/cmd/query/turn.rs index f6f59a7e..641ee047 100644 --- a/crates/jp_cli/src/cmd/query/turn.rs +++ b/crates/jp_cli/src/cmd/query/turn.rs @@ -1,22 +1,56 @@ -use indexmap::IndexMap; +//! Utilities related to conversation turns. +//! +//! See [`TurnState`] for more details. + +use indexmap::{IndexMap, IndexSet}; use serde_json::Value; /// State that is persisted for the duration of a turn. +/// +/// A turn is one or more request-response cycle(s) between the user and the +/// assistant. +/// +/// A turn MUST be initiated by the user with a `ChatRequest`, which MUST be +/// followed by a `ChatResponse` and/or `ToolCallRequest` from the assistant. +/// +/// After a `ToolCallRequest`, the user MUST return a `ToolCallResponse`, after +/// which the assistant MUST return a `ChatResponse` and/or a `ToolCallRequest`. +/// +/// The turn CONTINUES as long as the assistant responds with at least one +/// `ToolCallRequest`. +/// +/// The turn ENDS when the assistant responds with a `ChatResponse` but no +/// `ToolCallRequest`. #[derive(Debug, Default)] pub struct TurnState { /// Tool answers that are instructed to be re-used for the duration of the /// turn. /// /// For example, if a tool `foo` asks a question `bar`, and the user - /// indicates that the same answer should be used during this turn, then - /// this map will contain a key `foo` with a value that contains a key - /// `bar` with the [`Value`] of the answer. + /// indicates that the same answer should be used during this turn, then + /// this map will contain a key `foo` with a value that contains a key `bar` + /// with the [`Value`] of the answer. pub persisted_tool_answers: IndexMap>, /// The number of times we've tried a request to the assistant. /// /// This is used when the assistant returns an error that is retryable. /// Every retry increments this counter, until a maximum number of retries - /// is reached. + /// is reached, after which the turn ends in an error. pub request_count: usize, + + /// A list of pending tool call questions. + /// + /// The key is the [`ToolCallRequest::name`], the value is a list of + /// question IDs that have not yet been answered. + /// + /// NOTE: In the future we could swap this to use `id` instead of `name`, + /// but that requires either the LLM to correctly return the ID of the + /// original tool call in the response, which might be fragile, or for us to + /// track more tool call state, which is a bit more complex. Returning the + /// name of the tool call is the simplest solution, and hasn't + /// caused any issues so far. + /// + /// [`ToolCallRequest::name`]: jp_conversation::event::ToolCallRequest::name + pub pending_tool_call_questions: IndexMap>, } diff --git a/crates/jp_config/src/conversation/tool.rs b/crates/jp_config/src/conversation/tool.rs index 83b26503..84d99f9e 100644 --- a/crates/jp_config/src/conversation/tool.rs +++ b/crates/jp_config/src/conversation/tool.rs @@ -221,23 +221,13 @@ pub struct ToolConfig { #[setting(nested)] pub style: Option, - /// Automated responses to tool questions. - /// - /// This allows configuring predefined answers to questions that the tool - /// may ask during execution (e.g., "overwrite existing file?"). When an - /// answer is configured for a specific question ID, the tool will use it - /// automatically instead of prompting the user interactively. + /// Configuration for questions that the tool may ask during execution. /// /// Question IDs are defined by the tool implementation and should be /// documented by the tool. For example, `fs_create_file` uses /// `overwrite_file` when a file already exists. - // TODO: We should add an enumeration of possible options: - // - // - Fixed answer - // - Prompt once per turn - // - Prompt once per conversation - #[setting(default = IndexMap::new())] - pub answers: IndexMap, + #[setting(nested, merge = merge_nested_indexmap)] + pub questions: IndexMap, } impl AssignKeyValue for PartialToolConfig { @@ -252,7 +242,7 @@ impl AssignKeyValue for PartialToolConfig { "run" => self.run = kv.try_some_from_str()?, "result" => self.result = kv.try_some_from_str()?, _ if kv.p("style") => self.style.assign(kv)?, - "answers" => self.answers = kv.try_object()?, + "questions" => self.questions = kv.try_object()?, _ => return missing_key(&kv), } @@ -287,21 +277,23 @@ impl PartialConfigDelta for PartialToolConfig { run: delta_opt(self.run.as_ref(), next.run), result: delta_opt(self.result.as_ref(), next.result), style: delta_opt_partial(self.style.as_ref(), next.style), - answers: match (&self.answers, next.answers) { - (Some(prev), Some(next)) => Some( - next.into_iter() - .filter_map(|(k, next)| { - let prev_val = prev.get(&k); - if prev_val.is_some_and(|prev| prev == &next) { - return None; - } - - Some((k, next)) - }) - .collect(), - ), - (_, next) => next, - }, + questions: next + .questions + .into_iter() + .filter_map(|(k, next)| { + let prev = self.questions.get(&k); + if prev.is_some_and(|prev| prev == &next) { + return None; + } + + let next = match prev { + Some(prev) => prev.delta(next), + None => next, + }; + + Some((k, next)) + }) + .collect(), } } } @@ -323,11 +315,11 @@ impl ToPartial for ToolConfig { run: partial_opts(self.run.as_ref(), defaults.run), result: partial_opts(self.result.as_ref(), defaults.result), style: partial_opt_config(self.style.as_ref(), defaults.style), - answers: if self.answers.is_empty() { - defaults.answers - } else { - Some(self.answers.clone()) - }, + questions: self + .questions + .iter() + .map(|(k, v)| (k.clone(), v.to_partial())) + .collect(), } } } @@ -975,16 +967,85 @@ impl ToolConfigWithDefaults { self.tool.style.as_ref().unwrap_or(&self.defaults.style) } + /// Return the questions configuration of the tool. + #[must_use] + pub const fn questions(&self) -> &IndexMap { + &self.tool.questions + } + + /// Return the question target for the given question ID. + #[must_use] + pub fn question_target(&self, question_id: &str) -> Option { + self.tool.questions.get(question_id).map(|q| q.target) + } + /// Get an automated answer for a question. /// /// Returns the configured answer if one exists for the given question ID, /// otherwise returns `None`. #[must_use] pub fn get_answer(&self, question_id: &str) -> Option<&Value> { - self.tool.answers.get(question_id) + self.tool + .questions + .get(question_id) + .and_then(|q| q.answer.as_ref()) + } +} + +/// Question configuration. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Config)] +#[config(rename_all = "snake_case")] +pub struct QuestionConfig { + /// The target of the question. + /// + /// This determines whether the question is asked interactively to the user, + /// or sent to the assistant to be answered. + pub target: QuestionTarget, + + /// The fixed answer to the question. + /// + /// If this is set, the question will not be presented to the target, but + /// will always be answered with the given value. + // TODO: We should add an enumeration of possible options: + // + // - Fixed answer + // - Prompt once per turn + // - Prompt once per conversation + pub answer: Option, +} + +impl PartialConfigDelta for PartialQuestionConfig { + fn delta(&self, next: Self) -> Self { + Self { + target: delta_opt(self.target.as_ref(), next.target), + answer: delta_opt(self.answer.as_ref(), next.answer), + } + } +} + +impl ToPartial for QuestionConfig { + fn to_partial(&self) -> Self::Partial { + let defaults = Self::Partial::default(); + + Self::Partial { + target: partial_opt(&self.target, defaults.target), + answer: partial_opts(self.answer.as_ref(), defaults.answer), + } } } +/// The target of a question. +#[derive(Debug, Clone, Copy, PartialEq, Default, Serialize, Deserialize, ConfigEnum)] +#[serde(rename_all = "snake_case")] +pub enum QuestionTarget { + /// Ask the question to the user. + #[default] + User, + + /// Ask the question to the assistant. + Assistant, +} + #[cfg(test)] mod tests { use assert_matches::assert_matches; diff --git a/crates/jp_conversation/src/event/tool_call.rs b/crates/jp_conversation/src/event/tool_call.rs index 48b71d72..4a84d0f3 100644 --- a/crates/jp_conversation/src/event/tool_call.rs +++ b/crates/jp_conversation/src/event/tool_call.rs @@ -1,13 +1,13 @@ //! See [`ToolCallRequest`] and [`ToolCallResponse`]. -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::{Deserialize, Deserializer, Serialize, Serializer, ser::SerializeStruct as _}; use serde_json::{Map, Value}; /// A tool call request event - requesting execution of a tool. /// /// This event is typically triggered by the assistant as part of its response, /// but can also be triggered automatically by the client. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ToolCallRequest { /// Unique identifier for this tool call pub id: String, @@ -16,7 +16,6 @@ pub struct ToolCallRequest { pub name: String, /// Arguments to pass to the tool - #[serde(with = "jp_serde::repr::base64_json_map")] pub arguments: Map, } @@ -32,6 +31,75 @@ impl ToolCallRequest { } } +impl Serialize for ToolCallRequest { + fn serialize(&self, serializer: Ser) -> Result + where + Ser: Serializer, + { + #[derive(Serialize)] + #[serde(transparent)] + struct Wrapper<'a>( + #[serde(with = "jp_serde::repr::base64_json_map")] &'a Map, + ); + + let mut arguments = self.arguments.clone(); + let tool_answers = arguments + .remove("tool_answers") + .unwrap_or_default() + .as_object() + .cloned() + .unwrap_or_default(); + + let mut size_hint = 3; + if !tool_answers.is_empty() { + size_hint += 1; + } + + let mut state = serializer.serialize_struct("ToolCallRequest", size_hint)?; + + state.serialize_field("id", &self.id)?; + state.serialize_field("name", &self.name)?; + state.serialize_field("arguments", &Wrapper(&arguments))?; + + if !tool_answers.is_empty() { + state.serialize_field("tool_answers", &Wrapper(&tool_answers))?; + } + + state.end() + } +} + +impl<'de> Deserialize<'de> for ToolCallRequest { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + #[allow(clippy::allow_attributes, clippy::missing_docs_in_private_items)] + struct Helper { + id: String, + name: String, + #[serde(default, with = "jp_serde::repr::base64_json_map")] + arguments: Map, + #[serde(default, with = "jp_serde::repr::base64_json_map")] + tool_answers: Map, + } + + let mut helper = Helper::deserialize(deserializer)?; + + helper.arguments.insert( + "tool_answers".to_owned(), + Value::Object(helper.tool_answers), + ); + + Ok(Self { + id: helper.id, + name: helper.name, + arguments: helper.arguments, + }) + } +} + /// A tool call response event - the result of executing a tool. /// /// This event MUST be in response to a `ToolCallRequest` event, with a matching @@ -65,7 +133,7 @@ impl ToolCallResponse { } } -// Custom serialization to maintain backward compatibility with the JSON format +// Custom serialization to make it easier to recognize errors. impl Serialize for ToolCallResponse { fn serialize(&self, serializer: S) -> Result where @@ -94,7 +162,7 @@ impl Serialize for ToolCallResponse { } } -// Custom deserialization to maintain backward compatibility with the JSON format +// Custom deserialization to make it easier to recognize errors. impl<'de> Deserialize<'de> for ToolCallResponse { fn deserialize(deserializer: D) -> Result where diff --git a/crates/jp_llm/src/query/structured.rs b/crates/jp_llm/src/query/structured.rs index ec49505d..0383b0e6 100644 --- a/crates/jp_llm/src/query/structured.rs +++ b/crates/jp_llm/src/query/structured.rs @@ -171,6 +171,7 @@ impl StructuredQuery { name: SCHEMA_TOOL_NAME.to_owned(), description: Some(description), parameters, + include_tool_answers_parameter: false, }) } diff --git a/crates/jp_llm/src/test.rs b/crates/jp_llm/src/test.rs index 2c60ef5c..158d2c5f 100644 --- a/crates/jp_llm/src/test.rs +++ b/crates/jp_llm/src/test.rs @@ -2,12 +2,14 @@ use std::{panic, sync::Arc}; use futures::TryStreamExt as _; use jp_config::{ - AppConfig, Config as _, PartialAppConfig, ToPartial as _, + AppConfig, PartialAppConfig, ToPartial as _, assistant::tool_choice::ToolChoice, - conversation::tool::{RunMode, ToolParameterConfig}, + conversation::tool::ToolParameterConfig, model::{ - id::{ModelIdConfig, Name, PartialModelIdConfig, PartialModelIdOrAliasConfig, ProviderId}, - parameters::{PartialCustomReasoningConfig, PartialReasoningConfig, ReasoningEffort}, + id::{ModelIdConfig, ModelIdOrAliasConfig, Name, PartialModelIdOrAliasConfig, ProviderId}, + parameters::{ + PartialCustomReasoningConfig, PartialReasoningConfig, ReasoningConfig, ReasoningEffort, + }, }, providers::llm::LlmProviderConfig, }; @@ -85,22 +87,16 @@ impl TestRequest { model: test_model_details(provider), query: ChatQuery { thread: ThreadBuilder::new() - .with_events( - ConversationStream::new({ - let mut cfg = PartialAppConfig::empty(); - cfg.conversation.tools.defaults.run = Some(RunMode::Ask); - cfg.assistant.model.parameters.reasoning = - Some(PartialReasoningConfig::Off); - cfg.assistant.model.id = PartialModelIdConfig { - provider: Some(provider), - name: Some("test".parse().unwrap()), - } - .into(); - - AppConfig::from_partial(cfg, vec![]).unwrap().into() - }) - .with_created_at(utc_datetime!(2020-01-01 0:00)), - ) + .with_events({ + let mut config = AppConfig::new_test(); + config.assistant.model.parameters.reasoning = Some(ReasoningConfig::Off); + config.assistant.model.id = ModelIdOrAliasConfig::Id(ModelIdConfig { + provider, + name: "test".parse().unwrap(), + }); + ConversationStream::new(config.into()) + .with_created_at(utc_datetime!(2020-01-01 0:00)) + }) .build() .unwrap(), tools: vec![], @@ -118,20 +114,15 @@ impl TestRequest { query: StructuredQuery::new( true.into(), ThreadBuilder::new() - .with_events( - ConversationStream::new({ - let mut cfg = PartialAppConfig::empty(); - cfg.conversation.tools.defaults.run = Some(RunMode::Ask); - cfg.assistant.model.id = PartialModelIdConfig { - provider: Some(provider), - name: Some("test".parse().unwrap()), - } - .into(); - - AppConfig::from_partial(cfg, vec![]).unwrap().into() - }) - .with_created_at(datetime!(2020-01-01 0:00 utc)), - ) + .with_events({ + let mut config = AppConfig::new_test(); + config.assistant.model.id = ModelIdOrAliasConfig::Id(ModelIdConfig { + provider, + name: "test".parse().unwrap(), + }); + ConversationStream::new(config.into()) + .with_created_at(utc_datetime!(2020-01-01 0:00)) + }) .build() .unwrap(), ), @@ -253,6 +244,7 @@ impl TestRequest { .into_iter() .map(|(k, v)| (k.to_owned(), v)) .collect(), + include_tool_answers_parameter: false, }); } diff --git a/crates/jp_llm/src/tool.rs b/crates/jp_llm/src/tool.rs index 2fcfba60..0ad9f4d4 100644 --- a/crates/jp_llm/src/tool.rs +++ b/crates/jp_llm/src/tool.rs @@ -1,10 +1,10 @@ use std::{fmt::Write, path::Path, sync::Arc}; use crossterm::style::Stylize as _; -use indexmap::IndexMap; +use indexmap::{IndexMap, IndexSet}; use jp_config::conversation::tool::{ - OneOrManyTypes, ResultMode, RunMode, ToolCommandConfig, ToolConfigWithDefaults, - ToolParameterConfig, ToolSource, item::ToolParameterItemConfig, + OneOrManyTypes, QuestionConfig, QuestionTarget, ResultMode, RunMode, ToolCommandConfig, + ToolConfigWithDefaults, ToolParameterConfig, ToolSource, item::ToolParameterItemConfig, }; use jp_conversation::event::ToolCallResponse; use jp_inquire::{InlineOption, InlineSelect}; @@ -13,7 +13,7 @@ use jp_mcp::{ id::{McpServerId, McpToolId}, }; use jp_printer::PrinterWriter; -use jp_tool::Outcome; +use jp_tool::{Action, Outcome}; use minijinja::Environment; use serde_json::{Map, Value, json}; use tracing::{error, info, trace}; @@ -32,6 +32,13 @@ pub struct ToolDefinition { pub name: String, pub description: Option, pub parameters: IndexMap, + + /// Whether the tool should include the `tool_answers` parameter. + /// + /// This is `true` for any tool that has its questions configured such that + /// at least one question has to be answered by the assistant instead of the + /// user. + pub include_tool_answers_parameter: bool, } impl ToolDefinition { @@ -40,6 +47,7 @@ impl ToolDefinition { source: &ToolSource, description: Option, parameters: IndexMap, + questions: &IndexMap, mcp_client: &jp_mcp::Client, ) -> Result { match &source { @@ -47,6 +55,7 @@ impl ToolDefinition { name.to_owned(), description, parameters, + questions, )), ToolSource::Mcp { server, tool } => { mcp_tool_definition( @@ -81,7 +90,7 @@ impl ToolDefinition { "arguments": arguments, }, "context": { - "format_parameters": true, + "action": Action::FormatArguments, "root": root.to_string_lossy(), }, }); @@ -94,6 +103,7 @@ impl ToolDefinition { id: String, mut arguments: Value, answers: &IndexMap, + pending_questions: &IndexSet, mcp_client: &jp_mcp::Client, config: ToolConfigWithDefaults, root: &Path, @@ -104,7 +114,7 @@ impl ToolDefinition { // If the tool call has answers to provide to the tool, it means the // tool already ran once, and we should not ask for confirmation again. - let run_mode = if answers.is_empty() { + let run_mode = if pending_questions.is_empty() { config.run() } else { RunMode::Unattended @@ -181,6 +191,7 @@ impl ToolDefinition { "answers": answers, }, "context": { + "action": Action::Run, "root": root.to_string_lossy().into_owned(), }, }); @@ -237,11 +248,25 @@ impl ToolDefinition { /// Return a map of parameter names to JSON schemas. #[must_use] pub fn to_parameters_map(&self) -> Map { - self.parameters + let mut map = self + .parameters .clone() .into_iter() .map(|(k, v)| (k, v.to_json_schema())) - .collect::>() + .collect::>(); + + if self.include_tool_answers_parameter { + map.insert( + "tool_answers".to_owned(), + json!({ + "type": ["object", "null"], + "additionalProperties": true, + "description": "Answers to the tool's questions. This should only be used if explicitly requested by the user.", + }), + ); + } + + map } /// Return a JSON schema for the parameters of the tool. @@ -254,7 +279,7 @@ impl ToolDefinition { .map(|(k, _)| k.clone()) .collect::>(); - serde_json::json!({ + json!({ "type": "object", "properties": self.to_parameters_map(), "additionalProperties": false, @@ -714,6 +739,7 @@ pub async fn tool_definitions( config.source(), config.description().map(str::to_owned), config.parameters().clone(), + config.questions(), mcp_client, ) .await?, @@ -727,11 +753,17 @@ fn local_tool_definition( name: String, description: Option, parameters: IndexMap, + questions: &IndexMap, ) -> ToolDefinition { + let include_tool_answers_parameter = questions + .iter() + .any(|(_, v)| v.target == QuestionTarget::Assistant); + ToolDefinition { name, description, parameters, + include_tool_answers_parameter, } } @@ -895,6 +927,7 @@ async fn mcp_tool_definition( name: name.to_owned(), description, parameters: params, + include_tool_answers_parameter: false, }) }