From 6fd621f1de66ab740b0585733b1dc117e90dfbfa Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Thu, 17 Apr 2025 11:02:33 -0600 Subject: [PATCH 1/3] [llama4 -> llama4_jax] Refactor to be a proper installable Python package ; [llama4_jax/pyproject.toml] Add missing dependency ; [llama4_jax/README.md] Document new usage --- {llama4 => llama4_jax}/.gitignore | 2 +- {llama4 => llama4_jax}/README.md | 8 +++--- llama4_jax/llama4_jax/__init__.py | 0 .../llama4_jax/__main__.py | 1 - .../llama4_jax/chkpt_utils.py | 3 +- .../llama4_jax/decode_ragged_dot.py | 0 {llama4 => llama4_jax}/llama4_jax/model.py | 7 +++-- .../llama4_jax/ragged_attention.py | 0 {llama4 => llama4_jax}/pyproject.toml | 3 +- llama4_jax/scripts/__init__.py | 0 .../scripts/convert_weights.py | 28 ++++++++----------- .../scripts/download_model.py | 9 +++--- .../scripts/quantize_model.py | 12 +++----- 13 files changed, 32 insertions(+), 41 deletions(-) rename {llama4 => llama4_jax}/.gitignore (94%) rename {llama4 => llama4_jax}/README.md (98%) create mode 100644 llama4_jax/llama4_jax/__init__.py rename llama4/main.py => llama4_jax/llama4_jax/__main__.py (99%) rename {llama4 => llama4_jax}/llama4_jax/chkpt_utils.py (99%) rename {llama4 => llama4_jax}/llama4_jax/decode_ragged_dot.py (100%) rename {llama4 => llama4_jax}/llama4_jax/model.py (99%) rename {llama4 => llama4_jax}/llama4_jax/ragged_attention.py (100%) rename {llama4 => llama4_jax}/pyproject.toml (86%) create mode 100644 llama4_jax/scripts/__init__.py rename {llama4 => llama4_jax}/scripts/convert_weights.py (71%) rename {llama4 => llama4_jax}/scripts/download_model.py (77%) rename {llama4 => llama4_jax}/scripts/quantize_model.py (68%) diff --git a/llama4/.gitignore b/llama4_jax/.gitignore similarity index 94% rename from llama4/.gitignore rename to llama4_jax/.gitignore index 3afd13a..630d63b 100644 --- a/llama4/.gitignore +++ b/llama4_jax/.gitignore @@ -11,4 +11,4 @@ __pycache__/ build/** .venv -.vscode \ No newline at end of file +.vscode diff --git a/llama4/README.md b/llama4_jax/README.md similarity index 98% rename from llama4/README.md rename to llama4_jax/README.md index 53e2e9c..5e57309 100644 --- a/llama4/README.md +++ b/llama4_jax/README.md @@ -11,8 +11,8 @@ This is a pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights. It currently runs on TPU. Support for GPU is in-progress. -The entire model is defined in [model.py](llama4_jax/model.py) and invoked -via [main.py](main.py). Among other things, the model code demonstrates: +The entire model is defined in [__main__.py](llama4_jax/__main__.py) and invoked +via `python3 -m llam4_jax`. Among other things, the model code demonstrates: * an MLA attention implementation; * expert and tensor-parallelism via JAX's [`shard_map`](https://docs.jax.dev/en/latest/sharded-computation.html#manual-parallelism-with-shard-map) @@ -47,12 +47,12 @@ the full model. We've tested on v5e-64. Run on all hosts in the TPU cluster: ``` -$ python3 main.py +$ python3 -m llam4_jax ``` e.g. for Cloud TPU: ``` $ gcloud compute tpus tpu-vm ssh {TPU_NAME} --worker=all \ - --command="cd ~/llama4_jax && python3 main.py" + --command="cd ~/llama4_jax && python3 -m llam4_jax" ``` Responses: diff --git a/llama4_jax/llama4_jax/__init__.py b/llama4_jax/llama4_jax/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llama4/main.py b/llama4_jax/llama4_jax/__main__.py similarity index 99% rename from llama4/main.py rename to llama4_jax/llama4_jax/__main__.py index aa61675..9e07e8b 100644 --- a/llama4/main.py +++ b/llama4_jax/llama4_jax/__main__.py @@ -15,7 +15,6 @@ import dataclasses from etils import epath import json -from pprint import pformat import jax from jax import numpy as jnp diff --git a/llama4/llama4_jax/chkpt_utils.py b/llama4_jax/llama4_jax/chkpt_utils.py similarity index 99% rename from llama4/llama4_jax/chkpt_utils.py rename to llama4_jax/llama4_jax/chkpt_utils.py index 3370ff7..d79f85f 100644 --- a/llama4/llama4_jax/chkpt_utils.py +++ b/llama4_jax/llama4_jax/chkpt_utils.py @@ -18,7 +18,6 @@ from concurrent.futures import ThreadPoolExecutor from pathlib import Path import dataclasses -from typing import Any import jax from jax import numpy as jnp @@ -26,7 +25,7 @@ import torch from tqdm import tqdm -from . import model as l4jax +from llama4_jax import model as l4jax def quantize_model(ckpt_path: Path, quant_ckpt_path: Path): diff --git a/llama4/llama4_jax/decode_ragged_dot.py b/llama4_jax/llama4_jax/decode_ragged_dot.py similarity index 100% rename from llama4/llama4_jax/decode_ragged_dot.py rename to llama4_jax/llama4_jax/decode_ragged_dot.py diff --git a/llama4/llama4_jax/model.py b/llama4_jax/llama4_jax/model.py similarity index 99% rename from llama4/llama4_jax/model.py rename to llama4_jax/llama4_jax/model.py index e163aa8..9cb941c 100644 --- a/llama4/llama4_jax/model.py +++ b/llama4_jax/llama4_jax/model.py @@ -32,10 +32,11 @@ from jax.experimental.shard import auto_axes, reshard from etils import epath -from . import ragged_attention -from .decode_ragged_dot import decode_ragged_dot +from llama4_jax import ragged_attention +from llama4_jax.decode_ragged_dot import decode_ragged_dot -map, builtin_map = jax.util.safe_map, map +builtin_map = map +map = jax.util.safe_map AxisName = str | tuple[str, ...] | None Axes = tuple[AxisName, ...] diff --git a/llama4/llama4_jax/ragged_attention.py b/llama4_jax/llama4_jax/ragged_attention.py similarity index 100% rename from llama4/llama4_jax/ragged_attention.py rename to llama4_jax/llama4_jax/ragged_attention.py diff --git a/llama4/pyproject.toml b/llama4_jax/pyproject.toml similarity index 86% rename from llama4/pyproject.toml rename to llama4_jax/pyproject.toml index 1a0a29e..66f5e5b 100644 --- a/llama4/pyproject.toml +++ b/llama4_jax/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "llama4_jax" version = "0.1.0" -description = "" +description = "Pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights." authors = [ { name = "Robert Dyro" }, ] @@ -13,6 +13,7 @@ dependencies = [ "jax", "torch", "transformers", # for the model config and the tokenizer + "tune_jax", "tqdm", "numpy", "orbax-checkpoint", diff --git a/llama4_jax/scripts/__init__.py b/llama4_jax/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llama4/scripts/convert_weights.py b/llama4_jax/scripts/convert_weights.py similarity index 71% rename from llama4/scripts/convert_weights.py rename to llama4_jax/scripts/convert_weights.py index 2deb98f..9cd70d6 100644 --- a/llama4/scripts/convert_weights.py +++ b/llama4_jax/scripts/convert_weights.py @@ -1,31 +1,24 @@ #!/usr/bin/env python3 -import sys from pathlib import Path -from pprint import pprint from argparse import ArgumentParser import dataclasses +import os.path import shutil -try: - from llama4_jax import model as l4jax - from llama4_jax import chkpt_utils as utils -except ImportError: - sys.path.append(str(Path(__file__).parent.absolute())) - - from llama4_jax import model as l4jax - from llama4_jax import chkpt_utils as utils - from transformers import AutoConfig from safetensors import safe_open from tqdm import tqdm +from llama4_jax import model as l4jax +from llama4_jax import chkpt_utils as utils + def main(model_path: str | Path, ckpt_path: str | Path): model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser() - files = list(model_path.glob("**/*safetensors")) + files = list(model_path.glob(os.path.join("**", "*safetensors"))) assert len(files) > 1 - config_files = list(model_path.glob("**/config.json")) + config_files = list(model_path.glob(os.path.join("**", "config.json"))) assert len(config_files) == 1, "Must have only one `config.json` file in the model path" config = AutoConfig.from_pretrained(config_files[0]).text_config cfg = l4jax.hf_to_jax_config(config) @@ -44,9 +37,9 @@ def main(model_path: str | Path, ckpt_path: str | Path): additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"] for additional_file in additional_files: - full_paths = list(model_path.glob(f"**/{additional_file}")) + full_paths = list(model_path.glob(os.path.join("**", additional_file))) if len(full_paths) != 1: - print(f"Found more than 1 file for {additional_file}") + print("Found more than 1 file for", additional_file) if len(full_paths) == 0: continue full_path = full_paths[0] @@ -56,11 +49,12 @@ def main(model_path: str | Path, ckpt_path: str | Path): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( - "--source-path", default="~/DeepSeek-R1-Distill-Llama-70B", required=True, help="HF model directory path" + "--source-path", default=os.path.join(os.path.expanduser("~"), "DeepSeek-R1-Distill-Llama-70B"), + required=True, help="HF model directory path" ) parser.add_argument( "--dest-path", - default="~/DeepSeek-R1-Distill-Llama-3.1-70B-Instruct", + default=os.path.join(os.path.expanduser("~"), "DeepSeek-R1-Distill-Llama-3.1-70B-Instruct"), required=True, help="JAX model model directory (to be created).", ) diff --git a/llama4/scripts/download_model.py b/llama4_jax/scripts/download_model.py similarity index 77% rename from llama4/scripts/download_model.py rename to llama4_jax/scripts/download_model.py index 7b921e2..b795c39 100644 --- a/llama4/scripts/download_model.py +++ b/llama4_jax/scripts/download_model.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 - +import os from argparse import ArgumentParser from pathlib import Path @@ -12,19 +12,20 @@ def main(model_id: str, dest_root_path: str | Path): - local_dir = Path(dest_root_path).expanduser().absolute() / str(model_id).replace("/", "--") + local_dir = Path(dest_root_path).expanduser().absolute() / str(model_id).replace(os.path.sep, "--") snapshot_download(repo_id=model_id, local_dir=local_dir) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( - "--model-id", required=True, help=f"HuggingFace model / repo id. Examples include: {example_models}" + "--model-id", required=True, + help=f"HuggingFace model / repo id. Examples include: {example_models}" ) parser.add_argument( "--dest-root-path", required=True, - default="~/", + default=os.path.join(os.path.expanduser("~"), ""), help="Destination root directory, the model will be saved into its own directory.", ) args = parser.parse_args() diff --git a/llama4/scripts/quantize_model.py b/llama4_jax/scripts/quantize_model.py similarity index 68% rename from llama4/scripts/quantize_model.py rename to llama4_jax/scripts/quantize_model.py index 4fc0b2b..1fd23cc 100644 --- a/llama4/scripts/quantize_model.py +++ b/llama4_jax/scripts/quantize_model.py @@ -1,15 +1,10 @@ #!/usr/bin/env python3 -import sys +import os.path from pathlib import Path from argparse import ArgumentParser -try: - from llama4_jax import chkpt_utils as utils -except ImportError: - sys.path.append(str(Path(__file__).parent.absolute())) - - from llama4_jax import chkpt_utils as utils +from llama4_jax import chkpt_utils as utils def main(path: str | Path, suffix: str): @@ -21,7 +16,8 @@ def main(path: str | Path, suffix: str): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( - "--path", default="~/DeepSeek-R1-Distill-Llama-70B", required=True, help="Existing JAX model checkpoint path" + "--path", default=os.path.join(os.path.expanduser("~"), "DeepSeek-R1-Distill-Llama-70B"), + required=True, help="Existing JAX model checkpoint path" ) parser.add_argument( "--suffix", From d5c936a29b3fe134338144dedeecf66a9215d34a Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Mon, 23 Jun 2025 20:56:13 -0700 Subject: [PATCH 2/3] [*.py] Use `os.path.join` / `Path` ; [**/pyproject.toml] Add `huggingface-hub` to deps ; sort deps --- deepseek_r1_jax/deepseek_r1_jax/chkpt_utils.py | 3 +++ .../deepseek_r1_jax/decode_ragged_dot.py | 1 + deepseek_r1_jax/deepseek_r1_jax/model.py | 3 ++- .../scripts/get_params_metadata_hf.py | 4 +++- llama3/main.py | 6 ++++-- llama3/pyproject.toml | 13 +++++++------ llama3/scripts/convert_weights.py | 5 +++-- llama3/scripts/download_model.py | 3 ++- llama4_jax/llama4_jax/__main__.py | 6 ++++-- llama4_jax/llama4_jax/chkpt_utils.py | 2 ++ llama4_jax/llama4_jax/model.py | 1 + llama4_jax/llama4_jax/ragged_attention.py | 1 + llama4_jax/pyproject.toml | 13 +++++++------ llama4_jax/scripts/convert_weights.py | 16 ++++++---------- qwen3/main.py | 6 ++++-- qwen3/pyproject.toml | 13 +++++++------ qwen3/scripts/convert_weights.py | 12 ++++++------ qwen3/scripts/quantize_model.py | 2 +- 18 files changed, 64 insertions(+), 46 deletions(-) diff --git a/deepseek_r1_jax/deepseek_r1_jax/chkpt_utils.py b/deepseek_r1_jax/deepseek_r1_jax/chkpt_utils.py index 701e8fa..0c96b0e 100644 --- a/deepseek_r1_jax/deepseek_r1_jax/chkpt_utils.py +++ b/deepseek_r1_jax/deepseek_r1_jax/chkpt_utils.py @@ -21,6 +21,7 @@ import re from concurrent.futures import ThreadPoolExecutor import math + from etils import epath import jax @@ -28,8 +29,10 @@ from jax import numpy as jnp from jax.sharding import SingleDeviceSharding, NamedSharding, PartitionSpec as P import numpy as np + import torch import torch.utils.dlpack + from safetensors.torch import load_file from .third_party import modeling_deepseek as deepseek diff --git a/deepseek_r1_jax/deepseek_r1_jax/decode_ragged_dot.py b/deepseek_r1_jax/deepseek_r1_jax/decode_ragged_dot.py index 837ff75..895be1b 100644 --- a/deepseek_r1_jax/deepseek_r1_jax/decode_ragged_dot.py +++ b/deepseek_r1_jax/deepseek_r1_jax/decode_ragged_dot.py @@ -21,6 +21,7 @@ from jax import random from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu + from tqdm import tqdm diff --git a/deepseek_r1_jax/deepseek_r1_jax/model.py b/deepseek_r1_jax/deepseek_r1_jax/model.py index 699eb16..1015378 100644 --- a/deepseek_r1_jax/deepseek_r1_jax/model.py +++ b/deepseek_r1_jax/deepseek_r1_jax/model.py @@ -20,11 +20,12 @@ from functools import partial from typing import Callable import tempfile -from etils import epath import gzip import json from pathlib import Path +from etils import epath + import jax import jax.numpy as jnp from jax import random diff --git a/deepseek_r1_jax/scripts/get_params_metadata_hf.py b/deepseek_r1_jax/scripts/get_params_metadata_hf.py index d2ff0a2..5cf0244 100644 --- a/deepseek_r1_jax/scripts/get_params_metadata_hf.py +++ b/deepseek_r1_jax/scripts/get_params_metadata_hf.py @@ -13,8 +13,10 @@ # limitations under the License. import json -import safetensors from pathlib import Path + +import safetensors + from tqdm import tqdm if __name__ == "__main__": diff --git a/llama3/main.py b/llama3/main.py index 9b2cf4c..fdd9d15 100644 --- a/llama3/main.py +++ b/llama3/main.py @@ -13,6 +13,8 @@ # limitations under the License. import dataclasses +import os + from etils import epath import json from pprint import pprint @@ -41,9 +43,9 @@ def encode_input(tokenizer, texts: list[str], model_name: str, pad_id: int = 0): if __name__ == "__main__": jax.distributed.initialize() - quant = True + quant = os.environ.get("QUANT") in (None, 't', 'T', 'True', 'true', 'TRUE', 1) - ckpt_path = epath.Path("~/bucket/DeepSeek-R1-Distill-Llama-3.1-70B-Instruct").expanduser() + ckpt_path = epath.Path("~").expanduser() / "bucket" / "DeepSeek-R1-Distill-Llama-3.1-70B-Instruct" if quant: ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant" tokenizer = l3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") diff --git a/llama3/pyproject.toml b/llama3/pyproject.toml index 100b781..3d97322 100644 --- a/llama3/pyproject.toml +++ b/llama3/pyproject.toml @@ -10,15 +10,16 @@ requires-python = ">=3.10" license = { text = "Apache-2.0" } dependencies = [ + "datasets", + "etils", + "gcsfs", + "huggingface-hub", "jax", - "torch", - "transformers", # for the model config and the tokenizer - "tqdm", "numpy", "orbax-checkpoint", - "datasets", - "gcsfs", - "etils", + "torch", + "tqdm", + "transformers", # for the model config and the tokenizer ] # we don't need CUDA torch diff --git a/llama3/scripts/convert_weights.py b/llama3/scripts/convert_weights.py index 2912c22..5c71e59 100644 --- a/llama3/scripts/convert_weights.py +++ b/llama3/scripts/convert_weights.py @@ -5,6 +5,7 @@ from argparse import ArgumentParser import dataclasses import shutil +import os.path def main(model_path: str | Path, ckpt_path: str | Path): try: @@ -21,9 +22,9 @@ def main(model_path: str | Path, ckpt_path: str | Path): from tqdm import tqdm model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser() - files = list(model_path.glob("**/*safetensors")) + files = list(model_path.glob(os.path.join("**", "*safetensors"))) assert len(files) > 1 - config_files = list(model_path.glob("**/config.json")) + config_files = list(model_path.glob(os.path.join("**", "config.json"))) assert len(config_files) == 1, "Must have only one `config.json` file in the model path" config = AutoConfig.from_pretrained(config_files[0]) cfg = l3jax.llama_to_jax_config(config) diff --git a/llama3/scripts/download_model.py b/llama3/scripts/download_model.py index b609d63..eb6f41e 100644 --- a/llama3/scripts/download_model.py +++ b/llama3/scripts/download_model.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import os.path from argparse import ArgumentParser from pathlib import Path @@ -27,7 +28,7 @@ def main(model_id: str, dest_root_path: str | Path): parser.add_argument( "--dest-root-path", required=True, - default="~/", + default=os.path.join(os.path.expanduser("~"), ""), help="Destination root directory, the model will be saved into its own directory.", ) args = parser.parse_args() diff --git a/llama4_jax/llama4_jax/__main__.py b/llama4_jax/llama4_jax/__main__.py index 9e07e8b..009a0f1 100644 --- a/llama4_jax/llama4_jax/__main__.py +++ b/llama4_jax/llama4_jax/__main__.py @@ -13,6 +13,8 @@ # limitations under the License. import dataclasses +import os + from etils import epath import json @@ -39,9 +41,9 @@ def encode_input(tokenizer, texts, pad_id: int = 0): if __name__ == "__main__": jax.distributed.initialize() - quant = True + quant = os.environ.get("QUANT") in (None, 't', 'T', 'True', 'true', 'TRUE', 1) - ckpt_path = epath.Path("~/bucket/Llama-4-Scout-Instruct").expanduser() + ckpt_path = epath.Path("~").expanduser() / "bucket" / "Llama-4-Scout-Instruct" if quant: ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant" tokenizer = l4jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") diff --git a/llama4_jax/llama4_jax/chkpt_utils.py b/llama4_jax/llama4_jax/chkpt_utils.py index d79f85f..4507af7 100644 --- a/llama4_jax/llama4_jax/chkpt_utils.py +++ b/llama4_jax/llama4_jax/chkpt_utils.py @@ -22,7 +22,9 @@ import jax from jax import numpy as jnp from jax.sharding import PartitionSpec as P + import torch + from tqdm import tqdm from llama4_jax import model as l4jax diff --git a/llama4_jax/llama4_jax/model.py b/llama4_jax/llama4_jax/model.py index 2a4b8f8..27ca52b 100644 --- a/llama4_jax/llama4_jax/model.py +++ b/llama4_jax/llama4_jax/model.py @@ -30,6 +30,7 @@ from jax.experimental.shard_map import shard_map from jax.sharding import PartitionSpec as P, use_mesh from jax.experimental.shard import auto_axes, reshard + from etils import epath from . import ragged_attention diff --git a/llama4_jax/llama4_jax/ragged_attention.py b/llama4_jax/llama4_jax/ragged_attention.py index 347c346..b2b0556 100644 --- a/llama4_jax/llama4_jax/ragged_attention.py +++ b/llama4_jax/llama4_jax/ragged_attention.py @@ -10,6 +10,7 @@ from jax.experimental.pallas import tpu as pltpu from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, PartitionSpec as P, NamedSharding + import numpy as np NUM_LANES = 128 diff --git a/llama4_jax/pyproject.toml b/llama4_jax/pyproject.toml index 14718ba..28b4526 100644 --- a/llama4_jax/pyproject.toml +++ b/llama4_jax/pyproject.toml @@ -10,16 +10,17 @@ requires-python = ">=3.10" license = { text = "Apache-2.0" } dependencies = [ + "datasets", + "etils", + "gcsfs", + "huggingface-hub", "jax", + "numpy", + "orbax-checkpoint", "torch", + "tqdm", "transformers", # for the model config and the tokenizer "tune_jax", - "tqdm", - "numpy", - "orbax-checkpoint", - "datasets", - "gcsfs", - "etils", ] # we don't need CUDA torch diff --git a/llama4_jax/scripts/convert_weights.py b/llama4_jax/scripts/convert_weights.py index a5db741..b6b0ee8 100644 --- a/llama4_jax/scripts/convert_weights.py +++ b/llama4_jax/scripts/convert_weights.py @@ -1,22 +1,15 @@ #!/usr/bin/env python3 -import sys from pathlib import Path from argparse import ArgumentParser import dataclasses import shutil +from llama4_jax import model as l4jax +from llama4_jax import chkpt_utils as utils -def main(model_path: str | Path, ckpt_path: str | Path): - try: - from llama4_jax import model as l4jax - from llama4_jax import chkpt_utils as utils - except ImportError: - sys.path.append(str(Path(__file__).parents[1].absolute())) - - from llama4_jax import model as l4jax - from llama4_jax import chkpt_utils as utils +def main(model_path: str | Path, ckpt_path: str | Path): from transformers import AutoConfig from safetensors import safe_open from tqdm import tqdm @@ -44,6 +37,9 @@ def main(model_path: str | Path, ckpt_path: str | Path): additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"] for additional_file in additional_files: full_paths = list(model_path.glob(f"**/{additional_file}")) + PurePath('**/').full_match('*.py') + + full_paths = list(model_path.glob(Path("~").expanduser() / "DeepSeek-R1-Distill-Llama-70B")) if len(full_paths) != 1: print(f"Found more than 1 file for {additional_file}") if len(full_paths) == 0: diff --git a/qwen3/main.py b/qwen3/main.py index b768016..485ce14 100644 --- a/qwen3/main.py +++ b/qwen3/main.py @@ -13,6 +13,8 @@ # limitations under the License. import dataclasses +import os + from etils import epath import json @@ -38,9 +40,9 @@ def encode_input(tokenizer, texts, pad_id: int = 0): if __name__ == "__main__": #jax.distributed.initialize() # if you want to run multi-host - quant = True + quant = os.environ.get("QUANT") in (None, 't', 'T', 'True', 'true', 'TRUE', 1) - ckpt_path = epath.Path("~/bucket/qwen3_jax/Qwen3-30B-A3B").expanduser() + ckpt_path = epath.Path("~").expanduser() / "bucket" / "qwen3_jax" / "Qwen3-30B-A3B" if quant: ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant" tokenizer = q3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") diff --git a/qwen3/pyproject.toml b/qwen3/pyproject.toml index d7fec1f..86abad0 100644 --- a/qwen3/pyproject.toml +++ b/qwen3/pyproject.toml @@ -10,15 +10,16 @@ requires-python = ">=3.10" license = { text = "Apache-2.0" } dependencies = [ + "datasets", + "etils", + "gcsfs", + "huggingface-hub", "jax", - "torch", - "transformers", # for the model config and the tokenizer - "tqdm", "numpy", "orbax-checkpoint", - "datasets", - "gcsfs", - "etils", + "torch", + "tqdm", + "transformers", # for the model config and the tokenizer ] # we don't need CUDA torch diff --git a/qwen3/scripts/convert_weights.py b/qwen3/scripts/convert_weights.py index a72b87c..2f03a3b 100644 --- a/qwen3/scripts/convert_weights.py +++ b/qwen3/scripts/convert_weights.py @@ -2,7 +2,7 @@ import sys from pathlib import Path -from pprint import pprint +import os.path from argparse import ArgumentParser import dataclasses import shutil @@ -22,9 +22,9 @@ def main(model_path: str | Path, ckpt_path: str | Path): from tqdm import tqdm model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser() - files = list(model_path.glob("**/*safetensors")) + files = list(model_path.glob(os.path.join("**", "*safetensors"))) assert len(files) > 1 - config_files = list(model_path.glob("**/config.json")) + config_files = list(model_path.glob(os.path.join("**", "config.json"))) assert len(config_files) == 1, "Must have only one `config.json` file in the model path" config = AutoConfig.from_pretrained(config_files[0]) cfg = q3jax.hf_to_jax_config(config) @@ -43,7 +43,7 @@ def main(model_path: str | Path, ckpt_path: str | Path): additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"] for additional_file in additional_files: - full_paths = list(model_path.glob(f"**/{additional_file}")) + full_paths = list(model_path.glob(os.path.join("**", additional_file))) if len(full_paths) != 1: print(f"Found more than 1 file for {additional_file}") if len(full_paths) == 0: @@ -55,11 +55,11 @@ def main(model_path: str | Path, ckpt_path: str | Path): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( - "--source-path", default="~/Qwen3-30B-A3B", required=True, help="HF model directory path" + "--source-path", default=Path("~").expanduser() / "Qwen3-30B-A3B", required=True, help="HF model directory path" ) parser.add_argument( "--dest-path", - default="~/qwen3_jax/Qwen3-30B-A3B", + default=Path("~").expanduser() / "qwen3_jax" / "Qwen3-30B-A3B", required=True, help="JAX model model directory (to be created).", ) diff --git a/qwen3/scripts/quantize_model.py b/qwen3/scripts/quantize_model.py index 4b0a230..1594276 100644 --- a/qwen3/scripts/quantize_model.py +++ b/qwen3/scripts/quantize_model.py @@ -21,7 +21,7 @@ def main(path: str | Path, suffix: str): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( - "--path", default="~/Qwen3-30B-A3B", required=True, help="Existing JAX model checkpoint path" + "--path", default=Path("~").expanduser() / "Qwen3-30B-A3B", required=True, help="Existing JAX model checkpoint path" ) parser.add_argument( "--suffix", From 3812791e6375fce35340af6dd869037689c2fed3 Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Fri, 27 Jun 2025 13:11:17 -0700 Subject: [PATCH 3/3] [llama4_jax/scripts/convert_weights.py] Remove debug code --- llama4_jax/scripts/convert_weights.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/llama4_jax/scripts/convert_weights.py b/llama4_jax/scripts/convert_weights.py index b6b0ee8..2ea8c94 100644 --- a/llama4_jax/scripts/convert_weights.py +++ b/llama4_jax/scripts/convert_weights.py @@ -36,9 +36,6 @@ def main(model_path: str | Path, ckpt_path: str | Path): additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"] for additional_file in additional_files: - full_paths = list(model_path.glob(f"**/{additional_file}")) - PurePath('**/').full_match('*.py') - full_paths = list(model_path.glob(Path("~").expanduser() / "DeepSeek-R1-Distill-Llama-70B")) if len(full_paths) != 1: print(f"Found more than 1 file for {additional_file}")