From 74828efc8f1b5dffd631c92f29c7571f08748e23 Mon Sep 17 00:00:00 2001 From: PhilippGawlik Date: Tue, 19 Dec 2023 12:12:42 +0100 Subject: [PATCH 1/2] =?UTF-8?q?Revert=20"Revert=20"=F0=9F=93=9D=20Docstrin?= =?UTF-8?q?g""?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit cc7d868df1fc003e2f6089dfe8b28e1e0d9204fa. --- src/brdata_rag_tools/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brdata_rag_tools/models.py b/src/brdata_rag_tools/models.py index e014ffd..1e46cf1 100644 --- a/src/brdata_rag_tools/models.py +++ b/src/brdata_rag_tools/models.py @@ -19,6 +19,7 @@ class LLMName(Enum): @property def max_input_tokens(self) -> int: + """Return context window size for the particular LLM.""" if self in [LLMName.GPT35TURBO0613, LLMName.GPT35TURBO]: return 4096 elif self in [LLMName.GPT4, LLMName.GPT40314, LLMName.GPT40613, LLMName.IGEL, LLMName.BISON001]: From 878c1e68c66ae7742dc8076ca07770b2ded1f206 Mon Sep 17 00:00:00 2001 From: PhilippGawlik Date: Tue, 19 Dec 2023 12:12:54 +0100 Subject: [PATCH 2/2] =?UTF-8?q?Revert=20"Revert=20"=E2=99=BB=EF=B8=8F=20LL?= =?UTF-8?q?MName=20should=20hold=20the=20configuration=20information""?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 276d15afdcf6547d36383c1d96cf1ce9efba1bd3. --- src/brdata_rag_tools/models.py | 91 ++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 38 deletions(-) diff --git a/src/brdata_rag_tools/models.py b/src/brdata_rag_tools/models.py index 1e46cf1..8523df4 100644 --- a/src/brdata_rag_tools/models.py +++ b/src/brdata_rag_tools/models.py @@ -2,7 +2,7 @@ import os from enum import Enum import tiktoken -from typing import List +from typing import Optional from openai import OpenAI import requests @@ -17,6 +17,25 @@ class LLMName(Enum): IGEL = "igel" BISON001 = "text-bison@001" + @staticmethod + def openai_models() -> list["LLMName"]: + """Return a list of available OpenAI models.""" + return [ + LLMName.GPT35TURBO, + LLMName.GPT35TURBO0613, + LLMName.GPT35TURBO1106, + LLMName.GPT40314, + LLMName.GPT40613, + LLMName.GPT4, + ] + + @staticmethod + def google_models() -> list["LLMName"]: + """Return a list of available Google models.""" + return [ + LLMName.BISON001 + ] + @property def max_input_tokens(self) -> int: """Return context window size for the particular LLM.""" @@ -31,6 +50,28 @@ def max_input_tokens(self) -> int: logging.warning(f"Unknown context window size for LLM {self}. Opt for default context window size of 2048.") return 2048 + @property + def auth_token(self) -> str: + """Return the authentication token for the particular LLM.""" + if self in LLMName.openai_models(): + token = os.environ.get("OPENAI_TOKEN") + elif self == LLMName.IGEL: + token = os.environ.get("IGEL_TOKEN") + elif self in LLMName.google_models(): + token = os.environ.get("GOOGLE_TOKEN") + else: + raise ValueError(f"No auth_token provided for model {self.value}.") + + if token is None: + raise EnvironmentError( + ( + f"Please set enviornment variable for model {self.value}. " + f"See README for more information." + ) + ) + + return token + class Generator: """ @@ -75,19 +116,18 @@ class Generator: """ def __init__(self, - model: LLMName, - auth_token: str = None, - temperature: float = None, - max_new_tokens: int = None, - top_p: float = None, - top_k: int = None, - length_penalty: float = None, - number_of_responses: int = None, - max_token_length: int = None - ): - + model: LLMName, + auth_token: str = None, + temperature: float = None, + max_new_tokens: int = None, + top_p: float = None, + top_k: int = None, + length_penalty: float = None, + number_of_responses: int = None, + max_token_length: int = None + ): self.model: LLMName = model - self.auth_token: str = self.get_token(auth_token) + self.auth_token: str = auth_token or self.model.auth_token self.temperature: float = temperature self.max_new_tokens: int = max_new_tokens self.top_p: float = top_p @@ -96,31 +136,6 @@ def __init__(self, self.number_of_responses = number_of_responses self.max_token_length: int = max_token_length - def get_token(self, token: str) -> str: - """ - Returns the given auth_token or retrieves the appropriate auth_token based on the model value. - - :param token: The auth_token to be used for authentication. - :type token: str - :return: The retrieved auth_token or the given auth_token. - :rtype: str - :raises ValueError: If no auth_token is provided for the model value. - """ - if token is not None: - return token - else: - if self.model.value.startswith("gpt"): - token = os.environ.get("OPENAI_TOKEN") - elif self.model.value == "igel": - token = os.environ.get("IGEL_TOKEN") - if self.model.value.startswith("text-bison"): - token = os.environ.get("GOOGLE_TOKEN") - - if token is None: - raise ValueError(f"No auth_token provided for model {self.model.value}.") - else: - return token - def _estimate_tokens_openai(self, text: str) -> int: encoding = tiktoken.encoding_for_model(self.model.value) tokens = encoding.encode(text)