Skip to content

Commit 6e2085e

Browse files
authored
Add numerics comparison script (#259)
stack-info: PR: #259, branch: xmfan/stack/23
1 parent 10d8208 commit 6e2085e

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

examples/run_ds3_numerics_check.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Script to run DS3 numerics check by comparing outputs from local_map and pipeline parallel.
8+
"""
9+
import shutil
10+
import subprocess
11+
import tempfile
12+
import warnings
13+
from pathlib import Path
14+
15+
16+
def run_command(cmd, cwd):
17+
"""Run a shell command in the specified directory."""
18+
print(f"Running: {cmd}")
19+
print(f"In directory: {cwd}")
20+
result = subprocess.run(cmd, shell=True, cwd=cwd, capture_output=True, text=True)
21+
print(result.stdout)
22+
if result.stderr:
23+
print("STDERR:", result.stderr)
24+
if result.returncode != 0:
25+
warnings.warn(f"Command failed with return code {result.returncode}")
26+
return result
27+
28+
29+
def main(args):
30+
schedule_name = args.schedule_name
31+
32+
# Create a temporary directory
33+
temp_dir = tempfile.mkdtemp(prefix="ds3_numerics_check_")
34+
print(f"Created temporary directory: {temp_dir}")
35+
36+
try:
37+
examples_dir = Path(__file__).parent
38+
39+
print("\n" + "=" * 80)
40+
print("Running non-PP example with 4 GPUs...")
41+
print("=" * 80)
42+
cmd1 = f"torchrun --standalone --nproc-per-node 4 {examples_dir}/example_ds3_local_map.py --rng-seed 42"
43+
run_command(cmd1, temp_dir)
44+
45+
print("\n" + "=" * 80)
46+
print("Running PP example with 8 GPUs...")
47+
print("=" * 80)
48+
cmd2 = f"torchrun --standalone --nproc-per-node 8 {examples_dir}/example_ds3_pp.py --rng-seed 42 --schedule-name={schedule_name}"
49+
run_command(cmd2, temp_dir)
50+
51+
out_dir = Path(temp_dir) / "out"
52+
if not out_dir.exists():
53+
raise RuntimeError(f"Output directory {out_dir} does not exist")
54+
55+
print("\n" + "=" * 80)
56+
print("Comparing weights.log files...")
57+
print("=" * 80)
58+
run_command("diff out/0/weights.log out/1/pp_weights.log", temp_dir)
59+
60+
print("\n" + "=" * 80)
61+
print("Comparing diff.log files...")
62+
print("=" * 80)
63+
run_command("diff out/0/diff.log out/1/diff.log", temp_dir)
64+
65+
print("\n" + "=" * 80)
66+
print("Numerics check completed successfully!")
67+
print(f"Output directory: {temp_dir}/out")
68+
print("=" * 80)
69+
70+
except Exception as e:
71+
print(f"\nError occurred: {e}")
72+
print(f"Temporary directory preserved at: {temp_dir}")
73+
raise
74+
75+
print(f"\nTemporary directory location: {temp_dir}")
76+
response = input("Do you want to delete the temporary directory? (y/n): ")
77+
if response.lower() == "y":
78+
shutil.rmtree(temp_dir)
79+
print("Temporary directory deleted.")
80+
else:
81+
print(f"Temporary directory preserved at: {temp_dir}")
82+
83+
84+
if __name__ == "__main__":
85+
import argparse
86+
87+
parser = argparse.ArgumentParser(
88+
description="Run DeepSeek V3 pipeline parallel example"
89+
)
90+
parser.add_argument(
91+
"--schedule-name",
92+
type=str,
93+
default="ZBVZeroBubble",
94+
help="Schedule to use for PP",
95+
)
96+
args = parser.parse_args()
97+
main(args)

0 commit comments

Comments
 (0)