Skip to content

Commit add0b5b

Browse files
authored
[FIX] Add dummy get_input_embeddings to fix vLLM model type check (#971)
Signed-off-by: Allen Jia <kuafou@gmail.com>
1 parent 231b2b5 commit add0b5b

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

tests/models/common/test_model_loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ def test_register_model_vllm_wrapper_methods():
201201
with pytest.raises(NotImplementedError, match="JAX model"):
202202
instance.forward(input_ids=None, positions=None)
203203

204+
# `get_input_embeddings` should be unimplemented.
205+
with pytest.raises(NotImplementedError, match="JAX model"):
206+
instance.get_input_embeddings(input_ids=None, positions=None)
207+
204208
# `load_weights` should be a no-op that returns None.
205209
assert instance.load_weights() is None
206210

tpu_inference/models/common/model_loader.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,17 @@ def unimplemented_forward(
421421
"This is a JAX model and does not implement the PyTorch forward method."
422422
)
423423

424+
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
425+
def unimplemented_get_input_embeddings(
426+
self,
427+
input_ids: "torch.Tensor",
428+
positions: "torch.Tensor",
429+
inputs_embeds: Optional["torch.Tensor"] = None,
430+
) -> "torch.Tensor":
431+
raise NotImplementedError(
432+
"This is a JAX model and does not implement the PyTorch get_input_embeddings method."
433+
)
434+
424435
# We need a custom __init__ that only calls torch.nn.Module's init,
425436
# to avoid triggering JAX logic when vLLM inspects the class.
426437
def wrapper_init(self, *args, **kwargs):
@@ -434,6 +445,7 @@ def wrapper_init(self, *args, **kwargs):
434445
{
435446
"__init__": wrapper_init,
436447
"forward": unimplemented_forward,
448+
"get_input_embeddings": unimplemented_get_input_embeddings,
437449
# Prevent vLLM from trying to load weights into this dummy class.
438450
"load_weights": lambda self, *args, **kwargs: None,
439451
})

0 commit comments

Comments
 (0)