Skip to content

Feature/hyperparam search#22

Open
rukubrakov wants to merge 4 commits intodevfrom
feature/hyperparam_search
Open

Feature/hyperparam search#22
rukubrakov wants to merge 4 commits intodevfrom
feature/hyperparam_search

Conversation

@rukubrakov
Copy link
Copy Markdown
Collaborator

@rukubrakov rukubrakov commented Apr 3, 2026

Summary by CodeRabbit

  • New Features

    • Added simba-sweep CLI for Optuna hyperparameter search with trial resumption and starting-parameter seeding.
    • Early stopping via training.early_stopping_patience.
    • Added sweep plotting tool and a SLURM script template for running searches.
  • Documentation

    • New "Hyperparameter Search" guide: config, running sweeps, outputs, seeding, and resume behavior.
  • Tests

    • Added integration and unit tests for sweep orchestration and sampler utilities.
  • Chores

    • Updated .gitignore to exclude sweep/runtime outputs.

…earch space, per-trial checkpoint persistence, and automatic best-model inference
…Add hyperparam search SLURM script to tools/slurm/. Add tools/plot_sweep.py to visualise sweep results and best-model inference quality.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 3, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b10a8e43-ad3e-4809-a6f6-d094ceeb0dac

📥 Commits

Reviewing files that changed from the base of the PR and between 3231b4d and d2c30ed.

📒 Files selected for processing (2)
  • README.md
  • simba/commands/sweep_train.py
✅ Files skipped from review due to trivial changes (1)
  • README.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • simba/commands/sweep_train.py

📝 Walkthrough

Walkthrough

Adds an Optuna-driven hyperparameter sweep: new simba-sweep CLI entrypoint, sweep configs and default params, in-memory Optuna orchestration with trial sampling/resume/persistence, early-stopping integration into training, plotting and SLURM tooling, and unit/integration tests for sweep behavior.

Changes

Cohort / File(s) Summary
Manifest & ignores
\.gitignore, pyproject.toml, simba/configs/sweep/default.yaml, simba/configs/training/default.yaml
Appended ignore patterns for sweep outputs; added runtime deps (optuna, hydra-optuna-sweeper, hydra-submitit-launcher, sqlalchemy) and new CLI script simba-sweep; added default sweep config and training.early_stopping_patience.
Sweep orchestration
simba/commands/sweep_train.py
New Hydra entrypoint sweep_train implementing Optuna distributions, param sampling, per-trial invocation, JSON persistence (trials.json), resume logic, starting_params handling, per-trial params.json writes, and optional best-trial inference execution.
Training workflow updates
simba/commands/train.py, simba/workflows/training.py
Added optional EarlyStopping callback creation, setup_callbacks now returns four items, train accepts early_stopping_callback and now returns the fitted pl.Trainer; updated callers to pass through the new callback.
Tests & CI script
tests/integration/test_sweep_pipeline.py, tests/unit/test_sweep_utils.py, test_all_commands.sh
Added integration/unit tests for sweep sampling, persistence, resume, and starting_params; extended end-to-end test script to run sweep steps and updated step numbering/messages.
Plotting & SLURM tooling
tools/plot_sweep.py, tools/slurm/hyperparam_search.slurm.sh
New plotting utility to visualize trials and best-inference results; added SLURM batch script to run simba-sweep with configurable resources and resume semantics. (Review plotting assumptions about JSON layout and SLURM venv usage.)

Sequence Diagram(s)

sequenceDiagram
    actor User as User/CLI
    participant Hydra as Hydra Config
    participant Optuna as Optuna Study
    participant Runner as Trial Executor
    participant Checkpoint as Checkpoint/Inference

    User->>Hydra: invoke sweep_train(cfg)
    Hydra->>Optuna: create or resume Study (in-memory)
    Optuna-->>Hydra: Study instance

    loop for each trial
        Hydra->>Optuna: study.ask() / sample params
        Optuna-->>Hydra: sampled params
        Hydra->>Runner: _run_trial(cfg + params)
        Runner->>Runner: prepare cfg, dataloaders, callbacks, train
        Runner-->>Hydra: validation_loss + per-trial checkpoint
        Hydra->>Optuna: study.tell(trial, loss)
        Hydra->>Hydra: append/update `trials.json`
    end

    Hydra->>Optuna: select best trial (min loss)
    Optuna-->>Hydra: best trial id & params
    Hydra->>Checkpoint: run inference on best checkpoint (optionally)
    Checkpoint-->>Hydra: best_inference/metrics.json
    Hydra-->>User: print summary + best metrics
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • chevi1989
  • Janne98
  • bittremieux

"🐰
I hop through trials with Optuna's cheer,
I seed and resume, checkpoints near,
Early stops and plots in tow,
SLURM sends jobs, results to show —
A thump for best params found, hip-hop hooray!"

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Feature/hyperparam search' is vague and generic, using placeholder-like naming conventions without clearly describing the main changes introduced in the pull request. Replace with a more descriptive title that clearly summarizes the primary change, such as 'Add Optuna-based hyperparameter sweep CLI tool' or 'Implement hyperparameter sweep orchestration with Optuna integration'.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/hyperparam_search

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (8)
test_all_commands.sh (1)

96-96: Consistent $DEVICE quoting needed throughout.

The same SC2086 concern applies to all other uses of $DEVICE in this script. For consistency and safety, quote all variable expansions passed to command arguments.

Also applies to: 106-106, 118-118, 129-129, 141-141, 146-146, 162-162, 176-176, 192-192

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@test_all_commands.sh` at line 96, Every occurrence of the DEVICE variable is
unquoted when used as a command argument (e.g., the assignment/hint
"hardware.accelerator=$DEVICE" and other uses at the mentioned locations), which
can break when DEVICE contains spaces or glob characters; update every expansion
of DEVICE to use quoted form ("$DEVICE") so arguments are passed safely and
consistently across the script—search for the symbol DEVICE and replace unquoted
$DEVICE usages with "$DEVICE" in each command/assignment.
pyproject.toml (1)

71-71: Consider updating optuna minimum version.

optuna>=2.10.1 is from 2022. Optuna 3.x (current stable) brings performance improvements and better TPE sampling. If no compatibility issues exist, consider optuna>=3.0.0.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pyproject.toml` at line 71, Update the optuna dependency in pyproject.toml
from "optuna>=2.10.1" to "optuna>=3.0.0"; then run tests and CI and search for
usages of the optuna API (imports, Study/TPESampler/optimize calls) to adjust
any breaking changes introduced in Optuna 3.x (update argument names or sampler
constructors if needed) so the project remains compatible.
tools/slurm/hyperparam_search.slurm.sh (2)

79-82: Completion message may be inaccurate if trials fail.

The message claims $N_TRIALS trials completed, but some trials may have failed. Consider checking actual completed count from trials.json:

 echo "======================================================"
-echo "Sweep done! $N_TRIALS trials completed."
+COMPLETED=$(python -c "import json; t=json.load(open('${DB}')); print(sum(1 for x in t if x.get('status')=='completed'))" 2>/dev/null || echo "?")
+echo "Sweep done! ${COMPLETED} of $N_TRIALS trials completed successfully."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/slurm/hyperparam_search.slurm.sh` around lines 79 - 82, The completion
message currently echoes "Sweep done! $N_TRIALS trials completed" which can be
inaccurate if some trials failed; update the script around the echo lines to
compute the actual completed count from the results file referenced by ${DB}
(e.g., parse trials.json) and then print a corrected summary like "Sweep done! X
of $N_TRIALS trials completed (Y failed)"; specifically, replace or augment the
echo that uses $N_TRIALS with a command that counts entries where
status=='completed' and optionally status!='completed' from ${DB} and include
those counts in the final message and the "Best" lookup flow so that the
selection (the python -c line) still filters completed trials only.

25-27: Hardcoded path reduces portability.

The absolute path /scratch/gent/vo/000/gvo00017/vsc21162/simba is specific to one cluster/user. Consider parameterizing this or using $SLURM_SUBMIT_DIR for portability.

-cd /scratch/gent/vo/000/gvo00017/vsc21162/simba
+# Adjust this path or use: cd "${SLURM_SUBMIT_DIR:-/path/to/simba}"
+cd "${SIMBA_ROOT:-/scratch/gent/vo/000/gvo00017/vsc21162/simba}"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/slurm/hyperparam_search.slurm.sh` around lines 25 - 27, Replace the
hardcoded absolute path used in the cd command with a parameterized directory:
define a variable (e.g., BASE_DIR) that defaults to $SLURM_SUBMIT_DIR if not
provided and use that variable for the cd step; keep the mkdir -p logs and
source .venv/bin/activate lines unchanged but ensure paths are relative to
BASE_DIR (reference the cd invocation and the subsequent .venv activation to
locate where to apply the variable).
tools/plot_sweep.py (2)

18-29: Add error handling for empty or invalid sweep results.

If trials.json is empty or contains no trials, np.argmin(values) will fail. Consider adding validation:

 with open(ROOT / "trials.json") as f:
     trials = json.load(f)
+
+if not trials:
+    raise SystemExit("No trials found in trials.json")
+
 with open(ROOT / "best_inference" / "metrics.json") as f:
     inf = json.load(f)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/plot_sweep.py` around lines 18 - 29, The code assumes trials.json
contains trials and uses np.argmin on values which will raise on empty/invalid
data; after loading trials (variable trials) validate it's a non-empty list and
that values (derived from t["value"] for t in trials) contains at least one
numeric entry, otherwise handle the error path (e.g., log/raise a clear error or
set BEST_IDX to None and skip downstream computations). Update the block that
computes ids, values, params, BASELINE_IDX, BEST_IDX, and running_best to
perform these checks and fail fast with a descriptive message or safe defaults
if trials or values are empty or malformed.

14-16: Parameterize the sweep results path.

The hardcoded path requires manual editing for each sweep. Consider accepting the path as a CLI argument:

+import sys
+
-ROOT    = pathlib.Path("/scratch/gent/vo/000/gvo00017/vsc21162/simba/sweeps/run_v1")
+if len(sys.argv) < 2:
+    print("Usage: python plot_sweep.py <sweep_output_dir>")
+    sys.exit(1)
+ROOT = pathlib.Path(sys.argv[1])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/plot_sweep.py` around lines 14 - 16, The script currently hardcodes
ROOT and PLOTDIR (ROOT = pathlib.Path(...); PLOTDIR = ROOT / "plots") which
forces manual edits per sweep; change the script to accept the sweep results
path as a CLI argument (e.g., with argparse) and set ROOT =
pathlib.Path(args.root) with a sensible default matching the current hardcoded
path, then compute PLOTDIR = ROOT / "plots" and create it with
PLOTDIR.mkdir(parents=True, exist_ok=True); update any references to
ROOT/PLOTDIR accordingly so callers can override via CLI.
tests/integration/test_sweep_pipeline.py (1)

1-6: Consider adding a test for failed trial handling.

The tests cover success paths well, but there's no test verifying that a failed trial is recorded with status: "failed" and the sweep continues. This is an important resilience behavior.

Would you like me to draft a test case for failed trial handling?

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration/test_sweep_pipeline.py` around lines 1 - 6, Add an
integration test (e.g., test_failed_trial_handling) that mocks the training step
to raise an exception for one trial, then runs the sweep pipeline and asserts
that the failed trial is recorded with status "failed" in the persisted
trials.json and that the sweep continues to execute subsequent trials (verify
remaining trials complete or are enqueued as expected); specifically mock the
training/execute function used by the sweep orchestration, trigger a failure for
a named trial, run the same pipeline entrypoint used by other tests, then read
the trials persistence and assert presence of a trial record with status:
"failed" and that resume semantics/enqueueing of starting_params behave the same
as in successful runs.
simba/commands/sweep_train.py (1)

298-298: Consider using initialize_config_dir for consistency with other commands.

The @hydra.main decorator on sweep_train initializes Hydra's GlobalHydra, so calling compose() at line 298 should technically work. However, all other commands in the codebase (train.py, preprocess.py, inference.py, analog_discovery.py) use an explicit initialize_config_dir context manager with compose() for consistency and robustness. The test suite skips testing this inference block entirely, which suggests the pattern may benefit from alignment with established practices. Consider wrapping the compose() call in:

from hydra import initialize_config_dir
import simba

config_dir = str(Path(simba.__file__).parent / "configs")
with initialize_config_dir(config_dir=config_dir, version_base=None):
    inf_cfg = compose(config_name="config", overrides=inf_overrides)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@simba/commands/sweep_train.py` at line 298, The compose(config_name="config",
overrides=inf_overrides) call inside sweep_train should be wrapped in a
hydra.initialize_config_dir context for consistency with other commands; import
initialize_config_dir from hydra and compute config_dir from the simba package
path (Path(simba.__file__).parent / "configs"), then use with
initialize_config_dir(config_dir=config_dir, version_base=None): to call
compose() so inf_cfg is created inside that context; update the sweep_train
function to perform this change around the existing compose() invocation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pyproject.toml`:
- Around line 68-71: Move the sweep-specific packages out of the core
dependencies and define a new optional extras group named "sweep" under
[project.optional-dependencies]; specifically remove
"hydra-optuna-sweeper>=1.2.0", "hydra-submitit-launcher>=1.2.0", and
"optuna>=2.10.1" from the main dependencies list and add them to the "sweep"
extra (you may include sqlalchemy there only if it is exclusively needed for
sweep workflows), ensuring the new extras name matches the README instruction
`uv sync --extra sweep`.

In `@README.md`:
- Around line 580-595: The fenced code block showing the directory tree in
README.md lacks a language specifier; update that code fence to include a
language (e.g., "text") so Markdown renderers treat it as plain text (locate the
triple-backtick block that contains the sweeps/run1/ directory tree and change
``` to ```text).

In `@simba/commands/sweep_train.py`:
- Around line 246-252: Guard against ValueError from study.best_trial when all
trials failed: wrap the access to study.best_trial (used to set best, then
display best.value, best.number, best.params) in a try/except catching
ValueError and print a clear fallback message like "No completed trials" (still
print n_trials and trials_file), or alternatively check for completed trials via
study.get_trials(states=[optuna.trial.TrialState.COMPLETE]) before accessing
study.best_trial; update the block that prints best and trials_file accordingly
so the script does not raise when all trials failed.

In `@test_all_commands.sh`:
- Around line 64-86: In test_all_commands.sh, the $DEVICE variable in both "uv
run simba-sweep" invocations should be double-quoted to prevent word-splitting;
update the two occurrences in the script (the uv run simba-sweep commands that
set hardware.accelerator=$DEVICE) to use hardware.accelerator="$DEVICE" so the
value is passed as a single argument even if it contains spaces or special
characters.

In `@tests/unit/test_sweep_utils.py`:
- Line 8: Remove the unused top-level import symbol "math" from the
tests/unit/test_sweep_utils.py import list; locate the import statement that
says "import math" and delete it so the test module no longer references the
unused name and the CI unused-import failure is resolved.
- Around line 36-63: The production code uses removed Optuna classes; update
_build_distributions (in simba/commands/sweep_train.py) to return
optuna.distributions.FloatDistribution for continuous params (use log=True for
previous LogUniformDistribution and log=False for previous UniformDistribution)
and optuna.distributions.IntDistribution for integer params (replacing
IntUniformDistribution); also update any construction parameters (low/high/step)
to match FloatDistribution/IntDistribution signatures, and then adjust the tests
in tests/unit/test_sweep_utils.py to assert instances of
optuna.distributions.FloatDistribution (and check .log flag for log vs. linear),
optuna.distributions.IntDistribution, and
optuna.distributions.CategoricalDistribution as appropriate.

In `@tools/plot_sweep.py`:
- Line 199: The print statement in tools/plot_sweep.py uses an unnecessary
f-string (print(f"\nBest model validation metrics:")) even though there are no
placeholders; change the statement to a plain string by removing the leading "f"
so it becomes print("\nBest model validation metrics:") to avoid misleading use
of formatted strings.
- Around line 6-12: Split combined imports and sort them to satisfy linters:
replace the single-line "import json, pathlib" with two separate imports and
order all imports alphabetically by module group; ensure standard library
imports (json, pathlib) come first, then third-party imports (matplotlib, numpy,
scipy), and within those groups import submodules separately (e.g., "import
numpy as np", "import matplotlib.pyplot as plt", "import matplotlib.patches as
mpatches", "from matplotlib.cm import ScalarMappable", "from matplotlib.colors
import Normalize", "from scipy import stats") so that tools/flake8 and
isort/I001 no longer complain.
- Around line 52-53: The semicolon-separated multiple statements in
tools/plot_sweep.py (e.g., the ax configuration line containing ax.set_xlabel,
ax.set_ylabel, ax.set_title, ax.set_xticks and ax.grid) violate E702; split each
semicolon-separated call into its own line so each statement stands alone (for
example separate ax.set_xlabel("Trial #"), ax.set_ylabel(...),
ax.set_title(...), ax.set_xticks(ids), and ax.grid(...) into five lines). Apply
the same change to all other flagged locations (the grouped statements around
lines with multiple semicolon-separated calls at the noted spots) so every
function call or assignment is on its own line.

---

Nitpick comments:
In `@pyproject.toml`:
- Line 71: Update the optuna dependency in pyproject.toml from "optuna>=2.10.1"
to "optuna>=3.0.0"; then run tests and CI and search for usages of the optuna
API (imports, Study/TPESampler/optimize calls) to adjust any breaking changes
introduced in Optuna 3.x (update argument names or sampler constructors if
needed) so the project remains compatible.

In `@simba/commands/sweep_train.py`:
- Line 298: The compose(config_name="config", overrides=inf_overrides) call
inside sweep_train should be wrapped in a hydra.initialize_config_dir context
for consistency with other commands; import initialize_config_dir from hydra and
compute config_dir from the simba package path (Path(simba.__file__).parent /
"configs"), then use with initialize_config_dir(config_dir=config_dir,
version_base=None): to call compose() so inf_cfg is created inside that context;
update the sweep_train function to perform this change around the existing
compose() invocation.

In `@test_all_commands.sh`:
- Line 96: Every occurrence of the DEVICE variable is unquoted when used as a
command argument (e.g., the assignment/hint "hardware.accelerator=$DEVICE" and
other uses at the mentioned locations), which can break when DEVICE contains
spaces or glob characters; update every expansion of DEVICE to use quoted form
("$DEVICE") so arguments are passed safely and consistently across the
script—search for the symbol DEVICE and replace unquoted $DEVICE usages with
"$DEVICE" in each command/assignment.

In `@tests/integration/test_sweep_pipeline.py`:
- Around line 1-6: Add an integration test (e.g., test_failed_trial_handling)
that mocks the training step to raise an exception for one trial, then runs the
sweep pipeline and asserts that the failed trial is recorded with status
"failed" in the persisted trials.json and that the sweep continues to execute
subsequent trials (verify remaining trials complete or are enqueued as
expected); specifically mock the training/execute function used by the sweep
orchestration, trigger a failure for a named trial, run the same pipeline
entrypoint used by other tests, then read the trials persistence and assert
presence of a trial record with status: "failed" and that resume
semantics/enqueueing of starting_params behave the same as in successful runs.

In `@tools/plot_sweep.py`:
- Around line 18-29: The code assumes trials.json contains trials and uses
np.argmin on values which will raise on empty/invalid data; after loading trials
(variable trials) validate it's a non-empty list and that values (derived from
t["value"] for t in trials) contains at least one numeric entry, otherwise
handle the error path (e.g., log/raise a clear error or set BEST_IDX to None and
skip downstream computations). Update the block that computes ids, values,
params, BASELINE_IDX, BEST_IDX, and running_best to perform these checks and
fail fast with a descriptive message or safe defaults if trials or values are
empty or malformed.
- Around line 14-16: The script currently hardcodes ROOT and PLOTDIR (ROOT =
pathlib.Path(...); PLOTDIR = ROOT / "plots") which forces manual edits per
sweep; change the script to accept the sweep results path as a CLI argument
(e.g., with argparse) and set ROOT = pathlib.Path(args.root) with a sensible
default matching the current hardcoded path, then compute PLOTDIR = ROOT /
"plots" and create it with PLOTDIR.mkdir(parents=True, exist_ok=True); update
any references to ROOT/PLOTDIR accordingly so callers can override via CLI.

In `@tools/slurm/hyperparam_search.slurm.sh`:
- Around line 79-82: The completion message currently echoes "Sweep done!
$N_TRIALS trials completed" which can be inaccurate if some trials failed;
update the script around the echo lines to compute the actual completed count
from the results file referenced by ${DB} (e.g., parse trials.json) and then
print a corrected summary like "Sweep done! X of $N_TRIALS trials completed (Y
failed)"; specifically, replace or augment the echo that uses $N_TRIALS with a
command that counts entries where status=='completed' and optionally
status!='completed' from ${DB} and include those counts in the final message and
the "Best" lookup flow so that the selection (the python -c line) still filters
completed trials only.
- Around line 25-27: Replace the hardcoded absolute path used in the cd command
with a parameterized directory: define a variable (e.g., BASE_DIR) that defaults
to $SLURM_SUBMIT_DIR if not provided and use that variable for the cd step; keep
the mkdir -p logs and source .venv/bin/activate lines unchanged but ensure paths
are relative to BASE_DIR (reference the cd invocation and the subsequent .venv
activation to locate where to apply the variable).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ee1e65da-e30c-49fc-aabc-c26a2621fca3

📥 Commits

Reviewing files that changed from the base of the PR and between b400dba and c1716c6.

📒 Files selected for processing (13)
  • .gitignore
  • README.md
  • pyproject.toml
  • simba/commands/sweep_train.py
  • simba/commands/train.py
  • simba/configs/sweep/default.yaml
  • simba/configs/training/default.yaml
  • simba/workflows/training.py
  • test_all_commands.sh
  • tests/integration/test_sweep_pipeline.py
  • tests/unit/test_sweep_utils.py
  • tools/plot_sweep.py
  • tools/slurm/hyperparam_search.slurm.sh

Comment thread pyproject.toml
Comment thread README.md Outdated
Comment thread simba/commands/sweep_train.py Outdated
Comment thread test_all_commands.sh
Comment thread tests/unit/test_sweep_utils.py Outdated
Comment thread tests/unit/test_sweep_utils.py
Comment thread tools/plot_sweep.py Outdated
Comment thread tools/plot_sweep.py Outdated
Comment thread tools/plot_sweep.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (1)
simba/commands/sweep_train.py (1)

282-288: ⚠️ Potential issue | 🟠 Major

Guard the best-trial summary when nothing completed.

A full-failure sweep is already represented in trials.json, but study.best_trial still raises ValueError in that case. Check for at least one COMPLETE trial before reading best_trial, and skip the summary / inference step gracefully when there are none.

#!/bin/bash
set -euo pipefail

python - <<'PY'
import importlib.util
import subprocess
import sys

if importlib.util.find_spec("optuna") is None:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "optuna"])

import optuna

study = optuna.create_study(direction="minimize")
trial = study.ask()
study.tell(trial, state=optuna.trial.TrialState.FAIL)

try:
    study.best_trial
except Exception as exc:
    print(type(exc).__name__, str(exc))
PY
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@simba/commands/sweep_train.py` around lines 282 - 288, The summary block
currently accesses study.best_trial which raises ValueError if no trials
completed; before reading study.best_trial, check study.trials for any trial
with state == optuna.trial.TrialState.COMPLETE (e.g., any(t.state ==
optuna.trial.TrialState.COMPLETE for t in study.trials)); if none are COMPLETE,
skip printing the best-trial summary and instead print a brief message that the
sweep completed with no successful trials (still print n_trials and
trials_file), otherwise proceed to read study.best_trial and print the existing
best params/value info. Ensure you reference study.trials,
optuna.trial.TrialState.COMPLETE, study.best_trial, n_trials and trials_file in
the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@simba/commands/sweep_train.py`:
- Around line 291-296: The current guard uses save_checkpoints (from
OmegaConf.select) to skip post-sweep best-model inference, but setup_callbacks()
always creates the best-model checkpoint callback so periodic checkpoints flag
is the wrong condition; remove or replace the if-block that checks
save_checkpoints and instead gate post-sweep inference on the actual
best-checkpoint presence or a dedicated config flag: either always run the
inference unless there is no best checkpoint file (check the callback/best model
path produced by setup_callbacks() or existence of the best checkpoint file) or
introduce a new config like checkpoints.run_best_inference and use that. Ensure
references to save_checkpoints, OmegaConf.select(...
"checkpoints.save_checkpoints"), setup_callbacks(), and the best-model
callback/best checkpoint path are updated accordingly so
best_inference/metrics.json is produced when a best checkpoint exists.
- Around line 222-230: The current resume logic re-adds only completed trials
via the previous_trials loop and relies on Optuna-assigned trial.number, causing
ID collisions with checkpoints and wrong best-trial inference; to fix it,
persist the Study with Optuna storage (use persistent storage when creating the
study instead of in-memory) or explicitly preserve a stable artifact id per
persisted trial (e.g., store original prev["number"] or set a user_attr like
"artifact_id" when reconstructing trials instead of relying on trial.number),
stop depending on trial.number for checkpoint paths (update best_trial_dir code
that uses best.number to use the preserved artifact_id), guard accesses to
study.best_trial and checkpoint inference (wrap study.best_trial lookups with a
check for any completed trials from the persisted JSON or check for existence of
checkpoint files) and change the inference skip condition to check actual
checkpoint existence or completed-trials presence rather than
checkpoints.save_checkpoints; update places referenced by previous_trials loop,
study.add_trial, trial_number persistence, best_trial_dir creation, and the
checkpoint/inference gating logic (and the training.py best-model callback
behavior) accordingly.

In `@tools/plot_sweep.py`:
- Around line 27-39: The code currently uses all rows (including failed trials
with value=inf) and treats BASELINE_IDX/BEST_IDX as x positions, which breaks
plotting when trial_number != row index; filter the trials list to only include
completed trials (e.g., status != "failed" or value != math.inf) before building
ids, values, params, running_best and PARAM_KEYS, compute BEST_IDX as the argmin
over the filtered values (or keep BEST_IDX as the filtered index but always map
plot x positions via ids[idx]), and use the filtered min/max values for
Normalize(vmin=..., vmax=...); in short, derive plotting series from
completed_trials and always use ids[idx] whenever an x-coordinate or highlight
position is needed (and update BASELINE_IDX/BEST_IDX usage to refer to indices
in that filtered series).
- Around line 18-25: ROOT is hard-coded and file I/O happens at import; change
the script to accept the run directory as a CLI argument (e.g., via argparse)
and derive PLOTDIR = run_dir / "plots" instead of the fixed path, create PLOTDIR
and load ROOT / "trials.json" and ROOT / "best_inference" / "metrics.json" only
inside a main function (or under if __name__ == "__main__") so that variables
like ROOT, PLOTDIR, trials, and inf are initialized at runtime rather than
import time.
- Around line 35-36: The code assumes a fixed five-parameter numeric schema by
building PARAM_KEYS from params[0].keys() and hard-coding PARAM_SHORT, then
forcing columns through dtype=float and log10; instead, derive parameter labels
and types from the sweep metadata (sweep.params) or validate the schema before
numeric conversion: read explicit param names/types from sweep.params to build
PARAM_KEYS/PARAM_SHORT dynamically, detect categorical vs numeric and
skip/encode categoricals (do not cast them to float or apply log10), and if you
truly require exactly five numeric params, validate that sweep.params contains
exactly those numeric keys and abort with a clear error message; update every
place that currently casts to float/log10 to use the validated/detected types.

---

Duplicate comments:
In `@simba/commands/sweep_train.py`:
- Around line 282-288: The summary block currently accesses study.best_trial
which raises ValueError if no trials completed; before reading study.best_trial,
check study.trials for any trial with state == optuna.trial.TrialState.COMPLETE
(e.g., any(t.state == optuna.trial.TrialState.COMPLETE for t in study.trials));
if none are COMPLETE, skip printing the best-trial summary and instead print a
brief message that the sweep completed with no successful trials (still print
n_trials and trials_file), otherwise proceed to read study.best_trial and print
the existing best params/value info. Ensure you reference study.trials,
optuna.trial.TrialState.COMPLETE, study.best_trial, n_trials and trials_file in
the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e7549701-d876-451d-b8e9-d77a96bc3778

📥 Commits

Reviewing files that changed from the base of the PR and between c1716c6 and 3231b4d.

📒 Files selected for processing (6)
  • simba/commands/sweep_train.py
  • simba/commands/train.py
  • simba/workflows/training.py
  • tests/integration/test_sweep_pipeline.py
  • tests/unit/test_sweep_utils.py
  • tools/plot_sweep.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • tests/integration/test_sweep_pipeline.py
  • simba/workflows/training.py
  • tests/unit/test_sweep_utils.py

Comment thread simba/commands/sweep_train.py
Comment thread simba/commands/sweep_train.py
Comment thread tools/plot_sweep.py
Comment thread tools/plot_sweep.py
Comment thread tools/plot_sweep.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant