diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index c36899bb8..620abbbe0 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -30,6 +30,7 @@ tool_calls_var, ) from nemoguardrails.integrations.langchain.message_utils import dicts_to_messages +from nemoguardrails.llm.parameter_mapping import get_llm_provider, transform_llm_params from nemoguardrails.logging.callbacks import logging_callbacks from nemoguardrails.logging.explain import LLMCallInfo @@ -97,9 +98,23 @@ async def llm_call( _setup_llm_call_info(llm, model_name, model_provider) all_callbacks = _prepare_callbacks(custom_callback_handlers) - generation_llm: Union[BaseLanguageModel, Runnable] = ( - llm.bind(stop=stop, **llm_params) if llm_params and llm is not None else llm - ) + if llm_params or stop: + params_to_transform = llm_params.copy() if llm_params else {} + if stop is not None: + params_to_transform["stop"] = stop + + inferred_model_name = model_name or _infer_model_name(llm) + inferred_provider = model_provider or get_llm_provider(llm) + transformed_params = transform_llm_params( + params_to_transform, + provider=inferred_provider, + model_name=inferred_model_name, + ) + generation_llm: Union[BaseLanguageModel, Runnable] = llm.bind( + **transformed_params + ) + else: + generation_llm: Union[BaseLanguageModel, Runnable] = llm if isinstance(prompt, str): response = await _invoke_with_string_prompt( diff --git a/nemoguardrails/llm/parameter_mapping.py b/nemoguardrails/llm/parameter_mapping.py new file mode 100644 index 000000000..1f155df6d --- /dev/null +++ b/nemoguardrails/llm/parameter_mapping.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for transforming LLM parameters between internal and provider-specific formats.""" + +import logging +from typing import Any, Dict, Optional + +from langchain.base_language import BaseLanguageModel + +log = logging.getLogger(__name__) + +_llm_parameter_mappings = {} + +PROVIDER_PARAMETER_MAPPINGS = { + "huggingface": { + "max_tokens": "max_new_tokens", + }, + "google_vertexai": { + "max_tokens": "max_output_tokens", + }, +} + + +def register_llm_parameter_mapping( + provider: str, model_name: str, parameter_mapping: Dict[str, Optional[str]] +) -> None: + """Register a parameter mapping for a specific provider and model combination. + + Args: + provider: The LLM provider name + model_name: The model name + parameter_mapping: The parameter mapping dictionary + """ + key = (provider, model_name) + _llm_parameter_mappings[key] = parameter_mapping + log.debug("Registered parameter mapping for %s/%s", provider, model_name) + + +def get_llm_parameter_mapping( + provider: str, model_name: str +) -> Optional[Dict[str, Optional[str]]]: + """Get the registered parameter mapping for a provider and model combination. + + Args: + provider: The LLM provider name + model_name: The model name + + Returns: + The parameter mapping if registered, None otherwise + """ + return _llm_parameter_mappings.get((provider, model_name)) + + +def _infer_provider_from_module(llm: BaseLanguageModel) -> Optional[str]: + """Infer provider name from the LLM's module path. + + This function extracts the provider name from LangChain package naming conventions: + - langchain_openai -> openai + - langchain_anthropic -> anthropic + - langchain_google_genai -> google_genai + - langchain_nvidia_ai_endpoints -> nvidia_ai_endpoints + - langchain_community.chat_models.ollama -> ollama + + Args: + llm: The LLM instance + + Returns: + The inferred provider name, or None if it cannot be determined + """ + module = type(llm).__module__ + + if module.startswith("langchain_"): + package = module.split(".")[0] + provider = package.replace("langchain_", "") + + if provider == "community": + parts = module.split(".") + if len(parts) >= 3: + provider = parts[-1] + log.debug( + "Inferred provider '%s' from community module %s", provider, module + ) + return provider + else: + log.debug("Inferred provider '%s' from module %s", provider, module) + return provider + + log.debug("Could not infer provider from module %s", module) + return None + + +def get_llm_provider(llm: BaseLanguageModel) -> Optional[str]: + """Get the provider name for an LLM instance by inferring from module path. + + This function extracts the provider name from LangChain package naming conventions. + See _infer_provider_from_module for details on the inference logic. + + Args: + llm: The LLM instance + + Returns: + The provider name if it can be inferred, None otherwise + """ + return _infer_provider_from_module(llm) + + +def transform_llm_params( + llm_params: Dict[str, Any], + provider: Optional[str] = None, + model_name: Optional[str] = None, + parameter_mapping: Optional[Dict[str, Optional[str]]] = None, +) -> Dict[str, Any]: + """Transform LLM parameters using provider-specific or custom mappings. + + Args: + llm_params: The original parameters dictionary + provider: Optional provider name + model_name: Optional model name + parameter_mapping: Custom mapping dictionary. If None, uses built-in provider mappings. + Key is the internal parameter name, value is the provider parameter name. + If value is None, the parameter is dropped. + + Returns: + Transformed parameters dictionary + """ + if not llm_params: + return llm_params + + if parameter_mapping is not None: + return _apply_mapping(llm_params, parameter_mapping) + + has_instance_mapping = (provider, model_name) in _llm_parameter_mappings + has_builtin_mapping = provider in PROVIDER_PARAMETER_MAPPINGS + + if not has_instance_mapping and not has_builtin_mapping: + return llm_params + + mapping = None + if has_instance_mapping: + mapping = _llm_parameter_mappings.get((provider, model_name)) + log.debug("Using registered parameter mapping for %s/%s", provider, model_name) + if not mapping and has_builtin_mapping: + mapping = PROVIDER_PARAMETER_MAPPINGS[provider] + log.debug("Using built-in parameter mapping for provider: %s", provider) + + return _apply_mapping(llm_params, mapping) if mapping else llm_params + + +def _apply_mapping( + llm_params: Dict[str, Any], mapping: Dict[str, Optional[str]] +) -> Dict[str, Any]: + """Apply parameter mapping transformation. + + Args: + llm_params: The original parameters dictionary + mapping: The parameter mapping dictionary + + Returns: + Transformed parameters dictionary + """ + transformed_params = {} + + for param_name, param_value in llm_params.items(): + if param_name in mapping: + mapped_name = mapping[param_name] + if mapped_name is not None: + transformed_params[mapped_name] = param_value + log.debug("Mapped parameter %s -> %s", param_name, mapped_name) + else: + log.debug("Dropped parameter %s", param_name) + else: + transformed_params[param_name] = param_value + + return transformed_params diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 749ecfd32..2e1342fd9 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -123,6 +123,12 @@ class Model(BaseModel): description="Configuration parameters for reasoning LLMs.", ) parameters: Dict[str, Any] = Field(default_factory=dict) + parameter_mapping: Optional[Dict[str, Optional[str]]] = Field( + default=None, + description="Optional parameter mapping to transform parameter names for provider-specific requirements. " + "Keys are internal parameter names, values are provider parameter names. " + "Set value to null to drop a parameter.", + ) mode: Literal["chat", "text"] = Field( default="chat", diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index fe56bcf08..7eafaec9e 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -74,6 +74,7 @@ ModelInitializationError, init_llm_model, ) +from nemoguardrails.llm.parameter_mapping import register_llm_parameter_mapping from nemoguardrails.logging.explain import ExplainInfo from nemoguardrails.logging.processing_log import compute_generation_log from nemoguardrails.logging.stats import LLMStats @@ -443,11 +444,21 @@ def _init_llms(self): if self.llm: # If an LLM was provided via constructor, use it as the main LLM # Log a warning if a main LLM is also specified in the config - if any(model.type == "main" for model in self.config.models): + main_model = next( + (model for model in self.config.models if model.type == "main"), None + ) + if main_model: log.warning( "Both an LLM was provided via constructor and a main LLM is specified in the config. " "The LLM provided via constructor will be used and the main LLM from config will be ignored." ) + # Still register parameter mapping from config if available + if main_model.parameter_mapping and main_model.model: + register_llm_parameter_mapping( + main_model.engine, + main_model.model, + main_model.parameter_mapping, + ) self.runtime.register_action_param("llm", self.llm) self._configure_main_llm_streaming(self.llm) @@ -465,6 +476,12 @@ def _init_llms(self): mode="chat", kwargs=kwargs, ) + if main_model.parameter_mapping and main_model.model: + register_llm_parameter_mapping( + main_model.engine, + main_model.model, + main_model.parameter_mapping, + ) self.runtime.register_action_param("llm", self.llm) self._configure_main_llm_streaming( @@ -500,6 +517,12 @@ def _init_llms(self): kwargs=kwargs, ) + if llm_config.parameter_mapping and llm_config.model: + register_llm_parameter_mapping( + llm_config.engine, + llm_config.model, + llm_config.parameter_mapping, + ) if llm_config.type == "main": # If a main LLM was already injected, skip creating another # one. Otherwise, create and register it. diff --git a/tests/test_llm_call_parameter_mapping.py b/tests/test_llm_call_parameter_mapping.py new file mode 100644 index 000000000..d6811c1ec --- /dev/null +++ b/tests/test_llm_call_parameter_mapping.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LLM parameter mapping integration in llm_call function.""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +from nemoguardrails.actions.llm.utils import llm_call + + +class MockResponse: + """Mock response object.""" + + def __init__(self, content="Test response"): + self.content = content + + +class MockHuggingFaceLLM: + """Mock HuggingFace LLM for testing parameter mapping.""" + + __module__ = "langchain_huggingface.llms" + + def __init__(self): + self.model_name = "test-model" + self.bind = Mock(return_value=self) + self.ainvoke = AsyncMock(return_value=MockResponse()) + + +@pytest.mark.asyncio +async def test_llm_call_with_registered_parameter_mapping(): + """Test llm_call applies registered parameter mapping correctly.""" + from nemoguardrails.llm.parameter_mapping import register_llm_parameter_mapping + + mock_llm = MockHuggingFaceLLM() + register_llm_parameter_mapping( + "huggingface", "test-model", {"max_tokens": "max_new_tokens"} + ) + + result = await llm_call( + llm=mock_llm, + prompt="Test prompt", + llm_params={"max_tokens": 100, "temperature": 0.5}, + ) + + mock_llm.bind.assert_called_once_with(max_new_tokens=100, temperature=0.5) + assert result == "Test response" + + +@pytest.mark.asyncio +async def test_llm_call_with_builtin_mapping(): + """Test llm_call uses built-in provider mapping when no custom mapping provided.""" + mock_llm = MockHuggingFaceLLM() + + result = await llm_call( + llm=mock_llm, + prompt="Test prompt", + llm_params={"max_tokens": 50, "temperature": 0.7}, + ) + + mock_llm.bind.assert_called_once_with(max_new_tokens=50, temperature=0.7) + assert result == "Test response" + + +@pytest.mark.asyncio +async def test_llm_call_with_dropped_parameter(): + """Test llm_call drops parameters mapped to None.""" + from nemoguardrails.llm.parameter_mapping import register_llm_parameter_mapping + + mock_llm = MockHuggingFaceLLM() + register_llm_parameter_mapping( + "huggingface", + "test-model", + {"max_tokens": "max_new_tokens", "unsupported_param": None}, + ) + + result = await llm_call( + llm=mock_llm, + prompt="Test prompt", + llm_params={"max_tokens": 100, "unsupported_param": "value"}, + ) + + mock_llm.bind.assert_called_once_with(max_new_tokens=100) + assert result == "Test response" + + +@pytest.mark.asyncio +async def test_llm_call_without_params(): + """Test llm_call works without llm_params.""" + mock_llm = MockHuggingFaceLLM() + + result = await llm_call(llm=mock_llm, prompt="Test prompt") + + mock_llm.bind.assert_not_called() + mock_llm.ainvoke.assert_called_once() + assert result == "Test response" + + +@pytest.mark.asyncio +async def test_llm_call_with_stop_tokens(): + """Test llm_call handles stop tokens correctly with parameter mapping.""" + from nemoguardrails.llm.parameter_mapping import register_llm_parameter_mapping + + mock_llm = MockHuggingFaceLLM() + register_llm_parameter_mapping( + "huggingface", "test-model", {"max_tokens": "max_new_tokens"} + ) + + result = await llm_call( + llm=mock_llm, + prompt="Test prompt", + stop=["END", "STOP"], + llm_params={"max_tokens": 100}, + ) + + mock_llm.bind.assert_called_once_with(stop=["END", "STOP"], max_new_tokens=100) + assert result == "Test response" diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py index 9b8a2b300..4d45b1161 100644 --- a/tests/test_llmrails.py +++ b/tests/test_llmrails.py @@ -14,16 +14,15 @@ # limitations under the License. import os -from typing import Any, Dict, List, Optional, Union -from unittest.mock import MagicMock, patch +from typing import Optional +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from langchain_core.language_models import BaseChatModel from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails.llm.parameter_mapping import get_llm_parameter_mapping from nemoguardrails.logging.explain import ExplainInfo -from nemoguardrails.rails.llm.config import Model -from nemoguardrails.rails.llm.llmrails import get_action_details_from_flow_id from tests.utils import FakeLLM, clean_events, event_sequence_conforms @@ -1187,3 +1186,85 @@ def test_explain_calls_ensure_explain_info(): info = rails.explain() assert info == ExplainInfo() assert rails.explain_info == ExplainInfo() + + +class MockHuggingFaceLLM: + """Mock HuggingFace LLM for testing.""" + + def __init__(self): + self.bind = Mock(return_value=self) + self.ainvoke = AsyncMock(return_value=Mock(content="Test response")) + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_parameter_mapping_registered_during_rails_initialization(mock_init_llm): + """Test that parameter mapping is registered when LLMRails is initialized.""" + mock_llm = MockHuggingFaceLLM() + mock_init_llm.return_value = mock_llm + + config = RailsConfig.from_content( + config={ + "models": [ + { + "type": "main", + "engine": "huggingface", + "model": "test-model", + "parameter_mapping": { + "max_tokens": "max_new_tokens", + "temperature": "temp", + "unsupported_param": None, + }, + } + ] + } + ) + + rails = LLMRails(config=config) + + registered_mapping = get_llm_parameter_mapping("huggingface", "test-model") + expected_mapping = { + "max_tokens": "max_new_tokens", + "temperature": "temp", + "unsupported_param": None, + } + assert registered_mapping == expected_mapping + + +@patch("nemoguardrails.llm.models.initializer.init_llm_model") +def test_parameter_mapping_with_provided_llm(mock_init_llm): + """Test parameter mapping when LLM is provided via constructor.""" + mock_llm = MockHuggingFaceLLM() + + config = RailsConfig.from_content( + config={ + "models": [ + { + "type": "main", + "engine": "huggingface", + "model": "test-model", + "parameter_mapping": {"max_tokens": "max_new_tokens"}, + } + ] + } + ) + + rails = LLMRails(config=config, llm=mock_llm) + + registered_mapping = get_llm_parameter_mapping("huggingface", "test-model") + assert registered_mapping == {"max_tokens": "max_new_tokens"} + + +def test_no_parameter_mapping_in_config(): + """Test behavior when no parameter mapping is specified.""" + mock_llm = MockHuggingFaceLLM() + + config = RailsConfig.from_content( + config={ + "models": [{"type": "main", "engine": "openai", "model": "gpt-3.5-turbo"}] + } + ) + + rails = LLMRails(config=config, llm=mock_llm) + + registered_mapping = get_llm_parameter_mapping("openai", "gpt-3.5-turbo") + assert registered_mapping is None diff --git a/tests/test_parameter_mapping.py b/tests/test_parameter_mapping.py new file mode 100644 index 000000000..4874128fb --- /dev/null +++ b/tests/test_parameter_mapping.py @@ -0,0 +1,299 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LLM parameter mapping functionality.""" + +from unittest.mock import Mock + +import pytest + +from nemoguardrails.llm.parameter_mapping import ( + PROVIDER_PARAMETER_MAPPINGS, + get_llm_provider, + register_llm_parameter_mapping, + transform_llm_params, +) + + +class MockAnthropicLLM: + """Mock Anthropic LLM for testing.""" + + __module__ = "langchain_anthropic.chat_models" + model_name = "claude-3" + + +class MockHuggingFacePipeline: + """Mock HuggingFace LLM for testing.""" + + __module__ = "langchain_huggingface.llms" + model_name = "gpt2" + + +class MockGoogleVertexAI: + """Mock Google LLM for testing.""" + + __module__ = "langchain_google_vertexai.chat_models" + model_name = "gemini-pro" + + +class MockUnknownLLM: + """Mock unknown LLM for testing.""" + + __module__ = "some_unknown_module" + model_name = "unknown-model" + + +def test_infer_provider_from_llm_anthropic(): + """Test provider inference for Anthropic models.""" + llm = MockAnthropicLLM() + provider = get_llm_provider(llm) + assert provider == "anthropic" + + +def test_infer_provider_from_llm_huggingface(): + """Test provider inference for HuggingFace models.""" + llm = MockHuggingFacePipeline() + provider = get_llm_provider(llm) + assert provider == "huggingface" + + +def test_infer_provider_from_llm_google(): + """Test provider inference for Google models.""" + llm = MockGoogleVertexAI() + provider = get_llm_provider(llm) + assert provider == "google_vertexai" + + +def test_infer_provider_from_llm_unknown(): + """Test provider inference for unknown models.""" + llm = MockUnknownLLM() + provider = get_llm_provider(llm) + assert provider is None + + +def test_transform_llm_params_empty(): + """Test transformation with empty parameters.""" + result = transform_llm_params({}) + assert result == {} + + +def test_transform_llm_params_none(): + """Test transformation with None parameters.""" + result = transform_llm_params(None) + assert result is None + + +def test_transform_llm_params_custom_mapping(): + """Test transformation with custom parameter mapping.""" + params = {"max_tokens": 100, "temperature": 0.7, "top_p": 0.9} + mapping = {"max_tokens": "max_new_tokens", "temperature": "temp", "top_p": None} + + result = transform_llm_params(params, parameter_mapping=mapping) + + expected = {"max_new_tokens": 100, "temp": 0.7} + assert result == expected + + +def test_transform_llm_params_anthropic_builtin(): + """Test transformation with built-in Anthropic mapping.""" + llm = MockAnthropicLLM() + params = {"max_tokens": 100, "temperature": 0.7} + provider = get_llm_provider(llm) + + result = transform_llm_params(params, provider=provider, model_name=llm.model_name) + + expected = {"max_tokens": 100, "temperature": 0.7} + assert result == expected + + +def test_transform_llm_params_huggingface_builtin(): + """Test transformation with built-in HuggingFace mapping.""" + llm = MockHuggingFacePipeline() + params = {"max_tokens": 50, "temperature": 0.5, "top_p": 0.8} + provider = get_llm_provider(llm) + + result = transform_llm_params(params, provider=provider, model_name=llm.model_name) + + expected = {"max_new_tokens": 50, "temperature": 0.5, "top_p": 0.8} + assert result == expected + + +def test_transform_llm_params_google_builtin(): + """Test transformation with built-in Google mapping.""" + llm = MockGoogleVertexAI() + params = {"max_tokens": 200, "temperature": 1.0} + provider = get_llm_provider(llm) + + result = transform_llm_params(params, provider=provider, model_name=llm.model_name) + + expected = {"max_output_tokens": 200, "temperature": 1.0} + assert result == expected + + +def test_transform_llm_params_unknown_provider(): + """Test transformation with unknown provider returns unchanged params.""" + llm = MockUnknownLLM() + params = {"max_tokens": 100, "temperature": 0.7} + provider = get_llm_provider(llm) + + result = transform_llm_params(params, provider=provider, model_name=llm.model_name) + + assert result == params + + +def test_transform_llm_params_partial_mapping(): + """Test transformation with partial parameter mapping.""" + params = {"max_tokens": 100, "temperature": 0.7, "top_p": 0.9, "stop": ["END"]} + mapping = {"max_tokens": "max_length"} + + result = transform_llm_params(params, parameter_mapping=mapping) + + expected = {"max_length": 100, "temperature": 0.7, "top_p": 0.9, "stop": ["END"]} + assert result == expected + + +def test_transform_llm_params_drop_parameter(): + """Test dropping parameters by mapping to None.""" + params = {"max_tokens": 100, "temperature": 0.7, "unsupported_param": "value"} + mapping = {"unsupported_param": None} + + result = transform_llm_params(params, parameter_mapping=mapping) + + expected = {"max_tokens": 100, "temperature": 0.7} + assert result == expected + + +def test_provider_parameter_mappings_structure(): + """Test that provider mappings have expected structure.""" + assert "huggingface" in PROVIDER_PARAMETER_MAPPINGS + assert "google_vertexai" in PROVIDER_PARAMETER_MAPPINGS + + assert PROVIDER_PARAMETER_MAPPINGS["huggingface"]["max_tokens"] == "max_new_tokens" + assert ( + PROVIDER_PARAMETER_MAPPINGS["google_vertexai"]["max_tokens"] + == "max_output_tokens" + ) + + +def test_custom_mapping_overrides_builtin(): + """Test that custom mapping overrides built-in provider mapping.""" + params = {"max_tokens": 100} + custom_mapping = {"max_tokens": "custom_max_tokens"} + + result = transform_llm_params(params, parameter_mapping=custom_mapping) + + expected = {"custom_max_tokens": 100} + assert result == expected + + +def test_registered_mapping_used_in_transform(): + """Test that registered mapping is used automatically in transform_llm_params.""" + llm = MockUnknownLLM() + params = {"max_tokens": 100, "temperature": 0.7} + mapping = {"max_tokens": "max_length", "temperature": "temp"} + provider = get_llm_provider(llm) + + result = transform_llm_params(params, provider=provider, model_name=llm.model_name) + assert result == params + + register_llm_parameter_mapping(provider, llm.model_name, mapping) + + result = transform_llm_params(params, provider=provider, model_name=llm.model_name) + expected = {"max_length": 100, "temp": 0.7} + assert result == expected + + +def test_registered_mapping_overrides_builtin(): + """Test that registered mapping overrides built-in provider mapping.""" + llm = MockHuggingFacePipeline() + params = {"max_tokens": 100} + provider = get_llm_provider(llm) + + result = transform_llm_params(params, provider=provider, model_name=llm.model_name) + assert result == {"max_new_tokens": 100} + + custom_mapping = {"max_tokens": "custom_max_tokens"} + register_llm_parameter_mapping(provider, llm.model_name, custom_mapping) + + result = transform_llm_params(params, provider=provider, model_name=llm.model_name) + assert result == {"custom_max_tokens": 100} + + +def test_infer_provider_community_models(): + """Test provider inference for community models.""" + + class MockCommunityOllama: + __module__ = "langchain_community.chat_models.ollama" + + class MockCommunityGooglePalm: + __module__ = "langchain_community.chat_models.google_palm" + + ollama_llm = MockCommunityOllama() + provider = get_llm_provider(ollama_llm) + assert provider == "ollama" + + palm_llm = MockCommunityGooglePalm() + provider = get_llm_provider(palm_llm) + assert provider == "google_palm" + + +def test_infer_provider_google_variants(): + """Test provider inference for different Google provider variants.""" + + class MockGoogleGenAI: + __module__ = "langchain_google_genai.chat_models" + + class MockGoogleVertexAI: + __module__ = "langchain_google_vertexai.chat_models" + + genai_llm = MockGoogleGenAI() + provider = get_llm_provider(genai_llm) + assert provider == "google_genai" + + vertexai_llm = MockGoogleVertexAI() + provider = get_llm_provider(vertexai_llm) + assert provider == "google_vertexai" + + +def test_transform_params_google_genai(): + """Test parameter transformation for google_genai provider.""" + + class MockGoogleGenAI: + __module__ = "langchain_google_genai.chat_models" + + llm = MockGoogleGenAI() + params = {"max_tokens": 150, "temperature": 0.8} + + result = transform_llm_params(params, llm) + + expected = {"max_tokens": 150, "temperature": 0.8} + assert result == expected + + +def test_transform_params_google_vertexai(): + """Test parameter transformation for google_vertexai provider.""" + + class MockGoogleVertexAI: + __module__ = "langchain_google_vertexai.chat_models" + model_name = "gemini-pro" + + llm = MockGoogleVertexAI() + params = {"max_tokens": 200, "temperature": 0.9} + provider = get_llm_provider(llm) + + result = transform_llm_params(params, provider=provider, model_name=llm.model_name) + + expected = {"max_output_tokens": 200, "temperature": 0.9} + assert result == expected