Skip to content

Commit 0c66fde

Browse files
authored
Fix lora layers (#1068)
Signed-off-by: Richard Liu <ricliu@google.com>
1 parent 9e29186 commit 0c66fde

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

tpu_inference/layers/vllm/sharding.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import os
2+
13
import jax
24
import jax.numpy as jnp
35
import torch
46
import torchax
57
from jax.sharding import Mesh, NamedSharding, PartitionSpec
68
from torch.nn import Parameter
79
from torch.utils import _pytree as pytree
8-
from torchax.interop import torch_view
10+
from torchax.interop import jax_view, torch_view
11+
from torchax.ops.mappings import t2j
912
from vllm.lora.layers import (MergedColumnParallelLinearWithLoRA,
1013
MergedQKVParallelLinearWithLoRA,
1114
RowParallelLinearWithLoRA)
@@ -81,9 +84,16 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
8184

8285
def _convert_to_torchax_and_shard(tensor: torch.Tensor,
8386
sharding: NamedSharding) -> torch.Tensor:
84-
np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
85-
dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
86-
return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))
87+
if os.getenv("VLLM_TPU_USE_PATHWAYS", False) and tensor is torch.Tensor:
88+
np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
89+
dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
90+
return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))
91+
else:
92+
if isinstance(tensor, torchax.tensor.Tensor):
93+
tensor = jax_view(tensor)
94+
else:
95+
tensor = t2j(tensor)
96+
return torch_view(_sharded_device_put(tensor, sharding))
8797

8898

8999
def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,

0 commit comments

Comments
 (0)