From 5883fede06bc9cb308eddd36a7d1665738460fff Mon Sep 17 00:00:00 2001 From: Allen Jia Date: Wed, 29 Oct 2025 22:37:29 +0000 Subject: [PATCH 1/2] add dummy get_input_embeddings to fix vllm model type check Signed-off-by: Allen Jia --- tests/test_vllm_wrapper.py | 21 +++++++++++++++++++++ tpu_inference/models/common/model_loader.py | 12 ++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 tests/test_vllm_wrapper.py diff --git a/tests/test_vllm_wrapper.py b/tests/test_vllm_wrapper.py new file mode 100644 index 000000000..61fa54ffc --- /dev/null +++ b/tests/test_vllm_wrapper.py @@ -0,0 +1,21 @@ +from tpu_inference.models.common.model_loader import register_model + + +class DummyModel: + def __init__(self, vllm_config=None): pass + def __call__(self, kv_caches=None, input_ids=None, attention_metadata=None): pass + +def test_vllm_wrapper_has_required_methods(): + register_model("DummyForCausalLM", DummyModel) + + from vllm.model_executor.models.registry import ModelRegistry + wrapper_cls = ModelRegistry.models.get("DummyForCausalLM").model_cls + assert hasattr(wrapper_cls, "get_input_embeddings") + m = wrapper_cls() + try: + m.get_input_embeddings(input_ids=None, positions=None, inputs_embeds=None) + except NotImplementedError: + pass + + from vllm.model_executor.models.interfaces_base import is_vllm_model + assert is_vllm_model(wrapper_cls) diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index 8c24ced9b..b33089715 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -421,6 +421,17 @@ def unimplemented_forward( "This is a JAX model and does not implement the PyTorch forward method." ) + # Same as `forward`, this is a dummy method to satisfy vLLM's type checks. + def unimplemented_get_input_embeddings( + self, + input_ids: "torch.Tensor", + positions: "torch.Tensor", + inputs_embeds: Optional["torch.Tensor"] = None, + ) -> "torch.Tensor": + raise NotImplementedError( + "This is a JAX model and does not implement the PyTorch get_input_embeddings method." + ) + # We need a custom __init__ that only calls torch.nn.Module's init, # to avoid triggering JAX logic when vLLM inspects the class. def wrapper_init(self, *args, **kwargs): @@ -434,6 +445,7 @@ def wrapper_init(self, *args, **kwargs): { "__init__": wrapper_init, "forward": unimplemented_forward, + "get_input_embeddings": unimplemented_get_input_embeddings, # Prevent vLLM from trying to load weights into this dummy class. "load_weights": lambda self, *args, **kwargs: None, }) From 85606aa7760dac61c5dff2fccad16bd2d55a5b9b Mon Sep 17 00:00:00 2001 From: Allen Jia Date: Wed, 5 Nov 2025 18:22:47 +0000 Subject: [PATCH 2/2] add test to test_model_loader.py Signed-off-by: Allen Jia --- tests/models/common/test_model_loader.py | 4 ++++ tests/test_vllm_wrapper.py | 21 --------------------- 2 files changed, 4 insertions(+), 21 deletions(-) delete mode 100644 tests/test_vllm_wrapper.py diff --git a/tests/models/common/test_model_loader.py b/tests/models/common/test_model_loader.py index c667e6ba4..dd1e1277e 100644 --- a/tests/models/common/test_model_loader.py +++ b/tests/models/common/test_model_loader.py @@ -201,6 +201,10 @@ def test_register_model_vllm_wrapper_methods(): with pytest.raises(NotImplementedError, match="JAX model"): instance.forward(input_ids=None, positions=None) + # `get_input_embeddings` should be unimplemented. + with pytest.raises(NotImplementedError, match="JAX model"): + instance.get_input_embeddings(input_ids=None, positions=None) + # `load_weights` should be a no-op that returns None. assert instance.load_weights() is None diff --git a/tests/test_vllm_wrapper.py b/tests/test_vllm_wrapper.py deleted file mode 100644 index 61fa54ffc..000000000 --- a/tests/test_vllm_wrapper.py +++ /dev/null @@ -1,21 +0,0 @@ -from tpu_inference.models.common.model_loader import register_model - - -class DummyModel: - def __init__(self, vllm_config=None): pass - def __call__(self, kv_caches=None, input_ids=None, attention_metadata=None): pass - -def test_vllm_wrapper_has_required_methods(): - register_model("DummyForCausalLM", DummyModel) - - from vllm.model_executor.models.registry import ModelRegistry - wrapper_cls = ModelRegistry.models.get("DummyForCausalLM").model_cls - assert hasattr(wrapper_cls, "get_input_embeddings") - m = wrapper_cls() - try: - m.get_input_embeddings(input_ids=None, positions=None, inputs_embeds=None) - except NotImplementedError: - pass - - from vllm.model_executor.models.interfaces_base import is_vllm_model - assert is_vllm_model(wrapper_cls)