From 6aad6f7c025ef5a56c277ef3dedf55d5955deabe Mon Sep 17 00:00:00 2001 From: Toby Date: Wed, 6 Nov 2024 11:00:37 +0000 Subject: [PATCH 1/3] feat: implement multiple named conversations - Replace single conversation file with a conversations directory structure - Add support for named conversations via -n/--name flag - Implement conversation name validation - Add test coverage for conversation management --- Cargo.lock | 13 +++++---- Cargo.toml | 1 + src/config/mod.rs | 12 ++++++-- src/config/prompt.rs | 68 +++++++++++++++++++++++++++++++++++++++----- src/main.rs | 47 ++++++++++++++++++++++++++++-- src/utils.rs | 12 ++++++++ 6 files changed, 135 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2b32973..837505b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -729,9 +729,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.4" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -741,9 +741,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.6" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", @@ -752,9 +752,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" @@ -987,6 +987,7 @@ dependencies = [ "env_logger", "glob", "log", + "regex", "reqwest", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index e924ecc..cb92cf2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ serde_json = "1" toml = "0" env_logger = "0" reqwest = { version = "0", default-features = false, features = ["http2", "json", "blocking", "multipart", "rustls-tls"] } +regex = "1.11.1" [dev-dependencies] tempfile = "3" diff --git a/src/config/mod.rs b/src/config/mod.rs index 358b904..a9f1b4d 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -5,7 +5,7 @@ use std::{path::PathBuf, process::Command}; use self::{ api::{api_keys_path, generate_api_keys_file, get_api_config}, - prompt::{generate_prompts_file, get_prompts, prompts_path}, + prompt::{generate_prompts_file, get_prompts, prompts_path, conversations_path}, }; use crate::utils::is_interactive; @@ -58,6 +58,12 @@ pub fn ensure_config_files() -> std::io::Result<()> { } }; + // Create the conversations directory if it doesn't exist + if !conversations_path().exists() { + std::fs::create_dir_all(conversations_path()) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("Failed to create conversations directory: {}", e)))?; + } + Ok(()) } @@ -107,7 +113,7 @@ mod tests { config::{ api::{api_keys_path, default_timeout_seconds, Api, ApiConfig}, ensure_config_files, - prompt::{prompts_path, Prompt}, + prompt::{prompts_path, conversations_path, Prompt}, resolve_config_path, CUSTOM_CONFIG_ENV_VAR, DEFAULT_CONFIG_PATH, }, utils::IS_NONINTERACTIVE_ENV_VAR, @@ -175,6 +181,7 @@ mod tests { assert!(!api_keys_path.exists()); assert!(!prompts_path.exists()); + assert!(!conversations_path().exists()); let result = ensure_config_files(); @@ -187,6 +194,7 @@ mod tests { assert!(api_keys_path.exists()); assert!(prompts_path.exists()); + assert!(conversations_path().exists()); Ok(()) } diff --git a/src/config/prompt.rs b/src/config/prompt.rs index 699a87b..e38e3e1 100644 --- a/src/config/prompt.rs +++ b/src/config/prompt.rs @@ -9,7 +9,7 @@ use std::path::PathBuf; use crate::config::{api::Api, resolve_config_path}; const PROMPT_FILE: &str = "prompts.toml"; -const CONVERSATION_FILE: &str = "conversation.toml"; +const CONVERSATIONS_PATH: &str = "conversations/"; #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] pub struct Prompt { @@ -96,19 +96,31 @@ pub(super) fn prompts_path() -> PathBuf { resolve_config_path().join(PROMPT_FILE) } -pub fn conversation_file_path() -> PathBuf { - resolve_config_path().join(CONVERSATION_FILE) +// Get the path to the conversations directory +pub fn conversations_path() -> PathBuf { + resolve_config_path().join(CONVERSATIONS_PATH) } -pub fn get_last_conversation_as_prompt() -> Prompt { - let content = fs::read_to_string(conversation_file_path()).unwrap_or_else(|error| { +// Get the path to a specific conversation file +pub fn conversation_file_path(name: &str) -> PathBuf { + conversations_path().join(format!("{}.toml", name)) +} + +// Get the last conversation as a prompt, if it exists +pub fn get_last_conversation_as_prompt(name: &str) -> Option { + let file_path = conversation_file_path(name); + if !file_path.exists() { + return None; + } + + let content = fs::read_to_string(file_path).unwrap_or_else(|error| { panic!( "Could not read file {:?}, {:?}", - conversation_file_path(), + conversation_file_path(name), error ) }); - toml::from_str(&content).expect("failed to load the conversation file") + Some(toml::from_str(&content).expect("failed to load the conversation file")) } pub(super) fn generate_prompts_file() -> std::io::Result<()> { @@ -136,3 +148,45 @@ pub fn get_prompts() -> HashMap { .unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", prompts_path(), error)); toml::from_str(&content).expect("could not parse prompt file content") } + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + + #[test] + fn test_conversation_file_path() { + let name = "test_conversation"; + let file_path = conversation_file_path(name); + assert_eq!( + file_path.file_name().unwrap().to_str().unwrap(), + format!("{}.toml", name) + ); + assert_eq!(file_path.parent().unwrap(), conversations_path()); + } + + #[test] + fn test_get_last_conversation_as_prompt() { + let name = "test_conversation"; + let file_path = conversation_file_path(name); + let prompt = Prompt::default(); + + // Create a test conversation file + let toml_string = toml::to_string(&prompt).expect("Failed to serialize prompt"); + fs::write(&file_path, toml_string).expect("Failed to write test conversation file"); + + let loaded_prompt = get_last_conversation_as_prompt(name); + assert_eq!(loaded_prompt, Some(prompt)); + + // Clean up the test conversation file + fs::remove_file(&file_path).expect("Failed to remove test conversation file"); + } + + #[test] + fn test_get_last_conversation_as_prompt_missing_file() { + let name = "nonexistent_conversation"; + let loaded_prompt = get_last_conversation_as_prompt(name); + assert_eq!(loaded_prompt, None); + } + +} diff --git a/src/main.rs b/src/main.rs index 4c26682..2a52a3a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,7 @@ use crate::config::{ prompt::{conversation_file_path, get_last_conversation_as_prompt, get_prompts, Prompt}, }; use prompt_customization::customize_prompt; +use crate::utils::valid_conversation_name; use clap::{Args, Parser}; use log::debug; @@ -18,6 +19,7 @@ use std::io::{self, IsTerminal, Read, Write}; use text::process_input_with_request; const DEFAULT_PROMPT_NAME: &str = "default"; +const DEFAULT_CONVERSATION_NAME: &str = "default"; #[derive(Debug, Parser)] #[command( @@ -56,6 +58,9 @@ struct Cli { /// whether to repeat the input before the output, useful to extend instead of replacing #[arg(short, long)] repeat_input: bool, + /// conversation name + #[arg(short, long, value_parser = valid_conversation_name)] + name: Option, #[command(flatten)] prompt_params: PromptParams, } @@ -104,6 +109,7 @@ fn main() { } let args = Cli::parse(); + let name = args.name.as_deref().unwrap_or(DEFAULT_CONVERSATION_NAME); debug!("args: {:?}", args); @@ -118,14 +124,18 @@ fn main() { // if it doesn't use default prompt and treat that first arg as customization text get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) } else { - prompt_customizaton_text = args.input_or_template_ref; + prompt_customizaton_text = args.input_or_template_ref.clone(); if args.input_if_template_ref.is_some() { panic!( "Invalid parameters, cannot provide a config ref when extending a conversation.\n\ Use `sc -e \".\"`" ); } - get_last_conversation_as_prompt() + + match get_last_conversation_as_prompt(name) { + Some(prompt) => prompt, + None => get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text), + } }; // if no text was piped, use the custom prompt as input @@ -149,7 +159,7 @@ fn main() { Ok(prompt) => { let toml_string = toml::to_string(&prompt).expect("Failed to serialize prompt after response."); - let mut file = fs::File::create(conversation_file_path()) + let mut file = fs::File::create(conversation_file_path(name)) .expect("Failed to create the conversation save file."); file.write_all(toml_string.as_bytes()) .expect("Failed to write to the conversation file."); @@ -204,3 +214,34 @@ fn get_default_and_or_custom_prompt( .expect(&prompt_not_found_error) } } + +#[cfg(test)] +mod tests { + + use super::*; + use crate::config::prompt::Prompt; + + #[test] + fn test_get_last_conversation_as_prompt_missing_file() { + let args = Cli { + input_or_template_ref: Some("test_input".to_string()), + input_if_template_ref: None, + extend_conversation: true, + repeat_input: false, + name: Some("nonexistent_conversation".to_string()), + prompt_params: PromptParams::default(), + }; + let mut prompt_customizaton_text = None; + let name = args.name.as_deref().unwrap_or(DEFAULT_CONVERSATION_NAME); + + let prompt = get_last_conversation_as_prompt(name); + + assert_eq!(prompt, None); + + let default_prompt = get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text); + assert_eq!(default_prompt, Prompt::default()); + assert_eq!(prompt_customizaton_text, Some("test_input".to_string())); + + } + +} diff --git a/src/utils.rs b/src/utils.rs index 1ce813a..5643f45 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,5 @@ +use regex::Regex; + pub const IS_NONINTERACTIVE_ENV_VAR: &str = "SMARTCAT_NONINTERACTIVE"; /// clean error logging @@ -24,3 +26,13 @@ pub fn read_user_input() -> String { .expect("Failed to read line"); user_input.trim().to_string() } + +// Validate the conversation name +pub fn valid_conversation_name(s: &str) -> Result { + let re = Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap(); + if re.is_match(s) { + Ok(s.to_string()) + } else { + Err(format!("Invalid name: {}", s)) + } +} From 9f918592c3590b66c06128e43ed092b5087de98f Mon Sep 17 00:00:00 2001 From: Toby Date: Wed, 6 Nov 2024 15:51:25 +0000 Subject: [PATCH 2/3] feat: maintain backwards compatibility with conversastion.toml - Use conversation.toml for storing latest conversation state - Remove unnecessary directory creation in config initialization - Update test coverage --- src/config/mod.rs | 12 +-- src/config/prompt.rs | 184 ++++++++++++++++++++++++++++++++++--------- src/main.rs | 142 ++++++++++++++++++++++++++------- src/utils.rs | 2 +- 4 files changed, 264 insertions(+), 76 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index a9f1b4d..358b904 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -5,7 +5,7 @@ use std::{path::PathBuf, process::Command}; use self::{ api::{api_keys_path, generate_api_keys_file, get_api_config}, - prompt::{generate_prompts_file, get_prompts, prompts_path, conversations_path}, + prompt::{generate_prompts_file, get_prompts, prompts_path}, }; use crate::utils::is_interactive; @@ -58,12 +58,6 @@ pub fn ensure_config_files() -> std::io::Result<()> { } }; - // Create the conversations directory if it doesn't exist - if !conversations_path().exists() { - std::fs::create_dir_all(conversations_path()) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("Failed to create conversations directory: {}", e)))?; - } - Ok(()) } @@ -113,7 +107,7 @@ mod tests { config::{ api::{api_keys_path, default_timeout_seconds, Api, ApiConfig}, ensure_config_files, - prompt::{prompts_path, conversations_path, Prompt}, + prompt::{prompts_path, Prompt}, resolve_config_path, CUSTOM_CONFIG_ENV_VAR, DEFAULT_CONFIG_PATH, }, utils::IS_NONINTERACTIVE_ENV_VAR, @@ -181,7 +175,6 @@ mod tests { assert!(!api_keys_path.exists()); assert!(!prompts_path.exists()); - assert!(!conversations_path().exists()); let result = ensure_config_files(); @@ -194,7 +187,6 @@ mod tests { assert!(api_keys_path.exists()); assert!(prompts_path.exists()); - assert!(conversations_path().exists()); Ok(()) } diff --git a/src/config/prompt.rs b/src/config/prompt.rs index e38e3e1..1ae56a8 100644 --- a/src/config/prompt.rs +++ b/src/config/prompt.rs @@ -9,7 +9,8 @@ use std::path::PathBuf; use crate::config::{api::Api, resolve_config_path}; const PROMPT_FILE: &str = "prompts.toml"; -const CONVERSATIONS_PATH: &str = "conversations/"; +const CONVERSATION_FILE: &str = "conversation.toml"; +const CONVERSATIONS_PATH: &str = "saved_conversations"; #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] pub struct Prompt { @@ -96,31 +97,66 @@ pub(super) fn prompts_path() -> PathBuf { resolve_config_path().join(PROMPT_FILE) } +pub fn conversation_file_path() -> PathBuf { + resolve_config_path().join(CONVERSATION_FILE) +} + // Get the path to the conversations directory pub fn conversations_path() -> PathBuf { resolve_config_path().join(CONVERSATIONS_PATH) } // Get the path to a specific conversation file -pub fn conversation_file_path(name: &str) -> PathBuf { +pub fn named_conversation_path(name: &str) -> PathBuf { conversations_path().join(format!("{}.toml", name)) } // Get the last conversation as a prompt, if it exists -pub fn get_last_conversation_as_prompt(name: &str) -> Option { - let file_path = conversation_file_path(name); - if !file_path.exists() { - return None; +pub fn get_last_conversation_as_prompt(name: Option<&str>) -> Option { + if let Some(name) = name { + let named_path = named_conversation_path(name); + if !named_path.exists() { + return None; + } + let content = fs::read_to_string(named_path) + .unwrap_or_else(|error| { + panic!( + "Could not read file {:?}, {:?}", + named_conversation_path(name), + error + ) + }); + Some(toml::from_str(&content).expect("failed to load the conversation file")) + } else { + let path = conversation_file_path(); + if !path.exists() { + return None; + } + let content = fs::read_to_string(path) + .unwrap_or_else(|error| { + panic!( + "Could not read file {:?}, {:?}", + conversation_file_path(), + error + ) + }); + Some(toml::from_str(&content).expect("failed to load the conversation file")) + } +} + +pub fn save_conversation(prompt: &Prompt, name: Option<&str>) -> std::io::Result<()> { + let toml_string = toml::to_string(prompt).expect("Failed to serialize prompt"); + + // Always save to conversation.toml + fs::write(conversation_file_path(), &toml_string)?; + + // If name is provided, also save to named conversation file + if let Some(name) = name { + fs::create_dir_all(conversations_path())?; + fs::write(named_conversation_path(name), &toml_string)?; } - let content = fs::read_to_string(file_path).unwrap_or_else(|error| { - panic!( - "Could not read file {:?}, {:?}", - conversation_file_path(name), - error - ) - }); - Some(toml::from_str(&content).expect("failed to load the conversation file")) + Ok(()) } pub(super) fn generate_prompts_file() -> std::io::Result<()> { @@ -153,40 +189,114 @@ pub fn get_prompts() -> HashMap { mod tests { use super::*; use std::fs; + use tempfile::tempdir; + use crate::config::prompt::Prompt; + use serial_test::serial; + + fn setup() -> tempfile::TempDir { + let temp_dir = tempdir().unwrap(); + std::env::set_var("SMARTCAT_CONFIG_PATH", temp_dir.path()); + temp_dir + } + + fn create_test_prompt() -> Prompt { + let mut prompt = Prompt::default(); + prompt.messages = vec![(Message::user("test"))]; + prompt + } #[test] - fn test_conversation_file_path() { - let name = "test_conversation"; - let file_path = conversation_file_path(name); - assert_eq!( - file_path.file_name().unwrap().to_str().unwrap(), - format!("{}.toml", name) - ); - assert_eq!(file_path.parent().unwrap(), conversations_path()); + #[serial] + fn test_get_and_save_default_conversation() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + + // Test saving conversation + save_conversation(&test_prompt, None).unwrap(); + assert!(conversation_file_path().exists()); + + // Test retrieving conversation + let loaded_prompt = get_last_conversation_as_prompt(None).unwrap(); + assert_eq!(loaded_prompt, test_prompt); } #[test] - fn test_get_last_conversation_as_prompt() { - let name = "test_conversation"; - let file_path = conversation_file_path(name); - let prompt = Prompt::default(); + #[serial] + fn test_get_and_save_named_conversation() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + let conv_name = "test_conversation"; - // Create a test conversation file - let toml_string = toml::to_string(&prompt).expect("Failed to serialize prompt"); - fs::write(&file_path, toml_string).expect("Failed to write test conversation file"); + // Test saving named conversation + save_conversation(&test_prompt, Some(conv_name)).unwrap(); + assert!(named_conversation_path(conv_name).exists()); + assert!(conversation_file_path().exists()); // Should also save to default location - let loaded_prompt = get_last_conversation_as_prompt(name); - assert_eq!(loaded_prompt, Some(prompt)); + // Test retrieving named conversation + let loaded_prompt = get_last_conversation_as_prompt(Some(conv_name)).unwrap(); + assert_eq!(loaded_prompt, test_prompt); + } - // Clean up the test conversation file - fs::remove_file(&file_path).expect("Failed to remove test conversation file"); + #[test] + #[serial] + fn test_nonexistent_conversation() { + let _temp_dir = setup(); + + // Test getting nonexistent default conversation + assert!(get_last_conversation_as_prompt(None).is_none()); + + // Test getting nonexistent named conversation + assert!(get_last_conversation_as_prompt(Some("nonexistent")).is_none()); + } + + #[test] + #[serial] + fn test_conversation_file_contents() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + let conv_name = "test_conversation"; + + // Save conversation + save_conversation(&test_prompt, Some(conv_name)).unwrap(); + + // Verify default and named files have identical content + let default_content = fs::read_to_string(conversation_file_path()).unwrap(); + let named_content = fs::read_to_string(named_conversation_path(conv_name)).unwrap(); + assert_eq!(default_content, named_content); + + // Verify content can be parsed back to original prompt + let parsed_prompt: Prompt = toml::from_str(&default_content).unwrap(); + assert_eq!(parsed_prompt, test_prompt); } #[test] - fn test_get_last_conversation_as_prompt_missing_file() { - let name = "nonexistent_conversation"; - let loaded_prompt = get_last_conversation_as_prompt(name); - assert_eq!(loaded_prompt, None); + #[serial] + fn test_generate_prompts_file() { + let _temp_dir = setup(); + + // Test file generation + generate_prompts_file().unwrap(); + assert!(prompts_path().exists()); + + // Verify file is valid TOML and contains expected content + let content = fs::read_to_string(prompts_path()).unwrap(); + let prompts: HashMap = toml::from_str(&content).unwrap(); + assert!(!prompts.is_empty()); } + #[test] + #[serial] + fn test_get_prompts() { + let _temp_dir = setup(); + + // Generate prompts file + generate_prompts_file().unwrap(); + + // Test loading prompts + let prompts = get_prompts(); + assert!(!prompts.is_empty()); + + // Verify at least one default prompt exists + assert!(prompts.contains_key("default")); + } } diff --git a/src/main.rs b/src/main.rs index 2a52a3a..6cd35c5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,20 +6,18 @@ mod utils; use crate::config::{ api::Api, ensure_config_usable, - prompt::{conversation_file_path, get_last_conversation_as_prompt, get_prompts, Prompt}, + prompt::{get_last_conversation_as_prompt, save_conversation, get_prompts, Prompt}, }; use prompt_customization::customize_prompt; use crate::utils::valid_conversation_name; use clap::{Args, Parser}; use log::debug; -use std::fs; use std::io::{self, IsTerminal, Read, Write}; use text::process_input_with_request; const DEFAULT_PROMPT_NAME: &str = "default"; -const DEFAULT_CONVERSATION_NAME: &str = "default"; #[derive(Debug, Parser)] #[command( @@ -109,7 +107,6 @@ fn main() { } let args = Cli::parse(); - let name = args.name.as_deref().unwrap_or(DEFAULT_CONVERSATION_NAME); debug!("args: {:?}", args); @@ -119,12 +116,9 @@ fn main() { let is_piped = !stdin.is_terminal(); let mut prompt_customizaton_text: Option = None; - let prompt: Prompt = if !args.extend_conversation { - // try to get prompt matching the first arg and use second arg as customization text - // if it doesn't use default prompt and treat that first arg as customization text - get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) - } else { + let prompt = if args.extend_conversation { prompt_customizaton_text = args.input_or_template_ref.clone(); + if args.input_if_template_ref.is_some() { panic!( "Invalid parameters, cannot provide a config ref when extending a conversation.\n\ @@ -132,10 +126,19 @@ fn main() { ); } - match get_last_conversation_as_prompt(name) { + match get_last_conversation_as_prompt(args.name.as_deref()) { Some(prompt) => prompt, - None => get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text), + None => { + if args.name.is_some() { + panic!("Named conversation does not exist: {}", args.name.unwrap()); + } + get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) + } } + } else { + // try to get prompt matching the first arg and use second arg as customization text + // if it doesn't use default prompt and treat that first arg as customization text + get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) }; // if no text was piped, use the custom prompt as input @@ -156,13 +159,9 @@ fn main() { debug!("{:?}", prompt); match process_input_with_request(prompt, input, &mut output, args.repeat_input) { - Ok(prompt) => { - let toml_string = - toml::to_string(&prompt).expect("Failed to serialize prompt after response."); - let mut file = fs::File::create(conversation_file_path(name)) - .expect("Failed to create the conversation save file."); - file.write_all(toml_string.as_bytes()) - .expect("Failed to write to the conversation file."); + Ok(new_prompt) => { + save_conversation(&new_prompt, args.name.as_deref()) + .expect("Failed to save conversation"); } Err(e) => { eprintln!("Error: {}", e); @@ -217,12 +216,28 @@ fn get_default_and_or_custom_prompt( #[cfg(test)] mod tests { - use super::*; - use crate::config::prompt::Prompt; + use crate::config::prompt::{Prompt, Message}; + use tempfile::tempdir; + use serial_test::serial; + + fn setup() -> tempfile::TempDir { + let temp_dir = tempdir().unwrap(); + std::env::set_var("SMARTCAT_CONFIG_PATH", temp_dir.path()); + temp_dir + } + + fn create_test_prompt() -> Prompt { + let mut prompt = Prompt::default(); + prompt.messages = vec![(Message::user("test"))]; + prompt + } #[test] - fn test_get_last_conversation_as_prompt_missing_file() { + #[serial] + fn test_cli_with_nonexistent_conversation() { + let _temp_dir = setup(); + let args = Cli { input_or_template_ref: Some("test_input".to_string()), input_if_template_ref: None, @@ -231,17 +246,88 @@ mod tests { name: Some("nonexistent_conversation".to_string()), prompt_params: PromptParams::default(), }; - let mut prompt_customizaton_text = None; - let name = args.name.as_deref().unwrap_or(DEFAULT_CONVERSATION_NAME); - let prompt = get_last_conversation_as_prompt(name); + // Test that getting a nonexistent conversation returns None + let prompt = get_last_conversation_as_prompt(args.name.as_deref()); + assert!(prompt.is_none()); + } - assert_eq!(prompt, None); + #[test] + #[serial] + fn test_cli_with_existing_conversation() { + let _temp_dir = setup(); + + // Create a test conversation + let test_prompt = create_test_prompt(); + save_conversation(&test_prompt, Some("test_conversation")).unwrap(); + + let args = Cli { + input_or_template_ref: Some("test_input".to_string()), + input_if_template_ref: None, + extend_conversation: true, + repeat_input: false, + name: Some("test_conversation".to_string()), + prompt_params: PromptParams::default(), + }; + + // Test retrieving the saved conversation + let prompt = get_last_conversation_as_prompt(args.name.as_deref()); + assert!(prompt.is_some()); + assert_eq!(prompt.unwrap(), test_prompt); + } + + #[test] + #[serial] + fn test_valid_conversation_name() { + assert!(valid_conversation_name("valid_name").is_ok()); + assert!(valid_conversation_name("valid-name").is_ok()); + assert!(valid_conversation_name("valid123").is_ok()); + assert!(valid_conversation_name("VALID_NAME").is_ok()); + + assert!(valid_conversation_name("invalid name").is_err()); + assert!(valid_conversation_name("invalid/name").is_err()); + assert!(valid_conversation_name("invalid.name").is_err()); + assert!(valid_conversation_name("").is_err()); + } + + #[test] + #[serial] + fn test_conversation_persistence() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + + // Test saving and loading default conversation + save_conversation(&test_prompt, None).unwrap(); + let loaded_prompt = get_last_conversation_as_prompt(None); + assert!(loaded_prompt.is_some()); + assert_eq!(loaded_prompt.unwrap(), test_prompt); + + // Test saving and loading named conversation + save_conversation(&test_prompt, Some("test_conv")).unwrap(); + let loaded_named_prompt = get_last_conversation_as_prompt(Some("test_conv")); + assert!(loaded_named_prompt.is_some()); + assert_eq!(loaded_named_prompt.unwrap(), test_prompt); + } + + #[test] + #[serial] + fn test_default_prompt_fallback() { + let _temp_dir = setup(); + let args = Cli { + input_or_template_ref: Some("test_input".to_string()), + input_if_template_ref: None, + extend_conversation: true, + repeat_input: false, + name: None, + prompt_params: PromptParams::default(), + }; - let default_prompt = get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text); - assert_eq!(default_prompt, Prompt::default()); - assert_eq!(prompt_customizaton_text, Some("test_input".to_string())); + let prompt = get_last_conversation_as_prompt(args.name.as_deref()); + assert!(prompt.is_none()); // Should be None when no conversation exists + // Verify the prompt customization text is set correctly + let prompt_customization_text = args.input_or_template_ref; + assert_eq!(prompt_customization_text, Some("test_input".to_string())); } } diff --git a/src/utils.rs b/src/utils.rs index 5643f45..fb676a4 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -33,6 +33,6 @@ pub fn valid_conversation_name(s: &str) -> Result { if re.is_match(s) { Ok(s.to_string()) } else { - Err(format!("Invalid name: {}", s)) + Err(format!("Invalid conversation name: {}. Use only letters, numbers, underscores, and hyphens.", s)) } } From dd479f6c52bb309163d73b99425ce620137b1836 Mon Sep 17 00:00:00 2001 From: Toby Date: Thu, 7 Nov 2024 09:01:17 +0000 Subject: [PATCH 3/3] fix: trim whitespace from conversation names before validation --- src/utils.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/utils.rs b/src/utils.rs index fb676a4..2010a3a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -29,10 +29,11 @@ pub fn read_user_input() -> String { // Validate the conversation name pub fn valid_conversation_name(s: &str) -> Result { + let trimmed = s.trim(); let re = Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap(); - if re.is_match(s) { - Ok(s.to_string()) + if re.is_match(trimmed) { + Ok(trimmed.to_string()) } else { - Err(format!("Invalid conversation name: {}. Use only letters, numbers, underscores, and hyphens.", s)) + Err(format!("Invalid conversation name: {}. Use only letters, numbers, underscores, and hyphens.", trimmed)) } }