@@ -421,6 +421,17 @@ def unimplemented_forward(
421421 "This is a JAX model and does not implement the PyTorch forward method."
422422 )
423423
424+ # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
425+ def unimplemented_get_input_embeddings (
426+ self ,
427+ input_ids : "torch.Tensor" ,
428+ positions : "torch.Tensor" ,
429+ inputs_embeds : Optional ["torch.Tensor" ] = None ,
430+ ) -> "torch.Tensor" :
431+ raise NotImplementedError (
432+ "This is a JAX model and does not implement the PyTorch get_input_embeddings method."
433+ )
434+
424435 # We need a custom __init__ that only calls torch.nn.Module's init,
425436 # to avoid triggering JAX logic when vLLM inspects the class.
426437 def wrapper_init (self , * args , ** kwargs ):
@@ -434,6 +445,7 @@ def wrapper_init(self, *args, **kwargs):
434445 {
435446 "__init__" : wrapper_init ,
436447 "forward" : unimplemented_forward ,
448+ "get_input_embeddings" : unimplemented_get_input_embeddings ,
437449 # Prevent vLLM from trying to load weights into this dummy class.
438450 "load_weights" : lambda self , * args , ** kwargs : None ,
439451 })
0 commit comments