diff --git a/src/brdata_rag_tools/models.py b/src/brdata_rag_tools/models.py index e014ffd..8523df4 100644 --- a/src/brdata_rag_tools/models.py +++ b/src/brdata_rag_tools/models.py @@ -2,7 +2,7 @@ import os from enum import Enum import tiktoken -from typing import List +from typing import Optional from openai import OpenAI import requests @@ -17,8 +17,28 @@ class LLMName(Enum): IGEL = "igel" BISON001 = "text-bison@001" + @staticmethod + def openai_models() -> list["LLMName"]: + """Return a list of available OpenAI models.""" + return [ + LLMName.GPT35TURBO, + LLMName.GPT35TURBO0613, + LLMName.GPT35TURBO1106, + LLMName.GPT40314, + LLMName.GPT40613, + LLMName.GPT4, + ] + + @staticmethod + def google_models() -> list["LLMName"]: + """Return a list of available Google models.""" + return [ + LLMName.BISON001 + ] + @property def max_input_tokens(self) -> int: + """Return context window size for the particular LLM.""" if self in [LLMName.GPT35TURBO0613, LLMName.GPT35TURBO]: return 4096 elif self in [LLMName.GPT4, LLMName.GPT40314, LLMName.GPT40613, LLMName.IGEL, LLMName.BISON001]: @@ -30,6 +50,28 @@ def max_input_tokens(self) -> int: logging.warning(f"Unknown context window size for LLM {self}. Opt for default context window size of 2048.") return 2048 + @property + def auth_token(self) -> str: + """Return the authentication token for the particular LLM.""" + if self in LLMName.openai_models(): + token = os.environ.get("OPENAI_TOKEN") + elif self == LLMName.IGEL: + token = os.environ.get("IGEL_TOKEN") + elif self in LLMName.google_models(): + token = os.environ.get("GOOGLE_TOKEN") + else: + raise ValueError(f"No auth_token provided for model {self.value}.") + + if token is None: + raise EnvironmentError( + ( + f"Please set enviornment variable for model {self.value}. " + f"See README for more information." + ) + ) + + return token + class Generator: """ @@ -74,19 +116,18 @@ class Generator: """ def __init__(self, - model: LLMName, - auth_token: str = None, - temperature: float = None, - max_new_tokens: int = None, - top_p: float = None, - top_k: int = None, - length_penalty: float = None, - number_of_responses: int = None, - max_token_length: int = None - ): - + model: LLMName, + auth_token: str = None, + temperature: float = None, + max_new_tokens: int = None, + top_p: float = None, + top_k: int = None, + length_penalty: float = None, + number_of_responses: int = None, + max_token_length: int = None + ): self.model: LLMName = model - self.auth_token: str = self.get_token(auth_token) + self.auth_token: str = auth_token or self.model.auth_token self.temperature: float = temperature self.max_new_tokens: int = max_new_tokens self.top_p: float = top_p @@ -95,31 +136,6 @@ def __init__(self, self.number_of_responses = number_of_responses self.max_token_length: int = max_token_length - def get_token(self, token: str) -> str: - """ - Returns the given auth_token or retrieves the appropriate auth_token based on the model value. - - :param token: The auth_token to be used for authentication. - :type token: str - :return: The retrieved auth_token or the given auth_token. - :rtype: str - :raises ValueError: If no auth_token is provided for the model value. - """ - if token is not None: - return token - else: - if self.model.value.startswith("gpt"): - token = os.environ.get("OPENAI_TOKEN") - elif self.model.value == "igel": - token = os.environ.get("IGEL_TOKEN") - if self.model.value.startswith("text-bison"): - token = os.environ.get("GOOGLE_TOKEN") - - if token is None: - raise ValueError(f"No auth_token provided for model {self.model.value}.") - else: - return token - def _estimate_tokens_openai(self, text: str) -> int: encoding = tiktoken.encoding_for_model(self.model.value) tokens = encoding.encode(text)