From 5da6eb626b356fd7b9fa97ae4ab93dfa368ebc33 Mon Sep 17 00:00:00 2001 From: Ewan Wallace Date: Thu, 27 Mar 2025 10:20:00 +0000 Subject: [PATCH 1/3] added llama3 via groq --- requirements.in | 1 + src/paper_query/constants/__init__.py | 10 ++++++++-- src/paper_query/constants/_api_keys.py | 2 ++ src/paper_query/constants/_constants.py | 4 ++++ src/paper_query/llm/models.py | 6 ++++-- 5 files changed, 19 insertions(+), 4 deletions(-) create mode 100644 src/paper_query/constants/_constants.py diff --git a/requirements.in b/requirements.in index 0200d84..3e8374c 100644 --- a/requirements.in +++ b/requirements.in @@ -1,6 +1,7 @@ streamlit langchain langchain-openai +langchain-groq langchain_community pypdf chromadb diff --git a/src/paper_query/constants/__init__.py b/src/paper_query/constants/__init__.py index fc54fff..5173d59 100644 --- a/src/paper_query/constants/__init__.py +++ b/src/paper_query/constants/__init__.py @@ -1,3 +1,9 @@ -from ._api_keys import OPENAI_API_KEY +from ._api_keys import GROQ_API_KEY, HUGGINGFACE_API_KEY, OPENAI_API_KEY +from ._constants import MODEL_PROVIDER_MAP -__all__ = ["OPENAI_API_KEY"] +__all__ = [ + "OPENAI_API_KEY", + "HUGGINGFACE_API_KEY", + "GROQ_API_KEY", + "MODEL_PROVIDER_MAP", +] diff --git a/src/paper_query/constants/_api_keys.py b/src/paper_query/constants/_api_keys.py index 6337aa4..1323d71 100644 --- a/src/paper_query/constants/_api_keys.py +++ b/src/paper_query/constants/_api_keys.py @@ -4,3 +4,5 @@ api_keys = load_api_keys() OPENAI_API_KEY = api_keys.get("OPENAI_API_KEY", None) +HUGGINGFACE_API_KEY = api_keys.get("HUGGINGFACE_API_KEY", None) +GROQ_API_KEY = api_keys.get("GROQ_API_KEY", None) diff --git a/src/paper_query/constants/_constants.py b/src/paper_query/constants/_constants.py new file mode 100644 index 0000000..039cd7b --- /dev/null +++ b/src/paper_query/constants/_constants.py @@ -0,0 +1,4 @@ +MODEL_PROVIDER_MAP = { + "llama-3.1-8b-instant": "groq", + "gpt-4o": "openai", +} diff --git a/src/paper_query/llm/models.py b/src/paper_query/llm/models.py index dd99729..15f749b 100644 --- a/src/paper_query/llm/models.py +++ b/src/paper_query/llm/models.py @@ -9,7 +9,9 @@ def get_model(model_name: str, model_provider: str): from paper_query.constants import OPENAI_API_KEY os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY - else: - raise ValueError(f"API key not provided for {model_provider}") + if model_provider == "groq": + from paper_query.constants import GROQ_API_KEY + + os.environ["GROQ_API_KEY"] = GROQ_API_KEY return init_chat_model(model_name, model_provider=model_provider) From 53b34e6eb2a468220ee71d7110ff735a3f325b83 Mon Sep 17 00:00:00 2001 From: Ewan Wallace Date: Thu, 27 Mar 2025 10:22:12 +0000 Subject: [PATCH 2/3] added gpt-4o-mini --- src/paper_query/constants/_constants.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/paper_query/constants/_constants.py b/src/paper_query/constants/_constants.py index 039cd7b..ffbf3b0 100644 --- a/src/paper_query/constants/_constants.py +++ b/src/paper_query/constants/_constants.py @@ -1,4 +1,5 @@ MODEL_PROVIDER_MAP = { "llama-3.1-8b-instant": "groq", "gpt-4o": "openai", + "gpt-4o-mini": "openai", } From 07725a00896bd626cbbbf41c3a7f526c4ff8886c Mon Sep 17 00:00:00 2001 From: Ewan Wallace Date: Thu, 27 Mar 2025 16:29:20 +0000 Subject: [PATCH 3/3] model name infers model provider --- .../ui/components/sidebar_inputs.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/paper_query/ui/components/sidebar_inputs.py b/src/paper_query/ui/components/sidebar_inputs.py index edfe9ce..7eef75b 100644 --- a/src/paper_query/ui/components/sidebar_inputs.py +++ b/src/paper_query/ui/components/sidebar_inputs.py @@ -4,6 +4,8 @@ import streamlit as st +from paper_query.constants import MODEL_PROVIDER_MAP + assets_dir = Path(__file__).resolve().parents[4] / "assets" @@ -13,19 +15,13 @@ def get_class_params(cls) -> list[str]: return [ name for name in params.keys() - if name - in ("model_name", "model_provider", "paper_path", "references_dir", "github_repo_url") + if name in ("model_name", "paper_path", "references_dir", "github_repo_url") ] -def model_name_input(name: str = "gpt-4o") -> str: +def model_name_input() -> str: """Get the model name from the sidebar.""" - return st.sidebar.text_input("Model Name", value=name) - - -def model_provider_input(provider: str = "openai") -> str: - """Get the model provider from the sidebar.""" - return st.sidebar.text_input("Model Provider", value=provider) + return st.sidebar.selectbox("Choose a model:", list(MODEL_PROVIDER_MAP.keys())) def paper_path_input() -> str: @@ -63,7 +59,7 @@ def get_param(param: str) -> str | list[str]: """Get the parameter value from the sidebar.""" param_functions = { "model_name": model_name_input, - "model_provider": model_provider_input, + # "model_provider": determined in get_chatbot_params "paper_path": paper_path_input, "references_dir": references_input, "github_repo_url": code_dir_input, @@ -76,4 +72,11 @@ def get_param(param: str) -> str | list[str]: def get_chatbot_params(selected_chatbot_class: type) -> dict: """Get the chatbot parameters from the sidebar.""" chatbot_params = get_class_params(selected_chatbot_class) - return {param: get_param(param) for param in chatbot_params} + params = {param: get_param(param) for param in chatbot_params} + + # Get the model provider based on the selected model name + if params["model_name"] not in MODEL_PROVIDER_MAP.keys(): + raise ValueError(f"Provider not found for {params['model_name']} in MODEL_PROVIDER_MAP.") + params["model_provider"] = MODEL_PROVIDER_MAP[params["model_name"]] + + return params