Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llm_engine/src/llm_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub async fn process_natural_language_query(
// Get model name from parameters or environment
let model_name = model
.or_else(|| env::var("LLM_MODEL").ok())
.unwrap_or_else(|| "gpt-4".to_string());
.unwrap_or_else(|| "gpt-3.5-turbo".to_string());

// Generate prompt with query and schema context
let prompt = generate_prompt(&query, db_schema.as_ref());
Expand Down
216 changes: 216 additions & 0 deletions llm_engine/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use axum::{
response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tower_http::cors::{CorsLayer, Any};
use tracing::{info, error};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
Expand All @@ -33,6 +34,14 @@ struct QueryRequest {
model: Option<String>,
}

// Request model for visualization generation
#[derive(Deserialize)]
struct VisualizationRequest {
query: String,
results: Value,
model: Option<String>,
}

// Response model
#[derive(Serialize)]
struct QueryResponse {
Expand All @@ -41,6 +50,14 @@ struct QueryResponse {
confidence: Option<f64>,
}

// Response model for visualization generation
#[derive(Serialize)]
struct VisualizationResponse {
html_code: String,
explanation: String,
confidence: f64,
}

// Error handling
enum AppError {
InternalError(String),
Expand Down Expand Up @@ -110,6 +127,204 @@ async fn process_query(
}
}

// Process visualization request
async fn generate_visualization(
State(_state): State<Arc<AppState>>,
Json(request): Json<VisualizationRequest>,
) -> Result<impl IntoResponse, AppError> {
info!("Generating visualization for query: {}", request.query);

// Get the model name from the request or use a default
let model = request.model.unwrap_or_else(|| "gpt-3.5-turbo".to_string());

// Format the results data for the prompt
let results_json = serde_json::to_string_pretty(&request.results)
.unwrap_or_else(|_| format!("{:?}", request.results));

// Create the prompt for visualization generation
let system_prompt = format!(
"You are a data visualization expert. Your task is to create a Plotly.js visualization based on the provided data and query. \
You must return a complete, valid HTML file that uses Plotly.js to visualize the data.\
\
STYLING GUIDELINES:\
- Use Plotly.js to create the chart\
- Use 14px font for axis labels, 18px for titles\
- Use consistent margins and padding\
- Use a neobrutalist style\
- Include axis titles with units (if known)\
- Rotate x-axis labels if they are dates or long strings\
- Use tight layout with `autosize: true` and `responsive: true`\
- Enable zoom and pan interactivity\
- Enable tooltips on hover showing exact values and labels\
- Use hovermode: 'closest'\
\
OUTPUT FORMAT:\
Your response should be a complete HTML file that can be directly viewed in a browser.\
Return valid HTML that includes the Plotly.js library from a CDN and creates the visualization.\
Also include a brief explanation of the visualization choices you made."
);

let user_prompt = format!(
"Natural Language Query: {}\n\nData Results:\n{}\n\nBased on this query and data, create a complete HTML file with a Plotly.js visualization.",
request.query, results_json
);

// Call the LLM with the specialized prompt
let llm_response = call_llm_api(&system_prompt, &user_prompt, &model).await
.map_err(|e| {
let error_msg = format!("Error calling LLM API: {}", e);
error!("{}", error_msg);
AppError::InternalError(error_msg)
})?;

// Extract the HTML code and explanation from the response
let (html_code, explanation, confidence) = parse_visualization_response(&llm_response);

Ok(Json(VisualizationResponse {
html_code,
explanation,
confidence,
}))
}

// Call LLM API with system and user prompts
async fn call_llm_api(system_prompt: &str, user_prompt: &str, model_name: &str) -> Result<String, anyhow::Error> {
let api_key = env::var("LLM_API_KEY").map_err(|_| anyhow::anyhow!("LLM_API_KEY environment variable not set"))?;

let client = reqwest::Client::new();

#[derive(Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}

#[derive(Serialize)]
struct OpenAIRequest {
model: String,
messages: Vec<Message>,
temperature: f64,
}

let request = OpenAIRequest {
model: model_name.to_string(),
messages: vec![
Message {
role: "system".to_string(),
content: system_prompt.to_string(),
},
Message {
role: "user".to_string(),
content: user_prompt.to_string(),
},
],
temperature: 0.7, // Slightly higher temperature for more creative visualizations
};

let response = client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;

if !response.status().is_success() {
let error_text = response.text().await?;
return Err(anyhow::anyhow!("LLM API returned error: {}", error_text));
}

#[derive(Deserialize)]
struct OpenAIChoice {
message: Message,
}

#[derive(Deserialize)]
struct OpenAIResponse {
choices: Vec<OpenAIChoice>,
}

let response_json: OpenAIResponse = response.json().await?;

if response_json.choices.is_empty() {
return Err(anyhow::anyhow!("LLM API returned empty choices"));
}

Ok(response_json.choices[0].message.content.clone())
}

// Parse LLM response to extract HTML, explanation and confidence
fn parse_visualization_response(response: &str) -> (String, String, f64) {
// Try to extract HTML content - look for <!DOCTYPE html> or <html>
let html_start_patterns = ["<!DOCTYPE html>", "<html>"];
let mut html_code = String::new();
let mut explanation = String::new();
let confidence = 0.8; // Default confidence

// First, check if the response contains a code block with HTML
if let Some(html_block_start) = response.find("```html") {
// Find the end of the code block (next ```)
if let Some(html_block_end) = response[html_block_start + 6..].find("```") {
// Extract HTML content (skip the ```html and end ```)
let block_start_pos = html_block_start + "```html".len();
let block_end_pos = html_block_start + 6 + html_block_end;
html_code = response[block_start_pos..block_end_pos].trim().to_string();

// Look for explanation after the HTML block
if block_end_pos + 3 < response.len() {
explanation = response[block_end_pos + 3..].trim().to_string();
}
}
}
// If no code block, try to find direct HTML
else {
for pattern in html_start_patterns.iter() {
if let Some(start_idx) = response.find(pattern) {
html_code = response[start_idx..].trim().to_string();

// Everything before HTML is considered explanation
if start_idx > 0 {
explanation = response[0..start_idx].trim().to_string();
}
break;
}
}
}

// If still no HTML found, look for any content between <script> tags or <div id="plot">
if html_code.is_empty() {
if let Some(script_start) = response.find("<script>") {
if let Some(script_end) = response[script_start..].find("</script>") {
// Create a basic HTML wrapper around the script
let script_content = &response[script_start..script_start + script_end + 9];
html_code = format!(
"<!DOCTYPE html>\n<html>\n<head>\n<title>Visualization</title>\n<script src=\"https://cdn.plot.ly/plotly-latest.min.js\"></script>\n</head>\n<body>\n<div id=\"plot\"></div>\n{}\n</body>\n</html>",
script_content
);

// Everything else is explanation
explanation = response.replace(script_content, "").trim().to_string();
}
}
}

// If still nothing found, return the whole response as HTML with a warning
if html_code.is_empty() {
html_code = format!(
"<!DOCTYPE html>\n<html>\n<head>\n<title>Visualization Error</title>\n</head>\n<body>\n<h1>Could not generate visualization</h1>\n<pre>{}</pre>\n</body>\n</html>",
response.replace("<", "&lt;").replace(">", "&gt;")
);
explanation = "Could not parse LLM response into valid HTML visualization.".to_string();
}

// If explanation is empty, provide a default
if explanation.is_empty() {
explanation = "Visualization generated from the provided data.".to_string();
}

(html_code, explanation, confidence)
}

#[tokio::main]
async fn main() {
// Load environment variables
Expand All @@ -131,6 +346,7 @@ async fn main() {
.route("/", get(root))
.route("/health", get(health_check))
.route("/process-query", post(process_query))
.route("/generate", post(generate_visualization))
.layer(
CorsLayer::new()
.allow_origin(Any)
Expand Down
87 changes: 87 additions & 0 deletions query_router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ struct TranslateAndExecuteRequest {
model: String,
}

// Request model for visualization generation
#[derive(Debug, Deserialize)]
struct VisualizationRequest {
natural_query: String,
results: Value,
#[serde(default = "default_model")]
Copy link

Copilot AI Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default model value returned by default_model ('gpt-3.5-turbo') is inconsistent with the default ('gpt-4') used in the llm_engine endpoint. Consider harmonizing these defaults to prevent unexpected behavior.

Copilot uses AI. Check for mistakes.
model: String,
}

fn default_model() -> String {
"gpt-3.5-turbo".to_string()
}
Expand All @@ -43,6 +52,14 @@ struct ResponseMetadata {
total_time_ms: u64,
}

// Response model for visualization generation
#[derive(Debug, Serialize)]
struct VisualizationResponse {
html_code: String,
explanation: String,
metadata: ResponseMetadata,
}

// LLM Engine response structure
#[derive(Debug, Deserialize)]
struct LlmResponse {
Expand All @@ -51,6 +68,14 @@ struct LlmResponse {
confidence: f64,
}

// LLM Engine visualization response structure
#[derive(Debug, Deserialize)]
struct LlmVisualizationResponse {
html_code: String,
explanation: String,
confidence: f64,
}

// Application state
struct AppState {
client: Client,
Expand Down Expand Up @@ -128,6 +153,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new()
.route("/health", get(health_check))
.route("/translate-and-execute", post(translate_and_execute))
.route("/visualize", post(generate_visualization))
.with_state(state)
.layer(middleware);

Expand Down Expand Up @@ -262,3 +288,64 @@ async fn execute_sql_query(state: &AppState, sql_query: &str) -> Result<Value, A

Ok(result)
}

// Endpoint for generating visualizations from natural language and data
async fn generate_visualization(
State(state): State<Arc<AppState>>,
Json(request): Json<VisualizationRequest>,
) -> Result<Json<VisualizationResponse>, AppError> {
let start_time = Instant::now();

// Call LLM engine to generate visualization
let llm_start_time = Instant::now();
let llm_response = call_llm_visualization_engine(&state, &request).await?;
let llm_processing_time = llm_start_time.elapsed().as_millis() as u64;

// Build the response
let total_time = start_time.elapsed().as_millis() as u64;

let response = VisualizationResponse {
html_code: llm_response.html_code,
explanation: llm_response.explanation,
metadata: ResponseMetadata {
confidence: llm_response.confidence,
execution_time_ms: 0, // No SQL execution in this flow
llm_processing_time_ms: llm_processing_time,
total_time_ms: total_time,
},
};

Ok(Json(response))
}

// Call LLM engine to generate visualization HTML/JS
async fn call_llm_visualization_engine(
state: &AppState,
request: &VisualizationRequest
) -> Result<LlmVisualizationResponse, AppError> {
let url = format!("{}/generate", state.llm_engine_url);

let llm_request = json!({
"query": request.natural_query,
"results": request.results,
"model": request.model
});

let response = state.client
.post(&url)
.json(&llm_request)
.send()
.await
.map_err(AppError::LlmEngineError)?;

if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(AppError::LlmResponseError(format!("LLM engine returned error ({}): {}", status, error_text)));
}

let llm_response = response.json::<LlmVisualizationResponse>().await
.map_err(|e| AppError::LlmResponseError(format!("Failed to parse LLM visualization response: {}", e)))?;

Ok(llm_response)
}
Loading