diff --git a/.gitignore b/.gitignore
index 7cf0788..5b28a28 100644
--- a/.gitignore
+++ b/.gitignore
@@ -156,4 +156,6 @@ files/ # Ignore specific directory if generated
# OS generated files
.DS_Store
-Thumbs.db
\ No newline at end of file
+Thumbs.db
+.trieye_data
+.trieye_data/
\ No newline at end of file
diff --git a/.resumed.txt b/.resumed.txt
new file mode 100644
index 0000000..ad8313b
--- /dev/null
+++ b/.resumed.txt
@@ -0,0 +1,8640 @@
+File: LICENSE
+MIT License
+
+Copyright (c) 2025 Luis Guilherme P. M.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+File: requirements.txt
+trianglengin>=2.0.7
+trimcts>=1.2.1
+trieye>=0.1.2 # Added Trieye
+numpy>=1.20.0
+torch>=2.0.0
+torchvision>=0.11.0
+cloudpickle>=2.0.0
+numba>=0.55.0
+mlflow>=1.20.0
+ray[default]>=2.8.0 # Keep full ray
+pydantic>=2.0.0
+typing_extensions>=4.0.0
+typer[all]>=0.9.0
+tensorboard>=2.10.0
+rich>=13.0.0 #
+
+File: pyproject.toml
+# File: pyproject.toml
+[build-system]
+requires = ["setuptools>=61.0"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "alphatriangle"
+version = "1.9.0" # Incremented version for Trieye integration
+authors = [{ name="Luis Guilherme P. M.", email="lgpelin92@gmail.com" }]
+description = "AlphaZero implementation for a triangle puzzle game (uses trianglengin v2+, trimcts, and Trieye for stats/persistence)." # Updated description
+readme = "README.md"
+license = { file="LICENSE" }
+requires-python = ">=3.10"
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: OS Independent",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Games/Entertainment :: Puzzle Games",
+ "Development Status :: 4 - Beta",
+]
+dependencies = [
+ # --- Core Dependencies ---
+ "trianglengin>=2.0.7",
+ "trimcts>=1.2.1",
+ # --- Stats & Persistence (NEW) ---
+ "trieye>=0.1.3", # Added Trieye
+ # --- RL/ML specific dependencies ---
+ "numpy>=1.20.0",
+ "torch>=2.0.0",
+ "torchvision>=0.11.0",
+ "cloudpickle>=2.0.0", # Still needed by Trieye/Ray
+ "numba>=0.55.0",
+ "mlflow>=1.20.0", # Still needed by Trieye
+ "ray[default]", # Still needed by Trieye and workers
+ "pydantic>=2.0.0", # Still needed by Trieye and configs
+ "typing_extensions>=4.0.0", # Still needed by Trieye and configs
+ "typer[all]>=0.9.0",
+ "tensorboard>=2.10.0", # Still needed by Trieye
+ # --- CLI Enhancement ---
+ "rich>=13.0.0",
+]
+
+[project.urls]
+"Homepage" = "https://github.com/lguibr/alphatriangle"
+"Bug Tracker" = "https://github.com/lguibr/alphatriangle/issues"
+
+[project.scripts]
+alphatriangle = "alphatriangle.cli:app"
+
+[project.optional-dependencies]
+dev = [
+ "pytest>=7.0.0",
+ "pytest-cov>=3.0.0",
+ "pytest-mock>=3.0.0",
+ "ruff",
+ "mypy",
+ "build",
+ "twine",
+ "codecov",
+]
+
+[tool.setuptools.packages.find]
+# No 'where' needed, find searches from the project root by default
+
+[tool.setuptools.package-data]
+"*" = ["*.txt", "*.md", "*.json"]
+
+# --- Tool Configurations ---
+[tool.ruff]
+line-length = 88
+[tool.ruff.lint]
+select = ["E", "W", "F", "I", "UP", "B", "C4", "ARG", "SIM", "TCH", "PTH", "NPY"]
+ignore = ["E501", "B904"] # Ignore line length and raise from None for now
+[tool.ruff.format]
+quote-style = "double"
+[tool.mypy]
+python_version = "3.10"
+warn_return_any = true
+warn_unused_configs = true
+ignore_missing_imports = true
+# Add explicit Any disallowance if desired for stricter checking
+# disallow_any_unimported = true
+# disallow_any_expr = true
+# disallow_any_decorated = true
+# disallow_any_explicit = true
+# disallow_any_generics = true
+# disallow_subclassing_any = true
+
+[tool.pytest.ini_options]
+minversion = "7.0"
+addopts = "-ra -q --cov=alphatriangle --cov-report=term-missing"
+testpaths = ["tests"]
+[tool.coverage.run]
+omit = [
+ "alphatriangle/cli.py",
+ "alphatriangle/logging_config.py",
+ "alphatriangle/config/*", # Keep config omit
+ "alphatriangle/utils/types.py",
+ "alphatriangle/rl/types.py",
+ "*/__init__.py",
+ "*/README.md",
+ "run_*.py",
+ "alphatriangle/rl/self_play/mcts_helpers.py",
+]
+[tool.coverage.report]
+fail_under = 45 # Lowered threshold
+show_missing = true
+
+File: analyze_profiles.py
+# File: analyze_profiles.py
+import logging
+import pstats
+from pathlib import Path
+from pstats import SortKey
+from typing import Annotated
+
+import typer
+
+# Configure logging for the script itself
+logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
+logger = logging.getLogger(__name__)
+
+app = typer.Typer(
+ help="Analyzes cProfile output files (.prof) generated by AlphaTriangle workers."
+)
+
+# Define annotations for Typer arguments/options
+ProfilePathArg = Annotated[
+ Path,
+ typer.Argument(
+ ..., # Ellipsis makes it a required argument
+ help="Path to the .prof file to analyze.",
+ exists=True,
+ file_okay=True,
+ dir_okay=False,
+ readable=True,
+ resolve_path=True,
+ ),
+]
+
+# --- Corrected Option Definition (Default in Function Signature) ---
+# Remove default from Option(), add it to the function parameter below
+NumLinesOption = Annotated[
+ int,
+ typer.Option("--lines", "-n", help="Number of lines to show."),
+]
+# --- End Correction ---
+
+
+@app.command()
+def analyze(
+ profile_path: ProfilePathArg, # Use the annotation
+ num_lines: NumLinesOption = 30, # Assign default value here
+):
+ """
+ Loads a .prof file and prints statistics sorted by cumulative and total time.
+ """
+ logger.info(f"Analyzing profile: {profile_path}")
+
+ try:
+ # Ensure profile_path is converted to string for pstats
+ p = pstats.Stats(str(profile_path))
+ except Exception as e:
+ logger.error(f"Error loading profile stats from {profile_path}: {e}")
+ # Add 'from e' to preserve original exception context
+ raise typer.Exit(code=1) from e
+
+ # Remove directory paths for cleaner output
+ p.strip_dirs()
+
+ print("\n" + "=" * 30)
+ print(f" Top {num_lines} Functions by Cumulative Time")
+ print("=" * 30)
+ try:
+ p.sort_stats(SortKey.CUMULATIVE).print_stats(num_lines)
+ except Exception as e:
+ logger.error(f"Error sorting/printing by cumulative time: {e}")
+
+ print("\n" + "=" * 30)
+ print(f" Top {num_lines} Functions by Total Internal Time")
+ print("=" * 30)
+ try:
+ p.sort_stats(SortKey.TIME).print_stats(num_lines) # SortKey.TIME is Tottime
+ except Exception as e:
+ logger.error(f"Error sorting/printing by total time: {e}")
+
+ logger.info(f"Finished analyzing {profile_path}")
+
+
+if __name__ == "__main__":
+ app()
+
+
+File: MANIFEST.in
+
+# File: MANIFEST.in
+include README.md
+include LICENSE
+include requirements.txt
+graft alphatriangle
+graft tests
+include .python-version
+include pyproject.toml
+global-exclude __pycache__
+global-exclude *.py[co]
+# Remove pruned directories
+prune alphatriangle/visualization
+prune alphatriangle/interaction
+# REMOVE MCTS pruning
+# prune alphatriangle/mcts
+# Remove Trieye-replaced directories
+prune alphatriangle/stats
+prune alphatriangle/data
+# Remove pruned files
+global-exclude alphatriangle/app.py
+# Remove pruned test directories
+prune tests/visualization
+prune tests/interaction
+# REMOVE MCTS test pruning
+# prune tests/mcts
+# Remove Trieye-replaced test directories
+prune tests/stats
+prune tests/data
+# Remove pruned core files
+global-exclude alphatriangle/rl/core/visual_state_actor.py
+# REMOVE test_save_resume.py
+global-exclude tests/training/test_save_resume.py
+
+File: README.md
+
+[](https://github.com/lguibr/alphatriangle/actions/workflows/ci_cd.yml) - [](https://codecov.io/gh/lguibr/alphatriangle) - [](https://badge.fury.io/py/alphatriangle)[](https://opensource.org/licenses/MIT) - [](https://www.python.org/downloads/)
+
+# AlphaTriangle
+
+
+
+## Overview
+
+AlphaTriangle is a project implementing an artificial intelligence agent based on AlphaZero principles to learn and play a custom puzzle game involving placing triangular shapes onto a grid. The agent learns through **headless self-play reinforcement learning**, guided by Monte Carlo Tree Search (MCTS) and a deep neural network (PyTorch).
+
+**Key Features:**
+
+* **Core Game Logic:** Uses the [`trianglengin>=2.0.7`](https://github.com/lguibr/trianglengin) library for the triangle puzzle game rules and state management, featuring a high-performance C++ core.
+* **High-Performance MCTS:** Integrates the [`trimcts>=1.2.1`](https://github.com/lguibr/trimcts) library, providing a **C++ implementation of MCTS** for efficient search, callable from Python. MCTS parameters are configurable via `alphatriangle/config/mcts_config.py`.
+* **Deep Learning Model:** Features a PyTorch neural network with policy and distributional value heads, convolutional layers, and **optional Transformer Encoder layers**.
+* **Parallel Self-Play:** Leverages **Ray** for distributed self-play data generation across multiple CPU cores. **The number of workers automatically adjusts based on detected CPU cores (reserving some for stability), capped by the `NUM_SELF_PLAY_WORKERS` setting in `TrainConfig`.**
+* **Asynchronous Stats & Persistence (NEW):** Uses the [`trieye>=0.1.2`](https://github.com/lguibr/trieye) library, which provides a dedicated **Ray actor (`TrieyeActor`)** for:
+ * Asynchronous collection of raw metric events from workers and the training loop.
+ * Configurable processing and aggregation of metrics.
+ * Logging processed metrics to **MLflow** and **TensorBoard**.
+ * Saving/loading training state (checkpoints, buffers) to the filesystem and logging artifacts to MLflow.
+ * Handling auto-resumption from previous runs.
+ * All persistent data managed by `trieye` is stored within the `.trieye_data/` directory (default: `.trieye_data/alphatriangle`).
+* **Headless Training:** Focuses on a command-line interface for running the training pipeline without visual output.
+* **Enhanced CLI:** Uses the **`rich`** library for improved visual feedback (colors, panels, emojis) in the terminal.
+* **Centralized Logging:** Uses Python's standard `logging` module configured centrally for consistent log formatting (including `▲` prefix) and level control across the project. Run logs are saved to `.trieye_data//runs//logs/`.
+* **Optional Profiling:** Supports profiling worker 0 using `cProfile` via a command-line flag. Profile data is saved to `.trieye_data//runs//profile_data/`.
+* **Unit Tests:** Includes tests for RL components.
+
+---
+
+## 🎮 The Triangle Puzzle Game Guide 🧩
+
+This project trains an agent to play the game defined by the `trianglengin` library. The rules are detailed in the [`trianglengin` README](https://github.com/lguibr/trianglengin/blob/main/README.md#the-ultimate-triangle-puzzle-guide-).
+
+---
+
+## Core Technologies
+
+* **Python 3.10+**
+* **trianglengin>=2.0.7:** Core game engine (state, actions, rules) with C++ optimizations.
+* **trimcts>=1.2.1:** High-performance C++ MCTS implementation with Python bindings.
+* **trieye>=0.1.2:** Asynchronous statistics collection, processing, logging (MLflow/TensorBoard), and data persistence via a Ray actor.
+* **PyTorch:** For the deep learning model (CNNs, **optional Transformers**, Distributional Value Head) and training, with CUDA/MPS support.
+* **NumPy:** For numerical operations, especially state representation (used by `trianglengin` and features).
+* **Ray:** For parallelizing self-play data generation and hosting the `TrieyeActor`. **Dynamically scales worker count based on available cores.**
+* **Numba:** (Optional, used in `features.grid_features`) For performance optimization of specific grid calculations.
+* **Cloudpickle:** For serializing the experience replay buffer and training checkpoints (used by `trieye`).
+* **MLflow:** For logging parameters, metrics, and artifacts (checkpoints, buffers) during training runs. **Provides the primary web UI dashboard for experiment management. Data stored in `.trieye_data//mlruns/`.** Managed by `trieye`.
+* **TensorBoard:** For visualizing metrics during training (e.g., detailed loss curves). **Provides a secondary web UI dashboard. Data stored in `.trieye_data//runs//tensorboard/`.** Managed by `trieye`.
+* **Pydantic:** For configuration management and data validation (used by `alphatriangle` and `trieye`).
+* **Typer:** For the command-line interface.
+* **Rich:** For enhanced CLI output formatting and styling.
+* **Pytest:** For running unit tests.
+
+## Project Structure
+
+```markdown
+.
+├── .github/workflows/ # GitHub Actions CI/CD
+│ └── ci_cd.yml
+├── .trieye_data/ # Root directory for ALL persistent data (GITIGNORED) - Managed by Trieye
+│ └── alphatriangle/ # Default app_name
+│ ├── mlruns/ # MLflow internal tracking data & artifact store (for MLflow UI)
+│ │ └── /
+│ │ └── /
+│ │ ├── artifacts/ # MLflow's copy of logged artifacts (checkpoints, buffers, etc.)
+│ │ ├── metrics/
+│ │ ├── params/
+│ │ └── tags/
+│ └── runs/ # Local artifacts per run (source for TensorBoard UI & resume)
+│ └── / # e.g., train_YYYYMMDD_HHMMSS
+│ ├── checkpoints/ # Saved model weights & optimizer states (*.pkl)
+│ ├── buffers/ # Saved experience replay buffers (*.pkl)
+│ ├── logs/ # Plain text log files for the run (*.log) - App + Trieye logs
+│ ├── tensorboard/ # TensorBoard log files (event files)
+│ ├── profile_data/ # cProfile output files (*.prof) if profiling enabled
+│ └── configs.json # Copy of run configuration
+├── alphatriangle/ # Source code for the AlphaZero agent package
+│ ├── __init__.py
+│ ├── cli.py # CLI logic (train, ml, tb, ray commands - headless only, uses Rich)
+│ ├── config/ # Pydantic configuration models (Model, Train, MCTS) - Stats/Persistence now in Trieye
+│ │ └── README.md
+│ ├── features/ # Feature extraction logic (operates on trianglengin.GameState)
+│ │ └── README.md
+│ ├── logging_config.py # Centralized logging setup function (for console)
+│ ├── nn/ # Neural network definition and wrapper (implements trimcts.AlphaZeroNetworkInterface)
+│ │ └── README.md
+│ ├── rl/ # RL components (Trainer, Buffer, Worker using trimcts)
+│ │ └── README.md
+│ ├── training/ # Training orchestration (Loop, Setup, Runner, WorkerManager) - Interacts with TrieyeActor
+│ │ └── README.md
+│ └── utils/ # Shared utilities and types (specific to AlphaTriangle)
+│ └── README.md
+├── tests/ # Unit tests (for alphatriangle components, excluding MCTS, Stats, Data)
+│ ├── conftest.py
+│ ├── nn/
+│ ├── rl/
+│ └── training/
+├── .gitignore
+├── .python-version
+├── LICENSE # License file (MIT)
+├── MANIFEST.in # Specifies files for source distribution
+├── pyproject.toml # Build system & package configuration (depends on trianglengin, trimcts, trieye, rich)
+├── README.md # This file
+└── requirements.txt # List of dependencies (includes trianglengin, trimcts, trieye, rich)
+```
+
+## Key Modules (`alphatriangle`)
+
+* **`cli`:** Defines the command-line interface using Typer (**`train`**, **`ml`**, **`tb`**, **`ray`** commands - headless). Uses **`rich`** for styling. ([`alphatriangle/cli.py`](alphatriangle/cli.py))
+* **`config`:** Centralized Pydantic configuration classes (Model, Train, **MCTS**). Imports `EnvConfig` from `trianglengin`. **Uses `TrieyeConfig` from `trieye` for stats/persistence.** ([`alphatriangle/config/README.md`](alphatriangle/config/README.md))
+* **`features`:** Contains logic to convert `trianglengin.GameState` objects into numerical features (`StateType`). ([`alphatriangle/features/README.md`](alphatriangle/features/README.md))
+* **`logging_config`:** Defines the `setup_logging` function for centralized **console** logger configuration. ([`alphatriangle/logging_config.py`](alphatriangle/logging_config.py))
+* **`nn`:** Contains the PyTorch `nn.Module` definition (`AlphaTriangleNet`) and a wrapper class (`NeuralNetwork`). **The `NeuralNetwork` class implicitly conforms to the `trimcts.AlphaZeroNetworkInterface` protocol.** ([`alphatriangle/nn/README.md`](alphatriangle/nn/README.md))
+* **`rl`:** Contains RL components: `Trainer` (network updates), `ExperienceBuffer` (data storage, **supports PER**), and `SelfPlayWorker` (Ray actor for parallel self-play **using `trimcts.run_mcts`**). **Workers now send `RawMetricEvent`s to the `TrieyeActor`.** ([`alphatriangle/rl/README.md`](alphatriangle/rl/README.md))
+* **`training`:** Orchestrates the **headless** training process using `TrainingLoop`, managing workers, data flow, and interaction with the **`TrieyeActor`** for logging, checkpointing, and state loading. Includes `runner.py` for the callable training function. ([`alphatriangle/training/README.md`](alphatriangle/training/README.md))
+* **`utils`:** Provides common helper functions and shared type definitions specific to the AlphaZero implementation. ([`alphatriangle/utils/README.md`](alphatriangle/utils/README.md))
+
+## Setup
+
+1. **Clone the repository (for development):**
+ ```bash
+ git clone https://github.com/lguibr/alphatriangle.git
+ cd alphatriangle
+ ```
+2. **Create a virtual environment (recommended):**
+ ```bash
+ python -m venv venv
+ source venv/bin/activate # On Windows use `venv\Scripts\activate`
+ ```
+3. **Install the package (including `trianglengin`, `trimcts`, `trieye`, and `rich`):**
+ * **For users:**
+ ```bash
+ # This will automatically install trianglengin, trimcts, trieye, and rich from PyPI if available
+ pip install alphatriangle
+ # Or install directly from Git (installs dependencies from PyPI)
+ # pip install git+https://github.com/lguibr/alphatriangle.git
+ ```
+ * **For developers (editable install):**
+ * First, ensure `trianglengin`, `trimcts`, and `trieye` are installed (ideally in editable mode from their own directories if developing all):
+ ```bash
+ # From the trianglengin directory (requires C++ build tools):
+ # pip install -e .
+ # From the trimcts directory (requires C++ build tools):
+ # pip install -e .
+ # From the trieye directory:
+ # pip install -e .
+ ```
+ * Then, install `alphatriangle` in editable mode:
+ ```bash
+ # From the alphatriangle directory:
+ pip install -e .
+ # Install dev dependencies (optional, for running tests/linting)
+ pip install -e .[dev] # Installs dev deps from pyproject.toml
+ ```
+ *Note: Ensure you have the correct PyTorch version installed for your system (CPU/CUDA/MPS). See [pytorch.org](https://pytorch.org/). Ray may have specific system requirements. `trianglengin` and `trimcts` require a C++ compiler (like GCC, Clang, or MSVC) and CMake.*
+4. **(Optional but Recommended) Add data directory to `.gitignore`:**
+ Ensure the `.gitignore` file in your project root contains the line:
+ ```
+ .trieye_data/
+ ```
+
+## Running the Code (CLI)
+
+Use the `alphatriangle` command for training and monitoring. The CLI uses `rich` for enhanced output.
+
+* **Show Help:**
+ ```bash
+ alphatriangle --help
+ ```
+* **Run Training (Headless Only):**
+ ```bash
+ alphatriangle train [--seed 42] [--log-level INFO] [--profile] [--run-name my_custom_run]
+ ```
+ * `--log-level`: Set console logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Default: INFO. Logs are saved to `.trieye_data/alphatriangle/runs//logs/`.
+ * `--seed`: Set the random seed for reproducibility. Default: 42.
+ * `--profile`: Enable cProfile for worker 0. Generates `.prof` files in `.trieye_data/alphatriangle/runs//profile_data/`.
+ * `--run-name`: Specify a custom name for the run. Default: `train_YYYYMMDD_HHMMSS`.
+* **Launch MLflow UI:**
+ Launches the MLflow web interface, automatically pointing to the `.trieye_data/alphatriangle/mlruns` directory.
+ ```bash
+ alphatriangle ml [--host 127.0.0.1] [--port 5000]
+ ```
+ Access via `http://localhost:5000` (or the specified host/port).
+* **Launch TensorBoard UI:**
+ Launches the TensorBoard web interface, automatically pointing to the `.trieye_data/alphatriangle/runs` directory (which contains the individual run subdirectories with `tensorboard` logs).
+ ```bash
+ alphatriangle tb [--host 127.0.0.1] [--port 6006]
+ ```
+ Access via `http://localhost:6006` (or the specified host/port).
+* **Launch Ray Dashboard UI:**
+ Launches the Ray Dashboard web interface. **Note:** This typically requires a Ray cluster to be running (e.g., started by `alphatriangle train` or manually).
+ ```bash
+ alphatriangle ray [--host 127.0.0.1] [--port 8265]
+ ```
+ Access via `http://localhost:8265` (or the specified host/port).
+* **Interactive Play/Debug (Use `trianglengin` CLI):**
+ *Note: Interactive modes are part of the `trianglengin` library, not this `alphatriangle` package.*
+ ```bash
+ # Ensure trianglengin is installed
+ trianglengin play [--seed 42] [--log-level INFO]
+ trianglengin debug [--seed 42] [--log-level DEBUG]
+ ```
+* **Running Unit Tests (Development):**
+ ```bash
+ pytest tests/
+ ```
+* **Analyzing Profile Data (if `--profile` was used):**
+ Use the provided `analyze_profiles.py` script (requires `typer`).
+ ```bash
+ python analyze_profiles.py .trieye_data/alphatriangle/runs//profile_data/worker_0_ep_.prof [-n ]
+ ```
+
+## Configuration
+
+* **AlphaTriangle Specific:** Parameters for the Model (`ModelConfig`), Training (`TrainConfig`), and MCTS (`AlphaTriangleMCTSConfig`) are defined in the Pydantic classes within the `alphatriangle/config/` directory.
+* **Environment:** Environment configuration (`EnvConfig`) is defined within the `trianglengin` library.
+* **Stats & Persistence:** Statistics logging and data persistence are configured via `TrieyeConfig` (which includes `StatsConfig` and `PersistenceConfig`) from the `trieye` library. These are typically instantiated in `alphatriangle/training/runner.py` or `alphatriangle/cli.py`.
+
+## Data Storage
+
+All persistent data is now managed by the `trieye` library and stored within the `.trieye_data//` directory (default: `.trieye_data/alphatriangle/`) in the project root. This directory should be added to your `.gitignore`.
+
+* **`.trieye_data//mlruns/`**: Managed by **MLflow** (via `trieye`). Contains MLflow's internal tracking data (parameters, metrics) and its own copy of logged artifacts. This is the source for the MLflow UI (`alphatriangle ml`).
+* **`.trieye_data//runs/`**: Managed by **`trieye`**. Contains locally saved artifacts for each run (checkpoints, buffers, TensorBoard logs, configs, **profile data**) before/during logging to MLflow. This directory is used for auto-resuming and is the source for the TensorBoard UI (`alphatriangle tb`).
+ * **Replay Buffer Content:** The saved buffer file (`buffer.pkl`) contains `Experience` tuples: `(StateType, PolicyTargetMapping, n_step_return)`. The `StateType` includes:
+ * `grid`: Numerical features representing grid occupancy.
+ * `other_features`: Numerical features derived from game state and available shapes.
+ * **Visualization:** This stored data allows offline analysis and visualization of grid occupancy and the available shapes for each recorded step. It does **not** contain the full sequence of actions or raw `GameState` objects needed for a complete, interactive game replay.
+
+## Maintainability
+
+This project includes README files within each major `alphatriangle` submodule. **Please keep these READMEs updated** when making changes to the code's structure, interfaces, or core logic, especially regarding the interaction with the `trieye` library.
+
+File: .python-version
+3.10.13
+
+
+File: .trieye_data/AlphaTriangle/mlruns/891874414194323618/meta.yaml
+artifact_location: file:///Users/lg/lab/alphatriangle/.trieye_data/AlphaTriangle/mlruns/891874414194323618
+creation_time: 1745378337600
+experiment_id: '891874414194323618'
+last_update_time: 1745378337600
+lifecycle_stage: active
+name: AlphaTriangle
+
+
+File: .trieye_data/AlphaTriangle/mlruns/891874414194323618/37a98b962de74c60b73994a5886679cf/meta.yaml
+artifact_uri: file:///Users/lg/lab/alphatriangle/.trieye_data/AlphaTriangle/mlruns/891874414194323618/37a98b962de74c60b73994a5886679cf/artifacts
+end_time: 1745378340742
+entry_point_name: ''
+experiment_id: '891874414194323618'
+lifecycle_stage: active
+run_id: 37a98b962de74c60b73994a5886679cf
+run_name: run_20250423_001854
+run_uuid: 37a98b962de74c60b73994a5886679cf
+source_name: ''
+source_type: 4
+source_version: ''
+start_time: 1745378337673
+status: 3
+tags: []
+user_id: lg
+
+
+File: .trieye_data/AlphaTriangle/mlruns/891874414194323618/37a98b962de74c60b73994a5886679cf/artifacts/config/configs.json
+{
+ "app_name": "AlphaTriangle",
+ "run_name": "run_20250423_001854",
+ "persistence": {
+ "ROOT_DATA_DIR": ".trieye_data",
+ "RUNS_DIR_NAME": "runs",
+ "MLFLOW_DIR_NAME": "mlruns",
+ "CHECKPOINT_SAVE_DIR_NAME": "checkpoints",
+ "BUFFER_SAVE_DIR_NAME": "buffers",
+ "LOG_DIR_NAME": "logs",
+ "TENSORBOARD_DIR_NAME": "tensorboard",
+ "PROFILE_DIR_NAME": "profile_data",
+ "LATEST_CHECKPOINT_FILENAME": "latest.pkl",
+ "BEST_CHECKPOINT_FILENAME": "best.pkl",
+ "BUFFER_FILENAME": "buffer.pkl",
+ "CONFIG_FILENAME": "configs.json",
+ "SAVE_BUFFER": true,
+ "BUFFER_SAVE_FREQ_STEPS": 1000,
+ "MLFLOW_TRACKING_URI": "file:///Users/lg/lab/alphatriangle/.trieye_data/AlphaTriangle/mlruns"
+ },
+ "stats": {
+ "processing_interval_seconds": 1.0,
+ "metrics": []
+ }
+}
+
+File: .trieye_data/AlphaTriangle/mlruns/891874414194323618/37a98b962de74c60b73994a5886679cf/tags/mlflow.user
+lg
+
+File: .trieye_data/AlphaTriangle/mlruns/891874414194323618/37a98b962de74c60b73994a5886679cf/tags/mlflow.runName
+run_20250423_001854
+
+File: .trieye_data/AlphaTriangle/mlruns/891874414194323618/37a98b962de74c60b73994a5886679cf/tags/mlflow.source.name
+/Users/lg/lab/triii/.venv/lib/python3.10/site-packages/ray/_private/workers/default_worker.py
+
+File: .trieye_data/AlphaTriangle/mlruns/891874414194323618/37a98b962de74c60b73994a5886679cf/tags/mlflow.source.type
+LOCAL
+
+File: .trieye_data/AlphaTriangle/runs/run_20250423_001854/configs.json
+{
+ "app_name": "AlphaTriangle",
+ "run_name": "run_20250423_001854",
+ "persistence": {
+ "ROOT_DATA_DIR": ".trieye_data",
+ "RUNS_DIR_NAME": "runs",
+ "MLFLOW_DIR_NAME": "mlruns",
+ "CHECKPOINT_SAVE_DIR_NAME": "checkpoints",
+ "BUFFER_SAVE_DIR_NAME": "buffers",
+ "LOG_DIR_NAME": "logs",
+ "TENSORBOARD_DIR_NAME": "tensorboard",
+ "PROFILE_DIR_NAME": "profile_data",
+ "LATEST_CHECKPOINT_FILENAME": "latest.pkl",
+ "BEST_CHECKPOINT_FILENAME": "best.pkl",
+ "BUFFER_FILENAME": "buffer.pkl",
+ "CONFIG_FILENAME": "configs.json",
+ "SAVE_BUFFER": true,
+ "BUFFER_SAVE_FREQ_STEPS": 1000,
+ "MLFLOW_TRACKING_URI": "file:///Users/lg/lab/alphatriangle/.trieye_data/AlphaTriangle/mlruns"
+ },
+ "stats": {
+ "processing_interval_seconds": 1.0,
+ "metrics": []
+ }
+}
+
+File: .pytest_cache/CACHEDIR.TAG
+Signature: 8a477f597d28d172789f06886806bc55
+# This file is a cache directory tag created by pytest.
+# For information about cache directory tags, see:
+# https://bford.info/cachedir/spec.html
+
+
+File: .pytest_cache/README.md
+# pytest cache directory #
+
+This directory contains data from the pytest's cache plugin,
+which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
+
+**Do not** commit this to version control.
+
+See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
+
+
+File: .ruff_cache/CACHEDIR.TAG
+Signature: 8a477f597d28d172789f06886806bc55
+
+File: tests/conftest.py
+# File: tests/conftest.py
+import random
+from typing import cast
+
+import numpy as np
+import pytest
+import torch
+import torch.optim as optim
+
+# Import from trianglengin's top level
+from trianglengin import EnvConfig
+
+# Import trimcts config
+from trimcts import SearchConfiguration
+
+# Keep alphatriangle imports
+from alphatriangle.config import (
+ ModelConfig,
+ TrainConfig,
+)
+from alphatriangle.nn import NeuralNetwork
+from alphatriangle.rl import ExperienceBuffer, Trainer
+from alphatriangle.utils.types import Experience, StateType
+
+# Removed PersistenceConfig, StatsConfig imports
+
+rng = np.random.default_rng()
+
+
+@pytest.fixture(scope="session")
+def mock_env_config() -> EnvConfig:
+ """Provides a default, *valid* trianglengin.EnvConfig for tests."""
+ rows = 3
+ cols = 3
+ playable_range = [(0, 3), (0, 3), (0, 3)]
+ return EnvConfig(
+ ROWS=rows,
+ COLS=cols,
+ PLAYABLE_RANGE_PER_ROW=playable_range,
+ NUM_SHAPE_SLOTS=1,
+ )
+
+
+@pytest.fixture(scope="session")
+def mock_model_config(mock_env_config: EnvConfig) -> ModelConfig:
+ """Provides a default ModelConfig compatible with mock_env_config (session-scoped)."""
+ action_dim_int = int(
+ mock_env_config.NUM_SHAPE_SLOTS * mock_env_config.ROWS * mock_env_config.COLS
+ )
+ expected_other_dim = 10
+
+ return ModelConfig(
+ GRID_INPUT_CHANNELS=1,
+ CONV_FILTERS=[4],
+ CONV_KERNEL_SIZES=[3],
+ CONV_STRIDES=[1],
+ CONV_PADDING=[1],
+ NUM_RESIDUAL_BLOCKS=0,
+ RESIDUAL_BLOCK_FILTERS=4,
+ USE_TRANSFORMER=False,
+ TRANSFORMER_DIM=16,
+ TRANSFORMER_HEADS=2,
+ TRANSFORMER_LAYERS=0,
+ TRANSFORMER_FC_DIM=32,
+ FC_DIMS_SHARED=[8],
+ POLICY_HEAD_DIMS=[action_dim_int],
+ VALUE_HEAD_DIMS=[1],
+ OTHER_NN_INPUT_FEATURES_DIM=expected_other_dim,
+ ACTIVATION_FUNCTION="ReLU",
+ USE_BATCH_NORM=True,
+ )
+
+
+@pytest.fixture(scope="session")
+def mock_train_config() -> TrainConfig:
+ """Provides a default TrainConfig for tests (session-scoped)."""
+ return TrainConfig(
+ BATCH_SIZE=4,
+ BUFFER_CAPACITY=100,
+ MIN_BUFFER_SIZE_TO_TRAIN=10,
+ USE_PER=False,
+ LOAD_CHECKPOINT_PATH=None, # Keep load paths for potential Trieye usage
+ LOAD_BUFFER_PATH=None,
+ AUTO_RESUME_LATEST=False, # Keep auto-resume for potential Trieye usage
+ DEVICE="cpu",
+ RANDOM_SEED=42,
+ NUM_SELF_PLAY_WORKERS=1,
+ WORKER_DEVICE="cpu",
+ WORKER_UPDATE_FREQ_STEPS=10,
+ OPTIMIZER_TYPE="Adam",
+ LEARNING_RATE=1e-3,
+ WEIGHT_DECAY=1e-4,
+ LR_SCHEDULER_ETA_MIN=1e-6,
+ POLICY_LOSS_WEIGHT=1.0,
+ VALUE_LOSS_WEIGHT=1.0,
+ ENTROPY_BONUS_WEIGHT=0.0,
+ CHECKPOINT_SAVE_FREQ_STEPS=50, # Keep checkpoint freq for loop logic
+ PER_ALPHA=0.6,
+ PER_BETA_INITIAL=0.4,
+ PER_BETA_FINAL=1.0,
+ PER_BETA_ANNEAL_STEPS=100,
+ PER_EPSILON=1e-5,
+ MAX_TRAINING_STEPS=200,
+ N_STEP_RETURNS=3,
+ GAMMA=0.99,
+ PROFILE_WORKERS=False,
+ RUN_NAME="pytest_default_run", # Keep run name
+ )
+
+
+# Removed mock_persistence_config fixture
+
+
+@pytest.fixture(scope="session")
+def mock_mcts_config() -> SearchConfiguration:
+ """Provides a default trimcts.SearchConfiguration for tests."""
+ return SearchConfiguration(
+ max_simulations=8,
+ max_depth=5,
+ cpuct=1.0,
+ dirichlet_alpha=0.3,
+ dirichlet_epsilon=0.25,
+ discount=1.0,
+ mcts_batch_size=4,
+ )
+
+
+@pytest.fixture(scope="session")
+def mock_state_type(
+ mock_model_config: ModelConfig, mock_env_config: EnvConfig
+) -> StateType:
+ """Creates a mock StateType dictionary with correct shapes."""
+ grid_shape = (
+ mock_model_config.GRID_INPUT_CHANNELS,
+ mock_env_config.ROWS,
+ mock_env_config.COLS,
+ )
+ other_shape = (mock_model_config.OTHER_NN_INPUT_FEATURES_DIM,)
+
+ return {
+ "grid": rng.random(grid_shape, dtype=np.float32),
+ "other_features": rng.random(other_shape, dtype=np.float32),
+ }
+
+
+@pytest.fixture(scope="session")
+def mock_experience(
+ mock_state_type: StateType, mock_env_config: EnvConfig
+) -> Experience:
+ """Creates a mock Experience tuple."""
+ action_dim = int(
+ mock_env_config.NUM_SHAPE_SLOTS * mock_env_config.ROWS * mock_env_config.COLS
+ )
+ policy_target = (
+ dict.fromkeys(range(action_dim), 1.0 / action_dim)
+ if action_dim > 0
+ else {0: 1.0}
+ )
+ value_target = random.uniform(-1, 1)
+ return (mock_state_type, policy_target, value_target)
+
+
+@pytest.fixture(scope="session")
+def mock_nn_interface(
+ mock_model_config: ModelConfig,
+ mock_env_config: EnvConfig,
+ mock_train_config: TrainConfig,
+) -> NeuralNetwork:
+ """Provides a NeuralNetwork instance with a mock model for testing."""
+ device = torch.device("cpu")
+ nn_interface = NeuralNetwork(
+ mock_model_config, mock_env_config, mock_train_config, device
+ )
+ return nn_interface
+
+
+@pytest.fixture(scope="session")
+def mock_trainer(
+ mock_nn_interface: NeuralNetwork,
+ mock_train_config: TrainConfig,
+ mock_env_config: EnvConfig,
+) -> Trainer:
+ """Provides a Trainer instance."""
+ return Trainer(mock_nn_interface, mock_train_config, mock_env_config)
+
+
+@pytest.fixture(scope="session")
+def mock_optimizer(mock_trainer: Trainer) -> optim.Optimizer:
+ """Provides the optimizer from the mock_trainer."""
+ return cast("optim.Optimizer", mock_trainer.optimizer)
+
+
+@pytest.fixture
+def mock_experience_buffer(mock_train_config: TrainConfig) -> ExperienceBuffer:
+ """Provides an ExperienceBuffer instance."""
+ return ExperienceBuffer(mock_train_config)
+
+
+@pytest.fixture
+def filled_mock_buffer(
+ mock_experience_buffer: ExperienceBuffer, mock_experience: Experience
+) -> ExperienceBuffer:
+ """Provides a buffer filled with some mock experiences."""
+ for i in range(mock_experience_buffer.min_size_to_train + 5):
+ state_copy: StateType = {
+ "grid": mock_experience[0]["grid"].copy() + i * 0.01,
+ "other_features": mock_experience[0]["other_features"].copy() + i * 0.01,
+ }
+ exp_copy: Experience = (state_copy, mock_experience[1], random.uniform(-1, 1))
+ mock_experience_buffer.add(exp_copy)
+ return mock_experience_buffer
+
+
+File: tests/__init__.py
+
+
+File: tests/nn/test_model.py
+# File: tests/nn/test_model.py
+import pytest
+import torch
+
+# Import EnvConfig from trianglengin's top level
+from trianglengin import EnvConfig # UPDATED IMPORT
+
+# Keep alphatriangle imports
+from alphatriangle.config import ModelConfig
+from alphatriangle.nn import AlphaTriangleNet
+
+
+# Use shared fixtures implicitly via pytest injection
+@pytest.fixture
+def env_config(mock_env_config: EnvConfig) -> EnvConfig: # Uses trianglengin.EnvConfig
+ return mock_env_config
+
+
+@pytest.fixture
+def model_config(mock_model_config: ModelConfig) -> ModelConfig:
+ return mock_model_config
+
+
+@pytest.fixture
+def model(model_config: ModelConfig, env_config: EnvConfig) -> AlphaTriangleNet:
+ """Provides an instance of the AlphaTriangleNet model."""
+ # Pass trianglengin.EnvConfig
+ return AlphaTriangleNet(model_config, env_config)
+
+
+def test_model_initialization(
+ model: AlphaTriangleNet, model_config: ModelConfig, env_config: EnvConfig
+):
+ """Test if the model initializes without errors."""
+ assert model is not None
+ # Calculate action_dim manually for comparison
+ action_dim_int = int(env_config.NUM_SHAPE_SLOTS * env_config.ROWS * env_config.COLS)
+ assert model.action_dim == action_dim_int
+ assert model.model_config.USE_TRANSFORMER == model_config.USE_TRANSFORMER
+ if model_config.USE_TRANSFORMER:
+ assert model.transformer_body is not None
+ else:
+ assert model.transformer_body is None
+
+
+def test_model_forward_pass(
+ model: AlphaTriangleNet, model_config: ModelConfig, env_config: EnvConfig
+):
+ """Test the forward pass with dummy input tensors."""
+ batch_size = 4
+ device = torch.device("cpu")
+ model.to(device)
+ model.eval()
+ # Calculate action_dim manually
+ action_dim_int = int(env_config.NUM_SHAPE_SLOTS * env_config.ROWS * env_config.COLS)
+
+ grid_shape = (
+ batch_size,
+ model_config.GRID_INPUT_CHANNELS,
+ env_config.ROWS,
+ env_config.COLS,
+ )
+ other_shape = (batch_size, model_config.OTHER_NN_INPUT_FEATURES_DIM)
+
+ dummy_grid = torch.randn(grid_shape, device=device)
+ dummy_other = torch.randn(other_shape, device=device)
+
+ with torch.no_grad():
+ policy_logits, value_logits = model(dummy_grid, dummy_other)
+
+ assert policy_logits.shape == (
+ batch_size,
+ action_dim_int,
+ ), f"Policy logits shape mismatch: {policy_logits.shape}"
+ assert value_logits.shape == (
+ batch_size,
+ model_config.NUM_VALUE_ATOMS,
+ ), f"Value logits shape mismatch: {value_logits.shape}"
+
+ assert policy_logits.dtype == torch.float32
+ assert value_logits.dtype == torch.float32
+
+
+@pytest.mark.parametrize(
+ "use_transformer", [False, True], ids=["CNN_Only", "CNN_Transformer"]
+)
+def test_model_forward_transformer_toggle(use_transformer: bool, env_config: EnvConfig):
+ """Test forward pass with transformer enabled/disabled."""
+ # Calculate action_dim manually
+ action_dim_int = int(env_config.NUM_SHAPE_SLOTS * env_config.ROWS * env_config.COLS)
+ model_config_test = ModelConfig(
+ GRID_INPUT_CHANNELS=1,
+ CONV_FILTERS=[4, 8],
+ CONV_KERNEL_SIZES=[3, 3],
+ CONV_STRIDES=[1, 1],
+ CONV_PADDING=[1, 1],
+ NUM_RESIDUAL_BLOCKS=0,
+ RESIDUAL_BLOCK_FILTERS=8,
+ USE_TRANSFORMER=use_transformer,
+ TRANSFORMER_DIM=16,
+ TRANSFORMER_HEADS=2,
+ TRANSFORMER_LAYERS=1,
+ TRANSFORMER_FC_DIM=32,
+ FC_DIMS_SHARED=[16],
+ POLICY_HEAD_DIMS=[action_dim_int], # Use calculated int
+ OTHER_NN_INPUT_FEATURES_DIM=10,
+ ACTIVATION_FUNCTION="ReLU",
+ USE_BATCH_NORM=True,
+ )
+ # Pass trianglengin.EnvConfig
+ model = AlphaTriangleNet(model_config_test, env_config)
+ batch_size = 2
+ device = torch.device("cpu")
+ model.to(device)
+ model.eval()
+
+ grid_shape = (
+ batch_size,
+ model_config_test.GRID_INPUT_CHANNELS,
+ env_config.ROWS,
+ env_config.COLS,
+ )
+ other_shape = (batch_size, model_config_test.OTHER_NN_INPUT_FEATURES_DIM)
+ dummy_grid = torch.randn(grid_shape, device=device)
+ dummy_other = torch.randn(other_shape, device=device)
+
+ with torch.no_grad():
+ policy_logits, value_logits = model(dummy_grid, dummy_other)
+
+ assert policy_logits.shape == (batch_size, action_dim_int)
+ assert value_logits.shape == (batch_size, model_config_test.NUM_VALUE_ATOMS)
+
+
+File: tests/nn/__init__.py
+
+
+File: tests/nn/test_network.py
+# File: tests/nn/test_network.py
+import numpy as np
+import pytest
+import torch
+
+# Import GameState and EnvConfig from trianglengin's top level
+from trianglengin import EnvConfig, GameState
+
+# Keep alphatriangle imports
+from alphatriangle.config import ModelConfig, TrainConfig
+from alphatriangle.features import extract_state_features
+from alphatriangle.nn import NeuralNetwork
+from alphatriangle.nn.network import NetworkEvaluationError
+from alphatriangle.utils.types import StateType
+
+
+# Use shared fixtures implicitly via pytest injection
+@pytest.fixture
+def env_config(mock_env_config: EnvConfig) -> EnvConfig: # Uses trianglengin.EnvConfig
+ return mock_env_config
+
+
+@pytest.fixture
+def model_config(mock_model_config: ModelConfig) -> ModelConfig:
+ return mock_model_config
+
+
+@pytest.fixture
+def train_config(mock_train_config: TrainConfig) -> TrainConfig:
+ # Explicitly disable compilation for tests unless specifically testing compilation
+ cfg = mock_train_config.model_copy(deep=True)
+ cfg.COMPILE_MODEL = False
+ return cfg
+
+
+@pytest.fixture
+def nn_interface(
+ model_config: ModelConfig,
+ env_config: EnvConfig, # Uses trianglengin.EnvConfig
+ train_config: TrainConfig, # Use the modified train_config fixture
+) -> NeuralNetwork:
+ """Provides a NeuralNetwork instance for testing."""
+ device = torch.device("cpu")
+ # Pass trianglengin.EnvConfig and the modified train_config
+ nn_interface_instance = NeuralNetwork(
+ model_config, env_config, train_config, device
+ )
+ # Ensure model is in eval mode for consistency, although set_weights should handle it
+ nn_interface_instance.model.eval()
+ nn_interface_instance._orig_model.eval()
+ return nn_interface_instance
+
+
+@pytest.fixture
+def game_state(env_config: EnvConfig) -> GameState: # Uses trianglengin.GameState
+ """Provides a fresh GameState instance."""
+ # Pass trianglengin.EnvConfig
+ return GameState(config=env_config, initial_seed=123)
+
+
+def test_network_initialization(
+ nn_interface: NeuralNetwork,
+ model_config: ModelConfig,
+ env_config: EnvConfig, # Uses trianglengin.EnvConfig
+):
+ """Test if the NeuralNetwork wrapper initializes correctly."""
+ assert nn_interface.model is not None
+ assert nn_interface.device == torch.device("cpu")
+ assert nn_interface.model_config == model_config
+ assert nn_interface.env_config == env_config # Check env_config storage
+ # Calculate action_dim manually for comparison
+ action_dim_int = int(env_config.NUM_SHAPE_SLOTS * env_config.ROWS * env_config.COLS)
+ assert nn_interface.action_dim == action_dim_int
+
+
+def test_state_to_tensors(
+ nn_interface: NeuralNetwork,
+ game_state: GameState, # Uses trianglengin.GameState
+ model_config: ModelConfig,
+ env_config: EnvConfig, # Uses trianglengin.EnvConfig
+):
+ """Test the conversion of a GameState to tensors."""
+ grid_tensor, other_tensor = nn_interface._state_to_tensors(game_state)
+
+ assert isinstance(grid_tensor, torch.Tensor)
+ assert isinstance(other_tensor, torch.Tensor)
+
+ assert grid_tensor.shape == (
+ 1,
+ model_config.GRID_INPUT_CHANNELS,
+ env_config.ROWS,
+ env_config.COLS,
+ )
+ assert other_tensor.shape == (1, model_config.OTHER_NN_INPUT_FEATURES_DIM)
+
+ assert grid_tensor.device == nn_interface.device
+ assert other_tensor.device == nn_interface.device
+
+
+def test_batch_states_to_tensors(
+ nn_interface: NeuralNetwork,
+ game_state: GameState, # Uses trianglengin.GameState
+ model_config: ModelConfig,
+ env_config: EnvConfig, # Uses trianglengin.EnvConfig
+):
+ """Test the conversion of a batch of GameStates to tensors."""
+ batch_size = 3
+ states = [game_state.copy() for _ in range(batch_size)]
+ grid_tensor, other_tensor = nn_interface._batch_states_to_tensors(states)
+
+ assert isinstance(grid_tensor, torch.Tensor)
+ assert isinstance(other_tensor, torch.Tensor)
+
+ assert grid_tensor.shape == (
+ batch_size,
+ model_config.GRID_INPUT_CHANNELS,
+ env_config.ROWS,
+ env_config.COLS,
+ )
+ assert other_tensor.shape == (batch_size, model_config.OTHER_NN_INPUT_FEATURES_DIM)
+
+ assert grid_tensor.device == nn_interface.device
+ assert other_tensor.device == nn_interface.device
+
+
+def test_evaluate_state(
+ nn_interface: NeuralNetwork,
+ game_state: GameState, # Uses trianglengin.GameState
+ env_config: EnvConfig, # Uses trianglengin.EnvConfig
+):
+ """Test evaluating a single GameState."""
+ policy_dict, value = nn_interface.evaluate_state(game_state)
+
+ assert isinstance(policy_dict, dict)
+ assert isinstance(value, float)
+ # Calculate action_dim manually for comparison
+ action_dim_int = int(env_config.NUM_SHAPE_SLOTS * env_config.ROWS * env_config.COLS)
+ assert len(policy_dict) == action_dim_int
+ assert all(isinstance(k, int) for k in policy_dict)
+ assert all(isinstance(v, float) for v in policy_dict.values())
+ assert abs(sum(policy_dict.values()) - 1.0) < 1e-5
+
+
+def test_evaluate_batch(
+ nn_interface: NeuralNetwork,
+ game_state: GameState, # Uses trianglengin.GameState
+ env_config: EnvConfig, # Uses trianglengin.EnvConfig
+):
+ """Test evaluating a batch of GameStates."""
+ batch_size = 3
+ states = [game_state.copy() for _ in range(batch_size)]
+ results = nn_interface.evaluate_batch(states)
+
+ assert isinstance(results, list)
+ assert len(results) == batch_size
+
+ # Calculate action_dim manually for comparison
+ action_dim_int = int(env_config.NUM_SHAPE_SLOTS * env_config.ROWS * env_config.COLS)
+ for policy_dict, value in results:
+ assert isinstance(policy_dict, dict)
+ assert isinstance(value, float)
+ assert len(policy_dict) == action_dim_int
+ assert all(isinstance(k, int) for k in policy_dict)
+ assert all(isinstance(v, float) for v in policy_dict.values())
+ assert abs(sum(policy_dict.values()) - 1.0) < 1e-5
+
+
+def test_get_set_weights(nn_interface: NeuralNetwork):
+ """Test getting and setting model weights."""
+ # Ensure model is on CPU for this test
+ nn_interface._orig_model.cpu()
+ nn_interface.device = torch.device("cpu")
+
+ initial_weights_state_dict = nn_interface.get_weights()
+ assert isinstance(initial_weights_state_dict, dict)
+ assert all(isinstance(v, torch.Tensor) for v in initial_weights_state_dict.values())
+ assert all(
+ v.device == torch.device("cpu") for v in initial_weights_state_dict.values()
+ )
+
+ # Create a deep copy of initial weights for comparison later
+ initial_weights_copy = {
+ k: v.clone().detach() for k, v in initial_weights_state_dict.items()
+ }
+
+ # Modify weights: Create a new state dict with zeros for float params
+ modified_weights_state_dict = {}
+ params_modified = False
+ float_keys_modified = []
+
+ for k, v in initial_weights_state_dict.items():
+ if v.dtype.is_floating_point:
+ modified_weights_state_dict[k] = torch.zeros_like(v)
+ params_modified = True
+ float_keys_modified.append(k)
+ # Check if modification actually happened (unless initial was already zero)
+ if not torch.all(initial_weights_state_dict[k] == 0):
+ assert not torch.equal(
+ initial_weights_state_dict[k], modified_weights_state_dict[k]
+ ), f"Weight {k} did not change after creating zeros_like!"
+ else:
+ # Keep non-float tensors (buffers) unchanged, ensure they are detached copies
+ modified_weights_state_dict[k] = v.clone().detach()
+
+ assert params_modified, "No floating point parameters found to modify in state_dict"
+
+ # --- Set the modified weights ---
+ nn_interface.set_weights(modified_weights_state_dict)
+
+ # --- Direct Check: Compare model parameters immediately after setting ---
+ direct_check_passed = True
+ mismatched_params = []
+ with torch.no_grad():
+ # Use named_parameters which only includes trainable parameters
+ for name, param in nn_interface._orig_model.named_parameters():
+ if name in modified_weights_state_dict:
+ expected_tensor = modified_weights_state_dict[name]
+ actual_tensor = param.data
+ if not torch.allclose(actual_tensor, expected_tensor, atol=1e-6):
+ direct_check_passed = False
+ mismatched_params.append(name)
+ print(f"Direct Param Check Mismatch for key '{name}':")
+ print(f" Expected (modified): {expected_tensor.flatten()[:5]}...")
+ print(f" Got (actual param): {actual_tensor.flatten()[:5]}...")
+ else:
+ print(
+ f"Warning: Parameter '{name}' not found in modified_weights_state_dict during direct check."
+ )
+
+ assert direct_check_passed, (
+ f"Direct parameter check failed after set_weights for keys: {mismatched_params}"
+ )
+
+ # --- Get weights again using the interface method ---
+ new_weights_state_dict = nn_interface.get_weights()
+
+ # --- Comparison using state dicts ---
+ weights_changed_from_initial = False
+ mismatched_keys_final = []
+
+ for k in float_keys_modified:
+ initial_tensor = initial_weights_copy[k] # Compare against the initial copy
+ new_tensor = new_weights_state_dict[k]
+ modified_tensor = modified_weights_state_dict[k]
+
+ # Check 1: Did the tensor change from the initial state?
+ if not torch.equal(initial_tensor, new_tensor):
+ weights_changed_from_initial = True
+
+ # Check 2: Does the new tensor match the one we tried to set?
+ if not torch.allclose(new_tensor, modified_tensor, atol=1e-6):
+ mismatched_keys_final.append(k)
+ print(f"Final State Dict Mismatch for key '{k}':")
+ print(f" Expected (modified): {modified_tensor.flatten()[:5]}...")
+ print(f" Got (new state dict):{new_tensor.flatten()[:5]}...")
+ print(f" Initial: {initial_tensor.flatten()[:5]}...")
+
+ # Assert that at least one weight changed from the initial state
+ assert weights_changed_from_initial, (
+ "Floating point weights did not change from initial state after set_weights/get_weights cycle"
+ )
+
+ # Assert that all modified weights match the expected values in the final state dict
+ assert not mismatched_keys_final, (
+ f"Weights in final state_dict did not match expected values after set_weights for keys: {mismatched_keys_final}"
+ )
+
+ # Final check: ensure all keys match (including buffers)
+ assert all(
+ (
+ torch.allclose(
+ modified_weights_state_dict[k], new_weights_state_dict[k], atol=1e-6
+ )
+ if modified_weights_state_dict[k].dtype.is_floating_point
+ else torch.equal(modified_weights_state_dict[k], new_weights_state_dict[k])
+ )
+ for k in modified_weights_state_dict
+ ), (
+ "Mismatch between modified_weights and new_weights across all keys in final state dict."
+ )
+
+
+def test_evaluate_state_with_nan_features(
+ nn_interface: NeuralNetwork,
+ game_state: GameState, # Uses trianglengin.GameState
+ mocker,
+):
+ """Test that evaluation raises error if features contain NaN."""
+
+ def mock_extract_nan(*args, **kwargs) -> StateType:
+ state_dict = extract_state_features(*args, **kwargs)
+ state_dict["other_features"][0] = np.nan # Inject NaN
+ return state_dict
+
+ mocker.patch("alphatriangle.nn.network.extract_state_features", mock_extract_nan)
+
+ with pytest.raises(NetworkEvaluationError, match="Non-finite values found"):
+ nn_interface.evaluate_state(game_state)
+
+
+def test_evaluate_batch_with_nan_features(
+ nn_interface: NeuralNetwork,
+ game_state: GameState, # Uses trianglengin.GameState
+ mocker,
+):
+ """Test that batch evaluation raises error if features contain NaN."""
+ batch_size = 2
+ states = [game_state.copy() for _ in range(batch_size)]
+
+ def mock_extract_nan_batch(*args, **kwargs) -> StateType:
+ state_dict = extract_state_features(*args, **kwargs)
+ # Inject NaN into the first element of the batch only
+ if args[0] is states[0]:
+ state_dict["other_features"][0] = np.nan
+ return state_dict
+
+ mocker.patch(
+ "alphatriangle.nn.network.extract_state_features", mock_extract_nan_batch
+ )
+
+ with pytest.raises(NetworkEvaluationError, match="Non-finite values found"):
+ nn_interface.evaluate_batch(states)
+
+
+File: tests/training/test_loop_integration.py
+# File: tests/training/test_loop_integration.py
+import logging
+import time
+from typing import TYPE_CHECKING, Any, cast
+from unittest.mock import MagicMock
+
+import numpy as np
+import pytest
+import torch # <-- ADDED IMPORT
+
+# Import Trieye config for components
+from trieye import TrieyeConfig
+from trieye.schemas import RawMetricEvent # Import from trieye
+
+# Import necessary classes for patching targets
+from alphatriangle.config import TrainConfig
+from alphatriangle.rl import ExperienceBuffer, SelfPlayResult, Trainer
+from alphatriangle.training import TrainingComponents, TrainingLoop
+from alphatriangle.training.loop_helpers import LoopHelpers
+from alphatriangle.training.worker_manager import WorkerManager
+
+if TYPE_CHECKING:
+ from alphatriangle.utils.types import Experience, PERBatchSample
+
+logger = logging.getLogger(__name__)
+
+
+# --- Fixtures ---
+
+
+class MockSelfPlayWorker:
+ """Simplified Mock Worker for loop integration tests."""
+
+ def __init__(self, actor_id: int, trieye_actor_mock: MagicMock | None):
+ self.actor_id = actor_id
+ self.trieye_actor_mock = trieye_actor_mock # Store mock handle
+ self.weights_set_count = 0
+ self.last_weights_received: dict[str, Any] | None = None
+ self.step_set_count = 0
+ self.current_trainer_step = 0
+ self.task_running = False
+ self.set_weights = MagicMock(side_effect=self._set_weights_impl)
+ self.set_current_trainer_step = MagicMock(
+ side_effect=self._set_current_trainer_step_impl
+ )
+ self.run_episode = MagicMock(side_effect=self._run_episode_impl)
+
+ # Simulate remote calls
+ self.set_weights.remote = self.set_weights
+ self.set_current_trainer_step.remote = self.set_current_trainer_step
+ self.run_episode.remote = self.run_episode
+
+ def _run_episode_impl(self) -> SelfPlayResult:
+ self.task_running = True
+ step_at_start = self.current_trainer_step
+ time.sleep(0.01)
+ dummy_exp: Experience = (
+ {"grid": np.array([[[0.0]]]), "other_features": np.array([0.0])},
+ {0: 1.0},
+ 0.5,
+ )
+ episode_context = {
+ "score": 1.0,
+ "length": 1,
+ "simulations": 10,
+ "triangles_cleared": 0,
+ "trainer_step": step_at_start,
+ }
+
+ # Simulate worker sending events to Trieye
+ if self.trieye_actor_mock:
+ self.trieye_actor_mock.log_event.remote(
+ RawMetricEvent(
+ name="mcts_step",
+ value=10.0,
+ global_step=step_at_start,
+ context={"game_step": 0},
+ )
+ )
+ self.trieye_actor_mock.log_event.remote(
+ RawMetricEvent(
+ name="step_reward",
+ value=0.1,
+ global_step=step_at_start,
+ context={"game_step": 0},
+ )
+ )
+ self.trieye_actor_mock.log_event.remote(
+ RawMetricEvent(
+ name="current_score",
+ value=1.0,
+ global_step=step_at_start,
+ context={"game_step": 0},
+ )
+ )
+ self.trieye_actor_mock.log_event.remote(
+ RawMetricEvent(
+ name="episode_end",
+ value=1.0,
+ global_step=step_at_start,
+ context=episode_context,
+ )
+ )
+
+ result = SelfPlayResult(
+ episode_experiences=[dummy_exp],
+ final_score=1.0,
+ episode_steps=1,
+ trainer_step_at_episode_start=step_at_start,
+ total_simulations=10,
+ avg_root_visits=10.0,
+ avg_tree_depth=1.0,
+ context=episode_context,
+ )
+ self.task_running = False
+ return result
+
+ def _set_weights_impl(self, weights: dict) -> None:
+ self.weights_set_count += 1
+ self.last_weights_received = weights
+ time.sleep(0.001)
+
+ def _set_current_trainer_step_impl(self, global_step: int) -> None:
+ self.step_set_count += 1
+ self.current_trainer_step = global_step
+ time.sleep(0.001)
+
+
+@pytest.fixture
+def mock_training_config(mock_train_config: TrainConfig) -> TrainConfig:
+ """Fixture for TrainConfig with settings suitable for integration tests."""
+ cfg = mock_train_config.model_copy(deep=True)
+ cfg.NUM_SELF_PLAY_WORKERS = 2
+ cfg.WORKER_UPDATE_FREQ_STEPS = 5
+ cfg.MIN_BUFFER_SIZE_TO_TRAIN = 2
+ cfg.BATCH_SIZE = 1
+ cfg.MAX_TRAINING_STEPS = 20
+ cfg.USE_PER = False
+ cfg.COMPILE_MODEL = False
+ cfg.PROFILE_WORKERS = False
+ cfg.CHECKPOINT_SAVE_FREQ_STEPS = 10 # Save checkpoint during short run
+ return cfg
+
+
+@pytest.fixture
+def mock_trieye_config() -> TrieyeConfig:
+ """Fixture for a basic TrieyeConfig."""
+ # Use default metrics for simplicity in this test
+ return TrieyeConfig(app_name="test_loop_app", run_name="test_loop_run")
+
+
+@pytest.fixture
+def mock_components(
+ mocker,
+ mock_nn_interface,
+ mock_training_config,
+ mock_trieye_config, # Use Trieye config
+ mock_env_config,
+ mock_model_config,
+ mock_mcts_config,
+ mock_experience,
+) -> TrainingComponents:
+ """Fixture to create TrainingComponents with mocks suitable for loop tests."""
+
+ mock_buffer = MagicMock(spec=ExperienceBuffer)
+ mock_buffer.config = mock_training_config
+ mock_buffer.capacity = mock_training_config.BUFFER_CAPACITY
+ mock_buffer.min_size_to_train = mock_training_config.MIN_BUFFER_SIZE_TO_TRAIN
+ mock_buffer.use_per = mock_training_config.USE_PER
+ mock_buffer.is_ready.return_value = True
+ mock_buffer.__len__.return_value = mock_training_config.MIN_BUFFER_SIZE_TO_TRAIN + 5
+ dummy_sample: PERBatchSample = {
+ "batch": [mock_experience],
+ "indices": np.array([0]),
+ "weights": np.array([1.0]),
+ }
+ mock_buffer.sample.return_value = dummy_sample
+ if mock_training_config.USE_PER:
+ mock_buffer._calculate_beta = MagicMock(return_value=0.5)
+ mocker.patch(
+ "alphatriangle.rl.core.buffer.ExperienceBuffer", return_value=mock_buffer
+ )
+
+ mock_trainer = MagicMock(spec=Trainer)
+ mock_trainer.train_config = mock_training_config
+ mock_trainer.env_config = mock_env_config
+ mock_trainer.model_config = mock_model_config
+ mock_trainer.nn = mock_nn_interface
+ mock_trainer.device = mock_nn_interface.device
+ mock_trainer.train_step.return_value = (
+ {"total_loss": 0.1, "policy_loss": 0.05, "value_loss": 0.05},
+ np.array([0.1]),
+ )
+ mock_trainer.get_current_lr.return_value = mock_training_config.LEARNING_RATE
+ # Mock the optimizer attribute needed for saving state
+ mock_trainer.optimizer = MagicMock(spec=torch.optim.Optimizer)
+ mock_trainer.optimizer.state_dict.return_value = {"opt_state": "dummy"}
+ mocker.patch("alphatriangle.rl.core.trainer.Trainer", return_value=mock_trainer)
+
+ # Mock Trieye Actor Handle
+ mock_trieye_actor_handle = MagicMock(name="TrieyeActorMockHandle")
+ mock_trieye_actor_handle.log_event = MagicMock(name="log_event_remote")
+ mock_trieye_actor_handle.log_batch_events = MagicMock(
+ name="log_batch_events_remote"
+ )
+ mock_trieye_actor_handle.save_training_state = MagicMock(
+ name="save_training_state_remote"
+ )
+ # Simulate remote calls
+ mock_trieye_actor_handle.log_event.remote = mock_trieye_actor_handle.log_event
+ mock_trieye_actor_handle.log_batch_events.remote = (
+ mock_trieye_actor_handle.log_batch_events
+ )
+ mock_trieye_actor_handle.save_training_state.remote = (
+ mock_trieye_actor_handle.save_training_state
+ )
+ # Add mock serializer needed by loop._trigger_save_state
+ mock_serializer = MagicMock()
+ mock_serializer.prepare_optimizer_state.return_value = {"opt_state": "dummy"}
+ mock_serializer.prepare_buffer_data.return_value = MagicMock(buffer_list=[])
+ mock_trieye_actor_handle.serializer = mock_serializer
+
+ mock_workers = [
+ MockSelfPlayWorker(i, mock_trieye_actor_handle) # Pass Trieye mock
+ for i in range(mock_training_config.NUM_SELF_PLAY_WORKERS)
+ ]
+ mock_worker_manager_instance = MagicMock(spec=WorkerManager)
+ mock_worker_manager_instance.workers = mock_workers
+ mock_worker_manager_instance.active_worker_indices = set(
+ range(mock_training_config.NUM_SELF_PLAY_WORKERS)
+ )
+
+ mock_worker_tasks: dict[int, MagicMock] = {}
+
+ def mock_submit_task(worker_idx):
+ if worker_idx in mock_worker_manager_instance.active_worker_indices:
+ mock_task_ref = MagicMock(name=f"task_ref_w{worker_idx}")
+ mock_worker_tasks[worker_idx] = mock_task_ref
+ logger.debug(f"Mock submit task for worker {worker_idx}")
+
+ def mock_get_completed_tasks(*_args, **_kwargs):
+ results = []
+ if mock_worker_tasks:
+ worker_idx_to_complete = next(iter(mock_worker_tasks))
+ mock_worker_tasks.pop(worker_idx_to_complete)
+ if (
+ worker_idx_to_complete < len(mock_workers)
+ and mock_workers[worker_idx_to_complete] is not None
+ ):
+ result = mock_workers[worker_idx_to_complete].run_episode()
+ results.append((worker_idx_to_complete, result))
+ logger.debug(f"Mock completed task for worker {worker_idx_to_complete}")
+ else:
+ logger.warning(
+ f"Mock worker {worker_idx_to_complete} not found or None."
+ )
+ return results
+
+ mock_worker_manager_instance.submit_task.side_effect = mock_submit_task
+ mock_worker_manager_instance.get_completed_tasks.side_effect = (
+ mock_get_completed_tasks
+ )
+ mock_worker_manager_instance.get_num_pending_tasks.side_effect = lambda: len(
+ mock_worker_tasks
+ )
+
+ def mock_update_worker_networks(global_step: int):
+ logger.debug(f"Mock update_worker_networks called for step {global_step}")
+ weights = mock_nn_interface.get_weights()
+ for worker in mock_workers:
+ if worker:
+ worker.set_weights.remote(weights)
+ worker.set_current_trainer_step.remote(global_step)
+ logger.debug("Mock update_worker_networks finished.")
+
+ mock_worker_manager_instance.update_worker_networks.side_effect = (
+ mock_update_worker_networks
+ )
+ mock_worker_manager_instance.get_num_active_workers.return_value = (
+ mock_training_config.NUM_SELF_PLAY_WORKERS
+ )
+
+ mocker.patch(
+ "alphatriangle.training.loop.WorkerManager",
+ return_value=mock_worker_manager_instance,
+ )
+
+ mock_loop_helpers = MagicMock(spec=LoopHelpers)
+ mock_loop_helpers.log_progress_eta = MagicMock()
+ mock_loop_helpers.validate_experiences.side_effect = lambda exps: (exps, 0)
+ mocker.patch(
+ "alphatriangle.training.loop.LoopHelpers", return_value=mock_loop_helpers
+ )
+
+ # Create TrainingComponents with Trieye actor handle and config
+ components = TrainingComponents(
+ nn=mock_nn_interface,
+ buffer=mock_buffer,
+ trainer=mock_trainer,
+ trieye_actor=mock_trieye_actor_handle, # Pass mock handle
+ trieye_config=mock_trieye_config, # Pass Trieye config
+ train_config=mock_training_config,
+ env_config=mock_env_config,
+ model_config=mock_model_config,
+ mcts_config=mock_mcts_config,
+ profile_workers=mock_training_config.PROFILE_WORKERS,
+ )
+
+ return components
+
+
+@pytest.fixture
+def mock_training_loop(mock_components: TrainingComponents) -> TrainingLoop:
+ """Fixture for an initialized TrainingLoop with mocked components."""
+ loop = TrainingLoop(mock_components)
+ loop.set_initial_state(0, 0, 0)
+ for i in range(loop.train_config.NUM_SELF_PLAY_WORKERS):
+ loop.worker_manager.submit_task(i)
+ return loop
+
+
+# --- Test ---
+
+
+def test_worker_weight_update_and_stats_logging(
+ mock_training_loop: TrainingLoop, mock_components: TrainingComponents
+):
+ """
+ Verify that worker weights/steps are updated and that the weight update
+ event is sent to the Trieye actor.
+ """
+ loop = mock_training_loop
+ components = mock_components
+ worker_manager = loop.worker_manager
+ mock_workers = worker_manager.workers
+ trieye_actor_mock = components.trieye_actor # Get mock handle
+ assert trieye_actor_mock is not None, "Trieye actor mock handle should not be None"
+
+ update_freq = components.train_config.WORKER_UPDATE_FREQ_STEPS
+ max_steps = components.train_config.MAX_TRAINING_STEPS or 20
+
+ loop.run()
+
+ results_processed = loop.episodes_played
+ total_expected_updates = max_steps // update_freq
+ expected_total_weight_updates = total_expected_updates
+
+ # Check calls to Trieye actor's log_event
+ weight_update_event_calls = [
+ c
+ for c in trieye_actor_mock.log_event.call_args_list
+ if isinstance(c.args[0], RawMetricEvent)
+ and c.args[0].name == "Progress/Weight_Updates_Total"
+ ]
+ assert len(weight_update_event_calls) == total_expected_updates, (
+ f"Weight update event calls: expected {total_expected_updates}, got {len(weight_update_event_calls)}"
+ )
+
+ if weight_update_event_calls:
+ last_event_arg = weight_update_event_calls[-1].args[0]
+ assert isinstance(last_event_arg, RawMetricEvent)
+ assert last_event_arg.value == expected_total_weight_updates, (
+ f"Last weight update event value mismatch: expected {expected_total_weight_updates}, got {last_event_arg.value}"
+ )
+ last_expected_step = total_expected_updates * update_freq
+ assert last_event_arg.global_step == last_expected_step, (
+ f"Last weight update event step mismatch: expected {last_expected_step}, got {last_event_arg.global_step}"
+ )
+
+ assert results_processed > 0, "No self-play results were processed"
+
+ for worker_idx, worker in enumerate(mock_workers):
+ if worker is not None:
+ set_step_mock = cast("MagicMock", worker.set_current_trainer_step)
+ set_weights_mock = cast("MagicMock", worker.set_weights)
+
+ assert set_step_mock.call_count == total_expected_updates, (
+ f"Worker {worker_idx} set_step calls"
+ )
+ assert set_weights_mock.call_count == total_expected_updates, (
+ f"Worker {worker_idx} set_weights calls"
+ )
+
+ if total_expected_updates > 0:
+ last_expected_step = total_expected_updates * update_freq
+ set_step_mock.assert_called_with(last_expected_step)
+ assert worker.current_trainer_step == last_expected_step, (
+ f"Worker {worker_idx} final step"
+ )
+ else:
+ pytest.fail(f"Worker object {worker_idx} became None during test")
+
+ logger.info(
+ f"Test finished. Verified {results_processed} results processed. Expected updates: {total_expected_updates}."
+ )
+
+
+def test_checkpoint_save_trigger(
+ mock_training_loop: TrainingLoop, mock_components: TrainingComponents
+):
+ """Verify that save_training_state is called on the Trieye actor."""
+ loop = mock_training_loop
+ trieye_actor_mock = mock_components.trieye_actor
+ assert trieye_actor_mock is not None
+
+ save_freq = loop.train_config.CHECKPOINT_SAVE_FREQ_STEPS
+ loop.train_config.MAX_TRAINING_STEPS = save_freq + 2 # Ensure save is triggered
+
+ loop.run()
+
+ # Check if save_training_state was called at the correct step
+ save_calls = trieye_actor_mock.save_training_state.call_args_list
+ assert len(save_calls) >= 1, "save_training_state was not called"
+
+ # Check the call at the expected frequency step
+ call_found = False
+ for call in save_calls:
+ if call.kwargs.get("global_step") == save_freq:
+ assert call.kwargs.get("is_best") is False
+ assert (
+ call.kwargs.get("save_buffer") is False
+ ) # Default checkpoint save doesn't save buffer
+ call_found = True
+ break
+ assert call_found, f"save_training_state not called at expected step {save_freq}"
+
+
+File: tests/rl/test_buffer.py
+from collections import deque
+
+import numpy as np
+import pytest
+
+from alphatriangle.config import TrainConfig
+from alphatriangle.rl import ExperienceBuffer
+from alphatriangle.utils.sumtree import SumTree
+from alphatriangle.utils.types import Experience, StateType
+
+# Use module-level rng from tests/conftest.py
+from tests.conftest import rng
+
+# --- Fixtures ---
+
+
+@pytest.fixture
+def uniform_train_config() -> TrainConfig:
+ """TrainConfig for uniform buffer."""
+ return TrainConfig(
+ BUFFER_CAPACITY=100,
+ MIN_BUFFER_SIZE_TO_TRAIN=10,
+ BATCH_SIZE=4,
+ USE_PER=False,
+ # Provide defaults for other required fields
+ LOAD_CHECKPOINT_PATH=None,
+ LOAD_BUFFER_PATH=None,
+ AUTO_RESUME_LATEST=False,
+ DEVICE="cpu",
+ RANDOM_SEED=42,
+ NUM_SELF_PLAY_WORKERS=1,
+ WORKER_DEVICE="cpu",
+ WORKER_UPDATE_FREQ_STEPS=10,
+ OPTIMIZER_TYPE="Adam",
+ LEARNING_RATE=1e-3,
+ WEIGHT_DECAY=1e-4,
+ LR_SCHEDULER_ETA_MIN=1e-6,
+ POLICY_LOSS_WEIGHT=1.0,
+ VALUE_LOSS_WEIGHT=1.0,
+ ENTROPY_BONUS_WEIGHT=0.0,
+ CHECKPOINT_SAVE_FREQ_STEPS=50,
+ PER_ALPHA=0.6,
+ PER_BETA_INITIAL=0.4,
+ PER_BETA_FINAL=1.0,
+ PER_BETA_ANNEAL_STEPS=100,
+ PER_EPSILON=1e-5,
+ MAX_TRAINING_STEPS=200, # Set a finite value for tests
+ N_STEP_RETURNS=3,
+ GAMMA=0.99,
+ )
+
+
+@pytest.fixture
+def per_train_config() -> TrainConfig:
+ """TrainConfig for PER buffer."""
+ return TrainConfig(
+ BUFFER_CAPACITY=100,
+ MIN_BUFFER_SIZE_TO_TRAIN=10,
+ BATCH_SIZE=4,
+ USE_PER=True,
+ PER_ALPHA=0.6,
+ PER_BETA_INITIAL=0.4,
+ PER_BETA_FINAL=1.0,
+ PER_BETA_ANNEAL_STEPS=50, # Short anneal for testing
+ PER_EPSILON=1e-5,
+ # Provide defaults for other required fields
+ LOAD_CHECKPOINT_PATH=None,
+ LOAD_BUFFER_PATH=None,
+ AUTO_RESUME_LATEST=False,
+ DEVICE="cpu",
+ RANDOM_SEED=42,
+ NUM_SELF_PLAY_WORKERS=1,
+ WORKER_DEVICE="cpu",
+ WORKER_UPDATE_FREQ_STEPS=10,
+ OPTIMIZER_TYPE="Adam",
+ LEARNING_RATE=1e-3,
+ WEIGHT_DECAY=1e-4,
+ LR_SCHEDULER_ETA_MIN=1e-6,
+ POLICY_LOSS_WEIGHT=1.0,
+ VALUE_LOSS_WEIGHT=1.0,
+ ENTROPY_BONUS_WEIGHT=0.0,
+ CHECKPOINT_SAVE_FREQ_STEPS=50,
+ MAX_TRAINING_STEPS=200, # Set a finite value for tests
+ N_STEP_RETURNS=3,
+ GAMMA=0.99,
+ )
+
+
+@pytest.fixture
+def uniform_buffer(uniform_train_config: TrainConfig) -> ExperienceBuffer:
+ """Provides an empty uniform ExperienceBuffer."""
+ return ExperienceBuffer(uniform_train_config)
+
+
+@pytest.fixture
+def per_buffer(per_train_config: TrainConfig) -> ExperienceBuffer:
+ """Provides an empty PER ExperienceBuffer."""
+ return ExperienceBuffer(per_train_config)
+
+
+# Use shared mock_experience fixture implicitly from tests/conftest.py
+
+
+# --- Uniform Buffer Tests ---
+
+
+def test_uniform_buffer_init(uniform_buffer: ExperienceBuffer):
+ assert not uniform_buffer.use_per
+ assert isinstance(uniform_buffer.buffer, deque)
+ assert uniform_buffer.capacity == 100
+ assert len(uniform_buffer) == 0
+ assert not uniform_buffer.is_ready()
+
+
+# Use mock_experience directly injected by pytest
+def test_uniform_buffer_add(
+ uniform_buffer: ExperienceBuffer, mock_experience: Experience
+):
+ assert len(uniform_buffer) == 0
+ uniform_buffer.add(mock_experience)
+ assert len(uniform_buffer) == 1
+ # Check if the added experience is the same object (or equal)
+ # Note: Deep comparison might be needed if copies are made internally
+ assert uniform_buffer.buffer[0] == mock_experience
+
+
+# Use mock_experience directly injected by pytest
+def test_uniform_buffer_add_batch(
+ uniform_buffer: ExperienceBuffer, mock_experience: Experience
+):
+ # Create copies for the batch to avoid adding the same object reference multiple times
+ batch: list[Experience] = [
+ (
+ {
+ "grid": mock_experience[0]["grid"].copy(),
+ "other_features": mock_experience[0]["other_features"].copy(),
+ # REMOVED: "available_shapes_geometry": [...]
+ },
+ mock_experience[1],
+ mock_experience[2],
+ )
+ for _ in range(5)
+ ]
+ uniform_buffer.add_batch(batch)
+ assert len(uniform_buffer) == 5
+
+
+# Use mock_experience directly injected by pytest
+def test_uniform_buffer_capacity(
+ uniform_buffer: ExperienceBuffer, mock_experience: Experience
+):
+ for i in range(uniform_buffer.capacity + 10):
+ # Create slightly different experiences with deep copies
+ state_copy: StateType = {
+ "grid": mock_experience[0]["grid"].copy() + i,
+ "other_features": mock_experience[0]["other_features"].copy() + i,
+ # REMOVED: "available_shapes_geometry": [...]
+ }
+ exp_copy: Experience = (state_copy, mock_experience[1], mock_experience[2] + i)
+ uniform_buffer.add(exp_copy)
+ assert len(uniform_buffer) == uniform_buffer.capacity
+ # Check if the first added element is gone (check based on value)
+ first_added_val = mock_experience[2] + 0
+ assert not any(exp[2] == first_added_val for exp in uniform_buffer.buffer)
+ # Check if the last added element is present (check based on value)
+ last_added_val = mock_experience[2] + uniform_buffer.capacity + 9
+ assert any(exp[2] == last_added_val for exp in uniform_buffer.buffer)
+
+
+# Use mock_experience directly injected by pytest
+def test_uniform_buffer_is_ready(
+ uniform_buffer: ExperienceBuffer, mock_experience: Experience
+):
+ assert not uniform_buffer.is_ready()
+ for i in range(uniform_buffer.min_size_to_train):
+ # Create copies
+ state_copy: StateType = {
+ "grid": mock_experience[0]["grid"].copy(),
+ "other_features": mock_experience[0]["other_features"].copy(),
+ # REMOVED: "available_shapes_geometry": [...]
+ }
+ exp_copy: Experience = (state_copy, mock_experience[1], mock_experience[2] + i)
+ uniform_buffer.add(exp_copy)
+ assert uniform_buffer.is_ready()
+
+
+# Use mock_experience directly injected by pytest
+def test_uniform_buffer_sample(
+ uniform_buffer: ExperienceBuffer, mock_experience: Experience
+):
+ # Fill buffer until ready
+ for i in range(uniform_buffer.min_size_to_train):
+ # Create copies
+ state_copy: StateType = {
+ "grid": mock_experience[0]["grid"].copy() + i,
+ "other_features": mock_experience[0]["other_features"].copy() + i,
+ # REMOVED: "available_shapes_geometry": [...]
+ }
+ exp_copy: Experience = (state_copy, mock_experience[1], mock_experience[2] + i)
+ uniform_buffer.add(exp_copy)
+
+ sample = uniform_buffer.sample(uniform_buffer.config.BATCH_SIZE)
+ assert sample is not None
+ assert isinstance(sample, dict)
+ assert "batch" in sample
+ assert "indices" in sample
+ assert "weights" in sample
+ assert len(sample["batch"]) == uniform_buffer.config.BATCH_SIZE
+ assert isinstance(sample["batch"][0], tuple) # Check if it's an Experience tuple
+ # Check structure of the first element's StateType
+ assert isinstance(sample["batch"][0][0], dict)
+ assert "grid" in sample["batch"][0][0]
+ assert "other_features" in sample["batch"][0][0]
+ # REMOVED: assert "available_shapes_geometry" in sample["batch"][0][0]
+ assert sample["indices"].shape == (uniform_buffer.config.BATCH_SIZE,)
+ assert sample["weights"].shape == (uniform_buffer.config.BATCH_SIZE,)
+ assert np.allclose(sample["weights"], 1.0) # Uniform weights should be 1.0
+
+
+def test_uniform_buffer_sample_not_ready(uniform_buffer: ExperienceBuffer):
+ sample = uniform_buffer.sample(uniform_buffer.config.BATCH_SIZE)
+ assert sample is None
+
+
+def test_uniform_buffer_update_priorities(uniform_buffer: ExperienceBuffer):
+ # Should be a no-op
+ initial_len = len(uniform_buffer)
+ uniform_buffer.update_priorities(np.array([0, 1]), np.array([0.5, 0.1]))
+ assert len(uniform_buffer) == initial_len # No change expected
+
+
+# --- PER Buffer Tests ---
+
+
+def test_per_buffer_init(per_buffer: ExperienceBuffer):
+ assert per_buffer.use_per
+ assert isinstance(per_buffer.tree, SumTree)
+ assert per_buffer.capacity == 100
+ assert len(per_buffer) == 0
+ assert not per_buffer.is_ready()
+ assert per_buffer.tree.max_priority == 1.0 # Initial max priority
+
+
+# Use mock_experience directly injected by pytest
+def test_per_buffer_add(per_buffer: ExperienceBuffer, mock_experience: Experience):
+ assert len(per_buffer) == 0
+ initial_max_p = per_buffer.tree.max_priority
+ per_buffer.add(mock_experience)
+ assert len(per_buffer) == 1
+ # Check if added with initial max priority
+ data_idx = (
+ per_buffer.tree.data_pointer - 1 + per_buffer.capacity
+ ) % per_buffer.capacity
+ tree_idx = data_idx + per_buffer.capacity - 1
+ assert per_buffer.tree.tree[tree_idx] == initial_max_p
+ assert per_buffer.tree.data[data_idx] == mock_experience
+
+
+# Use mock_experience directly injected by pytest
+def test_per_buffer_add_batch(
+ per_buffer: ExperienceBuffer, mock_experience: Experience
+):
+ # Create copies for the batch
+ batch: list[Experience] = [
+ (
+ {
+ "grid": mock_experience[0]["grid"].copy(),
+ "other_features": mock_experience[0]["other_features"].copy(),
+ # REMOVED: "available_shapes_geometry": [...]
+ },
+ mock_experience[1],
+ mock_experience[2],
+ )
+ for _ in range(5)
+ ]
+ per_buffer.add_batch(batch)
+ assert len(per_buffer) == 5
+
+
+# Use mock_experience directly injected by pytest
+def test_per_buffer_capacity(per_buffer: ExperienceBuffer, mock_experience: Experience):
+ for i in range(per_buffer.capacity + 10):
+ # Create copies
+ state_copy: StateType = {
+ "grid": mock_experience[0]["grid"].copy() + i,
+ "other_features": mock_experience[0]["other_features"].copy() + i,
+ # REMOVED: "available_shapes_geometry": [...]
+ }
+ exp_copy: Experience = (state_copy, mock_experience[1], mock_experience[2] + i)
+ per_buffer.add(exp_copy) # Adds with current max priority
+ assert len(per_buffer) == per_buffer.capacity
+
+
+# Use mock_experience directly injected by pytest
+def test_per_buffer_is_ready(per_buffer: ExperienceBuffer, mock_experience: Experience):
+ assert not per_buffer.is_ready()
+ for i in range(per_buffer.min_size_to_train):
+ # Create copies
+ state_copy: StateType = {
+ "grid": mock_experience[0]["grid"].copy(),
+ "other_features": mock_experience[0]["other_features"].copy(),
+ # REMOVED: "available_shapes_geometry": [...]
+ }
+ exp_copy: Experience = (state_copy, mock_experience[1], mock_experience[2] + i)
+ per_buffer.add(exp_copy)
+ assert per_buffer.is_ready()
+
+
+# Use mock_experience directly injected by pytest
+def test_per_buffer_sample(per_buffer: ExperienceBuffer, mock_experience: Experience):
+ # Fill buffer until ready
+ for i in range(per_buffer.min_size_to_train):
+ # Create copies
+ state_copy: StateType = {
+ "grid": mock_experience[0]["grid"].copy() + i,
+ "other_features": mock_experience[0]["other_features"].copy() + i,
+ # REMOVED: "available_shapes_geometry": [...]
+ }
+ exp_copy: Experience = (state_copy, mock_experience[1], mock_experience[2] + i)
+ per_buffer.add(exp_copy)
+
+ # Need current_step for beta calculation
+ sample = per_buffer.sample(per_buffer.config.BATCH_SIZE, current_train_step=10)
+ assert sample is not None
+ assert isinstance(sample, dict)
+ assert "batch" in sample
+ assert "indices" in sample
+ assert "weights" in sample
+ assert len(sample["batch"]) == per_buffer.config.BATCH_SIZE
+ assert isinstance(sample["batch"][0], tuple)
+ # Check structure of the first element's StateType
+ assert isinstance(sample["batch"][0][0], dict)
+ assert "grid" in sample["batch"][0][0]
+ assert "other_features" in sample["batch"][0][0]
+ # REMOVED: assert "available_shapes_geometry" in sample["batch"][0][0]
+ assert sample["indices"].shape == (per_buffer.config.BATCH_SIZE,)
+ assert sample["weights"].shape == (per_buffer.config.BATCH_SIZE,)
+ assert np.all(sample["weights"] >= 0) and np.all(
+ sample["weights"] <= 1.0
+ ) # Weights are normalized
+
+
+# Use mock_experience directly injected by pytest
+def test_per_buffer_sample_requires_step(
+ per_buffer: ExperienceBuffer, mock_experience: Experience
+):
+ # Fill buffer
+ for i in range(per_buffer.min_size_to_train):
+ # Create copies
+ state_copy: StateType = {
+ "grid": mock_experience[0]["grid"].copy(),
+ "other_features": mock_experience[0]["other_features"].copy(),
+ # REMOVED: "available_shapes_geometry": [...]
+ }
+ exp_copy: Experience = (state_copy, mock_experience[1], mock_experience[2] + i)
+ per_buffer.add(exp_copy)
+ with pytest.raises(ValueError, match="current_train_step is required"):
+ per_buffer.sample(per_buffer.config.BATCH_SIZE)
+
+
+# Use mock_experience directly injected by pytest
+def test_per_buffer_update_priorities(
+ per_buffer: ExperienceBuffer, mock_experience: Experience
+):
+ # Add some items
+ num_items = per_buffer.min_size_to_train
+ for i in range(num_items):
+ # Create copies
+ state_copy: StateType = {
+ "grid": mock_experience[0]["grid"].copy() + i,
+ "other_features": mock_experience[0]["other_features"].copy() + i,
+ # REMOVED: "available_shapes_geometry": [...]
+ }
+ exp_copy: Experience = (state_copy, mock_experience[1], mock_experience[2] + i)
+ per_buffer.add(exp_copy)
+
+ # Sample to get indices
+ sample = per_buffer.sample(per_buffer.config.BATCH_SIZE, current_train_step=1)
+ assert sample is not None
+ indices = sample["indices"] # These are tree indices
+
+ # Update with some errors
+ td_errors = rng.random(per_buffer.config.BATCH_SIZE) * 0.5 # Example errors
+ per_buffer.update_priorities(indices, td_errors)
+
+ # --- Verification Adjustment ---
+ expected_priorities_map = {}
+ calculated_priorities = np.array(
+ [per_buffer._get_priority(err) for err in td_errors]
+ )
+ for tree_idx, expected_p in zip(indices, calculated_priorities, strict=True):
+ expected_priorities_map[tree_idx] = expected_p # Last write wins
+
+ unique_indices = sorted(expected_priorities_map.keys())
+ actual_updated_priorities = [per_buffer.tree.tree[idx] for idx in unique_indices]
+ expected_final_priorities = [expected_priorities_map[idx] for idx in unique_indices]
+
+ assert np.allclose(
+ actual_updated_priorities, expected_final_priorities, rtol=1e-4, atol=1e-4
+ ), (
+ f"Mismatch between actual tree priorities {actual_updated_priorities} and expected {expected_final_priorities} for unique indices {unique_indices}"
+ )
+
+
+def test_per_buffer_beta_annealing(per_buffer: ExperienceBuffer):
+ config = per_buffer.config
+ assert per_buffer._calculate_beta(0) == config.PER_BETA_INITIAL
+ anneal_steps = per_buffer.per_beta_anneal_steps
+ assert anneal_steps is not None and anneal_steps > 0
+ mid_step = anneal_steps // 2
+ expected_mid_beta = config.PER_BETA_INITIAL + 0.5 * (
+ config.PER_BETA_FINAL - config.PER_BETA_INITIAL
+ )
+ assert per_buffer._calculate_beta(mid_step) == pytest.approx(expected_mid_beta)
+ assert per_buffer._calculate_beta(anneal_steps) == config.PER_BETA_FINAL
+ assert per_buffer._calculate_beta(anneal_steps * 2) == config.PER_BETA_FINAL
+
+
+File: tests/rl/__init__.py
+
+
+File: tests/rl/test_trainer.py
+# File: tests/rl/test_trainer.py
+
+import numpy as np
+import pytest
+import torch
+
+# Import EnvConfig from trianglengin's top level
+from trianglengin import EnvConfig # UPDATED IMPORT
+
+# Keep alphatriangle imports
+from alphatriangle.config import ModelConfig, TrainConfig
+from alphatriangle.nn import NeuralNetwork
+from alphatriangle.rl import ExperienceBuffer, Trainer
+from alphatriangle.utils.types import (
+ Experience,
+ ExperienceBatch, # Import ExperienceBatch
+ PERBatchSample,
+ StateType,
+)
+
+# Use shared fixtures implicitly via pytest injection
+
+
+@pytest.fixture
+def env_config(mock_env_config: EnvConfig) -> EnvConfig: # Uses trianglengin.EnvConfig
+ return mock_env_config
+
+
+@pytest.fixture
+def model_config(mock_model_config: ModelConfig) -> ModelConfig:
+ # Revert OTHER_NN_INPUT_FEATURES_DIM to the value set in conftest
+ # (which should now be 10 based on the reverted conftest)
+ mock_model_config.OTHER_NN_INPUT_FEATURES_DIM = 10
+ return mock_model_config
+
+
+@pytest.fixture
+def train_config_uniform(mock_train_config: TrainConfig) -> TrainConfig:
+ cfg = mock_train_config.model_copy(deep=True)
+ cfg.USE_PER = False
+ return cfg
+
+
+@pytest.fixture
+def train_config_per(mock_train_config: TrainConfig) -> TrainConfig:
+ cfg = mock_train_config.model_copy(deep=True)
+ cfg.USE_PER = True
+ cfg.PER_BETA_ANNEAL_STEPS = 100
+ return cfg
+
+
+@pytest.fixture
+def nn_interface(
+ model_config: ModelConfig, # Use updated model_config
+ env_config: EnvConfig, # Uses trianglengin.EnvConfig
+ train_config_uniform: TrainConfig,
+) -> NeuralNetwork:
+ """Provides a NeuralNetwork instance for testing, configured for uniform buffer."""
+ device = torch.device("cpu")
+ # Pass trianglengin.EnvConfig
+ nn_interface_instance = NeuralNetwork(
+ model_config, env_config, train_config_uniform, device
+ )
+ nn_interface_instance.model.to(device)
+ nn_interface_instance.model.eval()
+ return nn_interface_instance
+
+
+@pytest.fixture
+def trainer_uniform(
+ nn_interface: NeuralNetwork,
+ train_config_uniform: TrainConfig,
+ env_config: EnvConfig, # Uses trianglengin.EnvConfig
+) -> Trainer:
+ """Provides a Trainer instance configured for uniform sampling."""
+ # Pass trianglengin.EnvConfig
+ return Trainer(nn_interface, train_config_uniform, env_config)
+
+
+@pytest.fixture
+def trainer_per(
+ nn_interface: NeuralNetwork,
+ train_config_per: TrainConfig,
+ env_config: EnvConfig, # Uses trianglengin.EnvConfig
+) -> Trainer:
+ """Provides a Trainer instance configured for PER."""
+ # Pass trianglengin.EnvConfig
+ return Trainer(nn_interface, train_config_per, env_config)
+
+
+@pytest.fixture
+def buffer_uniform(
+ train_config_uniform: TrainConfig, mock_experience: Experience
+) -> ExperienceBuffer:
+ """Provides a filled uniform buffer."""
+ buffer = ExperienceBuffer(train_config_uniform)
+ for i in range(buffer.min_size_to_train + 5):
+ # Create copies
+ state_copy: StateType = {
+ "grid": mock_experience[0]["grid"].copy() + i,
+ "other_features": mock_experience[0]["other_features"].copy() + i,
+ # REMOVED: "available_shapes_geometry": [...]
+ }
+ exp_copy: Experience = (
+ state_copy,
+ mock_experience[1],
+ mock_experience[2] + i * 0.1,
+ )
+ buffer.add(exp_copy)
+ return buffer
+
+
+@pytest.fixture
+def buffer_per(
+ train_config_per: TrainConfig, mock_experience: Experience
+) -> ExperienceBuffer:
+ """Provides a filled PER buffer."""
+ buffer = ExperienceBuffer(train_config_per)
+ for i in range(buffer.min_size_to_train + 5):
+ # Create copies
+ state_copy: StateType = {
+ "grid": mock_experience[0]["grid"].copy() + i,
+ "other_features": mock_experience[0]["other_features"].copy() + i,
+ # REMOVED: "available_shapes_geometry": [...]
+ }
+ exp_copy: Experience = (
+ state_copy,
+ mock_experience[1],
+ mock_experience[2] + i * 0.1,
+ )
+ buffer.add(exp_copy)
+ return buffer
+
+
+def test_trainer_initialization(trainer_uniform: Trainer):
+ assert trainer_uniform.nn is not None
+ assert trainer_uniform.model is not None
+ assert trainer_uniform.optimizer is not None
+ assert hasattr(trainer_uniform, "scheduler")
+
+
+def test_prepare_batch(trainer_uniform: Trainer, mock_experience: Experience):
+ """Test the internal _prepare_batch method."""
+ batch_size = trainer_uniform.train_config.BATCH_SIZE
+ # Create copies for the batch
+ batch: ExperienceBatch = [
+ (
+ {
+ "grid": mock_experience[0]["grid"].copy(),
+ "other_features": mock_experience[0]["other_features"].copy(),
+ # REMOVED: "available_shapes_geometry": [...]
+ },
+ mock_experience[1],
+ mock_experience[2],
+ )
+ for _ in range(batch_size)
+ ]
+ grid_t, other_t, policy_target_t, n_step_return_t = trainer_uniform._prepare_batch(
+ batch
+ )
+
+ assert grid_t.shape == (
+ batch_size,
+ trainer_uniform.model_config.GRID_INPUT_CHANNELS,
+ trainer_uniform.env_config.ROWS,
+ trainer_uniform.env_config.COLS,
+ )
+ assert other_t.shape == (
+ batch_size,
+ trainer_uniform.model_config.OTHER_NN_INPUT_FEATURES_DIM,
+ )
+ # Calculate action_dim manually for comparison
+ action_dim_int = int(
+ trainer_uniform.env_config.NUM_SHAPE_SLOTS
+ * trainer_uniform.env_config.ROWS
+ * trainer_uniform.env_config.COLS
+ )
+ assert policy_target_t.shape == (batch_size, action_dim_int)
+ assert n_step_return_t.shape == (batch_size,)
+
+ assert grid_t.device == trainer_uniform.device
+ assert other_t.device == trainer_uniform.device
+ assert policy_target_t.device == trainer_uniform.device
+ assert n_step_return_t.device == trainer_uniform.device
+
+
+def test_train_step_uniform(trainer_uniform: Trainer, buffer_uniform: ExperienceBuffer):
+ """Test a single training step with uniform sampling."""
+ initial_params = [p.clone() for p in trainer_uniform.model.parameters()]
+ sample = buffer_uniform.sample(trainer_uniform.train_config.BATCH_SIZE)
+ assert sample is not None
+
+ result = trainer_uniform.train_step(sample)
+
+ assert result is not None
+ loss_info, td_errors = result
+
+ assert isinstance(loss_info, dict)
+ assert "total_loss" in loss_info
+ assert "policy_loss" in loss_info
+ assert "value_loss" in loss_info
+ assert loss_info["total_loss"] > 0
+
+ assert isinstance(td_errors, np.ndarray)
+ assert td_errors.shape == (trainer_uniform.train_config.BATCH_SIZE,)
+
+ params_changed = False
+ for p_initial, p_final in zip(
+ initial_params, trainer_uniform.model.parameters(), strict=True
+ ):
+ if not torch.equal(p_initial, p_final):
+ params_changed = True
+ break
+ assert params_changed, "Model parameters did not change after training step."
+
+
+def test_train_step_per(trainer_per: Trainer, buffer_per: ExperienceBuffer):
+ """Test a single training step with PER."""
+ initial_params = [p.clone() for p in trainer_per.model.parameters()]
+ sample = buffer_per.sample(
+ trainer_per.train_config.BATCH_SIZE, current_train_step=10
+ )
+ assert sample is not None
+
+ result = trainer_per.train_step(sample)
+
+ assert result is not None
+ loss_info, td_errors = result
+
+ assert isinstance(loss_info, dict)
+ assert "total_loss" in loss_info
+ assert "policy_loss" in loss_info
+ assert "value_loss" in loss_info
+ assert loss_info["total_loss"] > 0
+
+ assert isinstance(td_errors, np.ndarray)
+ assert td_errors.shape == (trainer_per.train_config.BATCH_SIZE,)
+ assert np.all(np.isfinite(td_errors))
+
+ params_changed = False
+ for p_initial, p_final in zip(
+ initial_params, trainer_per.model.parameters(), strict=True
+ ):
+ if not torch.equal(p_initial, p_final):
+ params_changed = True
+ break
+ assert params_changed, "Model parameters did not change after training step."
+
+
+def test_train_step_empty_batch(trainer_uniform: Trainer):
+ """Test train_step with an empty batch."""
+ empty_sample: PERBatchSample = {
+ "batch": [],
+ "indices": np.array([]),
+ "weights": np.array([]),
+ }
+ result = trainer_uniform.train_step(empty_sample)
+ assert result is None
+
+
+def test_get_current_lr(trainer_uniform: Trainer):
+ """Test retrieving the current learning rate."""
+ lr = trainer_uniform.get_current_lr()
+ assert isinstance(lr, float)
+ assert lr == trainer_uniform.train_config.LEARNING_RATE
+
+ if trainer_uniform.scheduler:
+ trainer_uniform.scheduler.step()
+ lr_after_step = trainer_uniform.get_current_lr()
+ assert isinstance(lr_after_step, float)
+
+
+File: alphatriangle.egg-info/PKG-INFO
+Metadata-Version: 2.4
+Name: alphatriangle
+Version: 1.9.0
+Summary: AlphaZero implementation for a triangle puzzle game (uses trianglengin v2+, trimcts, and Trieye for stats/persistence).
+Author-email: "Luis Guilherme P. M."
+License: MIT License
+
+ Copyright (c) 2025 Luis Guilherme P. M.
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+Project-URL: Homepage, https://github.com/lguibr/alphatriangle
+Project-URL: Bug Tracker, https://github.com/lguibr/alphatriangle/issues
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Programming Language :: Python :: 3.11
+Classifier: License :: OSI Approved :: MIT License
+Classifier: Operating System :: OS Independent
+Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
+Classifier: Topic :: Games/Entertainment :: Puzzle Games
+Classifier: Development Status :: 4 - Beta
+Requires-Python: >=3.10
+Description-Content-Type: text/markdown
+License-File: LICENSE
+Requires-Dist: trianglengin>=2.0.7
+Requires-Dist: trimcts>=1.2.1
+Requires-Dist: trieye>=0.1.3
+Requires-Dist: numpy>=1.20.0
+Requires-Dist: torch>=2.0.0
+Requires-Dist: torchvision>=0.11.0
+Requires-Dist: cloudpickle>=2.0.0
+Requires-Dist: numba>=0.55.0
+Requires-Dist: mlflow>=1.20.0
+Requires-Dist: ray[default]
+Requires-Dist: pydantic>=2.0.0
+Requires-Dist: typing_extensions>=4.0.0
+Requires-Dist: typer[all]>=0.9.0
+Requires-Dist: tensorboard>=2.10.0
+Requires-Dist: rich>=13.0.0
+Provides-Extra: dev
+Requires-Dist: pytest>=7.0.0; extra == "dev"
+Requires-Dist: pytest-cov>=3.0.0; extra == "dev"
+Requires-Dist: pytest-mock>=3.0.0; extra == "dev"
+Requires-Dist: ruff; extra == "dev"
+Requires-Dist: mypy; extra == "dev"
+Requires-Dist: build; extra == "dev"
+Requires-Dist: twine; extra == "dev"
+Requires-Dist: codecov; extra == "dev"
+Dynamic: license-file
+
+
+[](https://github.com/lguibr/alphatriangle/actions/workflows/ci_cd.yml) - [](https://codecov.io/gh/lguibr/alphatriangle) - [](https://badge.fury.io/py/alphatriangle)[](https://opensource.org/licenses/MIT) - [](https://www.python.org/downloads/)
+
+# AlphaTriangle
+
+
+
+## Overview
+
+AlphaTriangle is a project implementing an artificial intelligence agent based on AlphaZero principles to learn and play a custom puzzle game involving placing triangular shapes onto a grid. The agent learns through **headless self-play reinforcement learning**, guided by Monte Carlo Tree Search (MCTS) and a deep neural network (PyTorch).
+
+**Key Features:**
+
+* **Core Game Logic:** Uses the [`trianglengin>=2.0.7`](https://github.com/lguibr/trianglengin) library for the triangle puzzle game rules and state management, featuring a high-performance C++ core.
+* **High-Performance MCTS:** Integrates the [`trimcts>=1.2.1`](https://github.com/lguibr/trimcts) library, providing a **C++ implementation of MCTS** for efficient search, callable from Python. MCTS parameters are configurable via `alphatriangle/config/mcts_config.py`.
+* **Deep Learning Model:** Features a PyTorch neural network with policy and distributional value heads, convolutional layers, and **optional Transformer Encoder layers**.
+* **Parallel Self-Play:** Leverages **Ray** for distributed self-play data generation across multiple CPU cores. **The number of workers automatically adjusts based on detected CPU cores (reserving some for stability), capped by the `NUM_SELF_PLAY_WORKERS` setting in `TrainConfig`.**
+* **Asynchronous Stats & Persistence (NEW):** Uses the [`trieye>=0.1.2`](https://github.com/lguibr/trieye) library, which provides a dedicated **Ray actor (`TrieyeActor`)** for:
+ * Asynchronous collection of raw metric events from workers and the training loop.
+ * Configurable processing and aggregation of metrics.
+ * Logging processed metrics to **MLflow** and **TensorBoard**.
+ * Saving/loading training state (checkpoints, buffers) to the filesystem and logging artifacts to MLflow.
+ * Handling auto-resumption from previous runs.
+ * All persistent data managed by `trieye` is stored within the `.trieye_data/` directory (default: `.trieye_data/alphatriangle`).
+* **Headless Training:** Focuses on a command-line interface for running the training pipeline without visual output.
+* **Enhanced CLI:** Uses the **`rich`** library for improved visual feedback (colors, panels, emojis) in the terminal.
+* **Centralized Logging:** Uses Python's standard `logging` module configured centrally for consistent log formatting (including `▲` prefix) and level control across the project. Run logs are saved to `.trieye_data//runs//logs/`.
+* **Optional Profiling:** Supports profiling worker 0 using `cProfile` via a command-line flag. Profile data is saved to `.trieye_data//runs//profile_data/`.
+* **Unit Tests:** Includes tests for RL components.
+
+---
+
+## 🎮 The Triangle Puzzle Game Guide 🧩
+
+This project trains an agent to play the game defined by the `trianglengin` library. The rules are detailed in the [`trianglengin` README](https://github.com/lguibr/trianglengin/blob/main/README.md#the-ultimate-triangle-puzzle-guide-).
+
+---
+
+## Core Technologies
+
+* **Python 3.10+**
+* **trianglengin>=2.0.7:** Core game engine (state, actions, rules) with C++ optimizations.
+* **trimcts>=1.2.1:** High-performance C++ MCTS implementation with Python bindings.
+* **trieye>=0.1.2:** Asynchronous statistics collection, processing, logging (MLflow/TensorBoard), and data persistence via a Ray actor.
+* **PyTorch:** For the deep learning model (CNNs, **optional Transformers**, Distributional Value Head) and training, with CUDA/MPS support.
+* **NumPy:** For numerical operations, especially state representation (used by `trianglengin` and features).
+* **Ray:** For parallelizing self-play data generation and hosting the `TrieyeActor`. **Dynamically scales worker count based on available cores.**
+* **Numba:** (Optional, used in `features.grid_features`) For performance optimization of specific grid calculations.
+* **Cloudpickle:** For serializing the experience replay buffer and training checkpoints (used by `trieye`).
+* **MLflow:** For logging parameters, metrics, and artifacts (checkpoints, buffers) during training runs. **Provides the primary web UI dashboard for experiment management. Data stored in `.trieye_data//mlruns/`.** Managed by `trieye`.
+* **TensorBoard:** For visualizing metrics during training (e.g., detailed loss curves). **Provides a secondary web UI dashboard. Data stored in `.trieye_data//runs//tensorboard/`.** Managed by `trieye`.
+* **Pydantic:** For configuration management and data validation (used by `alphatriangle` and `trieye`).
+* **Typer:** For the command-line interface.
+* **Rich:** For enhanced CLI output formatting and styling.
+* **Pytest:** For running unit tests.
+
+## Project Structure
+
+```markdown
+.
+├── .github/workflows/ # GitHub Actions CI/CD
+│ └── ci_cd.yml
+├── .trieye_data/ # Root directory for ALL persistent data (GITIGNORED) - Managed by Trieye
+│ └── alphatriangle/ # Default app_name
+│ ├── mlruns/ # MLflow internal tracking data & artifact store (for MLflow UI)
+│ │ └── /
+│ │ └── /
+│ │ ├── artifacts/ # MLflow's copy of logged artifacts (checkpoints, buffers, etc.)
+│ │ ├── metrics/
+│ │ ├── params/
+│ │ └── tags/
+│ └── runs/ # Local artifacts per run (source for TensorBoard UI & resume)
+│ └── / # e.g., train_YYYYMMDD_HHMMSS
+│ ├── checkpoints/ # Saved model weights & optimizer states (*.pkl)
+│ ├── buffers/ # Saved experience replay buffers (*.pkl)
+│ ├── logs/ # Plain text log files for the run (*.log) - App + Trieye logs
+│ ├── tensorboard/ # TensorBoard log files (event files)
+│ ├── profile_data/ # cProfile output files (*.prof) if profiling enabled
+│ └── configs.json # Copy of run configuration
+├── alphatriangle/ # Source code for the AlphaZero agent package
+│ ├── __init__.py
+│ ├── cli.py # CLI logic (train, ml, tb, ray commands - headless only, uses Rich)
+│ ├── config/ # Pydantic configuration models (Model, Train, MCTS) - Stats/Persistence now in Trieye
+│ │ └── README.md
+│ ├── features/ # Feature extraction logic (operates on trianglengin.GameState)
+│ │ └── README.md
+│ ├── logging_config.py # Centralized logging setup function (for console)
+│ ├── nn/ # Neural network definition and wrapper (implements trimcts.AlphaZeroNetworkInterface)
+│ │ └── README.md
+│ ├── rl/ # RL components (Trainer, Buffer, Worker using trimcts)
+│ │ └── README.md
+│ ├── training/ # Training orchestration (Loop, Setup, Runner, WorkerManager) - Interacts with TrieyeActor
+│ │ └── README.md
+│ └── utils/ # Shared utilities and types (specific to AlphaTriangle)
+│ └── README.md
+├── tests/ # Unit tests (for alphatriangle components, excluding MCTS, Stats, Data)
+│ ├── conftest.py
+│ ├── nn/
+│ ├── rl/
+│ └── training/
+├── .gitignore
+├── .python-version
+├── LICENSE # License file (MIT)
+├── MANIFEST.in # Specifies files for source distribution
+├── pyproject.toml # Build system & package configuration (depends on trianglengin, trimcts, trieye, rich)
+├── README.md # This file
+└── requirements.txt # List of dependencies (includes trianglengin, trimcts, trieye, rich)
+```
+
+## Key Modules (`alphatriangle`)
+
+* **`cli`:** Defines the command-line interface using Typer (**`train`**, **`ml`**, **`tb`**, **`ray`** commands - headless). Uses **`rich`** for styling. ([`alphatriangle/cli.py`](alphatriangle/cli.py))
+* **`config`:** Centralized Pydantic configuration classes (Model, Train, **MCTS**). Imports `EnvConfig` from `trianglengin`. **Uses `TrieyeConfig` from `trieye` for stats/persistence.** ([`alphatriangle/config/README.md`](alphatriangle/config/README.md))
+* **`features`:** Contains logic to convert `trianglengin.GameState` objects into numerical features (`StateType`). ([`alphatriangle/features/README.md`](alphatriangle/features/README.md))
+* **`logging_config`:** Defines the `setup_logging` function for centralized **console** logger configuration. ([`alphatriangle/logging_config.py`](alphatriangle/logging_config.py))
+* **`nn`:** Contains the PyTorch `nn.Module` definition (`AlphaTriangleNet`) and a wrapper class (`NeuralNetwork`). **The `NeuralNetwork` class implicitly conforms to the `trimcts.AlphaZeroNetworkInterface` protocol.** ([`alphatriangle/nn/README.md`](alphatriangle/nn/README.md))
+* **`rl`:** Contains RL components: `Trainer` (network updates), `ExperienceBuffer` (data storage, **supports PER**), and `SelfPlayWorker` (Ray actor for parallel self-play **using `trimcts.run_mcts`**). **Workers now send `RawMetricEvent`s to the `TrieyeActor`.** ([`alphatriangle/rl/README.md`](alphatriangle/rl/README.md))
+* **`training`:** Orchestrates the **headless** training process using `TrainingLoop`, managing workers, data flow, and interaction with the **`TrieyeActor`** for logging, checkpointing, and state loading. Includes `runner.py` for the callable training function. ([`alphatriangle/training/README.md`](alphatriangle/training/README.md))
+* **`utils`:** Provides common helper functions and shared type definitions specific to the AlphaZero implementation. ([`alphatriangle/utils/README.md`](alphatriangle/utils/README.md))
+
+## Setup
+
+1. **Clone the repository (for development):**
+ ```bash
+ git clone https://github.com/lguibr/alphatriangle.git
+ cd alphatriangle
+ ```
+2. **Create a virtual environment (recommended):**
+ ```bash
+ python -m venv venv
+ source venv/bin/activate # On Windows use `venv\Scripts\activate`
+ ```
+3. **Install the package (including `trianglengin`, `trimcts`, `trieye`, and `rich`):**
+ * **For users:**
+ ```bash
+ # This will automatically install trianglengin, trimcts, trieye, and rich from PyPI if available
+ pip install alphatriangle
+ # Or install directly from Git (installs dependencies from PyPI)
+ # pip install git+https://github.com/lguibr/alphatriangle.git
+ ```
+ * **For developers (editable install):**
+ * First, ensure `trianglengin`, `trimcts`, and `trieye` are installed (ideally in editable mode from their own directories if developing all):
+ ```bash
+ # From the trianglengin directory (requires C++ build tools):
+ # pip install -e .
+ # From the trimcts directory (requires C++ build tools):
+ # pip install -e .
+ # From the trieye directory:
+ # pip install -e .
+ ```
+ * Then, install `alphatriangle` in editable mode:
+ ```bash
+ # From the alphatriangle directory:
+ pip install -e .
+ # Install dev dependencies (optional, for running tests/linting)
+ pip install -e .[dev] # Installs dev deps from pyproject.toml
+ ```
+ *Note: Ensure you have the correct PyTorch version installed for your system (CPU/CUDA/MPS). See [pytorch.org](https://pytorch.org/). Ray may have specific system requirements. `trianglengin` and `trimcts` require a C++ compiler (like GCC, Clang, or MSVC) and CMake.*
+4. **(Optional but Recommended) Add data directory to `.gitignore`:**
+ Ensure the `.gitignore` file in your project root contains the line:
+ ```
+ .trieye_data/
+ ```
+
+## Running the Code (CLI)
+
+Use the `alphatriangle` command for training and monitoring. The CLI uses `rich` for enhanced output.
+
+* **Show Help:**
+ ```bash
+ alphatriangle --help
+ ```
+* **Run Training (Headless Only):**
+ ```bash
+ alphatriangle train [--seed 42] [--log-level INFO] [--profile] [--run-name my_custom_run]
+ ```
+ * `--log-level`: Set console logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Default: INFO. Logs are saved to `.trieye_data/alphatriangle/runs//logs/`.
+ * `--seed`: Set the random seed for reproducibility. Default: 42.
+ * `--profile`: Enable cProfile for worker 0. Generates `.prof` files in `.trieye_data/alphatriangle/runs//profile_data/`.
+ * `--run-name`: Specify a custom name for the run. Default: `train_YYYYMMDD_HHMMSS`.
+* **Launch MLflow UI:**
+ Launches the MLflow web interface, automatically pointing to the `.trieye_data/alphatriangle/mlruns` directory.
+ ```bash
+ alphatriangle ml [--host 127.0.0.1] [--port 5000]
+ ```
+ Access via `http://localhost:5000` (or the specified host/port).
+* **Launch TensorBoard UI:**
+ Launches the TensorBoard web interface, automatically pointing to the `.trieye_data/alphatriangle/runs` directory (which contains the individual run subdirectories with `tensorboard` logs).
+ ```bash
+ alphatriangle tb [--host 127.0.0.1] [--port 6006]
+ ```
+ Access via `http://localhost:6006` (or the specified host/port).
+* **Launch Ray Dashboard UI:**
+ Launches the Ray Dashboard web interface. **Note:** This typically requires a Ray cluster to be running (e.g., started by `alphatriangle train` or manually).
+ ```bash
+ alphatriangle ray [--host 127.0.0.1] [--port 8265]
+ ```
+ Access via `http://localhost:8265` (or the specified host/port).
+* **Interactive Play/Debug (Use `trianglengin` CLI):**
+ *Note: Interactive modes are part of the `trianglengin` library, not this `alphatriangle` package.*
+ ```bash
+ # Ensure trianglengin is installed
+ trianglengin play [--seed 42] [--log-level INFO]
+ trianglengin debug [--seed 42] [--log-level DEBUG]
+ ```
+* **Running Unit Tests (Development):**
+ ```bash
+ pytest tests/
+ ```
+* **Analyzing Profile Data (if `--profile` was used):**
+ Use the provided `analyze_profiles.py` script (requires `typer`).
+ ```bash
+ python analyze_profiles.py .trieye_data/alphatriangle/runs//profile_data/worker_0_ep_.prof [-n ]
+ ```
+
+## Configuration
+
+* **AlphaTriangle Specific:** Parameters for the Model (`ModelConfig`), Training (`TrainConfig`), and MCTS (`AlphaTriangleMCTSConfig`) are defined in the Pydantic classes within the `alphatriangle/config/` directory.
+* **Environment:** Environment configuration (`EnvConfig`) is defined within the `trianglengin` library.
+* **Stats & Persistence:** Statistics logging and data persistence are configured via `TrieyeConfig` (which includes `StatsConfig` and `PersistenceConfig`) from the `trieye` library. These are typically instantiated in `alphatriangle/training/runner.py` or `alphatriangle/cli.py`.
+
+## Data Storage
+
+All persistent data is now managed by the `trieye` library and stored within the `.trieye_data//` directory (default: `.trieye_data/alphatriangle/`) in the project root. This directory should be added to your `.gitignore`.
+
+* **`.trieye_data//mlruns/`**: Managed by **MLflow** (via `trieye`). Contains MLflow's internal tracking data (parameters, metrics) and its own copy of logged artifacts. This is the source for the MLflow UI (`alphatriangle ml`).
+* **`.trieye_data//runs/`**: Managed by **`trieye`**. Contains locally saved artifacts for each run (checkpoints, buffers, TensorBoard logs, configs, **profile data**) before/during logging to MLflow. This directory is used for auto-resuming and is the source for the TensorBoard UI (`alphatriangle tb`).
+ * **Replay Buffer Content:** The saved buffer file (`buffer.pkl`) contains `Experience` tuples: `(StateType, PolicyTargetMapping, n_step_return)`. The `StateType` includes:
+ * `grid`: Numerical features representing grid occupancy.
+ * `other_features`: Numerical features derived from game state and available shapes.
+ * **Visualization:** This stored data allows offline analysis and visualization of grid occupancy and the available shapes for each recorded step. It does **not** contain the full sequence of actions or raw `GameState` objects needed for a complete, interactive game replay.
+
+## Maintainability
+
+This project includes README files within each major `alphatriangle` submodule. **Please keep these READMEs updated** when making changes to the code's structure, interfaces, or core logic, especially regarding the interaction with the `trieye` library.
+
+
+File: alphatriangle.egg-info/SOURCES.txt
+.python-version
+LICENSE
+MANIFEST.in
+README.md
+pyproject.toml
+requirements.txt
+alphatriangle/__init__.py
+alphatriangle/cli.py
+alphatriangle/logging_config.py
+alphatriangle.egg-info/PKG-INFO
+alphatriangle.egg-info/SOURCES.txt
+alphatriangle.egg-info/dependency_links.txt
+alphatriangle.egg-info/entry_points.txt
+alphatriangle.egg-info/requires.txt
+alphatriangle.egg-info/top_level.txt
+alphatriangle/config/README.md
+alphatriangle/config/__init__.py
+alphatriangle/config/app_config.py
+alphatriangle/config/mcts_config.py
+alphatriangle/config/model_config.py
+alphatriangle/config/train_config.py
+alphatriangle/config/validation.py
+alphatriangle/features/README.md
+alphatriangle/features/__init__.py
+alphatriangle/features/extractor.py
+alphatriangle/features/grid_features.py
+alphatriangle/features/__pycache__/grid_features.count_holes-20.py310.1.nbc
+alphatriangle/features/__pycache__/grid_features.count_holes-20.py310.nbi
+alphatriangle/features/__pycache__/grid_features.count_holes-21.py310.1.nbc
+alphatriangle/features/__pycache__/grid_features.count_holes-21.py310.nbi
+alphatriangle/features/__pycache__/grid_features.count_holes-22.py310.1.nbc
+alphatriangle/features/__pycache__/grid_features.count_holes-22.py310.nbi
+alphatriangle/features/__pycache__/grid_features.get_bumpiness-34.py310.1.nbc
+alphatriangle/features/__pycache__/grid_features.get_bumpiness-34.py310.nbi
+alphatriangle/features/__pycache__/grid_features.get_bumpiness-35.py310.1.nbc
+alphatriangle/features/__pycache__/grid_features.get_bumpiness-35.py310.nbi
+alphatriangle/features/__pycache__/grid_features.get_bumpiness-36.py310.1.nbc
+alphatriangle/features/__pycache__/grid_features.get_bumpiness-36.py310.nbi
+alphatriangle/features/__pycache__/grid_features.get_column_heights-5.py310.1.nbc
+alphatriangle/features/__pycache__/grid_features.get_column_heights-5.py310.nbi
+alphatriangle/features/__pycache__/grid_features.get_column_heights-6.py310.1.nbc
+alphatriangle/features/__pycache__/grid_features.get_column_heights-6.py310.nbi
+alphatriangle/features/__pycache__/grid_features.get_column_heights-7.py310.1.nbc
+alphatriangle/features/__pycache__/grid_features.get_column_heights-7.py310.nbi
+alphatriangle/nn/README.md
+alphatriangle/nn/__init__.py
+alphatriangle/nn/model.py
+alphatriangle/nn/network.py
+alphatriangle/rl/README.md
+alphatriangle/rl/__init__.py
+alphatriangle/rl/types.py
+alphatriangle/rl/core/README.md
+alphatriangle/rl/core/__init__.py
+alphatriangle/rl/core/buffer.py
+alphatriangle/rl/core/trainer.py
+alphatriangle/rl/self_play/README.md
+alphatriangle/rl/self_play/__init__.py
+alphatriangle/rl/self_play/mcts_helpers.py
+alphatriangle/rl/self_play/worker.py
+alphatriangle/training/README.md
+alphatriangle/training/__init__.py
+alphatriangle/training/components.py
+alphatriangle/training/logging_utils.py
+alphatriangle/training/loop.py
+alphatriangle/training/loop_helpers.py
+alphatriangle/training/runner.py
+alphatriangle/training/runners.py
+alphatriangle/training/setup.py
+alphatriangle/training/worker_manager.py
+alphatriangle/utils/README.md
+alphatriangle/utils/__init__.py
+alphatriangle/utils/geometry.py
+alphatriangle/utils/helpers.py
+alphatriangle/utils/sumtree.py
+alphatriangle/utils/types.py
+tests/__init__.py
+tests/conftest.py
+tests/nn/__init__.py
+tests/nn/test_model.py
+tests/nn/test_network.py
+tests/rl/__init__.py
+tests/rl/test_buffer.py
+tests/rl/test_trainer.py
+tests/training/test_loop_integration.py
+
+File: alphatriangle.egg-info/entry_points.txt
+[console_scripts]
+alphatriangle = alphatriangle.cli:app
+
+
+File: alphatriangle.egg-info/requires.txt
+trianglengin>=2.0.7
+trimcts>=1.2.1
+trieye>=0.1.3
+numpy>=1.20.0
+torch>=2.0.0
+torchvision>=0.11.0
+cloudpickle>=2.0.0
+numba>=0.55.0
+mlflow>=1.20.0
+ray[default]
+pydantic>=2.0.0
+typing_extensions>=4.0.0
+typer[all]>=0.9.0
+tensorboard>=2.10.0
+rich>=13.0.0
+
+[dev]
+pytest>=7.0.0
+pytest-cov>=3.0.0
+pytest-mock>=3.0.0
+ruff
+mypy
+build
+twine
+codecov
+
+
+File: alphatriangle.egg-info/top_level.txt
+alphatriangle
+tests
+
+
+File: alphatriangle.egg-info/dependency_links.txt
+
+
+
+File: mlruns/0/meta.yaml
+artifact_location: file:///Users/lg/lab/alphatriangle/mlruns/0
+creation_time: 1745378338253
+experiment_id: '0'
+last_update_time: 1745378338253
+lifecycle_stage: active
+name: Default
+
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/meta.yaml
+artifact_uri: file:///Users/lg/lab/alphatriangle/mlruns/0/f0f522a9c2de48989d65674400a88c89/artifacts
+end_time: 1745378340460
+entry_point_name: ''
+experiment_id: '0'
+lifecycle_stage: active
+run_id: f0f522a9c2de48989d65674400a88c89
+run_name: capable-auk-759
+run_uuid: f0f522a9c2de48989d65674400a88c89
+source_name: ''
+source_type: 4
+source_version: ''
+start_time: 1745378338322
+status: 3
+tags: []
+user_id: lg
+
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/tags/mlflow.user
+lg
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/tags/mlflow.runName
+capable-auk-759
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/tags/mlflow.source.name
+/Users/lg/lab/triii/.venv/bin/alphatriangle
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/tags/mlflow.source.type
+LOCAL
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/GAMMA
+0.99
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/COLS
+15
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/VALUE_MIN
+-10.0
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/CHECKPOINT_SAVE_FREQ_STEPS
+2500
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/PENALTY_GAME_OVER
+-10.0
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/COMPILE_MODEL
+True
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/OTHER_NN_INPUT_FEATURES_DIM
+30
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/PER_BETA_ANNEAL_STEPS
+100000
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/LOAD_BUFFER_PATH
+None
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/REWARD_PER_PLACED_TRIANGLE
+0.01
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/PER_BETA_INITIAL
+0.4
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/BUFFER_CAPACITY
+250000
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/VALUE_LOSS_WEIGHT
+1.0
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/LEARNING_RATE
+0.0002
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/LOAD_CHECKPOINT_PATH
+None
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/mcts_dirichlet_epsilon
+0.25
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/USE_PER
+True
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/USE_BATCH_NORM
+True
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/POLICY_LOSS_WEIGHT
+1.0
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/PER_ALPHA
+0.6
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/GRADIENT_CLIP_VALUE
+1.0
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/VALUE_MAX
+10.0
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/WEIGHT_DECAY
+0.0001
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/WORKER_UPDATE_FREQ_STEPS
+10
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/RANDOM_SEED
+42
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/LR_SCHEDULER_TYPE
+CosineAnnealingLR
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/mcts_dirichlet_alpha
+0.3
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/mcts_discount
+1.0
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/mcts_cpuct
+1.5
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/TRANSFORMER_FC_DIM
+256
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/MAX_TRAINING_STEPS
+100000
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/REWARD_PER_STEP_ALIVE
+0.005
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/MIN_BUFFER_SIZE_TO_TRAIN
+25000
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/PER_BETA_FINAL
+1.0
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/WORKER_DEVICE
+cpu
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/BATCH_SIZE
+256
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/AUTO_RESUME_LATEST
+True
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/OPTIMIZER_TYPE
+AdamW
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/USE_TRANSFORMER
+True
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/DEVICE
+auto
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/mcts_mcts_batch_size
+32
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/TRANSFORMER_DIM
+128
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/RUN_NAME
+train_20250423_001854
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/PLAYABLE_RANGE_PER_ROW
+[(3, 12), (2, 13), (1, 14), (0, 15), (0, 15), (1, 14), (2, 13), (3, 12)]
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/mcts_max_simulations
+64
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/NUM_VALUE_ATOMS
+51
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/LR_SCHEDULER_ETA_MIN
+1e-06
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/ACTIVATION_FUNCTION
+ReLU
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/REWARD_PER_CLEARED_TRIANGLE
+0.5
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/LR_SCHEDULER_T_MAX
+100000
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/RESIDUAL_BLOCK_FILTERS
+128
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/PROFILE_WORKERS
+False
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/ENTROPY_BONUS_WEIGHT
+0.001
+
+File: mlruns/0/f0f522a9c2de48989d65674400a88c89/params/PER_EPSILON
+1e-05
+
+File: alphatriangle/__init__.py
+
+
+File: alphatriangle/cli.py
+# File: alphatriangle/cli.py
+import logging
+import shutil
+import subprocess
+import sys
+from typing import Annotated
+
+import typer
+from rich.console import Console
+from rich.panel import Panel
+
+# Import Trieye config
+from trieye import PersistenceConfig, TrieyeConfig
+
+# Import alphatriangle specific configs and runner
+from alphatriangle.config import (
+ APP_NAME, # Use APP_NAME from config
+ TrainConfig,
+)
+from alphatriangle.logging_config import setup_logging # Import centralized setup
+from alphatriangle.training.runners import run_training
+
+# Initialize Rich Console
+console = Console()
+
+# Use a standard string for the main help, rely on rich_markup_mode for styling
+app_help_text = (
+ "▲ AlphaTriangle CLI ▲\nAlphaZero training pipeline for a triangle puzzle game."
+)
+
+app = typer.Typer(
+ name="alphatriangle",
+ help=app_help_text, # Use the standard string here
+ add_completion=False,
+ rich_markup_mode="markdown", # Keep markdown enabled for help text formatting
+ pretty_exceptions_show_locals=False,
+)
+
+# --- CLI Option Annotations ---
+LogLevelOption = Annotated[
+ str,
+ typer.Option(
+ "--log-level",
+ "-l",
+ help="Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL).",
+ case_sensitive=False,
+ ),
+]
+
+SeedOption = Annotated[
+ int,
+ typer.Option(
+ "--seed",
+ "-s",
+ help="Random seed for reproducibility.",
+ ),
+]
+
+ProfileOption = Annotated[
+ bool,
+ typer.Option(
+ "--profile",
+ help="Enable cProfile for worker 0.",
+ is_flag=True,
+ ),
+]
+
+RunNameOption = Annotated[
+ str | None,
+ typer.Option(
+ "--run-name",
+ help="Specify a custom name for the run (overrides default timestamp).",
+ ),
+]
+
+
+HostOption = Annotated[
+ str, typer.Option(help="The network address to listen on (default: 127.0.0.1).")
+]
+
+PortOption = Annotated[int, typer.Option(help="The port to listen on.")]
+
+
+# --- Helper Function ---
+def _run_external_ui(
+ command_name: str,
+ command_args: list[str],
+ ui_name: str,
+ default_url: str,
+):
+ """Runs an external UI command and handles common errors."""
+ logger = logging.getLogger(__name__)
+ executable = shutil.which(command_name)
+ if not executable:
+ console.print(
+ f"[bold red]Error:[/bold red] Could not find '{command_name}' executable in PATH."
+ )
+ console.print(
+ f"Please ensure {ui_name} is installed and accessible in your environment."
+ )
+ raise typer.Exit(code=1)
+
+ full_command = [executable] + command_args
+ console.print(
+ Panel(
+ f"🚀 Launching [bold cyan]{ui_name}[/]...\n"
+ f" ❯ Command: [dim]{' '.join(full_command)}[/]\n"
+ f" ❯ Access URL (approx): [link={default_url}]{default_url}[/]",
+ title="External UI",
+ border_style="blue",
+ expand=False,
+ )
+ )
+
+ try:
+ # Use Popen for non-blocking execution if needed, but run is simpler for now
+ process = subprocess.run(full_command, check=False)
+ if process.returncode != 0:
+ console.print(
+ f"[bold red]Error:[/bold red] {ui_name} command failed with exit code {process.returncode}"
+ )
+ # Don't exit immediately, let the calling function handle it if needed
+ # raise typer.Exit(code=process.returncode)
+ except FileNotFoundError as e:
+ console.print(
+ f"[bold red]Error:[/bold red] '{executable}' command not found. Is {ui_name} installed and in your PATH?"
+ )
+ raise typer.Exit(code=1) from e
+ except KeyboardInterrupt:
+ console.print(f"\n[yellow]🟡 {ui_name} interrupted by user.[/]")
+ raise typer.Exit(code=0) from None
+ except Exception as e:
+ console.print(
+ f"[bold red]❌ An unexpected error occurred launching {ui_name}:[/]"
+ )
+ logger.error(f"Unexpected error launching {ui_name}: {e}", exc_info=True)
+ raise typer.Exit(code=1) from e
+
+
+# --- CLI Commands ---
+@app.command()
+def train(
+ log_level: LogLevelOption = "INFO",
+ seed: SeedOption = 42,
+ profile: ProfileOption = False,
+ run_name: RunNameOption = None, # Add run_name option
+):
+ """
+ 🚀 Run the AlphaTriangle training pipeline (headless).
+
+ Initiates the self-play and learning process. Uses Trieye for stats/persistence.
+ Logs will be saved to the run directory within `.trieye_data/alphatriangle/runs/`.
+ This command also initializes Ray and starts the Ray Dashboard. Check the logs for the dashboard URL.
+ """
+ # Setup logging using the centralized function (file logging handled by Trieye)
+ setup_logging(log_level)
+ logging.getLogger(__name__) # Get logger after setup
+
+ # Use alphatriangle TrainConfig
+ train_config_override = TrainConfig()
+ train_config_override.RANDOM_SEED = seed
+ train_config_override.PROFILE_WORKERS = profile # Set profile config
+
+ # Create TrieyeConfig, overriding run_name if provided
+ trieye_config_override = TrieyeConfig(app_name=APP_NAME)
+ if run_name:
+ trieye_config_override.run_name = run_name
+ # Sync run_name to persistence config within TrieyeConfig
+ trieye_config_override.persistence.RUN_NAME = run_name
+ else:
+ # Use the default factory-generated run_name from TrieyeConfig
+ run_name = trieye_config_override.run_name
+
+ console.print(
+ Panel(
+ f"Starting Training Run: '[bold cyan]{run_name}[/]'\n"
+ f"Seed: {seed}, Log Level: {log_level.upper()}, Profiling: {'✅ Enabled' if profile else '❌ Disabled'}",
+ title="[bold green]Training Setup[/]",
+ border_style="green",
+ expand=False,
+ )
+ )
+
+ # Call the single runner function directly, passing configs
+ exit_code = run_training(
+ log_level_str=log_level,
+ train_config_override=train_config_override,
+ trieye_config_override=trieye_config_override,
+ profile=profile,
+ )
+
+ if exit_code == 0:
+ console.print(
+ Panel(
+ f"✅ Training run '[bold cyan]{run_name}[/]' completed successfully.",
+ title="[bold green]Training Finished[/]",
+ border_style="green",
+ )
+ )
+ else:
+ console.print(
+ Panel(
+ f"❌ Training run '[bold cyan]{run_name}[/]' failed with exit code {exit_code}.",
+ title="[bold red]Training Failed[/]",
+ border_style="red",
+ )
+ )
+ sys.exit(exit_code)
+
+
+@app.command()
+def ml(
+ host: HostOption = "127.0.0.1",
+ port: PortOption = 5000,
+):
+ """
+ 📊 Launch the MLflow UI for experiment tracking.
+
+ Requires MLflow to be installed. Points to the `.trieye_data//mlruns` directory.
+ """
+ setup_logging("INFO") # Basic logging for this command
+ # Use Trieye's PersistenceConfig to find the path
+ persist_config = PersistenceConfig(APP_NAME=APP_NAME)
+ mlflow_uri = persist_config.MLFLOW_TRACKING_URI
+ mlflow_path = persist_config.get_mlflow_abs_path()
+
+ if not mlflow_path.exists() or not any(mlflow_path.iterdir()):
+ console.print(
+ f"[yellow]Warning:[/yellow] MLflow directory not found or empty at expected location: [dim]{mlflow_path}[/]"
+ )
+ console.print("[yellow]Attempting to launch MLflow UI anyway...[/]")
+ else:
+ console.print(f"Found MLflow data at: [dim]{mlflow_path}[/]")
+
+ command_args = [
+ "ui",
+ "--backend-store-uri",
+ mlflow_uri,
+ "--host",
+ host,
+ "--port",
+ str(port),
+ ]
+ try:
+ _run_external_ui("mlflow", command_args, "MLflow UI", f"http://{host}:{port}")
+ except typer.Exit as e:
+ if e.exit_code != 0:
+ console.print(
+ f"[yellow]MLflow UI failed to start (Exit Code: {e.exit_code}). "
+ f"Is port {port} already in use? Try specifying a different port with --port.[/]"
+ )
+ sys.exit(e.exit_code)
+
+
+@app.command()
+def tb(
+ host: HostOption = "127.0.0.1",
+ port: PortOption = 6006,
+):
+ """
+ 📈 Launch TensorBoard UI pointing to the runs directory.
+
+ Requires TensorBoard to be installed. Points to the `.trieye_data//runs` directory.
+ """
+ setup_logging("INFO") # Basic logging for this command
+ # Use Trieye's PersistenceConfig to find the path
+ persist_config = PersistenceConfig(APP_NAME=APP_NAME)
+ runs_root_dir = persist_config.get_runs_root_dir()
+
+ if not runs_root_dir.exists() or not any(runs_root_dir.iterdir()):
+ console.print(
+ f"[yellow]Warning:[/yellow] TensorBoard 'runs' directory not found or empty at: [dim]{runs_root_dir}[/]"
+ )
+ console.print(
+ "[yellow]Attempting to launch TensorBoard UI anyway (it might show no data)...[/]"
+ )
+ else:
+ console.print(f"Found TensorBoard runs data at: [dim]{runs_root_dir}[/]")
+
+ command_args = [
+ "--logdir",
+ str(runs_root_dir), # Pass the absolute path string
+ "--host",
+ host,
+ "--port",
+ str(port),
+ ]
+ try:
+ _run_external_ui(
+ "tensorboard", command_args, "TensorBoard UI", f"http://{host}:{port}"
+ )
+ except typer.Exit as e:
+ if e.exit_code != 0:
+ console.print(
+ f"[yellow]TensorBoard UI failed to start (Exit Code: {e.exit_code}). "
+ f"Is port {port} already in use? Try specifying a different port with --port.[/]"
+ )
+ sys.exit(e.exit_code)
+
+
+@app.command()
+def ray(
+ host: HostOption = "127.0.0.1", # Keep host/port options for reference
+ port: PortOption = 8265,
+):
+ """
+ ☀️ Provides instructions to view the Ray Dashboard.
+
+ The dashboard is automatically started when you run `alphatriangle train`.
+ Check the output logs of the `train` command for the correct URL.
+ """
+ setup_logging("INFO") # Basic logging for this command
+ console.print(
+ Panel(
+ f"💡 To view the Ray Dashboard:\n\n"
+ f"1. The Ray Dashboard is started automatically when you run the `[bold]alphatriangle train[/]` command.\n"
+ f"2. Check the console output or the log file for the `train` command (located in `.trieye_data/{APP_NAME}/runs//logs/`).\n"
+ f"3. Look for a line similar to: '[bold cyan]Ray Dashboard running at: http://:[/]' \n"
+ f"4. Open that specific URL in your web browser.\n\n"
+ f"[dim]Note: The default URL is often http://{host}:{port}, but it might differ. "
+ f"If you cannot access the URL, check firewall settings or if the port is blocked.[/]",
+ title="[bold yellow]Ray Dashboard Instructions[/]",
+ border_style="yellow",
+ expand=False,
+ )
+ )
+
+
+if __name__ == "__main__":
+ app()
+
+
+File: alphatriangle/nn/__init__.py
+"""
+Neural Network module for the AlphaTriangle agent.
+Contains the model definition and a wrapper for inference and training interface.
+"""
+
+from .model import AlphaTriangleNet
+from .network import NeuralNetwork
+
+__all__ = [
+ "AlphaTriangleNet",
+ "NeuralNetwork",
+]
+
+
+File: alphatriangle/nn/model.py
+# File: alphatriangle/nn/model.py
+import math
+from typing import cast
+
+import torch
+import torch.nn as nn
+
+# Import EnvConfig from trianglengin's top level
+from trianglengin import EnvConfig # UPDATED IMPORT
+
+# Keep alphatriangle ModelConfig import
+from ..config import ModelConfig
+
+
+def conv_block(
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int | tuple[int, int],
+ stride: int | tuple[int, int],
+ padding: int | tuple[int, int] | str,
+ use_batch_norm: bool,
+ activation: type[nn.Module],
+) -> nn.Sequential:
+ """Creates a standard convolutional block."""
+ layers: list[nn.Module] = [
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ bias=not use_batch_norm,
+ )
+ ]
+ if use_batch_norm:
+ layers.append(nn.BatchNorm2d(out_channels))
+ layers.append(activation())
+ return nn.Sequential(*layers)
+
+
+class ResidualBlock(nn.Module):
+ """Standard Residual Block."""
+
+ def __init__(
+ self, channels: int, use_batch_norm: bool, activation: type[nn.Module]
+ ):
+ super().__init__()
+ self.conv1 = conv_block(channels, channels, 3, 1, 1, use_batch_norm, activation)
+ self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1, bias=not use_batch_norm)
+ self.bn2 = nn.BatchNorm2d(channels) if use_batch_norm else nn.Identity()
+ self.activation = activation()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ residual = x
+ out: torch.Tensor = self.conv1(x)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out += residual
+ out = self.activation(out)
+ return out
+
+
+class PositionalEncoding(nn.Module):
+ """Injects sinusoidal positional encoding. (Adapted from PyTorch tutorial)"""
+
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
+ super().__init__()
+ if d_model <= 0:
+ raise ValueError("d_model must be positive for PositionalEncoding")
+ self.dropout = nn.Dropout(p=dropout)
+
+ position = torch.arange(max_len).unsqueeze(1) # Shape: [max_len, 1]
+ div_term = torch.exp(
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
+ ) # Shape: [d_model / 2]
+ pe = torch.zeros(max_len, d_model) # Shape: [max_len, d_model]
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ if d_model > 1:
+ pe[:, 1::2] = torch.cos(position * div_term)
+
+ pe = pe.unsqueeze(1)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x: Tensor, shape [seq_len, batch_size, embedding_dim]
+ Returns:
+ Tensor with added positional encoding.
+ """
+ pe_buffer = self.pe
+ if not isinstance(pe_buffer, torch.Tensor):
+ raise TypeError("PositionalEncoding buffer 'pe' is not a Tensor.")
+
+ if x.shape[0] > pe_buffer.shape[0]:
+ raise ValueError(
+ f"Input sequence length {x.shape[0]} exceeds max_len {pe_buffer.shape[0]} of PositionalEncoding"
+ )
+ if x.shape[2] != pe_buffer.shape[2]:
+ raise ValueError(
+ f"Input embedding dimension {x.shape[2]} does not match PositionalEncoding dimension {pe_buffer.shape[2]}"
+ )
+
+ x = x + pe_buffer[: x.size(0)]
+ return cast("torch.Tensor", self.dropout(x))
+
+
+class AlphaTriangleNet(nn.Module):
+ """
+ Neural Network architecture for AlphaTriangle.
+ Includes optional Transformer Encoder block after CNN body.
+ Supports Distributional Value Head (C51).
+ """
+
+ def __init__(
+ self, model_config: ModelConfig, env_config: EnvConfig
+ ): # Uses trianglengin.EnvConfig
+ super().__init__()
+ self.model_config = model_config
+ self.env_config = env_config
+ # Calculate action_dim manually
+ self.action_dim = int(
+ env_config.NUM_SHAPE_SLOTS * env_config.ROWS * env_config.COLS
+ )
+
+ activation_cls: type[nn.Module] = getattr(nn, model_config.ACTIVATION_FUNCTION)
+
+ # --- CNN Body ---
+ conv_layers: list[nn.Module] = []
+ in_channels = model_config.GRID_INPUT_CHANNELS
+ for i, out_channels in enumerate(model_config.CONV_FILTERS):
+ conv_layers.append(
+ conv_block(
+ in_channels,
+ out_channels,
+ model_config.CONV_KERNEL_SIZES[i],
+ model_config.CONV_STRIDES[i],
+ model_config.CONV_PADDING[i],
+ model_config.USE_BATCH_NORM,
+ activation_cls,
+ )
+ )
+ in_channels = out_channels
+ self.conv_body = nn.Sequential(*conv_layers)
+
+ # --- Residual Body ---
+ res_layers: list[nn.Module] = []
+ if model_config.NUM_RESIDUAL_BLOCKS > 0:
+ res_channels = model_config.RESIDUAL_BLOCK_FILTERS
+ if in_channels != res_channels:
+ res_layers.append(
+ conv_block(
+ in_channels,
+ res_channels,
+ 1,
+ 1,
+ 0,
+ model_config.USE_BATCH_NORM,
+ activation_cls,
+ )
+ )
+ in_channels = res_channels
+ for _ in range(model_config.NUM_RESIDUAL_BLOCKS):
+ res_layers.append(
+ ResidualBlock(
+ in_channels, model_config.USE_BATCH_NORM, activation_cls
+ )
+ )
+ self.res_body = nn.Sequential(*res_layers)
+ self.cnn_output_channels = in_channels
+
+ # --- Transformer Body (Optional) ---
+ self.transformer_body = None
+ self.pos_encoder = None
+ self.input_proj: nn.Module = nn.Identity()
+ self.transformer_output_size = 0
+
+ if model_config.USE_TRANSFORMER and model_config.TRANSFORMER_LAYERS > 0:
+ transformer_input_dim = model_config.TRANSFORMER_DIM
+ if self.cnn_output_channels != transformer_input_dim:
+ self.input_proj = nn.Conv2d(
+ self.cnn_output_channels, transformer_input_dim, kernel_size=1
+ )
+ else:
+ self.input_proj = nn.Identity()
+
+ self.pos_encoder = PositionalEncoding(transformer_input_dim, dropout=0.1)
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=transformer_input_dim,
+ nhead=model_config.TRANSFORMER_HEADS,
+ dim_feedforward=model_config.TRANSFORMER_FC_DIM,
+ activation=model_config.ACTIVATION_FUNCTION.lower(),
+ batch_first=False,
+ norm_first=True,
+ )
+ transformer_norm = nn.LayerNorm(transformer_input_dim)
+ self.transformer_body = nn.TransformerEncoder(
+ encoder_layer,
+ num_layers=model_config.TRANSFORMER_LAYERS,
+ norm=transformer_norm,
+ )
+
+ dummy_input_grid = torch.zeros(
+ 1, model_config.GRID_INPUT_CHANNELS, env_config.ROWS, env_config.COLS
+ )
+ with torch.no_grad():
+ cnn_out = self.conv_body(dummy_input_grid)
+ res_out = self.res_body(cnn_out)
+ proj_out = self.input_proj(res_out)
+ b, d, h, w = proj_out.shape
+ self.transformer_output_size = h * w * d
+ else:
+ dummy_input_grid = torch.zeros(
+ 1, model_config.GRID_INPUT_CHANNELS, env_config.ROWS, env_config.COLS
+ )
+ with torch.no_grad():
+ conv_output = self.conv_body(dummy_input_grid)
+ res_output = self.res_body(conv_output)
+ self.flattened_cnn_size = res_output.numel()
+
+ # --- Shared Fully Connected Layers ---
+ if model_config.USE_TRANSFORMER and model_config.TRANSFORMER_LAYERS > 0:
+ combined_input_size = (
+ self.transformer_output_size + model_config.OTHER_NN_INPUT_FEATURES_DIM
+ )
+ else:
+ combined_input_size = (
+ self.flattened_cnn_size + model_config.OTHER_NN_INPUT_FEATURES_DIM
+ )
+
+ shared_fc_layers: list[nn.Module] = []
+ in_features = combined_input_size
+ for hidden_dim in model_config.FC_DIMS_SHARED:
+ shared_fc_layers.append(nn.Linear(in_features, hidden_dim))
+ if model_config.USE_BATCH_NORM:
+ shared_fc_layers.append(nn.BatchNorm1d(hidden_dim))
+ shared_fc_layers.append(activation_cls())
+ in_features = hidden_dim
+ self.shared_fc = nn.Sequential(*shared_fc_layers)
+
+ # --- Policy Head ---
+ policy_head_layers: list[nn.Module] = []
+ policy_in_features = in_features
+ for hidden_dim in model_config.POLICY_HEAD_DIMS:
+ policy_head_layers.append(nn.Linear(policy_in_features, hidden_dim))
+ if model_config.USE_BATCH_NORM:
+ policy_head_layers.append(nn.BatchNorm1d(hidden_dim))
+ policy_head_layers.append(activation_cls())
+ policy_in_features = hidden_dim
+ policy_head_layers.append(nn.Linear(policy_in_features, self.action_dim))
+ self.policy_head = nn.Sequential(*policy_head_layers)
+
+ # --- Value Head (Distributional) ---
+ value_head_layers: list[nn.Module] = []
+ value_in_features = in_features
+ for hidden_dim in model_config.VALUE_HEAD_DIMS:
+ value_head_layers.append(nn.Linear(value_in_features, hidden_dim))
+ if model_config.USE_BATCH_NORM:
+ value_head_layers.append(nn.BatchNorm1d(hidden_dim))
+ value_head_layers.append(activation_cls())
+ value_in_features = hidden_dim
+ value_head_layers.append(
+ nn.Linear(value_in_features, model_config.NUM_VALUE_ATOMS)
+ )
+ self.value_head = nn.Sequential(*value_head_layers)
+
+ def forward(
+ self, grid_state: torch.Tensor, other_features: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Forward pass through the network.
+ Returns: (policy_logits, value_distribution_logits)
+ """
+ conv_out = self.conv_body(grid_state)
+ res_out = self.res_body(conv_out)
+
+ if (
+ self.model_config.USE_TRANSFORMER
+ and self.transformer_body is not None
+ and self.pos_encoder is not None
+ ):
+ proj_out = self.input_proj(res_out)
+ b, d, h, w = proj_out.shape
+ transformer_input = proj_out.flatten(2).permute(2, 0, 1)
+ transformer_input = self.pos_encoder(transformer_input)
+ transformer_output = self.transformer_body(transformer_input)
+ flattened_features = transformer_output.permute(1, 0, 2).flatten(1)
+ else:
+ flattened_features = res_out.view(res_out.size(0), -1)
+
+ combined_features = torch.cat([flattened_features, other_features], dim=1)
+
+ shared_out = self.shared_fc(combined_features)
+ policy_logits = self.policy_head(shared_out)
+ value_logits = self.value_head(shared_out)
+ return policy_logits, value_logits
+
+
+File: alphatriangle/nn/README.md
+
+# Neural Network Module (`alphatriangle.nn`)
+
+## Purpose and Architecture
+
+This module defines and manages the neural network used by the AlphaTriangle agent. It follows the AlphaZero paradigm, featuring a shared body and separate heads for policy and value prediction.
+
+- **Model Definition ([`model.py`](model.py)):**
+ - The `AlphaTriangleNet` class (inheriting from `torch.nn.Module`) defines the network architecture.
+ - It includes convolutional layers for processing the grid state, potentially residual blocks.
+ - **Optionally**, it can include a **Transformer Encoder block** after the CNN/ResNet body to apply self-attention over the spatial features before combining them with other input features. This is controlled by `ModelConfig.USE_TRANSFORMER`.
+ - The output from the CNN/Transformer body is combined with other extracted features (e.g., shape info) and passed through shared fully connected layers.
+ - It splits into two heads:
+ - **Policy Head:** Outputs logits representing the probability distribution over all possible actions.
+ - **Value Head:** Outputs logits representing a **distribution** over possible state values (C51 Distributional RL).
+ - The architecture is configurable via [`ModelConfig`](../config/model_config.py). **It uses `trianglengin.EnvConfig` for environment dimensions.**
+- **Network Interface ([`network.py`](network.py)):**
+ - The `NeuralNetwork` class acts as a wrapper around the `AlphaTriangleNet` PyTorch model.
+ - It provides a clean interface for the rest of the system (MCTS, Trainer) to interact with the network, abstracting away PyTorch specifics.
+ - It **internally uses [`alphatriangle.features.extract_state_features`](../features/extractor.py)** to convert input `GameState` objects into tensors before feeding them to the underlying `AlphaTriangleNet` model.
+ - It handles the **distributional value head**, calculating the expected value from the predicted distribution for use by MCTS.
+ - It **optionally compiles** the underlying model using `torch.compile()` based on `TrainConfig.COMPILE_MODEL` for potential performance improvements.
+ - **Crucially, it conforms to the `trimcts.AlphaZeroNetworkInterface` protocol**, providing `evaluate_state` and `evaluate_batch` methods with the expected signatures (returning concrete `dict` for policy).
+ - Key methods:
+ - `evaluate_state(state: GameState) -> tuple[dict[int, float], float]`: Takes a `GameState`, extracts features, performs a forward pass, and returns the policy probabilities (as a dictionary) and the **expected scalar value estimate**.
+ - `evaluate_batch(states: list[GameState]) -> list[tuple[dict[int, float], float]]`: Extracts features from a batch of `GameState` objects and performs batched evaluation for efficiency.
+ - `get_weights()`: Returns the model's state dictionary (on CPU).
+ - `set_weights(weights: Dict)`: Loads weights into the model (handles device placement).
+ - It handles device placement (`torch.device`).
+
+## Exposed Interfaces
+
+- **Classes:**
+ - `AlphaTriangleNet(model_config: ModelConfig, env_config: EnvConfig)`: The PyTorch `nn.Module` defining the architecture.
+ - `NeuralNetwork(model_config: ModelConfig, env_config: EnvConfig, train_config: TrainConfig, device: torch.device)`: The wrapper class providing the primary interface.
+ - `evaluate_state(state: GameState) -> tuple[dict[int, float], float]`
+ - `evaluate_batch(states: list[GameState]) -> list[tuple[dict[int, float], float]]`
+ - `get_weights() -> Dict[str, torch.Tensor]`
+ - `set_weights(weights: Dict[str, torch.Tensor])`
+ - `model`: Public attribute to access the underlying `AlphaTriangleNet` instance.
+ - `device`: Public attribute indicating the `torch.device`.
+ - `model_config`: Public attribute.
+ - `num_atoms`, `v_min`, `v_max`, `delta_z`, `support`: Attributes related to the distributional value head.
+
+## Dependencies
+
+- **[`alphatriangle.config`](../config/README.md)**:
+ - `ModelConfig`: Defines the network architecture parameters.
+ - `TrainConfig`: Used by `NeuralNetwork` init (e.g., for `COMPILE_MODEL`).
+- **`trianglengin`**:
+ - `GameState`: Input type for `evaluate_state` and `evaluate_batch`.
+ - `EnvConfig`: Provides environment dimensions (grid size, action space size) needed by the model.
+- **[`alphatriangle.features`](../features/README.md)**:
+ - `extract_state_features`: Used internally by `NeuralNetwork` to process `GameState` inputs.
+- **[`alphatriangle.utils`](../utils/README.md)**:
+ - `ActionType`, `PolicyValueOutput`, `StateType`: Used in method signatures and return types.
+- **`torch`**:
+ - The core deep learning framework (`torch`, `torch.nn`, `torch.nn.functional`).
+- **`numpy`**:
+ - Used for converting state components to tensors.
+- **Standard Libraries:** `typing`, `os`, `logging`, `math`, `sys`.
+
+---
+
+**Note:** Please keep this README updated when changing the neural network architecture (`AlphaTriangleNet`), the `NeuralNetwork` interface methods, or its interaction with configuration or other modules (especially `alphatriangle.features` and `trianglengin`). Accurate documentation is crucial for maintainability.
+
+
+File: alphatriangle/nn/network.py
+# File: alphatriangle/nn/network.py
+import logging
+import sys
+from typing import TYPE_CHECKING
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+# Import GameState and EnvConfig from trianglengin's top level
+from trianglengin import EnvConfig, GameState
+
+# Keep alphatriangle imports
+from ..config import ModelConfig, TrainConfig
+from ..features import extract_state_features
+
+# Import ActionType for explicit type usage
+from .model import AlphaTriangleNet
+
+if TYPE_CHECKING:
+ from ..utils.types import ActionType, StateType
+
+logger = logging.getLogger(__name__)
+
+
+class NetworkEvaluationError(Exception):
+ """Custom exception for errors during network evaluation."""
+
+ pass
+
+
+class NeuralNetwork:
+ """
+ Wrapper for the PyTorch model providing evaluation and state management.
+ Handles distributional value head (C51) by calculating expected value for MCTS.
+ Optionally compiles the model using torch.compile().
+ Conforms to trimcts.AlphaZeroNetworkInterface.
+ """
+
+ def __init__(
+ self,
+ model_config: ModelConfig,
+ env_config: EnvConfig, # Uses trianglengin.EnvConfig
+ train_config: TrainConfig,
+ device: torch.device,
+ ):
+ self.model_config = model_config
+ self.env_config = env_config # Store trianglengin.EnvConfig
+ self.train_config = train_config
+ self.device = device
+ # Pass trianglengin.EnvConfig to model
+ # Store the original, uncompiled model reference separately
+ self._orig_model = AlphaTriangleNet(model_config, env_config).to(device)
+ self.model = self._orig_model # Initially, model is the original model
+ # Calculate action_dim manually
+ self.action_dim = int(
+ env_config.NUM_SHAPE_SLOTS * env_config.ROWS * env_config.COLS
+ )
+ self.model.eval()
+
+ self.num_atoms = model_config.NUM_VALUE_ATOMS
+ self.v_min = model_config.VALUE_MIN
+ self.v_max = model_config.VALUE_MAX
+ self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
+ self.support = torch.linspace(
+ self.v_min, self.v_max, self.num_atoms, device=self.device
+ )
+
+ if self.train_config.COMPILE_MODEL:
+ if sys.platform == "win32":
+ logger.warning(
+ "Model compilation requested but running on Windows. Skipping torch.compile()."
+ )
+ elif self.device.type == "mps":
+ logger.warning(
+ "Model compilation requested but device is 'mps'. Skipping torch.compile()."
+ )
+ elif hasattr(torch, "compile"):
+ try:
+ logger.info(
+ f"Attempting to compile model with torch.compile() on device '{self.device}'..."
+ )
+ # Compile the original model and store the result in self.model
+ self.model = torch.compile(self._orig_model) # type: ignore
+ logger.info(
+ f"Model compiled successfully on device '{self.device}'."
+ )
+ except Exception as e:
+ logger.warning(
+ f"torch.compile() failed on device '{self.device}': {e}. Proceeding without compilation.",
+ exc_info=False,
+ )
+ # Ensure self.model still refers to the original if compilation fails
+ self.model = self._orig_model
+ else:
+ logger.warning(
+ "torch.compile() requested but not available (requires PyTorch 2.0+). Proceeding without compilation."
+ )
+ else:
+ logger.info(
+ "Model compilation skipped (COMPILE_MODEL=False in TrainConfig)."
+ )
+
+ def _state_to_tensors(self, state: GameState) -> tuple[torch.Tensor, torch.Tensor]:
+ """Extracts features from trianglengin.GameState and converts them to tensors."""
+ state_dict: StateType = extract_state_features(state, self.model_config)
+ grid_tensor = torch.from_numpy(state_dict["grid"]).unsqueeze(0).to(self.device)
+ other_features_tensor = (
+ torch.from_numpy(state_dict["other_features"]).unsqueeze(0).to(self.device)
+ )
+ if not torch.all(torch.isfinite(grid_tensor)):
+ raise NetworkEvaluationError(
+ f"Non-finite values found in input grid_tensor for state {state}"
+ )
+ if not torch.all(torch.isfinite(other_features_tensor)):
+ raise NetworkEvaluationError(
+ f"Non-finite values found in input other_features_tensor for state {state}"
+ )
+ return grid_tensor, other_features_tensor
+
+ def _batch_states_to_tensors(
+ self, states: list[GameState]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Extracts features from a batch of trianglengin.GameStates and converts to batched tensors."""
+ if not states:
+ grid_shape = (
+ 0,
+ self.model_config.GRID_INPUT_CHANNELS,
+ self.env_config.ROWS,
+ self.env_config.COLS,
+ )
+ other_shape = (0, self.model_config.OTHER_NN_INPUT_FEATURES_DIM)
+ return torch.empty(grid_shape, device=self.device), torch.empty(
+ other_shape, device=self.device
+ )
+
+ batch_grid = []
+ batch_other = []
+ for state in states:
+ state_dict: StateType = extract_state_features(state, self.model_config)
+ batch_grid.append(state_dict["grid"])
+ batch_other.append(state_dict["other_features"])
+
+ grid_tensor = torch.from_numpy(np.stack(batch_grid)).to(self.device)
+ other_features_tensor = torch.from_numpy(np.stack(batch_other)).to(self.device)
+
+ if not torch.all(torch.isfinite(grid_tensor)):
+ raise NetworkEvaluationError(
+ "Non-finite values found in batched input grid_tensor"
+ )
+ if not torch.all(torch.isfinite(other_features_tensor)):
+ raise NetworkEvaluationError(
+ "Non-finite values found in batched input other_features_tensor"
+ )
+ return grid_tensor, other_features_tensor
+
+ def _logits_to_expected_value(self, value_logits: torch.Tensor) -> torch.Tensor:
+ """Calculates the expected value from the value distribution logits."""
+ value_probs = F.softmax(value_logits, dim=1)
+ support_expanded = self.support.expand_as(value_probs)
+ expected_value = torch.sum(value_probs * support_expanded, dim=1, keepdim=True)
+ return expected_value
+
+ @torch.inference_mode()
+ def evaluate_state(self, state: GameState) -> tuple[dict[int, float], float]:
+ """
+ Evaluates a single trianglengin.GameState.
+ Returns policy mapping (dict) and EXPECTED value from the distribution.
+ Conforms to trimcts.AlphaZeroNetworkInterface.evaluate_state.
+ Raises NetworkEvaluationError on issues.
+ """
+ self.model.eval()
+ try:
+ grid_tensor, other_features_tensor = self._state_to_tensors(state)
+ policy_logits, value_logits = self.model(grid_tensor, other_features_tensor)
+
+ if not torch.all(torch.isfinite(policy_logits)):
+ raise NetworkEvaluationError(
+ f"Non-finite policy_logits detected for state {state}. Logits: {policy_logits}"
+ )
+ if not torch.all(torch.isfinite(value_logits)):
+ raise NetworkEvaluationError(
+ f"Non-finite value_logits detected for state {state}: {value_logits}"
+ )
+
+ policy_probs_tensor = F.softmax(policy_logits, dim=1)
+
+ if not torch.all(torch.isfinite(policy_probs_tensor)):
+ raise NetworkEvaluationError(
+ f"Non-finite policy probabilities AFTER softmax for state {state}. Logits were: {policy_logits}"
+ )
+
+ policy_probs = policy_probs_tensor.squeeze(0).cpu().numpy()
+ policy_probs = np.maximum(policy_probs, 0)
+ prob_sum = np.sum(policy_probs)
+ if abs(prob_sum - 1.0) > 1e-5:
+ logger.warning(
+ f"Evaluate: Policy probabilities sum to {prob_sum:.6f} (not 1.0) for state {state.current_step}. Re-normalizing."
+ )
+ if prob_sum <= 1e-9:
+ valid_actions = state.valid_actions()
+ if valid_actions:
+ num_valid = len(valid_actions)
+ policy_probs = np.zeros_like(policy_probs)
+ uniform_prob = 1.0 / num_valid
+ for action_idx in valid_actions:
+ if 0 <= action_idx < len(policy_probs):
+ policy_probs[action_idx] = uniform_prob
+ logger.warning(
+ "Policy sum was zero, returning uniform over valid actions."
+ )
+ else:
+ raise NetworkEvaluationError(
+ f"Policy probability sum is near zero ({prob_sum}) for state {state.current_step} with no valid actions. Cannot normalize."
+ )
+ else:
+ policy_probs /= prob_sum
+
+ expected_value_tensor = self._logits_to_expected_value(value_logits)
+ expected_value_scalar = expected_value_tensor.squeeze(0).item()
+
+ action_policy: dict[ActionType, float] = {
+ i: float(p) for i, p in enumerate(policy_probs)
+ }
+
+ num_non_zero = sum(1 for p in action_policy.values() if p > 1e-6)
+ logger.debug(
+ f"Evaluate Final Policy Dict (State {state.current_step}): {num_non_zero}/{self.action_dim} non-zero probs. Example: {list(action_policy.items())[:5]}"
+ )
+
+ return action_policy, expected_value_scalar
+
+ except Exception as e:
+ logger.error(
+ f"Exception during single evaluation for state {state}: {e}",
+ exc_info=True,
+ )
+ raise NetworkEvaluationError(
+ f"Evaluation failed for state {state}: {e}"
+ ) from e
+
+ @torch.inference_mode()
+ def evaluate_batch(
+ self, states: list[GameState]
+ ) -> list[tuple[dict[int, float], float]]:
+ """
+ Evaluates a batch of trianglengin.GameStates.
+ Returns a list of (policy mapping (dict), EXPECTED value).
+ Conforms to trimcts.AlphaZeroNetworkInterface.evaluate_batch.
+ Raises NetworkEvaluationError on issues.
+ """
+ if not states:
+ return []
+ self.model.eval()
+ logger.debug(f"Evaluating batch of {len(states)} states...")
+ try:
+ grid_tensor, other_features_tensor = self._batch_states_to_tensors(states)
+ policy_logits, value_logits = self.model(grid_tensor, other_features_tensor)
+
+ if not torch.all(torch.isfinite(policy_logits)):
+ raise NetworkEvaluationError(
+ f"Non-finite policy_logits detected in batch evaluation. Logits shape: {policy_logits.shape}"
+ )
+ if not torch.all(torch.isfinite(value_logits)):
+ raise NetworkEvaluationError(
+ f"Non-finite value_logits detected in batch value output. Value shape: {value_logits.shape}"
+ )
+
+ policy_probs_tensor = F.softmax(policy_logits, dim=1)
+
+ if not torch.all(torch.isfinite(policy_probs_tensor)):
+ raise NetworkEvaluationError(
+ f"Non-finite policy probabilities AFTER softmax in batch. Logits shape: {policy_logits.shape}"
+ )
+
+ policy_probs = policy_probs_tensor.cpu().numpy()
+ expected_values_tensor = self._logits_to_expected_value(value_logits)
+ expected_values = expected_values_tensor.squeeze(1).cpu().numpy()
+
+ results: list[tuple[dict[int, float], float]] = []
+ for batch_idx in range(len(states)):
+ probs_i = np.maximum(policy_probs[batch_idx], 0)
+ prob_sum_i = np.sum(probs_i)
+ if abs(prob_sum_i - 1.0) > 1e-5:
+ logger.warning(
+ f"EvaluateBatch: Policy probabilities sum to {prob_sum_i:.6f} (not 1.0) for sample {batch_idx}. Re-normalizing."
+ )
+ if prob_sum_i <= 1e-9:
+ valid_actions = states[batch_idx].valid_actions()
+ if valid_actions:
+ num_valid = len(valid_actions)
+ probs_i = np.zeros_like(probs_i)
+ uniform_prob = 1.0 / num_valid
+ for action_idx in valid_actions:
+ if 0 <= action_idx < len(probs_i):
+ probs_i[action_idx] = uniform_prob
+ logger.warning(
+ f"Batch sample {batch_idx} policy sum was zero, returning uniform."
+ )
+ else:
+ raise NetworkEvaluationError(
+ f"Policy probability sum is near zero ({prob_sum_i}) for batch sample {batch_idx} with no valid actions. Cannot normalize."
+ )
+ else:
+ probs_i /= prob_sum_i
+
+ policy_i: dict[ActionType, float] = {
+ i: float(p) for i, p in enumerate(probs_i)
+ }
+ value_i = float(expected_values[batch_idx])
+ results.append((policy_i, value_i))
+
+ except Exception as e:
+ logger.error(f"Error during batch evaluation: {e}", exc_info=True)
+ raise NetworkEvaluationError(f"Batch evaluation failed: {e}") from e
+
+ logger.debug(f" Batch evaluation finished. Returning {len(results)} results.")
+ return results
+
+ def get_weights(self) -> dict[str, torch.Tensor]:
+ """Returns the model's state dictionary, moved to CPU."""
+ # Always get weights from the original underlying model
+ return {k: v.cpu() for k, v in self._orig_model.state_dict().items()}
+
+ def set_weights(self, weights: dict[str, torch.Tensor]):
+ """Loads the model's state dictionary into the original underlying model."""
+ try:
+ weights_on_device = {k: v.to(self.device) for k, v in weights.items()}
+ # Always load weights into the original underlying model
+ self._orig_model.load_state_dict(weights_on_device)
+ # Ensure the potentially compiled model (self.model) is also in eval mode
+ self.model.eval()
+ logger.debug("NN weights set successfully into underlying model.")
+ except Exception as e:
+ logger.error(f"Error setting weights on NN instance: {e}", exc_info=True)
+ raise
+
+
+File: alphatriangle/config/model_config.py
+# File: alphatriangle/config/model_config.py
+from typing import Literal
+
+from pydantic import BaseModel, Field, model_validator
+
+
+class ModelConfig(BaseModel):
+ """
+ Configuration for the Neural Network model (Pydantic model).
+ --- TUNED FOR SMALLER CAPACITY (~3M Params Target, Laptop Feasible) ---
+ NOTE: Increasing filters/layers/dims can improve final agent strength
+ but will increase training time due to the NN evaluation bottleneck.
+ Adjust based on hardware capabilities and performance requirements.
+ """
+
+ # Input channels for the grid (e.g., 1 for occupancy, more for history/colors)
+ GRID_INPUT_CHANNELS: int = Field(default=1, gt=0)
+
+ # --- CNN Architecture Parameters ---
+ CONV_FILTERS: list[int] = Field(default=[32, 64, 128]) # Smaller CNN
+ CONV_KERNEL_SIZES: list[int | tuple[int, int]] = Field(default=[3, 3, 3])
+ CONV_STRIDES: list[int | tuple[int, int]] = Field(default=[1, 1, 1])
+ CONV_PADDING: list[int | tuple[int, int] | str] = Field(default=[1, 1, 1])
+
+ # --- Residual Block Parameters ---
+ NUM_RESIDUAL_BLOCKS: int = Field(default=2, ge=0) # Fewer blocks
+ RESIDUAL_BLOCK_FILTERS: int = Field(default=128, gt=0) # Match last conv filter
+
+ # --- Transformer Parameters (Optional) ---
+ USE_TRANSFORMER: bool = Field(default=True) # Keep Transformer enabled
+ TRANSFORMER_DIM: int = Field(default=128, gt=0) # Match Res block filters
+ TRANSFORMER_HEADS: int = Field(default=4, gt=0) # Moderate heads
+ TRANSFORMER_LAYERS: int = Field(default=2, ge=0) # Fewer layers
+ TRANSFORMER_FC_DIM: int = Field(default=256, gt=0) # Moderate feedforward dim
+
+ # --- Fully Connected Layers ---
+ FC_DIMS_SHARED: list[int] = Field(default=[128]) # Single shared layer
+
+ # --- Policy Head ---
+ POLICY_HEAD_DIMS: list[int] = Field(default=[128]) # Single policy layer
+
+ # --- Distributional Value Head Parameters ---
+ NUM_VALUE_ATOMS: int = Field(default=51, gt=1) # Standard C51 atoms
+ VALUE_MIN: float = Field(default=-10.0) # Reasonable expected value range
+ VALUE_MAX: float = Field(default=10.0) # Reasonable expected value range
+
+ # --- Value Head Dims ---
+ VALUE_HEAD_DIMS: list[int] = Field(default=[128]) # Single value layer
+
+ # --- Other Hyperparameters ---
+ ACTIVATION_FUNCTION: Literal["ReLU", "GELU", "SiLU", "Tanh", "Sigmoid"] = Field(
+ default="ReLU"
+ )
+ USE_BATCH_NORM: bool = Field(default=True)
+
+ # --- Input Feature Dimension ---
+ # Dimension of the non-grid feature vector concatenated after CNN/Transformer.
+ # Must match the output of features.extractor.get_combined_other_features.
+ OTHER_NN_INPUT_FEATURES_DIM: int = Field(default=30, gt=0) # Keep default
+
+ @model_validator(mode="after")
+ def check_conv_layers_consistency(self) -> "ModelConfig":
+ # Ensure attributes exist before checking lengths
+ if (
+ hasattr(self, "CONV_FILTERS")
+ and hasattr(self, "CONV_KERNEL_SIZES")
+ and hasattr(self, "CONV_STRIDES")
+ and hasattr(self, "CONV_PADDING")
+ ):
+ n_filters = len(self.CONV_FILTERS)
+ if not (
+ len(self.CONV_KERNEL_SIZES) == n_filters
+ and len(self.CONV_STRIDES) == n_filters
+ and len(self.CONV_PADDING) == n_filters
+ ):
+ raise ValueError(
+ "Lengths of CONV_FILTERS, CONV_KERNEL_SIZES, CONV_STRIDES, and CONV_PADDING must match."
+ )
+ return self
+
+ @model_validator(mode="after")
+ def check_residual_filter_match(self) -> "ModelConfig":
+ # Ensure attributes exist before checking
+ if (
+ hasattr(self, "NUM_RESIDUAL_BLOCKS")
+ and self.NUM_RESIDUAL_BLOCKS > 0
+ and hasattr(self, "CONV_FILTERS")
+ and self.CONV_FILTERS
+ and hasattr(self, "RESIDUAL_BLOCK_FILTERS")
+ and self.CONV_FILTERS[-1] != self.RESIDUAL_BLOCK_FILTERS
+ ):
+ # This warning is now handled by the projection layer in the model if needed
+ pass # Model handles projection if needed
+ return self
+
+ @model_validator(mode="after")
+ def check_transformer_config(self) -> "ModelConfig":
+ # Ensure attributes exist before checking
+ if hasattr(self, "USE_TRANSFORMER") and self.USE_TRANSFORMER:
+ if not hasattr(self, "TRANSFORMER_LAYERS") or self.TRANSFORMER_LAYERS < 0:
+ raise ValueError("TRANSFORMER_LAYERS cannot be negative.")
+ if self.TRANSFORMER_LAYERS > 0:
+ if not hasattr(self, "TRANSFORMER_DIM") or self.TRANSFORMER_DIM <= 0:
+ raise ValueError(
+ "TRANSFORMER_DIM must be positive if TRANSFORMER_LAYERS > 0."
+ )
+ if (
+ not hasattr(self, "TRANSFORMER_HEADS")
+ or self.TRANSFORMER_HEADS <= 0
+ ):
+ raise ValueError(
+ "TRANSFORMER_HEADS must be positive if TRANSFORMER_LAYERS > 0."
+ )
+ if self.TRANSFORMER_DIM % self.TRANSFORMER_HEADS != 0:
+ raise ValueError(
+ f"TRANSFORMER_DIM ({self.TRANSFORMER_DIM}) must be divisible by TRANSFORMER_HEADS ({self.TRANSFORMER_HEADS})."
+ )
+ if (
+ not hasattr(self, "TRANSFORMER_FC_DIM")
+ or self.TRANSFORMER_FC_DIM <= 0
+ ):
+ raise ValueError(
+ "TRANSFORMER_FC_DIM must be positive if TRANSFORMER_LAYERS > 0."
+ )
+ return self
+
+ @model_validator(mode="after")
+ def check_transformer_dim_consistency(self) -> "ModelConfig":
+ # Ensure attributes exist before checking
+ if (
+ hasattr(self, "USE_TRANSFORMER")
+ and self.USE_TRANSFORMER
+ and hasattr(self, "TRANSFORMER_LAYERS")
+ and self.TRANSFORMER_LAYERS > 0
+ and hasattr(self, "CONV_FILTERS")
+ and self.CONV_FILTERS
+ and hasattr(self, "TRANSFORMER_DIM")
+ ):
+ cnn_output_channels = (
+ self.RESIDUAL_BLOCK_FILTERS
+ if hasattr(self, "NUM_RESIDUAL_BLOCKS") and self.NUM_RESIDUAL_BLOCKS > 0
+ else self.CONV_FILTERS[-1]
+ )
+ if cnn_output_channels != self.TRANSFORMER_DIM:
+ # This is handled by an input projection layer in the model now
+ pass # Model handles projection
+ return self
+
+ @model_validator(mode="after")
+ def check_value_distribution_params(self) -> "ModelConfig":
+ if (
+ hasattr(self, "VALUE_MIN")
+ and hasattr(self, "VALUE_MAX")
+ and self.VALUE_MIN >= self.VALUE_MAX
+ ):
+ raise ValueError("VALUE_MIN must be strictly less than VALUE_MAX.")
+ return self
+
+
+ModelConfig.model_rebuild(force=True)
+
+
+File: alphatriangle/config/__init__.py
+# File: alphatriangle/config/__init__.py
+from trianglengin import EnvConfig
+
+from .app_config import APP_NAME
+from .mcts_config import AlphaTriangleMCTSConfig
+from .model_config import ModelConfig
+from .train_config import TrainConfig
+from .validation import print_config_info_and_validate
+
+__all__ = [
+ "APP_NAME",
+ "EnvConfig",
+ "ModelConfig",
+ "TrainConfig",
+ "AlphaTriangleMCTSConfig",
+ "print_config_info_and_validate",
+]
+
+
+File: alphatriangle/config/train_config.py
+# File: alphatriangle/config/train_config.py
+import logging
+import time
+from typing import Literal
+
+from pydantic import BaseModel, Field, field_validator, model_validator
+
+# Get logger instance
+logger = logging.getLogger(__name__)
+
+
+class TrainConfig(BaseModel):
+ """
+ Configuration for the training process (Pydantic model).
+ --- TUNED FOR MORE SUBSTANTIAL LEARNING RUNS ---
+ """
+
+ RUN_NAME: str = Field(
+ # More descriptive default run name
+ default_factory=lambda: f"train_{time.strftime('%Y%m%d_%H%M%S')}"
+ )
+ LOAD_CHECKPOINT_PATH: str | None = Field(default=None)
+ LOAD_BUFFER_PATH: str | None = Field(default=None)
+ AUTO_RESUME_LATEST: bool = Field(default=True) # Resume if possible
+ # --- DEVICE: Defaults to 'auto' for automatic detection (CUDA > MPS > CPU) ---
+ # This controls the device for the main Trainer process.
+ DEVICE: Literal["auto", "cuda", "cpu", "mps"] = Field(default="auto")
+ RANDOM_SEED: int = Field(default=42)
+
+ # --- Training Loop ---
+ MAX_TRAINING_STEPS: int | None = Field(default=100_000, ge=1) # Target steps
+
+ # --- Workers & Batching ---
+ NUM_SELF_PLAY_WORKERS: int = Field(
+ default=8, # Default workers, capped by cores
+ ge=1,
+ description="Suggested number of workers. Actual number may be adjusted based on detected CPU cores.",
+ )
+ # --- WORKER_DEVICE: Defaults to 'cpu' for self-play workers ---
+ # Workers run MCTS and NN eval; CPU is often sufficient and avoids GPU contention.
+ WORKER_DEVICE: Literal["auto", "cuda", "cpu", "mps"] = Field(default="cpu")
+ BATCH_SIZE: int = Field(default=256, ge=1) # Increased training batch size
+ BUFFER_CAPACITY: int = Field(default=250_000, ge=1) # Slightly larger buffer
+ MIN_BUFFER_SIZE_TO_TRAIN: int = Field(
+ default=25_000, # Increased min size (~10% of capacity)
+ ge=1,
+ )
+ WORKER_UPDATE_FREQ_STEPS: int = Field(
+ default=10,
+ ge=1, # Lowered for easier testing
+ )
+
+ # --- N-Step Returns ---
+ N_STEP_RETURNS: int = Field(default=5, ge=1) # 5-step returns
+ GAMMA: float = Field(default=0.99, gt=0, le=1.0) # Discount factor
+
+ # --- Optimizer ---
+ OPTIMIZER_TYPE: Literal["Adam", "AdamW", "SGD"] = Field(default="AdamW")
+ LEARNING_RATE: float = Field(default=2e-4, gt=0) # Keep LR, scheduler will adjust
+ WEIGHT_DECAY: float = Field(default=1e-4, ge=0)
+ GRADIENT_CLIP_VALUE: float | None = Field(default=1.0)
+
+ # --- LR Scheduler ---
+ LR_SCHEDULER_TYPE: Literal["StepLR", "CosineAnnealingLR"] | None = Field(
+ default="CosineAnnealingLR"
+ )
+ # T_MAX will be set automatically based on new MAX_TRAINING_STEPS
+ LR_SCHEDULER_T_MAX: int | None = Field(default=None)
+ LR_SCHEDULER_ETA_MIN: float = Field(default=1e-6, ge=0)
+
+ # --- Loss Weights ---
+ POLICY_LOSS_WEIGHT: float = Field(default=1.0, ge=0)
+ VALUE_LOSS_WEIGHT: float = Field(default=1.0, ge=0)
+ ENTROPY_BONUS_WEIGHT: float = Field(default=0.001, ge=0) # Small entropy bonus
+
+ # --- Checkpointing ---
+ CHECKPOINT_SAVE_FREQ_STEPS: int = Field(default=2500, ge=1) # Save every 2500 steps
+
+ # --- Prioritized Experience Replay (PER) ---
+ USE_PER: bool = Field(default=True) # Enable PER
+ PER_ALPHA: float = Field(default=0.6, ge=0)
+ PER_BETA_INITIAL: float = Field(default=0.4, ge=0, le=1.0)
+ PER_BETA_FINAL: float = Field(default=1.0, ge=0, le=1.0)
+ # Anneal steps will be set automatically based on MAX_TRAINING_STEPS
+ PER_BETA_ANNEAL_STEPS: int | None = Field(default=None)
+ PER_EPSILON: float = Field(default=1e-5, gt=0)
+
+ # --- Model Compilation ---
+ COMPILE_MODEL: bool = Field(
+ default=True, # Keep compilation enabled
+ description=(
+ "Enable torch.compile() for potential speedup (Trainer only). Requires PyTorch 2.0+. "
+ "May have initial overhead or compatibility issues on some setups/GPUs "
+ "(especially non-CUDA backends like MPS). Set to False if encountering problems. "
+ "The application will attempt compilation and fall back gracefully if it fails."
+ ),
+ )
+
+ # --- Profiling ---
+ PROFILE_WORKERS: bool = Field(
+ default=False,
+ description="Enable cProfile for worker 0 to generate .prof files.",
+ )
+
+ @model_validator(mode="after")
+ def check_buffer_sizes(self) -> "TrainConfig":
+ # Ensure attributes exist before comparing
+ if (
+ hasattr(self, "MIN_BUFFER_SIZE_TO_TRAIN")
+ and hasattr(self, "BUFFER_CAPACITY")
+ and self.MIN_BUFFER_SIZE_TO_TRAIN > self.BUFFER_CAPACITY
+ ):
+ raise ValueError(
+ "MIN_BUFFER_SIZE_TO_TRAIN cannot be greater than BUFFER_CAPACITY."
+ )
+ if (
+ hasattr(self, "BATCH_SIZE")
+ and hasattr(self, "BUFFER_CAPACITY")
+ and self.BATCH_SIZE > self.BUFFER_CAPACITY
+ ):
+ raise ValueError("BATCH_SIZE cannot be greater than BUFFER_CAPACITY.")
+ if (
+ hasattr(self, "BATCH_SIZE")
+ and hasattr(self, "MIN_BUFFER_SIZE_TO_TRAIN")
+ and self.BATCH_SIZE > self.MIN_BUFFER_SIZE_TO_TRAIN
+ ):
+ # Allow batch size to be larger than min buffer size (will just wait longer)
+ pass
+ return self
+
+ @model_validator(mode="after")
+ def set_scheduler_t_max(self) -> "TrainConfig":
+ # Ensure attributes exist before checking
+ if (
+ hasattr(self, "LR_SCHEDULER_TYPE")
+ and self.LR_SCHEDULER_TYPE == "CosineAnnealingLR"
+ and hasattr(self, "LR_SCHEDULER_T_MAX")
+ and self.LR_SCHEDULER_T_MAX is None # Only set if not manually specified
+ ):
+ if (
+ hasattr(self, "MAX_TRAINING_STEPS")
+ and self.MAX_TRAINING_STEPS is not None
+ ):
+ # Assign to self.LR_SCHEDULER_T_MAX only if MAX_TRAINING_STEPS is valid
+ if self.MAX_TRAINING_STEPS >= 1:
+ self.LR_SCHEDULER_T_MAX = self.MAX_TRAINING_STEPS
+ logger.info(
+ f"Set LR_SCHEDULER_T_MAX to MAX_TRAINING_STEPS ({self.MAX_TRAINING_STEPS})"
+ )
+ else:
+ # Handle invalid MAX_TRAINING_STEPS case if necessary
+ self.LR_SCHEDULER_T_MAX = 100_000 # Fallback (matches new default)
+ logger.warning(
+ f"Warning: MAX_TRAINING_STEPS is invalid ({self.MAX_TRAINING_STEPS}), setting LR_SCHEDULER_T_MAX to default {self.LR_SCHEDULER_T_MAX}"
+ )
+ else:
+ self.LR_SCHEDULER_T_MAX = 100_000 # Fallback (matches new default)
+ logger.warning(
+ f"Warning: MAX_TRAINING_STEPS is None, setting LR_SCHEDULER_T_MAX to default {self.LR_SCHEDULER_T_MAX}"
+ )
+
+ if (
+ hasattr(self, "LR_SCHEDULER_T_MAX")
+ and self.LR_SCHEDULER_T_MAX is not None
+ and self.LR_SCHEDULER_T_MAX <= 0
+ ):
+ raise ValueError("LR_SCHEDULER_T_MAX must be positive if set.")
+ return self
+
+ @model_validator(mode="after")
+ def set_per_beta_anneal_steps(self) -> "TrainConfig":
+ # Ensure attributes exist before checking
+ if (
+ hasattr(self, "USE_PER")
+ and self.USE_PER
+ and hasattr(self, "PER_BETA_ANNEAL_STEPS")
+ and self.PER_BETA_ANNEAL_STEPS is None # Only set if not manually specified
+ ):
+ if (
+ hasattr(self, "MAX_TRAINING_STEPS")
+ and self.MAX_TRAINING_STEPS is not None
+ ):
+ # Assign to self.PER_BETA_ANNEAL_STEPS only if MAX_TRAINING_STEPS is valid
+ if self.MAX_TRAINING_STEPS >= 1:
+ self.PER_BETA_ANNEAL_STEPS = self.MAX_TRAINING_STEPS
+ logger.info(
+ f"Set PER_BETA_ANNEAL_STEPS to MAX_TRAINING_STEPS ({self.MAX_TRAINING_STEPS})"
+ )
+ else:
+ # Handle invalid MAX_TRAINING_STEPS case if necessary
+ self.PER_BETA_ANNEAL_STEPS = (
+ 100_000 # Fallback (matches new default)
+ )
+ logger.warning(
+ f"Warning: MAX_TRAINING_STEPS is invalid ({self.MAX_TRAINING_STEPS}), setting PER_BETA_ANNEAL_STEPS to default {self.PER_BETA_ANNEAL_STEPS}"
+ )
+ else:
+ self.PER_BETA_ANNEAL_STEPS = 100_000 # Fallback (matches new default)
+ logger.warning(
+ f"Warning: MAX_TRAINING_STEPS is None, setting PER_BETA_ANNEAL_STEPS to default {self.PER_BETA_ANNEAL_STEPS}"
+ )
+
+ if (
+ hasattr(self, "PER_BETA_ANNEAL_STEPS")
+ and self.PER_BETA_ANNEAL_STEPS is not None
+ and self.PER_BETA_ANNEAL_STEPS <= 0
+ ):
+ raise ValueError("PER_BETA_ANNEAL_STEPS must be positive if set.")
+ return self
+
+ @field_validator("GRADIENT_CLIP_VALUE")
+ @classmethod
+ def check_gradient_clip(cls, v: float | None) -> float | None:
+ if v is not None and v <= 0:
+ raise ValueError("GRADIENT_CLIP_VALUE must be positive if set.")
+ return v
+
+ @field_validator("PER_BETA_FINAL")
+ @classmethod
+ def check_per_beta_final(cls, v: float, info) -> float:
+ # info.data might not be available during initial validation in Pydantic v2
+ # Check 'values' if info.data is empty
+ data = info.data if info.data else info.values
+ initial_beta = data.get("PER_BETA_INITIAL")
+ if initial_beta is not None and v < initial_beta:
+ raise ValueError("PER_BETA_FINAL cannot be less than PER_BETA_INITIAL")
+ return v
+
+
+# Ensure model is rebuilt after changes
+TrainConfig.model_rebuild(force=True)
+
+
+File: alphatriangle/config/README.md
+# Configuration Module (`alphatriangle.config`)
+
+## Purpose and Architecture
+
+This module centralizes configuration parameters for the AlphaTriangle agent itself, *excluding* statistics logging and data persistence which are now handled by the `trieye` library. It uses separate **Pydantic models** for different aspects of the agent (model, training loop, MCTS) to promote modularity, clarity, and automatic validation.
+
+**Core environment configuration (`EnvConfig`) is imported directly from the `trianglengin` library.**
+**Statistics and Persistence configuration (`StatsConfig`, `PersistenceConfig`) are defined and managed within the `trieye` library via `TrieyeConfig`.**
+
+- **Modularity:** Separating configurations makes it easier to manage parameters for different components.
+- **Type Safety & Validation:** Using Pydantic models (`BaseModel`) provides strong type hinting, automatic parsing, and validation of configuration values based on defined types and constraints (e.g., `Field(gt=0)`).
+- **Validation Script:** The [`validation.py`](validation.py) script instantiates the AlphaTriangle-specific configuration models (including importing and validating `trianglengin.EnvConfig`), triggering Pydantic's validation, and prints a summary. **Note:** It does *not* validate `TrieyeConfig` directly; `trieye` handles its own validation upon actor initialization.
+- **Dynamic Defaults:** Some configurations, like `RUN_NAME` in `TrainConfig`, use `default_factory` for dynamic defaults (e.g., timestamp). This default is often overridden by the `TrieyeConfig` setting.
+- **Tuned Defaults:** The default values in `TrainConfig` and `ModelConfig` are tuned for substantial learning runs. `AlphaTriangleMCTSConfig` defaults to 128 simulations.
+
+## Exposed Interfaces
+
+- **Pydantic Models:**
+ - `EnvConfig` (Imported from `trianglengin`): Environment parameters (grid size, shapes, rewards).
+ - [`ModelConfig`](model_config.py): Neural network architecture parameters.
+ - [`TrainConfig`](train_config.py): Training loop hyperparameters (batch size, learning rate, workers, PER settings, etc.).
+ - [`AlphaTriangleMCTSConfig`](mcts_config.py): MCTS parameters (simulations, exploration constants, temperature).
+- **Constants:**
+ - [`APP_NAME`](app_config.py): The name of the application (used by `trieye` for namespacing).
+- **Functions:**
+ - `print_config_info_and_validate(mcts_config_instance: AlphaTriangleMCTSConfig | None)`: Validates and prints a summary of AlphaTriangle-specific configurations.
+
+## Dependencies
+
+This module primarily defines configurations and relies heavily on **Pydantic**.
+
+- **`pydantic`**: The core library used for defining models and validation.
+- **`trianglengin`**: Imports `EnvConfig`.
+- **`trimcts`**: Used by `AlphaTriangleMCTSConfig` (implicitly, as it mirrors the structure).
+- **Standard Libraries:** `typing`, `time`, `os`, `logging`, `pathlib`.
+
+---
+
+**Note:** Please keep this README updated when adding, removing, or significantly modifying configuration parameters or the structure of the Pydantic models within this module.
+
+File: alphatriangle/config/mcts_config.py
+# File: alphatriangle/config/mcts_config.py
+"""
+Configuration for MCTS parameters specific to AlphaTriangle,
+mirroring trimcts.SearchConfiguration for easy control.
+"""
+
+from pydantic import BaseModel, ConfigDict, Field
+from trimcts import SearchConfiguration # Import base config for reference
+
+# Restore default simulations to a lower value for faster testing/profiling
+DEFAULT_MAX_SIMULATIONS = 64
+DEFAULT_MAX_DEPTH = 8
+DEFAULT_CPUCT = 1.5
+DEFAULT_MCTS_BATCH_SIZE = 32 # Default batch size for network evals within MCTS
+
+
+class AlphaTriangleMCTSConfig(BaseModel):
+ """MCTS Search Configuration managed within AlphaTriangle."""
+
+ # Core Search Parameters
+ max_simulations: int = Field(
+ default=DEFAULT_MAX_SIMULATIONS,
+ description="Maximum number of MCTS simulations per move.",
+ gt=0,
+ )
+ max_depth: int = Field(
+ default=DEFAULT_MAX_DEPTH,
+ description="Maximum depth for tree traversal during simulation.",
+ gt=0,
+ )
+
+ # UCT Parameters (AlphaZero style)
+ cpuct: float = Field(
+ default=DEFAULT_CPUCT,
+ description="Constant determining the level of exploration (PUCT).",
+ )
+
+ # Dirichlet Noise (for root node exploration)
+ dirichlet_alpha: float = Field(
+ default=0.3, description="Alpha parameter for Dirichlet noise.", ge=0
+ )
+ dirichlet_epsilon: float = Field(
+ default=0.25,
+ description="Weight of Dirichlet noise in root prior probabilities.",
+ ge=0,
+ le=1.0,
+ )
+
+ # Discount Factor (Primarily for MuZero/Value Propagation)
+ discount: float = Field(
+ default=1.0,
+ description="Discount factor (gamma) for future rewards/values.",
+ ge=0.0,
+ le=1.0,
+ )
+
+ # Batching for Network Evaluations within MCTS
+ mcts_batch_size: int = Field(
+ default=DEFAULT_MCTS_BATCH_SIZE,
+ description="Number of leaf nodes to collect in C++ MCTS before calling network evaluate_batch.",
+ gt=0,
+ )
+
+ # Use ConfigDict for Pydantic V2
+ model_config = ConfigDict(validate_assignment=True)
+
+ def to_trimcts_config(self) -> SearchConfiguration:
+ """Converts this config to the trimcts.SearchConfiguration."""
+ return SearchConfiguration(
+ max_simulations=self.max_simulations,
+ max_depth=self.max_depth,
+ cpuct=self.cpuct,
+ dirichlet_alpha=self.dirichlet_alpha,
+ dirichlet_epsilon=self.dirichlet_epsilon,
+ discount=self.discount,
+ mcts_batch_size=self.mcts_batch_size, # Pass batch size
+ )
+
+
+# Ensure model is rebuilt after changes
+AlphaTriangleMCTSConfig.model_rebuild(force=True)
+
+
+File: alphatriangle/config/app_config.py
+APP_NAME: str = "AlphaTriangle"
+
+
+File: alphatriangle/config/validation.py
+# File: alphatriangle/config/validation.py
+import logging
+from typing import Any
+
+from pydantic import BaseModel, ValidationError
+
+# Import EnvConfig from trianglengin's top level
+from trianglengin import EnvConfig
+
+# Import AlphaTriangle configs
+from .mcts_config import AlphaTriangleMCTSConfig
+from .model_config import ModelConfig
+from .train_config import TrainConfig
+
+# Removed PersistenceConfig, StatsConfig imports
+
+logger = logging.getLogger(__name__)
+
+
+def print_config_info_and_validate(
+ mcts_config_instance: AlphaTriangleMCTSConfig | None,
+):
+ """
+ Prints configuration summary and performs validation using Pydantic
+ for AlphaTriangle-specific configurations.
+ Note: TrieyeConfig validation happens within the Trieye library.
+ """
+ print("-" * 40)
+ print("AlphaTriangle Configuration Validation & Summary")
+ print("-" * 40)
+ all_valid = True
+ configs_validated: dict[str, Any] = {}
+
+ # Only validate AlphaTriangle-specific configs here
+ config_classes: dict[str, type[BaseModel]] = {
+ "Environment": EnvConfig,
+ "Model": ModelConfig,
+ "Training": TrainConfig,
+ "MCTS": AlphaTriangleMCTSConfig,
+ }
+
+ for name, ConfigClass in config_classes.items():
+ instance: BaseModel | None = None
+ try:
+ if name == "MCTS":
+ if mcts_config_instance is not None:
+ # Validate the provided instance against the class definition
+ instance = AlphaTriangleMCTSConfig.model_validate(
+ mcts_config_instance.model_dump()
+ )
+ print(f"[{name}] - Instance provided & validated OK")
+ else:
+ # Instantiate default if no instance provided
+ instance = ConfigClass()
+ print(f"[{name}] - Validated OK (Instantiated Default)")
+ else:
+ # Instantiate default for other configs
+ instance = ConfigClass()
+ print(f"[{name}] - Validated OK")
+ configs_validated[name] = instance
+ except ValidationError as e:
+ logger.error(f"Validation failed for {name} Config:")
+ logger.error(e)
+ all_valid = False
+ configs_validated[name] = None
+ except Exception as e:
+ logger.error(
+ f"Unexpected error instantiating/validating {name} Config: {e}"
+ )
+ all_valid = False
+ configs_validated[name] = None
+
+ print("-" * 40)
+ print("Configuration Values:")
+ print("-" * 40)
+
+ for name, instance in configs_validated.items():
+ print(f"--- {name} Config ---")
+ if instance:
+ # Use model_dump for Pydantic v2
+ dump_data = instance.model_dump()
+ for field_name, value in dump_data.items():
+ # Simple representation for long lists/dicts
+ if isinstance(value, list) and len(value) > 10:
+ print(f" {field_name}: [List with {len(value)} items]")
+ elif isinstance(value, dict) and len(value) > 10:
+ print(f" {field_name}: {{Dict with {len(value)} keys}}")
+ else:
+ print(f" {field_name}: {value}")
+ else:
+ print(" ")
+ print("-" * 20)
+
+ print("-" * 40)
+ if not all_valid:
+ logger.critical("Configuration validation failed. Please check errors above.")
+ raise ValueError("Invalid configuration settings.")
+ else:
+ logger.info("AlphaTriangle configurations validated successfully.")
+ print("-" * 40)
+
+
+File: alphatriangle/training/loop_helpers.py
+import logging
+import time
+from collections.abc import Callable
+from typing import TYPE_CHECKING, Any
+
+import numpy as np
+
+from ..utils import format_eta
+from ..utils.types import Experience
+
+if TYPE_CHECKING:
+ from .components import TrainingComponents
+
+logger = logging.getLogger(__name__)
+
+# Removed constants related to old logging intervals and keys
+
+
+class LoopHelpers:
+ """Provides helper functions for the TrainingLoop (simplified)."""
+
+ def __init__(
+ self,
+ components: "TrainingComponents",
+ get_loop_state_func: Callable[[], dict[str, Any]],
+ ):
+ self.components = components
+ self.get_loop_state = get_loop_state_func
+ self.train_config = components.train_config
+ # No longer needs direct access to stats_collector or trainer for logging
+
+ # Keep state for ETA calculation if needed, or rely on StatsProcessor rates
+ # self.last_rate_calc_time = time.time()
+ # self.last_rate_calc_step = 0
+
+ logger.info("LoopHelpers initialized (Simplified).")
+
+ # Removed set_tensorboard_writer
+ # Removed reset_rate_counters
+ # Removed _fetch_latest_stats
+ # Removed calculate_and_log_rates
+ # Removed log_training_results_async
+ # Removed log_weight_update_event
+ # Removed log_additional_stats
+
+ def log_progress_eta(self):
+ """Logs progress and ETA to console."""
+ loop_state = self.get_loop_state()
+ global_step = loop_state["global_step"]
+
+ # Log less frequently to console
+ if global_step == 0 or global_step % 100 != 0:
+ return
+
+ elapsed_time = time.time() - loop_state["start_time"]
+ steps_since_start = global_step
+
+ # --- Get rate from StatsProcessor (requires fetching or passing) ---
+ # This part is tricky without direct access. For simplicity,
+ # we might just use the average rate since start for the console log.
+ # A more robust way would involve fetching the latest rate from the stats actor,
+ # but that adds complexity back. Let's use the simple average for now.
+ steps_per_sec = 0.0
+ if elapsed_time > 1 and steps_since_start > 0:
+ steps_per_sec = steps_since_start / elapsed_time
+
+ target_steps = self.train_config.MAX_TRAINING_STEPS
+ target_steps_str = f"{target_steps:,}" if target_steps else "Infinite"
+ progress_str = f"Step {global_step:,}/{target_steps_str}"
+
+ eta_str = "--"
+ if target_steps and steps_per_sec > 1e-6:
+ remaining_steps = target_steps - global_step
+ if remaining_steps > 0:
+ eta_seconds = remaining_steps / steps_per_sec
+ eta_str = format_eta(eta_seconds)
+
+ buffer_fill_perc = (
+ (loop_state["buffer_size"] / loop_state["buffer_capacity"]) * 100
+ if loop_state["buffer_capacity"] > 0
+ else 0.0
+ )
+ total_sims = loop_state["total_simulations_run"]
+ total_sims_str = (
+ f"{total_sims / 1e6:.2f}M"
+ if total_sims >= 1e6
+ else (f"{total_sims / 1e3:.1f}k" if total_sims >= 1000 else str(total_sims))
+ )
+ num_pending_tasks = loop_state["num_pending_tasks"]
+
+ # Simplified console log - detailed rates come from StatsProcessor via MLflow/TB
+ logger.info(
+ f"Progress: {progress_str}, Episodes: {loop_state['episodes_played']:,}, Total Sims: {total_sims_str}, "
+ f"Buffer: {loop_state['buffer_size']:,}/{loop_state['buffer_capacity']:,} ({buffer_fill_perc:.1f}%), "
+ f"Pending Tasks: {num_pending_tasks}, Avg Speed: {steps_per_sec:.2f} steps/sec, ETA: {eta_str}"
+ )
+
+ def validate_experiences(
+ self, experiences: list[Experience]
+ ) -> tuple[list[Experience], int]:
+ """Validates the structure and content of experiences."""
+ # This function remains largely the same
+ valid_experiences = []
+ invalid_count = 0
+ for i, exp in enumerate(experiences):
+ is_valid = False
+ try:
+ if isinstance(exp, tuple) and len(exp) == 3:
+ state_type, policy_map, value = exp
+ if (
+ isinstance(state_type, dict)
+ and "grid" in state_type
+ and "other_features" in state_type
+ and isinstance(state_type["grid"], np.ndarray)
+ and isinstance(state_type["other_features"], np.ndarray)
+ and isinstance(policy_map, dict)
+ and isinstance(value, float | int)
+ ):
+ if np.all(np.isfinite(state_type["grid"])) and np.all(
+ np.isfinite(state_type["other_features"])
+ ):
+ is_valid = True
+ else:
+ logger.warning(
+ f"Experience {i} contains non-finite features (grid_finite={np.all(np.isfinite(state_type['grid']))}, other_finite={np.all(np.isfinite(state_type['other_features']))})."
+ )
+ else:
+ logger.warning(
+ f"Experience {i} has incorrect types: state={type(state_type)}, policy={type(policy_map)}, value={type(value)}"
+ )
+ else:
+ logger.warning(
+ f"Experience {i} is not a tuple of length 3: type={type(exp)}, len={len(exp) if isinstance(exp, tuple) else 'N/A'}"
+ )
+ except Exception as e:
+ logger.error(
+ f"Unexpected error validating experience {i}: {e}", exc_info=True
+ )
+ is_valid = False
+ if is_valid:
+ valid_experiences.append(exp)
+ else:
+ invalid_count += 1
+ if invalid_count > 0:
+ logger.warning(f"Filtered out {invalid_count} invalid experiences.")
+ return valid_experiences, invalid_count
+
+
+File: alphatriangle/training/runner.py
+# File: alphatriangle/training/runner.py
+import logging
+import sys
+import time
+import traceback
+
+import ray
+import torch
+from trieye import ( # Import from Trieye
+ LoadedTrainingState,
+ TrieyeConfig,
+)
+
+from ..config import TrainConfig
+from ..logging_config import setup_logging # Import centralized setup
+from ..utils.sumtree import SumTree
+from .components import TrainingComponents
+from .logging_utils import log_configs_to_mlflow # Keep MLflow helper for AT configs
+from .loop import TrainingLoop
+from .setup import count_parameters, setup_training_components
+
+logger = logging.getLogger(__name__)
+
+
+# Removed _initialize_mlflow
+
+
+def _load_and_apply_initial_state(
+ components: TrainingComponents,
+) -> tuple[TrainingLoop, LoadedTrainingState]:
+ """
+ Loads initial state using TrieyeActor and applies it to components.
+ Returns an initialized TrainingLoop and the loaded state object.
+ """
+ logger.info("Requesting initial state load from TrieyeActor...")
+ loaded_state: LoadedTrainingState = ray.get(
+ components.trieye_actor.load_initial_state.remote()
+ )
+ training_loop = TrainingLoop(components)
+ initial_step = 0
+
+ if loaded_state.checkpoint_data:
+ cp_data = loaded_state.checkpoint_data
+ logger.info(
+ f"Applying loaded checkpoint data (Run: {cp_data.run_name}, Step: {cp_data.global_step})"
+ )
+
+ if cp_data.model_state_dict:
+ components.nn.set_weights(cp_data.model_state_dict)
+ if cp_data.optimizer_state_dict:
+ try:
+ components.trainer.optimizer.load_state_dict(
+ cp_data.optimizer_state_dict
+ )
+ # Move optimizer state to the correct device
+ for state in components.trainer.optimizer.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.to(components.nn.device)
+ logger.info("Optimizer state loaded and moved to device.")
+ except Exception as opt_load_err:
+ logger.error(
+ f"Could not load optimizer state: {opt_load_err}. Optimizer might reset."
+ )
+ # Actor state is restored internally by TrieyeActor's load_initial_state
+
+ training_loop.set_initial_state(
+ cp_data.global_step,
+ cp_data.episodes_played,
+ cp_data.total_simulations_run,
+ )
+ initial_step = cp_data.global_step
+ else:
+ logger.info("No checkpoint data loaded. Starting fresh.")
+ training_loop.set_initial_state(0, 0, 0)
+
+ loaded_buffer_size = 0
+ if loaded_state.buffer_data:
+ buffer_list = loaded_state.buffer_data.buffer_list
+ if components.train_config.USE_PER:
+ logger.info("Rebuilding PER SumTree from loaded buffer data...")
+ if not hasattr(components.buffer, "tree") or components.buffer.tree is None:
+ components.buffer.tree = SumTree(components.buffer.capacity)
+ else:
+ components.buffer.tree = SumTree(components.buffer.capacity)
+
+ max_p = 1.0
+ for exp in buffer_list:
+ # Basic validation before adding to tree
+ if isinstance(exp, tuple) and len(exp) == 3:
+ components.buffer.tree.add(max_p, exp)
+ else:
+ logger.warning(
+ f"Skipping invalid item from loaded buffer: {type(exp)}"
+ )
+ loaded_buffer_size = len(components.buffer)
+ logger.info(f"PER buffer loaded. Size: {loaded_buffer_size}")
+ else:
+ from collections import deque
+
+ # Basic validation before adding to deque
+ valid_buffer_list = [
+ exp for exp in buffer_list if isinstance(exp, tuple) and len(exp) == 3
+ ]
+ if len(valid_buffer_list) < len(buffer_list):
+ logger.warning(
+ f"Filtered {len(buffer_list) - len(valid_buffer_list)} invalid items from loaded buffer."
+ )
+ components.buffer.buffer = deque(
+ valid_buffer_list, maxlen=components.buffer.capacity
+ )
+ loaded_buffer_size = len(components.buffer)
+ logger.info(f"Uniform buffer loaded. Size: {loaded_buffer_size}")
+ else:
+ logger.info("No buffer data loaded.")
+
+ components.nn.model.train()
+ logger.info(
+ f"Initial state loaded and applied. Starting step: {initial_step}, Buffer size: {loaded_buffer_size}"
+ )
+ return training_loop, loaded_state
+
+
+def _save_final_state(training_loop: TrainingLoop):
+ """Triggers the Trieye actor to save the final training state."""
+ if not training_loop or not training_loop.trieye_actor:
+ logger.warning("Cannot save final state: TrainingLoop or TrieyeActor missing.")
+ return
+ components = training_loop.components
+ logger.info("Requesting final training state save via TrieyeActor...")
+ try:
+ # Prepare data using the local serializer from components
+ nn_state = components.nn.get_weights()
+ opt_state = components.serializer.prepare_optimizer_state(
+ components.trainer.optimizer.state_dict() # Pass state_dict
+ )
+ buffer_data = components.serializer.prepare_buffer_data(
+ list(components.buffer.buffer)
+ if not components.buffer.use_per
+ else [
+ components.buffer.tree.data[i]
+ for i in range(components.buffer.tree.n_entries)
+ if components.buffer.tree.data[i] != 0
+ ] # Pass buffer content list
+ )
+
+ # Fire-and-forget call to the actor
+ components.trieye_actor.save_training_state.remote(
+ nn_state_dict=nn_state,
+ optimizer_state_dict=opt_state,
+ buffer_content=buffer_data.buffer_list if buffer_data else [],
+ global_step=training_loop.global_step,
+ episodes_played=training_loop.episodes_played,
+ total_simulations_run=training_loop.total_simulations_run,
+ is_best=False, # Final save is not 'best'
+ save_buffer=components.trieye_config.persistence.SAVE_BUFFER,
+ model_config_dict=components.model_config.model_dump(),
+ env_config_dict=components.env_config.model_dump(),
+ )
+ # Allow some time for the async save call to potentially start
+ time.sleep(0.5)
+ except Exception as e_save:
+ logger.error(f"Failed to trigger final state save: {e_save}", exc_info=True)
+
+
+def run_training(
+ log_level_str: str,
+ train_config_override: TrainConfig,
+ trieye_config_override: TrieyeConfig, # Use TrieyeConfig
+ profile: bool,
+) -> int:
+ """Runs the training pipeline (headless)."""
+ training_loop: TrainingLoop | None = None
+ components: TrainingComponents | None = None
+ exit_code = 1
+ # log_file_path: Path | None = None # Path determined by Trieye
+ file_handler: logging.FileHandler | None = None
+ ray_initialized_by_setup = False
+ trieye_actor_handle: ray.actor.ActorHandle | None = None
+
+ try:
+ # --- Setup Console Logging ---
+ # File logging is handled by TrieyeActor
+ file_handler = setup_logging(log_level_str, log_file=None) # Pass log_file=None
+ logger.info(f"Console logging level set to: {log_level_str.upper()}")
+
+ # --- Setup Components (includes Ray init, TrieyeActor init) ---
+ components, ray_initialized_by_setup = setup_training_components(
+ train_config_override,
+ trieye_config_override, # Pass TrieyeConfig
+ profile,
+ )
+ if not components or not components.trieye_actor:
+ raise RuntimeError(
+ "Failed to initialize training components or TrieyeActor."
+ )
+
+ trieye_actor_handle = components.trieye_actor # Store handle for cleanup
+
+ # --- Log Configs and Params to MLflow (if run started by Trieye) ---
+ mlflow_run_id = ray.get(trieye_actor_handle.get_mlflow_run_id.remote())
+ if mlflow_run_id:
+ log_configs_to_mlflow(components) # Log AlphaTriangle configs
+ total_params, trainable_params = count_parameters(components.nn.model)
+ logger.info(
+ f"Model Parameters: Total={total_params:,}, Trainable={trainable_params:,}"
+ )
+ # Parameter logging removed as TrieyeActor doesn't expose log_param
+ else:
+ logger.warning(
+ "MLflow run ID not available from TrieyeActor, skipping MLflow param logging."
+ )
+
+ # --- Load State & Initialize Loop ---
+ training_loop, _ = _load_and_apply_initial_state(components)
+
+ # --- Run Training Loop ---
+ training_loop.initialize_workers()
+ training_loop.run()
+
+ # --- Determine Exit Code ---
+ if training_loop.training_complete:
+ exit_code = 0
+ logger.info(
+ f"Training run '[bold cyan]{components.trieye_config.run_name}[/]' completed successfully.",
+ extra={"markup": True},
+ )
+ elif training_loop.training_exception:
+ exit_code = 1
+ logger.error(
+ f"Training run '[bold cyan]{components.trieye_config.run_name}[/]' failed due to exception: {training_loop.training_exception}",
+ extra={"markup": True},
+ )
+ else:
+ exit_code = 0
+ logger.warning(
+ f"Training run '[bold cyan]{components.trieye_config.run_name}[/]' stopped before completion.",
+ extra={"markup": True},
+ )
+
+ except Exception as e:
+ logger.critical(
+ f"An unhandled error occurred during training setup or execution: {e}"
+ )
+ traceback.print_exc()
+ # Attempt to log failure status via actor if possible
+ if trieye_actor_handle:
+ try:
+ # Use log_event to log status if log_param is unavailable
+ from trieye.schemas import RawMetricEvent
+
+ ray.get(
+ trieye_actor_handle.log_event.remote(
+ RawMetricEvent(
+ name="training_status_code", value=-1, global_step=0
+ )
+ )
+ ) # Example event
+ ray.get(
+ trieye_actor_handle.log_event.remote(
+ RawMetricEvent(
+ name="error_message_flag",
+ value=1,
+ global_step=0,
+ context={"error": str(e)},
+ )
+ )
+ )
+ except Exception as log_err:
+ logger.error(
+ f"Failed to log setup error status via TrieyeActor log_event: {log_err}"
+ )
+ exit_code = 1
+
+ finally:
+ # --- Cleanup ---
+ if training_loop and components:
+ _save_final_state(training_loop) # Trigger final save via actor
+ training_loop.cleanup_actors() # Cleans up workers
+
+ # Shutdown TrieyeActor gracefully
+ if trieye_actor_handle:
+ logger.info("Shutting down TrieyeActor...")
+ try:
+ # Force final processing and shutdown
+ final_step = training_loop.global_step if training_loop else 0
+ ray.get(
+ trieye_actor_handle.force_process_and_log.remote(final_step),
+ timeout=15,
+ )
+ ray.get(trieye_actor_handle.shutdown.remote(), timeout=15)
+ logger.info("TrieyeActor shutdown complete.")
+ except Exception as actor_shutdown_err:
+ logger.error(
+ f"Error during TrieyeActor shutdown: {actor_shutdown_err}. Attempting kill."
+ )
+ try:
+ ray.kill(trieye_actor_handle)
+ except Exception as kill_err:
+ logger.error(f"Failed to kill TrieyeActor: {kill_err}")
+
+ if ray_initialized_by_setup and ray.is_initialized():
+ try:
+ ray.shutdown()
+ logger.info("Ray shut down by runner.")
+ except Exception as e:
+ logger.error(f"Error shutting down Ray: {e}", exc_info=True)
+
+ # Close console logger handler if it was created
+ if file_handler:
+ try:
+ file_handler.flush()
+ file_handler.close()
+ logging.getLogger().removeHandler(file_handler)
+ except Exception as e_close:
+ print(f"Error closing log file handler: {e_close}", file=sys.__stderr__)
+
+ logger.info(f"Training finished with exit code {exit_code}.")
+ return exit_code
+
+
+File: alphatriangle/training/__init__.py
+# File: alphatriangle/training/__init__.py
+from .components import TrainingComponents
+
+# Utilities
+from .logging_utils import log_configs_to_mlflow
+from .loop import TrainingLoop
+from .loop_helpers import LoopHelpers
+
+# Re-export runner functions
+from .runner import run_training
+
+# REMOVE visual runner import
+from .setup import setup_training_components
+from .worker_manager import WorkerManager
+
+# Explicitly define __all__
+__all__ = [
+ # Core Components
+ "TrainingComponents",
+ "TrainingLoop",
+ # Helpers & Managers
+ "WorkerManager",
+ "LoopHelpers",
+ "setup_training_components",
+ # Runners (re-exported)
+ "run_training",
+ # Logging Utilities
+ "log_configs_to_mlflow",
+]
+
+
+File: alphatriangle/training/README.md
+# Training Module (`alphatriangle.training`)
+
+## Purpose and Architecture
+
+This module encapsulates the logic for setting up, running, and managing the **headless** reinforcement learning training pipeline. It aims to provide a cleaner separation of concerns compared to embedding all logic within the run scripts or a single orchestrator class. **It now leverages the `trieye` library for asynchronous statistics collection and data persistence.**
+
+- **`setup.py`:** Contains `setup_training_components` which initializes Ray, detects resources, **adjusts the requested worker count based on available CPU cores (reserving some for stability),** loads configurations (including `trianglengin.EnvConfig`, `AlphaTriangleMCTSConfig`), creates the `trimcts.SearchConfiguration`, **initializes the `TrieyeActor` (passing `TrieyeConfig`)**, and bundles the core components (`TrainingComponents`).
+- **`components.py`:** Defines the `TrainingComponents` dataclass, a simple container to bundle all the necessary initialized objects (NN, Buffer, Trainer, **`TrieyeActor` handle**, Configs including `trimcts.SearchConfiguration`) required by the `TrainingLoop`.
+- **`loop.py`:** Defines the `TrainingLoop` class. This class contains the core asynchronous logic of the training loop itself:
+ - Managing the pool of `SelfPlayWorker` actors via `WorkerManager`.
+ - Submitting and collecting results from self-play tasks.
+ - Adding experiences to the `ExperienceBuffer`.
+ - Triggering training steps on the `Trainer`.
+ - Updating worker network weights periodically, passing the current `global_step` to the workers.
+ - **Sending raw metric events (`RawMetricEvent` from `trieye`) to the `TrieyeActor` for various occurrences (e.g., training step losses, buffer size changes, total weight updates).**
+ - **Processing results from workers and sending `episode_end` events with context (score, length, triangles cleared) to the `TrieyeActor`.**
+ - **Triggering the `TrieyeActor` to save checkpoints and buffers.**
+ - Logging simple progress strings to the console.
+ - Handling stop requests.
+- **`worker_manager.py`:** Defines the `WorkerManager` class, responsible for creating, managing, submitting tasks to, and collecting results from the `SelfPlayWorker` actors. It passes the `trimcts.SearchConfiguration`, the **`TrieyeActor` name**, and the **run base directory path** (within `.trieye_data//runs/`) to workers during initialization.
+- **`loop_helpers.py`:** Contains simplified helper functions used by `TrainingLoop`, primarily for formatting console progress/ETA strings and validating experiences.
+- **`runner.py`:** Contains the top-level logic for running the headless training pipeline. It handles argument parsing (via CLI), sets up console logging, calls `setup_training_components` (which starts the `TrieyeActor`), runs the `TrainingLoop`, and manages overall cleanup (**including `TrieyeActor` shutdown**).
+- **`runners.py`:** Re-exports the main entry point function (`run_training`) from `runner.py`.
+- **`logging_utils.py`:** Contains helper functions for logging configurations/metrics to MLflow (now primarily used for logging AlphaTriangle-specific configs, as `trieye` handles its own config logging).
+
+This structure separates the high-level setup/teardown (`runner`) from the core iterative logic (`loop`), making the system more modular. Statistics collection and persistence are now handled asynchronously by the `TrieyeActor`, configured via `TrieyeConfig`.
+
+## Exposed Interfaces
+
+- **Classes:**
+ - `TrainingLoop`: Contains the core async loop logic.
+ - `TrainingComponents`: Dataclass holding initialized components.
+ - `WorkerManager`: Manages worker actors.
+ - `LoopHelpers`: Provides simplified helper functions for the loop.
+- **Functions (from `runners.py`):**
+ - `run_training(...) -> int`
+- **Functions (from `setup.py`):**
+ - `setup_training_components(...) -> Tuple[Optional[TrainingComponents], bool]`
+- **Functions (from `logging_utils.py`):**
+ - `log_configs_to_mlflow(...)`
+
+## Dependencies
+
+- **[`alphatriangle.config`](../config/README.md)**: `AlphaTriangleMCTSConfig`, `ModelConfig`, `TrainConfig`.
+- **`trianglengin`**: `GameState`, `EnvConfig`.
+- **`trimcts`**: `SearchConfiguration`.
+- **`trieye`**: `TrieyeConfig`, `TrieyeActor`, `RawMetricEvent`, `LoadedTrainingState`.
+- **[`alphatriangle.nn`](../nn/README.md)**: `NeuralNetwork`.
+- **[`alphatriangle.rl`](../rl/README.md)**: `ExperienceBuffer`, `Trainer`, `SelfPlayWorker`, `SelfPlayResult`.
+- **[`alphatriangle.utils`](../utils/README.md)**: Helper functions and types.
+- **`ray`**: For parallelism and actors.
+- **`mlflow`**: Used by `trieye`.
+- **`torch`**: For neural network operations.
+- **`torch.utils.tensorboard`**: Used by `trieye`.
+- **Standard Libraries:** `logging`, `time`, `threading`, `queue`, `os`, `json`, `collections.deque`, `dataclasses`, `pathlib`.
+
+---
+
+**Note:** Please keep this README updated when changing the structure of the training pipeline, the responsibilities of the components, or the way components interact, especially regarding the interaction with the `trieye` library. Accurate documentation is crucial for maintainability.
+
+File: alphatriangle/training/worker_manager.py
+# File: alphatriangle/training/worker_manager.py
+import logging
+from typing import TYPE_CHECKING
+
+import ray
+from pydantic import ValidationError
+
+# Import EnvConfig from trianglengin
+# Keep alphatriangle imports
+from ..rl import SelfPlayResult, SelfPlayWorker
+
+if TYPE_CHECKING:
+ from .components import TrainingComponents
+
+logger = logging.getLogger(__name__)
+
+
+class WorkerManager:
+ """Manages the pool of SelfPlayWorker actors, task submission, and results."""
+
+ def __init__(self, components: "TrainingComponents"):
+ self.components = components
+ self.train_config = components.train_config
+ self.nn = components.nn
+ # Get Trieye actor name via remote call
+ self.trieye_actor_name = ray.get(
+ components.trieye_actor.get_actor_name.remote()
+ )
+ # Get run base dir string via Trieye actor
+ self.run_base_dir_str = ray.get(
+ components.trieye_actor.get_run_base_dir_str.remote()
+ )
+ self.profile_workers = components.profile_workers
+
+ self.workers: list[ray.actor.ActorHandle | None] = []
+ self.worker_tasks: dict[ray.ObjectRef, int] = {}
+ self.active_worker_indices: set[int] = set()
+
+ def initialize_workers(self):
+ """Creates the pool of SelfPlayWorker Ray actors."""
+ logger.info(
+ f"Initializing {self.train_config.NUM_SELF_PLAY_WORKERS} self-play workers..."
+ )
+ initial_weights = self.nn.get_weights()
+ weights_ref = ray.put(initial_weights)
+ self.workers = [None] * self.train_config.NUM_SELF_PLAY_WORKERS
+
+ run_base_dir_for_worker = self.run_base_dir_str
+
+ for i in range(self.train_config.NUM_SELF_PLAY_WORKERS):
+ try:
+ profile_this_worker = self.profile_workers and i == 0
+
+ worker = SelfPlayWorker.options(num_cpus=1).remote(
+ actor_id=i,
+ env_config=self.components.env_config,
+ mcts_config=self.components.mcts_config,
+ model_config=self.components.model_config,
+ train_config=self.train_config,
+ trieye_actor_name=self.trieye_actor_name, # Pass actor name
+ run_base_dir=run_base_dir_for_worker,
+ initial_weights=weights_ref,
+ seed=self.train_config.RANDOM_SEED + i,
+ worker_device_str=self.train_config.WORKER_DEVICE,
+ profile_this_worker=profile_this_worker,
+ )
+ self.workers[i] = worker
+ self.active_worker_indices.add(i)
+ except Exception as e:
+ logger.error(f"Failed to initialize worker {i}: {e}", exc_info=True)
+
+ logger.info(
+ f"Initialized {len(self.active_worker_indices)} active self-play workers."
+ )
+ del weights_ref
+
+ def submit_initial_tasks(self):
+ """Submits the first task for each active worker."""
+ logger.info("Submitting initial tasks to workers...")
+ for worker_idx in self.active_worker_indices:
+ self.submit_task(worker_idx)
+
+ def submit_task(self, worker_idx: int):
+ """Submits a new run_episode task to a specific worker."""
+ if worker_idx not in self.active_worker_indices:
+ logger.warning(f"Attempted to submit task to inactive worker {worker_idx}.")
+ return
+ worker = self.workers[worker_idx]
+ if worker:
+ try:
+ task_ref = worker.run_episode.remote()
+ self.worker_tasks[task_ref] = worker_idx
+ logger.debug(f"Submitted task to worker {worker_idx}")
+ except Exception as e:
+ logger.error(
+ f"Failed to submit task to worker {worker_idx}: {e}", exc_info=True
+ )
+ self.active_worker_indices.discard(worker_idx)
+ self.workers[worker_idx] = None
+ else:
+ logger.error(
+ f"Worker {worker_idx} is None during task submission despite being in active set."
+ )
+ self.active_worker_indices.discard(worker_idx)
+
+ def get_completed_tasks(
+ self, timeout: float = 0.1
+ ) -> list[tuple[int, SelfPlayResult | Exception]]:
+ """
+ Checks for completed worker tasks using ray.wait.
+ Returns a list of tuples: (worker_id, result_or_exception).
+ """
+ completed_results: list[tuple[int, SelfPlayResult | Exception]] = []
+ if not self.worker_tasks:
+ return completed_results
+
+ ready_refs, _ = ray.wait(
+ list(self.worker_tasks.keys()), num_returns=1, timeout=timeout
+ )
+
+ if not ready_refs:
+ return completed_results
+
+ for ref in ready_refs:
+ worker_idx = self.worker_tasks.pop(ref, -1)
+ if worker_idx == -1 or worker_idx not in self.active_worker_indices:
+ logger.warning(
+ f"Received result for unknown or inactive worker task: {ref}"
+ )
+ continue
+
+ try:
+ logger.debug(f"Attempting ray.get for worker {worker_idx} task {ref}")
+ result_raw = ray.get(ref)
+ logger.debug(f"ray.get succeeded for worker {worker_idx}")
+ try:
+ # Use Pydantic validation from SelfPlayResult
+ result_validated = SelfPlayResult.model_validate(result_raw)
+ completed_results.append((worker_idx, result_validated))
+ logger.debug(
+ f"Pydantic validation passed for worker {worker_idx} result."
+ )
+ except ValidationError as e_val:
+ error_msg = f"Pydantic validation failed for result from worker {worker_idx}: {e_val}"
+ logger.error(error_msg, exc_info=False)
+ logger.debug(f"Invalid data structure received: {result_raw}")
+ completed_results.append((worker_idx, ValueError(error_msg)))
+ except Exception as e_other_val:
+ error_msg = f"Unexpected error during result validation for worker {worker_idx}: {e_other_val}"
+ logger.error(error_msg, exc_info=True)
+ completed_results.append((worker_idx, e_other_val))
+
+ except ray.exceptions.RayActorError as e_actor:
+ logger.error(
+ f"Worker {worker_idx} actor failed: {e_actor}", exc_info=True
+ )
+ completed_results.append((worker_idx, e_actor))
+ self.workers[worker_idx] = None
+ self.active_worker_indices.discard(worker_idx)
+ except Exception as e_get:
+ logger.error(
+ f"Error getting result from worker {worker_idx} task {ref}: {e_get}",
+ exc_info=True,
+ )
+ completed_results.append((worker_idx, e_get))
+
+ return completed_results
+
+ def update_worker_networks(self, global_step: int):
+ """Sends the latest network weights and current global_step to all active workers."""
+ active_workers = [
+ w
+ for i, w in enumerate(self.workers)
+ if i in self.active_worker_indices and w is not None
+ ]
+ if not active_workers:
+ return
+ logger.debug(f"Updating worker networks for step {global_step}...")
+ current_weights = self.nn.get_weights()
+ weights_ref = ray.put(current_weights)
+ set_weights_tasks = [
+ worker.set_weights.remote(weights_ref) for worker in active_workers
+ ]
+ set_step_tasks = [
+ worker.set_current_trainer_step.remote(global_step)
+ for worker in active_workers
+ ]
+
+ all_tasks = set_weights_tasks + set_step_tasks
+ if not all_tasks:
+ del weights_ref
+ return
+ try:
+ ray.get(all_tasks, timeout=120.0)
+ logger.info(
+ f"Worker networks updated successfully for {len(active_workers)} workers to step {global_step}."
+ )
+ except ray.exceptions.RayActorError as e:
+ logger.error(
+ f"A worker actor failed during weight/step update: {e}", exc_info=True
+ )
+ except ray.exceptions.GetTimeoutError:
+ logger.error("Timeout waiting for workers to update weights/step.")
+ except Exception as e:
+ logger.error(
+ f"Unexpected error updating worker networks/step: {e}", exc_info=True
+ )
+ finally:
+ del weights_ref
+
+ def get_num_active_workers(self) -> int:
+ """Returns the number of currently active workers."""
+ return len(self.active_worker_indices)
+
+ def get_num_pending_tasks(self) -> int:
+ """Returns the number of tasks currently running."""
+ return len(self.worker_tasks)
+
+ def cleanup_actors(self):
+ """Kills Ray actors associated with this manager."""
+ logger.info("Cleaning up WorkerManager actors...")
+ for task_ref in list(self.worker_tasks.keys()):
+ try:
+ ray.cancel(task_ref)
+ except Exception as cancel_e:
+ logger.warning(f"Error cancelling task {task_ref}: {cancel_e}")
+ self.worker_tasks = {}
+
+ for i, worker in enumerate(self.workers):
+ if worker:
+ try:
+ ray.kill(worker, no_restart=True)
+ logger.debug(f"Killed worker {i}.")
+ except Exception as kill_e:
+ logger.warning(f"Error killing worker {i}: {kill_e}")
+ self.workers = []
+ self.active_worker_indices = set()
+ logger.info("WorkerManager actors cleaned up.")
+
+
+File: alphatriangle/training/setup.py
+# File: alphatriangle/training/setup.py
+import logging
+from typing import TYPE_CHECKING, cast
+
+import ray
+import torch
+
+# Import EnvConfig from trianglengin's top level
+from trianglengin import EnvConfig
+from trieye import ( # Import from Trieye
+ Serializer,
+ TrieyeActor,
+ TrieyeConfig,
+)
+
+# Keep alphatriangle imports
+from .. import config, utils
+from ..config import AlphaTriangleMCTSConfig, ModelConfig # Import ModelConfig
+from ..nn import NeuralNetwork
+from ..rl import ExperienceBuffer, Trainer
+from .components import TrainingComponents
+
+if TYPE_CHECKING:
+ from ..config import TrainConfig
+
+logger = logging.getLogger(__name__)
+
+
+def setup_training_components(
+ train_config_override: "TrainConfig",
+ trieye_config_override: TrieyeConfig,
+ profile: bool,
+ # Add optional overrides
+ model_config_override: ModelConfig | None = None,
+ mcts_config_override: AlphaTriangleMCTSConfig | None = None,
+) -> tuple[TrainingComponents | None, bool]:
+ """
+ Initializes Ray, detects cores, updates config, initializes TrieyeActor,
+ and returns the TrainingComponents bundle. Accepts optional config overrides.
+ """
+ ray_initialized_here = False
+ detected_cpu_cores: int | None = None
+ dashboard_started_successfully = False
+ trieye_actor_handle: ray.actor.ActorHandle | None = None
+
+ try:
+ # --- Ray Initialization ---
+ # (Ray init logic remains the same)
+ if not ray.is_initialized():
+ try:
+ logger.info("Attempting to initialize Ray with dashboard...")
+ ray.init(
+ logging_level=logging.WARNING,
+ log_to_driver=False,
+ include_dashboard=True,
+ )
+ ray_initialized_here = True
+ dashboard_started_successfully = True
+ logger.info(
+ "Ray initialized by setup_training_components WITH dashboard attempt."
+ )
+ except Exception as e_dash:
+ if "Cannot include dashboard with missing packages" in str(e_dash):
+ logger.warning(
+ "Ray dashboard dependencies missing. Retrying Ray initialization without dashboard. Install 'ray[default]' for dashboard support."
+ )
+ try:
+ ray.init(
+ logging_level=logging.WARNING,
+ log_to_driver=False,
+ include_dashboard=False,
+ )
+ ray_initialized_here = True
+ dashboard_started_successfully = False
+ logger.info(
+ "Ray initialized by setup_training_components WITHOUT dashboard."
+ )
+ except Exception as e_no_dash:
+ logger.critical(
+ f"Failed to initialize Ray even without dashboard: {e_no_dash}",
+ exc_info=True,
+ )
+ raise RuntimeError("Ray initialization failed") from e_no_dash
+ else:
+ logger.critical(
+ f"Failed to initialize Ray (with dashboard attempt): {e_dash}",
+ exc_info=True,
+ )
+ raise RuntimeError("Ray initialization failed") from e_dash
+
+ if dashboard_started_successfully:
+ logger.info(
+ "Ray Dashboard *should* be running. Check Ray startup logs for the exact URL (usually http://127.0.0.1:8265)."
+ )
+ elif ray_initialized_here:
+ logger.info(
+ "Ray Dashboard is NOT running (missing dependencies). Install 'ray[default]' to enable it."
+ )
+ else:
+ logger.info("Ray already initialized.")
+ logger.info(
+ "Ray Dashboard status in existing session unknown. Check Ray logs or http://127.0.0.1:8265."
+ )
+ ray_initialized_here = False
+
+ # --- Resource Detection ---
+ try:
+ resources = ray.cluster_resources()
+ cores_to_reserve = 2
+ available_cores = int(resources.get("CPU", 0))
+ detected_cpu_cores = max(0, available_cores - cores_to_reserve)
+ logger.info(
+ f"Ray detected {available_cores} total CPU cores. Reserving {cores_to_reserve}. Available for workers: {detected_cpu_cores}."
+ )
+ except Exception as e:
+ logger.error(f"Could not get Ray cluster resources: {e}")
+
+ # --- Load Configurations ---
+ train_config = train_config_override
+ trieye_config = trieye_config_override
+ env_config = EnvConfig()
+ # Use overrides if provided, otherwise load defaults
+ model_config = model_config_override or config.ModelConfig()
+ alphatriangle_mcts_config = mcts_config_override or AlphaTriangleMCTSConfig()
+ logger.info(
+ f"Using Model Config: {'Override' if model_config_override else 'Default'}"
+ )
+ logger.info(
+ f"Using MCTS Config: {'Override' if mcts_config_override else 'Default'}"
+ )
+
+ # --- Adjust Worker Count ---
+ requested_workers = train_config.NUM_SELF_PLAY_WORKERS
+ actual_workers = requested_workers
+ if detected_cpu_cores is not None and detected_cpu_cores > 0:
+ actual_workers = min(requested_workers, detected_cpu_cores)
+ if actual_workers != requested_workers:
+ logger.info(
+ f"Adjusting requested workers ({requested_workers}) to available cores ({detected_cpu_cores}). Using {actual_workers} workers."
+ )
+ elif detected_cpu_cores == 0:
+ logger.warning(
+ "Detected 0 available CPU cores after reservation. Setting worker count to 1."
+ )
+ actual_workers = 1
+ else:
+ logger.warning(
+ f"Could not detect valid CPU cores ({detected_cpu_cores}). Using configured NUM_SELF_PLAY_WORKERS: {requested_workers}"
+ )
+ train_config.NUM_SELF_PLAY_WORKERS = actual_workers
+ logger.info(f"Final worker count set to: {train_config.NUM_SELF_PLAY_WORKERS}")
+
+ # --- Validate AlphaTriangle Configurations ---
+ # Pass the potentially overridden MCTS config instance for validation
+ config.print_config_info_and_validate(alphatriangle_mcts_config)
+
+ # --- Create trimcts SearchConfiguration ---
+ trimcts_mcts_config = alphatriangle_mcts_config.to_trimcts_config()
+ logger.info(
+ f"Created trimcts.SearchConfiguration with max_simulations={trimcts_mcts_config.max_simulations}, mcts_batch_size={trimcts_mcts_config.mcts_batch_size}"
+ )
+
+ # --- Setup Devices and Seeds ---
+ utils.set_random_seeds(train_config.RANDOM_SEED)
+ device = utils.get_device(train_config.DEVICE)
+ worker_device = utils.get_device(train_config.WORKER_DEVICE)
+ logger.info(f"Determined Training Device: {device}")
+ logger.info(f"Determined Worker Device: {worker_device}")
+ logger.info(f"Model Compilation Enabled: {train_config.COMPILE_MODEL}")
+ logger.info(f"Worker Profiling Enabled: {profile}")
+
+ # --- Initialize Trieye Actor ---
+ actor_name = f"trieye_actor_{trieye_config.run_name}"
+ try:
+ trieye_actor_handle = ray.get_actor(actor_name)
+ logger.info(f"Reconnected to existing TrieyeActor '{actor_name}'.")
+ except ValueError:
+ logger.info(f"Creating new TrieyeActor '{actor_name}'.")
+ trieye_actor_handle = TrieyeActor.options(
+ name=actor_name, lifetime="detached"
+ ).remote(config=trieye_config)
+ if trieye_actor_handle:
+ ray.get(trieye_actor_handle.get_mlflow_run_id.remote(), timeout=10)
+ logger.info(f"TrieyeActor '{actor_name}' created and ready.")
+ else:
+ logger.error(
+ f"TrieyeActor handle is None immediately after creation for '{actor_name}'."
+ )
+ raise RuntimeError(
+ f"Failed to get handle for TrieyeActor '{actor_name}' after creation."
+ ) from None
+
+ if not trieye_actor_handle:
+ logger.critical(
+ f"Failed to create or connect to TrieyeActor '{actor_name}'. Cannot proceed."
+ )
+ raise RuntimeError(
+ f"TrieyeActor '{actor_name}' handle is invalid."
+ ) from None
+
+ # --- Initialize Core AlphaTriangle Components ---
+ serializer = Serializer() # Instantiate Serializer here
+ # Pass the potentially overridden model_config
+ neural_net = NeuralNetwork(model_config, env_config, train_config, device)
+ buffer = ExperienceBuffer(train_config)
+ trainer = Trainer(neural_net, train_config, env_config)
+
+ # --- Bundle Components ---
+ components = TrainingComponents(
+ nn=neural_net,
+ buffer=buffer,
+ trainer=trainer,
+ trieye_actor=cast("ray.actor.ActorHandle", trieye_actor_handle),
+ trieye_config=trieye_config,
+ serializer=serializer,
+ train_config=train_config,
+ env_config=env_config,
+ model_config=model_config, # Store the used model config
+ mcts_config=trimcts_mcts_config, # Store the used MCTS config
+ profile_workers=profile,
+ )
+
+ return components, ray_initialized_here
+ except Exception as e:
+ logger.critical(f"Error setting up training components: {e}", exc_info=True)
+ if trieye_actor_handle:
+ try:
+ ray.kill(trieye_actor_handle)
+ except Exception as kill_err:
+ logger.error(
+ f"Error killing TrieyeActor during setup cleanup: {kill_err}"
+ )
+ if ray_initialized_by_setup and ray.is_initialized():
+ try:
+ ray.shutdown()
+ logger.info("Ray shut down due to setup error.")
+ except Exception as ray_err:
+ logger.error(f"Error shutting down Ray during setup cleanup: {ray_err}")
+ return None, ray_initialized_here
+
+
+def count_parameters(model: torch.nn.Module) -> tuple[int, int]:
+ """Counts total and trainable parameters."""
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return total_params, trainable_params
+
+
+File: alphatriangle/training/runners.py
+# File: alphatriangle/training/runners.py
+"""
+Entry points for running training modes.
+Imports functions from specific runner modules.
+"""
+
+# Import from the renamed runner
+from .runner import run_training # Rename export
+
+__all__ = [
+ "run_training", # Export the single runner function
+]
+
+
+File: alphatriangle/training/loop.py
+# File: alphatriangle/training/loop.py
+import logging
+import threading
+import time
+from typing import TYPE_CHECKING, Any
+
+from trieye.schemas import RawMetricEvent # Import from trieye
+
+from ..rl import SelfPlayResult
+from .loop_helpers import LoopHelpers
+from .worker_manager import WorkerManager
+
+if TYPE_CHECKING:
+ import numpy as np
+
+ from ..utils.types import PERBatchSample
+ from .components import TrainingComponents
+
+
+logger = logging.getLogger(__name__)
+
+
+class TrainingLoop:
+ """
+ Manages the core asynchronous training loop logic: coordinating worker tasks,
+ processing results, triggering training steps, and interacting with TrieyeActor.
+ Runs headless.
+ """
+
+ def __init__(
+ self,
+ components: "TrainingComponents",
+ ):
+ self.components = components
+ self.train_config = components.train_config
+ self.trieye_config = components.trieye_config
+ self.buffer = components.buffer
+ self.trainer = components.trainer
+ self.trieye_actor = components.trieye_actor
+ self.serializer = components.serializer # Get serializer from components
+
+ self.global_step = 0
+ self.episodes_played = 0
+ self.total_simulations_run = 0
+ self.weight_update_count = 0
+ self.start_time = time.time()
+ self.stop_requested = threading.Event()
+ self.training_complete = False
+ self.training_exception: Exception | None = None
+ self._buffer_ready_logged = False
+
+ self.worker_manager = WorkerManager(components)
+ self.loop_helpers = LoopHelpers(components, self._get_loop_state)
+
+ logger.info("TrainingLoop initialized (Headless, using Trieye).")
+
+ def _get_loop_state(self) -> dict[str, Any]:
+ """Provides current loop state to helpers."""
+ return {
+ "global_step": self.global_step,
+ "episodes_played": self.episodes_played,
+ "total_simulations_run": self.total_simulations_run,
+ "weight_update_count": self.weight_update_count,
+ "buffer_size": len(self.buffer),
+ "buffer_capacity": self.buffer.capacity,
+ "num_active_workers": self.worker_manager.get_num_active_workers(),
+ "num_pending_tasks": self.worker_manager.get_num_pending_tasks(),
+ "start_time": self.start_time,
+ "num_workers": self.train_config.NUM_SELF_PLAY_WORKERS,
+ }
+
+ def set_initial_state(
+ self, global_step: int, episodes_played: int, total_simulations: int
+ ):
+ """Sets the initial state counters after loading."""
+ self.global_step = global_step
+ self.episodes_played = episodes_played
+ self.total_simulations_run = total_simulations
+ self.weight_update_count = (
+ global_step // self.train_config.WORKER_UPDATE_FREQ_STEPS
+ if self.train_config.WORKER_UPDATE_FREQ_STEPS > 0
+ else 0
+ )
+ logger.info(
+ f"TrainingLoop initial state set: Step={global_step}, Episodes={episodes_played}, Sims={total_simulations}, WeightUpdates={self.weight_update_count}"
+ )
+
+ def initialize_workers(self):
+ """Initializes self-play workers using WorkerManager."""
+ self.worker_manager.initialize_workers()
+
+ def request_stop(self):
+ """Signals the training loop to stop gracefully."""
+ if not self.stop_requested.is_set():
+ logger.info("Stop requested for TrainingLoop.")
+ self.stop_requested.set()
+
+ def _send_event_async(
+ self, name: str, value: float | int, context: dict | None = None
+ ):
+ """Helper to send a raw metric event to the Trieye actor asynchronously."""
+ if self.trieye_actor:
+ event = RawMetricEvent(
+ name=name,
+ value=value,
+ global_step=self.global_step,
+ timestamp=time.time(),
+ context=context or {},
+ )
+ try:
+ self.trieye_actor.log_event.remote(event)
+ except Exception as e:
+ logger.error(f"Failed to send event '{name}' to Trieye actor: {e}")
+
+ def _send_batch_events_async(self, events: list[RawMetricEvent]):
+ """Helper to send a batch of raw metric events asynchronously."""
+ if self.trieye_actor and events:
+ try:
+ self.trieye_actor.log_batch_events.remote(events)
+ except Exception as e:
+ logger.error(f"Failed to send batch events to Trieye actor: {e}")
+
+ def _process_self_play_result(self, result: SelfPlayResult, worker_id: int):
+ """Processes a validated result from a worker."""
+ logger.debug(
+ f"Processing result from worker {worker_id} (Ep Steps: {result.episode_steps}, Score: {result.final_score:.2f})"
+ )
+
+ valid_experiences, invalid_count = self.loop_helpers.validate_experiences(
+ result.episode_experiences
+ )
+ if invalid_count > 0:
+ logger.warning(
+ f"Worker {worker_id}: {invalid_count} invalid experiences were filtered out before adding to buffer."
+ )
+
+ if valid_experiences:
+ try:
+ self.buffer.add_batch(valid_experiences)
+ buffer_size = len(self.buffer)
+ logger.debug(
+ f"Added {len(valid_experiences)} experiences from worker {worker_id} to buffer (Buffer size: {buffer_size})."
+ )
+ if (
+ not self._buffer_ready_logged
+ and buffer_size >= self.buffer.min_size_to_train
+ ):
+ logger.info(
+ f"--- Buffer ready for training (Size: {buffer_size}/{self.buffer.min_size_to_train}) ---"
+ )
+ self._buffer_ready_logged = True
+
+ except Exception as e:
+ logger.error(
+ f"Error adding batch to buffer from worker {worker_id}: {e}",
+ exc_info=True,
+ )
+ return
+
+ self.episodes_played += 1
+ self.total_simulations_run += result.total_simulations
+
+ self._send_event_async("Progress/Episodes_Played", self.episodes_played)
+ self._send_event_async(
+ "Progress/Total_Simulations", self.total_simulations_run
+ )
+
+ else:
+ logger.error(
+ f"Worker {worker_id}: Self-play episode produced NO valid experiences (Steps: {result.episode_steps}, Score: {result.final_score:.2f}). This prevents buffer filling and training."
+ )
+
+ def _trigger_save_state(self, is_best: bool = False, save_buffer: bool = False):
+ """Triggers the Trieye actor to save the current training state."""
+ if not self.trieye_actor:
+ logger.error("Cannot save state: TrieyeActor handle is missing.")
+ return
+
+ logger.info(
+ f"Requesting state save via Trieye at step {self.global_step}. Best={is_best}, SaveBuffer={save_buffer}"
+ )
+ try:
+ # Prepare data using the local serializer
+ nn_state = self.components.nn.get_weights()
+ opt_state = self.serializer.prepare_optimizer_state(self.trainer.optimizer)
+ buffer_data = self.serializer.prepare_buffer_data(self.buffer)
+
+ if buffer_data is None and save_buffer:
+ logger.error("Failed to prepare buffer data, cannot save buffer.")
+ save_buffer = False # Don't attempt to save None
+
+ # Fire-and-forget call to the actor
+ self.trieye_actor.save_training_state.remote(
+ nn_state_dict=nn_state,
+ optimizer_state_dict=opt_state,
+ buffer_content=(
+ buffer_data.buffer_list if buffer_data else []
+ ), # Pass the list
+ global_step=self.global_step,
+ episodes_played=self.episodes_played,
+ total_simulations_run=self.total_simulations_run,
+ is_best=is_best,
+ save_buffer=save_buffer,
+ model_config_dict=self.components.model_config.model_dump(),
+ env_config_dict=self.components.env_config.model_dump(),
+ )
+ except Exception as e_save:
+ logger.error(
+ f"Failed to trigger save state via Trieye at step {self.global_step}: {e_save}",
+ exc_info=True,
+ )
+
+ def _run_training_step(self) -> bool:
+ """Runs one training step."""
+ if not self.buffer.is_ready():
+ return False
+ per_sample: PERBatchSample | None = self.buffer.sample(
+ self.train_config.BATCH_SIZE, current_train_step=self.global_step
+ )
+ if not per_sample:
+ return False
+
+ train_result: tuple[dict[str, float], np.ndarray] | None = (
+ self.trainer.train_step(per_sample)
+ )
+ if train_result:
+ loss_info, td_errors = train_result
+ prev_step = self.global_step
+ self.global_step += 1
+ if prev_step == 0:
+ logger.info(
+ f"--- First training step completed (Global Step: {self.global_step}) ---"
+ )
+ self._send_event_async("step_completed", 1.0)
+
+ if self.train_config.USE_PER:
+ self.buffer.update_priorities(per_sample["indices"], td_errors)
+ per_beta = self.buffer._calculate_beta(self.global_step)
+ self._send_event_async("PER/Beta", per_beta)
+
+ events_batch = []
+ loss_name_map = {
+ "total_loss": "Loss/Total",
+ "policy_loss": "Loss/Policy",
+ "value_loss": "Loss/Value",
+ "entropy": "Loss/Entropy",
+ "mean_td_error": "Loss/Mean_Abs_TD_Error",
+ }
+ for key, value in loss_info.items():
+ metric_name = loss_name_map.get(key)
+ if metric_name:
+ events_batch.append(
+ RawMetricEvent(
+ name=metric_name,
+ value=value,
+ global_step=self.global_step,
+ )
+ )
+ else:
+ logger.warning(f"Unmapped loss key from trainer: {key}")
+
+ current_lr = self.trainer.get_current_lr()
+ events_batch.append(
+ RawMetricEvent(
+ name="LearningRate", value=current_lr, global_step=self.global_step
+ )
+ )
+
+ self._send_batch_events_async(events_batch)
+
+ if (
+ self.train_config.WORKER_UPDATE_FREQ_STEPS > 0
+ and self.global_step % self.train_config.WORKER_UPDATE_FREQ_STEPS == 0
+ ):
+ logger.info(
+ f"Step {self.global_step}: Triggering worker network update (Frequency: {self.train_config.WORKER_UPDATE_FREQ_STEPS})."
+ )
+ try:
+ self.worker_manager.update_worker_networks(self.global_step)
+ self.weight_update_count += 1
+ self._send_event_async(
+ "Progress/Weight_Updates_Total", self.weight_update_count
+ )
+ except Exception as update_err:
+ logger.error(
+ f"Failed to update worker networks at step {self.global_step}: {update_err}"
+ )
+
+ if self.global_step % 50 == 0:
+ logger.info(
+ f"Step {self.global_step}: P Loss={loss_info.get('policy_loss', 0.0):.4f}, V Loss={loss_info.get('value_loss', 0.0):.4f}"
+ )
+ return True
+ else:
+ logger.warning(f"Training step {self.global_step + 1} failed.")
+ return False
+
+ def run(self):
+ """Main training loop."""
+ max_steps_info = (
+ f"Target steps: {self.train_config.MAX_TRAINING_STEPS}"
+ if self.train_config.MAX_TRAINING_STEPS is not None
+ else "Running indefinitely (no MAX_TRAINING_STEPS)"
+ )
+ logger.info(f"Starting TrainingLoop run... {max_steps_info}")
+ self.start_time = time.time()
+
+ try:
+ self.worker_manager.submit_initial_tasks()
+
+ while not self.stop_requested.is_set():
+ if (
+ self.train_config.MAX_TRAINING_STEPS is not None
+ and self.global_step >= self.train_config.MAX_TRAINING_STEPS
+ ):
+ logger.info(
+ f"Reached MAX_TRAINING_STEPS ({self.train_config.MAX_TRAINING_STEPS}). Stopping loop."
+ )
+ self.training_complete = True
+ self.request_stop()
+ break
+
+ trained_this_step = False
+ if self.buffer.is_ready():
+ trained_this_step = self._run_training_step()
+ else:
+ if not self._buffer_ready_logged and self.global_step % 100 == 0:
+ logger.info(
+ f"Waiting for buffer... Size: {len(self.buffer)}/{self.buffer.min_size_to_train}"
+ )
+ time.sleep(0.01)
+
+ if (
+ trained_this_step
+ and self.train_config.CHECKPOINT_SAVE_FREQ_STEPS > 0
+ and self.global_step % self.train_config.CHECKPOINT_SAVE_FREQ_STEPS
+ == 0
+ ):
+ self._trigger_save_state(is_best=False, save_buffer=False)
+
+ if (
+ trained_this_step
+ and self.trieye_config.persistence.SAVE_BUFFER
+ and self.trieye_config.persistence.BUFFER_SAVE_FREQ_STEPS > 0
+ and self.global_step
+ % self.trieye_config.persistence.BUFFER_SAVE_FREQ_STEPS
+ == 0
+ ):
+ self._trigger_save_state(is_best=False, save_buffer=True)
+
+ if self.stop_requested.is_set():
+ break
+
+ wait_timeout = 0.1 if self.buffer.is_ready() else 0.5
+ completed_tasks = self.worker_manager.get_completed_tasks(wait_timeout)
+
+ for worker_id, result_or_error in completed_tasks:
+ if isinstance(result_or_error, SelfPlayResult):
+ try:
+ self._process_self_play_result(result_or_error, worker_id)
+ except Exception as proc_err:
+ logger.error(
+ f"Error processing result from worker {worker_id}: {proc_err}",
+ exc_info=True,
+ )
+ elif isinstance(result_or_error, Exception):
+ logger.error(
+ f"Worker {worker_id} task failed with exception: {result_or_error}"
+ )
+ else:
+ logger.error(
+ f"Received unexpected item from completed tasks for worker {worker_id}: {type(result_or_error)}"
+ )
+
+ self.worker_manager.submit_task(worker_id)
+
+ if self.stop_requested.is_set():
+ break
+
+ self.loop_helpers.log_progress_eta()
+ loop_state = self._get_loop_state()
+ self._send_event_async("Buffer/Size", loop_state["buffer_size"])
+ self._send_event_async(
+ "System/Num_Active_Workers", loop_state["num_active_workers"]
+ )
+ self._send_event_async(
+ "System/Num_Pending_Tasks", loop_state["num_pending_tasks"]
+ )
+
+ if self.trieye_actor:
+ self.trieye_actor.process_and_log.remote(self.global_step)
+
+ if (
+ not completed_tasks
+ and not trained_this_step
+ and not self.buffer.is_ready()
+ ):
+ time.sleep(0.05)
+
+ except KeyboardInterrupt:
+ logger.warning("KeyboardInterrupt received in TrainingLoop. Stopping.")
+ self.request_stop()
+ except Exception as e:
+ logger.critical(f"Unhandled exception in TrainingLoop: {e}", exc_info=True)
+ self.training_exception = e
+ self.request_stop()
+ finally:
+ if (
+ self.training_exception
+ or self.stop_requested.is_set()
+ and not self.training_complete
+ ):
+ self.training_complete = False
+ logger.info(
+ f"TrainingLoop finished. Complete: {self.training_complete}, Exception: {self.training_exception is not None}"
+ )
+
+ def cleanup_actors(self):
+ """Cleans up worker actors. TrieyeActor cleanup handled by runner."""
+ self.worker_manager.cleanup_actors()
+
+
+File: alphatriangle/training/components.py
+# File: alphatriangle/training/components.py
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+import ray
+
+# Import EnvConfig from trianglengin's top level
+from trianglengin import EnvConfig
+from trieye import Serializer # Import Serializer from trieye
+from trimcts import SearchConfiguration
+
+# Keep alphatriangle imports
+
+if TYPE_CHECKING:
+ from trieye import TrieyeConfig # Import TrieyeConfig
+
+ from alphatriangle.config import ModelConfig, TrainConfig
+ from alphatriangle.nn import NeuralNetwork
+ from alphatriangle.rl import ExperienceBuffer, Trainer
+
+
+@dataclass
+class TrainingComponents:
+ """Holds the initialized core components needed for training."""
+
+ nn: "NeuralNetwork"
+ buffer: "ExperienceBuffer"
+ trainer: "Trainer"
+ trieye_actor: ray.actor.ActorHandle
+ trieye_config: "TrieyeConfig"
+ serializer: Serializer # Added Serializer instance
+ train_config: "TrainConfig"
+ env_config: EnvConfig
+ model_config: "ModelConfig"
+ mcts_config: SearchConfiguration
+ profile_workers: bool
+
+
+File: alphatriangle/features/__init__.py
+"""
+Feature extraction module.
+Converts raw GameState objects into numerical representations suitable for NN input.
+"""
+
+from . import grid_features
+from .extractor import GameStateFeatures, extract_state_features
+
+__all__ = [
+ "extract_state_features",
+ "GameStateFeatures",
+ "grid_features",
+]
+
+
+File: alphatriangle/features/extractor.py
+import logging
+from typing import cast # Added List, Optional, Tuple
+
+import numpy as np
+
+# Import GameState and Shape from trianglengin's top level
+from trianglengin import EnvConfig, GameState, Shape # UPDATED IMPORT
+
+# Import alphatriangle configs
+from ..config import ModelConfig
+from ..utils.types import StateType # Import StateType
+
+# Import grid_features locally
+from . import grid_features
+
+logger = logging.getLogger(__name__)
+
+
+class GameStateFeatures:
+ """Extracts features from trianglengin.GameState for NN input."""
+
+ def __init__(self, game_state: GameState, model_config: ModelConfig):
+ self.gs = game_state
+ # Access EnvConfig directly from the GameState object
+ self.env_config: EnvConfig = game_state.env_config
+ self.model_config = model_config
+ # Access GridData NumPy arrays directly via the GameState wrapper
+ grid_data_np = game_state.get_grid_data_np()
+ self.occupied_np = grid_data_np["occupied"]
+ self.death_np = grid_data_np["death"]
+ self.color_id_np = grid_data_np["color_id"]
+
+ def _get_grid_state(self) -> np.ndarray:
+ """
+ Returns grid state as a single channel numpy array based on NumPy arrays.
+ Values: 1.0 (occupied playable), 0.0 (empty playable), -1.0 (death cell).
+ Shape: (C, H, W) where C is GRID_INPUT_CHANNELS
+ """
+ rows, cols = self.env_config.ROWS, self.env_config.COLS
+ grid_state: np.ndarray = np.zeros(
+ (self.model_config.GRID_INPUT_CHANNELS, rows, cols), dtype=np.float32
+ )
+ playable_occupied_mask = self.occupied_np & ~self.death_np
+ grid_state[0, playable_occupied_mask] = 1.0
+ grid_state[0, self.death_np] = -1.0
+ return grid_state
+
+ def _get_shape_features(self) -> np.ndarray:
+ """Extracts numerical features for each shape slot using the Shape class."""
+ num_slots = self.env_config.NUM_SHAPE_SLOTS
+ # Ensure this matches OTHER_NN_INPUT_FEATURES_DIM calculation
+ FEATURES_PER_SHAPE_HERE = 7
+ shape_feature_matrix = np.zeros(
+ (num_slots, FEATURES_PER_SHAPE_HERE), dtype=np.float32
+ )
+
+ # Access shapes via the GameState wrapper method
+ current_shapes: list[Shape | None] = self.gs.get_shapes()
+ for i, shape in enumerate(current_shapes):
+ # Check if shape exists and has triangles
+ if shape and shape.triangles:
+ n_tris = len(shape.triangles)
+ ups = sum(1 for _, _, is_up in shape.triangles if is_up)
+ downs = n_tris - ups
+ # Use the bbox method of the Shape class
+ min_r, min_c, max_r, max_c = shape.bbox()
+ height = max_r - min_r + 1
+ width_eff = (max_c - min_c + 1) * 0.75 + 0.25 if n_tris > 0 else 0
+
+ shape_feature_matrix[i, 0] = np.clip(n_tris / 5.0, 0, 1)
+ shape_feature_matrix[i, 1] = ups / n_tris if n_tris > 0 else 0
+ shape_feature_matrix[i, 2] = downs / n_tris if n_tris > 0 else 0
+ shape_feature_matrix[i, 3] = np.clip(
+ height / self.env_config.ROWS, 0, 1
+ )
+ shape_feature_matrix[i, 4] = np.clip(
+ width_eff / self.env_config.COLS, 0, 1
+ )
+ shape_feature_matrix[i, 5] = np.clip(
+ ((min_r + max_r) / 2.0) / self.env_config.ROWS, 0, 1
+ )
+ shape_feature_matrix[i, 6] = np.clip(
+ ((min_c + max_c) / 2.0) / self.env_config.COLS, 0, 1
+ )
+ return shape_feature_matrix.flatten()
+
+ def _get_shape_availability(self) -> np.ndarray:
+ """Returns a binary vector indicating which shape slots are filled."""
+ current_shapes = self.gs.get_shapes()
+ return np.array([1.0 if s else 0.0 for s in current_shapes], dtype=np.float32)
+
+ def _get_explicit_features(self) -> np.ndarray:
+ """
+ Extracts scalar features like score, heights, holes, etc.
+ Uses GridData NumPy arrays directly.
+ """
+ # Ensure this matches OTHER_NN_INPUT_FEATURES_DIM calculation
+ EXPLICIT_FEATURES_DIM_HERE = 6
+ features = np.zeros(EXPLICIT_FEATURES_DIM_HERE, dtype=np.float32)
+ occupied = self.occupied_np
+ death = self.death_np
+ rows, cols = self.env_config.ROWS, self.env_config.COLS
+
+ heights = grid_features.get_column_heights(occupied, death, rows, cols)
+ holes = grid_features.count_holes(occupied, death, heights, rows, cols)
+ bump = grid_features.get_bumpiness(heights)
+ total_playable_cells = np.sum(~death)
+
+ features[0] = np.clip(self.gs.game_score() / 100.0, -5.0, 5.0)
+ features[1] = np.mean(heights) / rows if rows > 0 else 0
+ features[2] = np.max(heights) / rows if rows > 0 else 0
+ features[3] = holes / total_playable_cells if total_playable_cells > 0 else 0
+ features[4] = (bump / (cols - 1)) / rows if cols > 1 and rows > 0 else 0
+ features[5] = np.clip(self.gs.current_step / 1000.0, 0, 1)
+
+ return cast(
+ "np.ndarray", np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
+ )
+
+ def get_combined_other_features(self) -> np.ndarray:
+ """Combines all non-grid numerical features into a single flat vector."""
+ shape_feats = self._get_shape_features()
+ avail_feats = self._get_shape_availability()
+ explicit_feats = self._get_explicit_features()
+ combined = np.concatenate([shape_feats, avail_feats, explicit_feats])
+
+ expected_dim = self.model_config.OTHER_NN_INPUT_FEATURES_DIM
+ if combined.shape[0] != expected_dim:
+ logger.error(
+ f"Combined other_features dimension mismatch! Extracted {combined.shape[0]}, but ModelConfig expects {expected_dim}. Padding/truncating."
+ )
+ # Adjust size if necessary
+ if combined.shape[0] < expected_dim:
+ padding = np.zeros(
+ expected_dim - combined.shape[0], dtype=combined.dtype
+ )
+ combined = np.concatenate([combined, padding])
+ else:
+ combined = combined[:expected_dim]
+
+ if not np.all(np.isfinite(combined)):
+ logger.error(
+ f"Non-finite values detected in combined other_features! Min: {np.nanmin(combined)}, Max: {np.nanmax(combined)}, Mean: {np.nanmean(combined)}"
+ )
+ combined = np.nan_to_num(combined, nan=0.0, posinf=0.0, neginf=0.0)
+
+ return cast("np.ndarray", combined.astype(np.float32))
+
+
+def extract_state_features(
+ game_state: GameState, model_config: ModelConfig
+) -> StateType:
+ """
+ Extracts and returns the state dictionary {grid, other_features}
+ for NN input and buffer storage.
+ Requires ModelConfig to ensure dimensions match the network's expectations.
+ Includes validation for non-finite values. Now reads from trianglengin.GameState wrapper.
+ """
+ extractor = GameStateFeatures(game_state, model_config)
+ state_dict: StateType = {
+ "grid": extractor._get_grid_state(),
+ "other_features": extractor.get_combined_other_features(),
+ }
+ grid_feat = state_dict["grid"]
+ other_feat = state_dict["other_features"]
+ logger.debug(
+ f"Extracted Features (State {game_state.current_step}): "
+ f"Grid(shape={grid_feat.shape}, min={grid_feat.min():.2f}, max={grid_feat.max():.2f}, mean={grid_feat.mean():.2f}), "
+ f"Other(shape={other_feat.shape}, min={other_feat.min():.2f}, max={other_feat.max():.2f}, mean={other_feat.mean():.2f})"
+ )
+ return state_dict
+
+
+File: alphatriangle/features/README.md
+
+# Feature Extraction Module (`alphatriangle.features`)
+
+## Purpose and Architecture
+
+This module is solely responsible for converting raw [`GameState`](../../trianglengin/game_interface.py) objects from the `trianglengin` library into numerical representations (features) suitable for input into the neural network ([`alphatriangle.nn`](../nn/README.md)). It acts as a bridge between the game's internal state and the requirements of the machine learning model.
+
+- **Decoupling:** This module completely decouples feature engineering from the core game engine logic. The `trianglengin` library focuses only on game rules and state transitions (now largely in C++), while this module handles the transformation for the NN.
+- **Feature Engineering:**
+ - [`extractor.py`](extractor.py): Contains the `GameStateFeatures` class and the main `extract_state_features` function. This orchestrates the extraction process, calling helper functions to generate different feature types. It uses the `Shape` class from `trianglengin.game_interface`. **It now also extracts the geometry (triangle list and color ID) of shapes available in the slots.**
+ - [`grid_features.py`](grid_features.py): Contains low-level, potentially performance-optimized (e.g., using Numba) functions for calculating specific scalar metrics derived from the grid state (like column heights, holes, bumpiness). This module operates directly on NumPy arrays obtained from the `GameState` wrapper.
+- **Output Format:** The `extract_state_features` function returns a `StateType` (a `TypedDict` defined in [`alphatriangle.utils.types`](../utils/types.py) containing `grid`, `other_features`, and `available_shapes_geometry` numpy arrays/lists), which is the standard input format expected by the `NeuralNetwork` interface and stored in the `ExperienceBuffer`.
+- **Configuration Dependency:** The extractor requires [`ModelConfig`](../config/model_config.py) to ensure the dimensions of the extracted numerical features (`other_features`) match the expectations of the neural network architecture.
+
+## Exposed Interfaces
+
+- **Functions:**
+ - `extract_state_features(game_state: GameState, model_config: ModelConfig) -> StateType`: The main function to perform feature extraction.
+ - Low-level grid feature functions from `grid_features` (e.g., `get_column_heights`, `count_holes`, `get_bumpiness`).
+- **Classes:**
+ - `GameStateFeatures`: Class containing the feature extraction logic (primarily used internally by `extract_state_features`).
+
+## Dependencies
+
+- **`trianglengin`**:
+ - `GameState`: The input object for feature extraction.
+ - `EnvConfig`: Accessed via `GameState` for environment dimensions.
+ - `Shape`: Used for processing shape features.
+- **[`alphatriangle.config`](../config/README.md)**:
+ - `ModelConfig`: Required by `extract_state_features` to ensure output dimensions match the NN input layer.
+- **[`alphatriangle.utils`](../utils/README.md)**:
+ - `StateType`: The return type dictionary format.
+ - `SerializableShapeInfo`, `ShapeGeometry`: Types for shape geometry.
+- **`numpy`**:
+ - Used extensively for creating and manipulating the numerical feature arrays.
+- **`numba`**:
+ - Used in `grid_features` for performance optimization.
+- **Standard Libraries:** `typing`, `logging`.
+
+---
+
+**Note:** Please keep this README updated when changing the feature extraction logic, the set of extracted features, or the output format (`StateType`). Accurate documentation is crucial for maintainability.
+
+
+File: alphatriangle/features/grid_features.py
+# File: alphatriangle/features/grid_features.py
+# No changes needed, operates on NumPy arrays.
+import numpy as np
+from numba import njit, prange
+
+
+@njit(cache=True)
+def get_column_heights(
+ occupied: np.ndarray, death: np.ndarray, rows: int, cols: int
+) -> np.ndarray:
+ """Calculates the height of each column (highest occupied non-death cell)."""
+ heights = np.zeros(cols, dtype=np.int32)
+ for c in prange(cols):
+ max_r = -1
+ for r in range(rows):
+ if occupied[r, c] and not death[r, c]:
+ max_r = r
+ heights[c] = max_r + 1
+ return heights
+
+
+@njit(cache=True)
+def count_holes(
+ occupied: np.ndarray, death: np.ndarray, heights: np.ndarray, _rows: int, cols: int
+) -> int:
+ """Counts the number of empty, non-death cells below the column height."""
+ holes = 0
+ for c in prange(cols):
+ col_height = heights[c]
+ for r in range(col_height):
+ if not occupied[r, c] and not death[r, c]:
+ holes += 1
+ return holes
+
+
+@njit(cache=True)
+def get_bumpiness(heights: np.ndarray) -> float:
+ """Calculates the total absolute difference between adjacent column heights."""
+ bumpiness = 0.0
+ for i in range(len(heights) - 1):
+ bumpiness += abs(heights[i] - heights[i + 1])
+ return bumpiness
+
+
+File: alphatriangle/utils/__init__.py
+# File: alphatriangle/utils/__init__.py
+from .geometry import is_point_in_polygon
+from .helpers import (
+ format_eta,
+ get_device,
+ normalize_color_for_matplotlib,
+ set_random_seeds,
+)
+from .sumtree import SumTree
+from .types import (
+ ActionType,
+ Experience,
+ ExperienceBatch,
+ PERBatchSample,
+ PolicyValueOutput,
+ StateType,
+ # Removed StatsCollectorData
+)
+
+__all__ = [
+ # helpers
+ "get_device",
+ "set_random_seeds",
+ "format_eta",
+ "normalize_color_for_matplotlib",
+ # types
+ "StateType",
+ "ActionType",
+ "Experience",
+ "ExperienceBatch",
+ "PolicyValueOutput",
+ # Removed StatsCollectorData
+ "PERBatchSample",
+ # geometry
+ "is_point_in_polygon",
+ # structures
+ "SumTree",
+]
+
+
+File: alphatriangle/utils/types.py
+# File: alphatriangle/utils/types.py
+from collections.abc import Mapping
+
+import numpy as np
+from typing_extensions import TypedDict
+
+
+class StateType(TypedDict):
+ """
+ Represents the processed state features input to the neural network.
+ Contains numerical arrays derived from the raw GameState.
+ """
+
+ grid: np.ndarray # (C, H, W) float32, e.g., occupancy, death cells
+ other_features: np.ndarray # (OtherFeatDim,) float32, e.g., shape info, game stats
+
+
+# Action representation (integer index)
+ActionType = int
+
+# Policy target from MCTS (visit counts distribution)
+# Mapping from action index to its probability (normalized visit count)
+PolicyTargetMapping = Mapping[ActionType, float]
+
+
+# REMOVED StepInfo TypedDict (no longer needed directly here)
+
+
+Experience = tuple[StateType, PolicyTargetMapping, float]
+# Represents one unit of experience stored in the replay buffer.
+# 1. StateType: The processed features (grid, other_features) of the state s_t.
+# NOTE: This is NOT the raw GameState object.
+# 2. PolicyTargetMapping: The MCTS-derived policy target pi(a|s_t) for state s_t.
+# 3. float: The calculated N-step return G_t^n starting from state s_t, used
+# as the target for the value head during training.
+
+
+# Batch of experiences for training
+ExperienceBatch = list[Experience]
+
+
+# Output type from the neural network's evaluate method (for MCTS interaction)
+# 1. dict[ActionType, float]: Policy probabilities P(a|s) for the evaluated state.
+# 2. float: The expected value V(s) calculated from the value distribution logits.
+PolicyValueOutput = tuple[dict[int, float], float]
+
+
+# REMOVED StatsCollectorData type alias
+
+
+class PERBatchSample(TypedDict):
+ """Output of the PER buffer's sample method."""
+
+ batch: ExperienceBatch # The sampled experiences
+ indices: np.ndarray # Tree indices of the sampled experiences (for priority update)
+ weights: np.ndarray # Importance sampling weights for the sampled experiences
+
+
+File: alphatriangle/utils/README.md
+# Utilities Module (alphatriangle.utils)
+
+## Purpose and Architecture
+
+This module provides common utility functions and type definitions used across various parts of the AlphaTriangle project. Its goal is to avoid code duplication and provide central definitions for shared concepts.
+
+- **Helper Functions ([helpers.py](helpers.py)):** Contains miscellaneous helper functions:
+ - get_device: Determines the appropriate PyTorch device (CPU, CUDA, MPS) based on availability and preference.
+ - set_random_seeds: Initializes random number generators for Python, NumPy, and PyTorch for reproducibility.
+ - format_eta: Converts a time duration (in seconds) into a human-readable string (HH:MM:SS).
+ - normalize_color_for_matplotlib: Converts RGB (0-255) to Matplotlib format (0.0-1.0).
+- **Type Definitions ([types.py](types.py)):** Defines common type aliases and TypedDicts used throughout the codebase, particularly for data structures passed between modules (like RL components, NN, and environment). This improves code readability and enables better static analysis. Examples include:
+ - StateType: A TypedDict defining the structure of the state representation passed to the NN and stored in the buffer. Contains grid (occupancy features) and other_features (numerical features).
+ - ActionType: An alias for int, representing encoded actions.
+ - PolicyTargetMapping: A mapping from ActionType to float, representing the policy target from MCTS.
+ - Experience: A tuple representing (StateType, PolicyTargetMapping, float) stored in the replay buffer (the float is the n-step return).
+ - ExperienceBatch: A list of Experience tuples.
+ - PolicyValueOutput: A tuple representing (PolicyTargetMapping, float) returned by the NN's evaluate method (the float is the expected value).
+ - PERBatchSample: A TypedDict defining the output of the PER buffer's sample method, including the batch, indices, and importance sampling weights.
+- **Geometry Utilities ([geometry.py](geometry.py)):** Contains geometric helper functions.
+ - is_point_in_polygon: Checks if a 2D point lies inside a given polygon.
+- **Data Structures ([sumtree.py](sumtree.py)):**
+ - SumTree: A simple SumTree implementation used for Prioritized Experience Replay.
+
+## Exposed Interfaces
+
+- **Functions:**
+ - get_device(device_preference: str = "auto") -> torch.device
+ - set_random_seeds(seed: int = 42)
+ - format_eta(seconds: Optional[float]) -> str
+ - normalize_color_for_matplotlib(color_tuple_0_255: tuple[int, int, int]) -> tuple[float, float, float]
+ - is_point_in_polygon(point: Tuple[float, float], polygon: List[Tuple[float, float]]) -> bool
+- **Classes:**
+ - SumTree: For PER.
+- **Types:**
+ - StateType (TypedDict)
+ - ActionType (TypeAlias for int)
+ - PolicyTargetMapping (TypeAlias for Mapping[ActionType, float])
+ - Experience (TypeAlias for Tuple[StateType, PolicyTargetMapping, float])
+ - ExperienceBatch (TypeAlias for List[Experience])
+ - PolicyValueOutput (TypeAlias for Tuple[Mapping[ActionType, float], float])
+ - PERBatchSample (TypedDict)
+
+## Dependencies
+
+- **torch**:
+ - Used by get_device and set_random_seeds.
+- **numpy**:
+ - Used by set_random_seeds and potentially in type definitions (np.ndarray).
+- **Standard Libraries:** typing, random, os, math, logging, collections.deque.
+- **typing_extensions**: For TypedDict.
+
+---
+
+**Note:** Please keep this README updated when adding or modifying utility functions or type definitions, especially those used as interfaces between different modules. Accurate documentation is crucial for maintainability.
+
+
+File: alphatriangle/utils/geometry.py
+def is_point_in_polygon(
+ point: tuple[float, float], polygon: list[tuple[float, float]]
+) -> bool:
+ """
+ Checks if a point is inside a polygon using the ray casting algorithm.
+
+ Args:
+ point: Tuple (x, y) representing the point coordinates.
+ polygon: List of tuples [(x1, y1), (x2, y2), ...] representing polygon vertices in order.
+
+ Returns:
+ True if the point is inside the polygon, False otherwise.
+ """
+ x, y = point
+ n = len(polygon)
+ inside = False
+
+ p1x, p1y = polygon[0]
+ for i in range(n + 1):
+ p2x, p2y = polygon[i % n]
+ # Combine nested if statements
+ if y > min(p1y, p2y) and y <= max(p1y, p2y) and x <= max(p1x, p2x):
+ # Use ternary operator for xinters calculation
+ xinters = ((y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x) if p1y != p2y else x
+
+ # Check if point is on the segment boundary or crosses the ray
+ if abs(p1x - p2x) < 1e-9: # Vertical line segment
+ if abs(x - p1x) < 1e-9:
+ return True # Point is on the vertical segment
+ elif abs(x - xinters) < 1e-9: # Point is exactly on the intersection
+ return True # Point is on the boundary
+ elif (
+ p1x == p2x or x <= xinters
+ ): # Point is to the left or on a non-horizontal segment
+ inside = not inside
+
+ p1x, p1y = p2x, p2y
+
+ # Check if the point is exactly one of the vertices
+ for px, py in polygon:
+ if abs(x - px) < 1e-9 and abs(y - py) < 1e-9:
+ return True
+
+ return inside
+
+
+File: alphatriangle/utils/sumtree.py
+import numpy as np
+
+from .types import Experience
+
+
+class SumTree:
+ """
+ Simple SumTree implementation for efficient prioritized sampling.
+ Stores priorities and allows sampling proportional to priority.
+ """
+
+ def __init__(self, capacity: int):
+ self.capacity = capacity
+
+ # Tree structure: Stores priorities. Size is 2*capacity - 1.
+ # Leaves are indices capacity-1 to 2*capacity-2.
+ self.tree = np.zeros(2 * capacity - 1)
+
+ # Data storage: Stores the actual experiences. Size is capacity.
+ self.data = np.zeros(capacity, dtype=object)
+ self.data_pointer = 0 # Points to the next available data slot
+ self.n_entries = 0 # Current number of entries in the buffer
+ self._max_priority = 1.0 # Track max priority for new entries
+
+ def add(self, priority: float, data: Experience):
+ """Adds an experience with a given priority."""
+ # Calculate the tree index for the leaf corresponding to the data slot
+ tree_idx = self.data_pointer + self.capacity - 1
+
+ # Store the data
+ self.data[self.data_pointer] = data
+
+ # Update the tree with the new priority
+ self.update(tree_idx, priority)
+
+ # Move data pointer
+ self.data_pointer += 1
+ if self.data_pointer >= self.capacity:
+ self.data_pointer = 0 # Wrap around
+
+ # Update entry count
+ if self.n_entries < self.capacity:
+ self.n_entries += 1
+
+ # Update max priority seen
+ self._max_priority = max(self._max_priority, priority)
+
+ def update(self, tree_idx: int, priority: float):
+ """Updates the priority of an experience at a given tree index."""
+ # Calculate the change in priority
+ change = priority - self.tree[tree_idx]
+
+ # Update the leaf node
+ self.tree[tree_idx] = priority
+
+ # Propagate the change up the tree
+ while tree_idx != 0:
+ tree_idx = (tree_idx - 1) // 2 # Move to parent index
+ self.tree[tree_idx] += change
+
+ def get_leaf(self, value: float) -> tuple[int, float, Experience]:
+ """Finds the leaf node corresponding to a given value (for sampling)."""
+ parent_idx = 0
+ while True:
+ left_child_idx = 2 * parent_idx + 1
+ right_child_idx = left_child_idx + 1
+
+ # If left child index is out of bounds, we've reached a leaf node
+ if left_child_idx >= len(self.tree):
+ leaf_idx = parent_idx
+ break
+ else:
+ # If the value is less than or equal to the left child's priority sum,
+ # go down the left branch.
+ if value <= self.tree[left_child_idx]:
+ parent_idx = left_child_idx
+ # Otherwise, subtract the left child's sum and go down the right branch.
+ else:
+ value -= self.tree[left_child_idx]
+ parent_idx = right_child_idx
+
+ # Calculate the corresponding data index in the self.data array
+ data_idx = leaf_idx - self.capacity + 1
+
+ # Return the tree index, the priority at that leaf, and the data
+ return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
+
+ @property
+ def total_priority(self) -> float:
+ """Returns the total priority (root node value)."""
+ # Ensure return type is float
+ return float(self.tree[0])
+
+ @property
+ def max_priority(self) -> float:
+ """Returns the maximum priority seen so far."""
+ # Return 1.0 if buffer is empty to avoid issues with initial adds
+ return self._max_priority if self.n_entries > 0 else 1.0
+
+ def __len__(self) -> int:
+ return self.n_entries
+
+
+File: alphatriangle/utils/helpers.py
+# File: alphatriangle/utils/helpers.py
+import logging
+import random
+from typing import cast
+
+import numpy as np
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+def get_device(device_preference: str = "auto") -> torch.device:
+ """
+ Gets the appropriate torch device based on preference and availability.
+ Prioritizes MPS on Mac if 'auto' is selected.
+ """
+ if device_preference == "cuda" and torch.cuda.is_available():
+ logger.info("Using CUDA device.")
+ return torch.device("cuda")
+ # --- CHANGED: Prioritize MPS if available and preferred/auto ---
+ if (
+ device_preference in ["auto", "mps"]
+ and hasattr(torch.backends, "mps")
+ and torch.backends.mps.is_available()
+ and torch.backends.mps.is_built() # Check if MPS is built
+ ):
+ logger.info(f"Using MPS device (Preference: {device_preference}).")
+ return torch.device("mps")
+ # --- END CHANGED ---
+ if device_preference == "cpu":
+ logger.info("Using CPU device.")
+ return torch.device("cpu")
+
+ # Auto-detection fallback (after MPS check)
+ if torch.cuda.is_available():
+ logger.info("Auto-selected CUDA device.")
+ return torch.device("cuda")
+ # Check MPS again in fallback (should have been caught above if available)
+ if (
+ hasattr(torch.backends, "mps")
+ and torch.backends.mps.is_available()
+ and torch.backends.mps.is_built()
+ ):
+ logger.info("Auto-selected MPS device.")
+ return torch.device("mps")
+
+ logger.info("Auto-selected CPU device.")
+ return torch.device("cpu")
+
+
+def set_random_seeds(seed: int = 42):
+ """Sets random seeds for Python, NumPy, and PyTorch."""
+ random.seed(seed)
+ # Use NumPy's recommended way to seed the global RNG state if needed,
+ # or preferably use a Generator instance. For simplicity here, we seed global.
+ np.random.seed(seed) # noqa: NPY002
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # Optional: Set deterministic algorithms for CuDNN
+ # torch.backends.cudnn.deterministic = True
+ # torch.backends.cudnn.benchmark = False
+ # --- ADDED: Seed MPS if available ---
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
+ try:
+ # Use torch.mps.manual_seed if available (newer PyTorch versions)
+ if hasattr(torch, "mps") and hasattr(torch.mps, "manual_seed"):
+ torch.mps.manual_seed(seed) # type: ignore
+ else:
+ # Fallback for older versions if needed, though less common
+ # torch.manual_seed(seed) might cover MPS indirectly in some versions
+ pass
+ except Exception as e:
+ logger.warning(f"Could not set MPS seed: {e}")
+ # --- END ADDED ---
+ logger.info(f"Set random seeds to {seed}")
+
+
+def format_eta(seconds: float | None) -> str:
+ """Formats seconds into a human-readable HH:MM:SS or MM:SS string."""
+ if seconds is None or not np.isfinite(seconds) or seconds < 0:
+ return "N/A"
+ if seconds > 3600 * 24 * 30:
+ return ">1 month"
+
+ seconds = int(seconds)
+ h, rem = divmod(seconds, 3600)
+ m, s = divmod(rem, 60)
+
+ if h > 0:
+ return f"{h}h {m:02d}m {s:02d}s"
+ if m > 0:
+ return f"{m}m {s:02d}s"
+ return f"{s}s"
+
+
+def normalize_color_for_matplotlib(
+ color_tuple_0_255: tuple[int, int, int],
+) -> tuple[float, float, float]:
+ """Converts RGB tuple (0-255) to Matplotlib format (0.0-1.0)."""
+ if isinstance(color_tuple_0_255, tuple) and len(color_tuple_0_255) == 3:
+ valid_color = tuple(max(0, min(255, c)) for c in color_tuple_0_255)
+ return cast("tuple[float, float, float]", tuple(c / 255.0 for c in valid_color))
+ logger.warning(
+ f"Invalid color format for normalization: {color_tuple_0_255}, returning black."
+ )
+ return (0.0, 0.0, 0.0)
+
+
+File: alphatriangle/rl/__init__.py
+"""
+Reinforcement Learning (RL) module.
+Contains the core components for training an agent using self-play and MCTS.
+"""
+
+from .core.buffer import ExperienceBuffer
+from .core.trainer import Trainer
+from .self_play.worker import SelfPlayWorker
+from .types import SelfPlayResult
+
+__all__ = [
+ # Core components used by the training pipeline
+ "Trainer",
+ "ExperienceBuffer",
+ # Self-play components
+ "SelfPlayWorker",
+ "SelfPlayResult",
+]
+
+
+File: alphatriangle/rl/types.py
+import logging
+from typing import Any
+
+import numpy as np
+from pydantic import BaseModel, ConfigDict, Field, model_validator
+
+from ..utils.types import Experience, StateType
+
+logger = logging.getLogger(__name__)
+
+arbitrary_types_config = ConfigDict(arbitrary_types_allowed=True)
+
+
+class SelfPlayResult(BaseModel):
+ """Pydantic model for structuring results from a self-play worker."""
+
+ model_config = arbitrary_types_config
+
+ episode_experiences: list[Experience]
+ final_score: float
+ episode_steps: int
+ trainer_step_at_episode_start: int
+
+ total_simulations: int = Field(..., ge=0)
+ avg_root_visits: float = Field(..., ge=0)
+ avg_tree_depth: float = Field(..., ge=0)
+ context: dict[str, Any] = Field(
+ default_factory=dict,
+ description="Additional context from the episode (e.g., triangles_cleared).",
+ )
+
+ @model_validator(mode="after")
+ def check_experience_structure(self) -> "SelfPlayResult":
+ """Basic structural validation for experiences, including StateType."""
+ invalid_count = 0
+ valid_experiences = []
+ for i, exp in enumerate(self.episode_experiences):
+ is_valid = False
+ reason = "Unknown structure"
+ try:
+ if isinstance(exp, tuple) and len(exp) == 3:
+ state_type: StateType = exp[0]
+ policy_map = exp[1]
+ value = exp[2]
+ if (
+ isinstance(state_type, dict)
+ and "grid" in state_type
+ and "other_features" in state_type
+ and isinstance(state_type["grid"], np.ndarray)
+ and isinstance(state_type["other_features"], np.ndarray)
+ and isinstance(policy_map, dict)
+ and isinstance(value, float | int)
+ ):
+ if np.all(np.isfinite(state_type["grid"])) and np.all(
+ np.isfinite(state_type["other_features"])
+ ):
+ is_valid = True
+ else:
+ reason = "Non-finite features"
+ else:
+ reason = f"Incorrect types or missing keys: state_keys={list(state_type.keys()) if isinstance(state_type, dict) else type(state_type)}, policy={type(policy_map)}, value={type(value)}"
+ else:
+ reason = f"Not a tuple of length 3: type={type(exp)}, len={len(exp) if isinstance(exp, tuple) else 'N/A'}"
+ except Exception as e:
+ reason = f"Validation exception: {e}"
+ logger.error(
+ f"SelfPlayResult validation: Exception validating experience {i}: {e}",
+ exc_info=True,
+ )
+
+ if is_valid:
+ valid_experiences.append(exp)
+ else:
+ invalid_count += 1
+ logger.warning(
+ f"SelfPlayResult validation: Invalid experience structure at index {i}. Reason: {reason}. Data: {exp}"
+ )
+
+ if invalid_count > 0:
+ logger.warning(
+ f"SelfPlayResult validation: Found {invalid_count} invalid experience structures out of {len(self.episode_experiences)}. Keeping only valid ones."
+ )
+ # Use object.__setattr__ to modify the field within the validator
+ object.__setattr__(self, "episode_experiences", valid_experiences)
+
+ return self
+
+
+SelfPlayResult.model_rebuild(force=True)
+
+
+File: alphatriangle/rl/README.md
+# Reinforcement Learning Module (alphatriangle.rl)
+
+## Purpose and Architecture
+
+This module contains core components related to the reinforcement learning algorithm itself, specifically the Trainer for network updates, the ExperienceBuffer for storing data, and the SelfPlayWorker actor for generating data using the trimcts library. **The overall orchestration of the training process resides in the [alphatriangle.training](../training/README.md) module, and statistics/persistence are handled by the trieye library.**
+
+- **Core Components ([core/README.md](core/README.md)):**
+ - Trainer: Responsible for performing the neural network update steps. It takes batches of experience from the buffer, calculates losses (policy cross-entropy, **distributional value cross-entropy**, optional entropy bonus), applies importance sampling weights if using PER, updates the network weights, and calculates TD errors for PER priority updates. **It returns raw loss components (e.g., total_loss, policy_loss, value_loss, entropy, mean_td_error) and TD errors. Sending loss metrics as RawMetricEvents to the TrieyeActor is handled by the TrainingLoop.** Uses trianglengin.EnvConfig.
+ - ExperienceBuffer: A replay buffer storing Experience tuples ((StateType, policy_target, n_step_return)). The StateType contains processed numerical features (grid, other_features). Supports both uniform sampling and Prioritized Experience Replay (PER). **Its content is saved/loaded via the TrieyeActor.**
+- **Self-Play Components ([self_play/README.md](self_play/README.md)):**
+ - worker: Defines the SelfPlayWorker Ray actor. Each actor runs game episodes independently using trimcts.run_mcts and its local copy of the neural network (NeuralNetwork which conforms to trimcts.AlphaZeroNetworkInterface). It collects experiences (including calculated n-step returns) and returns results via a SelfPlayResult object. **It sends raw metric events (like step rewards, MCTS simulations, episode completion details including triangles cleared) asynchronously to the TrieyeActor (using its name to get the handle), tagged with the current_trainer_step (global step of its network weights).** Uses trianglengin.GameState and trianglengin.EnvConfig.
+- **Types ([types.py](types.py)):**
+ - Defines Pydantic models like SelfPlayResult for structured data transfer between Ray actors and the training loop. **SelfPlayResult now includes a context field to pass structured data like triangles_cleared.**
+
+## Exposed Interfaces
+
+- **Core:**
+ - Trainer:
+ - __init__(nn_interface: NeuralNetwork, train_config: TrainConfig, env_config: EnvConfig)
+ - train_step(per_sample: PERBatchSample) -> Optional[Tuple[Dict[str, float], np.ndarray]]: Takes PER sample, returns raw loss info and TD errors.
+ - load_optimizer_state(state_dict: dict)
+ - get_current_lr() -> float
+ - ExperienceBuffer:
+ - __init__(config: TrainConfig)
+ - add(experience: Experience)
+ - add_batch(experiences: List[Experience])
+ - sample(batch_size: int, current_train_step: Optional[int] = None) -> Optional[PERBatchSample]: Samples batch, requires step for PER beta.
+ - update_priorities(tree_indices: np.ndarray, td_errors: np.ndarray): Updates priorities for PER.
+ - is_ready() -> bool
+ - __len__() -> int
+- **Self-Play:**
+ - SelfPlayWorker: Ray actor class.
+ - run_episode() -> SelfPlayResult
+ - set_weights(weights: Dict)
+ - set_current_trainer_step(global_step: int)
+- **Types:**
+ - SelfPlayResult: Pydantic model for self-play results.
+
+## Dependencies
+
+- **[alphatriangle.config](../config/README.md)**: TrainConfig, ModelConfig, AlphaTriangleMCTSConfig.
+- **trianglengin**: GameState, EnvConfig.
+- **trimcts**: run_mcts, SearchConfiguration, AlphaZeroNetworkInterface.
+- **trieye**: TrieyeActor, RawMetricEvent.
+- **[alphatriangle.nn](../nn/README.md)**: NeuralNetwork.
+- **[alphatriangle.features](../features/README.md)**: extract_state_features.
+- **[alphatriangle.utils](../utils/README.md)**: Types (Experience, StateType, PERBatchSample) and helpers (SumTree).
+- **torch**: Used by Trainer and NeuralNetwork.
+- **ray**: Used by SelfPlayWorker.
+- **Standard Libraries:** typing, logging, collections.deque, numpy, random, time.
+
+---
+
+**Note:** Please keep this README updated when changing the responsibilities of the Trainer, Buffer, or SelfPlayWorker, especially regarding how statistics are generated and reported via the TrieyeActor.
+
+
+File: alphatriangle/rl/core/__init__.py
+"""
+Core RL components: Trainer, Buffer.
+The Orchestrator logic has been moved to the alphatriangle.training module.
+"""
+
+from .buffer import ExperienceBuffer
+from .trainer import Trainer
+
+__all__ = [
+ "Trainer",
+ "ExperienceBuffer",
+]
+
+
+File: alphatriangle/rl/core/README.md
+# RL Core Submodule (alphatriangle.rl.core)
+
+## Purpose and Architecture
+
+This submodule contains core classes directly involved in the reinforcement learning update process and data storage. **The orchestration logic resides in the [alphatriangle.training](../../training/README.md) module, and persistence/statistics are handled by the trieye library.**
+
+- **[Trainer](trainer.py):** This class encapsulates the logic for updating the neural network's weights.
+ - It holds the main NeuralNetwork interface, optimizer, and scheduler.
+ - Its train_step method takes a batch of experiences (potentially with PER indices and weights), performs forward/backward passes, calculates losses (policy cross-entropy, **distributional value cross-entropy**, optional entropy bonus), applies importance sampling weights if using PER, updates weights, and returns **raw loss components** and calculated TD errors for PER priority updates. **Logging of these losses is handled externally by the TrainingLoop sending events to the TrieyeActor.** Uses trianglengin.EnvConfig.
+- **[ExperienceBuffer](buffer.py):** This class implements a replay buffer storing Experience tuples ((StateType, policy_target, n_step_return)). It supports Prioritized Experience Replay (PER) via a SumTree, including prioritized sampling and priority updates, based on configuration. **Its content is saved and loaded via the TrieyeActor.**
+
+## Exposed Interfaces
+
+- **Classes:**
+ - Trainer:
+ - __init__(nn_interface: NeuralNetwork, train_config: TrainConfig, env_config: EnvConfig)
+ - train_step(per_sample: PERBatchSample) -> Optional[Tuple[Dict[str, float], np.ndarray]]: Takes PER sample, returns raw loss info and TD errors.
+ - load_optimizer_state(state_dict: dict)
+ - get_current_lr() -> float
+ - ExperienceBuffer:
+ - __init__(config: TrainConfig)
+ - add(experience: Experience)
+ - add_batch(experiences: List[Experience])
+ - sample(batch_size: int, current_train_step: Optional[int] = None) -> Optional[PERBatchSample]: Samples batch, requires step for PER beta.
+ - update_priorities(tree_indices: np.ndarray, td_errors: np.ndarray): Updates priorities for PER.
+ - is_ready() -> bool
+ - __len__() -> int
+
+## Dependencies
+
+- **[alphatriangle.config](../../config/README.md)**: TrainConfig, ModelConfig.
+- **trianglengin**: EnvConfig.
+- **[alphatriangle.nn](../../nn/README.md)**: NeuralNetwork.
+- **[alphatriangle.utils](../../utils/README.md)**: Types (Experience, PERBatchSample, StateType, etc.) and helpers (SumTree).
+- **torch**: Used heavily by Trainer.
+- **Standard Libraries:** typing, logging, collections.deque, numpy, random.
+
+---
+
+**Note:** Please keep this README updated when changing the responsibilities or interfaces of the Trainer or Buffer.
+
+
+File: alphatriangle/rl/core/buffer.py
+import logging
+import random
+from collections import deque
+
+import numpy as np
+
+from ...config import TrainConfig
+from ...utils.sumtree import SumTree
+from ...utils.types import (
+ Experience,
+ ExperienceBatch,
+ PERBatchSample,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ExperienceBuffer:
+ """
+ Experience Replay Buffer storing (StateType, PolicyTarget, Value).
+ Supports both uniform sampling and Prioritized Experience Replay (PER)
+ based on TrainConfig.
+ """
+
+ def __init__(self, config: TrainConfig):
+ self.config = config
+ self.capacity = config.BUFFER_CAPACITY
+ self.min_size_to_train = config.MIN_BUFFER_SIZE_TO_TRAIN
+ self.use_per = config.USE_PER
+
+ if self.use_per:
+ self.tree = SumTree(self.capacity)
+ self.per_alpha = config.PER_ALPHA
+ self.per_beta_initial = config.PER_BETA_INITIAL
+ self.per_beta_final = config.PER_BETA_FINAL
+ # Ensure anneal steps is at least 1 to avoid division by zero
+ self.per_beta_anneal_steps = max(
+ 1, config.PER_BETA_ANNEAL_STEPS or config.MAX_TRAINING_STEPS or 1
+ )
+ self.per_epsilon = config.PER_EPSILON
+ logger.info(
+ f"Experience buffer initialized with PER (alpha={self.per_alpha}, beta_init={self.per_beta_initial}). Capacity: {self.capacity}"
+ )
+ else:
+ self.buffer: deque[Experience] = deque(maxlen=self.capacity)
+ logger.info(
+ f"Experience buffer initialized with uniform sampling. Capacity: {self.capacity}"
+ )
+
+ def _get_priority(self, error: float) -> float:
+ """Calculates priority from TD error."""
+ # Ensure return type is float
+ return float((np.abs(error) + self.per_epsilon) ** self.per_alpha)
+
+ def add(self, experience: Experience):
+ """Adds a single experience. Uses max priority if PER is enabled."""
+ if self.use_per:
+ max_p = self.tree.max_priority
+ self.tree.add(max_p, experience)
+ else:
+ self.buffer.append(experience)
+
+ def add_batch(self, experiences: list[Experience]):
+ """Adds a batch of experiences. Uses max priority if PER is enabled."""
+ if self.use_per:
+ max_p = self.tree.max_priority
+ for exp in experiences:
+ self.tree.add(max_p, exp)
+ else:
+ self.buffer.extend(experiences)
+
+ def _calculate_beta(self, current_step: int) -> float:
+ """Linearly anneals beta from initial to final value."""
+ fraction = min(1.0, current_step / self.per_beta_anneal_steps)
+ beta = self.per_beta_initial + fraction * (
+ self.per_beta_final - self.per_beta_initial
+ )
+ return beta
+
+ def sample(
+ self, batch_size: int, current_train_step: int | None = None
+ ) -> PERBatchSample | None:
+ """
+ Samples a batch of experiences.
+ Uses prioritized sampling if PER is enabled, otherwise uniform.
+ Requires current_train_step if PER is enabled to calculate beta.
+ """
+ current_size = len(self)
+ if current_size < batch_size or current_size < self.min_size_to_train:
+ return None
+
+ if self.use_per:
+ if current_train_step is None:
+ raise ValueError("current_train_step is required for PER sampling.")
+
+ batch: ExperienceBatch = []
+ idxs = np.empty((batch_size,), dtype=np.int32)
+ is_weights = np.empty((batch_size,), dtype=np.float32)
+ beta = self._calculate_beta(current_train_step)
+
+ priority_segment = self.tree.total_priority / batch_size
+ max_weight = 0.0
+
+ for i in range(batch_size):
+ a = priority_segment * i
+ b = priority_segment * (i + 1)
+ value = random.uniform(a, b)
+ idx, p, data = self.tree.get_leaf(value)
+
+ if not isinstance(data, tuple):
+ logger.warning(
+ f"PER sampling encountered non-experience data at index {idx}. Resampling."
+ )
+ # Resample with a random value across the entire range
+ value = random.uniform(0, self.tree.total_priority)
+ idx, p, data = self.tree.get_leaf(value)
+ if not isinstance(data, tuple):
+ logger.error(f"PER resampling failed. Skipping sample {i}.")
+ # Fallback: sample a random valid index if possible
+ if self.tree.n_entries > 0:
+ rand_data_idx = random.randint(0, self.tree.n_entries - 1)
+ rand_tree_idx = rand_data_idx + self.capacity - 1
+ idx, p, data = self.tree.get_leaf(
+ self.tree.tree[rand_tree_idx]
+ )
+ if not isinstance(data, tuple):
+ continue # Give up on this sample if fallback fails
+ else:
+ continue # Cannot sample if tree is empty
+
+ sampling_prob = p / self.tree.total_priority
+ weight = (
+ (current_size * sampling_prob) ** (-beta)
+ if sampling_prob > 1e-9
+ else 0.0
+ )
+ is_weights[i] = weight
+ max_weight = max(max_weight, weight)
+ idxs[i] = idx
+ batch.append(data)
+
+ if max_weight > 1e-9:
+ is_weights /= max_weight
+ else:
+ logger.warning(
+ "Max importance sampling weight is near zero. Weights might be invalid."
+ )
+ is_weights.fill(1.0)
+
+ return {"batch": batch, "indices": idxs, "weights": is_weights}
+
+ else:
+ uniform_batch = random.sample(self.buffer, batch_size)
+ dummy_indices = np.zeros(batch_size, dtype=np.int32)
+ uniform_weights = np.ones(batch_size, dtype=np.float32)
+ return {
+ "batch": uniform_batch,
+ "indices": dummy_indices,
+ "weights": uniform_weights,
+ }
+
+ def update_priorities(self, tree_indices: np.ndarray, td_errors: np.ndarray):
+ """Updates the priorities of sampled experiences based on TD errors."""
+ if not self.use_per:
+ return
+
+ if len(tree_indices) != len(td_errors):
+ logger.error(
+ f"Mismatch between tree_indices ({len(tree_indices)}) and td_errors ({len(td_errors)}) lengths."
+ )
+ return
+
+ # Calculate priorities for each error
+ priorities = np.array([self._get_priority(err) for err in td_errors])
+
+ if not np.all(np.isfinite(priorities)):
+ logger.warning("Non-finite priorities calculated. Clamping.")
+ priorities = np.nan_to_num(
+ priorities,
+ nan=self.per_epsilon,
+ posinf=self.tree.max_priority,
+ neginf=self.per_epsilon,
+ )
+ priorities = np.maximum(priorities, self.per_epsilon)
+
+ # Use strict=False for zip, although lengths should match after check above
+ for idx, p in zip(tree_indices, priorities, strict=False):
+ if not (0 <= idx < len(self.tree.tree)):
+ logger.error(f"Invalid tree index {idx} provided for priority update.")
+ continue
+ self.tree.update(idx, p)
+
+ # Update the overall max priority tracked by the tree
+ if len(priorities) > 0:
+ self.tree._max_priority = max(self.tree.max_priority, np.max(priorities))
+
+ def __len__(self) -> int:
+ """Returns the current number of experiences in the buffer."""
+ return self.tree.n_entries if self.use_per else len(self.buffer)
+
+ def is_ready(self) -> bool:
+ """Checks if the buffer has enough samples to start training."""
+ return len(self) >= self.min_size_to_train
+
+
+File: alphatriangle/rl/core/trainer.py
+import logging
+from typing import cast
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.optim.lr_scheduler import _LRScheduler
+
+# Import EnvConfig from trianglengin's top level
+from trianglengin import EnvConfig
+
+# Keep alphatriangle imports
+from ...config import TrainConfig
+from ...nn import NeuralNetwork
+from ...utils.types import ExperienceBatch, PERBatchSample
+
+logger = logging.getLogger(__name__)
+
+
+class Trainer:
+ """
+ Handles the neural network training process, including loss calculation
+ and optimizer steps. Supports Distributional RL (C51) value loss.
+ """
+
+ def __init__(
+ self,
+ nn_interface: NeuralNetwork,
+ train_config: TrainConfig,
+ env_config: EnvConfig,
+ ):
+ self.nn = nn_interface
+ self.model = nn_interface.model
+ self.train_config = train_config
+ self.env_config = env_config
+ self.model_config = nn_interface.model_config
+ self.device = nn_interface.device
+ self.optimizer = self._create_optimizer()
+ self.scheduler: _LRScheduler | None = self._create_scheduler(self.optimizer)
+
+ self.num_atoms = self.nn.num_atoms
+ self.v_min = self.nn.v_min
+ self.v_max = self.nn.v_max
+ self.delta_z = self.nn.delta_z
+ self.support = self.nn.support.to(self.device)
+
+ def _create_optimizer(self) -> optim.Optimizer:
+ """Creates the optimizer based on TrainConfig."""
+ lr = self.train_config.LEARNING_RATE
+ wd = self.train_config.WEIGHT_DECAY
+ params = self.model.parameters()
+ opt_type = self.train_config.OPTIMIZER_TYPE.lower()
+ logger.info(f"Creating optimizer: {opt_type}, LR: {lr}, WD: {wd}")
+ if opt_type == "adam":
+ return optim.Adam(params, lr=lr, weight_decay=wd)
+ elif opt_type == "adamw":
+ return optim.AdamW(params, lr=lr, weight_decay=wd)
+ elif opt_type == "sgd":
+ return optim.SGD(params, lr=lr, weight_decay=wd, momentum=0.9)
+ else:
+ raise ValueError(
+ f"Unsupported optimizer type: {self.train_config.OPTIMIZER_TYPE}"
+ )
+
+ def _create_scheduler(self, optimizer: optim.Optimizer) -> _LRScheduler | None:
+ """Creates the learning rate scheduler based on TrainConfig."""
+ scheduler_type_config = self.train_config.LR_SCHEDULER_TYPE
+ scheduler_type: str | None = None
+ if scheduler_type_config:
+ scheduler_type = scheduler_type_config.lower()
+
+ if not scheduler_type or scheduler_type == "none":
+ logger.info("No LR scheduler configured.")
+ return None
+
+ logger.info(f"Creating LR scheduler: {scheduler_type}")
+ if scheduler_type == "steplr":
+ step_size = getattr(self.train_config, "LR_SCHEDULER_STEP_SIZE", 100000)
+ gamma = getattr(self.train_config, "LR_SCHEDULER_GAMMA", 0.1)
+ logger.info(f" StepLR params: step_size={step_size}, gamma={gamma}")
+ return cast(
+ "_LRScheduler",
+ optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma),
+ )
+ elif scheduler_type == "cosineannealinglr":
+ t_max = self.train_config.LR_SCHEDULER_T_MAX
+ eta_min = self.train_config.LR_SCHEDULER_ETA_MIN
+ if t_max is None:
+ logger.warning(
+ "LR_SCHEDULER_T_MAX is None for CosineAnnealingLR. Scheduler might not work as expected."
+ )
+ t_max = self.train_config.MAX_TRAINING_STEPS or 1_000_000
+ logger.info(f" CosineAnnealingLR params: T_max={t_max}, eta_min={eta_min}")
+ return cast(
+ "_LRScheduler",
+ optim.lr_scheduler.CosineAnnealingLR(
+ optimizer, T_max=t_max, eta_min=eta_min
+ ),
+ )
+ else:
+ raise ValueError(f"Unsupported scheduler type: {scheduler_type_config}")
+
+ def _prepare_batch(
+ self, batch: ExperienceBatch
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Converts a batch of experiences into tensors.
+ The 4th tensor is now the n-step return G (scalar).
+ """
+ batch_size = len(batch)
+ grids = []
+ other_features = []
+ n_step_returns = []
+ action_dim_int = int(
+ self.env_config.NUM_SHAPE_SLOTS
+ * self.env_config.ROWS
+ * self.env_config.COLS
+ )
+ policy_target_tensor = torch.zeros(
+ (batch_size, action_dim_int),
+ dtype=torch.float32,
+ device=self.device,
+ )
+
+ for i, (state_features, policy_target_map, n_step_return) in enumerate(batch):
+ grids.append(state_features["grid"])
+ other_features.append(state_features["other_features"])
+ n_step_returns.append(n_step_return)
+ for action, prob in policy_target_map.items():
+ if 0 <= action < action_dim_int:
+ policy_target_tensor[i, action] = prob
+ else:
+ logger.warning(
+ f"Action {action} out of bounds in policy target map for sample {i}."
+ )
+
+ grid_tensor = torch.from_numpy(np.stack(grids)).to(self.device)
+ other_features_tensor = torch.from_numpy(np.stack(other_features)).to(
+ self.device
+ )
+ n_step_return_tensor = torch.tensor(
+ n_step_returns, dtype=torch.float32, device=self.device
+ )
+
+ expected_other_dim = self.model_config.OTHER_NN_INPUT_FEATURES_DIM
+ if batch_size > 0 and other_features_tensor.shape[1] != expected_other_dim:
+ raise ValueError(
+ f"Unexpected other_features tensor shape: {other_features_tensor.shape}, expected dim {expected_other_dim}"
+ )
+
+ return (
+ grid_tensor,
+ other_features_tensor,
+ policy_target_tensor,
+ n_step_return_tensor,
+ )
+
+ def _calculate_target_distribution(
+ self, n_step_returns: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Projects the n-step returns onto the fixed support atoms (z).
+ Args:
+ n_step_returns: Tensor of shape (batch_size,) containing scalar n-step returns (G).
+ Returns:
+ Tensor of shape (batch_size, num_atoms) representing the target distribution.
+ """
+ batch_size = n_step_returns.size(0)
+ m = torch.zeros(
+ (batch_size, self.num_atoms), dtype=torch.float32, device=self.device
+ )
+ target_returns = n_step_returns.clamp(self.v_min, self.v_max)
+ b = (target_returns - self.v_min) / self.delta_z
+ lower_idx = b.floor().long()
+ u = b.ceil().long()
+ # Ensure indices are within bounds [0, num_atoms - 1]
+ lower_idx = torch.clamp(lower_idx, 0, self.num_atoms - 1)
+ u = torch.clamp(u, 0, self.num_atoms - 1)
+
+ # Handle cases where lower and upper indices are the same
+ eq_mask = lower_idx == u
+ ne_mask = ~eq_mask
+
+ # Calculate weights for non-equal indices
+ m_l = u[ne_mask].float() - b[ne_mask]
+ m_u = b[ne_mask] - lower_idx[ne_mask].float()
+
+ # Distribute probability mass
+ batch_indices = torch.arange(batch_size, device=self.device)
+ m[batch_indices[ne_mask], lower_idx[ne_mask]] += m_l
+ m[batch_indices[ne_mask], u[ne_mask]] += m_u
+
+ # Handle equal indices (assign full probability mass)
+ m[batch_indices[eq_mask], lower_idx[eq_mask]] += 1.0
+
+ # Normalize rows just in case of floating point issues near boundaries
+ # This might not be strictly necessary but adds robustness
+ # row_sums = m.sum(dim=1, keepdim=True)
+ # m = m / (row_sums + 1e-9) # Add epsilon for stability
+
+ return m
+
+ def train_step(
+ self, per_sample: PERBatchSample
+ ) -> tuple[dict[str, float], np.ndarray] | None:
+ """
+ Performs a single training step on the given batch from PER buffer.
+ Uses distributional cross-entropy loss for the value head.
+ Returns raw loss info dictionary and TD errors for priority updates.
+ """
+ batch = per_sample["batch"]
+ is_weights = per_sample["weights"]
+
+ if not batch:
+ logger.warning("train_step called with empty batch.")
+ return None
+
+ self.model.train()
+ try:
+ grid_t, other_t, policy_target_t, n_step_return_t = self._prepare_batch(
+ batch
+ )
+ is_weights_t = torch.from_numpy(is_weights).to(self.device)
+ except Exception as e:
+ logger.error(f"Error preparing batch for training: {e}", exc_info=True)
+ return None
+
+ self.optimizer.zero_grad()
+ policy_logits, value_logits = self.model(grid_t, other_t)
+
+ # --- Value Loss (Distributional Cross-Entropy) ---
+ with torch.no_grad():
+ target_distribution = self._calculate_target_distribution(n_step_return_t)
+ log_pred_dist = F.log_softmax(value_logits, dim=1)
+ # Ensure target_distribution sums to 1 for valid cross-entropy calculation
+ # target_distribution = target_distribution / (target_distribution.sum(dim=1, keepdim=True) + 1e-9)
+ value_loss_elementwise = -torch.sum(target_distribution * log_pred_dist, dim=1)
+ value_loss = (value_loss_elementwise * is_weights_t).mean()
+
+ # --- Policy Loss (Cross-Entropy) ---
+ log_probs = F.log_softmax(policy_logits, dim=1)
+ policy_target_t = torch.nan_to_num(policy_target_t, nan=0.0)
+ # Ensure policy target sums to 1
+ # policy_target_t = policy_target_t / (policy_target_t.sum(dim=1, keepdim=True) + 1e-9)
+ policy_loss_elementwise = -torch.sum(policy_target_t * log_probs, dim=1)
+ policy_loss = (policy_loss_elementwise * is_weights_t).mean()
+
+ # --- Entropy Bonus ---
+ entropy_scalar: float = 0.0
+ entropy_loss_term = torch.tensor(0.0, device=self.device)
+ if self.train_config.ENTROPY_BONUS_WEIGHT > 0:
+ policy_probs = F.softmax(policy_logits, dim=1)
+ # Add epsilon for numerical stability in log
+ entropy_term_elementwise: torch.Tensor = -torch.sum(
+ policy_probs * torch.log(policy_probs + 1e-9), dim=1
+ )
+ # Calculate mean entropy across batch BEFORE applying weights
+ entropy_scalar = float(entropy_term_elementwise.mean().item())
+ # Apply entropy bonus loss (weighted mean if needed, but usually mean is fine)
+ entropy_loss_term = (
+ -self.train_config.ENTROPY_BONUS_WEIGHT
+ * entropy_term_elementwise.mean() # Use mean entropy, not weighted
+ )
+
+ # --- Total Loss ---
+ total_loss = (
+ self.train_config.POLICY_LOSS_WEIGHT * policy_loss
+ + self.train_config.VALUE_LOSS_WEIGHT * value_loss
+ + entropy_loss_term
+ )
+
+ # --- Backpropagation and Optimization ---
+ total_loss.backward()
+
+ if (
+ self.train_config.GRADIENT_CLIP_VALUE is not None
+ and self.train_config.GRADIENT_CLIP_VALUE > 0
+ ):
+ torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(), self.train_config.GRADIENT_CLIP_VALUE
+ )
+
+ self.optimizer.step()
+ if self.scheduler:
+ self.scheduler.step()
+
+ # --- Calculate TD Errors for PER ---
+ with torch.no_grad():
+ # Use the element-wise value loss as the TD error proxy for distributional RL
+ # This is a common heuristic. Alternatively, calculate expected value difference.
+ # Using value_loss_elementwise directly aligns priorities with the loss being minimized.
+ td_errors = value_loss_elementwise.detach().cpu().numpy()
+ # Alternative: Calculate expected value difference
+ # expected_value_pred = self.nn._logits_to_expected_value(value_logits)
+ # td_errors = (n_step_return_t - expected_value_pred.squeeze(1)).detach().cpu().numpy()
+
+ # --- Prepare Raw Loss Info ---
+ # Return individual loss components for logging by the stats system
+ loss_info = {
+ "total_loss": total_loss.item(),
+ "policy_loss": policy_loss.item(),
+ "value_loss": value_loss.item(),
+ "entropy": entropy_scalar, # Report the mean entropy scalar
+ "mean_td_error": float(
+ np.mean(np.abs(td_errors))
+ ), # Keep reporting mean TD error
+ }
+
+ return loss_info, td_errors
+
+ def get_current_lr(self) -> float:
+ """Returns the current learning rate from the optimizer."""
+ try:
+ # Ensure optimizer and param_groups exist
+ if self.optimizer and self.optimizer.param_groups:
+ return float(self.optimizer.param_groups[0]["lr"])
+ else:
+ logger.warning("Optimizer or param_groups not available.")
+ return 0.0
+ except (IndexError, KeyError, AttributeError) as e:
+ logger.warning(f"Could not retrieve learning rate from optimizer: {e}")
+ return 0.0
+
+
+File: alphatriangle/rl/self_play/worker.py
+# File: alphatriangle/rl/self_play/worker.py
+import cProfile
+import logging
+import random
+import time
+from collections import deque
+from pathlib import Path
+from typing import TYPE_CHECKING, Any
+
+import ray
+from ray.actor import ActorHandle
+from trianglengin import EnvConfig, GameState
+from trieye.schemas import RawMetricEvent # Import from trieye
+from trimcts import SearchConfiguration, run_mcts
+
+from ...config import ModelConfig, TrainConfig
+from ...features import extract_state_features
+from ...nn import NeuralNetwork
+from ...utils import get_device, set_random_seeds
+from ..types import SelfPlayResult
+from .mcts_helpers import (
+ PolicyGenerationError,
+ get_policy_target_from_visits,
+ select_action_from_visits,
+)
+
+if TYPE_CHECKING:
+ from ...utils.types import Experience, PolicyTargetMapping, StateType
+
+ MctsTreeHandle = Any
+
+logger = logging.getLogger(__name__)
+
+
+@ray.remote
+class SelfPlayWorker:
+ """
+ A Ray actor responsible for running self-play episodes using trimcts and a NN.
+ Uses trianglengin.GameState and supports MCTS tree reuse.
+ Sends raw metric events to the TrieyeActor using its name.
+ """
+
+ def __init__(
+ self,
+ actor_id: int,
+ env_config: EnvConfig,
+ mcts_config: SearchConfiguration,
+ model_config: ModelConfig,
+ train_config: TrainConfig,
+ trieye_actor_name: str, # Changed from handle to name
+ run_base_dir: str,
+ initial_weights: dict | None = None,
+ seed: int | None = None,
+ worker_device_str: str = "cpu",
+ profile_this_worker: bool = False,
+ ):
+ self.actor_id = actor_id
+ self.env_config = env_config
+ self.mcts_config = mcts_config
+ self.model_config = model_config
+ self.train_config = train_config
+ self.trieye_actor_name = trieye_actor_name # Store actor name
+ self.run_base_dir = Path(run_base_dir)
+ self.seed = seed if seed is not None else random.randint(0, 1_000_000)
+ self.worker_device_str = worker_device_str
+ self.profile_this_worker = profile_this_worker
+
+ self.n_step = self.train_config.N_STEP_RETURNS
+ self.gamma = self.train_config.GAMMA
+ self.current_trainer_step = 0
+
+ global logger
+ logger = logging.getLogger(__name__)
+
+ set_random_seeds(self.seed)
+ self.device = get_device(self.worker_device_str)
+
+ self.nn_evaluator = NeuralNetwork(
+ model_config=self.model_config,
+ env_config=self.env_config,
+ train_config=self.train_config,
+ device=self.device,
+ )
+
+ if initial_weights:
+ self.set_weights(initial_weights)
+ else:
+ self.nn_evaluator.model.eval()
+
+ # Cache the actor handle after first lookup
+ self._trieye_actor_handle: ActorHandle | None = None
+
+ logger.debug(f"INIT: MCTS Config: {self.mcts_config.model_dump()}")
+ logger.info(
+ f"Worker {self.actor_id} initialized on device {self.device}. Seed: {self.seed}. LogLevel: {logging.getLevelName(logger.getEffectiveLevel())}. TrieyeActor Name: {self.trieye_actor_name}"
+ )
+ logger.debug(f"Worker {self.actor_id} init complete.")
+
+ self.profiler: cProfile.Profile | None = None
+ if self.profile_this_worker:
+ self.profiler = cProfile.Profile()
+ logger.warning(f"Worker {self.actor_id}: Profiling ENABLED.")
+ else:
+ logger.info(f"Worker {self.actor_id}: Profiling DISABLED.")
+
+ def _get_trieye_actor(self) -> ActorHandle | None:
+ """Gets the TrieyeActor handle, caching it after the first lookup."""
+ if self._trieye_actor_handle is None:
+ try:
+ self._trieye_actor_handle = ray.get_actor(self.trieye_actor_name)
+ logger.info(
+ f"Worker {self.actor_id}: Successfully got TrieyeActor handle."
+ )
+ except ValueError:
+ logger.error(
+ f"Worker {self.actor_id}: Could not find TrieyeActor named '{self.trieye_actor_name}'. Logging disabled."
+ )
+ self._trieye_actor_handle = None
+ except Exception as e:
+ logger.error(
+ f"Worker {self.actor_id}: Error getting TrieyeActor handle: {e}"
+ )
+ self._trieye_actor_handle = None
+ return self._trieye_actor_handle
+
+ def set_weights(self, weights: dict):
+ """Updates the neural network weights."""
+ try:
+ self.nn_evaluator.set_weights(weights)
+ logger.info(f"Worker {self.actor_id}: Weights updated.")
+ except Exception as e:
+ logger.error(
+ f"Worker {self.actor_id}: Failed to set weights: {e}", exc_info=True
+ )
+
+ def set_current_trainer_step(self, global_step: int):
+ """Sets the global step corresponding to the current network weights."""
+ self.current_trainer_step = global_step
+ logger.info(f"Worker {self.actor_id}: Trainer step set to {global_step}")
+
+ def _send_event_async(
+ self, name: str, value: float | int, context: dict | None = None
+ ):
+ """Helper to send a raw metric event to the Trieye actor asynchronously."""
+ actor_handle = self._get_trieye_actor()
+ if actor_handle:
+ event = RawMetricEvent(
+ name=name,
+ value=value,
+ global_step=self.current_trainer_step, # Use worker's current trainer step
+ timestamp=time.time(),
+ context=context or {},
+ )
+ try:
+ # Fire-and-forget remote call
+ actor_handle.log_event.remote(event)
+ except Exception as e:
+ logger.error(
+ f"Worker {self.actor_id}: Failed to send event '{name}' to Trieye actor: {e}"
+ )
+ else:
+ logger.warning(
+ f"Worker {self.actor_id}: Cannot send event '{name}', TrieyeActor handle not available."
+ )
+
+ def run_episode(self) -> SelfPlayResult:
+ """Runs a single episode of self-play with MCTS tree reuse."""
+ step_at_start = self.current_trainer_step
+ logger.info(
+ f"Worker {self.actor_id}: Starting run_episode. Seed: {self.seed}. Using network from trainer step: {step_at_start}"
+ )
+ if self.profiler:
+ self.profiler.enable()
+
+ result: SelfPlayResult | None = None
+ episode_seed = self.seed + random.randint(0, 1000)
+ game: GameState | None = None
+ mcts_tree_handle: MctsTreeHandle | None = None
+ last_action: int = -1
+ final_score = 0.0
+ final_step = 0
+ total_sims_episode = 0
+ total_triangles_cleared_episode = 0
+
+ try:
+ self.nn_evaluator.model.eval()
+ logger.debug(
+ f"Worker {self.actor_id}: Initializing GameState with seed {episode_seed}..."
+ )
+ game = GameState(self.env_config, initial_seed=episode_seed)
+ logger.debug(
+ f"Worker {self.actor_id}: GameState initialized. Initial step: {game.current_step}, Initial score: {game.game_score()}"
+ )
+
+ if game.is_over():
+ reason = game.get_game_over_reason() or "Unknown reason at start"
+ logger.error(
+ f"Worker {self.actor_id}: Game over immediately after reset (Seed: {episode_seed}). Reason: {reason}"
+ )
+ episode_end_context = {
+ "score": game.game_score(),
+ "length": 0,
+ "simulations": 0,
+ "triangles_cleared": 0,
+ "trainer_step": step_at_start,
+ }
+ # Send event even on immediate game over
+ self._send_event_async("episode_end", 1.0, context=episode_end_context)
+ return SelfPlayResult(
+ episode_experiences=[],
+ final_score=game.game_score(),
+ episode_steps=0,
+ trainer_step_at_episode_start=step_at_start,
+ total_simulations=0,
+ avg_root_visits=0.0,
+ avg_tree_depth=0.0,
+ context=episode_end_context,
+ )
+ if not game.valid_actions():
+ logger.error(
+ f"Worker {self.actor_id}: Game not over, but NO valid actions at start (Seed: {episode_seed})."
+ )
+ episode_end_context = {
+ "score": game.game_score(),
+ "length": 0,
+ "simulations": 0,
+ "triangles_cleared": 0,
+ "trainer_step": step_at_start,
+ }
+ self._send_event_async("episode_end", 1.0, context=episode_end_context)
+ return SelfPlayResult(
+ episode_experiences=[],
+ final_score=game.game_score(),
+ episode_steps=0,
+ trainer_step_at_episode_start=step_at_start,
+ total_simulations=0,
+ avg_root_visits=0.0,
+ avg_tree_depth=0.0,
+ context=episode_end_context,
+ )
+
+ n_step_state_policy_buffer: deque[tuple[StateType, PolicyTargetMapping]] = (
+ deque(maxlen=self.n_step)
+ )
+ n_step_reward_buffer: deque[float] = deque(maxlen=self.n_step)
+ episode_experiences: list[Experience] = []
+ last_total_visits = 0
+
+ logger.info(
+ f"Worker {self.actor_id}: Starting episode loop with seed {episode_seed}"
+ )
+
+ while not game.is_over():
+ current_step_in_loop = game.current_step
+ logger.debug(
+ f"Worker {self.actor_id}: --- Starting Step {current_step_in_loop} ---"
+ )
+ step_start_time = time.monotonic()
+
+ if game.is_over():
+ logger.warning(
+ f"Worker {self.actor_id}: State terminal before MCTS (Step {current_step_in_loop}). Exiting loop."
+ )
+ break
+
+ logger.debug(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: Running MCTS ({self.mcts_config.max_simulations} sims)..."
+ )
+ mcts_start_time = time.monotonic()
+ visit_counts: dict[int, int] = {}
+ new_mcts_tree_handle: MctsTreeHandle | None = None
+ try:
+ visit_counts, new_mcts_tree_handle = run_mcts(
+ root_state=game,
+ network_interface=self.nn_evaluator,
+ config=self.mcts_config,
+ previous_tree_handle=mcts_tree_handle,
+ last_action=last_action,
+ )
+ mcts_tree_handle = new_mcts_tree_handle
+ logger.debug(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: MCTS returned visit_counts: {len(visit_counts)} actions."
+ )
+ except Exception as mcts_err:
+ logger.error(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: trimcts failed: {mcts_err}",
+ exc_info=True,
+ )
+ mcts_tree_handle = None
+ break
+ mcts_duration = time.monotonic() - mcts_start_time
+ last_total_visits = sum(visit_counts.values())
+ total_sims_episode += last_total_visits
+ logger.debug(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: MCTS finished ({mcts_duration:.3f}s). Total visits: {last_total_visits}"
+ )
+
+ # Send MCTS simulation count event
+ self._send_event_async(
+ "mcts_step",
+ float(last_total_visits),
+ context={"game_step": current_step_in_loop},
+ )
+
+ if not visit_counts:
+ logger.error(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: MCTS returned empty visit counts. Cannot proceed."
+ )
+ break
+
+ action_selection_start_time = time.monotonic()
+ temp = 1.0
+ selection_temp = 1.0
+ if self.train_config.MAX_TRAINING_STEPS is not None:
+ explore_steps = self.train_config.MAX_TRAINING_STEPS * 0.1
+ selection_temp = (
+ 1.0 if current_step_in_loop < explore_steps else 0.1
+ )
+
+ action_dim_int = int(
+ self.env_config.NUM_SHAPE_SLOTS
+ * self.env_config.ROWS
+ * self.env_config.COLS
+ )
+ action: int = -1
+ try:
+ policy_target = get_policy_target_from_visits(
+ visit_counts, action_dim_int, temperature=temp
+ )
+ action = select_action_from_visits(
+ visit_counts, temperature=selection_temp
+ )
+ logger.debug(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: Policy target generated. Action selected: {action}"
+ )
+ except PolicyGenerationError as policy_err:
+ logger.error(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: Policy/Action selection failed: {policy_err}",
+ exc_info=False,
+ )
+ break
+ except Exception as policy_err:
+ logger.error(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: Unexpected policy error: {policy_err}",
+ exc_info=True,
+ )
+ break
+ action_selection_duration = (
+ time.monotonic() - action_selection_start_time
+ )
+ logger.debug(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: Action selection time: {action_selection_duration:.4f}s"
+ )
+
+ feature_start_time = time.monotonic()
+ try:
+ state_features: StateType = extract_state_features(
+ game, self.model_config
+ )
+ except Exception as e:
+ logger.error(
+ f"Worker {self.actor_id}: Feature extraction error (Step {current_step_in_loop}): {e}",
+ exc_info=True,
+ )
+ break
+ feature_duration = time.monotonic() - feature_start_time
+ logger.debug(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: Feature extraction time: {feature_duration:.4f}s"
+ )
+
+ n_step_state_policy_buffer.append((state_features, policy_target))
+
+ game_step_start_time = time.monotonic()
+ step_reward, done = 0.0, False
+ cleared_triangles_this_step = 0
+ try:
+ step_reward, done = game.step(action)
+ cleared_triangles_this_step = game.get_last_cleared_triangles() # type: ignore [attr-defined]
+ total_triangles_cleared_episode += cleared_triangles_this_step
+ logger.debug(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: game.step({action}) -> Reward: {step_reward:.3f}, Done: {done}, Cleared: {cleared_triangles_this_step}"
+ )
+ last_action = action
+ except Exception as step_err:
+ logger.error(
+ f"Worker {self.actor_id}: Game step error (Action {action}, Step {current_step_in_loop}): {step_err}",
+ exc_info=True,
+ )
+ last_action = -1
+ break
+ game_step_duration = time.monotonic() - game_step_start_time
+ logger.debug(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: Game step time: {game_step_duration:.4f}s"
+ )
+
+ n_step_reward_buffer.append(step_reward)
+ # Send step reward event
+ self._send_event_async(
+ "step_reward",
+ step_reward,
+ context={"game_step": current_step_in_loop},
+ )
+ if cleared_triangles_this_step > 0:
+ self._send_event_async(
+ "triangles_cleared_step",
+ cleared_triangles_this_step,
+ context={"game_step": current_step_in_loop},
+ )
+
+ if len(n_step_reward_buffer) == self.n_step:
+ discounted_reward_sum = sum(
+ (self.gamma**i) * n_step_reward_buffer[i]
+ for i in range(self.n_step)
+ )
+ bootstrap_value = 0.0
+ if not done:
+ try:
+ _, bootstrap_value = self.nn_evaluator.evaluate_state(game)
+ except Exception as eval_err:
+ logger.error(
+ f"Worker {self.actor_id}: Bootstrap eval error (Step {game.current_step}): {eval_err}",
+ exc_info=True,
+ )
+ bootstrap_value = 0.0
+ n_step_return = (
+ discounted_reward_sum
+ + (self.gamma**self.n_step) * bootstrap_value
+ )
+ state_features_t_minus_n, policy_target_t_minus_n = (
+ n_step_state_policy_buffer[0]
+ )
+ exp: Experience = (
+ state_features_t_minus_n,
+ policy_target_t_minus_n,
+ n_step_return,
+ )
+ episode_experiences.append(exp)
+ logger.debug(
+ f"Worker {self.actor_id}: Step {current_step_in_loop}: Stored experience for step {current_step_in_loop - self.n_step + 1}. N-Step Return: {n_step_return:.4f}"
+ )
+
+ # Send current score event
+ self._send_event_async(
+ "current_score",
+ game.game_score(),
+ context={"game_step": current_step_in_loop},
+ )
+
+ step_duration = time.monotonic() - step_start_time
+ logger.debug(
+ f"Worker {self.actor_id}: --- Finished Step {current_step_in_loop}. Duration: {step_duration:.3f}s ---"
+ )
+
+ if done:
+ logger.info(
+ f"Worker {self.actor_id}: Game ended naturally at step {game.current_step}. Reason: {game.get_game_over_reason()}"
+ )
+ break
+
+ final_score = game.game_score() if game else 0.0
+ final_step = game.current_step if game else 0
+ logger.info(
+ f"Worker {self.actor_id}: Episode loop finished. Final Score: {final_score:.2f}, Final Step: {final_step}, Total Cleared: {total_triangles_cleared_episode}"
+ )
+
+ remaining_steps = len(n_step_reward_buffer)
+ logger.debug(
+ f"Worker {self.actor_id}: Processing {remaining_steps} remaining n-step items."
+ )
+ for k in range(remaining_steps):
+ discounted_reward_sum = sum(
+ (self.gamma**i) * n_step_reward_buffer[k + i]
+ for i in range(remaining_steps - k)
+ )
+ n_step_return = discounted_reward_sum
+ state_features_t, policy_target_t = n_step_state_policy_buffer[k]
+ final_exp: Experience = (
+ state_features_t,
+ policy_target_t,
+ n_step_return,
+ )
+ episode_experiences.append(final_exp)
+ logger.debug(
+ f"Worker {self.actor_id}: Stored final experience for step {final_step - remaining_steps + k + 1}. N-Step Return: {n_step_return:.4f}"
+ )
+
+ if not episode_experiences:
+ logger.warning(
+ f"Worker {self.actor_id}: Episode finished with 0 experiences collected. Score: {final_score}, Steps: {final_step}"
+ )
+
+ # Send episode end event
+ episode_end_context = {
+ "score": final_score,
+ "length": final_step,
+ "simulations": total_sims_episode,
+ "triangles_cleared": total_triangles_cleared_episode,
+ "trainer_step": step_at_start, # Log trainer step when episode started
+ }
+ self._send_event_async(
+ name="episode_end", value=1.0, context=episode_end_context
+ )
+
+ result = SelfPlayResult(
+ episode_experiences=episode_experiences,
+ final_score=final_score,
+ episode_steps=final_step,
+ trainer_step_at_episode_start=step_at_start,
+ total_simulations=total_sims_episode,
+ avg_root_visits=float(last_total_visits),
+ avg_tree_depth=0.0, # Placeholder, trimcts doesn't expose this easily
+ context=episode_end_context, # Pass context for potential use in loop
+ )
+ logger.info(
+ f"Worker {self.actor_id}: Episode result created. Experiences: {len(result.episode_experiences)}"
+ )
+
+ except Exception as e:
+ logger.critical(
+ f"Worker {self.actor_id}: Unhandled exception in run_episode: {e}",
+ exc_info=True,
+ )
+ final_score = game.game_score() if game else 0.0
+ final_step = game.current_step if game else 0
+ episode_end_context = {
+ "score": final_score,
+ "length": final_step,
+ "simulations": total_sims_episode,
+ "triangles_cleared": total_triangles_cleared_episode,
+ "trainer_step": step_at_start,
+ "error": True,
+ }
+ self._send_event_async(
+ "episode_end",
+ 1.0,
+ context=episode_end_context,
+ )
+ result = SelfPlayResult(
+ episode_experiences=[],
+ final_score=final_score,
+ episode_steps=final_step,
+ trainer_step_at_episode_start=step_at_start,
+ total_simulations=total_sims_episode,
+ avg_root_visits=0.0,
+ avg_tree_depth=0.0,
+ context=episode_end_context,
+ )
+
+ finally:
+ mcts_tree_handle = None
+ if self.profiler:
+ self.profiler.disable()
+ profile_dir = self.run_base_dir / "profile_data"
+ profile_dir.mkdir(exist_ok=True)
+ profile_filename = (
+ profile_dir / f"worker_{self.actor_id}_ep_{episode_seed}.prof"
+ )
+ try:
+ self.profiler.dump_stats(str(profile_filename))
+ logger.warning(
+ f"Worker {self.actor_id}: Profiling stats saved to {profile_filename}"
+ )
+ except Exception as e:
+ logger.error(
+ f"Worker {self.actor_id}: Failed to save profile stats: {e}"
+ )
+
+ if result is None:
+ logger.error(
+ f"Worker {self.actor_id}: run_episode finished without setting a result. Returning empty."
+ )
+ episode_end_context = {
+ "score": 0.0,
+ "length": 0,
+ "simulations": 0,
+ "triangles_cleared": 0,
+ "trainer_step": step_at_start,
+ "error": True,
+ }
+ self._send_event_async(
+ "episode_end",
+ 1.0,
+ context=episode_end_context,
+ )
+ result = SelfPlayResult(
+ episode_experiences=[],
+ final_score=0.0,
+ episode_steps=0,
+ trainer_step_at_episode_start=step_at_start,
+ total_simulations=0,
+ avg_root_visits=0.0,
+ avg_tree_depth=0.0,
+ context=episode_end_context,
+ )
+
+ logger.info(
+ f"Worker {self.actor_id}: Finished run_episode. Returning result with {len(result.episode_experiences)} experiences."
+ )
+ return result
+
+
+File: alphatriangle/rl/self_play/__init__.py
+from .mcts_helpers import (
+ PolicyGenerationError,
+ get_policy_target_from_visits,
+ select_action_from_visits,
+)
+from .worker import SelfPlayWorker
+
+__all__ = [
+ "SelfPlayWorker",
+ "select_action_from_visits",
+ "get_policy_target_from_visits",
+ "PolicyGenerationError",
+]
+
+
+File: alphatriangle/rl/self_play/README.md
+
+# RL Self-Play Submodule (`alphatriangle.rl.self_play`)
+
+## Purpose and Architecture
+
+This submodule focuses specifically on generating game episodes through self-play, driven by the current neural network and MCTS. It is designed to run in parallel using Ray actors managed by the [`alphatriangle.training.worker_manager`](../../training/worker_manager.py).
+
+- **[`worker.py`](worker.py):** Defines the `SelfPlayWorker` class, decorated with `@ray.remote`.
+ - Each `SelfPlayWorker` actor runs independently, typically on a separate CPU core.
+ - It initializes its own `GameState` environment and `NeuralNetwork` instance (usually on the CPU).
+ - It receives configuration objects (`EnvConfig`, `MCTSConfig`, `ModelConfig`, `TrainConfig`) during initialization.
+ - It has a `set_weights` method allowing the `TrainingLoop` to periodically update its local neural network with the latest trained weights from the central model. It also has `set_current_trainer_step` to store the global step associated with the current weights, called by the `WorkerManager`.
+ - Its main method, `run_episode`, simulates a complete game episode:
+ - Includes detailed logging for debugging.
+ - Uses its local `NeuralNetwork` evaluator and `MCTSConfig` to run MCTS using `trimcts.run_mcts`.
+ - Selects actions based on MCTS results ([`mcts_helpers.select_action_from_visits`](mcts_helpers.py)).
+ - Generates policy targets ([`mcts_helpers.get_policy_target_from_visits`](mcts_helpers.py)).
+ - Stores `(StateType, policy_target, n_step_return)` tuples (using extracted features and calculated n-step returns).
+ - Steps its local game environment (`GameState.step`).
+ - Returns the collected `Experience` list, final score, episode length, and MCTS statistics via a `SelfPlayResult` object.
+ - **Asynchronously sends raw metric events (`RawMetricEvent`) for step rewards, MCTS simulations, current score, and episode completion (including score, length, simulations) to the `StatsCollectorActor`, tagged with the `current_trainer_step` (global step of its network weights).**
+- **[`mcts_helpers.py`](mcts_helpers.py):** Contains helper functions for processing MCTS visit counts into policy targets and selecting actions based on temperature. Includes `PolicyGenerationError` for specific failures.
+
+## Exposed Interfaces
+
+- **Classes:**
+ - `SelfPlayWorker`: Ray actor class.
+ - `__init__(...)`
+ - `run_episode() -> SelfPlayResult`: Runs one episode and returns results.
+ - `set_weights(weights: Dict)`: Updates the actor's local network weights.
+ - `set_current_trainer_step(global_step: int)`: Updates the stored trainer step.
+- **Types:**
+ - `SelfPlayResult`: Pydantic model defined in [`alphatriangle.rl.types`](../types.py).
+- **Functions (from `mcts_helpers.py`):**
+ - `select_action_from_visits(...) -> ActionType`
+ - `get_policy_target_from_visits(...) -> PolicyTargetMapping`
+ - `PolicyGenerationError` (Exception)
+
+## Dependencies
+
+- **[`alphatriangle.config`](../../config/README.md)**:
+ - `EnvConfig`, `AlphaTriangleMCTSConfig`, `ModelConfig`, `TrainConfig`.
+- **`trianglengin`**:
+ - `GameState`, `EnvConfig`.
+- **`trimcts`**:
+ - `run_mcts`, `SearchConfiguration`.
+- **[`alphatriangle.nn`](../../nn/README.md)**:
+ - `NeuralNetwork`: Instantiated locally within the actor.
+- **[`alphatriangle.features`](../../features/README.md)**:
+ - `extract_state_features`: Used to generate `StateType` for experiences.
+- **[`alphatriangle.utils`](../../utils/README.md)**:
+ - `types`: `Experience`, `ActionType`, `PolicyTargetMapping`, `StateType`.
+ - `helpers`: `get_device`, `set_random_seeds`.
+- **[`alphatriangle.rl.types`](../types.py)**:
+ - `SelfPlayResult`: Return type.
+- **[`alphatriangle.stats`](../../stats/README.md)**:
+ - `StatsCollectorActor`, **`RawMetricEvent`**: Handle passed for logging.
+- **`numpy`**:
+ - Used by MCTS strategies and feature extraction.
+- **`ray`**:
+ - The `@ray.remote` decorator makes this a Ray actor.
+- **`torch`**:
+ - Used by the local `NeuralNetwork`.
+- **Standard Libraries:** `typing`, `logging`, `random`, `time`, `collections.deque`, `cProfile`.
+
+---
+
+**Note:** Please keep this README updated when changing the self-play episode generation logic, the data collected, the interaction with MCTS/environment, or the asynchronous logging behavior, especially regarding the sending of `RawMetricEvent` objects. Accurate documentation is crucial for maintainability.
+
+
+File: alphatriangle/rl/self_play/mcts_helpers.py
+# File: alphatriangle/rl/self_play/mcts_helpers.py
+import logging
+import random
+
+import numpy as np
+
+from ...utils.types import ActionType, PolicyTargetMapping
+
+logger = logging.getLogger(__name__)
+rng = np.random.default_rng()
+
+
+class PolicyGenerationError(Exception):
+ """Custom exception for errors during policy generation or action selection from visit counts."""
+
+ pass
+
+
+def select_action_from_visits(
+ visit_counts: dict[ActionType, int], temperature: float
+) -> ActionType:
+ """
+ Selects an action based on visit counts and temperature.
+ Operates on the dictionary returned by trimcts.run_mcts.
+ Raises PolicyGenerationError if selection is not possible.
+ """
+ if not visit_counts:
+ raise PolicyGenerationError(
+ "Cannot select action: Visit counts dictionary is empty."
+ )
+
+ actions = list(visit_counts.keys())
+ counts = np.array(list(visit_counts.values()), dtype=np.float64)
+
+ total_visits = np.sum(counts)
+ logger.debug(
+ f"[PolicySelect] Selecting action from visits. Total visits: {total_visits}. Num actions: {len(actions)}"
+ )
+
+ if total_visits == 0:
+ logger.warning(
+ "[PolicySelect] Total visit count is zero. Selecting uniformly from available actions."
+ )
+ selected_action = random.choice(actions)
+ logger.debug(
+ f"[PolicySelect] Uniform random action selected: {selected_action}"
+ )
+ return selected_action
+
+ if temperature == 0.0:
+ max_visits = np.max(counts)
+ logger.debug(
+ f"[PolicySelect] Greedy selection (temp=0). Max visits: {max_visits}"
+ )
+ best_action_indices = np.where(counts == max_visits)[0]
+ # Removed redundant log: logger.debug(f"[PolicySelect] Greedy selection. Best action indices: {best_action_indices}")
+ chosen_index = random.choice(best_action_indices)
+ selected_action = actions[chosen_index]
+ logger.debug(f"[PolicySelect] Greedy action selected: {selected_action}")
+ return selected_action
+ else:
+ logger.debug(f"[PolicySelect] Probabilistic selection: Temp={temperature:.4f}")
+ # Removed print: logger.debug(f" Visit Counts: {counts}")
+ # Use counts directly, avoid log for stability if counts can be zero
+ # Ensure counts are positive before raising to power
+ powered_counts = np.maximum(counts, 1e-9) ** (1.0 / temperature)
+ sum_powered_counts = np.sum(powered_counts)
+
+ if sum_powered_counts < 1e-9 or not np.isfinite(sum_powered_counts):
+ raise PolicyGenerationError(
+ f"Could not normalize visit probabilities (sum={sum_powered_counts}). Visits: {counts}"
+ )
+ else:
+ probabilities = powered_counts / sum_powered_counts
+
+ if not np.all(np.isfinite(probabilities)) or np.any(probabilities < 0):
+ raise PolicyGenerationError(
+ f"Invalid probabilities generated after normalization: {probabilities}"
+ )
+ if abs(np.sum(probabilities) - 1.0) > 1e-5:
+ logger.warning(
+ f"[PolicySelect] Probabilities sum to {np.sum(probabilities):.6f} after normalization. Attempting re-normalization."
+ )
+ probabilities /= np.sum(probabilities)
+ if abs(np.sum(probabilities) - 1.0) > 1e-5:
+ raise PolicyGenerationError(
+ f"Probabilities still do not sum to 1 after re-normalization: {probabilities}, Sum: {np.sum(probabilities)}"
+ )
+
+ # Removed print: logger.debug(f" Final Probabilities (normalized): {probabilities}")
+ # Removed print: logger.debug(f" Final Probabilities Sum: {np.sum(probabilities):.6f}")
+
+ try:
+ selected_action = rng.choice(actions, p=probabilities)
+ logger.debug(
+ f"[PolicySelect] Sampled action (temp={temperature:.2f}): {selected_action}"
+ )
+ return int(selected_action)
+ except ValueError as e:
+ raise PolicyGenerationError(
+ f"Error during np.random.choice: {e}. Probs: {probabilities}, Sum: {np.sum(probabilities)}"
+ ) from e
+
+
+def get_policy_target_from_visits(
+ visit_counts: dict[ActionType, int], action_dim: int, temperature: float = 1.0
+) -> PolicyTargetMapping:
+ """
+ Calculates the policy target distribution based on MCTS visit counts.
+ Operates on the dictionary returned by trimcts.run_mcts.
+ Raises PolicyGenerationError if target cannot be generated.
+ """
+ full_target = dict.fromkeys(range(action_dim), 0.0)
+
+ if not visit_counts:
+ logger.warning(
+ "[PolicyTarget] Cannot compute policy target: Visit counts dictionary is empty."
+ )
+ return full_target
+
+ actions = list(visit_counts.keys())
+ counts = np.array(list(visit_counts.values()), dtype=np.float64)
+ total_visits = np.sum(counts)
+
+ if total_visits == 0:
+ logger.warning(
+ "[PolicyTarget] Cannot compute policy target: Total visits is zero."
+ )
+ return full_target
+
+ if temperature == 0.0:
+ max_visits = np.max(counts)
+ if max_visits == 0:
+ logger.warning(
+ "[PolicyTarget] Temperature is 0 but max visits is 0. Returning zero target."
+ )
+ return full_target
+
+ best_actions = [actions[i] for i, v in enumerate(counts) if v == max_visits]
+ prob = 1.0 / len(best_actions)
+ for a in best_actions:
+ if 0 <= a < action_dim:
+ full_target[a] = prob
+ else:
+ logger.warning(
+ f"[PolicyTarget] Best action {a} is out of bounds ({action_dim}). Skipping."
+ )
+ else:
+ # Use counts directly, avoid potential issues with log(0)
+ powered_counts = np.maximum(counts, 1e-9) ** (1.0 / temperature)
+ sum_powered_counts = np.sum(powered_counts)
+
+ if sum_powered_counts < 1e-9 or not np.isfinite(sum_powered_counts):
+ raise PolicyGenerationError(
+ f"Could not normalize policy target probabilities (sum={sum_powered_counts}). Visits: {counts}"
+ )
+
+ probabilities = powered_counts / sum_powered_counts
+ if not np.all(np.isfinite(probabilities)) or np.any(probabilities < 0):
+ raise PolicyGenerationError(
+ f"Invalid probabilities generated for policy target: {probabilities}"
+ )
+ if abs(np.sum(probabilities) - 1.0) > 1e-5:
+ logger.warning(
+ f"[PolicyTarget] Target probabilities sum to {np.sum(probabilities):.6f}. Re-normalizing."
+ )
+ probabilities /= np.sum(probabilities)
+ if abs(np.sum(probabilities) - 1.0) > 1e-5:
+ raise PolicyGenerationError(
+ f"Target probabilities still do not sum to 1 after re-normalization: {probabilities}, Sum: {np.sum(probabilities)}"
+ )
+
+ raw_policy = {action: probabilities[i] for i, action in enumerate(actions)}
+ for action, prob in raw_policy.items():
+ if 0 <= action < action_dim:
+ full_target[action] = prob
+ else:
+ logger.warning(
+ f"[PolicyTarget] Action {action} from MCTS is out of bounds ({action_dim}). Skipping."
+ )
+
+ final_sum = sum(full_target.values())
+ if abs(final_sum - 1.0) > 1e-5:
+ # Keep this error log as it indicates a potential problem
+ logger.error(
+ f"[PolicyTarget] Final policy target does not sum to 1 ({final_sum:.6f}). Target: {full_target}"
+ )
+
+ return full_target
+
+
diff --git a/MANIFEST.in b/MANIFEST.in
index 692d721..5f1a193 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,5 +1,4 @@
-# File: MANIFEST.in
# File: MANIFEST.in
include README.md
include LICENSE
@@ -15,6 +14,9 @@ prune alphatriangle/visualization
prune alphatriangle/interaction
# REMOVE MCTS pruning
# prune alphatriangle/mcts
+# Remove Trieye-replaced directories
+prune alphatriangle/stats
+prune alphatriangle/data
# Remove pruned files
global-exclude alphatriangle/app.py
# Remove pruned test directories
@@ -22,5 +24,10 @@ prune tests/visualization
prune tests/interaction
# REMOVE MCTS test pruning
# prune tests/mcts
+# Remove Trieye-replaced test directories
+prune tests/stats
+prune tests/data
# Remove pruned core files
-global-exclude alphatriangle/rl/core/visual_state_actor.py
\ No newline at end of file
+global-exclude alphatriangle/rl/core/visual_state_actor.py
+# REMOVE test_save_resume.py
+global-exclude tests/training/test_save_resume.py
\ No newline at end of file
diff --git a/README.md b/README.md
index 2d5398b..961963b 100644
--- a/README.md
+++ b/README.md
@@ -11,15 +11,21 @@ AlphaTriangle is a project implementing an artificial intelligence agent based o
**Key Features:**
-* **Core Game Logic:** Uses the [`trianglengin>=2.0.6`](https://github.com/lguibr/trianglengin) library for the triangle puzzle game rules and state management, featuring a high-performance C++ core.
+* **Core Game Logic:** Uses the [`trianglengin>=2.0.7`](https://github.com/lguibr/trianglengin) library for the triangle puzzle game rules and state management, featuring a high-performance C++ core.
* **High-Performance MCTS:** Integrates the [`trimcts>=1.2.1`](https://github.com/lguibr/trimcts) library, providing a **C++ implementation of MCTS** for efficient search, callable from Python. MCTS parameters are configurable via `alphatriangle/config/mcts_config.py`.
* **Deep Learning Model:** Features a PyTorch neural network with policy and distributional value heads, convolutional layers, and **optional Transformer Encoder layers**.
* **Parallel Self-Play:** Leverages **Ray** for distributed self-play data generation across multiple CPU cores. **The number of workers automatically adjusts based on detected CPU cores (reserving some for stability), capped by the `NUM_SELF_PLAY_WORKERS` setting in `TrainConfig`.**
-* **Experiment Tracking:** Uses **MLflow** and **TensorBoard** for logging parameters, metrics, and artifacts, enabling web-based monitoring. All persistent data is stored within the `.alphatriangle_data` directory.
+* **Asynchronous Stats & Persistence (NEW):** Uses the [`trieye>=0.1.2`](https://github.com/lguibr/trieye) library, which provides a dedicated **Ray actor (`TrieyeActor`)** for:
+ * Asynchronous collection of raw metric events from workers and the training loop.
+ * Configurable processing and aggregation of metrics.
+ * Logging processed metrics to **MLflow** and **TensorBoard**.
+ * Saving/loading training state (checkpoints, buffers) to the filesystem and logging artifacts to MLflow.
+ * Handling auto-resumption from previous runs.
+ * All persistent data managed by `trieye` is stored within the `.trieye_data/` directory (default: `.trieye_data/alphatriangle`).
* **Headless Training:** Focuses on a command-line interface for running the training pipeline without visual output.
* **Enhanced CLI:** Uses the **`rich`** library for improved visual feedback (colors, panels, emojis) in the terminal.
-* **Centralized Logging:** Uses Python's standard `logging` module configured centrally for consistent log formatting (including `▲` prefix) and level control across the project. Run logs are saved to `.alphatriangle_data/runs//logs/`.
-* **Optional Profiling:** Supports profiling worker 0 using `cProfile` via a command-line flag. Profile data is saved to `.alphatriangle_data/runs//profile_data/`.
+* **Centralized Logging:** Uses Python's standard `logging` module configured centrally for consistent log formatting (including `▲` prefix) and level control across the project. Run logs are saved to `.trieye_data//runs//logs/`.
+* **Optional Profiling:** Supports profiling worker 0 using `cProfile` via a command-line flag. Profile data is saved to `.trieye_data//runs//profile_data/`.
* **Unit Tests:** Includes tests for RL components.
---
@@ -33,16 +39,17 @@ This project trains an agent to play the game defined by the `trianglengin` libr
## Core Technologies
* **Python 3.10+**
-* **trianglengin>=2.0.6:** Core game engine (state, actions, rules) with C++ optimizations.
+* **trianglengin>=2.0.7:** Core game engine (state, actions, rules) with C++ optimizations.
* **trimcts>=1.2.1:** High-performance C++ MCTS implementation with Python bindings.
+* **trieye>=0.1.2:** Asynchronous statistics collection, processing, logging (MLflow/TensorBoard), and data persistence via a Ray actor.
* **PyTorch:** For the deep learning model (CNNs, **optional Transformers**, Distributional Value Head) and training, with CUDA/MPS support.
* **NumPy:** For numerical operations, especially state representation (used by `trianglengin` and features).
-* **Ray:** For parallelizing self-play data generation and statistics collection across multiple CPU cores/processes. **Dynamically scales worker count based on available cores.**
+* **Ray:** For parallelizing self-play data generation and hosting the `TrieyeActor`. **Dynamically scales worker count based on available cores.**
* **Numba:** (Optional, used in `features.grid_features`) For performance optimization of specific grid calculations.
-* **Cloudpickle:** For serializing the experience replay buffer and training checkpoints.
-* **MLflow:** For logging parameters, metrics, and artifacts (checkpoints, buffers) during training runs. **Provides the primary web UI dashboard for experiment management. Data stored in `.alphatriangle_data/mlruns/`.**
-* **TensorBoard:** For visualizing metrics during training (e.g., detailed loss curves). **Provides a secondary web UI dashboard. Data stored in `.alphatriangle_data/runs//tensorboard/`.**
-* **Pydantic:** For configuration management and data validation.
+* **Cloudpickle:** For serializing the experience replay buffer and training checkpoints (used by `trieye`).
+* **MLflow:** For logging parameters, metrics, and artifacts (checkpoints, buffers) during training runs. **Provides the primary web UI dashboard for experiment management. Data stored in `.trieye_data//mlruns/`.** Managed by `trieye`.
+* **TensorBoard:** For visualizing metrics during training (e.g., detailed loss curves). **Provides a secondary web UI dashboard. Data stored in `.trieye_data//runs//tensorboard/`.** Managed by `trieye`.
+* **Pydantic:** For configuration management and data validation (used by `alphatriangle` and `trieye`).
* **Typer:** For the command-line interface.
* **Rich:** For enhanced CLI output formatting and styling.
* **Pytest:** For running unit tests.
@@ -53,68 +60,62 @@ This project trains an agent to play the game defined by the `trianglengin` libr
.
├── .github/workflows/ # GitHub Actions CI/CD
│ └── ci_cd.yml
-├── .alphatriangle_data/ # Root directory for ALL persistent data (GITIGNORED)
-│ ├── mlruns/ # MLflow internal tracking data & artifact store (for MLflow UI)
-│ │ └── /
-│ │ └── /
-│ │ ├── artifacts/ # MLflow's copy of logged artifacts (checkpoints, buffers, etc.)
-│ │ ├── metrics/
-│ │ ├── params/
-│ │ └── tags/
-│ └── runs/ # Local artifacts per run (source for TensorBoard UI & resume)
-│ └── / # e.g., train_YYYYMMDD_HHMMSS
-│ ├── checkpoints/ # Saved model weights & optimizer states (*.pkl)
-│ ├── buffers/ # Saved experience replay buffers (*.pkl)
-│ ├── logs/ # Plain text log files for the run (*.log)
-│ ├── tensorboard/ # TensorBoard log files (event files)
-│ ├── profile_data/ # cProfile output files (*.prof) if profiling enabled
-│ └── configs.json # Copy of run configuration
+├── .trieye_data/ # Root directory for ALL persistent data (GITIGNORED) - Managed by Trieye
+│ └── alphatriangle/ # Default app_name
+│ ├── mlruns/ # MLflow internal tracking data & artifact store (for MLflow UI)
+│ │ └── /
+│ │ └── /
+│ │ ├── artifacts/ # MLflow's copy of logged artifacts (checkpoints, buffers, etc.)
+│ │ ├── metrics/
+│ │ ├── params/
+│ │ └── tags/
+│ └── runs/ # Local artifacts per run (source for TensorBoard UI & resume)
+│ └── / # e.g., train_YYYYMMDD_HHMMSS
+│ ├── checkpoints/ # Saved model weights & optimizer states (*.pkl)
+│ ├── buffers/ # Saved experience replay buffers (*.pkl)
+│ ├── logs/ # Plain text log files for the run (*.log) - App + Trieye logs
+│ ├── tensorboard/ # TensorBoard log files (event files)
+│ ├── profile_data/ # cProfile output files (*.prof) if profiling enabled
+│ └── configs.json # Copy of run configuration
├── alphatriangle/ # Source code for the AlphaZero agent package
│ ├── __init__.py
│ ├── cli.py # CLI logic (train, ml, tb, ray commands - headless only, uses Rich)
-│ ├── config/ # Pydantic configuration models (Model, Train, Persistence, MCTS, Stats)
-│ │ └── README.md
-│ ├── data/ # Data saving/loading logic (DataManager, Schemas, PathManager, Serializer)
+│ ├── config/ # Pydantic configuration models (Model, Train, MCTS) - Stats/Persistence now in Trieye
│ │ └── README.md
│ ├── features/ # Feature extraction logic (operates on trianglengin.GameState)
│ │ └── README.md
-│ ├── logging_config.py # Centralized logging setup function
+│ ├── logging_config.py # Centralized logging setup function (for console)
│ ├── nn/ # Neural network definition and wrapper (implements trimcts.AlphaZeroNetworkInterface)
│ │ └── README.md
│ ├── rl/ # RL components (Trainer, Buffer, Worker using trimcts)
│ │ └── README.md
-│ ├── stats/ # Statistics collection actor (StatsCollectorActor, StatsProcessor)
-│ │ └── README.md
-│ ├── training/ # Training orchestration (Loop, Setup, Runner, WorkerManager)
+│ ├── training/ # Training orchestration (Loop, Setup, Runner, WorkerManager) - Interacts with TrieyeActor
│ │ └── README.md
│ └── utils/ # Shared utilities and types (specific to AlphaTriangle)
│ └── README.md
-├── tests/ # Unit tests (for alphatriangle components, excluding MCTS)
+├── tests/ # Unit tests (for alphatriangle components, excluding MCTS, Stats, Data)
│ ├── conftest.py
│ ├── nn/
│ ├── rl/
-│ ├── stats/
│ └── training/
├── .gitignore
├── .python-version
├── LICENSE # License file (MIT)
├── MANIFEST.in # Specifies files for source distribution
-├── pyproject.toml # Build system & package configuration (depends on trianglengin, trimcts, rich)
+├── pyproject.toml # Build system & package configuration (depends on trianglengin, trimcts, trieye, rich)
├── README.md # This file
-└── requirements.txt # List of dependencies (includes trianglengin, trimcts, rich)
+└── requirements.txt # List of dependencies (includes trianglengin, trimcts, trieye, rich)
```
## Key Modules (`alphatriangle`)
* **`cli`:** Defines the command-line interface using Typer (**`train`**, **`ml`**, **`tb`**, **`ray`** commands - headless). Uses **`rich`** for styling. ([`alphatriangle/cli.py`](alphatriangle/cli.py))
-* **`config`:** Centralized Pydantic configuration classes (Model, Train, Persistence, **MCTS**, **Stats**). Imports `EnvConfig` from `trianglengin`. ([`alphatriangle/config/README.md`](alphatriangle/config/README.md))
+* **`config`:** Centralized Pydantic configuration classes (Model, Train, **MCTS**). Imports `EnvConfig` from `trianglengin`. **Uses `TrieyeConfig` from `trieye` for stats/persistence.** ([`alphatriangle/config/README.md`](alphatriangle/config/README.md))
* **`features`:** Contains logic to convert `trianglengin.GameState` objects into numerical features (`StateType`). ([`alphatriangle/features/README.md`](alphatriangle/features/README.md))
-* **`logging_config`:** Defines the `setup_logging` function for centralized logger configuration. ([`alphatriangle/logging_config.py`](alphatriangle/logging_config.py))
+* **`logging_config`:** Defines the `setup_logging` function for centralized **console** logger configuration. ([`alphatriangle/logging_config.py`](alphatriangle/logging_config.py))
* **`nn`:** Contains the PyTorch `nn.Module` definition (`AlphaTriangleNet`) and a wrapper class (`NeuralNetwork`). **The `NeuralNetwork` class implicitly conforms to the `trimcts.AlphaZeroNetworkInterface` protocol.** ([`alphatriangle/nn/README.md`](alphatriangle/nn/README.md))
-* **`rl`:** Contains RL components: `Trainer` (network updates), `ExperienceBuffer` (data storage, **supports PER**), and `SelfPlayWorker` (Ray actor for parallel self-play **using `trimcts.run_mcts`**). ([`alphatriangle/rl/README.md`](alphatriangle/rl/README.md))
-* **`training`:** Orchestrates the **headless** training process using `TrainingLoop`, managing workers, data flow, logging (via centralized setup), and checkpoints. Includes `runner.py` for the callable training function. ([`alphatriangle/training/README.md`](alphatriangle/training/README.md))
-* **`stats`:** Contains the `StatsCollectorActor` (Ray actor) for asynchronous statistics collection and the `StatsProcessor` for aggregation/logging. ([`alphatriangle/stats/README.md`](alphatriangle/stats/README.md))
-* **`data`:** Manages saving and loading of training artifacts (`DataManager`, `PathManager`, `Serializer`) using Pydantic schemas and `cloudpickle`. ([`alphatriangle/data/README.md`](alphatriangle/data/README.md))
+* **`rl`:** Contains RL components: `Trainer` (network updates), `ExperienceBuffer` (data storage, **supports PER**), and `SelfPlayWorker` (Ray actor for parallel self-play **using `trimcts.run_mcts`**). **Workers now send `RawMetricEvent`s to the `TrieyeActor`.** ([`alphatriangle/rl/README.md`](alphatriangle/rl/README.md))
+* **`training`:** Orchestrates the **headless** training process using `TrainingLoop`, managing workers, data flow, and interaction with the **`TrieyeActor`** for logging, checkpointing, and state loading. Includes `runner.py` for the callable training function. ([`alphatriangle/training/README.md`](alphatriangle/training/README.md))
* **`utils`:** Provides common helper functions and shared type definitions specific to the AlphaZero implementation. ([`alphatriangle/utils/README.md`](alphatriangle/utils/README.md))
## Setup
@@ -129,21 +130,23 @@ This project trains an agent to play the game defined by the `trianglengin` libr
python -m venv venv
source venv/bin/activate # On Windows use `venv\Scripts\activate`
```
-3. **Install the package (including `trianglengin`, `trimcts`, and `rich`):**
+3. **Install the package (including `trianglengin`, `trimcts`, `trieye`, and `rich`):**
* **For users:**
```bash
- # This will automatically install trianglengin, trimcts, and rich from PyPI if available
+ # This will automatically install trianglengin, trimcts, trieye, and rich from PyPI if available
pip install alphatriangle
# Or install directly from Git (installs dependencies from PyPI)
# pip install git+https://github.com/lguibr/alphatriangle.git
```
* **For developers (editable install):**
- * First, ensure `trianglengin` and `trimcts` are installed (ideally in editable mode from their own directories if developing all three):
+ * First, ensure `trianglengin`, `trimcts`, and `trieye` are installed (ideally in editable mode from their own directories if developing all):
```bash
# From the trianglengin directory (requires C++ build tools):
# pip install -e .
# From the trimcts directory (requires C++ build tools):
# pip install -e .
+ # From the trieye directory:
+ # pip install -e .
```
* Then, install `alphatriangle` in editable mode:
```bash
@@ -156,7 +159,7 @@ This project trains an agent to play the game defined by the `trianglengin` libr
4. **(Optional but Recommended) Add data directory to `.gitignore`:**
Ensure the `.gitignore` file in your project root contains the line:
```
- .alphatriangle_data/
+ .trieye_data/
```
## Running the Code (CLI)
@@ -169,19 +172,20 @@ Use the `alphatriangle` command for training and monitoring. The CLI uses `rich`
```
* **Run Training (Headless Only):**
```bash
- alphatriangle train [--seed 42] [--log-level INFO] [--profile]
+ alphatriangle train [--seed 42] [--log-level INFO] [--profile] [--run-name my_custom_run]
```
- * `--log-level`: Set console/file logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Default: INFO. Logs are saved to `.alphatriangle_data/runs//logs/`.
+ * `--log-level`: Set console logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Default: INFO. Logs are saved to `.trieye_data/alphatriangle/runs//logs/`.
* `--seed`: Set the random seed for reproducibility. Default: 42.
- * `--profile`: Enable cProfile for worker 0. Generates `.prof` files in `.alphatriangle_data/runs//profile_data/`.
+ * `--profile`: Enable cProfile for worker 0. Generates `.prof` files in `.trieye_data/alphatriangle/runs//profile_data/`.
+ * `--run-name`: Specify a custom name for the run. Default: `train_YYYYMMDD_HHMMSS`.
* **Launch MLflow UI:**
- Launches the MLflow web interface, automatically pointing to the `.alphatriangle_data/mlruns` directory.
+ Launches the MLflow web interface, automatically pointing to the `.trieye_data/alphatriangle/mlruns` directory.
```bash
alphatriangle ml [--host 127.0.0.1] [--port 5000]
```
Access via `http://localhost:5000` (or the specified host/port).
* **Launch TensorBoard UI:**
- Launches the TensorBoard web interface, automatically pointing to the `.alphatriangle_data/runs` directory (which contains the individual run subdirectories with `tensorboard` logs).
+ Launches the TensorBoard web interface, automatically pointing to the `.trieye_data/alphatriangle/runs` directory (which contains the individual run subdirectories with `tensorboard` logs).
```bash
alphatriangle tb [--host 127.0.0.1] [--port 6006]
```
@@ -206,19 +210,21 @@ Use the `alphatriangle` command for training and monitoring. The CLI uses `rich`
* **Analyzing Profile Data (if `--profile` was used):**
Use the provided `analyze_profiles.py` script (requires `typer`).
```bash
- python analyze_profiles.py .alphatriangle_data/runs//profile_data/worker_0_ep_.prof [-n ]
+ python analyze_profiles.py .trieye_data/alphatriangle/runs//profile_data/worker_0_ep_.prof [-n ]
```
## Configuration
-All major parameters for the AlphaZero agent (Model, Training, Persistence, **MCTS**, **Stats**) are defined in the Pydantic classes within the `alphatriangle/config/` directory. Modify these files to experiment with different settings. Environment configuration (`EnvConfig`) is defined within the `trianglengin` library.
+* **AlphaTriangle Specific:** Parameters for the Model (`ModelConfig`), Training (`TrainConfig`), and MCTS (`AlphaTriangleMCTSConfig`) are defined in the Pydantic classes within the `alphatriangle/config/` directory.
+* **Environment:** Environment configuration (`EnvConfig`) is defined within the `trianglengin` library.
+* **Stats & Persistence:** Statistics logging and data persistence are configured via `TrieyeConfig` (which includes `StatsConfig` and `PersistenceConfig`) from the `trieye` library. These are typically instantiated in `alphatriangle/training/runner.py` or `alphatriangle/cli.py`.
## Data Storage
-All persistent data is stored within the `.alphatriangle_data/` directory in the project root. This directory should be added to your `.gitignore`.
+All persistent data is now managed by the `trieye` library and stored within the `.trieye_data//` directory (default: `.trieye_data/alphatriangle/`) in the project root. This directory should be added to your `.gitignore`.
-* **`.alphatriangle_data/mlruns/`**: Managed by **MLflow**. Contains MLflow's internal tracking data (parameters, metrics) and its own copy of logged artifacts. This is the source for the MLflow UI (`alphatriangle ml`).
-* **`.alphatriangle_data/runs/`**: Managed by **DataManager**. Contains locally saved artifacts for each run (checkpoints, buffers, TensorBoard logs, configs, **profile data**) before/during logging to MLflow. This directory is used for auto-resuming and is the source for the TensorBoard UI (`alphatriangle tb`).
+* **`.trieye_data//mlruns/`**: Managed by **MLflow** (via `trieye`). Contains MLflow's internal tracking data (parameters, metrics) and its own copy of logged artifacts. This is the source for the MLflow UI (`alphatriangle ml`).
+* **`.trieye_data//runs/`**: Managed by **`trieye`**. Contains locally saved artifacts for each run (checkpoints, buffers, TensorBoard logs, configs, **profile data**) before/during logging to MLflow. This directory is used for auto-resuming and is the source for the TensorBoard UI (`alphatriangle tb`).
* **Replay Buffer Content:** The saved buffer file (`buffer.pkl`) contains `Experience` tuples: `(StateType, PolicyTargetMapping, n_step_return)`. The `StateType` includes:
* `grid`: Numerical features representing grid occupancy.
* `other_features`: Numerical features derived from game state and available shapes.
@@ -226,4 +232,4 @@ All persistent data is stored within the `.alphatriangle_data/` directory in the
## Maintainability
-This project includes README files within each major `alphatriangle` submodule. **Please keep these READMEs updated** when making changes to the code's structure, interfaces, or core logic.
+This project includes README files within each major `alphatriangle` submodule. **Please keep these READMEs updated** when making changes to the code's structure, interfaces, or core logic, especially regarding the interaction with the `trieye` library.
\ No newline at end of file
diff --git a/alphatriangle/cli.py b/alphatriangle/cli.py
index db3e38a..f0b7f15 100644
--- a/alphatriangle/cli.py
+++ b/alphatriangle/cli.py
@@ -1,3 +1,4 @@
+# File: alphatriangle/cli.py
import logging
import shutil
import subprocess
@@ -8,9 +9,12 @@
from rich.console import Console
from rich.panel import Panel
+# Import Trieye config
+from trieye import PersistenceConfig, TrieyeConfig
+
# Import alphatriangle specific configs and runner
from alphatriangle.config import (
- PersistenceConfig,
+ APP_NAME, # Use APP_NAME from config
TrainConfig,
)
from alphatriangle.logging_config import setup_logging # Import centralized setup
@@ -61,6 +65,15 @@
),
]
+RunNameOption = Annotated[
+ str | None,
+ typer.Option(
+ "--run-name",
+ help="Specify a custom name for the run (overrides default timestamp).",
+ ),
+]
+
+
HostOption = Annotated[
str, typer.Option(help="The network address to listen on (default: 127.0.0.1).")
]
@@ -106,7 +119,8 @@ def _run_external_ui(
console.print(
f"[bold red]Error:[/bold red] {ui_name} command failed with exit code {process.returncode}"
)
- raise typer.Exit(code=process.returncode)
+ # Don't exit immediately, let the calling function handle it if needed
+ # raise typer.Exit(code=process.returncode)
except FileNotFoundError as e:
console.print(
f"[bold red]Error:[/bold red] '{executable}' command not found. Is {ui_name} installed and in your PATH?"
@@ -129,27 +143,37 @@ def train(
log_level: LogLevelOption = "INFO",
seed: SeedOption = 42,
profile: ProfileOption = False,
+ run_name: RunNameOption = None, # Add run_name option
):
"""
🚀 Run the AlphaTriangle training pipeline (headless).
- Initiates the self-play and learning process. Logs will be saved to the run directory.
+ Initiates the self-play and learning process. Uses Trieye for stats/persistence.
+ Logs will be saved to the run directory within `.trieye_data/alphatriangle/runs/`.
+ This command also initializes Ray and starts the Ray Dashboard. Check the logs for the dashboard URL.
"""
- # Setup logging using the centralized function (file logging handled by runner)
+ # Setup logging using the centralized function (file logging handled by Trieye)
setup_logging(log_level)
logging.getLogger(__name__) # Get logger after setup
- # Use alphatriangle configs here
+ # Use alphatriangle TrainConfig
train_config_override = TrainConfig()
- persist_config_override = PersistenceConfig()
train_config_override.RANDOM_SEED = seed
train_config_override.PROFILE_WORKERS = profile # Set profile config
- # Ensure run name is set for persistence config
- persist_config_override.RUN_NAME = train_config_override.RUN_NAME
+
+ # Create TrieyeConfig, overriding run_name if provided
+ trieye_config_override = TrieyeConfig(app_name=APP_NAME)
+ if run_name:
+ trieye_config_override.run_name = run_name
+ # Sync run_name to persistence config within TrieyeConfig
+ trieye_config_override.persistence.RUN_NAME = run_name
+ else:
+ # Use the default factory-generated run_name from TrieyeConfig
+ run_name = trieye_config_override.run_name
console.print(
Panel(
- f"Starting Training Run: '[bold cyan]{train_config_override.RUN_NAME}[/]'\n"
+ f"Starting Training Run: '[bold cyan]{run_name}[/]'\n"
f"Seed: {seed}, Log Level: {log_level.upper()}, Profiling: {'✅ Enabled' if profile else '❌ Disabled'}",
title="[bold green]Training Setup[/]",
border_style="green",
@@ -157,18 +181,18 @@ def train(
)
)
- # Call the single runner function directly, passing the profile flag
+ # Call the single runner function directly, passing configs
exit_code = run_training(
log_level_str=log_level,
train_config_override=train_config_override,
- persist_config_override=persist_config_override,
+ trieye_config_override=trieye_config_override,
profile=profile,
)
if exit_code == 0:
console.print(
Panel(
- f"✅ Training run '[bold cyan]{train_config_override.RUN_NAME}[/]' completed successfully.",
+ f"✅ Training run '[bold cyan]{run_name}[/]' completed successfully.",
title="[bold green]Training Finished[/]",
border_style="green",
)
@@ -176,7 +200,7 @@ def train(
else:
console.print(
Panel(
- f"❌ Training run '[bold cyan]{train_config_override.RUN_NAME}[/]' failed with exit code {exit_code}.",
+ f"❌ Training run '[bold cyan]{run_name}[/]' failed with exit code {exit_code}.",
title="[bold red]Training Failed[/]",
border_style="red",
)
@@ -192,11 +216,11 @@ def ml(
"""
📊 Launch the MLflow UI for experiment tracking.
- Requires MLflow to be installed. Points to the `.alphatriangle_data/mlruns` directory.
+ Requires MLflow to be installed. Points to the `.trieye_data//mlruns` directory.
"""
setup_logging("INFO") # Basic logging for this command
- persist_config = PersistenceConfig()
- # Use the computed property which resolves the path and creates the dir
+ # Use Trieye's PersistenceConfig to find the path
+ persist_config = PersistenceConfig(APP_NAME=APP_NAME)
mlflow_uri = persist_config.MLFLOW_TRACKING_URI
mlflow_path = persist_config.get_mlflow_abs_path()
@@ -217,7 +241,15 @@ def ml(
"--port",
str(port),
]
- _run_external_ui("mlflow", command_args, "MLflow UI", f"http://{host}:{port}")
+ try:
+ _run_external_ui("mlflow", command_args, "MLflow UI", f"http://{host}:{port}")
+ except typer.Exit as e:
+ if e.exit_code != 0:
+ console.print(
+ f"[yellow]MLflow UI failed to start (Exit Code: {e.exit_code}). "
+ f"Is port {port} already in use? Try specifying a different port with --port.[/]"
+ )
+ sys.exit(e.exit_code)
@app.command()
@@ -228,11 +260,11 @@ def tb(
"""
📈 Launch TensorBoard UI pointing to the runs directory.
- Requires TensorBoard to be installed. Points to the `.alphatriangle_data/runs` directory.
+ Requires TensorBoard to be installed. Points to the `.trieye_data//runs` directory.
"""
setup_logging("INFO") # Basic logging for this command
- persist_config = PersistenceConfig()
- # Point to the parent directory containing all individual run folders
+ # Use Trieye's PersistenceConfig to find the path
+ persist_config = PersistenceConfig(APP_NAME=APP_NAME)
runs_root_dir = persist_config.get_runs_root_dir()
if not runs_root_dir.exists() or not any(runs_root_dir.iterdir()):
@@ -253,38 +285,46 @@ def tb(
"--port",
str(port),
]
- _run_external_ui(
- "tensorboard", command_args, "TensorBoard UI", f"http://{host}:{port}"
- )
+ try:
+ _run_external_ui(
+ "tensorboard", command_args, "TensorBoard UI", f"http://{host}:{port}"
+ )
+ except typer.Exit as e:
+ if e.exit_code != 0:
+ console.print(
+ f"[yellow]TensorBoard UI failed to start (Exit Code: {e.exit_code}). "
+ f"Is port {port} already in use? Try specifying a different port with --port.[/]"
+ )
+ sys.exit(e.exit_code)
@app.command()
def ray(
- host: HostOption = "127.0.0.1",
+ host: HostOption = "127.0.0.1", # Keep host/port options for reference
port: PortOption = 8265,
):
"""
- ☀️ Launch the Ray Dashboard web UI.
+ ☀️ Provides instructions to view the Ray Dashboard.
- Requires Ray to be installed and potentially running (e.g., started by `alphatriangle train`).
+ The dashboard is automatically started when you run `alphatriangle train`.
+ Check the output logs of the `train` command for the correct URL.
"""
setup_logging("INFO") # Basic logging for this command
console.print(
- "[yellow]Note:[/yellow] This command attempts to open the Ray Dashboard for an existing Ray cluster."
- )
- console.print(
- "If Ray is not running, this command might fail. Start Ray first (e.g., via `alphatriangle train`)."
+ Panel(
+ f"💡 To view the Ray Dashboard:\n\n"
+ f"1. The Ray Dashboard is started automatically when you run the `[bold]alphatriangle train[/]` command.\n"
+ f"2. Check the console output or the log file for the `train` command (located in `.trieye_data/{APP_NAME}/runs//logs/`).\n"
+ f"3. Look for a line similar to: '[bold cyan]Ray Dashboard running at: http://:[/]' \n"
+ f"4. Open that specific URL in your web browser.\n\n"
+ f"[dim]Note: The default URL is often http://{host}:{port}, but it might differ. "
+ f"If you cannot access the URL, check firewall settings or if the port is blocked.[/]",
+ title="[bold yellow]Ray Dashboard Instructions[/]",
+ border_style="yellow",
+ expand=False,
+ )
)
- command_args = [
- "dashboard",
- "--host",
- host,
- "--port",
- str(port),
- ]
- _run_external_ui("ray", command_args, "Ray Dashboard", f"http://{host}:{port}")
-
if __name__ == "__main__":
app()
diff --git a/alphatriangle/config/README.md b/alphatriangle/config/README.md
index 32d6a60..cd12f41 100644
--- a/alphatriangle/config/README.md
+++ b/alphatriangle/config/README.md
@@ -1,19 +1,17 @@
-
# Configuration Module (`alphatriangle.config`)
## Purpose and Architecture
-This module centralizes all configuration parameters for the AlphaTriangle project *except* for the core environment settings. It uses separate **Pydantic models** for different aspects of the application (model, training, persistence, MCTS, **statistics**) to promote modularity, clarity, and automatic validation.
+This module centralizes configuration parameters for the AlphaTriangle agent itself, *excluding* statistics logging and data persistence which are now handled by the `trieye` library. It uses separate **Pydantic models** for different aspects of the agent (model, training loop, MCTS) to promote modularity, clarity, and automatic validation.
-**Core environment configuration (`EnvConfig`) is now defined and imported directly from the `trianglengin` library.**
+**Core environment configuration (`EnvConfig`) is imported directly from the `trianglengin` library.**
+**Statistics and Persistence configuration (`StatsConfig`, `PersistenceConfig`) are defined and managed within the `trieye` library via `TrieyeConfig`.**
- **Modularity:** Separating configurations makes it easier to manage parameters for different components.
- **Type Safety & Validation:** Using Pydantic models (`BaseModel`) provides strong type hinting, automatic parsing, and validation of configuration values based on defined types and constraints (e.g., `Field(gt=0)`).
-- **Validation Script:** The [`validation.py`](validation.py) script instantiates all configuration models (including importing and validating `trianglengin.EnvConfig`), triggering Pydantic's validation, and prints a summary.
-- **Dynamic Defaults:** Some configurations, like `RUN_NAME` in `TrainConfig`, use `default_factory` for dynamic defaults (e.g., timestamp).
-- **Computed Fields:** Properties like `MLFLOW_TRACKING_URI` in `PersistenceConfig` are defined using `@computed_field` for clarity.
-- **Tuned Defaults:** The default values in `TrainConfig` and `ModelConfig` are tuned for substantial learning runs. `AlphaTriangleMCTSConfig` defaults to 128 simulations. `StatsConfig` defines a default set of metrics to track.
-- **Data Paths:** `PersistenceConfig` defines the structure within the `.alphatriangle_data` directory where all local artifacts (runs, checkpoints, logs, TensorBoard data) and MLflow data (`mlruns`) are stored.
+- **Validation Script:** The [`validation.py`](validation.py) script instantiates the AlphaTriangle-specific configuration models (including importing and validating `trianglengin.EnvConfig`), triggering Pydantic's validation, and prints a summary. **Note:** It does *not* validate `TrieyeConfig` directly; `trieye` handles its own validation upon actor initialization.
+- **Dynamic Defaults:** Some configurations, like `RUN_NAME` in `TrainConfig`, use `default_factory` for dynamic defaults (e.g., timestamp). This default is often overridden by the `TrieyeConfig` setting.
+- **Tuned Defaults:** The default values in `TrainConfig` and `ModelConfig` are tuned for substantial learning runs. `AlphaTriangleMCTSConfig` defaults to 128 simulations.
## Exposed Interfaces
@@ -21,13 +19,11 @@ This module centralizes all configuration parameters for the AlphaTriangle proje
- `EnvConfig` (Imported from `trianglengin`): Environment parameters (grid size, shapes, rewards).
- [`ModelConfig`](model_config.py): Neural network architecture parameters.
- [`TrainConfig`](train_config.py): Training loop hyperparameters (batch size, learning rate, workers, PER settings, etc.).
- - [`PersistenceConfig`](persistence_config.py): Data saving/loading parameters (directories within `.alphatriangle_data`, filenames).
- [`AlphaTriangleMCTSConfig`](mcts_config.py): MCTS parameters (simulations, exploration constants, temperature).
- - [`StatsConfig`](stats_config.py): Statistics collection and logging parameters (metrics, aggregation, frequency).
- **Constants:**
- - [`APP_NAME`](app_config.py): The name of the application.
+ - [`APP_NAME`](app_config.py): The name of the application (used by `trieye` for namespacing).
- **Functions:**
- - `print_config_info_and_validate(mcts_config_instance: AlphaTriangleMCTSConfig | None)`: Validates and prints a summary of all configurations.
+ - `print_config_info_and_validate(mcts_config_instance: AlphaTriangleMCTSConfig | None)`: Validates and prints a summary of AlphaTriangle-specific configurations.
## Dependencies
@@ -40,4 +36,4 @@ This module primarily defines configurations and relies heavily on **Pydantic**.
---
-**Note:** Please keep this README updated when adding, removing, or significantly modifying configuration parameters or the structure of the Pydantic models. Accurate documentation is crucial for maintainability.
\ No newline at end of file
+**Note:** Please keep this README updated when adding, removing, or significantly modifying configuration parameters or the structure of the Pydantic models within this module.
\ No newline at end of file
diff --git a/alphatriangle/config/__init__.py b/alphatriangle/config/__init__.py
index 30d304a..df5eee2 100644
--- a/alphatriangle/config/__init__.py
+++ b/alphatriangle/config/__init__.py
@@ -4,8 +4,6 @@
from .app_config import APP_NAME
from .mcts_config import AlphaTriangleMCTSConfig
from .model_config import ModelConfig
-from .persistence_config import PersistenceConfig
-from .stats_config import StatsConfig # ADDED
from .train_config import TrainConfig
from .validation import print_config_info_and_validate
@@ -13,9 +11,7 @@
"APP_NAME",
"EnvConfig",
"ModelConfig",
- "PersistenceConfig",
"TrainConfig",
"AlphaTriangleMCTSConfig",
- "StatsConfig", # ADDED
"print_config_info_and_validate",
]
diff --git a/alphatriangle/config/mcts_config.py b/alphatriangle/config/mcts_config.py
index a055293..cd590db 100644
--- a/alphatriangle/config/mcts_config.py
+++ b/alphatriangle/config/mcts_config.py
@@ -8,8 +8,8 @@
from trimcts import SearchConfiguration # Import base config for reference
# Restore default simulations to a lower value for faster testing/profiling
-DEFAULT_MAX_SIMULATIONS = 128
-DEFAULT_MAX_DEPTH = 16
+DEFAULT_MAX_SIMULATIONS = 64
+DEFAULT_MAX_DEPTH = 8
DEFAULT_CPUCT = 1.5
DEFAULT_MCTS_BATCH_SIZE = 32 # Default batch size for network evals within MCTS
diff --git a/alphatriangle/config/persistence_config.py b/alphatriangle/config/persistence_config.py
deleted file mode 100644
index 7125d56..0000000
--- a/alphatriangle/config/persistence_config.py
+++ /dev/null
@@ -1,85 +0,0 @@
-from pathlib import Path
-
-from pydantic import BaseModel, Field, computed_field
-
-
-class PersistenceConfig(BaseModel):
- """Configuration for saving/loading artifacts (Pydantic model)."""
-
- # Root directory for all persistent data, relative to project root.
- ROOT_DATA_DIR: str = Field(default=".alphatriangle_data")
- RUNS_DIR_NAME: str = Field(default="runs")
- MLFLOW_DIR_NAME: str = Field(default="mlruns")
-
- CHECKPOINT_SAVE_DIR_NAME: str = Field(default="checkpoints")
- BUFFER_SAVE_DIR_NAME: str = Field(default="buffers")
- LOG_DIR_NAME: str = Field(default="logs")
- TENSORBOARD_DIR_NAME: str = Field(default="tensorboard")
- PROFILE_DIR_NAME: str = Field(default="profile_data") # Added for consistency
-
- LATEST_CHECKPOINT_FILENAME: str = Field(default="latest.pkl")
- BEST_CHECKPOINT_FILENAME: str = Field(default="best.pkl")
- BUFFER_FILENAME: str = Field(default="buffer.pkl")
- CONFIG_FILENAME: str = Field(default="configs.json")
-
- RUN_NAME: str = Field(default="default_run")
-
- SAVE_BUFFER: bool = Field(default=True)
- BUFFER_SAVE_FREQ_STEPS: int = Field(default=1000, ge=1)
-
- def _get_absolute_root(self) -> Path:
- """Resolves ROOT_DATA_DIR to an absolute path relative to the project root."""
- # Assume the config file is loaded relative to the project root
- # or the script is run from the project root.
- project_root = Path.cwd() # Or determine project root differently if needed
- root_path = project_root / self.ROOT_DATA_DIR
- return root_path.resolve()
-
- @computed_field # type: ignore[misc] # Decorator requires Pydantic v2
- @property
- def MLFLOW_TRACKING_URI(self) -> str:
- """Constructs the absolute file URI for MLflow tracking using pathlib."""
- if hasattr(self, "MLFLOW_DIR_NAME"):
- abs_path = self.get_mlflow_abs_path()
- # Ensure the path exists for the URI to be valid for mlflow ui
- abs_path.mkdir(parents=True, exist_ok=True)
- return abs_path.as_uri()
- return ""
-
- def get_runs_root_dir(self) -> Path:
- """Gets the absolute path to the directory containing all runs."""
- if not hasattr(self, "RUNS_DIR_NAME"):
- return Path() # Return empty path if config is incomplete
- root_path = self._get_absolute_root()
- return root_path / self.RUNS_DIR_NAME
-
- def get_run_base_dir(self, run_name: str | None = None) -> Path:
- """Gets the absolute base directory path for a specific run."""
- runs_root = self.get_runs_root_dir()
- name = run_name if run_name else self.RUN_NAME
- return runs_root / name
-
- def get_mlflow_abs_path(self) -> Path:
- """Gets the absolute OS path to the MLflow directory."""
- if not hasattr(self, "MLFLOW_DIR_NAME"):
- return Path()
- root_path = self._get_absolute_root()
- return root_path / self.MLFLOW_DIR_NAME
-
- def get_tensorboard_log_dir(self, run_name: str | None = None) -> Path:
- """Gets the absolute directory path for TensorBoard logs for a specific run."""
- run_base = self.get_run_base_dir(run_name)
- if not run_base or not hasattr(self, "TENSORBOARD_DIR_NAME"):
- return Path()
- return run_base / self.TENSORBOARD_DIR_NAME
-
- def get_profile_data_dir(self, run_name: str | None = None) -> Path:
- """Gets the absolute directory path for profile data for a specific run."""
- run_base = self.get_run_base_dir(run_name)
- if not run_base or not hasattr(self, "PROFILE_DIR_NAME"):
- return Path()
- return run_base / self.PROFILE_DIR_NAME
-
-
-# Ensure model is rebuilt after changes
-PersistenceConfig.model_rebuild(force=True)
diff --git a/alphatriangle/config/stats_config.py b/alphatriangle/config/stats_config.py
deleted file mode 100644
index b0e20e4..0000000
--- a/alphatriangle/config/stats_config.py
+++ /dev/null
@@ -1,277 +0,0 @@
-# File: alphatriangle/config/stats_config.py
-import logging
-from typing import Literal
-
-from pydantic import BaseModel, Field, field_validator
-
-logger = logging.getLogger(__name__)
-
-# --- Enums and Literals ---
-AggregationMethod = Literal[
- "latest", "mean", "sum", "rate", "min", "max", "std", "count"
-]
-LogTarget = Literal["mlflow", "tensorboard", "console", "stats_actor"]
-DataSource = Literal["trainer", "worker", "loop", "buffer", "system"]
-XAxis = Literal["global_step", "wall_time", "episode"]
-
-
-# --- Metric Definition ---
-class MetricConfig(BaseModel):
- """Configuration for a single metric to be tracked and logged."""
-
- name: str = Field(
- ..., description="Unique name for the metric (e.g., 'Loss/Total')"
- )
- source: DataSource = Field(
- ..., description="Origin of the raw metric data (e.g., 'trainer', 'worker')"
- )
- raw_event_name: str | None = Field(
- default=None,
- description="Specific raw event name if different from metric name (e.g., 'episode_end')",
- )
- aggregation: AggregationMethod = Field(
- default="latest",
- description="How to aggregate raw values over the logging interval ('rate' calculates per second)",
- )
- log_frequency_steps: int = Field(
- default=1,
- description="Log metric every N global steps. Set to 0 to disable step-based logging.",
- ge=0,
- )
- log_frequency_seconds: float = Field(
- default=0.0,
- description="Log metric every N seconds. Set to 0 to disable time-based logging.",
- ge=0.0,
- )
- log_to: list[LogTarget] = Field(
- default=["mlflow", "tensorboard"],
- description="Where to log the processed metric.",
- )
- x_axis: XAxis = Field(
- default="global_step", description="The primary x-axis for logging."
- )
- # Optional: For rate calculation, specify the numerator event
- rate_numerator_event: str | None = Field(
- default=None,
- description="Raw event name for the numerator in rate calculation (e.g., 'step_completed')",
- )
-
- @field_validator("log_frequency_steps", "log_frequency_seconds")
- @classmethod
- def check_log_frequency(cls, v): # Removed info argument
- """Ensure at least one frequency is set if logging is enabled."""
- # This validation is tricky as it depends on other fields.
- # We'll rely on the processor logic to handle cases where both are 0.
- return v
-
- @field_validator("rate_numerator_event")
- @classmethod
- def check_rate_config(cls, v, info):
- """Ensure numerator is specified if aggregation is 'rate'."""
- # info.data might not be available in Pydantic v2 validators
- # Access values via info.values if needed, but direct access is preferred
- if info.data.get("aggregation") == "rate" and v is None:
- metric_name = info.data.get("name", "Unknown Metric")
- raise ValueError(
- f"Metric '{metric_name}' has aggregation 'rate' but 'rate_numerator_event' is not set."
- )
- return v
-
- @property
- def event_key(self) -> str:
- """The key used to store/retrieve raw events for this metric."""
- return self.raw_event_name or self.name
-
-
-# --- Main Stats Configuration ---
-class StatsConfig(BaseModel):
- """Overall configuration for statistics collection and logging."""
-
- # Interval (in seconds) for the StatsProcessor to process collected data
- processing_interval_seconds: float = Field(
- default=2.0,
- description="How often the StatsProcessor aggregates and logs metrics.",
- gt=0,
- )
-
- # Default metrics - can be overridden or extended
- metrics: list[MetricConfig] = Field(
- default_factory=lambda: [
- # --- Trainer Metrics ---
- MetricConfig(
- name="Loss/Total",
- source="trainer",
- aggregation="mean",
- log_frequency_steps=10,
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="Loss/Policy",
- source="trainer",
- aggregation="mean",
- log_frequency_steps=10,
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="Loss/Value",
- source="trainer",
- aggregation="mean",
- log_frequency_steps=10,
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="Loss/Entropy",
- source="trainer",
- aggregation="mean",
- log_frequency_steps=10,
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="Loss/Mean_TD_Error",
- source="trainer",
- aggregation="mean",
- log_frequency_steps=10,
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="LearningRate",
- source="trainer",
- aggregation="latest",
- log_frequency_steps=10,
- log_to=["mlflow", "tensorboard"],
- ),
- # --- Worker Metrics (Aggregated) ---
- MetricConfig(
- name="Episode/Final_Score",
- source="worker",
- raw_event_name="episode_end", # Use context['score']
- aggregation="mean",
- log_frequency_steps=1, # Log every step where an episode ended
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="Episode/Length",
- source="worker",
- raw_event_name="episode_end", # Use context['length']
- aggregation="mean",
- log_frequency_steps=1,
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="MCTS/Avg_Simulations",
- source="worker",
- raw_event_name="mcts_step", # Use value from mcts_step event
- aggregation="mean",
- log_frequency_steps=50,
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="RL/Step_Reward_Mean",
- source="worker",
- raw_event_name="step_reward", # Use value from step_reward event
- aggregation="mean",
- log_frequency_steps=50,
- log_to=["mlflow", "tensorboard"],
- ),
- # --- Loop/System Metrics ---
- MetricConfig(
- name="Buffer/Size",
- source="loop",
- aggregation="latest",
- log_frequency_steps=10,
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="Progress/Total_Simulations",
- source="loop",
- aggregation="latest",
- log_frequency_steps=10,
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="Progress/Episodes_Played",
- source="loop",
- aggregation="latest",
- log_frequency_steps=10,
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="System/Num_Active_Workers",
- source="loop",
- aggregation="latest",
- log_frequency_steps=50,
- log_to=["mlflow", "tensorboard"],
- ),
- MetricConfig(
- name="System/Num_Pending_Tasks",
- source="loop",
- aggregation="latest",
- log_frequency_steps=50,
- log_to=["mlflow", "tensorboard"],
- ),
- # --- Rate Metrics ---
- MetricConfig(
- name="Rate/Steps_Per_Sec",
- source="loop",
- aggregation="rate",
- rate_numerator_event="step_completed", # Assumes loop logs this event
- log_frequency_seconds=5.0,
- log_frequency_steps=0, # Primarily time-based
- log_to=["mlflow", "tensorboard", "console"],
- ),
- MetricConfig(
- name="Rate/Episodes_Per_Sec",
- source="worker",
- aggregation="rate",
- rate_numerator_event="episode_end", # Count episode_end events
- log_frequency_seconds=5.0,
- log_frequency_steps=0,
- log_to=["mlflow", "tensorboard", "console"],
- ),
- MetricConfig(
- name="Rate/Simulations_Per_Sec",
- source="worker", # Sims happen in worker MCTS steps
- aggregation="rate",
- rate_numerator_event="mcts_step", # Sum the 'value' of mcts_step events
- log_frequency_seconds=5.0,
- log_frequency_steps=0,
- log_to=["mlflow", "tensorboard", "console"],
- ),
- # --- PER Beta ---
- MetricConfig(
- name="PER/Beta",
- source="buffer", # Or loop, depending on where it's calculated
- aggregation="latest",
- log_frequency_steps=10,
- log_to=["mlflow", "tensorboard"],
- ),
- # --- Event Markers ---
- MetricConfig(
- name="Event/Weight_Update", # Matches old key for compatibility
- source="loop",
- aggregation="count", # Count occurrences per interval
- log_frequency_steps=10, # Log counts every 10 steps
- log_to=["mlflow", "tensorboard"], # Log count, not just marker
- ),
- ]
- )
-
- @field_validator("metrics")
- @classmethod
- def check_metric_names_unique(cls, metrics: list[MetricConfig]):
- """Ensure all configured metric names are unique."""
- names = [m.name for m in metrics]
- if len(names) != len(set(names)):
- from collections import Counter
-
- duplicates = [name for name, count in Counter(names).items() if count > 1]
- raise ValueError(f"Duplicate metric names found in config: {duplicates}")
- return metrics
-
-
-# Instantiate default config
-default_stats_config = StatsConfig()
-
-# Rebuild model
-StatsConfig.model_rebuild(force=True)
-MetricConfig.model_rebuild(force=True)
diff --git a/alphatriangle/config/validation.py b/alphatriangle/config/validation.py
index f6bc9a2..8aa9aed 100644
--- a/alphatriangle/config/validation.py
+++ b/alphatriangle/config/validation.py
@@ -7,33 +7,36 @@
# Import EnvConfig from trianglengin's top level
from trianglengin import EnvConfig
-# Import the new config
+# Import AlphaTriangle configs
from .mcts_config import AlphaTriangleMCTSConfig
from .model_config import ModelConfig
-from .persistence_config import PersistenceConfig
-from .stats_config import StatsConfig # ADDED
from .train_config import TrainConfig
+# Removed PersistenceConfig, StatsConfig imports
+
logger = logging.getLogger(__name__)
def print_config_info_and_validate(
mcts_config_instance: AlphaTriangleMCTSConfig | None,
):
- """Prints configuration summary and performs validation using Pydantic."""
+ """
+ Prints configuration summary and performs validation using Pydantic
+ for AlphaTriangle-specific configurations.
+ Note: TrieyeConfig validation happens within the Trieye library.
+ """
print("-" * 40)
- print("Configuration Validation & Summary")
+ print("AlphaTriangle Configuration Validation & Summary")
print("-" * 40)
all_valid = True
configs_validated: dict[str, Any] = {}
+ # Only validate AlphaTriangle-specific configs here
config_classes: dict[str, type[BaseModel]] = {
"Environment": EnvConfig,
"Model": ModelConfig,
"Training": TrainConfig,
- "Persistence": PersistenceConfig,
"MCTS": AlphaTriangleMCTSConfig,
- "Statistics": StatsConfig, # ADDED
}
for name, ConfigClass in config_classes.items():
@@ -93,5 +96,5 @@ def print_config_info_and_validate(
logger.critical("Configuration validation failed. Please check errors above.")
raise ValueError("Invalid configuration settings.")
else:
- logger.info("All configurations validated successfully.")
+ logger.info("AlphaTriangle configurations validated successfully.")
print("-" * 40)
diff --git a/alphatriangle/data/README.md b/alphatriangle/data/README.md
deleted file mode 100644
index 10d4363..0000000
--- a/alphatriangle/data/README.md
+++ /dev/null
@@ -1,56 +0,0 @@
-# File: alphatriangle/data/README.md
-
-# Data Management Module (`alphatriangle.data`)
-
-## Purpose and Architecture
-
-This module is responsible for handling the persistence of training artifacts using structured data schemas defined with Pydantic. It manages:
-
-- Neural network checkpoints (model weights, optimizer state).
-- Experience replay buffers.
-- Statistics collector state.
-- Run configuration files.
-
-All data is stored within the `.alphatriangle_data` directory at the project root. The core component is the [`DataManager`](data_manager.py) class, which centralizes file path management and saving/loading logic based on the [`PersistenceConfig`](../config/persistence_config.py) and [`TrainConfig`](../config/train_config.py). It uses `cloudpickle` for robust serialization of complex Python objects, including Pydantic models containing tensors and deques.
-
-- **Schemas ([`schemas.py`](schemas.py)):** Defines Pydantic models (`CheckpointData`, `BufferData`, `LoadedTrainingState`) to structure the data being saved and loaded, ensuring clarity and enabling validation.
-- **Path Management ([`path_manager.py`](path_manager.py)):** The `PathManager` class handles constructing file paths within `.alphatriangle_data`, creating directories, and finding previous runs.
-- **Serialization ([`serializer.py`](serializer.py)):** The `Serializer` class handles the actual reading/writing of files using `cloudpickle` and JSON, including validation during loading.
-- **Centralization:** `DataManager` provides a single point of control for saving/loading operations.
-- **Configuration-Driven:** Uses `PersistenceConfig` and `TrainConfig` to determine save locations, filenames, and loading behavior (e.g., auto-resume).
-- **Run Management:** Organizes saved artifacts into subdirectories within `.alphatriangle_data/runs//`.
-- **State Loading:** `DataManager.load_initial_state` determines the correct files, deserializes them, validates the structure, and returns a `LoadedTrainingState` object.
-- **State Saving:** `DataManager.save_training_state` assembles data into Pydantic models, serializes them, and saves to files within the run directory.
-- **MLflow Integration:** Logs saved artifacts (checkpoints, buffers, configs) to the corresponding MLflow run (located in `.alphatriangle_data/mlruns/`) after successful local saving. Checkpoints and buffers are logged with relative paths (e.g., `checkpoints/checkpoint_step_1000.pkl`). **The run configuration JSON is logged to the root artifact directory using its filename (e.g., `configs.json`).**
-- **Buffer Content:** The saved buffer file (`buffer.pkl`) contains `Experience` tuples `(StateType, PolicyTargetMapping, n_step_return)`. The `StateType` contains processed numerical features (`grid`, `other_features`). It does **not** contain raw `GameState` objects or action sequences for full interactive replay.
-
-## Exposed Interfaces
-
-- **Classes:**
- - `DataManager`: Orchestrates saving and loading.
- - `__init__(persist_config: PersistenceConfig, train_config: TrainConfig)`
- - `load_initial_state() -> LoadedTrainingState`: Loads state, returns Pydantic model.
- - `save_training_state(...)`: Saves state using Pydantic models and cloudpickle, logs to MLflow.
- - `save_run_config(configs: Dict[str, Any])`: Saves config JSON locally and logs to MLflow.
- - `get_checkpoint_path(...) -> Path`
- - `get_buffer_path(...) -> Path`
- - `PathManager`: Manages file paths within `.alphatriangle_data`.
- - `Serializer`: Handles serialization/deserialization.
- - `CheckpointData` (from `schemas.py`): Pydantic model for checkpoint structure.
- - `BufferData` (from `schemas.py`): Pydantic model for buffer structure.
- - `LoadedTrainingState` (from `schemas.py`): Pydantic model wrapping loaded data.
-
-## Dependencies
-
-- **[`alphatriangle.config`](../config/README.md)**: `PersistenceConfig`, `TrainConfig`.
-- **[`alphatriangle.nn`](../nn/README.md)**: `NeuralNetwork`.
-- **[`alphatriangle.rl`](../rl/README.md)**: `ExperienceBuffer`.
-- **[`alphatriangle.stats`](../stats/README.md)**: `StatsCollectorActor`.
-- **[`alphatriangle.utils`](../utils/README.md)**: `Experience`, `StateType`.
-- **`torch.optim`**: `Optimizer`.
-- **Standard Libraries:** `os`, `shutil`, `logging`, `glob`, `re`, `json`, `collections.deque`, `pathlib`, `datetime`.
-- **Third-Party:** `pydantic`, `cloudpickle`, `torch`, `ray`, `mlflow`, `numpy`.
-
----
-
-**Note:** Please keep this README updated when changing the Pydantic schemas, the types of artifacts managed, the saving/loading mechanisms, or the responsibilities of the `DataManager`, `PathManager`, or `Serializer`.
\ No newline at end of file
diff --git a/alphatriangle/data/__init__.py b/alphatriangle/data/__init__.py
deleted file mode 100644
index 5796a6e..0000000
--- a/alphatriangle/data/__init__.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# File: alphatriangle/data/__init__.py
-"""
-Data management module for handling checkpoints, buffers, and potentially logs.
-Uses Pydantic schemas for data structure definition.
-"""
-
-from .data_manager import DataManager
-from .path_manager import PathManager
-from .schemas import BufferData, CheckpointData, LoadedTrainingState
-from .serializer import Serializer
-
-__all__ = [
- "DataManager",
- "PathManager",
- "Serializer",
- "CheckpointData",
- "BufferData",
- "LoadedTrainingState",
-]
diff --git a/alphatriangle/data/data_manager.py b/alphatriangle/data/data_manager.py
deleted file mode 100644
index 654f57f..0000000
--- a/alphatriangle/data/data_manager.py
+++ /dev/null
@@ -1,275 +0,0 @@
-# File: alphatriangle/data/data_manager.py
-import logging
-from pathlib import Path
-from typing import TYPE_CHECKING, Any
-
-import mlflow
-import ray # Import ray
-from pydantic import ValidationError
-
-from .path_manager import PathManager
-from .schemas import CheckpointData, LoadedTrainingState
-from .serializer import Serializer
-
-if TYPE_CHECKING:
- from torch.optim import Optimizer
-
- from ..config import PersistenceConfig, TrainConfig
- from ..nn import NeuralNetwork
- from ..rl.core.buffer import ExperienceBuffer
-
-logger = logging.getLogger(__name__)
-
-
-class DataManager:
- """
- Orchestrates loading and saving of training artifacts using PathManager and Serializer.
- Handles MLflow artifact logging. Saves step-specific files and updates links.
- """
-
- def __init__(
- self, persist_config: "PersistenceConfig", train_config: "TrainConfig"
- ):
- self.persist_config = persist_config
- self.train_config = train_config
- # Ensure PersistenceConfig reflects the current run name from TrainConfig
- self.persist_config.RUN_NAME = self.train_config.RUN_NAME
-
- self.path_manager = PathManager(self.persist_config)
- self.serializer = Serializer()
-
- self.path_manager.create_run_directories()
- logger.info(
- f"DataManager initialized. Current Run Name: {self.persist_config.RUN_NAME}. Run directory: {self.path_manager.run_base_dir}"
- )
-
- def load_initial_state(self) -> LoadedTrainingState:
- """
- Loads the initial training state using PathManager and Serializer.
- Returns a LoadedTrainingState object containing the deserialized data.
- Handles AUTO_RESUME_LATEST logic.
- """
- loaded_state = LoadedTrainingState()
- checkpoint_path = self.path_manager.determine_checkpoint_to_load(
- self.train_config.LOAD_CHECKPOINT_PATH,
- self.train_config.AUTO_RESUME_LATEST,
- )
- checkpoint_run_name: str | None = None
-
- # --- Load Checkpoint (Model + Optimizer + Stats) ---
- if checkpoint_path:
- logger.info(f"Loading checkpoint: {checkpoint_path}")
- try:
- loaded_checkpoint_model = self.serializer.load_checkpoint(
- checkpoint_path
- )
- if loaded_checkpoint_model:
- loaded_state.checkpoint_data = loaded_checkpoint_model
- checkpoint_run_name = (
- loaded_state.checkpoint_data.run_name
- ) # Store run name
- logger.info(
- f"Checkpoint loaded and validated (Run: {loaded_state.checkpoint_data.run_name}, Step: {loaded_state.checkpoint_data.global_step})"
- )
- else:
- logger.error(
- f"Loading checkpoint from {checkpoint_path} failed or returned None."
- )
- except (ValidationError, Exception) as e:
- logger.error(
- f"Error loading/validating checkpoint from {checkpoint_path}: {e}",
- exc_info=True,
- )
-
- # --- Load Buffer ---
- if self.persist_config.SAVE_BUFFER:
- buffer_path = self.path_manager.determine_buffer_to_load(
- self.train_config.LOAD_BUFFER_PATH,
- self.train_config.AUTO_RESUME_LATEST,
- checkpoint_run_name, # Pass run name from loaded checkpoint
- )
- if buffer_path:
- logger.info(f"Loading buffer: {buffer_path}")
- try:
- loaded_buffer_model = self.serializer.load_buffer(buffer_path)
- if loaded_buffer_model:
- loaded_state.buffer_data = loaded_buffer_model
- logger.info(
- f"Buffer loaded and validated. Size: {len(loaded_state.buffer_data.buffer_list)}"
- )
- else:
- logger.error(
- f"Loading buffer from {buffer_path} failed or returned None."
- )
- except (ValidationError, Exception) as e:
- logger.error(
- f"Failed to load/validate experience buffer from {buffer_path}: {e}",
- exc_info=True,
- )
-
- if not loaded_state.checkpoint_data and not loaded_state.buffer_data:
- logger.info("No checkpoint or buffer loaded. Starting fresh.")
-
- return loaded_state
-
- def save_training_state(
- self,
- nn: "NeuralNetwork",
- optimizer: "Optimizer",
- stats_collector_actor: ray.actor.ActorHandle | None,
- buffer: "ExperienceBuffer",
- global_step: int,
- episodes_played: int,
- total_simulations_run: int,
- is_best: bool = False,
- ):
- """
- Saves the training state using Serializer and PathManager.
- Saves step-specific files and updates 'latest'/'best' links.
- Logs artifacts to MLflow.
- """
- run_name = self.persist_config.RUN_NAME
- logger.info(
- f"Saving training state for run '{run_name}' at step {global_step}. Best={is_best}"
- )
-
- stats_collector_state = {}
- if stats_collector_actor is not None:
- try:
- stats_state_ref = stats_collector_actor.get_state.remote()
- stats_collector_state = ray.get(stats_state_ref, timeout=5.0)
- except Exception as e:
- logger.error(
- f"Error fetching state from StatsCollectorActor for saving: {e}",
- exc_info=True,
- )
-
- # --- Save Checkpoint ---
- saved_checkpoint_path: Path | None = None
- try:
- checkpoint_data = CheckpointData(
- run_name=run_name,
- global_step=global_step,
- episodes_played=episodes_played,
- total_simulations_run=total_simulations_run,
- model_config_dict=nn.model_config.model_dump(),
- env_config_dict=nn.env_config.model_dump(),
- model_state_dict=nn.get_weights(),
- optimizer_state_dict=self.serializer.prepare_optimizer_state(optimizer),
- stats_collector_state=stats_collector_state,
- )
- step_checkpoint_path = self.path_manager.get_checkpoint_path(
- step=global_step
- )
- self.serializer.save_checkpoint(checkpoint_data, step_checkpoint_path)
- saved_checkpoint_path = step_checkpoint_path
- self.path_manager.update_checkpoint_links(
- step_checkpoint_path, is_best=is_best
- )
- except ValidationError as e:
- logger.error(f"Failed to create CheckpointData model: {e}", exc_info=True)
- except Exception as e:
- logger.error(f"Failed to save checkpoint: {e}", exc_info=True)
-
- # --- Save Buffer ---
- saved_buffer_path: Path | None = None
- if self.persist_config.SAVE_BUFFER:
- try:
- buffer_data = self.serializer.prepare_buffer_data(buffer)
- if buffer_data:
- step_buffer_path = self.path_manager.get_buffer_path(
- step=global_step
- )
- self.serializer.save_buffer(buffer_data, step_buffer_path)
- saved_buffer_path = step_buffer_path
- self.path_manager.update_buffer_link(step_buffer_path)
- else:
- logger.warning("Buffer data preparation failed, buffer not saved.")
- except ValidationError as e:
- logger.error(f"Failed to create BufferData model: {e}", exc_info=True)
- except Exception as e:
- logger.error(f"Failed to save buffer: {e}", exc_info=True)
-
- # --- Log Artifacts to MLflow ---
- self._log_artifacts(saved_checkpoint_path, saved_buffer_path, is_best)
-
- def _log_artifacts(
- self,
- checkpoint_path: Path | None,
- buffer_path: Path | None,
- is_best: bool,
- ):
- """Logs saved step-specific files and links to MLflow using relative artifact paths."""
- try:
- # Log Checkpoint Artifacts
- if checkpoint_path and checkpoint_path.exists():
- # Define the artifact path within the MLflow run
- ckpt_artifact_dir = self.persist_config.CHECKPOINT_SAVE_DIR_NAME
- # Log the step-specific file
- mlflow.log_artifact(
- str(checkpoint_path), artifact_path=ckpt_artifact_dir
- )
- # Log the 'latest.pkl' link
- latest_path = self.path_manager.get_checkpoint_path(is_latest=True)
- if latest_path.exists():
- mlflow.log_artifact(
- str(latest_path), artifact_path=ckpt_artifact_dir
- )
- # Log the 'best.pkl' link if applicable
- if is_best:
- best_path = self.path_manager.get_checkpoint_path(is_best=True)
- if best_path.exists():
- mlflow.log_artifact(
- str(best_path), artifact_path=ckpt_artifact_dir
- )
- logger.info(
- f"Logged checkpoint artifacts (step, latest, best={is_best}) to MLflow path: {ckpt_artifact_dir}"
- )
-
- # Log Buffer Artifacts
- if buffer_path and buffer_path.exists():
- buffer_artifact_dir = self.persist_config.BUFFER_SAVE_DIR_NAME
- # Log the step-specific buffer file
- mlflow.log_artifact(str(buffer_path), artifact_path=buffer_artifact_dir)
- # Log the default buffer link ('buffer.pkl')
- default_buffer_path = self.path_manager.get_buffer_path()
- if default_buffer_path.exists():
- mlflow.log_artifact(
- str(default_buffer_path), artifact_path=buffer_artifact_dir
- )
- logger.info(
- f"Logged buffer artifacts (step, default link) to MLflow path: {buffer_artifact_dir}"
- )
- except Exception as e:
- logger.error(f"Failed to log artifacts to MLflow: {e}", exc_info=True)
-
- def save_run_config(self, configs: dict[str, Any]):
- """Saves the combined configuration dictionary as a JSON artifact."""
- try:
- config_path = self.path_manager.get_config_path()
- config_path.parent.mkdir(parents=True, exist_ok=True)
- self.serializer.save_config_json(configs, config_path)
- # Log the config file to the root of the MLflow artifacts using its filename
- mlflow.log_artifact(
- str(config_path), artifact_path=self.persist_config.CONFIG_FILENAME
- )
- logger.info(
- f"Run config saved locally to {config_path} and logged to MLflow as '{self.persist_config.CONFIG_FILENAME}'."
- )
- except Exception as e:
- logger.error(f"Failed to save/log run config JSON: {e}", exc_info=True)
-
- # --- Expose PathManager methods if needed ---
- def get_checkpoint_path(
- self,
- run_name: str | None = None,
- step: int | None = None,
- is_latest: bool = False,
- is_best: bool = False,
- ) -> Path:
- return self.path_manager.get_checkpoint_path(run_name, step, is_latest, is_best)
-
- def get_buffer_path(
- self, run_name: str | None = None, step: int | None = None
- ) -> Path:
- return self.path_manager.get_buffer_path(run_name, step)
diff --git a/alphatriangle/data/path_manager.py b/alphatriangle/data/path_manager.py
deleted file mode 100644
index 2931f68..0000000
--- a/alphatriangle/data/path_manager.py
+++ /dev/null
@@ -1,324 +0,0 @@
-import datetime
-import logging
-import re
-import shutil
-from pathlib import Path
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from ..config import PersistenceConfig
-
-logger = logging.getLogger(__name__)
-
-
-class PathManager:
- """Manages file paths, directory creation, and discovery for training runs."""
-
- def __init__(self, persist_config: "PersistenceConfig"):
- self.persist_config = persist_config
- # Resolve root data dir immediately
- self.root_data_dir = self._resolve_root_data_dir()
- self._update_paths() # Initialize paths based on config
-
- def _resolve_root_data_dir(self) -> Path:
- """Resolves ROOT_DATA_DIR to an absolute path relative to the project root."""
- project_root = Path.cwd() # Assume running from project root
- root_path = project_root / self.persist_config.ROOT_DATA_DIR
- return root_path.resolve()
-
- def _update_paths(self):
- """Updates paths based on the current RUN_NAME in persist_config."""
- self.run_base_dir = self.get_run_base_dir() # Use method to get absolute path
- self.checkpoint_dir = (
- self.run_base_dir / self.persist_config.CHECKPOINT_SAVE_DIR_NAME
- )
- self.buffer_dir = self.run_base_dir / self.persist_config.BUFFER_SAVE_DIR_NAME
- self.log_dir = self.run_base_dir / self.persist_config.LOG_DIR_NAME
- self.tb_log_dir = self.run_base_dir / self.persist_config.TENSORBOARD_DIR_NAME
- self.profile_dir = self.run_base_dir / self.persist_config.PROFILE_DIR_NAME
- self.config_path = self.run_base_dir / self.persist_config.CONFIG_FILENAME
-
- def get_runs_root_dir(self) -> Path:
- """Gets the absolute path to the directory containing all runs."""
- return self.root_data_dir / self.persist_config.RUNS_DIR_NAME
-
- def get_run_base_dir(self, run_name: str | None = None) -> Path:
- """Gets the absolute base directory path for a specific run."""
- runs_root = self.get_runs_root_dir()
- name = run_name if run_name else self.persist_config.RUN_NAME
- return runs_root / name
-
- def create_run_directories(self):
- """Creates necessary directories for the current run."""
- self.root_data_dir.mkdir(parents=True, exist_ok=True)
- self.run_base_dir.mkdir(parents=True, exist_ok=True)
- self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
- self.log_dir.mkdir(parents=True, exist_ok=True)
- self.tb_log_dir.mkdir(parents=True, exist_ok=True)
- self.profile_dir.mkdir(parents=True, exist_ok=True) # Create profile dir
- if self.persist_config.SAVE_BUFFER:
- self.buffer_dir.mkdir(parents=True, exist_ok=True)
-
- def get_checkpoint_path(
- self,
- run_name: str | None = None,
- step: int | None = None,
- is_latest: bool = False,
- is_best: bool = False,
- ) -> Path:
- """Constructs the absolute path for a checkpoint file."""
- target_run_base_dir = self.get_run_base_dir(run_name)
- checkpoint_dir = (
- target_run_base_dir / self.persist_config.CHECKPOINT_SAVE_DIR_NAME
- )
-
- if is_latest:
- filename = self.persist_config.LATEST_CHECKPOINT_FILENAME
- elif is_best:
- filename = self.persist_config.BEST_CHECKPOINT_FILENAME
- elif step is not None:
- filename = f"checkpoint_step_{step}.pkl"
- else:
- # Default to latest if no specific type is given
- filename = self.persist_config.LATEST_CHECKPOINT_FILENAME
-
- filename_pkl = Path(filename).with_suffix(".pkl")
- return checkpoint_dir / filename_pkl
-
- def get_buffer_path(
- self, run_name: str | None = None, step: int | None = None
- ) -> Path:
- """Constructs the absolute path for the replay buffer file."""
- target_run_base_dir = self.get_run_base_dir(run_name)
- buffer_dir = target_run_base_dir / self.persist_config.BUFFER_SAVE_DIR_NAME
-
- if step is not None:
- filename = f"buffer_step_{step}.pkl"
- else:
- # Default name for the main buffer link/file
- filename = self.persist_config.BUFFER_FILENAME
-
- return buffer_dir / Path(filename).with_suffix(".pkl")
-
- def get_config_path(self, run_name: str | None = None) -> Path:
- """Constructs the absolute path for the config JSON file."""
- target_run_base_dir = self.get_run_base_dir(run_name)
- return target_run_base_dir / self.persist_config.CONFIG_FILENAME
-
- def get_profile_path(
- self, worker_id: int, episode_seed: int, run_name: str | None = None
- ) -> Path:
- """Constructs the absolute path for a profile data file."""
- target_run_base_dir = self.get_run_base_dir(run_name)
- profile_dir = target_run_base_dir / self.persist_config.PROFILE_DIR_NAME
- filename = f"worker_{worker_id}_ep_{episode_seed}.prof"
- return profile_dir / filename
-
- def find_latest_run_dir(self, current_run_name: str) -> str | None:
- """
- Finds the most recent *previous* run directory based on timestamp parsing.
- Assumes run names contain a 'YYYYMMDD_HHMMSS' pattern, potentially with prefixes/suffixes.
- """
- runs_root_dir = self.get_runs_root_dir()
- potential_runs: list[tuple[datetime.datetime, str]] = []
- # Regex: Find YYYYMMDD_HHMMSS pattern anywhere in the name
- run_name_pattern = re.compile(r"(\d{8}_\d{6})")
- logger.info(f"Searching for previous runs in: {runs_root_dir}")
- logger.info(f"Current run name to exclude: {current_run_name}")
- logger.info(f"Using regex pattern: {run_name_pattern.pattern}")
-
- try:
- if not runs_root_dir.exists():
- logger.info("Runs root directory does not exist.")
- return None
-
- found_dirs = list(runs_root_dir.iterdir())
- logger.debug(
- f"Found {len(found_dirs)} items in runs directory: {[d.name for d in found_dirs]}"
- )
-
- for d in found_dirs:
- logger.debug(f"Checking item: {d.name}")
- if d.is_dir() and d.name != current_run_name:
- logger.debug(
- f" '{d.name}' is a directory and not the current run."
- )
- match = run_name_pattern.search(d.name)
- if match:
- timestamp_str = match.group(1)
- logger.debug(
- f" Regex matched! Found timestamp '{timestamp_str}' in '{d.name}'"
- )
- try:
- run_time = datetime.datetime.strptime(
- timestamp_str, "%Y%m%d_%H%M%S"
- )
- potential_runs.append((run_time, d.name))
- logger.debug(
- f" Successfully parsed timestamp and added '{d.name}' to potential runs."
- )
- except ValueError:
- logger.warning(
- f"Could not parse timestamp '{timestamp_str}' from directory name: {d.name}"
- )
- else:
- logger.debug(
- f" Directory name {d.name} did not match timestamp pattern."
- )
- elif not d.is_dir():
- logger.debug(f" '{d.name}' is not a directory.")
- else:
- logger.debug(f" '{d.name}' is the current run directory.")
-
- if not potential_runs:
- logger.info(
- "No previous run directories found matching the pattern after filtering."
- )
- return None
-
- potential_runs.sort(key=lambda item: item[0], reverse=True)
- logger.debug(
- f"Sorted potential runs (most recent first): {[(dt.strftime('%Y%m%d_%H%M%S'), name) for dt, name in potential_runs]}"
- )
-
- latest_run_name = potential_runs[0][1]
- logger.info(f"Selected latest previous run: {latest_run_name}")
- return latest_run_name
-
- except Exception as e:
- logger.error(f"Error finding latest run directory: {e}", exc_info=True)
- return None
-
- def determine_checkpoint_to_load(
- self, load_path_config: str | None, auto_resume: bool
- ) -> Path | None:
- """Determines the absolute path of the checkpoint file to load."""
- current_run_name = self.persist_config.RUN_NAME
- checkpoint_to_load: Path | None = None
-
- if load_path_config:
- load_path = Path(load_path_config).resolve() # Resolve immediately
- if load_path.exists():
- checkpoint_to_load = load_path
- logger.info(f"Using specified checkpoint path: {checkpoint_to_load}")
- else:
- logger.warning(
- f"Specified checkpoint path not found: {load_path_config}"
- )
-
- if not checkpoint_to_load and auto_resume:
- logger.info(
- f"Attempting to find latest run directory to auto-resume from (current: {current_run_name})."
- )
- latest_run_name = self.find_latest_run_dir(current_run_name)
- if latest_run_name:
- potential_latest_path = self.get_checkpoint_path(
- run_name=latest_run_name, is_latest=True
- )
- if potential_latest_path.exists():
- checkpoint_to_load = potential_latest_path.resolve()
- logger.info(
- f"Auto-resuming from latest checkpoint in previous run '{latest_run_name}': {checkpoint_to_load}"
- )
- else:
- logger.info(
- f"Latest checkpoint file ('{self.persist_config.LATEST_CHECKPOINT_FILENAME}') not found in latest run directory '{latest_run_name}'."
- )
- else:
- logger.info("Auto-resume enabled, but no previous run directory found.")
-
- if not checkpoint_to_load:
- logger.info("No checkpoint found to load. Starting training from scratch.")
-
- return checkpoint_to_load
-
- def determine_buffer_to_load(
- self,
- load_path_config: str | None,
- auto_resume: bool,
- checkpoint_run_name: str | None,
- ) -> Path | None:
- """Determines the buffer file path to load."""
- buffer_to_load: Path | None = None
-
- if load_path_config:
- load_path = Path(load_path_config).resolve() # Resolve immediately
- if load_path.exists():
- logger.info(f"Using specified buffer path: {load_path_config}")
- buffer_to_load = load_path
- else:
- logger.warning(f"Specified buffer path not found: {load_path_config}")
-
- if not buffer_to_load and checkpoint_run_name:
- potential_buffer_path = self.get_buffer_path(run_name=checkpoint_run_name)
- if potential_buffer_path.exists():
- logger.info(
- f"Loading buffer from checkpoint run '{checkpoint_run_name}' (using default link): {potential_buffer_path}"
- )
- buffer_to_load = potential_buffer_path.resolve()
- else:
- logger.info(
- f"Default buffer file ('{self.persist_config.BUFFER_FILENAME}') not found in checkpoint run directory '{checkpoint_run_name}'."
- )
-
- if not buffer_to_load and auto_resume and not checkpoint_run_name:
- latest_previous_run_name = self.find_latest_run_dir(
- self.persist_config.RUN_NAME
- )
- if latest_previous_run_name:
- potential_buffer_path = self.get_buffer_path(
- run_name=latest_previous_run_name
- )
- if potential_buffer_path.exists():
- logger.info(
- f"Auto-resuming buffer from latest previous run '{latest_previous_run_name}' (no checkpoint loaded): {potential_buffer_path}"
- )
- buffer_to_load = potential_buffer_path.resolve()
- else:
- logger.info(
- f"Default buffer file not found in latest run directory '{latest_previous_run_name}'."
- )
-
- if not buffer_to_load:
- logger.info("No suitable buffer file found to load.")
-
- return buffer_to_load
-
- def update_checkpoint_links(self, step_checkpoint_path: Path, is_best: bool):
- """Updates the 'latest' and optionally 'best' checkpoint links."""
- if not step_checkpoint_path.exists():
- logger.error(
- f"Source checkpoint path does not exist: {step_checkpoint_path}"
- )
- return
-
- latest_path = self.get_checkpoint_path(is_latest=True)
- best_path = self.get_checkpoint_path(is_best=True)
- try:
- shutil.copy2(step_checkpoint_path, latest_path)
- logger.debug(f"Updated latest checkpoint link to {step_checkpoint_path}")
- except Exception as e:
- logger.error(f"Failed to update latest checkpoint link: {e}")
- if is_best:
- try:
- shutil.copy2(step_checkpoint_path, best_path)
- logger.info(
- f"Updated best checkpoint link to step {step_checkpoint_path.stem}"
- )
- except Exception as e:
- logger.error(f"Failed to update best checkpoint link: {e}")
-
- def update_buffer_link(self, step_buffer_path: Path):
- """Updates the default buffer link ('buffer.pkl')."""
- if not step_buffer_path.exists():
- logger.error(f"Source buffer path does not exist: {step_buffer_path}")
- return
-
- default_buffer_path = self.get_buffer_path()
- try:
- shutil.copy2(step_buffer_path, default_buffer_path)
- logger.debug(f"Updated default buffer file link: {default_buffer_path}")
- except Exception as e_default:
- logger.error(
- f"Error updating default buffer file {default_buffer_path}: {e_default}"
- )
diff --git a/alphatriangle/data/schemas.py b/alphatriangle/data/schemas.py
deleted file mode 100644
index 6f0ff35..0000000
--- a/alphatriangle/data/schemas.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from typing import Any
-
-from pydantic import BaseModel, ConfigDict, Field
-
-# Use relative import
-from ..utils.types import Experience
-
-arbitrary_types_config = ConfigDict(arbitrary_types_allowed=True)
-
-
-class CheckpointData(BaseModel):
- """Pydantic model defining the structure of saved checkpoint data."""
-
- model_config = arbitrary_types_config
-
- run_name: str
- global_step: int = Field(..., ge=0)
- episodes_played: int = Field(..., ge=0)
- total_simulations_run: int = Field(..., ge=0)
- model_config_dict: dict[str, Any]
- env_config_dict: dict[str, Any]
- model_state_dict: dict[str, Any]
- optimizer_state_dict: dict[str, Any]
- stats_collector_state: dict[str, Any]
-
-
-class BufferData(BaseModel):
- """Pydantic model defining the structure of saved buffer data."""
-
- model_config = arbitrary_types_config
-
- buffer_list: list[Experience]
-
-
-class LoadedTrainingState(BaseModel):
- """Pydantic model representing the fully loaded state."""
-
- model_config = arbitrary_types_config
-
- checkpoint_data: CheckpointData | None = None
- buffer_data: BufferData | None = None
-
-
-BufferData.model_rebuild(force=True)
-LoadedTrainingState.model_rebuild(force=True)
diff --git a/alphatriangle/data/serializer.py b/alphatriangle/data/serializer.py
deleted file mode 100644
index ee0e8d1..0000000
--- a/alphatriangle/data/serializer.py
+++ /dev/null
@@ -1,243 +0,0 @@
-# File: alphatriangle/data/serializer.py
-import json
-import logging
-from collections import deque
-from pathlib import Path
-from typing import TYPE_CHECKING, Any # Import types
-
-import cloudpickle
-import numpy as np
-import torch
-from pydantic import ValidationError
-
-from ..utils.sumtree import SumTree
-from .schemas import BufferData, CheckpointData
-
-if TYPE_CHECKING:
- from torch.optim import Optimizer
-
- from ..rl.core.buffer import ExperienceBuffer
- from ..utils.types import StateType
-
-logger = logging.getLogger(__name__)
-
-
-class Serializer:
- """Handles serialization and deserialization of training data."""
-
- def _validate_experience_structure(self, exp: Any, index: int) -> bool:
- """Validates the structure of a single experience tuple."""
- try:
- if isinstance(exp, tuple) and len(exp) == 3:
- state_type: StateType = exp[0]
- policy_map = exp[1]
- value = exp[2]
- if (
- isinstance(state_type, dict)
- and "grid" in state_type
- and "other_features" in state_type
- # REMOVED: and "available_shapes_geometry" in state_type
- and isinstance(state_type["grid"], np.ndarray)
- and isinstance(state_type["other_features"], np.ndarray)
- # REMOVED: and isinstance(state_type["available_shapes_geometry"], list)
- and isinstance(policy_map, dict)
- and isinstance(value, float | int)
- ):
- return True
- else:
- logger.warning(
- f"Invalid types or missing keys in experience {index}: state_keys={list(state_type.keys()) if isinstance(state_type, dict) else type(state_type)}"
- )
- return False
- else:
- logger.warning(
- f"Experience {index} is not a tuple of length 3: type={type(exp)}"
- )
- return False
- except Exception as e:
- logger.error(
- f"Unexpected error validating experience {index}: {e}", exc_info=True
- )
- return False
-
- def load_checkpoint(self, path: Path) -> CheckpointData | None:
- """Loads and validates checkpoint data from a file."""
- try:
- with path.open("rb") as f:
- loaded_data = cloudpickle.load(f)
- if isinstance(loaded_data, CheckpointData):
- # Pydantic automatically validates on load if type matches
- return loaded_data
- else:
- logger.error(
- f"Loaded checkpoint file {path} did not contain a CheckpointData object (type: {type(loaded_data)})."
- )
- return None
- except ValidationError as e:
- logger.error(
- f"Pydantic validation failed for checkpoint {path}: {e}", exc_info=True
- )
- return None
- except FileNotFoundError:
- logger.warning(f"Checkpoint file not found: {path}")
- return None
- except Exception as e:
- logger.error(
- f"Error loading/deserializing checkpoint from {path}: {e}",
- exc_info=True,
- )
- return None
-
- def save_checkpoint(self, data: CheckpointData, path: Path):
- """Saves checkpoint data to a file using cloudpickle."""
- try:
- path.parent.mkdir(parents=True, exist_ok=True)
- with path.open("wb") as f:
- cloudpickle.dump(data, f)
- logger.info(f"Checkpoint data saved to {path}")
- except Exception as e:
- logger.error(
- f"Failed to save checkpoint file to {path}: {e}", exc_info=True
- )
- raise # Re-raise the exception
-
- def load_buffer(self, path: Path) -> BufferData | None:
- """Loads and validates buffer data from a file."""
- try:
- with path.open("rb") as f:
- loaded_data = cloudpickle.load(f)
- if isinstance(loaded_data, BufferData):
- # Perform validation on loaded experiences
- valid_experiences = []
- invalid_count = 0
- for i, exp in enumerate(loaded_data.buffer_list):
- if self._validate_experience_structure(exp, i):
- valid_experiences.append(exp)
- else:
- invalid_count += 1
- # Warning already logged in _validate_experience_structure
- if invalid_count > 0:
- logger.warning(
- f"Found {invalid_count} invalid experience structures in loaded buffer {path}."
- )
- loaded_data.buffer_list = valid_experiences
- return loaded_data
- else:
- logger.error(
- f"Loaded buffer file {path} did not contain a BufferData object (type: {type(loaded_data)})."
- )
- return None
- except ValidationError as e:
- logger.error(
- f"Pydantic validation failed for buffer {path}: {e}", exc_info=True
- )
- return None
- except FileNotFoundError:
- logger.warning(f"Buffer file not found: {path}")
- return None
- except Exception as e:
- logger.error(
- f"Failed to load/deserialize experience buffer from {path}: {e}",
- exc_info=True,
- )
- return None
-
- def save_buffer(self, data: BufferData, path: Path):
- """Saves buffer data to a file using cloudpickle."""
- try:
- path.parent.mkdir(parents=True, exist_ok=True)
- with path.open("wb") as f:
- cloudpickle.dump(data, f)
- logger.info(f"Buffer data saved to {path}")
- except Exception as e:
- logger.error(
- f"Error saving experience buffer to {path}: {e}", exc_info=True
- )
- raise # Re-raise the exception
-
- def prepare_optimizer_state(self, optimizer: "Optimizer") -> dict[str, Any]:
- """Prepares optimizer state dictionary, moving tensors to CPU."""
- optimizer_state_cpu = {}
- try:
- optimizer_state_dict = optimizer.state_dict()
-
- def move_to_cpu(item):
- if isinstance(item, torch.Tensor):
- return item.cpu()
- elif isinstance(item, dict):
- return {k: move_to_cpu(v) for k, v in item.items()}
- elif isinstance(item, list):
- return [move_to_cpu(elem) for elem in item]
- else:
- return item
-
- optimizer_state_cpu = move_to_cpu(optimizer_state_dict)
- except Exception as e:
- logger.error(f"Could not prepare optimizer state for saving: {e}")
- return optimizer_state_cpu
-
- def prepare_buffer_data(self, buffer: "ExperienceBuffer") -> BufferData | None:
- """Prepares buffer data for saving, extracting experiences."""
- try:
- if buffer.use_per:
- if hasattr(buffer, "tree") and isinstance(buffer.tree, SumTree):
- # Extract data, filtering out potential placeholders (like 0)
- buffer_list = [
- buffer.tree.data[i]
- for i in range(buffer.tree.n_entries)
- if buffer.tree.data[i] != 0
- ]
- else:
- logger.error("PER buffer tree is missing or invalid during save.")
- return None
- else:
- buffer_list = list(buffer.buffer)
-
- # Basic validation before creating BufferData
- valid_experiences = []
- invalid_count = 0
- for i, exp in enumerate(buffer_list):
- if self._validate_experience_structure(exp, i):
- valid_experiences.append(exp)
- else:
- invalid_count += 1
- # Warning logged in validation function
-
- if invalid_count > 0:
- logger.warning(
- f"Found {invalid_count} invalid experience structures before saving buffer."
- )
-
- return BufferData(buffer_list=valid_experiences)
- except Exception as e:
- logger.error(f"Error preparing buffer data for saving: {e}")
- return None
-
- def save_config_json(self, configs: dict[str, Any], path: Path):
- """Saves the configuration dictionary as JSON."""
- try:
- path.parent.mkdir(parents=True, exist_ok=True)
- with path.open("w") as f:
-
- def default_serializer(obj):
- if isinstance(obj, torch.Tensor | np.ndarray):
- return ""
- if isinstance(obj, deque):
- return list(obj)
- # Handle Path objects
- if isinstance(obj, Path):
- return str(obj)
- try:
- # Attempt standard JSON serialization first
- # Fallback to dict or str representation
- return obj.__dict__ if hasattr(obj, "__dict__") else str(obj)
- except TypeError:
- return f"