@@ -415,6 +415,17 @@ def unimplemented_forward(
415415 "This is a JAX model and does not implement the PyTorch forward method."
416416 )
417417
418+ # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
419+ def unimplemented_get_input_embeddings (
420+ self ,
421+ input_ids : "torch.Tensor" ,
422+ positions : "torch.Tensor" ,
423+ inputs_embeds : Optional ["torch.Tensor" ] = None ,
424+ ) -> "torch.Tensor" :
425+ raise NotImplementedError (
426+ "This is a JAX model and does not implement the PyTorch get_input_embeddings method."
427+ )
428+
418429 # We need a custom __init__ that only calls torch.nn.Module's init,
419430 # to avoid triggering JAX logic when vLLM inspects the class.
420431 def wrapper_init (self , * args , ** kwargs ):
@@ -428,6 +439,7 @@ def wrapper_init(self, *args, **kwargs):
428439 {
429440 "__init__" : wrapper_init ,
430441 "forward" : unimplemented_forward ,
442+ "get_input_embeddings" : unimplemented_get_input_embeddings ,
431443 # Prevent vLLM from trying to load weights into this dummy class.
432444 "load_weights" : lambda self , * args , ** kwargs : None ,
433445 })
0 commit comments