@@ -410,6 +410,17 @@ def unimplemented_forward(
410410 "This is a JAX model and does not implement the PyTorch forward method."
411411 )
412412
413+ # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
414+ def unimplemented_get_input_embeddings (
415+ self ,
416+ input_ids : "torch.Tensor" ,
417+ positions : "torch.Tensor" ,
418+ inputs_embeds : Optional ["torch.Tensor" ] = None ,
419+ ) -> "torch.Tensor" :
420+ raise NotImplementedError (
421+ "This is a JAX model and does not implement the PyTorch get_input_embeddings method."
422+ )
423+
413424 # We need a custom __init__ that only calls torch.nn.Module's init,
414425 # to avoid triggering JAX logic when vLLM inspects the class.
415426 def wrapper_init (self , * args , ** kwargs ):
@@ -423,6 +434,7 @@ def wrapper_init(self, *args, **kwargs):
423434 {
424435 "__init__" : wrapper_init ,
425436 "forward" : unimplemented_forward ,
437+ "get_input_embeddings" : unimplemented_get_input_embeddings ,
426438 # Prevent vLLM from trying to load weights into this dummy class.
427439 "load_weights" : lambda self , * args , ** kwargs : None ,
428440 })
0 commit comments