Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/models/common/test_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ def test_register_model_vllm_wrapper_methods():
with pytest.raises(NotImplementedError, match="JAX model"):
instance.forward(input_ids=None, positions=None)

# `get_input_embeddings` should be unimplemented.
with pytest.raises(NotImplementedError, match="JAX model"):
instance.get_input_embeddings(input_ids=None, positions=None)

# `load_weights` should be a no-op that returns None.
assert instance.load_weights() is None

Expand Down
12 changes: 12 additions & 0 deletions tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,17 @@ def unimplemented_forward(
"This is a JAX model and does not implement the PyTorch forward method."
)

# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
def unimplemented_get_input_embeddings(
self,
input_ids: "torch.Tensor",
positions: "torch.Tensor",
inputs_embeds: Optional["torch.Tensor"] = None,
) -> "torch.Tensor":
raise NotImplementedError(
"This is a JAX model and does not implement the PyTorch get_input_embeddings method."
)

# We need a custom __init__ that only calls torch.nn.Module's init,
# to avoid triggering JAX logic when vLLM inspects the class.
def wrapper_init(self, *args, **kwargs):
Expand All @@ -434,6 +445,7 @@ def wrapper_init(self, *args, **kwargs):
{
"__init__": wrapper_init,
"forward": unimplemented_forward,
"get_input_embeddings": unimplemented_get_input_embeddings,
# Prevent vLLM from trying to load weights into this dummy class.
"load_weights": lambda self, *args, **kwargs: None,
})
Expand Down