From 4d8ad5f7cbd4e9a9cc5a312c90384af70de62bb7 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Tue, 8 Jul 2025 00:38:02 +0000 Subject: [PATCH 1/7] Experimenting caching processed dataset to gcs --- pyproject.toml | 3 +- torchprime/data/dataset.py | 53 +++++++++++++++++-- .../gcs_cache_loading/preprocess_dataset.py | 52 ++++++++++++++++++ torchprime/launcher/cli.py | 40 ++++++++++++++ .../configs/dataset/wikitext.yaml | 1 + 5 files changed, 143 insertions(+), 6 deletions(-) create mode 100644 torchprime/experimental/gcs_cache_loading/preprocess_dataset.py diff --git a/pyproject.toml b/pyproject.toml index 0f812f71..01bfbf94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ dependencies = [ "benchmark-db-writer @ git+https://github.com/AI-Hypercomputer/aotc.git@2ff16e670df20b497ddaf1f86920dbb5dd9f0c8f#subdirectory=src/aotc/benchmark_db_writer", "dacite==1.9.2", "click~=8.1.8", - "google-cloud-storage==2.19.0" + "google-cloud-storage==2.19.0", + "gcsfs" ] [project.optional-dependencies] diff --git a/torchprime/data/dataset.py b/torchprime/data/dataset.py index ed9aacdd..5dcf65f5 100644 --- a/torchprime/data/dataset.py +++ b/torchprime/data/dataset.py @@ -1,11 +1,14 @@ """Utilities for preparing datasets for basic training tasks.""" import json +import logging import fsspec -from datasets import Dataset, DatasetDict, load_dataset +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from transformers.tokenization_utils import PreTrainedTokenizerBase +logger = logging.getLogger(__name__) + def _load_json_dataset(path: str, split: str) -> Dataset: """Load a dataset from a JSON Lines file. @@ -33,6 +36,8 @@ def _load_hf_dataset( config: str | None, split: str, cache_dir: str | None, + num_proc: int | None, + streaming: bool = False, ) -> Dataset: """Download and return a dataset from Hugging Face Hub. @@ -41,12 +46,21 @@ def _load_hf_dataset( config: Optional configuration name. split: Split to load. cache_dir: Directory where the dataset cache should live. + num_proc: Number of processes to use for dataset operations. + streaming: Whether to stream the dataset. Returns: The loaded ``Dataset`` instance for ``split``. """ - data = load_dataset(name, config, split=split, cache_dir=cache_dir) + data = load_dataset( + name, + config, + split=split, + cache_dir=cache_dir, + num_proc=num_proc, + streaming=streaming, + ) assert isinstance(data, Dataset | DatasetDict) if isinstance(data, DatasetDict): data = data[split] @@ -59,6 +73,8 @@ def load_hf_or_json_dataset( file_dataset_path: str | None = None, split: str = "train", cache_dir: str | None = None, + num_proc: int | None = None, + streaming: bool = False, ): """Loads a dataset either from Hugging Face Hub or a local/remote JSONL file. @@ -72,12 +88,21 @@ def load_hf_or_json_dataset( file_dataset_path: Optional path to a JSONL file (local or remote). split: Dataset split to load (default is "train"). cache_dir: Optional directory to use for dataset caching (HF only). + num_proc: Number of processes to use for dataset operations (HF only). + streaming: Whether to stream the dataset (HF only). Returns: A HuggingFace ``Dataset`` instance. """ if hf_dataset_name: - data = _load_hf_dataset(hf_dataset_name, hf_dataset_config_name, split, cache_dir) + data = _load_hf_dataset( + hf_dataset_name, + hf_dataset_config_name, + split, + cache_dir, + num_proc, + streaming, + ) elif file_dataset_path: data = _load_json_dataset(file_dataset_path, split) else: @@ -89,6 +114,7 @@ def load_hf_or_json_dataset( def make_train_dataset( + cached_dataset_path: str | None = None, hf_dataset_name: str | None = None, hf_dataset_config_name: str | None = None, file_dataset_path: str | None = None, @@ -97,6 +123,9 @@ def make_train_dataset( *, tokenizer: PreTrainedTokenizerBase, block_size: int, + text_column: str = "text", + num_proc: int | None = None, + streaming: bool = False, ) -> Dataset: """Loads and tokenizes a dataset, then chunks it into fixed-size blocks for training. @@ -106,6 +135,7 @@ def make_train_dataset( for efficient language modeling, especially on accelerators like TPUs. Args: + cached_dataset_path: Optional path to a pre-processed, cached dataset. hf_dataset_name: Optional Hugging Face dataset name. (e.g., "wikitext"). hf_dataset_config_name: Optional HF dataset config name. (e.g., "wikitext-103-raw-v1"). file_dataset_path: Optional path or ``gs://`` URI to a JSONL dataset. @@ -113,24 +143,37 @@ def make_train_dataset( cache_dir: Optional directory for HF dataset cache. tokenizer: A Hugging Face tokenizer used to tokenize the input text. block_size: The fixed length of each chunked training example. + text_column: The name of the column containing the text to be tokenized. + num_proc: Number of processes to use for dataset operations. + streaming: Whether to stream the dataset. Returns: A `Dataset` object containing tokenized and block-wise grouped training examples, each with keys `"input_ids"` and `"labels"`. """ + if cached_dataset_path: + logger.info(f"Loading cached dataset from: {cached_dataset_path}") + # `load_from_disk` works seamlessly with local paths and GCS URIs. + data = load_from_disk(cached_dataset_path) + data.set_format("torch") + return data + data = load_hf_or_json_dataset( hf_dataset_name=hf_dataset_name, hf_dataset_config_name=hf_dataset_config_name, file_dataset_path=file_dataset_path, split=split, cache_dir=cache_dir, + num_proc=num_proc, + streaming=streaming, ) column_names = list(data.features) data = data.map( - lambda samples: tokenizer(samples["text"]), + lambda samples: tokenizer(samples[text_column]), batched=True, remove_columns=column_names, + num_proc=num_proc, ) def group_texts(examples): @@ -155,5 +198,5 @@ def group_texts(examples): result["labels"] = result["input_ids"].copy() return result - data = data.map(group_texts, batched=True) + data = data.map(group_texts, batched=True, num_proc=num_proc) return data diff --git a/torchprime/experimental/gcs_cache_loading/preprocess_dataset.py b/torchprime/experimental/gcs_cache_loading/preprocess_dataset.py new file mode 100644 index 00000000..e8ba93db --- /dev/null +++ b/torchprime/experimental/gcs_cache_loading/preprocess_dataset.py @@ -0,0 +1,52 @@ +import logging + +from transformers import AutoTokenizer + +from torchprime.data.dataset import make_train_dataset + +logger = logging.getLogger(__name__) + + +def main( + dataset_name: str, + dataset_config_name: str | None, + tokenizer_name: str, + output_path: str, + block_size: int, + num_proc: int, + split: str, + text_column: str, +) -> None: + """Main function to preprocess a dataset and save it to a specified location. + + Args: + dataset_name: Name of the Hugging Face dataset. + dataset_config_name: Optional configuration name for the dataset. + tokenizer_name: Name of the Hugging Face tokenizer. + output_path: Path to save the processed dataset. + block_size: Sequence length for packing. + num_proc: Number of processes for mapping. + split: Dataset split to process. + text_column: The column containing text data. + """ + logger.info("Starting dataset preprocessing...") + + logger.info(f"Loading tokenizer: {tokenizer_name}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + processed_dataset = make_train_dataset( + hf_dataset_name=dataset_name, + hf_dataset_config_name=dataset_config_name, + split=split, + tokenizer=tokenizer, + block_size=block_size, + text_column=text_column, + num_proc=num_proc, + streaming=False, + ) + + logger.info(f"Saving processed dataset to: {output_path}") + processed_dataset.save_to_disk(output_path) + logger.info("Preprocessing complete.") \ No newline at end of file diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index a5134337..39702253 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -77,6 +77,46 @@ def cli(ctx, interactive): ctx.obj["interactive"] = interactive +@cli.command() +@click.option("--dataset-name", required=True, help="Name of the Hugging Face dataset.") +@click.option( + "--dataset-config-name", + default=None, + help="Configuration name of the Hugging Face dataset.", +) +@click.option( + "--tokenizer-name", required=True, help="Name of the Hugging Face tokenizer." +) +@click.option( + "--output-path", + required=True, + help="Path to save the processed dataset (local or GCS).", +) +@click.option("--block-size", type=int, default=4096, help="Sequence length for packing.") +@click.option("--num-proc", type=int, default=8, help="Number of processes for mapping.") +@click.option("--split", default="train", help="Dataset split to process.") +@click.option("--text-column", default="text", help="The column containing text data.") +def preprocess( + dataset_name, + dataset_config_name, + tokenizer_name, + output_path, + block_size, + num_proc, + split, + text_column, +): + """Preprocesses a dataset and saves it to a specified location.""" + from torchprime.experimental.gcs_cache_loading.preprocess_dataset import ( + main as preprocess_main, + ) + + preprocess_main( + dataset_name, dataset_config_name, tokenizer_name, output_path, block_size, + num_proc, split, text_column + ) + + @cli.command() @click.option("--cluster", required=True, help="Name of the XPK cluster") @click.option("--project", required=True, help="GCP project the cluster belongs to") diff --git a/torchprime/torch_xla_models/configs/dataset/wikitext.yaml b/torchprime/torch_xla_models/configs/dataset/wikitext.yaml index 0c21363c..c11ea435 100644 --- a/torchprime/torch_xla_models/configs/dataset/wikitext.yaml +++ b/torchprime/torch_xla_models/configs/dataset/wikitext.yaml @@ -4,3 +4,4 @@ hf_dataset_config_name: wikitext-2-raw-v1 split: train block_size: 8192 cache_dir: /tmp/ +cached_dataset_path: null # <--- Add this line with a default of null From d07fd4c1c36f012b4213e3cadd2a0e4975604f0f Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Wed, 23 Jul 2025 01:23:18 +0000 Subject: [PATCH 2/7] Added benchmark to measure initial daataste loading time and batch data loading time --- torchprime/data/dataset.py | 29 +++++--- .../preprocess.py} | 11 +-- torchprime/launcher/cli.py | 35 +++++++--- torchprime/metrics/metrics.py | 10 +++ .../torch_xla_models/configs/default.yaml | 2 + torchprime/torch_xla_models/train.py | 16 ++++- .../torch_xla_models/trainer/base_trainer.py | 69 ++++++++++++++++++- .../utils/data_load_benchmark_logger.py | 41 +++++++++++ 8 files changed, 185 insertions(+), 28 deletions(-) rename torchprime/{experimental/gcs_cache_loading/preprocess_dataset.py => data/preprocess.py} (83%) create mode 100644 torchprime/utils/data_load_benchmark_logger.py diff --git a/torchprime/data/dataset.py b/torchprime/data/dataset.py index 5dcf65f5..884737b1 100644 --- a/torchprime/data/dataset.py +++ b/torchprime/data/dataset.py @@ -4,6 +4,7 @@ import logging import fsspec +import torch_xla.runtime as xr from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from transformers.tokenization_utils import PreTrainedTokenizerBase @@ -36,7 +37,6 @@ def _load_hf_dataset( config: str | None, split: str, cache_dir: str | None, - num_proc: int | None, streaming: bool = False, ) -> Dataset: """Download and return a dataset from Hugging Face Hub. @@ -46,7 +46,6 @@ def _load_hf_dataset( config: Optional configuration name. split: Split to load. cache_dir: Directory where the dataset cache should live. - num_proc: Number of processes to use for dataset operations. streaming: Whether to stream the dataset. Returns: @@ -58,7 +57,6 @@ def _load_hf_dataset( config, split=split, cache_dir=cache_dir, - num_proc=num_proc, streaming=streaming, ) assert isinstance(data, Dataset | DatasetDict) @@ -73,7 +71,6 @@ def load_hf_or_json_dataset( file_dataset_path: str | None = None, split: str = "train", cache_dir: str | None = None, - num_proc: int | None = None, streaming: bool = False, ): """Loads a dataset either from Hugging Face Hub or a local/remote JSONL file. @@ -88,7 +85,6 @@ def load_hf_or_json_dataset( file_dataset_path: Optional path to a JSONL file (local or remote). split: Dataset split to load (default is "train"). cache_dir: Optional directory to use for dataset caching (HF only). - num_proc: Number of processes to use for dataset operations (HF only). streaming: Whether to stream the dataset (HF only). Returns: @@ -100,7 +96,6 @@ def load_hf_or_json_dataset( hf_dataset_config_name, split, cache_dir, - num_proc, streaming, ) elif file_dataset_path: @@ -124,8 +119,8 @@ def make_train_dataset( tokenizer: PreTrainedTokenizerBase, block_size: int, text_column: str = "text", - num_proc: int | None = None, streaming: bool = False, + num_proc: int | None = None, ) -> Dataset: """Loads and tokenizes a dataset, then chunks it into fixed-size blocks for training. @@ -144,8 +139,8 @@ def make_train_dataset( tokenizer: A Hugging Face tokenizer used to tokenize the input text. block_size: The fixed length of each chunked training example. text_column: The name of the column containing the text to be tokenized. - num_proc: Number of processes to use for dataset operations. streaming: Whether to stream the dataset. + num_proc: Number of processes for multiprocessing. Returns: A `Dataset` object containing tokenized and block-wise grouped training examples, @@ -155,19 +150,35 @@ def make_train_dataset( logger.info(f"Loading cached dataset from: {cached_dataset_path}") # `load_from_disk` works seamlessly with local paths and GCS URIs. data = load_from_disk(cached_dataset_path) + # In a distributed environment, ensure each process gets a unique shard of the + # dataset to avoid redundant work and OOM errors. + if xr.world_size() > 1: + logger.info( + f"Sharding cached dataset for worker {xr.process_ordinal()} of {xr.world_size()}" + ) + data = data.shard(num_shards=xr.world_size(), index=xr.process_ordinal()) data.set_format("torch") return data + logger.info("No `cached_dataset_path` provided. Processing dataset on-the-fly...") + data = load_hf_or_json_dataset( hf_dataset_name=hf_dataset_name, hf_dataset_config_name=hf_dataset_config_name, file_dataset_path=file_dataset_path, split=split, cache_dir=cache_dir, - num_proc=num_proc, streaming=streaming, ) + # In a distributed environment, ensure each process gets a unique shard of the + # dataset to avoid redundant work during on-the-fly preprocessing. + if xr.world_size() > 1 and not streaming: + logger.info( + f"Sharding dataset for worker {xr.process_ordinal()} of {xr.world_size()}" + ) + data = data.shard(num_shards=xr.world_size(), index=xr.process_ordinal()) + column_names = list(data.features) data = data.map( lambda samples: tokenizer(samples[text_column]), diff --git a/torchprime/experimental/gcs_cache_loading/preprocess_dataset.py b/torchprime/data/preprocess.py similarity index 83% rename from torchprime/experimental/gcs_cache_loading/preprocess_dataset.py rename to torchprime/data/preprocess.py index e8ba93db..c6055858 100644 --- a/torchprime/experimental/gcs_cache_loading/preprocess_dataset.py +++ b/torchprime/data/preprocess.py @@ -13,9 +13,10 @@ def main( tokenizer_name: str, output_path: str, block_size: int, - num_proc: int, split: str, text_column: str, + cache_dir: str | None, + num_workers: int = 1, ) -> None: """Main function to preprocess a dataset and save it to a specified location. @@ -25,9 +26,10 @@ def main( tokenizer_name: Name of the Hugging Face tokenizer. output_path: Path to save the processed dataset. block_size: Sequence length for packing. - num_proc: Number of processes for mapping. split: Dataset split to process. text_column: The column containing text data. + cache_dir: Directory to cache the raw dataset downloads. + num_workers: Number of processes for parallel processing. """ logger.info("Starting dataset preprocessing...") @@ -43,10 +45,11 @@ def main( tokenizer=tokenizer, block_size=block_size, text_column=text_column, - num_proc=num_proc, streaming=False, + cache_dir=cache_dir, + num_proc=num_workers, ) logger.info(f"Saving processed dataset to: {output_path}") processed_dataset.save_to_disk(output_path) - logger.info("Preprocessing complete.") \ No newline at end of file + logger.info("Preprocessing complete.") diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index 39702253..a03c1ade 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -92,29 +92,42 @@ def cli(ctx, interactive): required=True, help="Path to save the processed dataset (local or GCS).", ) -@click.option("--block-size", type=int, default=4096, help="Sequence length for packing.") -@click.option("--num-proc", type=int, default=8, help="Number of processes for mapping.") +@click.option( + "--block-size", type=int, default=4096, help="Sequence length for packing." +) @click.option("--split", default="train", help="Dataset split to process.") @click.option("--text-column", default="text", help="The column containing text data.") +@click.option( + "--cache-dir", default=None, help="Directory to cache the raw dataset downloads." +) +@click.option("--num-workers", type=int, default=50, help="Number of Dataflow workers.") def preprocess( dataset_name, dataset_config_name, tokenizer_name, output_path, block_size, - num_proc, split, text_column, + cache_dir, + num_workers, ): - """Preprocesses a dataset and saves it to a specified location.""" - from torchprime.experimental.gcs_cache_loading.preprocess_dataset import ( - main as preprocess_main, - ) + """Preprocesses a dataset and saves it to a specified location.""" + from torchprime.data.preprocess import ( + main as preprocess_main, + ) - preprocess_main( - dataset_name, dataset_config_name, tokenizer_name, output_path, block_size, - num_proc, split, text_column - ) + preprocess_main( + dataset_name, + dataset_config_name, + tokenizer_name, + output_path, + block_size, + split, + text_column, + cache_dir, + num_workers, + ) @cli.command() diff --git a/torchprime/metrics/metrics.py b/torchprime/metrics/metrics.py index fa13ac1b..5ac5b817 100644 --- a/torchprime/metrics/metrics.py +++ b/torchprime/metrics/metrics.py @@ -19,6 +19,9 @@ class Metrics: step_execution_time: timedelta | None """The average time to execute a training step.""" + dataset_load_time: timedelta | None + """The time it took to load and process the dataset.""" + mfu: float | None """Model FLOPs Utilization.""" @@ -60,10 +63,14 @@ def __init__(self): self.mfu = None self.tokens_per_second = None self.num_steps = None + self.dataset_load_time: float | None = None def log_step_execution_time(self, step_execution_time: float): self.step_execution_time = step_execution_time + def log_dataset_load_time(self, dataset_load_time: float): + self.dataset_load_time = dataset_load_time + def log_mfu(self, mfu: float): self.mfu = mfu @@ -80,6 +87,9 @@ def finalize(self) -> Metrics: step_execution_time=timedelta(seconds=self.step_execution_time) if self.step_execution_time else None, + dataset_load_time=timedelta(seconds=self.dataset_load_time) + if self.dataset_load_time is not None + else None, mfu=self.mfu, tokens_per_second=self.tokens_per_second, num_steps=self.num_steps, diff --git a/torchprime/torch_xla_models/configs/default.yaml b/torchprime/torch_xla_models/configs/default.yaml index 0cc84745..adf4e6ff 100644 --- a/torchprime/torch_xla_models/configs/default.yaml +++ b/torchprime/torch_xla_models/configs/default.yaml @@ -24,6 +24,8 @@ profile_end_step: null # when using tp run to launch the run using XPK profile_dir: profile +# Default path for preprocessed data, can be overridden in dataset-specific configs +cached_dataset_path: null # This might be overwritten when using tp run to launch the run using XPK output_dir: outputs diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index a8dcabef..826aad22 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -2,6 +2,7 @@ import logging import sys +from timeit import default_timer as timer import datasets import hydra @@ -69,12 +70,25 @@ def main(config: omegaconf.DictConfig): trainer_cls = torchprime.torch_xla_models.trainer.TRAINERS.get( config.task.name, torchprime.torch_xla_models.trainer.Trainer ) + + load_time_start = timer() data = retry.retry(lambda: dataset_fn(**config.dataset, tokenizer=tokenizer)) + load_time_end = timer() + load_time_seconds = load_time_end - load_time_start dataset_name = getattr(config.dataset, "hf_dataset_name", None) or getattr( config.dataset, "file_dataset_path", "unknown" ) - logger.info("Loaded dataset `%s`, size=%d (packed) samples", dataset_name, len(data)) + num_tokens = len(data) * config.dataset.block_size + tokens_per_second = num_tokens / load_time_seconds + logger.info("--- Dataset Loading Benchmark ---") + logger.info(" Dataset: %s", dataset_name) + logger.info(" Num samples: %d", len(data)) + logger.info(" Total tokens: %d", num_tokens) + logger.info(f" Load time: {load_time_seconds:.2f} seconds") + logger.info(f" Tokens/sec: {tokens_per_second:,.2f}") + logger.info("---------------------------------") + metrics_logger.log_dataset_load_time(load_time_seconds) trainer = trainer_cls( model=model, diff --git a/torchprime/torch_xla_models/trainer/base_trainer.py b/torchprime/torch_xla_models/trainer/base_trainer.py index 65e502eb..91692056 100644 --- a/torchprime/torch_xla_models/trainer/base_trainer.py +++ b/torchprime/torch_xla_models/trainer/base_trainer.py @@ -17,6 +17,7 @@ from pathlib import Path from timeit import default_timer as timer +import numpy as np import torch import torch.nn.utils as nn_utils import torch_xla @@ -49,6 +50,7 @@ setup_sharding_and_mesh, ) from torchprime.torch_xla_models.topology import get_num_slices +from torchprime.utils.data_load_benchmark_logger import DataLoadBenchmarkLogger from torchprime.utils.profiling import ensure_profile_end_step logger = logging.getLogger(__name__) @@ -91,10 +93,15 @@ def __init__( self.device = xm.xla_device() self.global_batch_size = self.config.task.global_batch_size self.train_dataset = train_dataset + self.dataloader_wait_times = [] # Initialize tensorboard metrics writer self._initialize_tensorboard_writer() + self.benchmark_logger = DataLoadBenchmarkLogger( + self.config.output_dir, "dataloader_benchmark.csv" + ) + # -- Model transformations -- # Recursively replace `nn.Linear` layers with einsum operations in the model. # Without this patch, an `nn.Linear` module will flatten non-contracting dimensions @@ -157,6 +164,7 @@ def _create_optimizer(config, model_parameters) -> torch.optim.Optimizer: def __del__(self): # Close TensorBoard writer on destruction. self.summary_writer.close() + self.benchmark_logger.close() def _initialize_tensorboard_writer(self): run_name = self.config.run_name @@ -196,16 +204,35 @@ def _get_train_dataloader(self) -> pl.MpDeviceLoader: # Each process will load the global batch, then discard the unneeded parts. batch_size = self.global_batch_size + # A good starting point for num_workers is the number of CPU cores per host. + # Setting this to 0 disables parallel data loading. + num_workers = getattr( + self.config.task, "dataloader_num_workers", os.cpu_count() or 0 + ) + + # # To avoid frequent synchronizations, set batches_per_execution to a larger + # # value. This allows the data loader to prefetch multiple batches + # # asynchronously. A good default is the number of logging steps. + # batches_per_execution = getattr( + # self.config.task, "batches_per_execution", self.config.logging_steps + # ) + # logger.info("Dataloader batches_per_execution: %d", batches_per_execution) + dataloader = DataLoader( self.train_dataset, # Data collator will default to DataCollatorWithPadding, so we change it. collate_fn=default_data_collator, + num_workers=num_workers, + persistent_workers=True, batch_size=batch_size, sampler=sampler, drop_last=True, ) loader = pl.MpDeviceLoader( - dataloader, self.device, input_sharding=self.input_sharding_spec + dataloader, + self.device, + input_sharding=self.input_sharding_spec, + # batches_per_execution=batches_per_execution, ) return loader @@ -235,6 +262,7 @@ def train_loop(self) -> None: epoch = 0 for step in range(max_step): + wait_start_time = timer() try: batch = next(train_iterator) except StopIteration: @@ -243,9 +271,31 @@ def train_loop(self) -> None: train_iterator = iter(train_loader) batch = next(train_iterator) + # Log batch shapes on the master worker to debug potential recompilations. + # This helps verify if input tensor shapes are consistent across steps. + if xm.is_master_ordinal(): + batch_shapes = {k: v.shape for k, v in batch.items()} + logger.info(f"Step {step} batch shapes: {batch_shapes}") + + wait_end_time = timer() + batch_wait_time = wait_end_time - wait_start_time + batch_wait_time_ms = batch_wait_time * 1000 + trace_start_time = timer() loss, grad_norm = self.train_step(batch) trace_end_time = timer() + compute_time_ms = (trace_end_time - trace_start_time) * 1000 + + self.dataloader_wait_times.append(batch_wait_time) + self.benchmark_logger.log_step( + epoch=step / steps_per_epoch, + step=step, + wait_time_ms=batch_wait_time_ms, + compute_time_ms=compute_time_ms, + ) + logger.info( + f"Epoch: {epoch:.4f}, step: {step}, batch loading time: {batch_wait_time_ms:.2f} ms" + ) if step % self.config.logging_steps == 0: @@ -254,14 +304,23 @@ def step_closure( ): loss = loss.detach().item() grad_norm = grad_norm.detach().item() + compute_time_ms = (trace_end_time - trace_start_time) * 1000 + + # A moving average of wait time over the last logging window. + wait_time_ms = ( + np.mean(self.dataloader_wait_times[-self.config.logging_steps :]) * 1000 + ) + step_time_ms = compute_time_ms + wait_time_ms logger.info( - "Epoch: %.4f, step: %d, loss: %.4f, grad_norm: %.4f, lr: %.2e, trace time: %.2f ms", + "Epoch: %.4f, step: %d, loss: %.4f, grad_norm: %.4f, lr: %.2e, step time: %.2f ms (compute: %.2f, wait: %.2f)", step / steps_per_epoch, step, loss, grad_norm, lr, - (trace_end_time - trace_start_time) * 1000, + step_time_ms, + compute_time_ms, + wait_time_ms, ) self._log_to_tensorboard(epoch, step, loss, lr, grad_norm) if math.isnan(loss): @@ -333,6 +392,10 @@ def finalize_training(self, metrics_logger) -> None: # Print and save metrics metrics = metrics_logger.finalize() logger.info("***** train metrics *****\n%s", metrics) + + # Ensure benchmark log is flushed and closed properly + logger.info("Saving data loading time benchmark log...") + self.benchmark_logger.close() metrics.save(Path(self.config.output_dir) / "train_metrics.json") # Save the hydra config diff --git a/torchprime/utils/data_load_benchmark_logger.py b/torchprime/utils/data_load_benchmark_logger.py new file mode 100644 index 00000000..89f8b676 --- /dev/null +++ b/torchprime/utils/data_load_benchmark_logger.py @@ -0,0 +1,41 @@ +import csv +from pathlib import Path + + +class DataLoadBenchmarkLogger: + """A simple logger for writing data loading benchmark data to a CSV file.""" + + def __init__(self, output_dir: str, filename: str): + """Initializes the logger. + + Args: + output_dir: The directory where the log file will be saved. + filename: The name of the CSV file. + """ + self.output_path = Path(output_dir) / filename + self.file = None + self.writer = None + + def log_step(self, **kwargs): + """Logs a single step of benchmark data. + + The first call to this method determines the CSV header from the keys + of the provided keyword arguments. + """ + with open(self.output_path, "a", newline="") as csvfile: + fieldnames = list(kwargs.keys()) + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + # Write header if the file is empty + if csvfile.tell() == 0: + writer.writeheader() + + self.writer.writerow(kwargs) + + def writerow(self, row): + with open(self.output_path, "a", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=list(row.keys())) + writer.writerow(row) + + def __del__(self): + pass From 7a658d155be365c05a6b16ac4edd43f15ca46c83 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Thu, 24 Jul 2025 16:27:47 +0000 Subject: [PATCH 3/7] Fix error in benchmark logger --- .../utils/data_load_benchmark_logger.py | 40 ++++++++++++------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/torchprime/utils/data_load_benchmark_logger.py b/torchprime/utils/data_load_benchmark_logger.py index 89f8b676..a0c6c35e 100644 --- a/torchprime/utils/data_load_benchmark_logger.py +++ b/torchprime/utils/data_load_benchmark_logger.py @@ -12,30 +12,40 @@ def __init__(self, output_dir: str, filename: str): output_dir: The directory where the log file will be saved. filename: The name of the CSV file. """ - self.output_path = Path(output_dir) / filename - self.file = None + self.output_path = Path(output_dir) + self.output_path.mkdir(parents=True, exist_ok=True) + self.output_path /= filename + # Open in append mode to support resuming training. + self.file = open(self.output_path, "a", newline="") self.writer = None + # Check if we need to write a header. If file is not empty, header is assumed to exist. + self.header_written = self.file.tell() > 0 def log_step(self, **kwargs): """Logs a single step of benchmark data. - The first call to this method determines the CSV header from the keys - of the provided keyword arguments. + The first call to this method also writes the CSV header. """ - with open(self.output_path, "a", newline="") as csvfile: - fieldnames = list(kwargs.keys()) - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + if self.file is None: + # Logger has been closed. + return - # Write header if the file is empty - if csvfile.tell() == 0: - writer.writeheader() + if self.writer is None: + fieldnames = list(kwargs.keys()) + self.writer = csv.DictWriter(self.file, fieldnames=fieldnames) + if not self.header_written: + self.writer.writeheader() + self.header_written = True self.writer.writerow(kwargs) + self.file.flush() - def writerow(self, row): - with open(self.output_path, "a", newline="") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=list(row.keys())) - writer.writerow(row) + def close(self): + """Closes the underlying file.""" + if self.file: + self.file.close() + self.file = None + self.writer = None def __del__(self): - pass + self.close() From e27d94dbb5c0f40a97ea2ae7cfb0b3ec726e48bf Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Thu, 24 Jul 2025 23:10:58 +0000 Subject: [PATCH 4/7] Clean up preprocessing --- torchprime/launcher/cli.py | 47 ++++++++++++++----- .../configs/dataset/wikitext.yaml | 1 - 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index a03c1ade..36916bc4 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -4,6 +4,7 @@ import getpass import json +import logging import os import re import subprocess @@ -19,6 +20,7 @@ from dataclasses_json import dataclass_json from pathspec import PathSpec from pathspec.patterns import GitWildMatchPattern # type: ignore +from transformers import AutoTokenizer from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer @@ -113,22 +115,41 @@ def preprocess( num_workers, ): """Preprocesses a dataset and saves it to a specified location.""" - from torchprime.data.preprocess import ( - main as preprocess_main, - ) + from torchprime.data.dataset import make_train_dataset - preprocess_main( - dataset_name, - dataset_config_name, - tokenizer_name, - output_path, - block_size, - split, - text_column, - cache_dir, - num_workers, + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + logger.info("Starting dataset preprocessing...") + + logger.info(f"Loading tokenizer: {tokenizer_name}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + logger.info("Loading and preprocessing raw dataset...") + processed_dataset = make_train_dataset( + hf_dataset_name=dataset_name, + hf_dataset_config_name=dataset_config_name, + split=split, + tokenizer=tokenizer, + block_size=block_size, + text_column=text_column, + streaming=False, + cache_dir=cache_dir, + num_proc=num_workers, ) + logger.info("Preprocessing finished. Now saving to disk...") + logger.info(f"Saving processed dataset to: {output_path}") + processed_dataset.save_to_disk(output_path) + logger.info("Preprocessing complete.") + @cli.command() @click.option("--cluster", required=True, help="Name of the XPK cluster") diff --git a/torchprime/torch_xla_models/configs/dataset/wikitext.yaml b/torchprime/torch_xla_models/configs/dataset/wikitext.yaml index c11ea435..0c21363c 100644 --- a/torchprime/torch_xla_models/configs/dataset/wikitext.yaml +++ b/torchprime/torch_xla_models/configs/dataset/wikitext.yaml @@ -4,4 +4,3 @@ hf_dataset_config_name: wikitext-2-raw-v1 split: train block_size: 8192 cache_dir: /tmp/ -cached_dataset_path: null # <--- Add this line with a default of null From de4d18ce542c5931c52e70f5aefbee3f59d11ca1 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Fri, 25 Jul 2025 00:42:46 +0000 Subject: [PATCH 5/7] Refactoring --- torchprime/data/dataset.py | 2 + .../torch_xla_models/trainer/base_trainer.py | 39 +++---------- .../utils/data_load_benchmark_logger.py | 55 ++++++------------- 3 files changed, 27 insertions(+), 69 deletions(-) diff --git a/torchprime/data/dataset.py b/torchprime/data/dataset.py index 884737b1..6dfcaaa1 100644 --- a/torchprime/data/dataset.py +++ b/torchprime/data/dataset.py @@ -121,6 +121,7 @@ def make_train_dataset( text_column: str = "text", streaming: bool = False, num_proc: int | None = None, + **kwargs, ) -> Dataset: """Loads and tokenizes a dataset, then chunks it into fixed-size blocks for training. @@ -141,6 +142,7 @@ def make_train_dataset( text_column: The name of the column containing the text to be tokenized. streaming: Whether to stream the dataset. num_proc: Number of processes for multiprocessing. + **kwargs: Unused keyword arguments. Returns: A `Dataset` object containing tokenized and block-wise grouped training examples, diff --git a/torchprime/torch_xla_models/trainer/base_trainer.py b/torchprime/torch_xla_models/trainer/base_trainer.py index 91692056..ec2f0918 100644 --- a/torchprime/torch_xla_models/trainer/base_trainer.py +++ b/torchprime/torch_xla_models/trainer/base_trainer.py @@ -99,7 +99,9 @@ def __init__( self._initialize_tensorboard_writer() self.benchmark_logger = DataLoadBenchmarkLogger( - self.config.output_dir, "dataloader_benchmark.csv" + self.config.output_dir, + "dataloader_benchmark.csv", + fieldnames=["epoch", "step", "wait_time_ms", "compute_time_ms"], ) # -- Model transformations -- @@ -164,7 +166,6 @@ def _create_optimizer(config, model_parameters) -> torch.optim.Optimizer: def __del__(self): # Close TensorBoard writer on destruction. self.summary_writer.close() - self.benchmark_logger.close() def _initialize_tensorboard_writer(self): run_name = self.config.run_name @@ -204,26 +205,10 @@ def _get_train_dataloader(self) -> pl.MpDeviceLoader: # Each process will load the global batch, then discard the unneeded parts. batch_size = self.global_batch_size - # A good starting point for num_workers is the number of CPU cores per host. - # Setting this to 0 disables parallel data loading. - num_workers = getattr( - self.config.task, "dataloader_num_workers", os.cpu_count() or 0 - ) - - # # To avoid frequent synchronizations, set batches_per_execution to a larger - # # value. This allows the data loader to prefetch multiple batches - # # asynchronously. A good default is the number of logging steps. - # batches_per_execution = getattr( - # self.config.task, "batches_per_execution", self.config.logging_steps - # ) - # logger.info("Dataloader batches_per_execution: %d", batches_per_execution) - dataloader = DataLoader( self.train_dataset, # Data collator will default to DataCollatorWithPadding, so we change it. collate_fn=default_data_collator, - num_workers=num_workers, - persistent_workers=True, batch_size=batch_size, sampler=sampler, drop_last=True, @@ -232,7 +217,6 @@ def _get_train_dataloader(self) -> pl.MpDeviceLoader: dataloader, self.device, input_sharding=self.input_sharding_spec, - # batches_per_execution=batches_per_execution, ) return loader @@ -271,12 +255,6 @@ def train_loop(self) -> None: train_iterator = iter(train_loader) batch = next(train_iterator) - # Log batch shapes on the master worker to debug potential recompilations. - # This helps verify if input tensor shapes are consistent across steps. - if xm.is_master_ordinal(): - batch_shapes = {k: v.shape for k, v in batch.items()} - logger.info(f"Step {step} batch shapes: {batch_shapes}") - wait_end_time = timer() batch_wait_time = wait_end_time - wait_start_time batch_wait_time_ms = batch_wait_time * 1000 @@ -300,7 +278,7 @@ def train_loop(self) -> None: if step % self.config.logging_steps == 0: def step_closure( - epoch, step, loss, grad_norm, trace_start_time, trace_end_time, lr + fractional_epoch, step, loss, grad_norm, trace_start_time, trace_end_time, lr ): loss = loss.detach().item() grad_norm = grad_norm.detach().item() @@ -313,7 +291,7 @@ def step_closure( step_time_ms = compute_time_ms + wait_time_ms logger.info( "Epoch: %.4f, step: %d, loss: %.4f, grad_norm: %.4f, lr: %.2e, step time: %.2f ms (compute: %.2f, wait: %.2f)", - step / steps_per_epoch, + fractional_epoch, step, loss, grad_norm, @@ -322,14 +300,14 @@ def step_closure( compute_time_ms, wait_time_ms, ) - self._log_to_tensorboard(epoch, step, loss, lr, grad_norm) + self._log_to_tensorboard(fractional_epoch, step, loss, lr, grad_norm) if math.isnan(loss): raise ValueError(f"Loss is NaN at step {step}") xm.add_step_closure( step_closure, args=( - epoch, + step / steps_per_epoch, step, loss, grad_norm, @@ -393,9 +371,8 @@ def finalize_training(self, metrics_logger) -> None: metrics = metrics_logger.finalize() logger.info("***** train metrics *****\n%s", metrics) - # Ensure benchmark log is flushed and closed properly + # The benchmark logger now handles file operations within each log_step call. logger.info("Saving data loading time benchmark log...") - self.benchmark_logger.close() metrics.save(Path(self.config.output_dir) / "train_metrics.json") # Save the hydra config diff --git a/torchprime/utils/data_load_benchmark_logger.py b/torchprime/utils/data_load_benchmark_logger.py index a0c6c35e..658bc07d 100644 --- a/torchprime/utils/data_load_benchmark_logger.py +++ b/torchprime/utils/data_load_benchmark_logger.py @@ -1,51 +1,30 @@ import csv from pathlib import Path +from typing import Any class DataLoadBenchmarkLogger: """A simple logger for writing data loading benchmark data to a CSV file.""" - def __init__(self, output_dir: str, filename: str): + def __init__(self, output_dir: str, filename: str, fieldnames: list[str]): """Initializes the logger. Args: output_dir: The directory where the log file will be saved. filename: The name of the CSV file. + fieldnames: The list of column names for the CSV file. """ - self.output_path = Path(output_dir) - self.output_path.mkdir(parents=True, exist_ok=True) - self.output_path /= filename - # Open in append mode to support resuming training. - self.file = open(self.output_path, "a", newline="") - self.writer = None - # Check if we need to write a header. If file is not empty, header is assumed to exist. - self.header_written = self.file.tell() > 0 - - def log_step(self, **kwargs): - """Logs a single step of benchmark data. - - The first call to this method also writes the CSV header. - """ - if self.file is None: - # Logger has been closed. - return - - if self.writer is None: - fieldnames = list(kwargs.keys()) - self.writer = csv.DictWriter(self.file, fieldnames=fieldnames) - if not self.header_written: - self.writer.writeheader() - self.header_written = True - - self.writer.writerow(kwargs) - self.file.flush() - - def close(self): - """Closes the underlying file.""" - if self.file: - self.file.close() - self.file = None - self.writer = None - - def __del__(self): - self.close() + self.output_path = Path(output_dir) / filename + self.fieldnames = fieldnames + # Ensure the output directory exists. + self.output_path.parent.mkdir(parents=True, exist_ok=True) + + def log_step(self, **kwargs: Any): + """Logs a single step of benchmark data.""" + file_exists = self.output_path.exists() + + with self.output_path.open("a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=self.fieldnames) + if not file_exists or f.tell() == 0: + writer.writeheader() + writer.writerow(kwargs) From 7d9a85c51b5e8dc075bf60d52ce76afb4f2a260b Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Fri, 25 Jul 2025 01:34:27 +0000 Subject: [PATCH 6/7] Fix a bug regarding yaml file and remove double sharding --- torchprime/data/dataset.py | 17 ----------------- .../torch_xla_models/configs/default.yaml | 6 ++++-- 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/torchprime/data/dataset.py b/torchprime/data/dataset.py index 6dfcaaa1..c961784a 100644 --- a/torchprime/data/dataset.py +++ b/torchprime/data/dataset.py @@ -4,7 +4,6 @@ import logging import fsspec -import torch_xla.runtime as xr from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from transformers.tokenization_utils import PreTrainedTokenizerBase @@ -152,14 +151,6 @@ def make_train_dataset( logger.info(f"Loading cached dataset from: {cached_dataset_path}") # `load_from_disk` works seamlessly with local paths and GCS URIs. data = load_from_disk(cached_dataset_path) - # In a distributed environment, ensure each process gets a unique shard of the - # dataset to avoid redundant work and OOM errors. - if xr.world_size() > 1: - logger.info( - f"Sharding cached dataset for worker {xr.process_ordinal()} of {xr.world_size()}" - ) - data = data.shard(num_shards=xr.world_size(), index=xr.process_ordinal()) - data.set_format("torch") return data logger.info("No `cached_dataset_path` provided. Processing dataset on-the-fly...") @@ -173,14 +164,6 @@ def make_train_dataset( streaming=streaming, ) - # In a distributed environment, ensure each process gets a unique shard of the - # dataset to avoid redundant work during on-the-fly preprocessing. - if xr.world_size() > 1 and not streaming: - logger.info( - f"Sharding dataset for worker {xr.process_ordinal()} of {xr.world_size()}" - ) - data = data.shard(num_shards=xr.world_size(), index=xr.process_ordinal()) - column_names = list(data.features) data = data.map( lambda samples: tokenizer(samples[text_column]), diff --git a/torchprime/torch_xla_models/configs/default.yaml b/torchprime/torch_xla_models/configs/default.yaml index adf4e6ff..52e1138e 100644 --- a/torchprime/torch_xla_models/configs/default.yaml +++ b/torchprime/torch_xla_models/configs/default.yaml @@ -24,8 +24,6 @@ profile_end_step: null # when using tp run to launch the run using XPK profile_dir: profile -# Default path for preprocessed data, can be overridden in dataset-specific configs -cached_dataset_path: null # This might be overwritten when using tp run to launch the run using XPK output_dir: outputs @@ -33,6 +31,10 @@ output_dir: outputs # If unspecified, defaults to the current date and time. run_name: null +dataset: + # Default path for preprocessed data, can be overridden in dataset-specific configs + cached_dataset_path: null + # The virtual device mesh shape to use within a TPU slice. This is also called # the "ICI mesh", since devices within a slice enjoy a faster network called # "Inter-Chip Interconnect". From a2e7cefd7831c7ea6b89ce21a51ea527144c87be Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Mon, 28 Jul 2025 20:20:44 +0000 Subject: [PATCH 7/7] Remove unused preprocess.py --- torchprime/data/preprocess.py | 55 ----------------------------------- 1 file changed, 55 deletions(-) delete mode 100644 torchprime/data/preprocess.py diff --git a/torchprime/data/preprocess.py b/torchprime/data/preprocess.py deleted file mode 100644 index c6055858..00000000 --- a/torchprime/data/preprocess.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging - -from transformers import AutoTokenizer - -from torchprime.data.dataset import make_train_dataset - -logger = logging.getLogger(__name__) - - -def main( - dataset_name: str, - dataset_config_name: str | None, - tokenizer_name: str, - output_path: str, - block_size: int, - split: str, - text_column: str, - cache_dir: str | None, - num_workers: int = 1, -) -> None: - """Main function to preprocess a dataset and save it to a specified location. - - Args: - dataset_name: Name of the Hugging Face dataset. - dataset_config_name: Optional configuration name for the dataset. - tokenizer_name: Name of the Hugging Face tokenizer. - output_path: Path to save the processed dataset. - block_size: Sequence length for packing. - split: Dataset split to process. - text_column: The column containing text data. - cache_dir: Directory to cache the raw dataset downloads. - num_workers: Number of processes for parallel processing. - """ - logger.info("Starting dataset preprocessing...") - - logger.info(f"Loading tokenizer: {tokenizer_name}") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - processed_dataset = make_train_dataset( - hf_dataset_name=dataset_name, - hf_dataset_config_name=dataset_config_name, - split=split, - tokenizer=tokenizer, - block_size=block_size, - text_column=text_column, - streaming=False, - cache_dir=cache_dir, - num_proc=num_workers, - ) - - logger.info(f"Saving processed dataset to: {output_path}") - processed_dataset.save_to_disk(output_path) - logger.info("Preprocessing complete.")