diff --git a/src/paper_query/constants/__init__.py b/src/paper_query/constants/__init__.py index 77efa44..c777d87 100644 --- a/src/paper_query/constants/__init__.py +++ b/src/paper_query/constants/__init__.py @@ -1,11 +1,12 @@ from ._api_keys import GROQ_API_KEY, HUGGINGFACE_API_KEY, OPENAI_API_KEY +from ._constants import MODEL_PROVIDER_MAP, RAG_DOC_ID from ._paths import PERSIST_DIRECTORY, assets_dir, data_dir, project_dir, src_dir, test_dir -from ._strings import RAG_DOC_ID __all__ = [ "OPENAI_API_KEY", "HUGGINGFACE_API_KEY", "GROQ_API_KEY", + "MODEL_PROVIDER_MAP", "PERSIST_DIRECTORY", "project_dir", "src_dir", diff --git a/src/paper_query/constants/_constants.py b/src/paper_query/constants/_constants.py new file mode 100644 index 0000000..05ce492 --- /dev/null +++ b/src/paper_query/constants/_constants.py @@ -0,0 +1,7 @@ +MODEL_PROVIDER_MAP = { + "llama-3.1-8b-instant": "groq", + "gpt-4o": "openai", + "gpt-4o-mini": "openai", +} + +RAG_DOC_ID = "file" diff --git a/src/paper_query/ui/components/sidebar_inputs.py b/src/paper_query/ui/components/sidebar_inputs.py index 2f39400..608fd5e 100644 --- a/src/paper_query/ui/components/sidebar_inputs.py +++ b/src/paper_query/ui/components/sidebar_inputs.py @@ -3,7 +3,7 @@ import streamlit as st -from paper_query.constants import assets_dir +from paper_query.constants import MODEL_PROVIDER_MAP, assets_dir def get_class_params(cls) -> list[str]: @@ -12,19 +12,15 @@ def get_class_params(cls) -> list[str]: return [ name for name in params.keys() - if name - in ("model_name", "model_provider", "paper_path", "references_dir", "github_repo_url") + if name in ("model_name", "paper_path", "references_dir", "github_repo_url") ] -def model_name_input(name: str = "gpt-4o") -> str: +def model_name_input() -> str: """Get the model name from the sidebar.""" - return st.sidebar.text_input("Model Name", value=name, key="model_name_input") - - -def model_provider_input(provider: str = "openai") -> str: - """Get the model provider from the sidebar.""" - return st.sidebar.text_input("Model Provider", value=provider, key="model_provider_input") + return st.sidebar.selectbox( + "Choose a model:", list(MODEL_PROVIDER_MAP.keys()), key="model_name_input" + ) def paper_path_input() -> str: @@ -62,7 +58,7 @@ def get_param(param: str) -> str | list[str]: """Get the parameter value from the sidebar.""" param_functions = { "model_name": model_name_input, - "model_provider": model_provider_input, + # "model_provider": determined in get_chatbot_params "paper_path": paper_path_input, "references_dir": references_input, "github_repo_url": code_dir_input, @@ -75,4 +71,11 @@ def get_param(param: str) -> str | list[str]: def get_chatbot_params(selected_chatbot_class: type) -> dict: """Get the chatbot parameters from the sidebar.""" chatbot_params = get_class_params(selected_chatbot_class) - return {param: get_param(param) for param in chatbot_params} + params = {param: get_param(param) for param in chatbot_params} + + # Get the model provider based on the selected model name + if params["model_name"] not in MODEL_PROVIDER_MAP.keys(): + raise ValueError(f"Provider not found for {params['model_name']} in MODEL_PROVIDER_MAP.") + params["model_provider"] = MODEL_PROVIDER_MAP[params["model_name"]] + + return params