Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 119 additions & 51 deletions src/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@
INGESTION_TIMEOUT = get_env_int("INGESTION_TIMEOUT", 3600)


def is_no_auth_mode():
"""Check if we're running in no-auth mode (OAuth credentials missing)"""
result = not (GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET)
return result
def is_no_auth_mode() -> bool:
"""Check if we're running in no-auth mode (OAuth credentials missing).

Returns:
bool: True if OAuth credentials are missing, False otherwise.
"""
return not (GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET)


# Webhook configuration - must be set to enable webhooks
Expand All @@ -93,23 +96,23 @@ def is_no_auth_mode():
KNN_M = 16
EMBED_MODEL = "text-embedding-3-small"

OPENAI_EMBEDDING_DIMENSIONS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
}

WATSONX_EMBEDDING_DIMENSIONS = {
# IBM Models
"ibm/granite-embedding-107m-multilingual": 384,
"ibm/granite-embedding-278m-multilingual": 1024,
"ibm/slate-125m-english-rtrvr": 768,
"ibm/slate-125m-english-rtrvr-v2": 768,
"ibm/slate-30m-english-rtrvr": 384,
"ibm/slate-30m-english-rtrvr-v2": 384,
# Third Party Models
"intfloat/multilingual-e5-large": 1024,
"sentence-transformers/all-minilm-l6-v2": 384,
OPENAI_EMBEDDING_DIMENSIONS: dict[str, int] = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
}

WATSONX_EMBEDDING_DIMENSIONS: dict[str, int] = {
# IBM Models
"ibm/granite-embedding-107m-multilingual": 384,
"ibm/granite-embedding-278m-multilingual": 1024,
"ibm/slate-125m-english-rtrvr": 768,
"ibm/slate-125m-english-rtrvr-v2": 768,
"ibm/slate-30m-english-rtrvr": 384,
"ibm/slate-30m-english-rtrvr-v2": 384,
# Third Party Models
"intfloat/multilingual-e5-large": 1024,
"sentence-transformers/all-minilm-l6-v2": 384,
}

INDEX_BODY = {
Expand Down Expand Up @@ -299,15 +302,17 @@ async def get_langflow_api_key(force_regenerate: bool = False):


class AppClients:
def __init__(self):
self.opensearch = None
self.langflow_client = None
self.langflow_http_client = None
self._patched_async_client = None # Private attribute - single client for all providers
"""Manages application-wide client connections (OpenSearch, Langflow, OpenAI, etc.)."""

def __init__(self) -> None:
self.opensearch: any = None
self.langflow_client: any = None
self.langflow_http_client: any = None
self._patched_async_client: any = None # Private attribute - single client for all providers
self._client_init_lock = __import__('threading').Lock() # Lock for thread-safe initialization
self.docling_http_client = None
self.docling_http_client: any = None

async def initialize(self):
async def initialize(self) -> "AppClients":
# Initialize OpenSearch client
self.opensearch = AsyncOpenSearch(
hosts=[{"host": OPENSEARCH_HOST, "port": OPENSEARCH_PORT}],
Expand Down Expand Up @@ -389,8 +394,12 @@ async def initialize(self):

return self

async def ensure_langflow_client(self):
"""Ensure Langflow client exists; try to generate key and create client lazily."""
async def ensure_langflow_client(self) -> any:
"""Ensure Langflow client exists; try to generate key and create client lazily.

Returns:
AsyncOpenAI: The Langflow client instance, or None if initialization failed.
"""
if self.langflow_client is not None:
return self.langflow_client
# Try generating key again (with retries)
Expand Down Expand Up @@ -549,8 +558,11 @@ def patched_embedding_client(self):
"""Alias for patched_async_client - for backward compatibility with code expecting separate clients."""
return self.patched_async_client

async def refresh_patched_client(self):
"""Reset patched client so next use picks up updated provider credentials."""
async def refresh_patched_client(self) -> None:
"""Reset patched client so next use picks up updated provider credentials.

Closes the existing patched client and clears it for re-initialization.
"""
if self._patched_async_client is not None:
try:
await self._patched_async_client.close()
Expand All @@ -560,8 +572,11 @@ async def refresh_patched_client(self):
finally:
self._patched_async_client = None

async def cleanup(self):
"""Cleanup resources - should be called on application shutdown"""
async def cleanup(self) -> None:
"""Cleanup resources - should be called on application shutdown.

Closes all open client connections (OpenSearch, Langflow, OpenAI, etc.).
"""
# Close AsyncOpenAI client if it was created
if self._patched_async_client is not None:
try:
Expand Down Expand Up @@ -612,10 +627,21 @@ async def cleanup(self):
finally:
self.langflow_client = None

async def langflow_request(self, method: str, endpoint: str, **kwargs):
async def langflow_request(self, method: str, endpoint: str, **kwargs: any) -> any:
"""Central method for all Langflow API requests.

Retries once with a fresh API key on auth failures (401/403).

Args:
method: HTTP method (GET, POST, PUT, PATCH, DELETE).
endpoint: API endpoint path.
**kwargs: Additional arguments passed to the HTTP client.

Returns:
httpx.Response: The HTTP response object.

Raises:
ValueError: If no Langflow API key is available.
"""
api_key = await get_langflow_api_key()
if not api_key:
Expand Down Expand Up @@ -654,8 +680,14 @@ async def langflow_request(self, method: str, endpoint: str, **kwargs):

async def _create_langflow_global_variable(
self, name: str, value: str, modify: bool = False
):
"""Create a global variable in Langflow via API"""
) -> None:
"""Create a global variable in Langflow via API.

Args:
name: The variable name.
value: The variable value.
modify: If True, update the variable if it already exists.
"""
payload = {
"name": name,
"value": value,
Expand Down Expand Up @@ -699,8 +731,13 @@ async def _create_langflow_global_variable(
)
raise e

async def _update_langflow_global_variable(self, name: str, value: str):
"""Update an existing global variable in Langflow via API"""
async def _update_langflow_global_variable(self, name: str, value: str) -> None:
"""Update an existing global variable in Langflow via API.

Args:
name: The variable name.
value: The new variable value.
"""
try:
# First, get all variables to find the one with the matching name
get_response = await self.langflow_request("GET", "/api/v1/variables/")
Expand Down Expand Up @@ -765,8 +802,15 @@ async def _update_langflow_global_variable(self, name: str, value: str):
error=str(e),
)

def create_user_opensearch_client(self, jwt_token: str):
"""Create OpenSearch client with user's JWT token for OIDC auth"""
def create_user_opensearch_client(self, jwt_token: str) -> any:
"""Create OpenSearch client with user's JWT token for OIDC auth.

Args:
jwt_token: The user's JWT token for authentication.

Returns:
AsyncOpenSearch: The OpenSearch client instance.
"""
headers = {"Authorization": f"Bearer {jwt_token}"}

return AsyncOpenSearch(
Expand Down Expand Up @@ -840,32 +884,56 @@ def create_user_opensearch_client(self, jwt_token: str):


# Configuration access
def get_openrag_config():
"""Get current OpenRAG configuration."""
def get_openrag_config() -> any:
"""Get current OpenRAG configuration.

Returns:
OpenRAGConfig: The current configuration object.
"""
return config_manager.get_config()


# Expose configuration settings for backward compatibility and easy access
def get_provider_config():
"""Get provider configuration."""
def get_provider_config() -> any:
"""Get provider configuration.

Returns:
ProviderConfig: The provider configuration object.
"""
return get_openrag_config().provider


def get_knowledge_config():
"""Get knowledge configuration."""
def get_knowledge_config() -> any:
"""Get knowledge configuration.

Returns:
KnowledgeConfig: The knowledge configuration object.
"""
return get_openrag_config().knowledge


def get_agent_config():
"""Get agent configuration."""
def get_agent_config() -> any:
"""Get agent configuration.

Returns:
AgentConfig: The agent configuration object.
"""
return get_openrag_config().agent


def get_embedding_model() -> str:
"""Return the currently configured embedding model."""
"""Return the currently configured embedding model.

Returns:
str: The embedding model name, or empty string if ingest with Langflow is disabled.
"""
return get_openrag_config().knowledge.embedding_model or EMBED_MODEL if DISABLE_INGEST_WITH_LANGFLOW else ""


def get_index_name() -> str:
"""Return the currently configured index name."""
"""Return the currently configured index name.

Returns:
str: The index name.
"""
return get_openrag_config().knowledge.index_name
Loading