diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a80f38..c9b3076 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,4 @@ +exclude: ^vectorstore/ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 diff --git a/src/paper_query/base_chatbot.py b/src/paper_query/base_chatbot.py index b39fc94..2fa8122 100644 --- a/src/paper_query/base_chatbot.py +++ b/src/paper_query/base_chatbot.py @@ -15,7 +15,7 @@ def main(): parser.add_argument( "--model", type=str, - default="gpt-4o", + default="gpt-4.1", help="Model name to use for the chatbot", ) parser.add_argument( diff --git a/src/paper_query/code_query_chatbot.py b/src/paper_query/code_query_chatbot.py index 104691a..683df5f 100644 --- a/src/paper_query/code_query_chatbot.py +++ b/src/paper_query/code_query_chatbot.py @@ -18,7 +18,7 @@ def main(): parser.add_argument( "--model", type=str, - default="gpt-4o", + default="gpt-4.1", help="Model name to use for the chatbot", ) parser.add_argument( diff --git a/src/paper_query/constants/__init__.py b/src/paper_query/constants/__init__.py index 77efa44..02ab920 100644 --- a/src/paper_query/constants/__init__.py +++ b/src/paper_query/constants/__init__.py @@ -1,6 +1,6 @@ from ._api_keys import GROQ_API_KEY, HUGGINGFACE_API_KEY, OPENAI_API_KEY from ._paths import PERSIST_DIRECTORY, assets_dir, data_dir, project_dir, src_dir, test_dir -from ._strings import RAG_DOC_ID +from ._strings import RAG_DOC_ID, STREAMLIT_CHEAP_MODEL, STREAMLIT_EXPENSIVE_MODEL __all__ = [ "OPENAI_API_KEY", @@ -13,4 +13,6 @@ "data_dir", "assets_dir", "RAG_DOC_ID", + "STREAMLIT_CHEAP_MODEL", + "STREAMLIT_EXPENSIVE_MODEL", ] diff --git a/src/paper_query/constants/_strings.py b/src/paper_query/constants/_strings.py index 70b6935..908f71f 100644 --- a/src/paper_query/constants/_strings.py +++ b/src/paper_query/constants/_strings.py @@ -1 +1,4 @@ RAG_DOC_ID = "file" + +STREAMLIT_CHEAP_MODEL = "GPT-4.1-nano" +STREAMLIT_EXPENSIVE_MODEL = "GPT-4.1" diff --git a/src/paper_query/hybrid_query_chatbot.py b/src/paper_query/hybrid_query_chatbot.py index eacc113..fda4b66 100644 --- a/src/paper_query/hybrid_query_chatbot.py +++ b/src/paper_query/hybrid_query_chatbot.py @@ -18,7 +18,7 @@ def main(): parser.add_argument( "--model", type=str, - default="gpt-4o", + default="gpt-4.1", help="Model name to use for the chatbot", ) parser.add_argument( diff --git a/src/paper_query/paper_query_chatbot.py b/src/paper_query/paper_query_chatbot.py index 20de261..8fea67b 100644 --- a/src/paper_query/paper_query_chatbot.py +++ b/src/paper_query/paper_query_chatbot.py @@ -18,7 +18,7 @@ def main(): parser.add_argument( "--model", type=str, - default="gpt-4o", + default="gpt-4.1", help="Model name to use for the chatbot", ) parser.add_argument( diff --git a/src/paper_query/paper_query_plus_chatbot.py b/src/paper_query/paper_query_plus_chatbot.py index 8ff401e..3ef13cd 100644 --- a/src/paper_query/paper_query_plus_chatbot.py +++ b/src/paper_query/paper_query_plus_chatbot.py @@ -19,7 +19,7 @@ def main(): parser.add_argument( "--model", type=str, - default="gpt-4o", + default="gpt-4.1", help="Model name to use for the chatbot", ) parser.add_argument( diff --git a/src/paper_query/ui/components/chat_interface.py b/src/paper_query/ui/components/chat_interface.py index 3f241ee..aa8ba4d 100644 --- a/src/paper_query/ui/components/chat_interface.py +++ b/src/paper_query/ui/components/chat_interface.py @@ -6,24 +6,33 @@ def display_chat_interface() -> None: if "messages" not in st.session_state: st.session_state.messages = [] - for message in st.session_state.messages: - with st.chat_message(message["role"]): - st.markdown(message["content"]) + message_container = st.container() - if "chatbot_confirmed" in st.session_state and st.session_state.chatbot_confirmed: + # Display all past messages in the message container + with message_container: + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + + # Create the input at the bottom + if "chatbot_ready" in st.session_state and st.session_state.chatbot_ready: if user_input := st.chat_input("What is your question?", key="user_input"): - st.chat_message("user").markdown(user_input) + # Add user message to UI + with message_container: + st.chat_message("user").markdown(user_input) st.session_state.messages.append({"role": "user", "content": user_input}) - with st.chat_message("assistant"): - message_placeholder = st.empty() - full_response = "" + # Add assistant response to UI + with message_container: + with st.chat_message("assistant"): + message_placeholder = st.empty() + full_response = "" - for response_chunk in st.session_state.chatbot.stream_response(user_input): - full_response += response_chunk - message_placeholder.markdown(full_response) + for response_chunk in st.session_state.chatbot.stream_response(user_input): + full_response += response_chunk + message_placeholder.markdown(full_response) - message_placeholder.markdown(full_response) + message_placeholder.markdown(full_response) st.session_state.messages.append({"role": "assistant", "content": full_response}) else: diff --git a/src/paper_query/ui/components/sidebar_inputs.py b/src/paper_query/ui/components/sidebar_inputs.py index 2f39400..d30d552 100644 --- a/src/paper_query/ui/components/sidebar_inputs.py +++ b/src/paper_query/ui/components/sidebar_inputs.py @@ -17,7 +17,7 @@ def get_class_params(cls) -> list[str]: ] -def model_name_input(name: str = "gpt-4o") -> str: +def model_name_input(name: str = "gpt-4.1") -> str: """Get the model name from the sidebar.""" return st.sidebar.text_input("Model Name", value=name, key="model_name_input") diff --git a/src/paper_query/ui/components/text.py b/src/paper_query/ui/components/text.py new file mode 100644 index 0000000..d5a9086 --- /dev/null +++ b/src/paper_query/ui/components/text.py @@ -0,0 +1,55 @@ +from paper_query.constants import STREAMLIT_EXPENSIVE_MODEL + +ABOUT = f""" +**StrainRelief is a tool for calculating ligand strain energy with quantum mechanical +accuracy**. + +##### What is ligand strain energy? +Ligand strain energy is the energy difference between the bound and unbound conformations +of a ligand. It's an important component in structure-based small molecule drug design. + +##### How does StrainRelief work? +StrainRelief uses a MACE Neural Network Potential (NNP) trained on a large database of +Density Functional Theory (DFT) calculations to estimate ligand strain of neutral molecules +with quantum accuracy. + +##### About this chatbot +This chatbot is built using a hybrid retrieval and cached augmented generation (RAG/CAG) +approach: + +1. The full StrainRelief [paper](https://arxiv.org/abs/2503.13352) is loaded and cached +in the context window for all queries +2. Reference papers cited in StrainRelief are embedded and available for retrieval +3. The StrainRelief code [repository](https://github.com/prescient-design/StrainRelief) +is embedded and available for retrieval + +The chatbot is currently has a naive modular framework. When you ask a question, the +system: +- Retrieves relevant information from the references and code +- Combines this with the full paper context +- Uses the LLM to generate a response based on all available information + +The chatbot uses the following components: +- **LLM**: {STREAMLIT_EXPENSIVE_MODEL} from OpenAI for generating responses +- **Embedding**: OpenAI embeddings for vector search +- **Vector Database**: ChromaDB for storing and retrieving embedded documents + +Feel free to ask about the StrainRelief methodology, implementation details, or +how to use the tool for drug discovery applications. + """ + +ABSTRACT = """ +:gray[**Abstract**: Ligand strain energy, the energy difference between the +bound and unbound conformations of a ligand, is an important component of +structure-based small molecule drug design. A large majority of observed +ligands in protein-small molecule co-crystal structures bind in low-strain +conformations, making strain energy a useful filter for structure-based drug +design. In this work we present a tool for calculating ligand strain with a +high accuracy. StrainRelief uses a MACE Neural Network Potential (NNP), +trained on a large database of Density Functional Theory (DFT) calculations +to estimate ligand strain of neutral molecules with quantum accuracy. We show +that this tool estimates strain energy differences relative to DFT to within +1.4 kcal/mol, more accurately than alternative NNPs. These results highlight +the utility of NNPs in drug discovery, and provide a useful tool for drug +discovery teams.] +""" diff --git a/src/paper_query/ui/components/validate_key.py b/src/paper_query/ui/components/validate_key.py new file mode 100644 index 0000000..47d2658 --- /dev/null +++ b/src/paper_query/ui/components/validate_key.py @@ -0,0 +1,25 @@ +import streamlit as st +from loguru import logger +from openai import OpenAI + +from paper_query.constants import STREAMLIT_CHEAP_MODEL, STREAMLIT_EXPENSIVE_MODEL + + +def validate_openai_api_key(api_key: str): + """Validates the OpenAI API key and updates the session state accordingly.""" + if api_key and api_key != st.session_state.last_validated_key: + try: + client = OpenAI(api_key=api_key) + client.models.list() + logger.debug("API key validation successful.") + st.session_state.model_name = STREAMLIT_EXPENSIVE_MODEL + st.session_state.last_validated_key = api_key + except Exception as e: + logger.error(f"API key validation failed: {e}") + st.sidebar.error("Invalid API key. Please check your OpenAI API key.") + st.session_state.model_name = STREAMLIT_CHEAP_MODEL + st.session_state.last_validated_key = None # Reset if validation fails + elif not api_key: + # Reset to cheap model if key is cleared + st.session_state.model_name = STREAMLIT_CHEAP_MODEL + st.session_state.last_validated_key = None diff --git a/src/paper_query/ui/custom_app.py b/src/paper_query/ui/custom_app.py index 6967f07..4f61fc8 100644 --- a/src/paper_query/ui/custom_app.py +++ b/src/paper_query/ui/custom_app.py @@ -20,7 +20,7 @@ def streamlit_chatbot(): chatbot_args = get_chatbot_params(selected_chatbot_class) if st.sidebar.button("Confirm Chatbot", key="confirm_chatbot_button"): - st.session_state.chatbot_confirmed = True + st.session_state.chatbot_ready = True st.session_state.chatbot = selected_chatbot_class(**chatbot_args) st.sidebar.success(f"{selected_label} is ready!") st.title(f"{selected_label} Chatbot") diff --git a/src/paper_query/ui/strain_relief_app.py b/src/paper_query/ui/strain_relief_app.py index c6bde2c..c4eead8 100644 --- a/src/paper_query/ui/strain_relief_app.py +++ b/src/paper_query/ui/strain_relief_app.py @@ -4,59 +4,71 @@ from loguru import logger from paper_query.chatbots import HybridQueryChatbot -from paper_query.constants import assets_dir +from paper_query.constants import STREAMLIT_CHEAP_MODEL, STREAMLIT_EXPENSIVE_MODEL, assets_dir from paper_query.ui.components.chat_interface import display_chat_interface +from paper_query.ui.components.text import ABOUT, ABSTRACT +from paper_query.ui.components.validate_key import validate_openai_api_key # Configure logger to use DEBUG level logger.remove() logger.add(sys.stderr, level="DEBUG") +def initialize_session_state(): + """Initialize session state variables.""" + if "chatbot_ready" not in st.session_state: + st.session_state.chatbot_ready = True + + if "chatbot" not in st.session_state: + st.session_state.chatbot = None + + if "model_name" not in st.session_state: + st.session_state.model_name = STREAMLIT_CHEAP_MODEL + + def strain_relief_chatbot(): """Chatbot for the StrainRelief paper.""" - st.session_state.chatbot_confirmed = True - if "chatbot" not in st.session_state: - st.session_state.chatbot = HybridQueryChatbot( - model_name="gpt-4o", - model_provider="openai", - paper_path=str(assets_dir / "strainrelief_preprint.pdf"), - references_dir=str(assets_dir / "references"), - ) + initialize_session_state() st.title("The StrainRelief Chatbot") + chat_tab, about_tab = st.tabs(["Chat", "About"]) - st.markdown( - "This retrieval augmented generation (RAG) chatbot is designed to answer questions about " - "the StrainRelief. The chatbot has access to the [paper](https://arxiv.org/abs/2503.13352)," - " all references, and the code " - "[repository](https://github.com/prescient-design/StrainRelief)." + st.sidebar.title("API Configuration") + # Enter API key in sidebar + openai_api_key = st.sidebar.text_input( + "OpenAI API Key", + type="password", + help="If you don't have an API key, you can get one from [OpenAI](https://platform.openai.com/api-keys).", + key="api_input", ) - if "messages" not in st.session_state: - st.markdown( - ":gray[**Abstract**: Ligand strain energy, the energy difference between the bound and " - "unbound conformations of a ligand, is an important component of structure-based small " - "molecule drug design. A large majority of observed ligands in protein-small molecule " - "co-crystal structures bind in low-strain conformations, making strain energy a useful " - "filter for structure-based drug design. In this work we present a tool for " - "calculating ligand strain with a high accuracy. StrainRelief uses a MACE Neural " - "Network Potential (NNP), trained on a large database of Density Functional Theory " - "(DFT) calculations to estimate ligand strain of neutral molecules with quantum " - "accuracy. We show that this tool estimates strain energy differences relative to DFT " - "to within 1.4 kcal/mol, more accurately than alternative NNPs. These results " - "highlight the utility of NNPs in drug discovery, and provide a useful tool for drug " - "discovery teams.]" - ) - - display_chat_interface() + validate_openai_api_key(openai_api_key) + # Display current model + st.sidebar.markdown(f"Using **{st.session_state.model_name}** model.") -if __name__ == "__main__": - if sys.platform != "linux": # Skip for GitHub actions - # Get API keys from Streamlit secrets - from paper_query import constants + st.session_state.chatbot = HybridQueryChatbot( + model_name=st.session_state.model_name.lower(), + model_provider="openai", + paper_path=str(assets_dir / "strainrelief_preprint.pdf"), + references_dir=str(assets_dir / "references"), + ) + + with chat_tab: + if "messages" not in st.session_state: + st.markdown(ABSTRACT) + + # Show info message only when using nano model + if st.session_state.model_name == STREAMLIT_CHEAP_MODEL: + st.info( + f"You are currently using {STREAMLIT_CHEAP_MODEL}. Add a valid OpenAI API key " + f"to access the more powerful {STREAMLIT_EXPENSIVE_MODEL} model." + ) - constants.OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] - constants.GROQ_API_KEY = st.secrets["GROQ_API_KEY"] - constants.HUGGINGFACE_API_KEY = st.secrets["HUGGINGFACE_API_KEY"] + display_chat_interface() + with about_tab: + st.markdown(ABOUT) + + +if __name__ == "__main__": strain_relief_chatbot() diff --git a/test/data/test_loaders.py b/test/data/test_loaders.py index 572be37..32bbef6 100644 --- a/test/data/test_loaders.py +++ b/test/data/test_loaders.py @@ -22,7 +22,7 @@ def test_pypdf_loader_w_images(test_assets_dir): """Test the pypdf_loader_w_images function.""" path = test_assets_dir / "example_pdf.pdf" # TODO: change to free model - doc = pypdf_loader_w_images(path, "gpt-4o-mini", "openai") + doc = pypdf_loader_w_images(path, "gpt-4.1-nano", "openai") assert isinstance(doc, Document) diff --git a/test/ui/components/test_sidebar_inputs.py b/test/ui/components/test_sidebar_inputs.py index 54bb26d..73be478 100644 --- a/test/ui/components/test_sidebar_inputs.py +++ b/test/ui/components/test_sidebar_inputs.py @@ -13,8 +13,8 @@ def test_model_name_input(): - assert model_name_input() == "gpt-4o" - assert model_name_input("gpt-4o-mini") == "gpt-4o-mini" + assert model_name_input() == "gpt-4.1" + assert model_name_input("gpt-4.1-nano") == "gpt-4.1-nano" def test_model_provider_input(): @@ -38,7 +38,7 @@ def test_code_dir_input(): def test_get_param(): - assert get_param("model_name") == "gpt-4o" + assert get_param("model_name") == "gpt-4.1" assert get_param("model_provider") == "openai" assert get_param("paper_path") == str(assets_dir / "strainrelief_preprint.pdf") assert get_param("references_dir") == str(assets_dir / "references") @@ -48,4 +48,4 @@ def test_get_param(): def test_get_chatbot_params(): - assert get_chatbot_params(BaseChatbot) == {"model_name": "gpt-4o", "model_provider": "openai"} + assert get_chatbot_params(BaseChatbot) == {"model_name": "gpt-4.1", "model_provider": "openai"} diff --git a/test/ui/components/test_validate_key.py b/test/ui/components/test_validate_key.py new file mode 100644 index 0000000..b5c323c --- /dev/null +++ b/test/ui/components/test_validate_key.py @@ -0,0 +1,22 @@ +import pytest +import streamlit as st +from paper_query.constants import OPENAI_API_KEY, STREAMLIT_CHEAP_MODEL, STREAMLIT_EXPENSIVE_MODEL +from paper_query.ui.components.validate_key import validate_openai_api_key + + +@pytest.mark.app +@pytest.mark.parametrize( + "api_key, last_key, model_name", + [ + (OPENAI_API_KEY, OPENAI_API_KEY, STREAMLIT_EXPENSIVE_MODEL), + (None, None, STREAMLIT_CHEAP_MODEL), + ("invalid_key", None, STREAMLIT_CHEAP_MODEL), + ], +) +def test_validate_openai_api_key_correct(api_key, last_key, model_name): + """Test the OpenAI API key validation.""" + st.session_state.last_validated_key = True + + validate_openai_api_key(api_key) + assert st.session_state.last_validated_key == last_key + assert st.session_state.model_name == model_name diff --git a/test/ui/test_custom_app.py b/test/ui/test_custom_app.py index 40871aa..b6b98ef 100644 --- a/test/ui/test_custom_app.py +++ b/test/ui/test_custom_app.py @@ -31,7 +31,7 @@ def test_confirm_chatbot(app): @pytest.mark.app def test_model_selection(app): """Test model selection text input.""" - assert app.sidebar.text_input("model_name_input").value == "gpt-4o" + assert app.sidebar.text_input("model_name_input").value == "gpt-4.1" app.sidebar.text_input("model_name_input").set_value(MODEL_NAME) assert app.sidebar.text_input("model_name_input").value == MODEL_NAME diff --git a/tutorials/langchain_chatbot_v3.ipynb b/tutorials/langchain_chatbot_v3.ipynb deleted file mode 100644 index fe7e97a..0000000 --- a/tutorials/langchain_chatbot_v3.ipynb +++ /dev/null @@ -1,412 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Build a Chatbot\n", - "\n", - "We'll go over an example of how to design and implement an LLM-powered chatbot. This chatbot will be able to have a conversation and remember previous interactions with a chat model.\n", - "\n", - "from [this](https://python.langchain.com/docs/tutorials/chatbot/) LangChain **v0.3** tutorial." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.3.21\n", - "0.3.47\n" - ] - } - ], - "source": [ - "import langgraph\n", - "import langchain\n", - "import langchain_core\n", - "from langchain_openai import ChatOpenAI\n", - "from langchain_core.messages import HumanMessage, AIMessage\n", - "\n", - "import os\n", - "\n", - "from paper_query.constants.api_keys import OPENAI_API_KEY\n", - "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n", - "\n", - "print(langchain.__version__)\n", - "print(langchain_core.__version__)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain.chat_models import init_chat_model\n", - "\n", - "model = init_chat_model(\"gpt-4o-mini\", model_provider=\"openai\")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AIMessage(content='Hi Bob! How can I assist you today?', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 11, 'prompt_tokens': 11, 'total_tokens': 22, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_e4fa3702df', 'id': 'chatcmpl-BEEbdfOpQvGtB0LTEcFpTG7btW6z5', 'finish_reason': 'stop', 'logprobs': None}, id='run-dc4f33e1-62f8-4962-8f20-039be16a97ee-0', usage_metadata={'input_tokens': 11, 'output_tokens': 11, 'total_tokens': 22, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.invoke([HumanMessage(content=\"Hi! I'm Bob\")])\n", - "# invoking the model is easy, but it currently has no concept of state i.e. it has no memory and cannot answer follow-up questions\n", - "# HumanMessage: HumanMessages are messages that are passed in from a human to the model." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AIMessage(content=\"I'm sorry, but I don't know your name. If you'd like to share it, feel free!\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 21, 'prompt_tokens': 11, 'total_tokens': 32, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_b8bc95a0ac', 'id': 'chatcmpl-BEEbjrDKCnvmNzwWfVHjJyLdv4gCh', 'finish_reason': 'stop', 'logprobs': None}, id='run-43de1432-dc17-451d-ab44-7b37dc0f90ec-0', usage_metadata={'input_tokens': 11, 'output_tokens': 21, 'total_tokens': 32, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.invoke([HumanMessage(content=\"What's my name?\")])\n", - "# no concept of state / memory. It is STATELESS" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AIMessage(content='Your name is Bob! How can I assist you today, Bob?', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 33, 'total_tokens': 48, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_b8bc95a0ac', 'id': 'chatcmpl-BEEbrnT4Qla9qzMAWsfxMgg17Cz0E', 'finish_reason': 'stop', 'logprobs': None}, id='run-6de0ffeb-0c42-4a1f-8c8d-e823988699f3-0', usage_metadata={'input_tokens': 33, 'output_tokens': 15, 'total_tokens': 48, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# you can add multiple messages at once, which the model uses as \"memory\"\n", - "model.invoke(\n", - " [\n", - " HumanMessage(content=\"Hi! I'm Bob\"),\n", - " AIMessage(content=\"Hello Bob! How can I assist you today?\"),\n", - " HumanMessage(content=\"What's my name?\"),\n", - " ]\n", - ")\n", - "# For a chatbot we want the model to remeber previous messages and answers automatically" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "from langgraph.checkpoint.memory import MemorySaver\n", - "from langgraph.graph import START, MessagesState, StateGraph\n", - "# LangGraph has a built-in persistence layer i.e. it is STATEFUL.\n", - "# Wrapping out model in a minimal LangGraph app allows message history.\n", - "\n", - "# Define the function that calls the model\n", - "def call_model(state: MessagesState):\n", - " response = model.invoke(state[\"messages\"])\n", - " return {\"messages\": response}\n", - "\n", - "# Define a new graph and the (single) node in the graph\n", - "workflow = StateGraph(state_schema=MessagesState)\n", - "workflow.add_edge(START, \"model\")\n", - "workflow.add_node(\"model\", call_model)\n", - "\n", - "# Add memory\n", - "memory = MemorySaver()\n", - "app = workflow.compile(checkpointer=memory)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==================================\u001b[1m Ai Message \u001b[0m==================================\n", - "\n", - "Hello again, Bob! What’s on your mind today?\n", - "==================================\u001b[1m Ai Message \u001b[0m==================================\n", - "\n", - "Your name is Bob! What would you like to discuss?\n" - ] - } - ], - "source": [ - "# Different states are identified using a \"thread_id\". Different states are not aware of one another.\n", - "config = {\"configurable\": {\"thread_id\": \"abc123\"}}\n", - "\n", - "query = \"Hi! I'm Bob.\"\n", - "\n", - "input_messages = [HumanMessage(query)]\n", - "output = app.invoke({\"messages\": input_messages}, config)\n", - "output[\"messages\"][-1].pretty_print()\n", - "\n", - "query = \"What's my name?\"\n", - "\n", - "input_messages = [HumanMessage(query)]\n", - "output = app.invoke({\"messages\": input_messages}, config)\n", - "output[\"messages\"][-1].pretty_print()\n", - "# \"app\" is STATEFUL and remebers previous messages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==================================\u001b[1m Ai Message \u001b[0m==================================\n", - "\n", - "I don't have access to personal information or previous interactions, so I don't know your name. If you'd like to share it or if there's something else you'd like to discuss, feel free to let me know!\n" - ] - } - ], - "source": [ - "config = {\"configurable\": {\"thread_id\": \"abc234\"}}\n", - "# Change the thread_id to start a new conversation with a new history.\n", - "query = \"What's my name?\"\n", - "\n", - "input_messages = [HumanMessage(query)]\n", - "output = app.invoke({\"messages\": input_messages}, config)\n", - "output[\"messages\"][-1].pretty_print()" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n", - "\n", - "prompt_template = ChatPromptTemplate.from_messages(\n", - " [\n", - " (\n", - " \"system\",\n", - " \"You talk like a {job}. Answer all questions to the best of your ability.\",\n", - " ),\n", - " MessagesPlaceholder(variable_name=\"messages\"),\n", - " ]\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Sequence\n", - "\n", - "from langchain_core.messages import BaseMessage\n", - "from langgraph.graph.message import add_messages\n", - "from typing_extensions import Annotated, TypedDict\n", - "\n", - "\n", - "class State(TypedDict):\n", - " messages: Annotated[Sequence[BaseMessage], add_messages]\n", - " job: str\n", - "\n", - "\n", - "workflow = StateGraph(state_schema=State)\n", - "\n", - "\n", - "def call_model(state: State):\n", - " prompt = prompt_template.invoke(state)\n", - " response = model.invoke(prompt)\n", - " return {\"messages\": [response]}\n", - "\n", - "\n", - "workflow.add_edge(START, \"model\")\n", - "workflow.add_node(\"model\", call_model)\n", - "\n", - "memory = MemorySaver()\n", - "app = workflow.compile(checkpointer=memory)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==================================\u001b[1m Ai Message \u001b[0m==================================\n", - "\n", - "Ahoy there, Jim! What be on ye mind today, matey?\n" - ] - } - ], - "source": [ - "config = {\"configurable\": {\"thread_id\": \"abc345\"}}\n", - "query = \"Hi! I'm Jim.\"\n", - "\n", - "input_messages = [HumanMessage(query)]\n", - "output = app.invoke({\"messages\": input_messages, \"job\": \"pirate\"}, config)\n", - "output[\"messages\"][-1].pretty_print()" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_core.messages import trim_messages, SystemMessage\n", - "\n", - "trimmer = trim_messages(\n", - " max_tokens=65,\n", - " strategy=\"last\",\n", - " token_counter=model,\n", - " include_system=True,\n", - " allow_partial=False,\n", - " start_on=\"human\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "workflow = StateGraph(state_schema=State)\n", - "\n", - "\n", - "def call_model(state: State):\n", - " trimmed_messages = trimmer.invoke(state[\"messages\"])\n", - " prompt = prompt_template.invoke(\n", - " {\"messages\": trimmed_messages, \"job\": state[\"job\"]}\n", - " )\n", - " response = model.invoke(prompt)\n", - " return {\"messages\": [response]}\n", - "\n", - "\n", - "workflow.add_edge(START, \"model\")\n", - "workflow.add_node(\"model\", call_model)\n", - "\n", - "memory = MemorySaver()\n", - "app = workflow.compile(checkpointer=memory)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==================================\u001b[1m Ai Message \u001b[0m==================================\n", - "\n", - "Ahoy there, Jim! I be glad t’ meet ye on this fine day! What be troublin’ yer sails, matey?\n" - ] - } - ], - "source": [ - "messages = [SystemMessage(content=\"You talk like a pirate. Answer all questions to the best of your ability.\")]\n", - "\n", - "input_messages = messages + [HumanMessage(query)]\n", - "output = app.invoke(\n", - " {\"messages\": input_messages, \"job\": \"pirate\"},\n", - " config,\n", - ")\n", - "output[\"messages\"][-1].pretty_print()" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "|Ah|oy| there|,| Jim|!| What| brings| ye| to| these| tre|acher|ous| waters|?| Be| ye| look|in|’| for| treasure| or| seek|in|'| knowledge|,| mate|y|?| Arr|r|!||" - ] - } - ], - "source": [ - "input_messages = [HumanMessage(query)]\n", - "for chunk, metadata in app.stream(\n", - " {\"messages\": input_messages, \"job\": \"pirate\"},\n", - " config,\n", - " stream_mode=\"messages\",\n", - "):\n", - " if isinstance(chunk, AIMessage): # Filter to just model responses\n", - " print(chunk.content, end=\"|\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "paper_query_v3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}