Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 54 additions & 38 deletions src/brdata_rag_tools/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from enum import Enum
import tiktoken
from typing import List
from typing import Optional
from openai import OpenAI
import requests

Expand All @@ -17,8 +17,28 @@ class LLMName(Enum):
IGEL = "igel"
BISON001 = "text-bison@001"

@staticmethod
def openai_models() -> list["LLMName"]:
"""Return a list of available OpenAI models."""
return [
LLMName.GPT35TURBO,
LLMName.GPT35TURBO0613,
LLMName.GPT35TURBO1106,
LLMName.GPT40314,
LLMName.GPT40613,
LLMName.GPT4,
]

@staticmethod
def google_models() -> list["LLMName"]:
"""Return a list of available Google models."""
return [
LLMName.BISON001
]

@property
def max_input_tokens(self) -> int:
"""Return context window size for the particular LLM."""
if self in [LLMName.GPT35TURBO0613, LLMName.GPT35TURBO]:
return 4096
elif self in [LLMName.GPT4, LLMName.GPT40314, LLMName.GPT40613, LLMName.IGEL, LLMName.BISON001]:
Expand All @@ -30,6 +50,28 @@ def max_input_tokens(self) -> int:
logging.warning(f"Unknown context window size for LLM {self}. Opt for default context window size of 2048.")
return 2048

@property
def auth_token(self) -> str:
"""Return the authentication token for the particular LLM."""
if self in LLMName.openai_models():
token = os.environ.get("OPENAI_TOKEN")
elif self == LLMName.IGEL:
token = os.environ.get("IGEL_TOKEN")
elif self in LLMName.google_models():
token = os.environ.get("GOOGLE_TOKEN")
else:
raise ValueError(f"No auth_token provided for model {self.value}.")

if token is None:
raise EnvironmentError(
(
f"Please set enviornment variable for model {self.value}. "
f"See README for more information."
)
)

return token


class Generator:
"""
Expand Down Expand Up @@ -74,19 +116,18 @@ class Generator:
"""

def __init__(self,
model: LLMName,
auth_token: str = None,
temperature: float = None,
max_new_tokens: int = None,
top_p: float = None,
top_k: int = None,
length_penalty: float = None,
number_of_responses: int = None,
max_token_length: int = None
):

model: LLMName,
auth_token: str = None,
temperature: float = None,
max_new_tokens: int = None,
top_p: float = None,
top_k: int = None,
length_penalty: float = None,
number_of_responses: int = None,
max_token_length: int = None
):
self.model: LLMName = model
self.auth_token: str = self.get_token(auth_token)
self.auth_token: str = auth_token or self.model.auth_token
self.temperature: float = temperature
self.max_new_tokens: int = max_new_tokens
self.top_p: float = top_p
Expand All @@ -95,31 +136,6 @@ def __init__(self,
self.number_of_responses = number_of_responses
self.max_token_length: int = max_token_length

def get_token(self, token: str) -> str:
"""
Returns the given auth_token or retrieves the appropriate auth_token based on the model value.

:param token: The auth_token to be used for authentication.
:type token: str
:return: The retrieved auth_token or the given auth_token.
:rtype: str
:raises ValueError: If no auth_token is provided for the model value.
"""
if token is not None:
return token
else:
if self.model.value.startswith("gpt"):
token = os.environ.get("OPENAI_TOKEN")
elif self.model.value == "igel":
token = os.environ.get("IGEL_TOKEN")
if self.model.value.startswith("text-bison"):
token = os.environ.get("GOOGLE_TOKEN")

if token is None:
raise ValueError(f"No auth_token provided for model {self.model.value}.")
else:
return token

def _estimate_tokens_openai(self, text: str) -> int:
encoding = tiktoken.encoding_for_model(self.model.value)
tokens = encoding.encode(text)
Expand Down