Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
trim_prompt_cache,
)
from .sample_utils import make_logits_processors, make_sampler
from .tool_call_dedup import ToolCallDedup
from .utils import _parse_size, load, sharded_load


Expand Down Expand Up @@ -1507,6 +1508,7 @@ def keepalive_callback(processed_tokens, total_tokens):
made_tool_call = False
tool_calls = []
tool_text = ""
_dedup = ToolCallDedup()
tool_idx = 0

def format_tool_call(tool_call):
Expand Down Expand Up @@ -1575,6 +1577,12 @@ def parse_tools(tool_calls):
in_tool_call = True
elif in_tool_call:
if gen.text == ctx.tool_call_end:
if _dedup.is_duplicate(tool_text):
finish_reason = "tool_calls"
ctx.stop()
tool_text = ""
in_tool_call = False
break
tool_calls.append(tool_text)
tool_text = ""
in_tool_call = False
Expand Down
46 changes: 46 additions & 0 deletions mlx_lm/tool_call_dedup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Consecutive duplicate tool call detection.

Tracks raw tool call text as it is appended during generation.
When the same text appears consecutively, signals the caller to stop
generation early — preventing degenerate loops where the model
produces identical tool calls until max_tokens.

See: https://github.com/ml-explore/mlx-lm/issues/613
"""

from __future__ import annotations

import logging

logger = logging.getLogger(__name__)

_MAX_LOG_LEN = 120


class ToolCallDedup:
"""Detect consecutive duplicate tool calls during generation.

Usage::

dedup = ToolCallDedup()
# After each tool_call_end token:
if dedup.is_duplicate(tool_text):
# stop generation
...
else:
tool_calls.append(tool_text)
"""

def __init__(self) -> None:
self._prev: str | None = None

def is_duplicate(self, tool_text: str) -> bool:
"""Return True if *tool_text* matches the previous call exactly."""
if self._prev is not None and tool_text == self._prev:
logger.warning(
"Consecutive duplicate tool call detected, stopping: %s",
tool_text[:_MAX_LOG_LEN],
)
return True
self._prev = tool_text
return False
70 changes: 70 additions & 0 deletions tests/test_tool_call_dedup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Unit tests for ToolCallDedup."""

import logging
import unittest

from mlx_lm.tool_call_dedup import ToolCallDedup


class TestToolCallDedup(unittest.TestCase):
"""Consecutive duplicate tool call detection."""

def test_first_call_never_duplicate(self):
dedup = ToolCallDedup()
self.assertFalse(dedup.is_duplicate('{"name": "run", "arguments": {}}'))

def test_different_calls_not_duplicate(self):
dedup = ToolCallDedup()
self.assertFalse(dedup.is_duplicate('{"name": "run", "arguments": {"cmd": "ls"}}'))
self.assertFalse(dedup.is_duplicate('{"name": "run", "arguments": {"cmd": "pwd"}}'))

def test_consecutive_identical_is_duplicate(self):
dedup = ToolCallDedup()
text = '{"name": "run", "arguments": {"cmd": "ls"}}'
self.assertFalse(dedup.is_duplicate(text))
self.assertTrue(dedup.is_duplicate(text))

def test_non_consecutive_identical_not_duplicate(self):
"""A-B-A pattern should NOT trigger (only consecutive)."""
dedup = ToolCallDedup()
a = '{"name": "run", "arguments": {"cmd": "ls"}}'
b = '{"name": "run", "arguments": {"cmd": "pwd"}}'
self.assertFalse(dedup.is_duplicate(a))
self.assertFalse(dedup.is_duplicate(b))
self.assertFalse(dedup.is_duplicate(a)) # not consecutive

def test_whitespace_difference_not_duplicate(self):
"""Exact text comparison — whitespace matters."""
dedup = ToolCallDedup()
self.assertFalse(dedup.is_duplicate('{"name":"run"}'))
self.assertFalse(dedup.is_duplicate('{"name": "run"}'))

def test_logs_warning_on_duplicate(self):
dedup = ToolCallDedup()
text = '{"name": "run", "arguments": {}}'
dedup.is_duplicate(text)
with self.assertLogs("mlx_lm.tool_call_dedup", level="WARNING") as cm:
dedup.is_duplicate(text)
self.assertTrue(any("duplicate" in msg.lower() for msg in cm.output))

def test_prev_not_updated_on_duplicate(self):
"""After duplicate detected, prev stays the same for next check."""
dedup = ToolCallDedup()
text = '{"name": "run", "arguments": {}}'
dedup.is_duplicate(text)
self.assertTrue(dedup.is_duplicate(text))
# Third consecutive should also be duplicate
self.assertTrue(dedup.is_duplicate(text))

def test_prev_updates_on_new_call(self):
dedup = ToolCallDedup()
a = '{"name": "a"}'
b = '{"name": "b"}'
dedup.is_duplicate(a)
dedup.is_duplicate(b)
# Now b-b should be duplicate
self.assertTrue(dedup.is_duplicate(b))


if __name__ == "__main__":
unittest.main()
Loading