Skip to content

Commit 415c8e9

Browse files
HosseinKaviani-HHossein Kavianihamedani
andauthored
Revert "Auto-detect NCCL network configuration - SLURM (#565)" (#570)
Co-authored-by: Hossein Kavianihamedani <hosseinkh@fb.com>
1 parent f99d7aa commit 415c8e9

File tree

3 files changed

+6
-28
lines changed

3 files changed

+6
-28
lines changed

src/forge/controller/provisioner.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -134,23 +134,6 @@ async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
134134
await env_setter.set_env.call(env_vars)
135135

136136

137-
def get_nccl_env_vars() -> dict[str, str]:
138-
"""Get NCCL environment variables by detecting network interfaces."""
139-
if "NCCL_SOCKET_IFNAME" in os.environ and "NCCL_IB_DISABLE" in os.environ:
140-
return {}
141-
142-
try:
143-
interfaces = os.listdir("/sys/class/net/")
144-
ib_interfaces = [i for i in interfaces if i.startswith("ib")]
145-
146-
return {
147-
"NCCL_SOCKET_IFNAME": ",".join(ib_interfaces) if ib_interfaces else "^lo",
148-
"NCCL_IB_DISABLE": "0" if ib_interfaces else "1",
149-
}
150-
except Exception:
151-
return {"NCCL_SOCKET_IFNAME": "^lo", "NCCL_IB_DISABLE": "1"}
152-
153-
154137
class GpuManager:
155138
"""Tracks and assigns GPU devices on a host.
156139
@@ -364,16 +347,11 @@ async def get_proc_mesh(
364347
if with_gpus:
365348
if not addr or not port:
366349
addr, port = await get_remote_info(host_mesh)
367-
gpu_ids: list[str] = gpu_manager.get_gpus(num_procs)
350+
gpu_ids = gpu_manager.get_gpus(num_procs)
368351

369-
# Set PyTorch distributed environment variables
370352
env_vars["MASTER_ADDR"] = addr
371353
env_vars["MASTER_PORT"] = port
372354

373-
# Get NCCL-specific environment variables
374-
nccl_vars = await get_nccl_env_vars()
375-
env_vars.update(nccl_vars)
376-
377355
# Set the PTD world size
378356
world_size = num_procs * (num_hosts or 1)
379357
env_vars["WORLD_SIZE"] = str(world_size)

tests/integration_tests/test_titan_fwd_vs_hf_fwd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def compare_logits(
213213
hf_val = hf_logits_cpu[pos].item()
214214
diff_val = abs_diff[pos].item()
215215
print(
216-
f" {i + 1}. Position {pos}: titan={titan_val:.6f}, hf={hf_val:.6f}, diff={diff_val:.6f}"
216+
f" {i+1}. Position {pos}: titan={titan_val:.6f}, hf={hf_val:.6f}, diff={diff_val:.6f}"
217217
)
218218

219219
return metrics
@@ -242,12 +242,12 @@ def compare_probabilities(
242242
zip(titan_top_k.values, titan_top_k.indices)
243243
):
244244
token = tokenizer.decode([token_id.item()])
245-
print(f" {i + 1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
245+
print(f" {i+1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
246246

247247
print("\nHugging Face Top-K:")
248248
for i, (prob, token_id) in enumerate(zip(hf_top_k.values, hf_top_k.indices)):
249249
token = tokenizer.decode([token_id.item()])
250-
print(f" {i + 1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
250+
print(f" {i+1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
251251

252252
# Calculate overlap in top-k predictions
253253
titan_top_tokens = set(titan_top_k.indices.tolist())

tests/unit_tests/datasets/test_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,10 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file):
231231
# But should contain the same set of IDs
232232
assert set(first_epoch_ids) == set(
233233
range(SMALL_DATASET_SIZE)
234-
), f"First epoch samples should be (0-{SMALL_DATASET_SIZE - 1}), got {first_epoch_ids}"
234+
), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_ids}"
235235
assert set(second_epoch_ids) == set(
236236
range(SMALL_DATASET_SIZE)
237-
), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE - 1}), got {second_epoch_ids}"
237+
), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_ids}"
238238

239239
def test_epoch_tracking(self, dataset_factory, small_dataset_file):
240240
"""Test that epoch number is correctly tracked across dataset restarts."""

0 commit comments

Comments
 (0)