Skip to content

Commit 9bd0122

Browse files
committed
Add version filter for Llama parallel tool calling
Only Llama 4+ models support parallel tool calling based on testing. Parallel tool calling support: - Llama 4+ - SUPPORTED (tested and verified with real OCI API) - ALL Llama 3.x (3.0, 3.1, 3.2, 3.3) - BLOCKED - Cohere - BLOCKED (existing behavior) - Other models (xAI Grok, OpenAI, Mistral) - SUPPORTED Implementation: - Added _supports_parallel_tool_calls() helper method with regex version parsing - Updated bind_tools() to validate model version before enabling parallel calls - Provides clear error messages: "only available for Llama 4+ models" Unit tests added (8 tests, all mocked, no OCI connection): - test_version_filter_llama_3_0_blocked - test_version_filter_llama_3_1_blocked - test_version_filter_llama_3_2_blocked - test_version_filter_llama_3_3_blocked (Llama 3.3 doesn't support it either) - test_version_filter_llama_4_allowed - test_version_filter_other_models_allowed - test_version_filter_supports_parallel_tool_calls_method - Plus existing parallel tool calling tests updated to use Llama 4
1 parent cf65baa commit 9bd0122

File tree

2 files changed

+193
-7
lines changed

2 files changed

+193
-7
lines changed

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,49 @@ def _prepare_request(
12001200

12011201
return request
12021202

1203+
def _supports_parallel_tool_calls(self, model_id: str) -> bool:
1204+
"""Check if the model supports parallel tool calling.
1205+
1206+
Parallel tool calling is supported for:
1207+
- Llama 4+ only (tested and verified)
1208+
- Other GenericChatRequest models (xAI Grok, OpenAI, Mistral)
1209+
1210+
Not supported for:
1211+
- All Llama 3.x versions (3.0, 3.1, 3.2, 3.3)
1212+
- Cohere models
1213+
1214+
Args:
1215+
model_id: The model identifier (e.g., "meta.llama-4-maverick-17b-128e-instruct-fp8")
1216+
1217+
Returns:
1218+
bool: True if model supports parallel tool calling, False otherwise
1219+
"""
1220+
import re
1221+
1222+
# Extract provider from model_id (e.g., "meta" from "meta.llama-4-maverick-17b-128e-instruct-fp8")
1223+
provider = model_id.split(".")[0].lower()
1224+
1225+
# Cohere models don't support parallel tool calling
1226+
if provider == "cohere":
1227+
return False
1228+
1229+
# For Meta/Llama models, check version
1230+
if provider == "meta" and "llama" in model_id.lower():
1231+
# Extract version number (e.g., "4" from "meta.llama-4-maverick-17b-128e-instruct-fp8")
1232+
version_match = re.search(r"llama-(\d+)", model_id.lower())
1233+
if version_match:
1234+
major = int(version_match.group(1))
1235+
1236+
# Only Llama 4+ supports parallel tool calling
1237+
# Llama 3.x (including 3.3) does NOT support it based on testing
1238+
if major >= 4:
1239+
return True
1240+
1241+
return False
1242+
1243+
# Other GenericChatRequest models (xAI Grok, OpenAI, Mistral) support it
1244+
return True
1245+
12031246
def bind_tools(
12041247
self,
12051248
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
@@ -1251,6 +1294,18 @@ def bind_tools(
12511294
else self.parallel_tool_calls
12521295
)
12531296
if use_parallel:
1297+
# Validate model supports parallel tool calling
1298+
if not self._supports_parallel_tool_calls(self.model_id):
1299+
if "llama" in self.model_id.lower():
1300+
raise ValueError(
1301+
f"Parallel tool calls are not supported for {self.model_id}. "
1302+
"This feature is only available for Llama 4+ models. "
1303+
"Llama 3.x models (including 3.3) do not support parallel tool calling."
1304+
)
1305+
else:
1306+
raise ValueError(
1307+
f"Parallel tool calls are not supported for {self.model_id}."
1308+
)
12541309
kwargs["is_parallel_tool_calls"] = True
12551310

12561311
return super().bind(tools=formatted_tools, **kwargs)

libs/oci/tests/unit_tests/chat_models/test_parallel_tool_calling.py

Lines changed: 138 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_parallel_tool_calls_class_level():
1111
"""Test class-level parallel_tool_calls parameter."""
1212
oci_gen_ai_client = MagicMock()
1313
llm = ChatOCIGenAI(
14-
model_id="meta.llama-3.3-70b-instruct",
14+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
1515
parallel_tool_calls=True,
1616
client=oci_gen_ai_client
1717
)
@@ -23,7 +23,7 @@ def test_parallel_tool_calls_default_false():
2323
"""Test that parallel_tool_calls defaults to False."""
2424
oci_gen_ai_client = MagicMock()
2525
llm = ChatOCIGenAI(
26-
model_id="meta.llama-3.3-70b-instruct",
26+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
2727
client=oci_gen_ai_client
2828
)
2929
assert llm.parallel_tool_calls is False
@@ -34,7 +34,7 @@ def test_parallel_tool_calls_bind_tools_explicit_true():
3434
"""Test parallel_tool_calls=True in bind_tools."""
3535
oci_gen_ai_client = MagicMock()
3636
llm = ChatOCIGenAI(
37-
model_id="meta.llama-3.3-70b-instruct",
37+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
3838
client=oci_gen_ai_client
3939
)
4040

@@ -59,7 +59,7 @@ def test_parallel_tool_calls_bind_tools_explicit_false():
5959
"""Test parallel_tool_calls=False in bind_tools."""
6060
oci_gen_ai_client = MagicMock()
6161
llm = ChatOCIGenAI(
62-
model_id="meta.llama-3.3-70b-instruct",
62+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
6363
client=oci_gen_ai_client
6464
)
6565

@@ -81,7 +81,7 @@ def test_parallel_tool_calls_bind_tools_uses_class_default():
8181
"""Test that bind_tools uses class default when not specified."""
8282
oci_gen_ai_client = MagicMock()
8383
llm = ChatOCIGenAI(
84-
model_id="meta.llama-3.3-70b-instruct",
84+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
8585
parallel_tool_calls=True, # Set class default
8686
client=oci_gen_ai_client
8787
)
@@ -102,7 +102,7 @@ def test_parallel_tool_calls_bind_tools_overrides_class_default():
102102
"""Test that bind_tools parameter overrides class default."""
103103
oci_gen_ai_client = MagicMock()
104104
llm = ChatOCIGenAI(
105-
model_id="meta.llama-3.3-70b-instruct",
105+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
106106
parallel_tool_calls=True, # Set class default to True
107107
client=oci_gen_ai_client
108108
)
@@ -125,7 +125,7 @@ def test_parallel_tool_calls_passed_to_oci_api_meta():
125125

126126
oci_gen_ai_client = MagicMock()
127127
llm = ChatOCIGenAI(
128-
model_id="meta.llama-3.3-70b-instruct",
128+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
129129
client=oci_gen_ai_client
130130
)
131131

@@ -197,3 +197,134 @@ def tool1(x: int) -> int:
197197
stream=False,
198198
**llm_with_tools.kwargs
199199
)
200+
201+
202+
@pytest.mark.requires("oci")
203+
def test_version_filter_llama_3_0_blocked():
204+
"""Test that Llama 3.0 models are blocked from parallel tool calling."""
205+
oci_gen_ai_client = MagicMock()
206+
llm = ChatOCIGenAI(
207+
model_id="meta.llama-3-70b-instruct",
208+
client=oci_gen_ai_client
209+
)
210+
211+
def tool1(x: int) -> int:
212+
"""Tool 1."""
213+
return x + 1
214+
215+
# Should raise ValueError when trying to enable parallel tool calling
216+
with pytest.raises(ValueError, match="Llama 4\\+"):
217+
llm.bind_tools([tool1], parallel_tool_calls=True)
218+
219+
220+
@pytest.mark.requires("oci")
221+
def test_version_filter_llama_3_1_blocked():
222+
"""Test that Llama 3.1 models are blocked from parallel tool calling."""
223+
oci_gen_ai_client = MagicMock()
224+
llm = ChatOCIGenAI(
225+
model_id="meta.llama-3.1-70b-instruct",
226+
client=oci_gen_ai_client
227+
)
228+
229+
def tool1(x: int) -> int:
230+
"""Tool 1."""
231+
return x + 1
232+
233+
# Should raise ValueError
234+
with pytest.raises(ValueError, match="Llama 4\\+"):
235+
llm.bind_tools([tool1], parallel_tool_calls=True)
236+
237+
238+
@pytest.mark.requires("oci")
239+
def test_version_filter_llama_3_2_blocked():
240+
"""Test that Llama 3.2 models are blocked from parallel tool calling."""
241+
oci_gen_ai_client = MagicMock()
242+
llm = ChatOCIGenAI(
243+
model_id="meta.llama-3.2-11b-vision-instruct",
244+
client=oci_gen_ai_client
245+
)
246+
247+
def tool1(x: int) -> int:
248+
"""Tool 1."""
249+
return x + 1
250+
251+
# Should raise ValueError
252+
with pytest.raises(ValueError, match="Llama 4\\+"):
253+
llm.bind_tools([tool1], parallel_tool_calls=True)
254+
255+
256+
@pytest.mark.requires("oci")
257+
def test_version_filter_llama_3_3_blocked():
258+
"""Test that Llama 3.3 models are blocked from parallel tool calling."""
259+
oci_gen_ai_client = MagicMock()
260+
llm = ChatOCIGenAI(
261+
model_id="meta.llama-3.3-70b-instruct",
262+
client=oci_gen_ai_client
263+
)
264+
265+
def tool1(x: int) -> int:
266+
"""Tool 1."""
267+
return x + 1
268+
269+
# Should raise ValueError - Llama 3.3 doesn't actually support parallel calls
270+
with pytest.raises(ValueError, match="Llama 4\\+"):
271+
llm.bind_tools([tool1], parallel_tool_calls=True)
272+
273+
274+
@pytest.mark.requires("oci")
275+
def test_version_filter_llama_4_allowed():
276+
"""Test that Llama 4 models are allowed parallel tool calling."""
277+
oci_gen_ai_client = MagicMock()
278+
llm = ChatOCIGenAI(
279+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
280+
client=oci_gen_ai_client
281+
)
282+
283+
def tool1(x: int) -> int:
284+
"""Tool 1."""
285+
return x + 1
286+
287+
# Should NOT raise ValueError
288+
llm_with_tools = llm.bind_tools([tool1], parallel_tool_calls=True)
289+
assert llm_with_tools.kwargs.get("is_parallel_tool_calls") is True
290+
291+
292+
@pytest.mark.requires("oci")
293+
def test_version_filter_other_models_allowed():
294+
"""Test that other GenericChatRequest models are allowed parallel tool calling."""
295+
oci_gen_ai_client = MagicMock()
296+
297+
# Test with xAI Grok
298+
llm_grok = ChatOCIGenAI(
299+
model_id="xai.grok-4-fast",
300+
client=oci_gen_ai_client
301+
)
302+
303+
def tool1(x: int) -> int:
304+
"""Tool 1."""
305+
return x + 1
306+
307+
# Should NOT raise ValueError for Grok
308+
llm_with_tools = llm_grok.bind_tools([tool1], parallel_tool_calls=True)
309+
assert llm_with_tools.kwargs.get("is_parallel_tool_calls") is True
310+
311+
312+
@pytest.mark.requires("oci")
313+
def test_version_filter_supports_parallel_tool_calls_method():
314+
"""Test the _supports_parallel_tool_calls method directly."""
315+
oci_gen_ai_client = MagicMock()
316+
llm = ChatOCIGenAI(
317+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
318+
client=oci_gen_ai_client
319+
)
320+
321+
# Test various model IDs
322+
assert llm._supports_parallel_tool_calls("meta.llama-4-maverick-17b-128e-instruct-fp8") is True
323+
assert llm._supports_parallel_tool_calls("meta.llama-3.3-70b-instruct") is False # Llama 3.3 NOT supported
324+
assert llm._supports_parallel_tool_calls("meta.llama-3.2-11b-vision-instruct") is False
325+
assert llm._supports_parallel_tool_calls("meta.llama-3.1-70b-instruct") is False
326+
assert llm._supports_parallel_tool_calls("meta.llama-3-70b-instruct") is False
327+
assert llm._supports_parallel_tool_calls("cohere.command-r-plus") is False
328+
assert llm._supports_parallel_tool_calls("xai.grok-4-fast") is True
329+
assert llm._supports_parallel_tool_calls("openai.gpt-4") is True
330+
assert llm._supports_parallel_tool_calls("mistral.mistral-large") is True

0 commit comments

Comments
 (0)