Skip to content

Conversation

@kuafou
Copy link

@kuafou kuafou commented Oct 29, 2025

Description

This PR fixes a compatibility issue with recent vLLM changes that now require model classes to implement a get_input_embeddings() method.
Without this method, vLLM fails its interface validation during model registration, breaking TPU model integration.

To address this, we add a dummy get_input_embeddings() implementation to the vLLM-compatible wrapper class in
tpu_inference/models/common/model_loader.py.
Similar to the existing dummy forward() method, this implementation only satisfies vLLM’s type checks and raises
NotImplementedError if invoked. This prevents JAX model initialization during import or introspection.

Why this change is needed

  • vLLM recently introduced a strict requirement for model classes to define get_input_embeddings()
    (link).
  • TPU inference uses a dummy PyTorch wrapper to register JAX models into vLLM’s registry.
  • Since this wrapper lacked get_input_embeddings, vLLM failed model registration checks.

Implementation details

  • Added unimplemented_get_input_embeddings() dummy function to the wrapper type.
  • Registered it inside the dynamically created wrapper class.
  • Added a test tests/test_vllm_wrapper.py to ensure:
    • The wrapper defines get_input_embeddings().
    • The method raises NotImplementedError.
    • The class passes is_vllm_model() validation.

Related Issue

Fixes: #951

Tests

pytest -v tests/models/common/test_model_loader.py 

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@kuafou kuafou force-pushed the qi/fix-vllm-model-wrapper branch from 79d6f2d to 80ab177 Compare October 29, 2025 23:22
@py4 py4 requested a review from karan October 30, 2025 19:30
Signed-off-by: Allen Jia <kuafou@gmail.com>
@kuafou kuafou force-pushed the qi/fix-vllm-model-wrapper branch from 80ab177 to 59e1a8e Compare November 5, 2025 18:23
Signed-off-by: Allen Jia <kuafou@gmail.com>
@kuafou kuafou force-pushed the qi/fix-vllm-model-wrapper branch from 59e1a8e to bc8fe75 Compare November 5, 2025 18:29
Copy link
Collaborator

@karan karan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: vllm model interface now requires get_input_embeddings

2 participants