diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1cdf221 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +.DS_Store + +# Python bytecode / caches +__pycache__/ +*.py[cod] + +# Local agent run artifacts +.agent_runs/ + +# Local virtualenvs +.venv/ +venv/ + diff --git a/api/README.md b/api/README.md new file mode 100644 index 0000000..22f882c --- /dev/null +++ b/api/README.md @@ -0,0 +1,112 @@ +# FERB Agentic API + +This API now supports: + +- `POST /chat` - GPT-backed kernel optimization chat +- `POST /optimize` - iterative agent loop for solution improvement +- `POST /optimize/stream` - live event stream of agent iterations + +## Setup + +```bash +cd /Users/rohk/FERB/api +python -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt +export OPENAI_API_KEY="your_key_here" +# If you plan to run evaluator on Modal (recommended for Triton backend): +modal token new +# Optional but recommended: point to Python with torch installed for evaluator runs +export FERB_EVAL_PYTHON="/absolute/path/to/python-with-torch" +uvicorn main:app --reload --host 0.0.0.0 --port 8000 +``` + +## `POST /chat` + +Request: + +```json +{ + "message": "How do I optimize problem 10 for fewer all-to-all rounds?", + "model": "gpt-4o-mini" +} +``` + +## `POST /optimize` + +Runs a generate -> (optional evaluate) -> select-best loop. + +Request: + +```json +{ + "objective": "Improve throughput while keeping correctness for problem 1 all-reduce", + "problem_id": 1, + "iterations": 3, + "model": "gpt-4o-mini", + "topology_json_path": "/Users/rohk/FERB/utils/example_topologies/nccl_topology_parsed.json", + "evaluator_command": "python /path/to/evaluator.py --candidate {candidate_path}", + "evaluator_timeout_s": 240 +} +``` + +Notes: + +- `evaluator_command` is optional. If omitted, the API still iterates and stores candidates. +- If used, your evaluator should print either: + - JSON: `{"score": 123.4}`, or + - text line: `score=123.4` +- Higher score is treated as better. + +Outputs are saved in: + +- `/Users/rohk/FERB/.agent_runs//candidate_iter_.py` + +### Real speedup evaluator (recommended) + +Use the provided distributed benchmark as evaluator: + +```bash +torchrun --nproc-per-node 8 /Users/rohk/FERB/scripts/benchmark_candidate.py \ + --problem 1 \ + --candidate {candidate_path} \ + --rows 1024 --cols 1024 --dtype float32 \ + --warmup 3 --iters 10 --score-only +``` + +For API usage, pass that whole string as `evaluator_command`. + +## `POST /optimize/stream` (live thinking/run events) + +This endpoint streams JSON events using Server-Sent Events (SSE), so you can watch: + +- run start +- iteration start +- candidate generated +- evaluation start/completed (if evaluator is configured) +- best update +- run completed + +Example: + +```bash +curl -N -X POST "http://127.0.0.1:8000/optimize/stream" \ + -H "Content-Type: application/json" \ + -d '{ + "objective": "Improve throughput for problem 1 while preserving correctness", + "problem_id": 1, + "iterations": 3, + "model": "gpt-4o-mini" + }' +``` + +You will receive events like: + +```text +data: {"type":"run_started", ...} +data: {"type":"iteration_started","iteration":1} +data: {"type":"candidate_generated","iteration":1,...} +data: {"type":"best_updated","iteration":1,...} +data: {"type":"run_completed","result":{...}} +``` + diff --git a/api/agentic.py b/api/agentic.py new file mode 100644 index 0000000..50272e3 --- /dev/null +++ b/api/agentic.py @@ -0,0 +1,913 @@ +""" +Agentic optimization loop for FERB kernel solutions. + +This module provides: +- GPT chat replies with FERB context +- Iterative candidate generation + optional evaluation +""" + +from __future__ import annotations + +import json +import os +import re +import shlex +import shutil +import subprocess +import selectors +import sys +import time +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Iterator +from typing import Any + +from openai import OpenAI + + +ROOT_DIR = Path(__file__).resolve().parent.parent +REFERENCE_DIR = ROOT_DIR / "reference" +TRITON_DIR = ROOT_DIR / "solutions_triton" +RUNS_DIR = ROOT_DIR / ".agent_runs" + + +SYSTEM_CHAT_PROMPT = """You are an expert multi-GPU systems and CUDA engineer working on FERB. +Focus on practical kernel optimization advice for distributed training speedups. +Keep answers concise and action-oriented. +""" + + +SYSTEM_OPTIMIZER_PROMPT = """You are an autonomous kernel optimization agent for FERB. +You are iteratively improving solution quality. +Always return valid Python code for a FERB solution module with a `solution(...)` function. +Do not include markdown fences. +""" + + +@dataclass +class IterationResult: + iteration: int + candidate_path: str + score: float | None + evaluator_stdout: str + evaluator_stderr: str + model_feedback: str + + +def _clip(text: str, limit: int) -> str: + if limit <= 0: + return text or "" + if len(text) <= limit: + return text + return text[:limit] + "\n... [truncated]" + + +def _openai_client() -> OpenAI: + api_key = os.environ.get("OPENAI_API_KEY", "").strip() + if not api_key: + raise RuntimeError("OPENAI_API_KEY is not set") + return OpenAI(api_key=api_key) + + +def _extract_python_code(text: str) -> str: + """ + Accept raw code or fenced markdown and return Python source. + """ + src = (text or "").strip() + fence = re.search(r"```(?:python)?\s*(.*?)```", src, flags=re.DOTALL | re.IGNORECASE) + if fence: + return fence.group(1).strip() + return src + + +def _extract_plan_and_code(text: str) -> tuple[str, str]: + src = (text or "").strip() + marker = "###CODE_START" + if marker in src: + head, tail = src.split(marker, 1) + plan = head.strip() + code = _extract_python_code(tail.strip()) + return plan, code + return "No explicit plan returned.", _extract_python_code(src) + + +def _read_problem_reference(problem_id: int) -> str: + path = REFERENCE_DIR / f"{problem_id}.py" + if not path.exists(): + return f"# Missing reference file for problem {problem_id}: {path}" + return path.read_text(encoding="utf-8") + + +def _read_triton_seed(problem_id: int) -> str: + path = TRITON_DIR / f"{problem_id}_triton.py" + if not path.exists(): + return "" + return path.read_text(encoding="utf-8") + + +def _read_problem_descriptions() -> str: + path = REFERENCE_DIR / "problems.md" + if not path.exists(): + return "" + return path.read_text(encoding="utf-8") + + +def gpt_chat_reply(message: str, model: str = "gpt-4o-mini") -> str: + """ + GPT-backed chat with FERB context. + """ + client = _openai_client() + response = client.responses.create( + model=model, + input=[ + {"role": "system", "content": SYSTEM_CHAT_PROMPT}, + {"role": "user", "content": message}, + ], + ) + return (response.output_text or "").strip() + + +def _generate_candidate( + *, + model: str, + objective: str, + problem_id: int, + iteration_idx: int, + target_backend: str, + previous_best_code: str | None, + previous_feedback: str | None, + previous_eval_feedback: str | None, + quality_feedback: str | None, + topology_json: str | None, +) -> tuple[str, str]: + """ + Returns (candidate_code, model_feedback). + """ + client = _openai_client() + reference_code = _read_problem_reference(problem_id) + triton_seed = _read_triton_seed(problem_id) + problem_docs = _read_problem_descriptions() + best_code_section = previous_best_code if previous_best_code else "# none yet" + feedback_section = previous_feedback if previous_feedback else "# first iteration" + eval_feedback_section = previous_eval_feedback if previous_eval_feedback else "# no evaluator feedback yet" + quality_feedback_section = quality_feedback if quality_feedback else "# no quality issues from last attempt" + topology_section = topology_json if topology_json else "{}" + + backend_requirements = ( + "Target backend is triton+nvshmem. Use Triton kernels and NVSHMEM APIs; avoid plain NCCL-only all_reduce wrappers." + if target_backend == "triton" + else "Target backend is reference/pytorch distributed baseline." + ) + + user_prompt = f""" +Objective: +{objective} + +Problem ID: +{problem_id} + +Iteration: +{iteration_idx} + +Problem notes: +{problem_docs} + +Reference implementation: +{reference_code} + +Current best candidate: +{best_code_section} + +Feedback from previous iteration: +{feedback_section} + +Topology JSON: +{topology_section} + +Task: +1) Write an improved solution module. +2) Keep function signature compatible with reference. +3) Prioritize correctness first, then performance. +4) Return with this exact format: +PLAN: <1-3 short bullets of what you'll try> +###CODE_START + + +Backend constraint: +{backend_requirements} + +Triton seed (if available): +{triton_seed if triton_seed else "# none"} + +Quality feedback from previous attempt: +{quality_feedback_section} + +Evaluator feedback from previous attempt (errors, logs, metrics): +{eval_feedback_section} +""" + + response = client.responses.create( + model=model, + input=[ + {"role": "system", "content": SYSTEM_OPTIMIZER_PROMPT}, + {"role": "user", "content": user_prompt}, + ], + ) + plan_text, code = _extract_plan_and_code(response.output_text) + return code, plan_text + + +def _parse_score(stdout: str) -> float | None: + """ + Accept score in either: + - JSON object: {"score": 123.4} + - plain text line: score=123.4 + """ + txt = (stdout or "").strip() + if not txt: + return None + + lines = [line.strip() for line in txt.splitlines() if line.strip()] + if not lines: + return None + + # Try JSON from last line then full text + for candidate in (lines[-1], txt): + try: + obj = json.loads(candidate) + if isinstance(obj, dict) and "score" in obj: + return float(obj["score"]) + except Exception: + pass + + # Try score=... + for line in reversed(lines): + m = re.search(r"score\s*=\s*([0-9]+(?:\.[0-9]+)?)", line, flags=re.IGNORECASE) + if m: + return float(m.group(1)) + + return None + + +def _normalize_evaluator_command(command: str) -> str: + """ + If torchrun is unavailable, replace it with: + python -m torch.distributed.run ... + """ + if shutil.which("torchrun"): + return command + try: + tokens = shlex.split(command) + except Exception: + return command + if not tokens: + return command + if tokens[0] != "torchrun": + return command + py = _find_python_with_torch() or (sys.executable or "python3") + rewritten = [py, "-m", "torch.distributed.run"] + tokens[1:] + return " ".join(shlex.quote(t) for t in rewritten) + + +def _python_has_torch(python_exe: str) -> bool: + try: + proc = subprocess.run( + [python_exe, "-c", "import torch; print(torch.__version__)"], + capture_output=True, + text=True, + timeout=10, + ) + return proc.returncode == 0 + except Exception: + return False + + +def _find_python_with_torch() -> str | None: + """ + Find a Python interpreter with torch installed. + Priority: + 1) FERB_EVAL_PYTHON env var + 2) python3 in PATH + 3) python in PATH + 4) current interpreter + """ + candidates: list[str] = [] + env_py = os.environ.get("FERB_EVAL_PYTHON", "").strip() + if env_py: + candidates.append(env_py) + for name in ("python3", "python"): + p = shutil.which(name) + if p: + candidates.append(p) + if sys.executable: + candidates.append(sys.executable) + + seen: set[str] = set() + for c in candidates: + if c in seen: + continue + seen.add(c) + if _python_has_torch(c): + return c + return None + + +def _candidate_quality_issues(code: str, target_backend: str) -> list[str]: + src = (code or "").lower() + issues: list[str] = [] + if not src.strip(): + issues.append("empty candidate code") + return issues + if "def solution" not in src: + issues.append("missing solution(...) function") + if target_backend == "triton": + if "import triton" not in src: + issues.append("missing Triton import") + if "nvshmem" not in src: + issues.append("missing NVSHMEM usage") + if "dist.all_reduce" in src and "triton" not in src: + issues.append("NCCL-only fallback detected; expected triton/nvshmem solution") + return issues + + +def _is_infra_blocker(stderr: str, target_backend: str) -> str | None: + s = (stderr or "").lower() + if "token missing" in s and "modal" in s: + return "Modal authentication token missing (modal CLI not logged in)." + if "could not authenticate client" in s and "modal" in s: + return "Modal authentication failed (check your modal token credentials)." + if target_backend == "triton": + if "no module named 'triton'" in s or "no module named triton" in s: + return "Evaluator environment is missing Triton." + if "no module named 'nvshmem'" in s or "no module named nvshmem" in s: + return "Evaluator environment is missing NVSHMEM bindings." + if "torch.cuda" in s and ("not available" in s or "not compiled" in s): + return "Evaluator environment has no CUDA." + if "torch.cuda.set_device" in s: + return "Evaluator environment cannot bind CUDA devices (likely no GPU)." + return None + + +def _rewrite_python_launcher(command: str, python_exe: str | None) -> str: + """ + If command begins with python/python3, replace with explicit interpreter. + """ + if not python_exe: + return command + try: + tokens = shlex.split(command) + except Exception: + return command + if not tokens: + return command + head = tokens[0] + if head in {"python", "python3"}: + tokens[0] = python_exe + return " ".join(shlex.quote(t) for t in tokens) + return command + + +def _run_evaluator( + command_template: str, + candidate_path: Path, + timeout_s: int, + evaluator_python: str | None = None, +) -> tuple[float, str, str]: + """ + Run evaluator command and parse score from stdout. + """ + command = command_template.format(candidate_path=str(candidate_path)) + py_with_torch = None + if evaluator_python: + # Only honor explicit evaluator_python if it exists and imports torch. + if os.path.exists(evaluator_python) and _python_has_torch(evaluator_python): + py_with_torch = evaluator_python + if py_with_torch is None: + py_with_torch = _find_python_with_torch() + command = _rewrite_python_launcher(command, py_with_torch) + command = _normalize_evaluator_command(command) + proc = subprocess.run( + command, + shell=True, + cwd=str(ROOT_DIR), + capture_output=True, + text=True, + timeout=timeout_s, + ) + stdout = proc.stdout or "" + stderr = proc.stderr or "" + score = _parse_score(stdout) + if score is None: + score = 0.0 + if proc.returncode != 0: + if stderr: + stderr = stderr + f"\n[evaluator_exit_code={proc.returncode}]" + else: + stderr = f"[evaluator_exit_code={proc.returncode}]" + if py_with_torch is None: + extra = ( + "\n[no_python_with_torch_found] Set FERB_EVAL_PYTHON to a Python interpreter " + "that has torch installed." + ) + stderr = (stderr + extra) if stderr else extra.lstrip() + return score, stdout, stderr + + +def _run_evaluator_live( + command_template: str, + candidate_path: Path, + timeout_s: int, + evaluator_python: str | None = None, + *, + heartbeat_every_s: float = 2.0, + tail_chars: int = 1200, +) -> Iterator[dict[str, Any]]: + """ + Run evaluator command while yielding periodic heartbeats. + + This is mainly to keep the SSE UI "alive" while long-running evaluators + (e.g. `modal run ...`) build images / wait for GPUs / execute benchmarks. + """ + command = command_template.format(candidate_path=str(candidate_path)) + py_with_torch = None + if evaluator_python: + if os.path.exists(evaluator_python) and _python_has_torch(evaluator_python): + py_with_torch = evaluator_python + if py_with_torch is None: + py_with_torch = _find_python_with_torch() + command = _rewrite_python_launcher(command, py_with_torch) + command = _normalize_evaluator_command(command) + + start = time.time() + last_hb = 0.0 + stdout_chunks: list[str] = [] + stderr_chunks: list[str] = [] + + proc = subprocess.Popen( + command, + shell=True, + cwd=str(ROOT_DIR), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + assert proc.stdout is not None + assert proc.stderr is not None + + sel = selectors.DefaultSelector() + sel.register(proc.stdout, selectors.EVENT_READ, data="stdout") + sel.register(proc.stderr, selectors.EVENT_READ, data="stderr") + + def _tail(txt: str) -> str: + if tail_chars <= 0: + return "" + if len(txt) <= tail_chars: + return txt + return txt[-tail_chars:] + + while True: + now = time.time() + elapsed = now - start + if elapsed > timeout_s: + try: + proc.kill() + except Exception: + pass + stderr_chunks.append(f"\n[evaluator_timeout_s={timeout_s}]") + break + + # Drain any available output. + events = sel.select(timeout=0.5) + for key, _mask in events: + stream_name = key.data + try: + line = key.fileobj.readline() + except Exception: + line = "" + if not line: + continue + if stream_name == "stdout": + stdout_chunks.append(line) + else: + stderr_chunks.append(line) + + if (now - last_hb) >= heartbeat_every_s: + last_hb = now + out = "".join(stdout_chunks) + err = "".join(stderr_chunks) + yield { + "elapsed_s": int(elapsed), + "stdout_tail": _tail(out), + "stderr_tail": _tail(err), + } + + rc = proc.poll() + if rc is not None: + # Drain remaining buffered output + for stream, name in ((proc.stdout, "stdout"), (proc.stderr, "stderr")): + try: + rest = stream.read() + except Exception: + rest = "" + if not rest: + continue + if name == "stdout": + stdout_chunks.append(rest) + else: + stderr_chunks.append(rest) + break + + stdout = "".join(stdout_chunks) + stderr = "".join(stderr_chunks) + score = _parse_score(stdout) + if score is None: + score = 0.0 + if proc.returncode not in (0, None): + if stderr: + stderr = stderr + f"\n[evaluator_exit_code={proc.returncode}]" + else: + stderr = f"[evaluator_exit_code={proc.returncode}]" + if py_with_torch is None: + extra = ( + "\n[no_python_with_torch_found] Set FERB_EVAL_PYTHON to a Python interpreter " + "that has torch installed." + ) + stderr = (stderr + extra) if stderr else extra.lstrip() + return (score, stdout, stderr) + + +def run_agentic_optimization( + *, + objective: str, + problem_id: int, + iterations: int = 3, + model: str = "gpt-4o-mini", + target_backend: str = "triton", + topology_json_path: str | None = None, + evaluator_command: str | None = None, + evaluator_timeout_s: int = 240, + evaluator_python: str | None = None, + include_full_code: bool = False, + include_trace_output: bool = True, + trace_text_limit: int = 0, +) -> dict[str, Any]: + """ + Iterative generate/evaluate/select loop. + """ + if iterations < 1: + raise ValueError("iterations must be >= 1") + + topology_json = None + if topology_json_path: + topo_path = Path(topology_json_path) + if topo_path.exists(): + topology_json = topo_path.read_text(encoding="utf-8") + + run_id = uuid.uuid4().hex[:12] + run_dir = RUNS_DIR / run_id + run_dir.mkdir(parents=True, exist_ok=True) + + best_code: str | None = None + best_score: float | None = None + best_candidate_path: str | None = None + previous_feedback: str | None = None + previous_eval_feedback: str | None = None + trace: list[IterationResult] = [] + + for idx in range(1, iterations + 1): + quality_feedback = None + model_feedback = "" + code = "" + for _attempt in range(1, 4): + code, model_feedback = _generate_candidate( + model=model, + objective=objective, + problem_id=problem_id, + iteration_idx=idx, + target_backend=target_backend, + previous_best_code=best_code, + previous_feedback=previous_feedback, + previous_eval_feedback=previous_eval_feedback, + quality_feedback=quality_feedback, + topology_json=topology_json, + ) + issues = _candidate_quality_issues(code, target_backend) + if not issues: + break + quality_feedback = "Quality issues: " + "; ".join(issues) + + candidate_path = run_dir / f"candidate_iter_{idx}.py" + candidate_path.write_text(code, encoding="utf-8") + + score = 0.0 + eval_stdout = "" + eval_stderr = "" + if evaluator_command: + score, eval_stdout, eval_stderr = _run_evaluator( + evaluator_command, + candidate_path, + timeout_s=evaluator_timeout_s, + evaluator_python=evaluator_python, + ) + previous_eval_feedback = ( + "Evaluator stdout:\n" + + (eval_stdout or "") + + "\n\nEvaluator stderr:\n" + + (eval_stderr or "") + + f"\n\nScore: {score}" + ) + blocker = _is_infra_blocker(eval_stderr, target_backend) + if blocker: + # Don't keep iterating when the evaluator environment can't run the target backend. + if best_code is None: + best_code = code + best_score = score + best_candidate_path = str(candidate_path) + previous_feedback = model_feedback + trace.append( + IterationResult( + iteration=idx, + candidate_path=str(candidate_path), + score=score, + evaluator_stdout=eval_stdout, + evaluator_stderr=(eval_stderr or "") + f"\n[blocked] {blocker}", + model_feedback=model_feedback, + ) + ) + break + else: + previous_eval_feedback = None + + choose_new_best = False + if best_code is None: + choose_new_best = True + elif best_score is None or score > best_score: + choose_new_best = True + + if choose_new_best: + best_code = code + best_score = score + best_candidate_path = str(candidate_path) + + previous_feedback = model_feedback + trace.append( + IterationResult( + iteration=idx, + candidate_path=str(candidate_path), + score=score, + evaluator_stdout=eval_stdout, + evaluator_stderr=eval_stderr, + model_feedback=model_feedback, + ) + ) + + result = { + "run_id": run_id, + "run_dir": str(run_dir), + "problem_id": problem_id, + "objective": objective, + "iterations": iterations, + "model": model, + "best_score": best_score, + "best_candidate_path": best_candidate_path, + "trace": [ + { + "iteration": t.iteration, + "candidate_path": t.candidate_path, + "score": t.score, + "model_feedback": _clip(t.model_feedback, trace_text_limit), + "evaluator_stdout": ( + _clip(t.evaluator_stdout, trace_text_limit) if include_trace_output else "" + ), + "evaluator_stderr": ( + _clip(t.evaluator_stderr, trace_text_limit) if include_trace_output else "" + ), + } + for t in trace + ], + } + if include_full_code: + result["best_code"] = best_code + return result + + +def stream_agentic_optimization_events( + *, + objective: str, + problem_id: int, + iterations: int = 3, + model: str = "gpt-4o-mini", + target_backend: str = "triton", + topology_json_path: str | None = None, + evaluator_command: str | None = None, + evaluator_timeout_s: int = 240, + evaluator_python: str | None = None, + feedback_preview_chars: int = 1200, +) -> Iterator[dict[str, Any]]: + """ + Stream per-iteration events for real-time visibility. + """ + if iterations < 1: + raise ValueError("iterations must be >= 1") + + topology_json = None + if topology_json_path: + topo_path = Path(topology_json_path) + if topo_path.exists(): + topology_json = topo_path.read_text(encoding="utf-8") + + run_id = uuid.uuid4().hex[:12] + run_dir = RUNS_DIR / run_id + run_dir.mkdir(parents=True, exist_ok=True) + + best_code: str | None = None + best_score: float | None = None + best_candidate_path: str | None = None + previous_feedback: str | None = None + previous_eval_feedback: str | None = None + trace: list[IterationResult] = [] + + yield { + "type": "run_started", + "run_id": run_id, + "run_dir": str(run_dir), + "objective": objective, + "problem_id": problem_id, + "iterations": iterations, + "model": model, + "target_backend": target_backend, + } + + for idx in range(1, iterations + 1): + yield {"type": "iteration_started", "iteration": idx} + + quality_feedback = None + code = "" + model_feedback = "" + for attempt in range(1, 4): + code, model_feedback = _generate_candidate( + model=model, + objective=objective, + problem_id=problem_id, + iteration_idx=idx, + target_backend=target_backend, + previous_best_code=best_code, + previous_feedback=previous_feedback, + previous_eval_feedback=previous_eval_feedback, + quality_feedback=quality_feedback, + topology_json=topology_json, + ) + yield { + "type": "agent_thought", + "iteration": idx, + "attempt": attempt, + "text": model_feedback, + } + issues = _candidate_quality_issues(code, target_backend) + if not issues: + break + quality_feedback = "Quality issues: " + "; ".join(issues) + yield { + "type": "quality_reject", + "iteration": idx, + "attempt": attempt, + "issues": issues, + } + + candidate_path = run_dir / f"candidate_iter_{idx}.py" + candidate_path.write_text(code, encoding="utf-8") + yield { + "type": "candidate_generated", + "iteration": idx, + "candidate_path": str(candidate_path), + "feedback_preview": _clip(model_feedback, feedback_preview_chars), + "candidate_code_preview": _clip(code, 20000), + } + + score = 0.0 + eval_stdout = "" + eval_stderr = "" + if evaluator_command: + yield { + "type": "evaluation_started", + "iteration": idx, + "command": evaluator_command, + } + timeout_s = evaluator_timeout_s + if "modal run" in (evaluator_command or "") and timeout_s < 1800: + # Modal runs often include image build/pull and GPU scheduling latency. + timeout_s = 1800 + + live = _run_evaluator_live( + evaluator_command, + candidate_path, + timeout_s=timeout_s, + evaluator_python=evaluator_python, + heartbeat_every_s=2.0, + tail_chars=1500, + ) + while True: + try: + hb = next(live) + yield { + "type": "evaluation_heartbeat", + "iteration": idx, + **hb, + } + except StopIteration as stop: + score, eval_stdout, eval_stderr = stop.value + break + previous_eval_feedback = ( + "Evaluator stdout:\n" + + (eval_stdout or "") + + "\n\nEvaluator stderr:\n" + + (eval_stderr or "") + + f"\n\nScore: {score}" + ) + yield { + "type": "evaluation_completed", + "iteration": idx, + "score": score, + "stdout_preview": _clip(eval_stdout, 0), + "stderr_preview": _clip(eval_stderr, 0), + } + + blocker = _is_infra_blocker(eval_stderr, target_backend) + if blocker: + if best_code is None: + best_code = code + best_score = score + best_candidate_path = str(candidate_path) + yield { + "type": "blocked", + "iteration": idx, + "reason": blocker, + "hint": ( + "If this is a Modal auth issue, run `modal token new` locally (same environment running the API), " + "then rerun. For Triton/NVSHMEM acceleration you must evaluate on a CUDA multi-GPU environment " + "(e.g. Modal H100x8) with triton + nvshmem installed." + ), + } + break + else: + previous_eval_feedback = None + + choose_new_best = False + if best_code is None: + choose_new_best = True + elif best_score is None or score > best_score: + choose_new_best = True + + if choose_new_best: + best_code = code + best_score = score + best_candidate_path = str(candidate_path) + yield { + "type": "best_updated", + "iteration": idx, + "best_score": best_score, + "best_candidate_path": best_candidate_path, + } + else: + yield { + "type": "best_unchanged", + "iteration": idx, + "best_score": best_score, + "best_candidate_path": best_candidate_path, + } + + previous_feedback = model_feedback + trace.append( + IterationResult( + iteration=idx, + candidate_path=str(candidate_path), + score=score, + evaluator_stdout=eval_stdout, + evaluator_stderr=eval_stderr, + model_feedback=model_feedback, + ) + ) + + result = { + "run_id": run_id, + "run_dir": str(run_dir), + "problem_id": problem_id, + "objective": objective, + "iterations": iterations, + "model": model, + "best_score": best_score, + "best_candidate_path": best_candidate_path, + "trace": [ + { + "iteration": t.iteration, + "candidate_path": t.candidate_path, + "score": t.score, + "model_feedback": _clip(t.model_feedback, 0), + } + for t in trace + ], + } + yield {"type": "run_completed", "result": result} diff --git a/api/main.py b/api/main.py new file mode 100644 index 0000000..6140b71 --- /dev/null +++ b/api/main.py @@ -0,0 +1,173 @@ +""" +FERB API: +- Chat endpoint for kernel optimization guidance +- Agentic optimization endpoint for iterative GPT-driven solution improvement +""" + +import os +import json +from contextlib import asynccontextmanager + +from pathlib import Path + +from fastapi import FastAPI +from fastapi import HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel, Field + +from agentic import gpt_chat_reply +from agentic import run_agentic_optimization +from agentic import stream_agentic_optimization_events + + +# --------------------------------------------------------------------------- +# API +# --------------------------------------------------------------------------- + +class ChatRequest(BaseModel): + message: str + model: str = "gpt-4o-mini" + + +class ChatResponse(BaseModel): + reply: str + ok: bool = True + + +class OptimizeRequest(BaseModel): + objective: str = Field(..., min_length=8) + problem_id: int = Field(..., ge=1) + iterations: int = Field(default=3, ge=1, le=10) + model: str = "gpt-4o-mini" + target_backend: str = "triton" + topology_json_path: str | None = None + evaluator_command: str | None = None + evaluator_timeout_s: int = Field(default=240, ge=10, le=3600) + evaluator_python: str | None = None + include_full_code: bool = False + include_trace_output: bool = False + + +class OptimizeResponse(BaseModel): + ok: bool = True + result: dict + + +@asynccontextmanager +async def lifespan(app: FastAPI): + yield + + +app = FastAPI(title="FERB Kernel Optimization Chat", lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/health") +def health() -> dict[str, str]: + return {"status": "ok"} + + +@app.post("/chat", response_model=ChatResponse) +def chat(body: ChatRequest) -> ChatResponse: + # GPT-backed chat; if no API key, return setup guidance. + if not os.environ.get("OPENAI_API_KEY", "").strip(): + return ChatResponse( + reply=( + "OPENAI_API_KEY is not set. Export it first, then retry. " + "Example: `export OPENAI_API_KEY=...`" + ), + ok=False, + ) + try: + reply = gpt_chat_reply(body.message, model=body.model) + return ChatResponse(reply=reply, ok=True) + except Exception as exc: + raise HTTPException(status_code=500, detail=f"chat failed: {exc}") from exc + + +@app.post("/optimize", response_model=OptimizeResponse) +def optimize(body: OptimizeRequest) -> OptimizeResponse: + """ + Agentic optimization loop: + - Generate candidate code with GPT + - Optionally evaluate candidate via command + - Keep best candidate + """ + if not os.environ.get("OPENAI_API_KEY", "").strip(): + raise HTTPException( + status_code=400, + detail="OPENAI_API_KEY is required for /optimize", + ) + try: + result = run_agentic_optimization( + objective=body.objective, + problem_id=body.problem_id, + iterations=body.iterations, + model=body.model, + target_backend=body.target_backend, + topology_json_path=body.topology_json_path, + evaluator_command=body.evaluator_command, + evaluator_timeout_s=body.evaluator_timeout_s, + evaluator_python=body.evaluator_python, + include_full_code=body.include_full_code, + include_trace_output=body.include_trace_output, + ) + return OptimizeResponse(ok=True, result=result) + except Exception as exc: + raise HTTPException(status_code=500, detail=f"optimization failed: {exc}") from exc + + +@app.post("/optimize/stream") +def optimize_stream(body: OptimizeRequest) -> StreamingResponse: + """ + Stream live agent iteration events via Server-Sent Events (SSE). + """ + if not os.environ.get("OPENAI_API_KEY", "").strip(): + raise HTTPException( + status_code=400, + detail="OPENAI_API_KEY is required for /optimize/stream", + ) + + def _event_stream(): + try: + for event in stream_agentic_optimization_events( + objective=body.objective, + problem_id=body.problem_id, + iterations=body.iterations, + model=body.model, + target_backend=body.target_backend, + topology_json_path=body.topology_json_path, + evaluator_command=body.evaluator_command, + evaluator_timeout_s=body.evaluator_timeout_s, + evaluator_python=body.evaluator_python, + ): + yield f"data: {json.dumps(event)}\n\n" + except Exception as exc: + err_event = {"type": "error", "detail": str(exc)} + yield f"data: {json.dumps(err_event)}\n\n" + + return StreamingResponse(_event_stream(), media_type="text/event-stream") + + +# Serve frontend (whitespace chatbot) when running from repo root +_frontend = Path(__file__).resolve().parent.parent / "frontend" +if _frontend.is_dir(): + app.mount("/", StaticFiles(directory=str(_frontend), html=True), name="frontend") + + +# --------------------------------------------------------------------------- +# Dev server (optional) +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/api/requirements.txt b/api/requirements.txt new file mode 100644 index 0000000..a1bc3af --- /dev/null +++ b/api/requirements.txt @@ -0,0 +1,5 @@ +fastapi>=0.115.0 +uvicorn[standard]>=0.32.0 +pydantic>=2.0.0 +openai>=1.55.0 +modal diff --git a/frontend/README.md b/frontend/README.md new file mode 100644 index 0000000..23475b1 --- /dev/null +++ b/frontend/README.md @@ -0,0 +1,30 @@ +# FERB Chat Frontend + +Whitespace-style AI chatbot for kernel optimization and faster ML training. + +## Run everything (API + frontend) + +From the **project root** (FERB): + +```bash +cd api && pip install -r requirements.txt && uvicorn main:app --reload --host 0.0.0.0 --port 8000 +``` + +Then open **http://127.0.0.1:8000** in your browser. The same server serves the UI and the `/chat` API. + +## Run frontend only (e.g. static server) + +If the API runs elsewhere, open `index.html` or serve `frontend/` with any static server and set the API base: + +```html + + +``` + +Or run from `frontend/`: + +```bash +npx serve -p 3000 +``` + +Then set `FERB_API_BASE` to your API URL (e.g. `http://127.0.0.1:8000`) if the API is on a different port. diff --git a/frontend/app.js b/frontend/app.js new file mode 100644 index 0000000..a32f411 --- /dev/null +++ b/frontend/app.js @@ -0,0 +1,386 @@ +(function () { + const API_BASE = window.FERB_API_BASE || ""; + + const form = document.getElementById("chatForm"); + const input = document.getElementById("input"); + const submitBtn = document.getElementById("submit"); + const messagesEl = document.getElementById("messages"); + const codePanel = document.getElementById("codePanel"); + const ideMeta = document.getElementById("ideMeta"); + let runningOptimize = false; + let typingRunId = 0; + + function escapeHtml(text) { + const div = document.createElement("div"); + div.textContent = text; + return div.innerHTML; + } + + function nl2br(text) { + return escapeHtml(text).replace(/\n/g, "
"); + } + + function renderRichText(text) { + // Tiny safe markdown subset: bold, italic, inline code + newlines. + let html = escapeHtml(text || ""); + html = html.replace(/`([^`]+)`/g, "$1"); + html = html.replace(/\*\*([^*]+)\*\*/g, "$1"); + html = html.replace(/\*([^*]+)\*/g, "$1"); + html = html.replace(/\n/g, "
"); + return html; + } + + function addMessage(role, content, options = {}) { + const div = document.createElement("div"); + div.className = "message message--" + role + (options.loading ? " message--loading" : ""); + div.setAttribute("data-role", role); + + const bubble = document.createElement("div"); + bubble.className = "message__bubble"; + if (options.markdown) { + bubble.innerHTML = renderRichText(content); + } else { + bubble.innerHTML = "

" + renderRichText(content) + "

"; + } + div.appendChild(bubble); + messagesEl.appendChild(div); + messagesEl.scrollTop = messagesEl.scrollHeight; + return div; + } + + function updateIDE(metaText, codeText) { + if (ideMeta && typeof metaText === "string") ideMeta.textContent = metaText; + if (codePanel && typeof codeText === "string") codePanel.textContent = codeText; + } + + function isNearBottom(el, thresholdPx = 48) { + if (!el) return true; + return el.scrollTop + el.clientHeight >= el.scrollHeight - thresholdPx; + } + + function typeIntoCodePanel(fullText, opts = {}) { + if (!codePanel) return; + const text = String(fullText || ""); + const chunkSize = Math.max(1, Number(opts.chunkSize || 12)); + const tickMs = Math.max(8, Number(opts.tickMs || 16)); + + const myRun = ++typingRunId; + codePanel.classList.add("typing"); + codePanel.textContent = ""; + + let i = 0; + const timer = setInterval(() => { + if (myRun !== typingRunId) { + clearInterval(timer); + return; + } + if (i >= text.length) { + clearInterval(timer); + codePanel.classList.remove("typing"); + return; + } + + const keepScroll = isNearBottom(codePanel); + codePanel.textContent += text.slice(i, i + chunkSize); + i += chunkSize; + if (keepScroll) codePanel.scrollTop = codePanel.scrollHeight; + }, tickMs); + } + + function setMessageContent(el, content) { + const bubble = el.querySelector(".message__bubble"); + if (!bubble) return; + el.classList.remove("message--loading"); + bubble.innerHTML = "

" + renderRichText(content) + "

"; + messagesEl.scrollTop = messagesEl.scrollHeight; + } + + async function sendMessage(text) { + if (!text.trim()) return; + + addMessage("user", text.trim()); + input.value = ""; + input.style.height = "auto"; + + const trimmed = text.trim(); + if (trimmed.toLowerCase().startsWith("/optimize ")) { + const parsed = parseOptimizeCommand(trimmed); + await runOptimizeStream(parsed.objective, parsed.options); + return; + } + + const loadingEl = addMessage("assistant", "", { loading: true }); + submitBtn.disabled = true; + + try { + const res = await fetch(API_BASE + "/chat", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ message: text.trim() }), + }); + const data = await res.json(); + if (data.reply) { + setMessageContent(loadingEl, data.reply); + } else { + setMessageContent(loadingEl, "Sorry, something went wrong. " + (data.detail || "")); + } + } catch (err) { + setMessageContent( + loadingEl, + "Could not reach the API. Is it running? Start with: cd api && uvicorn main:app --reload" + ); + } finally { + submitBtn.disabled = false; + input.focus(); + } + } + + function tokenizeCommand(cmd) { + const re = /"[^"]*"|\S+/g; + const out = []; + const m = cmd.match(re) || []; + for (const t of m) { + if (t.startsWith("\"") && t.endsWith("\"") && t.length >= 2) out.push(t.slice(1, -1)); + else out.push(t); + } + return out; + } + + function parseOptimizeCommand(cmd) { + // Supported: + // /optimize [--backend triton|reference] [--problem N] [--iters N] [--nproc N] [--model NAME] [--no-eval] + const tokens = tokenizeCommand(cmd); + const options = { + problem_id: 1, + iterations: 3, + model: "gpt-4o-mini", + target_backend: "triton", + nproc_per_node: 1, + no_eval: false, + }; + + let i = 1; // skip /optimize + while (i < tokens.length) { + const tok = tokens[i]; + if (!tok.startsWith("--")) break; + const key = tok.slice(2).toLowerCase(); + if (key === "no-eval") { + options.no_eval = true; + i += 1; + continue; + } + const val = tokens[i + 1]; + if (val == null) break; + if (key === "backend") options.target_backend = String(val); + else if (key === "problem") options.problem_id = Number(val); + else if (key === "iters") options.iterations = Number(val); + else if (key === "nproc") options.nproc_per_node = Number(val); + else if (key === "model") options.model = String(val); + i += 2; + } + + const objective = tokens.slice(i).join(" ").trim(); + return { objective, options }; + } + + async function runOptimizeStream(objective, options = {}) { + if (!objective) { + addMessage( + "assistant", + "Usage: /optimize [--backend triton|reference] [--problem N] [--iters N] [--nproc N] [--model NAME] [--no-eval] " + ); + return; + } + if (runningOptimize) { + addMessage("assistant", "An optimize run is already in progress."); + return; + } + + runningOptimize = true; + submitBtn.disabled = true; + const statusEl = addMessage("assistant", "Starting agentic optimization run...", { loading: true }); + + try { + const target_backend = options.target_backend || "triton"; + const problem_id = Number.isFinite(options.problem_id) ? options.problem_id : 1; + const iterations = Number.isFinite(options.iterations) ? options.iterations : 3; + const model = options.model || "gpt-4o-mini"; + const nproc = Number.isFinite(options.nproc_per_node) ? options.nproc_per_node : 1; + const noEval = !!options.no_eval; + + const useModalEval = target_backend === "triton"; + + const res = await fetch(API_BASE + "/optimize/stream", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + objective, + problem_id, + iterations, + model, + target_backend, + evaluator_python: window.FERB_EVAL_PYTHON || null, + evaluator_command: noEval + ? null + : useModalEval + ? `modal run scripts/modal_benchmark.py --problem ${problem_id} --candidate {candidate_path} --rows 1024 --cols 1024 --dtype float32 --warmup 3 --iters 10` + : `python -m torch.distributed.run --nproc-per-node ${nproc} scripts/benchmark_candidate.py --problem ${problem_id} --candidate {candidate_path} --rows 1024 --cols 1024 --dtype float32 --warmup 3 --iters 10`, + }), + }); + + if (!res.ok || !res.body) { + const txt = await res.text(); + setMessageContent(statusEl, "Optimize stream failed: " + txt); + return; + } + + setMessageContent( + statusEl, + `Agent run started. Streaming steps below.\n\nbackend=${target_backend} problem=${problem_id} iters=${iterations} eval=${noEval ? "disabled" : (useModalEval ? "modal H100x8" : `local nproc=${nproc}`)}` + ); + + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (!line.startsWith("data: ")) continue; + const raw = line.slice(6); + if (!raw) continue; + let event; + try { + event = JSON.parse(raw); + } catch { + continue; + } + renderOptimizeEvent(event); + } + } + } catch (err) { + addMessage( + "assistant", + "Could not stream optimization. Ensure API is running and OPENAI_API_KEY is set." + ); + } finally { + runningOptimize = false; + submitBtn.disabled = false; + input.focus(); + } + } + + function renderOptimizeEvent(event) { + const t = event && event.type; + if (!t) return; + if (t === "run_started") { + addMessage("assistant", `Run ${event.run_id} started.\nModel: ${event.model}\nIterations: ${event.iterations}`); + return; + } + if (t === "iteration_started") { + addMessage("assistant", `Let's try iteration ${event.iteration}...`); + return; + } + if (t === "candidate_generated") { + updateIDE( + `iter ${event.iteration} · ${event.candidate_path || ""}`, + "" + ); + typeIntoCodePanel(event.candidate_code_preview || "", { chunkSize: 14, tickMs: 14 }); + addMessage( + "assistant", + `Iteration ${event.iteration}: candidate generated.\nPath: ${event.candidate_path}\n\n(Full code is in the Agent IDE pane.)\n\nAgent critique:\n${event.feedback_preview || ""}` + ); + return; + } + if (t === "agent_thought") { + addMessage("assistant", `Iteration ${event.iteration} attempt ${event.attempt} plan:\n${event.text || ""}`); + return; + } + if (t === "quality_reject") { + addMessage( + "assistant", + `Iteration ${event.iteration} attempt ${event.attempt} rejected by quality gate:\n- ${(event.issues || []).join("\n- ")}` + ); + return; + } + if (t === "evaluation_started") { + addMessage("assistant", `Iteration ${event.iteration}: running evaluator...`); + updateIDE(`iter ${event.iteration} · evaluating…`, codePanel ? codePanel.textContent : ""); + return; + } + if (t === "evaluation_heartbeat") { + const elapsed = Number.isFinite(event.elapsed_s) ? event.elapsed_s : 0; + if (ideMeta) ideMeta.textContent = `iter ${event.iteration} · evaluating… ${elapsed}s`; + return; + } + if (t === "evaluation_completed") { + let metricsLine = ""; + try { + const parsed = JSON.parse(event.stdout_preview || "{}"); + if (parsed && typeof parsed === "object" && "speedup" in parsed) { + metricsLine = + `\nallclose=${parsed.allclose} max_abs_diff=${parsed.max_abs_diff}` + + `\nreference_ms=${parsed.reference_ms} candidate_ms=${parsed.candidate_ms}` + + `\nspeedup=${parsed.speedup} score=${parsed.score}`; + } + } catch (_) { + // stdout may not be JSON; ignore + } + addMessage( + "assistant", + `Iteration ${event.iteration}: score=${event.score}${metricsLine}\nstdout:\n${event.stdout_preview || ""}\nstderr:\n${event.stderr_preview || ""}` + ); + return; + } + if (t === "blocked") { + addMessage( + "assistant", + `Blocked at iteration ${event.iteration}.\nReason: ${event.reason || ""}\n\nHint: ${event.hint || ""}` + ); + return; + } + if (t === "best_updated") { + addMessage("assistant", `New best at iteration ${event.iteration}. score=${event.best_score}`); + return; + } + if (t === "best_unchanged") { + addMessage("assistant", `No improvement at iteration ${event.iteration}. Current best score=${event.best_score}`); + return; + } + if (t === "run_completed") { + const r = event.result || {}; + addMessage( + "assistant", + `Run complete.\nBest score: ${r.best_score}\nBest candidate: ${r.best_candidate_path}\nRun dir: ${r.run_dir}` + ); + return; + } + if (t === "error") { + addMessage("assistant", `Run error: ${event.detail || "Unknown error"}`); + } + } + + form.addEventListener("submit", function (e) { + e.preventDefault(); + sendMessage(input.value); + }); + + input.addEventListener("input", function () { + input.style.height = "auto"; + input.style.height = Math.min(input.scrollHeight, 12 * 24) + "px"; + }); + + input.addEventListener("keydown", function (e) { + if (e.key === "Enter" && !e.shiftKey) { + e.preventDefault(); + form.requestSubmit(); + } + }); +})(); diff --git a/frontend/index.html b/frontend/index.html new file mode 100644 index 0000000..72f60f3 --- /dev/null +++ b/frontend/index.html @@ -0,0 +1,62 @@ + + + + + + FERB — Kernel Optimization + + + + + + +
+
+ + Kernel optimization & faster ML training +
+ +
+
+
+
+
+
+

Ask about problems, NVSHMEM, Modal runs, MoE, or how to train models faster.

+

Run live agent loop: /optimize Improve throughput for problem 1 while preserving correctness

+
+
+
+ +
+ + +
+
+ + +
+
+ +
+ Whitespace AI · FERB +
+
+ + + diff --git a/frontend/style.css b/frontend/style.css new file mode 100644 index 0000000..59390b4 --- /dev/null +++ b/frontend/style.css @@ -0,0 +1,320 @@ +/* --- Variables --- */ +:root { + --bg: #fafaf9; + --bg-elevated: #ffffff; + --text: #1c1917; + --text-muted: #78716c; + --border: #e7e5e4; + --accent: #0f766e; + --accent-hover: #0d9488; + --radius: 12px; + --font-sans: "DM Sans", system-ui, sans-serif; + --font-mono: "JetBrains Mono", ui-monospace, monospace; + --space-xs: 0.25rem; + --space-sm: 0.5rem; + --space-md: 1rem; + --space-lg: 1.5rem; + --space-xl: 2rem; + --space-2xl: 3rem; + --space-3xl: 4rem; + --space-4xl: 6rem; + --space-5xl: 8rem; + --max-width: 42rem; + --input-min-height: 52px; +} + +/* --- Reset & base --- */ +*, *::before, *::after { + box-sizing: border-box; +} + +html { + font-size: 18px; + -webkit-font-smoothing: antialiased; +} + +body { + margin: 0; + min-height: 100vh; + font-family: var(--font-sans); + font-weight: 400; + color: var(--text); + background: var(--bg); + line-height: 1.6; +} + +/* --- Layout --- */ +.app { + min-height: 100vh; + display: flex; + flex-direction: column; + max-width: 100vw; +} + +.header { + flex-shrink: 0; + padding: var(--space-3xl) var(--space-xl) var(--space-2xl); + text-align: center; + border-bottom: 1px solid var(--border); +} + +.logo { + display: block; + font-size: 1rem; + font-weight: 500; + letter-spacing: 0.12em; + text-transform: uppercase; + color: var(--text-muted); + margin-bottom: var(--space-sm); +} + +.tagline { + font-size: 0.95rem; + font-weight: 300; + color: var(--text-muted); +} + +.main { + flex: 1; + display: flex; + flex-direction: column; + width: 100%; + max-width: 80rem; + margin: 0 auto; + padding: var(--space-4xl) var(--space-xl) var(--space-2xl); + min-height: 0; +} + +/* --- Workspace (chat + IDE) --- */ +.workspace { + flex: 1; + min-height: 0; + display: grid; + grid-template-columns: minmax(0, 1fr) minmax(0, 1fr); + gap: var(--space-xl); +} + +.chat { + min-height: 0; + display: flex; + flex-direction: column; +} + +.ide { + min-height: 0; + background: #0b0f14; + border: 1px solid #111827; + border-radius: var(--radius); + display: flex; + flex-direction: column; + overflow: hidden; +} + +.ide__header { + flex-shrink: 0; + padding: var(--space-md) var(--space-lg); + border-bottom: 1px solid #111827; + display: flex; + align-items: baseline; + justify-content: space-between; + gap: var(--space-md); +} + +.ide__title { + font-size: 0.9rem; + font-weight: 500; + letter-spacing: 0.06em; + text-transform: uppercase; + color: #94a3b8; +} + +.ide__meta { + font-size: 0.85rem; + color: #64748b; + font-family: var(--font-mono); + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + max-width: 70%; +} + +.code-panel { + flex: 1; + margin: 0; + padding: var(--space-lg); + overflow: auto; + font-family: var(--font-mono); + font-size: 0.84rem; + line-height: 1.55; + background: #0b0f14; + color: #e5e7eb; + white-space: pre; +} + +.code-panel.typing::after { + content: ""; + display: inline-block; + width: 8px; + height: 1em; + background: #e5e7eb; + opacity: 0.85; + margin-left: 3px; + vertical-align: -0.15em; + animation: blink 0.8s ease-in-out infinite; +} + +/* --- Messages --- */ +.messages { + flex: 1; + display: flex; + flex-direction: column; + gap: var(--space-3xl); + padding-bottom: var(--space-2xl); + overflow-y: auto; +} + +.message { + display: flex; + width: 100%; +} + +.message--user { + justify-content: flex-end; +} + +.message--user .message__bubble { + background: var(--text); + color: var(--bg); + border-radius: var(--radius) var(--radius) var(--radius) 2px; +} + +.message--assistant .message__bubble { + background: var(--bg-elevated); + color: var(--text); + border: 1px solid var(--border); + border-radius: var(--radius) var(--radius) 2px var(--radius); +} + +.message__bubble { + max-width: 85%; + padding: var(--space-lg) var(--space-xl); + font-size: 0.95rem; + white-space: pre-wrap; + word-break: break-word; +} + +.message__bubble p { + margin: 0 0 var(--space-sm); +} + +.message__bubble p:last-child { + margin-bottom: 0; +} + +.message--loading .message__bubble::after { + content: ""; + display: inline-block; + width: 4px; + height: 1em; + background: var(--text-muted); + animation: blink 0.8s ease-in-out infinite; + margin-left: 2px; + vertical-align: -0.2em; +} + +@keyframes blink { + 0%, 100% { opacity: 0.3; } + 50% { opacity: 1; } +} + +/* --- Input --- */ +.input-wrap { + flex-shrink: 0; + display: flex; + align-items: flex-end; + gap: var(--space-md); + padding: var(--space-md) 0; + border-top: 1px solid var(--border); +} + +.input { + flex: 1; + min-height: var(--input-min-height); + max-height: 12rem; + padding: var(--space-md) var(--space-lg); + font-family: var(--font-sans); + font-size: 0.95rem; + color: var(--text); + background: var(--bg-elevated); + border: 1px solid var(--border); + border-radius: var(--radius); + resize: none; + transition: border-color 0.15s ease, box-shadow 0.15s ease; +} + +.input::placeholder { + color: var(--text-muted); +} + +.input:focus { + outline: none; + border-color: var(--accent); + box-shadow: 0 0 0 2px rgba(15, 118, 110, 0.12); +} + +.submit { + flex-shrink: 0; + width: var(--input-min-height); + height: var(--input-min-height); + display: flex; + align-items: center; + justify-content: center; + font-size: 1.25rem; + color: var(--bg); + background: var(--accent); + border: none; + border-radius: var(--radius); + cursor: pointer; + transition: background 0.15s ease, transform 0.1s ease; +} + +.submit:hover { + background: var(--accent-hover); +} + +.submit:active { + transform: scale(0.98); +} + +.submit:disabled { + opacity: 0.5; + cursor: not-allowed; + transform: none; +} + +.submit__arrow { + line-height: 1; +} + +/* --- Footer --- */ +.footer { + flex-shrink: 0; + padding: var(--space-xl); + text-align: center; + border-top: 1px solid var(--border); +} + +.footer__text { + font-size: 0.8rem; + color: var(--text-muted); + font-weight: 300; +} + +@media (max-width: 900px) { + .workspace { + grid-template-columns: 1fr; + } + .ide__meta { + max-width: 60%; + } +} diff --git a/reference/__pycache__/1.cpython-310.pyc b/reference/__pycache__/1.cpython-310.pyc deleted file mode 100644 index 78d0e89..0000000 Binary files a/reference/__pycache__/1.cpython-310.pyc and /dev/null differ diff --git a/scripts/benchmark_candidate.py b/scripts/benchmark_candidate.py new file mode 100644 index 0000000..184bf8f --- /dev/null +++ b/scripts/benchmark_candidate.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +""" +Benchmark reference vs candidate solution under torchrun. + +Outputs: +- JSON summary (default) +- `score=` for evaluator mode (--score-only) +""" + +from __future__ import annotations + +import argparse +import importlib.util +import json +import os +import statistics +import sys +import time +from pathlib import Path +from typing import Any + +# Ensure repo root is importable when launched by torch.distributed.run. +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +import torch +import torch.distributed as dist + +from utils.init_and_finalize_backends import finalize_reference +from utils.init_and_finalize_backends import init_reference +from utils.input_output_tensors import create_input_tensor + + +def _dtype_from_string(name: str) -> torch.dtype: + lookup = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "bfloat16": torch.bfloat16, + "int32": torch.int32, + "int64": torch.int64, + } + key = (name or "").strip().lower() + if key not in lookup: + raise ValueError(f"Unsupported dtype: {name}") + return lookup[key] + + +def _load_solution(path_str: str): + path = Path(path_str) + if not path.exists(): + raise FileNotFoundError(f"Solution file not found: {path_str}") + mod_name = f"ferb_bench_{path.stem}_{os.getpid()}" + spec = importlib.util.spec_from_file_location(mod_name, str(path)) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to import: {path_str}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if not hasattr(module, "solution"): + raise AttributeError(f"{path_str} has no solution(...)") + return getattr(module, "solution") + + +def _clone_obj(obj: Any) -> Any: + if isinstance(obj, torch.Tensor): + return obj.clone() + if isinstance(obj, tuple): + return tuple(_clone_obj(x) for x in obj) + if isinstance(obj, list): + return [_clone_obj(x) for x in obj] + if isinstance(obj, dict): + return {k: _clone_obj(v) for k, v in obj.items()} + return obj + + +def _compare_outputs(a: Any, b: Any, atol: float, rtol: float) -> tuple[bool, float]: + # Returns (allclose, max_abs_diff) + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + ac = torch.allclose(a, b, atol=atol, rtol=rtol) + max_diff = float((a - b).abs().max().item()) if a.numel() > 0 else 0.0 + return ac, max_diff + if isinstance(a, tuple) and isinstance(b, tuple) and len(a) == len(b): + oks = [] + diffs = [] + for x, y in zip(a, b): + ok, d = _compare_outputs(x, y, atol, rtol) + oks.append(ok) + diffs.append(d) + return all(oks), (max(diffs) if diffs else 0.0) + if isinstance(a, list) and isinstance(b, list) and len(a) == len(b): + oks = [] + diffs = [] + for x, y in zip(a, b): + ok, d = _compare_outputs(x, y, atol, rtol) + oks.append(ok) + diffs.append(d) + return all(oks), (max(diffs) if diffs else 0.0) + if isinstance(a, dict) and isinstance(b, dict) and set(a.keys()) == set(b.keys()): + oks = [] + diffs = [] + for k in sorted(a.keys()): + ok, d = _compare_outputs(a[k], b[k], atol, rtol) + oks.append(ok) + diffs.append(d) + return all(oks), (max(diffs) if diffs else 0.0) + return False, float("inf") + + +def _tensor_sample(obj: Any, n: int = 8): + if isinstance(obj, torch.Tensor): + flat = obj.detach().flatten() + k = min(n, flat.numel()) + return flat[:k].tolist() + if isinstance(obj, tuple): + return [_tensor_sample(x, n) for x in obj] + if isinstance(obj, list): + return [_tensor_sample(x, n) for x in obj] + if isinstance(obj, dict): + return {k: _tensor_sample(v, n) for k, v in obj.items()} + return str(type(obj)) + + +def _run_timed(fn, inputs: tuple[Any, ...], warmup: int, iters: int) -> list[float]: + for _ in range(warmup): + _ = fn(*_clone_obj(inputs)) + if torch.cuda.is_available(): + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + times_ms: list[float] = [] + for _ in range(iters): + if dist.is_initialized(): + dist.barrier() + if torch.cuda.is_available(): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + _ = fn(*_clone_obj(inputs)) + end.record() + torch.cuda.synchronize() + times_ms.append(float(start.elapsed_time(end))) + else: + t0 = time.perf_counter() + _ = fn(*_clone_obj(inputs)) + t1 = time.perf_counter() + times_ms.append((t1 - t0) * 1000.0) + if dist.is_initialized(): + dist.barrier() + return times_ms + + +def _global_max_time(local_ms: float, device: torch.device) -> float: + t = torch.tensor([local_ms], device=device, dtype=torch.float32) + dist.all_reduce(t, op=dist.ReduceOp.MAX) + return float(t.item()) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Benchmark FERB candidate speedup") + parser.add_argument("--problem", type=int, required=True) + parser.add_argument("--candidate", type=str, required=True) + parser.add_argument("--rows", type=int, default=1024) + parser.add_argument("--cols", type=int, default=1024) + parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--warmup", type=int, default=3) + parser.add_argument("--iters", type=int, default=10) + parser.add_argument("--atol", type=float, default=1e-4) + parser.add_argument("--rtol", type=float, default=1e-4) + parser.add_argument("--score-only", action="store_true") + args = parser.parse_args() + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + + init_reference(rank, world_size) + try: + device = torch.device("cuda", torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu") + dtype = _dtype_from_string(args.dtype) + base_shape = (args.rows, args.cols) + inputs = create_input_tensor( + rank=rank, + world_size=world_size, + problem_id=args.problem, + base_shape=base_shape, + dtype=dtype, + device=device, + ) + + ref_path = Path(__file__).resolve().parent.parent / "reference" / f"{args.problem}.py" + ref_solution = _load_solution(str(ref_path)) + cand_solution = _load_solution(args.candidate) + + # Correctness pass + with torch.no_grad(): + ref_out = ref_solution(*_clone_obj(inputs)) + cand_out = cand_solution(*_clone_obj(inputs)) + ok_local, max_diff_local = _compare_outputs(ref_out, cand_out, atol=args.atol, rtol=args.rtol) + ok_tensor = torch.tensor([1 if ok_local else 0], device=device, dtype=torch.int32) + dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN) + all_ok = bool(ok_tensor.item() == 1) + + diff_tensor = torch.tensor([max_diff_local], device=device, dtype=torch.float32) + dist.all_reduce(diff_tensor, op=dist.ReduceOp.MAX) + max_abs_diff = float(diff_tensor.item()) + + # Timing + ref_times = _run_timed(ref_solution, inputs, warmup=args.warmup, iters=args.iters) + cand_times = _run_timed(cand_solution, inputs, warmup=args.warmup, iters=args.iters) + ref_local_mean = statistics.mean(ref_times) + cand_local_mean = statistics.mean(cand_times) + ref_global_ms = _global_max_time(ref_local_mean, device) + cand_global_ms = _global_max_time(cand_local_mean, device) + + speedup = (ref_global_ms / cand_global_ms) if cand_global_ms > 0 else 0.0 + score = speedup if all_ok else 0.0 + + if rank == 0: + payload = { + "problem": args.problem, + "candidate": args.candidate, + "shape": [args.rows, args.cols], + "dtype": args.dtype, + "world_size": world_size, + "allclose": all_ok, + "max_abs_diff": max_abs_diff, + "reference_ms": ref_global_ms, + "candidate_ms": cand_global_ms, + "speedup": speedup, + "score": score, + "reference_output_sample": _tensor_sample(ref_out, 8), + "candidate_output_sample": _tensor_sample(cand_out, 8), + } + if args.score_only: + print(f"score={score:.6f}", flush=True) + else: + print(json.dumps(payload), flush=True) + return 0 + except Exception as exc: + if rank == 0: + msg = {"score": 0.0, "error": str(exc)} + if args.score_only: + print("score=0.0", flush=True) + else: + print(json.dumps(msg), flush=True) + return 1 + finally: + finalize_reference() + + +if __name__ == "__main__": + try: + raise SystemExit(main()) + except ModuleNotFoundError as exc: + # Provide a clean machine-parseable error for agent loops. + print(json.dumps({"score": 0.0, "error": str(exc)}), flush=True) + raise SystemExit(1) + except Exception as exc: + print(json.dumps({"score": 0.0, "error": f"{type(exc).__name__}: {exc}"}), flush=True) + raise SystemExit(1) diff --git a/scripts/modal_benchmark.py b/scripts/modal_benchmark.py new file mode 100644 index 0000000..e3c211c --- /dev/null +++ b/scripts/modal_benchmark.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +""" +Run the candidate benchmark on Modal (8x H100) and print JSON metrics. + +This exists so the agent loop can evaluate Triton/NVSHMEM candidates in a real GPU +environment (instead of a local CPU/Mac environment that lacks Triton/CUDA). + +Typical usage (from repo root): + modal run scripts/modal_benchmark.py \ + --problem 1 \ + --candidate /Users/rohk/FERB/.agent_runs//candidate_iter_1.py \ + --rows 1024 --cols 1024 --dtype float32 --warmup 3 --iters 10 +""" + +from __future__ import annotations + +import argparse +import json +import os +from pathlib import Path + +import modal + + +APP_NAME = "ferb-agent-eval" +REMOTE_ROOT = "/workspace" + +app = modal.App(APP_NAME) + +image = ( + modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12") + .apt_install("wget", "xz-utils", "gnupg", "software-properties-common", "git") + # Install NVSHMEM from tarball (avoids version mismatch issues with apt package) + .run_commands( + "wget -q https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/linux-x86_64/libnvshmem-linux-x86_64-3.2.5_cuda12-archive.tar.xz -O /tmp/nvshmem.tar.xz", + "mkdir -p /opt/nvshmem", + "tar -xf /tmp/nvshmem.tar.xz -C /opt/nvshmem --strip-components=1", + "rm /tmp/nvshmem.tar.xz", + ) + .env( + { + "NVSHMEM_HOME": "/opt/nvshmem", + "LD_LIBRARY_PATH": "/opt/nvshmem/lib:/usr/local/cuda/lib64", + "CUDA_HOME": "/usr/local/cuda", + "PATH": "/usr/local/cuda/bin:${PATH}", + } + ) + .pip_install( + "torch", + "triton", + "numpy", + "mpi4py", + "nvshmem4py-cu12", + "cuda-python>=12.0", + "numba", + "numba-cuda", + "cffi", + ) +) + +project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +image_with_code = image.add_local_dir(project_dir, remote_path=REMOTE_ROOT, copy=True) + + +@app.function(image=image_with_code, gpu="H100:8", timeout=60 * 30) +def run_modal_benchmark( + *, + problem: int, + candidate_relpath: str, + rows: int, + cols: int, + dtype: str, + warmup: int, + iters: int, +) -> dict: + import subprocess + + candidate_path = f"{REMOTE_ROOT}/{candidate_relpath.lstrip('/')}" + cmd = [ + "torchrun", + "--nproc-per-node", + "8", + f"{REMOTE_ROOT}/scripts/benchmark_candidate.py", + "--problem", + str(problem), + "--candidate", + candidate_path, + "--rows", + str(rows), + "--cols", + str(cols), + "--dtype", + dtype, + "--warmup", + str(warmup), + "--iters", + str(iters), + ] + + proc = subprocess.run(cmd, capture_output=True, text=True) + stdout = (proc.stdout or "").strip() + stderr = (proc.stderr or "").strip() + + # Try to parse JSON metrics from stdout (benchmark prints JSON). + metrics: dict = {"score": 0.0} + try: + metrics = json.loads(stdout.splitlines()[-1]) if stdout else {"score": 0.0} + if not isinstance(metrics, dict): + metrics = {"score": 0.0} + except Exception: + metrics = {"score": 0.0} + + metrics["modal_returncode"] = proc.returncode + if stderr: + metrics["modal_stderr"] = stderr + return metrics + + +@app.local_entrypoint() +def main( + problem: int = 1, + candidate: str = "", + rows: int = 1024, + cols: int = 1024, + dtype: str = "float32", + warmup: int = 3, + iters: int = 10, +) -> None: + if not candidate: + raise SystemExit("--candidate is required") + + repo_root = Path(project_dir).resolve() + cand_path = Path(candidate).expanduser().resolve() + try: + rel = cand_path.relative_to(repo_root) + except Exception: + raise SystemExit(f"candidate must be within repo: {repo_root} (got {cand_path})") + + metrics = run_modal_benchmark.remote( + problem=problem, + candidate_relpath=str(rel), + rows=rows, + cols=cols, + dtype=dtype, + warmup=warmup, + iters=iters, + ) + print(json.dumps(metrics)) + diff --git a/scripts/worker.py b/scripts/worker.py new file mode 100644 index 0000000..c946a3e --- /dev/null +++ b/scripts/worker.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +""" +Distributed worker for FERB runs (used by run_modal.py). + +Per rank: +1) Initialize distributed backend +2) Load problem solution module +3) Create input tensor(s) +4) Run solution +5) Save outputs +""" + +from __future__ import annotations + +import argparse +import importlib.util +import json +import os +import sys +import traceback +from pathlib import Path +from typing import Any + +# Ensure repo root is importable under torchrun. +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +import torch +import torch.distributed as dist + +from utils.init_and_finalize_backends import finalize_numba_cuda +from utils.init_and_finalize_backends import finalize_reference +from utils.init_and_finalize_backends import finalize_triton +from utils.init_and_finalize_backends import init_numba_cuda +from utils.init_and_finalize_backends import init_reference +from utils.init_and_finalize_backends import init_triton +from utils.input_output_tensors import create_input_tensor +from utils.input_output_tensors import save_tensor + + +def _dtype_from_string(name: str) -> torch.dtype: + lookup = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "bfloat16": torch.bfloat16, + "int32": torch.int32, + "int64": torch.int64, + } + key = (name or "").strip().lower() + if key not in lookup: + raise ValueError(f"Unsupported dtype: {name}") + return lookup[key] + + +def _load_solution_module(problem_py: str): + path = Path(problem_py) + if not path.exists(): + raise FileNotFoundError(f"Problem file not found: {problem_py}") + + mod_name = f"ferb_solution_{path.stem}_{os.getpid()}" + spec = importlib.util.spec_from_file_location(mod_name, str(path)) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load module from {problem_py}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if not hasattr(module, "solution"): + raise AttributeError(f"{problem_py} does not define solution(...)") + return module + + +def _init_backend(backend: str, rank: int, world_size: int) -> None: + backend = backend.strip().lower() + # Always initialize torch.distributed first. + init_reference(rank, world_size) + if backend == "reference": + return + if backend == "triton": + init_triton(rank, world_size) + return + if backend == "numba_cuda": + init_numba_cuda(rank, world_size) + return + # Unknown backend: keep NCCL-only path. + + +def _finalize_backend(backend: str) -> None: + backend = backend.strip().lower() + if backend == "triton": + finalize_triton() + return + if backend == "numba_cuda": + finalize_numba_cuda() + return + finalize_reference() + + +def _save_rank_metadata(logs_dir: str, rank: int, payload: dict[str, Any]) -> None: + os.makedirs(logs_dir, exist_ok=True) + path = os.path.join(logs_dir, f"rank_{rank}_meta.json") + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f, indent=2) + + +def main() -> int: + parser = argparse.ArgumentParser(description="FERB distributed worker") + parser.add_argument("--backend", required=True, help="reference|triton|numba_cuda") + parser.add_argument("--problem_py", required=True, help="Absolute path to problem file") + parser.add_argument("--logs_dir", required=True, help="Output logs directory") + parser.add_argument("--rows", required=True, type=int) + parser.add_argument("--cols", required=True, type=int) + parser.add_argument("--dtype", default="float32") + parser.add_argument("--problem_id", required=True, type=int) + args = parser.parse_args() + + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + + status = "ok" + error_text = "" + output_path = "" + try: + _init_backend(args.backend, rank, world_size) + + module = _load_solution_module(args.problem_py) + solution_fn = getattr(module, "solution") + + dtype = _dtype_from_string(args.dtype) + base_shape = (args.rows, args.cols) + inputs = create_input_tensor( + rank=rank, + world_size=world_size, + problem_id=args.problem_id, + base_shape=base_shape, + dtype=dtype, + ) + + with torch.no_grad(): + output = solution_fn(*inputs) + output_path = save_tensor(output, args.logs_dir, rank) + + if dist.is_initialized(): + dist.barrier() + except Exception: + status = "error" + error_text = traceback.format_exc() + print(error_text, file=sys.stderr, flush=True) + finally: + _save_rank_metadata( + args.logs_dir, + rank, + { + "status": status, + "rank": rank, + "world_size": world_size, + "backend": args.backend, + "problem_id": args.problem_id, + "problem_py": args.problem_py, + "dtype": args.dtype, + "shape": [args.rows, args.cols], + "output_path": output_path, + "error": error_text, + }, + ) + try: + _finalize_backend(args.backend) + except Exception: + traceback.print_exc() + + return 0 if status == "ok" else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/utils/__pycache__/construct_prompt.cpython-310.pyc b/utils/__pycache__/construct_prompt.cpython-310.pyc deleted file mode 100644 index cadf65f..0000000 Binary files a/utils/__pycache__/construct_prompt.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/construct_prompt.cpython-312.pyc b/utils/__pycache__/construct_prompt.cpython-312.pyc deleted file mode 100644 index b0c8c42..0000000 Binary files a/utils/__pycache__/construct_prompt.cpython-312.pyc and /dev/null differ diff --git a/utils/init_and_finalize_backends.py b/utils/init_and_finalize_backends.py index d84b386..8f90cd4 100644 --- a/utils/init_and_finalize_backends.py +++ b/utils/init_and_finalize_backends.py @@ -15,24 +15,37 @@ # --------------------------------------------------------------------------- def init_reference(rank: int, world_size: int) -> None: - """Initialize torch.distributed with NCCL for reference backend.""" + """Initialize torch.distributed for reference backend. + + - If CUDA is available: uses NCCL and binds each rank to a GPU. + - If CUDA is NOT available (e.g. local Mac dev): uses GLOO on CPU so scripts can run. + """ os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) - torch.cuda.set_device(rank) - try: - dist.init_process_group( - backend="nccl", - init_method="env://", - rank=rank, - world_size=world_size, - device_id=torch.device("cuda", rank), - ) - except TypeError: + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + try: + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size, + device_id=torch.device("cuda", rank), + ) + except TypeError: + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size, + ) + else: + # CPU-only fallback for local development. dist.init_process_group( - backend="nccl", + backend="gloo", init_method="env://", rank=rank, world_size=world_size,