-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
169 lines (141 loc) Β· 6.04 KB
/
app.py
File metadata and controls
169 lines (141 loc) Β· 6.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import streamlit as st
import os
import sqlite3
import pandas as pd
from utils.schema_reader import get_schema_metadata, format_schema_for_prompt
from utils.model_loader import load_hf_pipeline, generate_response
from utils.prompt_builder import build_prompt
from utils.parsers import parse_model_response
from utils.sql_validator import validate_sql
from utils.query_runner import execute_query
from utils.state import init_state, add_to_history, clear_history
st.set_page_config(page_title="SQLify", page_icon="π", layout="wide")
def load_env_variables():
# Helper to read from .env if python-dotenv is not installed
if os.path.exists(".env"):
with open(".env") as f:
for line in f:
if line.strip() and not line.startswith("#"):
key, val = line.strip().split("=", 1)
os.environ[key] = val.strip('"\'')
load_env_variables()
MODEL_ID_OR_PATH = os.environ.get("MODEL_ID_OR_PATH", "REPLACE_WITH_MY_MODEL")
DATABASE_PATH = os.environ.get("DATABASE_PATH", "REPLACE_WITH_MY_SQLITE_DB")
init_state()
# -----------------
# Sidebar Settings
# -----------------
with st.sidebar:
st.title("SQLify Settings βοΈ")
st.markdown("### Environment")
st.text(f"Model: {MODEL_ID_OR_PATH}")
st.text(f"Database: {DATABASE_PATH}")
st.markdown("### Generation Settings")
max_new_tokens = st.number_input("Max New Tokens", min_value=10, max_value=2048, value=256)
temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=0.1, step=0.1)
top_p = st.slider("Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05)
do_sample = st.checkbox("Do Sample", value=False)
st.markdown("### Schema Viewer")
try:
schema_metadata = get_schema_metadata(DATABASE_PATH)
schema_text = format_schema_for_prompt(schema_metadata)
st.code(schema_text, language="sql")
except Exception as e:
st.error(f"Error loading schema: {e}")
schema_metadata = None
schema_text = ""
st.markdown("### Example Questions")
st.markdown("- Show total sales by month")
st.markdown("- How many users signed up last week?")
st.markdown("- List the top 5 products by revenue")
# -----------------
# Main Area
# -----------------
st.title("π SQLify")
st.markdown("Convert natural language business questions into safe, read-only SQLite queries.")
# Load Model
pipe, tokenizer = load_hf_pipeline(MODEL_ID_OR_PATH)
if not pipe:
st.error("Model failed to load. Please check the model path in your environment variables.")
st.stop()
if not schema_metadata:
st.error("Failed to read database schema. Please check DATABASE_PATH.")
st.stop()
# Input
question = st.text_input("Ask a question about your data:", placeholder="e.g. Show total sales by month")
# Action Buttons
col1, col2, _ = st.columns([1, 1, 4])
with col1:
generate_clicked = st.button("Generate SQL", type="primary")
# Generation Logic
if generate_clicked and question:
with st.spinner("Generating SQL..."):
prompt = build_prompt(question, schema_text)
settings = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample
}
raw_output = generate_response(pipe, tokenizer, prompt, settings)
if raw_output:
# We only want the newly generated text if it returns the full prompt
# But we set return_full_text=False in generation_kwargs, so it should just be the answer
parsed_response = parse_model_response(raw_output)
st.session_state.current_query = {
'question': question,
'can_answer': parsed_response.can_answer,
'sql': parsed_response.sql,
'explanation': parsed_response.explanation,
'is_valid': False,
'error_message': ""
}
if parsed_response.can_answer:
val_result = validate_sql(parsed_response.sql)
st.session_state.current_query['is_valid'] = val_result['is_valid']
st.session_state.current_query['sql'] = val_result['cleaned_sql']
st.session_state.current_query['error_message'] = val_result['error_message']
add_to_history(
question,
parsed_response.sql,
parsed_response.explanation,
st.session_state.current_query['is_valid']
)
# Display Current Query
if st.session_state.current_query:
cq = st.session_state.current_query
st.markdown("### Generated Result")
if not cq['can_answer']:
st.warning(f"**The model cannot answer this question based on the schema.**\n\nExplanation: {cq['explanation']}")
else:
st.markdown(f"**Explanation:** {cq['explanation']}")
st.code(cq['sql'], language="sql")
if not cq['is_valid']:
st.error(f"**SQL Validation Failed:** {cq['error_message']}")
else:
st.success("SQL is valid and read-only.")
# Show Run Query Button
if st.button("Run Query π"):
with st.spinner("Executing query..."):
df, err = execute_query(DATABASE_PATH, cq['sql'])
if err:
st.error(f"Execution Error: {err}")
else:
st.markdown("### Results")
st.dataframe(df, use_container_width=True)
# History
st.markdown("---")
col_h1, col_h2 = st.columns([4, 1])
with col_h1:
st.markdown("### Recent Queries")
with col_h2:
if st.button("Clear History"):
clear_history()
st.rerun()
if not st.session_state.history:
st.info("No query history yet.")
else:
for item in st.session_state.history:
with st.expander(f"{'β
' if item['success'] else 'β'} {item['question']}"):
st.markdown(f"**Explanation:** {item['explanation']}")
st.code(item['sql'], language="sql")