diff --git a/agent/main.py b/agent/main.py index 25d0859b..46ee3f2e 100644 --- a/agent/main.py +++ b/agent/main.py @@ -30,7 +30,7 @@ from agent.core.session import OpType from agent.core.tools import ToolRouter from agent.messaging.gateway import NotificationGateway -from agent.utils.reliability_checks import check_training_script_save_pattern +from agent.utils.reliability_checks import format_finding, run_preflight_checks from agent.utils.terminal_display import ( get_console, print_approval_header, @@ -513,10 +513,6 @@ def _cancel_event(): if script_args: print(f"Script args: {' '.join(script_args)}") - # Run reliability checks on the full script (not truncated) - check_message = check_training_script_save_pattern(script) - if check_message: - print(check_message) elif command: # Docker mode image = arguments.get("image", "python:3.12") @@ -544,6 +540,9 @@ def _cancel_event(): if schedule: print(f"Schedule: {schedule}") + for finding in run_preflight_checks(arguments): + print(format_finding(finding)) + elif tool_name == "hf_private_repos": # Handle private repo operations args = _safe_get_args(arguments) diff --git a/agent/utils/reliability_checks.py b/agent/utils/reliability_checks.py index 3ed76d72..522e5631 100644 --- a/agent/utils/reliability_checks.py +++ b/agent/utils/reliability_checks.py @@ -1,14 +1,169 @@ -"""Reliability checks for job submissions and other operations""" +"""Static pre-flight checks for hf_jobs submissions. +Each check is pure substring inspection on the arguments dict the agent is +about to send — no network calls, no imports of training libraries. Findings +are advisory: the CLI prints them at the approval prompt; nothing is blocked. -def check_training_script_save_pattern(script: str) -> str | None: - """Check if a training script properly saves models.""" +The five failure modes covered are documented in +``agent/prompts/system_prompt_v3.yaml`` (see lines 29-47, 65-70). +""" + +from dataclasses import dataclass +from typing import Literal + +_RED = "\033[91m" +_GREEN = "\033[92m" +_RESET = "\033[0m" + +# Substrings that strongly indicate a training entry point. Conservative on +# purpose: a script that only does ``model.generate(...)`` should not trip +# the timeout or trackio checks. +_TRAINER_PATTERNS = ( + "Trainer(", + "SFTTrainer(", + "GRPOTrainer(", + "DPOTrainer(", + "trainer.train(", +) + + +@dataclass(frozen=True) +class Finding: + severity: Literal["warn", "info"] + message: str + + +def format_finding(finding: Finding) -> str: + """Render a finding as a single colored line for terminal output.""" + color = _RED if finding.severity == "warn" else _GREEN + return f"\n{color}{finding.message}{_RESET}" + + +def run_preflight_checks(arguments: dict) -> list[Finding]: + """Run every static check against an hf_jobs ``arguments`` dict. + + ``arguments`` is the same dict already in scope at the CLI approval + prompt: keys include ``script``, ``command``, ``dependencies``, + ``hardware_flavor``, ``timeout``, ``env``, ``schedule``. Script-parsing + checks self-skip when the job is in Docker mode (no ``script``). + """ + findings: list[Finding] = [] + script = arguments.get("script") or "" + + if script: + if (f := _check_save_pattern(script)) is not None: + findings.append(f) + if (f := _check_hub_model_id(script)) is not None: + findings.append(f) + if (f := _check_flash_attn(arguments)) is not None: + findings.append(f) + if (f := _check_trackio(arguments)) is not None: + findings.append(f) + + if (f := _check_timeout(arguments)) is not None: + findings.append(f) + + return findings + + +def _check_save_pattern(script: str) -> Finding | None: has_from_pretrained = "from_pretrained" in script has_push_to_hub = "push_to_hub" in script + has_local_save = "trainer.save_model" in script or "save_pretrained" in script + + if not has_from_pretrained: + return None + if has_push_to_hub: + return Finding("info", "Model will be pushed to hub after training.") + if has_local_save: + return Finding( + "warn", + "Model is saved locally but not pushed to hub. hf_jobs storage is " + "ephemeral — add push_to_hub=True to keep the model.", + ) + return Finding( + "warn", + "No model save detected in this script. Ensure this is intentional.", + ) + + +def _check_timeout(arguments: dict) -> Finding | None: + # The hf_jobs default is 30m (see agent/main.py: arguments.get("timeout", "30m")). + # Treat both an explicit "30m" and a missing timeout the same way. + timeout = arguments.get("timeout") or "30m" + if timeout != "30m": + return None + + script = arguments.get("script", "") or "" + command = arguments.get("command") or "" + command_text = " ".join(command) if isinstance(command, list) else str(command) + text = f"{script}\n{command_text}" + if not any(pat in text for pat in _TRAINER_PATTERNS): + return None + + return Finding( + "warn", + "Default 30m timeout with a training call — training takes hours and " + "the job will be killed mid-run. Set timeout explicitly (e.g. '6h').", + ) + + +def _check_hub_model_id(script: str) -> Finding | None: + # Only the TrainingArguments config form (push_to_hub=True) requires a + # matching hub_model_id keyword. The method-call form + # ``trainer.push_to_hub("me/foo")`` carries the destination inline and + # must not trip this warning. + if "push_to_hub=True" not in script.replace(" ", ""): + return None + if "hub_model_id" in script: + return None + return Finding( + "warn", + "push_to_hub=True is set without hub_model_id — the model will land " + "at a default repo path. Set hub_model_id explicitly.", + ) - if has_from_pretrained and not has_push_to_hub: - return "\n\033[91mWARNING: No model save detected in this script. Ensure this is intentional.\033[0m" - elif has_from_pretrained and has_push_to_hub: - return "\n\033[92mModel will be pushed to hub after training.\033[0m" - return None +def _check_flash_attn(arguments: dict) -> Finding | None: + # system_prompt_v3.yaml:45 now steers users away from compiling + # flash-attn from source: "Do NOT pip install 'flash-attn'… Instead, + # use the HF kernels library and attn_implementation= + # 'kernels-community/flash-attn2'". Fire on the legacy literal + # regardless of deps — building from source is slow and fragile. + script = arguments.get("script", "") or "" + if "flash_attention_2" not in script: + return None + return Finding( + "warn", + 'Script uses attn_implementation="flash_attention_2" — building ' + "flash-attn from source is slow and often fails on the job's CUDA " + "build. Prefer attn_implementation=\"kernels-community/flash-attn2\" " + "which loads a prebuilt kernel from the Hub.", + ) + + +def _check_trackio(arguments: dict) -> Finding | None: + script = arguments.get("script", "") or "" + if not any(pat in script for pat in _TRAINER_PATTERNS): + return None + if "trackio" in script.lower(): + return None + return Finding( + "info", + 'Training script does not configure report_to="trackio" — ' + "you will have no live training metrics.", + ) + + +# --------------------------------------------------------------------------- +# Backward-compatible legacy entry point. +# --------------------------------------------------------------------------- + +def check_training_script_save_pattern(script: str) -> str | None: + """Legacy single-string API. Kept so older imports keep working. + + Prefer ``run_preflight_checks(arguments)`` in new code — it returns + structured findings for every check, not just the save-pattern one. + """ + f = _check_save_pattern(script) + return format_finding(f) if f is not None else None diff --git a/tests/unit/test_reliability_checks.py b/tests/unit/test_reliability_checks.py new file mode 100644 index 00000000..61a81a5c --- /dev/null +++ b/tests/unit/test_reliability_checks.py @@ -0,0 +1,247 @@ +"""Tests for the static pre-flight checks run at the hf_jobs approval prompt.""" + +import pytest + +from agent.utils.reliability_checks import ( + Finding, + check_training_script_save_pattern, + format_finding, + run_preflight_checks, +) + + +# ── Finding / format_finding ──────────────────────────────────────────── + + +def test_finding_is_frozen(): + f = Finding("warn", "msg") + with pytest.raises(Exception): + f.message = "other" # type: ignore[misc] + + +def test_format_finding_uses_red_for_warn(): + out = format_finding(Finding("warn", "boom")) + assert "\033[91m" in out and "boom" in out and out.endswith("\033[0m") + + +def test_format_finding_uses_green_for_info(): + out = format_finding(Finding("info", "ok")) + assert "\033[92m" in out and "ok" in out and out.endswith("\033[0m") + + +# ── save-pattern check (system_prompt_v3.yaml:39) ────────────────────── + + +def test_save_pattern_warns_when_from_pretrained_without_push(): + findings = run_preflight_checks({"script": "model = AutoModel.from_pretrained('x')"}) + assert any(f.severity == "warn" and "No model save" in f.message for f in findings) + + +def test_save_pattern_info_when_push_to_hub_present(): + script = "AutoModel.from_pretrained('x'); trainer.push_to_hub()" + findings = run_preflight_checks({"script": script}) + assert any(f.severity == "info" and "pushed to hub" in f.message for f in findings) + + +def test_save_pattern_warns_on_local_save_without_push(): + script = "AutoModel.from_pretrained('x'); trainer.save_model('out')" + findings = run_preflight_checks({"script": script}) + assert any( + f.severity == "warn" and "ephemeral" in f.message for f in findings + ) + + +def test_save_pattern_silent_when_no_from_pretrained(): + findings = run_preflight_checks({"script": "print('hello')"}) + assert all("save" not in f.message.lower() for f in findings) + + +# ── timeout check (system_prompt_v3.yaml:37) ─────────────────────────── + + +@pytest.mark.parametrize("trainer_pattern", [ + "Trainer(model=m)", + "SFTTrainer(model=m)", + "GRPOTrainer(model=m)", + "DPOTrainer(model=m)", + "trainer.train(", +]) +def test_timeout_warns_on_default_with_training_call(trainer_pattern): + findings = run_preflight_checks({"script": trainer_pattern, "timeout": "30m"}) + assert any( + f.severity == "warn" and "30m timeout" in f.message for f in findings + ) + + +def test_timeout_warns_when_timeout_missing_entirely(): + # Missing timeout is treated as the default 30m. + findings = run_preflight_checks({"script": "Trainer(model=m)"}) + assert any("30m timeout" in f.message for f in findings) + + +def test_timeout_silent_when_explicitly_set_long(): + findings = run_preflight_checks({"script": "Trainer(model=m)", "timeout": "6h"}) + assert all("30m timeout" not in f.message for f in findings) + + +def test_timeout_silent_when_no_training_call(): + findings = run_preflight_checks({"script": "model.generate(x)", "timeout": "30m"}) + assert all("30m timeout" not in f.message for f in findings) + + +def test_timeout_check_runs_for_docker_mode(): + findings = run_preflight_checks({ + "command": ["python", "-c", "from trl import SFTTrainer; SFTTrainer(...)"], + "timeout": "30m", + }) + assert any("30m timeout" in f.message for f in findings) + + +# ── hub_model_id check (system_prompt_v3.yaml:39) ────────────────────── + + +def test_hub_model_id_warns_when_pushing_without_id(): + script = ( + "AutoModel.from_pretrained('x')\n" + "args = TrainingArguments(push_to_hub=True)" + ) + findings = run_preflight_checks({"script": script}) + assert any( + f.severity == "warn" and "hub_model_id" in f.message for f in findings + ) + + +def test_hub_model_id_silent_for_method_call_with_inline_repo(): + # ``trainer.push_to_hub("me/foo")`` carries the destination inline; the + # check must not fire on this form. + script = "AutoModel.from_pretrained('x'); trainer.push_to_hub('me/foo')" + findings = run_preflight_checks({"script": script}) + assert all("hub_model_id" not in f.message for f in findings) + + +def test_hub_model_id_silent_when_id_present(): + script = ( + "AutoModel.from_pretrained('x')\n" + "args = TrainingArguments(push_to_hub=True, hub_model_id='me/foo')" + ) + findings = run_preflight_checks({"script": script}) + assert all("hub_model_id" not in f.message for f in findings) + + +def test_hub_model_id_silent_when_push_explicitly_disabled(): + script = "AutoModel.from_pretrained('x')\nargs = TrainingArguments(push_to_hub=False)" + findings = run_preflight_checks({"script": script}) + assert all("hub_model_id" not in f.message for f in findings) + + +# ── flash-attn check (system_prompt_v3.yaml:45) ──────────────────────── + + +def test_flash_attn_warns_on_legacy_literal_even_with_dep(): + # Per system_prompt_v3.yaml:45 the guidance is to avoid building + # flash-attn from source entirely. The check fires on the legacy + # ``attn_implementation="flash_attention_2"`` literal regardless of + # whether flash-attn is in deps. + script = 'AutoModel.from_pretrained("x", attn_implementation="flash_attention_2")' + findings = run_preflight_checks({ + "script": script, + "dependencies": ["transformers", "flash-attn"], + }) + assert any( + f.severity == "warn" and "kernels-community/flash-attn2" in f.message + for f in findings + ) + + +def test_flash_attn_warns_when_dep_missing(): + script = 'model = AutoModel.from_pretrained("x", attn_implementation="flash_attention_2")' + findings = run_preflight_checks({"script": script, "dependencies": ["transformers"]}) + assert any( + f.severity == "warn" and "kernels-community/flash-attn2" in f.message + for f in findings + ) + + +def test_flash_attn_silent_for_kernels_community_form(): + # The recommended form must not trip the warning. Note the dash in + # "flash-attn2" vs the underscore in the legacy "flash_attention_2". + script = ( + 'AutoModel.from_pretrained("x", ' + 'attn_implementation="kernels-community/flash-attn2")' + ) + findings = run_preflight_checks({"script": script, "dependencies": []}) + assert all("flash_attention_2" not in f.message for f in findings) + + +def test_flash_attn_silent_when_not_used(): + findings = run_preflight_checks({ + "script": "AutoModel.from_pretrained('x')", + "dependencies": [], + }) + assert all("flash_attention_2" not in f.message for f in findings) + + +# ── trackio check (system_prompt_v3.yaml:65-70) ──────────────────────── + + +def test_trackio_info_when_training_without_trackio(): + findings = run_preflight_checks({"script": "Trainer(model=m).train()", "timeout": "6h"}) + assert any( + f.severity == "info" and "trackio" in f.message for f in findings + ) + + +def test_trackio_silent_when_trackio_configured(): + script = 'args = TrainingArguments(report_to="trackio")\nTrainer(model=m).train()' + findings = run_preflight_checks({"script": script, "timeout": "6h"}) + assert all("trackio" not in f.message for f in findings) + + +def test_trackio_silent_for_inference_only(): + findings = run_preflight_checks({"script": "model.generate(x)", "timeout": "6h"}) + assert all("trackio" not in f.message for f in findings) + + +# ── Docker mode / overall integration ────────────────────────────────── + + +def test_docker_mode_skips_script_parsing_checks(): + # No `script` key. Only the timeout check applies; the others must self-skip. + findings = run_preflight_checks({"command": ["python", "infer.py"], "timeout": "6h"}) + assert findings == [] + + +def test_empty_arguments_returns_no_findings(): + assert run_preflight_checks({}) == [] + + +def test_findings_are_emitted_in_documented_order(): + # When several checks fire on one script, the save-pattern finding comes + # before the timeout finding. The CLI relies on this for stable output. + script = "AutoModel.from_pretrained('x')\nTrainer(model=m)" + findings = run_preflight_checks({"script": script, "timeout": "30m"}) + severities = [f.message for f in findings] + save_idx = next(i for i, m in enumerate(severities) if "save" in m.lower()) + timeout_idx = next(i for i, m in enumerate(severities) if "30m timeout" in m) + assert save_idx < timeout_idx + + +# ── Legacy wrapper (back-compat) ─────────────────────────────────────── + + +def test_legacy_wrapper_returns_warning_string_when_no_save(): + out = check_training_script_save_pattern("AutoModel.from_pretrained('x')") + assert out is not None + assert "\033[91m" in out and "No model save" in out + + +def test_legacy_wrapper_returns_info_string_when_pushing(): + out = check_training_script_save_pattern( + "AutoModel.from_pretrained('x'); push_to_hub()" + ) + assert out is not None + assert "\033[92m" in out and "pushed to hub" in out + + +def test_legacy_wrapper_returns_none_for_plain_script(): + assert check_training_script_save_pattern("print('hi')") is None