Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions litellm/experimental_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
Experimental feature flags for LiteLLM.

All flags default OFF. Enable by setting the corresponding environment variable.
"""

import os

ENABLE_PARALLEL_ACOMPLETIONS: bool = os.getenv(
"LITELLM_ENABLE_PARALLEL_ACOMPLETIONS", "0"
).lower() in ("1", "true", "yes", "on")
158 changes: 111 additions & 47 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@
async_raise_no_deployment_exception,
send_llm_exception_alert,
)
from litellm.experimental_flags import ENABLE_PARALLEL_ACOMPLETIONS
from litellm.router_utils.parallel_acompletion import (
iter_parallel_acompletions as _iter_parallel_acompletions,
gather_parallel_acompletions as _gather_parallel_acompletions,
RouterParallelRequest,
RouterParallelResult,
)
from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import (
PromptCachingDeploymentCheck,
)
Expand Down Expand Up @@ -359,9 +366,9 @@ def __init__( # noqa: PLR0915
) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
### CACHING ###
cache_type: Literal[
"local", "redis", "redis-semantic", "s3", "disk"
] = "local" # default to an in-memory cache
cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
"local" # default to an in-memory cache
)
redis_cache = None
cache_config: Dict[str, Any] = {}

Expand Down Expand Up @@ -403,9 +410,9 @@ def __init__( # noqa: PLR0915
self.default_max_parallel_requests = default_max_parallel_requests
self.provider_default_deployment_ids: List[str] = []
self.pattern_router = PatternMatchRouter()
self.team_pattern_routers: Dict[
str, PatternMatchRouter
] = {} # {"TEAM_ID": PatternMatchRouter}
self.team_pattern_routers: Dict[str, PatternMatchRouter] = (
{}
) # {"TEAM_ID": PatternMatchRouter}
self.auto_routers: Dict[str, "AutoRouter"] = {}

if model_list is not None:
Expand Down Expand Up @@ -587,9 +594,9 @@ def __init__( # noqa: PLR0915
)
)

self.model_group_retry_policy: Optional[
Dict[str, RetryPolicy]
] = model_group_retry_policy
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy
)

self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
if allowed_fails_policy is not None:
Expand Down Expand Up @@ -782,9 +789,6 @@ def initialize_router_endpoints(self):
self.aget_responses = self.factory_function(
litellm.aget_responses, call_type="aget_responses"
)
self.acancel_responses = self.factory_function(
litellm.acancel_responses, call_type="acancel_responses"
)
self.adelete_responses = self.factory_function(
litellm.adelete_responses, call_type="adelete_responses"
)
Expand Down Expand Up @@ -876,6 +880,7 @@ def validate_fallbacks(self, fallback_param: Optional[List]):
def add_optional_pre_call_checks(
self, optional_pre_call_checks: Optional[OptionalPreCallChecks]
):

if optional_pre_call_checks is not None:
for pre_call_check in optional_pre_call_checks:
_callback: Optional[CustomLogger] = None
Expand Down Expand Up @@ -1211,7 +1216,10 @@ async def stream_with_fallbacks():

async def _acompletion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper,]:
) -> Union[
ModelResponse,
CustomStreamWrapper,
]:
"""
- Get an available deployment
- call it with a semaphore over the call
Expand Down Expand Up @@ -2712,6 +2720,7 @@ async def _ageneric_api_call_with_fallbacks_helper(
passthrough_on_no_deployment = kwargs.pop("passthrough_on_no_deployment", False)
function_name = "_ageneric_api_call_with_fallbacks"
try:

parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
try:
deployment = await self.async_get_available_deployment(
Expand Down Expand Up @@ -3155,9 +3164,9 @@ async def create_file_for_deployment(deployment: dict) -> OpenAIFileObject:
healthy_deployments=healthy_deployments, responses=responses
)
returned_response = cast(OpenAIFileObject, responses[0])
returned_response._hidden_params[
"model_file_id_mapping"
] = model_file_id_mapping
returned_response._hidden_params["model_file_id_mapping"] = (
model_file_id_mapping
)
return returned_response
except Exception as e:
verbose_router_logger.exception(
Expand Down Expand Up @@ -3483,7 +3492,6 @@ def factory_function(
"moderation",
"anthropic_messages",
"aresponses",
"acancel_responses",
"responses",
"aget_responses",
"adelete_responses",
Expand Down Expand Up @@ -3577,7 +3585,6 @@ async def async_wrapper(
)
elif call_type in (
"aget_responses",
"acancel_responses",
"adelete_responses",
"alist_input_items",
):
Expand Down Expand Up @@ -3625,7 +3632,7 @@ async def _init_responses_api_endpoints(
"""
Initialize the Responses API endpoints on the router.

GET, DELETE, CANCEL Responses API Requests encode the model_id in the response_id, this function decodes the response_id and sets the model to the model_id.
GET, DELETE Responses API Requests encode the model_id in the response_id, this function decodes the response_id and sets the model to the model_id.
"""
from litellm.responses.utils import ResponsesAPIRequestUtils

Expand Down Expand Up @@ -3720,11 +3727,11 @@ async def async_function_with_fallbacks_common_utils( # noqa: PLR0915

if isinstance(e, litellm.ContextWindowExceededError):
if context_window_fallbacks is not None:
context_window_fallback_model_group: Optional[
List[str]
] = self._get_fallback_model_group_from_fallbacks(
fallbacks=context_window_fallbacks,
model_group=model_group,
context_window_fallback_model_group: Optional[List[str]] = (
self._get_fallback_model_group_from_fallbacks(
fallbacks=context_window_fallbacks,
model_group=model_group,
)
)
if context_window_fallback_model_group is None:
raise original_exception
Expand Down Expand Up @@ -3756,11 +3763,11 @@ async def async_function_with_fallbacks_common_utils( # noqa: PLR0915
e.message += "\n{}".format(error_message)
elif isinstance(e, litellm.ContentPolicyViolationError):
if content_policy_fallbacks is not None:
content_policy_fallback_model_group: Optional[
List[str]
] = self._get_fallback_model_group_from_fallbacks(
fallbacks=content_policy_fallbacks,
model_group=model_group,
content_policy_fallback_model_group: Optional[List[str]] = (
self._get_fallback_model_group_from_fallbacks(
fallbacks=content_policy_fallbacks,
model_group=model_group,
)
)
if content_policy_fallback_model_group is None:
raise original_exception
Expand Down Expand Up @@ -4562,10 +4569,8 @@ async def async_deployment_callback_on_failure(
parent_otel_span=parent_otel_span,
ttl=RoutingArgs.ttl.value,
)

def _get_metadata_variable_name_from_kwargs(
self, kwargs: dict
) -> Literal["metadata", "litellm_metadata"]:

def _get_metadata_variable_name_from_kwargs(self, kwargs: dict) -> Literal["metadata", "litellm_metadata"]:
"""
Helper to return what the "metadata" field should be called in the request data

Expand Down Expand Up @@ -4992,26 +4997,26 @@ def init_auto_router_deployment(self, deployment: Deployment):
"""
from litellm.router_strategy.auto_router.auto_router import AutoRouter

auto_router_config_path: Optional[
str
] = deployment.litellm_params.auto_router_config_path
auto_router_config_path: Optional[str] = (
deployment.litellm_params.auto_router_config_path
)
auto_router_config: Optional[str] = deployment.litellm_params.auto_router_config
if auto_router_config_path is None and auto_router_config is None:
raise ValueError(
"auto_router_config_path or auto_router_config is required for auto-router deployments. Please set it in the litellm_params"
)

default_model: Optional[
str
] = deployment.litellm_params.auto_router_default_model
default_model: Optional[str] = (
deployment.litellm_params.auto_router_default_model
)
if default_model is None:
raise ValueError(
"auto_router_default_model is required for auto-router deployments. Please set it in the litellm_params"
)

embedding_model: Optional[
str
] = deployment.litellm_params.auto_router_embedding_model
embedding_model: Optional[str] = (
deployment.litellm_params.auto_router_embedding_model
)
if embedding_model is None:
raise ValueError(
"auto_router_embedding_model is required for auto-router deployments. Please set it in the litellm_params"
Expand Down Expand Up @@ -5674,11 +5679,11 @@ def _set_model_group_info( # noqa: PLR0915
)
if supported_openai_params is None:
supported_openai_params = []

# Get mode from database model_info if available, otherwise default to "chat"
db_model_info = model.get("model_info", {})
mode = db_model_info.get("mode", "chat")

model_info = ModelMapInfo(
key=model_group,
max_tokens=None,
Expand Down Expand Up @@ -6804,9 +6809,7 @@ async def async_get_healthy_deployments(
model=model,
request_kwargs=request_kwargs,
healthy_deployments=healthy_deployments,
metadata_variable_name=self._get_metadata_variable_name_from_kwargs(
request_kwargs
),
metadata_variable_name=self._get_metadata_variable_name_from_kwargs(request_kwargs),
)

if len(healthy_deployments) == 0:
Expand Down Expand Up @@ -7268,6 +7271,67 @@ def set_custom_routing_strategy(
CustomRoutingStrategy.async_get_available_deployment,
)

async def parallel_acompletions(
self,
requests: List[RouterParallelRequest],
*,
concurrency: Optional[int] = None,
return_exceptions: bool = True,
preserve_order: bool = False,
) -> List[RouterParallelResult]:
"""
Experimental: run multiple acompletion calls concurrently and collect results.

Requires env variable: LITELLM_ENABLE_PARALLEL_ACOMPLETIONS=1
"""
if not ENABLE_PARALLEL_ACOMPLETIONS:
raise RuntimeError(
"parallel_acompletions disabled; set LITELLM_ENABLE_PARALLEL_ACOMPLETIONS=1"
)
# Concurrency default tie-in
_concurrency = (
concurrency
if concurrency is not None
else (self.default_max_parallel_requests or 8)
)

return await _gather_parallel_acompletions(
self,
requests,
concurrency=_concurrency,
return_exceptions=return_exceptions,
preserve_order=preserve_order,
)

def iter_parallel_acompletions(
self,
requests: List[RouterParallelRequest],
*,
concurrency: Optional[int] = None,
return_exceptions: bool = True,
):
"""
Experimental: async iterator yielding each result as it finishes (completion order).

Requires env variable: LITELLM_ENABLE_PARALLEL_ACOMPLETIONS=1
"""
if not ENABLE_PARALLEL_ACOMPLETIONS:
raise RuntimeError(
"parallel_acompletions disabled; set LITELLM_ENABLE_PARALLEL_ACOMPLETIONS=1"
)
_concurrency = (
concurrency
if concurrency is not None
else (self.default_max_parallel_requests or 8)
)

return _iter_parallel_acompletions(
self,
requests,
concurrency=_concurrency,
return_exceptions=return_exceptions,
)

def flush_cache(self):
litellm.cache = None
self.cache.flush_cache()
Expand Down
Loading
Loading