Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions aider/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,33 +962,38 @@ def get_io(pretty):
model_overrides[model_name] = {}
model_overrides[model_name].update(tags)

# Parse model names with suffixes and apply overrides
def parse_model_with_suffix(model_name, overrides):
"""Parse model name with optional :suffix and apply overrides."""
# Build an index from full "base:suffix" names to (base_model, override_dict)
# so we don't have to parse/split user-provided model names at runtime.
override_index = {}
for base_model, suffixes in model_overrides.items():
if not isinstance(suffixes, dict):
continue
for suffix, cfg in suffixes.items():
if not isinstance(cfg, dict):
continue
full_name = f"{base_model}:{suffix}"
# Later entries override earlier ones
override_index[full_name] = (base_model, cfg)

def apply_model_overrides(model_name):
"""Return (effective_model_name, override_kwargs) for a given model_name.

If model_name exactly matches a configured "base:suffix" override, we
switch to the base model and apply that override dict. Otherwise we
leave the name unchanged and return empty overrides.
"""
if not model_name:
return model_name, {}
entry = override_index.get(model_name)
if not entry:
return model_name, {}
base_model, cfg = entry
return base_model, cfg.copy()

# Split on last colon to get model name and suffix
if ":" in model_name:
base_model, suffix = model_name.rsplit(":", 1)
else:
base_model, suffix = model_name, None

# Apply overrides if suffix exists
override_kwargs = {}
if suffix and base_model in overrides and suffix in overrides[base_model]:
override_kwargs = overrides[base_model][suffix].copy()

return base_model, override_kwargs

# Parse main model
main_model_name, main_model_overrides = parse_model_with_suffix(args.model, model_overrides)
weak_model_name, weak_model_overrides = parse_model_with_suffix(
args.weak_model, model_overrides
)
editor_model_name, editor_model_overrides = parse_model_with_suffix(
args.editor_model, model_overrides
)
# Apply overrides (if any) to the selected models
main_model_name, main_model_overrides = apply_model_overrides(args.model)
weak_model_name, weak_model_overrides = apply_model_overrides(args.weak_model)
editor_model_name, editor_model_overrides = apply_model_overrides(args.editor_model)

# Create weak model if specified with overrides
weak_model_obj = None
Expand Down
97 changes: 97 additions & 0 deletions tests/basic/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,103 @@ async def test_model_precedence(self):
del os.environ["ANTHROPIC_API_KEY"]
del os.environ["OPENAI_API_KEY"]

async def test_model_overrides_suffix_applied(self):
with GitTemporaryDirectory() as git_dir:
git_dir = Path(git_dir)
overrides_file = git_dir / ".aider.model.overrides.yml"
overrides_file.write_text("gpt-4o:\n fast:\n temperature: 0.1\n")

with (
patch("aider.models.Model") as MockModel,
patch("aider.coders.Coder.create") as MockCoder,
):
mock_coder_instance = MagicMock()
MockCoder.return_value = mock_coder_instance

mock_instance = MockModel.return_value
mock_instance.info = {}
mock_instance.name = "gpt-4o"
mock_instance.validate_environment.return_value = {
"missing_keys": [],
"keys_in_environment": [],
}
mock_instance.accepts_settings = []
mock_instance.weak_model_name = None
mock_instance.get_weak_model.return_value = None

await main(
["--model", "gpt-4o:fast", "--exit", "--yes", "--no-git"],
input=DummyInput(),
output=DummyOutput(),
force_git_root=git_dir,
)

# Find the call that constructed the main model with overrides
matched_call_found = False
for call_args in MockModel.call_args_list:
args, kwargs = call_args
if (
args
and args[0] == "gpt-4o"
and kwargs.get("override_kwargs") == {"temperature": 0.1}
):
matched_call_found = True
break

self.assertTrue(
matched_call_found,
(
"Expected a Model call with base name 'gpt-4o' and override_kwargs"
" {'temperature': 0.1}"
),
)

async def test_model_overrides_no_match_preserves_model_name(self):
with GitTemporaryDirectory() as git_dir:
git_dir = Path(git_dir)

with (
patch("aider.models.Model") as MockModel,
patch("aider.coders.Coder.create") as MockCoder,
):
mock_coder_instance = MagicMock()
MockCoder.return_value = mock_coder_instance

mock_instance = MockModel.return_value
mock_instance.info = {}
mock_instance.name = "test-model"
mock_instance.validate_environment.return_value = {
"missing_keys": [],
"keys_in_environment": [],
}
mock_instance.accepts_settings = []
mock_instance.weak_model_name = None
mock_instance.get_weak_model.return_value = None

model_name = "hf:moonshotai/Kimi-K2-Thinking"

await main(
["--model", model_name, "--exit", "--yes", "--no-git"],
input=DummyInput(),
output=DummyOutput(),
force_git_root=git_dir,
)

matched_call_found = False
for call_args in MockModel.call_args_list:
args, kwargs = call_args
if args and args[0] == model_name and kwargs.get("override_kwargs") == {}:
matched_call_found = True
break

self.assertTrue(
matched_call_found,
(
"Expected a Model call with the full model name preserved and empty"
" override_kwargs"
),
)

async def test_chat_language_spanish(self):
with GitTemporaryDirectory():
coder = await main(
Expand Down
Loading