Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions hackathon/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def _extract_sample_data(answer: str, correct_answer) -> tuple[float, list[AISam
async def provider_run(ai_provider: AIProvider, body: AIRunBody, api_key: str = Header(default=None)):
_validate_body_model(ai_provider, body)


param = ProviderParam(
sample_id=body.sample_id,
provider_model=body.provider_model,
Expand All @@ -336,6 +337,64 @@ async def provider_run(ai_provider: AIProvider, body: AIRunBody, api_key: str =
top_p=body.top_p,
top_k=body.top_k,
)

from dataclasses import dataclass, asdict
# import json
# with open(f'C:/dev/hackathon-quantminds-2023/tmp/tmp.txt', 'w') as f:
# json.dump(param.dict(), f)
# # param = ProviderParam(**param.dict())
# from dacite import from_dict
# param = from_dict(data_class=ProviderParam, data=param.dict())
provider = get_provider(ai_provider, api_key)
question = provider.build_question(prompt=param.prompt, context=param.context)
provider_answers = await get_provider(ai_provider, api_key).run([param])
answer = provider_answers[0].answer if provider_answers else ""

param = ProviderParam(
sample_id=body.sample_id,
provider_model=body.provider_model,
prompt=None,
context=None,
seed=body.seed,
temperature=body.temperature,
top_p=body.top_p,
top_k=body.top_k,
# question=(f"I am about to ask a question. Before responding, can you please start by repeating for me this current prompt, before giving your answer to the question, as I cannot see my own prompt easily.\n"
# f"My question is: given the following QUESTION, and the following ANSWER that you gave,"
# f"can you please explain to me how you came to the ANSWER. "
# f"\n"
# f"QUESTION: {question}\n"
# f"ANSWER: {answer}"),
question="When you answer, please begin by repeating the content of this prompt in full, from the next line.\n\n"
"Below I will refer to a question I asked you as 'QUESTION', and an answer you gave as 'ANSWER'. Can you please explain the answer you gave for each key.\n"
"Please repeat the full JSON from 'ANSWER' in your current answer, or if there is no JSON in 'ANSWER', say so.\n\n"
f"QUESTION: {question}\n\n"
f"ANSWER: {answer}"
)

# param = ProviderParam(
# sample_id=body.sample_id,
# provider_model=body.provider_model,
# prompt=None,
# context=None,
# seed=body.seed,
# temperature=body.temperature,
# top_p=body.top_p,
# top_k=body.top_k,
# # question=(f"I am about to ask a question. Before responding, can you please start by repeating for me this current prompt, before giving your answer to the question, as I cannot see my own prompt easily.\n"
# # f"My question is: given the following QUESTION, and the following ANSWER that you gave,"
# # f"can you please explain to me how you came to the ANSWER. "
# # f"\n"
# # f"QUESTION: {question}\n"
# # f"ANSWER: {answer}"),
# question="When you answer, please begin by repeating the content of this prompt in full, from the next line.\n\n"
# "Please check that the following JSON is consistent with the following term sheet and make any modifications that are required to the JSON values (not keys) to make it consistent.\n"
# "Please output only the JSON, modified if necessary:\n"
# #"Please repeat the full JSON from 'ANSWER' in your current answer, or if there is no JSON in 'ANSWER', say so.\n\n"
# f"JSON: {answer}\n\n"
# f"TERM SHEET: {body.input}"
# )

provider_answers = await get_provider(ai_provider, api_key).run([param])
answer = provider_answers[0].answer if provider_answers else ""

Expand Down
6 changes: 4 additions & 2 deletions hackathon/providers/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import abc
import asyncio
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from typing import Final, Optional


Expand All @@ -28,7 +28,9 @@ class ProviderParam:
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None

question: Optional[str] = None
# def dict(self):
# return {k: str(v) for k, v in asdict(self).items()}

@dataclass
class ProviderAnswer:
Expand Down
7 changes: 5 additions & 2 deletions hackathon/providers/openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


class OpenAIProvider(BaseProvider):
REQUEST_TIMEOUT: Final[int] = 60
REQUEST_TIMEOUT: Final[int] = 120

async def run(self, params: list[ProviderParam]) -> list[ProviderAnswer]:
async with aiohttp.ClientSession() as session:
Expand All @@ -43,7 +43,10 @@ async def get_answer(self, param: ProviderParam) -> ProviderAnswer:

@retry(reraise=True, stop=stop_after_attempt(BaseProvider.RETRY_ATTEMPT))
async def _openai_create(self, param: ProviderParam):
question = self.build_question(prompt=param.prompt, context=param.context)
if not param.question:
question = self.build_question(prompt=param.prompt, context=param.context)
else:
question = param.question
messages = [{"role": "user", "content": question}]
return await openai.ChatCompletion.acreate(
model=param.provider_model,
Expand Down