Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@

# chat function to trigger inference
def chatbot(query, history):
""" Function to trigger inference for a chatbot.

Args:
query (str): The input query for the chatbot.
history (list): The history of previous interactions.

Returns:
str: The response generated by the chatbot based on the input query.
"""

if verbose:
start_time = time.time()
response = query_engine.query(query)
Expand Down
25 changes: 25 additions & 0 deletions faiss_vector_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,28 @@
class FaissEmbeddingStorage:

def __init__(self, data_dir, dimension=384):
""" Initialize the object with the given data directory and dimension.

Args:
data_dir (str): The directory path where the data is located.
dimension (int?): The dimension size. Defaults to 384.
"""

self.d = dimension
self.data_dir = data_dir
self.index = self.initialize_index()

def initialize_index(self):
""" Initialize the index for vector storage.

If the "storage-default" directory exists and is not empty, the function loads the index from the persisted storage
and returns it. If the directory does not exist or is empty, the function generates new values, creates a new index,
and persists it in the "storage-default" directory.

Returns:
VectorStoreIndex: The initialized or loaded vector store index.
"""

if os.path.exists("storage-default") and os.listdir("storage-default"):
print("Using the persisted value")
vector_store = FaissVectorStore.from_persist_dir("storage-default")
Expand All @@ -55,4 +72,12 @@ def initialize_index(self):
return index

def get_query_engine(self):
""" Returns a query engine for the index.

This method returns a query engine for the index, allowing for efficient querying of the data stored in the index.

Returns:
QueryEngine: A query engine for the index.
"""

return self.index.as_query_engine()
116 changes: 111 additions & 5 deletions trt_llama_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,29 @@ def __init__(
model_kwargs: Optional[Dict[str, Any]] = None,
verbose: bool = False
) -> None:
""" Initialize the object with the provided parameters.

Args:
model_path (Optional[str]): Path to the model.
engine_name (Optional[str]): Name of the engine.
tokenizer_dir (Optional[str]): Directory of the tokenizer.
temperature (float): Temperature for token generation.
max_new_tokens (int): Maximum number of new tokens to generate.
context_window (int): Context window size.
messages_to_prompt (Optional[Callable]): Function for prompting messages.
completion_to_prompt (Optional[Callable]): Function for prompting completions.
callback_manager (Optional[CallbackManager]): Manager for callbacks.
generate_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for generation.
model_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the model.
verbose (bool): Verbosity flag.

Raises:
ValueError: If the provided model path does not exist.

Note:
The function initializes the object with the provided parameters and sets up the model configuration, tokenizer, and sampling configuration.
"""


model_kwargs = model_kwargs or {}
model_kwargs.update({"n_ctx": context_window, "verbose": verbose})
Expand Down Expand Up @@ -191,12 +214,22 @@ def __init__(

@classmethod
def class_name(cls) -> str:
"""Get class name."""
""" Get the name of the class.

This function returns the name of the class as a string.

Returns:
str: The name of the class.
"""
return "TrtLlmAPI"

@property
def metadata(self) -> LLMMetadata:
"""LLM metadata."""
""" Return LLM metadata.

Returns:
LLMMetadata: An instance of LLMMetadata containing context_window, num_output, and model_name.
"""
return LLMMetadata(
context_window=self.context_window,
num_output=self.max_new_tokens,
Expand All @@ -205,12 +238,33 @@ def metadata(self) -> LLMMetadata:

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
""" Generate a chat response based on the input messages.

Args:
messages (Sequence[ChatMessage]): The input messages for the chat.
**kwargs (Any): Additional keyword arguments for customization.

Returns:
ChatResponse: The response generated based on the input messages.
This function takes a sequence of ChatMessage objects as input and generates a chat response based on these messages. It uses the messages to prompt the chat, completes the prompt, and then converts the completion response to a ChatResponse.
"""

prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)

@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
""" Generate completion for the given prompt.

Args:
prompt (str): The input prompt for which completion needs to be generated.
**kwargs (Any): Additional keyword arguments.

Returns:
CompletionResponse: The response object containing the completion text and raw completion dictionary.
"""

self.generate_kwargs.update({"stream": False})

is_formatted = kwargs.pop("formatted", False)
Expand Down Expand Up @@ -255,6 +309,18 @@ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:

def parse_input(self, input_text: str, tokenizer, end_id: int,
remove_input_padding: bool):
""" Parse the input text using the provided tokenizer and return the input ids and input lengths.

Args:
input_text (str): The input text to be tokenized.
tokenizer: The tokenizer object to encode the input text.
end_id (int): The end id for padding the input tokens.
remove_input_padding (bool): A flag indicating whether to remove input padding.

Returns:
tuple: A tuple containing the input ids and input lengths.
"""

input_tokens = []

input_tokens.append(
Expand All @@ -275,6 +341,18 @@ def parse_input(self, input_text: str, tokenizer, end_id: int,
return input_ids, input_lengths

def remove_extra_eos_ids(self, outputs):
""" Remove extra end-of-sequence (EOS) IDs from the outputs.

This function reverses the 'outputs' list, removes any leading EOS IDs (value 2), and then reverses the list back.
Finally, it appends an EOS ID to the end of the 'outputs' list.

Args:
outputs (list): The list of output IDs.

Returns:
list: The modified 'outputs' list with extra EOS IDs removed and an additional EOS ID appended.
"""

outputs.reverse()
while outputs and outputs[0] == 2:
outputs.pop(0)
Expand All @@ -283,6 +361,21 @@ def remove_extra_eos_ids(self, outputs):
return outputs

def get_output(self, output_ids, input_lengths, max_output_len, tokenizer):
""" Generate the output text based on the given output_ids, input_lengths, max_output_len, and tokenizer.

Args:
output_ids (Tensor): The output ids generated by the model.
input_lengths (Tensor): The lengths of the input sequences.
max_output_len (int): The maximum length of the output text.
tokenizer (Tokenizer): The tokenizer used to decode the output ids.

Returns:
output_text (str): The decoded output text.
outputs (list): The list of output ids after removing extra eos ids.
This function iterates through the input lengths and generates the output text based on the given parameters.
It decodes the output ids using the provided tokenizer and returns the decoded output text along with the list of output ids.
"""

num_beams = output_ids.size(1)
output_text = ""
outputs = None
Expand All @@ -297,10 +390,13 @@ def get_output(self, output_ids, input_lengths, max_output_len, tokenizer):
return output_text, outputs

def generate_completion_dict(self, text_str):
"""
Generate a dictionary for text completion details.
""" Generate a dictionary for text completion details.

Args:
text_str (str): The input text for which completion details are to be generated.

Returns:
dict: A dictionary containing completion details.
dict: A dictionary containing completion details, including completion ID, creation time, model name, and text choices.
"""
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
Expand All @@ -327,4 +423,14 @@ def generate_completion_dict(self, text_str):

@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
""" Complete the stream based on the given prompt.

Args:
prompt (str): The prompt for which the stream needs to be completed.
**kwargs: Additional keyword arguments.

Returns:
CompletionResponse: The response containing the completed stream.
"""

pass