diff --git a/llm_engine/src/llm_processor.rs b/llm_engine/src/llm_processor.rs index 42aa4f6..f096577 100644 --- a/llm_engine/src/llm_processor.rs +++ b/llm_engine/src/llm_processor.rs @@ -46,7 +46,7 @@ pub async fn process_natural_language_query( // Get model name from parameters or environment let model_name = model .or_else(|| env::var("LLM_MODEL").ok()) - .unwrap_or_else(|| "gpt-4".to_string()); + .unwrap_or_else(|| "gpt-3.5-turbo".to_string()); // Generate prompt with query and schema context let prompt = generate_prompt(&query, db_schema.as_ref()); diff --git a/llm_engine/src/main.rs b/llm_engine/src/main.rs index 149f47e..9f3e98e 100644 --- a/llm_engine/src/main.rs +++ b/llm_engine/src/main.rs @@ -9,6 +9,7 @@ use axum::{ response::IntoResponse, }; use serde::{Deserialize, Serialize}; +use serde_json::Value; use tower_http::cors::{CorsLayer, Any}; use tracing::{info, error}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -33,6 +34,14 @@ struct QueryRequest { model: Option, } +// Request model for visualization generation +#[derive(Deserialize)] +struct VisualizationRequest { + query: String, + results: Value, + model: Option, +} + // Response model #[derive(Serialize)] struct QueryResponse { @@ -41,6 +50,14 @@ struct QueryResponse { confidence: Option, } +// Response model for visualization generation +#[derive(Serialize)] +struct VisualizationResponse { + html_code: String, + explanation: String, + confidence: f64, +} + // Error handling enum AppError { InternalError(String), @@ -110,6 +127,204 @@ async fn process_query( } } +// Process visualization request +async fn generate_visualization( + State(_state): State>, + Json(request): Json, +) -> Result { + info!("Generating visualization for query: {}", request.query); + + // Get the model name from the request or use a default + let model = request.model.unwrap_or_else(|| "gpt-3.5-turbo".to_string()); + + // Format the results data for the prompt + let results_json = serde_json::to_string_pretty(&request.results) + .unwrap_or_else(|_| format!("{:?}", request.results)); + + // Create the prompt for visualization generation + let system_prompt = format!( + "You are a data visualization expert. Your task is to create a Plotly.js visualization based on the provided data and query. \ + You must return a complete, valid HTML file that uses Plotly.js to visualize the data.\ + \ + STYLING GUIDELINES:\ + - Use Plotly.js to create the chart\ + - Use 14px font for axis labels, 18px for titles\ + - Use consistent margins and padding\ + - Use a neobrutalist style\ + - Include axis titles with units (if known)\ + - Rotate x-axis labels if they are dates or long strings\ + - Use tight layout with `autosize: true` and `responsive: true`\ + - Enable zoom and pan interactivity\ + - Enable tooltips on hover showing exact values and labels\ + - Use hovermode: 'closest'\ + \ + OUTPUT FORMAT:\ + Your response should be a complete HTML file that can be directly viewed in a browser.\ + Return valid HTML that includes the Plotly.js library from a CDN and creates the visualization.\ + Also include a brief explanation of the visualization choices you made." + ); + + let user_prompt = format!( + "Natural Language Query: {}\n\nData Results:\n{}\n\nBased on this query and data, create a complete HTML file with a Plotly.js visualization.", + request.query, results_json + ); + + // Call the LLM with the specialized prompt + let llm_response = call_llm_api(&system_prompt, &user_prompt, &model).await + .map_err(|e| { + let error_msg = format!("Error calling LLM API: {}", e); + error!("{}", error_msg); + AppError::InternalError(error_msg) + })?; + + // Extract the HTML code and explanation from the response + let (html_code, explanation, confidence) = parse_visualization_response(&llm_response); + + Ok(Json(VisualizationResponse { + html_code, + explanation, + confidence, + })) +} + +// Call LLM API with system and user prompts +async fn call_llm_api(system_prompt: &str, user_prompt: &str, model_name: &str) -> Result { + let api_key = env::var("LLM_API_KEY").map_err(|_| anyhow::anyhow!("LLM_API_KEY environment variable not set"))?; + + let client = reqwest::Client::new(); + + #[derive(Serialize, Deserialize)] + struct Message { + role: String, + content: String, + } + + #[derive(Serialize)] + struct OpenAIRequest { + model: String, + messages: Vec, + temperature: f64, + } + + let request = OpenAIRequest { + model: model_name.to_string(), + messages: vec![ + Message { + role: "system".to_string(), + content: system_prompt.to_string(), + }, + Message { + role: "user".to_string(), + content: user_prompt.to_string(), + }, + ], + temperature: 0.7, // Slightly higher temperature for more creative visualizations + }; + + let response = client + .post("https://api.openai.com/v1/chat/completions") + .header("Authorization", format!("Bearer {}", api_key)) + .header("Content-Type", "application/json") + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let error_text = response.text().await?; + return Err(anyhow::anyhow!("LLM API returned error: {}", error_text)); + } + + #[derive(Deserialize)] + struct OpenAIChoice { + message: Message, + } + + #[derive(Deserialize)] + struct OpenAIResponse { + choices: Vec, + } + + let response_json: OpenAIResponse = response.json().await?; + + if response_json.choices.is_empty() { + return Err(anyhow::anyhow!("LLM API returned empty choices")); + } + + Ok(response_json.choices[0].message.content.clone()) +} + +// Parse LLM response to extract HTML, explanation and confidence +fn parse_visualization_response(response: &str) -> (String, String, f64) { + // Try to extract HTML content - look for or + let html_start_patterns = ["", ""]; + let mut html_code = String::new(); + let mut explanation = String::new(); + let confidence = 0.8; // Default confidence + + // First, check if the response contains a code block with HTML + if let Some(html_block_start) = response.find("```html") { + // Find the end of the code block (next ```) + if let Some(html_block_end) = response[html_block_start + 6..].find("```") { + // Extract HTML content (skip the ```html and end ```) + let block_start_pos = html_block_start + "```html".len(); + let block_end_pos = html_block_start + 6 + html_block_end; + html_code = response[block_start_pos..block_end_pos].trim().to_string(); + + // Look for explanation after the HTML block + if block_end_pos + 3 < response.len() { + explanation = response[block_end_pos + 3..].trim().to_string(); + } + } + } + // If no code block, try to find direct HTML + else { + for pattern in html_start_patterns.iter() { + if let Some(start_idx) = response.find(pattern) { + html_code = response[start_idx..].trim().to_string(); + + // Everything before HTML is considered explanation + if start_idx > 0 { + explanation = response[0..start_idx].trim().to_string(); + } + break; + } + } + } + + // If still no HTML found, look for any content between ") { + // Create a basic HTML wrapper around the script + let script_content = &response[script_start..script_start + script_end + 9]; + html_code = format!( + "\n\n\nVisualization\n\n\n\n
\n{}\n\n", + script_content + ); + + // Everything else is explanation + explanation = response.replace(script_content, "").trim().to_string(); + } + } + } + + // If still nothing found, return the whole response as HTML with a warning + if html_code.is_empty() { + html_code = format!( + "\n\n\nVisualization Error\n\n\n

Could not generate visualization

\n
{}
\n\n", + response.replace("<", "<").replace(">", ">") + ); + explanation = "Could not parse LLM response into valid HTML visualization.".to_string(); + } + + // If explanation is empty, provide a default + if explanation.is_empty() { + explanation = "Visualization generated from the provided data.".to_string(); + } + + (html_code, explanation, confidence) +} + #[tokio::main] async fn main() { // Load environment variables @@ -131,6 +346,7 @@ async fn main() { .route("/", get(root)) .route("/health", get(health_check)) .route("/process-query", post(process_query)) + .route("/generate", post(generate_visualization)) .layer( CorsLayer::new() .allow_origin(Any) diff --git a/query_router/src/main.rs b/query_router/src/main.rs index 8594ebe..699b1aa 100644 --- a/query_router/src/main.rs +++ b/query_router/src/main.rs @@ -22,6 +22,15 @@ struct TranslateAndExecuteRequest { model: String, } +// Request model for visualization generation +#[derive(Debug, Deserialize)] +struct VisualizationRequest { + natural_query: String, + results: Value, + #[serde(default = "default_model")] + model: String, +} + fn default_model() -> String { "gpt-3.5-turbo".to_string() } @@ -43,6 +52,14 @@ struct ResponseMetadata { total_time_ms: u64, } +// Response model for visualization generation +#[derive(Debug, Serialize)] +struct VisualizationResponse { + html_code: String, + explanation: String, + metadata: ResponseMetadata, +} + // LLM Engine response structure #[derive(Debug, Deserialize)] struct LlmResponse { @@ -51,6 +68,14 @@ struct LlmResponse { confidence: f64, } +// LLM Engine visualization response structure +#[derive(Debug, Deserialize)] +struct LlmVisualizationResponse { + html_code: String, + explanation: String, + confidence: f64, +} + // Application state struct AppState { client: Client, @@ -128,6 +153,7 @@ async fn main() -> Result<(), Box> { let app = Router::new() .route("/health", get(health_check)) .route("/translate-and-execute", post(translate_and_execute)) + .route("/visualize", post(generate_visualization)) .with_state(state) .layer(middleware); @@ -262,3 +288,64 @@ async fn execute_sql_query(state: &AppState, sql_query: &str) -> Result>, + Json(request): Json, +) -> Result, AppError> { + let start_time = Instant::now(); + + // Call LLM engine to generate visualization + let llm_start_time = Instant::now(); + let llm_response = call_llm_visualization_engine(&state, &request).await?; + let llm_processing_time = llm_start_time.elapsed().as_millis() as u64; + + // Build the response + let total_time = start_time.elapsed().as_millis() as u64; + + let response = VisualizationResponse { + html_code: llm_response.html_code, + explanation: llm_response.explanation, + metadata: ResponseMetadata { + confidence: llm_response.confidence, + execution_time_ms: 0, // No SQL execution in this flow + llm_processing_time_ms: llm_processing_time, + total_time_ms: total_time, + }, + }; + + Ok(Json(response)) +} + +// Call LLM engine to generate visualization HTML/JS +async fn call_llm_visualization_engine( + state: &AppState, + request: &VisualizationRequest +) -> Result { + let url = format!("{}/generate", state.llm_engine_url); + + let llm_request = json!({ + "query": request.natural_query, + "results": request.results, + "model": request.model + }); + + let response = state.client + .post(&url) + .json(&llm_request) + .send() + .await + .map_err(AppError::LlmEngineError)?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); + return Err(AppError::LlmResponseError(format!("LLM engine returned error ({}): {}", status, error_text))); + } + + let llm_response = response.json::().await + .map_err(|e| AppError::LlmResponseError(format!("Failed to parse LLM visualization response: {}", e)))?; + + Ok(llm_response) +}