Skip to content

Commit 852996f

Browse files
committed
Add numerics comparison script
stack-info: PR: #259, branch: xmfan/stack/23
1 parent b8546e1 commit 852996f

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

examples/run_ds3_numerics_check.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
Script to run DS3 numerics check by comparing outputs from local_map and pipeline parallel.
3+
"""
4+
import shutil
5+
import subprocess
6+
import tempfile
7+
import warnings
8+
from pathlib import Path
9+
10+
11+
def run_command(cmd, cwd):
12+
"""Run a shell command in the specified directory."""
13+
print(f"Running: {cmd}")
14+
print(f"In directory: {cwd}")
15+
result = subprocess.run(cmd, shell=True, cwd=cwd, capture_output=True, text=True)
16+
print(result.stdout)
17+
if result.stderr:
18+
print("STDERR:", result.stderr)
19+
if result.returncode != 0:
20+
warnings.warn(f"Command failed with return code {result.returncode}")
21+
return result
22+
23+
24+
def main():
25+
# Create a temporary directory
26+
temp_dir = tempfile.mkdtemp(prefix="ds3_numerics_check_")
27+
print(f"Created temporary directory: {temp_dir}")
28+
29+
try:
30+
examples_dir = Path(__file__).parent
31+
32+
print("\n" + "=" * 80)
33+
print("Running local_map example with 4 GPUs...")
34+
print("=" * 80)
35+
cmd1 = f"torchrun --standalone --nproc-per-node 4 {examples_dir}/example_ds3_local_map.py --rng-seed 42"
36+
run_command(cmd1, temp_dir)
37+
38+
print("\n" + "=" * 80)
39+
print("Running pipeline parallel example with 8 GPUs...")
40+
print("=" * 80)
41+
cmd2 = f"torchrun --standalone --nproc-per-node 8 {examples_dir}/example_ds3_pp.py --rng-seed 42"
42+
run_command(cmd2, temp_dir)
43+
44+
out_dir = Path(temp_dir) / "out"
45+
if not out_dir.exists():
46+
raise RuntimeError(f"Output directory {out_dir} does not exist")
47+
48+
print("\n" + "=" * 80)
49+
print("Comparing weights.log files...")
50+
print("=" * 80)
51+
run_command("diff out/0/weights.log out/1/pp_weights.log", temp_dir)
52+
53+
print("\n" + "=" * 80)
54+
print("Comparing diff.log files...")
55+
print("=" * 80)
56+
run_command("diff out/0/diff.log out/1/diff.log", temp_dir)
57+
58+
print("\n" + "=" * 80)
59+
print("Numerics check completed successfully!")
60+
print(f"Output directory: {temp_dir}/out")
61+
print("=" * 80)
62+
63+
except Exception as e:
64+
print(f"\nError occurred: {e}")
65+
print(f"Temporary directory preserved at: {temp_dir}")
66+
raise
67+
68+
print(f"\nTemporary directory location: {temp_dir}")
69+
response = input("Do you want to delete the temporary directory? (y/n): ")
70+
if response.lower() == "y":
71+
shutil.rmtree(temp_dir)
72+
print("Temporary directory deleted.")
73+
else:
74+
print(f"Temporary directory preserved at: {temp_dir}")
75+
76+
77+
if __name__ == "__main__":
78+
main()

0 commit comments

Comments
 (0)