Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions training/bf16_master_weight/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# BF16 Low-Precision Master Weights and Optimizer States

This example demonstrates DeepSpeed's [new low-precision training options](https://github.com/deepspeedai/DeepSpeed/pull/7700) that can significantly reduce memory usage:

- `bf16_master_weights_and_grads`: Keep master parameters and gradients in BF16 instead of FP32
- `bf16_optimizer_states`: Keep optimizer states (e.g., Adam moments) in BF16


### Running an Example

The following commands run training for 1000 steps on the Wikitext-103 dataset using both the baseline and BF16 low-precision configurations, then generates a loss comparison plot.
The model has approximately 6.86 billion parameters (hidden=4096, layers=32, heads=32, batch=1, seq=512).
For BF16 low-precision training, we use `torch.autocast`.


```bash
# Run 1000 steps with wikitext dataset
deepspeed --num_gpus=4 train.py --deepspeed_config configs/baseline.json \
--num_layers 32 --hidden_dim 4096 --num_heads 32 --batch_size 1 \
--num_steps 1000 --activation_checkpointing \
--loss_log_file logs/baseline_loss.csv --use_real_data --seed 42

deepspeed --num_gpus=4 train.py --deepspeed_config configs/bf16_full.json \
--num_layers 32 --hidden_dim 4096 --num_heads 32 --batch_size 1 \
--num_steps 1000 --activation_checkpointing \
--loss_log_file logs/bf16_full_loss.csv --use_real_data --seed 42

# Generate comparison plot
python plot_loss.py --baseline logs/baseline_loss.csv --bf16 logs/bf16_full_loss.csv \
--output loss_comparison.png
```

Here is a summary of the memory usage and training time results using 4xH100.
This shows a significant memory reduction: Memory reduction: 9.57 GB allocated (37%), 12.45 GB peak (39.7%)**.

| Configuration | Allocated Memory | Peak Memory | Avg Step Time |
|---------------|------------------|-------------|---------------|
| Baseline (fp32 master) | 25.74 GB | 31.38 GB | 0.6016s |
| BF16 low-precision (master + opt states) | **16.17 GB** | **18.93 GB** | 0.6427s |


## Loss Curve Comparison

To verify that BF16 low-precision training maintains numerical stability, we trained for 1000 steps on the Wikitext-103 dataset:

![Loss Comparison](logs/7b_loss_run/loss_comparison.png)

| Configuration | Final Loss | Mean Loss | Loss Std |
|---------------|------------|-----------|----------|
| Baseline (fp32 master) | 3.09 | 2.78 | 1.56 |
| BF16 Low-Precision | 3.12 | 2.90 | 2.37 |

The loss curves show that both configurations converge similarly, demonstrating that the reduced precision does not significantly impact training quality while providing substantial memory savings.

### Memory Breakdown

For a model with N parameters:

| Component | Baseline | BF16 Low-Precision |
|-----------|----------|-------------------|
| Model params | 2N bytes (BF16) | 2N bytes (BF16) |
| Master weights | 4N bytes (FP32) | 2N bytes (BF16) |
| Gradients | 4N bytes (FP32) | 2N bytes (BF16) |
| Adam momentum | 4N bytes (FP32) | 2N bytes (BF16) |
| Adam variance | 4N bytes (FP32) | 2N bytes (BF16) |
| **Total** | **18N bytes** | **10N bytes** |

This gives a theoretical ~44% reduction in optimizer-related memory. The actual savings depend on activation memory and other factors, but our results show a very close match to the theoretical savings.

## Related Resources

- [DeepSpeed BF16 Documentation](https://www.deepspeed.ai/docs/config-json/#bf16-training-options)
- [Low-precision master params PR](https://github.com/deepspeedai/DeepSpeed/pull/7700)
31 changes: 31 additions & 0 deletions training/bf16_master_weight/configs/baseline.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,

"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-4,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01
}
},

"bf16": {
"enabled": true
},

"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e7,
"stage3_param_persistence_threshold": 0
},

"torch_autocast": {
"enabled": false
}
}
34 changes: 34 additions & 0 deletions training/bf16_master_weight/configs/bf16_full.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,

"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-4,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01
}
},

"bf16": {
"enabled": true,
"bf16_master_weights_and_grads": true,
"bf16_optimizer_states": true
},

"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e7,
"stage3_param_persistence_threshold": 0
},

"torch_autocast": {
"enabled": true,
"dtype": "torch.bfloat16"
}
}
175 changes: 175 additions & 0 deletions training/bf16_master_weight/gather_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

"""
Script to gather and compare memory usage from training logs.

Usage:
python gather_memory.py --log_dir logs/20231201_120000
"""

import argparse
import os
import re
from pathlib import Path


def parse_summary_line(line):
"""Parse the SUMMARY line from log output."""
pattern = r"SUMMARY: config=(\S+) params=(\d+) peak_mem_bytes=(\d+) alloc_mem_bytes=(\d+) avg_step_time=(\S+)"
match = re.search(pattern, line)
if match:
return {
"config": match.group(1),
"params": int(match.group(2)),
"peak_mem_bytes": int(match.group(3)),
"alloc_mem_bytes": int(match.group(4)),
"avg_step_time": float(match.group(5)),
}
return None


def format_bytes(bytes_val):
"""Format bytes to human-readable string."""
gb = bytes_val / (1024 ** 3)
return f"{gb:.2f} GB"


def format_bytes_mb(bytes_val):
"""Format bytes to MB."""
mb = bytes_val / (1024 ** 2)
return f"{mb:.1f} MB"


def get_config_name(config_path):
"""Extract clean config name from path."""
name = Path(config_path).stem
if name == "baseline":
return "Baseline (fp32 master)"
elif name == "bf16_master_wg":
return "bf16_master_weights_and_grads"
elif name == "bf16_full":
return "bf16_full (master + opt states)"
return name


def main():
parser = argparse.ArgumentParser(description="Gather memory usage from training logs")
parser.add_argument("--log_dir", type=str, required=True, help="Directory containing log files")
parser.add_argument("--output", type=str, default=None, help="Output file for summary")
args = parser.parse_args()

log_dir = Path(args.log_dir)
if not log_dir.exists():
print(f"Error: Log directory '{log_dir}' does not exist")
return 1

# Find and parse all log files
results = []
log_files = ["baseline.log", "bf16_full.log"]

for log_file in log_files:
log_path = log_dir / log_file
if not log_path.exists():
print(f"Warning: Log file '{log_path}' not found, skipping")
continue

with open(log_path, "r") as f:
for line in f:
summary = parse_summary_line(line)
if summary:
results.append(summary)
break

if not results:
print("No results found in log files")
return 1

# Calculate baseline for comparison
baseline_peak = None
for r in results:
if "baseline" in r["config"]:
baseline_peak = r["peak_mem_bytes"]
break

# Generate summary
output_lines = []
output_lines.append("=" * 80)
output_lines.append("BF16 Low-Precision Master Weights - Memory Usage Comparison")
output_lines.append("=" * 80)
output_lines.append("")

# Table header
output_lines.append(f"{'Configuration':<40} {'Peak Memory':<15} {'Reduction':<15} {'Step Time':<12}")
output_lines.append("-" * 80)

for r in results:
config_name = get_config_name(r["config"])
peak_mem = format_bytes(r["peak_mem_bytes"])
step_time = f"{r['avg_step_time']:.4f}s"

if baseline_peak and baseline_peak > 0:
reduction = ((baseline_peak - r["peak_mem_bytes"]) / baseline_peak) * 100
reduction_str = f"{reduction:+.1f}%" if reduction != 0 else "-"
else:
reduction_str = "-"

output_lines.append(f"{config_name:<40} {peak_mem:<15} {reduction_str:<15} {step_time:<12}")

output_lines.append("-" * 80)
output_lines.append("")

# Detailed breakdown
output_lines.append("Detailed Results:")
output_lines.append("-" * 40)
for r in results:
config_name = get_config_name(r["config"])
output_lines.append(f"\n{config_name}:")
output_lines.append(f" Parameters: {r['params']:,}")
output_lines.append(f" Peak Memory: {format_bytes(r['peak_mem_bytes'])} ({r['peak_mem_bytes']:,} bytes)")
output_lines.append(f" Allocated Memory: {format_bytes(r['alloc_mem_bytes'])} ({r['alloc_mem_bytes']:,} bytes)")
output_lines.append(f" Avg Step Time: {r['avg_step_time']:.4f}s")

output_lines.append("")
output_lines.append("=" * 80)

# Generate markdown table
output_lines.append("")
output_lines.append("Markdown Table (for README):")
output_lines.append("-" * 40)
output_lines.append("")
output_lines.append("| Configuration | Peak Memory | Memory Reduction | Avg Step Time |")
output_lines.append("|---------------|-------------|------------------|---------------|")

for r in results:
config_name = get_config_name(r["config"])
peak_mem = format_bytes(r["peak_mem_bytes"])
step_time = f"{r['avg_step_time']:.4f}s"

if baseline_peak and baseline_peak > 0:
reduction = ((baseline_peak - r["peak_mem_bytes"]) / baseline_peak) * 100
reduction_str = f"{reduction:+.1f}%" if reduction != 0 else "-"
else:
reduction_str = "-"

output_lines.append(f"| {config_name} | {peak_mem} | {reduction_str} | {step_time} |")

output_lines.append("")

# Print to stdout
summary_text = "\n".join(output_lines)
print(summary_text)

# Save to file
output_path = args.output or (log_dir / "summary.txt")
with open(output_path, "w") as f:
f.write(summary_text)

print(f"\nSummary saved to: {output_path}")

return 0


if __name__ == "__main__":
exit(main())
Loading