diff --git a/bird/llm/run/run_gpt.sh b/bird/llm/run/run_gpt.sh index 47d408b8..69041c8c 100644 --- a/bird/llm/run/run_gpt.sh +++ b/bird/llm/run/run_gpt.sh @@ -8,6 +8,7 @@ cot='True' no_cot='Fales' YOUR_API_KEY='' +BASE_URL='https://api.openai.com/v1' engine1='code-davinci-002' engine2='text-davinci-003' @@ -21,11 +22,11 @@ data_kg_output_path='./exp_result/turbo_output_kg/' echo 'generate GPT3.5 batch with knowledge' -python3 -u ./src/gpt_request.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} --mode ${mode} \ +python3 -u ./src/gpt_request.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} --base_url ${BASE_URL} --mode ${mode} \ --engine ${engine3} --eval_path ${eval_path} --data_output_path ${data_kg_output_path} --use_knowledge ${use_knowledge} \ --chain_of_thought ${no_cot} echo 'generate GPT3.5 batch without knowledge' -python3 -u ./src/gpt_request.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} --mode ${mode} \ +python3 -u ./src/gpt_request.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} --base_url ${BASE_URL} --mode ${mode} \ --engine ${engine3} --eval_path ${eval_path} --data_output_path ${data_output_path} --use_knowledge ${not_use_knowledge} \ --chain_of_thought ${no_cot} diff --git a/bird/llm/src/gpt_request.py b/bird/llm/src/gpt_request.py index da85e336..f7d20fea 100644 --- a/bird/llm/src/gpt_request.py +++ b/bird/llm/src/gpt_request.py @@ -10,14 +10,13 @@ from typing import Dict, List, Tuple import backoff -import openai +from openai import OpenAI import pandas as pd import sqlparse from tqdm import tqdm -'''openai configure''' - -openai.debug=True +# Initialize OpenAI client instead of using the global openai module +client = None # We'll initialize this with the API key later def new_directory(path): if not os.path.exists(path): @@ -147,51 +146,85 @@ def generate_combined_prompts_one(db_path, question, knowledge=None): return combined_prompts def quota_giveup(e): - return isinstance(e, openai.error.RateLimitError) and "quota" in str(e) + return "quota" in str(e) +# Updated backoff decorator to use more generic exception handling @backoff.on_exception( backoff.constant, - openai.error.OpenAIError, + Exception, # Using generic Exception instead of OpenAI specific error giveup=quota_giveup, raise_on_giveup=True, interval=20 ) def connect_gpt(engine, prompt, max_tokens, temperature, stop): - # print(prompt) + global client try: - result = openai.Completion.create(engine=engine, prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop=stop) + # Use the new client-based API + if engine.startswith("code-"): + # For codex models, use completions endpoint + response = client.completions.create( + model=engine, + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + stop=stop + ) + return response + else: + # For newer models, use chat completions endpoint + response = client.chat.completions.create( + model=engine, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens, + temperature=temperature, + stop=stop + ) + return response except Exception as e: - result = 'error:{}'.format(e) - return result -def collect_response_from_gpt(db_path_list, question_list, api_key, engine, knowledge_list=None): + return f'error:{e}' + +def collect_response_from_gpt(db_path_list, question_list, api_key, base_url, engine, knowledge_list=None): ''' :param db_path: str :param question_list: [] :return: dict of responses collected from openai ''' + global client + client = OpenAI(api_key=api_key, base_url=base_url) + responses_dict = {} response_list = [] - openai.api_key = api_key + for i, question in tqdm(enumerate(question_list)): print('--------------------- processing {}th question ---------------------'.format(i)) print('the question is: {}'.format(question)) - + if knowledge_list: cur_prompt = generate_combined_prompts_one(db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]) else: cur_prompt = generate_combined_prompts_one(db_path=db_path_list[i], question=question) plain_result = connect_gpt(engine=engine, prompt=cur_prompt, max_tokens=256, temperature=0, stop=['--', '\n\n', ';', '#']) - # pdb.set_trace() - # plain_result = connect_gpt(engine=engine, prompt=cur_prompt, max_tokens=256, temperature=0, stop=['']) - # determine wheter the sql is wrong - - if type(plain_result) == str: + + # Parse the response based on its type + if isinstance(plain_result, str): sql = plain_result else: - sql = 'SELECT' + plain_result['choices'][0]['text'] - - # responses_dict[i] = sql + # Handle different response formats based on the API version + if hasattr(plain_result, 'choices') and hasattr(plain_result.choices[0], 'text'): + # Old completions API format + sql = 'SELECT' + plain_result.choices[0].text + elif hasattr(plain_result, 'choices') and hasattr(plain_result.choices[0], 'message'): + # New chat completions API format + message_content = plain_result.choices[0].message.content + if message_content.startswith('SELECT'): + sql = message_content + else: + sql = 'SELECT' + message_content + else: + # Fallback + sql = f"Error: Unexpected response format: {plain_result}" + db_id = db_path_list[i].split('/')[-1].split('.sqlite')[0] sql = sql + '\t----- bird -----\t' + db_id # to avoid unpredicted \t appearing in codex results response_list.append(sql) @@ -245,6 +278,7 @@ def generate_sql_file(sql_lst, output_path=None): args_parser.add_argument('--db_root_path', type=str, default='') # args_parser.add_argument('--db_name', type=str, required=True) args_parser.add_argument('--api_key', type=str, required=True) + args_parser.add_argument('--base_url', type=str, default='https://api.openai.com/v1') args_parser.add_argument('--engine', type=str, required=True, default='code-davinci-002') args_parser.add_argument('--data_output_path', type=str) args_parser.add_argument('--chain_of_thought', type=str) @@ -259,9 +293,9 @@ def generate_sql_file(sql_lst, output_path=None): assert len(question_list) == len(db_path_list) == len(knowledge_list) if args.use_knowledge == 'True': - responses = collect_response_from_gpt(db_path_list=db_path_list, question_list=question_list, api_key=args.api_key, engine=args.engine, knowledge_list=knowledge_list) + responses = collect_response_from_gpt(db_path_list=db_path_list, question_list=question_list, api_key=args.api_key, base_url=args.base_url, engine=args.engine, knowledge_list=knowledge_list) else: - responses = collect_response_from_gpt(db_path_list=db_path_list, question_list=question_list, api_key=args.api_key, engine=args.engine, knowledge_list=None) + responses = collect_response_from_gpt(db_path_list=db_path_list, question_list=question_list, api_key=args.api_key, base_url=args.base_url, engine=args.engine, knowledge_list=None) if args.chain_of_thought == 'True': output_name = args.data_output_path + 'predict_' + args.mode + '_cot.json'