[FIX] Add dummy get_input_embeddings to fix vLLM model type check #971
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR fixes a compatibility issue with recent vLLM changes that now require model classes to implement a
get_input_embeddings()method.Without this method, vLLM fails its interface validation during model registration, breaking TPU model integration.
To address this, we add a dummy
get_input_embeddings()implementation to the vLLM-compatible wrapper class intpu_inference/models/common/model_loader.py.Similar to the existing dummy
forward()method, this implementation only satisfies vLLM’s type checks and raisesNotImplementedErrorif invoked. This prevents JAX model initialization during import or introspection.Why this change is needed
get_input_embeddings()(link).
get_input_embeddings, vLLM failed model registration checks.Implementation details
unimplemented_get_input_embeddings()dummy function to the wrapper type.tests/test_vllm_wrapper.pyto ensure:get_input_embeddings().NotImplementedError.is_vllm_model()validation.Related Issue
Fixes: #951
Tests
Checklist
Before submitting this PR, please make sure: