From 8f03f02888e7944c4673f9b1d9081a130325364c Mon Sep 17 00:00:00 2001 From: piz Date: Thu, 24 Apr 2025 05:47:52 +0000 Subject: [PATCH 1/2] backport to 2.7 test --- torchprime/launcher/Dockerfile | 2 +- torchprime/torch_xla_models/offloading.py | 4 ++-- torchprime/torch_xla_models/remat_all.py | 4 ++-- torchprime/torch_xla_models/train.py | 9 ++++++++- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/torchprime/launcher/Dockerfile b/torchprime/launcher/Dockerfile index a1107ba5..6c60c842 100644 --- a/torchprime/launcher/Dockerfile +++ b/torchprime/launcher/Dockerfile @@ -1,6 +1,6 @@ # syntax=docker/dockerfile:experimental # Use torch_xla Python 3.10 as the base image -FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20250410 +FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.7.0_3.10_tpuvm ARG USE_TRANSFORMERS=false ARG USE_LOCAL_WHEEL=false diff --git a/torchprime/torch_xla_models/offloading.py b/torchprime/torch_xla_models/offloading.py index ccef61cf..56b33850 100644 --- a/torchprime/torch_xla_models/offloading.py +++ b/torchprime/torch_xla_models/offloading.py @@ -44,7 +44,7 @@ def remat_all_and_offload_these_inputs( *, num_fwd_outputs, names_to_offload: Sequence[str], - static_lifetime_input_indices=None, + # static_lifetime_input_indices=None, ): """Partition the graph to rematerialize forward activations and offload forward inputs to host. @@ -72,7 +72,7 @@ def remat_all_and_offload_these_inputs( joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs, - static_lifetime_input_indices=static_lifetime_input_indices, + # static_lifetime_input_indices=static_lifetime_input_indices, ) with torch.device(input_device): fw_example_args = _make_arguments(fwd) diff --git a/torchprime/torch_xla_models/remat_all.py b/torchprime/torch_xla_models/remat_all.py index 94266416..4f303134 100644 --- a/torchprime/torch_xla_models/remat_all.py +++ b/torchprime/torch_xla_models/remat_all.py @@ -9,7 +9,7 @@ def remat_all_partition_fn( _joint_inputs, *, num_fwd_outputs, - static_lifetime_input_indices=None, + # static_lifetime_input_indices=None, ): """ remat_all_partition_fn is a graph partition function that closely matches the @@ -30,7 +30,7 @@ def remat_all_partition_fn( joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs, - static_lifetime_input_indices=static_lifetime_input_indices, + # static_lifetime_input_indices=static_lifetime_input_indices, ) diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 944dbaa9..2ab9eaba 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -19,9 +19,9 @@ import transformers from datasets import load_dataset from omegaconf import DictConfig, OmegaConf +from packaging import version from torch import nn from torch.utils.data import DataLoader, Dataset, IterableDataset -from torch_xla._internal.jax_workarounds import jax_env_context from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear from transformers import ( @@ -44,6 +44,13 @@ from torchprime.torch_xla_models import offloading, remat_all, scan_layers from torchprime.torch_xla_models.topology import get_mesh, is_1d_sharding +if version.parse(torch_xla.__version__.split("+")[0]) >= version.parse("2.8.0"): + from torch_xla._internal.jax_workarounds import jax_env_context +else: + from torch_xla.experimental.custom_kernel import _jax_env_context as jax_env_context + + + check_min_version("4.39.3") logger = logging.getLogger(__name__) From d2c06af74c12e93e2f2bc27b863d3fb5fd842b35 Mon Sep 17 00:00:00 2001 From: piz Date: Thu, 24 Apr 2025 16:32:36 +0000 Subject: [PATCH 2/2] reinstall torchax --- torchprime/launcher/Dockerfile | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchprime/launcher/Dockerfile b/torchprime/launcher/Dockerfile index 6c60c842..232739ca 100644 --- a/torchprime/launcher/Dockerfile +++ b/torchprime/launcher/Dockerfile @@ -24,6 +24,15 @@ RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python WORKDIR /workspaces +# Install torchax +RUN git clone --depth 1 https://github.com/pytorch/xla.git +WORKDIR /workspaces/xla/torchax +RUN pip install torch_xla[pallas] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +RUN pip install -e . + + # Install torchprime # Optimization: we rerun `pip install -e .` only if `pyproject.toml` changes. # Copy only the installation-related files first to make Docker cache them separately.