diff --git a/hackathon/api/routes.py b/hackathon/api/routes.py index bb06318..5080700 100644 --- a/hackathon/api/routes.py +++ b/hackathon/api/routes.py @@ -326,6 +326,7 @@ def _extract_sample_data(answer: str, correct_answer) -> tuple[float, list[AISam async def provider_run(ai_provider: AIProvider, body: AIRunBody, api_key: str = Header(default=None)): _validate_body_model(ai_provider, body) + param = ProviderParam( sample_id=body.sample_id, provider_model=body.provider_model, @@ -336,6 +337,64 @@ async def provider_run(ai_provider: AIProvider, body: AIRunBody, api_key: str = top_p=body.top_p, top_k=body.top_k, ) + + from dataclasses import dataclass, asdict + # import json + # with open(f'C:/dev/hackathon-quantminds-2023/tmp/tmp.txt', 'w') as f: + # json.dump(param.dict(), f) + # # param = ProviderParam(**param.dict()) + # from dacite import from_dict + # param = from_dict(data_class=ProviderParam, data=param.dict()) + provider = get_provider(ai_provider, api_key) + question = provider.build_question(prompt=param.prompt, context=param.context) + provider_answers = await get_provider(ai_provider, api_key).run([param]) + answer = provider_answers[0].answer if provider_answers else "" + + param = ProviderParam( + sample_id=body.sample_id, + provider_model=body.provider_model, + prompt=None, + context=None, + seed=body.seed, + temperature=body.temperature, + top_p=body.top_p, + top_k=body.top_k, + # question=(f"I am about to ask a question. Before responding, can you please start by repeating for me this current prompt, before giving your answer to the question, as I cannot see my own prompt easily.\n" + # f"My question is: given the following QUESTION, and the following ANSWER that you gave," + # f"can you please explain to me how you came to the ANSWER. " + # f"\n" + # f"QUESTION: {question}\n" + # f"ANSWER: {answer}"), + question="When you answer, please begin by repeating the content of this prompt in full, from the next line.\n\n" + "Below I will refer to a question I asked you as 'QUESTION', and an answer you gave as 'ANSWER'. Can you please explain the answer you gave for each key.\n" + "Please repeat the full JSON from 'ANSWER' in your current answer, or if there is no JSON in 'ANSWER', say so.\n\n" + f"QUESTION: {question}\n\n" + f"ANSWER: {answer}" + ) + + # param = ProviderParam( + # sample_id=body.sample_id, + # provider_model=body.provider_model, + # prompt=None, + # context=None, + # seed=body.seed, + # temperature=body.temperature, + # top_p=body.top_p, + # top_k=body.top_k, + # # question=(f"I am about to ask a question. Before responding, can you please start by repeating for me this current prompt, before giving your answer to the question, as I cannot see my own prompt easily.\n" + # # f"My question is: given the following QUESTION, and the following ANSWER that you gave," + # # f"can you please explain to me how you came to the ANSWER. " + # # f"\n" + # # f"QUESTION: {question}\n" + # # f"ANSWER: {answer}"), + # question="When you answer, please begin by repeating the content of this prompt in full, from the next line.\n\n" + # "Please check that the following JSON is consistent with the following term sheet and make any modifications that are required to the JSON values (not keys) to make it consistent.\n" + # "Please output only the JSON, modified if necessary:\n" + # #"Please repeat the full JSON from 'ANSWER' in your current answer, or if there is no JSON in 'ANSWER', say so.\n\n" + # f"JSON: {answer}\n\n" + # f"TERM SHEET: {body.input}" + # ) + provider_answers = await get_provider(ai_provider, api_key).run([param]) answer = provider_answers[0].answer if provider_answers else "" diff --git a/hackathon/providers/base_provider.py b/hackathon/providers/base_provider.py index b7a8e59..d2352e0 100644 --- a/hackathon/providers/base_provider.py +++ b/hackathon/providers/base_provider.py @@ -14,7 +14,7 @@ import abc import asyncio -from dataclasses import dataclass +from dataclasses import dataclass, asdict from typing import Final, Optional @@ -28,7 +28,9 @@ class ProviderParam: temperature: Optional[float] = None top_p: Optional[float] = None top_k: Optional[int] = None - + question: Optional[str] = None + # def dict(self): + # return {k: str(v) for k, v in asdict(self).items()} @dataclass class ProviderAnswer: diff --git a/hackathon/providers/openai_provider.py b/hackathon/providers/openai_provider.py index e1a1e12..af9318b 100644 --- a/hackathon/providers/openai_provider.py +++ b/hackathon/providers/openai_provider.py @@ -23,7 +23,7 @@ class OpenAIProvider(BaseProvider): - REQUEST_TIMEOUT: Final[int] = 60 + REQUEST_TIMEOUT: Final[int] = 120 async def run(self, params: list[ProviderParam]) -> list[ProviderAnswer]: async with aiohttp.ClientSession() as session: @@ -43,7 +43,10 @@ async def get_answer(self, param: ProviderParam) -> ProviderAnswer: @retry(reraise=True, stop=stop_after_attempt(BaseProvider.RETRY_ATTEMPT)) async def _openai_create(self, param: ProviderParam): - question = self.build_question(prompt=param.prompt, context=param.context) + if not param.question: + question = self.build_question(prompt=param.prompt, context=param.context) + else: + question = param.question messages = [{"role": "user", "content": question}] return await openai.ChatCompletion.acreate( model=param.provider_model,