Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ dependencies = [
"benchmark-db-writer @ git+https://github.com/AI-Hypercomputer/aotc.git@2ff16e670df20b497ddaf1f86920dbb5dd9f0c8f#subdirectory=src/aotc/benchmark_db_writer",
"dacite==1.9.2",
"click~=8.1.8",
"google-cloud-storage==2.19.0"
"google-cloud-storage==2.19.0",
"gcsfs"
]

[project.optional-dependencies]
Expand Down
49 changes: 44 additions & 5 deletions torchprime/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Utilities for preparing datasets for basic training tasks."""

import json
import logging

import fsspec
from datasets import Dataset, DatasetDict, load_dataset
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from transformers.tokenization_utils import PreTrainedTokenizerBase

logger = logging.getLogger(__name__)


def _load_json_dataset(path: str, split: str) -> Dataset:
"""Load a dataset from a JSON Lines file.
Expand Down Expand Up @@ -33,6 +36,7 @@ def _load_hf_dataset(
config: str | None,
split: str,
cache_dir: str | None,
streaming: bool = False,
) -> Dataset:
"""Download and return a dataset from Hugging Face Hub.

Expand All @@ -41,12 +45,19 @@ def _load_hf_dataset(
config: Optional configuration name.
split: Split to load.
cache_dir: Directory where the dataset cache should live.
streaming: Whether to stream the dataset.

Returns:
The loaded ``Dataset`` instance for ``split``.
"""

data = load_dataset(name, config, split=split, cache_dir=cache_dir)
data = load_dataset(
name,
config,
split=split,
cache_dir=cache_dir,
streaming=streaming,
)
assert isinstance(data, Dataset | DatasetDict)
if isinstance(data, DatasetDict):
data = data[split]
Expand All @@ -59,6 +70,7 @@ def load_hf_or_json_dataset(
file_dataset_path: str | None = None,
split: str = "train",
cache_dir: str | None = None,
streaming: bool = False,
):
"""Loads a dataset either from Hugging Face Hub or a local/remote JSONL file.

Expand All @@ -72,12 +84,19 @@ def load_hf_or_json_dataset(
file_dataset_path: Optional path to a JSONL file (local or remote).
split: Dataset split to load (default is "train").
cache_dir: Optional directory to use for dataset caching (HF only).
streaming: Whether to stream the dataset (HF only).

Returns:
A HuggingFace ``Dataset`` instance.
"""
if hf_dataset_name:
data = _load_hf_dataset(hf_dataset_name, hf_dataset_config_name, split, cache_dir)
data = _load_hf_dataset(
hf_dataset_name,
hf_dataset_config_name,
split,
cache_dir,
streaming,
)
elif file_dataset_path:
data = _load_json_dataset(file_dataset_path, split)
else:
Expand All @@ -89,6 +108,7 @@ def load_hf_or_json_dataset(


def make_train_dataset(
cached_dataset_path: str | None = None,
hf_dataset_name: str | None = None,
hf_dataset_config_name: str | None = None,
file_dataset_path: str | None = None,
Expand All @@ -97,6 +117,10 @@ def make_train_dataset(
*,
tokenizer: PreTrainedTokenizerBase,
block_size: int,
text_column: str = "text",
streaming: bool = False,
num_proc: int | None = None,
**kwargs,
) -> Dataset:
"""Loads and tokenizes a dataset, then chunks it into fixed-size blocks for training.

Expand All @@ -106,31 +130,46 @@ def make_train_dataset(
for efficient language modeling, especially on accelerators like TPUs.

Args:
cached_dataset_path: Optional path to a pre-processed, cached dataset.
hf_dataset_name: Optional Hugging Face dataset name. (e.g., "wikitext").
hf_dataset_config_name: Optional HF dataset config name. (e.g., "wikitext-103-raw-v1").
file_dataset_path: Optional path or ``gs://`` URI to a JSONL dataset.
split: Dataset split to load from HF. (e.g., "train", "validation").
cache_dir: Optional directory for HF dataset cache.
tokenizer: A Hugging Face tokenizer used to tokenize the input text.
block_size: The fixed length of each chunked training example.
text_column: The name of the column containing the text to be tokenized.
streaming: Whether to stream the dataset.
num_proc: Number of processes for multiprocessing.
**kwargs: Unused keyword arguments.

Returns:
A `Dataset` object containing tokenized and block-wise grouped training examples,
each with keys `"input_ids"` and `"labels"`.
"""
if cached_dataset_path:
logger.info(f"Loading cached dataset from: {cached_dataset_path}")
# `load_from_disk` works seamlessly with local paths and GCS URIs.
data = load_from_disk(cached_dataset_path)
return data

logger.info("No `cached_dataset_path` provided. Processing dataset on-the-fly...")

data = load_hf_or_json_dataset(
hf_dataset_name=hf_dataset_name,
hf_dataset_config_name=hf_dataset_config_name,
file_dataset_path=file_dataset_path,
split=split,
cache_dir=cache_dir,
streaming=streaming,
)

column_names = list(data.features)
data = data.map(
lambda samples: tokenizer(samples["text"]),
lambda samples: tokenizer(samples[text_column]),
batched=True,
remove_columns=column_names,
num_proc=num_proc,
)

def group_texts(examples):
Expand All @@ -155,5 +194,5 @@ def group_texts(examples):
result["labels"] = result["input_ids"].copy()
return result

data = data.map(group_texts, batched=True)
data = data.map(group_texts, batched=True, num_proc=num_proc)
return data
74 changes: 74 additions & 0 deletions torchprime/launcher/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import getpass
import json
import logging
import os
import re
import subprocess
Expand All @@ -19,6 +20,7 @@
from dataclasses_json import dataclass_json
from pathspec import PathSpec
from pathspec.patterns import GitWildMatchPattern # type: ignore
from transformers import AutoTokenizer
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer

Expand Down Expand Up @@ -77,6 +79,78 @@ def cli(ctx, interactive):
ctx.obj["interactive"] = interactive


@cli.command()
@click.option("--dataset-name", required=True, help="Name of the Hugging Face dataset.")
@click.option(
"--dataset-config-name",
default=None,
help="Configuration name of the Hugging Face dataset.",
)
@click.option(
"--tokenizer-name", required=True, help="Name of the Hugging Face tokenizer."
)
@click.option(
"--output-path",
required=True,
help="Path to save the processed dataset (local or GCS).",
)
@click.option(
"--block-size", type=int, default=4096, help="Sequence length for packing."
)
@click.option("--split", default="train", help="Dataset split to process.")
@click.option("--text-column", default="text", help="The column containing text data.")
@click.option(
"--cache-dir", default=None, help="Directory to cache the raw dataset downloads."
)
@click.option("--num-workers", type=int, default=50, help="Number of Dataflow workers.")
def preprocess(
dataset_name,
dataset_config_name,
tokenizer_name,
output_path,
block_size,
split,
text_column,
cache_dir,
num_workers,
):
"""Preprocesses a dataset and saves it to a specified location."""
from torchprime.data.dataset import make_train_dataset

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

logger.info("Starting dataset preprocessing...")

logger.info(f"Loading tokenizer: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

logger.info("Loading and preprocessing raw dataset...")
processed_dataset = make_train_dataset(
hf_dataset_name=dataset_name,
hf_dataset_config_name=dataset_config_name,
split=split,
tokenizer=tokenizer,
block_size=block_size,
text_column=text_column,
streaming=False,
cache_dir=cache_dir,
num_proc=num_workers,
)

logger.info("Preprocessing finished. Now saving to disk...")
logger.info(f"Saving processed dataset to: {output_path}")
processed_dataset.save_to_disk(output_path)
logger.info("Preprocessing complete.")


@cli.command()
@click.option("--cluster", required=True, help="Name of the XPK cluster")
@click.option("--project", required=True, help="GCP project the cluster belongs to")
Expand Down
10 changes: 10 additions & 0 deletions torchprime/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class Metrics:
step_execution_time: timedelta | None
"""The average time to execute a training step."""

dataset_load_time: timedelta | None
"""The time it took to load and process the dataset."""

mfu: float | None
"""Model FLOPs Utilization."""

Expand Down Expand Up @@ -60,10 +63,14 @@ def __init__(self):
self.mfu = None
self.tokens_per_second = None
self.num_steps = None
self.dataset_load_time: float | None = None

def log_step_execution_time(self, step_execution_time: float):
self.step_execution_time = step_execution_time

def log_dataset_load_time(self, dataset_load_time: float):
self.dataset_load_time = dataset_load_time

def log_mfu(self, mfu: float):
self.mfu = mfu

Expand All @@ -80,6 +87,9 @@ def finalize(self) -> Metrics:
step_execution_time=timedelta(seconds=self.step_execution_time)
if self.step_execution_time
else None,
dataset_load_time=timedelta(seconds=self.dataset_load_time)
if self.dataset_load_time is not None
else None,
mfu=self.mfu,
tokens_per_second=self.tokens_per_second,
num_steps=self.num_steps,
Expand Down
4 changes: 4 additions & 0 deletions torchprime/torch_xla_models/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ output_dir: outputs
# If unspecified, defaults to the current date and time.
run_name: null

dataset:
# Default path for preprocessed data, can be overridden in dataset-specific configs
cached_dataset_path: null

# The virtual device mesh shape to use within a TPU slice. This is also called
# the "ICI mesh", since devices within a slice enjoy a faster network called
# "Inter-Chip Interconnect".
Expand Down
16 changes: 15 additions & 1 deletion torchprime/torch_xla_models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import sys
from timeit import default_timer as timer

import datasets
import hydra
Expand Down Expand Up @@ -69,12 +70,25 @@ def main(config: omegaconf.DictConfig):
trainer_cls = torchprime.torch_xla_models.trainer.TRAINERS.get(
config.task.name, torchprime.torch_xla_models.trainer.Trainer
)

load_time_start = timer()
data = retry.retry(lambda: dataset_fn(**config.dataset, tokenizer=tokenizer))
load_time_end = timer()
load_time_seconds = load_time_end - load_time_start

dataset_name = getattr(config.dataset, "hf_dataset_name", None) or getattr(
config.dataset, "file_dataset_path", "unknown"
)
logger.info("Loaded dataset `%s`, size=%d (packed) samples", dataset_name, len(data))
num_tokens = len(data) * config.dataset.block_size
tokens_per_second = num_tokens / load_time_seconds
logger.info("--- Dataset Loading Benchmark ---")
logger.info(" Dataset: %s", dataset_name)
logger.info(" Num samples: %d", len(data))
logger.info(" Total tokens: %d", num_tokens)
logger.info(f" Load time: {load_time_seconds:.2f} seconds")
logger.info(f" Tokens/sec: {tokens_per_second:,.2f}")
logger.info("---------------------------------")
metrics_logger.log_dataset_load_time(load_time_seconds)

trainer = trainer_cls(
model=model,
Expand Down
Loading