diff --git a/README.md b/README.md index 72c1243..794f2a6 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,18 @@

+## Quick Start + +1. Use `uv sync` or `conda env create -f environment.yml` to initialize python env. +2. Run `./scripts/init_dataset.sh` to prepare the dataset. +3. Run `./script/run_gpt.sh` to generate questions & have the LLM predict SQLs. +4. Run `./script/run_evaluation.sh` to get all scores; + + To get specific ex/res-v/soft-f1 score, add parms 1/2/3 at the end. + +**Important:** Before running all these scripts, remember to modify the **CAP_VARIABLES** to match your env. + +--- ## Overview Here, we provide a Lite version of developtment dataset: **Mini-Dev**. This mini-dev dataset is designed to facilitate efficient and cost-effective development cycles, especially for testing and refining SQL query generation models. This dataset results from community feedback, leading to the compilation of 500 high-quality text2sql pairs derived from 11 distinct databases in a development environment. To further enhance the practicality of the BIRD system in industry settings and support the development of text-to-SQL models, we make the Mini-Dev dataset available in both **MySQL** and **PostgreSQL**. diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..da07081 --- /dev/null +++ b/environment.yml @@ -0,0 +1,22 @@ +name: mini_dev +dependencies: + - python=3.11.5 + - pip=24.0 + - annotated-types=0.7.0 + - anyio=4.4.0 + - certifi=2024.6.2 + - distro=1.9.0 + - func_timeout=4.3.5 + - h11=0.14.0 + - httpcore=1.0.5 + - httpx=0.27.0 + - idna=3.7 + - numpy=2.0.0 + - openai=1.34.0 + - psycopg2-binary=2.9.9 + - pydantic=2.7.4 + - pydantic-core=2.18.4 + - pymysql=1.1.1 + - sniffio=1.3.1 + - tqdm=4.66.4 + - typing-extensions=4.12.2 \ No newline at end of file diff --git a/evaluation/evaluation_ex.py b/evaluation/evaluation_ex.py index 45cc3bb..32230c3 100644 --- a/evaluation/evaluation_ex.py +++ b/evaluation/evaluation_ex.py @@ -3,7 +3,7 @@ import multiprocessing as mp from func_timeout import func_timeout, FunctionTimedOut from evaluation_utils import ( - load_jsonl, + load_json_data, execute_sql, package_sqls, sort_results, @@ -34,10 +34,10 @@ def execute_model( except KeyboardInterrupt: sys.exit(0) except FunctionTimedOut: - result = [(f"timeout",)] + result = [("timeout",)] res = 0 - except Exception as e: - result = [(f"error",)] # possibly len(query) > 512 or not executable + except Exception: + result = [("error",)] # possibly len(query) > 512 or not executable res = 0 result = {"sql_idx": idx, "res": res} return result @@ -69,7 +69,7 @@ def run_sqls_parallel( def compute_acc_by_diff(exec_results, diff_json_path): num_queries = len(exec_results) results = [res["res"] for res in exec_results] - contents = load_jsonl(diff_json_path) + contents = load_json_data(diff_json_path) simple_results, moderate_results, challenging_results = [], [], [] for i, content in enumerate(contents): diff --git a/evaluation/evaluation_f1.py b/evaluation/evaluation_f1.py index 0b94582..bf369a8 100644 --- a/evaluation/evaluation_f1.py +++ b/evaluation/evaluation_f1.py @@ -3,7 +3,7 @@ import multiprocessing as mp from func_timeout import func_timeout, FunctionTimedOut from evaluation_utils import ( - load_jsonl, + load_json_data, execute_sql, package_sqls, sort_results, @@ -123,10 +123,10 @@ def execute_model( except KeyboardInterrupt: sys.exit(0) except FunctionTimedOut: - result = [(f"timeout",)] + result = [("timeout",)] res = 0 - except Exception as e: - result = [(f"error",)] # possibly len(query) > 512 or not executable + except Exception: + result = [("error",)] # possibly len(query) > 512 or not executable res = 0 # print(result) # result = str(set([ret[0] for ret in result])) @@ -161,7 +161,7 @@ def run_sqls_parallel( def compute_f1_by_diff(exec_results, diff_json_path): num_queries = len(exec_results) results = [res["res"] for res in exec_results] - contents = load_jsonl(diff_json_path) + contents = load_json_data(diff_json_path) simple_results, moderate_results, challenging_results = [], [], [] for i, content in enumerate(contents): diff --git a/evaluation/evaluation_utils.py b/evaluation/evaluation_utils.py index 1c2c9be..aee7b95 100644 --- a/evaluation/evaluation_utils.py +++ b/evaluation/evaluation_utils.py @@ -1,8 +1,21 @@ import json +from pathlib import Path + import psycopg2 import pymysql import sqlite3 + +def load_json_data(file_path): + file = Path(file_path) + if file.suffix == '.json': + return load_json(file) + elif file.suffix == '.jsonl': + return load_jsonl(file) + else: + raise ValueError('Invalid file type') + + def load_jsonl(file_path): data = [] with open(file_path, "r") as file: @@ -10,6 +23,7 @@ def load_jsonl(file_path): data.append(json.loads(line)) return data + def load_json(dir): with open(dir, "r") as j: contents = json.loads(j.read()) @@ -29,7 +43,7 @@ def connect_postgresql(): # PyMySQL 1.1.1 def connect_mysql(): # Open database connection - # Connect to the database" + # Connect to the database db = pymysql.connect( host="localhost", user="root", @@ -119,6 +133,7 @@ def print_data(score_lists, count_lists, metric="F1 Score",result_log_file=None) # Log to file in append mode if result_log_file is not None: + Path(result_log_file).parent.mkdir(parents=True, exist_ok=True) with open(result_log_file, "a") as log_file: log_file.write(f"start calculate {metric}\n") log_file.write("{:20} {:20} {:20} {:20} {:20}\n".format("", *levels)) diff --git a/evaluation/evaluation_ves.py b/evaluation/evaluation_ves.py index 637e3e5..e95e179 100644 --- a/evaluation/evaluation_ves.py +++ b/evaluation/evaluation_ves.py @@ -5,7 +5,7 @@ import multiprocessing as mp from func_timeout import func_timeout, FunctionTimedOut from evaluation_utils import ( - load_jsonl, + load_json_data, execute_sql, package_sqls, sort_results, @@ -96,10 +96,10 @@ def execute_model( except KeyboardInterrupt: sys.exit(0) except FunctionTimedOut: - result = [(f"timeout",)] + result = [("timeout",)] reward = 0 - except Exception as e: - result = [(f"error",)] # possibly len(query) > 512 or not executable + except Exception: + result = [("error",)] # possibly len(query) > 512 or not executable reward = 0 result = {"sql_idx": idx, "reward": reward} return result @@ -148,7 +148,7 @@ def compute_ves(exec_results): def compute_ves_by_diff(exec_results, diff_json_path): num_queries = len(exec_results) - contents = load_jsonl(diff_json_path) + contents = load_json_data(diff_json_path) simple_results, moderate_results, challenging_results = [], [], [] for i, content in enumerate(contents): if content["difficulty"] == "simple": diff --git a/evaluation/run_evaluation.sh b/evaluation/run_evaluation.sh deleted file mode 100644 index 8910ed0..0000000 --- a/evaluation/run_evaluation.sh +++ /dev/null @@ -1,66 +0,0 @@ -# DO NOT CHANGE THIS -db_root_path='../sqlite/dev_databases/' -num_cpus=16 -meta_time_out=30.0 -# DO NOT CHANGE THIS - -# ************************* # -predicted_sql_path='../sql_result/predict_mini_dev_gpt-4-32k_cot_SQLite.json' # Replace with your predict sql json path -# predicted_sql_path='../sql_result/predict_mini_dev_gpt-4-32k_cot_PostgreSQL.json' # Replace with your predict sql json path -# predicted_sql_path='../sql_result/predict_mini_dev_gpt-4-32k_cot_MySQL.json' # Replace with your predict sql json path - -sql_dialect="SQLite" # ONLY Modify this -# sql_dialect="PostgreSQL" # ONLY Modify this -# sql_dialect="MySQL" # ONLY Modify this -# ************************* # - -# DO NOT CHANGE THIS -# Extract the base filename without extension -base_name=$(basename "$predicted_sql_path" .json) -# Define the output log path -output_log_path="../eval_result/${base_name}.txt" - -case $sql_dialect in - "SQLite") - diff_json_path="../sqlite/mini_dev_sqlite.jsonl" - ground_truth_path="../sqlite/mini_dev_sqlite_gold.sql" - ;; - "PostgreSQL") - diff_json_path="../postgresql/mini_dev_postgresql.jsonl" - ground_truth_path="../postgresql/mini_dev_postgresql_gold.sql" - ;; - "MySQL") - diff_json_path="../mysql/mini_dev_mysql.jsonl" - ground_truth_path="../mysql/mini_dev_mysql_gold.sql" - ;; - *) - echo "Invalid SQL dialect: $sql_dialect" - exit 1 - ;; -esac -# DO NOT CHANGE THIS - -# Output the set paths -echo "Differential JSON Path: $diff_json_path" -echo "Ground Truth Path: $ground_truth_path" - - - - -echo "starting to compare with knowledge for ex, sql_dialect: ${sql_dialect}" -python3 -u ./evaluation_ex.py --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path} \ ---ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --output_log_path ${output_log_path} \ ---diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} --sql_dialect ${sql_dialect} - - - -# echo "starting to compare with knowledge for R-VES, sql_dialect: ${sql_dialect}" -# python3 -u ./evaluation_ves.py --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path} \ -# --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --output_log_path ${output_log_path} \ -# --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} --sql_dialect ${sql_dialect} - - -# echo "starting to compare with knowledge for soft-f1, sql_dialect: ${sql_dialect}" -# python3 -u ./evaluation_f1.py --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path} \ -# --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --output_log_path ${output_log_path} \ -# --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} --sql_dialect ${sql_dialect} \ No newline at end of file diff --git a/llm/run/run_gpt.sh b/llm/run/run_gpt.sh deleted file mode 100644 index 9dfa358..0000000 --- a/llm/run/run_gpt.sh +++ /dev/null @@ -1,28 +0,0 @@ -eval_path='./data/mini_dev_sqlite.json' # _sqlite.json, _mysql.json, _postgresql.json -dev_path='./output/' -db_root_path='./data/dev_databases/' -use_knowledge='True' -mode='mini_dev' # dev, train, mini_dev -cot='True' - -YOUR_API_KEY='YOUR_API_KEY' - -# Choose the engine to run, e.g. gpt-4, gpt-4-32k, gpt-4-turbo, gpt-35-turbo, GPT35-turbo-instruct -engine='gpt-4-turbo' - -# Choose the number of threads to run in parallel, 1 for single thread -num_threads=3 - -# Choose the SQL dialect to run, e.g. SQLite, MySQL, PostgreSQL -# PLEASE NOTE: You have to setup the database information in table_schema.py -# if you want to run the evaluation script using MySQL or PostgreSQL -sql_dialect='SQLite' - -# Choose the output path for the generated SQL queries -data_output_path='./exp_result/turbo_output/' -data_kg_output_path='./exp_result/turbo_output_kg/' - -echo "generate $engine batch, run in $num_threads threads, with knowledge: $use_knowledge, with chain of thought: $cot" -python3 -u ./src/gpt_request.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} --mode ${mode} \ ---engine ${engine} --eval_path ${eval_path} --data_output_path ${data_kg_output_path} --use_knowledge ${use_knowledge} \ ---chain_of_thought ${cot} --num_process ${num_threads} --sql_dialect ${sql_dialect} \ No newline at end of file diff --git a/llm/src/gpt_request.py b/llm/src/gpt_request.py index e1d5cde..20a515c 100644 --- a/llm/src/gpt_request.py +++ b/llm/src/gpt_request.py @@ -1,240 +1,176 @@ #!/usr/bin/env python3 import argparse import json -import os -from openai import AzureOpenAI +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from pathlib import Path + from tqdm import tqdm -import time -from concurrent.futures import ThreadPoolExecutor -import concurrent.futures +from llm_client import LLMClient from prompt import generate_combined_prompts_one -"""openai configure""" -api_version = "2024-02-01" -api_base = "https://gcrendpoint.azurewebsites.net" - - -def new_directory(path): - if not os.path.exists(path): - os.makedirs(path) - - -def connect_gpt(engine, prompt, max_tokens, temperature, stop, client): +@dataclass +class Config: """ - Function to connect to the GPT API and get the response. + Encapsulates all script configurations. """ - MAX_API_RETRY = 10 - for i in range(MAX_API_RETRY): - time.sleep(2) - try: - - if engine == "gpt-35-turbo-instruct": - result = client.completions.create( - model="gpt-3.5-turbo-instruct", - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - stop=stop, - ) - result = result.choices[0].text - else: # gpt-4-turbo, gpt-4, gpt-4-32k, gpt-35-turbo - messages = [ - {"role": "user", "content": prompt}, - ] - result = client.chat.completions.create( - model=engine, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - stop=stop, - ) - break - except Exception as e: - result = "error:{}".format(e) - print(result) - time.sleep(4) - return result - - -def decouple_question_schema(datasets, db_root_path): - question_list = [] - db_path_list = [] - knowledge_list = [] - for i, data in enumerate(datasets): - question_list.append(data["question"]) - cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite" - db_path_list.append(cur_db_path) - knowledge_list.append(data["evidence"]) - - return question_list, db_path_list, knowledge_list - - -def generate_sql_file(sql_lst, output_path=None): + # API and Model Config + provider: str # 'azure' | 'openai' + base_url: str + api_key: str + api_version: str + model: str + + # Data and Path Config + eval_path: str + db_root_path: str + data_output_path: str + + # Execution Config + mode: str = "dev" + use_knowledge: bool = False + chain_of_thought: bool = False + num_threads: int = 3 + sql_dialect: str = "SQLite" + +@dataclass +class Task: """ - Function to save the SQL results to a file. + Represents a single task for the worker to process. """ - sql_lst.sort(key=lambda x: x[1]) - result = {} - for i, (sql, _) in enumerate(sql_lst): - result[i] = sql + index: int + question: str + db_path: str + prompt: str - if output_path: - directory_path = os.path.dirname(output_path) - new_directory(directory_path) - json.dump(result, open(output_path, "w"), indent=4) - return result - - -def init_client(api_key, api_version, engine): +def prepare_tasks(config: Config): """ - Initialize the AzureOpenAI client for a worker. + Prepares the list of tasks to be processed. """ - return AzureOpenAI( - api_key=api_key, - api_version=api_version, - base_url=f"{api_base}/openai/deployments/{engine}", - ) - - -def post_process_response(response, db_path): - sql = response if isinstance(response, str) else response.choices[0].message.content - db_id = db_path.split("/")[-1].split(".sqlite")[0] - sql = f"{sql}\t----- bird -----\t{db_id}" - return sql + with open(config.eval_path, 'r') as f: + eval_data = json.load(f) + + tasks = [] + for i, data in enumerate(eval_data): + db_path = Path(config.db_root_path )/ data["db_id"] / f"{data['db_id']}.sqlite" + knowledge = data.get("evidence") if config.use_knowledge else None + + prompt = generate_combined_prompts_one( + db_path=db_path, + question=data["question"], + sql_dialect=config.sql_dialect, + knowledge=knowledge, + ) + tasks.append(Task(index=i, question=data["question"], db_path=db_path, prompt=prompt)) + return tasks -def worker_function(question_data): - """ - Function to process each question, set up the client, - generate the prompt, and collect the GPT response. - """ - prompt, engine, client, db_path, question, i = question_data - response = connect_gpt(engine, prompt, 512, 0, ["--", "\n\n", ";", "#"], client) - sql = post_process_response(response, db_path) - print(f"Processed {i}th question: {question}") - return sql, i - - -def collect_response_from_gpt( - db_path_list, - question_list, - api_key, - engine, - sql_dialect, - num_threads=3, - knowledge_list=None, -): +def process_task(task: Task, client: LLMClient) -> tuple[str, int]: """ - Collect responses from GPT using multiple threads. + Worker function to process a single task. """ - client = init_client(api_key, api_version, engine) - - tasks = [ - ( - generate_combined_prompts_one( - db_path=db_path_list[i], - question=question_list[i], - sql_dialect=sql_dialect, - knowledge=knowledge_list[i], - ), - engine, - client, - db_path_list[i], - question_list[i], - i, - ) - for i in range(len(question_list)) - ] - responses = [] - with ThreadPoolExecutor(max_workers=num_threads) as executor: - future_to_task = { - executor.submit(worker_function, task): task for task in tasks - } - for future in tqdm( - concurrent.futures.as_completed(future_to_task), total=len(tasks) - ): - responses.append(future.result()) - return responses + response_text = client.ask(task.prompt) + db_id = Path(task.db_path).stem + sql_result = f"{response_text}\t----- bird -----\t{db_id}" + print(f"Processed {task.index}th question: {task.question}") + return sql_result, task.index -if __name__ == "__main__": +def generate_sql_file(results: list, output_path: Path): + """ + Saves the generated SQL results to a JSON file, sorted by index. + """ + if not results: + return + + results.sort(key=lambda x: x[1]) + output_dict = {i: sql for i, (sql, _) in enumerate(results)} + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w") as f: + json.dump(output_dict, f, indent=4) + print(f"Successfully saved results to {output_path}") + +def load_config(): args_parser = argparse.ArgumentParser() - args_parser.add_argument("--eval_path", type=str, default="") - args_parser.add_argument("--mode", type=str, default="dev") - args_parser.add_argument("--test_path", type=str, default="") - args_parser.add_argument("--use_knowledge", type=str, default="False") - args_parser.add_argument("--db_root_path", type=str, default="") + # API and Model Config + args_parser.add_argument("--provider", type=str, default="openai") + args_parser.add_argument("--base_url", type=str, required=True) args_parser.add_argument("--api_key", type=str, required=True) + args_parser.add_argument("--api_version", type=str, default="") args_parser.add_argument( - "--engine", type=str, required=True, default="code-davinci-002" + "--model", type=str, required=True, default="code-davinci-002" ) + + # Data and Path Config + args_parser.add_argument("--eval_path", type=str, default="") + args_parser.add_argument("--db_root_path", type=str, default="") args_parser.add_argument("--data_output_path", type=str) - args_parser.add_argument("--chain_of_thought", type=str) - args_parser.add_argument("--num_processes", type=int, default=3) + + # Execution Config + args_parser.add_argument("--mode", type=str, default="dev") args_parser.add_argument("--sql_dialect", type=str, default="SQLite") + args_parser.add_argument("--num_threads", type=int, default=3) + args_parser.add_argument("--use_knowledge", type=str, default="False") + args_parser.add_argument("--chain_of_thought", type=str) + args = args_parser.parse_args() - eval_data = json.load(open(args.eval_path, "r")) + return Config( + provider=args.provider, + base_url=args.base_url, + api_key=args.api_key, + api_version=args.api_version, + model=args.model, + eval_path=args.eval_path, + db_root_path=args.db_root_path, + data_output_path=args.data_output_path, + mode=args.mode, + sql_dialect=args.sql_dialect, + num_threads=args.num_threads, + use_knowledge=args.use_knowledge, + chain_of_thought=args.chain_of_thought, + ) - question_list, db_path_list, knowledge_list = decouple_question_schema( - datasets=eval_data, db_root_path=args.db_root_path +def main(): + cfg = load_config() + llm = LLMClient( + provider=cfg.provider, + model=cfg.model, + api_key=cfg.api_key, + base_url=cfg.base_url, + api_version=cfg.api_version + ) + tasks = prepare_tasks(cfg) + all_responses = [] + + with ThreadPoolExecutor(max_workers=cfg.num_threads) as executor: + future_to_task = {executor.submit(process_task, task, llm): task for task in tasks} + + for future in tqdm(as_completed(future_to_task), total=len(tasks), desc="Generating SQL"): + try: + result = future.result() + all_responses.append(result) + except Exception as e: + print(f"A task generated an exception: {e}") + + cot_suffix = "_cot" if cfg.chain_of_thought else "" + output_file =( + Path(cfg.data_output_path) / + f"predict_{cfg.mode}_{cfg.model}{cot_suffix}_{cfg.sql_dialect}.json" ) - assert len(question_list) == len(db_path_list) == len(knowledge_list) - - if args.use_knowledge == "True": - responses = collect_response_from_gpt( - db_path_list, - question_list, - args.api_key, - args.engine, - args.sql_dialect, - args.num_processes, - knowledge_list, - ) - else: - responses = collect_response_from_gpt( - db_path_list, - question_list, - args.api_key, - args.engine, - args.sql_dialect, - args.num_processes, - ) - if args.chain_of_thought == "True": - output_name = ( - args.data_output_path - + "predict_" - + args.mode - + "_" - + args.engine - + "_cot" - + "_" - + args.sql_dialect - + ".json" - ) - else: - output_name = ( - args.data_output_path - + "predict_" - + args.mode - + "_" - + args.engine - + "_" - + args.sql_dialect - + ".json" - ) - generate_sql_file(sql_lst=responses, output_path=output_name) + generate_sql_file(results=all_responses, output_path=output_file) print( - "successfully collect results from {} for {} evaluation; SQL dialect {} Use knowledge: {}; Use COT: {}".format( - args.engine, - args.mode, - args.sql_dialect, - args.use_knowledge, - args.chain_of_thought, - ) + f"successfully collect results from {cfg.model} for {cfg.mode} evaluation; " + f"SQL dialect {cfg.sql_dialect} Use knowledge: {cfg.use_knowledge}; " + f"Use COT: {cfg.chain_of_thought}" ) + + +if __name__ == "__main__": + main() diff --git a/llm/src/llm_client.py b/llm/src/llm_client.py new file mode 100644 index 0000000..7ed64f3 --- /dev/null +++ b/llm/src/llm_client.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +@Author : sz +@Project: mini_dev +@File : llm_client.py +@Time : 2025/6/9 10:12 +""" + +from openai import AzureOpenAI, OpenAI +from openai.types.chat import ChatCompletionUserMessageParam + + +class LLMClient: + """ + Initialize the LLM client. + Support OpenAI & Azure. + """ + + def __init__( + self, + provider: str, + model: str, + api_key: str, + base_url: str = None, + api_version: str = None, + temperature: float = 0, + max_tokens: int = 512, + max_retries: int = 10, + stop=None, + ): + if provider == "azure": + self.client = AzureOpenAI( + api_key=api_key, + api_version=api_version or "2024-02-01", + base_url=base_url, + max_retries=max_retries, + ) + + else: + self.client = OpenAI( + api_key=api_key, + base_url=base_url, + max_retries=max_retries, + ) + + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + self.stop = stop or ["--", "\n\n", ";", "#"] + self.is_chat_model = "instruct" not in self.model + + def ask(self, prompt: str) -> str: + """ + Sends a prompt to the configured LLM and returns the text response. + """ + if self.is_chat_model: + messages = [ChatCompletionUserMessageParam(role="user", content=prompt)] + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + stop=self.stop, + ) + result = response.choices[0].message.content + + else: # if model is an Instruct model + response = self.client.completions.create( + model=self.model, + prompt=prompt, + temperature=self.temperature, + max_tokens=self.max_tokens, + stop=self.stop, + ) + result = response.choices[0].text + + return result diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1af41ac --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[project] +name = "app" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "annotated-types>=0.7.0", + "anyio>=4.9.0", + "certifi>=2025.4.26", + "distro>=1.9.0", + "func-timeout>=4.3.5", + "h11>=0.16.0", + "httpcore>=1.0.9", + "httpx>=0.28.1", + "idna>=3.10", + "jiter>=0.10.0", + "numpy>=2.3.0", + "openai>=1.84.0", + "psycopg2-binary>=2.9.10", + "pydantic>=2.11.5", + "pydantic-core>=2.33.2", + "pymysql>=1.1.1", + "sniffio>=1.3.1", + "tenacity>=9.1.2", + "tqdm>=4.67.1", + "typing-extensions>=4.14.0", + "typing-inspection>=0.4.1", +] + +[[tool.uv.index]] +url = "https://mirrors.aliyun.com/pypi/simple" +default = true diff --git a/scripts/init_dataset.sh b/scripts/init_dataset.sh new file mode 100755 index 0000000..0f2274f --- /dev/null +++ b/scripts/init_dataset.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +set -eu + +# --- Data and Path Config --- +# dataset download url +PRIMARY_URL="https://bird-bench.oss-cn-beijing.aliyuncs.com/minidev.zip" +SECONDARY_URL="https://drive.google.com/file/d/1UJyA6I6pTmmhYpwdn8iT9QKrcJqSQAcX/view?usp=sharing" + +# dataset file name +MINIDEV_ZIP="llm/mini_dev_data/minidev.zip" + +# where to save the unzip dateset files +DATA_DIR="data" + +if [ ! -f "${MINIDEV_ZIP}" ]; then + echo "${MINIDEV_ZIP} not exist." + echo "Download dataset..." + + if curl -# -L -o "${MINIDEV_ZIP}" --connect-timeout 5 "${PRIMARY_URL}"; then + echo "Download finished." + echo + else + echo "Can't connect to Aliyun, change to Google Drive." + if curl -# -L -o "${MINIDEV_ZIP}" --connect-timeout 5 "${SECONDARY_URL}"; then + echo "Download finished." + echo + else + echo "Fail to download Mini_dev dataset, check your network" + exit 1 + fi + fi +fi + +echo "Unzip dataset..." +echo +mkdir -p "${DATA_DIR}" +unzip "${MINIDEV_ZIP}" -d "${DATA_DIR}" + +echo "Everything is ready~" diff --git a/scripts/run_evaluation.sh b/scripts/run_evaluation.sh new file mode 100755 index 0000000..02ce13a --- /dev/null +++ b/scripts/run_evaluation.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash +set -eu + +# --- User Modify Variables --- +# main data path +DATA_PATH="data/minidev/MINIDEV" + +# default sqlite file's path +DB_ROOT_PATH="${DATA_PATH}/dev_databases/" + +# how many cpus to run the eval +NUM_CPUS=3 + +# maximum seconds to execute sql +META_TIME_OUT=30.0 + +# how to execute python script, depends on your package manager +EXEC_CMD="uv run" # Options: "uv run" | "conda run -n mini_dev" + +# kind of DB you use +SQL_DIALECT="SQLite" # Options: "SQLite" | "PostgreSQL" | "MySQL" + +# your predict sql json path, which may generate by llm before +# so make it same as your DATA_OUTPUT_PATH in run_gpt.sh +PREDICTED_SQL_PATH="exp_result/predict_mini_dev_deepseek-chat_cot_SQLite.json" + +# define the output log path, which has same basename as your predict sql json +OUTPUT_LOG_PATH="eval_result/$(basename "$PREDICTED_SQL_PATH" .json).txt" + +# --- DO NOT CHANGE BELOW --- +EX_SCRIPT="evaluation/evaluation_ex.py" +R_VES_SCRIPT="evaluation/evaluation_ves.py" +F1_SCRIPT="evaluation/evaluation_f1.py" + +case $SQL_DIALECT in +"SQLite") + diff_json_path="${DATA_PATH}/mini_dev_sqlite.json" + ground_truth_path="${DATA_PATH}/mini_dev_sqlite_gold.sql" + ;; +"PostgreSQL") + diff_json_path="${DATA_PATH}/mini_dev_postgresql.json" + ground_truth_path="${DATA_PATH}/mini_dev_postgresql_gold.sql" + ;; +"MySQL") + diff_json_path="${DATA_PATH}/mini_dev_mysql.json" + ground_truth_path="${DATA_PATH}/mini_dev_mysql_gold.sql" + ;; +*) + echo "Invalid SQL dialect: $SQL_DIALECT" + exit 1 + ;; +esac +# --- DO NOT CHANGE ABOVE --- + +function eval() { + local eval_name="${1}" + local python_script="${2}" + + echo "Starting to compare with knowledge for ${1}" + ${EXEC_CMD} python -u ${python_script} \ + --db_root_path "${DB_ROOT_PATH}" \ + --predicted_sql_path "${PREDICTED_SQL_PATH}" \ + --ground_truth_path "${ground_truth_path}" \ + --num_cpus "${NUM_CPUS}" \ + --output_log_path "${OUTPUT_LOG_PATH}" \ + --diff_json_path "${diff_json_path}" \ + --meta_time_out "${META_TIME_OUT}" \ + --sql_dialect "${SQL_DIALECT}" +} + +function conform() { + local evals="${1}" + local countdown=5 + + echo "Evaluation ${evals} will run next. Press Ctrl+C to cancel..." + for ((i = countdown; i > 0; i--)); do + echo "${i}" + sleep 1 + done +} + +function main() { + local param="${1:-0}" + + echo "Evaluation setup:" + echo " SQL Dialect: ${SQL_DIALECT}" + echo " Predicted SQL Path: ${PREDICTED_SQL_PATH}" + echo " Differential JSON Path: ${diff_json_path}" + echo " Ground Truth Path: ${ground_truth_path}" + echo " Output Log Path: ${OUTPUT_LOG_PATH}" + echo " CPUs: ${NUM_CPUS}" + echo " Timeout: ${META_TIME_OUT}s" + echo "" + + echo "You provided the parameter: ${param:-(none)}" + + case "$param" in + "0" | "all" | "") + conform "EX, R-VES, Soft F1-Score" + eval "EX" "${EX_SCRIPT}" + eval "R-VES" "${R_VES_SCRIPT}" + eval "Soft F1-Score" "${F1_SCRIPT}" + ;; + "1" | "ex") + conform "EX" + eval "EX" "${EX_SCRIPT}" + ;; + "2" | "ves") + conform "R-VES" + eval "R-VES" "${R_VES_SCRIPT}" + ;; + "3" | "f1") + conform "Soft F1-Score" + eval "Soft F1-Score" "${F1_SCRIPT}" + ;; + *) + echo "Error: Invalid parameter: '${param}'" + echo "Usage: $0 [0|all|1|ex|2|ves|3|f1]" + echo " 0 or all or None: Run all evaluations(EX, R-VES, F1), default" + echo " 1 or ex: Run EX evaluation only" + echo " 2 or ves: Run R-VES evaluation only" + echo " 3 or f1: Run F1 evaluation only" + exit 1 + ;; + esac +} + +main "$@" diff --git a/scripts/run_gpt.sh b/scripts/run_gpt.sh new file mode 100755 index 0000000..be3f343 --- /dev/null +++ b/scripts/run_gpt.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash +set -eu + +# --- API and Model Config --- +# api format +PROVIDER="openai" # Options: 'azure' | 'openai' + +# llm serve's url, e.g. +# - azure: "https://gcrendpoint.azurewebsites.net/openai/deployments/{MODEL}" +# - deepseek: "https://api.deepseek.com" +BASE_URL="https://api.deepseek.com" + +API_KEY="" + +# only need to change when use azure serve +API_VERSION="2024-02-01" + +# which model your llm serve deployed, e.g. +# - azure: gpt-4, gpt-4-32k, gpt-4-turbo, gpt-35-turbo, GPT35-turbo-instruct +# - aliyun: qwq-plus, qwen-max, qwen3-235b-a22b +# - deepseek platform: deepseek-chat +# - local llm server: deepseek-ai/DeepSeek-R1, Qwen/Qwen3-32B +MODEL="Qwen/Qwen3-32B" + +# --- Data and Path Config --- +# eval question json file, depends on your DB type +EVAL_PATH="data/minidev/MINIDEV/mini_dev_sqlite.json" # _sqlite.json, _mysql.json, _postgresql.json + +# default sqlite file's path +DB_ROOT_PATH="data/minidev/MINIDEV/dev_databases/" + +# output path for the generated SQL queries +DATA_OUTPUT_PATH="exp_result/" + +# --- Execution Config --- +# task mode +MODE="mini_dev" # Options: "dev" | "train" | "mini_dev" + +# SQL dialect to run +# PLEASE NOTE: You have to setup the database information in table_schema.py +# if you want to run the evaluation script using MySQL or PostgreSQL +SQL_DIALECT="SQLite" # Options: "SQLite" | "PostgreSQL" | "MySQL" + +# number of threads to run in parallel, 1 for single thread +NUM_THREADS=6 + +# use evidence in question json file (to let llm have more information) or not +USE_KNOWLEDGE="True" + +# generate cot prompt or not +COT="True" + +# how to execute python script, depends on your package manager +EXEC_CMD="uv run" # Options: "uv run" | "conda run -n mini_dev" + +function main() { + echo "ICL setup:" + echo " LLM provider: ${PROVIDER}" + echo " Model: ${MODEL}" + echo " SQL Dialect: ${SQL_DIALECT}" + echo " Eval path: ${EVAL_PATH}" + echo " Output path: ${DATA_OUTPUT_PATH}" + echo " Threads: ${NUM_THREADS}" + echo " Use knowledge: ${USE_KNOWLEDGE}" + echo " With chain of thought: ${COT}" + echo "" + + echo "Task will run in 5 seconds. Press Ctrl+C to cancel..." + for ((i = 5; i > 0; i--)); do + echo "${i}" + sleep 1 + done + + echo "Starting to generate perdict sql" + ${EXEC_CMD} python -u llm/src/gpt_request.py \ + --provider ${PROVIDER} \ + --base_url ${BASE_URL} \ + --api_key ${API_KEY} \ + --api_version ${API_VERSION} \ + --model ${MODEL} \ + --eval_path ${EVAL_PATH} \ + --db_root_path ${DB_ROOT_PATH} \ + --data_output_path ${DATA_OUTPUT_PATH} \ + --mode ${MODE} \ + --sql_dialect ${SQL_DIALECT} \ + --num_threads ${NUM_THREADS} \ + --use_knowledge ${USE_KNOWLEDGE} \ + --chain_of_thought ${COT} +} + +main "$@"