diff --git a/training/bf16_master_weight/README.md b/training/bf16_master_weight/README.md new file mode 100644 index 000000000..2af73a278 --- /dev/null +++ b/training/bf16_master_weight/README.md @@ -0,0 +1,73 @@ +# BF16 Low-Precision Master Weights and Optimizer States + +This example demonstrates DeepSpeed's [new low-precision training options](https://github.com/deepspeedai/DeepSpeed/pull/7700) that can significantly reduce memory usage: + +- `bf16_master_weights_and_grads`: Keep master parameters and gradients in BF16 instead of FP32 +- `bf16_optimizer_states`: Keep optimizer states (e.g., Adam moments) in BF16 + + +### Running an Example + +The following commands run training for 1000 steps on the Wikitext-103 dataset using both the baseline and BF16 low-precision configurations, then generates a loss comparison plot. +The model has approximately 6.86 billion parameters (hidden=4096, layers=32, heads=32, batch=1, seq=512). +For BF16 low-precision training, we use `torch.autocast`. + + +```bash +# Run 1000 steps with wikitext dataset +deepspeed --num_gpus=4 train.py --deepspeed_config configs/baseline.json \ + --num_layers 32 --hidden_dim 4096 --num_heads 32 --batch_size 1 \ + --num_steps 1000 --activation_checkpointing \ + --loss_log_file logs/baseline_loss.csv --use_real_data --seed 42 + +deepspeed --num_gpus=4 train.py --deepspeed_config configs/bf16_full.json \ + --num_layers 32 --hidden_dim 4096 --num_heads 32 --batch_size 1 \ + --num_steps 1000 --activation_checkpointing \ + --loss_log_file logs/bf16_full_loss.csv --use_real_data --seed 42 + +# Generate comparison plot +python plot_loss.py --baseline logs/baseline_loss.csv --bf16 logs/bf16_full_loss.csv \ + --output loss_comparison.png +``` + +Here is a summary of the memory usage and training time results using 4xH100. +This shows a significant memory reduction: Memory reduction: 9.57 GB allocated (37%), 12.45 GB peak (39.7%)**. + +| Configuration | Allocated Memory | Peak Memory | Avg Step Time | +|---------------|------------------|-------------|---------------| +| Baseline (fp32 master) | 25.74 GB | 31.38 GB | 0.6016s | +| BF16 low-precision (master + opt states) | **16.17 GB** | **18.93 GB** | 0.6427s | + + +## Loss Curve Comparison + +To verify that BF16 low-precision training maintains numerical stability, we trained for 1000 steps on the Wikitext-103 dataset: + +![Loss Comparison](logs/7b_loss_run/loss_comparison.png) + +| Configuration | Final Loss | Mean Loss | Loss Std | +|---------------|------------|-----------|----------| +| Baseline (fp32 master) | 3.09 | 2.78 | 1.56 | +| BF16 Low-Precision | 3.12 | 2.90 | 2.37 | + +The loss curves show that both configurations converge similarly, demonstrating that the reduced precision does not significantly impact training quality while providing substantial memory savings. + +### Memory Breakdown + +For a model with N parameters: + +| Component | Baseline | BF16 Low-Precision | +|-----------|----------|-------------------| +| Model params | 2N bytes (BF16) | 2N bytes (BF16) | +| Master weights | 4N bytes (FP32) | 2N bytes (BF16) | +| Gradients | 4N bytes (FP32) | 2N bytes (BF16) | +| Adam momentum | 4N bytes (FP32) | 2N bytes (BF16) | +| Adam variance | 4N bytes (FP32) | 2N bytes (BF16) | +| **Total** | **18N bytes** | **10N bytes** | + +This gives a theoretical ~44% reduction in optimizer-related memory. The actual savings depend on activation memory and other factors, but our results show a very close match to the theoretical savings. + +## Related Resources + +- [DeepSpeed BF16 Documentation](https://www.deepspeed.ai/docs/config-json/#bf16-training-options) +- [Low-precision master params PR](https://github.com/deepspeedai/DeepSpeed/pull/7700) diff --git a/training/bf16_master_weight/configs/baseline.json b/training/bf16_master_weight/configs/baseline.json new file mode 100644 index 000000000..525c4e00a --- /dev/null +++ b/training/bf16_master_weight/configs/baseline.json @@ -0,0 +1,31 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-4, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 0.01 + } + }, + + "bf16": { + "enabled": true + }, + + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": 5e7, + "stage3_param_persistence_threshold": 0 + }, + + "torch_autocast": { + "enabled": false + } +} diff --git a/training/bf16_master_weight/configs/bf16_full.json b/training/bf16_master_weight/configs/bf16_full.json new file mode 100644 index 000000000..cc854baae --- /dev/null +++ b/training/bf16_master_weight/configs/bf16_full.json @@ -0,0 +1,34 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-4, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 0.01 + } + }, + + "bf16": { + "enabled": true, + "bf16_master_weights_and_grads": true, + "bf16_optimizer_states": true + }, + + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": 5e7, + "stage3_param_persistence_threshold": 0 + }, + + "torch_autocast": { + "enabled": true, + "dtype": "torch.bfloat16" + } +} diff --git a/training/bf16_master_weight/gather_memory.py b/training/bf16_master_weight/gather_memory.py new file mode 100644 index 000000000..6eb9ab813 --- /dev/null +++ b/training/bf16_master_weight/gather_memory.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +""" +Script to gather and compare memory usage from training logs. + +Usage: + python gather_memory.py --log_dir logs/20231201_120000 +""" + +import argparse +import os +import re +from pathlib import Path + + +def parse_summary_line(line): + """Parse the SUMMARY line from log output.""" + pattern = r"SUMMARY: config=(\S+) params=(\d+) peak_mem_bytes=(\d+) alloc_mem_bytes=(\d+) avg_step_time=(\S+)" + match = re.search(pattern, line) + if match: + return { + "config": match.group(1), + "params": int(match.group(2)), + "peak_mem_bytes": int(match.group(3)), + "alloc_mem_bytes": int(match.group(4)), + "avg_step_time": float(match.group(5)), + } + return None + + +def format_bytes(bytes_val): + """Format bytes to human-readable string.""" + gb = bytes_val / (1024 ** 3) + return f"{gb:.2f} GB" + + +def format_bytes_mb(bytes_val): + """Format bytes to MB.""" + mb = bytes_val / (1024 ** 2) + return f"{mb:.1f} MB" + + +def get_config_name(config_path): + """Extract clean config name from path.""" + name = Path(config_path).stem + if name == "baseline": + return "Baseline (fp32 master)" + elif name == "bf16_master_wg": + return "bf16_master_weights_and_grads" + elif name == "bf16_full": + return "bf16_full (master + opt states)" + return name + + +def main(): + parser = argparse.ArgumentParser(description="Gather memory usage from training logs") + parser.add_argument("--log_dir", type=str, required=True, help="Directory containing log files") + parser.add_argument("--output", type=str, default=None, help="Output file for summary") + args = parser.parse_args() + + log_dir = Path(args.log_dir) + if not log_dir.exists(): + print(f"Error: Log directory '{log_dir}' does not exist") + return 1 + + # Find and parse all log files + results = [] + log_files = ["baseline.log", "bf16_full.log"] + + for log_file in log_files: + log_path = log_dir / log_file + if not log_path.exists(): + print(f"Warning: Log file '{log_path}' not found, skipping") + continue + + with open(log_path, "r") as f: + for line in f: + summary = parse_summary_line(line) + if summary: + results.append(summary) + break + + if not results: + print("No results found in log files") + return 1 + + # Calculate baseline for comparison + baseline_peak = None + for r in results: + if "baseline" in r["config"]: + baseline_peak = r["peak_mem_bytes"] + break + + # Generate summary + output_lines = [] + output_lines.append("=" * 80) + output_lines.append("BF16 Low-Precision Master Weights - Memory Usage Comparison") + output_lines.append("=" * 80) + output_lines.append("") + + # Table header + output_lines.append(f"{'Configuration':<40} {'Peak Memory':<15} {'Reduction':<15} {'Step Time':<12}") + output_lines.append("-" * 80) + + for r in results: + config_name = get_config_name(r["config"]) + peak_mem = format_bytes(r["peak_mem_bytes"]) + step_time = f"{r['avg_step_time']:.4f}s" + + if baseline_peak and baseline_peak > 0: + reduction = ((baseline_peak - r["peak_mem_bytes"]) / baseline_peak) * 100 + reduction_str = f"{reduction:+.1f}%" if reduction != 0 else "-" + else: + reduction_str = "-" + + output_lines.append(f"{config_name:<40} {peak_mem:<15} {reduction_str:<15} {step_time:<12}") + + output_lines.append("-" * 80) + output_lines.append("") + + # Detailed breakdown + output_lines.append("Detailed Results:") + output_lines.append("-" * 40) + for r in results: + config_name = get_config_name(r["config"]) + output_lines.append(f"\n{config_name}:") + output_lines.append(f" Parameters: {r['params']:,}") + output_lines.append(f" Peak Memory: {format_bytes(r['peak_mem_bytes'])} ({r['peak_mem_bytes']:,} bytes)") + output_lines.append(f" Allocated Memory: {format_bytes(r['alloc_mem_bytes'])} ({r['alloc_mem_bytes']:,} bytes)") + output_lines.append(f" Avg Step Time: {r['avg_step_time']:.4f}s") + + output_lines.append("") + output_lines.append("=" * 80) + + # Generate markdown table + output_lines.append("") + output_lines.append("Markdown Table (for README):") + output_lines.append("-" * 40) + output_lines.append("") + output_lines.append("| Configuration | Peak Memory | Memory Reduction | Avg Step Time |") + output_lines.append("|---------------|-------------|------------------|---------------|") + + for r in results: + config_name = get_config_name(r["config"]) + peak_mem = format_bytes(r["peak_mem_bytes"]) + step_time = f"{r['avg_step_time']:.4f}s" + + if baseline_peak and baseline_peak > 0: + reduction = ((baseline_peak - r["peak_mem_bytes"]) / baseline_peak) * 100 + reduction_str = f"{reduction:+.1f}%" if reduction != 0 else "-" + else: + reduction_str = "-" + + output_lines.append(f"| {config_name} | {peak_mem} | {reduction_str} | {step_time} |") + + output_lines.append("") + + # Print to stdout + summary_text = "\n".join(output_lines) + print(summary_text) + + # Save to file + output_path = args.output or (log_dir / "summary.txt") + with open(output_path, "w") as f: + f.write(summary_text) + + print(f"\nSummary saved to: {output_path}") + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/training/bf16_master_weight/logs/7b_loss_run/baseline_loss.csv b/training/bf16_master_weight/logs/7b_loss_run/baseline_loss.csv new file mode 100644 index 000000000..0915cc541 --- /dev/null +++ b/training/bf16_master_weight/logs/7b_loss_run/baseline_loss.csv @@ -0,0 +1,1001 @@ +step,loss +0,11.125 +1,7.15625 +2,5.875 +3,3.984375 +4,13.4375 +5,12.1875 +6,3.765625 +7,5.0625 +8,4.15625 +9,0.96484375 +10,8.625 +11,5.28125 +12,5.09375 +13,6.625 +14,4.0 +15,9.375 +16,7.0625 +17,5.28125 +18,4.625 +19,1.7890625 +20,1.453125 +21,2.0625 +22,2.984375 +23,2.234375 +24,5.53125 +25,2.859375 +26,4.3125 +27,5.3125 +28,2.546875 +29,5.15625 +30,2.0 +31,4.1875 +32,3.859375 +33,6.0625 +34,1.8671875 +35,4.96875 +36,1.8203125 +37,6.03125 +38,1.3359375 +39,1.515625 +40,1.1953125 +41,1.421875 +42,3.15625 +43,4.96875 +44,5.5 +45,1.53125 +46,2.0625 +47,3.078125 +48,3.375 +49,2.78125 +50,2.34375 +51,3.203125 +52,2.609375 +53,2.578125 +54,4.25 +55,3.578125 +56,6.21875 +57,3.734375 +58,4.8125 +59,2.859375 +60,3.546875 +61,0.8359375 +62,4.375 +63,4.71875 +64,2.203125 +65,2.015625 +66,3.703125 +67,3.234375 +68,1.296875 +69,1.515625 +70,1.5703125 +71,3.984375 +72,3.671875 +73,5.46875 +74,3.03125 +75,6.59375 +76,1.8828125 +77,2.875 +78,1.3984375 +79,0.984375 +80,4.65625 +81,2.78125 +82,2.9375 +83,2.59375 +84,5.5 +85,0.98046875 +86,1.4453125 +87,2.015625 +88,4.53125 +89,1.0234375 +90,1.9296875 +91,1.4453125 +92,2.5625 +93,2.703125 +94,3.484375 +95,6.1875 +96,2.265625 +97,1.8359375 +98,6.1875 +99,5.15625 +100,1.2734375 +101,2.90625 +102,1.59375 +103,0.54296875 +104,1.3046875 +105,2.171875 +106,2.484375 +107,1.5546875 +108,1.625 +109,0.60546875 +110,2.125 +111,2.046875 +112,5.96875 +113,1.9609375 +114,2.390625 +115,3.921875 +116,2.65625 +117,3.96875 +118,2.109375 +119,3.140625 +120,3.421875 +121,2.3125 +122,3.6875 +123,1.4140625 +124,0.63671875 +125,5.5625 +126,3.40625 +127,6.75 +128,2.40625 +129,2.71875 +130,3.34375 +131,4.90625 +132,6.28125 +133,2.203125 +134,5.34375 +135,1.4609375 +136,1.53125 +137,2.515625 +138,3.15625 +139,4.90625 +140,3.671875 +141,1.734375 +142,2.359375 +143,1.6796875 +144,2.15625 +145,0.6171875 +146,1.078125 +147,2.15625 +148,2.25 +149,3.5 +150,3.171875 +151,3.1875 +152,3.328125 +153,0.8984375 +154,0.87890625 +155,2.65625 +156,4.03125 +157,3.0625 +158,3.015625 +159,2.171875 +160,3.28125 +161,1.8046875 +162,4.125 +163,1.4140625 +164,4.25 +165,4.5 +166,3.359375 +167,3.1875 +168,3.5 +169,2.78125 +170,3.53125 +171,2.953125 +172,0.5546875 +173,2.015625 +174,1.1796875 +175,2.046875 +176,1.8125 +177,1.640625 +178,2.875 +179,4.59375 +180,0.65234375 +181,7.25 +182,3.75 +183,1.46875 +184,2.046875 +185,3.0625 +186,0.77734375 +187,1.5859375 +188,2.1875 +189,0.51953125 +190,4.28125 +191,2.890625 +192,4.46875 +193,1.0 +194,2.578125 +195,1.90625 +196,5.0 +197,1.71875 +198,6.59375 +199,0.6015625 +200,1.2578125 +201,4.25 +202,0.43359375 +203,1.2265625 +204,1.875 +205,2.921875 +206,2.859375 +207,1.8828125 +208,1.09375 +209,7.8125 +210,2.390625 +211,1.8125 +212,1.2578125 +213,0.671875 +214,1.9765625 +215,1.59375 +216,2.796875 +217,2.03125 +218,1.875 +219,3.0625 +220,1.0859375 +221,3.140625 +222,1.265625 +223,0.7265625 +224,0.96875 +225,3.6875 +226,4.03125 +227,4.53125 +228,2.59375 +229,1.9921875 +230,0.82421875 +231,0.7890625 +232,3.078125 +233,3.65625 +234,3.578125 +235,2.265625 +236,1.1484375 +237,2.875 +238,2.5625 +239,0.94921875 +240,2.78125 +241,0.64453125 +242,4.03125 +243,1.2421875 +244,3.171875 +245,1.9765625 +246,0.5859375 +247,2.828125 +248,3.9375 +249,1.140625 +250,2.5 +251,3.90625 +252,5.0625 +253,6.15625 +254,2.609375 +255,2.4375 +256,2.453125 +257,1.8359375 +258,1.890625 +259,2.640625 +260,1.609375 +261,4.84375 +262,3.859375 +263,0.86328125 +264,0.81640625 +265,5.28125 +266,1.9921875 +267,5.8125 +268,0.828125 +269,0.984375 +270,0.90234375 +271,2.890625 +272,6.03125 +273,2.40625 +274,3.453125 +275,3.015625 +276,2.3125 +277,2.984375 +278,3.65625 +279,3.3125 +280,5.8125 +281,0.7421875 +282,5.1875 +283,3.109375 +284,2.34375 +285,0.8203125 +286,1.8671875 +287,1.671875 +288,0.6875 +289,4.09375 +290,2.890625 +291,1.171875 +292,2.953125 +293,3.65625 +294,0.7578125 +295,1.59375 +296,2.171875 +297,2.1875 +298,3.609375 +299,3.34375 +300,3.515625 +301,7.75 +302,3.203125 +303,4.84375 +304,3.125 +305,3.6875 +306,3.9375 +307,0.8046875 +308,0.625 +309,1.609375 +310,0.62890625 +311,2.0625 +312,2.296875 +313,1.0859375 +314,2.671875 +315,2.859375 +316,1.703125 +317,2.71875 +318,2.546875 +319,2.234375 +320,3.53125 +321,3.234375 +322,2.84375 +323,1.015625 +324,3.5625 +325,0.60546875 +326,1.5546875 +327,1.328125 +328,4.46875 +329,3.109375 +330,2.171875 +331,1.8984375 +332,4.34375 +333,2.515625 +334,2.75 +335,2.6875 +336,2.953125 +337,3.828125 +338,2.171875 +339,4.5625 +340,4.0 +341,1.6015625 +342,3.28125 +343,4.65625 +344,2.28125 +345,9.375 +346,5.46875 +347,3.5 +348,1.1796875 +349,2.890625 +350,3.03125 +351,2.640625 +352,1.75 +353,4.96875 +354,2.359375 +355,1.7109375 +356,3.8125 +357,4.65625 +358,2.6875 +359,1.5703125 +360,2.53125 +361,2.5625 +362,2.25 +363,1.8515625 +364,1.9453125 +365,2.71875 +366,0.890625 +367,3.875 +368,3.46875 +369,3.0 +370,0.8828125 +371,3.359375 +372,1.6171875 +373,5.25 +374,3.84375 +375,3.140625 +376,0.98046875 +377,4.875 +378,2.671875 +379,1.0546875 +380,1.3828125 +381,3.078125 +382,1.1640625 +383,3.234375 +384,2.140625 +385,3.671875 +386,3.8125 +387,1.0 +388,4.5 +389,4.09375 +390,6.96875 +391,2.9375 +392,2.140625 +393,5.84375 +394,5.875 +395,3.25 +396,3.453125 +397,1.8125 +398,1.3828125 +399,0.9140625 +400,2.765625 +401,1.40625 +402,0.84375 +403,3.171875 +404,2.359375 +405,0.65625 +406,2.296875 +407,2.390625 +408,2.578125 +409,0.75390625 +410,4.1875 +411,1.3828125 +412,3.640625 +413,4.5625 +414,3.671875 +415,5.96875 +416,4.1875 +417,1.90625 +418,2.046875 +419,4.6875 +420,0.66796875 +421,1.9140625 +422,3.375 +423,4.78125 +424,1.8046875 +425,2.8125 +426,2.296875 +427,1.0390625 +428,1.9921875 +429,1.6484375 +430,0.9140625 +431,4.5 +432,2.390625 +433,2.03125 +434,2.546875 +435,2.109375 +436,5.65625 +437,6.53125 +438,2.28125 +439,4.96875 +440,0.90234375 +441,2.25 +442,3.3125 +443,5.0625 +444,2.46875 +445,3.1875 +446,1.953125 +447,6.46875 +448,4.78125 +449,6.5625 +450,3.015625 +451,1.1640625 +452,1.7109375 +453,4.125 +454,3.15625 +455,5.03125 +456,1.3671875 +457,1.7890625 +458,1.734375 +459,4.125 +460,4.53125 +461,3.578125 +462,0.77734375 +463,1.546875 +464,4.5625 +465,4.40625 +466,1.3984375 +467,2.03125 +468,2.109375 +469,3.5 +470,2.21875 +471,1.4140625 +472,5.09375 +473,3.40625 +474,3.34375 +475,4.40625 +476,1.90625 +477,2.484375 +478,1.84375 +479,1.3359375 +480,2.765625 +481,9.125 +482,3.625 +483,3.828125 +484,1.8046875 +485,4.8125 +486,1.84375 +487,1.1484375 +488,1.3828125 +489,1.5078125 +490,1.84375 +491,2.921875 +492,2.859375 +493,2.234375 +494,4.1875 +495,2.796875 +496,1.2109375 +497,2.5625 +498,1.921875 +499,3.5 +500,0.85546875 +501,3.34375 +502,2.5 +503,2.859375 +504,2.96875 +505,0.91796875 +506,2.796875 +507,1.359375 +508,3.796875 +509,2.390625 +510,5.59375 +511,5.65625 +512,1.703125 +513,1.1796875 +514,2.21875 +515,0.61328125 +516,2.703125 +517,1.921875 +518,1.8828125 +519,2.796875 +520,2.859375 +521,3.328125 +522,1.21875 +523,2.234375 +524,1.421875 +525,1.3984375 +526,1.6875 +527,2.296875 +528,4.53125 +529,2.796875 +530,2.28125 +531,4.40625 +532,2.703125 +533,3.21875 +534,1.4609375 +535,2.578125 +536,1.9375 +537,2.765625 +538,4.0 +539,2.453125 +540,3.8125 +541,3.703125 +542,2.46875 +543,3.625 +544,2.734375 +545,2.46875 +546,2.625 +547,3.109375 +548,0.99609375 +549,0.7109375 +550,3.28125 +551,0.6171875 +552,3.265625 +553,2.734375 +554,2.1875 +555,1.734375 +556,7.65625 +557,5.34375 +558,3.03125 +559,3.671875 +560,2.21875 +561,2.5 +562,2.796875 +563,4.78125 +564,3.609375 +565,0.81640625 +566,0.78515625 +567,4.46875 +568,3.921875 +569,3.375 +570,0.9765625 +571,4.90625 +572,4.03125 +573,3.453125 +574,2.1875 +575,0.59375 +576,0.4921875 +577,1.0 +578,4.28125 +579,2.421875 +580,1.984375 +581,3.34375 +582,2.03125 +583,3.84375 +584,1.7265625 +585,1.3359375 +586,3.0625 +587,2.671875 +588,2.59375 +589,3.171875 +590,2.875 +591,0.61328125 +592,1.328125 +593,0.609375 +594,2.875 +595,1.6640625 +596,2.640625 +597,4.28125 +598,0.7890625 +599,3.484375 +600,1.6328125 +601,3.59375 +602,0.84375 +603,0.65625 +604,2.515625 +605,1.453125 +606,1.2734375 +607,3.578125 +608,5.78125 +609,0.67578125 +610,3.453125 +611,1.4296875 +612,2.09375 +613,2.90625 +614,3.3125 +615,0.65625 +616,2.265625 +617,1.8359375 +618,3.578125 +619,5.375 +620,2.9375 +621,5.6875 +622,1.7265625 +623,1.578125 +624,0.75390625 +625,3.875 +626,3.25 +627,1.953125 +628,3.796875 +629,1.6953125 +630,2.328125 +631,2.34375 +632,3.265625 +633,0.79296875 +634,2.953125 +635,3.140625 +636,2.453125 +637,2.53125 +638,3.171875 +639,2.78125 +640,2.421875 +641,5.625 +642,3.015625 +643,3.203125 +644,0.62109375 +645,0.67578125 +646,2.375 +647,3.8125 +648,4.53125 +649,1.34375 +650,4.5 +651,2.84375 +652,1.46875 +653,3.0 +654,3.5625 +655,3.03125 +656,1.6875 +657,0.94140625 +658,1.3671875 +659,1.109375 +660,0.765625 +661,4.03125 +662,2.546875 +663,3.453125 +664,3.328125 +665,5.25 +666,2.0 +667,2.5625 +668,1.859375 +669,1.984375 +670,2.203125 +671,0.64453125 +672,3.171875 +673,1.8828125 +674,5.5 +675,0.546875 +676,2.0 +677,1.4765625 +678,2.28125 +679,3.0 +680,1.9921875 +681,3.640625 +682,2.609375 +683,1.7890625 +684,3.046875 +685,0.6875 +686,3.34375 +687,0.6875 +688,1.265625 +689,1.75 +690,2.984375 +691,1.5546875 +692,1.09375 +693,3.296875 +694,0.9765625 +695,0.7109375 +696,2.53125 +697,5.78125 +698,3.4375 +699,5.0 +700,3.078125 +701,3.5625 +702,4.78125 +703,2.234375 +704,0.62109375 +705,2.21875 +706,3.03125 +707,3.25 +708,2.421875 +709,4.59375 +710,3.0625 +711,2.609375 +712,2.953125 +713,3.015625 +714,1.5390625 +715,1.1328125 +716,0.83984375 +717,1.9140625 +718,3.296875 +719,3.34375 +720,4.03125 +721,3.28125 +722,1.421875 +723,3.59375 +724,5.09375 +725,2.890625 +726,2.65625 +727,3.234375 +728,4.625 +729,1.4296875 +730,2.234375 +731,3.578125 +732,1.5234375 +733,2.375 +734,1.3359375 +735,0.72265625 +736,3.515625 +737,2.09375 +738,6.8125 +739,1.984375 +740,2.296875 +741,4.125 +742,2.203125 +743,3.71875 +744,2.203125 +745,4.1875 +746,1.5546875 +747,5.46875 +748,1.71875 +749,2.703125 +750,2.765625 +751,4.21875 +752,4.59375 +753,1.6640625 +754,5.5625 +755,1.5859375 +756,0.859375 +757,0.73828125 +758,2.71875 +759,3.90625 +760,2.078125 +761,2.1875 +762,2.078125 +763,1.59375 +764,2.609375 +765,2.53125 +766,4.875 +767,0.5078125 +768,4.03125 +769,3.328125 +770,1.703125 +771,2.734375 +772,3.59375 +773,0.82421875 +774,2.734375 +775,0.69921875 +776,2.21875 +777,3.15625 +778,2.421875 +779,1.734375 +780,2.578125 +781,2.578125 +782,1.0859375 +783,1.65625 +784,2.59375 +785,4.1875 +786,2.515625 +787,2.15625 +788,4.1875 +789,0.984375 +790,0.71875 +791,2.734375 +792,1.03125 +793,1.796875 +794,1.9296875 +795,3.296875 +796,5.09375 +797,4.25 +798,2.75 +799,2.015625 +800,1.65625 +801,2.859375 +802,2.3125 +803,0.8515625 +804,4.5625 +805,3.015625 +806,2.734375 +807,1.5234375 +808,1.328125 +809,1.3828125 +810,2.546875 +811,2.25 +812,1.9453125 +813,1.71875 +814,0.921875 +815,2.625 +816,3.53125 +817,4.125 +818,1.53125 +819,4.0 +820,1.359375 +821,1.7578125 +822,2.796875 +823,4.5625 +824,2.40625 +825,0.71875 +826,1.7265625 +827,0.69921875 +828,1.4609375 +829,3.96875 +830,1.078125 +831,3.296875 +832,2.109375 +833,1.390625 +834,4.28125 +835,3.078125 +836,1.671875 +837,1.5625 +838,1.890625 +839,2.109375 +840,1.5546875 +841,2.4375 +842,1.1015625 +843,2.015625 +844,2.15625 +845,1.609375 +846,3.65625 +847,2.625 +848,2.109375 +849,1.546875 +850,4.6875 +851,3.140625 +852,3.328125 +853,0.78125 +854,2.46875 +855,6.0 +856,4.28125 +857,0.8984375 +858,4.40625 +859,1.3828125 +860,1.53125 +861,0.96875 +862,2.3125 +863,1.640625 +864,1.7734375 +865,4.1875 +866,1.78125 +867,2.984375 +868,2.484375 +869,1.203125 +870,3.984375 +871,2.890625 +872,2.171875 +873,3.234375 +874,3.140625 +875,0.625 +876,1.171875 +877,3.5 +878,1.3125 +879,0.58984375 +880,2.21875 +881,1.3359375 +882,2.46875 +883,2.953125 +884,4.21875 +885,2.78125 +886,2.28125 +887,0.734375 +888,2.421875 +889,2.890625 +890,1.6328125 +891,3.28125 +892,3.234375 +893,4.59375 +894,1.390625 +895,2.359375 +896,1.4921875 +897,2.578125 +898,2.65625 +899,3.734375 +900,3.140625 +901,1.96875 +902,1.7578125 +903,3.03125 +904,0.703125 +905,2.921875 +906,1.734375 +907,1.3046875 +908,2.53125 +909,1.7109375 +910,3.828125 +911,2.828125 +912,1.7734375 +913,1.3515625 +914,2.46875 +915,3.125 +916,0.62109375 +917,5.03125 +918,4.125 +919,3.171875 +920,0.7734375 +921,1.5703125 +922,3.109375 +923,1.9609375 +924,2.671875 +925,2.84375 +926,5.59375 +927,0.73046875 +928,3.03125 +929,2.78125 +930,3.375 +931,2.84375 +932,3.03125 +933,5.34375 +934,1.3984375 +935,1.1484375 +936,2.5 +937,1.734375 +938,2.421875 +939,3.90625 +940,2.515625 +941,3.453125 +942,2.1875 +943,3.375 +944,4.34375 +945,4.03125 +946,1.4765625 +947,2.453125 +948,1.921875 +949,4.6875 +950,1.71875 +951,1.9140625 +952,3.265625 +953,1.6640625 +954,6.4375 +955,2.75 +956,2.171875 +957,5.28125 +958,3.328125 +959,1.6484375 +960,2.546875 +961,2.140625 +962,6.09375 +963,2.09375 +964,1.203125 +965,1.328125 +966,2.734375 +967,4.09375 +968,2.203125 +969,2.296875 +970,6.75 +971,1.0234375 +972,5.46875 +973,6.84375 +974,5.875 +975,2.828125 +976,2.4375 +977,3.03125 +978,5.5 +979,2.484375 +980,3.0625 +981,1.375 +982,3.265625 +983,3.25 +984,2.109375 +985,0.80078125 +986,0.470703125 +987,3.234375 +988,1.7578125 +989,2.109375 +990,0.9296875 +991,4.28125 +992,2.15625 +993,2.0625 +994,1.3984375 +995,1.65625 +996,0.66796875 +997,1.6328125 +998,2.875 +999,3.09375 diff --git a/training/bf16_master_weight/logs/7b_loss_run/bf16_full_loss.csv b/training/bf16_master_weight/logs/7b_loss_run/bf16_full_loss.csv new file mode 100644 index 000000000..362c11199 --- /dev/null +++ b/training/bf16_master_weight/logs/7b_loss_run/bf16_full_loss.csv @@ -0,0 +1,1001 @@ +step,loss +0,11.1134033203125 +1,7.132085800170898 +2,4.012870788574219 +3,36.0167236328125 +4,24.2044677734375 +5,26.455078125 +6,20.6685791015625 +7,27.5283203125 +8,21.50885009765625 +9,4.9122314453125 +10,8.48201847076416 +11,3.808837413787842 +12,8.899911880493164 +13,8.652142524719238 +14,3.5635316371917725 +15,6.893402099609375 +16,6.127050399780273 +17,4.923583984375 +18,4.4448394775390625 +19,2.13494873046875 +20,2.0114288330078125 +21,1.949951171875 +22,2.407745361328125 +23,1.5159912109375 +24,6.731402397155762 +25,3.2044296264648438 +26,4.803050994873047 +27,5.274627685546875 +28,2.7762451171875 +29,5.0152130126953125 +30,2.579620361328125 +31,4.108551025390625 +32,3.868255615234375 +33,6.349365234375 +34,1.6423149108886719 +35,5.202238082885742 +36,1.54669189453125 +37,6.46185302734375 +38,1.2036895751953125 +39,1.492156982421875 +40,1.2420654296875 +41,1.4814453125 +42,3.099506378173828 +43,4.7874755859375 +44,5.309574127197266 +45,1.6868438720703125 +46,2.1453781127929688 +47,3.08148193359375 +48,3.3328895568847656 +49,2.7371826171875 +50,2.27471923828125 +51,3.1996421813964844 +52,2.597991943359375 +53,2.57806396484375 +54,4.214580535888672 +55,3.5824012756347656 +56,6.238430023193359 +57,3.6706466674804688 +58,4.710197448730469 +59,2.9067230224609375 +60,3.5498733520507812 +61,0.910888671875 +62,4.35565185546875 +63,4.719722747802734 +64,2.167003631591797 +65,1.9737205505371094 +66,3.7718124389648438 +67,3.2602691650390625 +68,1.2766494750976562 +69,1.5045166015625 +70,1.5656356811523438 +71,3.9177169799804688 +72,3.622711181640625 +73,5.317466735839844 +74,3.068389892578125 +75,6.45367431640625 +76,1.9294891357421875 +77,2.8223037719726562 +78,1.311737060546875 +79,0.8475265502929688 +80,4.85711669921875 +81,2.817279815673828 +82,2.951894760131836 +83,2.5832557678222656 +84,5.453367233276367 +85,1.0596847534179688 +86,1.5508003234863281 +87,2.1005859375 +88,4.404327392578125 +89,1.083099365234375 +90,1.9316482543945312 +91,1.4068145751953125 +92,2.5340423583984375 +93,2.6919097900390625 +94,3.4820175170898438 +95,6.285247802734375 +96,2.244140625 +97,1.8041229248046875 +98,6.160301208496094 +99,5.0840911865234375 +100,1.2824630737304688 +101,2.894184112548828 +102,1.6238822937011719 +103,0.6098289489746094 +104,1.3268051147460938 +105,2.1666526794433594 +106,2.453876495361328 +107,1.551025390625 +108,1.6232891082763672 +109,0.5829753875732422 +110,2.1284561157226562 +111,2.0172958374023438 +112,6.07049560546875 +113,1.935211181640625 +114,2.33990478515625 +115,3.9076614379882812 +116,2.6385498046875 +117,3.9417266845703125 +118,2.087188720703125 +119,3.135406494140625 +120,3.41455078125 +121,2.2973289489746094 +122,3.6464080810546875 +123,1.4052810668945312 +124,0.600006103515625 +125,5.67963981628418 +126,3.4544477462768555 +127,6.842742919921875 +128,2.3925647735595703 +129,2.6814651489257812 +130,3.316925048828125 +131,4.820037841796875 +132,6.195037841796875 +133,2.1667022705078125 +134,5.387062072753906 +135,1.3815345764160156 +136,1.4716072082519531 +137,2.519887924194336 +138,3.1788883209228516 +139,4.956596374511719 +140,3.654022216796875 +141,1.7361297607421875 +142,2.3483810424804688 +143,1.7004280090332031 +144,2.1595535278320312 +145,0.64605712890625 +146,1.0794601440429688 +147,2.1403236389160156 +148,2.229206085205078 +149,3.4716796875 +150,3.1315536499023438 +151,3.1563339233398438 +152,3.2884521484375 +153,0.9025955200195312 +154,0.8674774169921875 +155,2.6904144287109375 +156,4.048305511474609 +157,3.055718421936035 +158,2.996180534362793 +159,2.1611242294311523 +160,3.25347900390625 +161,1.7946586608886719 +162,4.068458557128906 +163,1.45806884765625 +164,4.2074127197265625 +165,4.471534729003906 +166,3.3179588317871094 +167,3.15625 +168,3.4902725219726562 +169,2.7586212158203125 +170,3.4999923706054688 +171,2.938079833984375 +172,0.5755615234375 +173,2.0198135375976562 +174,1.1787986755371094 +175,2.032512664794922 +176,1.7835540771484375 +177,1.6204299926757812 +178,2.8943557739257812 +179,4.579692840576172 +180,0.6011581420898438 +181,7.233245849609375 +182,3.7090682983398438 +183,1.4494476318359375 +184,2.058147430419922 +185,3.034748077392578 +186,0.79302978515625 +187,1.5850868225097656 +188,2.1826953887939453 +189,0.502166748046875 +190,4.2438507080078125 +191,2.8563690185546875 +192,4.449333190917969 +193,1.0401611328125 +194,2.5916748046875 +195,1.8848648071289062 +196,5.068851470947266 +197,1.7050399780273438 +198,6.735166549682617 +199,0.5814323425292969 +200,1.2588768005371094 +201,4.211669921875 +202,0.4931640625 +203,1.26312255859375 +204,1.9409255981445312 +205,2.9423751831054688 +206,2.848846435546875 +207,1.8715667724609375 +208,1.0339775085449219 +209,7.9362335205078125 +210,2.3682479858398438 +211,1.7879524230957031 +212,1.24407958984375 +213,0.6689834594726562 +214,1.9735183715820312 +215,1.6094589233398438 +216,2.7777099609375 +217,2.0335693359375 +218,1.8849143981933594 +219,3.063079833984375 +220,1.0706634521484375 +221,3.130584716796875 +222,1.2679176330566406 +223,0.7518539428710938 +224,0.9654579162597656 +225,3.7058658599853516 +226,4.020160675048828 +227,4.510618209838867 +228,2.578065872192383 +229,1.9994659423828125 +230,0.9467849731445312 +231,0.8559494018554688 +232,3.0728759765625 +233,3.6545753479003906 +234,3.615123748779297 +235,2.252063751220703 +236,1.118560791015625 +237,2.8727340698242188 +238,2.560455322265625 +239,0.967529296875 +240,2.76373291015625 +241,0.689727783203125 +242,3.9878807067871094 +243,1.2500762939453125 +244,3.176746368408203 +245,1.9695472717285156 +246,0.5463180541992188 +247,2.8335800170898438 +248,3.9300308227539062 +249,1.123687744140625 +250,2.5000457763671875 +251,3.8977737426757812 +252,5.036094665527344 +253,6.100738525390625 +254,2.600006103515625 +255,2.4392242431640625 +256,2.445037841796875 +257,1.8540763854980469 +258,1.8763656616210938 +259,2.6569042205810547 +260,1.5930023193359375 +261,4.9168548583984375 +262,3.9106521606445312 +263,0.8489055633544922 +264,0.8064384460449219 +265,5.282951354980469 +266,1.9851341247558594 +267,5.7763519287109375 +268,0.8550567626953125 +269,1.0300445556640625 +270,0.9309310913085938 +271,2.892791748046875 +272,6.195487976074219 +273,2.441011428833008 +274,3.4674148559570312 +275,3.0429115295410156 +276,2.309520721435547 +277,2.98455810546875 +278,3.6409454345703125 +279,3.296466827392578 +280,5.745349884033203 +281,0.7671279907226562 +282,5.163974761962891 +283,3.103363037109375 +284,2.3465194702148438 +285,0.8010101318359375 +286,1.8536376953125 +287,1.6527175903320312 +288,0.6404037475585938 +289,4.074607849121094 +290,2.8751220703125 +291,1.1809844970703125 +292,2.9337730407714844 +293,3.6357154846191406 +294,0.79071044921875 +295,1.6123733520507812 +296,2.164337158203125 +297,2.1856632232666016 +298,3.6139373779296875 +299,3.333059310913086 +300,3.5030879974365234 +301,7.875446319580078 +302,3.210662841796875 +303,4.817474365234375 +304,3.1644973754882812 +305,3.7071151733398438 +306,3.9317092895507812 +307,0.7196159362792969 +308,0.5052833557128906 +309,1.595855712890625 +310,0.5883102416992188 +311,2.079524040222168 +312,2.30682373046875 +313,1.0888118743896484 +314,2.637971878051758 +315,2.823902130126953 +316,1.80908203125 +317,2.781005859375 +318,2.5951385498046875 +319,2.2735366821289062 +320,3.5009918212890625 +321,3.203327178955078 +322,2.8149452209472656 +323,0.9082870483398438 +324,3.644315719604492 +325,0.5691452026367188 +326,1.5577869415283203 +327,1.3331146240234375 +328,4.3954315185546875 +329,3.05267333984375 +330,2.176136016845703 +331,1.9119300842285156 +332,4.3045654296875 +333,2.5377044677734375 +334,2.7671661376953125 +335,2.6728286743164062 +336,2.9331512451171875 +337,3.832061767578125 +338,2.130687713623047 +339,4.605674743652344 +340,4.019569396972656 +341,1.5959663391113281 +342,3.273557662963867 +343,4.640205383300781 +344,2.2807388305664062 +345,9.244529724121094 +346,5.3750152587890625 +347,3.497802734375 +348,1.22119140625 +349,2.889404296875 +350,3.0145111083984375 +351,2.6391067504882812 +352,1.7414531707763672 +353,5.096355438232422 +354,2.3698997497558594 +355,1.70794677734375 +356,3.8087921142578125 +357,4.649024963378906 +358,2.6874847412109375 +359,1.6042404174804688 +360,2.5404624938964844 +361,2.5606842041015625 +362,2.267913818359375 +363,1.85089111328125 +364,1.9569931030273438 +365,2.70562744140625 +366,0.8981552124023438 +367,3.89678955078125 +368,3.4724559783935547 +369,3.019786834716797 +370,0.8579788208007812 +371,3.39013671875 +372,1.5927085876464844 +373,5.293540954589844 +374,3.840545654296875 +375,3.153797149658203 +376,1.0394439697265625 +377,4.864471435546875 +378,2.69451904296875 +379,1.115753173828125 +380,1.3917999267578125 +381,3.06719970703125 +382,1.1477203369140625 +383,3.240375518798828 +384,2.1412124633789062 +385,3.6806678771972656 +386,3.8046875 +387,0.9766197204589844 +388,4.5174560546875 +389,4.07159423828125 +390,6.927947998046875 +391,2.9377593994140625 +392,2.157146453857422 +393,5.8182525634765625 +394,5.853107452392578 +395,3.243896484375 +396,3.44207763671875 +397,1.8158073425292969 +398,1.38507080078125 +399,0.9081001281738281 +400,2.771198272705078 +401,1.39654541015625 +402,0.8424606323242188 +403,3.1727371215820312 +404,2.365640640258789 +405,0.6451492309570312 +406,2.2880172729492188 +407,2.3833656311035156 +408,2.5667648315429688 +409,0.7316741943359375 +410,4.179656982421875 +411,1.382537841796875 +412,3.6302032470703125 +413,4.5171966552734375 +414,3.6554718017578125 +415,5.948081970214844 +416,4.1723175048828125 +417,1.8863525390625 +418,2.03424072265625 +419,4.659786224365234 +420,0.68743896484375 +421,1.9088134765625 +422,3.3668212890625 +423,4.7425079345703125 +424,1.8061065673828125 +425,2.8161659240722656 +426,2.3002777099609375 +427,1.03265380859375 +428,1.9893379211425781 +429,1.6477088928222656 +430,0.9208526611328125 +431,4.5027618408203125 +432,2.3863067626953125 +433,2.038341522216797 +434,2.5567398071289062 +435,2.1070709228515625 +436,5.679634094238281 +437,6.574100494384766 +438,2.2461776733398438 +439,4.967735290527344 +440,0.93743896484375 +441,2.2551002502441406 +442,3.2996063232421875 +443,5.007568359375 +444,2.4714508056640625 +445,3.1668853759765625 +446,1.9462738037109375 +447,6.492755889892578 +448,4.795219421386719 +449,6.624614715576172 +450,3.016498565673828 +451,1.133056640625 +452,1.70758056640625 +453,4.1140289306640625 +454,3.1417903900146484 +455,4.935401916503906 +456,1.3822021484375 +457,1.7961349487304688 +458,1.7515106201171875 +459,4.09033203125 +460,4.515331268310547 +461,3.5706787109375 +462,0.7114334106445312 +463,1.5092239379882812 +464,4.604736328125 +465,4.4217529296875 +466,1.3947181701660156 +467,2.0403671264648438 +468,2.1144866943359375 +469,3.482555389404297 +470,2.2158432006835938 +471,1.4167633056640625 +472,5.098869323730469 +473,3.3986968994140625 +474,3.340484619140625 +475,4.4418487548828125 +476,1.8847389221191406 +477,2.4641494750976562 +478,1.8467636108398438 +479,1.3627281188964844 +480,2.7605056762695312 +481,8.9556884765625 +482,3.627307891845703 +483,3.8278656005859375 +484,1.8056793212890625 +485,4.8571014404296875 +486,1.7985496520996094 +487,1.1248779296875 +488,1.3766860961914062 +489,1.5183334350585938 +490,1.8641853332519531 +491,2.907207489013672 +492,2.8539066314697266 +493,2.224214553833008 +494,4.172332763671875 +495,2.781158447265625 +496,1.2123565673828125 +497,2.5675811767578125 +498,1.9204216003417969 +499,3.5189361572265625 +500,0.8372611999511719 +501,3.3610305786132812 +502,2.4884376525878906 +503,2.8520660400390625 +504,2.9636001586914062 +505,0.9283828735351562 +506,2.7838211059570312 +507,1.3746414184570312 +508,3.7743759155273438 +509,2.39794921875 +510,5.595367431640625 +511,5.6392364501953125 +512,1.721435546875 +513,1.1995849609375 +514,2.213287353515625 +515,0.5738868713378906 +516,2.71051025390625 +517,1.9071578979492188 +518,1.8789024353027344 +519,2.816082000732422 +520,2.8899879455566406 +521,3.3472518920898438 +522,1.2237567901611328 +523,2.2347335815429688 +524,1.4557609558105469 +525,1.4267578125 +526,1.7025146484375 +527,2.2992019653320312 +528,4.513965606689453 +529,2.7763900756835938 +530,2.2656707763671875 +531,4.434913635253906 +532,2.70062255859375 +533,3.222929000854492 +534,1.4761962890625 +535,2.5851573944091797 +536,1.9381179809570312 +537,2.7728805541992188 +538,4.004940032958984 +539,2.4510154724121094 +540,3.8232421875 +541,3.688995361328125 +542,2.46881103515625 +543,3.64715576171875 +544,2.7339134216308594 +545,2.4784927368164062 +546,2.6420669555664062 +547,3.0912399291992188 +548,1.0083999633789062 +549,0.7198677062988281 +550,3.266571044921875 +551,0.6218490600585938 +552,3.2571983337402344 +553,2.7480716705322266 +554,2.1891021728515625 +555,1.7352447509765625 +556,7.654598236083984 +557,5.308502197265625 +558,3.0606536865234375 +559,3.6773757934570312 +560,2.2095947265625 +561,2.490966796875 +562,2.7908401489257812 +563,4.944637298583984 +564,3.6980762481689453 +565,0.799713134765625 +566,0.7726097106933594 +567,4.5063323974609375 +568,3.887918472290039 +569,3.3334732055664062 +570,1.0847396850585938 +571,4.8308258056640625 +572,4.020942687988281 +573,3.4620361328125 +574,2.176300048828125 +575,0.508270263671875 +576,0.40605926513671875 +577,0.9608879089355469 +578,4.364368438720703 +579,2.4404983520507812 +580,1.984304428100586 +581,3.3433990478515625 +582,2.0698928833007812 +583,3.8393402099609375 +584,1.7667045593261719 +585,1.3889541625976562 +586,3.0649337768554688 +587,2.6718292236328125 +588,2.5797119140625 +589,3.1608619689941406 +590,2.879131317138672 +591,0.5877914428710938 +592,1.3176078796386719 +593,0.5999565124511719 +594,2.8893566131591797 +595,1.660898208618164 +596,2.6410675048828125 +597,4.2732391357421875 +598,0.851531982421875 +599,3.4852752685546875 +600,1.6572265625 +601,3.6005859375 +602,0.8543777465820312 +603,0.6684913635253906 +604,2.5104446411132812 +605,1.4377632141113281 +606,1.2602081298828125 +607,3.561370849609375 +608,5.76171875 +609,0.68505859375 +610,3.462726593017578 +611,1.422576904296875 +612,2.0976028442382812 +613,2.9023895263671875 +614,3.3157958984375 +615,0.6492080688476562 +616,2.2681961059570312 +617,1.8462715148925781 +618,3.5692367553710938 +619,5.419681549072266 +620,2.9370269775390625 +621,5.6828460693359375 +622,1.737152099609375 +623,1.602386474609375 +624,0.746856689453125 +625,3.9183425903320312 +626,3.311798095703125 +627,1.9642677307128906 +628,3.844409942626953 +629,1.7015380859375 +630,2.332853317260742 +631,2.3421859741210938 +632,3.2565040588378906 +633,0.8407440185546875 +634,2.966064453125 +635,3.1615371704101562 +636,2.4944839477539062 +637,2.5630569458007812 +638,3.1669464111328125 +639,2.7607421875 +640,2.4143142700195312 +641,5.789052963256836 +642,3.057220458984375 +643,3.2291336059570312 +644,0.5886573791503906 +645,0.6615276336669922 +646,2.36865234375 +647,3.7959442138671875 +648,4.51214599609375 +649,1.4258270263671875 +650,4.4889984130859375 +651,2.864990234375 +652,1.502471923828125 +653,3.0164642333984375 +654,3.56158447265625 +655,3.042694091796875 +656,1.6647262573242188 +657,0.9025287628173828 +658,1.3644561767578125 +659,1.0993881225585938 +660,0.7659111022949219 +661,4.01934814453125 +662,2.541027069091797 +663,3.4581985473632812 +664,3.3175582885742188 +665,5.193931579589844 +666,2.0252456665039062 +667,2.5653457641601562 +668,1.87408447265625 +669,1.9904556274414062 +670,2.2085418701171875 +671,0.61199951171875 +672,3.2067089080810547 +673,1.8816394805908203 +674,5.542570114135742 +675,0.5315284729003906 +676,2.0016250610351562 +677,1.4842510223388672 +678,2.2986984252929688 +679,3.00762939453125 +680,1.9947052001953125 +681,3.6446876525878906 +682,2.6217994689941406 +683,1.7988967895507812 +684,3.031036376953125 +685,0.7153358459472656 +686,3.334339141845703 +687,0.7108306884765625 +688,1.2816200256347656 +689,1.7671051025390625 +690,2.9856719970703125 +691,1.5478477478027344 +692,1.0807342529296875 +693,3.3064193725585938 +694,0.9369049072265625 +695,0.66357421875 +696,2.5416908264160156 +697,5.779876708984375 +698,3.4454498291015625 +699,5.008609771728516 +700,3.0802230834960938 +701,3.5379295349121094 +702,4.722465515136719 +703,2.2318572998046875 +704,0.6578369140625 +705,2.223125457763672 +706,3.0366439819335938 +707,3.2618484497070312 +708,2.4131393432617188 +709,4.5950469970703125 +710,3.08062744140625 +711,2.605419158935547 +712,2.960582733154297 +713,2.996685028076172 +714,1.509735107421875 +715,1.1052627563476562 +716,0.8219146728515625 +717,1.942230224609375 +718,3.301952362060547 +719,3.3530731201171875 +720,4.0272674560546875 +721,3.260211944580078 +722,1.4404525756835938 +723,3.5693702697753906 +724,5.06439208984375 +725,2.8951263427734375 +726,2.647674560546875 +727,3.245880126953125 +728,4.684795379638672 +729,1.3997955322265625 +730,2.2204971313476562 +731,3.5927658081054688 +732,1.5444793701171875 +733,2.3971214294433594 +734,1.34686279296875 +735,0.7172088623046875 +736,3.4994773864746094 +737,2.089111328125 +738,6.775016784667969 +739,1.99200439453125 +740,2.3129653930664062 +741,4.1109619140625 +742,2.208709716796875 +743,3.71685791015625 +744,2.1857223510742188 +745,4.196357727050781 +746,1.5334548950195312 +747,5.4618072509765625 +748,1.7282867431640625 +749,2.693511962890625 +750,2.766704559326172 +751,4.1958160400390625 +752,4.5892486572265625 +753,1.6612091064453125 +754,5.5735626220703125 +755,1.5728797912597656 +756,0.8423728942871094 +757,0.7233734130859375 +758,2.7063446044921875 +759,3.910602569580078 +760,2.0574302673339844 +761,2.21533203125 +762,2.0984954833984375 +763,1.6245613098144531 +764,2.629138946533203 +765,2.523406982421875 +766,4.832759857177734 +767,0.5534858703613281 +768,4.025787353515625 +769,3.3492164611816406 +770,1.6903953552246094 +771,2.7217559814453125 +772,3.6035385131835938 +773,0.7267303466796875 +774,2.7258644104003906 +775,0.664794921875 +776,2.2199172973632812 +777,3.159778594970703 +778,2.4090499877929688 +779,1.7539825439453125 +780,2.5580062866210938 +781,2.5738601684570312 +782,1.1121826171875 +783,1.669921875 +784,2.601806640625 +785,4.244083404541016 +786,2.5022811889648438 +787,2.133392333984375 +788,4.2078857421875 +789,0.935394287109375 +790,0.6654891967773438 +791,2.7285995483398438 +792,1.06890869140625 +793,1.8234825134277344 +794,1.9538841247558594 +795,3.265777587890625 +796,5.0235137939453125 +797,4.219125747680664 +798,2.740842819213867 +799,2.019315719604492 +800,1.6578369140625 +801,2.8594970703125 +802,2.3069114685058594 +803,0.8043365478515625 +804,4.569328308105469 +805,3.020465850830078 +806,2.734027862548828 +807,1.5362548828125 +808,1.3424072265625 +809,1.3950881958007812 +810,2.55596923828125 +811,2.247222900390625 +812,1.9441699981689453 +813,1.72515869140625 +814,0.92816162109375 +815,2.6270751953125 +816,3.5499114990234375 +817,4.100822448730469 +818,1.6193313598632812 +819,4.028602600097656 +820,1.40875244140625 +821,1.700958251953125 +822,2.763885498046875 +823,4.5920257568359375 +824,2.3800392150878906 +825,0.6684722900390625 +826,1.7244644165039062 +827,0.7157058715820312 +828,1.4762954711914062 +829,3.9496231079101562 +830,1.1088714599609375 +831,3.2770729064941406 +832,2.113544464111328 +833,1.4053840637207031 +834,4.242767333984375 +835,3.085693359375 +836,1.6646347045898438 +837,1.5428619384765625 +838,1.8729248046875 +839,2.1006317138671875 +840,1.5426254272460938 +841,2.4358062744140625 +842,1.105804443359375 +843,2.0183258056640625 +844,2.1589813232421875 +845,1.6298904418945312 +846,3.618694305419922 +847,2.62042236328125 +848,2.1177101135253906 +849,1.5579643249511719 +850,4.641153335571289 +851,3.1571388244628906 +852,3.3424758911132812 +853,0.7713623046875 +854,2.4743270874023438 +855,6.051445007324219 +856,4.31634521484375 +857,0.8591461181640625 +858,4.4473419189453125 +859,1.3624038696289062 +860,1.5254631042480469 +861,0.9824981689453125 +862,2.3231887817382812 +863,1.666168212890625 +864,1.8040504455566406 +865,4.148765563964844 +866,1.8070945739746094 +867,2.983673095703125 +868,2.4856948852539062 +869,1.2143898010253906 +870,4.0149078369140625 +871,2.8915557861328125 +872,2.1803131103515625 +873,3.2860031127929688 +874,3.1650390625 +875,0.5963287353515625 +876,1.1540031433105469 +877,3.526172637939453 +878,1.3190460205078125 +879,0.6318016052246094 +880,2.2330856323242188 +881,1.3703231811523438 +882,2.4706382751464844 +883,2.9503707885742188 +884,4.2059326171875 +885,2.784942626953125 +886,2.27197265625 +887,0.7115211486816406 +888,2.415630340576172 +889,2.8975830078125 +890,1.611419677734375 +891,3.3005638122558594 +892,3.244609832763672 +893,4.608329772949219 +894,1.3816642761230469 +895,2.3633384704589844 +896,1.5139007568359375 +897,2.5995101928710938 +898,2.6572418212890625 +899,3.7224998474121094 +900,3.1385421752929688 +901,1.982940673828125 +902,1.7619056701660156 +903,3.04119873046875 +904,0.682769775390625 +905,2.91143798828125 +906,1.73614501953125 +907,1.3114356994628906 +908,2.529857635498047 +909,1.6997261047363281 +910,3.850330352783203 +911,2.8565673828125 +912,1.7781143188476562 +913,1.3589935302734375 +914,2.483814239501953 +915,3.1075439453125 +916,0.6572761535644531 +917,4.9974822998046875 +918,4.10919189453125 +919,3.1724166870117188 +920,0.80682373046875 +921,1.5555534362792969 +922,3.0896377563476562 +923,1.930419921875 +924,2.658203125 +925,2.85784912109375 +926,5.639961242675781 +927,0.7335128784179688 +928,3.0355072021484375 +929,2.7886390686035156 +930,3.3843994140625 +931,2.8586044311523438 +932,3.031139373779297 +933,5.3276824951171875 +934,1.4003448486328125 +935,1.14642333984375 +936,2.52325439453125 +937,1.7171630859375 +938,2.42291259765625 +939,3.911651611328125 +940,2.5111732482910156 +941,3.4689865112304688 +942,2.1542510986328125 +943,3.3801727294921875 +944,4.362659454345703 +945,4.00445556640625 +946,1.4822425842285156 +947,2.4615707397460938 +948,1.9614181518554688 +949,4.649784088134766 +950,1.73931884765625 +951,1.9136409759521484 +952,3.306222915649414 +953,1.6706218719482422 +954,6.530183792114258 +955,2.7719345092773438 +956,2.188152313232422 +957,5.286003112792969 +958,3.335968017578125 +959,1.7009239196777344 +960,2.5676422119140625 +961,2.1608810424804688 +962,6.158561706542969 +963,2.0723800659179688 +964,1.153411865234375 +965,1.3055076599121094 +966,2.72601318359375 +967,4.1406707763671875 +968,2.2024497985839844 +969,2.3154220581054688 +970,6.7754364013671875 +971,1.038095474243164 +972,5.451904296875 +973,6.805778503417969 +974,5.848621368408203 +975,2.8549118041992188 +976,2.46435546875 +977,3.0417861938476562 +978,5.574432373046875 +979,2.4814796447753906 +980,3.07537841796875 +981,1.3266258239746094 +982,3.3192214965820312 +983,3.3089828491210938 +984,2.109272003173828 +985,0.8027496337890625 +986,0.4718017578125 +987,3.2317256927490234 +988,1.7699203491210938 +989,2.112060546875 +990,0.9800643920898438 +991,4.264732360839844 +992,2.1623687744140625 +993,2.0734634399414062 +994,1.4274139404296875 +995,1.6518898010253906 +996,0.6370849609375 +997,1.6137313842773438 +998,2.902008056640625 +999,3.12066650390625 diff --git a/training/bf16_master_weight/logs/7b_loss_run/loss_comparison.png b/training/bf16_master_weight/logs/7b_loss_run/loss_comparison.png new file mode 100644 index 000000000..9394b1241 Binary files /dev/null and b/training/bf16_master_weight/logs/7b_loss_run/loss_comparison.png differ diff --git a/training/bf16_master_weight/plot_loss.py b/training/bf16_master_weight/plot_loss.py new file mode 100644 index 000000000..2ec0ec850 --- /dev/null +++ b/training/bf16_master_weight/plot_loss.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +""" +Plot loss curves from training logs. + +Usage: + python plot_loss.py --baseline logs/baseline_loss.csv --bf16 logs/bf16_loss.csv --output loss_comparison.png +""" + +import argparse +import csv +import matplotlib.pyplot as plt +import numpy as np + + +def load_loss_data(filepath): + """Load loss data from CSV file.""" + steps = [] + losses = [] + with open(filepath, 'r') as f: + reader = csv.DictReader(f) + for row in reader: + steps.append(int(row['step'])) + losses.append(float(row['loss'])) + return np.array(steps), np.array(losses) + + +def smooth_curve(values, window=10): + """Apply moving average smoothing.""" + if len(values) < window: + return values + smoothed = np.convolve(values, np.ones(window)/window, mode='valid') + # Pad the beginning to maintain length + padding = np.full(window-1, smoothed[0]) + return np.concatenate([padding, smoothed]) + + +def main(): + parser = argparse.ArgumentParser(description="Plot loss curves comparison") + parser.add_argument("--baseline", type=str, required=True, help="Baseline loss CSV file") + parser.add_argument("--bf16", type=str, required=True, help="BF16 low-precision loss CSV file") + parser.add_argument("--output", type=str, default="loss_comparison.png", help="Output plot file") + parser.add_argument("--smooth", type=int, default=20, help="Smoothing window size") + parser.add_argument("--title", type=str, default="Loss Comparison: Baseline vs BF16 Low-Precision", help="Plot title") + args = parser.parse_args() + + # Load data + baseline_steps, baseline_losses = load_loss_data(args.baseline) + bf16_steps, bf16_losses = load_loss_data(args.bf16) + + # Apply smoothing + baseline_smooth = smooth_curve(baseline_losses, args.smooth) + bf16_smooth = smooth_curve(bf16_losses, args.smooth) + + # Create plot + fig, ax = plt.subplots(figsize=(12, 6)) + + # Plot raw data with low alpha + ax.plot(baseline_steps, baseline_losses, alpha=0.2, color='blue') + ax.plot(bf16_steps, bf16_losses, alpha=0.2, color='orange') + + # Plot smoothed curves + ax.plot(baseline_steps, baseline_smooth, label='Baseline (fp32 master)', color='blue', linewidth=2) + ax.plot(bf16_steps, bf16_smooth, label='BF16 Low-Precision', color='orange', linewidth=2) + + ax.set_xlabel('Step', fontsize=12) + ax.set_ylabel('Loss', fontsize=12) + ax.set_title(args.title, fontsize=14) + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + + # Add final loss values as text + final_baseline = baseline_losses[-1] + final_bf16 = bf16_losses[-1] + textstr = f'Final Loss:\n Baseline: {final_baseline:.4f}\n BF16: {final_bf16:.4f}' + props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) + ax.text(0.98, 0.98, textstr, transform=ax.transAxes, fontsize=10, + verticalalignment='top', horizontalalignment='right', bbox=props) + + plt.tight_layout() + plt.savefig(args.output, dpi=150) + print(f"Plot saved to: {args.output}") + + # Print statistics + print(f"\nStatistics:") + print(f" Baseline - Final loss: {final_baseline:.4f}, Mean: {baseline_losses.mean():.4f}, Std: {baseline_losses.std():.4f}") + print(f" BF16 - Final loss: {final_bf16:.4f}, Mean: {bf16_losses.mean():.4f}, Std: {bf16_losses.std():.4f}") + + +if __name__ == "__main__": + main() diff --git a/training/bf16_master_weight/run_comparison.sh b/training/bf16_master_weight/run_comparison.sh new file mode 100755 index 000000000..760e507d0 --- /dev/null +++ b/training/bf16_master_weight/run_comparison.sh @@ -0,0 +1,112 @@ +#!/bin/bash +# Run memory comparison across different bf16 configurations +# +# Usage: +# ./run_comparison.sh [--num_gpus N] [--num_layers L] [--hidden_dim H] +# +# This script runs training with two configurations: +# 1. baseline - Standard bf16 with fp32 master weights/grads/optimizer states +# 2. bf16_full - bf16 master weights, gradients, and optimizer states + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Default settings +NUM_GPUS=1 +NUM_LAYERS=12 +HIDDEN_DIM=1024 +BATCH_SIZE=4 +SEQ_LENGTH=512 +NUM_STEPS=20 +WARMUP_STEPS=5 +MASTER_PORT=29600 + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + --num_gpus) + NUM_GPUS="$2" + shift 2 + ;; + --num_layers) + NUM_LAYERS="$2" + shift 2 + ;; + --hidden_dim) + HIDDEN_DIM="$2" + shift 2 + ;; + --batch_size) + BATCH_SIZE="$2" + shift 2 + ;; + --seq_length) + SEQ_LENGTH="$2" + shift 2 + ;; + --num_steps) + NUM_STEPS="$2" + shift 2 + ;; + --master_port) + MASTER_PORT="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Create logs directory +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +LOG_DIR="logs/${TIMESTAMP}" +mkdir -p "$LOG_DIR" + +echo "==============================================" +echo "BF16 Low-Precision Master Weights Memory Test" +echo "==============================================" +echo "Configuration:" +echo " NUM_GPUS: $NUM_GPUS" +echo " NUM_LAYERS: $NUM_LAYERS" +echo " HIDDEN_DIM: $HIDDEN_DIM" +echo " BATCH_SIZE: $BATCH_SIZE" +echo " SEQ_LENGTH: $SEQ_LENGTH" +echo " NUM_STEPS: $NUM_STEPS" +echo " LOG_DIR: $LOG_DIR" +echo "==============================================" + +COMMON_ARGS="--num_layers $NUM_LAYERS --hidden_dim $HIDDEN_DIM --batch_size $BATCH_SIZE --seq_length $SEQ_LENGTH --num_steps $NUM_STEPS --warmup_steps $WARMUP_STEPS" + +# Run baseline configuration +echo "" +echo "[1/2] Running BASELINE (bf16 with fp32 master weights/grads/optimizer states)..." +echo "----------------------------------------------" +deepspeed --num_gpus=$NUM_GPUS --master_port=$MASTER_PORT train.py \ + --deepspeed_config configs/baseline.json \ + $COMMON_ARGS \ + 2>&1 | tee "$LOG_DIR/baseline.log" + +# Run bf16_full configuration +echo "" +echo "[2/2] Running BF16_FULL (bf16 master weights, gradients, and optimizer states)..." +echo "----------------------------------------------" +deepspeed --num_gpus=$NUM_GPUS --master_port=$MASTER_PORT train.py \ + --deepspeed_config configs/bf16_full.json \ + $COMMON_ARGS \ + 2>&1 | tee "$LOG_DIR/bf16_full.log" + +echo "" +echo "==============================================" +echo "All runs complete. Gathering results..." +echo "==============================================" + +# Run the gather script to create summary +python gather_memory.py --log_dir "$LOG_DIR" + +echo "" +echo "Results saved to: $LOG_DIR/summary.txt" +echo "==============================================" diff --git a/training/bf16_master_weight/train.py b/training/bf16_master_weight/train.py new file mode 100644 index 000000000..230c95367 --- /dev/null +++ b/training/bf16_master_weight/train.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +""" +Training script demonstrating bf16_master_weights_and_grads and bf16_optimizer_states +options in DeepSpeed for reduced memory usage. + +Usage: + deepspeed --num_gpus=1 train.py --deepspeed_config configs/baseline.json + deepspeed --num_gpus=1 train.py --deepspeed_config configs/bf16_master_wg.json + deepspeed --num_gpus=1 train.py --deepspeed_config configs/bf16_full.json +""" + +import argparse +import time +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +import deepspeed +from deepspeed.accelerator import get_accelerator +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + + +class SimpleTransformerBlock(nn.Module): + """A simple transformer block for demonstration.""" + + def __init__(self, hidden_dim, num_heads=8, ff_dim=None): + super().__init__() + ff_dim = ff_dim or hidden_dim * 4 + + self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True) + self.norm1 = nn.LayerNorm(hidden_dim) + self.norm2 = nn.LayerNorm(hidden_dim) + + self.ff = nn.Sequential( + nn.Linear(hidden_dim, ff_dim), + nn.GELU(), + nn.Linear(ff_dim, hidden_dim), + ) + + def forward(self, x): + # Self-attention with residual + attn_out, _ = self.attention(x, x, x) + x = self.norm1(x + attn_out) + + # Feed-forward with residual + ff_out = self.ff(x) + x = self.norm2(x + ff_out) + + return x + + +class SimpleTransformerModel(nn.Module): + """A simple transformer model for memory benchmarking.""" + + def __init__(self, vocab_size=50000, hidden_dim=1024, num_layers=12, num_heads=16, max_seq_len=512): + super().__init__() + self.hidden_dim = hidden_dim + self.activation_checkpointing = False + + self.embedding = nn.Embedding(vocab_size, hidden_dim) + self.pos_embedding = nn.Embedding(max_seq_len, hidden_dim) + + self.layers = nn.ModuleList([ + SimpleTransformerBlock(hidden_dim, num_heads) for _ in range(num_layers) + ]) + + self.output_proj = nn.Linear(hidden_dim, vocab_size) + + def enable_activation_checkpointing(self): + self.activation_checkpointing = True + + def forward(self, input_ids): + batch_size, seq_len = input_ids.shape + positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1) + + x = self.embedding(input_ids) + self.pos_embedding(positions) + + for layer in self.layers: + if self.activation_checkpointing: + x = checkpoint(layer, x, use_reentrant=False) + else: + x = layer(x) + + logits = self.output_proj(x) + return logits + + +def get_args(): + parser = argparse.ArgumentParser(description="BF16 Low-Precision Master Weights Demo") + + # Model configuration + parser.add_argument("--hidden_dim", type=int, default=1024, help="Hidden dimension size") + parser.add_argument("--num_layers", type=int, default=12, help="Number of transformer layers") + parser.add_argument("--num_heads", type=int, default=16, help="Number of attention heads") + parser.add_argument("--vocab_size", type=int, default=50000, help="Vocabulary size") + + # Training configuration + parser.add_argument("--batch_size", type=int, default=4, help="Batch size per GPU") + parser.add_argument("--seq_length", type=int, default=512, help="Sequence length") + parser.add_argument("--num_steps", type=int, default=20, help="Number of training steps") + parser.add_argument("--warmup_steps", type=int, default=5, help="Warmup steps before measuring") + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") + + # DeepSpeed configuration + parser.add_argument("--deepspeed_config", type=str, required=True, help="Path to DeepSpeed config") + parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training") + + # Logging + parser.add_argument("--log_interval", type=int, default=5, help="Log interval") + + # Activation checkpointing + parser.add_argument("--activation_checkpointing", action="store_true", help="Enable activation checkpointing") + + # Loss logging + parser.add_argument("--loss_log_file", type=str, default=None, help="File to save loss values for plotting") + + # Dataset + parser.add_argument("--use_real_data", action="store_true", help="Use wikitext dataset instead of random data") + parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") + + return parser.parse_args() + + +def load_wikitext_data(tokenizer_name, seq_length, batch_size, world_size, rank): + """Load wikitext dataset for training.""" + from datasets import load_dataset + from transformers import AutoTokenizer + + print(f"[Rank {rank}] Loading wikitext dataset...") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Load dataset + dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train[:5%]') + + # Filter empty texts + dataset = dataset.filter(lambda x: len(x['text'].strip()) > 50) + + # Tokenize + def tokenize_function(examples): + return tokenizer( + examples['text'], + padding='max_length', + max_length=seq_length, + truncation=True, + return_tensors=None + ) + + tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['text']) + tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask']) + + # Create distributed sampler and dataloader + sampler = DistributedSampler(tokenized_dataset, num_replicas=world_size, rank=rank, shuffle=True) + dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, sampler=sampler, num_workers=2) + + print(f"[Rank {rank}] Dataset loaded: {len(tokenized_dataset)} examples, vocab_size={tokenizer.vocab_size}") + return dataloader, tokenizer.vocab_size + + +def count_parameters(model): + """Count total and trainable parameters.""" + total = sum(p.numel() for p in model.parameters()) + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + return total, trainable + + +def get_memory_stats(): + """Get current GPU memory statistics.""" + if not torch.cuda.is_available(): + return {"allocated": 0, "reserved": 0, "peak": 0} + + return { + "allocated": torch.cuda.memory_allocated(), + "reserved": torch.cuda.memory_reserved(), + "peak": torch.cuda.max_memory_allocated(), + } + + +def format_memory(bytes_val): + """Format bytes to human readable string.""" + gb = bytes_val / (1024 ** 3) + return f"{gb:.2f} GB" + + +def main(): + args = get_args() + + # Initialize distributed + deepspeed.init_distributed() + local_rank = args.local_rank + if local_rank == -1: + local_rank = int(deepspeed.comm.get_rank()) + + world_size = deepspeed.comm.get_world_size() + + device = get_accelerator().device_name(local_rank) + torch.cuda.set_device(local_rank) + + # Set seed for reproducibility + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + # Reset memory stats + torch.cuda.reset_peak_memory_stats() + + # Load real data if requested + dataloader = None + actual_vocab_size = args.vocab_size + if args.use_real_data: + dataloader, actual_vocab_size = load_wikitext_data( + tokenizer_name="gpt2", + seq_length=args.seq_length, + batch_size=args.batch_size, + world_size=world_size, + rank=local_rank + ) + + # Create model + print(f"[Rank {local_rank}] Creating model with hidden_dim={args.hidden_dim}, " + f"num_layers={args.num_layers}, num_heads={args.num_heads}, vocab_size={actual_vocab_size}") + + model = SimpleTransformerModel( + vocab_size=actual_vocab_size, + hidden_dim=args.hidden_dim, + num_layers=args.num_layers, + num_heads=args.num_heads, + max_seq_len=args.seq_length, + ) + + total_params, trainable_params = count_parameters(model) + print(f"[Rank {local_rank}] Model parameters: {total_params:,} total, {trainable_params:,} trainable") + + # Enable activation checkpointing if requested + if args.activation_checkpointing: + model.enable_activation_checkpointing() + print(f"[Rank {local_rank}] Activation checkpointing enabled") + + # Read config to check if torch_autocast is enabled + import json + with open(args.deepspeed_config, 'r') as f: + ds_config = json.load(f) + + use_autocast = ds_config.get("torch_autocast", {}).get("enabled", False) + autocast_dtype_str = ds_config.get("torch_autocast", {}).get("dtype", "torch.bfloat16") + autocast_dtype = torch.bfloat16 if "bfloat16" in autocast_dtype_str else torch.float16 + + # Initialize DeepSpeed - use config file path directly (not via args to avoid conflict) + model_engine, optimizer, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=args.deepspeed_config, + ) + + print(f"[Rank {local_rank}] DeepSpeed initialized with config: {args.deepspeed_config}") + + mem_after_init = get_memory_stats() + print(f"[Rank {local_rank}] Memory after init: allocated={format_memory(mem_after_init['allocated'])}, " + f"reserved={format_memory(mem_after_init['reserved'])}") + + # Training loop + model_engine.train() + loss_fn = nn.CrossEntropyLoss() + + total_time = 0 + step_times = [] + loss_history = [] + + # Setup data iterator + if dataloader is not None: + data_iter = iter(dataloader) + + for step in range(args.num_steps): + start_time = time.time() + + # Get input data + if dataloader is not None: + try: + batch = next(data_iter) + except StopIteration: + data_iter = iter(dataloader) + batch = next(data_iter) + input_ids = batch['input_ids'].to(device) + labels = input_ids.clone() # For language modeling, labels = input_ids + else: + # Generate random input data + input_ids = torch.randint(0, actual_vocab_size, (args.batch_size, args.seq_length), device=device) + labels = torch.randint(0, actual_vocab_size, (args.batch_size, args.seq_length), device=device) + + # Forward pass with optional autocast + if use_autocast: + with torch.autocast(device_type="cuda", dtype=autocast_dtype): + logits = model_engine(input_ids) + loss = loss_fn(logits.view(-1, actual_vocab_size), labels.view(-1)) + else: + logits = model_engine(input_ids) + loss = loss_fn(logits.view(-1, actual_vocab_size), labels.view(-1)) + + # Backward pass - use PyTorch-style backward + loss.backward() + + # Optimizer step + model_engine.step() + + step_time = time.time() - start_time + + if step >= args.warmup_steps: + step_times.append(step_time) + + # Record loss for plotting + loss_history.append((step, loss.item())) + + if step % args.log_interval == 0 or step == args.num_steps - 1: + mem_stats = get_memory_stats() + print(f"[Rank {local_rank}] Step {step}: loss={loss.item():.4f}, " + f"time={step_time:.3f}s, " + f"alloc_mem={format_memory(mem_stats['allocated'])}, " + f"peak_mem={format_memory(mem_stats['peak'])}") + + # Final statistics + final_mem = get_memory_stats() + avg_step_time = sum(step_times) / len(step_times) if step_times else 0 + + print("\n" + "=" * 60) + print(f"[Rank {local_rank}] FINAL RESULTS") + print(f" Config: {args.deepspeed_config}") + print(f" Model: hidden_dim={args.hidden_dim}, num_layers={args.num_layers}") + print(f" Parameters: {total_params:,}") + print(f" Batch size: {args.batch_size}, Seq length: {args.seq_length}") + print(f" Average step time: {avg_step_time:.3f}s") + print(f" Peak memory: {format_memory(final_mem['peak'])}") + print(f" Final allocated memory: {format_memory(final_mem['allocated'])}") + print("=" * 60) + + # Output machine-readable summary line for parsing + print(f"SUMMARY: config={args.deepspeed_config} params={total_params} " + f"peak_mem_bytes={final_mem['peak']} alloc_mem_bytes={final_mem['allocated']} " + f"avg_step_time={avg_step_time:.4f}") + + # Save loss history to file if requested (only rank 0) + if args.loss_log_file and local_rank == 0: + import csv + with open(args.loss_log_file, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['step', 'loss']) + writer.writerows(loss_history) + print(f"Loss history saved to: {args.loss_log_file}") + + model_engine.destroy() + + +if __name__ == "__main__": + main()