Skip to content

Commit 494d22d

Browse files
committed
add dummy get_input_embeddings to fix vllm model type check
Signed-off-by: Allen Jia <kuafou@gmail.com>
1 parent 60c14f5 commit 494d22d

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

tests/test_vllm_wrapper.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from tpu_inference.models.common.model_loader import register_model
2+
3+
4+
class DummyModel:
5+
def __init__(self, vllm_config=None): pass
6+
def __call__(self, kv_caches=None, input_ids=None, attention_metadata=None): pass
7+
8+
def test_vllm_wrapper_has_required_methods():
9+
register_model("DummyForCausalLM", DummyModel)
10+
11+
from vllm.model_executor.models.registry import ModelRegistry
12+
wrapper_cls = ModelRegistry.models.get("DummyForCausalLM").model_cls
13+
assert hasattr(wrapper_cls, "get_input_embeddings")
14+
m = wrapper_cls()
15+
try:
16+
m.get_input_embeddings(input_ids=None, positions=None, inputs_embeds=None)
17+
except NotImplementedError:
18+
pass
19+
20+
from vllm.model_executor.models.interfaces_base import is_vllm_model
21+
assert is_vllm_model(wrapper_cls)

tpu_inference/models/common/model_loader.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,17 @@ def unimplemented_forward(
415415
"This is a JAX model and does not implement the PyTorch forward method."
416416
)
417417

418+
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
419+
def unimplemented_get_input_embeddings(
420+
self,
421+
input_ids: "torch.Tensor",
422+
positions: "torch.Tensor",
423+
inputs_embeds: Optional["torch.Tensor"] = None,
424+
) -> "torch.Tensor":
425+
raise NotImplementedError(
426+
"This is a JAX model and does not implement the PyTorch get_input_embeddings method."
427+
)
428+
418429
# We need a custom __init__ that only calls torch.nn.Module's init,
419430
# to avoid triggering JAX logic when vLLM inspects the class.
420431
def wrapper_init(self, *args, **kwargs):
@@ -428,6 +439,7 @@ def wrapper_init(self, *args, **kwargs):
428439
{
429440
"__init__": wrapper_init,
430441
"forward": unimplemented_forward,
442+
"get_input_embeddings": unimplemented_get_input_embeddings,
431443
# Prevent vLLM from trying to load weights into this dummy class.
432444
"load_weights": lambda self, *args, **kwargs: None,
433445
})

0 commit comments

Comments
 (0)