diff --git a/mlx_lm/tool_parsers/json_tools.py b/mlx_lm/tool_parsers/json_tools.py index 27a9caa44..1f8a4fb19 100644 --- a/mlx_lm/tool_parsers/json_tools.py +++ b/mlx_lm/tool_parsers/json_tools.py @@ -1,6 +1,6 @@ # Copyright © 2025 Apple Inc. -import json +import json_repair tool_call_start = "" @@ -8,4 +8,4 @@ def parse_tool_call(text, tools=None): - return json.loads(text.strip()) + return json_repair.loads(text.strip()) diff --git a/setup.py b/setup.py index b4d45da6b..f6d5c575b 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "protobuf", "pyyaml", "jinja2", + "json-repair>=0.58.7", ], packages=[ "mlx_lm", diff --git a/tests/test_tool_parsing.py b/tests/test_tool_parsing.py index 08d452dc8..cd133ea0b 100644 --- a/tests/test_tool_parsing.py +++ b/tests/test_tool_parsing.py @@ -148,6 +148,74 @@ def test_parsers(self): } self.assertEqual(tool_call, expected) + def test_json_tools_repairs_malformed_json(self): + tools_multiply = [ + { + "type": "function", + "function": { + "name": "multiply", + "description": "Multiply two numbers.", + "parameters": { + "type": "object", + "required": ["a", "b"], + "properties": { + "a": {"type": "number", "description": "a is a number"}, + "b": {"type": "number", "description": "b is a number"}, + }, + }, + }, + } + ] + expected_multiply = { + "name": "multiply", + "arguments": {"a": 12234585, "b": 48838483920}, + } + malformed_multiply = [ + ( + "trailing_comma", + '{"name": "multiply", "arguments": {"a": 12234585, "b": 48838483920,},}', + ), + ( + "single_quoted", + "{'name': 'multiply', 'arguments': {'a': 12234585, 'b': 48838483920}}", + ), + ] + for label, text in malformed_multiply: + with self.subTest(case=label): + self.assertEqual( + json_tools.parse_tool_call(text, tools_multiply), + expected_multiply, + ) + + tools_temp = [ + { + "type": "function", + "function": { + "name": "get_current_temperature", + "description": "Get the current temperature.", + "parameters": { + "type": "object", + "required": ["location"], + "properties": { + "location": {"type": "str", "description": "The location."}, + }, + }, + }, + } + ] + expected_temp = { + "name": "get_current_temperature", + "arguments": {"location": "London"}, + } + text = ( + '{"name": "get_current_temperature", ' + '"arguments": {"location": "London",},}' + ) + self.assertEqual( + json_tools.parse_tool_call(text, tools_temp), + expected_temp, + ) + def test_qwen3_coder_single_quoted_params(self): tools = [ {