Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions bird/llm/run/run_gpt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ cot='True'
no_cot='Fales'

YOUR_API_KEY=''
BASE_URL='https://api.openai.com/v1'

engine1='code-davinci-002'
engine2='text-davinci-003'
Expand All @@ -21,11 +22,11 @@ data_kg_output_path='./exp_result/turbo_output_kg/'


echo 'generate GPT3.5 batch with knowledge'
python3 -u ./src/gpt_request.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} --mode ${mode} \
python3 -u ./src/gpt_request.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} --base_url ${BASE_URL} --mode ${mode} \
--engine ${engine3} --eval_path ${eval_path} --data_output_path ${data_kg_output_path} --use_knowledge ${use_knowledge} \
--chain_of_thought ${no_cot}

echo 'generate GPT3.5 batch without knowledge'
python3 -u ./src/gpt_request.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} --mode ${mode} \
python3 -u ./src/gpt_request.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} --base_url ${BASE_URL} --mode ${mode} \
--engine ${engine3} --eval_path ${eval_path} --data_output_path ${data_output_path} --use_knowledge ${not_use_knowledge} \
--chain_of_thought ${no_cot}
80 changes: 57 additions & 23 deletions bird/llm/src/gpt_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
from typing import Dict, List, Tuple

import backoff
import openai
from openai import OpenAI
import pandas as pd
import sqlparse
from tqdm import tqdm
'''openai configure'''

openai.debug=True

# Initialize OpenAI client instead of using the global openai module
client = None # We'll initialize this with the API key later

def new_directory(path):
if not os.path.exists(path):
Expand Down Expand Up @@ -147,51 +146,85 @@ def generate_combined_prompts_one(db_path, question, knowledge=None):
return combined_prompts

def quota_giveup(e):
return isinstance(e, openai.error.RateLimitError) and "quota" in str(e)
return "quota" in str(e)

# Updated backoff decorator to use more generic exception handling
@backoff.on_exception(
backoff.constant,
openai.error.OpenAIError,
Exception, # Using generic Exception instead of OpenAI specific error
giveup=quota_giveup,
raise_on_giveup=True,
interval=20
)
def connect_gpt(engine, prompt, max_tokens, temperature, stop):
# print(prompt)
global client
try:
result = openai.Completion.create(engine=engine, prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
# Use the new client-based API
if engine.startswith("code-"):
# For codex models, use completions endpoint
response = client.completions.create(
model=engine,
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
stop=stop
)
return response
else:
# For newer models, use chat completions endpoint
response = client.chat.completions.create(
model=engine,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature,
stop=stop
)
return response
except Exception as e:
result = 'error:{}'.format(e)
return result
def collect_response_from_gpt(db_path_list, question_list, api_key, engine, knowledge_list=None):
return f'error:{e}'

def collect_response_from_gpt(db_path_list, question_list, api_key, base_url, engine, knowledge_list=None):
'''
:param db_path: str
:param question_list: []
:return: dict of responses collected from openai
'''
global client
client = OpenAI(api_key=api_key, base_url=base_url)

responses_dict = {}
response_list = []
openai.api_key = api_key

for i, question in tqdm(enumerate(question_list)):
print('--------------------- processing {}th question ---------------------'.format(i))
print('the question is: {}'.format(question))

if knowledge_list:
cur_prompt = generate_combined_prompts_one(db_path=db_path_list[i], question=question, knowledge=knowledge_list[i])
else:
cur_prompt = generate_combined_prompts_one(db_path=db_path_list[i], question=question)

plain_result = connect_gpt(engine=engine, prompt=cur_prompt, max_tokens=256, temperature=0, stop=['--', '\n\n', ';', '#'])
# pdb.set_trace()
# plain_result = connect_gpt(engine=engine, prompt=cur_prompt, max_tokens=256, temperature=0, stop=['</s>'])
# determine wheter the sql is wrong

if type(plain_result) == str:

# Parse the response based on its type
if isinstance(plain_result, str):
sql = plain_result
else:
sql = 'SELECT' + plain_result['choices'][0]['text']

# responses_dict[i] = sql
# Handle different response formats based on the API version
if hasattr(plain_result, 'choices') and hasattr(plain_result.choices[0], 'text'):
# Old completions API format
sql = 'SELECT' + plain_result.choices[0].text
elif hasattr(plain_result, 'choices') and hasattr(plain_result.choices[0], 'message'):
# New chat completions API format
message_content = plain_result.choices[0].message.content
if message_content.startswith('SELECT'):
sql = message_content
else:
sql = 'SELECT' + message_content
else:
# Fallback
sql = f"Error: Unexpected response format: {plain_result}"

db_id = db_path_list[i].split('/')[-1].split('.sqlite')[0]
sql = sql + '\t----- bird -----\t' + db_id # to avoid unpredicted \t appearing in codex results
response_list.append(sql)
Expand Down Expand Up @@ -245,6 +278,7 @@ def generate_sql_file(sql_lst, output_path=None):
args_parser.add_argument('--db_root_path', type=str, default='')
# args_parser.add_argument('--db_name', type=str, required=True)
args_parser.add_argument('--api_key', type=str, required=True)
args_parser.add_argument('--base_url', type=str, default='https://api.openai.com/v1')
args_parser.add_argument('--engine', type=str, required=True, default='code-davinci-002')
args_parser.add_argument('--data_output_path', type=str)
args_parser.add_argument('--chain_of_thought', type=str)
Expand All @@ -259,9 +293,9 @@ def generate_sql_file(sql_lst, output_path=None):
assert len(question_list) == len(db_path_list) == len(knowledge_list)

if args.use_knowledge == 'True':
responses = collect_response_from_gpt(db_path_list=db_path_list, question_list=question_list, api_key=args.api_key, engine=args.engine, knowledge_list=knowledge_list)
responses = collect_response_from_gpt(db_path_list=db_path_list, question_list=question_list, api_key=args.api_key, base_url=args.base_url, engine=args.engine, knowledge_list=knowledge_list)
else:
responses = collect_response_from_gpt(db_path_list=db_path_list, question_list=question_list, api_key=args.api_key, engine=args.engine, knowledge_list=None)
responses = collect_response_from_gpt(db_path_list=db_path_list, question_list=question_list, api_key=args.api_key, base_url=args.base_url, engine=args.engine, knowledge_list=None)

if args.chain_of_thought == 'True':
output_name = args.data_output_path + 'predict_' + args.mode + '_cot.json'
Expand Down