Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
tool_calls_var,
)
from nemoguardrails.integrations.langchain.message_utils import dicts_to_messages
from nemoguardrails.llm.parameter_mapping import get_llm_provider, transform_llm_params
from nemoguardrails.logging.callbacks import logging_callbacks
from nemoguardrails.logging.explain import LLMCallInfo

Expand Down Expand Up @@ -97,9 +98,23 @@ async def llm_call(
_setup_llm_call_info(llm, model_name, model_provider)
all_callbacks = _prepare_callbacks(custom_callback_handlers)

generation_llm: Union[BaseLanguageModel, Runnable] = (
llm.bind(stop=stop, **llm_params) if llm_params and llm is not None else llm
)
if llm_params or stop:
params_to_transform = llm_params.copy() if llm_params else {}
if stop is not None:
params_to_transform["stop"] = stop

inferred_model_name = model_name or _infer_model_name(llm)
inferred_provider = model_provider or get_llm_provider(llm)
transformed_params = transform_llm_params(
params_to_transform,
provider=inferred_provider,
model_name=inferred_model_name,
)
generation_llm: Union[BaseLanguageModel, Runnable] = llm.bind(
**transformed_params
)
else:
generation_llm: Union[BaseLanguageModel, Runnable] = llm

if isinstance(prompt, str):
response = await _invoke_with_string_prompt(
Expand Down
187 changes: 187 additions & 0 deletions nemoguardrails/llm/parameter_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for transforming LLM parameters between internal and provider-specific formats."""

import logging
from typing import Any, Dict, Optional

from langchain.base_language import BaseLanguageModel

log = logging.getLogger(__name__)

_llm_parameter_mappings = {}

PROVIDER_PARAMETER_MAPPINGS = {
"huggingface": {
"max_tokens": "max_new_tokens",
},
"google_vertexai": {
"max_tokens": "max_output_tokens",
},
}


def register_llm_parameter_mapping(
provider: str, model_name: str, parameter_mapping: Dict[str, Optional[str]]
) -> None:
"""Register a parameter mapping for a specific provider and model combination.

Args:
provider: The LLM provider name
model_name: The model name
parameter_mapping: The parameter mapping dictionary
"""
key = (provider, model_name)
_llm_parameter_mappings[key] = parameter_mapping
log.debug("Registered parameter mapping for %s/%s", provider, model_name)


def get_llm_parameter_mapping(
provider: str, model_name: str
) -> Optional[Dict[str, Optional[str]]]:
"""Get the registered parameter mapping for a provider and model combination.

Args:
provider: The LLM provider name
model_name: The model name

Returns:
The parameter mapping if registered, None otherwise
"""
return _llm_parameter_mappings.get((provider, model_name))


def _infer_provider_from_module(llm: BaseLanguageModel) -> Optional[str]:
"""Infer provider name from the LLM's module path.

This function extracts the provider name from LangChain package naming conventions:
- langchain_openai -> openai
- langchain_anthropic -> anthropic
- langchain_google_genai -> google_genai
- langchain_nvidia_ai_endpoints -> nvidia_ai_endpoints
- langchain_community.chat_models.ollama -> ollama

Args:
llm: The LLM instance

Returns:
The inferred provider name, or None if it cannot be determined
"""
module = type(llm).__module__

if module.startswith("langchain_"):
package = module.split(".")[0]
provider = package.replace("langchain_", "")

if provider == "community":
parts = module.split(".")
if len(parts) >= 3:
provider = parts[-1]
log.debug(
"Inferred provider '%s' from community module %s", provider, module
)
return provider
else:
log.debug("Inferred provider '%s' from module %s", provider, module)
return provider

log.debug("Could not infer provider from module %s", module)
return None


def get_llm_provider(llm: BaseLanguageModel) -> Optional[str]:
"""Get the provider name for an LLM instance by inferring from module path.

This function extracts the provider name from LangChain package naming conventions.
See _infer_provider_from_module for details on the inference logic.

Args:
llm: The LLM instance

Returns:
The provider name if it can be inferred, None otherwise
"""
return _infer_provider_from_module(llm)


def transform_llm_params(
llm_params: Dict[str, Any],
provider: Optional[str] = None,
model_name: Optional[str] = None,
parameter_mapping: Optional[Dict[str, Optional[str]]] = None,
) -> Dict[str, Any]:
"""Transform LLM parameters using provider-specific or custom mappings.

Args:
llm_params: The original parameters dictionary
provider: Optional provider name
model_name: Optional model name
parameter_mapping: Custom mapping dictionary. If None, uses built-in provider mappings.
Key is the internal parameter name, value is the provider parameter name.
If value is None, the parameter is dropped.

Returns:
Transformed parameters dictionary
"""
if not llm_params:
return llm_params

if parameter_mapping is not None:
return _apply_mapping(llm_params, parameter_mapping)

has_instance_mapping = (provider, model_name) in _llm_parameter_mappings
has_builtin_mapping = provider in PROVIDER_PARAMETER_MAPPINGS

if not has_instance_mapping and not has_builtin_mapping:
return llm_params

mapping = None
if has_instance_mapping:
mapping = _llm_parameter_mappings.get((provider, model_name))
log.debug("Using registered parameter mapping for %s/%s", provider, model_name)
if not mapping and has_builtin_mapping:
mapping = PROVIDER_PARAMETER_MAPPINGS[provider]
log.debug("Using built-in parameter mapping for provider: %s", provider)

return _apply_mapping(llm_params, mapping) if mapping else llm_params


def _apply_mapping(
llm_params: Dict[str, Any], mapping: Dict[str, Optional[str]]
) -> Dict[str, Any]:
"""Apply parameter mapping transformation.

Args:
llm_params: The original parameters dictionary
mapping: The parameter mapping dictionary

Returns:
Transformed parameters dictionary
"""
transformed_params = {}

for param_name, param_value in llm_params.items():
if param_name in mapping:
mapped_name = mapping[param_name]
if mapped_name is not None:
transformed_params[mapped_name] = param_value
log.debug("Mapped parameter %s -> %s", param_name, mapped_name)
else:
log.debug("Dropped parameter %s", param_name)
else:
transformed_params[param_name] = param_value

return transformed_params
6 changes: 6 additions & 0 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ class Model(BaseModel):
description="Configuration parameters for reasoning LLMs.",
)
parameters: Dict[str, Any] = Field(default_factory=dict)
parameter_mapping: Optional[Dict[str, Optional[str]]] = Field(
default=None,
description="Optional parameter mapping to transform parameter names for provider-specific requirements. "
"Keys are internal parameter names, values are provider parameter names. "
"Set value to null to drop a parameter.",
)

mode: Literal["chat", "text"] = Field(
default="chat",
Expand Down
25 changes: 24 additions & 1 deletion nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
ModelInitializationError,
init_llm_model,
)
from nemoguardrails.llm.parameter_mapping import register_llm_parameter_mapping
from nemoguardrails.logging.explain import ExplainInfo
from nemoguardrails.logging.processing_log import compute_generation_log
from nemoguardrails.logging.stats import LLMStats
Expand Down Expand Up @@ -443,11 +444,21 @@ def _init_llms(self):
if self.llm:
# If an LLM was provided via constructor, use it as the main LLM
# Log a warning if a main LLM is also specified in the config
if any(model.type == "main" for model in self.config.models):
main_model = next(
(model for model in self.config.models if model.type == "main"), None
)
if main_model:
log.warning(
"Both an LLM was provided via constructor and a main LLM is specified in the config. "
"The LLM provided via constructor will be used and the main LLM from config will be ignored."
)
# Still register parameter mapping from config if available
if main_model.parameter_mapping and main_model.model:
register_llm_parameter_mapping(
main_model.engine,
main_model.model,
main_model.parameter_mapping,
)
self.runtime.register_action_param("llm", self.llm)

self._configure_main_llm_streaming(self.llm)
Expand All @@ -465,6 +476,12 @@ def _init_llms(self):
mode="chat",
kwargs=kwargs,
)
if main_model.parameter_mapping and main_model.model:
register_llm_parameter_mapping(
main_model.engine,
main_model.model,
main_model.parameter_mapping,
)
self.runtime.register_action_param("llm", self.llm)

self._configure_main_llm_streaming(
Expand Down Expand Up @@ -500,6 +517,12 @@ def _init_llms(self):
kwargs=kwargs,
)

if llm_config.parameter_mapping and llm_config.model:
register_llm_parameter_mapping(
llm_config.engine,
llm_config.model,
llm_config.parameter_mapping,
)
if llm_config.type == "main":
# If a main LLM was already injected, skip creating another
# one. Otherwise, create and register it.
Expand Down
Loading
Loading