Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@
<a><img src="materials/intro.png" style="width: 100%; min-width: 300px; display: block; margin: auto;"></a>
</p>

## Quick Start

1. Use `uv sync` or `conda env create -f environment.yml` to initialize python env.
2. Run `./scripts/init_dataset.sh` to prepare the dataset.
3. Run `./script/run_gpt.sh` to generate questions & have the LLM predict SQLs.
4. Run `./script/run_evaluation.sh` to get all scores;

To get specific ex/res-v/soft-f1 score, add parms 1/2/3 at the end.

**Important:** Before running all these scripts, remember to modify the **CAP_VARIABLES** to match your env.

---

## Overview
Here, we provide a Lite version of developtment dataset: **Mini-Dev**. This mini-dev dataset is designed to facilitate efficient and cost-effective development cycles, especially for testing and refining SQL query generation models. This dataset results from community feedback, leading to the compilation of 500 high-quality text2sql pairs derived from 11 distinct databases in a development environment. To further enhance the practicality of the BIRD system in industry settings and support the development of text-to-SQL models, we make the Mini-Dev dataset available in both **MySQL** and **PostgreSQL**.
Expand Down
22 changes: 22 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: mini_dev
dependencies:
- python=3.11.5
- pip=24.0
- annotated-types=0.7.0
- anyio=4.4.0
- certifi=2024.6.2
- distro=1.9.0
- func_timeout=4.3.5
- h11=0.14.0
- httpcore=1.0.5
- httpx=0.27.0
- idna=3.7
- numpy=2.0.0
- openai=1.34.0
- psycopg2-binary=2.9.9
- pydantic=2.7.4
- pydantic-core=2.18.4
- pymysql=1.1.1
- sniffio=1.3.1
- tqdm=4.66.4
- typing-extensions=4.12.2
10 changes: 5 additions & 5 deletions evaluation/evaluation_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import multiprocessing as mp
from func_timeout import func_timeout, FunctionTimedOut
from evaluation_utils import (
load_jsonl,
load_json_data,
execute_sql,
package_sqls,
sort_results,
Expand Down Expand Up @@ -34,10 +34,10 @@ def execute_model(
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = [(f"timeout",)]
result = [("timeout",)]
res = 0
except Exception as e:
result = [(f"error",)] # possibly len(query) > 512 or not executable
except Exception:
result = [("error",)] # possibly len(query) > 512 or not executable
res = 0
result = {"sql_idx": idx, "res": res}
return result
Expand Down Expand Up @@ -69,7 +69,7 @@ def run_sqls_parallel(
def compute_acc_by_diff(exec_results, diff_json_path):
num_queries = len(exec_results)
results = [res["res"] for res in exec_results]
contents = load_jsonl(diff_json_path)
contents = load_json_data(diff_json_path)
simple_results, moderate_results, challenging_results = [], [], []

for i, content in enumerate(contents):
Expand Down
10 changes: 5 additions & 5 deletions evaluation/evaluation_f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import multiprocessing as mp
from func_timeout import func_timeout, FunctionTimedOut
from evaluation_utils import (
load_jsonl,
load_json_data,
execute_sql,
package_sqls,
sort_results,
Expand Down Expand Up @@ -123,10 +123,10 @@ def execute_model(
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = [(f"timeout",)]
result = [("timeout",)]
res = 0
except Exception as e:
result = [(f"error",)] # possibly len(query) > 512 or not executable
except Exception:
result = [("error",)] # possibly len(query) > 512 or not executable
res = 0
# print(result)
# result = str(set([ret[0] for ret in result]))
Expand Down Expand Up @@ -161,7 +161,7 @@ def run_sqls_parallel(
def compute_f1_by_diff(exec_results, diff_json_path):
num_queries = len(exec_results)
results = [res["res"] for res in exec_results]
contents = load_jsonl(diff_json_path)
contents = load_json_data(diff_json_path)
simple_results, moderate_results, challenging_results = [], [], []

for i, content in enumerate(contents):
Expand Down
17 changes: 16 additions & 1 deletion evaluation/evaluation_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
import json
from pathlib import Path

import psycopg2
import pymysql
import sqlite3


def load_json_data(file_path):
file = Path(file_path)
if file.suffix == '.json':
return load_json(file)
elif file.suffix == '.jsonl':
return load_jsonl(file)
else:
raise ValueError('Invalid file type')


def load_jsonl(file_path):
data = []
with open(file_path, "r") as file:
for line in file:
data.append(json.loads(line))
return data


def load_json(dir):
with open(dir, "r") as j:
contents = json.loads(j.read())
Expand All @@ -29,7 +43,7 @@ def connect_postgresql():
# PyMySQL 1.1.1
def connect_mysql():
# Open database connection
# Connect to the database"
# Connect to the database
db = pymysql.connect(
host="localhost",
user="root",
Expand Down Expand Up @@ -119,6 +133,7 @@ def print_data(score_lists, count_lists, metric="F1 Score",result_log_file=None)

# Log to file in append mode
if result_log_file is not None:
Path(result_log_file).parent.mkdir(parents=True, exist_ok=True)
with open(result_log_file, "a") as log_file:
log_file.write(f"start calculate {metric}\n")
log_file.write("{:20} {:20} {:20} {:20} {:20}\n".format("", *levels))
Expand Down
10 changes: 5 additions & 5 deletions evaluation/evaluation_ves.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import multiprocessing as mp
from func_timeout import func_timeout, FunctionTimedOut
from evaluation_utils import (
load_jsonl,
load_json_data,
execute_sql,
package_sqls,
sort_results,
Expand Down Expand Up @@ -96,10 +96,10 @@ def execute_model(
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = [(f"timeout",)]
result = [("timeout",)]
reward = 0
except Exception as e:
result = [(f"error",)] # possibly len(query) > 512 or not executable
except Exception:
result = [("error",)] # possibly len(query) > 512 or not executable
reward = 0
result = {"sql_idx": idx, "reward": reward}
return result
Expand Down Expand Up @@ -148,7 +148,7 @@ def compute_ves(exec_results):

def compute_ves_by_diff(exec_results, diff_json_path):
num_queries = len(exec_results)
contents = load_jsonl(diff_json_path)
contents = load_json_data(diff_json_path)
simple_results, moderate_results, challenging_results = [], [], []
for i, content in enumerate(contents):
if content["difficulty"] == "simple":
Expand Down
66 changes: 0 additions & 66 deletions evaluation/run_evaluation.sh

This file was deleted.

28 changes: 0 additions & 28 deletions llm/run/run_gpt.sh

This file was deleted.

Loading