diff --git a/Cargo.lock b/Cargo.lock index 2b32973..837505b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -729,9 +729,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.4" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -741,9 +741,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.6" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", @@ -752,9 +752,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" @@ -987,6 +987,7 @@ dependencies = [ "env_logger", "glob", "log", + "regex", "reqwest", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index e924ecc..cb92cf2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ serde_json = "1" toml = "0" env_logger = "0" reqwest = { version = "0", default-features = false, features = ["http2", "json", "blocking", "multipart", "rustls-tls"] } +regex = "1.11.1" [dev-dependencies] tempfile = "3" diff --git a/src/config/prompt.rs b/src/config/prompt.rs index 699a87b..1ae56a8 100644 --- a/src/config/prompt.rs +++ b/src/config/prompt.rs @@ -10,6 +10,7 @@ use crate::config::{api::Api, resolve_config_path}; const PROMPT_FILE: &str = "prompts.toml"; const CONVERSATION_FILE: &str = "conversation.toml"; +const CONVERSATIONS_PATH: &str = "saved_conversations"; #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] pub struct Prompt { @@ -100,15 +101,62 @@ pub fn conversation_file_path() -> PathBuf { resolve_config_path().join(CONVERSATION_FILE) } -pub fn get_last_conversation_as_prompt() -> Prompt { - let content = fs::read_to_string(conversation_file_path()).unwrap_or_else(|error| { - panic!( - "Could not read file {:?}, {:?}", - conversation_file_path(), - error - ) - }); - toml::from_str(&content).expect("failed to load the conversation file") +// Get the path to the conversations directory +pub fn conversations_path() -> PathBuf { + resolve_config_path().join(CONVERSATIONS_PATH) +} + +// Get the path to a specific conversation file +pub fn named_conversation_path(name: &str) -> PathBuf { + conversations_path().join(format!("{}.toml", name)) +} + +// Get the last conversation as a prompt, if it exists +pub fn get_last_conversation_as_prompt(name: Option<&str>) -> Option { + if let Some(name) = name { + let named_path = named_conversation_path(name); + if !named_path.exists() { + return None; + } + let content = fs::read_to_string(named_path) + .unwrap_or_else(|error| { + panic!( + "Could not read file {:?}, {:?}", + named_conversation_path(name), + error + ) + }); + Some(toml::from_str(&content).expect("failed to load the conversation file")) + } else { + let path = conversation_file_path(); + if !path.exists() { + return None; + } + let content = fs::read_to_string(path) + .unwrap_or_else(|error| { + panic!( + "Could not read file {:?}, {:?}", + conversation_file_path(), + error + ) + }); + Some(toml::from_str(&content).expect("failed to load the conversation file")) + } +} + +pub fn save_conversation(prompt: &Prompt, name: Option<&str>) -> std::io::Result<()> { + let toml_string = toml::to_string(prompt).expect("Failed to serialize prompt"); + + // Always save to conversation.toml + fs::write(conversation_file_path(), &toml_string)?; + + // If name is provided, also save to named conversation file + if let Some(name) = name { + fs::create_dir_all(conversations_path())?; + fs::write(named_conversation_path(name), &toml_string)?; + } + + Ok(()) } pub(super) fn generate_prompts_file() -> std::io::Result<()> { @@ -136,3 +184,119 @@ pub fn get_prompts() -> HashMap { .unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", prompts_path(), error)); toml::from_str(&content).expect("could not parse prompt file content") } + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::tempdir; + use crate::config::prompt::Prompt; + use serial_test::serial; + + fn setup() -> tempfile::TempDir { + let temp_dir = tempdir().unwrap(); + std::env::set_var("SMARTCAT_CONFIG_PATH", temp_dir.path()); + temp_dir + } + + fn create_test_prompt() -> Prompt { + let mut prompt = Prompt::default(); + prompt.messages = vec![(Message::user("test"))]; + prompt + } + + #[test] + #[serial] + fn test_get_and_save_default_conversation() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + + // Test saving conversation + save_conversation(&test_prompt, None).unwrap(); + assert!(conversation_file_path().exists()); + + // Test retrieving conversation + let loaded_prompt = get_last_conversation_as_prompt(None).unwrap(); + assert_eq!(loaded_prompt, test_prompt); + } + + #[test] + #[serial] + fn test_get_and_save_named_conversation() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + let conv_name = "test_conversation"; + + // Test saving named conversation + save_conversation(&test_prompt, Some(conv_name)).unwrap(); + assert!(named_conversation_path(conv_name).exists()); + assert!(conversation_file_path().exists()); // Should also save to default location + + // Test retrieving named conversation + let loaded_prompt = get_last_conversation_as_prompt(Some(conv_name)).unwrap(); + assert_eq!(loaded_prompt, test_prompt); + } + + #[test] + #[serial] + fn test_nonexistent_conversation() { + let _temp_dir = setup(); + + // Test getting nonexistent default conversation + assert!(get_last_conversation_as_prompt(None).is_none()); + + // Test getting nonexistent named conversation + assert!(get_last_conversation_as_prompt(Some("nonexistent")).is_none()); + } + + #[test] + #[serial] + fn test_conversation_file_contents() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + let conv_name = "test_conversation"; + + // Save conversation + save_conversation(&test_prompt, Some(conv_name)).unwrap(); + + // Verify default and named files have identical content + let default_content = fs::read_to_string(conversation_file_path()).unwrap(); + let named_content = fs::read_to_string(named_conversation_path(conv_name)).unwrap(); + assert_eq!(default_content, named_content); + + // Verify content can be parsed back to original prompt + let parsed_prompt: Prompt = toml::from_str(&default_content).unwrap(); + assert_eq!(parsed_prompt, test_prompt); + } + + #[test] + #[serial] + fn test_generate_prompts_file() { + let _temp_dir = setup(); + + // Test file generation + generate_prompts_file().unwrap(); + assert!(prompts_path().exists()); + + // Verify file is valid TOML and contains expected content + let content = fs::read_to_string(prompts_path()).unwrap(); + let prompts: HashMap = toml::from_str(&content).unwrap(); + assert!(!prompts.is_empty()); + } + + #[test] + #[serial] + fn test_get_prompts() { + let _temp_dir = setup(); + + // Generate prompts file + generate_prompts_file().unwrap(); + + // Test loading prompts + let prompts = get_prompts(); + assert!(!prompts.is_empty()); + + // Verify at least one default prompt exists + assert!(prompts.contains_key("default")); + } +} diff --git a/src/main.rs b/src/main.rs index 4c26682..6cd35c5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,13 +6,13 @@ mod utils; use crate::config::{ api::Api, ensure_config_usable, - prompt::{conversation_file_path, get_last_conversation_as_prompt, get_prompts, Prompt}, + prompt::{get_last_conversation_as_prompt, save_conversation, get_prompts, Prompt}, }; use prompt_customization::customize_prompt; +use crate::utils::valid_conversation_name; use clap::{Args, Parser}; use log::debug; -use std::fs; use std::io::{self, IsTerminal, Read, Write}; use text::process_input_with_request; @@ -56,6 +56,9 @@ struct Cli { /// whether to repeat the input before the output, useful to extend instead of replacing #[arg(short, long)] repeat_input: bool, + /// conversation name + #[arg(short, long, value_parser = valid_conversation_name)] + name: Option, #[command(flatten)] prompt_params: PromptParams, } @@ -113,19 +116,29 @@ fn main() { let is_piped = !stdin.is_terminal(); let mut prompt_customizaton_text: Option = None; - let prompt: Prompt = if !args.extend_conversation { - // try to get prompt matching the first arg and use second arg as customization text - // if it doesn't use default prompt and treat that first arg as customization text - get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) - } else { - prompt_customizaton_text = args.input_or_template_ref; + let prompt = if args.extend_conversation { + prompt_customizaton_text = args.input_or_template_ref.clone(); + if args.input_if_template_ref.is_some() { panic!( "Invalid parameters, cannot provide a config ref when extending a conversation.\n\ Use `sc -e \".\"`" ); } - get_last_conversation_as_prompt() + + match get_last_conversation_as_prompt(args.name.as_deref()) { + Some(prompt) => prompt, + None => { + if args.name.is_some() { + panic!("Named conversation does not exist: {}", args.name.unwrap()); + } + get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) + } + } + } else { + // try to get prompt matching the first arg and use second arg as customization text + // if it doesn't use default prompt and treat that first arg as customization text + get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) }; // if no text was piped, use the custom prompt as input @@ -146,13 +159,9 @@ fn main() { debug!("{:?}", prompt); match process_input_with_request(prompt, input, &mut output, args.repeat_input) { - Ok(prompt) => { - let toml_string = - toml::to_string(&prompt).expect("Failed to serialize prompt after response."); - let mut file = fs::File::create(conversation_file_path()) - .expect("Failed to create the conversation save file."); - file.write_all(toml_string.as_bytes()) - .expect("Failed to write to the conversation file."); + Ok(new_prompt) => { + save_conversation(&new_prompt, args.name.as_deref()) + .expect("Failed to save conversation"); } Err(e) => { eprintln!("Error: {}", e); @@ -204,3 +213,121 @@ fn get_default_and_or_custom_prompt( .expect(&prompt_not_found_error) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::prompt::{Prompt, Message}; + use tempfile::tempdir; + use serial_test::serial; + + fn setup() -> tempfile::TempDir { + let temp_dir = tempdir().unwrap(); + std::env::set_var("SMARTCAT_CONFIG_PATH", temp_dir.path()); + temp_dir + } + + fn create_test_prompt() -> Prompt { + let mut prompt = Prompt::default(); + prompt.messages = vec![(Message::user("test"))]; + prompt + } + + #[test] + #[serial] + fn test_cli_with_nonexistent_conversation() { + let _temp_dir = setup(); + + let args = Cli { + input_or_template_ref: Some("test_input".to_string()), + input_if_template_ref: None, + extend_conversation: true, + repeat_input: false, + name: Some("nonexistent_conversation".to_string()), + prompt_params: PromptParams::default(), + }; + + // Test that getting a nonexistent conversation returns None + let prompt = get_last_conversation_as_prompt(args.name.as_deref()); + assert!(prompt.is_none()); + } + + #[test] + #[serial] + fn test_cli_with_existing_conversation() { + let _temp_dir = setup(); + + // Create a test conversation + let test_prompt = create_test_prompt(); + save_conversation(&test_prompt, Some("test_conversation")).unwrap(); + + let args = Cli { + input_or_template_ref: Some("test_input".to_string()), + input_if_template_ref: None, + extend_conversation: true, + repeat_input: false, + name: Some("test_conversation".to_string()), + prompt_params: PromptParams::default(), + }; + + // Test retrieving the saved conversation + let prompt = get_last_conversation_as_prompt(args.name.as_deref()); + assert!(prompt.is_some()); + assert_eq!(prompt.unwrap(), test_prompt); + } + + #[test] + #[serial] + fn test_valid_conversation_name() { + assert!(valid_conversation_name("valid_name").is_ok()); + assert!(valid_conversation_name("valid-name").is_ok()); + assert!(valid_conversation_name("valid123").is_ok()); + assert!(valid_conversation_name("VALID_NAME").is_ok()); + + assert!(valid_conversation_name("invalid name").is_err()); + assert!(valid_conversation_name("invalid/name").is_err()); + assert!(valid_conversation_name("invalid.name").is_err()); + assert!(valid_conversation_name("").is_err()); + } + + #[test] + #[serial] + fn test_conversation_persistence() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + + // Test saving and loading default conversation + save_conversation(&test_prompt, None).unwrap(); + let loaded_prompt = get_last_conversation_as_prompt(None); + assert!(loaded_prompt.is_some()); + assert_eq!(loaded_prompt.unwrap(), test_prompt); + + // Test saving and loading named conversation + save_conversation(&test_prompt, Some("test_conv")).unwrap(); + let loaded_named_prompt = get_last_conversation_as_prompt(Some("test_conv")); + assert!(loaded_named_prompt.is_some()); + assert_eq!(loaded_named_prompt.unwrap(), test_prompt); + } + + #[test] + #[serial] + fn test_default_prompt_fallback() { + let _temp_dir = setup(); + let args = Cli { + input_or_template_ref: Some("test_input".to_string()), + input_if_template_ref: None, + extend_conversation: true, + repeat_input: false, + name: None, + prompt_params: PromptParams::default(), + }; + + let prompt = get_last_conversation_as_prompt(args.name.as_deref()); + assert!(prompt.is_none()); // Should be None when no conversation exists + + // Verify the prompt customization text is set correctly + let prompt_customization_text = args.input_or_template_ref; + assert_eq!(prompt_customization_text, Some("test_input".to_string())); + } + +} diff --git a/src/utils.rs b/src/utils.rs index 1ce813a..2010a3a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,5 @@ +use regex::Regex; + pub const IS_NONINTERACTIVE_ENV_VAR: &str = "SMARTCAT_NONINTERACTIVE"; /// clean error logging @@ -24,3 +26,14 @@ pub fn read_user_input() -> String { .expect("Failed to read line"); user_input.trim().to_string() } + +// Validate the conversation name +pub fn valid_conversation_name(s: &str) -> Result { + let trimmed = s.trim(); + let re = Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap(); + if re.is_match(trimmed) { + Ok(trimmed.to_string()) + } else { + Err(format!("Invalid conversation name: {}. Use only letters, numbers, underscores, and hyphens.", trimmed)) + } +}