From 7a36af6e8f904023569977e503c72124b6d4c7cf Mon Sep 17 00:00:00 2001 From: John Aziz Date: Sat, 20 Jul 2024 16:32:57 +0000 Subject: [PATCH 01/16] refactor and add streaming functions --- src/backend/fastapi_app/rag_advanced.py | 170 ++++++++++++++++++----- src/backend/fastapi_app/rag_simple.py | 175 ++++++++++++++++++++---- 2 files changed, 286 insertions(+), 59 deletions(-) diff --git a/src/backend/fastapi_app/rag_advanced.py b/src/backend/fastapi_app/rag_advanced.py index 024a5fbd..20cacb15 100644 --- a/src/backend/fastapi_app/rag_advanced.py +++ b/src/backend/fastapi_app/rag_advanced.py @@ -1,19 +1,17 @@ -import pathlib from collections.abc import AsyncGenerator -from typing import ( - Any, -) +from typing import Any -from openai import AsyncAzureOpenAI, AsyncOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionMessageParam +from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep -from .postgres_searcher import PostgresSearcher -from .query_rewriter import build_search_function, extract_search_arguments +from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep +from fastapi_app.postgres_searcher import PostgresSearcher +from fastapi_app.query_rewriter import build_search_function, extract_search_arguments +from fastapi_app.rag_simple import RAGChatBase -class AdvancedRAGChat: +class AdvancedRAGChat(RAGChatBase): def __init__( self, *, @@ -27,29 +25,21 @@ def __init__( self.chat_model = chat_model self.chat_deployment = chat_deployment self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True) - current_dir = pathlib.Path(__file__).parent - self.query_prompt_template = open(current_dir / "prompts/query.txt").read() - self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read() async def run( - self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {} - ) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]: - text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] - vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] - top = overrides.get("top", 3) - - original_user_query = messages[-1]["content"] - if not isinstance(original_user_query, str): - raise ValueError("The most recent message content must be a string.") - past_messages = messages[:-1] + self, + messages: list[ChatCompletionMessageParam], + overrides: dict[str, Any] = {}, + ) -> RetrievalResponse: + chat_params = self.get_params(messages, overrides) # Generate an optimized keyword search query based on the chat history and the last question query_response_token_limit = 500 query_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=self.query_prompt_template, - new_user_content=original_user_query, - past_messages=past_messages, + new_user_content=chat_params.original_user_query, + past_messages=chat_params.past_messages, max_tokens=self.chat_token_limit - query_response_token_limit, # TODO: count functions fallback_to_default=True, ) @@ -65,14 +55,14 @@ async def run( tool_choice="auto", ) - query_text, filters = extract_search_arguments(original_user_query, chat_completion) + query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion) # Retrieve relevant items from the database with the GPT optimized query results = await self.searcher.search_and_embed( query_text, - top=top, - enable_vector_search=vector_search, - enable_text_search=text_search, + top=chat_params.top, + enable_vector_search=chat_params.enable_vector_search, + enable_text_search=chat_params.enable_text_search, filters=filters, ) @@ -84,8 +74,8 @@ async def run( contextual_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, - new_user_content=original_user_query + "\n\nSources:\n" + content, - past_messages=past_messages, + new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content, + past_messages=chat_params.past_messages, max_tokens=self.chat_token_limit - response_token_limit, fallback_to_default=True, ) @@ -99,6 +89,7 @@ async def run( n=1, stream=False, ) + first_choice_message = chat_completion_response.choices[0].message return RetrievalResponse( @@ -119,9 +110,9 @@ async def run( title="Search using generated search arguments", description=query_text, props={ - "top": top, - "vector_search": vector_search, - "text_search": text_search, + "top": chat_params.top, + "vector_search": chat_params.enable_vector_search, + "text_search": chat_params.enable_text_search, "filters": filters, }, ), @@ -141,3 +132,114 @@ async def run( ], ), ) + + async def run_stream( + self, + messages: list[ChatCompletionMessageParam], + overrides: dict[str, Any] = {}, + ) -> AsyncGenerator[RetrievalResponse | Message, None]: + chat_params = self.get_params(messages, overrides) + + # Generate an optimized keyword search query based on the chat history and the last question + query_response_token_limit = 500 + query_messages: list[ChatCompletionMessageParam] = build_messages( + model=self.chat_model, + system_prompt=self.query_prompt_template, + new_user_content=chat_params.original_user_query, + past_messages=chat_params.past_messages, + max_tokens=self.chat_token_limit - query_response_token_limit, # TODO: count functions + fallback_to_default=True, + ) + + chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create( + messages=query_messages, + # Azure OpenAI takes the deployment name as the model name + model=self.chat_deployment if self.chat_deployment else self.chat_model, + temperature=0.0, # Minimize creativity for search query generation + max_tokens=query_response_token_limit, # Setting too low risks malformed JSON, too high risks performance + n=1, + tools=build_search_function(), + tool_choice="auto", + ) + + query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion) + + # Retrieve relevant items from the database with the GPT optimized query + results = await self.searcher.search_and_embed( + query_text, + top=chat_params.top, + enable_vector_search=chat_params.enable_vector_search, + enable_text_search=chat_params.enable_text_search, + filters=filters, + ) + + sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results] + content = "\n".join(sources_content) + + # Generate a contextual and content specific answer using the search results and chat history + response_token_limit = 1024 + contextual_messages: list[ChatCompletionMessageParam] = build_messages( + model=self.chat_model, + system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, + new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content, + past_messages=chat_params.past_messages, + max_tokens=self.chat_token_limit - response_token_limit, + fallback_to_default=True, + ) + + chat_completion_async_stream: AsyncStream[ + ChatCompletionChunk + ] = await self.openai_chat_client.chat.completions.create( + # Azure OpenAI takes the deployment name as the model name + model=self.chat_deployment if self.chat_deployment else self.chat_model, + messages=contextual_messages, + temperature=overrides.get("temperature", 0.3), + max_tokens=response_token_limit, + n=1, + stream=True, + ) + + yield RetrievalResponse( + message=Message(content="", role="assistant"), + context=RAGContext( + data_points={item.id: item.to_dict() for item in results}, + thoughts=[ + ThoughtStep( + title="Prompt to generate search arguments", + description=[str(message) for message in query_messages], + props=( + {"model": self.chat_model, "deployment": self.chat_deployment} + if self.chat_deployment + else {"model": self.chat_model} + ), + ), + ThoughtStep( + title="Search using generated search arguments", + description=query_text, + props={ + "top": chat_params.top, + "vector_search": chat_params.enable_vector_search, + "text_search": chat_params.enable_text_search, + "filters": filters, + }, + ), + ThoughtStep( + title="Search results", + description=[result.to_dict() for result in results], + ), + ThoughtStep( + title="Prompt to generate answer", + description=[str(message) for message in contextual_messages], + props=( + {"model": self.chat_model, "deployment": self.chat_deployment} + if self.chat_deployment + else {"model": self.chat_model} + ), + ), + ], + ), + ) + + async for response_chunk in chat_completion_async_stream: + yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant") + return diff --git a/src/backend/fastapi_app/rag_simple.py b/src/backend/fastapi_app/rag_simple.py index f8db974e..e323ad7c 100644 --- a/src/backend/fastapi_app/rag_simple.py +++ b/src/backend/fastapi_app/rag_simple.py @@ -1,16 +1,69 @@ import pathlib +from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from typing import Any -from openai import AsyncAzureOpenAI, AsyncOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionMessageParam +from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit +from pydantic import BaseModel -from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep -from .postgres_searcher import PostgresSearcher +from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep +from fastapi_app.postgres_searcher import PostgresSearcher -class SimpleRAGChat: +class ChatParams(BaseModel): + top: int + temperature: float + enable_text_search: bool + enable_vector_search: bool + original_user_query: str + past_messages: list[ChatCompletionMessageParam] + + +class RAGChatBase(ABC): + current_dir = pathlib.Path(__file__).parent + query_prompt_template = open(current_dir / "prompts/query.txt").read() + answer_prompt_template = open(current_dir / "prompts/answer.txt").read() + + def get_params(self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any]) -> ChatParams: + enable_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] + enable_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] + top: int = overrides.get("top", 3) + temperature: float = overrides.get("temperature", 0.3) + original_user_query = messages[-1]["content"] + if not isinstance(original_user_query, str): + raise ValueError("The most recent message content must be a string.") + past_messages = messages[:-1] + return ChatParams( + top=top, + temperature=temperature, + enable_text_search=enable_text_search, + enable_vector_search=enable_vector_search, + original_user_query=original_user_query, + past_messages=past_messages, + ) + + @abstractmethod + async def run( + self, + messages: list[ChatCompletionMessageParam], + overrides: dict[str, Any] = {}, + ) -> RetrievalResponse: + raise NotImplementedError + + @abstractmethod + async def run_stream( + self, + messages: list[ChatCompletionMessageParam], + overrides: dict[str, Any] = {}, + ) -> AsyncGenerator[RetrievalResponse | Message, None]: + raise NotImplementedError + if False: + yield 0 + + +class SimpleRAGChat(RAGChatBase): def __init__( self, *, @@ -24,24 +77,20 @@ def __init__( self.chat_model = chat_model self.chat_deployment = chat_deployment self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True) - current_dir = pathlib.Path(__file__).parent - self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read() async def run( - self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {} - ) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]: - text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] - vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] - top = overrides.get("top", 3) - - original_user_query = messages[-1]["content"] - if not isinstance(original_user_query, str): - raise ValueError("The most recent message content must be a string.") - past_messages = messages[:-1] + self, + messages: list[ChatCompletionMessageParam], + overrides: dict[str, Any] = {}, + ) -> RetrievalResponse: + chat_params = self.get_params(messages, overrides) # Retrieve relevant items from the database results = await self.searcher.search_and_embed( - original_user_query, top=top, enable_vector_search=vector_search, enable_text_search=text_search + chat_params.original_user_query, + top=chat_params.top, + enable_vector_search=chat_params.enable_vector_search, + enable_text_search=chat_params.enable_text_search, ) sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results] @@ -52,8 +101,8 @@ async def run( contextual_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, - new_user_content=original_user_query + "\n\nSources:\n" + content, - past_messages=past_messages, + new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content, + past_messages=chat_params.past_messages, max_tokens=self.chat_token_limit - response_token_limit, fallback_to_default=True, ) @@ -62,11 +111,12 @@ async def run( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, messages=contextual_messages, - temperature=overrides.get("temperature", 0.3), + temperature=chat_params.temperature, max_tokens=response_token_limit, n=1, stream=False, ) + first_choice_message = chat_completion_response.choices[0].message return RetrievalResponse( @@ -76,11 +126,83 @@ async def run( thoughts=[ ThoughtStep( title="Search query for database", - description=original_user_query if text_search else None, + description=chat_params.original_user_query if chat_params.enable_text_search else None, + props={ + "top": chat_params.top, + "vector_search": chat_params.enable_vector_search, + "text_search": chat_params.enable_text_search, + }, + ), + ThoughtStep( + title="Search results", + description=[result.to_dict() for result in results], + ), + ThoughtStep( + title="Prompt to generate answer", + description=[str(message) for message in contextual_messages], + props=( + {"model": self.chat_model, "deployment": self.chat_deployment} + if self.chat_deployment + else {"model": self.chat_model} + ), + ), + ], + ), + ) + + async def run_stream( + self, + messages: list[ChatCompletionMessageParam], + overrides: dict[str, Any] = {}, + ) -> AsyncGenerator[RetrievalResponse | Message, None]: + chat_params = self.get_params(messages, overrides) + + # Retrieve relevant items from the database + results = await self.searcher.search_and_embed( + chat_params.original_user_query, + top=chat_params.top, + enable_vector_search=chat_params.enable_vector_search, + enable_text_search=chat_params.enable_text_search, + ) + + sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results] + content = "\n".join(sources_content) + + # Generate a contextual and content specific answer using the search results and chat history + response_token_limit = 1024 + contextual_messages: list[ChatCompletionMessageParam] = build_messages( + model=self.chat_model, + system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, + new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content, + past_messages=chat_params.past_messages, + max_tokens=self.chat_token_limit - response_token_limit, + fallback_to_default=True, + ) + + chat_completion_async_stream: AsyncStream[ + ChatCompletionChunk + ] = await self.openai_chat_client.chat.completions.create( + # Azure OpenAI takes the deployment name as the model name + model=self.chat_deployment if self.chat_deployment else self.chat_model, + messages=contextual_messages, + temperature=overrides.get("temperature", 0.3), + max_tokens=response_token_limit, + n=1, + stream=True, + ) + + yield RetrievalResponse( + message=Message(content="", role="assistant"), + context=RAGContext( + data_points={item.id: item.to_dict() for item in results}, + thoughts=[ + ThoughtStep( + title="Search query for database", + description=chat_params.original_user_query if chat_params.enable_text_search else None, props={ - "top": top, - "vector_search": vector_search, - "text_search": text_search, + "top": chat_params.top, + "vector_search": chat_params.enable_vector_search, + "text_search": chat_params.enable_text_search, }, ), ThoughtStep( @@ -99,3 +221,6 @@ async def run( ], ), ) + async for response_chunk in chat_completion_async_stream: + yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant") + return From 172104c7a470194ef66d96f16b99fa189f6764c5 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Sat, 20 Jul 2024 17:40:47 +0000 Subject: [PATCH 02/16] refactor code --- src/backend/fastapi_app/rag_advanced.py | 98 +++++++++++-------------- src/backend/fastapi_app/rag_simple.py | 80 +++++++++++--------- 2 files changed, 86 insertions(+), 92 deletions(-) diff --git a/src/backend/fastapi_app/rag_advanced.py b/src/backend/fastapi_app/rag_advanced.py index 20cacb15..edcb52d5 100644 --- a/src/backend/fastapi_app/rag_advanced.py +++ b/src/backend/fastapi_app/rag_advanced.py @@ -6,9 +6,10 @@ from openai_messages_token_helper import build_messages, get_token_limit from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep +from fastapi_app.postgres_models import Item from fastapi_app.postgres_searcher import PostgresSearcher from fastapi_app.query_rewriter import build_search_function, extract_search_arguments -from fastapi_app.rag_simple import RAGChatBase +from fastapi_app.rag_simple import ChatParams, RAGChatBase class AdvancedRAGChat(RAGChatBase): @@ -26,15 +27,10 @@ def __init__( self.chat_deployment = chat_deployment self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True) - async def run( - self, - messages: list[ChatCompletionMessageParam], - overrides: dict[str, Any] = {}, - ) -> RetrievalResponse: - chat_params = self.get_params(messages, overrides) - - # Generate an optimized keyword search query based on the chat history and the last question - query_response_token_limit = 500 + async def generate_search_query( + self, chat_params: ChatParams, query_response_token_limit: int + ) -> tuple[list[ChatCompletionMessageParam], Any | str | None, list]: + """Generate an optimized keyword search query based on the chat history and the last question""" query_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=self.query_prompt_template, @@ -57,6 +53,12 @@ async def run( query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion) + return query_messages, query_text, filters + + async def retreive_and_build_context( + self, chat_params: ChatParams, query_text: str | Any | None, filters: list + ) -> tuple[list[ChatCompletionMessageParam], list[Item]]: + """Retrieve relevant items from the database and build a context for the chat model.""" # Retrieve relevant items from the database with the GPT optimized query results = await self.searcher.search_and_embed( query_text, @@ -70,22 +72,40 @@ async def run( content = "\n".join(sources_content) # Generate a contextual and content specific answer using the search results and chat history - response_token_limit = 1024 contextual_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, - system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, + system_prompt=chat_params.prompt_template, new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content, past_messages=chat_params.past_messages, - max_tokens=self.chat_token_limit - response_token_limit, + max_tokens=self.chat_token_limit - chat_params.response_token_limit, fallback_to_default=True, ) + return contextual_messages, results + + async def run( + self, + messages: list[ChatCompletionMessageParam], + overrides: dict[str, Any] = {}, + ) -> RetrievalResponse: + chat_params = self.get_params(messages, overrides) + + # Generate an optimized keyword search query based on the chat history and the last question + query_messages, query_text, filters = await self.generate_search_query( + chat_params=chat_params, query_response_token_limit=500 + ) + + # Retrieve relevant items from the database with the GPT optimized query + # Generate a contextual and content specific answer using the search results and chat history + contextual_messages, results = await self.retreive_and_build_context( + chat_params=chat_params, query_text=query_text, filters=filters + ) chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, messages=contextual_messages, - temperature=overrides.get("temperature", 0.3), - max_tokens=response_token_limit, + temperature=chat_params.temperature, + max_tokens=chat_params.response_token_limit, n=1, stream=False, ) @@ -141,50 +161,14 @@ async def run_stream( chat_params = self.get_params(messages, overrides) # Generate an optimized keyword search query based on the chat history and the last question - query_response_token_limit = 500 - query_messages: list[ChatCompletionMessageParam] = build_messages( - model=self.chat_model, - system_prompt=self.query_prompt_template, - new_user_content=chat_params.original_user_query, - past_messages=chat_params.past_messages, - max_tokens=self.chat_token_limit - query_response_token_limit, # TODO: count functions - fallback_to_default=True, + query_messages, query_text, filters = await self.generate_search_query( + chat_params=chat_params, query_response_token_limit=500 ) - chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create( - messages=query_messages, - # Azure OpenAI takes the deployment name as the model name - model=self.chat_deployment if self.chat_deployment else self.chat_model, - temperature=0.0, # Minimize creativity for search query generation - max_tokens=query_response_token_limit, # Setting too low risks malformed JSON, too high risks performance - n=1, - tools=build_search_function(), - tool_choice="auto", - ) - - query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion) - # Retrieve relevant items from the database with the GPT optimized query - results = await self.searcher.search_and_embed( - query_text, - top=chat_params.top, - enable_vector_search=chat_params.enable_vector_search, - enable_text_search=chat_params.enable_text_search, - filters=filters, - ) - - sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results] - content = "\n".join(sources_content) - # Generate a contextual and content specific answer using the search results and chat history - response_token_limit = 1024 - contextual_messages: list[ChatCompletionMessageParam] = build_messages( - model=self.chat_model, - system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, - new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content, - past_messages=chat_params.past_messages, - max_tokens=self.chat_token_limit - response_token_limit, - fallback_to_default=True, + contextual_messages, results = await self.retreive_and_build_context( + chat_params=chat_params, query_text=query_text, filters=filters ) chat_completion_async_stream: AsyncStream[ @@ -193,8 +177,8 @@ async def run_stream( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, messages=contextual_messages, - temperature=overrides.get("temperature", 0.3), - max_tokens=response_token_limit, + temperature=chat_params.temperature, + max_tokens=chat_params.response_token_limit, n=1, stream=True, ) diff --git a/src/backend/fastapi_app/rag_simple.py b/src/backend/fastapi_app/rag_simple.py index e323ad7c..36ee1913 100644 --- a/src/backend/fastapi_app/rag_simple.py +++ b/src/backend/fastapi_app/rag_simple.py @@ -9,16 +9,19 @@ from pydantic import BaseModel from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep +from fastapi_app.postgres_models import Item from fastapi_app.postgres_searcher import PostgresSearcher class ChatParams(BaseModel): - top: int - temperature: float + top: int = 3 + temperature: float = 0.3 + response_token_limit: int = 1024 enable_text_search: bool enable_vector_search: bool original_user_query: str past_messages: list[ChatCompletionMessageParam] + prompt_template: str class RAGChatBase(ABC): @@ -27,17 +30,24 @@ class RAGChatBase(ABC): answer_prompt_template = open(current_dir / "prompts/answer.txt").read() def get_params(self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any]) -> ChatParams: - enable_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] - enable_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] top: int = overrides.get("top", 3) temperature: float = overrides.get("temperature", 0.3) + response_token_limit = 1024 + prompt_template = overrides.get("prompt_template") or self.answer_prompt_template + + enable_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] + enable_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] + original_user_query = messages[-1]["content"] if not isinstance(original_user_query, str): raise ValueError("The most recent message content must be a string.") past_messages = messages[:-1] + return ChatParams( top=top, temperature=temperature, + response_token_limit=response_token_limit, + prompt_template=prompt_template, enable_text_search=enable_text_search, enable_vector_search=enable_vector_search, original_user_query=original_user_query, @@ -52,6 +62,15 @@ async def run( ) -> RetrievalResponse: raise NotImplementedError + @abstractmethod + async def retreive_and_build_context( + self, + chat_params: ChatParams, + *args, + **kwargs, + ) -> tuple[list[ChatCompletionMessageParam], list[Item]]: + raise NotImplementedError + @abstractmethod async def run_stream( self, @@ -78,12 +97,10 @@ def __init__( self.chat_deployment = chat_deployment self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True) - async def run( - self, - messages: list[ChatCompletionMessageParam], - overrides: dict[str, Any] = {}, - ) -> RetrievalResponse: - chat_params = self.get_params(messages, overrides) + async def retreive_and_build_context( + self, chat_params: ChatParams + ) -> tuple[list[ChatCompletionMessageParam], list[Item]]: + """Retrieve relevant items from the database and build a context for the chat model.""" # Retrieve relevant items from the database results = await self.searcher.search_and_embed( @@ -97,22 +114,33 @@ async def run( content = "\n".join(sources_content) # Generate a contextual and content specific answer using the search results and chat history - response_token_limit = 1024 contextual_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, - system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, + system_prompt=chat_params.prompt_template, new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content, past_messages=chat_params.past_messages, - max_tokens=self.chat_token_limit - response_token_limit, + max_tokens=self.chat_token_limit - chat_params.response_token_limit, fallback_to_default=True, ) + return contextual_messages, results + + async def run( + self, + messages: list[ChatCompletionMessageParam], + overrides: dict[str, Any] = {}, + ) -> RetrievalResponse: + chat_params = self.get_params(messages, overrides) + + # Retrieve relevant items from the database + # Generate a contextual and content specific answer using the search results and chat history + contextual_messages, results = await self.retreive_and_build_context(chat_params=chat_params) chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, messages=contextual_messages, temperature=chat_params.temperature, - max_tokens=response_token_limit, + max_tokens=chat_params.response_token_limit, n=1, stream=False, ) @@ -158,26 +186,8 @@ async def run_stream( chat_params = self.get_params(messages, overrides) # Retrieve relevant items from the database - results = await self.searcher.search_and_embed( - chat_params.original_user_query, - top=chat_params.top, - enable_vector_search=chat_params.enable_vector_search, - enable_text_search=chat_params.enable_text_search, - ) - - sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results] - content = "\n".join(sources_content) - # Generate a contextual and content specific answer using the search results and chat history - response_token_limit = 1024 - contextual_messages: list[ChatCompletionMessageParam] = build_messages( - model=self.chat_model, - system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, - new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content, - past_messages=chat_params.past_messages, - max_tokens=self.chat_token_limit - response_token_limit, - fallback_to_default=True, - ) + contextual_messages, results = await self.retreive_and_build_context(chat_params=chat_params) chat_completion_async_stream: AsyncStream[ ChatCompletionChunk @@ -185,8 +195,8 @@ async def run_stream( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, messages=contextual_messages, - temperature=overrides.get("temperature", 0.3), - max_tokens=response_token_limit, + temperature=chat_params.temperature, + max_tokens=chat_params.response_token_limit, n=1, stream=True, ) From 24ac36b2e4e490c396a8cc090716e7c5f4ee0705 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Sat, 20 Jul 2024 20:26:49 +0000 Subject: [PATCH 03/16] implement chat/stream/ endpoint and fix empty choices error --- src/backend/fastapi_app/rag_simple.py | 4 +- src/backend/fastapi_app/routes/api_routes.py | 55 +++++++++++++++++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/src/backend/fastapi_app/rag_simple.py b/src/backend/fastapi_app/rag_simple.py index 36ee1913..e6baa18c 100644 --- a/src/backend/fastapi_app/rag_simple.py +++ b/src/backend/fastapi_app/rag_simple.py @@ -232,5 +232,7 @@ async def run_stream( ), ) async for response_chunk in chat_completion_async_stream: - yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant") + # first response has empty choices + if response_chunk.choices: + yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant") return diff --git a/src/backend/fastapi_app/routes/api_routes.py b/src/backend/fastapi_app/routes/api_routes.py index b0f02189..7b7f376a 100644 --- a/src/backend/fastapi_app/routes/api_routes.py +++ b/src/backend/fastapi_app/routes/api_routes.py @@ -1,8 +1,13 @@ +import json +import logging +from collections.abc import AsyncGenerator + import fastapi from fastapi import HTTPException +from fastapi.responses import StreamingResponse from sqlalchemy import select -from fastapi_app.api_models import ChatRequest, ItemPublic, ItemWithDistance, RetrievalResponse +from fastapi_app.api_models import ChatRequest, ItemPublic, ItemWithDistance, Message, RetrievalResponse from fastapi_app.dependencies import ChatClient, CommonDeps, DBSession, EmbeddingsClient from fastapi_app.postgres_models import Item from fastapi_app.postgres_searcher import PostgresSearcher @@ -12,6 +17,18 @@ router = fastapi.APIRouter() +async def format_as_ndjson(r: AsyncGenerator[RetrievalResponse | Message, None]) -> AsyncGenerator[str, None]: + """ + Format the response as NDJSON + """ + try: + async for event in r: + yield json.dumps(event.model_dump(), ensure_ascii=False) + "\n" + except Exception as error: + logging.exception("Exception while generating response stream: %s", error) + yield json.dumps({"error": str(error)}, ensure_ascii=False) + "\n" + + @router.get("/items/{id}", response_model=ItemPublic) async def item_handler(database_session: DBSession, id: int) -> ItemPublic: """A simple API to get an item by ID.""" @@ -96,3 +113,39 @@ async def chat_handler( response = await run_ragchat(chat_request.messages, overrides=overrides) return response + + +@router.post("/chat/stream") +async def chat_stream_handler( + context: CommonDeps, + database_session: DBSession, + openai_embed: EmbeddingsClient, + openai_chat: ChatClient, + chat_request: ChatRequest, +): + overrides = chat_request.context.get("overrides", {}) + + searcher = PostgresSearcher( + db_session=database_session, + openai_embed_client=openai_embed.client, + embed_deployment=context.openai_embed_deployment, + embed_model=context.openai_embed_model, + embed_dimensions=context.openai_embed_dimensions, + ) + if overrides.get("use_advanced_flow"): + run_ragchat = AdvancedRAGChat( + searcher=searcher, + openai_chat_client=openai_chat.client, + chat_model=context.openai_chat_model, + chat_deployment=context.openai_chat_deployment, + ).run_stream + else: + run_ragchat = SimpleRAGChat( + searcher=searcher, + openai_chat_client=openai_chat.client, + chat_model=context.openai_chat_model, + chat_deployment=context.openai_chat_deployment, + ).run_stream + + result = run_ragchat(chat_request.messages, overrides=overrides) + return StreamingResponse(content=format_as_ndjson(result), media_type="application/x-ndjson") From 04cf3b2fd53c3b81c5a1e7742e0b5126024b6e8a Mon Sep 17 00:00:00 2001 From: John Aziz Date: Mon, 22 Jul 2024 00:07:24 +0000 Subject: [PATCH 04/16] add tests and fix db_session issue --- src/backend/fastapi_app/rag_advanced.py | 5 + src/backend/fastapi_app/rag_simple.py | 5 + tests/test_api_routes.py | 117 ++++++++++++++++++++++++ 3 files changed, 127 insertions(+) diff --git a/src/backend/fastapi_app/rag_advanced.py b/src/backend/fastapi_app/rag_advanced.py index edcb52d5..d4f37a07 100644 --- a/src/backend/fastapi_app/rag_advanced.py +++ b/src/backend/fastapi_app/rag_advanced.py @@ -183,6 +183,11 @@ async def run_stream( stream=True, ) + # Forcefully Close the database session before yielding the response + # Yielding keeps the connection open while streaming the response until the end + # The connection closes when it returns back to the context manger in the dependencies + await self.searcher.db_session.close() + yield RetrievalResponse( message=Message(content="", role="assistant"), context=RAGContext( diff --git a/src/backend/fastapi_app/rag_simple.py b/src/backend/fastapi_app/rag_simple.py index e6baa18c..f56ee44c 100644 --- a/src/backend/fastapi_app/rag_simple.py +++ b/src/backend/fastapi_app/rag_simple.py @@ -201,6 +201,11 @@ async def run_stream( stream=True, ) + # Forcefully Close the database session before yielding the response + # Yielding keeps the connection open while streaming the response until the end + # The connection closes when it returns back to the context manger in the dependencies + await self.searcher.db_session.close() + yield RetrievalResponse( message=Message(content="", role="assistant"), context=RAGContext( diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py index 98e3ceeb..10e92233 100644 --- a/tests/test_api_routes.py +++ b/tests/test_api_routes.py @@ -227,6 +227,123 @@ async def test_simple_chat_flow(test_client): assert response_data["session_state"] is None +@pytest.mark.asyncio +async def test_simple_chat_streaming_flow(test_client): + """test the simple chat streaming flow route with hybrid retrieval mode""" + response = test_client.post( + "/chat/stream", + json={ + "context": { + "overrides": {"top": 1, "use_advanced_flow": False, "retrieval_mode": "hybrid", "temperature": 0.3} + }, + "messages": [{"content": "What is the capital of France?", "role": "user"}], + }, + ) + response_data = response.content.split(b"\n") + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/x-ndjson" + assert response_data[0] == ( + b'{"message": {"content": "", "role": "assistant"}, "context": {"data_points":' + + b' {"1": {"id": 1, "type": "Footwear", "brand": "Daybird", "name": "Wanderer B' + + b'lack Hiking Boots", "description": "Daybird\'s Wanderer Hiking Boots in s' + + b"leek black are perfect for all your outdoor adventures. These boots are made" + + b" with a waterproof leather upper and a durable rubber sole for superior trac" + + b"tion. With their cushioned insole and padded collar, these boots will keep y" + + b'ou comfortable all day long.", "price": 109.99}}, "thoughts": [{"title": "Se' + + b'arch query for database", "description": "What is the capital of France?", "' + + b'props": {"top": 1, "vector_search": true, "text_search": true}}, {"title": "' + + b'Search results", "description": [{"id": 1, "type": "Footwear", "brand": "Day' + + b'bird", "name": "Wanderer Black Hiking Boots", "description": "Daybird\'s ' + + b"Wanderer Hiking Boots in sleek black are perfect for all your outdoor advent" + + b"ures. These boots are made with a waterproof leather upper and a durable rub" + + b"ber sole for superior traction. With their cushioned insole and padded colla" + + b'r, these boots will keep you comfortable all day long.", "price": 109.99}], ' + + b'"props": {}}, {"title": "Prompt to generate answer", "description": ["{\'' + + b"role': 'system', 'content': \\\"Assistant helps customers with questio" + + b"ns about products.\\\\nRespond as if you are a salesperson helping a custo" + + b"mer in a store. Do NOT respond with tables.\\\\nAnswer ONLY with the produ" + + b"ct details listed in the products.\\\\nIf there isn't enough information b" + + b"elow, say you don't know.\\\\nDo not generate answers that don't use the s" + + b"ources below.\\\\nEach product has an ID in brackets followed by colon and" + + b" the product details.\\\\nAlways include the product ID for each product y" + + b"ou use in the response.\\\\nUse square brackets to reference the source, f" + + b"or example [52].\\\\nDon't combine citations, list each product separately" + + b", for example [27][51].\\\"}\", \"{'role': 'user', 'content': \\\"What is " + + b"the capital of France?\\\\n\\\\nSources:\\\\n[1]:Name:Wanderer Black Hikin" + + b"g Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfe" + + b"ct for all your outdoor adventures. These boots are made with a waterproof l" + + b"eather upper and a durable rubber sole for superior traction. With their cus" + + b"hioned insole and padded collar, these boots will keep you comfortable all d" + + b'ay long. Price:109.99 Brand:Daybird Type:Footwear\\\\n\\\\n\\"}"], "props' + + b'": {"model": "gpt-35-turbo", "deployment": "gpt-35-turbo"}}], "followup_ques' + + b'tions": null}, "session_state": null}' + ) + + +@pytest.mark.asyncio +async def test_advanved_chat_streaming_flow(test_client): + """test the advanced chat streaming flow route with hybrid retrieval mode""" + response = test_client.post( + "/chat/stream", + json={ + "context": { + "overrides": {"top": 1, "use_advanced_flow": True, "retrieval_mode": "hybrid", "temperature": 0.3} + }, + "messages": [{"content": "What is the capital of France?", "role": "user"}], + }, + ) + response_data = response.content.split(b"\n") + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/x-ndjson" + assert response_data[0] == ( + b'{"message": {"content": "", "role": "assistant"}, "context": {"data_points":' + + b' {"1": {"id": 1, "type": "Footwear", "brand": "Daybird", "name": "Wanderer B' + + b'lack Hiking Boots", "description": "Daybird\'s Wanderer Hiking Boots in s' + + b'leek black are perfect for all your outdoor adventures. These boots are made' + + b' with a waterproof leather upper and a durable rubber sole for superior trac' + + b'tion. With their cushioned insole and padded collar, these boots will keep y' + + b'ou comfortable all day long.", "price": 109.99}}, "thoughts": [{"title": "Pr' + + b'ompt to generate search arguments", "description": ["{\'role\': \'system\', ' + + b"'content': 'Below is a history of the conversation so far, and a new questio" + + b'n asked by the user that needs to be answered by searching database rows' + + b'.\\\\nYou have access to an Azure PostgreSQL database with an items table ' + + b'that has columns for title, description, brand, price, and type.\\\\nGener' + + b'ate a search query based on the conversation and the new question.\\\\nIf ' + + b'the question is not in English, translate the question to English before gen' + + b'erating the search query.\\\\nIf you cannot generate a search query, retur' + + b'n the original user question.\\\\nDO NOT return anything besides the query' + + b'.\'}", "{\'role\': \'user\', \'content\': \'What is the capital of Franc' + + b'e?\'}"], "props": {"model": "gpt-35-turbo", "deployment": "gpt-35-turbo"}' + + b'}, {"title": "Search using generated search arguments", "description": "The ' + + b'capital of France is Paris. [Benefit_Options-2.pdf].", "props": {"top": 1, "' + + b'vector_search": true, "text_search": true, "filters": []}}, {"title": "Searc' + + b'h results", "description": [{"id": 1, "type": "Footwear", "brand": "Daybird"' + + b', "name": "Wanderer Black Hiking Boots", "description": "Daybird\'s Wande' + + b'rer Hiking Boots in sleek black are perfect for all your outdoor adventures.' + + b' These boots are made with a waterproof leather upper and a durable rubber s' + + b'ole for superior traction. With their cushioned insole and padded collar, th' + + b'ese boots will keep you comfortable all day long.", "price": 109.99}], "prop' + + b's": {}}, {"title": "Prompt to generate answer", "description": ["{\'role\'' + + b': \'system\', \'content\': \\"Assistant helps customers with questions ab' + + b'out products.\\\\nRespond as if you are a salesperson helping a customer i' + + b'n a store. Do NOT respond with tables.\\\\nAnswer ONLY with the product de' + + b"tails listed in the products.\\\\nIf there isn't enough information below," + + b" say you don't know.\\\\nDo not generate answers that don't use the source" + + b's below.\\\\nEach product has an ID in brackets followed by colon and the ' + + b'product details.\\\\nAlways include the product ID for each product you us' + + b'e in the response.\\\\nUse square brackets to reference the source, for ex' + + b"ample [52].\\\\nDon't combine citations, list each product separately, for" + + b' example [27][51].\\"}", "{\'role\': \'user\', \'content\': \\"What is the c' + + b'apital of France?\\\\n\\\\nSources:\\\\n[1]:Name:Wanderer Black Hiking Boo' + + b"ts Description:Daybird's Wanderer Hiking Boots in sleek black are perfect fo" + + b'r all your outdoor adventures. These boots are made with a waterproof leathe' + + b'r upper and a durable rubber sole for superior traction. With their cushione' + + b'd insole and padded collar, these boots will keep you comfortable all day lo' + + b'ng. Price:109.99 Brand:Daybird Type:Footwear\\\\n\\\\n\\"}"], "props": {"' + + b'model": "gpt-35-turbo", "deployment": "gpt-35-turbo"}}], "followup_questions' + + b'": null}, "session_state": null}' + ) + @pytest.mark.asyncio async def test_advanced_chat_flow(test_client): """test the advanced chat flow route with hybrid retrieval mode""" From 6beb091d055dd35e83ef3213155861c8f3e301b5 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Tue, 23 Jul 2024 01:21:14 +0000 Subject: [PATCH 05/16] fix typo and add missing check --- src/backend/fastapi_app/rag_advanced.py | 6 ++++-- src/backend/fastapi_app/rag_simple.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/backend/fastapi_app/rag_advanced.py b/src/backend/fastapi_app/rag_advanced.py index d4f37a07..6d7d7f71 100644 --- a/src/backend/fastapi_app/rag_advanced.py +++ b/src/backend/fastapi_app/rag_advanced.py @@ -183,7 +183,7 @@ async def run_stream( stream=True, ) - # Forcefully Close the database session before yielding the response + # Forcefully close the database session before yielding the response # Yielding keeps the connection open while streaming the response until the end # The connection closes when it returns back to the context manger in the dependencies await self.searcher.db_session.close() @@ -230,5 +230,7 @@ async def run_stream( ) async for response_chunk in chat_completion_async_stream: - yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant") + # first response has empty choices + if response_chunk.choices: + yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant") return diff --git a/src/backend/fastapi_app/rag_simple.py b/src/backend/fastapi_app/rag_simple.py index f56ee44c..be05abd7 100644 --- a/src/backend/fastapi_app/rag_simple.py +++ b/src/backend/fastapi_app/rag_simple.py @@ -201,7 +201,7 @@ async def run_stream( stream=True, ) - # Forcefully Close the database session before yielding the response + # Forcefully close the database session before yielding the response # Yielding keeps the connection open while streaming the response until the end # The connection closes when it returns back to the context manger in the dependencies await self.searcher.db_session.close() From c48dac3d4ca4c8b2ea1efed3d402239397ef5f41 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Tue, 23 Jul 2024 01:23:08 +0000 Subject: [PATCH 06/16] initial frontend --- src/frontend/src/api/models.ts | 6 +- src/frontend/src/pages/chat/Chat.tsx | 105 +++++++++++++++++++++++++-- 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/src/frontend/src/api/models.ts b/src/frontend/src/api/models.ts index deee7b68..cd7d3c3b 100644 --- a/src/frontend/src/api/models.ts +++ b/src/frontend/src/api/models.ts @@ -1,4 +1,4 @@ -import { AIChatCompletion } from "@microsoft/ai-chat-protocol"; +import { AIChatCompletion, AIChatCompletionDelta } from "@microsoft/ai-chat-protocol"; export const enum RetrievalMode { Hybrid = "hybrid", @@ -29,3 +29,7 @@ export type RAGContext = { export interface RAGChatCompletion extends AIChatCompletion { context: RAGContext; } + +export interface RAGChatCompletionDelta extends AIChatCompletionDelta { + context: RAGContext; +} diff --git a/src/frontend/src/pages/chat/Chat.tsx b/src/frontend/src/pages/chat/Chat.tsx index 6918cf76..8f19d3b4 100644 --- a/src/frontend/src/pages/chat/Chat.tsx +++ b/src/frontend/src/pages/chat/Chat.tsx @@ -1,11 +1,10 @@ import { useRef, useState, useEffect } from "react"; import { Panel, DefaultButton, TextField, SpinButton, Slider, Checkbox } from "@fluentui/react"; import { SparkleFilled } from "@fluentui/react-icons"; -import { AIChatMessage, AIChatProtocolClient } from "@microsoft/ai-chat-protocol"; import styles from "./Chat.module.css"; -import {RetrievalMode, RAGChatCompletion} from "../../api"; +import { RetrievalMode, RAGChatCompletion, RAGChatCompletionDelta } from "../../api"; import { Answer, AnswerError, AnswerLoading } from "../../components/Answer"; import { QuestionInput } from "../../components/QuestionInput"; import { ExampleList } from "../../components/Example"; @@ -22,11 +21,13 @@ const Chat = () => { const [retrieveCount, setRetrieveCount] = useState(3); const [retrievalMode, setRetrievalMode] = useState(RetrievalMode.Hybrid); const [useAdvancedFlow, setUseAdvancedFlow] = useState(true); + const [shouldStream, setShouldStream] = useState(true); const lastQuestionRef = useRef(""); const chatMessageStreamEnd = useRef(null); const [isLoading, setIsLoading] = useState(false); + const [isStreaming, setIsStreaming] = useState(false); const [error, setError] = useState(); const [activeCitation, setActiveCitation] = useState(); @@ -34,7 +35,63 @@ const Chat = () => { const [selectedAnswer, setSelectedAnswer] = useState(0); const [answers, setAnswers] = useState<[user: string, response: RAGChatCompletion][]>([]); + const [streamedAnswers, setStreamedAnswers] = useState<[user: string, response: RAGChatCompletion][]>([]); + const handleAsyncRequest = async ( + question: string, + answers: [string, RAGChatCompletionDelta][], + setStreamedAnswers: Function, + result: AsyncIterable + ) => { + let answer = ""; + let chatCompletion: RAGChatCompletion = { + context: { + data_points: {}, + followup_questions: null, + thoughts: [] + }, + message: { content: "", role: "assistant" }, + }; + const updateState = (newContent: string) => { + return new Promise(resolve => { + setTimeout(() => { + answer += newContent; + // We need to create a new object to trigger a re-render + const latestCompletion: RAGChatCompletionDelta = { + ...chatCompletion, + delta: { content: answer, role: chatCompletion.message.role } + }; + setStreamedAnswers([...answers, [question, latestCompletion]]); + resolve(null); + }, 33); + }); + }; + try { + setIsStreaming(true); + for await (const response of result) { + if (!response.delta) { + continue; + } + if (response.role) { + chatCompletion.message.role = response.delta.role; + } + if (response.content) { + setIsLoading(false); + await updateState(response.delta.content); + } + if (response.context) { + chatCompletion.context = { + ...chatCompletion.context, + ...response.context + }; + } + } + } finally { + setIsStreaming(false); + } + chatCompletion.message.content = answer; + return chatCompletion; + }; const makeApiRequest = async (question: string) => { lastQuestionRef.current = question; @@ -61,8 +118,14 @@ const Chat = () => { } }; const chatClient: AIChatProtocolClient = new AIChatProtocolClient("/chat"); - const result = await chatClient.getCompletion(allMessages, options) as RAGChatCompletion; - setAnswers([...answers, [question, result]]); + if (shouldStream) { + const result = await chatClient.getStreamedCompletion(allMessages, options); + const parsedResponse = await handleAsyncRequest(question, answers, setStreamedAnswers, result); + setAnswers([...answers, [question, parsedResponse]]); + } else { + const result = await chatClient.getCompletion(allMessages, options) as RAGChatCompletion; + setAnswers([...answers, [question, result]]); + } } catch (e) { setError(e); } finally { @@ -76,10 +139,13 @@ const Chat = () => { setActiveCitation(undefined); setActiveAnalysisPanelTab(undefined); setAnswers([]); + setStreamedAnswers([]); setIsLoading(false); + setIsStreaming(false); }; useEffect(() => chatMessageStreamEnd.current?.scrollIntoView({ behavior: "smooth" }), [isLoading]); + useEffect(() => chatMessageStreamEnd.current?.scrollIntoView({ behavior: "auto" }), [streamedAnswers]); const onPromptTemplateChange = (_ev?: React.FormEvent, newValue?: string) => { setPromptTemplate(newValue || ""); @@ -101,6 +167,10 @@ const Chat = () => { setUseAdvancedFlow(!!checked); } + const onShouldStreamChange = (_ev?: React.FormEvent, checked?: boolean) => { + setShouldStream(!!checked); + }; + const onExampleClicked = (example: string) => { makeApiRequest(example); }; @@ -143,7 +213,25 @@ const Chat = () => { ) : (
- {answers.map((answer, index) => ( + {isStreaming && streamedAnswers.map((streamedAnswer, index) => ( +
+ +
+ onShowCitation(c, index)} + onThoughtProcessClicked={() => onToggleTab(AnalysisPanelTabs.ThoughtProcessTab, index)} + onSupportingContentClicked={() => onToggleTab(AnalysisPanelTabs.SupportingContentTab, index)} + onFollowupQuestionClicked={q => makeApiRequest(q)} + /> +
+
+ ))} + {!isStreaming && + answers.map((answer, index) => (
@@ -257,6 +345,13 @@ const Chat = () => { snapToStep /> + +
From 3e50e6112b834bba3ea11f17e8bbd90d43c0a984 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Tue, 23 Jul 2024 04:15:56 +0000 Subject: [PATCH 07/16] fix type for streaming to conform with Microsoft Chat Protocol --- src/backend/fastapi_app/api_models.py | 15 +- src/backend/fastapi_app/rag_advanced.py | 28 ++-- src/backend/fastapi_app/rag_simple.py | 30 ++-- src/backend/fastapi_app/routes/api_routes.py | 12 +- tests/test_api_routes.py | 158 +++++++++---------- 5 files changed, 138 insertions(+), 105 deletions(-) diff --git a/src/backend/fastapi_app/api_models.py b/src/backend/fastapi_app/api_models.py index 2e214a5e..db0ea2c9 100644 --- a/src/backend/fastapi_app/api_models.py +++ b/src/backend/fastapi_app/api_models.py @@ -1,12 +1,19 @@ +from enum import Enum from typing import Any from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel +class AIChatRoles(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + class Message(BaseModel): content: str - role: str = "user" + role: AIChatRoles = AIChatRoles.USER class ChatRequest(BaseModel): @@ -32,6 +39,12 @@ class RetrievalResponse(BaseModel): session_state: Any | None = None +class RetrievalResponseDelta(BaseModel): + delta: Message | None = None + context: RAGContext | None = None + session_state: Any | None = None + + class ItemPublic(BaseModel): id: int type: str diff --git a/src/backend/fastapi_app/rag_advanced.py b/src/backend/fastapi_app/rag_advanced.py index 6d7d7f71..3fc27748 100644 --- a/src/backend/fastapi_app/rag_advanced.py +++ b/src/backend/fastapi_app/rag_advanced.py @@ -5,7 +5,14 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep +from fastapi_app.api_models import ( + AIChatRoles, + Message, + RAGContext, + RetrievalResponse, + RetrievalResponseDelta, + ThoughtStep, +) from fastapi_app.postgres_models import Item from fastapi_app.postgres_searcher import PostgresSearcher from fastapi_app.query_rewriter import build_search_function, extract_search_arguments @@ -110,10 +117,10 @@ async def run( stream=False, ) - first_choice_message = chat_completion_response.choices[0].message - return RetrievalResponse( - message=Message(content=str(first_choice_message.content), role=first_choice_message.role), + message=Message( + content=str(chat_completion_response.choices[0].message.content), role=AIChatRoles.ASSISTANT + ), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, thoughts=[ @@ -157,7 +164,7 @@ async def run_stream( self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {}, - ) -> AsyncGenerator[RetrievalResponse | Message, None]: + ) -> AsyncGenerator[RetrievalResponseDelta, None]: chat_params = self.get_params(messages, overrides) # Generate an optimized keyword search query based on the chat history and the last question @@ -188,8 +195,7 @@ async def run_stream( # The connection closes when it returns back to the context manger in the dependencies await self.searcher.db_session.close() - yield RetrievalResponse( - message=Message(content="", role="assistant"), + yield RetrievalResponseDelta( context=RAGContext( data_points={item.id: item.to_dict() for item in results}, thoughts=[ @@ -230,7 +236,9 @@ async def run_stream( ) async for response_chunk in chat_completion_async_stream: - # first response has empty choices - if response_chunk.choices: - yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant") + # first response has empty choices and last response has empty content + if response_chunk.choices and response_chunk.choices[0].delta.content: + yield RetrievalResponseDelta( + delta=Message(content=str(response_chunk.choices[0].delta.content), role=AIChatRoles.ASSISTANT) + ) return diff --git a/src/backend/fastapi_app/rag_simple.py b/src/backend/fastapi_app/rag_simple.py index be05abd7..2280aa20 100644 --- a/src/backend/fastapi_app/rag_simple.py +++ b/src/backend/fastapi_app/rag_simple.py @@ -8,7 +8,14 @@ from openai_messages_token_helper import build_messages, get_token_limit from pydantic import BaseModel -from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep +from fastapi_app.api_models import ( + AIChatRoles, + Message, + RAGContext, + RetrievalResponse, + RetrievalResponseDelta, + ThoughtStep, +) from fastapi_app.postgres_models import Item from fastapi_app.postgres_searcher import PostgresSearcher @@ -76,7 +83,7 @@ async def run_stream( self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {}, - ) -> AsyncGenerator[RetrievalResponse | Message, None]: + ) -> AsyncGenerator[RetrievalResponseDelta, None]: raise NotImplementedError if False: yield 0 @@ -145,10 +152,10 @@ async def run( stream=False, ) - first_choice_message = chat_completion_response.choices[0].message - return RetrievalResponse( - message=Message(content=str(first_choice_message.content), role=first_choice_message.role), + message=Message( + content=str(chat_completion_response.choices[0].message.content), role=AIChatRoles.ASSISTANT + ), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, thoughts=[ @@ -182,7 +189,7 @@ async def run_stream( self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {}, - ) -> AsyncGenerator[RetrievalResponse | Message, None]: + ) -> AsyncGenerator[RetrievalResponseDelta, None]: chat_params = self.get_params(messages, overrides) # Retrieve relevant items from the database @@ -206,8 +213,7 @@ async def run_stream( # The connection closes when it returns back to the context manger in the dependencies await self.searcher.db_session.close() - yield RetrievalResponse( - message=Message(content="", role="assistant"), + yield RetrievalResponseDelta( context=RAGContext( data_points={item.id: item.to_dict() for item in results}, thoughts=[ @@ -237,7 +243,9 @@ async def run_stream( ), ) async for response_chunk in chat_completion_async_stream: - # first response has empty choices - if response_chunk.choices: - yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant") + # first response has empty choices and last response has empty content + if response_chunk.choices and response_chunk.choices[0].delta.content: + yield RetrievalResponseDelta( + delta=Message(content=str(response_chunk.choices[0].delta.content), role=AIChatRoles.ASSISTANT) + ) return diff --git a/src/backend/fastapi_app/routes/api_routes.py b/src/backend/fastapi_app/routes/api_routes.py index 7b7f376a..e531885e 100644 --- a/src/backend/fastapi_app/routes/api_routes.py +++ b/src/backend/fastapi_app/routes/api_routes.py @@ -7,7 +7,13 @@ from fastapi.responses import StreamingResponse from sqlalchemy import select -from fastapi_app.api_models import ChatRequest, ItemPublic, ItemWithDistance, Message, RetrievalResponse +from fastapi_app.api_models import ( + ChatRequest, + ItemPublic, + ItemWithDistance, + RetrievalResponse, + RetrievalResponseDelta, +) from fastapi_app.dependencies import ChatClient, CommonDeps, DBSession, EmbeddingsClient from fastapi_app.postgres_models import Item from fastapi_app.postgres_searcher import PostgresSearcher @@ -17,13 +23,13 @@ router = fastapi.APIRouter() -async def format_as_ndjson(r: AsyncGenerator[RetrievalResponse | Message, None]) -> AsyncGenerator[str, None]: +async def format_as_ndjson(r: AsyncGenerator[RetrievalResponseDelta, None]) -> AsyncGenerator[str, None]: """ Format the response as NDJSON """ try: async for event in r: - yield json.dumps(event.model_dump(), ensure_ascii=False) + "\n" + yield event.model_dump_json() + "\n" except Exception as error: logging.exception("Exception while generating response stream: %s", error) yield json.dumps({"error": str(error)}, ensure_ascii=False) + "\n" diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py index 10e92233..7f138759 100644 --- a/tests/test_api_routes.py +++ b/tests/test_api_routes.py @@ -243,40 +243,39 @@ async def test_simple_chat_streaming_flow(test_client): assert response.status_code == 200 assert response.headers["Content-Type"] == "application/x-ndjson" assert response_data[0] == ( - b'{"message": {"content": "", "role": "assistant"}, "context": {"data_points":' - + b' {"1": {"id": 1, "type": "Footwear", "brand": "Daybird", "name": "Wanderer B' - + b'lack Hiking Boots", "description": "Daybird\'s Wanderer Hiking Boots in s' - + b"leek black are perfect for all your outdoor adventures. These boots are made" - + b" with a waterproof leather upper and a durable rubber sole for superior trac" - + b"tion. With their cushioned insole and padded collar, these boots will keep y" - + b'ou comfortable all day long.", "price": 109.99}}, "thoughts": [{"title": "Se' - + b'arch query for database", "description": "What is the capital of France?", "' - + b'props": {"top": 1, "vector_search": true, "text_search": true}}, {"title": "' - + b'Search results", "description": [{"id": 1, "type": "Footwear", "brand": "Day' - + b'bird", "name": "Wanderer Black Hiking Boots", "description": "Daybird\'s ' - + b"Wanderer Hiking Boots in sleek black are perfect for all your outdoor advent" - + b"ures. These boots are made with a waterproof leather upper and a durable rub" - + b"ber sole for superior traction. With their cushioned insole and padded colla" - + b'r, these boots will keep you comfortable all day long.", "price": 109.99}], ' - + b'"props": {}}, {"title": "Prompt to generate answer", "description": ["{\'' - + b"role': 'system', 'content': \\\"Assistant helps customers with questio" - + b"ns about products.\\\\nRespond as if you are a salesperson helping a custo" - + b"mer in a store. Do NOT respond with tables.\\\\nAnswer ONLY with the produ" - + b"ct details listed in the products.\\\\nIf there isn't enough information b" - + b"elow, say you don't know.\\\\nDo not generate answers that don't use the s" - + b"ources below.\\\\nEach product has an ID in brackets followed by colon and" - + b" the product details.\\\\nAlways include the product ID for each product y" - + b"ou use in the response.\\\\nUse square brackets to reference the source, f" - + b"or example [52].\\\\nDon't combine citations, list each product separately" - + b", for example [27][51].\\\"}\", \"{'role': 'user', 'content': \\\"What is " - + b"the capital of France?\\\\n\\\\nSources:\\\\n[1]:Name:Wanderer Black Hikin" - + b"g Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfe" - + b"ct for all your outdoor adventures. These boots are made with a waterproof l" - + b"eather upper and a durable rubber sole for superior traction. With their cus" - + b"hioned insole and padded collar, these boots will keep you comfortable all d" - + b'ay long. Price:109.99 Brand:Daybird Type:Footwear\\\\n\\\\n\\"}"], "props' - + b'": {"model": "gpt-35-turbo", "deployment": "gpt-35-turbo"}}], "followup_ques' - + b'tions": null}, "session_state": null}' + b'{"delta":null,"context":{"data_points":{"1":{"id":1,"type":"Footwear","brand' + + b'":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird\'' + + b"s Wanderer Hiking Boots in sleek black are perfect for all your outdoor adve" + + b"ntures. These boots are made with a waterproof leather upper and a durable r" + + b"ubber sole for superior traction. With their cushioned insole and padded col" + + b'lar, these boots will keep you comfortable all day long.","price":109.99}},"' + + b'thoughts":[{"title":"Search query for database","description":"What is the c' + + b'apital of France?","props":{"top":1,"vector_search":true,"text_search":true}' + + b'},{"title":"Search results","description":[{"id":1,"type":"Footwear","brand"' + + b':"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird\'s' + + b" Wanderer Hiking Boots in sleek black are perfect for all your outdoor adven" + + b"tures. These boots are made with a waterproof leather upper and a durable ru" + + b"bber sole for superior traction. With their cushioned insole and padded coll" + + b'ar, these boots will keep you comfortable all day long.","price":109.99}],"p' + + b'rops":{}},{"title":"Prompt to generate answer","description":["{\'role\': ' + + b"'system', 'content': \\\"Assistant helps customers with questions abou" + + b"t products.\\\\nRespond as if you are a salesperson helping a customer in " + + b"a store. Do NOT respond with tables.\\\\nAnswer ONLY with the product deta" + + b"ils listed in the products.\\\\nIf there isn't enough information below, s" + + b"ay you don't know.\\\\nDo not generate answers that don't use the sources " + + b"below.\\\\nEach product has an ID in brackets followed by colon and the pr" + + b"oduct details.\\\\nAlways include the product ID for each product you use " + + b"in the response.\\\\nUse square brackets to reference the source, for exam" + + b"ple [52].\\\\nDon't combine citations, list each product separately, for e" + + b"xample [27][51].\\\"}\",\"{'role': 'user', 'content': \\\"What is the capi" + + b"tal of France?\\\\n\\\\nSources:\\\\n[1]:Name:Wanderer Black Hiking Boots " + + b"Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for a" + + b"ll your outdoor adventures. These boots are made with a waterproof leather u" + + b"pper and a durable rubber sole for superior traction. With their cushioned i" + + b"nsole and padded collar, these boots will keep you comfortable all day long." + + b' Price:109.99 Brand:Daybird Type:Footwear\\\\n\\\\n\\"}"],"props":{"model' + + b'":"gpt-35-turbo","deployment":"gpt-35-turbo"}}],"followup_questions":null},"' + + b'session_state":null}' ) @@ -296,54 +295,53 @@ async def test_advanved_chat_streaming_flow(test_client): assert response.status_code == 200 assert response.headers["Content-Type"] == "application/x-ndjson" assert response_data[0] == ( - b'{"message": {"content": "", "role": "assistant"}, "context": {"data_points":' - + b' {"1": {"id": 1, "type": "Footwear", "brand": "Daybird", "name": "Wanderer B' - + b'lack Hiking Boots", "description": "Daybird\'s Wanderer Hiking Boots in s' - + b'leek black are perfect for all your outdoor adventures. These boots are made' - + b' with a waterproof leather upper and a durable rubber sole for superior trac' - + b'tion. With their cushioned insole and padded collar, these boots will keep y' - + b'ou comfortable all day long.", "price": 109.99}}, "thoughts": [{"title": "Pr' - + b'ompt to generate search arguments", "description": ["{\'role\': \'system\', ' - + b"'content': 'Below is a history of the conversation so far, and a new questio" - + b'n asked by the user that needs to be answered by searching database rows' - + b'.\\\\nYou have access to an Azure PostgreSQL database with an items table ' - + b'that has columns for title, description, brand, price, and type.\\\\nGener' - + b'ate a search query based on the conversation and the new question.\\\\nIf ' - + b'the question is not in English, translate the question to English before gen' - + b'erating the search query.\\\\nIf you cannot generate a search query, retur' - + b'n the original user question.\\\\nDO NOT return anything besides the query' - + b'.\'}", "{\'role\': \'user\', \'content\': \'What is the capital of Franc' - + b'e?\'}"], "props": {"model": "gpt-35-turbo", "deployment": "gpt-35-turbo"}' - + b'}, {"title": "Search using generated search arguments", "description": "The ' - + b'capital of France is Paris. [Benefit_Options-2.pdf].", "props": {"top": 1, "' - + b'vector_search": true, "text_search": true, "filters": []}}, {"title": "Searc' - + b'h results", "description": [{"id": 1, "type": "Footwear", "brand": "Daybird"' - + b', "name": "Wanderer Black Hiking Boots", "description": "Daybird\'s Wande' - + b'rer Hiking Boots in sleek black are perfect for all your outdoor adventures.' - + b' These boots are made with a waterproof leather upper and a durable rubber s' - + b'ole for superior traction. With their cushioned insole and padded collar, th' - + b'ese boots will keep you comfortable all day long.", "price": 109.99}], "prop' - + b's": {}}, {"title": "Prompt to generate answer", "description": ["{\'role\'' - + b': \'system\', \'content\': \\"Assistant helps customers with questions ab' - + b'out products.\\\\nRespond as if you are a salesperson helping a customer i' - + b'n a store. Do NOT respond with tables.\\\\nAnswer ONLY with the product de' - + b"tails listed in the products.\\\\nIf there isn't enough information below," - + b" say you don't know.\\\\nDo not generate answers that don't use the source" - + b's below.\\\\nEach product has an ID in brackets followed by colon and the ' - + b'product details.\\\\nAlways include the product ID for each product you us' - + b'e in the response.\\\\nUse square brackets to reference the source, for ex' - + b"ample [52].\\\\nDon't combine citations, list each product separately, for" - + b' example [27][51].\\"}", "{\'role\': \'user\', \'content\': \\"What is the c' - + b'apital of France?\\\\n\\\\nSources:\\\\n[1]:Name:Wanderer Black Hiking Boo' - + b"ts Description:Daybird's Wanderer Hiking Boots in sleek black are perfect fo" - + b'r all your outdoor adventures. These boots are made with a waterproof leathe' - + b'r upper and a durable rubber sole for superior traction. With their cushione' - + b'd insole and padded collar, these boots will keep you comfortable all day lo' - + b'ng. Price:109.99 Brand:Daybird Type:Footwear\\\\n\\\\n\\"}"], "props": {"' - + b'model": "gpt-35-turbo", "deployment": "gpt-35-turbo"}}], "followup_questions' - + b'": null}, "session_state": null}' + b'{"delta":null,"context":{"data_points":{"1":{"id":1,"type":"Footwear","brand' + + b'":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird\'' + + b"s Wanderer Hiking Boots in sleek black are perfect for all your outdoor adve" + + b"ntures. These boots are made with a waterproof leather upper and a durable r" + + b"ubber sole for superior traction. With their cushioned insole and padded col" + + b'lar, these boots will keep you comfortable all day long.","price":109.99}},"' + + b'thoughts":[{"title":"Prompt to generate search arguments","description":' + + b"[\"{'role': 'system', 'content': 'Below is a history of the conversat" + + b"ion so far, and a new question asked by the user that needs to be answered b" + + b"y searching database rows.\\\\nYou have access to an Azure PostgreSQL data" + + b"base with an items table that has columns for title, description, brand, pri" + + b"ce, and type.\\\\nGenerate a search query based on the conversation and th" + + b"e new question.\\\\nIf the question is not in English, translate the quest" + + b"ion to English before generating the search query.\\\\nIf you cannot gener" + + b"ate a search query, return the original user question.\\\\nDO NOT return a" + + b"nything besides the query.'}\",\"{'role': 'user', 'content': 'What is " + + b'the capital of France?\'}"],"props":{"model":"gpt-35-turbo","deployment":' + + b'"gpt-35-turbo"}},{"title":"Search using generated search arguments","descrip' + + b'tion":"The capital of France is Paris. [Benefit_Options-2.pdf].","props":{"t' + + b'op":1,"vector_search":true,"text_search":true,"filters":[]}},{"title":"Searc' + + b'h results","description":[{"id":1,"type":"Footwear","brand":"Daybird","name"' + + b':"Wanderer Black Hiking Boots","description":"Daybird\'s Wanderer Hiking ' + + b"Boots in sleek black are perfect for all your outdoor adventures. These boot" + + b"s are made with a waterproof leather upper and a durable rubber sole for sup" + + b"erior traction. With their cushioned insole and padded collar, these boots w" + + b'ill keep you comfortable all day long.","price":109.99}],"props":{}},{"title' + + b'":"Prompt to generate answer","description":["{\'role\': \'system\', \'co' + + b"ntent': \\\"Assistant helps customers with questions about products.\\\\nRes" + + b"pond as if you are a salesperson helping a customer in a store. Do NOT respo" + + b"nd with tables.\\\\nAnswer ONLY with the product details listed in the pro" + + b"ducts.\\\\nIf there isn't enough information below, say you don't know.\\\\n" + + b"Do not generate answers that don't use the sources below.\\\\nEach product" + + b" has an ID in brackets followed by colon and the product details.\\\\nAlwa" + + b"ys include the product ID for each product you use in the response.\\\\nUs" + + b"e square brackets to reference the source, for example [52].\\\\nDon't com" + + b'bine citations, list each product separately, for example [27][51].\\"}",' + + b"\"{'role': 'user', 'content': \\\"What is the capital of France?\\\\n" + + b"\\\\nSources:\\\\n[1]:Name:Wanderer Black Hiking Boots Description:Daybird's" + + b" Wanderer Hiking Boots in sleek black are perfect for all your outdoor adven" + + b"tures. These boots are made with a waterproof leather upper and a durable ru" + + b"bber sole for superior traction. With their cushioned insole and padded coll" + + b"ar, these boots will keep you comfortable all day long. Price:109.99 Brand:D" + + b'aybird Type:Footwear\\\\n\\\\n\\"}"],"props":{"model":"gpt-35-turbo","dep' + + b'loyment":"gpt-35-turbo"}}],"followup_questions":null},"session_state":null}' ) + @pytest.mark.asyncio async def test_advanced_chat_flow(test_client): """test the advanced chat flow route with hybrid retrieval mode""" From 5ca4ac33045b723fc48f2e3edbbddb183f89d013 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Tue, 23 Jul 2024 04:20:05 +0000 Subject: [PATCH 08/16] add working streaming frontend --- src/frontend/src/pages/chat/Chat.tsx | 53 +++++++++++----------------- 1 file changed, 21 insertions(+), 32 deletions(-) diff --git a/src/frontend/src/pages/chat/Chat.tsx b/src/frontend/src/pages/chat/Chat.tsx index 8f19d3b4..2f7569dc 100644 --- a/src/frontend/src/pages/chat/Chat.tsx +++ b/src/frontend/src/pages/chat/Chat.tsx @@ -5,6 +5,7 @@ import { SparkleFilled } from "@fluentui/react-icons"; import styles from "./Chat.module.css"; import { RetrievalMode, RAGChatCompletion, RAGChatCompletionDelta } from "../../api"; +import { AIChatProtocolClient, AIChatMessage } from "@microsoft/ai-chat-protocol"; import { Answer, AnswerError, AnswerLoading } from "../../components/Answer"; import { QuestionInput } from "../../components/QuestionInput"; import { ExampleList } from "../../components/Example"; @@ -37,12 +38,7 @@ const Chat = () => { const [answers, setAnswers] = useState<[user: string, response: RAGChatCompletion][]>([]); const [streamedAnswers, setStreamedAnswers] = useState<[user: string, response: RAGChatCompletion][]>([]); - const handleAsyncRequest = async ( - question: string, - answers: [string, RAGChatCompletionDelta][], - setStreamedAnswers: Function, - result: AsyncIterable - ) => { + const handleAsyncRequest = async (question: string, answers: [string, RAGChatCompletion][], result: AsyncIterable) => { let answer = ""; let chatCompletion: RAGChatCompletion = { context: { @@ -50,16 +46,16 @@ const Chat = () => { followup_questions: null, thoughts: [] }, - message: { content: "", role: "assistant" }, + message: { content: "", role: "assistant" } }; const updateState = (newContent: string) => { return new Promise(resolve => { setTimeout(() => { answer += newContent; // We need to create a new object to trigger a re-render - const latestCompletion: RAGChatCompletionDelta = { + const latestCompletion: RAGChatCompletion = { ...chatCompletion, - delta: { content: answer, role: chatCompletion.message.role } + message: { content: answer, role: chatCompletion.message.role } }; setStreamedAnswers([...answers, [question, latestCompletion]]); resolve(null); @@ -69,22 +65,19 @@ const Chat = () => { try { setIsStreaming(true); for await (const response of result) { - if (!response.delta) { - continue; - } - if (response.role) { - chatCompletion.message.role = response.delta.role; - } - if (response.content) { - setIsLoading(false); - await updateState(response.delta.content); - } if (response.context) { chatCompletion.context = { ...chatCompletion.context, ...response.context }; } + if (response.delta && response.delta.role) { + chatCompletion.message.role = response.delta.role; + } + if (response.delta && response.delta.content) { + setIsLoading(false); + await updateState(response.delta.content); + } } } finally { setIsStreaming(false); @@ -118,12 +111,12 @@ const Chat = () => { } }; const chatClient: AIChatProtocolClient = new AIChatProtocolClient("/chat"); - if (shouldStream) { - const result = await chatClient.getStreamedCompletion(allMessages, options); - const parsedResponse = await handleAsyncRequest(question, answers, setStreamedAnswers, result); + if (shouldStream) { + const result = (await chatClient.getStreamedCompletion(allMessages, options)) as AsyncIterable; + const parsedResponse = await handleAsyncRequest(question, answers, result); setAnswers([...answers, [question, parsedResponse]]); } else { - const result = await chatClient.getCompletion(allMessages, options) as RAGChatCompletion; + const result = (await chatClient.getCompletion(allMessages, options)) as RAGChatCompletion; setAnswers([...answers, [question, result]]); } } catch (e) { @@ -165,7 +158,7 @@ const Chat = () => { const onUseAdvancedFlowChange = (_ev?: React.FormEvent, checked?: boolean) => { setUseAdvancedFlow(!!checked); - } + }; const onShouldStreamChange = (_ev?: React.FormEvent, checked?: boolean) => { setShouldStream(!!checked); @@ -213,7 +206,8 @@ const Chat = () => {
) : (
- {isStreaming && streamedAnswers.map((streamedAnswer, index) => ( + {isStreaming && + streamedAnswers.map((streamedAnswer, index) => (
@@ -230,7 +224,7 @@ const Chat = () => {
))} - {!isStreaming && + {!isStreaming && answers.map((answer, index) => (
@@ -298,7 +292,6 @@ const Chat = () => { onRenderFooterContent={() => setIsConfigPanelOpen(false)}>Close} isFooterAtBottom={true} > - { onChange={onRetrieveCountChange} /> - setRetrievalMode(retrievalMode)} - /> - + setRetrievalMode(retrievalMode)} />

Settings for final chat completion:

@@ -351,7 +341,6 @@ const Chat = () => { label="Stream chat completion responses" onChange={onShouldStreamChange} /> -
From bcfabd0be773621b8950ffc10a23d2c56a0da90f Mon Sep 17 00:00:00 2001 From: John Aziz Date: Tue, 23 Jul 2024 15:00:28 +0300 Subject: [PATCH 09/16] Update src/backend/fastapi_app/rag_simple.py Co-authored-by: Pamela Fox --- src/backend/fastapi_app/rag_simple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/fastapi_app/rag_simple.py b/src/backend/fastapi_app/rag_simple.py index 2280aa20..c68bbb5e 100644 --- a/src/backend/fastapi_app/rag_simple.py +++ b/src/backend/fastapi_app/rag_simple.py @@ -70,7 +70,7 @@ async def run( raise NotImplementedError @abstractmethod - async def retreive_and_build_context( + async def retrieve_and_build_context( self, chat_params: ChatParams, *args, From f5a27339d694b668a7345444af7c2ef1782b2efb Mon Sep 17 00:00:00 2001 From: John Aziz Date: Tue, 23 Jul 2024 12:11:12 +0000 Subject: [PATCH 10/16] apply feedback from pr review --- src/backend/fastapi_app/rag_advanced.py | 27 ++++---- src/backend/fastapi_app/rag_base.py | 82 ++++++++++++++++++++++++ src/backend/fastapi_app/rag_simple.py | 83 ++----------------------- 3 files changed, 100 insertions(+), 92 deletions(-) create mode 100644 src/backend/fastapi_app/rag_base.py diff --git a/src/backend/fastapi_app/rag_advanced.py b/src/backend/fastapi_app/rag_advanced.py index 3fc27748..5232ab23 100644 --- a/src/backend/fastapi_app/rag_advanced.py +++ b/src/backend/fastapi_app/rag_advanced.py @@ -16,7 +16,7 @@ from fastapi_app.postgres_models import Item from fastapi_app.postgres_searcher import PostgresSearcher from fastapi_app.query_rewriter import build_search_function, extract_search_arguments -from fastapi_app.rag_simple import ChatParams, RAGChatBase +from fastapi_app.rag_base import ChatParams, RAGChatBase class AdvancedRAGChat(RAGChatBase): @@ -35,14 +35,14 @@ def __init__( self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True) async def generate_search_query( - self, chat_params: ChatParams, query_response_token_limit: int + self, original_user_query: str, past_messages: list[ChatCompletionMessageParam], query_response_token_limit: int ) -> tuple[list[ChatCompletionMessageParam], Any | str | None, list]: """Generate an optimized keyword search query based on the chat history and the last question""" query_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=self.query_prompt_template, - new_user_content=chat_params.original_user_query, - past_messages=chat_params.past_messages, + new_user_content=original_user_query, + past_messages=past_messages, max_tokens=self.chat_token_limit - query_response_token_limit, # TODO: count functions fallback_to_default=True, ) @@ -58,11 +58,11 @@ async def generate_search_query( tool_choice="auto", ) - query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion) + query_text, filters = extract_search_arguments(original_user_query, chat_completion) return query_messages, query_text, filters - async def retreive_and_build_context( + async def retrieve_and_build_context( self, chat_params: ChatParams, query_text: str | Any | None, filters: list ) -> tuple[list[ChatCompletionMessageParam], list[Item]]: """Retrieve relevant items from the database and build a context for the chat model.""" @@ -98,12 +98,14 @@ async def run( # Generate an optimized keyword search query based on the chat history and the last question query_messages, query_text, filters = await self.generate_search_query( - chat_params=chat_params, query_response_token_limit=500 + original_user_query=chat_params.original_user_query, + past_messages=chat_params.past_messages, + query_response_token_limit=500, ) # Retrieve relevant items from the database with the GPT optimized query # Generate a contextual and content specific answer using the search results and chat history - contextual_messages, results = await self.retreive_and_build_context( + contextual_messages, results = await self.retrieve_and_build_context( chat_params=chat_params, query_text=query_text, filters=filters ) @@ -167,14 +169,13 @@ async def run_stream( ) -> AsyncGenerator[RetrievalResponseDelta, None]: chat_params = self.get_params(messages, overrides) - # Generate an optimized keyword search query based on the chat history and the last question query_messages, query_text, filters = await self.generate_search_query( - chat_params=chat_params, query_response_token_limit=500 + original_user_query=chat_params.original_user_query, + past_messages=chat_params.past_messages, + query_response_token_limit=500, ) - # Retrieve relevant items from the database with the GPT optimized query - # Generate a contextual and content specific answer using the search results and chat history - contextual_messages, results = await self.retreive_and_build_context( + contextual_messages, results = await self.retrieve_and_build_context( chat_params=chat_params, query_text=query_text, filters=filters ) diff --git a/src/backend/fastapi_app/rag_base.py b/src/backend/fastapi_app/rag_base.py new file mode 100644 index 00000000..586401a2 --- /dev/null +++ b/src/backend/fastapi_app/rag_base.py @@ -0,0 +1,82 @@ +import pathlib +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator +from typing import Any + +from openai.types.chat import ChatCompletionMessageParam +from pydantic import BaseModel + +from fastapi_app.api_models import ( + RetrievalResponse, + RetrievalResponseDelta, +) +from fastapi_app.postgres_models import Item + + +class ChatParams(BaseModel): + top: int = 3 + temperature: float = 0.3 + response_token_limit: int = 1024 + enable_text_search: bool + enable_vector_search: bool + original_user_query: str + past_messages: list[ChatCompletionMessageParam] + prompt_template: str + + +class RAGChatBase(ABC): + current_dir = pathlib.Path(__file__).parent + query_prompt_template = open(current_dir / "prompts/query.txt").read() + answer_prompt_template = open(current_dir / "prompts/answer.txt").read() + + def get_params(self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any]) -> ChatParams: + top: int = overrides.get("top", 3) + temperature: float = overrides.get("temperature", 0.3) + response_token_limit = 1024 + prompt_template = overrides.get("prompt_template") or self.answer_prompt_template + + enable_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] + enable_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] + + original_user_query = messages[-1]["content"] + if not isinstance(original_user_query, str): + raise ValueError("The most recent message content must be a string.") + past_messages = messages[:-1] + + return ChatParams( + top=top, + temperature=temperature, + response_token_limit=response_token_limit, + prompt_template=prompt_template, + enable_text_search=enable_text_search, + enable_vector_search=enable_vector_search, + original_user_query=original_user_query, + past_messages=past_messages, + ) + + @abstractmethod + async def retrieve_and_build_context( + self, + chat_params: ChatParams, + *args, + **kwargs, + ) -> tuple[list[ChatCompletionMessageParam], list[Item]]: + raise NotImplementedError + + @abstractmethod + async def run( + self, + messages: list[ChatCompletionMessageParam], + overrides: dict[str, Any] = {}, + ) -> RetrievalResponse: + raise NotImplementedError + + @abstractmethod + async def run_stream( + self, + messages: list[ChatCompletionMessageParam], + overrides: dict[str, Any] = {}, + ) -> AsyncGenerator[RetrievalResponseDelta, None]: + raise NotImplementedError + if False: + yield 0 diff --git a/src/backend/fastapi_app/rag_simple.py b/src/backend/fastapi_app/rag_simple.py index c68bbb5e..6fec8cdc 100644 --- a/src/backend/fastapi_app/rag_simple.py +++ b/src/backend/fastapi_app/rag_simple.py @@ -1,12 +1,9 @@ -import pathlib -from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from typing import Any from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from pydantic import BaseModel from fastapi_app.api_models import ( AIChatRoles, @@ -18,75 +15,7 @@ ) from fastapi_app.postgres_models import Item from fastapi_app.postgres_searcher import PostgresSearcher - - -class ChatParams(BaseModel): - top: int = 3 - temperature: float = 0.3 - response_token_limit: int = 1024 - enable_text_search: bool - enable_vector_search: bool - original_user_query: str - past_messages: list[ChatCompletionMessageParam] - prompt_template: str - - -class RAGChatBase(ABC): - current_dir = pathlib.Path(__file__).parent - query_prompt_template = open(current_dir / "prompts/query.txt").read() - answer_prompt_template = open(current_dir / "prompts/answer.txt").read() - - def get_params(self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any]) -> ChatParams: - top: int = overrides.get("top", 3) - temperature: float = overrides.get("temperature", 0.3) - response_token_limit = 1024 - prompt_template = overrides.get("prompt_template") or self.answer_prompt_template - - enable_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] - enable_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] - - original_user_query = messages[-1]["content"] - if not isinstance(original_user_query, str): - raise ValueError("The most recent message content must be a string.") - past_messages = messages[:-1] - - return ChatParams( - top=top, - temperature=temperature, - response_token_limit=response_token_limit, - prompt_template=prompt_template, - enable_text_search=enable_text_search, - enable_vector_search=enable_vector_search, - original_user_query=original_user_query, - past_messages=past_messages, - ) - - @abstractmethod - async def run( - self, - messages: list[ChatCompletionMessageParam], - overrides: dict[str, Any] = {}, - ) -> RetrievalResponse: - raise NotImplementedError - - @abstractmethod - async def retrieve_and_build_context( - self, - chat_params: ChatParams, - *args, - **kwargs, - ) -> tuple[list[ChatCompletionMessageParam], list[Item]]: - raise NotImplementedError - - @abstractmethod - async def run_stream( - self, - messages: list[ChatCompletionMessageParam], - overrides: dict[str, Any] = {}, - ) -> AsyncGenerator[RetrievalResponseDelta, None]: - raise NotImplementedError - if False: - yield 0 +from fastapi_app.rag_base import ChatParams, RAGChatBase class SimpleRAGChat(RAGChatBase): @@ -104,7 +33,7 @@ def __init__( self.chat_deployment = chat_deployment self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True) - async def retreive_and_build_context( + async def retrieve_and_build_context( self, chat_params: ChatParams ) -> tuple[list[ChatCompletionMessageParam], list[Item]]: """Retrieve relevant items from the database and build a context for the chat model.""" @@ -138,9 +67,7 @@ async def run( ) -> RetrievalResponse: chat_params = self.get_params(messages, overrides) - # Retrieve relevant items from the database - # Generate a contextual and content specific answer using the search results and chat history - contextual_messages, results = await self.retreive_and_build_context(chat_params=chat_params) + contextual_messages, results = await self.retrieve_and_build_context(chat_params=chat_params) chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create( # Azure OpenAI takes the deployment name as the model name @@ -192,9 +119,7 @@ async def run_stream( ) -> AsyncGenerator[RetrievalResponseDelta, None]: chat_params = self.get_params(messages, overrides) - # Retrieve relevant items from the database - # Generate a contextual and content specific answer using the search results and chat history - contextual_messages, results = await self.retreive_and_build_context(chat_params=chat_params) + contextual_messages, results = await self.retrieve_and_build_context(chat_params=chat_params) chat_completion_async_stream: AsyncStream[ ChatCompletionChunk From 266325630a6e929e853820473d04a2bac9b0c5b3 Mon Sep 17 00:00:00 2001 From: John Aziz Date: Tue, 23 Jul 2024 12:48:35 +0000 Subject: [PATCH 11/16] add pytest-snapshot --- requirements-dev.txt | 3 +- .../advanced_chat_flow_response.json | 68 ++++ .../advanced_chat_streaming_flow_response.txt | 2 + .../simple_chat_flow_response.json | 56 +++ .../simple_chat_streaming_flow_response.txt | 2 + tests/test_api_routes.py | 359 +----------------- 6 files changed, 150 insertions(+), 340 deletions(-) create mode 100644 tests/snapshots/test_api_routes/test_advanced_chat_flow/advanced_chat_flow_response.json create mode 100644 tests/snapshots/test_api_routes/test_advanved_chat_streaming_flow/advanced_chat_streaming_flow_response.txt create mode 100644 tests/snapshots/test_api_routes/test_simple_chat_flow/simple_chat_flow_response.json create mode 100644 tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/simple_chat_streaming_flow_response.txt diff --git a/requirements-dev.txt b/requirements-dev.txt index bfe6862b..14ba3930 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,10 @@ -r src/backend/requirements.txt ruff +mypy pre-commit pip-tools pip-compile-cross-platform pytest pytest-cov pytest-asyncio -mypy +pytest-snapshot diff --git a/tests/snapshots/test_api_routes/test_advanced_chat_flow/advanced_chat_flow_response.json b/tests/snapshots/test_api_routes/test_advanced_chat_flow/advanced_chat_flow_response.json new file mode 100644 index 00000000..2e9eb3ae --- /dev/null +++ b/tests/snapshots/test_api_routes/test_advanced_chat_flow/advanced_chat_flow_response.json @@ -0,0 +1,68 @@ +{ + "message": { + "content": "The capital of France is Paris. [Benefit_Options-2.pdf].", + "role": "assistant" + }, + "context": { + "data_points": { + "1": { + "id": 1, + "type": "Footwear", + "brand": "Daybird", + "name": "Wanderer Black Hiking Boots", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.", + "price": 109.99 + } + }, + "thoughts": [ + { + "title": "Prompt to generate search arguments", + "description": [ + "{'role': 'system', 'content': 'Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching database rows.\\nYou have access to an Azure PostgreSQL database with an items table that has columns for title, description, brand, price, and type.\\nGenerate a search query based on the conversation and the new question.\\nIf the question is not in English, translate the question to English before generating the search query.\\nIf you cannot generate a search query, return the original user question.\\nDO NOT return anything besides the query.'}", + "{'role': 'user', 'content': 'What is the capital of France?'}" + ], + "props": { + "model": "gpt-35-turbo", + "deployment": "gpt-35-turbo" + } + }, + { + "title": "Search using generated search arguments", + "description": "The capital of France is Paris. [Benefit_Options-2.pdf].", + "props": { + "top": 1, + "vector_search": true, + "text_search": true, + "filters": [] + } + }, + { + "title": "Search results", + "description": [ + { + "id": 1, + "type": "Footwear", + "brand": "Daybird", + "name": "Wanderer Black Hiking Boots", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.", + "price": 109.99 + } + ], + "props": {} + }, + { + "title": "Prompt to generate answer", + "description": [ + "{'role': 'system', 'content': \"Assistant helps customers with questions about products.\\nRespond as if you are a salesperson helping a customer in a store. Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the products.\\nIf there isn't enough information below, say you don't know.\\nDo not generate answers that don't use the sources below.\\nEach product has an ID in brackets followed by colon and the product details.\\nAlways include the product ID for each product you use in the response.\\nUse square brackets to reference the source, for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", + "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird Type:Footwear\\n\\n\"}" + ], + "props": { + "model": "gpt-35-turbo", + "deployment": "gpt-35-turbo" + } + } + ], + "followup_questions": null + }, + "session_state": null +} \ No newline at end of file diff --git a/tests/snapshots/test_api_routes/test_advanved_chat_streaming_flow/advanced_chat_streaming_flow_response.txt b/tests/snapshots/test_api_routes/test_advanved_chat_streaming_flow/advanced_chat_streaming_flow_response.txt new file mode 100644 index 00000000..8b65342f --- /dev/null +++ b/tests/snapshots/test_api_routes/test_advanved_chat_streaming_flow/advanced_chat_streaming_flow_response.txt @@ -0,0 +1,2 @@ +{"delta":null,"context":{"data_points":{"1":{"id":1,"type":"Footwear","brand":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.","price":109.99}},"thoughts":[{"title":"Prompt to generate search arguments","description":["{'role': 'system', 'content': 'Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching database rows.\\nYou have access to an Azure PostgreSQL database with an items table that has columns for title, description, brand, price, and type.\\nGenerate a search query based on the conversation and the new question.\\nIf the question is not in English, translate the question to English before generating the search query.\\nIf you cannot generate a search query, return the original user question.\\nDO NOT return anything besides the query.'}","{'role': 'user', 'content': 'What is the capital of France?'}"],"props":{"model":"gpt-35-turbo","deployment":"gpt-35-turbo"}},{"title":"Search using generated search arguments","description":"The capital of France is Paris. [Benefit_Options-2.pdf].","props":{"top":1,"vector_search":true,"text_search":true,"filters":[]}},{"title":"Search results","description":[{"id":1,"type":"Footwear","brand":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.","price":109.99}],"props":{}},{"title":"Prompt to generate answer","description":["{'role': 'system', 'content': \"Assistant helps customers with questions about products.\\nRespond as if you are a salesperson helping a customer in a store. Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the products.\\nIf there isn't enough information below, say you don't know.\\nDo not generate answers that don't use the sources below.\\nEach product has an ID in brackets followed by colon and the product details.\\nAlways include the product ID for each product you use in the response.\\nUse square brackets to reference the source, for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}","{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird Type:Footwear\\n\\n\"}"],"props":{"model":"gpt-35-turbo","deployment":"gpt-35-turbo"}}],"followup_questions":null},"session_state":null} +{"delta":{"content":"The capital of France is Paris. [Benefit_Options-2.pdf].","role":"assistant"},"context":null,"session_state":null} diff --git a/tests/snapshots/test_api_routes/test_simple_chat_flow/simple_chat_flow_response.json b/tests/snapshots/test_api_routes/test_simple_chat_flow/simple_chat_flow_response.json new file mode 100644 index 00000000..d5ecba21 --- /dev/null +++ b/tests/snapshots/test_api_routes/test_simple_chat_flow/simple_chat_flow_response.json @@ -0,0 +1,56 @@ +{ + "message": { + "content": "The capital of France is Paris. [Benefit_Options-2.pdf].", + "role": "assistant" + }, + "context": { + "data_points": { + "1": { + "id": 1, + "type": "Footwear", + "brand": "Daybird", + "name": "Wanderer Black Hiking Boots", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.", + "price": 109.99 + } + }, + "thoughts": [ + { + "title": "Search query for database", + "description": "What is the capital of France?", + "props": { + "top": 1, + "vector_search": true, + "text_search": true + } + }, + { + "title": "Search results", + "description": [ + { + "id": 1, + "type": "Footwear", + "brand": "Daybird", + "name": "Wanderer Black Hiking Boots", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.", + "price": 109.99 + } + ], + "props": {} + }, + { + "title": "Prompt to generate answer", + "description": [ + "{'role': 'system', 'content': \"Assistant helps customers with questions about products.\\nRespond as if you are a salesperson helping a customer in a store. Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the products.\\nIf there isn't enough information below, say you don't know.\\nDo not generate answers that don't use the sources below.\\nEach product has an ID in brackets followed by colon and the product details.\\nAlways include the product ID for each product you use in the response.\\nUse square brackets to reference the source, for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", + "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird Type:Footwear\\n\\n\"}" + ], + "props": { + "model": "gpt-35-turbo", + "deployment": "gpt-35-turbo" + } + } + ], + "followup_questions": null + }, + "session_state": null +} \ No newline at end of file diff --git a/tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/simple_chat_streaming_flow_response.txt b/tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/simple_chat_streaming_flow_response.txt new file mode 100644 index 00000000..6251bd52 --- /dev/null +++ b/tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/simple_chat_streaming_flow_response.txt @@ -0,0 +1,2 @@ +{"delta":null,"context":{"data_points":{"1":{"id":1,"type":"Footwear","brand":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.","price":109.99}},"thoughts":[{"title":"Search query for database","description":"What is the capital of France?","props":{"top":1,"vector_search":true,"text_search":true}},{"title":"Search results","description":[{"id":1,"type":"Footwear","brand":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.","price":109.99}],"props":{}},{"title":"Prompt to generate answer","description":["{'role': 'system', 'content': \"Assistant helps customers with questions about products.\\nRespond as if you are a salesperson helping a customer in a store. Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the products.\\nIf there isn't enough information below, say you don't know.\\nDo not generate answers that don't use the sources below.\\nEach product has an ID in brackets followed by colon and the product details.\\nAlways include the product ID for each product you use in the response.\\nUse square brackets to reference the source, for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}","{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird Type:Footwear\\n\\n\"}"],"props":{"model":"gpt-35-turbo","deployment":"gpt-35-turbo"}}],"followup_questions":null},"session_state":null} +{"delta":{"content":"The capital of France is Paris. [Benefit_Options-2.pdf].","role":"assistant"},"context":null,"session_state":null} diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py index 7f138759..1d48ae05 100644 --- a/tests/test_api_routes.py +++ b/tests/test_api_routes.py @@ -1,3 +1,5 @@ +import json + import pytest from tests.data import test_data @@ -105,7 +107,7 @@ async def test_search_handler_422(test_client): @pytest.mark.asyncio -async def test_simple_chat_flow(test_client): +async def test_simple_chat_flow(test_client, snapshot): """test the simple chat flow route with hybrid retrieval mode""" response = test_client.post( "/chat", @@ -120,115 +122,11 @@ async def test_simple_chat_flow(test_client): assert response.status_code == 200 assert response.headers["Content-Type"] == "application/json" - assert response_data["message"]["content"] == "The capital of France is Paris. [Benefit_Options-2.pdf]." - assert response_data["message"]["role"] == "assistant" - assert response_data["context"]["data_points"] == { - "1": { - "id": 1, - "name": "Wanderer Black Hiking Boots", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " - "your outdoor adventures. These boots are made with a waterproof " - "leather upper and a durable rubber sole for superior traction. With " - "their cushioned insole and padded collar, these boots will keep you " - "comfortable all day long.", - "brand": "Daybird", - "price": 109.99, - "type": "Footwear", - } - } - assert response_data["context"]["thoughts"] == [ - { - "description": "What is the capital of France?", - "props": {"text_search": True, "top": 1, "vector_search": True}, - "title": "Search query for database", - }, - { - "description": [ - { - "brand": "Daybird", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your " - "outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long.", - "id": 1, - "name": "Wanderer Black Hiking Boots", - "price": 109.99, - "type": "Footwear", - }, - ], - "props": {}, - "title": "Search results", - }, - { - "description": [ - "{'role': 'system', 'content': \"Assistant helps customers with questions about " - "products.\\nRespond as if you are a salesperson helping a customer in a store. " - "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " - "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " - "generate answers that don't use the sources below.\\nEach product has an ID in brackets " - "followed by colon and the product details.\\nAlways include the product ID for each product " - "you use in the response.\\nUse square brackets to reference the source, " - "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", - "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " - "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " - "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " - 'Type:Footwear\\n\\n"}', - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate answer", - }, - ] - assert response_data["context"]["thoughts"] == [ - { - "description": "What is the capital of France?", - "props": {"text_search": True, "top": 1, "vector_search": True}, - "title": "Search query for database", - }, - { - "description": [ - { - "brand": "Daybird", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " - "your outdoor adventures. These boots are made with a waterproof leather upper and " - "a durable rubber sole for superior traction. With their cushioned insole and padded " - "collar, these boots will keep you comfortable all day long.", - "id": 1, - "name": "Wanderer Black Hiking Boots", - "price": 109.99, - "type": "Footwear", - } - ], - "props": {}, - "title": "Search results", - }, - { - "description": [ - "{'role': 'system', 'content': \"Assistant helps customers with questions about " - "products.\\nRespond as if you are a salesperson helping a customer in a store. " - "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " - "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " - "generate answers that don't use the sources below.\\nEach product has an ID in brackets " - "followed by colon and the product details.\\nAlways include the product ID for each product " - "you use in the response.\\nUse square brackets to reference the source, " - "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", - "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " - "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " - "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " - 'Type:Footwear\\n\\n"}', - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate answer", - }, - ] - assert response_data["session_state"] is None + snapshot.assert_match(json.dumps(response_data, indent=4), "simple_chat_flow_response.json") @pytest.mark.asyncio -async def test_simple_chat_streaming_flow(test_client): +async def test_simple_chat_streaming_flow(test_client, snapshot): """test the simple chat streaming flow route with hybrid retrieval mode""" response = test_client.post( "/chat/stream", @@ -239,51 +137,17 @@ async def test_simple_chat_streaming_flow(test_client): "messages": [{"content": "What is the capital of France?", "role": "user"}], }, ) - response_data = response.content.split(b"\n") + response_data = response.content assert response.status_code == 200 assert response.headers["Content-Type"] == "application/x-ndjson" - assert response_data[0] == ( - b'{"delta":null,"context":{"data_points":{"1":{"id":1,"type":"Footwear","brand' - + b'":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird\'' - + b"s Wanderer Hiking Boots in sleek black are perfect for all your outdoor adve" - + b"ntures. These boots are made with a waterproof leather upper and a durable r" - + b"ubber sole for superior traction. With their cushioned insole and padded col" - + b'lar, these boots will keep you comfortable all day long.","price":109.99}},"' - + b'thoughts":[{"title":"Search query for database","description":"What is the c' - + b'apital of France?","props":{"top":1,"vector_search":true,"text_search":true}' - + b'},{"title":"Search results","description":[{"id":1,"type":"Footwear","brand"' - + b':"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird\'s' - + b" Wanderer Hiking Boots in sleek black are perfect for all your outdoor adven" - + b"tures. These boots are made with a waterproof leather upper and a durable ru" - + b"bber sole for superior traction. With their cushioned insole and padded coll" - + b'ar, these boots will keep you comfortable all day long.","price":109.99}],"p' - + b'rops":{}},{"title":"Prompt to generate answer","description":["{\'role\': ' - + b"'system', 'content': \\\"Assistant helps customers with questions abou" - + b"t products.\\\\nRespond as if you are a salesperson helping a customer in " - + b"a store. Do NOT respond with tables.\\\\nAnswer ONLY with the product deta" - + b"ils listed in the products.\\\\nIf there isn't enough information below, s" - + b"ay you don't know.\\\\nDo not generate answers that don't use the sources " - + b"below.\\\\nEach product has an ID in brackets followed by colon and the pr" - + b"oduct details.\\\\nAlways include the product ID for each product you use " - + b"in the response.\\\\nUse square brackets to reference the source, for exam" - + b"ple [52].\\\\nDon't combine citations, list each product separately, for e" - + b"xample [27][51].\\\"}\",\"{'role': 'user', 'content': \\\"What is the capi" - + b"tal of France?\\\\n\\\\nSources:\\\\n[1]:Name:Wanderer Black Hiking Boots " - + b"Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for a" - + b"ll your outdoor adventures. These boots are made with a waterproof leather u" - + b"pper and a durable rubber sole for superior traction. With their cushioned i" - + b"nsole and padded collar, these boots will keep you comfortable all day long." - + b' Price:109.99 Brand:Daybird Type:Footwear\\\\n\\\\n\\"}"],"props":{"model' - + b'":"gpt-35-turbo","deployment":"gpt-35-turbo"}}],"followup_questions":null},"' - + b'session_state":null}' - ) + snapshot.assert_match(response_data, "simple_chat_streaming_flow_response.txt") @pytest.mark.asyncio -async def test_advanved_chat_streaming_flow(test_client): - """test the advanced chat streaming flow route with hybrid retrieval mode""" +async def test_advanced_chat_flow(test_client, snapshot): + """test the advanced chat flow route with hybrid retrieval mode""" response = test_client.post( - "/chat/stream", + "/chat", json={ "context": { "overrides": {"top": 1, "use_advanced_flow": True, "retrieval_mode": "hybrid", "temperature": 0.3} @@ -291,62 +155,18 @@ async def test_advanved_chat_streaming_flow(test_client): "messages": [{"content": "What is the capital of France?", "role": "user"}], }, ) - response_data = response.content.split(b"\n") + response_data = response.json() + assert response.status_code == 200 - assert response.headers["Content-Type"] == "application/x-ndjson" - assert response_data[0] == ( - b'{"delta":null,"context":{"data_points":{"1":{"id":1,"type":"Footwear","brand' - + b'":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird\'' - + b"s Wanderer Hiking Boots in sleek black are perfect for all your outdoor adve" - + b"ntures. These boots are made with a waterproof leather upper and a durable r" - + b"ubber sole for superior traction. With their cushioned insole and padded col" - + b'lar, these boots will keep you comfortable all day long.","price":109.99}},"' - + b'thoughts":[{"title":"Prompt to generate search arguments","description":' - + b"[\"{'role': 'system', 'content': 'Below is a history of the conversat" - + b"ion so far, and a new question asked by the user that needs to be answered b" - + b"y searching database rows.\\\\nYou have access to an Azure PostgreSQL data" - + b"base with an items table that has columns for title, description, brand, pri" - + b"ce, and type.\\\\nGenerate a search query based on the conversation and th" - + b"e new question.\\\\nIf the question is not in English, translate the quest" - + b"ion to English before generating the search query.\\\\nIf you cannot gener" - + b"ate a search query, return the original user question.\\\\nDO NOT return a" - + b"nything besides the query.'}\",\"{'role': 'user', 'content': 'What is " - + b'the capital of France?\'}"],"props":{"model":"gpt-35-turbo","deployment":' - + b'"gpt-35-turbo"}},{"title":"Search using generated search arguments","descrip' - + b'tion":"The capital of France is Paris. [Benefit_Options-2.pdf].","props":{"t' - + b'op":1,"vector_search":true,"text_search":true,"filters":[]}},{"title":"Searc' - + b'h results","description":[{"id":1,"type":"Footwear","brand":"Daybird","name"' - + b':"Wanderer Black Hiking Boots","description":"Daybird\'s Wanderer Hiking ' - + b"Boots in sleek black are perfect for all your outdoor adventures. These boot" - + b"s are made with a waterproof leather upper and a durable rubber sole for sup" - + b"erior traction. With their cushioned insole and padded collar, these boots w" - + b'ill keep you comfortable all day long.","price":109.99}],"props":{}},{"title' - + b'":"Prompt to generate answer","description":["{\'role\': \'system\', \'co' - + b"ntent': \\\"Assistant helps customers with questions about products.\\\\nRes" - + b"pond as if you are a salesperson helping a customer in a store. Do NOT respo" - + b"nd with tables.\\\\nAnswer ONLY with the product details listed in the pro" - + b"ducts.\\\\nIf there isn't enough information below, say you don't know.\\\\n" - + b"Do not generate answers that don't use the sources below.\\\\nEach product" - + b" has an ID in brackets followed by colon and the product details.\\\\nAlwa" - + b"ys include the product ID for each product you use in the response.\\\\nUs" - + b"e square brackets to reference the source, for example [52].\\\\nDon't com" - + b'bine citations, list each product separately, for example [27][51].\\"}",' - + b"\"{'role': 'user', 'content': \\\"What is the capital of France?\\\\n" - + b"\\\\nSources:\\\\n[1]:Name:Wanderer Black Hiking Boots Description:Daybird's" - + b" Wanderer Hiking Boots in sleek black are perfect for all your outdoor adven" - + b"tures. These boots are made with a waterproof leather upper and a durable ru" - + b"bber sole for superior traction. With their cushioned insole and padded coll" - + b"ar, these boots will keep you comfortable all day long. Price:109.99 Brand:D" - + b'aybird Type:Footwear\\\\n\\\\n\\"}"],"props":{"model":"gpt-35-turbo","dep' - + b'loyment":"gpt-35-turbo"}}],"followup_questions":null},"session_state":null}' - ) + assert response.headers["Content-Type"] == "application/json" + snapshot.assert_match(json.dumps(response_data, indent=4), "advanced_chat_flow_response.json") @pytest.mark.asyncio -async def test_advanced_chat_flow(test_client): - """test the advanced chat flow route with hybrid retrieval mode""" +async def test_advanved_chat_streaming_flow(test_client, snapshot): + """test the advanced chat streaming flow route with hybrid retrieval mode""" response = test_client.post( - "/chat", + "/chat/stream", json={ "context": { "overrides": {"top": 1, "use_advanced_flow": True, "retrieval_mode": "hybrid", "temperature": 0.3} @@ -354,149 +174,10 @@ async def test_advanced_chat_flow(test_client): "messages": [{"content": "What is the capital of France?", "role": "user"}], }, ) - response_data = response.json() - + response_data = response.content assert response.status_code == 200 - assert response.headers["Content-Type"] == "application/json" - assert response_data["message"]["content"] == "The capital of France is Paris. [Benefit_Options-2.pdf]." - assert response_data["message"]["role"] == "assistant" - assert response_data["context"]["data_points"] == { - "1": { - "id": 1, - "name": "Wanderer Black Hiking Boots", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " - "your outdoor adventures. These boots are made with a waterproof " - "leather upper and a durable rubber sole for superior traction. With " - "their cushioned insole and padded collar, these boots will keep you " - "comfortable all day long.", - "brand": "Daybird", - "price": 109.99, - "type": "Footwear", - } - } - assert response_data["context"]["thoughts"] == [ - { - "description": [ - "{'role': 'system', 'content': 'Below is a history of the " - "conversation so far, and a new question asked by the user that " - "needs to be answered by searching database rows.\\nYou have " - "access to an Azure PostgreSQL database with an items table that " - "has columns for title, description, brand, price, and " - "type.\\nGenerate a search query based on the conversation and the " - "new question.\\nIf the question is not in English, translate the " - "question to English before generating the search query.\\nIf you " - "cannot generate a search query, return the original user " - "question.\\nDO NOT return anything besides the query.'}", - "{'role': 'user', 'content': 'What is the capital of France?'}", - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate search arguments", - }, - { - "description": "The capital of France is Paris. [Benefit_Options-2.pdf].", - "props": {"filters": [], "text_search": True, "top": 1, "vector_search": True}, - "title": "Search using generated search arguments", - }, - { - "description": [ - { - "brand": "Daybird", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your " - "outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long.", - "id": 1, - "name": "Wanderer Black Hiking Boots", - "price": 109.99, - "type": "Footwear", - }, - ], - "props": {}, - "title": "Search results", - }, - { - "description": [ - "{'role': 'system', 'content': \"Assistant helps customers with questions about " - "products.\\nRespond as if you are a salesperson helping a customer in a store. " - "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " - "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " - "generate answers that don't use the sources below.\\nEach product has an ID in brackets " - "followed by colon and the product details.\\nAlways include the product ID for each product " - "you use in the response.\\nUse square brackets to reference the source, " - "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", - "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " - "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " - "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " - 'Type:Footwear\\n\\n"}', - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate answer", - }, - ] - assert response_data["context"]["thoughts"] == [ - { - "description": [ - "{'role': 'system', 'content': 'Below is a history of the " - "conversation so far, and a new question asked by the user that " - "needs to be answered by searching database rows.\\nYou have " - "access to an Azure PostgreSQL database with an items table that " - "has columns for title, description, brand, price, and " - "type.\\nGenerate a search query based on the conversation and the " - "new question.\\nIf the question is not in English, translate the " - "question to English before generating the search query.\\nIf you " - "cannot generate a search query, return the original user " - "question.\\nDO NOT return anything besides the query.'}", - "{'role': 'user', 'content': 'What is the capital of France?'}", - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate search arguments", - }, - { - "description": "The capital of France is Paris. [Benefit_Options-2.pdf].", - "props": {"filters": [], "text_search": True, "top": 1, "vector_search": True}, - "title": "Search using generated search arguments", - }, - { - "description": [ - { - "brand": "Daybird", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " - "your outdoor adventures. These boots are made with a waterproof leather upper and " - "a durable rubber sole for superior traction. With their cushioned insole and padded " - "collar, these boots will keep you comfortable all day long.", - "id": 1, - "name": "Wanderer Black Hiking Boots", - "price": 109.99, - "type": "Footwear", - } - ], - "props": {}, - "title": "Search results", - }, - { - "description": [ - "{'role': 'system', 'content': \"Assistant helps customers with questions about " - "products.\\nRespond as if you are a salesperson helping a customer in a store. " - "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " - "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " - "generate answers that don't use the sources below.\\nEach product has an ID in brackets " - "followed by colon and the product details.\\nAlways include the product ID for each product " - "you use in the response.\\nUse square brackets to reference the source, " - "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", - "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " - "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " - "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " - 'Type:Footwear\\n\\n"}', - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate answer", - }, - ] - assert response_data["session_state"] is None + assert response.headers["Content-Type"] == "application/x-ndjson" + snapshot.assert_match(response_data, "advanced_chat_streaming_flow_response.txt") @pytest.mark.asyncio From 063bb9d7ff60d7ed0774a71da8df1fe5e906adca Mon Sep 17 00:00:00 2001 From: John Aziz Date: Tue, 23 Jul 2024 15:59:52 +0000 Subject: [PATCH 12/16] add type for chat overrides --- src/backend/fastapi_app/api_models.py | 20 +++++++++++- src/backend/fastapi_app/rag_advanced.py | 5 +-- src/backend/fastapi_app/rag_base.py | 33 ++++++++------------ src/backend/fastapi_app/rag_simple.py | 6 ++-- src/backend/fastapi_app/routes/api_routes.py | 8 ++--- src/frontend/src/api/models.ts | 10 +++++- src/frontend/src/pages/chat/Chat.tsx | 4 +-- 7 files changed, 53 insertions(+), 33 deletions(-) diff --git a/src/backend/fastapi_app/api_models.py b/src/backend/fastapi_app/api_models.py index db0ea2c9..6b5204e6 100644 --- a/src/backend/fastapi_app/api_models.py +++ b/src/backend/fastapi_app/api_models.py @@ -16,9 +16,27 @@ class Message(BaseModel): role: AIChatRoles = AIChatRoles.USER +class RetrievalMode(str, Enum): + TEXT = "text" + VECTORS = "vectors" + HYBRID = "hybrid" + + +class ChatRequestOverrides(BaseModel): + top: int = 3 + temperature: float = 0.3 + retrieval_mode: RetrievalMode = RetrievalMode.HYBRID + use_advanced_flow: bool = True + prompt_template: str | None = None + + +class ChatRequestContext(BaseModel): + overrides: ChatRequestOverrides + + class ChatRequest(BaseModel): messages: list[ChatCompletionMessageParam] - context: dict = {} + context: ChatRequestContext class ThoughtStep(BaseModel): diff --git a/src/backend/fastapi_app/rag_advanced.py b/src/backend/fastapi_app/rag_advanced.py index 5232ab23..1f9d8fc5 100644 --- a/src/backend/fastapi_app/rag_advanced.py +++ b/src/backend/fastapi_app/rag_advanced.py @@ -7,6 +7,7 @@ from fastapi_app.api_models import ( AIChatRoles, + ChatRequestOverrides, Message, RAGContext, RetrievalResponse, @@ -92,7 +93,7 @@ async def retrieve_and_build_context( async def run( self, messages: list[ChatCompletionMessageParam], - overrides: dict[str, Any] = {}, + overrides: ChatRequestOverrides, ) -> RetrievalResponse: chat_params = self.get_params(messages, overrides) @@ -165,7 +166,7 @@ async def run( async def run_stream( self, messages: list[ChatCompletionMessageParam], - overrides: dict[str, Any] = {}, + overrides: ChatRequestOverrides, ) -> AsyncGenerator[RetrievalResponseDelta, None]: chat_params = self.get_params(messages, overrides) diff --git a/src/backend/fastapi_app/rag_base.py b/src/backend/fastapi_app/rag_base.py index 586401a2..57247a0f 100644 --- a/src/backend/fastapi_app/rag_base.py +++ b/src/backend/fastapi_app/rag_base.py @@ -1,27 +1,20 @@ import pathlib from abc import ABC, abstractmethod from collections.abc import AsyncGenerator -from typing import Any from openai.types.chat import ChatCompletionMessageParam -from pydantic import BaseModel -from fastapi_app.api_models import ( - RetrievalResponse, - RetrievalResponseDelta, -) +from fastapi_app.api_models import ChatRequestOverrides, RetrievalResponse, RetrievalResponseDelta from fastapi_app.postgres_models import Item -class ChatParams(BaseModel): - top: int = 3 - temperature: float = 0.3 +class ChatParams(ChatRequestOverrides): + prompt_template: str response_token_limit: int = 1024 enable_text_search: bool enable_vector_search: bool original_user_query: str past_messages: list[ChatCompletionMessageParam] - prompt_template: str class RAGChatBase(ABC): @@ -29,14 +22,12 @@ class RAGChatBase(ABC): query_prompt_template = open(current_dir / "prompts/query.txt").read() answer_prompt_template = open(current_dir / "prompts/answer.txt").read() - def get_params(self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any]) -> ChatParams: - top: int = overrides.get("top", 3) - temperature: float = overrides.get("temperature", 0.3) + def get_params(self, messages: list[ChatCompletionMessageParam], overrides: ChatRequestOverrides) -> ChatParams: response_token_limit = 1024 - prompt_template = overrides.get("prompt_template") or self.answer_prompt_template + prompt_template = overrides.prompt_template or self.answer_prompt_template - enable_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] - enable_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] + enable_text_search = overrides.retrieval_mode in ["text", "hybrid", None] + enable_vector_search = overrides.retrieval_mode in ["vectors", "hybrid", None] original_user_query = messages[-1]["content"] if not isinstance(original_user_query, str): @@ -44,8 +35,10 @@ def get_params(self, messages: list[ChatCompletionMessageParam], overrides: dict past_messages = messages[:-1] return ChatParams( - top=top, - temperature=temperature, + top=overrides.top, + temperature=overrides.temperature, + retrieval_mode=overrides.retrieval_mode, + use_advanced_flow=overrides.use_advanced_flow, response_token_limit=response_token_limit, prompt_template=prompt_template, enable_text_search=enable_text_search, @@ -67,7 +60,7 @@ async def retrieve_and_build_context( async def run( self, messages: list[ChatCompletionMessageParam], - overrides: dict[str, Any] = {}, + overrides: ChatRequestOverrides, ) -> RetrievalResponse: raise NotImplementedError @@ -75,7 +68,7 @@ async def run( async def run_stream( self, messages: list[ChatCompletionMessageParam], - overrides: dict[str, Any] = {}, + overrides: ChatRequestOverrides, ) -> AsyncGenerator[RetrievalResponseDelta, None]: raise NotImplementedError if False: diff --git a/src/backend/fastapi_app/rag_simple.py b/src/backend/fastapi_app/rag_simple.py index 6fec8cdc..f3a5754a 100644 --- a/src/backend/fastapi_app/rag_simple.py +++ b/src/backend/fastapi_app/rag_simple.py @@ -1,5 +1,4 @@ from collections.abc import AsyncGenerator -from typing import Any from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam @@ -7,6 +6,7 @@ from fastapi_app.api_models import ( AIChatRoles, + ChatRequestOverrides, Message, RAGContext, RetrievalResponse, @@ -63,7 +63,7 @@ async def retrieve_and_build_context( async def run( self, messages: list[ChatCompletionMessageParam], - overrides: dict[str, Any] = {}, + overrides: ChatRequestOverrides, ) -> RetrievalResponse: chat_params = self.get_params(messages, overrides) @@ -115,7 +115,7 @@ async def run( async def run_stream( self, messages: list[ChatCompletionMessageParam], - overrides: dict[str, Any] = {}, + overrides: ChatRequestOverrides, ) -> AsyncGenerator[RetrievalResponseDelta, None]: chat_params = self.get_params(messages, overrides) diff --git a/src/backend/fastapi_app/routes/api_routes.py b/src/backend/fastapi_app/routes/api_routes.py index e531885e..ac53497e 100644 --- a/src/backend/fastapi_app/routes/api_routes.py +++ b/src/backend/fastapi_app/routes/api_routes.py @@ -93,7 +93,7 @@ async def chat_handler( openai_chat: ChatClient, chat_request: ChatRequest, ): - overrides = chat_request.context.get("overrides", {}) + overrides = chat_request.context.overrides searcher = PostgresSearcher( db_session=database_session, @@ -102,7 +102,7 @@ async def chat_handler( embed_model=context.openai_embed_model, embed_dimensions=context.openai_embed_dimensions, ) - if overrides.get("use_advanced_flow"): + if overrides.use_advanced_flow: run_ragchat = AdvancedRAGChat( searcher=searcher, openai_chat_client=openai_chat.client, @@ -129,7 +129,7 @@ async def chat_stream_handler( openai_chat: ChatClient, chat_request: ChatRequest, ): - overrides = chat_request.context.get("overrides", {}) + overrides = chat_request.context.overrides searcher = PostgresSearcher( db_session=database_session, @@ -138,7 +138,7 @@ async def chat_stream_handler( embed_model=context.openai_embed_model, embed_dimensions=context.openai_embed_dimensions, ) - if overrides.get("use_advanced_flow"): + if overrides.use_advanced_flow: run_ragchat = AdvancedRAGChat( searcher=searcher, openai_chat_client=openai_chat.client, diff --git a/src/frontend/src/api/models.ts b/src/frontend/src/api/models.ts index cd7d3c3b..4e9c3e26 100644 --- a/src/frontend/src/api/models.ts +++ b/src/frontend/src/api/models.ts @@ -1,4 +1,4 @@ -import { AIChatCompletion, AIChatCompletionDelta } from "@microsoft/ai-chat-protocol"; +import { AIChatCompletion, AIChatCompletionDelta, AIChatCompletionOperationOptions } from "@microsoft/ai-chat-protocol"; export const enum RetrievalMode { Hybrid = "hybrid", @@ -14,6 +14,14 @@ export type ChatAppRequestOverrides = { prompt_template?: string; }; +export type ChatAppRequestContext = { + overrides: ChatAppRequestOverrides; +}; + +export interface ChatAppRequestOptions extends AIChatCompletionOperationOptions { + context: ChatAppRequestContext +} + export type Thoughts = { title: string; description: any; // It can be any output from the api diff --git a/src/frontend/src/pages/chat/Chat.tsx b/src/frontend/src/pages/chat/Chat.tsx index 2f7569dc..da0b6934 100644 --- a/src/frontend/src/pages/chat/Chat.tsx +++ b/src/frontend/src/pages/chat/Chat.tsx @@ -4,7 +4,7 @@ import { SparkleFilled } from "@fluentui/react-icons"; import styles from "./Chat.module.css"; -import { RetrievalMode, RAGChatCompletion, RAGChatCompletionDelta } from "../../api"; +import { RetrievalMode, RAGChatCompletion, RAGChatCompletionDelta, ChatAppRequestOptions } from "../../api"; import { AIChatProtocolClient, AIChatMessage } from "@microsoft/ai-chat-protocol"; import { Answer, AnswerError, AnswerLoading } from "../../components/Answer"; import { QuestionInput } from "../../components/QuestionInput"; @@ -99,7 +99,7 @@ const Chat = () => { { content: answer[1].message.content, role: "assistant" } ]); const allMessages: AIChatMessage[] = [...messages, { content: question, role: "user" }]; - const options = { + const options: ChatAppRequestOptions = { context: { overrides: { use_advanced_flow: useAdvancedFlow, From 3d726dde422f9aa5322ac8f72010bb6c466ecea4 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 23 Jul 2024 20:09:24 +0000 Subject: [PATCH 13/16] Use jsonlines for gitattributes to work --- tests/test_api_routes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py index 1d48ae05..cd221fbb 100644 --- a/tests/test_api_routes.py +++ b/tests/test_api_routes.py @@ -140,7 +140,7 @@ async def test_simple_chat_streaming_flow(test_client, snapshot): response_data = response.content assert response.status_code == 200 assert response.headers["Content-Type"] == "application/x-ndjson" - snapshot.assert_match(response_data, "simple_chat_streaming_flow_response.txt") + snapshot.assert_match(response_data, "simple_chat_streaming_flow_response.jsonlines") @pytest.mark.asyncio @@ -163,7 +163,7 @@ async def test_advanced_chat_flow(test_client, snapshot): @pytest.mark.asyncio -async def test_advanved_chat_streaming_flow(test_client, snapshot): +async def test_advanced_chat_streaming_flow(test_client, snapshot): """test the advanced chat streaming flow route with hybrid retrieval mode""" response = test_client.post( "/chat/stream", @@ -177,7 +177,7 @@ async def test_advanved_chat_streaming_flow(test_client, snapshot): response_data = response.content assert response.status_code == 200 assert response.headers["Content-Type"] == "application/x-ndjson" - snapshot.assert_match(response_data, "advanced_chat_streaming_flow_response.txt") + snapshot.assert_match(response_data, "advanced_chat_streaming_flow_response.jsonlines") @pytest.mark.asyncio From ebdb48bd742ed10804bb6a5ee1a6d6953c65e814 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 23 Jul 2024 20:58:44 +0000 Subject: [PATCH 14/16] Update test filenames and folder name --- .../advanced_chat_streaming_flow_response.jsonlines} | 0 ...response.txt => simple_chat_streaming_flow_response.jsonlines} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/snapshots/test_api_routes/{test_advanved_chat_streaming_flow/advanced_chat_streaming_flow_response.txt => test_advanced_chat_streaming_flow/advanced_chat_streaming_flow_response.jsonlines} (100%) rename tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/{simple_chat_streaming_flow_response.txt => simple_chat_streaming_flow_response.jsonlines} (100%) diff --git a/tests/snapshots/test_api_routes/test_advanved_chat_streaming_flow/advanced_chat_streaming_flow_response.txt b/tests/snapshots/test_api_routes/test_advanced_chat_streaming_flow/advanced_chat_streaming_flow_response.jsonlines similarity index 100% rename from tests/snapshots/test_api_routes/test_advanved_chat_streaming_flow/advanced_chat_streaming_flow_response.txt rename to tests/snapshots/test_api_routes/test_advanced_chat_streaming_flow/advanced_chat_streaming_flow_response.jsonlines diff --git a/tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/simple_chat_streaming_flow_response.txt b/tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/simple_chat_streaming_flow_response.jsonlines similarity index 100% rename from tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/simple_chat_streaming_flow_response.txt rename to tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/simple_chat_streaming_flow_response.jsonlines From 0227789ab3d715a9a0ddb0645d24711a8693fa49 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 23 Jul 2024 21:23:04 +0000 Subject: [PATCH 15/16] Refactor to avoid error when streaming --- src/backend/fastapi_app/__init__.py | 2 +- src/backend/fastapi_app/api_models.py | 9 ++ src/backend/fastapi_app/rag_advanced.py | 142 +++++++------------ src/backend/fastapi_app/rag_base.py | 42 +++--- src/backend/fastapi_app/rag_simple.py | 84 +++++------ src/backend/fastapi_app/routes/api_routes.py | 44 +++--- 6 files changed, 141 insertions(+), 182 deletions(-) diff --git a/src/backend/fastapi_app/__init__.py b/src/backend/fastapi_app/__init__.py index 55a60334..077930ba 100644 --- a/src/backend/fastapi_app/__init__.py +++ b/src/backend/fastapi_app/__init__.py @@ -51,7 +51,7 @@ def create_app(testing: bool = False): else: if not testing: load_dotenv(override=True) - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.DEBUG) # Turn off particularly noisy INFO level logs from Azure Core SDK: logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING) diff --git a/src/backend/fastapi_app/api_models.py b/src/backend/fastapi_app/api_models.py index 6b5204e6..c98ca76d 100644 --- a/src/backend/fastapi_app/api_models.py +++ b/src/backend/fastapi_app/api_models.py @@ -74,3 +74,12 @@ class ItemPublic(BaseModel): class ItemWithDistance(ItemPublic): distance: float + + +class ChatParams(ChatRequestOverrides): + prompt_template: str + response_token_limit: int = 1024 + enable_text_search: bool + enable_vector_search: bool + original_user_query: str + past_messages: list[ChatCompletionMessageParam] diff --git a/src/backend/fastapi_app/rag_advanced.py b/src/backend/fastapi_app/rag_advanced.py index 1f9d8fc5..ddbd65ce 100644 --- a/src/backend/fastapi_app/rag_advanced.py +++ b/src/backend/fastapi_app/rag_advanced.py @@ -7,7 +7,6 @@ from fastapi_app.api_models import ( AIChatRoles, - ChatRequestOverrides, Message, RAGContext, RetrievalResponse, @@ -63,10 +62,15 @@ async def generate_search_query( return query_messages, query_text, filters - async def retrieve_and_build_context( - self, chat_params: ChatParams, query_text: str | Any | None, filters: list - ) -> tuple[list[ChatCompletionMessageParam], list[Item]]: - """Retrieve relevant items from the database and build a context for the chat model.""" + async def prepare_context( + self, chat_params: ChatParams + ) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]: + query_messages, query_text, filters = await self.generate_search_query( + original_user_query=chat_params.original_user_query, + past_messages=chat_params.past_messages, + query_response_token_limit=500, + ) + # Retrieve relevant items from the database with the GPT optimized query results = await self.searcher.search_and_embed( query_text, @@ -88,28 +92,41 @@ async def retrieve_and_build_context( max_tokens=self.chat_token_limit - chat_params.response_token_limit, fallback_to_default=True, ) - return contextual_messages, results - async def run( + thoughts = [ + ThoughtStep( + title="Prompt to generate search arguments", + description=[str(message) for message in query_messages], + props=( + {"model": self.chat_model, "deployment": self.chat_deployment} + if self.chat_deployment + else {"model": self.chat_model} + ), + ), + ThoughtStep( + title="Search using generated search arguments", + description=query_text, + props={ + "top": chat_params.top, + "vector_search": chat_params.enable_vector_search, + "text_search": chat_params.enable_text_search, + "filters": filters, + }, + ), + ThoughtStep( + title="Search results", + description=[result.to_dict() for result in results], + ), + ] + return contextual_messages, results, thoughts + + async def answer( self, - messages: list[ChatCompletionMessageParam], - overrides: ChatRequestOverrides, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], ) -> RetrievalResponse: - chat_params = self.get_params(messages, overrides) - - # Generate an optimized keyword search query based on the chat history and the last question - query_messages, query_text, filters = await self.generate_search_query( - original_user_query=chat_params.original_user_query, - past_messages=chat_params.past_messages, - query_response_token_limit=500, - ) - - # Retrieve relevant items from the database with the GPT optimized query - # Generate a contextual and content specific answer using the search results and chat history - contextual_messages, results = await self.retrieve_and_build_context( - chat_params=chat_params, query_text=query_text, filters=filters - ) - chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, @@ -126,30 +143,8 @@ async def run( ), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, - thoughts=[ - ThoughtStep( - title="Prompt to generate search arguments", - description=[str(message) for message in query_messages], - props=( - {"model": self.chat_model, "deployment": self.chat_deployment} - if self.chat_deployment - else {"model": self.chat_model} - ), - ), - ThoughtStep( - title="Search using generated search arguments", - description=query_text, - props={ - "top": chat_params.top, - "vector_search": chat_params.enable_vector_search, - "text_search": chat_params.enable_text_search, - "filters": filters, - }, - ), - ThoughtStep( - title="Search results", - description=[result.to_dict() for result in results], - ), + thoughts=earlier_thoughts + + [ ThoughtStep( title="Prompt to generate answer", description=[str(message) for message in contextual_messages], @@ -163,23 +158,13 @@ async def run( ), ) - async def run_stream( + async def answer_stream( self, - messages: list[ChatCompletionMessageParam], - overrides: ChatRequestOverrides, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], ) -> AsyncGenerator[RetrievalResponseDelta, None]: - chat_params = self.get_params(messages, overrides) - - query_messages, query_text, filters = await self.generate_search_query( - original_user_query=chat_params.original_user_query, - past_messages=chat_params.past_messages, - query_response_token_limit=500, - ) - - contextual_messages, results = await self.retrieve_and_build_context( - chat_params=chat_params, query_text=query_text, filters=filters - ) - chat_completion_async_stream: AsyncStream[ ChatCompletionChunk ] = await self.openai_chat_client.chat.completions.create( @@ -192,38 +177,11 @@ async def run_stream( stream=True, ) - # Forcefully close the database session before yielding the response - # Yielding keeps the connection open while streaming the response until the end - # The connection closes when it returns back to the context manger in the dependencies - await self.searcher.db_session.close() - yield RetrievalResponseDelta( context=RAGContext( data_points={item.id: item.to_dict() for item in results}, - thoughts=[ - ThoughtStep( - title="Prompt to generate search arguments", - description=[str(message) for message in query_messages], - props=( - {"model": self.chat_model, "deployment": self.chat_deployment} - if self.chat_deployment - else {"model": self.chat_model} - ), - ), - ThoughtStep( - title="Search using generated search arguments", - description=query_text, - props={ - "top": chat_params.top, - "vector_search": chat_params.enable_vector_search, - "text_search": chat_params.enable_text_search, - "filters": filters, - }, - ), - ThoughtStep( - title="Search results", - description=[result.to_dict() for result in results], - ), + thoughts=earlier_thoughts + + [ ThoughtStep( title="Prompt to generate answer", description=[str(message) for message in contextual_messages], diff --git a/src/backend/fastapi_app/rag_base.py b/src/backend/fastapi_app/rag_base.py index 57247a0f..f7f7bff4 100644 --- a/src/backend/fastapi_app/rag_base.py +++ b/src/backend/fastapi_app/rag_base.py @@ -4,19 +4,16 @@ from openai.types.chat import ChatCompletionMessageParam -from fastapi_app.api_models import ChatRequestOverrides, RetrievalResponse, RetrievalResponseDelta +from fastapi_app.api_models import ( + ChatParams, + ChatRequestOverrides, + RetrievalResponse, + RetrievalResponseDelta, + ThoughtStep, +) from fastapi_app.postgres_models import Item -class ChatParams(ChatRequestOverrides): - prompt_template: str - response_token_limit: int = 1024 - enable_text_search: bool - enable_vector_search: bool - original_user_query: str - past_messages: list[ChatCompletionMessageParam] - - class RAGChatBase(ABC): current_dir = pathlib.Path(__file__).parent query_prompt_template = open(current_dir / "prompts/query.txt").read() @@ -48,27 +45,28 @@ def get_params(self, messages: list[ChatCompletionMessageParam], overrides: Chat ) @abstractmethod - async def retrieve_and_build_context( - self, - chat_params: ChatParams, - *args, - **kwargs, - ) -> tuple[list[ChatCompletionMessageParam], list[Item]]: + async def prepare_context( + self, chat_params: ChatParams + ) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]: raise NotImplementedError @abstractmethod - async def run( + async def answer( self, - messages: list[ChatCompletionMessageParam], - overrides: ChatRequestOverrides, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], ) -> RetrievalResponse: raise NotImplementedError @abstractmethod - async def run_stream( + async def answer_stream( self, - messages: list[ChatCompletionMessageParam], - overrides: ChatRequestOverrides, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], ) -> AsyncGenerator[RetrievalResponseDelta, None]: raise NotImplementedError if False: diff --git a/src/backend/fastapi_app/rag_simple.py b/src/backend/fastapi_app/rag_simple.py index f3a5754a..2e6d859e 100644 --- a/src/backend/fastapi_app/rag_simple.py +++ b/src/backend/fastapi_app/rag_simple.py @@ -6,7 +6,6 @@ from fastapi_app.api_models import ( AIChatRoles, - ChatRequestOverrides, Message, RAGContext, RetrievalResponse, @@ -33,9 +32,9 @@ def __init__( self.chat_deployment = chat_deployment self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True) - async def retrieve_and_build_context( + async def prepare_context( self, chat_params: ChatParams - ) -> tuple[list[ChatCompletionMessageParam], list[Item]]: + ) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]: """Retrieve relevant items from the database and build a context for the chat model.""" # Retrieve relevant items from the database @@ -58,17 +57,31 @@ async def retrieve_and_build_context( max_tokens=self.chat_token_limit - chat_params.response_token_limit, fallback_to_default=True, ) - return contextual_messages, results - async def run( + thoughts = [ + ThoughtStep( + title="Search query for database", + description=chat_params.original_user_query, + props={ + "top": chat_params.top, + "vector_search": chat_params.enable_vector_search, + "text_search": chat_params.enable_text_search, + }, + ), + ThoughtStep( + title="Search results", + description=[result.to_dict() for result in results], + ), + ] + return contextual_messages, results, thoughts + + async def answer( self, - messages: list[ChatCompletionMessageParam], - overrides: ChatRequestOverrides, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], ) -> RetrievalResponse: - chat_params = self.get_params(messages, overrides) - - contextual_messages, results = await self.retrieve_and_build_context(chat_params=chat_params) - chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, @@ -85,20 +98,8 @@ async def run( ), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, - thoughts=[ - ThoughtStep( - title="Search query for database", - description=chat_params.original_user_query if chat_params.enable_text_search else None, - props={ - "top": chat_params.top, - "vector_search": chat_params.enable_vector_search, - "text_search": chat_params.enable_text_search, - }, - ), - ThoughtStep( - title="Search results", - description=[result.to_dict() for result in results], - ), + thoughts=earlier_thoughts + + [ ThoughtStep( title="Prompt to generate answer", description=[str(message) for message in contextual_messages], @@ -112,15 +113,13 @@ async def run( ), ) - async def run_stream( + async def answer_stream( self, - messages: list[ChatCompletionMessageParam], - overrides: ChatRequestOverrides, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], ) -> AsyncGenerator[RetrievalResponseDelta, None]: - chat_params = self.get_params(messages, overrides) - - contextual_messages, results = await self.retrieve_and_build_context(chat_params=chat_params) - chat_completion_async_stream: AsyncStream[ ChatCompletionChunk ] = await self.openai_chat_client.chat.completions.create( @@ -133,28 +132,11 @@ async def run_stream( stream=True, ) - # Forcefully close the database session before yielding the response - # Yielding keeps the connection open while streaming the response until the end - # The connection closes when it returns back to the context manger in the dependencies - await self.searcher.db_session.close() - yield RetrievalResponseDelta( context=RAGContext( data_points={item.id: item.to_dict() for item in results}, - thoughts=[ - ThoughtStep( - title="Search query for database", - description=chat_params.original_user_query if chat_params.enable_text_search else None, - props={ - "top": chat_params.top, - "vector_search": chat_params.enable_vector_search, - "text_search": chat_params.enable_text_search, - }, - ), - ThoughtStep( - title="Search results", - description=[result.to_dict() for result in results], - ), + thoughts=earlier_thoughts + + [ ThoughtStep( title="Prompt to generate answer", description=[str(message) for message in contextual_messages], diff --git a/src/backend/fastapi_app/routes/api_routes.py b/src/backend/fastapi_app/routes/api_routes.py index ac53497e..44e08320 100644 --- a/src/backend/fastapi_app/routes/api_routes.py +++ b/src/backend/fastapi_app/routes/api_routes.py @@ -93,8 +93,6 @@ async def chat_handler( openai_chat: ChatClient, chat_request: ChatRequest, ): - overrides = chat_request.context.overrides - searcher = PostgresSearcher( db_session=database_session, openai_embed_client=openai_embed.client, @@ -102,22 +100,28 @@ async def chat_handler( embed_model=context.openai_embed_model, embed_dimensions=context.openai_embed_dimensions, ) - if overrides.use_advanced_flow: - run_ragchat = AdvancedRAGChat( + rag_flow: SimpleRAGChat | AdvancedRAGChat + if chat_request.context.overrides.use_advanced_flow: + rag_flow = AdvancedRAGChat( searcher=searcher, openai_chat_client=openai_chat.client, chat_model=context.openai_chat_model, chat_deployment=context.openai_chat_deployment, - ).run + ) else: - run_ragchat = SimpleRAGChat( + rag_flow = SimpleRAGChat( searcher=searcher, openai_chat_client=openai_chat.client, chat_model=context.openai_chat_model, chat_deployment=context.openai_chat_deployment, - ).run + ) + + chat_params = rag_flow.get_params(chat_request.messages, chat_request.context.overrides) - response = await run_ragchat(chat_request.messages, overrides=overrides) + contextual_messages, results, thoughts = await rag_flow.prepare_context(chat_params) + response = await rag_flow.answer( + chat_params=chat_params, contextual_messages=contextual_messages, results=results, earlier_thoughts=thoughts + ) return response @@ -129,8 +133,6 @@ async def chat_stream_handler( openai_chat: ChatClient, chat_request: ChatRequest, ): - overrides = chat_request.context.overrides - searcher = PostgresSearcher( db_session=database_session, openai_embed_client=openai_embed.client, @@ -138,20 +140,30 @@ async def chat_stream_handler( embed_model=context.openai_embed_model, embed_dimensions=context.openai_embed_dimensions, ) - if overrides.use_advanced_flow: - run_ragchat = AdvancedRAGChat( + + rag_flow: SimpleRAGChat | AdvancedRAGChat + if chat_request.context.overrides.use_advanced_flow: + rag_flow = AdvancedRAGChat( searcher=searcher, openai_chat_client=openai_chat.client, chat_model=context.openai_chat_model, chat_deployment=context.openai_chat_deployment, - ).run_stream + ) else: - run_ragchat = SimpleRAGChat( + rag_flow = SimpleRAGChat( searcher=searcher, openai_chat_client=openai_chat.client, chat_model=context.openai_chat_model, chat_deployment=context.openai_chat_deployment, - ).run_stream + ) + + chat_params = rag_flow.get_params(chat_request.messages, chat_request.context.overrides) - result = run_ragchat(chat_request.messages, overrides=overrides) + # Intentionally do this before we stream down a response, to avoid using database connections during stream + # See https://github.com/tiangolo/fastapi/discussions/11321 + contextual_messages, results, thoughts = await rag_flow.prepare_context(chat_params) + + result = rag_flow.answer_stream( + chat_params=chat_params, contextual_messages=contextual_messages, results=results, earlier_thoughts=thoughts + ) return StreamingResponse(content=format_as_ndjson(result), media_type="application/x-ndjson") From 8b32800a1db67040a36425d70af3e11f67544c1b Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 23 Jul 2024 21:35:00 +0000 Subject: [PATCH 16/16] Change back to log info --- src/backend/fastapi_app/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/fastapi_app/__init__.py b/src/backend/fastapi_app/__init__.py index 077930ba..55a60334 100644 --- a/src/backend/fastapi_app/__init__.py +++ b/src/backend/fastapi_app/__init__.py @@ -51,7 +51,7 @@ def create_app(testing: bool = False): else: if not testing: load_dotenv(override=True) - logging.basicConfig(level=logging.DEBUG) + logging.basicConfig(level=logging.INFO) # Turn off particularly noisy INFO level logs from Azure Core SDK: logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)