Skip to content

Commit 80ab177

Browse files
committed
add dummy get_input_embeddings to fix vllm model type check
Signed-off-by: Allen Jia <kuafou@gmail.com>
1 parent e8620b3 commit 80ab177

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

tests/test_vllm_wrapper.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from tpu_inference.models.common.model_loader import register_model
2+
3+
4+
class DummyModel:
5+
def __init__(self, vllm_config=None): pass
6+
def __call__(self, kv_caches=None, input_ids=None, attention_metadata=None): pass
7+
8+
def test_vllm_wrapper_has_required_methods():
9+
register_model("DummyForCausalLM", DummyModel)
10+
11+
from vllm.model_executor.models.registry import ModelRegistry
12+
wrapper_cls = ModelRegistry.models.get("DummyForCausalLM").model_cls
13+
assert hasattr(wrapper_cls, "get_input_embeddings")
14+
m = wrapper_cls()
15+
try:
16+
m.get_input_embeddings(input_ids=None, positions=None, inputs_embeds=None)
17+
except NotImplementedError:
18+
pass
19+
20+
from vllm.model_executor.models.interfaces_base import is_vllm_model
21+
assert is_vllm_model(wrapper_cls)

tpu_inference/models/common/model_loader.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,17 @@ def unimplemented_forward(
410410
"This is a JAX model and does not implement the PyTorch forward method."
411411
)
412412

413+
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
414+
def unimplemented_get_input_embeddings(
415+
self,
416+
input_ids: "torch.Tensor",
417+
positions: "torch.Tensor",
418+
inputs_embeds: Optional["torch.Tensor"] = None,
419+
) -> "torch.Tensor":
420+
raise NotImplementedError(
421+
"This is a JAX model and does not implement the PyTorch get_input_embeddings method."
422+
)
423+
413424
# We need a custom __init__ that only calls torch.nn.Module's init,
414425
# to avoid triggering JAX logic when vLLM inspects the class.
415426
def wrapper_init(self, *args, **kwargs):
@@ -423,6 +434,7 @@ def wrapper_init(self, *args, **kwargs):
423434
{
424435
"__init__": wrapper_init,
425436
"forward": unimplemented_forward,
437+
"get_input_embeddings": unimplemented_get_input_embeddings,
426438
# Prevent vLLM from trying to load weights into this dummy class.
427439
"load_weights": lambda self, *args, **kwargs: None,
428440
})

0 commit comments

Comments
 (0)