diff --git a/tests/models/common/test_model_loader.py b/tests/models/common/test_model_loader.py index c667e6ba4..dd1e1277e 100644 --- a/tests/models/common/test_model_loader.py +++ b/tests/models/common/test_model_loader.py @@ -201,6 +201,10 @@ def test_register_model_vllm_wrapper_methods(): with pytest.raises(NotImplementedError, match="JAX model"): instance.forward(input_ids=None, positions=None) + # `get_input_embeddings` should be unimplemented. + with pytest.raises(NotImplementedError, match="JAX model"): + instance.get_input_embeddings(input_ids=None, positions=None) + # `load_weights` should be a no-op that returns None. assert instance.load_weights() is None diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index 8c24ced9b..b33089715 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -421,6 +421,17 @@ def unimplemented_forward( "This is a JAX model and does not implement the PyTorch forward method." ) + # Same as `forward`, this is a dummy method to satisfy vLLM's type checks. + def unimplemented_get_input_embeddings( + self, + input_ids: "torch.Tensor", + positions: "torch.Tensor", + inputs_embeds: Optional["torch.Tensor"] = None, + ) -> "torch.Tensor": + raise NotImplementedError( + "This is a JAX model and does not implement the PyTorch get_input_embeddings method." + ) + # We need a custom __init__ that only calls torch.nn.Module's init, # to avoid triggering JAX logic when vLLM inspects the class. def wrapper_init(self, *args, **kwargs): @@ -434,6 +445,7 @@ def wrapper_init(self, *args, **kwargs): { "__init__": wrapper_init, "forward": unimplemented_forward, + "get_input_embeddings": unimplemented_get_input_embeddings, # Prevent vLLM from trying to load weights into this dummy class. "load_weights": lambda self, *args, **kwargs: None, })