From fd2c1988d385530a29648d41e117720e1143231d Mon Sep 17 00:00:00 2001 From: Gianluca Mittone Date: Wed, 18 Feb 2026 20:48:46 +0100 Subject: [PATCH 1/5] Prepare EuroPar experiments --- .gitignore | 1 + examples/intra-silo/03_LLM/config.py | 192 ++++++++++++++----------- examples/intra-silo/03_LLM/training.py | 3 +- examples/scripts/env/leonardo.sh | 3 +- xffl/distributed/aggregation.py | 19 ++- xffl/learning/modelling.py | 7 +- xffl/learning/utils.py | 22 ++- 7 files changed, 149 insertions(+), 98 deletions(-) diff --git a/.gitignore b/.gitignore index 4d272b4..e9b410e 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,4 @@ project/ *.sif MNIST/ CIFAR10/ +models/ diff --git a/examples/intra-silo/03_LLM/config.py b/examples/intra-silo/03_LLM/config.py index fbb071b..2ebb41a 100644 --- a/examples/intra-silo/03_LLM/config.py +++ b/examples/intra-silo/03_LLM/config.py @@ -4,16 +4,18 @@ import math import os from dataclasses import dataclass, field +from functools import partial from pathlib import Path from typing import Callable, Mapping, Sequence, Type import torch from torch import nn from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.optim import AdamW, Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler from transformers import AutoModelForCausalLM, default_data_collator -from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer from xffl.custom.config import DatasetInfo, ModelInfo, XFFLConfig @@ -21,91 +23,51 @@ from xffl.learning.data import load_datasets_from_disk # Constants -TINY_RANDOM_LLAMA_3: str = "tiny_random_Llama-3" -LLAMA3_1_8B: str = "llama3.1-8b" +TINY_RANDOM_LLAMA_3: str = "tiny-random-llama-3" +LLAMA3_1_8B: str = "llama3.1-8b-init" LLAMA3_1_70B: str = "llama3.1-70b" MIXTRAL_8x7b_v0_1: str = "mixtral-8x7b-v0.1" CLEAN_MC4_IT: str = "clean_mc4_it" -BASE_PATH: str = str(os.getcwd()) + "/xffl" - - -# LLM loading from saved model -def _load_llm_from_checkpoint(config: XFFLConfig, state: DistributedState) -> nn.Module: - return AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=str(config.model_info.path), - use_cache=False, - local_files_only=True, # Most HPCs do not have internet access from the nodes - attn_implementation=config.model_info.attention, - dtype=torch.float32, # Model is loaded in torch.bfloat16 (from the JSON file) - also "auto" - device_map=state.init_device, - use_safetensors=True, - ) - - -def _get_llama31_cosine_schedule( - optimizer: Optimizer, total_steps: int, config: XFFLConfig -) -> LRScheduler: - """ - Scheduler stile LLaMA3.1 semplificato: warmup -> cosine decay. - - Args: - optimizer: torch.optim.Optimizer - total_steps (int): passi totali (es. 128) - lr_max (float): learning rate massimo - warmup_frac (float): frazione di warmup (default 5%) - """ - warmup_steps = int(total_steps * config.warmup_frac) # type: ignore - decay_steps = total_steps - warmup_steps - - def lr_lambda(step): - if step < warmup_steps: - # warmup lineare - return step / max(1, warmup_steps) - else: - # decadimento coseno - progress = (step - warmup_steps) / max(1, decay_steps) - return 0.5 * (1 + math.cos(math.pi * progress)) - - return LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step)) - - -# Optimizer -def _get_optimizer(model: nn.Module, config: XFFLConfig) -> Optimizer: - return AdamW( - params=model.parameters(), - lr=config.learning_rate, # type: ignore - weight_decay=config.weight_decay, # type: ignore - betas=config.betas, # type: ignore - # foreach=True, # Optimizes performances but uses more memory - fused=True, # Supported only on torch.float64, torch.float32, torch.float16, and torch.bfloat16 - ) +BASE_PATH: str = "/leonardo_scratch/fast/uToID_bench/xffl" @dataclass class llama(ModelInfo): - name: str = TINY_RANDOM_LLAMA_3 - attention: str = "sdpa" - model: Callable = _load_llm_from_checkpoint - decoder_layer: Type = LlamaDecoderLayer - activation_checkpointing: bool = True - mixed_precision: MixedPrecision = field( - default_factory=lambda: MixedPrecision( - param_dtype=torch.bfloat16, - reduce_dtype=torch.bfloat16, - buffer_dtype=torch.bfloat16, - # cast_forward_inputs=True, + + # LLM loading from saved model + @staticmethod + def _load_llm_from_checkpoint( + config: XFFLConfig, state: DistributedState + ) -> nn.Module: + return LlamaForCausalLM.from_pretrained( + pretrained_model_name_or_path=str(config.model_info.path), + use_cache=True, + local_files_only=True, # Most HPCs do not have internet access from the nodes + attn_implementation=config.model_info.attention, + dtype=torch.bfloat16, # Model is loaded in torch.bfloat16 (from the JSON file) - also "auto" # This slows down model loading + device_map=state.init_device, + use_safetensors=True, + low_cpu_mem_usage=True, + tie_word_embeddings=True, ) - ) - path: str = BASE_PATH + "/model/" + name + # Auto wrap policy + @staticmethod + def llama_fsdp_wrap_policy(): + return partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + LlamaDecoderLayer, + }, + ) -@dataclass -class mixtral(ModelInfo): - name: str = MIXTRAL_8x7b_v0_1 - attention: str = "sdpa" + name: str = LLAMA3_1_8B + attention: str = "sdpa" # "flash_attention_2" model: Callable = _load_llm_from_checkpoint - decoder_layer: Type = MixtralDecoderLayer + decoder_layer: Type = LlamaDecoderLayer + wrapping_policy: Callable = llama_fsdp_wrap_policy + activation_checkpointing: bool = False # True mixed_precision: MixedPrecision = field( default_factory=lambda: MixedPrecision( param_dtype=torch.bfloat16, @@ -114,7 +76,24 @@ class mixtral(ModelInfo): # cast_forward_inputs=True, ) ) - path: str = BASE_PATH + "/model/" + name + path: str = BASE_PATH + "/models/" + name + + +# @dataclass +# class mixtral(ModelInfo): +# name: str = MIXTRAL_8x7b_v0_1 +# attention: str = "sdpa" +# model: Callable = _load_llm_from_checkpoint +# decoder_layer: Type = MixtralDecoderLayer +# mixed_precision: MixedPrecision = field( +# default_factory=lambda: MixedPrecision( +# param_dtype=torch.bfloat16, +# reduce_dtype=torch.bfloat16, +# buffer_dtype=torch.bfloat16, +# # cast_forward_inputs=True, +# ) +# ) +# path: str = BASE_PATH + "/model/" + name @dataclass @@ -130,26 +109,41 @@ def _get_cleanmc4it_splits(config: XFFLConfig, state: DistributedState): name: str = CLEAN_MC4_IT splits: Callable = _get_cleanmc4it_splits batch_sizes: Mapping[str, int] = field( - default_factory=lambda: {"train": 4, "val": 1} + default_factory=lambda: {"train": 2, "val": 2} ) - subsampling: int = 16 + subsampling: int = 1024 workers: int = 2 collate_fn: Callable = default_data_collator - path: str = BASE_PATH + "/dataset/" + CLEAN_MC4_IT + path: str = BASE_PATH + "/data/" + CLEAN_MC4_IT # XFFL configuration @dataclass class xffl_config(XFFLConfig): + # Optimizer + @staticmethod + def _get_optimizer(model: nn.Module, config: XFFLConfig) -> Optimizer: + return AdamW( + params=model.parameters(), + lr=config.learning_rate, # type: ignore + weight_decay=config.weight_decay, # type: ignore + betas=config.betas, # type: ignore + # foreach=True, # Optimizes performances but uses more memory + fused=True, # Supported only on torch.float64, torch.float32, torch.float16, and torch.bfloat16 + ) + # Default model_info: ModelInfo = field(default_factory=llama) dataset_info: DatasetInfo = field(default_factory=cleanmc4it) optimizer: Callable[[nn.Module, XFFLConfig], Optimizer] = _get_optimizer # General - loglevel: int = logging.INFO + loglevel: int = logging.DEBUG seed: int = 42 + hsdp: int = 4 + federated: int = 4 + federated_batches: int = 8 # Learning learning_rate: float = 3e-4 @@ -157,14 +151,42 @@ class xffl_config(XFFLConfig): # WandB wandb_entity: str = "alpha-unito" - wandb_project: str = "xFFL playground" - wandb_group: str = "02_CNN" - wandb_name: str = "Example" - wandb_notes: str = "Example run of xFFL with a CNN" + wandb_project: str = "FL+DP" + wandb_group: str = "FL+HSDP" + wandb_name: str = "Prova" + wandb_notes: str = "Example run of xFFL with a LLM" wandb_tags: Sequence[str] = field( - default_factory=lambda: ["xFFL", "example", "MLP"] + default_factory=lambda: ["xFFL", "example", "LLM"] ) - wandb_mode: str = "online" + wandb_mode: str = "offline" + + # Learning rate scheduler + @staticmethod + def _get_llama31_cosine_schedule( + optimizer: Optimizer, total_steps: int, config: XFFLConfig + ) -> LRScheduler: + """ + Scheduler stile LLaMA3.1 semplificato: warmup -> cosine decay. + + Args: + optimizer: torch.optim.Optimizer + total_steps (int): passi totali (es. 128) + lr_max (float): learning rate massimo + warmup_frac (float): frazione di warmup (default 5%) + """ + warmup_steps = int(total_steps * config.warmup_frac) # type: ignore + decay_steps = total_steps - warmup_steps + + def lr_lambda(step): + if step < warmup_steps: + # warmup lineare + return step / max(1, warmup_steps) + else: + # decadimento coseno + progress = (step - warmup_steps) / max(1, decay_steps) + return 0.5 * (1 + math.cos(math.pi * progress)) + + return LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step)) # Advanced configuration lr_scheduler: Callable = _get_llama31_cosine_schedule diff --git a/examples/intra-silo/03_LLM/training.py b/examples/intra-silo/03_LLM/training.py index 5a5a41b..a4abbd0 100644 --- a/examples/intra-silo/03_LLM/training.py +++ b/examples/intra-silo/03_LLM/training.py @@ -53,6 +53,7 @@ def pretraining(config: XFFLConfig) -> None: ) # Large data preloading in background + start_time: float = time.perf_counter() if state.node_local_rank == 0: utils.preload(files=[config.model_info.path, config.dataset_info.path]) @@ -65,7 +66,7 @@ def pretraining(config: XFFLConfig) -> None: f"Model loading time: {(time.perf_counter() - start_time):.2f} seconds" ) logger.debug( - f"Training {config.model_info.name}: {(utils.get_model_size(model=model) / 1e6):.2f} million trainable parameters" + f"Training {config.model_info.name}: {(utils.get_model_size(model=model, state=state) / 1e6):.2f} million trainable parameters" ) # Dataset loading diff --git a/examples/scripts/env/leonardo.sh b/examples/scripts/env/leonardo.sh index 90563dd..f47857a 100644 --- a/examples/scripts/env/leonardo.sh +++ b/examples/scripts/env/leonardo.sh @@ -1,7 +1,8 @@ #!/bin/bash ulimit -n 131072 -module load cuda/12.2 nccl/2.22.3-1--gcc--12.2.0-cuda-12.2-spack0.22 +#module load cuda/12.2 nccl/2.22.3-1--gcc--12.2.0-cuda-12.2-spack0.22 +#module load cuda/12.2 gcc/12.2.0 python/3.11.7 export PYTHONUNBUFFERED=1 diff --git a/xffl/distributed/aggregation.py b/xffl/distributed/aggregation.py index be002c2..1e81d4e 100644 --- a/xffl/distributed/aggregation.py +++ b/xffl/distributed/aggregation.py @@ -239,7 +239,12 @@ def get_average_distributed_loss( assert state.federated_group is not None scale_factor: int = state.federated_local_size[state.federated_rank] - group: Optional[ProcessGroup] = state.federated_group[state.federated_rank] + + group: Optional[ProcessGroup] = ( + state.federated_group[0] + if state.streams is None + else state.federated_group[state.federated_rank] + ) else: assert state.world_size is not None @@ -327,7 +332,7 @@ def layer_by_layer_optimized( use_multiple_cuda_streams=use_multiple_cuda_streams, state=state ) - bucket_size: int = get_model_size(model=model) // stream_number + bucket_size: int = get_model_size(model=model, state=state) // stream_number parameter_counter: int = 0 mapping: List[Tuple[Tuple[int, ...], Tensor, ContextManager, int]] = [] @@ -509,7 +514,7 @@ def bucket_optimized_flatten( param_list: List[Tensor] = list(model.parameters()) - bucket_size: int = get_model_size(model=model) // stream_number + bucket_size: int = get_model_size(model=model, state=state) // stream_number parameter_counter: int = 0 buckets: List[List[int]] = [[] for _ in range(stream_number)] @@ -577,7 +582,7 @@ def bucket_optimized_coalesced( param_list: List[Tensor] = list(model.parameters()) - bucket_size: int = get_model_size(model=model) // stream_number + bucket_size: int = get_model_size(model=model, state=state) // stream_number parameter_counter: int = 0 buckets: List[List[int]] = [[] for _ in range(stream_number)] @@ -688,7 +693,7 @@ def layer_by_layer_optimized_( use_multiple_cuda_streams=use_multiple_cuda_streams, state=state ) - bucket_size: int = get_model_size(model=model) // stream_number + bucket_size: int = get_model_size(model=model, state=state) // stream_number parameter_counter: int = 0 for layer in model.parameters(): @@ -834,7 +839,7 @@ def bucket_optimized_flatten_( param_list: List[Tensor] = list(model.parameters()) - bucket_size: int = get_model_size(model=model) // stream_number + bucket_size: int = get_model_size(model=model, state=state) // stream_number parameter_counter: int = 0 buckets: List[List[int]] = [[] for _ in range(stream_number)] @@ -896,7 +901,7 @@ def bucket_optimized_coalesced_( param_list: List[Tensor] = list(model.parameters()) - bucket_size: int = get_model_size(model=model) // stream_number + bucket_size: int = get_model_size(model=model, state=state) // stream_number parameter_counter: int = 0 buckets: List[List[int]] = [[] for _ in range(stream_number)] diff --git a/xffl/learning/modelling.py b/xffl/learning/modelling.py index 85c7099..d714a24 100644 --- a/xffl/learning/modelling.py +++ b/xffl/learning/modelling.py @@ -105,9 +105,12 @@ def create_fsdp_model( model = FullyShardedDataParallel( module=_module, sharding_strategy=get_appropriate_sharding_strategy(state=state), - auto_wrap_policy=_wrapping_policy, + auto_wrap_policy=( + _wrapping_policy() if _wrapping_policy is not None else None + ), device_id=state.current_device, - forward_prefetch=True, + backward_prefetch=None, # True + forward_prefetch=False, # True limit_all_gathers=False, mixed_precision=_mixed_precision, sync_module_states=bool(state.meta_initialization), diff --git a/xffl/learning/utils.py b/xffl/learning/utils.py index 6b61ea0..719113a 100644 --- a/xffl/learning/utils.py +++ b/xffl/learning/utils.py @@ -92,7 +92,7 @@ def set_nondeterministic_execution() -> None: torch.use_deterministic_algorithms(mode=False) -def get_model_size(model: nn.Module) -> int: +def get_model_size(model: nn.Module, state: DistributedState) -> int: """Returns the model's trainable parameters number :param model: PyTorch model @@ -100,7 +100,25 @@ def get_model_size(model: nn.Module) -> int: :return: Number of trainable parameters :rtype: int """ - return sum(p.numel() for p in model.parameters() if p.requires_grad) + params: int = sum(p.numel() for p in model.parameters() if p.requires_grad) + if state is not None: + + if state.is_hsdp_setup(): + assert state.replica_local_size is not None + + params *= state.replica_local_size + + elif state.is_fsdp_setup: + assert state.world_size is not None + + params *= state.world_size + + if state.is_federated_scaling_setup: + assert state.federated_world_size is not None + + params //= state.federated_world_size + + return params def get_model_size_in_bits(model: nn.Module) -> int: From aa95a675c9e033525e7ba3e5eaa078572ab9cece Mon Sep 17 00:00:00 2001 From: Gianluca Mittone Date: Mon, 16 Mar 2026 10:44:30 +0100 Subject: [PATCH 2/5] Reproducibility code for EuroPar2026 --- examples/EuroPar/README.md | 146 +++++++++++++++++++++ examples/EuroPar/config_FL+FSDP.py | 172 ++++++++++++++++++++++++ examples/EuroPar/config_FL+HSDP.py | 173 +++++++++++++++++++++++++ examples/EuroPar/config_FSDP.py | 170 ++++++++++++++++++++++++ examples/EuroPar/config_HSDP.py | 171 ++++++++++++++++++++++++ examples/EuroPar/plots/time_to_perp.py | 165 +++++++++++++++++++++++ examples/EuroPar/training.py | 134 +++++++++++++++++++ xffl/cli/exec.py | 1 + xffl/distributed/aggregation.py | 19 +-- xffl/distributed/distributed.py | 7 +- xffl/learning/modelling.py | 14 +- xffl/learning/processing.py | 51 ++++---- xffl/learning/utils.py | 6 +- xffl/utils/utils.py | 2 +- 14 files changed, 1173 insertions(+), 58 deletions(-) create mode 100644 examples/EuroPar/README.md create mode 100644 examples/EuroPar/config_FL+FSDP.py create mode 100644 examples/EuroPar/config_FL+HSDP.py create mode 100644 examples/EuroPar/config_FSDP.py create mode 100644 examples/EuroPar/config_HSDP.py create mode 100644 examples/EuroPar/plots/time_to_perp.py create mode 100644 examples/EuroPar/training.py diff --git a/examples/EuroPar/README.md b/examples/EuroPar/README.md new file mode 100644 index 0000000..c5798ef --- /dev/null +++ b/examples/EuroPar/README.md @@ -0,0 +1,146 @@ +# xFFL β€” EuroPar 2026 β€” FL + *SDP Experiments + +This repository contains the source code required to reproduce the large-scale pre-training experiments presented at **EuroPar 2026**, based on **federated-learning-augmented sharded data parallelism** (FL + *SDP). + +The experiments train **Llama 3.1-8B** on the **clean_mc4_it** dataset using the xFFL framework on a large HPC system. + +> πŸ§ͺ Original experiments were executed on **128 Leonardo HPC nodes (512 NVIDIA A100 GPUs)**. + +## Overview + +This repository provides everything needed to reproduce the training runs evaluated in the paper. + +### Included components + +* **Four configuration files**, one for each parallelization strategy: + + * `config_FSDP.py` β€” Fully Sharded Data Parallel (FSDP) + * `config_HSDP.py` β€” Hierarchical Sharded Data Parallel (HSDP) + * `config_FL+FSDP.py` β€” Federated Learning + FSDP + * `config_FL+HSDP.py` β€” Federated Learning + HSDP + + Each configuration defines: + + * model parameters + * dataset paths + * optimizer settings + * distributed training options + * logging configuration + +* **Training script (`training.py`)** + + Implements a full xFFL-compliant large-scale LLM pre-training pipeline with: + + * deterministic initialization for reproducibility + * distributed setup (FSDP / HSDP / FL+FSDP / FL+HSDP) + * dataset loading and preprocessing + * optional Weights & Biases logging + * multi-node orchestration via xFFL + +## Training Setup + +* **Model:** Llama 3.1-8B +* **Dataset:** clean_mc4_it +* **Precision:** bfloat16 +* **Training type:** full pre-training (not fine-tuning) + +⚠️ Parameter-efficient methods such as LoRA or QLoRA are **not used**. + +## Repository Structure + +``` +EuroPar/ +β”‚ +β”œβ”€β”€ config_FSDP.py # LLaMA 3.1-8B + clean_mc4_it using FSDP +β”œβ”€β”€ config_HSDP.py # LLaMA 3.1-8B + clean_mc4_it using HSDP +β”œβ”€β”€ config_FL+FSDP.py # LLaMA 3.1-8B + clean_mc4_it using FL+FSDP +β”œβ”€β”€ config_FL+HSDP.py # LLaMA 3.1-8B + clean_mc4_it using FL+HSDP +β”œβ”€β”€ training.py # Main distributed training script +``` + +## Requirements + +### Software + +* Python environment compatible with xFFL +* xFFL installed and properly configured +* PyTorch with distributed support +* Access to the Llama 3.1-8B weights +* Access to the clean_mc4_it dataset + +### Hardware + +This code targets large HPC systems. + +At minimum: + +* Multi-node GPU cluster +* High-speed interconnect (e.g., InfiniBand) +* SLURM or compatible scheduler +* Sufficient storage bandwidth for dataset streaming + +## Installation + +Install and configure xFFL according to its official documentation. + +Ensure that: + +* distributed communication works across nodes +* filesystem paths are accessible from all nodes +* required datasets and model checkpoints are available + +## Running the Experiments + +After allocating compute nodes and activating your environment: + +### Basic execution example + +```bash +xffl exec training.py config_FSDP.py +``` + +### What happens during execution + +xFFL will automatically: + +* detect the allocated resources (e.g., via SLURM) +* initialize distributed training +* load the Llama 3.1-8B model from the configured path +* wrap the model with the selected parallelization strategy +* load the dataset +* start synchronized multi-node training + +## Configuration + +All experiment-specific parameters are defined in the configuration files. + +Before running: + +* Update dataset paths +* Update model checkpoint paths +* Adjust output/logging directories +* Verify cluster-specific settings + +⚠️ Paths must be valid on all participating nodes. + + +## Reproducibility Notes + +* Deterministic initialization is enabled where possible +* Performance and scalability results depend on hardware topology +* Exact reproduction requires a system comparable to the Leonardo cluster + +## Limitations + +* Designed for **pre-training workloads**, not fine-tuning +* Assumes availability of large GPU clusters +* Not optimized for single-node execution + +## Adapting the Code + +Although tailored for pre-training, the pipeline can be extended to support: + +* fine-tuning +* alternative datasets +* different model sizes +* custom distributed strategies diff --git a/examples/EuroPar/config_FL+FSDP.py b/examples/EuroPar/config_FL+FSDP.py new file mode 100644 index 0000000..dbd39cc --- /dev/null +++ b/examples/EuroPar/config_FL+FSDP.py @@ -0,0 +1,172 @@ +"""Configuration file for the xFFL-LLM example""" + +import logging +import math +import os +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import Callable, Mapping, Sequence, Type + +import torch +from torch import nn +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.optim import AdamW, Optimizer +from torch.optim.lr_scheduler import LambdaLR, LRScheduler +from transformers import AutoModelForCausalLM, default_data_collator +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer + +from xffl.custom.config import DatasetInfo, ModelInfo, XFFLConfig +from xffl.distributed.distributed_state import DistributedState +from xffl.learning.data import load_datasets_from_disk + +# Constants +LLAMA3_1_8B: str = "llama3.1-8b-init" +CLEAN_MC4_IT: str = "clean_mc4_it" + +BASE_PATH: str = "/leonardo_scratch/fast/uToID_bench/xffl" + + +@dataclass +class llama(ModelInfo): + + # LLM loading from saved model + @staticmethod + def _load_llm_from_checkpoint( + config: XFFLConfig, state: DistributedState + ) -> nn.Module: + return LlamaForCausalLM.from_pretrained( + pretrained_model_name_or_path=str(config.model_info.path), + use_cache=True, + local_files_only=True, # Most HPCs do not have internet access from the nodes + attn_implementation=config.model_info.attention, + dtype=torch.bfloat16, # Model is loaded in torch.bfloat16 (from the JSON file) - also "auto" + device_map=state.init_device, + use_safetensors=True, + low_cpu_mem_usage=True, + tie_word_embeddings=True, + ) + + # Auto wrap policy + @staticmethod + def llama_fsdp_wrap_policy(): + return partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + LlamaDecoderLayer, + }, + ) + + name: str = LLAMA3_1_8B + attention: str = "sdpa" # "flash_attention_2" + model: Callable = _load_llm_from_checkpoint + decoder_layer: Type = LlamaDecoderLayer + wrapping_policy: Callable = llama_fsdp_wrap_policy + mixed_precision: MixedPrecision = field( + default_factory=lambda: MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ) + ) + path: str = BASE_PATH + "/models/" + name + + +@dataclass +class cleanmc4it(DatasetInfo): + + @staticmethod + def _get_cleanmc4it_splits(config: XFFLConfig, state: DistributedState): + return load_datasets_from_disk( + splits={"train": "train", "val": "val"}, + base_path=Path(str(config.dataset_info.path)), + ) # Original LLaMA training packs the datasets + + name: str = CLEAN_MC4_IT + splits: Callable = _get_cleanmc4it_splits + batch_sizes: Mapping[str, int] = field( + default_factory=lambda: {"train": 2, "val": 2} + ) + subsampling: Mapping[str, int] = field( + default_factory=lambda: {"train": 65536, "val": 4096} + ) + workers: int = 2 + collate_fn: Callable = default_data_collator + path: str = BASE_PATH + "/data/" + CLEAN_MC4_IT + + +# XFFL configuration +@dataclass +class xffl_config(XFFLConfig): + + # Optimizer + @staticmethod + def _get_optimizer(model: nn.Module, config: XFFLConfig) -> Optimizer: + return AdamW( + params=model.parameters(), + lr=config.learning_rate, # type: ignore + weight_decay=config.weight_decay, # type: ignore + betas=config.betas, # type: ignore + fused=True, # Supported only on torch.float64, torch.float32, torch.float16, and torch.bfloat16 + ) + + # Default + model_info: ModelInfo = field(default_factory=llama) + dataset_info: DatasetInfo = field(default_factory=cleanmc4it) + optimizer: Callable[[nn.Module, XFFLConfig], Optimizer] = _get_optimizer + + # General + loglevel: int = logging.DEBUG + seed: int = 42 + federated: int = 4 + federated_batches: int = 8 + + # Learning + learning_rate: float = 3e-4 + epochs: int = 1 + + # WandB + wandb_entity: str = "alpha-unito" + wandb_project: str = "FL+DP" + wandb_group: str = "FL+FSDP_new" + wandb_notes: str = "EuroPar 2026 experiments" + wandb_tags: Sequence[str] = field(default_factory=lambda: ["xFFL", "EuroPar"]) + wandb_mode: str = "offline" + + # Learning rate scheduler + @staticmethod + def _get_llama31_cosine_schedule( + optimizer: Optimizer, total_steps: int, config: XFFLConfig + ) -> LRScheduler: + """ + Scheduler stile LLaMA3.1 semplificato: warmup -> cosine decay. + + Args: + optimizer: torch.optim.Optimizer + total_steps (int): passi totali (es. 128) + lr_max (float): learning rate massimo + warmup_frac (float): frazione di warmup (default 5%) + """ + warmup_steps = int(total_steps * config.warmup_frac) # type: ignore + decay_steps = total_steps - warmup_steps + + def lr_lambda(step): + if step < warmup_steps: + # Linear warmup + return step / max(1, warmup_steps) + else: + # Cosine decay + progress = (step - warmup_steps) / max(1, decay_steps) + return 0.5 * (1 + math.cos(math.pi * progress)) + + return LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step)) + + # Advanced configuration + lr_scheduler: Callable = _get_llama31_cosine_schedule + + # Custom - optimizer + weight_decay: float = 0.1 + betas: Sequence[float] = (0.9, 0.95) + warmup_frac: float = 0.1 diff --git a/examples/EuroPar/config_FL+HSDP.py b/examples/EuroPar/config_FL+HSDP.py new file mode 100644 index 0000000..b4ccf43 --- /dev/null +++ b/examples/EuroPar/config_FL+HSDP.py @@ -0,0 +1,173 @@ +"""Configuration file for the xFFL-LLM example""" + +import logging +import math +import os +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import Callable, Mapping, Sequence, Type + +import torch +from torch import nn +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.optim import AdamW, Optimizer +from torch.optim.lr_scheduler import LambdaLR, LRScheduler +from transformers import AutoModelForCausalLM, default_data_collator +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer + +from xffl.custom.config import DatasetInfo, ModelInfo, XFFLConfig +from xffl.distributed.distributed_state import DistributedState +from xffl.learning.data import load_datasets_from_disk + +# Constants +LLAMA3_1_8B: str = "llama3.1-8b-init" +CLEAN_MC4_IT: str = "clean_mc4_it" + +BASE_PATH: str = "/leonardo_scratch/fast/uToID_bench/xffl" + + +@dataclass +class llama(ModelInfo): + + # LLM loading from saved model + @staticmethod + def _load_llm_from_checkpoint( + config: XFFLConfig, state: DistributedState + ) -> nn.Module: + return LlamaForCausalLM.from_pretrained( + pretrained_model_name_or_path=str(config.model_info.path), + use_cache=True, + local_files_only=True, # Most HPCs do not have internet access from the nodes + attn_implementation=config.model_info.attention, + dtype=torch.bfloat16, # Model is loaded in torch.bfloat16 (from the JSON file) - also "auto" + device_map=state.init_device, + use_safetensors=True, + low_cpu_mem_usage=True, + tie_word_embeddings=True, + ) + + # Auto wrap policy + @staticmethod + def llama_fsdp_wrap_policy(): + return partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + LlamaDecoderLayer, + }, + ) + + name: str = LLAMA3_1_8B + attention: str = "sdpa" # "flash_attention_2" + model: Callable = _load_llm_from_checkpoint + decoder_layer: Type = LlamaDecoderLayer + wrapping_policy: Callable = llama_fsdp_wrap_policy + mixed_precision: MixedPrecision = field( + default_factory=lambda: MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ) + ) + path: str = BASE_PATH + "/models/" + name + + +@dataclass +class cleanmc4it(DatasetInfo): + + @staticmethod + def _get_cleanmc4it_splits(config: XFFLConfig, state: DistributedState): + return load_datasets_from_disk( + splits={"train": "train", "val": "val"}, + base_path=Path(str(config.dataset_info.path)), + ) # Original LLaMA training packs the datasets + + name: str = CLEAN_MC4_IT + splits: Callable = _get_cleanmc4it_splits + batch_sizes: Mapping[str, int] = field( + default_factory=lambda: {"train": 2, "val": 2} + ) + subsampling: Mapping[str, int] = field( + default_factory=lambda: {"train": 65536, "val": 4096} + ) + workers: int = 2 + collate_fn: Callable = default_data_collator + path: str = BASE_PATH + "/data/" + CLEAN_MC4_IT + + +# XFFL configuration +@dataclass +class xffl_config(XFFLConfig): + + # Optimizer + @staticmethod + def _get_optimizer(model: nn.Module, config: XFFLConfig) -> Optimizer: + return AdamW( + params=model.parameters(), + lr=config.learning_rate, # type: ignore + weight_decay=config.weight_decay, # type: ignore + betas=config.betas, # type: ignore + fused=True, # Supported only on torch.float64, torch.float32, torch.float16, and torch.bfloat16 + ) + + # Default + model_info: ModelInfo = field(default_factory=llama) + dataset_info: DatasetInfo = field(default_factory=cleanmc4it) + optimizer: Callable[[nn.Module, XFFLConfig], Optimizer] = _get_optimizer + + # General + loglevel: int = logging.DEBUG + seed: int = 42 + hsdp: int = 4 + federated: int = 32 + federated_batches: int = 8 + + # Learning + learning_rate: float = 3e-4 + epochs: int = 1 + + # WandB + wandb_entity: str = "alpha-unito" + wandb_project: str = "FL+DP" + wandb_group: str = "FL+HSDP_new" + wandb_notes: str = "EuroPar 2026 experiments" + wandb_tags: Sequence[str] = field(default_factory=lambda: ["xFFL", "EuroPar"]) + wandb_mode: str = "offline" + + # Learning rate scheduler + @staticmethod + def _get_llama31_cosine_schedule( + optimizer: Optimizer, total_steps: int, config: XFFLConfig + ) -> LRScheduler: + """ + Scheduler stile LLaMA3.1 semplificato: warmup -> cosine decay. + + Args: + optimizer: torch.optim.Optimizer + total_steps (int): passi totali (es. 128) + lr_max (float): learning rate massimo + warmup_frac (float): frazione di warmup (default 5%) + """ + warmup_steps = int(total_steps * config.warmup_frac) # type: ignore + decay_steps = total_steps - warmup_steps + + def lr_lambda(step): + if step < warmup_steps: + # Linear warmup + return step / max(1, warmup_steps) + else: + # Cosine decay + progress = (step - warmup_steps) / max(1, decay_steps) + return 0.5 * (1 + math.cos(math.pi * progress)) + + return LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step)) + + # Advanced configuration + lr_scheduler: Callable = _get_llama31_cosine_schedule + + # Custom - optimizer + weight_decay: float = 0.1 + betas: Sequence[float] = (0.9, 0.95) + warmup_frac: float = 0.1 diff --git a/examples/EuroPar/config_FSDP.py b/examples/EuroPar/config_FSDP.py new file mode 100644 index 0000000..bfa1047 --- /dev/null +++ b/examples/EuroPar/config_FSDP.py @@ -0,0 +1,170 @@ +"""Configuration file for the xFFL-LLM example""" + +import logging +import math +import os +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import Callable, Mapping, Sequence, Type + +import torch +from torch import nn +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.optim import AdamW, Optimizer +from torch.optim.lr_scheduler import LambdaLR, LRScheduler +from transformers import AutoModelForCausalLM, default_data_collator +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer + +from xffl.custom.config import DatasetInfo, ModelInfo, XFFLConfig +from xffl.distributed.distributed_state import DistributedState +from xffl.learning.data import load_datasets_from_disk + +# Constants +LLAMA3_1_8B: str = "llama3.1-8b-init" +CLEAN_MC4_IT: str = "clean_mc4_it" + +BASE_PATH: str = "/leonardo_scratch/fast/uToID_bench/xffl" + + +@dataclass +class llama(ModelInfo): + + # LLM loading from saved model + @staticmethod + def _load_llm_from_checkpoint( + config: XFFLConfig, state: DistributedState + ) -> nn.Module: + return LlamaForCausalLM.from_pretrained( + pretrained_model_name_or_path=str(config.model_info.path), + use_cache=True, + local_files_only=True, # Most HPCs do not have internet access from the nodes + attn_implementation=config.model_info.attention, + dtype=torch.bfloat16, # Model is loaded in torch.bfloat16 (from the JSON file) - also "auto" + device_map=state.init_device, + use_safetensors=True, + low_cpu_mem_usage=True, + tie_word_embeddings=True, + ) + + # Auto wrap policy + @staticmethod + def llama_fsdp_wrap_policy(): + return partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + LlamaDecoderLayer, + }, + ) + + name: str = LLAMA3_1_8B + attention: str = "sdpa" # "flash_attention_2" + model: Callable = _load_llm_from_checkpoint + decoder_layer: Type = LlamaDecoderLayer + wrapping_policy: Callable = llama_fsdp_wrap_policy + mixed_precision: MixedPrecision = field( + default_factory=lambda: MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ) + ) + path: str = BASE_PATH + "/models/" + name + + +@dataclass +class cleanmc4it(DatasetInfo): + + @staticmethod + def _get_cleanmc4it_splits(config: XFFLConfig, state: DistributedState): + return load_datasets_from_disk( + splits={"train": "train", "val": "val"}, + base_path=Path(str(config.dataset_info.path)), + ) # Original LLaMA training packs the datasets + + name: str = CLEAN_MC4_IT + splits: Callable = _get_cleanmc4it_splits + batch_sizes: Mapping[str, int] = field( + default_factory=lambda: {"train": 2, "val": 2} + ) + subsampling: Mapping[str, int] = field( + default_factory=lambda: {"train": 65536, "val": 4096} + ) + workers: int = 2 + collate_fn: Callable = default_data_collator + path: str = BASE_PATH + "/data/" + CLEAN_MC4_IT + + +# XFFL configuration +@dataclass +class xffl_config(XFFLConfig): + + # Optimizer + @staticmethod + def _get_optimizer(model: nn.Module, config: XFFLConfig) -> Optimizer: + return AdamW( + params=model.parameters(), + lr=config.learning_rate, # type: ignore + weight_decay=config.weight_decay, # type: ignore + betas=config.betas, # type: ignore + fused=True, # Supported only on torch.float64, torch.float32, torch.float16, and torch.bfloat16 + ) + + # Default + model_info: ModelInfo = field(default_factory=llama) + dataset_info: DatasetInfo = field(default_factory=cleanmc4it) + optimizer: Callable[[nn.Module, XFFLConfig], Optimizer] = _get_optimizer + + # General + loglevel: int = logging.DEBUG + seed: int = 42 + + # Learning + learning_rate: float = 3e-4 + epochs: int = 1 + + # WandB + wandb_entity: str = "alpha-unito" + wandb_project: str = "FL+DP" + wandb_group: str = "FSDP_new" + wandb_notes: str = "EuroPar 2026 experiments" + wandb_tags: Sequence[str] = field(default_factory=lambda: ["xFFL", "EuroPar"]) + wandb_mode: str = "offline" + + # Learning rate scheduler + @staticmethod + def _get_llama31_cosine_schedule( + optimizer: Optimizer, total_steps: int, config: XFFLConfig + ) -> LRScheduler: + """ + Scheduler stile LLaMA3.1 semplificato: warmup -> cosine decay. + + Args: + optimizer: torch.optim.Optimizer + total_steps (int): passi totali (es. 128) + lr_max (float): learning rate massimo + warmup_frac (float): frazione di warmup (default 5%) + """ + warmup_steps = int(total_steps * config.warmup_frac) # type: ignore + decay_steps = total_steps - warmup_steps + + def lr_lambda(step): + if step < warmup_steps: + # Linear warmup + return step / max(1, warmup_steps) + else: + # Cosine decay + progress = (step - warmup_steps) / max(1, decay_steps) + return 0.5 * (1 + math.cos(math.pi * progress)) + + return LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step)) + + # Advanced configuration + lr_scheduler: Callable = _get_llama31_cosine_schedule + + # Custom - optimizer + weight_decay: float = 0.1 + betas: Sequence[float] = (0.9, 0.95) + warmup_frac: float = 0.1 diff --git a/examples/EuroPar/config_HSDP.py b/examples/EuroPar/config_HSDP.py new file mode 100644 index 0000000..5cc4f55 --- /dev/null +++ b/examples/EuroPar/config_HSDP.py @@ -0,0 +1,171 @@ +"""Configuration file for the xFFL-LLM example""" + +import logging +import math +import os +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import Callable, Mapping, Sequence, Type + +import torch +from torch import nn +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.optim import AdamW, Optimizer +from torch.optim.lr_scheduler import LambdaLR, LRScheduler +from transformers import AutoModelForCausalLM, default_data_collator +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer + +from xffl.custom.config import DatasetInfo, ModelInfo, XFFLConfig +from xffl.distributed.distributed_state import DistributedState +from xffl.learning.data import load_datasets_from_disk + +# Constants +LLAMA3_1_8B: str = "llama3.1-8b-init" +CLEAN_MC4_IT: str = "clean_mc4_it" + +BASE_PATH: str = "/leonardo_scratch/fast/uToID_bench/xffl" + + +@dataclass +class llama(ModelInfo): + + # LLM loading from saved model + @staticmethod + def _load_llm_from_checkpoint( + config: XFFLConfig, state: DistributedState + ) -> nn.Module: + return LlamaForCausalLM.from_pretrained( + pretrained_model_name_or_path=str(config.model_info.path), + use_cache=True, + local_files_only=True, # Most HPCs do not have internet access from the nodes + attn_implementation=config.model_info.attention, + dtype=torch.bfloat16, # Model is loaded in torch.bfloat16 (from the JSON file) - also "auto" + device_map=state.init_device, + use_safetensors=True, + low_cpu_mem_usage=True, + tie_word_embeddings=True, + ) + + # Auto wrap policy + @staticmethod + def llama_fsdp_wrap_policy(): + return partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + LlamaDecoderLayer, + }, + ) + + name: str = LLAMA3_1_8B + attention: str = "sdpa" # "flash_attention_2" + model: Callable = _load_llm_from_checkpoint + decoder_layer: Type = LlamaDecoderLayer + wrapping_policy: Callable = llama_fsdp_wrap_policy + mixed_precision: MixedPrecision = field( + default_factory=lambda: MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ) + ) + path: str = BASE_PATH + "/models/" + name + + +@dataclass +class cleanmc4it(DatasetInfo): + + @staticmethod + def _get_cleanmc4it_splits(config: XFFLConfig, state: DistributedState): + return load_datasets_from_disk( + splits={"train": "train", "val": "val"}, + base_path=Path(str(config.dataset_info.path)), + ) # Original LLaMA training packs the datasets + + name: str = CLEAN_MC4_IT + splits: Callable = _get_cleanmc4it_splits + batch_sizes: Mapping[str, int] = field( + default_factory=lambda: {"train": 2, "val": 2} + ) + subsampling: Mapping[str, int] = field( + default_factory=lambda: {"train": 65536, "val": 4096} + ) + workers: int = 2 + collate_fn: Callable = default_data_collator + path: str = BASE_PATH + "/data/" + CLEAN_MC4_IT + + +# XFFL configuration +@dataclass +class xffl_config(XFFLConfig): + + # Optimizer + @staticmethod + def _get_optimizer(model: nn.Module, config: XFFLConfig) -> Optimizer: + return AdamW( + params=model.parameters(), + lr=config.learning_rate, # type: ignore + weight_decay=config.weight_decay, # type: ignore + betas=config.betas, # type: ignore + fused=True, # Supported only on torch.float64, torch.float32, torch.float16, and torch.bfloat16 + ) + + # Default + model_info: ModelInfo = field(default_factory=llama) + dataset_info: DatasetInfo = field(default_factory=cleanmc4it) + optimizer: Callable[[nn.Module, XFFLConfig], Optimizer] = _get_optimizer + + # General + loglevel: int = logging.DEBUG + seed: int = 42 + hsdp: int = 4 + + # Learning + learning_rate: float = 3e-4 + epochs: int = 1 + + # WandB + wandb_entity: str = "alpha-unito" + wandb_project: str = "FL+DP" + wandb_group: str = "HSDP_new" + wandb_notes: str = "EuroPar 2026 experiments" + wandb_tags: Sequence[str] = field(default_factory=lambda: ["xFFL", "EuroPar"]) + wandb_mode: str = "offline" + + # Learning rate scheduler + @staticmethod + def _get_llama31_cosine_schedule( + optimizer: Optimizer, total_steps: int, config: XFFLConfig + ) -> LRScheduler: + """ + Scheduler stile LLaMA3.1 semplificato: warmup -> cosine decay. + + Args: + optimizer: torch.optim.Optimizer + total_steps (int): passi totali (es. 128) + lr_max (float): learning rate massimo + warmup_frac (float): frazione di warmup (default 5%) + """ + warmup_steps = int(total_steps * config.warmup_frac) # type: ignore + decay_steps = total_steps - warmup_steps + + def lr_lambda(step): + if step < warmup_steps: + # Linear warmup + return step / max(1, warmup_steps) + else: + # Cosine decay + progress = (step - warmup_steps) / max(1, decay_steps) + return 0.5 * (1 + math.cos(math.pi * progress)) + + return LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step)) + + # Advanced configuration + lr_scheduler: Callable = _get_llama31_cosine_schedule + + # Custom - optimizer + weight_decay: float = 0.1 + betas: Sequence[float] = (0.9, 0.95) + warmup_frac: float = 0.1 diff --git a/examples/EuroPar/plots/time_to_perp.py b/examples/EuroPar/plots/time_to_perp.py new file mode 100644 index 0000000..cf8e567 --- /dev/null +++ b/examples/EuroPar/plots/time_to_perp.py @@ -0,0 +1,165 @@ +from pathlib import Path +from typing import List + +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +import pandas as pd +import seaborn as sns + +# ========================= +# Configurazione globale +# ========================= +sns.set_theme(style="darkgrid") + +# ========================= +# Font piΓΉ grandi (globali) +# ========================= +plt.rcParams.update( + { + "font.size": 20, # base + "axes.titlesize": 20, + "axes.labelsize": 20, + "xtick.labelsize": 20, + "ytick.labelsize": 20, + "legend.fontsize": 20, + "legend.title_fontsize": 20, + } +) + +LOG_DIR = Path("/leonardo_scratch/fast/uToID_bench/xffl/examples/EuroPar/plots") +ENGINE = "pyarrow" +SEP = ";" + +METHODS: List[str] = [ + "FSDP", + "HSDP", + "FL+FSDP", + "FL+HSDP", +] + +TIME_COL = "Relative Time (Process)" +STEP_COL = "Step" + +X_LABEL = "Time (minutes)" +Y_LABEL = "Perplexity" + +X_LIM = (0, 730) +Y_LIM = (1e2, 1e5) + +XTICKS = [i * 60 for i in range(0, (X_LIM[1] // 60) + 1)] +XTICKLABELS = list(range(0, (X_LIM[1] // 60) + 1)) + + +# ========================= +# Funzioni +# ========================= +def load_time_data(method: str) -> pd.DataFrame: + return pd.read_csv( + LOG_DIR / f"time_to_perp_{method}.csv", + sep=SEP, + engine=ENGINE, + ) + + +def load_step_data() -> pd.DataFrame: + return pd.read_csv( + LOG_DIR / "step_to_perp.csv", + sep=SEP, + engine=ENGINE, + ) + + +def preprocess( + method: str, time_df: pd.DataFrame, step_df: pd.DataFrame +) -> pd.DataFrame: + perp = f"Group: {method}_new - train/Step_perplexity" + perp_step = f"Group: {method}_new - _step" + perp_min = f"Group: {method}_new - train/Step_perplexity__MIN" + perp_max = f"Group: {method}_new - train/Step_perplexity__MAX" + + # Tempo medio per step + time_processed = ( + time_df[[TIME_COL, perp_step]] + .dropna() + .rename(columns={perp_step: STEP_COL}) + .astype({STEP_COL: int}) + .groupby(STEP_COL, as_index=False) + .mean() + ) + + # Merge con metriche + step_metrics = step_df[[STEP_COL, perp, perp_min, perp_max]] + merged = pd.merge(time_processed, step_metrics, on=STEP_COL) + + # Normalizza tempo (parte da zero) + merged[TIME_COL] -= merged[TIME_COL].iloc[0] + + return merged + + +def plot_method(ax, data: pd.DataFrame, method: str): + perp = f"Group: {method}_new - train/Step_perplexity" + perp_min = f"Group: {method}_new - train/Step_perplexity__MIN" + perp_max = f"Group: {method}_new - train/Step_perplexity__MAX" + + sns.lineplot( + ax=ax, + data=data, + x=TIME_COL, + y=perp, + label=method, + ) + + ax.fill_between( + data[TIME_COL], + data[perp_min], + data[perp_max], + alpha=0.2, + ) + + +# ========================= +# Main +# ========================= +def main(): + fig, ax = plt.subplots(figsize=(16, 9)) + + step_df = load_step_data() + + for method in METHODS: + time_df = load_time_data(method) + processed = preprocess(method, time_df, step_df) + plot_method(ax, processed, method) + + ax.set( + xlabel=X_LABEL, + ylabel=Y_LABEL, + xlim=X_LIM, + ylim=Y_LIM, + yscale="log", + ) + + # ========================= + # Log-scale grid piΓΉ densa + # ========================= + ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0)) + ax.yaxis.set_minor_locator( + ticker.LogLocator(base=10.0, subs=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) + ) + ax.yaxis.set_minor_formatter(ticker.NullFormatter()) + + ax.grid(which="major", linestyle="-", linewidth=1, alpha=1) + ax.grid(which="minor", linestyle="--", linewidth=0.9, alpha=1) + + ax.set_xticks(XTICKS) + ax.set_xticklabels(XTICKLABELS) + + ax.legend() + fig.tight_layout() + + fig.savefig("out.png", dpi=300, bbox_inches="tight") + fig.savefig("time_to_perp.pdf", dpi=300, bbox_inches="tight") + + +if __name__ == "__main__": + main() diff --git a/examples/EuroPar/training.py b/examples/EuroPar/training.py new file mode 100644 index 0000000..768ef8e --- /dev/null +++ b/examples/EuroPar/training.py @@ -0,0 +1,134 @@ +"""LLM training example script + +Inspired from llama-recipes' fine-tuning.py script: +https://github.com/meta-llama/llama-cookbook/blob/main/src/llama_recipes/finetuning.py +""" + +import time +from logging import Logger, getLogger +from typing import Any, MutableMapping, Optional + +import torch +import torch.nn as nn +import wandb +from config_FSDP import xffl_config +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from xffl.custom.config import XFFLConfig +from xffl.distributed import distributed +from xffl.learning import modelling, processing, utils +from xffl.learning.data import create_dataloaders +from xffl.learning.utils import wandb_setup +from xffl.utils.logging import setup_logging + +logger: Logger = getLogger(__name__) +"""Default xFFL logger""" + + +def pretraining(config: XFFLConfig) -> None: + """Simple MLP training script + + :param config: xFFL configuration + :type config: XFFLConfig + """ + setup_time: float = time.perf_counter() + + # Set the requested logging level + setup_logging(log_level=config.loglevel) + + # Sets RNGs seeds and force PyTorch's deterministic execution + generator: Optional[torch.Generator] = utils.set_deterministic_execution( + config=config + ) + + # PyTorch's distributed backend setup + start_time: float = time.perf_counter() + state: distributed.DistributedState = distributed.setup_distributed_process_group( + config=config + ) + if state.rank == 0: + logger.debug( + f"Rendez-vous time: {(time.perf_counter() - start_time):.2f} seconds" + ) + + # Large data preloading in background + start_time: float = time.perf_counter() + if state.node_local_rank == 0: + utils.preload(files=[config.model_info.path, config.dataset_info.path]) + + # Model setup + start_time: float = time.perf_counter() + model: nn.Module = modelling.create_fsdp_model(state=state, config=config) + + if state.rank == 0: + logger.debug( + f"Model loading time: {(time.perf_counter() - start_time):.2f} seconds" + ) + logger.debug( + f"Training {config.model_info.name}: {(utils.get_model_size(model=model, state=state) / 1e6):.2f} million trainable parameters" + ) + + # Dataset loading + start_time: float = time.perf_counter() + dataloaders: Optional[MutableMapping[str, DataLoader]] = create_dataloaders( + state=state, + config=config, + generator=generator, + ) + if state.rank == 0: + logger.debug( + f"Dataset loading time: {(time.perf_counter() - start_time):.2f} seconds" + ) + + # Optimizer and lr scheduler creation + optimizer: Optimizer = config.optimizer(model=model, config=config) # type: ignore + + if state.rank == 0: + logger.debug( + f"Total setup time: {(time.perf_counter() - setup_time):.2f} seconds" + ) + logger.debug( + f"GPU RAM allocated before training: {torch.cuda.max_memory_allocated() / 10**9:.2f} GB" + ) + + # WandB setup + wandb_run: Any = wandb_setup(name=f"rank_{state.rank}", config=config) + + # Main training function + results = processing.distributed_training( + model=model, + state=state, + optimizer=optimizer, + train_dataloader=dataloaders["train"], + val_dataloader=dataloaders["val"], + config=config, + wandb_run=wandb_run, + ) + + if state.rank == 0: + [logger.info(f"{k}{v:.2f}") for k, v in results.items()] + if wandb_run is not None: + for k, v in results.items(): + wandb_run.summary[k] = v + + # PyTorch's distributed backend cleanup + wandb.finish() + distributed.cleanup_distributed_process_group( + state=state, del_obj=(model, optimizer) + ) + + +def main(): + """Argument parsing and training launch""" + + try: + pretraining(config=xffl_config()) + except KeyboardInterrupt as e: + logger.exception(e) + except Exception as e: + logger.exception(e) + + +if __name__ == "__main__": + main() diff --git a/xffl/cli/exec.py b/xffl/cli/exec.py index b22e270..cfb8c6f 100644 --- a/xffl/cli/exec.py +++ b/xffl/cli/exec.py @@ -183,6 +183,7 @@ def exec(args: Namespace) -> int: # Federated scaling if args.federated_scaling is not None: if args.federated_scaling == "auto": + logger.debug("Setting automatic federated scaling...") federated_local_size: Tuple[int, ...] = get_cells_ids( nodes=args.nodelist, cell_size=180 ) diff --git a/xffl/distributed/aggregation.py b/xffl/distributed/aggregation.py index 1e81d4e..979aeac 100644 --- a/xffl/distributed/aggregation.py +++ b/xffl/distributed/aggregation.py @@ -233,23 +233,10 @@ def get_average_distributed_loss( if state.backend == "nccl": _loss: Tensor = loss.to(device=state.current_device, non_blocking=True) - if state.is_federated_scaling_setup(): - assert state.federated_local_size is not None - assert state.federated_rank is not None - assert state.federated_group is not None + assert state.world_size is not None - scale_factor: int = state.federated_local_size[state.federated_rank] - - group: Optional[ProcessGroup] = ( - state.federated_group[0] - if state.streams is None - else state.federated_group[state.federated_rank] - ) - else: - assert state.world_size is not None - - scale_factor: int = state.world_size - group: Optional[ProcessGroup] = dist.group.WORLD + scale_factor: int = state.world_size + group: Optional[ProcessGroup] = dist.group.WORLD _loss /= tensor(total_length) dist.all_reduce(tensor=_loss, op=dist.ReduceOp.SUM, group=group) diff --git a/xffl/distributed/distributed.py b/xffl/distributed/distributed.py index c9cce36..c887d6e 100644 --- a/xffl/distributed/distributed.py +++ b/xffl/distributed/distributed.py @@ -220,9 +220,10 @@ def setup_distributed_process_group( if "XFFL_FEDERATED_LOCAL_WORLD_SIZE" in os.environ: _federated = tuple( int(item) * state.node_local_size - for item in str( - os.environ.get("XFFL_FEDERATED_LOCAL_WORLD_SIZE") - ).split(",") + for item in str(os.environ.get("XFFL_FEDERATED_LOCAL_WORLD_SIZE")) + .replace("(", "") + .replace(")", "") + .split(",") ) elif len(_federated) == 1: if state.world_size % _federated[0] != 0: diff --git a/xffl/learning/modelling.py b/xffl/learning/modelling.py index d714a24..72ab5ea 100644 --- a/xffl/learning/modelling.py +++ b/xffl/learning/modelling.py @@ -30,7 +30,7 @@ def create_fsdp_model( module: Optional[nn.Module] = None, wrapping_policy: Optional[Callable] = None, mixed_precision: Optional[MixedPrecision] = None, - decoder_layers: Optional[Type] = None, + decoder_layer: Optional[Type] = None, activation_checkpointing: Optional[bool] = None, config: Optional[XFFLConfig] = None, use_orig_params: bool = False, @@ -75,8 +75,8 @@ def create_fsdp_model( _mixed_precision: Optional[MixedPrecision] = resolve_param( value=mixed_precision, config=model_info, attr="mixed_precision" ) - _decoder_layers: Optional[Type] = resolve_param( - value=decoder_layers, config=model_info, attr="decoder_layers" + _decoder_layer: Optional[Type] = resolve_param( + value=decoder_layer, config=model_info, attr="decoder_layer" ) _activation_checkpointing: Optional[bool] = resolve_param( value=activation_checkpointing, @@ -87,7 +87,7 @@ def create_fsdp_model( _module: Optional[nn.Module] = module _wrapping_policy: Optional[Callable] = wrapping_policy _mixed_precision: Optional[MixedPrecision] = mixed_precision - _decoder_layers: Optional[Type] = decoder_layers + _decoder_layer: Optional[Type] = decoder_layer _activation_checkpointing: Optional[bool] = activation_checkpointing # Model and device mashes creation @@ -128,14 +128,14 @@ def create_fsdp_model( # Activation checkpointing # This can also be called before FSDP, will result in applying the HF-specific method, giving warnings during the training if _activation_checkpointing is not None and _activation_checkpointing: - if _decoder_layers is not None: + if _decoder_layer is not None: logger.debug("Activating activation checkpointing.") utils.set_activation_checkpointing( model=model, - layer=_decoder_layers, + layer=_decoder_layer, ) logger.info( - f"Activation checkpointing activated on the {_decoder_layers} layers." + f"Activation checkpointing activated on the {_decoder_layer} layer." ) else: logger.warning( diff --git a/xffl/learning/processing.py b/xffl/learning/processing.py index 38db707..78ebf77 100644 --- a/xffl/learning/processing.py +++ b/xffl/learning/processing.py @@ -695,38 +695,31 @@ def distributed_training( if wandb_run: metrics: Mapping[str, Any] = { "train/Step": epoch * total_length + step, - "train/Step loss": train_step_loss[-1], - "train/Step perplexity": train_step_perplexity[-1], - "train/Optimizer step": optimizer_step, - "train/Learning rate": ( - _lr_scheduler.get_lr() - if _lr_scheduler is not None - else optimizer.param_groups[0]["lr"] - ), + "train/Step_loss": train_step_loss[-1], + "train/Step_perplexity": train_step_perplexity[-1], + "opt/Step": optimizer_step, + "opt/lr": optimizer.param_groups[0]["lr"], } if state.is_federated_scaling_setup(): - metrics["train/Aggregation step"] = aggregation + metrics["train/Aggregation_step"] = aggregation if _fedopt: assert _fedopt_lr_scheduler is not None assert _fedopt_optimizer is not None - metrics["train/FedOpt learning rate"] = ( - ( - _fedopt_lr_scheduler.get_lr() - if _lr_scheduler is not None - else _fedopt_optimizer.param_groups[0]["lr"] - ), - ) + metrics["opt/FedOpt_lr"] = _fedopt_optimizer.param_groups[0][ + "lr" + ] + if logging.root.level == logging.DEBUG: metrics.update( { - "train/Forward time": batch_time, - "train/Backward time": back_time, - "train/Optimizer time": optimizer_time, - "train/Aggregation time": comm_time, - "train/Other time": other_step_time, - "train/Overall step time": overall_step_time, + "time/Forward": round(batch_time, 2), + "time/Backward": round(back_time, 2), + "time/Optimizer": round(optimizer_time, 2), + "time/Aggregation": round(comm_time, 2), + "time/Other": round(other_step_time, 2), + "time/Overall_step": round(overall_step_time, 2), } ) wandb_run.log(metrics) @@ -744,9 +737,9 @@ def distributed_training( if wandb_run: metrics: Mapping[str, Any] = { "train/Epoch": epoch + 1, - "train/Epoch loss": _train_epoch_loss, - "train/Epoch perplexity": train_epoch_perplexity, - "train/Epoch time": epoch_times[-1], + "train/Epoch_loss": _train_epoch_loss, + "train/Epoch_perplexity": train_epoch_perplexity, + "train/Epoch_time": epoch_times[-1], } wandb_run.log( metrics, @@ -924,10 +917,10 @@ def validation( if wandb_run: metrics: Mapping[str, Any] = { - "train/Epoch": epoch + 1, - "eval/Epoch loss": _val_epoch_loss, - "eval/Epoch perplexity": val_epoch_perplexity, - "train/Epoch time": epoch_total_time, + "eval/Epoch": epoch + 1, + "eval/Epoch_loss": _val_epoch_loss, + "eval/Epoch_perplexity": val_epoch_perplexity, + "time/Eval_time": epoch_total_time, } if correct is not None: metrics["eval/accuracy"] = val_acc diff --git a/xffl/learning/utils.py b/xffl/learning/utils.py index 719113a..3bcdc8e 100644 --- a/xffl/learning/utils.py +++ b/xffl/learning/utils.py @@ -5,6 +5,7 @@ import random import subprocess import sys +from dataclasses import asdict from logging import Logger, getLogger from pathlib import Path from typing import Any, Literal, Optional, Sequence, Type @@ -108,12 +109,12 @@ def get_model_size(model: nn.Module, state: DistributedState) -> int: params *= state.replica_local_size - elif state.is_fsdp_setup: + elif state.is_fsdp_setup(): assert state.world_size is not None params *= state.world_size - if state.is_federated_scaling_setup: + if state.is_federated_scaling_setup(): assert state.federated_world_size is not None params //= state.federated_world_size @@ -324,4 +325,5 @@ def wandb_setup( notes=_notes, tags=_tags, mode=_mode, + config=asdict(config) if config is not None else None, ) diff --git a/xffl/utils/utils.py b/xffl/utils/utils.py index 568b218..e095577 100644 --- a/xffl/utils/utils.py +++ b/xffl/utils/utils.py @@ -13,7 +13,7 @@ def get_timeout( - seconds: float = 120.0, + seconds: float = 3600.0, ) -> timedelta: """Maximum allowed timeout for distributed communications From 7fd25516319e64becc2e75f0db32069b9913962e Mon Sep 17 00:00:00 2001 From: Gianluca Mittone Date: Mon, 16 Mar 2026 11:08:19 +0100 Subject: [PATCH 3/5] Fix minor linting issues --- examples/EuroPar/config_FL+FSDP.py | 4 +--- examples/EuroPar/config_FL+HSDP.py | 4 +--- examples/EuroPar/config_FSDP.py | 4 +--- examples/EuroPar/config_HSDP.py | 4 +--- 4 files changed, 4 insertions(+), 12 deletions(-) diff --git a/examples/EuroPar/config_FL+FSDP.py b/examples/EuroPar/config_FL+FSDP.py index dbd39cc..a1dfbf1 100644 --- a/examples/EuroPar/config_FL+FSDP.py +++ b/examples/EuroPar/config_FL+FSDP.py @@ -2,7 +2,6 @@ import logging import math -import os from dataclasses import dataclass, field from functools import partial from pathlib import Path @@ -14,9 +13,8 @@ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.optim import AdamW, Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from transformers import AutoModelForCausalLM, default_data_collator +from transformers import default_data_collator from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM -from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer from xffl.custom.config import DatasetInfo, ModelInfo, XFFLConfig from xffl.distributed.distributed_state import DistributedState diff --git a/examples/EuroPar/config_FL+HSDP.py b/examples/EuroPar/config_FL+HSDP.py index b4ccf43..ce0ec6d 100644 --- a/examples/EuroPar/config_FL+HSDP.py +++ b/examples/EuroPar/config_FL+HSDP.py @@ -2,7 +2,6 @@ import logging import math -import os from dataclasses import dataclass, field from functools import partial from pathlib import Path @@ -14,9 +13,8 @@ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.optim import AdamW, Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from transformers import AutoModelForCausalLM, default_data_collator +from transformers import default_data_collator from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM -from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer from xffl.custom.config import DatasetInfo, ModelInfo, XFFLConfig from xffl.distributed.distributed_state import DistributedState diff --git a/examples/EuroPar/config_FSDP.py b/examples/EuroPar/config_FSDP.py index bfa1047..a98c784 100644 --- a/examples/EuroPar/config_FSDP.py +++ b/examples/EuroPar/config_FSDP.py @@ -2,7 +2,6 @@ import logging import math -import os from dataclasses import dataclass, field from functools import partial from pathlib import Path @@ -14,9 +13,8 @@ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.optim import AdamW, Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from transformers import AutoModelForCausalLM, default_data_collator +from transformers import default_data_collator from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM -from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer from xffl.custom.config import DatasetInfo, ModelInfo, XFFLConfig from xffl.distributed.distributed_state import DistributedState diff --git a/examples/EuroPar/config_HSDP.py b/examples/EuroPar/config_HSDP.py index 5cc4f55..7398599 100644 --- a/examples/EuroPar/config_HSDP.py +++ b/examples/EuroPar/config_HSDP.py @@ -2,7 +2,6 @@ import logging import math -import os from dataclasses import dataclass, field from functools import partial from pathlib import Path @@ -14,9 +13,8 @@ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.optim import AdamW, Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from transformers import AutoModelForCausalLM, default_data_collator +from transformers import default_data_collator from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM -from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer from xffl.custom.config import DatasetInfo, ModelInfo, XFFLConfig from xffl.distributed.distributed_state import DistributedState From 0f3251a8f9f38555662f429d215ba1263b3c90ae Mon Sep 17 00:00:00 2001 From: Gianluca Mittone Date: Mon, 16 Mar 2026 11:18:43 +0100 Subject: [PATCH 4/5] Minor fixes --- examples/intra-silo/01_simple_MLP/training.py | 2 +- examples/intra-silo/04_LLM_tokenizer/training.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/intra-silo/01_simple_MLP/training.py b/examples/intra-silo/01_simple_MLP/training.py index 7822fbc..aedf8aa 100644 --- a/examples/intra-silo/01_simple_MLP/training.py +++ b/examples/intra-silo/01_simple_MLP/training.py @@ -56,7 +56,7 @@ def pretraining(config: XFFLConfig) -> None: f"Model loading time: {(time.perf_counter() - start_time):.2f} seconds" ) logger.debug( - f"Training {config.model_info.name}: {(utils.get_model_size(model=model) / 1e6):.2f} million trainable parameters" + f"Training {config.model_info.name}: {(utils.get_model_size(model=model, state=state) / 1e6):.2f} million trainable parameters" ) # Dataset loading diff --git a/examples/intra-silo/04_LLM_tokenizer/training.py b/examples/intra-silo/04_LLM_tokenizer/training.py index b1b30b1..ae236fb 100644 --- a/examples/intra-silo/04_LLM_tokenizer/training.py +++ b/examples/intra-silo/04_LLM_tokenizer/training.py @@ -65,7 +65,7 @@ def pretraining(config: XFFLConfig) -> None: f"Model loading time: {(time.perf_counter() - start_time):.2f} seconds" ) logger.debug( - f"Training {config.model_info.name}: {(utils.get_model_size(model=model) / 1e6):.2f} million trainable parameters" + f"Training {config.model_info.name}: {(utils.get_model_size(model=model, state=state) / 1e6):.2f} million trainable parameters" ) # Dataset loading From 5e436259b8475440205add68705ddcc2550aedbe Mon Sep 17 00:00:00 2001 From: Gianluca Mittone Date: Mon, 16 Mar 2026 11:30:08 +0100 Subject: [PATCH 5/5] Minor linting fix --- examples/intra-silo/03_LLM/config.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/intra-silo/03_LLM/config.py b/examples/intra-silo/03_LLM/config.py index 2ebb41a..b8f9a7c 100644 --- a/examples/intra-silo/03_LLM/config.py +++ b/examples/intra-silo/03_LLM/config.py @@ -2,7 +2,6 @@ import logging import math -import os from dataclasses import dataclass, field from functools import partial from pathlib import Path @@ -14,9 +13,8 @@ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.optim import AdamW, Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from transformers import AutoModelForCausalLM, default_data_collator +from transformers import default_data_collator from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM -from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer from xffl.custom.config import DatasetInfo, ModelInfo, XFFLConfig from xffl.distributed.distributed_state import DistributedState