Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,4 @@ project/
*.sif
MNIST/
CIFAR10/
models/
146 changes: 146 additions & 0 deletions examples/EuroPar/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# xFFL — EuroPar 2026 — FL + *SDP Experiments

This repository contains the source code required to reproduce the large-scale pre-training experiments presented at **EuroPar 2026**, based on **federated-learning-augmented sharded data parallelism** (FL + *SDP).

The experiments train **Llama 3.1-8B** on the **clean_mc4_it** dataset using the xFFL framework on a large HPC system.

> 🧪 Original experiments were executed on **128 Leonardo HPC nodes (512 NVIDIA A100 GPUs)**.

## Overview

This repository provides everything needed to reproduce the training runs evaluated in the paper.

### Included components

* **Four configuration files**, one for each parallelization strategy:

* `config_FSDP.py` — Fully Sharded Data Parallel (FSDP)
* `config_HSDP.py` — Hierarchical Sharded Data Parallel (HSDP)
* `config_FL+FSDP.py` — Federated Learning + FSDP
* `config_FL+HSDP.py` — Federated Learning + HSDP

Each configuration defines:

* model parameters
* dataset paths
* optimizer settings
* distributed training options
* logging configuration

* **Training script (`training.py`)**

Implements a full xFFL-compliant large-scale LLM pre-training pipeline with:

* deterministic initialization for reproducibility
* distributed setup (FSDP / HSDP / FL+FSDP / FL+HSDP)
* dataset loading and preprocessing
* optional Weights & Biases logging
* multi-node orchestration via xFFL

## Training Setup

* **Model:** Llama 3.1-8B
* **Dataset:** clean_mc4_it
* **Precision:** bfloat16
* **Training type:** full pre-training (not fine-tuning)

⚠️ Parameter-efficient methods such as LoRA or QLoRA are **not used**.

## Repository Structure

```
EuroPar/
├── config_FSDP.py # LLaMA 3.1-8B + clean_mc4_it using FSDP
├── config_HSDP.py # LLaMA 3.1-8B + clean_mc4_it using HSDP
├── config_FL+FSDP.py # LLaMA 3.1-8B + clean_mc4_it using FL+FSDP
├── config_FL+HSDP.py # LLaMA 3.1-8B + clean_mc4_it using FL+HSDP
├── training.py # Main distributed training script
```

## Requirements

### Software

* Python environment compatible with xFFL
* xFFL installed and properly configured
* PyTorch with distributed support
* Access to the Llama 3.1-8B weights
* Access to the clean_mc4_it dataset

### Hardware

This code targets large HPC systems.

At minimum:

* Multi-node GPU cluster
* High-speed interconnect (e.g., InfiniBand)
* SLURM or compatible scheduler
* Sufficient storage bandwidth for dataset streaming

## Installation

Install and configure xFFL according to its official documentation.

Ensure that:

* distributed communication works across nodes
* filesystem paths are accessible from all nodes
* required datasets and model checkpoints are available

## Running the Experiments

After allocating compute nodes and activating your environment:

### Basic execution example

```bash
xffl exec training.py config_FSDP.py
```

### What happens during execution

xFFL will automatically:

* detect the allocated resources (e.g., via SLURM)
* initialize distributed training
* load the Llama 3.1-8B model from the configured path
* wrap the model with the selected parallelization strategy
* load the dataset
* start synchronized multi-node training

## Configuration

All experiment-specific parameters are defined in the configuration files.

Before running:

* Update dataset paths
* Update model checkpoint paths
* Adjust output/logging directories
* Verify cluster-specific settings

⚠️ Paths must be valid on all participating nodes.


## Reproducibility Notes

* Deterministic initialization is enabled where possible
* Performance and scalability results depend on hardware topology
* Exact reproduction requires a system comparable to the Leonardo cluster

## Limitations

* Designed for **pre-training workloads**, not fine-tuning
* Assumes availability of large GPU clusters
* Not optimized for single-node execution

## Adapting the Code

Although tailored for pre-training, the pipeline can be extended to support:

* fine-tuning
* alternative datasets
* different model sizes
* custom distributed strategies
170 changes: 170 additions & 0 deletions examples/EuroPar/config_FL+FSDP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Configuration file for the xFFL-LLM example"""

import logging
import math
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Callable, Mapping, Sequence, Type

import torch
from torch import nn
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from transformers import default_data_collator
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM

from xffl.custom.config import DatasetInfo, ModelInfo, XFFLConfig
from xffl.distributed.distributed_state import DistributedState
from xffl.learning.data import load_datasets_from_disk

# Constants
LLAMA3_1_8B: str = "llama3.1-8b-init"
CLEAN_MC4_IT: str = "clean_mc4_it"

BASE_PATH: str = "/leonardo_scratch/fast/uToID_bench/xffl"


@dataclass
class llama(ModelInfo):

# LLM loading from saved model
@staticmethod
def _load_llm_from_checkpoint(
config: XFFLConfig, state: DistributedState
) -> nn.Module:
return LlamaForCausalLM.from_pretrained(
pretrained_model_name_or_path=str(config.model_info.path),
use_cache=True,
local_files_only=True, # Most HPCs do not have internet access from the nodes
attn_implementation=config.model_info.attention,
dtype=torch.bfloat16, # Model is loaded in torch.bfloat16 (from the JSON file) - also "auto"
device_map=state.init_device,
use_safetensors=True,
low_cpu_mem_usage=True,
tie_word_embeddings=True,
)

# Auto wrap policy
@staticmethod
def llama_fsdp_wrap_policy():
return partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
LlamaDecoderLayer,
},
)

name: str = LLAMA3_1_8B
attention: str = "sdpa" # "flash_attention_2"
model: Callable = _load_llm_from_checkpoint
decoder_layer: Type = LlamaDecoderLayer
wrapping_policy: Callable = llama_fsdp_wrap_policy
mixed_precision: MixedPrecision = field(
default_factory=lambda: MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
)
path: str = BASE_PATH + "/models/" + name


@dataclass
class cleanmc4it(DatasetInfo):

@staticmethod
def _get_cleanmc4it_splits(config: XFFLConfig, state: DistributedState):
return load_datasets_from_disk(
splits={"train": "train", "val": "val"},
base_path=Path(str(config.dataset_info.path)),
) # Original LLaMA training packs the datasets

name: str = CLEAN_MC4_IT
splits: Callable = _get_cleanmc4it_splits
batch_sizes: Mapping[str, int] = field(
default_factory=lambda: {"train": 2, "val": 2}
)
subsampling: Mapping[str, int] = field(
default_factory=lambda: {"train": 65536, "val": 4096}
)
workers: int = 2
collate_fn: Callable = default_data_collator
path: str = BASE_PATH + "/data/" + CLEAN_MC4_IT


# XFFL configuration
@dataclass
class xffl_config(XFFLConfig):

# Optimizer
@staticmethod
def _get_optimizer(model: nn.Module, config: XFFLConfig) -> Optimizer:
return AdamW(
params=model.parameters(),
lr=config.learning_rate, # type: ignore
weight_decay=config.weight_decay, # type: ignore
betas=config.betas, # type: ignore
fused=True, # Supported only on torch.float64, torch.float32, torch.float16, and torch.bfloat16
)

# Default
model_info: ModelInfo = field(default_factory=llama)
dataset_info: DatasetInfo = field(default_factory=cleanmc4it)
optimizer: Callable[[nn.Module, XFFLConfig], Optimizer] = _get_optimizer

# General
loglevel: int = logging.DEBUG
seed: int = 42
federated: int = 4
federated_batches: int = 8

# Learning
learning_rate: float = 3e-4
epochs: int = 1

# WandB
wandb_entity: str = "alpha-unito"
wandb_project: str = "FL+DP"
wandb_group: str = "FL+FSDP_new"
wandb_notes: str = "EuroPar 2026 experiments"
wandb_tags: Sequence[str] = field(default_factory=lambda: ["xFFL", "EuroPar"])
wandb_mode: str = "offline"

# Learning rate scheduler
@staticmethod
def _get_llama31_cosine_schedule(
optimizer: Optimizer, total_steps: int, config: XFFLConfig
) -> LRScheduler:
"""
Scheduler stile LLaMA3.1 semplificato: warmup -> cosine decay.

Args:
optimizer: torch.optim.Optimizer
total_steps (int): passi totali (es. 128)
lr_max (float): learning rate massimo
warmup_frac (float): frazione di warmup (default 5%)
"""
warmup_steps = int(total_steps * config.warmup_frac) # type: ignore
decay_steps = total_steps - warmup_steps

def lr_lambda(step):
if step < warmup_steps:
# Linear warmup
return step / max(1, warmup_steps)
else:
# Cosine decay
progress = (step - warmup_steps) / max(1, decay_steps)
return 0.5 * (1 + math.cos(math.pi * progress))

return LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step))

# Advanced configuration
lr_scheduler: Callable = _get_llama31_cosine_schedule

# Custom - optimizer
weight_decay: float = 0.1
betas: Sequence[float] = (0.9, 0.95)
warmup_frac: float = 0.1
Loading
Loading