diff --git a/app.py b/app.py index c597e10..75a5ad0 100644 --- a/app.py +++ b/app.py @@ -83,6 +83,16 @@ # chat function to trigger inference def chatbot(query, history): + """ Function to trigger inference for a chatbot. + + Args: + query (str): The input query for the chatbot. + history (list): The history of previous interactions. + + Returns: + str: The response generated by the chatbot based on the input query. + """ + if verbose: start_time = time.time() response = query_engine.query(query) diff --git a/faiss_vector_storage.py b/faiss_vector_storage.py index 252904f..fe767f5 100644 --- a/faiss_vector_storage.py +++ b/faiss_vector_storage.py @@ -31,11 +31,28 @@ class FaissEmbeddingStorage: def __init__(self, data_dir, dimension=384): + """ Initialize the object with the given data directory and dimension. + + Args: + data_dir (str): The directory path where the data is located. + dimension (int?): The dimension size. Defaults to 384. + """ + self.d = dimension self.data_dir = data_dir self.index = self.initialize_index() def initialize_index(self): + """ Initialize the index for vector storage. + + If the "storage-default" directory exists and is not empty, the function loads the index from the persisted storage + and returns it. If the directory does not exist or is empty, the function generates new values, creates a new index, + and persists it in the "storage-default" directory. + + Returns: + VectorStoreIndex: The initialized or loaded vector store index. + """ + if os.path.exists("storage-default") and os.listdir("storage-default"): print("Using the persisted value") vector_store = FaissVectorStore.from_persist_dir("storage-default") @@ -55,4 +72,12 @@ def initialize_index(self): return index def get_query_engine(self): + """ Returns a query engine for the index. + + This method returns a query engine for the index, allowing for efficient querying of the data stored in the index. + + Returns: + QueryEngine: A query engine for the index. + """ + return self.index.as_query_engine() diff --git a/trt_llama_api.py b/trt_llama_api.py index 09fecd7..f363cf4 100644 --- a/trt_llama_api.py +++ b/trt_llama_api.py @@ -96,6 +96,29 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, verbose: bool = False ) -> None: + """ Initialize the object with the provided parameters. + + Args: + model_path (Optional[str]): Path to the model. + engine_name (Optional[str]): Name of the engine. + tokenizer_dir (Optional[str]): Directory of the tokenizer. + temperature (float): Temperature for token generation. + max_new_tokens (int): Maximum number of new tokens to generate. + context_window (int): Context window size. + messages_to_prompt (Optional[Callable]): Function for prompting messages. + completion_to_prompt (Optional[Callable]): Function for prompting completions. + callback_manager (Optional[CallbackManager]): Manager for callbacks. + generate_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for generation. + model_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the model. + verbose (bool): Verbosity flag. + + Raises: + ValueError: If the provided model path does not exist. + + Note: + The function initializes the object with the provided parameters and sets up the model configuration, tokenizer, and sampling configuration. + """ + model_kwargs = model_kwargs or {} model_kwargs.update({"n_ctx": context_window, "verbose": verbose}) @@ -191,12 +214,22 @@ def __init__( @classmethod def class_name(cls) -> str: - """Get class name.""" + """ Get the name of the class. + + This function returns the name of the class as a string. + + Returns: + str: The name of the class. + """ return "TrtLlmAPI" @property def metadata(self) -> LLMMetadata: - """LLM metadata.""" + """ Return LLM metadata. + + Returns: + LLMMetadata: An instance of LLMMetadata containing context_window, num_output, and model_name. + """ return LLMMetadata( context_window=self.context_window, num_output=self.max_new_tokens, @@ -205,12 +238,33 @@ def metadata(self) -> LLMMetadata: @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + """ Generate a chat response based on the input messages. + + Args: + messages (Sequence[ChatMessage]): The input messages for the chat. + **kwargs (Any): Additional keyword arguments for customization. + + Returns: + ChatResponse: The response generated based on the input messages. + This function takes a sequence of ChatMessage objects as input and generates a chat response based on these messages. It uses the messages to prompt the chat, completes the prompt, and then converts the completion response to a ChatResponse. + """ + prompt = self.messages_to_prompt(messages) completion_response = self.complete(prompt, formatted=True, **kwargs) return completion_response_to_chat_response(completion_response) @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: + """ Generate completion for the given prompt. + + Args: + prompt (str): The input prompt for which completion needs to be generated. + **kwargs (Any): Additional keyword arguments. + + Returns: + CompletionResponse: The response object containing the completion text and raw completion dictionary. + """ + self.generate_kwargs.update({"stream": False}) is_formatted = kwargs.pop("formatted", False) @@ -255,6 +309,18 @@ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: def parse_input(self, input_text: str, tokenizer, end_id: int, remove_input_padding: bool): + """ Parse the input text using the provided tokenizer and return the input ids and input lengths. + + Args: + input_text (str): The input text to be tokenized. + tokenizer: The tokenizer object to encode the input text. + end_id (int): The end id for padding the input tokens. + remove_input_padding (bool): A flag indicating whether to remove input padding. + + Returns: + tuple: A tuple containing the input ids and input lengths. + """ + input_tokens = [] input_tokens.append( @@ -275,6 +341,18 @@ def parse_input(self, input_text: str, tokenizer, end_id: int, return input_ids, input_lengths def remove_extra_eos_ids(self, outputs): + """ Remove extra end-of-sequence (EOS) IDs from the outputs. + + This function reverses the 'outputs' list, removes any leading EOS IDs (value 2), and then reverses the list back. + Finally, it appends an EOS ID to the end of the 'outputs' list. + + Args: + outputs (list): The list of output IDs. + + Returns: + list: The modified 'outputs' list with extra EOS IDs removed and an additional EOS ID appended. + """ + outputs.reverse() while outputs and outputs[0] == 2: outputs.pop(0) @@ -283,6 +361,21 @@ def remove_extra_eos_ids(self, outputs): return outputs def get_output(self, output_ids, input_lengths, max_output_len, tokenizer): + """ Generate the output text based on the given output_ids, input_lengths, max_output_len, and tokenizer. + + Args: + output_ids (Tensor): The output ids generated by the model. + input_lengths (Tensor): The lengths of the input sequences. + max_output_len (int): The maximum length of the output text. + tokenizer (Tokenizer): The tokenizer used to decode the output ids. + + Returns: + output_text (str): The decoded output text. + outputs (list): The list of output ids after removing extra eos ids. + This function iterates through the input lengths and generates the output text based on the given parameters. + It decodes the output ids using the provided tokenizer and returns the decoded output text along with the list of output ids. + """ + num_beams = output_ids.size(1) output_text = "" outputs = None @@ -297,10 +390,13 @@ def get_output(self, output_ids, input_lengths, max_output_len, tokenizer): return output_text, outputs def generate_completion_dict(self, text_str): - """ - Generate a dictionary for text completion details. + """ Generate a dictionary for text completion details. + + Args: + text_str (str): The input text for which completion details are to be generated. + Returns: - dict: A dictionary containing completion details. + dict: A dictionary containing completion details, including completion ID, creation time, model name, and text choices. """ completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) @@ -327,4 +423,14 @@ def generate_completion_dict(self, text_str): @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: + """ Complete the stream based on the given prompt. + + Args: + prompt (str): The prompt for which the stream needs to be completed. + **kwargs: Additional keyword arguments. + + Returns: + CompletionResponse: The response containing the completed stream. + """ + pass