Skip to content

Commit ba1fcd8

Browse files
authored
[TPU] add tpu_inference (vllm-project#27277)
Signed-off-by: Johnny Yang <johnnyyang@google.com>
1 parent 56539cd commit ba1fcd8

File tree

4 files changed

+5
-13
lines changed

4 files changed

+5
-13
lines changed

requirements/tpu.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,4 @@ ray[data]
1212
setuptools==78.1.0
1313
nixl==0.3.0
1414
tpu_info==0.4.0
15-
16-
# Install torch_xla
17-
torch_xla[tpu, pallas]==2.8.0
15+
tpu-inference==0.11.1

vllm/distributed/device_communicators/tpu_communicator.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,3 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
9797
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
9898
assert dim == -1, "TPUs only support dim=-1 for all-gather."
9999
return xm.all_gather(input_, dim=dim)
100-
101-
102-
if USE_TPU_INFERENCE:
103-
from tpu_inference.distributed.device_communicators import (
104-
TpuCommunicator as TpuInferenceCommunicator,
105-
)
106-
107-
TpuCommunicator = TpuInferenceCommunicator # type: ignore

vllm/platforms/tpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,9 @@ def check_max_model_len(cls, max_model_len: int) -> int:
267267

268268

269269
try:
270-
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
270+
from tpu_inference.platforms.tpu_platforms import (
271+
TpuPlatform as TpuInferencePlatform,
272+
)
271273

272274
TpuPlatform = TpuInferencePlatform # type: ignore
273275
USE_TPU_INFERENCE = True

vllm/v1/worker/tpu_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,6 @@ def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
346346

347347

348348
if USE_TPU_INFERENCE:
349-
from tpu_inference.worker import TPUWorker as TpuInferenceWorker
349+
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
350350

351351
TPUWorker = TpuInferenceWorker # type: ignore

0 commit comments

Comments
 (0)