diff --git a/scripts/train_crypto_lora_sweep.py b/scripts/train_crypto_lora_sweep.py index b51845f2..3c3b3523 100644 --- a/scripts/train_crypto_lora_sweep.py +++ b/scripts/train_crypto_lora_sweep.py @@ -53,6 +53,18 @@ def consistency_score(self) -> float: return self.mae_percent_mean + 0.5 * self.mae_percent_std + 0.3 * (self.mae_percent_max - self.mae_percent_mean) +def resolve_data_path(symbol: str, data_root: Path) -> Path: + """Resolve the CSV data path for *symbol* within *data_root*. + + Checks ``data_root/stocks/symbol.csv`` first (common mixed-hourly layout), + then falls back to ``data_root/symbol.csv``. + """ + stocks_candidate = data_root / "stocks" / f"{symbol}.csv" + if stocks_candidate.exists(): + return stocks_candidate + return data_root / f"{symbol}.csv" + + def load_hourly_frame(csv_path: Path) -> pd.DataFrame: df = pd.read_csv(csv_path) df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True, errors="coerce") diff --git a/tests/conftest.py b/tests/conftest.py index b8d04919..131cec70 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -336,6 +336,9 @@ def pytest_ignore_collect(collection_path, config): "tests/test_chronos2_real_data.py", "tests/test_critical_math.py", "tests/test_forecast_cache_metrics.py", + "tests/test_jax_losses.py", + "tests/test_jax_policy.py", + "tests/test_jax_trainer_wandboard.py", "tests/test_maxdiff_price_cache.py", "tests/test_pctdiff_price_cache.py", "tests/test_rl_trainingbinance.py",