forked from madroidmaq/mlx-omni-server
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfunction_calling_benchmark.py
More file actions
101 lines (82 loc) · 3.04 KB
/
function_calling_benchmark.py
File metadata and controls
101 lines (82 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import asyncio
import json
import weave
from datasets import load_dataset
from openai import OpenAI
from tqdm import tqdm
weave.init("mlx-omni-function-calling-benchmark")
client = OpenAI(
base_url="http://localhost:10240/v1",
# base_url="http://localhost:11434/v1",
api_key="mlx-omni-server", # not-needed
)
class FunctionCallingModel(weave.Model):
model_name: str
@weave.op()
def predict(self, messages, tools, tool_calls, target_name) -> dict:
response = client.chat.completions.create(
model=self.model_name, messages=messages, tools=tools, tool_choice="auto"
)
message = response.choices[0].message
if not tool_calls:
return {
"expected": [],
"function_name": None,
"successful": False,
"message": "no expected tool calls",
}
# function_name =target_name
if not message.tool_calls:
return {
"expected": tool_calls,
"function_name": target_name,
"successful": False,
"is_tool_call": False,
"content": message.content,
"message": f"not a tool call, message content: {message.content}",
}
actual_calls = message.tool_calls[0]
return {
"expected": tool_calls,
"actual_calls": actual_calls,
"is_tool_call": True,
"function_name": target_name,
"successful": actual_calls.function.name == target_name,
"message": f"gen tool calls, expected {target_name} but actual calls is {actual_calls.function.name}",
}
@weave.op()
def tool_call_score(output: dict) -> dict:
correct = "successful" in output and output["successful"]
is_tool_call = output["is_tool_call"]
return {"is_matched": correct, "is_tool_call": is_tool_call}
def run_eval(model_name: str):
dataset = load_dataset("madroid/glaive-function-calling-openai", split="test")
examples = []
for i, example in enumerate(
tqdm(dataset, desc="Processing examples", unit="example")
):
data = json.loads(example["json"])
if "tool_calls" in data:
expected_calls = data.get("tool_calls", [])
target_name = expected_calls[0]["function"]["name"]
examples.append(
{
"id": str(i),
"messages": data["messages"],
"tools": data["tools"],
"tool_calls": data["tool_calls"],
"target_name": target_name,
}
)
evaluation = weave.Evaluation(
name="function_call_eval",
dataset=examples,
scorers=[tool_call_score],
)
model = FunctionCallingModel(name="my_func_call_model", model_name=model_name)
results = asyncio.run(evaluation.evaluate(model))
print(results)
if __name__ == "__main__":
model_name = "mlx-community/Llama-3.2-3B-Instruct-4bit"
# model_name = "llama3.2:3b"
run_eval(model_name)