|
| 1 | +import os |
| 2 | + |
1 | 3 | import jax |
2 | 4 | import jax.numpy as jnp |
3 | 5 | import torch |
4 | 6 | import torchax |
5 | 7 | from jax.sharding import Mesh, NamedSharding, PartitionSpec |
6 | 8 | from torch.nn import Parameter |
7 | 9 | from torch.utils import _pytree as pytree |
8 | | -from torchax.interop import torch_view |
| 10 | +from torchax.interop import jax_view, torch_view |
| 11 | +from torchax.ops.mappings import t2j |
9 | 12 | from vllm.lora.layers import (MergedColumnParallelLinearWithLoRA, |
10 | 13 | MergedQKVParallelLinearWithLoRA, |
11 | 14 | RowParallelLinearWithLoRA) |
@@ -81,9 +84,16 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool: |
81 | 84 |
|
82 | 85 | def _convert_to_torchax_and_shard(tensor: torch.Tensor, |
83 | 86 | sharding: NamedSharding) -> torch.Tensor: |
84 | | - np_tensor = tensor.detach().cpu().to(torch.float32).numpy() |
85 | | - dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32) |
86 | | - return torch_view(jax.device_put(np_tensor, sharding).astype(dtype)) |
| 87 | + if os.getenv("VLLM_TPU_USE_PATHWAYS", False) and tensor is torch.Tensor: |
| 88 | + np_tensor = tensor.detach().cpu().to(torch.float32).numpy() |
| 89 | + dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32) |
| 90 | + return torch_view(jax.device_put(np_tensor, sharding).astype(dtype)) |
| 91 | + else: |
| 92 | + if isinstance(tensor, torchax.tensor.Tensor): |
| 93 | + tensor = jax_view(tensor) |
| 94 | + else: |
| 95 | + tensor = t2j(tensor) |
| 96 | + return torch_view(_sharded_device_put(tensor, sharding)) |
87 | 97 |
|
88 | 98 |
|
89 | 99 | def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor, |
|
0 commit comments