Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
2f415a6
feat: added smaller qwen models for debugging
guicho271828 Sep 4, 2025
36202ba
feat(vllm): copied from huggingface
guicho271828 Aug 28, 2025
43dc0e8
fix(vllm): remove alora and cache
guicho271828 Aug 28, 2025
153f58b
fix(vllm): remove tool calls
guicho271828 Sep 2, 2025
1abb626
fix(vllm): finished the implementation with limited functionality: fr…
guicho271828 Aug 28, 2025
6b703d1
fix(vllm): passing mypy and linter
guicho271828 Sep 5, 2025
756e750
fix(vllm): added vllm optional dep in pyproject.toml
guicho271828 Aug 28, 2025
1db23e5
feat(vllm test): copied from huggingface
guicho271828 Sep 2, 2025
7483077
fix(vllm test): implemented the test
guicho271828 Sep 4, 2025
19a2adb
test: require V0 in vllm test
guicho271828 Sep 5, 2025
b8b41d4
refactor: ctx to chat conversion function
guicho271828 Sep 10, 2025
68925dd
refactor: use_alora function
guicho271828 Sep 24, 2025
7e77f0e
refactor: moved _extract_model_tool_requests to mellea.backends.utils
guicho271828 Sep 10, 2025
a972c17
feat(vllm): added tool calls
guicho271828 Sep 10, 2025
cc0ff3d
test(tools): run test with mistral
guicho271828 Sep 10, 2025
974c2c2
fix(vllm): rename model_options -> engine_args
guicho271828 Sep 19, 2025
1e2ec28
fix(vllm): use FancyLogger
guicho271828 Sep 19, 2025
e72b7e9
fix(vllm): ignore type checking for vllm and msgspec
guicho271828 Sep 24, 2025
0db3171
fix(vllm): fixed the backend name in the log
guicho271828 Sep 24, 2025
c1ebd6d
feat(vllm): asynchronous call support
guicho271828 Sep 24, 2025
048e90d
test(vllm): asynchronous call support
guicho271828 Sep 24, 2025
d274720
fix(vllm): avoid unnecessary incremental processing in non-streaming …
guicho271828 Sep 24, 2025
30edeee
feat(sglang): copied from vllm
guicho271828 Sep 24, 2025
45af48a
fix(sglang): pyproject.toml
guicho271828 Sep 24, 2025
a56d3c9
fix(sglang): made it work
guicho271828 Sep 24, 2025
9b55985
fix(sglang): fixed the backend name in the log
guicho271828 Sep 24, 2025
1bc6efd
test(sglang): copied and modified from vllm tests
guicho271828 Sep 24, 2025
9b1cae6
fix(sglang): sglang requires nest_asyncio patch
guicho271828 Sep 24, 2025
cc9a691
feat(sglang): asynchronous support
guicho271828 Sep 24, 2025
00fd955
test(sglang): asynchronous support
guicho271828 Sep 24, 2025
6a40d50
fix(sglang): ignore streaming option
guicho271828 Sep 24, 2025
e194bf9
docs(async): explain the asynchronous streaming steps
guicho271828 Sep 25, 2025
4f28307
debug
Sep 25, 2025
f92bd9b
debug
Sep 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 22 additions & 75 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
add_tools_from_context_actions,
add_tools_from_model_options,
convert_tools_to_json,
parse_tools,
)
from mellea.backends.types import ModelOption
from mellea.backends.utils import extract_model_tool_requests, to_chat, use_alora
from mellea.helpers.async_helpers import send_to_queue
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib.base import (
Expand Down Expand Up @@ -196,24 +196,22 @@ def generate_from_context(
# Upsert model options.
model_opts = self._simplify_and_merge(model_options)

# See `docs/dev/requirement_aLoRA_rerouting.md` for an explanation of the following code block.
if issubclass(type(action), Requirement):
# The general rule is that we reroute to the alora if it exists.
reroute_to_alora = self.get_alora("constraint") is not None
# However, there are some exceptions:
if not self.default_to_constraint_checking_alora:
reroute_to_alora = False
if issubclass(type(action), LLMaJRequirement):
reroute_to_alora = False
if issubclass(type(action), ALoraRequirement):
reroute_to_alora = True
if reroute_to_alora:
return self._generate_from_context_alora(
action, ctx, format=format, model_options=model_opts
)
return self._generate_from_context_standard(
action, ctx, format=format, model_options=model_opts, tool_calls=tool_calls
)
if use_alora(
action,
self.get_alora("constraint"),
self.default_to_constraint_checking_alora,
):
return self._generate_from_context_alora(
action, ctx, format=format, model_options=model_opts
)
else:
return self._generate_from_context_standard(
action,
ctx,
format=format,
model_options=model_opts,
tool_calls=tool_calls,
)

def _generate_from_context_alora(
self,
Expand Down Expand Up @@ -275,35 +273,8 @@ def _generate_from_context_standard(
# If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template.
# Otherwise, we will linearize the context and treat it as a raw input.
if ctx.is_chat_context:
linearized_ctx = ctx.render_for_generation()
assert linearized_ctx is not None, (
"If ctx.is_chat_context, then the context should be linearizable."
)
ctx_as_message_list: list[Message] = self.formatter.to_chat_messages(
linearized_ctx
)
# add action
ctx_as_message_list.extend(self.formatter.to_chat_messages([action]))
ctx_as_conversation = [
{"role": m.role, "content": m.content} for m in ctx_as_message_list
]

# Check that we ddin't accidentally end up with CBlocks.
for msg in ctx_as_conversation:
for v in msg.values():
if "CBlock" in v:
FancyLogger.get_logger().error(
f"Found the string `CBlock` in what should've been a stringified context: {ctx_as_conversation}"
)

# handle custom system prompts. It's important that we do this before the _parse_and_**clean**_model_options step.
system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, None)
if system_prompt is not None:
system_msg: dict[str, str] = {
"role": "system",
"content": system_prompt,
}
ctx_as_conversation.insert(0, system_msg)
ctx_as_chat = to_chat(action, ctx, self.formatter, system_prompt)

# Append tool call information if applicable.
tools: dict[str, Callable] = dict()
Expand All @@ -328,7 +299,7 @@ def _generate_from_context_standard(
set_seed(seed)

input_ids = self._tokenizer.apply_chat_template( # type: ignore
ctx_as_conversation,
ctx_as_chat,
tools=convert_tools_to_json(tools), # type: ignore
return_tensors="pt",
**self._make_backend_specific_and_remove(model_options),
Expand Down Expand Up @@ -388,7 +359,7 @@ def _generate_from_context_standard(
)

output = ModelOutputThunk(None)
output._context = linearized_ctx
output._context = ctx.render_for_generation()
output._action = action
output._model_options = model_options

Expand All @@ -397,7 +368,7 @@ def _generate_from_context_standard(
output._process = functools.partial(self.processing, input_ids=input_ids)
output._post_process = functools.partial(
self.post_processing,
conversation=ctx_as_conversation,
conversation=ctx_as_chat,
input_ids=input_ids,
tool_calls=tool_calls,
tools=tools,
Expand Down Expand Up @@ -486,7 +457,7 @@ async def post_processing(

# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
if format is None and tool_calls:
mot.tool_calls = self._extract_model_tool_requests(tools, mot.value)
mot.tool_calls = extract_model_tool_requests(tools, mot.value)

assert mot._action is not None, (
"ModelOutputThunks should have their action assigned during generation"
Expand Down Expand Up @@ -667,30 +638,6 @@ def _make_backend_specific_and_remove(
)
return ModelOption.remove_special_keys(backend_specific)

def _extract_model_tool_requests(
self, tools: dict[str, Callable], decoded_result: str
) -> dict[str, ModelToolCall] | None:
model_tool_calls: dict[str, ModelToolCall] = dict()
for tool_name, tool_args in parse_tools(decoded_result):
func = tools.get(tool_name)
if func is None:
FancyLogger.get_logger().warning(
f"model attempted to call a non-existing function: {tool_name}"
)
continue

# Clean up the function args slightly. Some models seem to
# hallucinate parameters when none are required.
sig = inspect.signature(func)
if len(sig.parameters) == 0:
tool_args = {}

model_tool_calls[tool_name] = ModelToolCall(tool_name, func, tool_args)

if len(model_tool_calls) > 0:
return model_tool_calls
return None

# region ALora loading, unloading, and utility functions.
def add_alora(self, alora: HFAlora):
"""Loads an ALora for this backend.
Expand Down
4 changes: 4 additions & 0 deletions mellea/backends/model_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ class ModelIdentifier:
#### Qwen models ####
#####################

QWEN3_0_6B = ModelIdentifier(hf_model_name="Qwen/Qwen3-0.6B", ollama_name="qwen3:0.6b")

QWEN3_1_7B = ModelIdentifier(hf_model_name="Qwen/Qwen3-1.7B", ollama_name="qwen3:1.7b")

QWEN3_8B = ModelIdentifier(hf_model_name="Qwen/Qwen3-8B", ollama_name="qwen3:8b")

QWEN3_14B = ModelIdentifier(hf_model_name="Qwen/Qwen3-14B", ollama_name="qwen3:14b")
Expand Down
Loading