Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion miles/utils/external_utils/command_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import random
import time
from dataclasses import dataclass
from functools import partial
from pathlib import Path

from miles.utils.misc import exec_command, exec_command_all_ray_node
Expand All @@ -23,6 +24,7 @@ def convert_checkpoint(
megatron_model_type,
num_gpus_per_node: int,
multinode: bool = False,
num_nodes: int | None = None,
extra_args: str = "",
dir_dst: str = "/root",
hf_checkpoint: str | None = None,
Expand All @@ -43,7 +45,10 @@ def convert_checkpoint(
"--master-addr {{master_addr}} " "--master-port 23456 " "--nnodes={{nnodes}} " "--node-rank {{node_rank}} "
)

fn = exec_command_all_ray_node if multinode else exec_command
if multinode:
fn = partial(exec_command_all_ray_node, num_nodes=num_nodes)
else:
fn = exec_command
fn(
f"source {repo_base_dir}/scripts/models/{megatron_model_type}.sh && "
f"PYTHONPATH={megatron_path} "
Expand Down
13 changes: 11 additions & 2 deletions miles/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,19 @@ def _exec_command_on_node(cmd: str, capture_output: bool) -> str | None:
return exec_command(f"unset CUDA_VISIBLE_DEVICES; {cmd}", capture_output=capture_output)


def exec_command_all_ray_node(cmd: str, capture_output: bool = False) -> list[str | None]:
def exec_command_all_ray_node(
cmd: str, capture_output: bool = False, num_nodes: int | None = None
) -> list[str | None]:
"""Execute a shell command on every alive Ray node in parallel.

Supported placeholders in `cmd` (replaced per-node before execution):
{{node_rank}} - 0-based index of the node
{{nnodes}} - total number of alive nodes
{{nnodes}} - total number of alive nodes (or num_nodes if specified)
{{master_addr}} - NodeManagerAddress of the first node
{{node_ip}} - NodeManagerAddress of the current node

Args:
num_nodes: If set, only use the first `num_nodes` nodes instead of all alive nodes.
"""
ray.init(address="auto")
try:
Expand All @@ -119,6 +124,10 @@ def exec_command_all_ray_node(cmd: str, capture_output: bool = False) -> list[st
)
assert len(nodes) > 0

if num_nodes is not None:
assert num_nodes <= len(nodes), f"Requested {num_nodes} nodes but only {len(nodes)} alive nodes available."
nodes = nodes[:num_nodes]

master_addr = nodes[0]["NodeManagerAddress"]
nnodes = str(len(nodes))

Expand Down
13 changes: 10 additions & 3 deletions scripts/run_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,19 @@ def _prepare_bf16_ckpt(args: ScriptArgs):
def _prepare_megatron_ckpt(args: ScriptArgs):
# TODO unify 5layer w/ 20layer, also maybe unify the whole script
extra_args = "--tensor-model-parallel-size 1 " "--expert-tensor-parallel-size 1 "
if args.num_nodes == 1 and args.model_name == "DeepSeek-V3-0324-5layer":
num_gpus_per_node = args.num_gpus_per_node
multinode = True
num_nodes = None
if args.model_name == "DeepSeek-V3-0324-5layer":
extra_args += "--pipeline-model-parallel-size 1 " "--expert-model-parallel-size 1 "
num_gpus_per_node = min(4, num_gpus_per_node)
multinode = False
elif args.model_name == "DeepSeek-V3-0324-20layer":
extra_args += (
"--expert-model-parallel-size 4 "
# PP info will be auto determined by converter script
)
num_nodes = 2
else:
extra_args += (
"--pipeline-model-parallel-size 8 "
Expand All @@ -77,8 +83,9 @@ def _prepare_megatron_ckpt(args: ScriptArgs):
model_name=args.model_name,
hf_checkpoint=f"{args.model_dir}/{args.model_name}-bf16",
megatron_model_type=args.megatron_model_type,
num_gpus_per_node=args.num_gpus_per_node,
multinode=True,
num_gpus_per_node=num_gpus_per_node,
multinode=multinode,
num_nodes=num_nodes,
extra_args=extra_args,
dir_dst=args.model_dir,
megatron_path=args.megatron_path,
Expand Down