diff --git a/.coverage b/.coverage deleted file mode 100644 index f939db0..0000000 Binary files a/.coverage and /dev/null differ diff --git a/MANIFEST.in b/MANIFEST.in index 606a558..cf9999f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,6 @@ # File: MANIFEST.in +# File: alphatriangle/MANIFEST.in +# File: MANIFEST.in include README.md include LICENSE include requirements.txt @@ -7,4 +9,20 @@ graft tests include .python-version include pyproject.toml global-exclude __pycache__ -global-exclude *.py[co] \ No newline at end of file +global-exclude *.py[co] +# Remove pruned directories +prune alphatriangle/visualization +prune alphatriangle/interaction +# REMOVE environment and structs pruning (they are part of trianglengin now) +# prune alphatriangle/environment +# prune alphatriangle/structs +# Remove pruned files +global-exclude alphatriangle/app.py +# Remove pruned test directories +prune tests/visualization +prune tests/interaction +# REMOVE environment and structs test pruning +# prune tests/environment +# prune tests/structs +# Remove pruned core files +global-exclude alphatriangle/rl/core/visual_state_actor.py \ No newline at end of file diff --git a/README.md b/README.md index dc4c0fd..a113fe5 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,160 @@ -[![CI/CD Status](https://github.com/lguibr/alphatriangle/actions/workflows/ci_cd.yml/badge.svg)](https://github.com/lguibr/alphatriangle/actions/workflows/ci_cd.yml) - [![codecov](https://codecov.io/gh/lguibr/alphatriangle/graph/badge.svg?token=YOUR_CODECOV_TOKEN_HERE)](https://codecov.io/gh/lguibr/alphatriangle) - [![PyPI version](https://badge.fury.io/py/alphatriangle.svg)](https://badge.fury.io/py/alphatriangle)[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) - [![Python Version](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) +[![CI/CD Status](https://github.com/lguibr/alphatriangle/actions/workflows/ci_cd.yml/badge.svg)](https://github.com/lguibr/alphatriangle/actions/workflows/ci_cd.yml) - [![codecov](https://codecov.io/gh/lguibr/alphatriangle/graph/badge.svg?token=YOUR_CODECOV_TOKEN_HERE)](https://codecov.io/gh/lguibr/alphatriangle) - [![PyPI version](https://badge.fury.io/py/alphatriangle.svg)](https://badge.fury.io/py/alphatriangle)[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) - [![Python Version](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) +# AlphaTriangle -# AlphaTriangle AlphaTriangle Logo - ## Overview -AlphaTriangle is a project implementing an artificial intelligence agent based on AlphaZero principles to learn and play a custom puzzle game involving placing triangular shapes onto a grid. The agent learns through self-play reinforcement learning, guided by Monte Carlo Tree Search (MCTS) and a deep neural network (PyTorch). + +AlphaTriangle is a project implementing an artificial intelligence agent based on AlphaZero principles to learn and play a custom puzzle game involving placing triangular shapes onto a grid. The agent learns through **headless self-play reinforcement learning**, guided by Monte Carlo Tree Search (MCTS) and a deep neural network (PyTorch). **It uses the `trianglengin` library for core game logic.** The project includes: -* A playable version of the triangle puzzle game using Pygame. + * An implementation of the MCTS algorithm tailored for the game. * A deep neural network (policy and value heads) implemented in PyTorch, featuring convolutional layers and **optional Transformer Encoder layers**. * A reinforcement learning pipeline coordinating **parallel self-play (using Ray)**, data storage, and network training, managed by the `alphatriangle.training` module. -* Visualization tools for interactive play, debugging, and monitoring training progress (**with near real-time plot updates**). -* Experiment tracking using MLflow. -* Unit tests for core components. -* A command-line interface for easy execution. +* Experiment tracking and visualization using **MLflow** and **TensorBoard**. +* Unit tests for RL components. +* A command-line interface for running the **headless** training pipeline. + +--- + +## 🎮 The Triangle Puzzle Game Guide 🧩 + +This project trains an agent to play the game defined by the `trianglengin` library. Here's a detailed explanation of the game rules: + +### 1. Introduction: Your Mission! 🎯 + +The goal is to place colorful shapes onto a special triangular grid. By filling up lines of triangles, you make them disappear and score points! Keep placing shapes and clearing lines for as long as possible to get the highest score before the grid fills up and you run out of moves. + +### 2. The Playing Field: The Grid 🗺️ + +- **Triangle Cells:** The game board is a grid made of many small triangles. Some point UP (🔺) and some point DOWN (🔻). They alternate like a checkerboard pattern based on their row and column index (specifically, `(row + col) % 2 != 0` means UP). +- **Shape:** The grid itself is rectangular overall, but the playable area within it is typically shaped like a triangle or hexagon, wider in the middle and narrower at the top and bottom. +- **Playable Area:** You can only place shapes within the designated playable area. +- **Death Zones 💀:** Around the edges of the playable area (often at the start and end of rows), some triangles are marked as "Death Zones". You **cannot** place any part of a shape onto these triangles. They are off-limits! Think of them as the boundaries within the rectangular grid. + +### 3. Your Tools: The Shapes 🟦🟥🟩 + +- **Shape Formation:** Each shape is a collection of connected small triangles (🔺 and 🔻). They come in different colors and arrangements. Some might be a single triangle, others might be long lines, L-shapes, or more complex patterns. +- **Relative Positions:** The triangles within a shape have fixed positions _relative to each other_. When you move the shape, all its triangles move together as one block. +- **Preview Area:** You will always have **three** shapes available to choose from at any time. These are shown in a special "preview area". + +### 4. Making Your Move: Placing Shapes 🖱️➡️▦ + +This is the core action! Here's exactly how to place a shape: + +- **Step 4a: Select a Shape:** Choose one of the three shapes available in the preview area. +- **Step 4b: Aim on the Grid:** Select a target coordinate `(row, col)` on the main grid. This coordinate represents the *anchor* point for placing the shape. +- **Step 4c: The Placement Rules (MUST Follow!)** + - 📏 **Rule 1: Fit Inside Playable Area:** ALL triangles of your chosen shape must land within the playable grid area. No part of the shape can land in a Death Zone 💀. + - 🧱 **Rule 2: No Overlap:** ALL triangles of your chosen shape must land on currently _empty_ spaces on the grid. You cannot place a shape on top of triangles that are already filled with color from previous shapes. + - 📐 **Rule 3: Orientation Match!** This is crucial! + - If a part of your shape is an UP triangle (🔺), it MUST land on an UP space (🔺) on the grid. + - If a part of your shape is a DOWN triangle (🔻), it MUST land on a DOWN space (🔻) on the grid. + - 🔺➡️🔺 (OK!) + - 🔻➡️🔻 (OK!) + - 🔺➡️🔻 (INVALID! ❌) + - 🔻➡️🔺 (INVALID! ❌) +- **Step 4d: Confirm Placement:** If the chosen shape can be placed at the target coordinate according to ALL three rules, the placement is valid. The shape is now placed permanently on the grid! ✨ + +### 5. Scoring Points: How You Win! 🏆 + +You score points in two main ways: + +- **Placing Triangles:** You get a small number of points for _every single small triangle_ that makes up the shape you just placed. (e.g., placing a 3-triangle shape might give you 3 \* tiny_score points). +- **Clearing Lines:** This is where the BIG points come from! You get a much larger number of points for _every single small triangle_ that disappears when you clear a line (or multiple lines at once!). See the next section for details! + +### 6. Line Clearing Magic! ✨ (The Key to High Scores!) + +This is the most exciting part! When you place a shape, the game immediately checks if you've completed any lines. This section explains how the game _finds_ and _clears_ these lines. + +- **What Lines Can Be Cleared?** There are **three** types of lines the game looks for: + + - **Horizontal Lines ↔️:** A straight, unbroken line of filled triangles going across a single row. + - **Diagonal Lines (Top-Left to Bottom-Right) ↘️:** An unbroken diagonal line of filled triangles stepping down and to the right. + - **Diagonal Lines (Bottom-Left to Top-Right) ↗️:** An unbroken diagonal line of filled triangles stepping up and to the right. + +- **How Lines are Found: Pre-calculation of Maximal Lines** + + - **The Idea:** Instead of checking every possible line combination all the time, the game pre-calculates all *maximal* continuous lines of playable triangles when it starts. A **maximal line** is the longest possible straight segment of *playable* triangles (not in a Death Zone) in one of the three directions (Horizontal, Diagonal ↘️, Diagonal ↗️). + - **Tracing:** For every playable triangle on the grid, the game traces outwards in each of the three directions to find the full extent of the continuous playable line passing through that triangle in that direction. + - **Storing Maximal Lines:** Only the complete maximal lines found are stored. For example, if tracing finds a playable sequence `A-B-C-D`, only the line `(A,B,C,D)` is stored, not the sub-segments like `(A,B,C)` or `(B,C,D)`. These maximal lines represent the *potential* lines that can be cleared. + - **Coordinate Map:** The game also builds a map linking each playable triangle coordinate `(r, c)` to the set of maximal lines it belongs to. This allows for quick lookup. + +- **Defining the Paths (Neighbor Logic):** How does the game know which triangle is "next" when tracing? It depends on the current triangle's orientation (🔺 or 🔻) and the direction being traced: + + - **Horizontal ↔️:** + - Left Neighbor: `(r, c-1)` (Always in the same row) + - Right Neighbor: `(r, c+1)` (Always in the same row) + - **Diagonal ↘️ (TL-BR):** + - If current is 🔺 (Up): Next is `(r+1, c)` (Down triangle directly below) + - If current is 🔻 (Down): Next is `(r, c+1)` (Up triangle to the right) + - **Diagonal ↗️ (BL-TR):** + - If current is 🔻 (Down): Next is `(r-1, c)` (Up triangle directly above) + - If current is 🔺 (Up): Next is `(r, c+1)` (Down triangle to the right) + +- **Visualizing the Paths:** + + - **Horizontal ↔️:** + ``` + ... [🔻][🔺][🔻][🔺][🔻][🔺] ... (Moves left/right in the same row) + ``` + - **Diagonal ↘️ (TL-BR):** (Connects via shared horizontal edges) + ``` + ...[🔺]... + ...[🔻][🔺] ... + ... [🔻][🔺] ... + ... [🔻] ... + (Path alternates row/col increments depending on orientation) + ``` + - **Diagonal ↗️ (BL-TR):** (Connects via shared horizontal edges) + ``` + ... [🔺] ... + ... [🔺][🔻] ... + ... [🔺][🔻] ... + ... [🔻] ... + (Path alternates row/col increments depending on orientation) + ``` + +- **The "Full Line" Rule:** After you place a piece, the game looks at the coordinates `(r, c)` of the triangles you just placed. Using the pre-calculated map, it finds all the *maximal* lines that contain _any_ of those coordinates. For each of those maximal lines (that have at least 2 triangles), it checks: "Is _every single triangle coordinate_ in this maximal line now occupied?" If yes, that line is complete! (Note: Single isolated triangles don't count as clearable lines). + +- **The _Poof_! 💨:** + - If placing your shape completes one or MORE maximal lines (of any type, length >= 2) simultaneously, all the triangles in ALL completed lines vanish instantly! + - The spaces become empty again. + - You score points for _every single triangle_ that vanished. Clearing multiple lines at once is the best way to rack up points! 🥳 + +### 7. Getting New Shapes: The Refill 🪄 + +- **The Trigger:** The game only gives you new shapes when a specific condition is met. +- **The Condition:** New shapes appear **only when all three of your preview slots become empty at the exact same time.** +- **How it Happens:** This usually occurs right after you place your _last_ available shape (the third one). +- **The Refill:** As soon as the third slot becomes empty, _BAM!_ 🪄 Three brand new, randomly generated shapes instantly appear in the preview slots. +- **Important:** If you place a shape and only one or two slots are empty, you **do not** get new shapes yet. You must use up all three before the refill happens. + +### 8. The End of the Road: Game Over 😭 + +So, how does the game end? + +- **The Condition:** The game is over when you **cannot legally place _any_ of the three shapes currently available in your preview slots anywhere on the grid.** +- **The Check:** After every move (placing a shape and any resulting line clears), and after any potential shape refill, the game checks: "Is there at least one valid spot on the grid for Shape 1? OR for Shape 2? OR for Shape 3?" +- **No More Moves:** If the answer is "NO" for all three shapes (meaning none of them can be placed anywhere according to the Placement Rules), then the game immediately ends. +- **Strategy:** This means you need to be careful! Don't fill up the grid in a way that leaves no room for the types of shapes you might get later. Always try to keep options open! 🤔 + +--- ## Core Technologies * **Python 3.10+** -* **Pygame:** For game visualization and interactive modes. +* **trianglengin:** Core game engine (state, actions, rules). * **PyTorch:** For the deep learning model (CNNs, **optional Transformers**, Distributional Value Head) and training, with CUDA/MPS support. -* **NumPy:** For numerical operations, especially state representation. +* **NumPy:** For numerical operations, especially state representation (used by `trianglengin` and features). * **Ray:** For parallelizing self-play data generation and statistics collection across multiple CPU cores/processes. * **Numba:** (Optional, used in `features.grid_features`) For performance optimization of specific grid calculations. * **Cloudpickle:** For serializing the experience replay buffer and training checkpoints. -* **MLflow:** For logging parameters, metrics, and artifacts (checkpoints, buffers) during training runs. +* **MLflow:** For logging parameters, metrics, and artifacts (checkpoints, buffers) during training runs. **Provides the primary web UI dashboard for experiment management.** +* **TensorBoard:** For visualizing metrics during training (e.g., detailed loss curves). **Provides a secondary web UI dashboard, often with faster graph updates.** * **Pydantic:** For configuration management and data validation. * **Typer:** For the command-line interface. * **Pytest:** For running unit tests. @@ -40,75 +166,63 @@ The project includes: ├── .github/workflows/ # GitHub Actions CI/CD │ └── ci_cd.yml ├── .alphatriangle_data/ # Root directory for ALL persistent data (GITIGNORED) -│ ├── mlruns/ # MLflow tracking data -│ └── runs/ # Stores temporary/local artifacts per run +│ ├── mlruns/ # MLflow internal tracking data & artifact store (for UI) +│ └── runs/ # Local artifacts per run (checkpoints, buffers, TB logs, configs) │ └── / -│ ├── checkpoints/ -│ ├── buffers/ -│ ├── logs/ -│ └── configs.json -├── alphatriangle/ # Source code for the project package +│ ├── checkpoints/ # Saved model weights & optimizer states +│ ├── buffers/ # Saved experience replay buffers +│ ├── logs/ # Plain text log files for the run +│ ├── tensorboard/ # TensorBoard log files (scalars, etc.) +│ └── configs.json # Copy of run configuration +├── alphatriangle/ # Source code for the AlphaZero agent package │ ├── __init__.py -│ ├── app.py -│ ├── cli.py # CLI logic -│ ├── config/ # Pydantic configuration models +│ ├── cli.py # CLI logic (train command - headless only) +│ ├── config/ # Pydantic configuration models (MCTS, Model, Train, Persistence) │ │ └── README.md -│ ├── data/ # Data saving/loading logic +│ ├── data/ # Data saving/loading logic (DataManager, Schemas) │ │ └── README.md -│ ├── environment/ # Game rules, state, actions +│ ├── features/ # Feature extraction logic (operates on trianglengin.GameState) │ │ └── README.md -│ ├── features/ # Feature extraction logic -│ │ └── README.md -│ ├── interaction/ # User input handling -│ │ └── README.md -│ ├── mcts/ # Monte Carlo Tree Search +│ ├── mcts/ # Monte Carlo Tree Search (operates on trianglengin.GameState) │ │ └── README.md │ ├── nn/ # Neural network definition and wrapper │ │ └── README.md │ ├── rl/ # RL components (Trainer, Buffer, Worker) │ │ └── README.md -│ ├── stats/ # Statistics collection and plotting -│ │ └── README.md -│ ├── structs/ # Core data structures (Triangle, Shape) +│ ├── stats/ # Statistics collection actor (StatsCollectorActor) │ │ └── README.md -│ ├── training/ # Training orchestration (Loop, Setup, Runners) +│ ├── training/ # Training orchestration (Loop, Setup, Runner) │ │ └── README.md -│ ├── utils/ # Shared utilities and types -│ │ └── README.md -│ └── visualization/ # Pygame rendering components +│ └── utils/ # Shared utilities and types (specific to AlphaTriangle) │ └── README.md -├── tests/ # Unit tests -│ ├── ... +├── tests/ # Unit tests (for alphatriangle components) +│ ├── conftest.py +│ ├── mcts/ +│ ├── nn/ +│ ├── rl/ +│ ├── stats/ +│ └── training/ ├── .gitignore ├── .python-version ├── LICENSE # License file (MIT) ├── MANIFEST.in # Specifies files for source distribution -├── pyproject.toml # Build system & package configuration +├── pyproject.toml # Build system & package configuration (depends on trianglengin) ├── README.md # This file -├── requirements.txt # List of dependencies (also in pyproject.toml) -├── run_interactive.py # Legacy script to run interactive modes -├── run_shape_editor.py # Script to run the interactive shape definition tool -├── run_training_headless.py # Legacy script for headless training -└── run_training_visual.py # Legacy script for visual training +└── requirements.txt # List of dependencies (includes trianglengin) ``` ## Key Modules (`alphatriangle`) -* **`cli`:** Defines the command-line interface using Typer. ([`alphatriangle/cli.py`](alphatriangle/cli.py)) -* **`config`:** Centralized Pydantic configuration classes. ([`alphatriangle/config/README.md`](alphatriangle/config/README.md)) -* **`structs`:** Defines core, low-level data structures (`Triangle`, `Shape`) and constants. ([`alphatriangle/structs/README.md`](alphatriangle/structs/README.md)) -* **`environment`:** Defines the game rules, `GameState`, action encoding/decoding, and grid/shape *logic*. ([`alphatriangle/environment/README.md`](alphatriangle/environment/README.md)) -* **`features`:** Contains logic to convert `GameState` objects into numerical features (`StateType`). ([`alphatriangle/features/README.md`](alphatriangle/features/README.md)) +* **`cli`:** Defines the command-line interface using Typer (**only `train` command, headless**). ([`alphatriangle/cli.py`](alphatriangle/cli.py)) +* **`config`:** Centralized Pydantic configuration classes (excluding `EnvConfig` and `DisplayConfig`). ([`alphatriangle/config/README.md`](alphatriangle/config/README.md)) +* **`features`:** Contains logic to convert `trianglengin.GameState` objects into numerical features (`StateType`). ([`alphatriangle/features/README.md`](alphatriangle/features/README.md)) * **`nn`:** Contains the PyTorch `nn.Module` definition (`AlphaTriangleNet`) and a wrapper class (`NeuralNetwork`). ([`alphatriangle/nn/README.md`](alphatriangle/nn/README.md)) -* **`mcts`:** Implements the Monte Carlo Tree Search algorithm (`Node`, `run_mcts_simulations`). ([`alphatriangle/mcts/README.md`](alphatriangle/mcts/README.md)) -* **`rl`:** Contains RL components: `Trainer` (network updates), `ExperienceBuffer` (data storage, **supports PER**), and `SelfPlayWorker` (Ray actor for parallel self-play). ([`alphatriangle/rl/README.md`](alphatriangle/rl/README.md)) -* **`training`:** Orchestrates the training process using `TrainingLoop`, managing workers, data flow, logging, and checkpoints. Includes `runners.py` for callable training functions. ([`alphatriangle/training/README.md`](alphatriangle/training/README.md)) -* **`stats`:** Contains the `StatsCollectorActor` (Ray actor) for asynchronous statistics collection and the `Plotter` class for rendering plots. ([`alphatriangle/stats/README.md`](alphatriangle/stats/README.md)) -* **`visualization`:** Uses Pygame to render the game state, previews, HUD, plots, etc. `DashboardRenderer` handles the training visualization layout. ([`alphatriangle/visualization/README.md`](alphatriangle/visualization/README.md)) -* **`interaction`:** Handles keyboard/mouse input for interactive modes via `InputHandler`. ([`alphatriangle/interaction/README.md`](alphatriangle/interaction/README.md)) +* **`mcts`:** Implements the Monte Carlo Tree Search algorithm (`Node`, `run_mcts_simulations`), operating on `trianglengin.GameState`. ([`alphatriangle/mcts/README.md`](alphatriangle/mcts/README.md)) +* **`rl`:** Contains RL components: `Trainer` (network updates), `ExperienceBuffer` (data storage, **supports PER**), and `SelfPlayWorker` (Ray actor for parallel self-play using `trianglengin.GameState`). ([`alphatriangle/rl/README.md`](alphatriangle/rl/README.md)) +* **`training`:** Orchestrates the **headless** training process using `TrainingLoop`, managing workers, data flow, logging (to console, file, MLflow, TensorBoard), and checkpoints. Includes `runner.py` for the callable training function. ([`alphatriangle/training/README.md`](alphatriangle/training/README.md)) +* **`stats`:** Contains the `StatsCollectorActor` (Ray actor) for asynchronous statistics collection. ([`alphatriangle/stats/README.md`](alphatriangle/stats/README.md)) * **`data`:** Manages saving and loading of training artifacts (`DataManager`) using Pydantic schemas and `cloudpickle`. ([`alphatriangle/data/README.md`](alphatriangle/data/README.md)) -* **`utils`:** Provides common helper functions, shared type definitions, and geometry helpers. ([`alphatriangle/utils/README.md`](alphatriangle/utils/README.md)) -* **`app`:** Integrates components for interactive modes (`run_interactive.py`). ([`alphatriangle/app.py`](alphatriangle/app.py)) +* **`utils`:** Provides common helper functions and shared type definitions specific to the AlphaZero implementation. ([`alphatriangle/utils/README.md`](alphatriangle/utils/README.md)) ## Setup @@ -122,17 +236,27 @@ The project includes: python -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate` ``` -3. **Install the package:** +3. **Install the package (including `trianglengin`):** * **For users:** ```bash - pip install alphatriangle # Or pip install git+https://github.com/lguibr/alphatriangle.git + # This will automatically install trianglengin from PyPI if available + pip install alphatriangle + # Or install directly from Git (installs trianglengin from PyPI) + # pip install git+https://github.com/lguibr/alphatriangle.git ``` * **For developers (editable install):** - ```bash - pip install -e . - # Install dev dependencies (optional, for running tests/linting) - pip install pytest pytest-cov pytest-mock ruff mypy codecov twine build - ``` + * First, ensure `trianglengin` is installed (ideally in editable mode from its own directory if developing both): + ```bash + # From the trianglengin directory: + # pip install -e . + ``` + * Then, install `alphatriangle` in editable mode: + ```bash + # From the alphatriangle directory: + pip install -e . + # Install dev dependencies (optional, for running tests/linting) + pip install -e .[dev] # Installs dev deps from pyproject.toml + ``` *Note: Ensure you have the correct PyTorch version installed for your system (CPU/CUDA/MPS). See [pytorch.org](https://pytorch.org/). Ray may have specific system requirements.* 4. **(Optional) Add data directory to `.gitignore`:** Create or edit the `.gitignore` file in your project root and add the line: @@ -142,40 +266,42 @@ The project includes: ## Running the Code (CLI) -Use the `alphatriangle` command: +Use the `alphatriangle` command for training: * **Show Help:** ```bash alphatriangle --help ``` -* **Interactive Play Mode:** - ```bash - alphatriangle play [--seed 42] [--log-level INFO] - ``` -* **Interactive Debug Mode:** - ```bash - alphatriangle debug [--seed 42] [--log-level DEBUG] - ``` -* **Run Training (Visual Mode):** +* **Run Training (Headless Only):** ```bash alphatriangle train [--seed 42] [--log-level INFO] ``` -* **Run Training (Headless Mode):** +* **Interactive Play/Debug (Use `trianglengin` CLI):** + *Note: Interactive modes are part of the `trianglengin` library, not this `alphatriangle` package.* ```bash - alphatriangle train --headless [--seed 42] [--log-level INFO] - # or - alphatriangle train -H [--seed 42] [--log-level INFO] + # Ensure trianglengin is installed + trianglengin play [--seed 42] [--log-level INFO] + trianglengin debug [--seed 42] [--log-level DEBUG] ``` -* **Shape Editor (Run directly):** - ```bash - python run_shape_editor.py - ``` -* **Monitoring Training (MLflow UI):** - While training (headless or visual), or after runs have completed, open a separate terminal in the project root and run: - ```bash - mlflow ui --backend-store-uri file:./.alphatriangle_data/mlruns - ``` - Then navigate to `http://localhost:5000` (or the specified port) in your browser. +* **Monitoring Training (Web Dashboards):** + This project uses **MLflow** and **TensorBoard** to provide web-based dashboards for monitoring. It's recommended to run both concurrently for the best experience: + * **MLflow UI (Experiment Overview & Artifacts):** + Provides the main dashboard for comparing runs, viewing parameters, high-level metrics, and accessing saved artifacts (checkpoints, buffers). Updates occur as data is logged, but may require a browser refresh for the latest points. + ```bash + # Run from the project root directory + mlflow ui --backend-store-uri file:./.alphatriangle_data/mlruns + ``` + Access via `http://localhost:5000`. + * **TensorBoard (Near Real-Time Graphs):** + Offers more frequently updated graphs of scalar metrics (losses, rates, etc.) during a run, making it ideal for closely monitoring training progress. + ```bash + # Run from the project root directory, pointing to the *specific run's* TB logs + tensorboard --logdir .alphatriangle_data/runs//tensorboard + # Replace with the actual name (e.g., train_20240101_120000) + # You can also point to the parent 'runs' directory to see all runs: + # tensorboard --logdir .alphatriangle_data/runs + ``` + Access via `http://localhost:6006`. * **Running Unit Tests (Development):** ```bash pytest tests/ @@ -183,11 +309,13 @@ Use the `alphatriangle` command: ## Configuration -All major parameters are defined in the Pydantic classes within the `alphatriangle/config/` directory. Modify these files to experiment with different settings. The `alphatriangle/config/validation.py` script performs basic checks on startup. +All major parameters for the AlphaZero agent (MCTS, Model, Training, Persistence) are defined in the Pydantic classes within the `alphatriangle/config/` directory. Modify these files to experiment with different settings. Environment configuration (`EnvConfig`) is defined within the `trianglengin` library. ## Data Storage -All persistent data, including MLflow tracking data and run-specific artifacts, is stored within the `.alphatriangle_data/` directory in the project root, managed by the `DataManager` and MLflow. +All persistent data is stored within the `.alphatriangle_data/` directory in the project root. +* **`.alphatriangle_data/mlruns/`**: Managed by **MLflow**. Contains MLflow's internal tracking data (parameters, metrics) and its own copy of logged artifacts. This is the source for the MLflow UI. +* **`.alphatriangle_data/runs/`**: Managed by **DataManager**. Contains locally saved artifacts for each run (checkpoints, buffers, TensorBoard logs, configs) before/during logging to MLflow. This directory is used for auto-resuming and direct access to TensorBoard logs during a run. ## Maintainability diff --git a/alphatriangle/app.py b/alphatriangle/app.py deleted file mode 100644 index 331f8e6..0000000 --- a/alphatriangle/app.py +++ /dev/null @@ -1,113 +0,0 @@ -import logging - -import pygame - -from . import ( - config, - environment, - interaction, - visualization, -) - -logger = logging.getLogger(__name__) - - -class Application: - """Main application integrating visualization and interaction.""" - - def __init__(self, mode: str = "play"): - self.vis_config = config.VisConfig() - self.env_config = config.EnvConfig() - self.mode = mode - - pygame.init() - pygame.font.init() - self.screen = self._setup_screen() - self.clock = pygame.time.Clock() - self.fonts = visualization.load_fonts() - - if self.mode in ["play", "debug"]: - # Create GameState first - self.game_state = environment.GameState(self.env_config) - # Create Visualizer - self.visualizer = visualization.Visualizer( - self.screen, self.vis_config, self.env_config, self.fonts - ) - # Create InputHandler, passing GameState and Visualizer - self.input_handler = interaction.InputHandler( - self.game_state, self.visualizer, self.mode, self.env_config - ) - else: - # Handle other modes or raise error if necessary - logger.error(f"Unsupported application mode: {self.mode}") - raise ValueError(f"Unsupported application mode: {self.mode}") - - self.running = True - - def _setup_screen(self) -> pygame.Surface: - """Initializes the Pygame screen.""" - screen = pygame.display.set_mode( - (self.vis_config.SCREEN_WIDTH, self.vis_config.SCREEN_HEIGHT), - pygame.RESIZABLE, - ) - pygame.display.set_caption(f"{config.APP_NAME} - {self.mode.capitalize()} Mode") - return screen - - def run(self): - """Main application loop.""" - logger.info(f"Starting application in {self.mode} mode.") - while self.running: - # dt = ( # Unused variable - # self.clock.tick(self.vis_config.FPS) / 1000.0 - # ) # Delta time (unused currently) - self.clock.tick(self.vis_config.FPS) # Still tick the clock - - # Handle Input using InputHandler - if self.input_handler: - self.running = self.input_handler.handle_input() - if not self.running: - break # Exit loop if handle_input returns False - else: - # Fallback event handling if input_handler is not initialized (should not happen in play/debug) - for event in pygame.event.get(): - if event.type == pygame.QUIT: - self.running = False - if event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE: - self.running = False - # Basic resize handling needed even without input handler - # Combine nested if statements - if event.type == pygame.VIDEORESIZE and self.visualizer: - try: - w, h = max(320, event.w), max(240, event.h) - # Update visualizer's screen reference - self.visualizer.screen = pygame.display.set_mode( - (w, h), pygame.RESIZABLE - ) - # Invalidate visualizer's layout cache - self.visualizer.layout_rects = None - except pygame.error as e: - logger.error(f"Error resizing window: {e}") - if not self.running: - break - - # Render using Visualizer - if ( - self.mode in ["play", "debug"] - and self.visualizer - and self.game_state - and self.input_handler - ): - # Get interaction state needed for rendering from InputHandler - interaction_render_state = ( - self.input_handler.get_render_interaction_state() - ) - # Pass game state, mode, and interaction state to visualizer - self.visualizer.render( - self.game_state, - self.mode, - **interaction_render_state, # Unpack the dict as keyword arguments - ) - pygame.display.flip() # Update the full display - - logger.info("Application loop finished.") - pygame.quit() diff --git a/alphatriangle/cli.py b/alphatriangle/cli.py index 98567d1..db38316 100644 --- a/alphatriangle/cli.py +++ b/alphatriangle/cli.py @@ -1,29 +1,22 @@ -# File: alphatriangle/cli.py import logging import sys from typing import Annotated import typer -from alphatriangle import config, utils -from alphatriangle.app import Application +# Import EnvConfig from trianglengin +# Import alphatriangle specific configs and runner from alphatriangle.config import ( - MCTSConfig, PersistenceConfig, TrainConfig, ) -# --- REVERTED: Import from the re-exporting runners.py --- -from alphatriangle.training.runners import ( - run_training_headless_mode, - run_training_visual_mode, -) - -# --- END REVERTED --- +# Import the single runner function +from alphatriangle.training.runners import run_training app = typer.Typer( name="alphatriangle", - help="AlphaZero implementation for a triangle puzzle game.", + help="AlphaZero training pipeline for a triangle puzzle game (uses trianglengin).", add_completion=False, ) @@ -64,87 +57,39 @@ def setup_logging(log_level_str: str): handlers=[logging.StreamHandler(sys.stdout)], force=True, # Override existing config ) - logging.getLogger("ray").setLevel(logging.WARNING) # Keep Ray less verbose - logging.getLogger("matplotlib").setLevel( - logging.WARNING - ) # Keep Matplotlib less verbose + logging.getLogger("ray").setLevel(logging.WARNING) + # Add trianglengin logger control if needed + logging.getLogger("trianglengin").setLevel( + logging.INFO + ) # Example: Set engine log level logging.info(f"Root logger level set to {logging.getLevelName(log_level)}") -def run_interactive_mode(mode: str, seed: int, log_level: str): - """Runs the interactive application.""" - setup_logging(log_level) - logger = logging.getLogger(__name__) # Get logger after setup - logger.info(f"Running in {mode.capitalize()} mode...") - utils.set_random_seeds(seed) - - mcts_config = MCTSConfig() - config.print_config_info_and_validate(mcts_config) - - try: - app_instance = Application(mode=mode) - app_instance.run() - except ImportError as e: - logger.error(f"Runtime ImportError: {e}") - logger.error("Please ensure all dependencies are installed.") - sys.exit(1) - except Exception as e: - logger.critical(f"An unhandled error occurred: {e}", exc_info=True) - sys.exit(1) - - logger.info("Exiting.") - sys.exit(0) - - -@app.command() -def play( - log_level: LogLevelOption = "INFO", - seed: SeedOption = 42, -): - """Run the game in interactive Play mode.""" - run_interactive_mode(mode="play", seed=seed, log_level=log_level) - - -@app.command() -def debug( - log_level: LogLevelOption = "INFO", - seed: SeedOption = 42, -): - """Run the game in interactive Debug mode.""" - run_interactive_mode(mode="debug", seed=seed, log_level=log_level) +# --- REMOVED run_interactive_mode, play, debug commands --- @app.command() def train( - headless: Annotated[ - bool, - typer.Option("--headless", "-H", help="Run training without visualization."), - ] = False, + # REMOVE headless option - it's always headless now log_level: LogLevelOption = "INFO", seed: SeedOption = 42, ): - """Run the AlphaTriangle training pipeline.""" + """Run the AlphaTriangle training pipeline (headless).""" setup_logging(log_level) logger = logging.getLogger(__name__) + # Use alphatriangle configs here train_config_override = TrainConfig() persist_config_override = PersistenceConfig() train_config_override.RANDOM_SEED = seed - if headless: - logger.info("Starting training in Headless mode...") - exit_code = run_training_headless_mode( - log_level_str=log_level, - train_config_override=train_config_override, - persist_config_override=persist_config_override, - ) - else: - logger.info("Starting training in Visual mode...") - exit_code = run_training_visual_mode( - log_level_str=log_level, - train_config_override=train_config_override, - persist_config_override=persist_config_override, - ) + logger.info("Starting training...") + # Call the single runner function directly + exit_code = run_training( + log_level_str=log_level, + train_config_override=train_config_override, + persist_config_override=persist_config_override, + ) logger.info(f"Training finished with exit code {exit_code}.") sys.exit(exit_code) diff --git a/alphatriangle/config/README.md b/alphatriangle/config/README.md index ed474ca..feed61d 100644 --- a/alphatriangle/config/README.md +++ b/alphatriangle/config/README.md @@ -1,9 +1,9 @@ -# File: alphatriangle/config/README.md + # Configuration Module (`alphatriangle.config`) ## Purpose and Architecture -This module centralizes all configuration parameters for the AlphaTriangle project. It uses separate **Pydantic models** for different aspects of the application (environment, model, training, visualization, persistence) to promote modularity, clarity, and automatic validation. +This module centralizes all configuration parameters for the AlphaTriangle project. It uses separate **Pydantic models** for different aspects of the application (environment, model, training, persistence, MCTS) to promote modularity, clarity, and automatic validation. - **Modularity:** Separating configurations makes it easier to manage parameters for different components. - **Type Safety & Validation:** Using Pydantic models (`BaseModel`) provides strong type hinting, automatic parsing, and validation of configuration values based on defined types and constraints (e.g., `Field(gt=0)`). @@ -15,10 +15,9 @@ This module centralizes all configuration parameters for the AlphaTriangle proje ## Exposed Interfaces - **Pydantic Models:** - - [`EnvConfig`](env_config.py): Environment parameters (grid size, shapes). + - [`EnvConfig`](env_config.py): Environment parameters (grid size, shapes) - **Imported from `trianglengin`**. - [`ModelConfig`](model_config.py): Neural network architecture parameters. **Defaults tuned for larger capacity.** - [`TrainConfig`](train_config.py): Training loop hyperparameters (batch size, learning rate, workers, **PER settings**, etc.). **Defaults tuned for longer runs.** - - [`VisConfig`](vis_config.py): Visualization parameters (screen size, FPS, layout). - [`PersistenceConfig`](persistence_config.py): Data saving/loading parameters (directories, filenames). - [`MCTSConfig`](mcts_config.py): MCTS parameters (simulations, exploration constants, temperature). - **Constants:** @@ -31,6 +30,7 @@ This module centralizes all configuration parameters for the AlphaTriangle proje This module primarily defines configurations and relies heavily on **Pydantic**. - **`pydantic`**: The core library used for defining models and validation. +- **`trianglengin.config`**: Imports `EnvConfig`. - **Standard Libraries:** `typing`, `time`, `os`, `logging`, `pathlib`. --- diff --git a/alphatriangle/config/__init__.py b/alphatriangle/config/__init__.py index 07131f6..063e586 100644 --- a/alphatriangle/config/__init__.py +++ b/alphatriangle/config/__init__.py @@ -1,19 +1,21 @@ +from trianglengin.config import EnvConfig + from .app_config import APP_NAME -from .env_config import EnvConfig from .mcts_config import MCTSConfig from .model_config import ModelConfig from .persistence_config import PersistenceConfig from .train_config import TrainConfig from .validation import print_config_info_and_validate -from .vis_config import VisConfig + +# REMOVE DisplayConfig import __all__ = [ "APP_NAME", - "EnvConfig", + "EnvConfig", # Now imported from trianglengin "ModelConfig", "PersistenceConfig", "TrainConfig", - "VisConfig", + # "DisplayConfig", # REMOVED "MCTSConfig", "print_config_info_and_validate", ] diff --git a/alphatriangle/config/persistence_config.py b/alphatriangle/config/persistence_config.py index 8bd6ec4..8fe4207 100644 --- a/alphatriangle/config/persistence_config.py +++ b/alphatriangle/config/persistence_config.py @@ -1,3 +1,4 @@ +# File: alphatriangle/config/persistence_config.py from pathlib import Path from pydantic import BaseModel, Field, computed_field @@ -12,8 +13,9 @@ class PersistenceConfig(BaseModel): CHECKPOINT_SAVE_DIR_NAME: str = Field(default="checkpoints") BUFFER_SAVE_DIR_NAME: str = Field(default="buffers") - GAME_STATE_SAVE_DIR_NAME: str = Field(default="game_states") + # REMOVED GAME_STATE_SAVE_DIR_NAME (handled externally now) LOG_DIR_NAME: str = Field(default="logs") + TENSORBOARD_DIR_NAME: str = Field(default="tensorboard") # ADDED LATEST_CHECKPOINT_FILENAME: str = Field(default="latest.pkl") BEST_CHECKPOINT_FILENAME: str = Field(default="best.pkl") @@ -22,11 +24,10 @@ class PersistenceConfig(BaseModel): RUN_NAME: str = Field(default="default_run") - SAVE_GAME_STATES: bool = Field(default=False) - GAME_STATE_SAVE_FREQ_EPISODES: int = Field(default=5, ge=1) + # REMOVED SAVE_GAME_STATES and related freq SAVE_BUFFER: bool = Field(default=True) - BUFFER_SAVE_FREQ_STEPS: int = Field(default=10, ge=1) + BUFFER_SAVE_FREQ_STEPS: int = Field(default=1000, ge=1) # Increased default freq @computed_field # type: ignore[misc] # Decorator requires Pydantic v2 @property @@ -54,6 +55,13 @@ def get_mlflow_abs_path(self) -> str: abs_path = Path(self.ROOT_DATA_DIR).joinpath(self.MLFLOW_DIR_NAME).resolve() return str(abs_path) + def get_tensorboard_log_dir(self, run_name: str | None = None) -> str: + """Gets the directory for TensorBoard logs for a specific run.""" + run_base = self.get_run_base_dir(run_name) + if not run_base or not hasattr(self, "TENSORBOARD_DIR_NAME"): + return "" # Fallback + return str(Path(run_base) / self.TENSORBOARD_DIR_NAME) + # Ensure model is rebuilt after changes PersistenceConfig.model_rebuild(force=True) diff --git a/alphatriangle/config/train_config.py b/alphatriangle/config/train_config.py index fdc1df7..1256c6c 100644 --- a/alphatriangle/config/train_config.py +++ b/alphatriangle/config/train_config.py @@ -1,4 +1,5 @@ # File: alphatriangle/config/train_config.py +# File: alphatriangle/config/train_config.py import logging import time from typing import Literal diff --git a/alphatriangle/config/validation.py b/alphatriangle/config/validation.py index 0721fd6..4c638a1 100644 --- a/alphatriangle/config/validation.py +++ b/alphatriangle/config/validation.py @@ -3,12 +3,14 @@ from pydantic import BaseModel, ValidationError -from .env_config import EnvConfig +# Import EnvConfig from trianglengin +# REMOVE DisplayConfig import +from trianglengin.config import EnvConfig + from .mcts_config import MCTSConfig from .model_config import ModelConfig from .persistence_config import PersistenceConfig from .train_config import TrainConfig -from .vis_config import VisConfig logger = logging.getLogger(__name__) @@ -22,10 +24,11 @@ def print_config_info_and_validate(mcts_config_instance: MCTSConfig | None): configs_validated: dict[str, Any] = {} config_classes: dict[str, type[BaseModel]] = { - "Environment": EnvConfig, + "Environment": EnvConfig, # Uses trianglengin.EnvConfig "Model": ModelConfig, "Training": TrainConfig, - "Visualization": VisConfig, + # REMOVE DisplayConfig + # "Display": DisplayConfig, "Persistence": PersistenceConfig, "MCTS": MCTSConfig, } @@ -35,14 +38,17 @@ def print_config_info_and_validate(mcts_config_instance: MCTSConfig | None): try: if name == "MCTS": if mcts_config_instance is not None: + # Validate the provided instance against the class definition instance = MCTSConfig.model_validate( mcts_config_instance.model_dump() ) print(f"[{name}] - Instance provided & validated OK") else: + # Instantiate default if no instance provided instance = ConfigClass() print(f"[{name}] - Validated OK (Instantiated Default)") else: + # Instantiate default for other configs instance = ConfigClass() print(f"[{name}] - Validated OK") configs_validated[name] = instance @@ -65,8 +71,10 @@ def print_config_info_and_validate(mcts_config_instance: MCTSConfig | None): for name, instance in configs_validated.items(): print(f"--- {name} Config ---") if instance: + # Use model_dump for Pydantic v2 dump_data = instance.model_dump() for field_name, value in dump_data.items(): + # Simple representation for long lists/dicts if isinstance(value, list) and len(value) > 5: print(f" {field_name}: [List with {len(value)} items]") elif isinstance(value, dict) and len(value) > 5: diff --git a/alphatriangle/config/vis_config.py b/alphatriangle/config/vis_config.py deleted file mode 100644 index e029b08..0000000 --- a/alphatriangle/config/vis_config.py +++ /dev/null @@ -1,29 +0,0 @@ -from pydantic import BaseModel, Field - - -class VisConfig(BaseModel): - """Configuration for visualization (Pydantic model).""" - - FPS: int = Field(default=30, gt=0) - SCREEN_WIDTH: int = Field(default=1000, gt=0) - SCREEN_HEIGHT: int = Field(default=800, gt=0) - - # Layout - GRID_AREA_RATIO: float = Field(default=0.7, gt=0, le=1.0) - PREVIEW_AREA_WIDTH: int = Field(default=150, gt=0) - PADDING: int = Field(default=10, ge=0) - HUD_HEIGHT: int = Field(default=40, ge=0) - - # Fonts (sizes) - FONT_UI_SIZE: int = Field(default=24, gt=0) - FONT_SCORE_SIZE: int = Field(default=30, gt=0) - FONT_HELP_SIZE: int = Field(default=18, gt=0) - - # Preview Area - PREVIEW_PADDING: int = Field(default=5, ge=0) - PREVIEW_BORDER_WIDTH: int = Field(default=1, ge=0) - PREVIEW_SELECTED_BORDER_WIDTH: int = Field(default=3, ge=0) - PREVIEW_INNER_PADDING: int = Field(default=2, ge=0) - - -VisConfig.model_rebuild(force=True) diff --git a/alphatriangle/data/data_manager.py b/alphatriangle/data/data_manager.py index b7ab50b..a74d3cc 100644 --- a/alphatriangle/data/data_manager.py +++ b/alphatriangle/data/data_manager.py @@ -1,5 +1,3 @@ -# File: alphatriangle/data/data_manager.py -# No changes needed, already expects ActorHandle | None import logging from pathlib import Path from typing import TYPE_CHECKING, Any diff --git a/alphatriangle/environment/README.md b/alphatriangle/environment/README.md deleted file mode 100644 index 68e3002..0000000 --- a/alphatriangle/environment/README.md +++ /dev/null @@ -1,54 +0,0 @@ -# File: alphatriangle/environment/README.md -# Environment Module (`alphatriangle.environment`) - -## Purpose and Architecture - -This module defines the game world for AlphaTriangle. It encapsulates the rules, state representation, actions, and core game logic. **Crucially, this module is now independent of any feature extraction logic specific to the neural network.** Its sole focus is the simulation of the game itself. - -- **State Representation:** [`GameState`](core/game_state.py) holds the current board ([`GridData`](grid/grid_data.py)), available shapes (`List[Shape]`), score, and game status. It represents the canonical state of the game. It uses core structures like `Shape` and `Triangle` defined in [`alphatriangle.structs`](../structs/README.md). -- **Core Logic:** Submodules ([`grid`](grid/README.md), [`shapes`](shapes/README.md), [`logic`](logic/README.md)) handle specific aspects like checking valid placements, clearing lines, managing shape generation, and calculating rewards. These logic modules operate on `GridData`, `Shape`, and `Triangle`. **Shape refilling now happens in batches: all slots are refilled only when all slots become empty.** -- **Action Handling:** [`action_codec`](core/action_codec.py) provides functions to convert between a structured action (shape index, row, column) and a single integer representation used by the RL agent and MCTS. -- **Modularity:** Separating grid logic, shape logic, and core state makes the code easier to understand and modify. - -**Note:** Feature extraction (converting `GameState` to NN input tensors) is handled by the separate [`alphatriangle.features`](../features/README.md) module. Core data structures (`Triangle`, `Shape`) are defined in [`alphatriangle.structs`](../structs/README.md). - -## Exposed Interfaces - -- **Core ([`core/README.md`](core/README.md)):** - - `GameState`: The main class representing the environment state. - - `reset()` - - `step(action_index: ActionType) -> Tuple[float, bool]` - - `valid_actions() -> List[ActionType]` - - `is_over() -> bool` - - `get_outcome() -> float` - - `copy() -> GameState` - - Public attributes like `grid_data`, `shapes`, `game_score`, `current_step`, etc. - - `encode_action(shape_idx: int, r: int, c: int, config: EnvConfig) -> ActionType` - - `decode_action(action_index: ActionType, config: EnvConfig) -> Tuple[int, int, int]` -- **Grid ([`grid/README.md`](grid/README.md)):** - - `GridData`: Class holding grid triangle data and line information. - - `GridLogic`: Namespace containing functions like `link_neighbors`, `initialize_lines_and_index`, `can_place`, `check_and_clear_lines`. -- **Shapes ([`shapes/README.md`](shapes/README.md)):** - - `ShapeLogic`: Namespace containing functions like `refill_shape_slots`, `generate_random_shape`. **Includes `PREDEFINED_SHAPE_TEMPLATES` constant.** -- **Logic ([`logic/README.md`](logic/README.md)):** - - `get_valid_actions(game_state: GameState) -> List[ActionType]` - - `execute_placement(game_state: GameState, shape_idx: int, r: int, c: int, rng: random.Random) -> float` **(Triggers batch refill)** - - `calculate_reward(...) -> float` (Used internally by `execute_placement`) -- **Config:** - - `EnvConfig`: Configuration class (re-exported for convenience). - -## Dependencies - -- **[`alphatriangle.config`](../config/README.md)**: - - Uses `EnvConfig` extensively to define grid dimensions, shape slots, etc. -- **[`alphatriangle.structs`](../structs/README.md)**: - - Uses `Triangle`, `Shape`, `SHAPE_COLORS`, `NO_COLOR_ID`, `DEBUG_COLOR_ID`, `COLOR_TO_ID_MAP`. -- **[`alphatriangle.utils`](../utils/README.md)**: - - Uses `ActionType`. -- **`numpy`**: - - Used for grid representation (`GridData`). -- **Standard Libraries:** `typing`, `numpy`, `logging`, `random`, `copy`. - ---- - -**Note:** Please keep this README updated when changing game rules, state representation, action space, or the module's internal structure (especially the shape refill logic). Accurate documentation is crucial for maintainability. \ No newline at end of file diff --git a/alphatriangle/environment/__init__.py b/alphatriangle/environment/__init__.py deleted file mode 100644 index 057c86a..0000000 --- a/alphatriangle/environment/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -Environment module defining the game rules, state, actions, and logic. -This module is now independent of feature extraction for the NN. -""" - -from alphatriangle.config import EnvConfig - -from .core.action_codec import decode_action, encode_action -from .core.game_state import GameState -from .grid import logic as GridLogic -from .grid.grid_data import GridData -from .logic.actions import get_valid_actions -from .logic.step import calculate_reward, execute_placement -from .shapes import logic as ShapeLogic - -__all__ = [ - # Core - "GameState", - "encode_action", - "decode_action", - # Grid - "GridData", - "GridLogic", - # Shapes - "ShapeLogic", - # Logic - "get_valid_actions", - "execute_placement", - "calculate_reward", - # Config - "EnvConfig", -] diff --git a/alphatriangle/environment/core/README.md b/alphatriangle/environment/core/README.md deleted file mode 100644 index a651c19..0000000 --- a/alphatriangle/environment/core/README.md +++ /dev/null @@ -1,35 +0,0 @@ -# File: alphatriangle/environment/core/README.md -# Environment Core Submodule (`alphatriangle.environment.core`) - -## Purpose and Architecture - -This submodule contains the most fundamental components of the game environment: the [`GameState`](game_state.py) class and the [`action_codec`](action_codec.py). - -- **`GameState`:** This class acts as the central hub for the environment's state. It holds references to the [`GridData`](../grid/grid_data.py), the current shapes, score, game status, and other relevant information. It provides the primary interface (`reset`, `step`, `valid_actions`, `is_over`, `get_outcome`, `copy`) for agents (like MCTS or self-play workers) to interact with the game. It delegates specific logic (like placement validation, line clearing, shape generation) to other submodules ([`grid`](../grid/README.md), [`shapes`](../shapes/README.md), [`logic`](../logic/README.md)). -- **`action_codec`:** Provides simple, stateless functions (`encode_action`, `decode_action`) to translate between the agent's integer action representation and the game's internal representation (shape index, row, column). This decouples the agent's action space from the internal game logic. - -## Exposed Interfaces - -- **Classes:** - - `GameState`: The main state class (see [`alphatriangle/environment/README.md`](../README.md) for methods). -- **Functions:** - - `encode_action(shape_idx: int, r: int, c: int, config: EnvConfig) -> ActionType` - - `decode_action(action_index: ActionType, config: EnvConfig) -> Tuple[int, int, int]` - -## Dependencies - -- **[`alphatriangle.config`](../../config/README.md)**: - - `EnvConfig`: Used by `GameState` and `action_codec`. -- **[`alphatriangle.utils`](../../utils/README.md)**: - - `ActionType`: Used for method signatures and return types. -- **[`alphatriangle.environment.grid`](../grid/README.md)**: - - `GridData`, `GridLogic`: Used internally by `GameState`. -- **[`alphatriangle.environment.shapes`](../shapes/README.md)**: - - `Shape`, `ShapeLogic`: Used internally by `GameState`. -- **[`alphatriangle.environment.logic`](../logic/README.md)**: - - `get_valid_actions`, `execute_placement`: Used internally by `GameState`. -- **Standard Libraries:** `typing`, `numpy`, `logging`, `random`. - ---- - -**Note:** Please keep this README updated when modifying the core `GameState` interface or the action encoding/decoding scheme. Accurate documentation is crucial for maintainability. \ No newline at end of file diff --git a/alphatriangle/environment/core/__init__.py b/alphatriangle/environment/core/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/alphatriangle/environment/core/action_codec.py b/alphatriangle/environment/core/action_codec.py deleted file mode 100644 index 10bf1d7..0000000 --- a/alphatriangle/environment/core/action_codec.py +++ /dev/null @@ -1,35 +0,0 @@ -from ...config import EnvConfig -from ...utils.types import ActionType - - -def encode_action(shape_idx: int, r: int, c: int, config: EnvConfig) -> ActionType: - """Encodes a (shape_idx, r, c) action into a single integer.""" - if not (0 <= shape_idx < config.NUM_SHAPE_SLOTS): - raise ValueError( - f"Invalid shape index: {shape_idx}, must be < {config.NUM_SHAPE_SLOTS}" - ) - if not (0 <= r < config.ROWS): - raise ValueError(f"Invalid row index: {r}, must be < {config.ROWS}") - if not (0 <= c < config.COLS): - raise ValueError(f"Invalid column index: {c}, must be < {config.COLS}") - - action_index = shape_idx * (config.ROWS * config.COLS) + r * config.COLS + c - return action_index - - -def decode_action(action_index: ActionType, config: EnvConfig) -> tuple[int, int, int]: - """Decodes an integer action into (shape_idx, r, c).""" - # Cast ACTION_DIM to int for comparison - action_dim_int = int(config.ACTION_DIM) # type: ignore[call-overload] - if not (0 <= action_index < action_dim_int): - raise ValueError( - f"Invalid action index: {action_index}, must be < {action_dim_int}" - ) - - grid_size = config.ROWS * config.COLS - shape_idx = action_index // grid_size - remainder = action_index % grid_size - r = remainder // config.COLS - c = remainder % config.COLS - - return shape_idx, r, c diff --git a/alphatriangle/environment/core/game_state.py b/alphatriangle/environment/core/game_state.py deleted file mode 100644 index b06a327..0000000 --- a/alphatriangle/environment/core/game_state.py +++ /dev/null @@ -1,118 +0,0 @@ -import logging -import random -from typing import TYPE_CHECKING - -from ...config import EnvConfig -from ...utils.types import ActionType -from .. import shapes -from ..grid.grid_data import GridData -from ..logic.actions import get_valid_actions -from ..logic.step import execute_placement -from .action_codec import decode_action - -if TYPE_CHECKING: - from ...structs import Shape - - -logger = logging.getLogger(__name__) - - -class GameState: - """ - Represents the mutable state of the game. Does not handle NN feature extraction - or visualization/interaction-specific state. - """ - - def __init__( - self, config: EnvConfig | None = None, initial_seed: int | None = None - ): - self.env_config = config if config else EnvConfig() # type: ignore[call-arg] - self._rng = ( - random.Random(initial_seed) if initial_seed is not None else random.Random() - ) - - self.grid_data: GridData = None # type: ignore - self.shapes: list[Shape | None] = [] - self.game_score: float = 0.0 - self.game_over: bool = False - self.triangles_cleared_this_episode: int = 0 - self.pieces_placed_this_episode: int = 0 - self.current_step: int = 0 - - self.reset() - - def reset(self): - """Resets the game to the initial state.""" - self.grid_data = GridData(self.env_config) - self.shapes = [None] * self.env_config.NUM_SHAPE_SLOTS - self.game_score = 0.0 - self.triangles_cleared_this_episode = 0 - self.pieces_placed_this_episode = 0 - self.game_over = False - self.current_step = 0 - - # Call refill_shape_slots with the updated signature (no index) - shapes.refill_shape_slots(self, self._rng) - - if not self.valid_actions(): - logger.warning( - "Game is over immediately after reset (no valid initial moves)." - ) - self.game_over = True - - def step(self, action_index: ActionType) -> tuple[float, bool]: - """ - Performs one game step. - Returns: - Tuple[float, bool]: (reward, done) - """ - if self.is_over(): - logger.warning("Attempted to step in a game that is already over.") - return 0.0, True - - shape_idx, r, c = decode_action(action_index, self.env_config) - reward = execute_placement(self, shape_idx, r, c, self._rng) - self.current_step += 1 - - if not self.game_over and not self.valid_actions(): - self.game_over = True - logger.info(f"Game over detected after step {self.current_step}.") - - return reward, self.game_over - - def valid_actions(self) -> list[ActionType]: - """Returns a list of valid encoded action indices.""" - return get_valid_actions(self) - - def is_over(self) -> bool: - """Checks if the game is over.""" - return self.game_over - - def get_outcome(self) -> float: - """Returns the terminal outcome value (e.g., final score). Used by MCTS.""" - if not self.is_over(): - logger.warning("get_outcome() called on a non-terminal state.") - # Consider returning a default value or raising an error? - # Returning current score might be misleading for MCTS if not terminal. - # Let's return 0.0 as a neutral value if not over. - return 0.0 - return self.game_score - - def copy(self) -> "GameState": - """Creates a deep copy for simulations (e.g., MCTS).""" - new_state = GameState.__new__(GameState) - new_state.env_config = self.env_config - new_state._rng = random.Random() - new_state._rng.setstate(self._rng.getstate()) - new_state.grid_data = self.grid_data.deepcopy() - new_state.shapes = [s.copy() if s else None for s in self.shapes] - new_state.game_score = self.game_score - new_state.game_over = self.game_over - new_state.triangles_cleared_this_episode = self.triangles_cleared_this_episode - new_state.pieces_placed_this_episode = self.pieces_placed_this_episode - new_state.current_step = self.current_step - return new_state - - def __str__(self) -> str: - shape_strs = [str(s) if s else "None" for s in self.shapes] - return f"GameState(Step:{self.current_step}, Score:{self.game_score:.1f}, Over:{self.is_over()}, Shapes:[{', '.join(shape_strs)}])" diff --git a/alphatriangle/environment/grid/README.md b/alphatriangle/environment/grid/README.md deleted file mode 100644 index fa12a1f..0000000 --- a/alphatriangle/environment/grid/README.md +++ /dev/null @@ -1,48 +0,0 @@ -# File: alphatriangle/environment/grid/README.md -# Environment Grid Submodule (`alphatriangle.environment.grid`) - -## Purpose and Architecture - -This submodule manages the game's grid structure and related logic. It defines the triangular cells, their properties, relationships, and operations like placement validation and line clearing. - -- **Cell Representation:** The `Triangle` class (defined in [`alphatriangle.structs`](../../structs/README.md)) represents a single cell, storing its position and orientation (`is_up`). The actual state (occupied, death, color) is managed within `GridData`. -- **Grid Data Structure:** The [`GridData`](grid_data.py) class holds the grid state using efficient `numpy` arrays (`_occupied_np`, `_death_np`, `_color_id_np`). It also manages precomputed information about potential lines (sets of coordinates) for efficient clearing checks. -- **Grid Logic:** The [`logic.py`](logic.py) module (exposed as `GridLogic`) contains functions operating on `GridData`. This includes: - - Initializing the grid based on `EnvConfig` (defining death zones). - - Precomputing potential lines (`_precompute_lines`) and indexing them (`initialize_lines_and_index`) for efficient checking. - - Checking if a shape can be placed (`can_place`), **including matching triangle orientations**. - - Checking for and clearing completed lines (`check_and_clear_lines`). **This function does NOT implement gravity.** -- **Grid Features:** Note: The `grid_features.py` module, which previously provided functions to calculate scalar metrics (heights, holes, bumpiness), has been **moved** to the top-level [`alphatriangle.features`](../../features/README.md) module (`alphatriangle/features/grid_features.py`) as part of decoupling feature extraction from the core environment. - -## Exposed Interfaces - -- **Classes:** - - `GridData`: Holds the grid state using NumPy arrays. - - `__init__(config: EnvConfig)` - - `valid(r: int, c: int) -> bool` - - `is_death(r: int, c: int) -> bool` - - `is_occupied(r: int, c: int) -> bool` - - `get_color_id(r: int, c: int) -> int` - - `get_occupied_state() -> np.ndarray` - - `get_death_state() -> np.ndarray` - - `get_color_id_state() -> np.ndarray` - - `deepcopy() -> GridData` -- **Modules/Namespaces:** - - `logic` (often imported as `GridLogic`): - - `initialize_lines_and_index(grid_data: GridData)` - - `can_place(grid_data: GridData, shape: Shape, r: int, c: int) -> bool` - - `check_and_clear_lines(grid_data: GridData, newly_occupied_coords: Set[Tuple[int, int]]) -> Tuple[int, Set[Tuple[int, int]], Set[frozenset[Tuple[int, int]]]]` **(Returns: lines_cleared_count, unique_coords_cleared_set, set_of_cleared_lines_coord_sets)** - -## Dependencies - -- **[`alphatriangle.config`](../../config/README.md)**: - - `EnvConfig`: Used by `GridData` initialization and logic functions. -- **[`alphatriangle.structs`](../../structs/README.md)**: - - Uses `Triangle`, `Shape`, `NO_COLOR_ID`. -- **`numpy`**: - - Used extensively in `GridData`. -- **Standard Libraries:** `typing`, `logging`, `numpy`, `copy`. - ---- - -**Note:** Please keep this README updated when changing the grid structure, cell properties, placement rules, or line clearing logic. Accurate documentation is crucial for maintainability. \ No newline at end of file diff --git a/alphatriangle/environment/grid/__init__.py b/alphatriangle/environment/grid/__init__.py deleted file mode 100644 index d6f728e..0000000 --- a/alphatriangle/environment/grid/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# File: alphatriangle/environment/grid/__init__.py -""" -Grid submodule handling the triangular grid structure, data, and logic. -""" - -# Removed: from .triangle import Triangle -from . import logic -from .grid_data import GridData - -# DO NOT import grid_features here. It has been moved up one level -# to alphatriangle/environment/grid_features.py to break circular dependencies. - -__all__ = [ - "GridData", - "logic", -] diff --git a/alphatriangle/environment/grid/grid_data.py b/alphatriangle/environment/grid/grid_data.py deleted file mode 100644 index 23f0366..0000000 --- a/alphatriangle/environment/grid/grid_data.py +++ /dev/null @@ -1,285 +0,0 @@ -# File: alphatriangle/environment/grid/grid_data.py -import copy -import logging - -import numpy as np - -from ...config import EnvConfig -from ...structs import NO_COLOR_ID - -logger = logging.getLogger(__name__) - - -def _precompute_lines(config: EnvConfig) -> list[list[tuple[int, int]]]: - """ - Generates all potential horizontal and diagonal lines based on grid geometry. - Returns a list of lines, where each line is a list of (row, col) tuples. - This function no longer needs actual Triangle objects. - """ - lines = [] - rows, cols = config.ROWS, config.COLS - min_len = config.MIN_LINE_LENGTH - - # --- Determine playable cells based on config --- - playable_mask = np.zeros((rows, cols), dtype=bool) - for r in range(rows): - playable_width = config.COLS_PER_ROW[r] - padding = cols - playable_width - pad_left = padding // 2 - playable_start_col = pad_left - playable_end_col = pad_left + playable_width - for c in range(cols): - if playable_start_col <= c < playable_end_col: - playable_mask[r, c] = True - # --- End Playable Mask --- - - # Helper to check validity and playability - def is_valid_playable(r, c): - return 0 <= r < rows and 0 <= c < cols and playable_mask[r, c] - - # --- Trace Lines using Coordinates --- - visited_in_line: set[tuple[int, int, str]] = set() # (r, c, direction) - - for r_start in range(rows): - for c_start in range(cols): - if not is_valid_playable(r_start, c_start): - continue - - # --- Trace Horizontal --- - if (r_start, c_start, "h") not in visited_in_line: - current_line_h = [] - # Trace left - cr, cc = r_start, c_start - while is_valid_playable(cr, cc - 1): - cc -= 1 - # Trace right from the start - while is_valid_playable(cr, cc): - if (cr, cc, "h") not in visited_in_line: - current_line_h.append((cr, cc)) - visited_in_line.add((cr, cc, "h")) - else: - # If we hit a visited cell, the rest of the line was already processed - break - cc += 1 - if len(current_line_h) >= min_len: - lines.append(current_line_h) - - # --- Trace Diagonal TL-BR (Down-Right) --- - if (r_start, c_start, "d1") not in visited_in_line: - current_line_d1 = [] - # Trace backwards (Up-Left) - cr, cc = r_start, c_start - while True: - is_up = (cr + cc) % 2 != 0 - prev_r, prev_c = (cr, cc - 1) if is_up else (cr - 1, cc) - if is_valid_playable(prev_r, prev_c): - cr, cc = prev_r, prev_c - else: - break - # Trace forwards - while is_valid_playable(cr, cc): - if (cr, cc, "d1") not in visited_in_line: - current_line_d1.append((cr, cc)) - visited_in_line.add((cr, cc, "d1")) - else: - break - is_up = (cr + cc) % 2 != 0 - next_r, next_c = (cr + 1, cc) if is_up else (cr, cc + 1) - cr, cc = next_r, next_c - if len(current_line_d1) >= min_len: - lines.append(current_line_d1) - - # --- Trace Diagonal BL-TR (Up-Right) --- - if (r_start, c_start, "d2") not in visited_in_line: - current_line_d2 = [] - # Trace backwards (Down-Left) - cr, cc = r_start, c_start - while True: - is_up = (cr + cc) % 2 != 0 - prev_r, prev_c = (cr + 1, cc) if is_up else (cr, cc - 1) - if is_valid_playable(prev_r, prev_c): - cr, cc = prev_r, prev_c - else: - break - # Trace forwards - while is_valid_playable(cr, cc): - if (cr, cc, "d2") not in visited_in_line: - current_line_d2.append((cr, cc)) - visited_in_line.add((cr, cc, "d2")) - else: - break - is_up = (cr + cc) % 2 != 0 - next_r, next_c = (cr, cc + 1) if is_up else (cr - 1, cc) - cr, cc = next_r, next_c - if len(current_line_d2) >= min_len: - lines.append(current_line_d2) - # --- End Line Tracing --- - - # Remove duplicates (lines traced from different start points) - unique_lines_tuples = {tuple(sorted(line)) for line in lines} - unique_lines = [list(line_tuple) for line_tuple in unique_lines_tuples] - - # Final filter by length (should be redundant but safe) - final_lines = [line for line in unique_lines if len(line) >= min_len] - - return final_lines - - -class GridData: - """ - Holds the grid state using NumPy arrays for occupancy, death zones, and color IDs. - Manages precomputed line information based on coordinates. - """ - - def __init__(self, config: EnvConfig): - self.rows = config.ROWS - self.cols = config.COLS - self.config = config - - # --- NumPy Array State --- - self._occupied_np: np.ndarray = np.zeros((self.rows, self.cols), dtype=bool) - self._death_np: np.ndarray = np.zeros((self.rows, self.cols), dtype=bool) - # Stores color ID, NO_COLOR_ID (-1) means empty/no color - self._color_id_np: np.ndarray = np.full( - (self.rows, self.cols), NO_COLOR_ID, dtype=np.int8 - ) - # --- End NumPy Array State --- - - self._initialize_death_zone(config) - self._occupied_np[self._death_np] = True # Death cells are considered occupied - - # --- Line Information (Coordinate Based) --- - # Stores frozensets of (r, c) tuples - self.potential_lines: set[frozenset[tuple[int, int]]] = set() - # Maps (r, c) tuple to a set of line frozensets it belongs to - self._coord_to_lines_map: dict[ - tuple[int, int], set[frozenset[tuple[int, int]]] - ] = {} - # --- End Line Information --- - - self._initialize_lines_and_index() - logger.debug( - f"GridData initialized ({self.rows}x{self.cols}) using NumPy arrays. Found {len(self.potential_lines)} potential lines." - ) - - def _initialize_death_zone(self, config: EnvConfig): - """Initializes the death zone numpy array.""" - cols_per_row = config.COLS_PER_ROW - if len(cols_per_row) != self.rows: - raise ValueError( - f"COLS_PER_ROW length mismatch: {len(cols_per_row)} vs {self.rows}" - ) - - for r in range(self.rows): - playable_width = cols_per_row[r] - padding = self.cols - playable_width - pad_left = padding // 2 - playable_start_col = pad_left - playable_end_col = pad_left + playable_width - for c in range(self.cols): - if not (playable_start_col <= c < playable_end_col): - self._death_np[r, c] = True - - def _initialize_lines_and_index(self) -> None: - """ - Precomputes potential lines (as coordinate sets) and creates a map - from coordinates to the lines they belong to. - """ - self.potential_lines = set() - self._coord_to_lines_map = {} - - potential_lines_coords = _precompute_lines(self.config) - - for line_coords in potential_lines_coords: - # Filter out lines containing death cells - valid_line = True - line_coord_set: set[tuple[int, int]] = set() - for r, c in line_coords: - # Use self.valid() and self._death_np directly - if self.valid(r, c) and not self._death_np[r, c]: - line_coord_set.add((r, c)) - else: - valid_line = False - break # Skip this line if any part is invalid/death - - if valid_line and len(line_coord_set) >= self.config.MIN_LINE_LENGTH: - frozen_line = frozenset(line_coord_set) - self.potential_lines.add(frozen_line) - # Add to the reverse map - for coord in line_coord_set: - if coord not in self._coord_to_lines_map: - self._coord_to_lines_map[coord] = set() - self._coord_to_lines_map[coord].add(frozen_line) - - logger.debug( - f"Initialized {len(self.potential_lines)} potential lines and mapping for {len(self._coord_to_lines_map)} coordinates." - ) - - def valid(self, r: int, c: int) -> bool: - """Checks if coordinates are within grid bounds.""" - return 0 <= r < self.rows and 0 <= c < self.cols - - def is_death(self, r: int, c: int) -> bool: - """Checks if a cell is a death cell.""" - if not self.valid(r, c): - return True # Out of bounds is considered death - # Cast NumPy bool_ to Python bool for type consistency - return bool(self._death_np[r, c]) - - def is_occupied(self, r: int, c: int) -> bool: - """Checks if a cell is occupied (includes death cells).""" - if not self.valid(r, c): - return True # Out of bounds is considered occupied - # Cast NumPy bool_ to Python bool for type consistency - return bool(self._occupied_np[r, c]) - - def get_color_id(self, r: int, c: int) -> int: - """Gets the color ID of a cell.""" - if not self.valid(r, c): - return NO_COLOR_ID - # Cast NumPy int8 to Python int for type consistency - return int(self._color_id_np[r, c]) - - def get_occupied_state(self) -> np.ndarray: - """Returns a copy of the occupancy numpy array.""" - return self._occupied_np.copy() - - def get_death_state(self) -> np.ndarray: - """Returns a copy of the death zone numpy array.""" - return self._death_np.copy() - - def get_color_id_state(self) -> np.ndarray: - """Returns a copy of the color ID numpy array.""" - return self._color_id_np.copy() - - def deepcopy(self) -> "GridData": - """ - Creates a deep copy of the grid data using NumPy array copying - and standard dictionary/set copying for line data. - """ - new_grid = GridData.__new__( - GridData - ) # Create new instance without calling __init__ - new_grid.rows = self.rows - new_grid.cols = self.cols - new_grid.config = self.config # Config is likely immutable, shallow copy ok - - # 1. Copy NumPy arrays - new_grid._occupied_np = self._occupied_np.copy() - new_grid._death_np = self._death_np.copy() - new_grid._color_id_np = self._color_id_np.copy() - - # 2. Copy Line Data (Set of frozensets and Dict[Tuple, Set[frozenset]]) - # potential_lines contains immutable frozensets, shallow copy is fine - new_grid.potential_lines = self.potential_lines.copy() - # _coord_to_lines_map values are sets, need deepcopy - new_grid._coord_to_lines_map = copy.deepcopy(self._coord_to_lines_map) - - # No Triangle objects or neighbors to handle anymore - - return new_grid - - def __str__(self) -> str: - # Basic representation, could be enhanced to show grid visually - occupied_count = np.sum(self._occupied_np & ~self._death_np) - return f"GridData({self.rows}x{self.cols}, Occupied: {occupied_count})" diff --git a/alphatriangle/environment/grid/logic.py b/alphatriangle/environment/grid/logic.py deleted file mode 100644 index e8649da..0000000 --- a/alphatriangle/environment/grid/logic.py +++ /dev/null @@ -1,106 +0,0 @@ -# File: alphatriangle/environment/grid/logic.py -import logging -from typing import TYPE_CHECKING - -# Import NO_COLOR_ID from the structs package directly -from ...structs import NO_COLOR_ID - -if TYPE_CHECKING: - from ...structs import Shape - from .grid_data import GridData - -logger = logging.getLogger(__name__) - - -# Removed link_neighbors function as it's no longer needed - - -def can_place(grid_data: "GridData", shape: "Shape", r: int, c: int) -> bool: - """ - Checks if a shape can be placed at the specified (r, c) top-left position - on the grid, considering occupancy, death zones, and triangle orientation. - Reads state from GridData's NumPy arrays. - """ - if not shape or not shape.triangles: - return False - - for dr, dc, is_up_shape in shape.triangles: - tri_r, tri_c = r + dr, c + dc - - # Check bounds and death zone first - if not grid_data.valid(tri_r, tri_c) or grid_data._death_np[tri_r, tri_c]: - return False - - # Check occupancy - if grid_data._occupied_np[tri_r, tri_c]: - return False - - # Check orientation match - is_up_grid = (tri_r + tri_c) % 2 != 0 - if is_up_grid != is_up_shape: - # Log the mismatch for debugging the test failure - logger.debug( - f"Orientation mismatch at ({tri_r},{tri_c}): Grid is {'Up' if is_up_grid else 'Down'}, Shape requires {'Up' if is_up_shape else 'Down'}" - ) - return False - - return True - - -def check_and_clear_lines( - grid_data: "GridData", newly_occupied_coords: set[tuple[int, int]] -) -> tuple[int, set[tuple[int, int]], set[frozenset[tuple[int, int]]]]: - """ - Checks for completed lines involving the newly occupied coordinates and clears them. - Operates on GridData's NumPy arrays. - - Args: - grid_data: The GridData object (will be modified). - newly_occupied_coords: A set of (r, c) tuples that were just occupied. - - Returns: - Tuple containing: - - int: Number of lines cleared. - - set[tuple[int, int]]: Set of unique (r, c) coordinates cleared. - - set[frozenset[tuple[int, int]]]: Set containing the frozenset representations - of the actual lines that were cleared. - """ - lines_to_check: set[frozenset[tuple[int, int]]] = set() - for coord in newly_occupied_coords: - if coord in grid_data._coord_to_lines_map: - lines_to_check.update(grid_data._coord_to_lines_map[coord]) - - cleared_lines_set: set[frozenset[tuple[int, int]]] = set() - unique_coords_cleared: set[tuple[int, int]] = set() - - if not lines_to_check: - return 0, unique_coords_cleared, cleared_lines_set - - logger.debug(f"Checking {len(lines_to_check)} potential lines for completion.") - - for line_coords_fs in lines_to_check: - is_complete = True - for r_line, c_line in line_coords_fs: - # Check occupancy directly from the NumPy array - if not grid_data._occupied_np[r_line, c_line]: - is_complete = False - break - - if is_complete: - logger.debug(f"Line completed: {line_coords_fs}") - cleared_lines_set.add(line_coords_fs) - # Add coordinates from this cleared line to the set of unique cleared coordinates - unique_coords_cleared.update(line_coords_fs) - - if unique_coords_cleared: - logger.info( - f"Clearing {len(cleared_lines_set)} lines involving {len(unique_coords_cleared)} unique coordinates." - ) - # Update NumPy arrays for cleared coordinates - # Convert set to tuple of arrays for advanced indexing - if unique_coords_cleared: # Ensure set is not empty - rows_idx, cols_idx = zip(*unique_coords_cleared, strict=False) - grid_data._occupied_np[rows_idx, cols_idx] = False - grid_data._color_id_np[rows_idx, cols_idx] = NO_COLOR_ID - - return len(cleared_lines_set), unique_coords_cleared, cleared_lines_set diff --git a/alphatriangle/environment/grid/triangle.py b/alphatriangle/environment/grid/triangle.py deleted file mode 100644 index 0dd008e..0000000 --- a/alphatriangle/environment/grid/triangle.py +++ /dev/null @@ -1,45 +0,0 @@ -class Triangle: - """Represents a single triangular cell on the grid.""" - - def __init__(self, row: int, col: int, is_up: bool, is_death: bool = False): - self.row = row - self.col = col - self.is_up = is_up - self.is_death = is_death - self.is_occupied = is_death - self.color: tuple[int, int, int] | None = None - - self.neighbor_left: Triangle | None = None - self.neighbor_right: Triangle | None = None - self.neighbor_vert: Triangle | None = None - - def get_points( - self, ox: float, oy: float, cw: float, ch: float - ) -> list[tuple[float, float]]: - """Calculates vertex points for drawing, relative to origin (ox, oy).""" - x = ox + self.col * (cw * 0.75) - y = oy + self.row * ch - if self.is_up: - return [(x, y + ch), (x + cw, y + ch), (x + cw / 2, y)] - else: - return [(x, y), (x + cw, y), (x + cw / 2, y + ch)] - - def copy(self) -> "Triangle": - """Creates a copy of the Triangle object's state (neighbors are not copied).""" - new_tri = Triangle(self.row, self.col, self.is_up, self.is_death) - new_tri.is_occupied = self.is_occupied - new_tri.color = self.color - return new_tri - - def __repr__(self) -> str: - state = "D" if self.is_death else ("O" if self.is_occupied else ".") - orient = "^" if self.is_up else "v" - return f"T({self.row},{self.col} {orient}{state})" - - def __hash__(self): - return hash((self.row, self.col)) - - def __eq__(self, other): - if not isinstance(other, Triangle): - return NotImplemented - return self.row == other.row and self.col == other.col diff --git a/alphatriangle/environment/logic/README.md b/alphatriangle/environment/logic/README.md deleted file mode 100644 index 4e2bf8f..0000000 --- a/alphatriangle/environment/logic/README.md +++ /dev/null @@ -1,40 +0,0 @@ -# File: alphatriangle/environment/logic/README.md -# Environment Logic Submodule (`alphatriangle.environment.logic`) - -## Purpose and Architecture - -This submodule contains higher-level game logic that operates on the `GameState` and its components (`GridData`, `Shape`). It bridges the gap between basic actions/rules and the overall game flow. - -- **`actions.py`:** - - `get_valid_actions`: Determines all possible valid moves (shape placements) from the current `GameState` by iterating through available shapes and grid positions, checking placement validity using [`GridLogic.can_place`](../grid/logic.py). Returns a list of encoded `ActionType` integers. -- **`step.py`:** - - `execute_placement`: Performs the core logic when a shape is placed. It updates the `GridData` (occupancy and color), checks for and clears completed lines using [`GridLogic.check_and_clear_lines`](../grid/logic.py), calculates the reward for the step using `calculate_reward`, updates the game score and step counters, and **triggers a batch refill of shape slots using [`ShapeLogic.refill_shape_slots`](../shapes/logic.py) only if all slots become empty after the placement.** - - `calculate_reward`: Calculates the reward based on the number of triangles placed, triangles cleared, and whether the game ended. - -## Exposed Interfaces - -- **Functions:** - - `get_valid_actions(game_state: GameState) -> List[ActionType]` - - `execute_placement(game_state: GameState, shape_idx: int, r: int, c: int, rng: random.Random) -> float` - - `calculate_reward(placed_count: int, unique_coords_cleared: Set[Tuple[int, int]], is_game_over: bool, config: EnvConfig) -> float` - -## Dependencies - -- **[`alphatriangle.environment.core`](../core/README.md)**: - - `GameState`: The primary object operated upon. - - `ActionCodec`: Used by `get_valid_actions`. -- **[`alphatriangle.environment.grid`](../grid/README.md)**: - - `GridData`, `GridLogic`: Used for placement checks and line clearing. -- **[`alphatriangle.environment.shapes`](../shapes/README.md)**: - - `Shape`, `ShapeLogic`: Used for shape refilling. -- **[`alphatriangle.config`](../../config/README.md)**: - - `EnvConfig`: Used for reward calculation and action encoding. -- **[`alphatriangle.structs`](../../structs/README.md)**: - - `Shape`, `Triangle`, `COLOR_TO_ID_MAP`, `NO_COLOR_ID`. -- **[`alphatriangle.utils`](../../utils/README.md)**: - - `ActionType`. -- **Standard Libraries:** `typing`, `random`, `logging`. - ---- - -**Note:** Please keep this README updated when changing the logic for determining valid actions, executing placements (including reward calculation and shape refilling), or modifying dependencies. \ No newline at end of file diff --git a/alphatriangle/environment/logic/__init__.py b/alphatriangle/environment/logic/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/alphatriangle/environment/logic/actions.py b/alphatriangle/environment/logic/actions.py deleted file mode 100644 index 376d436..0000000 --- a/alphatriangle/environment/logic/actions.py +++ /dev/null @@ -1,30 +0,0 @@ -import logging -from typing import TYPE_CHECKING - -from ..core.action_codec import encode_action -from ..grid import logic as GridLogic - -if TYPE_CHECKING: - from ...utils.types import ActionType - from ..core.game_state import GameState - -logger = logging.getLogger(__name__) - - -def get_valid_actions(state: "GameState") -> list["ActionType"]: - """ - Calculates and returns a list of all valid encoded action indices - for the current game state. - """ - valid_actions: list[ActionType] = [] - for shape_idx, shape in enumerate(state.shapes): - if shape is None: - continue - - for r in range(state.env_config.ROWS): - for c in range(state.env_config.COLS): - if GridLogic.can_place(state.grid_data, shape, r, c): - action_index = encode_action(shape_idx, r, c, state.env_config) - valid_actions.append(action_index) - - return valid_actions diff --git a/alphatriangle/environment/logic/step.py b/alphatriangle/environment/logic/step.py deleted file mode 100644 index 0c2e438..0000000 --- a/alphatriangle/environment/logic/step.py +++ /dev/null @@ -1,160 +0,0 @@ -# File: alphatriangle/environment/logic/step.py -import logging -import random -from typing import TYPE_CHECKING - -# Correct import path for constants -from ...structs.constants import COLOR_TO_ID_MAP, NO_COLOR_ID -from .. import shapes as ShapeLogic - -# Import the logic submodule correctly -from ..grid import logic as GridLogic - -if TYPE_CHECKING: - from ...config import EnvConfig - from ..core.game_state import GameState - -logger = logging.getLogger(__name__) - - -def calculate_reward( - placed_count: int, - unique_coords_cleared: set[tuple[int, int]], - is_game_over: bool, - config: "EnvConfig", -) -> float: - """ - Calculates the step reward based on the new specification (v3). - - Args: - placed_count: Number of triangles successfully placed. - unique_coords_cleared: Set of unique (r, c) coordinates cleared this step. - is_game_over: Boolean indicating if the game ended *after* this step. - config: Environment configuration containing reward constants. - - Returns: - The calculated step reward. - """ - reward = 0.0 - - # 1. Placement Reward - reward += placed_count * config.REWARD_PER_PLACED_TRIANGLE - - # 2. Line Clear Reward - reward += len(unique_coords_cleared) * config.REWARD_PER_CLEARED_TRIANGLE - - # 3. Survival Reward OR Game Over Penalty - if is_game_over: - reward += config.PENALTY_GAME_OVER - else: - reward += config.REWARD_PER_STEP_ALIVE - - logger.debug( - f"Calculated Reward: Placement({placed_count * config.REWARD_PER_PLACED_TRIANGLE:.3f}) " - f"+ LineClear({len(unique_coords_cleared) * config.REWARD_PER_CLEARED_TRIANGLE:.3f}) " - f"+ {'GameOver' if is_game_over else 'Survival'}({config.PENALTY_GAME_OVER if is_game_over else config.REWARD_PER_STEP_ALIVE:.3f}) " - f"= {reward:.3f}" - ) - return reward - - -def execute_placement( - game_state: "GameState", shape_idx: int, r: int, c: int, rng: random.Random -) -> float: - """ - Places a shape, clears lines, updates game state (NumPy arrays), and calculates reward. - Handles batch refilling of shapes. - - Args: - game_state: The current game state (will be modified). - shape_idx: Index of the shape to place. - r: Target row for placement. - c: Target column for placement. - rng: Random number generator for shape refilling. - - Returns: - The reward obtained for this step. - """ - shape = game_state.shapes[shape_idx] - if not shape: - logger.error(f"Attempted to place an empty shape slot: {shape_idx}") - return 0.0 - - # Use the NumPy-based can_place from GridLogic - if not GridLogic.can_place(game_state.grid_data, shape, r, c): - logger.error(f"Invalid placement attempted: Shape {shape_idx} at ({r},{c})") - # It's possible this check fails even if valid_actions included it, - # especially if the state changed unexpectedly (e.g., in multi-threaded envs, though not the case here). - # Returning 0 reward is reasonable. - return 0.0 - - # --- Place the shape --- - placed_coords: set[tuple[int, int]] = set() - placed_count = 0 - # Get color ID from the shape's color - color_id = COLOR_TO_ID_MAP.get(shape.color, NO_COLOR_ID) - if color_id == NO_COLOR_ID: - logger.warning(f"Shape color {shape.color} not found in COLOR_TO_ID_MAP!") - # Assign a default color ID? Or handle as error? Let's use 0 for now. - color_id = 0 - - for dr, dc, _ in shape.triangles: - tri_r, tri_c = r + dr, c + dc - # Check validity using GridData method (which checks bounds) - if game_state.grid_data.valid(tri_r, tri_c): - # Check death and occupancy using NumPy arrays - if ( - not game_state.grid_data._death_np[tri_r, tri_c] - and not game_state.grid_data._occupied_np[tri_r, tri_c] - ): - # Update NumPy arrays - game_state.grid_data._occupied_np[tri_r, tri_c] = True - game_state.grid_data._color_id_np[tri_r, tri_c] = color_id - placed_coords.add((tri_r, tri_c)) - placed_count += 1 - else: - # This case should ideally not be reached if can_place passed. Log if it does. - logger.error( - f"Placement conflict at ({tri_r},{tri_c}) during execution, though can_place was true." - ) - else: - # This case should ideally not be reached if can_place passed. Log if it does. - logger.error( - f"Invalid coordinates ({tri_r},{tri_c}) encountered during placement execution." - ) - - game_state.shapes[shape_idx] = None # Remove shape from slot - game_state.pieces_placed_this_episode += 1 - - # --- Check and clear lines --- - # Use check_and_clear_lines from GridLogic - lines_cleared_count, unique_coords_cleared, _ = GridLogic.check_and_clear_lines( - game_state.grid_data, placed_coords - ) - game_state.triangles_cleared_this_episode += len(unique_coords_cleared) - - # --- Update Score (Optional tracking) --- - game_state.game_score += placed_count + len(unique_coords_cleared) * 2 - - # --- Refill shapes if all slots are empty --- - if all(s is None for s in game_state.shapes): - logger.debug("All shape slots empty, triggering batch refill.") - ShapeLogic.refill_shape_slots(game_state, rng) - - # --- Check for game over AFTER placement and refill --- - # Game is over if no valid moves remain for the *new* state - if not game_state.valid_actions(): - game_state.game_over = True - logger.info( - f"Game over detected after placing shape {shape_idx} and potential refill." - ) - - # --- Calculate Reward based on the outcome of this step --- - step_reward = calculate_reward( - placed_count=placed_count, - unique_coords_cleared=unique_coords_cleared, # Pass the set of cleared coords - is_game_over=game_state.game_over, - config=game_state.env_config, - ) - - return step_reward diff --git a/alphatriangle/environment/shapes/README.md b/alphatriangle/environment/shapes/README.md deleted file mode 100644 index 092a118..0000000 --- a/alphatriangle/environment/shapes/README.md +++ /dev/null @@ -1,35 +0,0 @@ -# File: alphatriangle/environment/shapes/README.md -# Environment Shapes Submodule (`alphatriangle.environment.shapes`) - -## Purpose and Architecture - -This submodule defines the logic for managing placeable shapes within the game environment. - -- **Shape Representation:** The `Shape` class (defined in [`alphatriangle.structs`](../../structs/README.md)) stores the geometry of a shape as a list of relative triangle coordinates (`(dr, dc, is_up)`) and its color. -- **Shape Templates:** The [`templates.py`](templates.py) file contains the `PREDEFINED_SHAPE_TEMPLATES` list, which defines the geometry of all possible shapes used in the game. **This list should not be modified.** -- **Shape Logic:** The [`logic.py`](logic.py) module (exposed as `ShapeLogic`) contains functions related to shapes: - - `generate_random_shape`: Creates a new `Shape` instance by randomly selecting a template from `PREDEFINED_SHAPE_TEMPLATES` and assigning a random color (using `SHAPE_COLORS` from [`alphatriangle.structs`](../../structs/README.md)). - - `refill_shape_slots`: **Refills ALL empty shape slots** in the `GameState`, but **only if ALL slots are currently empty**. This implements batch refilling. - -## Exposed Interfaces - -- **Modules/Namespaces:** - - `logic` (often imported as `ShapeLogic`): - - `generate_random_shape(rng: random.Random) -> Shape` - - `refill_shape_slots(game_state: GameState, rng: random.Random)` **(Refills all slots only if all are empty)** -- **Constants:** - - `PREDEFINED_SHAPE_TEMPLATES` (from `templates.py`): The list of shape geometries. - -## Dependencies - -- **[`alphatriangle.environment.core`](../core/README.md)**: - - `GameState`: Used by `ShapeLogic.refill_shape_slots` to access and modify the list of available shapes. -- **[`alphatriangle.config`](../../config/README.md)**: - - `EnvConfig`: Accessed via `GameState` (e.g., for `NUM_SHAPE_SLOTS`). -- **[`alphatriangle.structs`](../../structs/README.md)**: - - Uses `Shape`, `Triangle`, `SHAPE_COLORS`. -- **Standard Libraries:** `typing`, `random`, `logging`. - ---- - -**Note:** Please keep this README updated when changing the shape generation algorithm or the logic for managing shape slots in the game state (especially the batch refill mechanism). Accurate documentation is crucial for maintainability. **Do not modify `templates.py`.** \ No newline at end of file diff --git a/alphatriangle/environment/shapes/__init__.py b/alphatriangle/environment/shapes/__init__.py deleted file mode 100644 index 54352e4..0000000 --- a/alphatriangle/environment/shapes/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Shapes submodule handling shape generation and management. -""" - -from .logic import ( - generate_random_shape, - get_neighbors, - is_shape_connected, - refill_shape_slots, -) -from .templates import PREDEFINED_SHAPE_TEMPLATES - -__all__ = [ - "generate_random_shape", - "refill_shape_slots", - "is_shape_connected", - "get_neighbors", - "PREDEFINED_SHAPE_TEMPLATES", -] diff --git a/alphatriangle/environment/shapes/logic.py b/alphatriangle/environment/shapes/logic.py deleted file mode 100644 index eb6be57..0000000 --- a/alphatriangle/environment/shapes/logic.py +++ /dev/null @@ -1,74 +0,0 @@ -# File: alphatriangle/environment/shapes/logic.py -import logging -import random -from typing import TYPE_CHECKING - -from ...structs import SHAPE_COLORS, Shape -from .templates import PREDEFINED_SHAPE_TEMPLATES - -if TYPE_CHECKING: - from ..core.game_state import GameState - -logger = logging.getLogger(__name__) - - -def generate_random_shape(rng: random.Random) -> Shape: - """Generates a random shape from predefined templates and colors.""" - template = rng.choice(PREDEFINED_SHAPE_TEMPLATES) - color = rng.choice(SHAPE_COLORS) - return Shape(template, color) - - -def refill_shape_slots(game_state: "GameState", rng: random.Random) -> None: - """ - Refills ALL empty shape slots in the GameState, but ONLY if ALL slots are currently empty. - This implements batch refilling. - """ - # --- CHANGED: Check if ALL slots are None --- - if all(shape is None for shape in game_state.shapes): - logger.debug("All shape slots are empty. Refilling all slots.") - for i in range(game_state.env_config.NUM_SHAPE_SLOTS): - game_state.shapes[i] = generate_random_shape(rng) - logger.debug(f"Refilled slot {i} with {game_state.shapes[i]}") - else: - logger.debug("Not all shape slots are empty. Skipping refill.") - # --- END CHANGED --- - - -def get_neighbors(r: int, c: int, is_up: bool) -> list[tuple[int, int]]: - """Gets potential neighbor coordinates for connectivity check.""" - if is_up: - # Up triangle neighbors: (r, c-1), (r, c+1), (r+1, c) - return [(r, c - 1), (r, c + 1), (r + 1, c)] - else: - # Down triangle neighbors: (r, c-1), (r, c+1), (r-1, c) - return [(r, c - 1), (r, c + 1), (r - 1, c)] - - -def is_shape_connected(shape: Shape) -> bool: - """Checks if all triangles in a shape are connected.""" - if not shape.triangles or len(shape.triangles) <= 1: - return True - - coords_set = {(r, c) for r, c, _ in shape.triangles} - start_node = shape.triangles[0][:2] # (r, c) of the first triangle - visited: set[tuple[int, int]] = set() - queue = [start_node] - visited.add(start_node) - - while queue: - current_r, current_c = queue.pop(0) - # Find the orientation of the current triangle in the shape list - current_is_up = False - for r, c, is_up in shape.triangles: - if r == current_r and c == current_c: - current_is_up = is_up - break - - for nr, nc in get_neighbors(current_r, current_c, current_is_up): - neighbor_coord = (nr, nc) - if neighbor_coord in coords_set and neighbor_coord not in visited: - visited.add(neighbor_coord) - queue.append(neighbor_coord) - - return len(visited) == len(coords_set) diff --git a/alphatriangle/environment/shapes/shape.py b/alphatriangle/environment/shapes/shape.py deleted file mode 100644 index 97b18cd..0000000 --- a/alphatriangle/environment/shapes/shape.py +++ /dev/null @@ -1,26 +0,0 @@ -class Shape: - """Represents a polyomino-like shape made of triangles.""" - - def __init__( - self, triangles: list[tuple[int, int, bool]], color: tuple[int, int, int] - ): - self.triangles: list[tuple[int, int, bool]] = triangles - self.color: tuple[int, int, int] = color - - def bbox(self) -> tuple[int, int, int, int]: - """Calculates bounding box (min_r, min_c, max_r, max_c) in relative coords.""" - if not self.triangles: - return (0, 0, 0, 0) - rows = [t[0] for t in self.triangles] - cols = [t[1] for t in self.triangles] - return (min(rows), min(cols), max(rows), max(cols)) - - def copy(self) -> "Shape": - """Creates a shallow copy (triangle list is copied, color is shared).""" - new_shape = Shape.__new__(Shape) - new_shape.triangles = list(self.triangles) - new_shape.color = self.color - return new_shape - - def __str__(self) -> str: - return f"Shape(Color:{self.color}, Tris:{len(self.triangles)})" diff --git a/alphatriangle/environment/shapes/templates.py b/alphatriangle/environment/shapes/templates.py deleted file mode 100644 index bf5b015..0000000 --- a/alphatriangle/environment/shapes/templates.py +++ /dev/null @@ -1,586 +0,0 @@ -# ============================================================================== -# == PREDEFINED SHAPE TEMPLATES == -# == == -# == DO NOT MODIFY THIS LIST MANUALLY unless you are absolutely sure! == -# == These shapes are fundamental to the game's design and balance. == -# == Modifying them can have unintended consequences on gameplay and agent == -# == training. == -# ============================================================================== - -# List of predefined shape templates. Each template is a list of relative triangle coordinates (dr, dc, is_up). -# Coordinates are relative to the shape's origin (typically the top-leftmost triangle). -# is_up = True for upward-pointing triangle, False for downward-pointing. -PREDEFINED_SHAPE_TEMPLATES: list[list[tuple[int, int, bool]]] = [ - [ # Shape 1 - ( - 0, - 0, - True, - ) - ], - [ # Shape 1 - ( - 0, - 0, - True, - ) - ], - [ # Shape 2 - ( - 0, - 0, - True, - ), - ( - 1, - 0, - False, - ), - ], - [ # Shape 2 - ( - 0, - 0, - True, - ), - ( - 1, - 0, - False, - ), - ], - [ # Shape 3 - ( - 0, - 0, - False, - ) - ], - [ # Shape 4 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ], - [ # Shape 4 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ], - [ # Shape 5 - ( - 0, - 0, - False, - ), - ( - 0, - 1, - True, - ), - ], - [ # Shape 5 - ( - 0, - 0, - False, - ), - ( - 0, - 1, - True, - ), - ], - [ # Shape 6 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ( - 0, - 2, - True, - ), - ], - [ # Shape 7 - ( - 0, - 0, - False, - ), - ( - 0, - 1, - True, - ), - ( - 0, - 2, - False, - ), - ], - [ # Shape 8 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ( - 0, - 2, - True, - ), - ( - 1, - 0, - False, - ), - ], - [ # Shape 9 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ( - 0, - 2, - True, - ), - ( - 1, - 2, - False, - ), - ], - [ # Shape 10 - ( - 0, - 0, - False, - ), - ( - 0, - 1, - True, - ), - ( - 1, - 0, - True, - ), - ( - 1, - 1, - False, - ), - ], - [ # Shape 11 - ( - 0, - 0, - True, - ), - ( - 0, - 2, - True, - ), - ( - 1, - 0, - False, - ), - ( - 1, - 1, - True, - ), - ( - 1, - 2, - False, - ), - ], - [ # Shape 12 - ( - 0, - 0, - True, - ), - ( - 1, - -2, - False, - ), - ( - 1, - -1, - True, - ), - ( - 1, - 0, - False, - ), - ], - [ # Shape 13 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ( - 1, - 0, - False, - ), - ( - 1, - 1, - True, - ), - ], - [ # Shape 14 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ( - 1, - 0, - False, - ), - ( - 1, - 1, - True, - ), - ( - 1, - 2, - False, - ), - ], - [ # Shape 15 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ( - 0, - 2, - True, - ), - ( - 1, - 0, - False, - ), - ( - 1, - 1, - True, - ), - ], - [ # Shape 16 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ( - 0, - 2, - True, - ), - ( - 1, - 0, - False, - ), - ( - 1, - 2, - False, - ), - ], - [ # Shape 17 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ( - 0, - 2, - True, - ), - ( - 1, - 1, - True, - ), - ( - 1, - 2, - False, - ), - ], - [ # Shape 18 - ( - 0, - 0, - True, - ), - ( - 0, - 2, - True, - ), - ( - 1, - 0, - False, - ), - ( - 1, - 1, - True, - ), - ( - 1, - 2, - False, - ), - ], - [ # Shape 19 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ( - 1, - 0, - False, - ), - ( - 1, - 1, - True, - ), - ( - 1, - 2, - False, - ), - ], - [ # Shape 20 - ( - 0, - 0, - False, - ), - ( - 0, - 1, - True, - ), - ( - 1, - 1, - False, - ), - ], - [ # Shape 21 - ( - 0, - 0, - True, - ), - ( - 1, - -1, - True, - ), - ( - 1, - 0, - False, - ), - ], - [ # Shape 22 - ( - 0, - 0, - True, - ), - ( - 1, - 0, - False, - ), - ( - 1, - 1, - True, - ), - ], - [ # Shape 23 - ( - 0, - 0, - True, - ), - ( - 1, - -1, - True, - ), - ( - 1, - 0, - False, - ), - ( - 1, - 1, - True, - ), - ], - [ # Shape 24 - ( - 0, - 0, - True, - ), - ( - 1, - -1, - True, - ), - ( - 1, - 0, - False, - ), - ], - [ # Shape 25 - ( - 0, - 0, - False, - ), - ( - 0, - 1, - True, - ), - ( - 0, - 2, - False, - ), - ( - 1, - 1, - False, - ), - ], - [ # Shape 26 - ( - 0, - 0, - False, - ), - ( - 0, - 1, - True, - ), - ( - 1, - 1, - False, - ), - ], - [ # Shape 27 - ( - 0, - 0, - True, - ), - ( - 0, - 1, - False, - ), - ( - 1, - 0, - False, - ), - ], -] diff --git a/alphatriangle/features/extractor.py b/alphatriangle/features/extractor.py index 285f182..74c8a11 100644 --- a/alphatriangle/features/extractor.py +++ b/alphatriangle/features/extractor.py @@ -1,31 +1,33 @@ # File: alphatriangle/features/extractor.py +# File: alphatriangle/features/extractor.py import logging -from typing import TYPE_CHECKING, cast +from typing import cast import numpy as np -from ..config import ModelConfig -from ..utils.types import StateType -from . import grid_features # Keep this import - -if TYPE_CHECKING: - from ..environment import GameState +# Import GameState from trianglengin +from trianglengin.core.environment import GameState +# Import alphatriangle configs +from ..config import ModelConfig +from ..utils.types import StateType # Keep alphatriangle StateType for now +from . import grid_features logger = logging.getLogger(__name__) class GameStateFeatures: - """Extracts features from GameState for NN input. Reads from GridData NumPy arrays.""" + """Extracts features from trianglengin.GameState for NN input.""" - def __init__(self, game_state: "GameState", model_config: ModelConfig): + def __init__(self, game_state: GameState, model_config: ModelConfig): self.gs = game_state + # Access EnvConfig from trianglengin.GameState self.env_config = game_state.env_config self.model_config = model_config - # Get direct references to NumPy arrays for efficiency + # Access GridData NumPy arrays directly (assuming GridData structure is stable) self.occupied_np = game_state.grid_data._occupied_np self.death_np = game_state.grid_data._death_np - # self.color_id_np = game_state.grid_data._color_id_np # Not used in current features + # self.color_id_np = game_state.grid_data._color_id_np # Still not used def _get_grid_state(self) -> np.ndarray: """ @@ -34,32 +36,25 @@ def _get_grid_state(self) -> np.ndarray: Shape: (C, H, W) where C is GRID_INPUT_CHANNELS """ rows, cols = self.env_config.ROWS, self.env_config.COLS - # Initialize with 0.0 (empty playable) grid_state: np.ndarray = np.zeros( (self.model_config.GRID_INPUT_CHANNELS, rows, cols), dtype=np.float32 ) - - # Mark occupied playable cells as 1.0 playable_occupied_mask = self.occupied_np & ~self.death_np grid_state[0, playable_occupied_mask] = 1.0 - - # Mark death cells as -1.0 grid_state[0, self.death_np] = -1.0 - - # No need for the loop or isfinite check here if input arrays are guaranteed finite - return grid_state def _get_shape_features(self) -> np.ndarray: - """Extracts features for each shape slot. (No change needed here)""" + """Extracts features for each shape slot.""" num_slots = self.env_config.NUM_SHAPE_SLOTS - - FEATURES_PER_SHAPE_HERE = 7 + FEATURES_PER_SHAPE_HERE = 7 # Keep this consistent with ModelConfig shape_feature_matrix = np.zeros( (num_slots, FEATURES_PER_SHAPE_HERE), dtype=np.float32 ) - for i, shape in enumerate(self.gs.shapes): + for i, shape in enumerate( + self.gs.shapes + ): # Access shapes from trianglengin.GameState if shape and shape.triangles: n_tris = len(shape.triangles) ups = sum(1 for _, _, is_up in shape.triangles if is_up) @@ -68,7 +63,6 @@ def _get_shape_features(self) -> np.ndarray: height = max_r - min_r + 1 width_eff = (max_c - min_c + 1) * 0.75 + 0.25 if n_tris > 0 else 0 - # Populate features shape_feature_matrix[i, 0] = np.clip(n_tris / 5.0, 0, 1) shape_feature_matrix[i, 1] = ups / n_tris if n_tris > 0 else 0 shape_feature_matrix[i, 2] = downs / n_tris if n_tris > 0 else 0 @@ -84,11 +78,10 @@ def _get_shape_features(self) -> np.ndarray: shape_feature_matrix[i, 6] = np.clip( ((min_c + max_c) / 2.0) / self.env_config.COLS, 0, 1 ) - # Flatten the matrix to get a 1D array return shape_feature_matrix.flatten() def _get_shape_availability(self) -> np.ndarray: - """Returns a binary vector indicating which shape slots are filled. (No change needed)""" + """Returns a binary vector indicating which shape slots are filled.""" return np.array([1.0 if s else 0.0 for s in self.gs.shapes], dtype=np.float32) def _get_explicit_features(self) -> np.ndarray: @@ -96,9 +89,8 @@ def _get_explicit_features(self) -> np.ndarray: Extracts scalar features like score, heights, holes, etc. Uses GridData NumPy arrays directly. """ - EXPLICIT_FEATURES_DIM_HERE = 6 + EXPLICIT_FEATURES_DIM_HERE = 6 # Keep consistent with ModelConfig features = np.zeros(EXPLICIT_FEATURES_DIM_HERE, dtype=np.float32) - # Use the direct references stored in self occupied = self.occupied_np death = self.death_np rows, cols = self.env_config.ROWS, self.env_config.COLS @@ -109,13 +101,16 @@ def _get_explicit_features(self) -> np.ndarray: bump = grid_features.get_bumpiness(heights) total_playable_cells = np.sum(~death) - # Populate features - features[0] = np.clip(self.gs.game_score / 100.0, -5.0, 5.0) + # Use game_score() method from trianglengin.GameState + features[0] = np.clip(self.gs.game_score() / 100.0, -5.0, 5.0) features[1] = np.mean(heights) / rows if rows > 0 else 0 features[2] = np.max(heights) / rows if rows > 0 else 0 features[3] = holes / total_playable_cells if total_playable_cells > 0 else 0 features[4] = (bump / (cols - 1)) / rows if cols > 1 and rows > 0 else 0 - features[5] = np.clip(self.gs.pieces_placed_this_episode / 100.0, 0, 1) + # Access current_step directly (assuming it exists) + features[5] = np.clip( + getattr(self.gs, "current_step", 0) / 1000.0, 0, 1 + ) # Use current_step as proxy for pieces placed return cast( "np.ndarray", np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) @@ -130,11 +125,9 @@ def get_combined_other_features(self) -> np.ndarray: expected_dim = self.model_config.OTHER_NN_INPUT_FEATURES_DIM if combined.shape[0] != expected_dim: - # Log error instead of raising ValueError immediately during feature extraction logger.error( f"Combined other_features dimension mismatch! Extracted {combined.shape[0]}, but ModelConfig expects {expected_dim}. Padding/truncating." ) - # Pad or truncate to match expected dimension if combined.shape[0] < expected_dim: padding = np.zeros( expected_dim - combined.shape[0], dtype=combined.dtype @@ -153,12 +146,12 @@ def get_combined_other_features(self) -> np.ndarray: def extract_state_features( - game_state: "GameState", model_config: ModelConfig + game_state: GameState, model_config: ModelConfig ) -> StateType: """ Extracts and returns the state dictionary {grid, other_features} for NN input. Requires ModelConfig to ensure dimensions match the network's expectations. - Includes validation for non-finite values. Now reads from GridData NumPy arrays. + Includes validation for non-finite values. Now reads from trianglengin.GridData NumPy arrays. """ extractor = GameStateFeatures(game_state, model_config) state_dict: StateType = { diff --git a/alphatriangle/features/grid_features.py b/alphatriangle/features/grid_features.py index 8cb8711..bde7e05 100644 --- a/alphatriangle/features/grid_features.py +++ b/alphatriangle/features/grid_features.py @@ -1,3 +1,5 @@ +# File: alphatriangle/features/grid_features.py +# No changes needed, operates on NumPy arrays. import numpy as np from numba import njit, prange diff --git a/alphatriangle/interaction/README.md b/alphatriangle/interaction/README.md deleted file mode 100644 index 0349f6c..0000000 --- a/alphatriangle/interaction/README.md +++ /dev/null @@ -1,50 +0,0 @@ -# File: alphatriangle/interaction/README.md -# Interaction Module (`alphatriangle.interaction`) - -## Purpose and Architecture - -This module handles user input (keyboard and mouse) for interactive modes of the application, such as "play" and "debug". It bridges the gap between raw Pygame events and actions within the game simulation ([`GameState`](../environment/core/game_state.py)). - -- **Event Processing:** [`event_processor.py`](event_processor.py) handles common Pygame events like quitting (QUIT, ESC) and window resizing. It acts as a generator, yielding other events for mode-specific processing. -- **Input Handler:** The [`InputHandler`](input_handler.py) class is the main entry point. - - It receives Pygame events (via the `event_processor`). - - It **manages interaction-specific state** internally (e.g., `selected_shape_idx`, `hover_grid_coord`, `debug_highlight_coord`). - - It determines the current interaction mode ("play" or "debug") and delegates event handling and hover updates to specific handler functions ([`play_mode_handler`](play_mode_handler.py), [`debug_mode_handler`](debug_mode_handler.py)). - - It provides the necessary interaction state to the [`Visualizer`](../visualization/core/visualizer.py) for rendering feedback (hover previews, selection highlights). -- **Mode-Specific Handlers:** `play_mode_handler.py` and `debug_mode_handler.py` contain the logic specific to each mode, operating on the `InputHandler`'s state and the `GameState`. - - `play`: Handles selecting shapes, checking placement validity, and triggering `GameState.step` on valid clicks. Updates hover state in the `InputHandler`. - - `debug`: Handles toggling the state of individual triangles directly on the `GameState.grid_data`. Updates hover state in the `InputHandler`. -- **Decoupling:** It separates input handling logic from the core game simulation ([`environment`](../environment/README.md)) and rendering ([`visualization`](../visualization/README.md)), although it needs references to both to function. The `Visualizer` is now only responsible for drawing based on the state provided by the `GameState` and the `InputHandler`. - -## Exposed Interfaces - -- **Classes:** - - `InputHandler`: - - `__init__(game_state: GameState, visualizer: Visualizer, mode: str, env_config: EnvConfig)` - - `handle_input() -> bool`: Processes events for one frame, returns `False` if quitting. - - `get_render_interaction_state() -> dict`: Returns interaction state needed by `Visualizer.render`. -- **Functions:** - - `process_pygame_events(visualizer: Visualizer) -> Generator[pygame.event.Event, Any, bool]`: Processes common events, yields others. - - `handle_play_click(event: pygame.event.Event, handler: InputHandler)`: Handles clicks in play mode. - - `update_play_hover(handler: InputHandler)`: Updates hover state in play mode. - - `handle_debug_click(event: pygame.event.Event, handler: InputHandler)`: Handles clicks in debug mode. - - `update_debug_hover(handler: InputHandler)`: Updates hover state in debug mode. - -## Dependencies - -- **[`alphatriangle.environment`](../environment/README.md)**: - - `GameState`: Modifies the game state based on user actions (placing shapes, toggling debug cells). - - `EnvConfig`: Used for coordinate mapping and action encoding. - - `GridLogic`, `ActionCodec`: Used by mode-specific handlers. -- **[`alphatriangle.visualization`](../visualization/README.md)**: - - `Visualizer`: Used to get layout information (`grid_rect`, `preview_rects`) and for coordinate mapping (`get_grid_coords_from_screen`, `get_preview_index_from_screen`). Also updated directly during resize events. - - `VisConfig`: Accessed via `Visualizer`. -- **[`alphatriangle.structs`](../structs/README.md)**: - - `Shape`, `Triangle`, `DEBUG_COLOR_ID`, `NO_COLOR_ID`. -- **`pygame`**: - - Relies heavily on Pygame for event handling (`pygame.event`, `pygame.mouse`) and constants (`MOUSEBUTTONDOWN`, `KEYDOWN`, etc.). -- **Standard Libraries:** `typing`, `logging`. - ---- - -**Note:** Please keep this README updated when adding new interaction modes, changing input handling logic, or modifying the interfaces between interaction, environment, and visualization. Accurate documentation is crucial for maintainability. \ No newline at end of file diff --git a/alphatriangle/interaction/__init__.py b/alphatriangle/interaction/__init__.py deleted file mode 100644 index b311553..0000000 --- a/alphatriangle/interaction/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .debug_mode_handler import handle_debug_click, update_debug_hover -from .event_processor import process_pygame_events -from .input_handler import InputHandler -from .play_mode_handler import handle_play_click, update_play_hover - -__all__ = [ - "InputHandler", - "process_pygame_events", - "handle_play_click", - "update_play_hover", - "handle_debug_click", - "update_debug_hover", -] diff --git a/alphatriangle/interaction/debug_mode_handler.py b/alphatriangle/interaction/debug_mode_handler.py deleted file mode 100644 index 658d421..0000000 --- a/alphatriangle/interaction/debug_mode_handler.py +++ /dev/null @@ -1,95 +0,0 @@ -# File: alphatriangle/interaction/debug_mode_handler.py -import logging -from typing import TYPE_CHECKING - -import pygame - -from ..environment import grid as env_grid - -# Import constants from the structs package directly -from ..structs import DEBUG_COLOR_ID, NO_COLOR_ID -from ..visualization import core as vis_core - -if TYPE_CHECKING: - # Keep Triangle for type hinting if GridLogic still uses it temporarily - from .input_handler import InputHandler - -logger = logging.getLogger(__name__) - - -def handle_debug_click(event: pygame.event.Event, handler: "InputHandler") -> None: - """Handles mouse clicks in debug mode (toggle triangle state using NumPy arrays).""" - if not (event.type == pygame.MOUSEBUTTONDOWN and event.button == 1): - return - - game_state = handler.game_state - visualizer = handler.visualizer - mouse_pos = handler.mouse_pos - - layout_rects = visualizer.ensure_layout() - grid_rect = layout_rects.get("grid") - if not grid_rect: - logger.error("Grid layout rectangle not available for debug click.") - return - - grid_coords = vis_core.coord_mapper.get_grid_coords_from_screen( - mouse_pos, grid_rect, game_state.env_config - ) - if not grid_coords: - return - - r, c = grid_coords - if game_state.grid_data.valid(r, c): - # Check death zone first - if not game_state.grid_data._death_np[r, c]: - # Toggle occupancy state in NumPy array - current_occupied_state = game_state.grid_data._occupied_np[r, c] - new_occupied_state = not current_occupied_state - game_state.grid_data._occupied_np[r, c] = new_occupied_state - - # Update color ID based on new state - new_color_id = DEBUG_COLOR_ID if new_occupied_state else NO_COLOR_ID - game_state.grid_data._color_id_np[r, c] = new_color_id - - logger.debug( - f": Toggled triangle ({r},{c}) -> {'Occupied' if new_occupied_state else 'Empty'}" - ) - - # Check for line clears if the cell became occupied - if new_occupied_state: - # Pass the coordinate tuple in a set - lines_cleared, unique_tris_coords, _ = ( - env_grid.logic.check_and_clear_lines( - game_state.grid_data, newly_occupied_coords={(r, c)} - ) - ) - if lines_cleared > 0: - logger.debug( - f"Cleared {lines_cleared} lines ({len(unique_tris_coords)} coords) after toggle." - ) - else: - logger.info(f"Clicked on death cell ({r},{c}). No action.") - - -def update_debug_hover(handler: "InputHandler") -> None: - """Updates the debug highlight position within the InputHandler.""" - handler.debug_highlight_coord = None # Reset hover state - - game_state = handler.game_state - visualizer = handler.visualizer - mouse_pos = handler.mouse_pos - - layout_rects = visualizer.ensure_layout() - grid_rect = layout_rects.get("grid") - if not grid_rect or not grid_rect.collidepoint(mouse_pos): - return # Not hovering over grid - - grid_coords = vis_core.coord_mapper.get_grid_coords_from_screen( - mouse_pos, grid_rect, game_state.env_config - ) - - if grid_coords: - r, c = grid_coords - # Highlight only valid, non-death cells - if game_state.grid_data.valid(r, c) and not game_state.grid_data.is_death(r, c): - handler.debug_highlight_coord = grid_coords diff --git a/alphatriangle/interaction/event_processor.py b/alphatriangle/interaction/event_processor.py deleted file mode 100644 index a7a243d..0000000 --- a/alphatriangle/interaction/event_processor.py +++ /dev/null @@ -1,42 +0,0 @@ -import logging -from collections.abc import Generator -from typing import TYPE_CHECKING, Any - -import pygame - -if TYPE_CHECKING: - from ..visualization.core.visualizer import Visualizer - -logger = logging.getLogger(__name__) - - -def process_pygame_events( - visualizer: "Visualizer", -) -> Generator[pygame.event.Event, Any, bool]: - """ - Processes basic Pygame events like QUIT, ESCAPE, VIDEORESIZE. - Yields other events for mode-specific handlers. - Returns False via StopIteration value if the application should quit, True otherwise. - """ - should_quit = False - for event in pygame.event.get(): - if event.type == pygame.QUIT: - logger.info("Received QUIT event.") - should_quit = True - break - if event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE: - logger.info("Received ESCAPE key press.") - should_quit = True - break - if event.type == pygame.VIDEORESIZE: - try: - w, h = max(320, event.w), max(240, event.h) - visualizer.screen = pygame.display.set_mode((w, h), pygame.RESIZABLE) - visualizer.layout_rects = None - logger.info(f"Window resized to {w}x{h}") - except pygame.error as e: - logger.error(f"Error resizing window: {e}") - yield event - else: - yield event - return not should_quit diff --git a/alphatriangle/interaction/input_handler.py b/alphatriangle/interaction/input_handler.py deleted file mode 100644 index a77ba2b..0000000 --- a/alphatriangle/interaction/input_handler.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging -from typing import TYPE_CHECKING - -import pygame - -from .. import environment, visualization -from . import debug_mode_handler, event_processor, play_mode_handler - -if TYPE_CHECKING: - from ..structs import Shape - - -logger = logging.getLogger(__name__) - - -class InputHandler: - """ - Handles user input, manages interaction state (selection, hover), - and delegates actions to mode-specific handlers. - """ - - def __init__( - self, - game_state: environment.GameState, - visualizer: visualization.Visualizer, - mode: str, - env_config: environment.EnvConfig, - ): - self.game_state = game_state - self.visualizer = visualizer - self.mode = mode - self.env_config = env_config - - # Interaction state managed here - self.selected_shape_idx: int = -1 - self.hover_grid_coord: tuple[int, int] | None = None - self.hover_is_valid: bool = False - self.hover_shape: Shape | None = None - self.debug_highlight_coord: tuple[int, int] | None = None - self.mouse_pos: tuple[int, int] = (0, 0) - - def handle_input(self) -> bool: - """Processes Pygame events and updates state based on mode. Returns False to quit.""" - self.mouse_pos = pygame.mouse.get_pos() - - # Reset hover/highlight state each frame before processing events/updates - self.hover_grid_coord = None - self.hover_is_valid = False - self.hover_shape = None - self.debug_highlight_coord = None - - running = True - event_generator = event_processor.process_pygame_events(self.visualizer) - try: - while True: - event = next(event_generator) - # Pass self to handlers so they can modify interaction state - if self.mode == "play": - play_mode_handler.handle_play_click(event, self) - elif self.mode == "debug": - debug_mode_handler.handle_debug_click(event, self) - except StopIteration as e: - running = e.value # False if quit requested - - # Update hover state after processing events - if running: - if self.mode == "play": - play_mode_handler.update_play_hover(self) - elif self.mode == "debug": - debug_mode_handler.update_debug_hover(self) - - return running - - def get_render_interaction_state(self) -> dict: - """Returns interaction state needed by Visualizer.render""" - return { - "selected_shape_idx": self.selected_shape_idx, - "hover_shape": self.hover_shape, - "hover_grid_coord": self.hover_grid_coord, - "hover_is_valid": self.hover_is_valid, - "hover_screen_pos": self.mouse_pos, # Pass current mouse pos - "debug_highlight_coord": self.debug_highlight_coord, - } diff --git a/alphatriangle/interaction/play_mode_handler.py b/alphatriangle/interaction/play_mode_handler.py deleted file mode 100644 index 21946a3..0000000 --- a/alphatriangle/interaction/play_mode_handler.py +++ /dev/null @@ -1,142 +0,0 @@ -import logging -from typing import TYPE_CHECKING - -import pygame - -from ..environment import core as env_core -from ..environment import grid as env_grid -from ..visualization import core as vis_core - -if TYPE_CHECKING: - from ..structs import Shape - from .input_handler import InputHandler - -logger = logging.getLogger(__name__) - - -def handle_play_click(event: pygame.event.Event, handler: "InputHandler") -> None: - """Handles mouse clicks in play mode (select preview, place shape). Modifies handler state.""" - if not (event.type == pygame.MOUSEBUTTONDOWN and event.button == 1): - return - - game_state = handler.game_state - visualizer = handler.visualizer - mouse_pos = handler.mouse_pos - - if game_state.is_over(): - logger.info("Game is over, ignoring click.") - return - - layout_rects = visualizer.ensure_layout() - grid_rect = layout_rects.get("grid") - # Get preview rects from visualizer cache - preview_rects = visualizer.preview_rects - - # 1. Check for clicks on shape previews - preview_idx = vis_core.coord_mapper.get_preview_index_from_screen( - mouse_pos, preview_rects - ) - if preview_idx is not None: - if handler.selected_shape_idx == preview_idx: - # Clicked selected shape again: deselect - handler.selected_shape_idx = -1 - handler.hover_grid_coord = None # Clear hover state on deselect - handler.hover_shape = None - logger.info("Deselected shape.") - elif ( - 0 <= preview_idx < len(game_state.shapes) and game_state.shapes[preview_idx] - ): - # Clicked a valid, available shape: select it - handler.selected_shape_idx = preview_idx - logger.info(f"Selected shape index: {preview_idx}") - # Immediately update hover based on current mouse pos after selection - update_play_hover(handler) # Update hover state within handler - else: - # Clicked an empty or invalid slot - logger.info(f"Clicked empty/invalid preview slot: {preview_idx}") - # Deselect if clicking an empty slot while another is selected - if handler.selected_shape_idx != -1: - handler.selected_shape_idx = -1 - handler.hover_grid_coord = None - handler.hover_shape = None - return # Handled preview click - - # 2. Check for clicks on the grid (if a shape is selected) - selected_idx = handler.selected_shape_idx - if selected_idx != -1 and grid_rect and grid_rect.collidepoint(mouse_pos): - # A shape is selected, and the click is within the grid area. - grid_coords = vis_core.coord_mapper.get_grid_coords_from_screen( - mouse_pos, grid_rect, game_state.env_config - ) - # Use TYPE_CHECKING import for Shape type hint - shape_to_place: Shape | None = game_state.shapes[selected_idx] - - # Check if the placement is valid *at the clicked location* - if ( - grid_coords - and shape_to_place - and env_grid.logic.can_place( - game_state.grid_data, shape_to_place, grid_coords[0], grid_coords[1] - ) - ): - # Valid placement click! - r, c = grid_coords - action = env_core.action_codec.encode_action( - selected_idx, r, c, game_state.env_config - ) - # Execute the step using the game state's method - reward, done = game_state.step(action) # Now returns (reward, done) - logger.info( - f"Placed shape {selected_idx} at {grid_coords}. R={reward:.1f}, Done={done}" - ) - # Deselect shape after successful placement - handler.selected_shape_idx = -1 - handler.hover_grid_coord = None # Clear hover state - handler.hover_shape = None - else: - # Clicked grid, shape selected, but not a valid placement spot for the click - logger.info(f"Clicked grid at {grid_coords}, but placement invalid.") - - -def update_play_hover(handler: "InputHandler") -> None: - """Updates the hover state within the InputHandler.""" - # Reset hover state first - handler.hover_grid_coord = None - handler.hover_is_valid = False - handler.hover_shape = None - - game_state = handler.game_state - visualizer = handler.visualizer - mouse_pos = handler.mouse_pos - - if game_state.is_over() or handler.selected_shape_idx == -1: - return # No hover if game over or no shape selected - - layout_rects = visualizer.ensure_layout() - grid_rect = layout_rects.get("grid") - if not grid_rect or not grid_rect.collidepoint(mouse_pos): - return # Not hovering over grid - - shape_idx = handler.selected_shape_idx - if not (0 <= shape_idx < len(game_state.shapes)): - return - shape: Shape | None = game_state.shapes[shape_idx] - if not shape: - return - - # Get grid coordinates under mouse - grid_coords = vis_core.coord_mapper.get_grid_coords_from_screen( - mouse_pos, grid_rect, game_state.env_config - ) - - if grid_coords: - # Check if placement is valid at these coordinates - is_valid = env_grid.logic.can_place( - game_state.grid_data, shape, grid_coords[0], grid_coords[1] - ) - # Update handler's hover state - handler.hover_grid_coord = grid_coords - handler.hover_is_valid = is_valid - handler.hover_shape = shape # Store the shape being hovered - else: - handler.hover_shape = shape # Store shape for floating preview diff --git a/alphatriangle/mcts/core/node.py b/alphatriangle/mcts/core/node.py index 1cea150..27f1ec9 100644 --- a/alphatriangle/mcts/core/node.py +++ b/alphatriangle/mcts/core/node.py @@ -1,10 +1,16 @@ +# File: alphatriangle/mcts/core/node.py from __future__ import annotations import logging from typing import TYPE_CHECKING +# Import GameState from trianglengin + +# Keep ActionType from alphatriangle utils for now + if TYPE_CHECKING: - from alphatriangle.environment import GameState + from trianglengin.core.environment import GameState + from alphatriangle.utils.types import ActionType logger = logging.getLogger(__name__) @@ -15,7 +21,7 @@ class Node: def __init__( self, - state: GameState, + state: GameState, # Use trianglengin.GameState parent: Node | None = None, action_taken: ActionType | None = None, prior_probability: float = 0.0, @@ -23,9 +29,7 @@ def __init__( self.state = state self.parent = parent self.action_taken = action_taken - self.children: dict[ActionType, Node] = {} - self.visit_count: int = 0 self.total_action_value: float = 0.0 self.prior_probability: float = prior_probability diff --git a/alphatriangle/mcts/core/search.py b/alphatriangle/mcts/core/search.py index 356d89c..680dba1 100644 --- a/alphatriangle/mcts/core/search.py +++ b/alphatriangle/mcts/core/search.py @@ -1,10 +1,11 @@ -# File: alphatriangle/mcts/core/search.py import concurrent.futures import logging import time import numpy as np +# Import GameState from trianglengin +# Keep alphatriangle imports from ...config import MCTSConfig from ..strategy import backpropagation, expansion, selection from .node import Node @@ -12,9 +13,7 @@ logger = logging.getLogger(__name__) -# --- CHANGED: Default batch size, can be adjusted --- -MCTS_PARALLEL_TRAVERSALS = 16 # Number of traversals to run in parallel -# --- END CHANGED --- +MCTS_PARALLEL_TRAVERSALS = 16 class MCTSExecutionError(Exception): @@ -25,17 +24,13 @@ class MCTSExecutionError(Exception): def _run_single_traversal(root_node: Node, config: MCTSConfig) -> tuple[Node, int]: """Helper function to run a single MCTS traversal (selection phase).""" - # This function is designed to be thread-safe as selection reads node stats - # but doesn't modify them until backpropagation. try: leaf_node, selection_depth = selection.traverse_to_leaf(root_node, config) return leaf_node, selection_depth except Exception as e: - # Log error within the thread/task for better context logger.error( f"[MCTS Traversal Task] Error during traversal: {e}", exc_info=True ) - # Re-raise or return an indicator? Re-raising is cleaner for future handling. raise MCTSExecutionError(f"Traversal failed: {e}") from e @@ -52,6 +47,7 @@ def run_mcts_simulations( Returns: The maximum tree depth reached during the simulations. """ + # Use is_over() method from trianglengin.GameState if root_node.state.is_over(): logger.warning("[MCTS] MCTS started on a terminal state. No simulations run.") return 0 @@ -62,12 +58,11 @@ def run_mcts_simulations( eval_error_count = 0 total_sims_run = 0 - # --- Initial Root Expansion (if needed) --- if not root_node.is_expanded: logger.debug("[MCTS] Root node not expanded, performing initial evaluation...") try: + # Pass trianglengin.GameState to evaluator action_policy, root_value = network_evaluator.evaluate(root_node.state) - # Basic validation (can be enhanced) if not isinstance(action_policy, dict) or not isinstance(root_value, float): raise MCTSExecutionError("Initial evaluation returned invalid type.") if not np.all(np.isfinite(list(action_policy.values()))): @@ -80,12 +75,11 @@ def run_mcts_simulations( ) expansion.expand_node_with_policy(root_node, action_policy) + # Use is_over() method from trianglengin.GameState if root_node.is_expanded or root_node.state.is_over(): depth_bp = backpropagation.backpropagate_value(root_node, root_value) max_depth_overall = max(max_depth_overall, depth_bp) - selection.add_dirichlet_noise( - root_node, config - ) # Apply noise after first expansion/backprop + selection.add_dirichlet_noise(root_node, config) else: logger.warning("[MCTS] Initial root expansion failed.") except Exception as e: @@ -95,16 +89,14 @@ def run_mcts_simulations( raise MCTSExecutionError( f"Initial root evaluation/expansion failed: {e}" ) from e - elif root_node.visit_count == 0: # Apply noise if root is expanded but unvisited + elif root_node.visit_count == 0: selection.add_dirichlet_noise(root_node, config) - # --- End Initial Root Expansion --- logger.info( f"[MCTS] Starting MCTS loop for {config.num_simulations} simulations " f"(Parallel Traversals: {MCTS_PARALLEL_TRAVERSALS}). Root state step: {root_node.state.current_step}" ) - # Use ThreadPoolExecutor for parallel traversals with concurrent.futures.ThreadPoolExecutor( max_workers=MCTS_PARALLEL_TRAVERSALS ) as executor: @@ -117,17 +109,15 @@ def run_mcts_simulations( f"[MCTS Batch] Launching {num_to_launch} parallel traversals..." ) - # --- Submit Traversal Tasks --- futures_to_leaf: dict[concurrent.futures.Future, int] = {} for i in range(num_to_launch): future = executor.submit(_run_single_traversal, root_node, config) - futures_to_leaf[future] = processed_simulations + i # Store sim index + futures_to_leaf[future] = processed_simulations + i leaves_to_evaluate: list[Node] = [] paths_to_backprop: list[tuple[Node, float]] = [] traversal_results: list[tuple[Node | None, int, Exception | None]] = [] - # --- Collect Traversal Results --- for future in concurrent.futures.as_completed(futures_to_leaf): sim_idx = futures_to_leaf[future] try: @@ -141,19 +131,13 @@ def run_mcts_simulations( traversal_results.append((None, 0, exc)) logger.error(f" [MCTS Traversal] Sim {sim_idx + 1} failed: {exc}") - # --- Process Traversal Results --- for leaf_node_optional, selection_depth, error in traversal_results: - # --- CHANGED: Explicit check and assignment --- if error or leaf_node_optional is None: - continue # Skip failed traversals - - # Now we know leaf_node_optional is not None, assign to typed variable + continue valid_leaf_node: Node = leaf_node_optional - # --- END CHANGED --- - max_depth_overall = max(max_depth_overall, selection_depth) - # --- Use valid_leaf_node --- + # Use is_over() and get_outcome() from trianglengin.GameState if valid_leaf_node.state.is_over(): outcome = valid_leaf_node.state.get_outcome() logger.debug( @@ -165,22 +149,21 @@ def run_mcts_simulations( " [Process] Sim result: Leaf needs EVALUATION. Adding to batch." ) leaves_to_evaluate.append(valid_leaf_node) - else: # Hit max depth or encountered selection error resulting in expanded node + else: logger.debug( f" [Process] Sim result: EXPANDED leaf (likely max depth). Value: {valid_leaf_node.value_estimate:.3f}. Adding to backprop." ) paths_to_backprop.append( (valid_leaf_node, valid_leaf_node.value_estimate) ) - # --- END Use valid_leaf_node --- - # --- Batch Evaluate Leaves --- evaluation_start_time = time.monotonic() if leaves_to_evaluate: logger.debug( f" [MCTS Eval] Evaluating batch of {len(leaves_to_evaluate)} leaves..." ) try: + # Pass list of trianglengin.GameState to evaluator leaf_states = [node.state for node in leaves_to_evaluate] batch_results = network_evaluator.evaluate_batch(leaf_states) @@ -193,7 +176,6 @@ def run_mcts_simulations( for i, node in enumerate(leaves_to_evaluate): action_policy, value = batch_results[i] - # Basic validation if ( not isinstance(action_policy, dict) or not isinstance(value, float) @@ -202,18 +184,16 @@ def run_mcts_simulations( logger.error( f" [MCTS Eval] Invalid policy/value received for leaf {i}. Policy: {type(action_policy)}, Value: {value}. Using 0 value." ) - value = 0.0 # Use neutral value on error - action_policy = {} # Use empty policy on error + value = 0.0 + action_policy = {} + # Use is_over() from trianglengin.GameState if not node.is_expanded and not node.state.is_over(): expansion.expand_node_with_policy(node, action_policy) logger.debug( f" [MCTS Eval/Expand] Expanded evaluated leaf node {i}: {node}" ) - - paths_to_backprop.append( - (node, value) - ) # Add evaluated node for backprop + paths_to_backprop.append((node, value)) except Exception as e: eval_error_count += len(leaves_to_evaluate) @@ -221,7 +201,6 @@ def run_mcts_simulations( f" [MCTS Eval] Error during batch evaluation/expansion: {e}", exc_info=True, ) - # Skip backprop for these leaves if eval failed evaluation_duration = time.monotonic() - evaluation_start_time if leaves_to_evaluate: @@ -229,7 +208,6 @@ def run_mcts_simulations( f" [MCTS Eval] Evaluation/Expansion phase finished. Duration: {evaluation_duration:.4f}s" ) - # --- Sequential Backpropagation --- backprop_start_time = time.monotonic() logger.debug( f" [MCTS Backprop] Backpropagating {len(paths_to_backprop)} paths..." @@ -249,25 +227,21 @@ def run_mcts_simulations( f" [Backprop] Error backpropagating path {i} (Value={value_to_prop:.4f}, Node={leaf_node_bp}): {bp_err}", exc_info=True, ) - sim_error_count += 1 # Count backprop errors separately + sim_error_count += 1 backprop_duration = time.monotonic() - backprop_start_time logger.debug( f" [MCTS Backprop] Backpropagation phase finished. Duration: {backprop_duration:.4f}s" ) - # --- Update Loop Control --- processed_simulations += num_to_launch pending_simulations -= num_to_launch - total_sims_run = ( - processed_simulations # Track total sims attempted in this run - ) + total_sims_run = processed_simulations logger.debug( f"[MCTS Batch] Finished batch. Processed: {processed_simulations}/{config.num_simulations}. Pending: {pending_simulations}" ) - # --- Final Logging --- final_log_level = logging.INFO logger.log( final_log_level, @@ -278,12 +252,12 @@ def run_mcts_simulations( if root_node.children: child_visits_log = {a: c.visit_count for a, c in root_node.children.items()} logger.info(f"[MCTS] Root children visit counts: {child_visits_log}") + # Use is_over() from trianglengin.GameState elif not root_node.state.is_over(): logger.warning("[MCTS] MCTS finished but root node still has no children.") - # --- Error Check --- total_errors = sim_error_count + eval_error_count - if total_errors > config.num_simulations * 0.01: # Example threshold: 50% errors + if total_errors > config.num_simulations * 0.5: # Increased threshold raise MCTSExecutionError( f"MCTS failed: High error rate ({total_errors} errors in {total_sims_run} simulations)." ) diff --git a/alphatriangle/mcts/core/types.py b/alphatriangle/mcts/core/types.py index a4490fc..c3aee72 100644 --- a/alphatriangle/mcts/core/types.py +++ b/alphatriangle/mcts/core/types.py @@ -1,19 +1,21 @@ from collections.abc import Mapping -from typing import TYPE_CHECKING, Protocol +from typing import Protocol -from ...utils.types import PolicyValueOutput +# Import GameState from trianglengin +from trianglengin.core.environment import GameState -if TYPE_CHECKING: - from ...environment import GameState - from ...utils.types import ActionType +# Keep alphatriangle utils types for now +from ...utils.types import ActionType, PolicyValueOutput -ActionPolicyMapping = Mapping["ActionType", float] +ActionPolicyMapping = Mapping[ActionType, float] class ActionPolicyValueEvaluator(Protocol): """Defines the interface for evaluating a game state using a neural network.""" - def evaluate(self, state: "GameState") -> PolicyValueOutput: + def evaluate( + self, state: GameState + ) -> PolicyValueOutput: # Uses trianglengin.GameState """ Evaluates a single game state using the neural network. @@ -28,7 +30,9 @@ def evaluate(self, state: "GameState") -> PolicyValueOutput: """ ... - def evaluate_batch(self, states: list["GameState"]) -> list[PolicyValueOutput]: + def evaluate_batch( + self, states: list[GameState] + ) -> list[PolicyValueOutput]: # Uses trianglengin.GameState """ Evaluates a batch of game states using the neural network. (Optional but recommended for performance if MCTS supports batch evaluation). diff --git a/alphatriangle/mcts/strategy/expansion.py b/alphatriangle/mcts/strategy/expansion.py index d4a9a21..eb11f5e 100644 --- a/alphatriangle/mcts/strategy/expansion.py +++ b/alphatriangle/mcts/strategy/expansion.py @@ -1,14 +1,14 @@ import logging from typing import TYPE_CHECKING +# Import GameState from trianglengin +# Keep alphatriangle utils types for now +from ..core.node import Node +from ..core.types import ActionPolicyMapping + if TYPE_CHECKING: from ...utils.types import ActionType -from ..core.node import Node -from ..core.types import ( - ActionPolicyMapping, -) - logger = logging.getLogger(__name__) @@ -22,14 +22,15 @@ def expand_node_with_policy(node: Node, action_policy: ActionPolicyMapping): if node.is_expanded: logger.debug(f"[Expand] Attempted to expand an already expanded node: {node}") return + # Use is_over() method from trianglengin.GameState if node.state.is_over(): logger.warning(f"[Expand] Attempted to expand a terminal node: {node}") return logger.debug(f"[Expand] Expanding Node: {node}") - # Use TYPE_CHECKING import for ActionType type hint - valid_actions: list[ActionType] = node.state.valid_actions() + # Use valid_actions() method from trianglengin.GameState + valid_actions: set[ActionType] = node.state.valid_actions() logger.debug( f"[Expand] Found {len(valid_actions)} valid actions for state step {node.state.current_step}." ) @@ -41,12 +42,8 @@ def expand_node_with_policy(node: Node, action_policy: ActionPolicyMapping): logger.warning( f"[Expand] Expanding node at step {node.state.current_step} with no valid actions but not terminal? Marking state as game over." ) - if hasattr(node.state, "game_over"): - node.state.game_over = True - elif hasattr(node.state, "_is_over"): - node.state._is_over = True - else: - logger.error("[Expand] Cannot mark state as game over - attribute missing.") + # Use force_game_over method from trianglengin.GameState + node.state.force_game_over("Expansion found no valid actions") return children_created = 0 @@ -62,19 +59,20 @@ def expand_node_with_policy(node: Node, action_policy: ActionPolicyMapping): f"[Expand] Valid action {action} received prior=0 from network." ) + # Use copy() method from trianglengin.GameState next_state_copy = node.state.copy() try: - # Correctly unpack the (reward, done) tuple returned by step + # Use step() method from trianglengin.GameState _, done = next_state_copy.step(action) except Exception as e: logger.error( f"[Expand] Error stepping state for child node expansion (action {action}): {e}", exc_info=True, ) - continue # Skip creating this child if stepping fails + continue child = Node( - state=next_state_copy, + state=next_state_copy, # Already a trianglengin.GameState parent=node, action_taken=action, prior_probability=prior, diff --git a/alphatriangle/mcts/strategy/policy.py b/alphatriangle/mcts/strategy/policy.py index d4fd80a..ba69e3d 100644 --- a/alphatriangle/mcts/strategy/policy.py +++ b/alphatriangle/mcts/strategy/policy.py @@ -1,12 +1,18 @@ import logging import random +from typing import TYPE_CHECKING import numpy as np +# Import EnvConfig from trianglengin +# Keep alphatriangle imports from ...utils.types import ActionType from ..core.node import Node from ..core.types import ActionPolicyMapping +if TYPE_CHECKING: + from trianglengin.config import EnvConfig + logger = logging.getLogger(__name__) rng = np.random.default_rng() @@ -62,7 +68,6 @@ def select_action_based_on_visits(root_node: Node, temperature: float) -> Action logger.debug( f"[PolicySelect] Greedy selection. Best action indices: {best_action_indices}" ) - # Use standard library random for tie-breaking chosen_index = random.choice(best_action_indices) selected_action = actions[chosen_index] logger.debug(f"[PolicySelect] Greedy action selected: {selected_action}") @@ -102,12 +107,10 @@ def select_action_based_on_visits(root_node: Node, temperature: float) -> Action logger.debug(f" Final Probabilities Sum: {np.sum(probabilities):.6f}") try: - # Use NumPy's default_rng for weighted choice selected_action = rng.choice(actions, p=probabilities) logger.debug( f"[PolicySelect] Sampled action (temp={temperature:.2f}): {selected_action}" ) - # Ensure return type is ActionType (int) return int(selected_action) except ValueError as e: raise PolicyGenerationError( @@ -120,7 +123,9 @@ def get_policy_target(root_node: Node, temperature: float = 1.0) -> ActionPolicy Calculates the policy target distribution based on MCTS visit counts. Raises PolicyGenerationError if target cannot be generated. """ - action_dim = int(root_node.state.env_config.ACTION_DIM) # type: ignore[call-overload] + # Access EnvConfig from the node's state + env_config: EnvConfig = root_node.state.env_config + action_dim = int(env_config.ACTION_DIM) # type: ignore[call-overload] full_target = dict.fromkeys(range(action_dim), 0.0) if not root_node.children or root_node.visit_count == 0: diff --git a/alphatriangle/mcts/strategy/selection.py b/alphatriangle/mcts/strategy/selection.py index 7c8fdaf..856e7eb 100644 --- a/alphatriangle/mcts/strategy/selection.py +++ b/alphatriangle/mcts/strategy/selection.py @@ -1,8 +1,12 @@ +# File: alphatriangle/mcts/strategy/selection.py import logging import math import numpy as np +# Import GameState from trianglengin (only needed for type hint if used) +# from trianglengin.core.environment import GameState +# Keep alphatriangle imports from ...config import MCTSConfig from ..core.node import Node @@ -25,8 +29,6 @@ def calculate_puct_score( q_value = child_node.value_estimate prior = child_node.prior_probability child_visits = child_node.visit_count - # Use parent_visit_count directly; sqrt comes later if needed (original AlphaGo used N(s), not sqrt(N(s))) - # Let's use sqrt(parent_visit_count) for UCB1-like exploration bonus scaling parent_visits_sqrt = math.sqrt(max(1, parent_visit_count)) exploration_term = ( @@ -34,7 +36,6 @@ def calculate_puct_score( ) score = q_value + exploration_term - # Ensure score is finite, default to Q-value if exploration term explodes if not np.isfinite(score): logger.warning( f"Non-finite PUCT score calculated (Q={q_value}, P={prior}, ChildN={child_visits}, ParentN={parent_visit_count}, Exp={exploration_term}). Defaulting to Q-value." @@ -56,7 +57,6 @@ def add_dirichlet_noise(node: Node, config: MCTSConfig): return actions = list(node.children.keys()) - # Use the module-level rng generator noise = rng.dirichlet([config.dirichlet_alpha] * len(actions)) eps = config.dirichlet_epsilon @@ -70,7 +70,6 @@ def add_dirichlet_noise(node: Node, config: MCTSConfig): f" [Noise] Action {action}: OrigP={original_prior:.4f}, Noise={noise[i]:.4f} -> NewP={child.prior_probability:.4f}" ) - # Re-normalize priors after adding noise to ensure they sum to 1 if abs(noisy_priors_sum - 1.0) > 1e-6: logger.debug( f"Re-normalizing priors after Dirichlet noise (Sum={noisy_priors_sum:.6f})" @@ -79,14 +78,10 @@ def add_dirichlet_noise(node: Node, config: MCTSConfig): if noisy_priors_sum > 1e-9: node.children[action].prior_probability /= noisy_priors_sum else: - # Handle case where sum is zero - distribute equally? Or leave as 0? - # Leaving as 0 might be safer if original priors were also 0. - # Distributing equally might introduce unintended exploration. - # Let's log a warning and leave them as potentially 0. logger.warning( "Sum of priors after noise is near zero. Cannot normalize." ) - node.children[action].prior_probability = 0.0 # Or 1.0 / len(actions) ? + node.children[action].prior_probability = 0.0 logger.debug( f"[Noise] Added Dirichlet noise (alpha={config.dirichlet_alpha}, eps={eps}) to {len(actions)} root node priors." @@ -111,11 +106,9 @@ def select_child_node(node: Node, config: MCTSConfig) -> Node: f" [Select] Selecting child for Node (Visits={node.visit_count}, Children={len(node.children)}, StateStep={node.state.current_step}):" ) - # Use parent_visit_count from the node being considered for selection parent_visit_count = node.visit_count for action, child in node.children.items(): - # Pass the correct parent_visit_count for PUCT calculation score, q, exp_term = calculate_puct_score(child, parent_visit_count, config) if logger.isEnabledFor(logging.DEBUG): @@ -124,7 +117,6 @@ def select_child_node(node: Node, config: MCTSConfig) -> Node: f"(Q={q:.3f}, P={child.prior_probability:.4f}, N={child.visit_count}, Exp={exp_term:.4f})" ) child_scores_log.append(log_entry) - # Removed per-child log line here to reduce verbosity, summary below is sufficient if not np.isfinite(score): logger.warning( @@ -132,7 +124,6 @@ def select_child_node(node: Node, config: MCTSConfig) -> Node: ) continue - # Tie-breaking: add small random value? Or just take first max? Taking first max is simpler. if score > best_score: best_score = score best_child = child @@ -151,11 +142,10 @@ def get_score_from_log(log_str): except Exception as sort_err: logger.warning(f"Could not sort child score logs: {sort_err}") logger.debug(" [Select] All Child Scores Considered (Top 5):") - for log_line in child_scores_log[:5]: # Log only top 5 scores + for log_line in child_scores_log[:5]: logger.debug(f" {log_line}") if best_child is None: - # Log available children details for debugging child_details = [ f"Act={a}, N={c.visit_count}, P={c.prior_probability:.4f}, Q={c.value_estimate:.3f}" for a, c in node.children.items() @@ -191,26 +181,26 @@ def traverse_to_leaf(root_node: Node, config: MCTSConfig) -> tuple[Node, int]: f" [Traverse] Depth {depth}: Considering Node: {current_node} (Expanded={current_node.is_expanded}, Terminal={current_node.state.is_over()})" ) + # Use is_over() method from trianglengin.GameState if current_node.state.is_over(): stop_reason = "Terminal State" - logger.debug( # Changed level from INFO to DEBUG + logger.debug( f" [Traverse] Depth {depth}: Node is TERMINAL. Stopping traverse." ) break if not current_node.is_expanded: stop_reason = "Unexpanded Leaf" - logger.debug( # Changed level from INFO to DEBUG + logger.debug( f" [Traverse] Depth {depth}: Node is LEAF (not expanded). Stopping traverse." ) break if config.max_search_depth is not None and depth >= config.max_search_depth: stop_reason = "Max Depth Reached" - logger.debug( # Changed level from INFO to DEBUG + logger.debug( f" [Traverse] Depth {depth}: Hit MAX DEPTH ({config.max_search_depth}). Stopping traverse." ) break - # Node is expanded, non-terminal, and below max depth - select child try: selected_child = select_child_node(current_node, config) logger.debug( @@ -222,10 +212,8 @@ def traverse_to_leaf(root_node: Node, config: MCTSConfig) -> tuple[Node, int]: stop_reason = f"Child Selection Error: {e}" logger.error( f" [Traverse] Depth {depth}: Error during child selection: {e}. Breaking traverse.", - exc_info=False, # Avoid full traceback for selection errors unless needed + exc_info=False, ) - # It's better to return the current node where selection failed than raise an exception - # The MCTS search loop can then handle this (e.g., backpropagate current value) logger.warning( f" [Traverse] Returning node {current_node} due to SelectionError." ) @@ -236,13 +224,12 @@ def traverse_to_leaf(root_node: Node, config: MCTSConfig) -> tuple[Node, int]: f" [Traverse] Depth {depth}: Unexpected error during child selection: {e}. Breaking traverse.", exc_info=True, ) - # Also return current node here instead of raising logger.warning( f" [Traverse] Returning node {current_node} due to Unexpected Error." ) break - logger.debug( # Changed level from INFO to DEBUG + logger.debug( f"[Traverse] --- End Traverse: Reached Node at Depth {depth}. Reason: {stop_reason}. Final Node: {current_node} ---" ) return current_node, depth diff --git a/alphatriangle/nn/model.py b/alphatriangle/nn/model.py index c59153d..6e96a62 100644 --- a/alphatriangle/nn/model.py +++ b/alphatriangle/nn/model.py @@ -1,11 +1,14 @@ -# File: alphatriangle/nn/model.py import math from typing import cast import torch import torch.nn as nn -from ..config import EnvConfig, ModelConfig +# Import EnvConfig from trianglengin +from trianglengin.config import EnvConfig + +# Keep alphatriangle ModelConfig import +from ..config import ModelConfig # --- REMOVED: Incorrect self-import --- # from .model import AlphaTriangleNet @@ -130,7 +133,9 @@ class AlphaTriangleNet(nn.Module): Supports Distributional Value Head (C51). """ - def __init__(self, model_config: ModelConfig, env_config: EnvConfig): + def __init__( + self, model_config: ModelConfig, env_config: EnvConfig + ): # Uses trianglengin.EnvConfig super().__init__() self.model_config = model_config self.env_config = env_config diff --git a/alphatriangle/nn/network.py b/alphatriangle/nn/network.py index 8f40dad..72e3182 100644 --- a/alphatriangle/nn/network.py +++ b/alphatriangle/nn/network.py @@ -1,6 +1,5 @@ -# File: alphatriangle/nn/network.py import logging -import sys # Import sys for platform check +import sys from collections.abc import Mapping from typing import TYPE_CHECKING @@ -8,8 +7,12 @@ import torch import torch.nn.functional as F -from ..config import EnvConfig, ModelConfig, TrainConfig -from ..environment import GameState +# Import GameState and EnvConfig from trianglengin +from trianglengin.config import EnvConfig +from trianglengin.core.environment import GameState + +# Keep alphatriangle imports +from ..config import ModelConfig, TrainConfig from ..features import extract_state_features from ..utils.types import ActionType, PolicyValueOutput, StateType from .model import AlphaTriangleNet @@ -36,14 +39,15 @@ class NeuralNetwork: def __init__( self, model_config: ModelConfig, - env_config: EnvConfig, + env_config: EnvConfig, # Uses trianglengin.EnvConfig train_config: TrainConfig, device: torch.device, ): self.model_config = model_config - self.env_config = env_config + self.env_config = env_config # Store trianglengin.EnvConfig self.train_config = train_config self.device = device + # Pass trianglengin.EnvConfig to model self.model = AlphaTriangleNet(model_config, env_config).to(device) self.action_dim = env_config.ACTION_DIM self.model.eval() @@ -56,21 +60,14 @@ def __init__( self.v_min, self.v_max, self.num_atoms, device=self.device ) - # --- ADDED: Check for Windows/MPS before attempting compile --- if self.train_config.COMPILE_MODEL: - # --- ADDED: Skip compilation entirely on Windows due to Triton dependency --- if sys.platform == "win32": logger.warning( - "Model compilation requested but running on Windows. " - "Skipping torch.compile() as the default CUDA backend (Inductor) requires Triton, " - "which is not officially supported on Windows. Proceeding with eager execution." + "Model compilation requested but running on Windows. Skipping torch.compile()." ) - # --- END ADDED --- elif self.device.type == "mps": logger.warning( - "Model compilation requested but device is 'mps'. " - "Skipping torch.compile() due to known compatibility issues with this backend. " - "Proceeding with eager execution." + "Model compilation requested but device is 'mps'. Skipping torch.compile()." ) elif hasattr(torch, "compile"): try: @@ -83,9 +80,7 @@ def __init__( ) except Exception as e: logger.warning( - f"torch.compile() failed on device '{self.device}': {e}. " - f"Proceeding without compilation (using eager mode). " - f"Compilation might not be supported for this model/backend combination.", + f"torch.compile() failed on device '{self.device}': {e}. Proceeding without compilation.", exc_info=False, ) else: @@ -96,10 +91,9 @@ def __init__( logger.info( "Model compilation skipped (COMPILE_MODEL=False in TrainConfig)." ) - # --- END ADDED --- def _state_to_tensors(self, state: GameState) -> tuple[torch.Tensor, torch.Tensor]: - """Extracts features from GameState and converts them to tensors.""" + """Extracts features from trianglengin.GameState and converts them to tensors.""" state_dict: StateType = extract_state_features(state, self.model_config) grid_tensor = torch.from_numpy(state_dict["grid"]).unsqueeze(0).to(self.device) other_features_tensor = ( @@ -118,7 +112,7 @@ def _state_to_tensors(self, state: GameState) -> tuple[torch.Tensor, torch.Tenso def _batch_states_to_tensors( self, states: list[GameState] ) -> tuple[torch.Tensor, torch.Tensor]: - """Extracts features from a batch of GameStates and converts to batched tensors.""" + """Extracts features from a batch of trianglengin.GameStates and converts to batched tensors.""" if not states: grid_shape = ( 0, @@ -154,7 +148,6 @@ def _batch_states_to_tensors( def _logits_to_expected_value(self, value_logits: torch.Tensor) -> torch.Tensor: """Calculates the expected value from the value distribution logits.""" value_probs = F.softmax(value_logits, dim=1) - # Expand support to match batch size for broadcasting support_expanded = self.support.expand_as(value_probs) expected_value = torch.sum(value_probs * support_expanded, dim=1, keepdim=True) return expected_value @@ -162,7 +155,7 @@ def _logits_to_expected_value(self, value_logits: torch.Tensor) -> torch.Tensor: @torch.inference_mode() def evaluate(self, state: GameState) -> PolicyValueOutput: """ - Evaluates a single state. + Evaluates a single trianglengin.GameState. Returns policy mapping and EXPECTED value from the distribution. Raises NetworkEvaluationError on issues. """ @@ -201,9 +194,7 @@ def evaluate(self, state: GameState) -> PolicyValueOutput: policy_probs /= prob_sum expected_value_tensor = self._logits_to_expected_value(value_logits) - expected_value_scalar = expected_value_tensor.squeeze( - 0 - ).item() # Squeeze batch and atom dim, get scalar + expected_value_scalar = expected_value_tensor.squeeze(0).item() action_policy: Mapping[ActionType, float] = { i: float(p) for i, p in enumerate(policy_probs) @@ -228,7 +219,7 @@ def evaluate(self, state: GameState) -> PolicyValueOutput: @torch.inference_mode() def evaluate_batch(self, states: list[GameState]) -> list[PolicyValueOutput]: """ - Evaluates a batch of states. + Evaluates a batch of trianglengin.GameStates. Returns a list of (policy mapping, EXPECTED value). Raises NetworkEvaluationError on issues. """ @@ -258,9 +249,7 @@ def evaluate_batch(self, states: list[GameState]) -> list[PolicyValueOutput]: policy_probs = policy_probs_tensor.cpu().numpy() expected_values_tensor = self._logits_to_expected_value(value_logits) - expected_values = ( - expected_values_tensor.squeeze(1).cpu().numpy() - ) # Squeeze the atom dim + expected_values = expected_values_tensor.squeeze(1).cpu().numpy() results: list[PolicyValueOutput] = [] for batch_idx in range(len(states)): @@ -279,7 +268,7 @@ def evaluate_batch(self, states: list[GameState]) -> list[PolicyValueOutput]: policy_i: Mapping[ActionType, float] = { i: float(p) for i, p in enumerate(probs_i) } - value_i = float(expected_values[batch_idx]) # This is now a scalar + value_i = float(expected_values[batch_idx]) results.append((policy_i, value_i)) except Exception as e: @@ -291,7 +280,6 @@ def evaluate_batch(self, states: list[GameState]) -> list[PolicyValueOutput]: def get_weights(self) -> dict[str, torch.Tensor]: """Returns the model's state dictionary, moved to CPU.""" - # If model is compiled, access the original model for state_dict model_to_save = getattr(self.model, "_orig_mod", self.model) return {k: v.cpu() for k, v in model_to_save.state_dict().items()} @@ -299,10 +287,9 @@ def set_weights(self, weights: dict[str, torch.Tensor]): """Loads the model's state dictionary from the provided weights.""" try: weights_on_device = {k: v.to(self.device) for k, v in weights.items()} - # If model is compiled, load into the original model model_to_load = getattr(self.model, "_orig_mod", self.model) model_to_load.load_state_dict(weights_on_device) - self.model.eval() # Ensure the main (potentially compiled) model is in eval mode + self.model.eval() logger.debug("NN weights set successfully.") except Exception as e: logger.error(f"Error setting weights on NN instance: {e}", exc_info=True) diff --git a/alphatriangle/rl/core/README.md b/alphatriangle/rl/core/README.md index 570050b..23813e2 100644 --- a/alphatriangle/rl/core/README.md +++ b/alphatriangle/rl/core/README.md @@ -1,4 +1,4 @@ -# File: alphatriangle/rl/core/README.md + # RL Core Submodule (`alphatriangle.rl.core`) ## Purpose and Architecture @@ -29,7 +29,8 @@ This submodule contains core classes directly involved in the reinforcement lear ## Dependencies -- **[`alphatriangle.config`](../../config/README.md)**: `TrainConfig`, `EnvConfig`, `ModelConfig`. +- **[`alphatriangle.config`](../../config/README.md)**: `TrainConfig`, `ModelConfig`. +- **[`trianglengin.config`](../../config/README.md)**: `EnvConfig`. - **[`alphatriangle.nn`](../../nn/README.md)**: `NeuralNetwork`. - **[`alphatriangle.utils`](../../utils/README.md)**: Types (`Experience`, `PERBatchSample`, `StateType`, etc.) and helpers (`SumTree`). - **`torch`**: Used heavily by `Trainer`. @@ -37,4 +38,4 @@ This submodule contains core classes directly involved in the reinforcement lear --- -**Note:** Please keep this README updated when changing the responsibilities or interfaces of the Trainer or Buffer. \ No newline at end of file +**Note:** Please keep this README updated when changing the responsibilities or interfaces of the Trainer or Buffer. diff --git a/alphatriangle/rl/core/trainer.py b/alphatriangle/rl/core/trainer.py index 3d394d2..c28ae4f 100644 --- a/alphatriangle/rl/core/trainer.py +++ b/alphatriangle/rl/core/trainer.py @@ -1,4 +1,3 @@ -# File: alphatriangle/rl/core/trainer.py import logging from typing import cast @@ -8,12 +7,13 @@ import torch.optim as optim from torch.optim.lr_scheduler import _LRScheduler -from ...config import EnvConfig, TrainConfig +# Import EnvConfig from trianglengin +from trianglengin.config import EnvConfig + +# Keep alphatriangle imports +from ...config import TrainConfig from ...nn import NeuralNetwork -from ...utils.types import ( - ExperienceBatch, - PERBatchSample, -) +from ...utils.types import ExperienceBatch, PERBatchSample logger = logging.getLogger(__name__) @@ -28,24 +28,22 @@ def __init__( self, nn_interface: NeuralNetwork, train_config: TrainConfig, - env_config: EnvConfig, + env_config: EnvConfig, # Uses trianglengin.EnvConfig ): self.nn = nn_interface self.model = nn_interface.model self.train_config = train_config - self.env_config = env_config + self.env_config = env_config # Store trianglengin.EnvConfig self.model_config = nn_interface.model_config self.device = nn_interface.device self.optimizer = self._create_optimizer() self.scheduler: _LRScheduler | None = self._create_scheduler(self.optimizer) - # --- ADDED: Distributional Value Attributes (from NN interface) --- self.num_atoms = self.nn.num_atoms self.v_min = self.nn.v_min self.v_max = self.nn.v_max self.delta_z = self.nn.delta_z - self.support = self.nn.support.to(self.device) # Ensure support is on device - # --- END ADDED --- + self.support = self.nn.support.to(self.device) def _create_optimizer(self) -> optim.Optimizer: """Creates the optimizer based on TrainConfig.""" @@ -81,7 +79,6 @@ def _create_scheduler(self, optimizer: optim.Optimizer) -> _LRScheduler | None: step_size = getattr(self.train_config, "LR_SCHEDULER_STEP_SIZE", 100000) gamma = getattr(self.train_config, "LR_SCHEDULER_GAMMA", 0.1) logger.info(f" StepLR params: step_size={step_size}, gamma={gamma}") - # Cast return type return cast( "_LRScheduler", optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma), @@ -114,8 +111,8 @@ def _prepare_batch( batch_size = len(batch) grids = [] other_features = [] - # --- Store n-step returns --- n_step_returns = [] + # Use env_config from trianglengin action_dim_int = int(self.env_config.ACTION_DIM) # type: ignore[call-overload] policy_target_tensor = torch.zeros( (batch_size, action_dim_int), @@ -123,11 +120,10 @@ def _prepare_batch( device=self.device, ) - # --- Unpack n_step_return --- for i, (state_features, policy_target_map, n_step_return) in enumerate(batch): grids.append(state_features["grid"]) other_features.append(state_features["other_features"]) - n_step_returns.append(n_step_return) # Store the scalar return G + n_step_returns.append(n_step_return) for action, prob in policy_target_map.items(): if 0 <= action < action_dim_int: policy_target_tensor[i, action] = prob @@ -140,7 +136,6 @@ def _prepare_batch( other_features_tensor = torch.from_numpy(np.stack(other_features)).to( self.device ) - # --- Create tensor for n-step returns --- n_step_return_tensor = torch.tensor( n_step_returns, dtype=torch.float32, device=self.device ) @@ -151,7 +146,6 @@ def _prepare_batch( f"Unexpected other_features tensor shape: {other_features_tensor.shape}, expected dim {expected_other_dim}" ) - # --- Return n_step_return_tensor --- return ( grid_tensor, other_features_tensor, @@ -159,7 +153,6 @@ def _prepare_batch( n_step_return_tensor, ) - # --- REWRITTEN: Helper for calculating target distribution --- def _calculate_target_distribution( self, n_step_returns: torch.Tensor ) -> torch.Tensor: @@ -171,52 +164,22 @@ def _calculate_target_distribution( Tensor of shape (batch_size, num_atoms) representing the target distribution. """ batch_size = n_step_returns.size(0) - # Initialize target distribution tensor m = torch.zeros( (batch_size, self.num_atoms), dtype=torch.float32, device=self.device ) - - # Clamp returns to the support range [V_min, V_max] target_returns = n_step_returns.clamp(self.v_min, self.v_max) - - # Calculate the fractional index b and lower/upper atom indices l, u b = (target_returns - self.v_min) / self.delta_z - # --- CHANGED: Rename l to lower_idx --- lower_idx = b.floor().long() - # --- END CHANGED --- u = b.ceil().long() - - # Handle cases where b is an integer (l == u) - # Ensure indices stay within bounds [0, num_atoms - 1] - # --- CHANGED: Use lower_idx --- lower_idx = torch.max(torch.tensor(0, device=self.device), lower_idx) - # --- END CHANGED --- u = torch.min(torch.tensor(self.num_atoms - 1, device=self.device), u) - # If l==u after clamping, it means the target hit an atom exactly. - # We can assign full probability to that atom. - # However, the logic below handles this implicitly. - - # Calculate probabilities for lower and upper atoms based on distance - # --- CHANGED: Use lower_idx --- - m_l = u.float() - b # Weight for lower atom - m_u = b - lower_idx.float() # Weight for upper atom - # --- END CHANGED --- - - # Distribute probability mass using direct indexing - # Create batch indices for advanced indexing + m_l = u.float() - b + m_u = b - lower_idx.float() batch_indices = torch.arange(batch_size, device=self.device) - - # Add probabilities to the lower atoms - # --- CHANGED: Use lower_idx --- m[batch_indices, lower_idx] += m_l - # --- END CHANGED --- - # Add probabilities to the upper atoms m[batch_indices, u] += m_u - return m - # --- END REWRITTEN --- - def train_step( self, per_sample: PERBatchSample ) -> tuple[dict[str, float], np.ndarray] | None: @@ -234,7 +197,6 @@ def train_step( self.model.train() try: - # --- Get n_step_return_t --- grid_t, other_t, policy_target_t, n_step_return_t = self._prepare_batch( batch ) @@ -244,44 +206,26 @@ def train_step( return None self.optimizer.zero_grad() - # --- Get value_logits --- policy_logits, value_logits = self.model(grid_t, other_t) - # --- Value Loss (Distributional Cross-Entropy) --- - # Calculate target distribution target_distribution = self._calculate_target_distribution(n_step_return_t) - # Calculate cross-entropy loss - # F.cross_entropy expects logits (N, C) and targets (N,) with class indices - # OR targets (N, C) with probabilities if soft labels are used. - # We have target probabilities, so use KLDivLoss or manual cross-entropy. - # Manual Cross-Entropy: - sum(target_prob * log_softmax(pred_logits)) log_pred_dist = F.log_softmax(value_logits, dim=1) value_loss_elementwise = -torch.sum(target_distribution * log_pred_dist, dim=1) - # Apply importance sampling weights value_loss = (value_loss_elementwise * is_weights_t).mean() - # --- Policy Loss (Cross-Entropy) --- (No change needed here) log_probs = F.log_softmax(policy_logits, dim=1) policy_target_t = torch.nan_to_num(policy_target_t, nan=0.0) policy_loss_elementwise = -torch.sum(policy_target_t * log_probs, dim=1) policy_loss = (policy_loss_elementwise * is_weights_t).mean() - # --- Entropy Bonus --- (No change needed here) - entropy_scalar: float = 0.0 # Initialize as float - entropy_loss_term = torch.tensor( - 0.0, device=self.device - ) # Initialize as tensor + entropy_scalar: float = 0.0 + entropy_loss_term = torch.tensor(0.0, device=self.device) if self.train_config.ENTROPY_BONUS_WEIGHT > 0: policy_probs = F.softmax(policy_logits, dim=1) - # Calculate entropy term: -Sum(p * log(p)) entropy_term_elementwise: torch.Tensor = -torch.sum( policy_probs * torch.log(policy_probs + 1e-9), dim=1 ) - # Calculate mean entropy across batch for logging - entropy_scalar = float( - entropy_term_elementwise.mean().item() - ) # Cast result to float - # Calculate the loss term (negative entropy bonus) + entropy_scalar = float(entropy_term_elementwise.mean().item()) entropy_loss_term = ( -self.train_config.ENTROPY_BONUS_WEIGHT * entropy_term_elementwise.mean() @@ -290,7 +234,7 @@ def train_step( total_loss = ( self.train_config.POLICY_LOSS_WEIGHT * policy_loss + self.train_config.VALUE_LOSS_WEIGHT * value_loss - + entropy_loss_term # Use the calculated term + + entropy_loss_term ) total_loss.backward() @@ -307,11 +251,8 @@ def train_step( if self.scheduler: self.scheduler.step() - # --- TD Error Calculation for PER --- - # Use the difference between the n-step return G and the expected value E[V(s)] with torch.no_grad(): expected_value_pred = self.nn._logits_to_expected_value(value_logits) - # Ensure n_step_return_t has shape (batch_size,) td_errors = ( (n_step_return_t - expected_value_pred.squeeze(1)).detach().cpu().numpy() ) @@ -329,7 +270,6 @@ def train_step( def get_current_lr(self) -> float: """Returns the current learning rate from the optimizer.""" try: - # Ensure return type is float return float(self.optimizer.param_groups[0]["lr"]) except (IndexError, KeyError): logger.warning("Could not retrieve learning rate from optimizer.") diff --git a/alphatriangle/rl/core/visual_state_actor.py b/alphatriangle/rl/core/visual_state_actor.py deleted file mode 100644 index 9c4ba79..0000000 --- a/alphatriangle/rl/core/visual_state_actor.py +++ /dev/null @@ -1,55 +0,0 @@ -# File: alphatriangle/rl/core/visual_state_actor.py -import logging -import time -from typing import Any - -import ray - -from ...environment import GameState - -logger = logging.getLogger(__name__) - - -@ray.remote -class VisualStateActor: - """A simple Ray actor to hold the latest game states from workers for visualization.""" - - def __init__(self) -> None: - self.worker_states: dict[int, GameState] = {} - self.global_stats: dict[str, Any] = {} - self.last_update_times: dict[int, float] = {} - - def update_state(self, worker_id: int, game_state: GameState): - """Workers call this to update their latest state.""" - self.worker_states[worker_id] = game_state - self.last_update_times[worker_id] = time.time() - - def update_global_stats(self, stats: dict[str, Any]): - """Orchestrator calls this to update global stats.""" - # Ensure stats is a dictionary - if isinstance(stats, dict): - # Use update to merge instead of direct assignment - self.global_stats.update(stats) - else: - # Handle error or log warning if stats is not a dict - logger.error( - f"VisualStateActor received non-dict type for global stats: {type(stats)}" - ) - # Don't reset, just ignore the update - # self.global_stats = {} - - def get_all_states(self) -> dict[int, Any]: - """ - Called by the orchestrator to get states for the visual queue. - Key -1 holds the global_stats dictionary. - Other keys hold GameState objects. - """ - # Use dict() constructor instead of comprehension for ruff C416 - # Cast worker_states to dict[int, Any] before combining - combined_states: dict[int, Any] = dict(self.worker_states) - combined_states[-1] = self.global_stats.copy() - return combined_states - - def get_state(self, worker_id: int) -> GameState | None: - """Get state for a specific worker (unused currently).""" - return self.worker_states.get(worker_id) diff --git a/alphatriangle/rl/self_play/worker.py b/alphatriangle/rl/self_play/worker.py index 3922e9d..0324dd2 100644 --- a/alphatriangle/rl/self_play/worker.py +++ b/alphatriangle/rl/self_play/worker.py @@ -1,4 +1,3 @@ -# File: alphatriangle/rl/self_play/worker.py import logging import random import time @@ -7,10 +6,13 @@ import numpy as np import ray -import torch # Import torch +# Import trianglengin components +from trianglengin.config import EnvConfig +from trianglengin.core.environment import GameState + +# Import alphatriangle components from ...config import MCTSConfig, ModelConfig, TrainConfig -from ...environment import EnvConfig, GameState from ...features import extract_state_features from ...mcts import ( MCTSExecutionError, @@ -21,22 +23,12 @@ ) from ...nn import NeuralNetwork from ...utils import get_device, set_random_seeds - -# --- REMOVED: Type imports moved below --- -# from ...utils.types import Experience, PolicyTargetMapping, StateType, StepInfo -# --- END REMOVED --- +from ..types import SelfPlayResult if TYPE_CHECKING: from ...stats import StatsCollectorActor - - # --- ADDED: Type imports moved here --- from ...utils.types import Experience, PolicyTargetMapping, StateType, StepInfo - # --- END ADDED --- - - -from ..types import SelfPlayResult - logger = logging.getLogger(__name__) @@ -44,16 +36,13 @@ class SelfPlayWorker: """ A Ray actor responsible for running self-play episodes using MCTS and a NN. - Implements MCTS tree reuse between steps. - Stores extracted features (StateType) and the N-STEP RETURN in the experience buffer. - Returns a SelfPlayResult Pydantic model including aggregated stats. - Reports current state and step stats asynchronously using StepInfo including game_step and trainer_step. + Uses trianglengin.GameState. """ def __init__( self, actor_id: int, - env_config: EnvConfig, + env_config: EnvConfig, # Use trianglengin.EnvConfig mcts_config: MCTSConfig, model_config: ModelConfig, train_config: TrainConfig, @@ -63,7 +52,7 @@ def __init__( worker_device_str: str = "cpu", ): self.actor_id = actor_id - self.env_config = env_config + self.env_config = env_config # Store trianglengin.EnvConfig self.mcts_config = mcts_config self.model_config = model_config self.train_config = train_config @@ -71,14 +60,10 @@ def __init__( self.seed = seed if seed is not None else random.randint(0, 1_000_000) self.worker_device_str = worker_device_str - # --- N-Step Config --- self.n_step = self.train_config.N_STEP_RETURNS self.gamma = self.train_config.GAMMA - - # Store the global step of the current weights self.current_trainer_step = 0 - # Configure logging for the worker process worker_log_level = logging.INFO log_format = ( f"%(asctime)s [%(levelname)s] [W{self.actor_id}] %(name)s: %(message)s" @@ -93,32 +78,9 @@ def __init__( logging.getLogger("alphatriangle.nn").setLevel(nn_log_level) set_random_seeds(self.seed) - self.device = get_device(self.worker_device_str) - if self.device.type == "cuda": - try: - torch.cuda.set_device(self.device) - logger.info( - f"Successfully set default CUDA device for worker {self.actor_id} to {self.device} (Index: {torch.cuda.current_device()})." - ) - count = torch.cuda.device_count() - if count != 1: - logger.warning( - f"Worker {self.actor_id} sees {count} CUDA devices, expected 1 after Ray assignment. This might indicate an issue." - ) - else: - logger.info( - f"Worker {self.actor_id} sees 1 CUDA device as expected." - ) - - except Exception as cuda_set_err: - logger.error( - f"Failed to set default CUDA device for worker {self.actor_id} to {self.device}: {cuda_set_err}. " - f"Compilation or CUDA operations might fail.", - exc_info=True, - ) - + # NN uses trianglengin.EnvConfig self.nn_evaluator = NeuralNetwork( model_config=self.model_config, env_config=self.env_config, @@ -140,7 +102,6 @@ def __init__( def set_weights(self, weights: dict): """Updates the neural network weights.""" try: - # Removed attempt to get step from weights dict self.nn_evaluator.set_weights(weights) logger.debug("Weights updated.") except Exception as e: @@ -153,17 +114,19 @@ def set_current_trainer_step(self, global_step: int): def _report_current_state(self, game_state: GameState): """Asynchronously sends the current game state to the collector.""" - if self.stats_collector_actor: - try: - state_copy = game_state.copy() - self.stats_collector_actor.update_worker_game_state.remote( # type: ignore - self.actor_id, state_copy - ) - logger.debug( - f"Reported state step {state_copy.current_step} to collector." - ) - except Exception as e: - logger.error(f"Failed to report game state to collector: {e}") + # REMOVED - No more visual state reporting + pass + # if self.stats_collector_actor: + # try: + # state_copy = game_state.copy() + # self.stats_collector_actor.update_worker_game_state.remote( # type: ignore + # self.actor_id, state_copy + # ) + # logger.debug( + # f"Reported state step {state_copy.current_step} to collector." + # ) + # except Exception as e: + # logger.error(f"Failed to report game state to collector: {e}") def _log_step_stats_async( self, @@ -178,13 +141,13 @@ def _log_step_stats_async( """ if self.stats_collector_actor: try: - # Include current_trainer_step step_info: StepInfo = { "game_step_index": game_state.current_step, - "global_step": self.current_trainer_step, # Add trainer step context + "global_step": self.current_trainer_step, } step_stats: dict[str, tuple[float, StepInfo]] = { - "RL/Current_Score": (game_state.game_score, step_info), + # Use game_score() method from trianglengin.GameState + "RL/Current_Score": (game_state.game_score(), step_info), "MCTS/Step_Visits": (float(mcts_visits), step_info), "MCTS/Step_Depth": (float(mcts_depth), step_info), "RL/Step_Reward": (step_reward, step_info), @@ -197,15 +160,14 @@ def _log_step_stats_async( def run_episode(self) -> SelfPlayResult: """ Runs a single episode of self-play using MCTS and the internal neural network. - Implements MCTS tree reuse. - Stores extracted features (StateType) and the N-STEP RETURN in the experience buffer. - Returns a SelfPlayResult Pydantic model including aggregated stats. - Reports current state and step stats asynchronously. + Uses trianglengin.GameState. """ self.nn_evaluator.model.eval() episode_seed = self.seed + random.randint(0, 1000) + # Instantiate trianglengin.GameState game = GameState(self.env_config, initial_seed=episode_seed) + # Use is_over() from trianglengin.GameState if game.is_over(): logger.error( f"Game is over immediately after reset with seed {episode_seed}. Returning empty result." @@ -230,10 +192,12 @@ def run_episode(self) -> SelfPlayResult: step_simulations: list[int] = [] logger.info(f"Starting episode with seed {episode_seed}") - self._report_current_state(game) + # REMOVED self._report_current_state(game) + # Use copy() from trianglengin.GameState root_node: Node | None = Node(state=game.copy()) + # Use is_over() from trianglengin.GameState while not game.is_over(): step_start_time = time.monotonic() if root_node is None: @@ -242,6 +206,7 @@ def run_episode(self) -> SelfPlayResult: ) break + # Use is_over() from trianglengin.GameState if root_node.state.is_over(): logger.warning( f"MCTS root node state (Step {root_node.state.current_step}) is already terminal before running simulations. Ending episode." @@ -275,7 +240,7 @@ def run_episode(self) -> SelfPlayResult: f"Step {game.current_step}: MCTS finished ({mcts_duration:.3f}s). Max Depth: {mcts_max_depth}, Root Visits: {root_node.visit_count}" ) - # Log stats *before* taking the step + # Log stats *before* taking the step (reward is 0 here) self._log_step_stats_async( game, root_node.visit_count, mcts_max_depth, step_reward=0.0 ) @@ -304,6 +269,7 @@ def run_episode(self) -> SelfPlayResult: feature_start_time = time.monotonic() try: + # Pass trianglengin.GameState to feature extractor state_features: StateType = extract_state_features( game, self.model_config ) @@ -328,6 +294,7 @@ def run_episode(self) -> SelfPlayResult: game_step_start_time = time.monotonic() step_reward = 0.0 try: + # Use step() from trianglengin.GameState step_reward, done = game.step(action) except Exception as step_err: logger.error( @@ -351,6 +318,7 @@ def run_episode(self) -> SelfPlayResult: bootstrap_value = 0.0 if not done: try: + # Pass trianglengin.GameState to evaluator _, bootstrap_value = self.nn_evaluator.evaluate(game) except Exception as eval_err: logger.error( @@ -375,8 +343,8 @@ def run_episode(self) -> SelfPlayResult: ) ) - self._report_current_state(game) - # Log stats *after* taking the step + # REMOVED self._report_current_state(game) + # Log stats *after* taking the step, using the actual step_reward self._log_step_stats_async( game, root_node.visit_count if root_node else 0, @@ -386,7 +354,7 @@ def run_episode(self) -> SelfPlayResult: tree_reuse_start_time = time.monotonic() if not done: - if root_node and action in root_node.children: # Check root_node exists + if root_node and action in root_node.children: root_node = root_node.children[action] root_node.parent = None logger.debug( @@ -396,6 +364,7 @@ def run_episode(self) -> SelfPlayResult: logger.error( f"Child node for selected action {action} not found in MCTS tree children: {list(root_node.children.keys()) if root_node else 'No Root'}. Resetting MCTS root to current game state." ) + # Use copy() from trianglengin.GameState root_node = Node(state=game.copy()) else: root_node = None @@ -413,7 +382,8 @@ def run_episode(self) -> SelfPlayResult: if done: break - final_score = game.game_score + # Use game_score() from trianglengin.GameState + final_score = game.game_score() logger.info( f"Episode finished. Final Score: {final_score:.2f}, Steps: {game.current_step}" ) diff --git a/alphatriangle/stats/__init__.py b/alphatriangle/stats/__init__.py index 9cd14a8..e7160ec 100644 --- a/alphatriangle/stats/__init__.py +++ b/alphatriangle/stats/__init__.py @@ -1,26 +1,16 @@ -# File: alphatriangle/stats/__init__.py """ -Statistics collection and plotting module. +Statistics collection module. """ +# Import StatsCollectorData from utils where it should reside from alphatriangle.utils.types import StatsCollectorData -from . import plot_utils from .collector import StatsCollectorActor -from .plot_definitions import PlotDefinitions, PlotType # Import new definitions -from .plot_rendering import render_subplot # Import new rendering function -from .plotter import Plotter + +# REMOVE Plotter, PlotDefinitions, PlotType, render_subplot, plot_utils __all__ = [ # Core Collector "StatsCollectorActor", - "StatsCollectorData", - # Plotting Orchestrator - "Plotter", - # Plotting Definitions & Rendering Logic - "PlotDefinitions", - "PlotType", - "render_subplot", - # Plotting Utilities - "plot_utils", + "StatsCollectorData", # Export type alias ] diff --git a/alphatriangle/stats/collector.py b/alphatriangle/stats/collector.py index bb02344..217309e 100644 --- a/alphatriangle/stats/collector.py +++ b/alphatriangle/stats/collector.py @@ -7,10 +7,11 @@ import numpy as np import ray -from ..utils.types import StatsCollectorData, StepInfo - +# Correct import path for GameState if TYPE_CHECKING: - from ..environment import GameState + from trianglengin.core.environment import GameState + +from ..utils.types import StatsCollectorData, StepInfo logger = logging.getLogger(__name__) diff --git a/alphatriangle/stats/plot_definitions.py b/alphatriangle/stats/plot_definitions.py deleted file mode 100644 index 3825bd6..0000000 --- a/alphatriangle/stats/plot_definitions.py +++ /dev/null @@ -1,68 +0,0 @@ -# File: alphatriangle/stats/plot_definitions.py -from typing import Literal, NamedTuple - -# Define type for x-axis data source -PlotXAxisType = Literal["index", "global_step", "buffer_size"] - -# Define metric key constant for weight updates -WEIGHT_UPDATE_METRIC_KEY = "Internal/Weight_Update_Step" - - -class PlotDefinition(NamedTuple): - """Configuration for a single subplot.""" - - metric_key: str # Key in the StatsCollectorData dictionary - label: str # Title displayed on the plot - y_log_scale: bool # Use logarithmic scale for y-axis - x_axis_type: PlotXAxisType # What the x-axis represents - - -class PlotDefinitions: - """Holds the definitions for all plots in the dashboard.""" - - def __init__(self, colors: dict[str, tuple[float, float, float]]): - self.colors = colors # Store colors if needed for default lookups - self.nrows: int = 4 - self.ncols: int = 3 - # Key used to get weight update steps for vertical lines - self.weight_update_key = WEIGHT_UPDATE_METRIC_KEY # Use the constant - - # Define the layout and properties of each plot - self._definitions: list[PlotDefinition] = [ - # Row 1 - # --- CHANGED: x_axis_type to "index" --- - PlotDefinition("RL/Current_Score", "Score", False, "index"), - PlotDefinition( - "Rate/Episodes_Per_Sec", "Episodes/sec", False, "buffer_size" - ), - PlotDefinition("Loss/Total", "Total Loss", True, "global_step"), - # Row 2 - PlotDefinition("RL/Step_Reward", "Step Reward", False, "index"), - PlotDefinition( - "Rate/Simulations_Per_Sec", "Sims/sec", False, "buffer_size" - ), - PlotDefinition("Loss/Policy", "Policy Loss", True, "global_step"), - # Row 3 - PlotDefinition("MCTS/Step_Visits", "MCTS Visits", False, "index"), - PlotDefinition("Buffer/Size", "Buffer Size", False, "buffer_size"), - PlotDefinition("Loss/Value", "Value Loss", True, "global_step"), - # Row 4 - PlotDefinition("MCTS/Step_Depth", "MCTS Depth", False, "index"), - # --- END CHANGED --- - PlotDefinition("Rate/Steps_Per_Sec", "Steps/sec", False, "global_step"), - PlotDefinition("LearningRate", "Learn Rate", True, "global_step"), - ] - - # Validate grid size - if len(self._definitions) > self.nrows * self.ncols: - raise ValueError( - f"Number of plot definitions ({len(self._definitions)}) exceeds grid size ({self.nrows}x{self.ncols})" - ) - - def get_definitions(self) -> list[PlotDefinition]: - """Returns the list of plot definitions.""" - return self._definitions - - -# Define PlotType for potential external use, though PlotDefinition is more specific -PlotType = PlotDefinition diff --git a/alphatriangle/stats/plot_rendering.py b/alphatriangle/stats/plot_rendering.py deleted file mode 100644 index bc2bdfd..0000000 --- a/alphatriangle/stats/plot_rendering.py +++ /dev/null @@ -1,317 +0,0 @@ -# File: alphatriangle/stats/plot_rendering.py -import logging -from collections import deque -from typing import TYPE_CHECKING # Added cast - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.ticker import FuncFormatter, MaxNLocator - -from ..utils.types import StepInfo -from .plot_definitions import PlotDefinition -from .plot_utils import calculate_rolling_average, format_value - -if TYPE_CHECKING: - from .collector import StatsCollectorData - -logger = logging.getLogger(__name__) - - -def _find_closest_index_for_global_step( - target_global_step: int, - step_info_list: list[StepInfo], -) -> int | None: - """ - Finds the index in the step_info_list where the stored 'global_step' - is closest to the target_global_step. - Returns None if no suitable point is found or list is empty. - """ - if not step_info_list: - return None - - best_match_idx = None - min_step_diff = float("inf") - - for i, step_info in enumerate(step_info_list): - global_step_in_info = step_info.get("global_step") - - if global_step_in_info is not None: - step_diff = abs(global_step_in_info - target_global_step) - if step_diff < min_step_diff: - min_step_diff = step_diff - best_match_idx = i - # Optimization: If we found an exact match, we can stop early - # Also, if the steps start increasing again, we passed the best point - if step_diff == 0 or ( - best_match_idx is not None and global_step_in_info > target_global_step - ): - break - - # Optional: Add a tolerance? If min_step_diff is too large, maybe don't return a match? - # For now, return the index of the closest found value. - return best_match_idx - - -def render_subplot( - ax: plt.Axes, - plot_data: "StatsCollectorData", - plot_def: PlotDefinition, - colors: dict[str, tuple[float, float, float]], - rolling_window_sizes: list[int], - weight_update_steps: list[int] | None = None, # Global steps where updates happened -) -> bool: - """ - Renders data for a single metric onto the given Matplotlib Axes object. - Scatter point size/alpha decrease linearly as more data/longer averages are shown. - Draws semi-transparent black, solid vertical lines for weight updates on all plots. - Returns True if data was plotted, False otherwise. - """ - ax.clear() - ax.set_facecolor((0.15, 0.15, 0.15)) # Dark background for axes - - metric_key = plot_def.metric_key - label = plot_def.label - log_scale = plot_def.y_log_scale - x_axis_type = plot_def.x_axis_type # e.g., "global_step", "buffer_size", "index" - - x_data: list[int] = [] - y_data: list[float] = [] - x_label_text = "Index" # Default label - step_info_list: list[StepInfo] = [] # Store step info for mapping - - dq = plot_data.get(metric_key, deque()) - if dq: - # Extract x-axis value and store StepInfo - temp_x = [] - temp_y = [] - for i, (step_info, value) in enumerate(dq): - x_val: int | None = None - if x_axis_type == "global_step": - x_val = step_info.get("global_step") - x_label_text = "Train Step" - elif x_axis_type == "buffer_size": - x_val = step_info.get("buffer_size") - x_label_text = "Buffer Size" - else: # index - x_val = i # Use the simple index 'i' as the x-value - if ( - "Score" in metric_key - or "Reward" in metric_key - or "MCTS" in metric_key - ): - x_label_text = "Game Step Index" # Label remains descriptive - else: - x_label_text = "Data Point Index" - - if x_val is not None: - temp_x.append(x_val) - temp_y.append(value) - step_info_list.append( - step_info - ) # Keep StepInfo aligned with data points - else: - logger.warning( - f"Missing x-axis key '{x_axis_type}' in step_info for metric '{metric_key}'. Skipping point." - ) - - x_data = temp_x - y_data = temp_y - - color_mpl = colors.get(metric_key, (0.5, 0.5, 0.5)) - placeholder_color_mpl = colors.get("placeholder", (0.5, 0.5, 0.5)) - - data_plotted = False - if not x_data or not y_data: - ax.text( - 0.5, - 0.5, - f"{label}\n(No Data)", - ha="center", - va="center", - transform=ax.transAxes, - color=placeholder_color_mpl, - fontsize=9, - ) - ax.set_title(label, loc="left", fontsize=9, color="white", pad=2) - ax.set_xticks([]) - ax.set_yticks([]) - else: - data_plotted = True - - # Determine best rolling average window - num_points = len(y_data) - best_window = 0 - for window in sorted(rolling_window_sizes, reverse=True): - if num_points >= window: - best_window = window - break - - # Determine scatter size/alpha based on best_window - # Initialize as float - scatter_size: float = 0.0 - scatter_alpha: float = 0.0 - max_scatter_size = 10.0 # Use float - min_scatter_size = 1.0 # Use float - max_scatter_alpha = 0.3 - min_scatter_alpha = 0.03 - max_window_for_interp = float(max(rolling_window_sizes)) - - if best_window == 0: - scatter_size = max_scatter_size - scatter_alpha = max_scatter_alpha - elif best_window >= max_window_for_interp: - scatter_size = min_scatter_size - scatter_alpha = min_scatter_alpha - else: - interp_fraction = best_window / max_window_for_interp - # Cast result of np.interp to float - scatter_size = float( - np.interp(interp_fraction, [0, 1], [max_scatter_size, min_scatter_size]) - ) - scatter_alpha = float( - np.interp( - interp_fraction, [0, 1], [max_scatter_alpha, min_scatter_alpha] - ) - ) - - # Plot raw data with dynamic size/alpha - ax.scatter( - x_data, - y_data, - color=color_mpl, - alpha=scatter_alpha, - s=scatter_size, # Pass float size - label="_nolegend_", - zorder=2, - ) - - # Plot best rolling average - if best_window > 0: - rolling_avg = calculate_rolling_average(y_data, best_window) - if len(rolling_avg) == len(x_data): - ax.plot( - x_data, - rolling_avg, - color=color_mpl, - alpha=0.9, - linewidth=1.5, - label=f"Avg {best_window}", - zorder=3, - ) - ax.legend( - fontsize=6, loc="upper right", frameon=False, labelcolor="lightgray" - ) - else: - logger.warning( - f"Length mismatch for rolling avg ({len(rolling_avg)}) vs x_data ({len(x_data)}) for {label}. Skipping avg plot." - ) - - # Draw vertical lines by mapping global_step to current x-axis value - if weight_update_steps and step_info_list: - plotted_line_x_values: set[float] = set() # Store plotted x-values as float - for update_global_step in weight_update_steps: - x_index_for_line = _find_closest_index_for_global_step( - update_global_step, step_info_list - ) - - if x_index_for_line is not None and x_index_for_line < len(x_data): - actual_x_value: int | float - if x_axis_type == "index": - actual_x_value = x_index_for_line # int - else: - # Explicitly cast list access to int to satisfy MyPy - actual_x_value = int(x_data[x_index_for_line]) # int - - # Cast to float for axvline and check if already plotted - actual_x_float = float(actual_x_value) - if actual_x_float not in plotted_line_x_values: - ax.axvline( - x=actual_x_float, # Pass float - color="black", - linestyle="-", - linewidth=0.7, - alpha=0.5, - zorder=10, - ) - plotted_line_x_values.add(actual_x_float) - else: - logger.debug( - f"Could not map global_step {update_global_step} to an index for plot '{label}'" - ) - - # Formatting - ax.set_title(label, loc="left", fontsize=9, color="white", pad=2) - ax.tick_params(axis="both", which="major", labelsize=7, colors="lightgray") - ax.grid( - True, linestyle=":", linewidth=0.5, color=(0.4, 0.4, 0.4), zorder=1 - ) # Ensure grid is behind lines - - # Set y-axis scale - if log_scale: - ax.set_yscale("log") - min_val = min((v for v in y_data if v > 0), default=1e-6) - max_val = max(y_data, default=1.0) - ylim_bottom = max(1e-9, min_val * 0.1) - ylim_top = max_val * 10 - if ylim_bottom < ylim_top: - ax.set_ylim(bottom=ylim_bottom, top=ylim_top) - else: - ax.set_ylim(bottom=1e-9, top=1.0) - else: - ax.set_yscale("linear") - min_val = min(y_data) if y_data else 0.0 - max_val = max(y_data) if y_data else 0.0 - val_range = max_val - min_val - if abs(val_range) < 1e-6: - center_val = (min_val + max_val) / 2.0 - buffer = max(abs(center_val * 0.1), 0.5) - ylim_bottom, ylim_top = center_val - buffer, center_val + buffer - else: - buffer = val_range * 0.1 - ylim_bottom, ylim_top = min_val - buffer, max_val + buffer - if all(v >= 0 for v in y_data) and ylim_bottom < 0: - ylim_bottom = 0.0 - if ylim_bottom >= ylim_top: - ylim_bottom, ylim_top = min_val - 0.5, max_val + 0.5 - if ylim_bottom >= ylim_top: - ylim_bottom, ylim_top = 0.0, max(1.0, max_val) - ax.set_ylim(bottom=ylim_bottom, top=ylim_top) - - # Format x-axis - ax.xaxis.set_major_locator(MaxNLocator(nbins=4, integer=True)) - ax.xaxis.set_major_formatter( - FuncFormatter( - lambda x, _: f"{int(x / 1000)}k" if x >= 1000 else f"{int(x)}" - ) - ) - ax.set_xlabel(x_label_text, fontsize=8, color="gray") - - # Format y-axis - ax.yaxis.set_major_locator(MaxNLocator(nbins=5)) - ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: format_value(y))) - - # Add info text (min/max/current) - current_val_str = format_value(y_data[-1]) - min_val_overall = min(y_data) - max_val_overall = max(y_data) - min_str = format_value(min_val_overall) - max_str = format_value(max_val_overall) - info_text = f"Min:{min_str} | Max:{max_str} | Cur:{current_val_str}" - ax.text( - 1.0, - 1.01, - info_text, - ha="right", - va="bottom", - transform=ax.transAxes, - fontsize=6, - color="white", - ) - - # Common Axis Styling (applied regardless of data presence) - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.spines["bottom"].set_color("gray") - ax.spines["left"].set_color("gray") - - return data_plotted diff --git a/alphatriangle/stats/plot_utils.py b/alphatriangle/stats/plot_utils.py deleted file mode 100644 index ebc8733..0000000 --- a/alphatriangle/stats/plot_utils.py +++ /dev/null @@ -1,178 +0,0 @@ -# File: alphatriangle/stats/plot_utils.py -import logging - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.ticker import FuncFormatter, MaxNLocator - -logger = logging.getLogger(__name__) - - -def calculate_rolling_average(data: list[float], window: int) -> list[float]: - """Calculates the rolling average with handling for edges.""" - if not data or window <= 0: - return [] - if window > len(data): - # If window is larger than data, return average of all data for all points - avg = np.mean(data) - # Cast to float explicitly - return [float(avg)] * len(data) - # Use convolution for efficient rolling average - weights = np.ones(window) / window - rolling_avg = np.convolve(data, weights, mode="valid") - # Pad the beginning to match the original length - padding = [float(np.mean(data[:i])) for i in range(1, window)] - # Cast result to list of floats - return padding + [float(x) for x in rolling_avg] - - -def calculate_trend_line( - steps: list[int], values: list[float] -) -> tuple[list[int], list[float]]: - """Calculates a simple linear trend line.""" - if len(steps) < 2: - return [], [] - try: - coeffs = np.polyfit(steps, values, 1) - poly = np.poly1d(coeffs) - trend_values = poly(steps) - return steps, list(trend_values) - except Exception as e: - logger.warning(f"Could not calculate trend line: {e}") - return [], [] - - -def format_value(value: float) -> str: - """Formats a float value for display on the plot.""" - if abs(value) < 1e-6: - return "0" - if abs(value) < 1e-3: - return f"{value:.2e}" - if abs(value) >= 1e6: - return f"{value / 1e6:.1f}M" - if abs(value) >= 1e3: - return f"{value / 1e3:.1f}k" - if abs(value) >= 10: - return f"{value:.1f}" - return f"{value:.2f}" - - -def render_single_plot( - ax: plt.Axes, - steps: list[int], - values: list[float], - label: str, - color: tuple[float, float, float], - placeholder_color: tuple[float, float, float], - rolling_window_sizes: list[int], - show_placeholder: bool = False, - placeholder_text: str | None = None, - y_log_scale: bool = False, -): - """ - Renders a single metric plot onto a Matplotlib Axes object. - Plots raw data and the single best rolling average line. - """ - ax.clear() - ax.set_facecolor((0.15, 0.15, 0.15)) # Dark background for axes - - if show_placeholder or not steps or not values: - text_to_display = placeholder_text if placeholder_text else "(No Data)" - ax.text( - 0.5, - 0.5, - text_to_display, - ha="center", - va="center", - transform=ax.transAxes, - color=placeholder_color, - fontsize=9, - ) - ax.set_title(label, loc="left", fontsize=9, color="white", pad=2) - ax.set_xticks([]) - ax.set_yticks([]) - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.spines["bottom"].set_color("gray") - ax.spines["left"].set_color("gray") - return - - # Plot raw data (thin, semi-transparent) - ax.plot(steps, values, color=color, alpha=0.3, linewidth=0.7, label="_nolegend_") - - # --- CHANGED: Plot only the single best rolling average --- - num_points = len(steps) - best_window = 0 - # Iterate through window sizes in descending order - for window in sorted(rolling_window_sizes, reverse=True): - if num_points >= window: - best_window = window - break # Found the largest applicable window - - if best_window > 0: - rolling_avg = calculate_rolling_average(values, best_window) - ax.plot( - steps, - rolling_avg, - color=color, - alpha=0.9, # Make it more prominent - linewidth=1.5, - label=f"Avg {best_window}", # Label this single line - ) - # Add legend only if rolling average was plotted - ax.legend(fontsize=6, loc="upper right", frameon=False, labelcolor="lightgray") - # --- END CHANGED --- - - # Formatting - ax.set_title(label, loc="left", fontsize=9, color="white", pad=2) - ax.tick_params(axis="both", which="major", labelsize=7, colors="lightgray") - ax.grid(True, linestyle=":", linewidth=0.5, color=(0.4, 0.4, 0.4)) - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.spines["bottom"].set_color("gray") - ax.spines["left"].set_color("gray") - - # Set y-axis scale - if y_log_scale: - ax.set_yscale("log") - # Ensure positive values for log scale, adjust limits if needed - min_val = ( - min(v for v in values if v > 0) if any(v > 0 for v in values) else 1e-6 - ) - max_val = max(values) if values else 1.0 - # Add buffer for log scale - ylim_bottom = max(1e-9, min_val * 0.1) - ylim_top = max_val * 10 - # Prevent potential errors if limits are invalid - if ylim_bottom < ylim_top: - ax.set_ylim(bottom=ylim_bottom, top=ylim_top) - else: - ax.set_ylim(bottom=1e-9, top=1.0) # Fallback limits - else: - ax.set_yscale("linear") - - # Format x-axis (steps) - ax.xaxis.set_major_locator(MaxNLocator(nbins=4, integer=True)) - ax.xaxis.set_major_formatter( - # Remove unused 'p' argument - FuncFormatter(lambda x, _: f"{int(x / 1000)}k" if x >= 1000 else f"{int(x)}") - ) - ax.set_xlabel("Step", fontsize=8, color="gray") - - # Format y-axis - ax.yaxis.set_major_locator(MaxNLocator(nbins=5)) - # Remove unused 'p' argument - ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: format_value(y))) - - # Add current value text - current_val_str = format_value(values[-1]) - ax.text( - 1.0, - 1.01, - f"Cur: {current_val_str}", - ha="right", - va="bottom", - transform=ax.transAxes, - fontsize=7, - color="white", - ) diff --git a/alphatriangle/stats/plotter.py b/alphatriangle/stats/plotter.py deleted file mode 100644 index 7f28125..0000000 --- a/alphatriangle/stats/plotter.py +++ /dev/null @@ -1,387 +0,0 @@ -# File: alphatriangle/stats/plotter.py -import contextlib -import logging -import time -from collections import deque -from io import BytesIO -from typing import TYPE_CHECKING, Any - -import matplotlib - -if TYPE_CHECKING: - import numpy as np - - # --- MOVED: Import vis_colors only for type checking --- - -import pygame - -# Use Agg backend before importing pyplot -matplotlib.use("Agg") -import matplotlib.pyplot as plt # noqa: E402 - -# --- MOVED: Import normalize_color_for_matplotlib here --- -from ..utils.helpers import normalize_color_for_matplotlib # noqa: E402 - -# --- CHANGED: Import StepInfo --- -from ..utils.types import StatsCollectorData # noqa: E402 - -# --- END CHANGED --- -from .plot_definitions import ( # noqa: E402 - WEIGHT_UPDATE_METRIC_KEY, # Import key - PlotDefinitions, -) -from .plot_rendering import render_subplot # Import subplot rendering logic - -logger = logging.getLogger(__name__) - - -class Plotter: - """ - Handles creation and caching of the multi-plot Matplotlib surface. - Uses PlotDefinitions for layout and plot_rendering for drawing subplots. - """ - - def __init__(self, plot_update_interval: float = 0.75): # Increased interval - self.plot_surface_cache: pygame.Surface | None = None - self.last_plot_update_time: float = 0.0 - self.plot_update_interval: float = plot_update_interval - self.colors = self._init_colors() - self.plot_definitions = PlotDefinitions(self.colors) # Instantiate definitions - - self.rolling_window_sizes: list[int] = [ - 10, - 50, - 100, - 500, - 1000, - 5000, - ] - - self.fig: plt.Figure | None = None - self.axes: np.ndarray | None = None # type: ignore # numpy is type-checked \only - self.last_target_size: tuple[int, int] = (0, 0) - self.last_data_hash: int | None = None - - logger.info( - f"[Plotter] Initialized with update interval: {self.plot_update_interval}s" - ) - - def _init_colors(self) -> dict[str, tuple[float, float, float]]: - """Initializes plot colors using hardcoded values to avoid vis import.""" - # This breaks the circular import. Ensure these match vis_colors.py - colors_rgb = { - "YELLOW": (230, 230, 40), - "WHITE": (255, 255, 255), - "LIGHT_GRAY": (180, 180, 180), - "LIGHTG": (144, 238, 144), - "RED": (220, 40, 40), - "BLUE": (60, 60, 220), - "GREEN": (40, 200, 40), - "CYAN": (40, 200, 200), - "PURPLE": (140, 40, 140), - "BLACK": (0, 0, 0), - "GRAY": (100, 100, 100), - "ORANGE": (240, 150, 20), - "HOTPINK": (255, 105, 180), - } - - return { - "RL/Current_Score": normalize_color_for_matplotlib(colors_rgb["YELLOW"]), - "RL/Step_Reward": normalize_color_for_matplotlib(colors_rgb["WHITE"]), - "MCTS/Step_Visits": normalize_color_for_matplotlib( - colors_rgb["LIGHT_GRAY"] - ), - "MCTS/Step_Depth": normalize_color_for_matplotlib(colors_rgb["LIGHTG"]), - "Loss/Total": normalize_color_for_matplotlib(colors_rgb["RED"]), - "Loss/Value": normalize_color_for_matplotlib(colors_rgb["BLUE"]), - "Loss/Policy": normalize_color_for_matplotlib(colors_rgb["GREEN"]), - "LearningRate": normalize_color_for_matplotlib(colors_rgb["CYAN"]), - "Buffer/Size": normalize_color_for_matplotlib(colors_rgb["PURPLE"]), - WEIGHT_UPDATE_METRIC_KEY: normalize_color_for_matplotlib( - colors_rgb["BLACK"] - ), - "placeholder": normalize_color_for_matplotlib(colors_rgb["GRAY"]), - "Rate/Steps_Per_Sec": normalize_color_for_matplotlib(colors_rgb["ORANGE"]), - "Rate/Episodes_Per_Sec": normalize_color_for_matplotlib( - colors_rgb["HOTPINK"] - ), - "Rate/Simulations_Per_Sec": normalize_color_for_matplotlib( - colors_rgb["LIGHTG"] - ), - "PER/Beta": normalize_color_for_matplotlib(colors_rgb["ORANGE"]), - "Loss/Entropy": normalize_color_for_matplotlib(colors_rgb["PURPLE"]), - "Loss/Mean_TD_Error": normalize_color_for_matplotlib(colors_rgb["RED"]), - "Progress/Train_Step_Percent": normalize_color_for_matplotlib( - colors_rgb["GREEN"] - ), - "Progress/Total_Simulations": normalize_color_for_matplotlib( - colors_rgb["CYAN"] - ), - } - - def _init_figure(self, target_width: int, target_height: int): - """Initializes the Matplotlib figure and axes based on plot definitions.""" - logger.info( - f"[Plotter] Initializing Matplotlib figure for size {target_width}x{target_height}" - ) - if self.fig: - try: - plt.close(self.fig) - except Exception as e: - logger.warning(f"Error closing previous figure: {e}") - - dpi = 96 - fig_width_in = max(1, target_width / dpi) - fig_height_in = max(1, target_height / dpi) - - try: - nrows = self.plot_definitions.nrows - ncols = self.plot_definitions.ncols - self.fig, self.axes = plt.subplots( - nrows, - ncols, - figsize=(fig_width_in, fig_height_in), - dpi=dpi, - sharex=False, - ) - if self.axes is None: - raise RuntimeError("Failed to create Matplotlib subplots.") - - self.fig.patch.set_facecolor((0.1, 0.1, 0.1)) - self.fig.subplots_adjust( - hspace=0.40, - wspace=0.08, - left=0.03, - right=0.99, - bottom=0.05, - top=0.98, - ) - self.last_target_size = (target_width, target_height) - logger.info( - f"[Plotter] Matplotlib figure initialized ({nrows}x{ncols} grid)." - ) - except Exception as e: - logger.error(f"Error creating Matplotlib figure: {e}", exc_info=True) - self.fig, self.axes, self.last_target_size = None, None, (0, 0) - - def _get_data_hash(self, plot_data: StatsCollectorData) -> int: - """Generates a hash based on data lengths and recent values.""" - hash_val = 0 - sample_size = 5 - for key in sorted(plot_data.keys()): - dq = plot_data[key] - hash_val ^= hash(key) ^ len(dq) - if not dq: - continue - try: - num_to_sample = min(len(dq), sample_size) - for i in range(-1, -num_to_sample - 1, -1): - # Hash StepInfo dict and value - step_info, val = dq[i] - # Simple hash for dict: hash tuple of sorted items - step_info_hash = hash(tuple(sorted(step_info.items()))) - hash_val ^= step_info_hash ^ hash(f"{val:.6f}") - except IndexError: - pass - return hash_val - - def _update_plot_data(self, plot_data: StatsCollectorData) -> bool: - """Updates the data on the existing Matplotlib axes using render_subplot.""" - if self.fig is None or self.axes is None: - logger.warning("[Plotter] Cannot update plot data, figure not initialized.") - return False - - plot_update_start = time.monotonic() - try: - axes_flat = self.axes.flatten() - plot_defs = self.plot_definitions.get_definitions() - num_plots = len(plot_defs) - - # Extract weight update steps (global_step values) - weight_update_steps: list[int] = [] - if WEIGHT_UPDATE_METRIC_KEY in plot_data: - dq = plot_data[WEIGHT_UPDATE_METRIC_KEY] - if dq: - # Extract global_step from StepInfo - weight_update_steps = [ - step_info["global_step"] - for step_info, _ in dq - if "global_step" in step_info - ] - - for i, plot_def in enumerate(plot_defs): - if i >= len(axes_flat): - break - ax = axes_flat[i] - # Pass weight_update_steps - render_subplot( - ax=ax, - plot_data=plot_data, - plot_def=plot_def, - colors=self.colors, - rolling_window_sizes=self.rolling_window_sizes, - weight_update_steps=weight_update_steps, # Pass the list - ) - - for i in range(num_plots, len(axes_flat)): - ax = axes_flat[i] - ax.clear() - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_facecolor((0.15, 0.15, 0.15)) - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.spines["bottom"].set_color("gray") - ax.spines["left"].set_color("gray") - - self._apply_final_axis_formatting(axes_flat) - - plot_update_duration = time.monotonic() - plot_update_start - logger.debug(f"[Plotter] Plot data updated in {plot_update_duration:.4f}s") - return True - - except Exception as e: - logger.error(f"Error updating plot data: {e}", exc_info=True) - try: - if self.axes is not None: - for ax in self.axes.flatten(): - ax.clear() - except Exception: - pass - return False - - def _apply_final_axis_formatting(self, axes_flat: Any): - """Hides x-axis labels for plots not in the bottom row.""" - if not hasattr(axes_flat, "__iter__"): - logger.error("axes_flat is not iterable in _apply_final_axis_formatting") - return - - nrows = self.plot_definitions.nrows - ncols = self.plot_definitions.ncols - num_plots = len(self.plot_definitions.get_definitions()) - - for i, ax in enumerate(axes_flat): - if i >= num_plots: - continue - - if i < (nrows - 1) * ncols: - ax.set_xticklabels([]) - ax.set_xlabel("") - ax.tick_params(axis="x", rotation=0) - - def _render_figure_to_surface( - self, target_width: int, target_height: int - ) -> pygame.Surface | None: - """Renders the current Matplotlib figure to a Pygame surface.""" - if self.fig is None: - logger.warning("[Plotter] Cannot render figure, not initialized.") - return None - - render_start = time.monotonic() - try: - self.fig.canvas.draw_idle() - buf = BytesIO() - self.fig.savefig( - buf, format="png", transparent=False, facecolor=self.fig.get_facecolor() - ) - buf.seek(0) - plot_img_surface = pygame.image.load(buf, "png").convert_alpha() - buf.close() - - current_size = plot_img_surface.get_size() - if current_size != (target_width, target_height): - plot_img_surface = pygame.transform.smoothscale( - plot_img_surface, (target_width, target_height) - ) - render_duration = time.monotonic() - render_start - logger.debug( - f"[Plotter] Figure rendered to surface in {render_duration:.4f}s" - ) - return plot_img_surface - - except Exception as e: - logger.error(f"Error rendering Matplotlib figure: {e}", exc_info=True) - return None - - def get_plot_surface( - self, plot_data: StatsCollectorData, target_width: int, target_height: int - ) -> pygame.Surface | None: - """Returns the cached plot surface or creates/updates one if needed.""" - current_time = time.time() - has_data = any( - isinstance(dq, deque) and dq - for key, dq in plot_data.items() - if not key.startswith("Internal/") - ) - target_size = (target_width, target_height) - - needs_reinit = ( - self.fig is None - or self.axes is None - or self.last_target_size != target_size - ) - current_data_hash = self._get_data_hash(plot_data) - data_changed = self.last_data_hash != current_data_hash - time_elapsed = ( - current_time - self.last_plot_update_time - ) > self.plot_update_interval - needs_update = data_changed or time_elapsed - can_create_plot = target_width > 50 and target_height > 50 - - if not can_create_plot: - if self.plot_surface_cache is not None: - logger.info("[Plotter] Target size too small, clearing cache/figure.") - self.plot_surface_cache = None - if self.fig: - plt.close(self.fig) - self.fig, self.axes, self.last_target_size = None, None, (0, 0) - return None - - if not has_data: - logger.debug("[Plotter] No plot data available, returning cache (if any).") - return self.plot_surface_cache - - try: - if needs_reinit: - self._init_figure(target_width, target_height) - needs_update = True - - if needs_update and self.fig: - if self._update_plot_data(plot_data): - self.plot_surface_cache = self._render_figure_to_surface( - target_width, target_height - ) - self.last_plot_update_time = current_time - self.last_data_hash = current_data_hash - else: - logger.warning( - "[Plotter] Plot update failed, returning stale cache if available." - ) - elif ( - self.plot_surface_cache is None - and self.fig - and self._update_plot_data(plot_data) - ): - self.plot_surface_cache = self._render_figure_to_surface( - target_width, target_height - ) - self.last_plot_update_time = current_time - self.last_data_hash = current_data_hash - - except Exception as e: - logger.error(f"[Plotter] Error in get_plot_surface: {e}", exc_info=True) - self.plot_surface_cache = None - if self.fig: - with contextlib.suppress(Exception): - plt.close(self.fig) - self.fig, self.axes, self.last_target_size = None, None, (0, 0) - - return self.plot_surface_cache - - def __del__(self): - """Ensure Matplotlib figure is closed.""" - if self.fig: - try: - plt.close(self.fig) - except Exception as e: - print(f"[Plotter] Error closing figure in destructor: {e}") diff --git a/alphatriangle/structs/README.md b/alphatriangle/structs/README.md deleted file mode 100644 index a3045b2..0000000 --- a/alphatriangle/structs/README.md +++ /dev/null @@ -1,100 +0,0 @@ -# File: alphatriangle/structs/README.md -# Core Structures Module (`alphatriangle.structs`) - -## Purpose and Architecture - -This module defines fundamental data structures and constants that are shared across multiple major components of the application (like [`environment`](../environment/README.md), [`visualization`](../visualization/README.md), [`features`](../features/README.md)). Its primary purpose is to break potential circular dependencies that arise when these components need to know about the same basic building blocks. - -- **[`triangle.py`](triangle.py):** Defines the `Triangle` class, representing a single cell on the game grid. -- **[`shape.py`](shape.py):** Defines the `Shape` class, representing a placeable piece composed of triangles. -- **[`constants.py`](constants.py):** Defines shared constants, such as the list of possible `SHAPE_COLORS` and color IDs used in `GridData`. - -By placing these core definitions in a low-level module with minimal dependencies, other modules can import them without creating import cycles. - -## Exposed Interfaces - -- **Classes:** - - `Triangle`: Represents a grid cell. - - `Shape`: Represents a placeable piece. -- **Constants:** - - `SHAPE_COLORS`: A list of RGB tuples for shape generation. - - `NO_COLOR_ID`: Integer ID for empty cells in `GridData`. - - `DEBUG_COLOR_ID`: Integer ID for debug-toggled cells in `GridData`. - - `COLOR_ID_MAP`: List mapping color ID index to RGB tuple. - - `COLOR_TO_ID_MAP`: Dictionary mapping RGB tuple to color ID index. - -## Dependencies - -This module has minimal dependencies, primarily relying on standard Python libraries (`typing`). It should **not** import from higher-level modules like `environment`, `visualization`, `nn`, `rl`, etc. - ---- - -**Note:** This module should only contain widely shared, fundamental data structures and constants. More complex logic or structures specific to a particular domain (like game rules or rendering details) should remain in their respective modules. -``` - -**22. File:** `alphatriangle/training/README.md` -**Explanation:** Review content and add relative links. - -```markdown -# File: alphatriangle/training/README.md -# Training Module (`alphatriangle.training`) - -## Purpose and Architecture - -This module encapsulates the logic for setting up, running, and managing the reinforcement learning training pipeline. It aims to provide a cleaner separation of concerns compared to embedding all logic within the run scripts or a single orchestrator class. - -- **[`setup.py`](setup.py):** Contains `setup_training_components` which initializes Ray, detects resources, adjusts worker counts, loads configurations, and creates the core components bundle (`TrainingComponents`). -- **[`components.py`](components.py):** Defines the `TrainingComponents` dataclass, a simple container to bundle all the necessary initialized objects (NN, Buffer, Trainer, DataManager, StatsCollector, Configs) required by the `TrainingLoop`. -- **[`loop.py`](loop.py):** Defines the `TrainingLoop` class. This class contains the core asynchronous logic of the training loop itself: - - Managing the pool of `SelfPlayWorker` actors via `WorkerManager`. - - Submitting and collecting results from self-play tasks. - - Adding experiences to the `ExperienceBuffer`. - - Triggering training steps on the `Trainer`. - - Updating worker network weights periodically, **passing the current `global_step` to the workers**, and logging a special event (`Internal/Weight_Update_Step`) with the `global_step` to the `StatsCollectorActor` when updates occur. - - Updating progress bars. - - Pushing state updates to the visualizer queue (if provided). - - Handling stop requests. -- **[`worker_manager.py`](worker_manager.py):** Defines the `WorkerManager` class, responsible for creating, managing, submitting tasks to, and collecting results from the `SelfPlayWorker` actors. **It now passes the `global_step` to workers when updating weights.** -- **[`loop_helpers.py`](loop_helpers.py):** Contains helper functions used by `TrainingLoop` for tasks like logging rates, updating the visual queue, validating experiences, and logging results. **It constructs the `StepInfo` dictionary containing relevant step counters (`global_step`, `buffer_size`) for logging.** It also includes logic to log the weight update event. -- **[`runners.py`](runners.py):** Re-exports the main entry point functions (`run_training_headless_mode`, `run_training_visual_mode`) from their respective modules. -- **[`headless_runner.py`](headless_runner.py) / [`visual_runner.py`](visual_runner.py):** Contain the top-level logic for running training in either headless or visual mode. They handle argument parsing (via CLI), setup logging, call `setup_training_components`, load initial state, run the `TrainingLoop` (potentially in a separate thread for visual mode), handle visualization setup (visual mode), and manage overall cleanup (MLflow, Ray shutdown). -- **[`logging_utils.py`](logging_utils.py):** Contains helper functions for setting up file logging, redirecting output (`Tee` class), and logging configurations/metrics to MLflow. - -This structure separates the high-level setup/teardown (`headless_runner`/`visual_runner`) from the core iterative logic (`loop`), making the system more modular and potentially easier to test or modify. - -## Exposed Interfaces - -- **Classes:** - - `TrainingLoop`: Contains the core async loop logic. - - `TrainingComponents`: Dataclass holding initialized components. - - `WorkerManager`: Manages worker actors. - - `LoopHelpers`: Provides helper functions for the loop. -- **Functions (from `runners.py`):** - - `run_training_headless_mode(...) -> int` - - `run_training_visual_mode(...) -> int` -- **Functions (from `setup.py`):** - - `setup_training_components(...) -> Tuple[Optional[TrainingComponents], bool]` -- **Functions (from `logging_utils.py`):** - - `setup_file_logging(...) -> str` - - `get_root_logger() -> logging.Logger` - - `Tee` class - - `log_configs_to_mlflow(...)` - -## Dependencies - -- **[`alphatriangle.config`](../config/README.md)**: All configuration classes. -- **[`alphatriangle.nn`](../nn/README.md)**: `NeuralNetwork`. -- **[`alphatriangle.rl`](../rl/README.md)**: `ExperienceBuffer`, `Trainer`, `SelfPlayWorker`, `SelfPlayResult`. -- **[`alphatriangle.data`](../data/README.md)**: `DataManager`, `LoadedTrainingState`. -- **[`alphatriangle.stats`](../stats/README.md)**: `StatsCollectorActor`, `PlotDefinitions`. -- **[`alphatriangle.environment`](../environment/README.md)**: `GameState`. -- **[`alphatriangle.utils`](../utils/README.md)**: Helper functions and types (including `StepInfo`). -- **[`alphatriangle.visualization`](../visualization/README.md)**: `ProgressBar`, `DashboardRenderer`. -- **`ray`**: For parallelism. -- **`mlflow`**: For experiment tracking. -- **`torch`**: For neural network operations. -- **Standard Libraries:** `logging`, `time`, `threading`, `queue`, `os`, `json`, `collections.deque`, `dataclasses`, `sys`, `traceback`, `pathlib`. - ---- - -**Note:** Please keep this README updated when changing the structure of the training pipeline, the responsibilities of the components, or the way components interact, especially regarding the logging of step context information (`StepInfo`) and worker weight updates. \ No newline at end of file diff --git a/alphatriangle/structs/__init__.py b/alphatriangle/structs/__init__.py deleted file mode 100644 index 3ddbbcc..0000000 --- a/alphatriangle/structs/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# File: alphatriangle/structs/__init__.py -""" -Module for core data structures used across different parts of the application, -like environment, visualization, and features. Helps avoid circular dependencies. -""" - -# Correctly export constants from the constants submodule -from .constants import ( - COLOR_ID_MAP, - COLOR_TO_ID_MAP, - DEBUG_COLOR_ID, - NO_COLOR_ID, - SHAPE_COLORS, -) -from .shape import Shape -from .triangle import Triangle - -__all__ = [ - "Triangle", - "Shape", - # Exported Constants - "SHAPE_COLORS", - "NO_COLOR_ID", - "DEBUG_COLOR_ID", - "COLOR_ID_MAP", - "COLOR_TO_ID_MAP", -] diff --git a/alphatriangle/structs/constants.py b/alphatriangle/structs/constants.py deleted file mode 100644 index 03464be..0000000 --- a/alphatriangle/structs/constants.py +++ /dev/null @@ -1,41 +0,0 @@ -# File: alphatriangle/structs/constants.py - -# Define standard colors used for shapes -# Ensure these colors are distinct and visually clear -# Also ensure BLACK (0,0,0) is NOT used here if it represents empty in color_np -SHAPE_COLORS: list[tuple[int, int, int]] = [ - (220, 40, 40), # 0: Red - (60, 60, 220), # 1: Blue - (40, 200, 40), # 2: Green - (230, 230, 40), # 3: Yellow - (240, 150, 20), # 4: Orange - (140, 40, 140), # 5: Purple - (40, 200, 200), # 6: Cyan - (200, 100, 180), # 7: Pink (Example addition) - (100, 180, 200), # 8: Light Blue (Example addition) -] - -# --- NumPy GridData Color Representation --- -# ID for empty cells in the _color_id_np array -NO_COLOR_ID: int = -1 -# ID for debug-toggled cells -DEBUG_COLOR_ID: int = -2 - -# Mapping from Color ID (int >= 0) to RGB tuple. -# Index 0 corresponds to SHAPE_COLORS[0], etc. -# This list is used by visualization to get the RGB from the ID. -COLOR_ID_MAP: list[tuple[int, int, int]] = SHAPE_COLORS - -# Reverse mapping for efficient lookup during placement (Color Tuple -> ID) -# Note: Ensure SHAPE_COLORS have unique tuples. -COLOR_TO_ID_MAP: dict[tuple[int, int, int], int] = { - color: i for i, color in enumerate(COLOR_ID_MAP) -} - -# Add special colors to the map if needed for rendering debug/other states -# These IDs won't be stored during normal shape placement. -# Example: If you want to render the debug color: -# DEBUG_RGB_COLOR = (255, 255, 0) # Example Yellow -# COLOR_ID_MAP.append(DEBUG_RGB_COLOR) # Append if needed elsewhere, but generally lookup handled separately. - -# --- End NumPy GridData Color Representation --- diff --git a/alphatriangle/structs/shape.py b/alphatriangle/structs/shape.py deleted file mode 100644 index e1eb69f..0000000 --- a/alphatriangle/structs/shape.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - - -class Shape: - """Represents a polyomino-like shape made of triangles.""" - - def __init__( - self, triangles: list[tuple[int, int, bool]], color: tuple[int, int, int] - ): - self.triangles: list[tuple[int, int, bool]] = sorted(triangles) - self.color: tuple[int, int, int] = color - - def bbox(self) -> tuple[int, int, int, int]: - """Calculates bounding box (min_r, min_c, max_r, max_c) in relative coords.""" - if not self.triangles: - return (0, 0, 0, 0) - rows = [t[0] for t in self.triangles] - cols = [t[1] for t in self.triangles] - return (min(rows), min(cols), max(rows), max(cols)) - - def copy(self) -> Shape: - """Creates a shallow copy (triangle list is copied, color is shared).""" - new_shape = Shape.__new__(Shape) - new_shape.triangles = list(self.triangles) - new_shape.color = self.color - return new_shape - - def __str__(self) -> str: - return f"Shape(Color:{self.color}, Tris:{len(self.triangles)})" - - def __eq__(self, other: object) -> bool: - """Checks for equality based on triangles and color.""" - if not isinstance(other, Shape): - return NotImplemented - return self.triangles == other.triangles and self.color == other.color - - def __hash__(self) -> int: - """Allows shapes to be used in sets/dicts if needed.""" - return hash((tuple(self.triangles), self.color)) diff --git a/alphatriangle/structs/triangle.py b/alphatriangle/structs/triangle.py deleted file mode 100644 index 7f0ad32..0000000 --- a/alphatriangle/structs/triangle.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - - -class Triangle: - """Represents a single triangular cell on the grid.""" - - def __init__(self, row: int, col: int, is_up: bool, is_death: bool = False): - self.row = row - self.col = col - self.is_up = is_up - self.is_death = is_death - self.is_occupied = is_death - self.color: tuple[int, int, int] | None = None - - self.neighbor_left: Triangle | None = None - self.neighbor_right: Triangle | None = None - self.neighbor_vert: Triangle | None = None - - def get_points( - self, ox: float, oy: float, cw: float, ch: float - ) -> list[tuple[float, float]]: - """Calculates vertex points for drawing, relative to origin (ox, oy).""" - x = ox + self.col * (cw * 0.75) - y = oy + self.row * ch - if self.is_up: - return [(x, y + ch), (x + cw, y + ch), (x + cw / 2, y)] - else: - return [(x, y), (x + cw, y), (x + cw / 2, y + ch)] - - def copy(self) -> Triangle: - """Creates a copy of the Triangle object's state (neighbors are not copied).""" - new_tri = Triangle(self.row, self.col, self.is_up, self.is_death) - new_tri.is_occupied = self.is_occupied - new_tri.color = self.color - return new_tri - - def __repr__(self) -> str: - state = "D" if self.is_death else ("O" if self.is_occupied else ".") - orient = "^" if self.is_up else "v" - return f"T({self.row},{self.col} {orient}{state})" - - def __hash__(self): - return hash((self.row, self.col)) - - def __eq__(self, other): - if not isinstance(other, Triangle): - return NotImplemented - return self.row == other.row and self.col == other.col diff --git a/alphatriangle/training/README.md b/alphatriangle/training/README.md index 255ade0..6a874cb 100644 --- a/alphatriangle/training/README.md +++ b/alphatriangle/training/README.md @@ -1,9 +1,9 @@ -# File: alphatriangle/training/README.md + # Training Module (`alphatriangle.training`) ## Purpose and Architecture -This module encapsulates the logic for setting up, running, and managing the reinforcement learning training pipeline. It aims to provide a cleaner separation of concerns compared to embedding all logic within the run scripts or a single orchestrator class. +This module encapsulates the logic for setting up, running, and managing the **headless** reinforcement learning training pipeline. It aims to provide a cleaner separation of concerns compared to embedding all logic within the run scripts or a single orchestrator class. - **`setup.py`:** Contains `setup_training_components` which initializes Ray, detects resources, adjusts worker counts, loads configurations, and creates the core components bundle (`TrainingComponents`). - **`components.py`:** Defines the `TrainingComponents` dataclass, a simple container to bundle all the necessary initialized objects (NN, Buffer, Trainer, DataManager, StatsCollector, Configs) required by the `TrainingLoop`. @@ -12,17 +12,16 @@ This module encapsulates the logic for setting up, running, and managing the rei - Submitting and collecting results from self-play tasks. - Adding experiences to the `ExperienceBuffer`. - Triggering training steps on the `Trainer`. - - Updating worker network weights periodically, **passing the current `global_step` to the workers**, and logging a special event (`Internal/Weight_Update_Step`) with the `global_step` to the `StatsCollectorActor` when updates occur. - - Updating progress bars. - - Pushing state updates to the visualizer queue (if provided). + - Updating worker network weights periodically, **passing the current `global_step` to the workers**, and logging a special event (`Events/Weight_Update`) with the `global_step` to the `StatsCollectorActor` when updates occur. + - Logging progress and rates. - Handling stop requests. - **`worker_manager.py`:** Defines the `WorkerManager` class, responsible for creating, managing, submitting tasks to, and collecting results from the `SelfPlayWorker` actors. **It now passes the `global_step` to workers when updating weights.** -- **`loop_helpers.py`:** Contains helper functions used by `TrainingLoop` for tasks like logging rates, updating the visual queue, validating experiences, and logging results. **It constructs the `StepInfo` dictionary containing relevant step counters (`global_step`, `buffer_size`) for logging.** It also includes logic to log the weight update event. -- **`runners.py`:** Re-exports the main entry point functions (`run_training_headless_mode`, `run_training_visual_mode`) from their respective modules. -- **`headless_runner.py` / `visual_runner.py`:** Contain the top-level logic for running training in either headless or visual mode. They handle argument parsing (via CLI), setup logging, call `setup_training_components`, load initial state, run the `TrainingLoop` (potentially in a separate thread for visual mode), handle visualization setup (visual mode), and manage overall cleanup (MLflow, Ray shutdown). +- **`loop_helpers.py`:** Contains helper functions used by `TrainingLoop` for tasks like logging rates, validating experiences, and logging results to the `StatsCollectorActor` and TensorBoard. **It constructs the `StepInfo` dictionary containing relevant step counters (`global_step`, `buffer_size`) for logging.** It also includes logic to log the weight update event. +- **`runner.py`:** Contains the top-level logic for running the headless training pipeline. It handles argument parsing (via CLI), setup logging, calls `setup_training_components`, loads initial state, runs the `TrainingLoop`, and manages overall cleanup (MLflow, TensorBoard, Ray shutdown). +- **`runners.py`:** Re-exports the main entry point function (`run_training`) from `runner.py`. - **`logging_utils.py`:** Contains helper functions for setting up file logging, redirecting output (`Tee` class), and logging configurations/metrics to MLflow. -This structure separates the high-level setup/teardown (`headless_runner`/`visual_runner`) from the core iterative logic (`loop`), making the system more modular and potentially easier to test or modify. +This structure separates the high-level setup/teardown (`runner`) from the core iterative logic (`loop`), making the system more modular and potentially easier to test or modify. ## Exposed Interfaces @@ -32,8 +31,7 @@ This structure separates the high-level setup/teardown (`headless_runner`/`visua - `WorkerManager`: Manages worker actors. - `LoopHelpers`: Provides helper functions for the loop. - **Functions (from `runners.py`):** - - `run_training_headless_mode(...) -> int` - - `run_training_visual_mode(...) -> int` + - `run_training(...) -> int` - **Functions (from `setup.py`):** - `setup_training_components(...) -> Tuple[Optional[TrainingComponents], bool]` - **Functions (from `logging_utils.py`):** @@ -48,15 +46,15 @@ This structure separates the high-level setup/teardown (`headless_runner`/`visua - **`alphatriangle.nn`**: `NeuralNetwork`. - **`alphatriangle.rl`**: `ExperienceBuffer`, `Trainer`, `SelfPlayWorker`, `SelfPlayResult`. - **`alphatriangle.data`**: `DataManager`, `LoadedTrainingState`. -- **`alphatriangle.stats`**: `StatsCollectorActor`, `PlotDefinitions`. -- **`alphatriangle.environment`**: `GameState`. +- **`alphatriangle.stats`**: `StatsCollectorActor`. +- **`trianglengin`**: `GameState`, `EnvConfig`. - **`alphatriangle.utils`**: Helper functions and types (including `StepInfo`). -- **`alphatriangle.visualization`**: `ProgressBar`, `DashboardRenderer`. - **`ray`**: For parallelism. - **`mlflow`**: For experiment tracking. - **`torch`**: For neural network operations. +- **`torch.utils.tensorboard`**: For TensorBoard logging. - **Standard Libraries:** `logging`, `time`, `threading`, `queue`, `os`, `json`, `collections.deque`, `dataclasses`. --- -**Note:** Please keep this README updated when changing the structure of the training pipeline, the responsibilities of the components, or the way components interact, especially regarding the logging of step context information (`StepInfo`) and worker weight updates. \ No newline at end of file +**Note:** Please keep this README updated when changing the structure of the training pipeline, the responsibilities of the components, or the way components interact, especially regarding the logging of step context information (`StepInfo`) and worker weight updates. Accurate documentation is crucial for maintainability. \ No newline at end of file diff --git a/alphatriangle/training/__init__.py b/alphatriangle/training/__init__.py index 60dc2d8..4d56302 100644 --- a/alphatriangle/training/__init__.py +++ b/alphatriangle/training/__init__.py @@ -1,10 +1,3 @@ -# File: alphatriangle/training/__init__.py -""" -Training module containing the pipeline, loop, components, and utilities -for orchestrating the reinforcement learning training process. -""" - -# Core components & classes from .components import TrainingComponents # Utilities @@ -13,13 +6,10 @@ from .loop_helpers import LoopHelpers # Re-export runner functions -from .runners import ( - run_training_headless_mode, - run_training_visual_mode, -) -from .setup import setup_training_components +from .runner import run_training # Import the single runner -# from .pipeline import TrainingPipeline # REMOVED +# REMOVE visual runner import +from .setup import setup_training_components from .worker_manager import WorkerManager # Explicitly define __all__ @@ -27,14 +17,12 @@ # Core Components "TrainingComponents", "TrainingLoop", - # "TrainingPipeline", # REMOVED # Helpers & Managers "WorkerManager", "LoopHelpers", "setup_training_components", # Runners (re-exported) - "run_training_headless_mode", - "run_training_visual_mode", + "run_training", # Export single runner # Logging Utilities "setup_file_logging", "get_root_logger", diff --git a/alphatriangle/training/components.py b/alphatriangle/training/components.py index 385b81b..c86b063 100644 --- a/alphatriangle/training/components.py +++ b/alphatriangle/training/components.py @@ -1,18 +1,15 @@ -# File: alphatriangle/training/components.py from dataclasses import dataclass from typing import TYPE_CHECKING -# --- ADDED: Import ActorHandle --- import ray -# --- END ADDED --- +# Import EnvConfig from trianglengin +from trianglengin.config import EnvConfig -if TYPE_CHECKING: - # Keep ray import here as well for consistency if needed elsewhere - # import ray +# Keep alphatriangle imports +if TYPE_CHECKING: from alphatriangle.config import ( - EnvConfig, MCTSConfig, ModelConfig, PersistenceConfig, @@ -22,9 +19,7 @@ from alphatriangle.nn import NeuralNetwork from alphatriangle.rl import ExperienceBuffer, Trainer - # --- REMOVED: Import StatsCollectorActor class --- - # from alphatriangle.stats import StatsCollectorActor - # --- END REMOVED --- + pass # No changes needed here @dataclass @@ -35,11 +30,10 @@ class TrainingComponents: buffer: "ExperienceBuffer" trainer: "Trainer" data_manager: "DataManager" - # --- CORRECTED: Use ActorHandle type hint --- - stats_collector_actor: ray.actor.ActorHandle | None - # --- END CORRECTED --- + stats_collector_actor: ray.actor.ActorHandle | None # Keep type hint train_config: "TrainConfig" - env_config: "EnvConfig" + env_config: EnvConfig # Use trianglengin.EnvConfig model_config: "ModelConfig" mcts_config: "MCTSConfig" persist_config: "PersistenceConfig" + # REMOVE visual_state_actor field diff --git a/alphatriangle/training/loop.py b/alphatriangle/training/loop.py index b98fd88..830c804 100644 --- a/alphatriangle/training/loop.py +++ b/alphatriangle/training/loop.py @@ -1,34 +1,21 @@ -# File: alphatriangle/training/loop.py import logging -import queue import threading import time from typing import TYPE_CHECKING, Any -# --- MOVED: numpy import --- -# import numpy as np -# --- END MOVED --- from ..rl import SelfPlayResult -# --- MOVED: ProgressBar import --- -# from ..visualization.ui import ProgressBar -# --- END MOVED --- -# --- MOVED: TrainingComponents import --- -# from .components import TrainingComponents -# --- END MOVED --- +# REMOVE ProgressBar import from .loop_helpers import LoopHelpers from .worker_manager import WorkerManager if TYPE_CHECKING: - # --- ADDED: Imports under TYPE_CHECKING --- import numpy as np from ..utils.types import PERBatchSample - from ..visualization.ui import ProgressBar - from .components import TrainingComponents - - # --- END ADDED --- + # REMOVE ProgressBar import + from .components import TrainingComponents logger = logging.getLogger(__name__) @@ -36,47 +23,42 @@ class TrainingLoop: """ Manages the core asynchronous training loop logic: coordinating worker tasks, - processing results, triggering training steps, and updating visual queue. - Receives initialized components via TrainingComponents. Runs indefinitely - until stop_requested is set. Uses WorkerManager and LoopHelpers. + processing results, triggering training steps. Runs headless. """ def __init__( self, components: "TrainingComponents", - visual_state_queue: queue.Queue[dict[int, Any] | None] | None = None, + # REMOVE visual_state_queue parameter ): self.components = components - self.visual_state_queue = visual_state_queue + # REMOVE self.visual_state_queue = visual_state_queue self.train_config = components.train_config - - # Core components self.buffer = components.buffer self.trainer = components.trainer - # State variables self.global_step = 0 self.episodes_played = 0 self.total_simulations_run = 0 - self.worker_weight_updates_count = 0 # Counter for worker updates + self.worker_weight_updates_count = 0 self.start_time = time.time() self.stop_requested = threading.Event() self.training_complete = False self.training_exception: Exception | None = None - # Progress Bars (initialized later) - self.train_step_progress: ProgressBar | None = None - self.buffer_fill_progress: ProgressBar | None = None + # REMOVE Progress Bars + # self.train_step_progress: ProgressBar | None = None + # self.buffer_fill_progress: ProgressBar | None = None - # Instantiate helpers self.worker_manager = WorkerManager(components) + # Pass None for visual_state_queue self.loop_helpers = LoopHelpers( components, - self.visual_state_queue, - self._get_loop_state, # Pass method to get current state + None, # Pass None for visual_state_queue + self._get_loop_state, ) - logger.info("TrainingLoop initialized.") + logger.info("TrainingLoop initialized (Headless).") def _get_loop_state(self) -> dict[str, Any]: """Provides current loop state to helpers.""" @@ -89,8 +71,9 @@ def _get_loop_state(self) -> dict[str, Any]: "buffer_capacity": self.buffer.capacity, "num_active_workers": self.worker_manager.get_num_active_workers(), "num_pending_tasks": self.worker_manager.get_num_pending_tasks(), - "train_progress": self.train_step_progress, - "buffer_progress": self.buffer_fill_progress, + # REMOVE progress bars from state + # "train_progress": self.train_step_progress, + # "buffer_progress": self.buffer_fill_progress, "start_time": self.start_time, "num_workers": self.train_config.NUM_SELF_PLAY_WORKERS, } @@ -102,15 +85,15 @@ def set_initial_state( self.global_step = global_step self.episodes_played = episodes_played self.total_simulations_run = total_simulations - # Estimate initial weight updates based on loaded step and frequency self.worker_weight_updates_count = ( global_step // self.train_config.WORKER_UPDATE_FREQ_STEPS ) - self.train_step_progress, self.buffer_fill_progress = ( - self.loop_helpers.initialize_progress_bars( - global_step, len(self.buffer), self.start_time - ) - ) + # REMOVE progress bar initialization + # self.train_step_progress, self.buffer_fill_progress = ( + # self.loop_helpers.initialize_progress_bars( + # global_step, len(self.buffer), self.start_time + # ) + # ) self.loop_helpers.reset_rate_counters( global_step, episodes_played, total_simulations ) @@ -153,10 +136,11 @@ def _process_self_play_result(self, result: SelfPlayResult, worker_id: int): f"Error adding batch to buffer from worker {worker_id}: {e}", exc_info=True, ) - return # Don't update counters if add failed + return - if self.buffer_fill_progress: - self.buffer_fill_progress.set_current_steps(len(self.buffer)) + # REMOVE progress bar update + # if self.buffer_fill_progress: + # self.buffer_fill_progress.set_current_steps(len(self.buffer)) self.episodes_played += 1 self.total_simulations_run += result.total_simulations else: @@ -180,29 +164,26 @@ def _run_training_step(self) -> bool: if train_result: loss_info, td_errors = train_result self.global_step += 1 - if self.train_step_progress: - self.train_step_progress.set_current_steps(self.global_step) + # REMOVE progress bar update + # if self.train_step_progress: + # self.train_step_progress.set_current_steps(self.global_step) if self.train_config.USE_PER: self.buffer.update_priorities(per_sample["indices"], td_errors) self.loop_helpers.log_training_results_async( loss_info, self.global_step, self.total_simulations_run ) - # Check if it's time to update worker networks if self.global_step % self.train_config.WORKER_UPDATE_FREQ_STEPS == 0: try: - # --- CHANGED: Pass global_step --- self.worker_manager.update_worker_networks(self.global_step) - # --- END CHANGED --- - self.worker_weight_updates_count += 1 # Increment counter - # Log the update event using the helper + self.worker_weight_updates_count += 1 self.loop_helpers.log_weight_update_event(self.global_step) except Exception as update_err: logger.error( f"Failed to update worker networks at step {self.global_step}: {update_err}" ) - if self.global_step % 50 == 0: + if self.global_step % 50 == 0: # Keep periodic logging logger.info( f"Step {self.global_step}: P Loss={loss_info['policy_loss']:.4f}, V Loss={loss_info['value_loss']:.4f}, Ent={loss_info['entropy']:.4f}, TD Err={loss_info['mean_td_error']:.4f}" ) @@ -222,11 +203,9 @@ def run(self): self.start_time = time.time() try: - # Initial task submission self.worker_manager.submit_initial_tasks() while not self.stop_requested.is_set(): - # Check if max steps reached if ( self.train_config.MAX_TRAINING_STEPS is not None and self.global_step >= self.train_config.MAX_TRAINING_STEPS @@ -238,16 +217,14 @@ def run(self): self.request_stop() break - # Training Step if self.buffer.is_ready(): - _ = self._run_training_step() # Call training step + _ = self._run_training_step() else: - time.sleep(0.01) + time.sleep(0.01) # Short sleep if not training if self.stop_requested.is_set(): break - # Handle Completed Worker Tasks wait_timeout = 0.1 if self.buffer.is_ready() else 0.5 completed_tasks = self.worker_manager.get_completed_tasks(wait_timeout) @@ -274,9 +251,9 @@ def run(self): if self.stop_requested.is_set(): break - # Periodic Tasks (using LoopHelpers) - self.loop_helpers.update_visual_queue() - self.loop_helpers.log_progress_eta() + # REMOVE visual queue update + # self.loop_helpers.update_visual_queue() + self.loop_helpers.log_progress_eta() # Keep ETA logging self.loop_helpers.calculate_and_log_rates() if not completed_tasks and not self.buffer.is_ready(): diff --git a/alphatriangle/training/loop_helpers.py b/alphatriangle/training/loop_helpers.py index 75f3d94..deb89ab 100644 --- a/alphatriangle/training/loop_helpers.py +++ b/alphatriangle/training/loop_helpers.py @@ -1,6 +1,5 @@ -# File: alphatriangle/training/loop_helpers.py import logging -import queue +import queue # Keep for type hint check import time from collections.abc import Callable from typing import TYPE_CHECKING, Any @@ -8,22 +7,29 @@ import numpy as np import ray -from ..environment import GameState -from ..stats.plot_definitions import WEIGHT_UPDATE_METRIC_KEY +# --- ADD TensorBoard --- +from torch.utils.tensorboard import SummaryWriter + +# --- END ADD --- +# Import GameState from trianglengin +# Keep alphatriangle imports +# REMOVED WEIGHT_UPDATE_METRIC_KEY import (no longer plotting) from ..utils import format_eta from ..utils.types import Experience, StatsCollectorData, StepInfo -from ..visualization.core import colors -from ..visualization.ui import ProgressBar + +# REMOVE ProgressBar import +# REMOVE colors import if TYPE_CHECKING: from .components import TrainingComponents logger = logging.getLogger(__name__) -VISUAL_UPDATE_INTERVAL = 0.2 +# REMOVE VISUAL_UPDATE_INTERVAL STATS_FETCH_INTERVAL = 0.5 -VIS_STATE_FETCH_TIMEOUT = 0.1 -RATE_CALCULATION_INTERVAL = 5.0 # Check rates every 5 seconds +# REMOVE VIS_STATE_FETCH_TIMEOUT +RATE_CALCULATION_INTERVAL = 5.0 +WEIGHT_UPDATE_EVENT_KEY = "Events/Weight_Update" # Define key for logging class LoopHelpers: @@ -32,18 +38,21 @@ class LoopHelpers: def __init__( self, components: "TrainingComponents", - visual_state_queue: queue.Queue[dict[int, Any] | None] | None, + # REMOVE visual_state_queue parameter + _visual_state_queue: ( + queue.Queue[dict[int, Any] | None] | None + ), # Keep param name but mark unused get_loop_state_func: Callable[[], dict[str, Any]], ): self.components = components - self.visual_state_queue = visual_state_queue - self.get_loop_state = get_loop_state_func # Function to get current loop state + # REMOVE self.visual_state_queue = visual_state_queue + self.get_loop_state = get_loop_state_func self.stats_collector_actor = components.stats_collector_actor self.train_config = components.train_config - self.trainer = components.trainer # Needed for LR + self.trainer = components.trainer - self.last_visual_update_time = 0.0 + # REMOVE self.last_visual_update_time = 0.0 self.last_stats_fetch_time = 0.0 self.latest_stats_data: StatsCollectorData = {} @@ -52,6 +61,16 @@ def __init__( self.last_rate_calc_episodes = 0 self.last_rate_calc_sims = 0 + # --- ADD TensorBoard Writer --- + self.tb_writer: SummaryWriter | None = None + # --- END ADD --- + + # --- ADD Method to set writer --- + def set_tensorboard_writer(self, writer: SummaryWriter): + self.tb_writer = writer + + # --- END ADD --- + def reset_rate_counters( self, global_step: int, episodes_played: int, total_simulations: int ): @@ -61,30 +80,7 @@ def reset_rate_counters( self.last_rate_calc_episodes = episodes_played self.last_rate_calc_sims = total_simulations - def initialize_progress_bars( - self, global_step: int, buffer_size: int, start_time: float - ) -> tuple[ProgressBar, ProgressBar]: - """Initializes and returns progress bars.""" - train_total_steps = self.train_config.MAX_TRAINING_STEPS - train_total_steps_for_bar = ( - train_total_steps if train_total_steps is not None else 1 - ) - - train_step_progress = ProgressBar( - "Training Steps", - total_steps=train_total_steps_for_bar, - start_time=start_time, - initial_steps=global_step, - initial_color=colors.GREEN, - ) - buffer_fill_progress = ProgressBar( - "Buffer Fill", - self.train_config.BUFFER_CAPACITY, - start_time=start_time, - initial_steps=buffer_size, - initial_color=colors.ORANGE, - ) - return train_step_progress, buffer_fill_progress + # REMOVE initialize_progress_bars method def _fetch_latest_stats(self): """Fetches the latest stats data from the actor.""" @@ -102,8 +98,7 @@ def _fetch_latest_stats(self): def calculate_and_log_rates(self): """ Calculates and logs steps/sec, episodes/sec, sims/sec, and buffer size. - Ensures buffer-related rates are logged against buffer size. - Logs metrics with StepInfo containing global_step and buffer_size. + Logs to StatsCollectorActor and TensorBoard. """ current_time = time.time() time_delta = current_time - self.last_rate_calc_time @@ -114,7 +109,7 @@ def calculate_and_log_rates(self): global_step = loop_state["global_step"] episodes_played = loop_state["episodes_played"] total_simulations = loop_state["total_simulations_run"] - current_buffer_size = int(loop_state["buffer_size"]) # Use int for step info + current_buffer_size = int(loop_state["buffer_size"]) steps_delta = global_step - self.last_rate_calc_step episodes_delta = episodes_played - self.last_rate_calc_episodes @@ -124,6 +119,7 @@ def calculate_and_log_rates(self): episodes_per_sec = episodes_delta / time_delta if time_delta > 0 else 0.0 sims_per_sec = sims_delta / time_delta if time_delta > 0 else 0.0 + # Log to StatsCollectorActor if self.stats_collector_actor: step_info_buffer: StepInfo = { "global_step": global_step, @@ -136,49 +132,76 @@ def calculate_and_log_rates(self): "Rate/Simulations_Per_Sec": (sims_per_sec, step_info_buffer), "Buffer/Size": (float(current_buffer_size), step_info_buffer), } - log_msg_steps = "Steps/s=N/A" if steps_delta > 0: rate_stats["Rate/Steps_Per_Sec"] = (steps_per_sec, step_info_global) - log_msg_steps = f"Steps/s={steps_per_sec:.2f}" try: self.stats_collector_actor.log_batch.remote(rate_stats) # type: ignore - logger.debug( - f"Logged rates/buffer at step {global_step} / buffer {current_buffer_size}: " - f"{log_msg_steps}, Eps/s={episodes_per_sec:.2f}, Sims/s={sims_per_sec:.1f}, " - f"Buffer={current_buffer_size}" - ) except Exception as e: logger.error(f"Failed to log rate/buffer stats to collector: {e}") + # --- ADD Log to TensorBoard --- + if self.tb_writer: + try: + self.tb_writer.add_scalar( + "Rates/Episodes_Per_Sec", episodes_per_sec, global_step + ) + self.tb_writer.add_scalar( + "Rates/Simulations_Per_Sec", sims_per_sec, global_step + ) + self.tb_writer.add_scalar( + "Buffer/Size", float(current_buffer_size), global_step + ) + if steps_delta > 0: + self.tb_writer.add_scalar( + "Rates/Steps_Per_Sec", steps_per_sec, global_step + ) + except Exception as tb_err: + logger.error(f"Failed to log rates to TensorBoard: {tb_err}") + # --- END ADD --- + + log_msg_steps = ( + f"Steps/s={steps_per_sec:.2f}" if steps_delta > 0 else "Steps/s=N/A" + ) + logger.debug( + f"Logged rates/buffer at step {global_step} / buffer {current_buffer_size}: " + f"{log_msg_steps}, Eps/s={episodes_per_sec:.2f}, Sims/s={sims_per_sec:.1f}, " + f"Buffer={current_buffer_size}" + ) + self.reset_rate_counters(global_step, episodes_played, total_simulations) def log_progress_eta(self): """Logs progress and ETA.""" loop_state = self.get_loop_state() global_step = loop_state["global_step"] - train_progress = loop_state["train_progress"] + # REMOVE train_progress bar usage - if global_step == 0 or global_step % 100 != 0 or not train_progress: + if global_step == 0 or global_step % 100 != 0: return elapsed_time = time.time() - loop_state["start_time"] - steps_since_load = global_step - train_progress.initial_steps - steps_per_sec = 0.0 - self._fetch_latest_stats() # Fetch stats to get latest rate + # Calculate steps_since_load based on last_rate_calc_step if needed, or just use global_step + steps_since_start = global_step # Assuming we want overall ETA + steps_per_sec = 0.0 + self._fetch_latest_stats() rate_dq = self.latest_stats_data.get("Rate/Steps_Per_Sec") if rate_dq: - # Get the value from the last tuple (step_info, value) steps_per_sec = rate_dq[-1][1] - elif elapsed_time > 1 and steps_since_load > 0: - # Fallback calculation if rate not in stats yet - steps_per_sec = steps_since_load / elapsed_time + elif elapsed_time > 1 and steps_since_start > 0: + steps_per_sec = steps_since_start / elapsed_time target_steps = self.train_config.MAX_TRAINING_STEPS target_steps_str = f"{target_steps:,}" if target_steps else "Infinite" progress_str = f"Step {global_step:,}/{target_steps_str}" - eta_str = format_eta(train_progress.get_eta_seconds()) + + eta_str = "--" + if target_steps and steps_per_sec > 1e-6: + remaining_steps = target_steps - global_step + if remaining_steps > 0: + eta_seconds = remaining_steps / steps_per_sec + eta_str = format_eta(eta_seconds) buffer_fill_perc = ( (loop_state["buffer_size"] / loop_state["buffer_capacity"]) * 100 @@ -198,68 +221,7 @@ def log_progress_eta(self): f"Pending Tasks: {num_pending_tasks}, Speed: {steps_per_sec:.2f} steps/sec, ETA: {eta_str}" ) - def update_visual_queue(self): - """Fetches latest states/stats and puts them onto the visual queue.""" - if not self.visual_state_queue or not self.stats_collector_actor: - return - current_time = time.time() - if current_time - self.last_visual_update_time < VISUAL_UPDATE_INTERVAL: - return - self.last_visual_update_time = current_time - - latest_worker_states: dict[int, GameState] = {} - try: - states_ref = self.stats_collector_actor.get_latest_worker_states.remote() # type: ignore - latest_worker_states = ray.get(states_ref, timeout=VIS_STATE_FETCH_TIMEOUT) - if not isinstance(latest_worker_states, dict): - logger.warning( - f"StatsCollectorActor returned invalid type for states: {type(latest_worker_states)}" - ) - latest_worker_states = {} - except Exception as e: - logger.warning( - f"Failed to fetch latest worker states for visualization: {e}" - ) - latest_worker_states = {} - - self._fetch_latest_stats() # Fetch latest metric data - - visual_data: dict[int, Any] = {} - for worker_id, state in latest_worker_states.items(): - if isinstance(state, GameState): - visual_data[worker_id] = state - else: - logger.warning( - f"Received invalid state type for worker {worker_id} from collector: {type(state)}" - ) - - visual_data[-1] = { - **self.get_loop_state(), - "stats_data": self.latest_stats_data, - } - - if not visual_data or len(visual_data) == 1: - logger.debug( - "No worker states available from collector to send to visual queue." - ) - return - - worker_keys = [k for k in visual_data if k != -1] - logger.debug( - f"Putting visual data on queue. Worker IDs with states: {worker_keys}" - ) - - try: - while self.visual_state_queue.qsize() > 2: - try: - self.visual_state_queue.get_nowait() - except queue.Empty: - break - self.visual_state_queue.put_nowait(visual_data) - except queue.Full: - logger.warning("Visual state queue full, dropping state dictionary.") - except Exception as qe: - logger.error(f"Error putting state dict in visual queue: {qe}") + # REMOVE update_visual_queue method def validate_experiences( self, experiences: list[Experience] @@ -314,21 +276,19 @@ def log_training_results_async( self, loss_info: dict[str, float], global_step: int, total_simulations: int ) -> None: """ - Logs training results asynchronously to the StatsCollectorActor. - Logs metrics with StepInfo containing global_step. + Logs training results asynchronously to StatsCollectorActor and TensorBoard. """ current_lr = self.trainer.get_current_lr() - loop_state = self.get_loop_state() - train_progress = loop_state.get("train_progress") + self.get_loop_state() + # REMOVE train_progress usage buffer = self.components.buffer - train_step_perc = ( - (train_progress.get_progress_fraction() * 100) if train_progress else 0.0 - ) + # REMOVE train_step_perc calculation based on progress bar per_beta = ( buffer._calculate_beta(global_step) if self.train_config.USE_PER else None ) + # Log to StatsCollectorActor if self.stats_collector_actor: step_info: StepInfo = {"global_step": global_step} stats_batch: dict[str, tuple[float, StepInfo]] = { @@ -338,7 +298,7 @@ def log_training_results_async( "Loss/Entropy": (loss_info["entropy"], step_info), "Loss/Mean_TD_Error": (loss_info["mean_td_error"], step_info), "LearningRate": (current_lr, step_info), - "Progress/Train_Step_Percent": (train_step_perc, step_info), + # REMOVE Progress/Train_Step_Percent "Progress/Total_Simulations": (float(total_simulations), step_info), } if per_beta is not None: @@ -351,14 +311,51 @@ def log_training_results_async( except Exception as e: logger.error(f"Failed to log batch to StatsCollectorActor: {e}") + # --- ADD Log to TensorBoard --- + if self.tb_writer: + try: + self.tb_writer.add_scalar( + "Loss/Total", loss_info["total_loss"], global_step + ) + self.tb_writer.add_scalar( + "Loss/Policy", loss_info["policy_loss"], global_step + ) + self.tb_writer.add_scalar( + "Loss/Value", loss_info["value_loss"], global_step + ) + self.tb_writer.add_scalar( + "Loss/Entropy", loss_info["entropy"], global_step + ) + self.tb_writer.add_scalar( + "Loss/Mean_TD_Error", loss_info["mean_td_error"], global_step + ) + self.tb_writer.add_scalar("LearningRate", current_lr, global_step) + self.tb_writer.add_scalar( + "Progress/Total_Simulations", float(total_simulations), global_step + ) + if per_beta is not None: + self.tb_writer.add_scalar("PER/Beta", per_beta, global_step) + except Exception as tb_err: + logger.error(f"Failed to log training results to TensorBoard: {tb_err}") + # --- END ADD --- + def log_weight_update_event(self, global_step: int) -> None: """Logs the event of a worker weight update with StepInfo.""" if self.stats_collector_actor: try: - # Log with value 1.0 at the current global step step_info: StepInfo = {"global_step": global_step} - update_metric = {WEIGHT_UPDATE_METRIC_KEY: (1.0, step_info)} + # Use the defined key for the event + update_metric = {WEIGHT_UPDATE_EVENT_KEY: (1.0, step_info)} self.stats_collector_actor.log_batch.remote(update_metric) # type: ignore logger.info(f"Logged worker weight update event at step {global_step}.") except Exception as e: logger.error(f"Failed to log weight update event: {e}") + # --- ADD Log to TensorBoard --- + if self.tb_writer: + try: + # Log as a scalar event (value 1 indicates update occurred) + self.tb_writer.add_scalar(WEIGHT_UPDATE_EVENT_KEY, 1.0, global_step) + except Exception as tb_err: + logger.error( + f"Failed to log weight update event to TensorBoard: {tb_err}" + ) diff --git a/alphatriangle/training/headless_runner.py b/alphatriangle/training/runner.py similarity index 72% rename from alphatriangle/training/headless_runner.py rename to alphatriangle/training/runner.py index 8d86798..e387f46 100644 --- a/alphatriangle/training/headless_runner.py +++ b/alphatriangle/training/runner.py @@ -1,14 +1,17 @@ -# File: alphatriangle/training/headless_runner.py +# File: alphatriangle/training/runner.py import logging import sys import traceback -from collections import deque from pathlib import Path import mlflow import ray import torch +# --- ADD TensorBoard --- +from torch.utils.tensorboard import SummaryWriter + +# --- END ADD --- from ..config import APP_NAME, PersistenceConfig, TrainConfig from ..utils.sumtree import SumTree from .components import TrainingComponents @@ -17,7 +20,7 @@ log_configs_to_mlflow, setup_file_logging, ) -from .loop import TrainingLoop # Import TrainingLoop here +from .loop import TrainingLoop from .setup import count_parameters, setup_training_components logger = logging.getLogger(__name__) @@ -50,6 +53,7 @@ def _initialize_mlflow(persist_config: PersistenceConfig, run_name: str) -> bool def _load_and_apply_initial_state(components: TrainingComponents) -> TrainingLoop: """Loads initial state using DataManager and applies it to components, returning an initialized TrainingLoop.""" loaded_state = components.data_manager.load_initial_state() + # Pass None for visual queue training_loop = TrainingLoop(components) # Instantiate loop first if loaded_state.checkpoint_data: @@ -113,13 +117,16 @@ def _load_and_apply_initial_state(components: TrainingComponents) -> TrainingLoo components.buffer.tree.add(max_p, exp) logger.info(f"PER buffer loaded. Size: {len(components.buffer)}") else: + from collections import deque # Import locally if needed + components.buffer.buffer = deque( loaded_state.buffer_data.buffer_list, maxlen=components.buffer.capacity, ) logger.info(f"Uniform buffer loaded. Size: {len(components.buffer)}") - if training_loop.buffer_fill_progress: - training_loop.buffer_fill_progress.set_current_steps(len(components.buffer)) + # REMOVED progress bar update + # if training_loop.buffer_fill_progress: + # training_loop.buffer_fill_progress.set_current_steps(len(components.buffer)) else: logger.info("No buffer data loaded.") @@ -150,12 +157,12 @@ def _save_final_state(training_loop: TrainingLoop): logger.error(f"Failed to save final training state: {e_save}", exc_info=True) -def run_training_headless_mode( +def run_training( log_level_str: str, train_config_override: TrainConfig, persist_config_override: PersistenceConfig, ) -> int: - """Runs the training pipeline in headless mode.""" + """Runs the training pipeline (headless).""" training_loop: TrainingLoop | None = None components: TrainingComponents | None = None exit_code = 1 @@ -163,11 +170,15 @@ def run_training_headless_mode( file_handler = None ray_initialized_by_setup = False mlflow_run_active = False + tb_writer: SummaryWriter | None = None # ADDED + tb_log_dir: str | None = None # ADDED try: # --- Setup File Logging --- log_file_path = setup_file_logging( - persist_config_override, train_config_override.RUN_NAME, "headless" + persist_config_override, + train_config_override.RUN_NAME, + "train", # Changed suffix ) log_level = logging.getLevelName(log_level_str.upper()) logger.info( @@ -194,11 +205,46 @@ def run_training_headless_mode( ) mlflow.log_param("model_total_params", total_params) mlflow.log_param("model_trainable_params", trainable_params) + # Log log file path + if log_file_path: + try: + mlflow.log_artifact(log_file_path, artifact_path="logs") + except Exception as log_artifact_err: + logger.error( + f"Failed to log log file to MLflow: {log_artifact_err}" + ) else: logger.warning("MLflow initialization failed, proceeding without MLflow.") + # --- Initialize TensorBoard --- ADDED --- + tb_log_dir = components.persist_config.get_tensorboard_log_dir() + if tb_log_dir: + Path(tb_log_dir).mkdir(parents=True, exist_ok=True) + tb_writer = SummaryWriter(log_dir=tb_log_dir) + logger.info(f"TensorBoard logging initialized. Log directory: {tb_log_dir}") + if mlflow_run_active: + try: + # Log TB dir relative path within MLflow run + relative_tb_path = Path(tb_log_dir).relative_to( + Path(components.persist_config.get_run_base_dir()).parent + ) + mlflow.log_param("tensorboard_log_dir", str(relative_tb_path)) + except Exception as tb_param_err: + logger.error( + f"Failed to log TensorBoard path to MLflow: {tb_param_err}" + ) + else: + logger.warning( + "Could not determine TensorBoard log directory. Skipping TensorBoard." + ) + # --- END ADDED --- + # --- Load State & Initialize Loop --- training_loop = _load_and_apply_initial_state(components) + # --- Pass TensorBoard writer to loop helpers --- ADDED --- + if tb_writer: + training_loop.loop_helpers.set_tensorboard_writer(tb_writer) + # --- END ADDED --- # --- Run Training Loop --- training_loop.initialize_workers() @@ -214,7 +260,7 @@ def run_training_headless_mode( except Exception as e: logger.critical( - f"An unhandled error occurred during headless training setup or execution: {e}" + f"An unhandled error occurred during training setup or execution: {e}" ) traceback.print_exc() # Attempt to log failure status if MLflow run was started @@ -246,6 +292,36 @@ def run_training_headless_mode( else: final_status = "SETUP_FAILED" + # --- Close TensorBoard Writer --- ADDED --- + if tb_writer: + try: + tb_writer.flush() + tb_writer.close() + logger.info("TensorBoard writer closed.") + # Log TB dir as artifact AFTER closing + # Check if components is not None before accessing persist_config + if ( + mlflow_run_active + and tb_log_dir + and Path(tb_log_dir).exists() + and components is not None + ): + try: + mlflow.log_artifacts( + tb_log_dir, + artifact_path=components.persist_config.TENSORBOARD_DIR_NAME, + ) + logger.info( + f"Logged TensorBoard directory to MLflow artifacts: {components.persist_config.TENSORBOARD_DIR_NAME}" + ) + except Exception as tb_artifact_err: + logger.error( + f"Failed to log TensorBoard directory to MLflow: {tb_artifact_err}" + ) + except Exception as tb_close_err: + logger.error(f"Error closing TensorBoard writer: {tb_close_err}") + # --- END ADDED --- + # End MLflow run if mlflow_run_active: try: @@ -261,7 +337,7 @@ def run_training_headless_mode( if ray_initialized_by_setup and ray.is_initialized(): try: ray.shutdown() - logger.info("Ray shut down by headless runner.") + logger.info("Ray shut down by runner.") except Exception as e: logger.error(f"Error shutting down Ray: {e}", exc_info=True) @@ -279,5 +355,5 @@ def run_training_headless_mode( except Exception as e_close: print(f"Error closing log file handler: {e_close}", file=sys.__stderr__) - logger.info(f"Headless training finished with exit code {exit_code}.") + logger.info(f"Training finished with exit code {exit_code}.") return exit_code diff --git a/alphatriangle/training/runners.py b/alphatriangle/training/runners.py index 250674b..b9dc65e 100644 --- a/alphatriangle/training/runners.py +++ b/alphatriangle/training/runners.py @@ -4,10 +4,9 @@ Imports functions from specific runner modules. """ -from .headless_runner import run_training_headless_mode -from .visual_runner import run_training_visual_mode +# Import from the renamed runner +from .runner import run_training # Rename export __all__ = [ - "run_training_headless_mode", - "run_training_visual_mode", + "run_training", # Export the single runner function ] diff --git a/alphatriangle/training/setup.py b/alphatriangle/training/setup.py index 8db991b..72b8626 100644 --- a/alphatriangle/training/setup.py +++ b/alphatriangle/training/setup.py @@ -1,10 +1,13 @@ -# File: alphatriangle/training/setup.py import logging from typing import TYPE_CHECKING import ray import torch +# Import EnvConfig from trianglengin +from trianglengin.config import EnvConfig + +# Keep alphatriangle imports from .. import config, utils from ..data import DataManager from ..nn import NeuralNetwork @@ -12,6 +15,8 @@ from ..stats import StatsCollectorActor from .components import TrainingComponents +# REMOVE VisualStateActor import + if TYPE_CHECKING: from ..config import PersistenceConfig, TrainConfig @@ -31,68 +36,61 @@ def setup_training_components( detected_cpu_cores: int | None = None try: - # --- Initialize Ray (if needed) and Detect Cores --- if not ray.is_initialized(): try: - # Attempt initialization ray.init(logging_level=logging.WARNING, log_to_driver=True) ray_initialized_here = True logger.info("Ray initialized by setup_training_components.") except Exception as e: - # Log critical error and re-raise to stop setup logger.critical(f"Failed to initialize Ray: {e}", exc_info=True) raise RuntimeError("Ray initialization failed") from e else: logger.info("Ray already initialized.") - # Ensure flag is False if Ray was already running ray_initialized_here = False - # --- Detect Cores (proceed even if Ray was already initialized) --- try: resources = ray.cluster_resources() + # Reserve 1 core for main process, 1 for stats/other? Be conservative. detected_cpu_cores = int(resources.get("CPU", 0)) - 2 - logger.info(f"Ray detected {detected_cpu_cores} CPU cores.") + logger.info( + f"Ray detected {detected_cpu_cores} available CPU cores for workers." + ) except Exception as e: logger.error(f"Could not get Ray cluster resources: {e}") - # Continue without detected cores, will use configured value - # --- Initialize Configurations --- train_config = train_config_override persist_config = persist_config_override - env_config = config.EnvConfig() + # Instantiate EnvConfig from trianglengin + env_config = EnvConfig() model_config = config.ModelConfig() mcts_config = config.MCTSConfig() - # --- Adjust Worker Count based on Detected Cores --- requested_workers = train_config.NUM_SELF_PLAY_WORKERS - actual_workers = requested_workers # Start with configured value + actual_workers = requested_workers if detected_cpu_cores is not None and detected_cpu_cores > 0: - # --- CHANGED: Prioritize detected cores --- - actual_workers = detected_cpu_cores # Use detected cores + # Cap requested workers by detected cores + actual_workers = min(requested_workers, detected_cpu_cores) if actual_workers != requested_workers: logger.info( - f"Overriding configured workers ({requested_workers}) with detected CPU cores ({actual_workers})." + f"Adjusting requested workers ({requested_workers}) to available cores ({detected_cpu_cores}). Using {actual_workers} workers." ) else: logger.info( - f"Using {actual_workers} self-play workers (matches detected cores)." + f"Using {actual_workers} self-play workers (fits within detected cores)." ) - # --- END CHANGED --- else: logger.warning( f"Could not detect valid CPU cores ({detected_cpu_cores}). Using configured NUM_SELF_PLAY_WORKERS: {requested_workers}" ) - actual_workers = requested_workers # Fallback to configured value + actual_workers = requested_workers - # Update the config object with the final determined number train_config.NUM_SELF_PLAY_WORKERS = actual_workers logger.info(f"Final worker count set to: {train_config.NUM_SELF_PLAY_WORKERS}") - # --- Validate Configs --- + # Pass trianglengin.EnvConfig to validation config.print_config_info_and_validate(mcts_config) - # --- Setup --- utils.set_random_seeds(train_config.RANDOM_SEED) device = utils.get_device(train_config.DEVICE) worker_device = utils.get_device(train_config.WORKER_DEVICE) @@ -100,35 +98,38 @@ def setup_training_components( logger.info(f"Determined Worker Device: {worker_device}") logger.info(f"Model Compilation Enabled: {train_config.COMPILE_MODEL}") - # --- Initialize Core Components --- stats_collector_actor = StatsCollectorActor.remote(max_history=500_000) # type: ignore logger.info("Initialized StatsCollectorActor with max_history=500k.") + # REMOVE VisualStateActor instantiation + # visual_state_actor = VisualStateActor.remote() # type: ignore + # logger.info("Initialized VisualStateActor.") + + # Pass trianglengin.EnvConfig to NN and Trainer neural_net = NeuralNetwork(model_config, env_config, train_config, device) buffer = ExperienceBuffer(train_config) trainer = Trainer(neural_net, train_config, env_config) data_manager = DataManager(persist_config, train_config) - # --- Create Components Bundle --- components = TrainingComponents( nn=neural_net, buffer=buffer, trainer=trainer, data_manager=data_manager, stats_collector_actor=stats_collector_actor, + # REMOVE visual_state_actor train_config=train_config, - env_config=env_config, + env_config=env_config, # Store trianglengin.EnvConfig model_config=model_config, mcts_config=mcts_config, persist_config=persist_config, ) - # Return components and the flag indicating if Ray was initialized *by this function* return components, ray_initialized_here except Exception as e: logger.critical(f"Error setting up training components: {e}", exc_info=True) - # Return None and the Ray init flag (which might be True if init succeeded before error) return None, ray_initialized_here +# ... (count_parameters remains the same) ... def count_parameters(model: torch.nn.Module) -> tuple[int, int]: """Counts total and trainable parameters.""" total_params = sum(p.numel() for p in model.parameters()) diff --git a/alphatriangle/training/visual_runner.py b/alphatriangle/training/visual_runner.py deleted file mode 100644 index 553512a..0000000 --- a/alphatriangle/training/visual_runner.py +++ /dev/null @@ -1,498 +0,0 @@ -# File: alphatriangle/training/visual_runner.py -import logging -import queue -import sys -import threading -import time -import traceback -from collections import deque -from pathlib import Path -from typing import Any - -import mlflow -import pygame -import ray -import torch - -from .. import config, environment, visualization -from ..config import APP_NAME, PersistenceConfig, TrainConfig -from ..utils.sumtree import SumTree -from .components import TrainingComponents -from .logging_utils import ( - Tee, - get_root_logger, - log_configs_to_mlflow, - setup_file_logging, -) -from .loop import TrainingLoop # Import TrainingLoop here -from .setup import count_parameters, setup_training_components - -logger = logging.getLogger(__name__) - -# Queue for loop to send combined state dict {worker_id: state, -1: global_stats} -visual_state_queue: queue.Queue[dict[int, Any] | None] = queue.Queue(maxsize=5) - - -def _initialize_mlflow(persist_config: PersistenceConfig, run_name: str) -> bool: - """Sets up MLflow tracking and starts a run.""" - try: - mlflow_abs_path = persist_config.get_mlflow_abs_path() - Path(mlflow_abs_path).mkdir(parents=True, exist_ok=True) - mlflow_tracking_uri = persist_config.MLFLOW_TRACKING_URI - mlflow.set_tracking_uri(mlflow_tracking_uri) - mlflow.set_experiment(APP_NAME) - logger.info(f"Set MLflow tracking URI to: {mlflow_tracking_uri}") - logger.info(f"Set MLflow experiment to: {APP_NAME}") - - mlflow.start_run(run_name=run_name) - active_run = mlflow.active_run() - if active_run: - logger.info(f"MLflow Run started (ID: {active_run.info.run_id}).") - return True - else: - logger.error("MLflow run failed to start.") - return False - except Exception as e: - logger.error(f"Failed to initialize MLflow: {e}", exc_info=True) - return False - - -def _load_and_apply_initial_state(components: TrainingComponents) -> TrainingLoop: - """Loads initial state using DataManager and applies it to components, returning an initialized TrainingLoop.""" - loaded_state = components.data_manager.load_initial_state() - # Pass visual queue to TrainingLoop constructor - training_loop = TrainingLoop(components, visual_state_queue=visual_state_queue) - - if loaded_state.checkpoint_data: - cp_data = loaded_state.checkpoint_data - logger.info( - f"Applying loaded checkpoint data (Run: {cp_data.run_name}, Step: {cp_data.global_step})" - ) - - if cp_data.model_state_dict: - components.nn.set_weights(cp_data.model_state_dict) - if cp_data.optimizer_state_dict: - try: - components.trainer.optimizer.load_state_dict( - cp_data.optimizer_state_dict - ) - for state in components.trainer.optimizer.state.values(): - for k, v in state.items(): - if isinstance(v, torch.Tensor): - state[k] = v.to(components.nn.device) - logger.info("Optimizer state loaded and moved to device.") - except Exception as opt_load_err: - logger.error( - f"Could not load optimizer state: {opt_load_err}. Optimizer might reset." - ) - # --- CHANGED: Removed isinstance check, rely on ActorHandle type hint --- - if ( - cp_data.stats_collector_state - and components.stats_collector_actor is not None - ): - # --- END CHANGED --- - try: - # MyPy should now know this is an ActorHandle - set_state_ref = components.stats_collector_actor.set_state.remote( - cp_data.stats_collector_state - ) - ray.get(set_state_ref, timeout=5.0) - logger.info("StatsCollectorActor state restored.") - except Exception as e: - logger.error( - f"Error restoring StatsCollectorActor state: {e}", exc_info=True - ) - - training_loop.set_initial_state( - cp_data.global_step, - cp_data.episodes_played, - cp_data.total_simulations_run, - ) - else: - logger.info("No checkpoint data loaded. Starting fresh.") - training_loop.set_initial_state(0, 0, 0) - - if loaded_state.buffer_data: - if components.train_config.USE_PER: - logger.info("Rebuilding PER SumTree from loaded buffer data...") - if not hasattr(components.buffer, "tree") or components.buffer.tree is None: - components.buffer.tree = SumTree(components.buffer.capacity) - else: - components.buffer.tree = SumTree(components.buffer.capacity) - max_p = 1.0 - for exp in loaded_state.buffer_data.buffer_list: - components.buffer.tree.add(max_p, exp) - logger.info(f"PER buffer loaded. Size: {len(components.buffer)}") - else: - components.buffer.buffer = deque( - loaded_state.buffer_data.buffer_list, - maxlen=components.buffer.capacity, - ) - logger.info(f"Uniform buffer loaded. Size: {len(components.buffer)}") - if training_loop.buffer_fill_progress: - training_loop.buffer_fill_progress.set_current_steps(len(components.buffer)) - else: - logger.info("No buffer data loaded.") - - components.nn.model.train() - return training_loop - - -def _save_final_state(training_loop: TrainingLoop): - """Saves the final training state.""" - if not training_loop: - logger.warning("Cannot save final state: TrainingLoop not available.") - return - components = training_loop.components - logger.info("Saving final training state...") - try: - # Pass the actor handle directly - components.data_manager.save_training_state( - nn=components.nn, - optimizer=components.trainer.optimizer, - stats_collector_actor=components.stats_collector_actor, - buffer=components.buffer, - global_step=training_loop.global_step, - episodes_played=training_loop.episodes_played, - total_simulations_run=training_loop.total_simulations_run, - is_final=True, - ) - except Exception as e_save: - logger.error(f"Failed to save final training state: {e_save}", exc_info=True) - - -def _training_loop_thread_func(training_loop: TrainingLoop): - """Function to run the training loop in a separate thread.""" - logger = logging.getLogger(__name__) # Get logger within thread - try: - logger.info("Training loop thread started.") - training_loop.initialize_workers() - training_loop.run() - logger.info("Training loop thread finished.") - except Exception as e: - logger.critical(f"Error in training loop thread: {e}", exc_info=True) - training_loop.training_exception = e - finally: - # Signal the main visualization loop to exit - try: - while not visual_state_queue.empty(): - try: - visual_state_queue.get_nowait() - except queue.Empty: - break - visual_state_queue.put(None, timeout=1.0) - except queue.Full: - logger.error("Visual queue still full during shutdown.") - except Exception as e_q: - logger.error(f"Error putting None signal into visual queue: {e_q}") - - -def run_training_visual_mode( - log_level_str: str, - train_config_override: TrainConfig, - persist_config_override: PersistenceConfig, -) -> int: - """Runs the training pipeline in visual mode.""" - main_thread_exception = None - train_thread = None - training_loop: TrainingLoop | None = None - components: TrainingComponents | None = None - exit_code = 1 - original_stdout = sys.stdout - original_stderr = sys.stderr - file_handler = None - tee_stdout = None - tee_stderr = None - ray_initialized_by_setup = False - mlflow_run_active = False - total_params: int | None = None - trainable_params: int | None = None - - try: - # --- Setup File Logging & Redirection --- - log_file_path = setup_file_logging( - persist_config_override, train_config_override.RUN_NAME, "visual" - ) - log_level = logging.getLevelName(log_level_str.upper()) - logger.info( - f"Logging {logging.getLevelName(log_level)} and higher messages to: {log_file_path}" - ) - root_logger = get_root_logger() - file_handler = next( - (h for h in root_logger.handlers if isinstance(h, logging.FileHandler)), - None, - ) - - if file_handler and hasattr(file_handler, "stream") and file_handler.stream: - tee_stdout = Tee( - original_stdout, - file_handler.stream, - main_stream_for_fileno=original_stdout, - ) - tee_stderr = Tee( - original_stderr, - file_handler.stream, - main_stream_for_fileno=original_stderr, - ) - sys.stdout = tee_stdout - sys.stderr = tee_stderr - print("--- Stdout/Stderr redirected to console and log file ---") - logger.info("Stdout/Stderr redirected to console and log file.") - else: - logger.error( - "Could not redirect stdout/stderr: File handler stream not available." - ) - - # --- Setup Components (includes Ray init) --- - components, ray_initialized_by_setup = setup_training_components( - train_config_override, persist_config_override - ) - if not components: - raise RuntimeError("Failed to initialize training components.") - - # --- Initialize MLflow --- - mlflow_run_active = _initialize_mlflow( - components.persist_config, components.train_config.RUN_NAME - ) - if mlflow_run_active: - log_configs_to_mlflow(components) # Log configs after run starts - # Log parameter counts after MLflow run starts - total_params, trainable_params = count_parameters(components.nn.model) - logger.info( - f"Model Parameters: Total={total_params:,}, Trainable={trainable_params:,}" - ) - mlflow.log_param("model_total_params", total_params) - mlflow.log_param("model_trainable_params", trainable_params) - else: - logger.warning("MLflow initialization failed, proceeding without MLflow.") - - # --- Load State & Initialize Loop --- - training_loop = _load_and_apply_initial_state(components) - - # --- Start Training Thread --- - train_thread = threading.Thread( - target=_training_loop_thread_func, args=(training_loop,), daemon=True - ) - train_thread.start() - logger.info("Training loop thread launched.") - - # --- Initialize Visualization --- - vis_config = config.VisConfig() - pygame.init() - pygame.font.init() - screen = pygame.display.set_mode( - (vis_config.SCREEN_WIDTH, vis_config.SCREEN_HEIGHT), pygame.RESIZABLE - ) - pygame.display.set_caption( - f"{config.APP_NAME} - Training Visual Mode ({components.train_config.RUN_NAME})" - ) - clock = pygame.time.Clock() - fonts = visualization.load_fonts() - # Pass the actor handle directly - dashboard_renderer = visualization.DashboardRenderer( - screen, - vis_config, - components.env_config, - fonts, - components.stats_collector_actor, - components.model_config, - total_params=total_params, # Pass param counts - trainable_params=trainable_params, - ) - - current_worker_states: dict[int, environment.GameState] = {} - current_global_stats: dict[str, Any] = {} - has_received_data = False - - # --- Visualization Loop (Main Thread) --- - running = True - while running: - for event in pygame.event.get(): - if event.type == pygame.QUIT: - running = False - if event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE: - running = False - if event.type == pygame.VIDEORESIZE: - try: - w, h = max(640, event.w), max(480, event.h) - screen = pygame.display.set_mode((w, h), pygame.RESIZABLE) - dashboard_renderer.screen = screen - dashboard_renderer.layout_rects = None - except pygame.error as e: - logger.error(f"Error resizing window: {e}") - - # Process Visual Queue - try: - visual_data = visual_state_queue.get(timeout=0.05) - if visual_data is None: - if train_thread and not train_thread.is_alive(): - running = False - logger.info("Received exit signal from training thread.") - elif isinstance(visual_data, dict): - has_received_data = True - global_stats_update = visual_data.pop(-1, {}) - if isinstance(global_stats_update, dict): - if not isinstance(current_global_stats, dict): - current_global_stats = {} - current_global_stats.update(global_stats_update) - else: - logger.warning( - f"Received non-dict global stats update: {type(global_stats_update)}" - ) - - current_worker_states = { - k: v - for k, v in visual_data.items() - if isinstance(k, int) - and k >= 0 - and isinstance(v, environment.GameState) - } - remaining_items = { - k: v - for k, v in visual_data.items() - if k != -1 and k not in current_worker_states - } - if remaining_items: - logger.warning( - f"Unexpected items remaining in visual_data after processing: {remaining_items.keys()}" - ) - else: - logger.warning( - f"Received unexpected item from visual queue: {type(visual_data)}" - ) - except queue.Empty: - pass - except Exception as q_get_err: - logger.error(f"Error getting from visual queue: {q_get_err}") - time.sleep(0.1) - - # Rendering Logic - screen.fill(visualization.colors.DARK_GRAY) - if has_received_data: - try: - dashboard_renderer.render( - current_worker_states, current_global_stats - ) - except Exception as render_err: - logger.error(f"Error during rendering: {render_err}", exc_info=True) - err_font = fonts.get("help") - if err_font: - err_surf = err_font.render( - f"Render Error: {render_err}", - True, - visualization.colors.RED, - ) - screen.blit(err_surf, (10, screen.get_height() // 2)) - else: - help_font = fonts.get("help") - if help_font: - wait_surf = help_font.render( - "Waiting for first data from training...", - True, - visualization.colors.LIGHT_GRAY, - ) - wait_rect = wait_surf.get_rect( - center=(screen.get_width() // 2, screen.get_height() // 2) - ) - screen.blit(wait_surf, wait_rect) - - pygame.display.flip() - - # Check Training Thread Status - if train_thread and not train_thread.is_alive() and running: - logger.warning("Training loop thread terminated unexpectedly.") - if training_loop and training_loop.training_exception: - logger.error( - f"Training thread terminated due to exception: {training_loop.training_exception}" - ) - main_thread_exception = training_loop.training_exception - running = False - - clock.tick(vis_config.FPS) - - except Exception as e: - logger.critical( - f"An unhandled error occurred in visual training script (main thread): {e}" - ) - traceback.print_exc() - main_thread_exception = e - if mlflow_run_active: - try: - mlflow.log_param("training_status", "VIS_FAILED") - mlflow.log_param("error_message", f"MainThread: {str(e)}") - except Exception as mlf_err: - logger.error(f"Failed to log main thread error to MLflow: {mlf_err}") - - finally: - # Restore stdout/stderr - if tee_stdout: - sys.stdout = original_stdout - if tee_stderr: - sys.stderr = original_stderr - print("--- Restored stdout/stderr ---") - - logger.info("Initiating shutdown sequence...") - if training_loop and not training_loop.stop_requested.is_set(): - training_loop.request_stop() - - if train_thread and train_thread.is_alive(): - logger.info("Waiting for training loop thread to join...") - train_thread.join(timeout=15.0) - if train_thread.is_alive(): - logger.error("Training loop thread did not exit gracefully.") - - # --- Cleanup --- - final_status = "UNKNOWN" - error_msg = "" - if training_loop: - # Save final state - _save_final_state(training_loop) - # Cleanup loop actors - training_loop.cleanup_actors() - # Determine final status - if main_thread_exception or training_loop.training_exception: - final_status = "FAILED" - error_msg = str( - main_thread_exception or training_loop.training_exception - ) - elif training_loop.training_complete: - final_status = "COMPLETED" - else: - final_status = "INTERRUPTED" - else: - final_status = "SETUP_FAILED" - - # End MLflow run - if mlflow_run_active: - try: - mlflow.log_param("training_status", final_status) - if error_msg: - mlflow.log_param("error_message", error_msg) - mlflow.end_run() - logger.info(f"MLflow Run ended. Final Status: {final_status}") - except Exception as mlf_end_err: - logger.error(f"Error ending MLflow run: {mlf_end_err}") - - pygame.quit() - logger.info("Pygame quit.") - - # Shutdown Ray ONLY if initialized by the setup function in this process - if ray_initialized_by_setup and ray.is_initialized(): - try: - ray.shutdown() - logger.info("Ray shut down by visual runner.") - except Exception as e: - logger.error(f"Error shutting down Ray: {e}", exc_info=True) - - # Close file handler - if file_handler: - try: - file_handler.flush() - file_handler.close() - root_logger = get_root_logger() - root_logger.removeHandler(file_handler) - except Exception as e_close: - print(f"Error closing log file handler: {e_close}", file=sys.__stderr__) - - logger.info(f"Visual training finished with exit code {exit_code}.") - return exit_code diff --git a/alphatriangle/training/worker_manager.py b/alphatriangle/training/worker_manager.py index ed511b8..957af1b 100644 --- a/alphatriangle/training/worker_manager.py +++ b/alphatriangle/training/worker_manager.py @@ -1,10 +1,11 @@ -# File: alphatriangle/training/worker_manager.py import logging from typing import TYPE_CHECKING import ray from pydantic import ValidationError +# Import EnvConfig from trianglengin +# Keep alphatriangle imports from ..rl import SelfPlayResult, SelfPlayWorker if TYPE_CHECKING: @@ -37,9 +38,10 @@ def initialize_workers(self): for i in range(self.train_config.NUM_SELF_PLAY_WORKERS): try: + # Pass trianglengin.EnvConfig to worker worker = SelfPlayWorker.options(num_cpus=1).remote( actor_id=i, - env_config=self.components.env_config, + env_config=self.components.env_config, # Pass trianglengin.EnvConfig mcts_config=self.components.mcts_config, model_config=self.components.model_config, train_config=self.train_config, @@ -58,6 +60,7 @@ def initialize_workers(self): ) del weights_ref + # ... (submit_initial_tasks, submit_task, get_completed_tasks remain the same) ... def submit_initial_tasks(self): """Submits the first task for each active worker.""" logger.info("Submitting initial tasks to workers...") @@ -149,10 +152,9 @@ def get_completed_tasks( return completed_results - # --- CHANGED: Accept global_step --- + # ... (update_worker_networks, get_num_active_workers, get_num_pending_tasks, cleanup_actors remain the same) ... def update_worker_networks(self, global_step: int): """Sends the latest network weights and current global_step to all active workers.""" - # --- END CHANGED --- active_workers = [ w for i, w in enumerate(self.workers) @@ -163,7 +165,6 @@ def update_worker_networks(self, global_step: int): logger.debug(f"Updating worker networks for step {global_step}...") current_weights = self.nn.get_weights() weights_ref = ray.put(current_weights) - # --- CHANGED: Create separate task lists --- set_weights_tasks = [ worker.set_weights.remote(weights_ref) for worker in active_workers ] @@ -171,24 +172,20 @@ def update_worker_networks(self, global_step: int): worker.set_current_trainer_step.remote(global_step) for worker in active_workers ] - # --- END CHANGED --- all_tasks = set_weights_tasks + set_step_tasks if not all_tasks: del weights_ref return try: - # Wait for all tasks to complete ray.get(all_tasks, timeout=120.0) logger.debug( f"Worker networks updated for {len(active_workers)} workers to step {global_step}." ) - # Logging the update event is now handled in TrainingLoop after this call succeeds except ray.exceptions.RayActorError as e: logger.error( f"A worker actor failed during weight/step update: {e}", exc_info=True ) - # Consider attempting to identify and remove the failed worker except ray.exceptions.GetTimeoutError: logger.error("Timeout waiting for workers to update weights/step.") except Exception as e: @@ -196,7 +193,7 @@ def update_worker_networks(self, global_step: int): f"Unexpected error updating worker networks/step: {e}", exc_info=True ) finally: - del weights_ref # Ensure ref is deleted + del weights_ref def get_num_active_workers(self) -> int: """Returns the number of currently active workers.""" diff --git a/alphatriangle/visualization/README.md b/alphatriangle/visualization/README.md deleted file mode 100644 index 88c8452..0000000 --- a/alphatriangle/visualization/README.md +++ /dev/null @@ -1,69 +0,0 @@ -# File: alphatriangle/visualization/README.md -# Visualization Module (`alphatriangle.visualization`) - -## Purpose and Architecture - -This module is responsible for rendering the game state visually using the Pygame library. It provides components for drawing the grid, shapes, previews, HUD elements, and statistics plots. **In training visualization mode, it now renders the states of multiple self-play workers in a grid layout alongside plots and progress bars (with specific information displayed on each bar).** - -- **Core Components ([`core/README.md`](core/README.md)):** - - `Visualizer`: Orchestrates the rendering process for interactive modes ("play", "debug"). It manages the layout, calls drawing functions, and handles hover/selection states specific to visualization. - - `GameRenderer`: **Adapted renderer** for displaying **multiple** game states and statistics during training visualization (`run_training_visual.py`). It uses `layout.py` to divide the screen. It renders worker game states in one area and statistics plots/progress bars in another. It re-instantiates [`alphatriangle.stats.Plotter`](../stats/plotter.py). - - `DashboardRenderer`: Renderer specifically for the **training visualization mode**. It uses `layout.py` to divide the screen into a worker game grid area and a statistics area. It renders multiple worker `GameState` objects (using `GameRenderer` instances) in the top grid and displays statistics plots (using `alphatriangle.stats.Plotter`) and progress bars in the bottom area. **The training progress bar shows model/parameter info, while the buffer progress bar shows global training stats (updates, episodes, sims, workers).** It takes a dictionary mapping worker IDs to `GameState` objects and a dictionary of global statistics. - - `layout`: Calculates the screen positions and sizes for different UI areas (worker grid, stats area, plots). - - `fonts`: Loads necessary font files. - - `colors`: Defines a centralized palette of RGB color tuples. - - `coord_mapper`: Provides functions to map screen coordinates to grid coordinates (`get_grid_coords_from_screen`) and preview indices (`get_preview_index_from_screen`). -- **Drawing Components ([`drawing/README.md`](drawing/README.md)):** - - Contains specific functions for drawing different elements onto Pygame surfaces: - - `grid`: Draws the grid background and occupied/empty triangles. - - `shapes`: Draws individual shapes (used by previews). - - `previews`: Renders the shape preview area. - - `hud`: Renders text information like global training stats and help text at the bottom. - - `highlight`: Draws debug highlights. -- **UI Components ([`ui/README.md`](ui/README.md)):** - - Contains reusable UI elements like `ProgressBar`. - -## Exposed Interfaces - -- **Core Classes & Functions:** - - `Visualizer`: Main renderer for interactive modes. - - `GameRenderer`: Renderer for a single worker's game state. - - `DashboardRenderer`: Renderer for combined multi-game/stats training visualization. - - `calculate_interactive_layout`, `calculate_training_layout`: Calculates UI layout rectangles. - - `load_fonts`: Loads Pygame fonts. - - `colors`: Module containing color constants (e.g., `colors.WHITE`). - - `get_grid_coords_from_screen`: Maps screen to grid coordinates. - - `get_preview_index_from_screen`: Maps screen to preview index. -- **Drawing Functions (primarily used internally by Visualizer/GameRenderer but exposed):** - - `draw_grid_background`, `draw_grid_triangles`, `draw_grid_indices` - - `draw_shape` - - `render_previews`, `draw_placement_preview`, `draw_floating_preview` - - `render_hud` - - `draw_debug_highlight` -- **UI Components:** - - `ProgressBar`: Class for rendering progress bars. -- **Config:** - - `VisConfig`: Configuration class (re-exported from [`alphatriangle.config`](../config/README.md)). - -## Dependencies - -- **[`alphatriangle.config`](../config/README.md)**: - - `VisConfig`, `EnvConfig`, `ModelConfig`: Used extensively for layout, sizing, and coordinate mapping. -- **[`alphatriangle.environment`](../environment/README.md)**: - - `GameState`: The primary object whose state is visualized. - - `GridData`: Accessed via `GameState` or passed directly to drawing functions. -- **[`alphatriangle.structs`](../structs/README.md)**: - - Uses `Triangle`, `Shape`, `COLOR_ID_MAP`, `DEBUG_COLOR_ID`, `NO_COLOR_ID`. -- **[`alphatriangle.stats`](../stats/README.md)**: - - Uses `Plotter` within `DashboardRenderer`. -- **[`alphatriangle.utils`](../utils/README.md)**: - - Uses `geometry.is_point_in_polygon`, `helpers.format_eta`, `types.StatsCollectorData`. -- **`pygame`**: - - The core library used for all drawing, surface manipulation, event handling (via `interaction`), and font rendering. -- **`matplotlib`**: - - Used by `alphatriangle.stats.Plotter`. -- **Standard Libraries:** `typing`, `logging`, `math`, `time`. - ---- - -**Note:** Please keep this README updated when changing rendering logic, adding new visual elements, modifying layout calculations, or altering the interfaces exposed to other modules (like `interaction` or the main application scripts). Accurate documentation is crucial for maintainability. \ No newline at end of file diff --git a/alphatriangle/visualization/__init__.py b/alphatriangle/visualization/__init__.py deleted file mode 100644 index de8e520..0000000 --- a/alphatriangle/visualization/__init__.py +++ /dev/null @@ -1,59 +0,0 @@ -""" -Visualization module for rendering the game state using Pygame. -""" - -from ..config import VisConfig -from .core import colors -from .core.coord_mapper import ( - get_grid_coords_from_screen, - get_preview_index_from_screen, -) -from .core.dashboard_renderer import DashboardRenderer -from .core.fonts import load_fonts -from .core.game_renderer import GameRenderer -from .core.layout import ( - calculate_interactive_layout, - calculate_training_layout, -) -from .core.visualizer import Visualizer -from .drawing.grid import ( - draw_grid_background, - draw_grid_indices, - draw_grid_triangles, -) -from .drawing.highlight import draw_debug_highlight -from .drawing.hud import render_hud -from .drawing.previews import ( - draw_floating_preview, - draw_placement_preview, - render_previews, -) -from .drawing.shapes import draw_shape -from .ui.progress_bar import ProgressBar - -__all__ = [ - # Core Renderers & Layout - "Visualizer", - "GameRenderer", - "DashboardRenderer", - "calculate_interactive_layout", - "calculate_training_layout", - "load_fonts", - "colors", # Export colors module - "get_grid_coords_from_screen", - "get_preview_index_from_screen", - # Drawing Functions - "draw_grid_background", - "draw_grid_triangles", - "draw_grid_indices", - "draw_shape", - "render_previews", - "draw_placement_preview", - "draw_floating_preview", - "render_hud", - "draw_debug_highlight", - # UI Components - "ProgressBar", - # Config - "VisConfig", -] diff --git a/alphatriangle/visualization/core/README.md b/alphatriangle/visualization/core/README.md deleted file mode 100644 index ed6290b..0000000 --- a/alphatriangle/visualization/core/README.md +++ /dev/null @@ -1,62 +0,0 @@ -# File: alphatriangle/visualization/core/README.md -# Visualization Core Submodule (`alphatriangle.visualization.core`) - -## Purpose and Architecture - -This submodule contains the central classes and foundational elements for the visualization system. It orchestrates rendering, manages layout and coordinate systems, and defines core visual properties like colors and fonts. - -- **Render Orchestration:** - - [`Visualizer`](visualizer.py): The main class for rendering in **interactive modes** ("play", "debug"). It maintains the Pygame screen, calculates layout using `layout.py`, manages cached preview area rectangles, and calls appropriate drawing functions from [`alphatriangle.visualization.drawing`](../drawing/README.md). **It receives interaction state (hover position, selected index) via its `render` method to display visual feedback.** - - [`GameRenderer`](game_renderer.py): **Simplified renderer** responsible for drawing a **single** worker's `GameState` (grid and previews) within a specified sub-rectangle. Used by the `DashboardRenderer`. - - [`DashboardRenderer`](dashboard_renderer.py): Renderer specifically for the **training visualization mode**. It uses `layout.py` to divide the screen into a worker game grid area and a statistics area. It renders multiple worker `GameState` objects (using `GameRenderer` instances) in the top grid and displays statistics plots (using [`alphatriangle.stats.Plotter`](../../stats/plotter.py)) and progress bars in the bottom area. **The training progress bar shows model/parameter info, while the buffer progress bar shows global training stats (worker weight updates, episodes, sims, worker status). Plots now include black, solid vertical lines (drawn on top) indicating the `global_step` when worker weights were updated, mapped to the appropriate position on each plot's x-axis.** It takes a dictionary mapping worker IDs to `GameState` objects and a dictionary of global statistics. -- **Layout Management:** - - [`layout.py`](layout.py): Contains functions (`calculate_interactive_layout`, `calculate_training_layout`) to determine the size and position of the main UI areas based on the screen dimensions, mode, and `VisConfig`. -- **Coordinate System:** - - [`coord_mapper.py`](coord_mapper.py): Provides essential mapping functions: - - `_calculate_render_params`: Internal helper to get scaling and offset for grid rendering. - - `get_grid_coords_from_screen`: Converts mouse/screen coordinates into logical grid (row, column) coordinates. - - `get_preview_index_from_screen`: Converts mouse/screen coordinates into the index of the shape preview slot being pointed at. -- **Visual Properties:** - - [`colors.py`](colors.py): Defines a centralized palette of named color constants (RGB tuples). - - [`fonts.py`](fonts.py): Contains the `load_fonts` function to load and manage Pygame font objects. - -## Exposed Interfaces - -- **Classes:** - - `Visualizer`: Renderer for interactive modes. - - `__init__(...)` - - `render(game_state: GameState, mode: str, **interaction_state)`: Renders based on game state and interaction hints. - - `ensure_layout() -> Dict[str, pygame.Rect]` - - `screen`: Public attribute (Pygame Surface). - - `preview_rects`: Public attribute (cached preview area rects). - - `GameRenderer`: Renderer for a single worker's game state. - - `__init__(...)` - - `render_worker_state(target_surface: pygame.Surface, area_rect: pygame.Rect, worker_id: int, game_state: Optional[GameState], worker_step_stats: Optional[Dict[str, Any]])` - - `DashboardRenderer`: Renderer for combined multi-game/stats training visualization. - - `__init__(...)` - - `render(worker_states: Dict[int, GameState], global_stats: Optional[Dict[str, Any]])` - - `screen`: Public attribute (Pygame Surface). -- **Functions:** - - `calculate_interactive_layout(...) -> Dict[str, pygame.Rect]` - - `calculate_training_layout(...) -> Dict[str, pygame.Rect]` - - `load_fonts() -> Dict[str, Optional[pygame.font.Font]]` - - `get_grid_coords_from_screen(...) -> Optional[Tuple[int, int]]` - - `get_preview_index_from_screen(...) -> Optional[int]` -- **Modules:** - - `colors`: Provides color constants (e.g., `colors.RED`). - -## Dependencies - -- **[`alphatriangle.config`](../../config/README.md)**: `VisConfig`, `EnvConfig`, `ModelConfig`. -- **[`alphatriangle.environment`](../../environment/README.md)**: `GameState`, `GridData`. -- **[`alphatriangle.stats`](../../stats/README.md)**: `Plotter`, `StatsCollectorActor`. -- **[`alphatriangle.utils`](../../utils/README.md)**: `types`, `helpers`. -- **[`alphatriangle.visualization.drawing`](../drawing/README.md)**: Drawing functions are called by renderers. -- **[`alphatriangle.visualization.ui`](../ui/README.md)**: `ProgressBar` (used by `DashboardRenderer`). -- **`pygame`**: Used for surfaces, rectangles, fonts, display management. -- **`ray`**: Used by `DashboardRenderer` (for actor handle type hint). -- **Standard Libraries:** `typing`, `logging`, `math`, `collections.deque`. - ---- - -**Note:** Please keep this README updated when changing the core rendering logic, layout calculations, coordinate mapping, or the interfaces of the renderers. Accurate documentation is crucial for maintainability. \ No newline at end of file diff --git a/alphatriangle/visualization/core/__init__.py b/alphatriangle/visualization/core/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/alphatriangle/visualization/core/colors.py b/alphatriangle/visualization/core/colors.py deleted file mode 100644 index 0e819b7..0000000 --- a/alphatriangle/visualization/core/colors.py +++ /dev/null @@ -1,51 +0,0 @@ -# File: alphatriangle/visualization/core/colors.py -"""Centralized color definitions (RGB tuples 0-255).""" - -WHITE: tuple[int, int, int] = (255, 255, 255) -BLACK: tuple[int, int, int] = (0, 0, 0) -LIGHT_GRAY: tuple[int, int, int] = (180, 180, 180) -GRAY: tuple[int, int, int] = (100, 100, 100) -DARK_GRAY: tuple[int, int, int] = (40, 40, 40) -RED: tuple[int, int, int] = (220, 40, 40) -DARK_RED: tuple[int, int, int] = (100, 10, 10) -BLUE: tuple[int, int, int] = (60, 60, 220) -YELLOW: tuple[int, int, int] = (230, 230, 40) -GREEN: tuple[int, int, int] = (40, 200, 40) -DARK_GREEN: tuple[int, int, int] = (10, 80, 10) -ORANGE: tuple[int, int, int] = (240, 150, 20) -PURPLE: tuple[int, int, int] = (140, 40, 140) -CYAN: tuple[int, int, int] = (40, 200, 200) -LIGHTG: tuple[int, int, int] = (144, 238, 144) -HOTPINK: tuple[int, int, int] = (255, 105, 180) # Added for plots - -GOOGLE_COLORS: list[tuple[int, int, int]] = [ - (15, 157, 88), # Green - (244, 180, 0), # Yellow - (66, 133, 244), # Blue - (219, 68, 55), # Red -] - -# Game Specific Visuals -GRID_BG_DEFAULT: tuple[int, int, int] = (20, 20, 30) -GRID_BG_GAME_OVER: tuple[int, int, int] = DARK_RED -GRID_LINE_COLOR: tuple[int, int, int] = GRAY -TRIANGLE_EMPTY_COLOR: tuple[int, int, int] = (60, 60, 70) -PREVIEW_BG: tuple[int, int, int] = (30, 30, 40) -PREVIEW_BORDER: tuple[int, int, int] = GRAY -PREVIEW_SELECTED_BORDER: tuple[int, int, int] = BLUE -PLACEMENT_VALID_COLOR: tuple[int, int, int, int] = (*GREEN, 150) # RGBA -PLACEMENT_INVALID_COLOR: tuple[int, int, int, int] = (*RED, 100) # RGBA -DEBUG_TOGGLE_COLOR: tuple[int, int, int] = YELLOW - -# --- ADDED: Colors for Progress Bar Cycling --- -PROGRESS_BAR_CYCLE_COLORS: list[tuple[int, int, int]] = [ - GREEN, - BLUE, - YELLOW, - ORANGE, - PURPLE, - CYAN, - HOTPINK, - RED, # Add red towards the end -] -# --- END ADDED --- diff --git a/alphatriangle/visualization/core/coord_mapper.py b/alphatriangle/visualization/core/coord_mapper.py deleted file mode 100644 index 2d24698..0000000 --- a/alphatriangle/visualization/core/coord_mapper.py +++ /dev/null @@ -1,71 +0,0 @@ -import pygame - -from ...config import EnvConfig -from ...structs import Triangle -from ...utils.geometry import is_point_in_polygon - - -def _calculate_render_params( - width: int, height: int, config: EnvConfig -) -> tuple[float, float, float, float]: - """Calculates scale (cw, ch) and offset (ox, oy) for rendering the grid.""" - rows, cols = config.ROWS, config.COLS - cols_eff = cols * 0.75 + 0.25 if cols > 0 else 1 - scale_w = width / cols_eff if cols_eff > 0 else 1 - scale_h = height / rows if rows > 0 else 1 - scale = max(1.0, min(scale_w, scale_h)) - cell_size = scale - grid_w_px = cols_eff * cell_size - grid_h_px = rows * cell_size - offset_x = (width - grid_w_px) / 2 - offset_y = (height - grid_h_px) / 2 - return cell_size, cell_size, offset_x, offset_y - - -def get_grid_coords_from_screen( - screen_pos: tuple[int, int], grid_area_rect: pygame.Rect, config: EnvConfig -) -> tuple[int, int] | None: - """Maps screen coordinates (relative to screen) to grid row/column.""" - if not grid_area_rect or not grid_area_rect.collidepoint(screen_pos): - return None - - local_x = screen_pos[0] - grid_area_rect.left - local_y = screen_pos[1] - grid_area_rect.top - cw, ch, ox, oy = _calculate_render_params( - grid_area_rect.width, grid_area_rect.height, config - ) - if cw <= 0 or ch <= 0: - return None - - row = int((local_y - oy) / ch) if ch > 0 else -1 - approx_col_center_index = (local_x - ox - cw / 4) / (cw * 0.75) if cw > 0 else -1 - col = int(round(approx_col_center_index)) - - for r_check in [row, row - 1, row + 1]: - if not (0 <= r_check < config.ROWS): - continue - for c_check in [col, col - 1, col + 1]: - if not (0 <= c_check < config.COLS): - continue - # Use corrected orientation check - is_up = (r_check + c_check) % 2 != 0 - temp_tri = Triangle(r_check, c_check, is_up) - pts = temp_tri.get_points(ox, oy, cw, ch) - if is_point_in_polygon((local_x, local_y), pts): - return r_check, c_check - - if 0 <= row < config.ROWS and 0 <= col < config.COLS: - return row, col - return None - - -def get_preview_index_from_screen( - screen_pos: tuple[int, int], preview_rects: dict[int, pygame.Rect] -) -> int | None: - """Maps screen coordinates to a shape preview index.""" - if not preview_rects: - return None - for idx, rect in preview_rects.items(): - if rect and rect.collidepoint(screen_pos): - return idx - return None diff --git a/alphatriangle/visualization/core/dashboard_renderer.py b/alphatriangle/visualization/core/dashboard_renderer.py deleted file mode 100644 index e04a755..0000000 --- a/alphatriangle/visualization/core/dashboard_renderer.py +++ /dev/null @@ -1,362 +0,0 @@ -# File: alphatriangle/visualization/core/dashboard_renderer.py -import logging -import math -from collections import deque -from typing import TYPE_CHECKING, Any, Optional - -import pygame -import ray # Import ray - -from ...environment import GameState -from ...stats.plotter import Plotter -from ..drawing import hud as hud_drawing -from ..ui import ProgressBar -from . import colors, layout -from .game_renderer import GameRenderer - -if TYPE_CHECKING: - from ...config import EnvConfig, ModelConfig, VisConfig - from ...stats import StatsCollectorData - - -logger = logging.getLogger(__name__) - - -class DashboardRenderer: - """ - Renders the training dashboard with minimal spacing. - Displays worker states, plots, and progress bars with specific info lines. - """ - - def __init__( - self, - screen: pygame.Surface, - vis_config: "VisConfig", - env_config: "EnvConfig", - fonts: dict[str, pygame.font.Font | None], - stats_collector_actor: ray.actor.ActorHandle | None = None, - model_config: Optional["ModelConfig"] = None, - total_params: int | None = None, - trainable_params: int | None = None, - ): - self.screen = screen - self.vis_config = vis_config - self.env_config = env_config - self.fonts = fonts - self.stats_collector_actor = stats_collector_actor - self.model_config = model_config - self.total_params = total_params - self.trainable_params = trainable_params - self.layout_rects: dict[str, pygame.Rect] | None = None - self.worker_sub_rects: dict[int, pygame.Rect] = {} - self.last_worker_grid_size = (0, 0) - self.last_num_workers = 0 - - self.single_game_renderer = GameRenderer(vis_config, env_config, fonts) - self.plotter = Plotter(plot_update_interval=0.2) - - self.progress_bar_height_per_bar = 25 - self.num_progress_bars = 2 - self.progress_bar_spacing = 2 - self.progress_bars_total_height = ( - ( - (self.progress_bar_height_per_bar * self.num_progress_bars) - + (self.progress_bar_spacing * (self.num_progress_bars - 1)) - ) - if self.num_progress_bars > 0 - else 0 - ) - - self._layout_calculated_for_size: tuple[int, int] = (0, 0) - # Don't call ensure_layout here, wait for first render - - def ensure_layout(self): - """Calculates or retrieves the main layout areas.""" - current_w, current_h = self.screen.get_size() - current_size = (current_w, current_h) - - if ( - self.layout_rects is None - or self._layout_calculated_for_size != current_size - ): - # Pass the calculated total height needed for progress bars - self.layout_rects = layout.calculate_training_layout( - current_w, - current_h, - self.vis_config, - progress_bars_total_height=self.progress_bars_total_height, - ) - self._layout_calculated_for_size = current_size - logger.info( - f"Recalculated dashboard layout for size {current_size}: {self.layout_rects}" - ) - self.last_worker_grid_size = (0, 0) - self.worker_sub_rects = {} - return self.layout_rects if self.layout_rects is not None else {} - - def _calculate_worker_sub_layout( - self, worker_grid_area: pygame.Rect, worker_ids: list[int] - ): - """Calculates the grid layout within the worker_grid_area with NO padding.""" - area_w, area_h = worker_grid_area.size - num_workers = len(worker_ids) - - if ( - area_w, - area_h, - ) == self.last_worker_grid_size and num_workers == self.last_num_workers: - return - - logger.debug( - f"Recalculating worker sub-layout for {num_workers} workers in area {area_w}x{area_h}" - ) - self.last_worker_grid_size = (area_w, area_h) - self.last_num_workers = num_workers - self.worker_sub_rects = {} - - if area_h <= 10 or area_w <= 10 or num_workers <= 0: - if num_workers > 0: - logger.warning( - f"Worker grid area too small ({area_w}x{area_h}). Cannot calculate sub-layout." - ) - return - - aspect_ratio = area_w / max(1, area_h) - cols = math.ceil(math.sqrt(num_workers * aspect_ratio)) - rows = math.ceil(num_workers / cols) - - cols = max(1, cols) - rows = max(1, rows) - - cell_w = max(1, area_w / cols) - cell_h = max(1, area_h / rows) - - min_cell_w, min_cell_h = 60, 40 - if cell_w < min_cell_w or cell_h < min_cell_h: - logger.warning( - f"Worker grid cells potentially too small ({cell_w:.1f}x{cell_h:.1f})." - ) - - logger.info( - f"Calculated worker sub-layout (no pad): {rows}x{cols} for {num_workers} workers. Cell: {cell_w:.1f}x{cell_h:.1f}" - ) - - sorted_worker_ids = sorted(worker_ids) - for i, worker_id in enumerate(sorted_worker_ids): - row = i // cols - col = i % cols - worker_area_x = int(worker_grid_area.left + col * cell_w) - worker_area_y = int(worker_grid_area.top + row * cell_h) - worker_w = int((col + 1) * cell_w) - int(col * cell_w) - worker_h = int((row + 1) * cell_h) - int(row * cell_h) - - worker_rect = pygame.Rect(worker_area_x, worker_area_y, worker_w, worker_h) - self.worker_sub_rects[worker_id] = worker_rect.clip(worker_grid_area) - - def render( - self, - worker_states: dict[int, GameState], - global_stats: dict[str, Any] | None = None, - ): - """Renders the entire training dashboard.""" - self.screen.fill(colors.DARK_GRAY) - layout_rects = self.ensure_layout() - if not layout_rects: - return - - worker_grid_area = layout_rects.get("worker_grid") - plots_rect = layout_rects.get("plots") - progress_bar_area_rect = layout_rects.get("progress_bar_area") - - worker_step_stats = ( - global_stats.get("worker_step_stats", {}) if global_stats else {} - ) - - # --- Render Worker Grid Area --- - if ( - worker_grid_area - and worker_grid_area.width > 0 - and worker_grid_area.height > 0 - ): - worker_ids = list(worker_states.keys()) - if not worker_ids and global_stats and "num_workers" in global_stats: - worker_ids = list(range(global_stats["num_workers"])) - - self._calculate_worker_sub_layout(worker_grid_area, worker_ids) - - for worker_id in self.worker_sub_rects: - worker_area_rect = self.worker_sub_rects[worker_id] - game_state = worker_states.get(worker_id) - step_stats = worker_step_stats.get(worker_id) - self.single_game_renderer.render_worker_state( - self.screen, - worker_area_rect, - worker_id, - game_state, - worker_step_stats=step_stats, - ) - pygame.draw.rect(self.screen, colors.GRAY, worker_area_rect, 1) - else: - logger.warning("Worker grid area not available or too small.") - - # --- Render Plot Area --- - if global_stats: - plot_surface = None - if plots_rect and plots_rect.width > 0 and plots_rect.height > 0: - stats_data_for_plot: StatsCollectorData | None = global_stats.get( - "stats_data" - ) - if stats_data_for_plot is not None: - has_any_metric_data = any( - isinstance(dq, deque) and dq - for dq in stats_data_for_plot.values() - ) - if has_any_metric_data: - plot_surface = self.plotter.get_plot_surface( - stats_data_for_plot, - int(plots_rect.width), - int(plots_rect.height), - ) - else: - logger.debug( - "Plot data received but all metric deques are empty." - ) - else: - logger.debug( - "No 'stats_data' key found in global_stats for plotting." - ) - - if plot_surface: - self.screen.blit(plot_surface, plots_rect.topleft) - else: - pygame.draw.rect(self.screen, colors.DARK_GRAY, plots_rect) - plot_font = self.fonts.get("help") - if plot_font: - wait_text = ( - "Plot Area (Waiting for data...)" - if stats_data_for_plot is None - else "Plot Area (No data yet)" - ) - wait_surf = plot_font.render(wait_text, True, colors.LIGHT_GRAY) - wait_rect = wait_surf.get_rect(center=plots_rect.center) - self.screen.blit(wait_surf, wait_rect) - pygame.draw.rect(self.screen, colors.GRAY, plots_rect, 1) - - # --- Render Progress Bars (in their dedicated area) --- - if progress_bar_area_rect: - current_y = ( - progress_bar_area_rect.top - ) # Start at the top of the PB area - progress_bar_font = self.fonts.get("help") - - if progress_bar_font: - bar_width = progress_bar_area_rect.width # Use the area width - bar_x = progress_bar_area_rect.left - bar_height = self.progress_bar_height_per_bar - - # --- Generate Info Strings for Each Bar --- - train_bar_info_str = "" - buffer_bar_info_str = "" - - # Info for Training Bar (Model/Params) - train_info_parts = [] - if self.model_config: - model_str = f"CNN:{len(self.model_config.CONV_FILTERS)}L" - if self.model_config.NUM_RESIDUAL_BLOCKS > 0: - model_str += ( - f"/Res:{self.model_config.NUM_RESIDUAL_BLOCKS}L" - ) - if ( - self.model_config.USE_TRANSFORMER - and self.model_config.TRANSFORMER_LAYERS > 0 - ): - model_str += f"/TF:{self.model_config.TRANSFORMER_LAYERS}L" - train_info_parts.append(model_str) - if self.total_params is not None: - train_info_parts.append( - f"Params:{self.total_params / 1e6:.1f}M" - ) - train_bar_info_str = " | ".join(train_info_parts) - - # Info for Buffer Bar (Weight Updates, Episodes, Sims, Workers) - buffer_info_parts = [] - # Use .get with default '?' for robustness - updates = global_stats.get("worker_weight_updates", "?") - episodes = global_stats.get("total_episodes", "?") - sims = global_stats.get("total_simulations_run", "?") # Correct key - num_workers = global_stats.get("num_workers", "?") - pending_tasks = global_stats.get("pending_tasks", "?") - - buffer_info_parts.append(f"Weight Updates:{updates}") - buffer_info_parts.append(f"Episodes:{episodes}") - if isinstance(sims, int | float): - sims_str = ( - f"{sims / 1e6:.1f}M" - if sims >= 1e6 - else ( - f"{sims / 1e3:.0f}k" if sims >= 1000 else str(int(sims)) - ) - ) - buffer_info_parts.append(f"Simulations:{sims_str}") - else: - buffer_info_parts.append(f"Simulations:{sims}") - - # --- CHANGED: Robust worker status formatting --- - if isinstance(pending_tasks, int) and isinstance(num_workers, int): - buffer_info_parts.append( - f"Workers:{pending_tasks}/{num_workers}" - ) - else: - buffer_info_parts.append( - f"Workers:{pending_tasks or '?'}/{num_workers or '?'}" - ) - # --- END CHANGED --- - - buffer_bar_info_str = " | ".join(buffer_info_parts) - # --- End Generate Info Strings --- - - # Training Progress Bar (with model/param info) - train_progress = global_stats.get("train_progress") - if isinstance(train_progress, ProgressBar): - train_progress.render( - self.screen, - (bar_x, current_y), - int(bar_width), - bar_height, - progress_bar_font, - border_radius=3, - info_line=train_bar_info_str, # Pass specific info - ) - current_y += bar_height + self.progress_bar_spacing - else: - logger.debug( - "Train progress bar data not available or invalid type." - ) - - # Buffer Progress Bar (with global stats info) - buffer_progress = global_stats.get("buffer_progress") - if isinstance(buffer_progress, ProgressBar): - buffer_progress.render( - self.screen, - (bar_x, current_y), - int(bar_width), - bar_height, - progress_bar_font, - border_radius=3, - info_line=buffer_bar_info_str, # Pass specific info - ) - else: - logger.debug( - "Buffer progress bar data not available or invalid type." - ) - - elif not global_stats: - logger.debug("No global_stats provided to DashboardRenderer.") - - # --- Render HUD (Help Text Only) --- - hud_drawing.render_hud( - self.screen, - mode="training_visual", - fonts=self.fonts, - display_stats=None, - ) diff --git a/alphatriangle/visualization/core/fonts.py b/alphatriangle/visualization/core/fonts.py deleted file mode 100644 index b6addd1..0000000 --- a/alphatriangle/visualization/core/fonts.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging - -import pygame - -logger = logging.getLogger(__name__) - -DEFAULT_FONT_NAME = None -FALLBACK_FONT_NAME = "arial,freesans" - - -def load_single_font(name: str | None, size: int) -> pygame.font.Font | None: - """Loads a single font, handling potential errors.""" - try: - font = pygame.font.SysFont(name, size) - return font - except Exception as e: - logger.error(f"Error loading font '{name}' size {size}: {e}") - if name != FALLBACK_FONT_NAME: - logger.warning(f"Attempting fallback font: {FALLBACK_FONT_NAME}") - try: - font = pygame.font.SysFont(FALLBACK_FONT_NAME, size) - logger.info(f"Loaded fallback font: {FALLBACK_FONT_NAME} size {size}") - return font - except Exception as e_fallback: - logger.error(f"Fallback font failed: {e_fallback}") - return None - return None - - -def load_fonts( - font_sizes: dict[str, int] | None = None, -) -> dict[str, pygame.font.Font | None]: - """Loads standard game fonts.""" - if font_sizes is None: - font_sizes = { - "ui": 24, - "score": 30, - "help": 18, - "title": 48, - } - - fonts: dict[str, pygame.font.Font | None] = {} - required_fonts = ["score", "help"] - - logger.info("Loading fonts...") - for name, size in font_sizes.items(): - fonts[name] = load_single_font(DEFAULT_FONT_NAME, size) - - for name in required_fonts: - if fonts.get(name) is None: - logger.critical( - f"Essential font '{name}' failed to load. Text rendering will be affected." - ) - - return fonts diff --git a/alphatriangle/visualization/core/game_renderer.py b/alphatriangle/visualization/core/game_renderer.py deleted file mode 100644 index 41473a8..0000000 --- a/alphatriangle/visualization/core/game_renderer.py +++ /dev/null @@ -1,150 +0,0 @@ -import logging -from typing import TYPE_CHECKING, Any - -import pygame - -from ...environment import GameState -from ..drawing import grid as grid_drawing -from ..drawing import previews as preview_drawing -from . import colors - -if TYPE_CHECKING: - from ...config import EnvConfig, VisConfig - -logger = logging.getLogger(__name__) - - -class GameRenderer: - """ - Renders a single GameState (grid and previews) within a specified area. - Used by DashboardRenderer for displaying worker states. - Also displays latest per-step stats for the worker. - """ - - def __init__( - self, - vis_config: "VisConfig", - env_config: "EnvConfig", - fonts: dict[str, pygame.font.Font | None], - ): - self.vis_config = vis_config - self.env_config = env_config - self.fonts = fonts - self.preview_width_ratio = 0.2 - self.min_preview_width = 30 - self.max_preview_width = 60 - self.padding = 5 - - def render_worker_state( - self, - target_surface: pygame.Surface, - area_rect: pygame.Rect, - worker_id: int, - game_state: GameState | None, - # Add worker_step_stats parameter - worker_step_stats: dict[str, Any] | None = None, - ): - """ - Renders the game state of a single worker into the specified area_rect - on the target_surface. Includes per-step stats display. - """ - if not game_state: - # Optionally draw a placeholder if state is None - pygame.draw.rect(target_surface, colors.DARK_GRAY, area_rect) - pygame.draw.rect(target_surface, colors.GRAY, area_rect, 1) - id_font = self.fonts.get("help") - if id_font: - id_surf = id_font.render( - f"W{worker_id} (No State)", True, colors.LIGHT_GRAY - ) - id_rect = id_surf.get_rect(center=area_rect.center) - target_surface.blit(id_surf, id_rect) - return - - # Calculate layout within the worker's area_rect - preview_w = max( - self.min_preview_width, - min(area_rect.width * self.preview_width_ratio, self.max_preview_width), - ) - grid_w = max(0, area_rect.width - preview_w - self.padding) - grid_h = area_rect.height - preview_h = area_rect.height - - grid_rect_local = pygame.Rect(0, 0, grid_w, grid_h) - preview_rect_local = pygame.Rect(grid_w + self.padding, 0, preview_w, preview_h) - - # Create subsurfaces relative to the target_surface - try: - worker_surface = target_surface.subsurface(area_rect) - worker_surface.fill(colors.DARK_GRAY) # Background for the worker area - pygame.draw.rect( - target_surface, colors.GRAY, area_rect, 1 - ) # Border around worker area - - # Render Grid - if grid_rect_local.width > 0 and grid_rect_local.height > 0: - grid_surf = worker_surface.subsurface(grid_rect_local) - bg_color = ( - colors.GRID_BG_GAME_OVER - if game_state.is_over() - else colors.GRID_BG_DEFAULT - ) - grid_drawing.draw_grid_background(grid_surf, bg_color) - grid_drawing.draw_grid_triangles( - grid_surf, game_state.grid_data, self.env_config - ) - - # --- Render Worker Info Text --- - id_font = self.fonts.get("help") - if id_font: - line_y = 3 - line_height = id_font.get_height() + 1 - - # Worker ID and Game Step - id_text = f"W{worker_id} (Step {game_state.current_step})" - id_surf = id_font.render(id_text, True, colors.LIGHT_GRAY) - grid_surf.blit(id_surf, (3, line_y)) - line_y += line_height - - # Current Score - score_text = f"Score: {game_state.game_score:.0f}" - score_surf = id_font.render(score_text, True, colors.YELLOW) - grid_surf.blit(score_surf, (3, line_y)) - line_y += line_height - - # MCTS Stats (if available) - if worker_step_stats: - visits = worker_step_stats.get("mcts_visits", "?") - depth = worker_step_stats.get("mcts_depth", "?") - mcts_text = f"MCTS: V={visits} D={depth}" - mcts_surf = id_font.render(mcts_text, True, colors.CYAN) - grid_surf.blit(mcts_surf, (3, line_y)) - line_y += line_height - - # Render Previews - if preview_rect_local.width > 0 and preview_rect_local.height > 0: - preview_surf = worker_surface.subsurface(preview_rect_local) - # Pass 0,0 as topleft because preview_surf is already positioned - _ = preview_drawing.render_previews( - preview_surf, - game_state, - (0, 0), # Relative to preview_surf - "training_visual", # Mode context - self.env_config, - self.vis_config, - selected_shape_idx=-1, # No selection in training view - ) - - except ValueError as e: - # Handle cases where subsurface creation fails (e.g., zero dimensions) - if "subsurface rectangle is invalid" not in str(e): - logger.error( - f"Error creating subsurface for W{worker_id} in area {area_rect}: {e}" - ) - # Draw error indicator if subsurface fails - pygame.draw.rect(target_surface, colors.RED, area_rect, 2) - id_font = self.fonts.get("help") - if id_font: - id_surf = id_font.render(f"W{worker_id} (Render Err)", True, colors.RED) - id_rect = id_surf.get_rect(center=area_rect.center) - target_surface.blit(id_surf, id_rect) diff --git a/alphatriangle/visualization/core/layout.py b/alphatriangle/visualization/core/layout.py deleted file mode 100644 index b7561d7..0000000 --- a/alphatriangle/visualization/core/layout.py +++ /dev/null @@ -1,107 +0,0 @@ -# File: alphatriangle/visualization/core/layout.py -# Changes: -# - Position progress_bar_area_rect precisely above the HUD. -# - Calculate plot_rect height to fill the gap between worker grid and progress bars. - -import logging - -import pygame - -from ...config import VisConfig - -logger = logging.getLogger(__name__) - - -def calculate_interactive_layout( - screen_width: int, screen_height: int, vis_config: VisConfig -) -> dict[str, pygame.Rect]: - """ - Calculates layout rectangles for interactive modes (play/debug). - Places grid on the left and preview on the right. - """ - sw, sh = screen_width, screen_height - pad = vis_config.PADDING - hud_h = vis_config.HUD_HEIGHT - preview_w = vis_config.PREVIEW_AREA_WIDTH - - available_h = max(0, sh - hud_h - 2 * pad) - available_w = max(0, sw - 3 * pad) - - grid_w = max(0, available_w - preview_w) - grid_h = available_h - - grid_rect = pygame.Rect(pad, pad, grid_w, grid_h) - preview_rect = pygame.Rect(grid_rect.right + pad, pad, preview_w, grid_h) - - screen_rect = pygame.Rect(0, 0, sw, sh) - grid_rect = grid_rect.clip(screen_rect) - preview_rect = preview_rect.clip(screen_rect) - - logger.debug( - f"Interactive Layout calculated: Grid={grid_rect}, Preview={preview_rect}" - ) - - return { - "grid": grid_rect, - "preview": preview_rect, - } - - -def calculate_training_layout( - screen_width: int, - screen_height: int, - vis_config: VisConfig, - progress_bars_total_height: int, # Height needed for progress bars -) -> dict[str, pygame.Rect]: - """ - Calculates layout rectangles for training visualization mode. MINIMAL SPACING. - Worker grid top, progress bars bottom (above HUD), plots fill middle. - """ - sw, sh = screen_width, screen_height - pad = 2 # Minimal padding - hud_h = vis_config.HUD_HEIGHT - - # --- Worker Grid Area (Top) --- - # Calculate available height excluding HUD and minimal padding - total_available_h_for_grid_plots_bars = max(0, sh - hud_h - 2 * pad) - top_area_h = min( - int(total_available_h_for_grid_plots_bars * 0.10), 80 - ) # 10% or 80px max - top_area_w = sw - 2 * pad - worker_grid_rect = pygame.Rect(pad, pad, top_area_w, top_area_h) - - # --- Progress Bar Area (Bottom, above HUD) --- - # Position it precisely based on its required height - pb_area_y = sh - hud_h - pad - progress_bars_total_height - pb_area_w = sw - 2 * pad - progress_bar_area_rect = pygame.Rect( - pad, pb_area_y, pb_area_w, progress_bars_total_height - ) - - # --- Plot Area (Middle) --- - # Calculate height to fill the gap precisely - plot_area_y = worker_grid_rect.bottom + pad - plot_area_w = sw - 2 * pad - plot_area_h = max( - 0, progress_bar_area_rect.top - plot_area_y - pad - ) # Fill space between worker grid and progress bars - plot_rect = pygame.Rect(pad, plot_area_y, plot_area_w, plot_area_h) - - # Clip all rects to screen bounds - screen_rect = pygame.Rect(0, 0, sw, sh) - worker_grid_rect = worker_grid_rect.clip(screen_rect) - plot_rect = plot_rect.clip(screen_rect) - progress_bar_area_rect = progress_bar_area_rect.clip(screen_rect) - - logger.debug( - f"Training Layout calculated (Compact V3): WorkerGrid={worker_grid_rect}, PlotRect={plot_rect}, ProgressBarArea={progress_bar_area_rect}" - ) - - return { - "worker_grid": worker_grid_rect, - "plots": plot_rect, - "progress_bar_area": progress_bar_area_rect, # Use this rect for drawing PBs - } - - -calculate_layout = calculate_training_layout diff --git a/alphatriangle/visualization/core/visualizer.py b/alphatriangle/visualization/core/visualizer.py deleted file mode 100644 index e1b0960..0000000 --- a/alphatriangle/visualization/core/visualizer.py +++ /dev/null @@ -1,221 +0,0 @@ -import logging -from typing import TYPE_CHECKING - -import pygame - -from ...structs import Shape -from ..drawing import grid as grid_drawing -from ..drawing import highlight as highlight_drawing -from ..drawing import hud as hud_drawing -from ..drawing import previews as preview_drawing -from ..drawing.previews import ( - draw_floating_preview, - draw_placement_preview, -) -from . import colors, layout - -if TYPE_CHECKING: - from ...config import EnvConfig, VisConfig - from ...environment.core.game_state import GameState - -logger = logging.getLogger(__name__) - - -class Visualizer: - """ - Orchestrates rendering of a single game state for interactive modes. - Receives interaction state (hover, selection) via render parameters. - """ - - def __init__( - self, - screen: pygame.Surface, - vis_config: "VisConfig", - env_config: "EnvConfig", - fonts: dict[str, pygame.font.Font | None], - ): - self.screen = screen - self.vis_config = vis_config - self.env_config = env_config - self.fonts = fonts - self.layout_rects: dict[str, pygame.Rect] | None = None - self.preview_rects: dict[int, pygame.Rect] = {} # Cache preview rects - self._layout_calculated_for_size: tuple[int, int] = (0, 0) - self.ensure_layout() # Initial layout calculation - - def ensure_layout(self) -> dict[str, pygame.Rect]: - """Returns cached layout or calculates it if needed.""" - current_w, current_h = self.screen.get_size() - current_size = (current_w, current_h) - - if ( - self.layout_rects is None - or self._layout_calculated_for_size != current_size - ): - # Use the interactive layout calculation - self.layout_rects = layout.calculate_interactive_layout( - current_w, current_h, self.vis_config - ) - self._layout_calculated_for_size = current_size - logger.info( - f"Recalculated interactive layout for size {current_size}: {self.layout_rects}" - ) - # Clear preview rect cache when layout changes - self.preview_rects = {} - - return self.layout_rects if self.layout_rects is not None else {} - - def render( - self, - game_state: "GameState", - mode: str, - # Interaction state passed in: - selected_shape_idx: int = -1, - hover_shape: Shape | None = None, - hover_grid_coord: tuple[int, int] | None = None, - hover_is_valid: bool = False, - hover_screen_pos: tuple[int, int] | None = None, - debug_highlight_coord: tuple[int, int] | None = None, - ): - """ - Renders the entire game visualization for interactive modes. - Uses interaction state passed as parameters for visual feedback. - """ - self.screen.fill(colors.GRID_BG_DEFAULT) # Clear screen - layout_rects = self.ensure_layout() - grid_rect = layout_rects.get("grid") - preview_rect = layout_rects.get("preview") - - # Render Grid Area - if grid_rect and grid_rect.width > 0 and grid_rect.height > 0: - try: - grid_surf = self.screen.subsurface(grid_rect) - self._render_grid_area( - grid_surf, - game_state, - mode, - grid_rect, - hover_shape, - hover_grid_coord, - hover_is_valid, - hover_screen_pos, - debug_highlight_coord, - ) - except ValueError as e: - logger.error(f"Error creating grid subsurface ({grid_rect}): {e}") - pygame.draw.rect(self.screen, colors.RED, grid_rect, 1) - - # Render Preview Area - if preview_rect and preview_rect.width > 0 and preview_rect.height > 0: - try: - preview_surf = self.screen.subsurface(preview_rect) - # Pass selected_shape_idx for highlighting - self._render_preview_area( - preview_surf, game_state, mode, preview_rect, selected_shape_idx - ) - except ValueError as e: - logger.error(f"Error creating preview subsurface ({preview_rect}): {e}") - pygame.draw.rect(self.screen, colors.RED, preview_rect, 1) - - # Render HUD - # --- CORRECTED CALL --- - hud_drawing.render_hud( - surface=self.screen, - mode=mode, - fonts=self.fonts, - display_stats=None, # Pass None for display_stats in interactive modes - ) - # --- END CORRECTION --- - - def _render_grid_area( - self, - grid_surf: pygame.Surface, - game_state: "GameState", - mode: str, - grid_rect: pygame.Rect, # Pass grid_rect for hover calculations - hover_shape: Shape | None, - hover_grid_coord: tuple[int, int] | None, - hover_is_valid: bool, - hover_screen_pos: tuple[int, int] | None, - debug_highlight_coord: tuple[int, int] | None, - ): - """Renders the main game grid and overlays onto the provided grid_surf.""" - # Background - bg_color = ( - colors.GRID_BG_GAME_OVER if game_state.is_over() else colors.GRID_BG_DEFAULT - ) - grid_drawing.draw_grid_background(grid_surf, bg_color) - - # Grid Triangles - grid_drawing.draw_grid_triangles( - grid_surf, game_state.grid_data, self.env_config - ) - - # Debug Indices - if mode == "debug": - grid_drawing.draw_grid_indices( - grid_surf, game_state.grid_data, self.env_config, self.fonts - ) - - # Play Mode Hover Previews - if mode == "play" and hover_shape: - if hover_grid_coord: # Snapped preview - draw_placement_preview( - grid_surf, - hover_shape, - hover_grid_coord[0], - hover_grid_coord[1], - is_valid=hover_is_valid, # Use validity passed in - config=self.env_config, - ) - elif hover_screen_pos: # Floating preview (relative to grid_surf) - # Adjust screen pos to be relative to grid_surf - local_hover_pos = ( - hover_screen_pos[0] - grid_rect.left, - hover_screen_pos[1] - grid_rect.top, - ) - if grid_surf.get_rect().collidepoint(local_hover_pos): - draw_floating_preview( - grid_surf, - hover_shape, - local_hover_pos, - self.env_config, - ) - - # Debug Mode Highlight - if mode == "debug" and debug_highlight_coord: - r, c = debug_highlight_coord - highlight_drawing.draw_debug_highlight(grid_surf, r, c, self.env_config) - - # --- ADDED: Display Score in Grid Area for Interactive Modes --- - score_font = self.fonts.get("score") - if score_font: - score_text = f"Score: {game_state.game_score:.0f}" - score_surf = score_font.render(score_text, True, colors.YELLOW) - # Position score at top-left of grid area - score_rect = score_surf.get_rect(topleft=(5, 5)) - grid_surf.blit(score_surf, score_rect) - # --- END ADDED --- - - def _render_preview_area( - self, - preview_surf: pygame.Surface, - game_state: "GameState", - mode: str, - preview_rect: pygame.Rect, - selected_shape_idx: int, # Pass selected index - ): - """Renders the shape preview slots onto preview_surf and caches rects.""" - # Pass selected_shape_idx to render_previews for highlighting - current_preview_rects = preview_drawing.render_previews( - preview_surf, - game_state, - preview_rect.topleft, # Pass absolute top-left - mode, - self.env_config, - self.vis_config, - selected_shape_idx=selected_shape_idx, # Pass selection state - ) - # Update cache only if it changed (or first time) - if not self.preview_rects or self.preview_rects != current_preview_rects: - self.preview_rects = current_preview_rects diff --git a/alphatriangle/visualization/drawing/README.md b/alphatriangle/visualization/drawing/README.md deleted file mode 100644 index ff78b33..0000000 --- a/alphatriangle/visualization/drawing/README.md +++ /dev/null @@ -1,56 +0,0 @@ -# File: alphatriangle/visualization/drawing/README.md -# Visualization Drawing Submodule (`alphatriangle.visualization.drawing`) - -## Purpose and Architecture - -This submodule contains specialized functions responsible for drawing specific visual elements of the game onto Pygame surfaces. These functions are typically called by the core renderers (`Visualizer`, `GameRenderer`) in [`alphatriangle.visualization.core`](../core/README.md). Separating drawing logic makes the core renderers cleaner and promotes reusability of drawing code. - -- **[`grid.py`](grid.py):** Functions for drawing the grid background (`draw_grid_background`), the individual triangles within it colored based on occupancy/emptiness (`draw_grid_triangles`), and optional indices (`draw_grid_indices`). Uses `Triangle` from [`alphatriangle.structs`](../../structs/README.md) and `GridData` from [`alphatriangle.environment`](../../environment/README.md). -- **[`shapes.py`](shapes.py):** Contains `draw_shape`, a function to render a given `Shape` object at a specific location on a surface (used primarily for previews). Uses `Shape` and `Triangle` from [`alphatriangle.structs`](../../structs/README.md). -- **[`previews.py`](previews.py):** Handles rendering related to shape previews: - - `render_previews`: Draws the dedicated preview area, including borders and the shapes within their slots, handling selection highlights. Uses `Shape` from [`alphatriangle.structs`](../../structs/README.md). - - `draw_placement_preview`: Draws a semi-transparent version of a shape snapped to the grid, indicating a potential placement location (used in play mode hover). Uses `Shape` and `Triangle` from [`alphatriangle.structs`](../../structs/README.md). - - `draw_floating_preview`: Draws a semi-transparent shape directly under the mouse cursor when hovering over the grid but not snapped (used in play mode hover). Uses `Shape` and `Triangle` from [`alphatriangle.structs`](../../structs/README.md). -- **[`hud.py`](hud.py):** `render_hud` draws Heads-Up Display elements like the game score, help text, and optional training statistics onto the main screen surface. -- **[`highlight.py`](highlight.py):** `draw_debug_highlight` draws a distinct border around a specific triangle, used for visual feedback in debug mode. Uses `Triangle` from [`alphatriangle.structs`](../../structs/README.md). -- **[`utils.py`](utils.py):** Contains general drawing utility functions (currently empty). - -## Exposed Interfaces - -- **Grid Drawing:** - - `draw_grid_background(surface: pygame.Surface, bg_color: tuple)` - - `draw_grid_triangles(surface: pygame.Surface, grid_data: GridData, config: EnvConfig)` - - `draw_grid_indices(surface: pygame.Surface, grid_data: GridData, config: EnvConfig, fonts: Dict[str, Optional[pygame.font.Font]])` -- **Shape Drawing:** - - `draw_shape(surface: pygame.Surface, shape: Shape, topleft: Tuple[int, int], cell_size: float, is_selected: bool = False, origin_offset: Tuple[int, int] = (0, 0))` -- **Preview Drawing:** - - `render_previews(surface: pygame.Surface, game_state: GameState, area_topleft: Tuple[int, int], mode: str, env_config: EnvConfig, vis_config: VisConfig, selected_shape_idx: int = -1) -> Dict[int, pygame.Rect]` - - `draw_placement_preview(surface: pygame.Surface, shape: Shape, r: int, c: int, is_valid: bool, config: EnvConfig)` - - `draw_floating_preview(surface: pygame.Surface, shape: Shape, screen_pos: Tuple[int, int], config: EnvConfig)` -- **HUD Drawing:** - - `render_hud(surface: pygame.Surface, mode: str, fonts: Dict[str, Optional[pygame.font.Font]], display_stats: Optional[Dict[str, Any]] = None)` -- **Highlight Drawing:** - - `draw_debug_highlight(surface: pygame.Surface, r: int, c: int, config: EnvConfig)` -- **Utility Functions:** - - (Currently empty or contains other drawing-specific utils) - -## Dependencies - -- **[`alphatriangle.visualization.core`](../core/README.md)**: - - `colors`: Used extensively for drawing colors. - - `coord_mapper`: Used internally (e.g., by `draw_placement_preview`) or relies on its calculations passed in. -- **[`alphatriangle.config`](../../config/README.md)**: - - `EnvConfig`, `VisConfig`: Provide dimensions, padding, etc., needed for drawing calculations. -- **[`alphatriangle.environment`](../../environment/README.md)**: - - `GameState`, `GridData`: Provide the data to be drawn. -- **[`alphatriangle.structs`](../../structs/README.md)**: - - Uses `Triangle`, `Shape`, `COLOR_ID_MAP`, `DEBUG_COLOR_ID`, `NO_COLOR_ID`. -- **[`alphatriangle.visualization.ui`](../ui/README.md)**: - - `ProgressBar` (used by `hud.py`). -- **`pygame`**: - - The core library used for all drawing operations (`pygame.draw.polygon`, `surface.fill`, `surface.blit`, etc.) and font rendering. -- **Standard Libraries:** `typing`, `logging`, `math`. - ---- - -**Note:** Please keep this README updated when adding new drawing functions, modifying existing ones, or changing their dependencies on configuration or environment data structures. Accurate documentation is crucial for maintainability. \ No newline at end of file diff --git a/alphatriangle/visualization/drawing/__init__.py b/alphatriangle/visualization/drawing/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/alphatriangle/visualization/drawing/grid.py b/alphatriangle/visualization/drawing/grid.py deleted file mode 100644 index 7e09b4b..0000000 --- a/alphatriangle/visualization/drawing/grid.py +++ /dev/null @@ -1,134 +0,0 @@ -# File: alphatriangle/visualization/drawing/grid.py -from typing import TYPE_CHECKING - -import pygame - -# Import constants from the structs package directly -from ...structs import COLOR_ID_MAP, DEBUG_COLOR_ID, NO_COLOR_ID, Triangle -from ..core import colors, coord_mapper - -if TYPE_CHECKING: - from ...config import EnvConfig - from ...environment import GridData - - -def draw_grid_background(surface: pygame.Surface, bg_color: tuple) -> None: - """Fills the grid area surface with a background color.""" - surface.fill(bg_color) - - -def draw_grid_triangles( - surface: pygame.Surface, grid_data: "GridData", config: "EnvConfig" -) -> None: - """Draws all triangles (empty, occupied, death) on the grid surface using NumPy state.""" - if surface.get_width() <= 0 or surface.get_height() <= 0: - return - - cw, ch, ox, oy = coord_mapper._calculate_render_params( - surface.get_width(), surface.get_height(), config - ) - if cw <= 0 or ch <= 0: - return - - # Get direct references to NumPy arrays - occupied_np = grid_data._occupied_np - death_np = grid_data._death_np - color_id_np = grid_data._color_id_np - - for r in range(grid_data.rows): - for c in range(grid_data.cols): - is_death = death_np[r, c] - is_occupied = occupied_np[r, c] - color_id = color_id_np[r, c] - is_up = (r + c) % 2 != 0 # Calculate orientation - - color: tuple[int, int, int] | None = None - border_color = colors.GRID_LINE_COLOR - border_width = 1 - - if is_death: - color = colors.DARK_GRAY - border_color = colors.RED - elif is_occupied: - if color_id == DEBUG_COLOR_ID: - color = colors.DEBUG_TOGGLE_COLOR # Special debug color - elif color_id != NO_COLOR_ID and 0 <= color_id < len(COLOR_ID_MAP): - color = COLOR_ID_MAP[color_id] - else: - # Fallback if occupied but no valid color ID (shouldn't happen) - color = colors.PURPLE # Error color - else: # Empty playable cell - color = colors.TRIANGLE_EMPTY_COLOR - - # Create temporary Triangle only for geometry calculation - temp_tri = Triangle(r, c, is_up) - pts = temp_tri.get_points(ox, oy, cw, ch) - - if color: # Should always be true unless error - pygame.draw.polygon(surface, color, pts) - pygame.draw.polygon(surface, border_color, pts, border_width) - - -def draw_grid_indices( - surface: pygame.Surface, - grid_data: "GridData", - config: "EnvConfig", - fonts: dict[str, pygame.font.Font | None], -) -> None: - """Draws the index number inside each triangle, including death cells.""" - if surface.get_width() <= 0 or surface.get_height() <= 0: - return - - font = fonts.get("help") - if not font: - return - - cw, ch, ox, oy = coord_mapper._calculate_render_params( - surface.get_width(), surface.get_height(), config - ) - if cw <= 0 or ch <= 0: - return - - # Get direct references to NumPy arrays - occupied_np = grid_data._occupied_np - death_np = grid_data._death_np - color_id_np = grid_data._color_id_np - - for r in range(grid_data.rows): - for c in range(grid_data.cols): - is_death = death_np[r, c] - is_occupied = occupied_np[r, c] - color_id = color_id_np[r, c] - is_up = (r + c) % 2 != 0 # Calculate orientation - - # Create temporary Triangle only for geometry calculation - temp_tri = Triangle(r, c, is_up) - pts = temp_tri.get_points(ox, oy, cw, ch) - center_x = sum(p[0] for p in pts) / 3 - center_y = sum(p[1] for p in pts) / 3 - - text_color = colors.WHITE # Default - - if is_death: - text_color = colors.LIGHT_GRAY - elif is_occupied: - bg_color: tuple[int, int, int] | None = None - if color_id == DEBUG_COLOR_ID: - bg_color = colors.DEBUG_TOGGLE_COLOR - elif color_id != NO_COLOR_ID and 0 <= color_id < len(COLOR_ID_MAP): - bg_color = COLOR_ID_MAP[color_id] - - if bg_color: - brightness = sum(bg_color) / 3 - text_color = colors.WHITE if brightness < 128 else colors.BLACK - else: # Fallback if color missing - text_color = colors.RED - else: # Empty playable - bg_color = colors.TRIANGLE_EMPTY_COLOR - brightness = sum(bg_color) / 3 - text_color = colors.WHITE if brightness < 128 else colors.BLACK - - index = r * config.COLS + c - text_surf = font.render(str(index), True, text_color) - text_rect = text_surf.get_rect(center=(center_x, center_y)) - surface.blit(text_surf, text_rect) diff --git a/alphatriangle/visualization/drawing/highlight.py b/alphatriangle/visualization/drawing/highlight.py deleted file mode 100644 index a77e09f..0000000 --- a/alphatriangle/visualization/drawing/highlight.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import TYPE_CHECKING - -import pygame - -from ...structs import Triangle -from ..core import colors, coord_mapper - -if TYPE_CHECKING: - from ...config import EnvConfig - - -def draw_debug_highlight( - surface: pygame.Surface, r: int, c: int, config: "EnvConfig" -) -> None: - """Highlights a specific triangle border for debugging.""" - if surface.get_width() <= 0 or surface.get_height() <= 0: - return - - cw, ch, ox, oy = coord_mapper._calculate_render_params( - surface.get_width(), surface.get_height(), config - ) - if cw <= 0 or ch <= 0: - return - - is_up = (r + c) % 2 != 0 - temp_tri = Triangle(r, c, is_up) - pts = temp_tri.get_points(ox, oy, cw, ch) - - pygame.draw.polygon(surface, colors.DEBUG_TOGGLE_COLOR, pts, 3) diff --git a/alphatriangle/visualization/drawing/hud.py b/alphatriangle/visualization/drawing/hud.py deleted file mode 100644 index d09ff0b..0000000 --- a/alphatriangle/visualization/drawing/hud.py +++ /dev/null @@ -1,94 +0,0 @@ -# File: alphatriangle/visualization/drawing/hud.py -from typing import Any - -import pygame - -from ..core import colors -from ..ui import ProgressBar - - -def render_hud( - surface: pygame.Surface, - mode: str, - fonts: dict[str, pygame.font.Font | None], - display_stats: dict[str, Any] | None = None, -) -> None: - """ - Renders global information (like step count, worker status) at the bottom. - Individual game scores are not shown here anymore. - Help text is skipped in training_visual mode. - """ - screen_w, screen_h = surface.get_size() - help_font = fonts.get("help") - stats_font = fonts.get("help") # Use same font for stats line - step_font = fonts.get("ui") or help_font # Use UI font for step, fallback to help - - bottom_y = screen_h - 10 # Position from bottom - - stats_rect = None - step_rect = None - - # --- Render Global Step as "Weight Updates" --- - if step_font and display_stats: - train_progress = display_stats.get("train_progress") - global_step = ( - train_progress.current_steps - if isinstance(train_progress, ProgressBar) # Check type - else display_stats.get("global_step", "?") - ) - step_text = f"Weight Updates: {global_step}" - step_surf = step_font.render(step_text, True, colors.YELLOW) - step_rect = step_surf.get_rect(bottomleft=(15, bottom_y)) - surface.blit(step_surf, step_rect) - - # Render other global training stats if available, positioned after the step count - if stats_font and display_stats and step_rect: - stats_items = [] - episodes = display_stats.get("total_episodes", "?") - sims = display_stats.get("total_simulations", "?") - num_workers = display_stats.get("num_workers", "?") - pending_tasks = display_stats.get("pending_tasks", "?") - - stats_items.append(f"Episodes: {episodes}") - if isinstance(sims, int | float): - sims_str = ( - f"{sims / 1e6:.2f}M" - if sims >= 1e6 - else (f"{sims / 1e3:.1f}k" if sims >= 1000 else str(int(sims))) - ) - stats_items.append(f"Sims: {sims_str}") - else: - stats_items.append(f"Sims: {sims}") - - stats_items.append(f"Workers: {pending_tasks}/{num_workers} busy") - - stats_text = " | ".join(stats_items) - stats_surf = stats_font.render(stats_text, True, colors.CYAN) - stats_rect = stats_surf.get_rect(bottomleft=(step_rect.right + 20, bottom_y)) - surface.blit(stats_surf, stats_rect) - - # --- CHANGED: Skip help text in training visual mode --- - if help_font and mode != "training_visual": - help_text = "[ESC] Quit" - if mode == "play": - help_text += " | [Click] Select/Place Shape" - elif mode == "debug": - help_text += " | [Click] Toggle Cell" - # No need for training_visual case here anymore - - help_surf = help_font.render(help_text, True, colors.LIGHT_GRAY) - help_rect = help_surf.get_rect(bottomright=(screen_w - 15, bottom_y)) - - combined_left_width = ( - stats_rect.right if stats_rect else (step_rect.right if step_rect else 0) - ) - if combined_left_width > help_rect.left - 20: - help_rect.bottom = ( - stats_rect.top - if stats_rect - else (step_rect.top if step_rect else bottom_y) - ) - 5 - help_rect.right = screen_w - 15 - - surface.blit(help_surf, help_rect) - # --- END CHANGED --- diff --git a/alphatriangle/visualization/drawing/previews.py b/alphatriangle/visualization/drawing/previews.py deleted file mode 100644 index de6881c..0000000 --- a/alphatriangle/visualization/drawing/previews.py +++ /dev/null @@ -1,208 +0,0 @@ -import logging -from typing import TYPE_CHECKING - -import pygame - -from ...structs import Shape, Triangle -from ..core import colors, coord_mapper -from .shapes import draw_shape - -if TYPE_CHECKING: - from ...config import EnvConfig, VisConfig - from ...environment import GameState - -logger = logging.getLogger(__name__) - - -def render_previews( - surface: pygame.Surface, - game_state: "GameState", - area_topleft: tuple[int, int], - _mode: str, - env_config: "EnvConfig", - vis_config: "VisConfig", - selected_shape_idx: int = -1, -) -> dict[int, pygame.Rect]: - """Renders shape previews in their area. Returns dict {index: screen_rect}.""" - surface.fill(colors.PREVIEW_BG) - preview_rects_screen: dict[int, pygame.Rect] = {} - num_slots = env_config.NUM_SHAPE_SLOTS - pad = vis_config.PREVIEW_PADDING - inner_pad = vis_config.PREVIEW_INNER_PADDING - border = vis_config.PREVIEW_BORDER_WIDTH - selected_border = vis_config.PREVIEW_SELECTED_BORDER_WIDTH - - if num_slots <= 0: - return {} - - # Calculate dimensions for each slot - total_pad_h = (num_slots + 1) * pad - available_h = surface.get_height() - total_pad_h - slot_h = available_h / num_slots if num_slots > 0 else 0 - slot_w = surface.get_width() - 2 * pad - - current_y = float(pad) # Start y position as float - - for i in range(num_slots): - # Calculate local rectangle for the slot within the preview surface - slot_rect_local = pygame.Rect(pad, int(current_y), int(slot_w), int(slot_h)) - # Calculate screen rectangle by offsetting local rect - slot_rect_screen = slot_rect_local.move(area_topleft) - preview_rects_screen[i] = ( - slot_rect_screen # Store screen rect for interaction mapping - ) - - shape: Shape | None = game_state.shapes[i] - # Use the passed selected_shape_idx for highlighting - is_selected = selected_shape_idx == i - - # Determine border style based on selection - border_width = selected_border if is_selected else border - border_color = ( - colors.PREVIEW_SELECTED_BORDER if is_selected else colors.PREVIEW_BORDER - ) - # Draw the border rectangle onto the local preview surface - pygame.draw.rect(surface, border_color, slot_rect_local, border_width) - - # Draw the shape if it exists - if shape: - # Calculate drawing area inside the border and padding - draw_area_w = slot_w - 2 * (border_width + inner_pad) - draw_area_h = slot_h - 2 * (border_width + inner_pad) - - if draw_area_w > 0 and draw_area_h > 0: - # Calculate shape bounding box and required cell size - min_r, min_c, max_r, max_c = shape.bbox() - shape_rows = max_r - min_r + 1 - # Effective width considering triangle geometry (0.75 factor) - shape_cols_eff = ( - (max_c - min_c + 1) * 0.75 + 0.25 if shape.triangles else 1 - ) - - # Determine cell size based on available space and shape dimensions - scale_w = ( - draw_area_w / shape_cols_eff if shape_cols_eff > 0 else draw_area_w - ) - scale_h = draw_area_h / shape_rows if shape_rows > 0 else draw_area_h - cell_size = max(1.0, min(scale_w, scale_h)) # Use the smaller scale - - # Calculate centered top-left position for drawing the shape - shape_render_w = shape_cols_eff * cell_size - shape_render_h = shape_rows * cell_size - draw_topleft_x = ( - slot_rect_local.left - + border_width - + inner_pad - + (draw_area_w - shape_render_w) / 2 - ) - draw_topleft_y = ( - slot_rect_local.top - + border_width - + inner_pad - + (draw_area_h - shape_render_h) / 2 - ) - - # Draw the shape onto the local preview surface - # Cast float coordinates to int for draw_shape - # Use _is_selected to match the function signature - draw_shape( - surface, - shape, - (int(draw_topleft_x), int(draw_topleft_y)), - cell_size, - _is_selected=is_selected, - origin_offset=( - -min_r, - -min_c, - ), # Adjust drawing origin based on bbox - ) - - # Move to the next slot position - current_y += slot_h + pad - - return preview_rects_screen - - -def draw_placement_preview( - surface: pygame.Surface, - shape: "Shape", - r: int, - c: int, - is_valid: bool, - config: "EnvConfig", -) -> None: - """Draws a semi-transparent shape snapped to the grid.""" - if not shape or not shape.triangles: - return - - cw, ch, ox, oy = coord_mapper._calculate_render_params( - surface.get_width(), surface.get_height(), config - ) - if cw <= 0 or ch <= 0: - return - - # Use valid/invalid colors (could be passed in or defined here) - base_color = ( - colors.PLACEMENT_VALID_COLOR[:3] - if is_valid - else colors.PLACEMENT_INVALID_COLOR[:3] - ) - alpha = ( - colors.PLACEMENT_VALID_COLOR[3] - if is_valid - else colors.PLACEMENT_INVALID_COLOR[3] - ) - color = list(base_color) + [alpha] # Combine RGB and Alpha - - # Use a temporary surface for transparency - temp_surface = pygame.Surface(surface.get_size(), pygame.SRCALPHA) - temp_surface.fill((0, 0, 0, 0)) # Fully transparent background - - for dr, dc, is_up in shape.triangles: - tri_r, tri_c = r + dr, c + dc - # Create a temporary Triangle to get points easily - temp_tri = Triangle(tri_r, tri_c, is_up) - pts = temp_tri.get_points(ox, oy, cw, ch) - pygame.draw.polygon(temp_surface, color, pts) - - # Blit the transparent preview onto the main grid surface - surface.blit(temp_surface, (0, 0)) - - -def draw_floating_preview( - surface: pygame.Surface, - shape: "Shape", - screen_pos: tuple[int, int], # Position relative to the surface being drawn on - _config: "EnvConfig", # Mark config as unused -) -> None: - """Draws a semi-transparent shape floating at the screen position.""" - if not shape or not shape.triangles: - return - - cell_size = 20.0 # Fixed size for floating preview? Or scale based on config? - color = list(shape.color) + [100] # Base color with fixed alpha - - # Use a temporary surface for transparency - temp_surface = pygame.Surface(surface.get_size(), pygame.SRCALPHA) - temp_surface.fill((0, 0, 0, 0)) - - # Center the shape around the screen_pos - min_r, min_c, max_r, max_c = shape.bbox() - center_r = (min_r + max_r) / 2.0 - center_c = (min_c + max_c) / 2.0 - - for dr, dc, is_up in shape.triangles: - # Calculate position relative to shape center and screen_pos - pt_x = screen_pos[0] + (dc - center_c) * (cell_size * 0.75) - pt_y = screen_pos[1] + (dr - center_r) * cell_size - - # Create a temporary Triangle at origin to get relative points - temp_tri = Triangle(0, 0, is_up) - # Get points relative to 0,0 and scale - rel_pts = temp_tri.get_points(0, 0, cell_size, cell_size) - # Translate points to the calculated screen position - pts = [(px + pt_x, py + pt_y) for px, py in rel_pts] - pygame.draw.polygon(temp_surface, color, pts) - - # Blit the transparent preview onto the target surface - surface.blit(temp_surface, (0, 0)) diff --git a/alphatriangle/visualization/drawing/shapes.py b/alphatriangle/visualization/drawing/shapes.py deleted file mode 100644 index 2318deb..0000000 --- a/alphatriangle/visualization/drawing/shapes.py +++ /dev/null @@ -1,35 +0,0 @@ -import pygame - -from ...structs import Shape, Triangle -from ..core import colors - - -def draw_shape( - surface: pygame.Surface, - shape: Shape, - topleft: tuple[int, int], - cell_size: float, - _is_selected: bool = False, - origin_offset: tuple[int, int] = (0, 0), -) -> None: - """Draws a single shape onto a surface.""" - if not shape or not shape.triangles or cell_size <= 0: - return - - shape_color = shape.color - border_color = colors.GRAY - - cw = cell_size - ch = cell_size - - for dr, dc, is_up in shape.triangles: - adj_r, adj_c = dr + origin_offset[0], dc + origin_offset[1] - - tri_x = topleft[0] + adj_c * (cw * 0.75) - tri_y = topleft[1] + adj_r * ch - - temp_tri = Triangle(0, 0, is_up) - pts = [(px + tri_x, py + tri_y) for px, py in temp_tri.get_points(0, 0, cw, ch)] - - pygame.draw.polygon(surface, shape_color, pts) - pygame.draw.polygon(surface, border_color, pts, 1) diff --git a/alphatriangle/visualization/drawing/utils.py b/alphatriangle/visualization/drawing/utils.py deleted file mode 100644 index e69de29..0000000 diff --git a/alphatriangle/visualization/ui/__init__.py b/alphatriangle/visualization/ui/__init__.py deleted file mode 100644 index ef6262a..0000000 --- a/alphatriangle/visualization/ui/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -UI Components subpackage for visualization. -Contains reusable UI elements like progress bars, buttons, etc. -""" - -from .progress_bar import ProgressBar - -__all__ = [ - "ProgressBar", -] diff --git a/alphatriangle/visualization/ui/progress_bar.py b/alphatriangle/visualization/ui/progress_bar.py deleted file mode 100644 index 0ac3473..0000000 --- a/alphatriangle/visualization/ui/progress_bar.py +++ /dev/null @@ -1,260 +0,0 @@ -# File: alphatriangle/visualization/ui/progress_bar.py -# Changes: -# - Modify render text logic: If info_line is provided, prepend default progress info. -# - Cast return type of _get_pulsing_color to satisfy mypy. - -import math -import random -import time -from typing import Any, cast # Added cast - -import pygame - -from ...utils import format_eta -from ..core import colors - - -class ProgressBar: - """ - A reusable progress bar component for visualization. - Handles overflow by cycling colors and displaying actual progress count. - Includes rounded corners and subtle pulsing effect. - Can display a custom info line alongside default progress text. - """ - - def __init__( - self, - entity_title: str, - total_steps: int, - start_time: float | None = None, - initial_steps: int = 0, - initial_color: tuple[int, int, int] = colors.BLUE, - ): - self.entity_title = entity_title - self._original_total_steps = max( - 1, total_steps if total_steps is not None else 1 - ) - self.initial_steps = max(0, initial_steps) - self.current_steps = self.initial_steps - self.start_time = start_time if start_time is not None else time.time() - self._last_step_time = self.start_time - self._step_times: list[float] = [] - self.extra_data: dict[str, Any] = {} - self._current_bar_color = initial_color - self._last_cycle = -1 - self._rng = random.Random() - self._pulse_phase = random.uniform(0, 2 * math.pi) - - def add_steps(self, steps_added: int): - """Adds steps to the progress bar's current count.""" - if steps_added <= 0: - return - self.current_steps += steps_added - self._check_color_cycle() - - def set_current_steps(self, steps: int): - """Directly sets the current step count.""" - self.current_steps = max(0, steps) - self._check_color_cycle() - - def _check_color_cycle(self): - """Updates the bar color if a new cycle is reached.""" - current_cycle = self.current_steps // self._original_total_steps - if current_cycle > self._last_cycle: - self._last_cycle = current_cycle - if current_cycle > 0: - available_colors = [ - c - for c in colors.PROGRESS_BAR_CYCLE_COLORS - if c != self._current_bar_color - ] - if not available_colors: - available_colors = colors.PROGRESS_BAR_CYCLE_COLORS - if available_colors: - self._current_bar_color = self._rng.choice(available_colors) - - def update_extra_data(self, data: dict[str, Any]): - """Updates or adds key-value pairs to display.""" - self.extra_data.update(data) - - def reset_time(self): - """Resets the start time to now, keeping current steps.""" - self.start_time = time.time() - self._last_step_time = self.start_time - self._step_times = [] - self.initial_steps = self.current_steps - - def reset_all(self, new_total_steps: int | None = None): - """Resets steps to 0 and start time to now. Optionally updates total steps.""" - self.current_steps = 0 - self.initial_steps = 0 - if new_total_steps is not None: - self._original_total_steps = max(1, new_total_steps) - self.start_time = time.time() - self._last_step_time = self.start_time - self._step_times = [] - self.extra_data = {} - self._last_cycle = -1 - self._current_bar_color = ( - colors.PROGRESS_BAR_CYCLE_COLORS[0] - if colors.PROGRESS_BAR_CYCLE_COLORS - else colors.BLUE - ) - - def get_progress_fraction(self) -> float: - """Returns progress within the current cycle as a fraction (0.0 to 1.0).""" - if self._original_total_steps <= 1: - return 1.0 - if self.current_steps == 0: - return 0.0 - progress_in_cycle = self.current_steps % self._original_total_steps - if progress_in_cycle == 0 and self.current_steps > 0: - return 1.0 - else: - return progress_in_cycle / self._original_total_steps - - def get_elapsed_time(self) -> float: - """Returns the time elapsed since the start time.""" - return time.time() - self.start_time - - def get_eta_seconds(self) -> float | None: - """Calculates the estimated time remaining in seconds.""" - if ( - self._original_total_steps <= 1 - or self.current_steps >= self._original_total_steps - ): - return None - steps_processed = self.current_steps - self.initial_steps - if steps_processed <= 0: - return None - elapsed = self.get_elapsed_time() - if elapsed < 1.0: - return None - speed = steps_processed / elapsed - if speed < 1e-6: - return None - remaining_steps = self._original_total_steps - self.current_steps - if remaining_steps <= 0: - return 0.0 - eta = remaining_steps / speed - return eta - - def _get_pulsing_color(self) -> tuple[int, int, int]: - """Applies a subtle brightness pulse to the current bar color.""" - base_color = self._current_bar_color - pulse_amplitude = 15 - brightness_offset = int( - pulse_amplitude * math.sin(time.time() * 4 + self._pulse_phase) - ) - # --- CHANGED: Explicitly cast to tuple[int, int, int] --- - pulsed_color = cast( - "tuple[int, int, int]", - tuple(max(0, min(255, c + brightness_offset)) for c in base_color), - ) - # --- END CHANGED --- - return pulsed_color - - def render( - self, - surface: pygame.Surface, - position: tuple[int, int], - width: int, - height: int, - font: pygame.font.Font, - bg_color: tuple[int, int, int] = colors.DARK_GRAY, - text_color: tuple[int, int, int] = colors.WHITE, - border_width: int = 1, - border_color: tuple[int, int, int] = colors.GRAY, - border_radius: int = 3, - info_line: str | None = None, # Keep optional info_line - ): - """Draws the progress bar onto the given surface.""" - x, y = position - progress_fraction = self.get_progress_fraction() - - # Background - bg_rect = pygame.Rect(x, y, width, height) - pygame.draw.rect(surface, bg_color, bg_rect, border_radius=border_radius) - - # Pulsing Bar Fill - fill_width = int(width * progress_fraction) - if fill_width > 0: - fill_width = min(width, fill_width) - fill_rect = pygame.Rect(x, y, fill_width, height) - pulsing_bar_color = self._get_pulsing_color() - pygame.draw.rect( - surface, pulsing_bar_color, fill_rect, border_radius=border_radius - ) - - # Border - if border_width > 0: - pygame.draw.rect( - surface, - border_color, - bg_rect, - border_width, - border_radius=border_radius, - ) - - # --- Text Rendering (Modified) --- - line_height = font.get_height() - if height >= line_height + 4: - # Always generate default progress text parts - elapsed_time_str = format_eta(self.get_elapsed_time()) - eta_seconds = self.get_eta_seconds() - eta_str = format_eta(eta_seconds) if eta_seconds is not None else "--" - processed_steps = self.current_steps - expected_steps = self._original_total_steps - processed_str = ( - f"{processed_steps / 1e6:.1f}M" - if processed_steps >= 1e6 - else ( - f"{processed_steps / 1e3:.0f}k" - if processed_steps >= 1000 - else f"{processed_steps:,}" - ) - ) - expected_str = ( - f"{expected_steps / 1e6:.1f}M" - if expected_steps >= 1e6 - else ( - f"{expected_steps / 1e3:.0f}k" - if expected_steps >= 1000 - else f"{expected_steps:,}" - ) - ) - progress_text = f"{processed_str}/{expected_str}" - if self._original_total_steps <= 1: - progress_text = f"{processed_str}" - extra_text = "" - if self.extra_data: - extra_items = [f"{k}:{v}" for k, v in self.extra_data.items()] - extra_text = " | " + " ".join(extra_items) - - # Construct the display text - default_progress_info = f"{self.entity_title}: {progress_text} (T:{elapsed_time_str} ETA:{eta_str}){extra_text}" - - # --- CHANGED: Prepend default info if info_line is given --- - if info_line is not None: - display_text = ( - f"{default_progress_info} || {info_line}" # Combine with separator - ) - else: - display_text = default_progress_info # Use only default - # --- END CHANGED --- - - # Truncate if necessary - max_text_width = width - 10 - if font.size(display_text)[0] > max_text_width: - while ( - font.size(display_text + "...")[0] > max_text_width - and len(display_text) > 20 - ): - display_text = display_text[:-1] - display_text += "..." - - # Render and blit centered vertically - text_surf = font.render(display_text, True, text_color) - text_rect = text_surf.get_rect(center=(x + width // 2, y + height // 2)) - surface.blit(text_surf, text_rect) - # --- End Text Rendering --- diff --git a/alphatriangle/visualization/utils.py b/alphatriangle/visualization/utils.py deleted file mode 100644 index 8a96b1f..0000000 --- a/alphatriangle/visualization/utils.py +++ /dev/null @@ -1,19 +0,0 @@ -import logging -from typing import cast - -logger = logging.getLogger(__name__) - - -def normalize_color_for_matplotlib( - color_tuple_0_255: tuple[int, int, int], -) -> tuple[float, float, float]: - """Converts RGB tuple (0-255) to Matplotlib format (0.0-1.0).""" - if isinstance(color_tuple_0_255, tuple) and len(color_tuple_0_255) == 3: - # Ensure values are within 0-255 before dividing - valid_color = tuple(max(0, min(255, c)) for c in color_tuple_0_255) - # Cast the result to the expected precise tuple type - return cast("tuple[float, float, float]", tuple(c / 255.0 for c in valid_color)) - logger.warning( - f"Invalid color format for normalization: {color_tuple_0_255}, returning black." - ) - return (0.0, 0.0, 0.0) diff --git a/pyproject.toml b/pyproject.toml index c4bf307..2549a41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,14 @@ +# File: pyproject.toml +# File: alphatriangle/pyproject.toml [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "alphatriangle" -version = "0.4.0" +version = "0.7.0" # Incremented version for removing visualization authors = [{ name="Luis Guilherme P. M.", email="lgpelin92@gmail.com" }] -description = "AlphaZero implementation for a triangle puzzle game." +description = "AlphaZero implementation for a triangle puzzle game (uses trianglengin)." # Kept description readme = "README.md" license = { file="LICENSE" } requires-python = ">=3.10" @@ -18,21 +20,23 @@ classifiers = [ "Operating System :: OS Independent", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Games/Entertainment :: Puzzle Games", - "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", # Status remains Beta ] dependencies = [ - "pygame>=2.1.0", + # --- Core Dependency --- + "trianglengin>=1.0.6", # Depend on the engine library + # --- RL/ML specific dependencies --- "numpy>=1.20.0", "torch>=2.0.0", "torchvision>=0.11.0", "cloudpickle>=2.0.0", - "numba>=0.55.0", + "numba>=0.55.0", # Used in features "mlflow>=1.20.0", - "matplotlib>=3.5.0", "ray>=2.8.0", "pydantic>=2.0.0", "typing_extensions>=4.0.0", - "typer[all]>=0.9.0", # Added typer for CLI + "typer[all]>=0.9.0", # Still needed for the train CLI + "tensorboard>=2.10.0", ] [project.urls] @@ -40,24 +44,40 @@ dependencies = [ "Bug Tracker" = "https://github.com/lguibr/alphatriangle/issues" [project.scripts] +# Keep only the training script entry point alphatriangle = "alphatriangle.cli:app" +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-cov>=3.0.0", + "pytest-mock>=3.0.0", + "ruff", + "mypy", + "build", + "twine", + "codecov", + # REMOVE pygame and matplotlib from dev dependencies + # "pygame>=2.1.0", + # "matplotlib>=3.5.0", +] + + [tool.setuptools.packages.find] # No 'where' needed, find searches from the project root by default # It will find the 'alphatriangle' directory now. - [tool.setuptools.package-data] "*" = ["*.txt", "*.md", "*.json"] # Include non-code files -# --- Tool Configurations (Optional but Recommended) --- +# --- Tool Configurations (ruff, mypy, pytest, coverage) --- [tool.ruff] line-length = 88 [tool.ruff.lint] select = ["E", "W", "F", "I", "UP", "B", "C4", "ARG", "SIM", "TCH", "PTH", "NPY"] -ignore = ["E501"] # Ignore line length errors if needed selectively +ignore = ["E501"] [tool.ruff.format] quote-style = "double" @@ -66,36 +86,33 @@ quote-style = "double" python_version = "3.10" warn_return_any = true warn_unused_configs = true -ignore_missing_imports = true # Start with true, gradually reduce -# Add specific module ignores if necessary -# [[tool.mypy.overrides]] -# module = "some_missing_types_module.*" -# ignore_missing_imports = true +ignore_missing_imports = true # Keep true, especially with the new dependency [tool.pytest.ini_options] minversion = "7.0" -addopts = "-ra -q --cov=alphatriangle --cov-report=term-missing" # Point coverage to the new package dir +addopts = "-ra -q --cov=alphatriangle --cov-report=term-missing" testpaths = [ - "tests", + "tests", # Keep testing alphatriangle specific code ] [tool.coverage.run] omit = [ - "alphatriangle/cli.py", # Exclude CLI from coverage for now - "alphatriangle/visualization/*", # Exclude visualization for now - "alphatriangle/app.py", - "run_*.py", - "alphatriangle/training/logging_utils.py", # Logging utils can be hard to cover fully - "alphatriangle/config/*", # Config models are mostly declarative - "alphatriangle/data/schemas.py", - "alphatriangle/structs/*", - "alphatriangle/utils/types.py", - "alphatriangle/mcts/core/types.py", - "alphatriangle/rl/types.py", + # Keep omissions for CLI, logging, config, types, etc. + "alphatriangle/cli.py", + # REMOVE visualization, interaction, app, structs, environment omissions + "alphatriangle/training/logging_utils.py", + "alphatriangle/config/*", # Keep omitting config + "alphatriangle/data/schemas.py", # Keep omitting schemas + "alphatriangle/utils/types.py", # Keep omitting utils/types + "alphatriangle/mcts/core/types.py", # Keep omitting mcts/types + "alphatriangle/rl/types.py", # Keep omitting rl/types + "alphatriangle/stats/collector.py", # StatsCollectorActor is hard to test fully without Ray cluster + # REMOVE visual_state_actor omission "*/__init__.py", "*/README.md", + "run_*.py", # Keep omitting legacy run scripts (these should be deleted) ] [tool.coverage.report] -fail_under = 28 # Set a reasonable initial coverage target +fail_under = 25 # Adjust target based on remaining code show_missing = true \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f716fe4..6c27163 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,13 @@ -pygame>=2.1.0 +# File: requirements.txt +trianglengin>=1.0.6 numpy>=1.20.0 torch>=2.0.0 torchvision>=0.11.0 cloudpickle>=2.0.0 numba>=0.55.0 mlflow>=1.20.0 -matplotlib>=3.5.0 ray>=2.8.0 pydantic>=2.0.0 typing_extensions>=4.0.0 -pytest>=7.0.0 -pytest-mock>=3.0.0 typer[all]>=0.9.0 -pytest-cov>=3.0.0 \ No newline at end of file +tensorboard>=2.10.0 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 6112a7c..aefdb6b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,52 +6,39 @@ import torch import torch.optim as optim -# Use relative imports for alphatriangle components if running tests from project root -# or absolute imports if package is installed -try: - # Try absolute imports first (for installed package) - from alphatriangle.config import EnvConfig, MCTSConfig, ModelConfig, TrainConfig - from alphatriangle.nn import NeuralNetwork - from alphatriangle.rl import ExperienceBuffer, Trainer - from alphatriangle.utils.types import Experience, StateType -except ImportError: - # Fallback to relative imports (for running tests directly) - from alphatriangle.config import EnvConfig, MCTSConfig, ModelConfig, TrainConfig - from alphatriangle.nn import NeuralNetwork - from alphatriangle.rl import ExperienceBuffer, Trainer - from alphatriangle.utils.types import Experience, StateType - - -# Use default NumPy random number generator +# Import EnvConfig from trianglengin +from trianglengin.config import EnvConfig + +# Keep alphatriangle imports +from alphatriangle.config import MCTSConfig, ModelConfig, TrainConfig +from alphatriangle.nn import NeuralNetwork +from alphatriangle.rl import ExperienceBuffer, Trainer +from alphatriangle.utils.types import Experience, StateType + rng = np.random.default_rng() @pytest.fixture(scope="session") def mock_env_config() -> EnvConfig: - """Provides a default, *valid* EnvConfig for tests (session-scoped).""" - # Use a smaller, fully playable grid for easier testing of placement logic + """Provides a default, *valid* trianglengin.EnvConfig for tests.""" + # Use a smaller, fully playable grid for easier testing rows = 3 cols = 3 - cols_per_row = [cols] * rows - # Pydantic models with defaults can be instantiated without args - # if all required fields have defaults or are computed. - # Let's provide them explicitly for clarity in tests. + # Define playable range for the 3x3 grid + playable_range = [(0, 3), (0, 3), (0, 3)] + # Instantiate trianglengin.EnvConfig return EnvConfig( ROWS=rows, COLS=cols, - COLS_PER_ROW=cols_per_row, + PLAYABLE_RANGE_PER_ROW=playable_range, NUM_SHAPE_SLOTS=1, - MIN_LINE_LENGTH=3, ) +# ... (mock_model_config, mock_train_config, mock_mcts_config remain the same) ... @pytest.fixture(scope="session") def mock_model_config(mock_env_config: EnvConfig) -> ModelConfig: """Provides a default ModelConfig compatible with mock_env_config (session-scoped).""" - # Simple CNN config for testing - # Pydantic models with defaults can be instantiated without args - # if all required fields have defaults or are computed. - # Let's provide them explicitly for clarity in tests. action_dim_int = int(mock_env_config.ACTION_DIM) return ModelConfig( GRID_INPUT_CHANNELS=1, @@ -67,9 +54,9 @@ def mock_model_config(mock_env_config: EnvConfig) -> ModelConfig: TRANSFORMER_LAYERS=0, TRANSFORMER_FC_DIM=32, FC_DIMS_SHARED=[8], - POLICY_HEAD_DIMS=[action_dim_int], # Use casted int + POLICY_HEAD_DIMS=[action_dim_int], VALUE_HEAD_DIMS=[1], - OTHER_NN_INPUT_FEATURES_DIM=10, + OTHER_NN_INPUT_FEATURES_DIM=10, # Keep this consistent with mock_state_type ACTIVATION_FUNCTION="ReLU", USE_BATCH_NORM=True, ) @@ -78,8 +65,6 @@ def mock_model_config(mock_env_config: EnvConfig) -> ModelConfig: @pytest.fixture(scope="session") def mock_train_config() -> TrainConfig: """Provides a default TrainConfig for tests (session-scoped).""" - # Pydantic models with defaults can be instantiated without args - # if all required fields have defaults or are computed. return TrainConfig( BATCH_SIZE=4, BUFFER_CAPACITY=100, @@ -113,7 +98,6 @@ def mock_train_config() -> TrainConfig: @pytest.fixture(scope="session") def mock_mcts_config() -> MCTSConfig: """Provides a default MCTSConfig for tests (session-scoped).""" - # Pydantic models with defaults can be instantiated without args return MCTSConfig( num_simulations=10, puct_coefficient=1.5, @@ -126,10 +110,7 @@ def mock_mcts_config() -> MCTSConfig: ) -# --- Fixtures Moved from tests/mcts/conftest.py --- - - -@pytest.fixture(scope="session") # Make session-scoped if appropriate +@pytest.fixture(scope="session") def mock_state_type( mock_model_config: ModelConfig, mock_env_config: EnvConfig ) -> StateType: @@ -139,6 +120,7 @@ def mock_state_type( mock_env_config.ROWS, mock_env_config.COLS, ) + # Ensure OTHER_NN_INPUT_FEATURES_DIM matches mock_model_config other_shape = (mock_model_config.OTHER_NN_INPUT_FEATURES_DIM,) return { "grid": rng.random(grid_shape, dtype=np.float32), @@ -146,12 +128,11 @@ def mock_state_type( } -@pytest.fixture(scope="session") # Make session-scoped if appropriate +@pytest.fixture(scope="session") def mock_experience( mock_state_type: StateType, mock_env_config: EnvConfig ) -> Experience: """Creates a mock Experience tuple.""" - # Cast ACTION_DIM to int action_dim = int(mock_env_config.ACTION_DIM) policy_target = ( dict.fromkeys(range(action_dim), 1.0 / action_dim) @@ -159,68 +140,60 @@ def mock_experience( else {0: 1.0} ) value_target = random.uniform(-1, 1) - # Cast StateType to Any temporarily to satisfy Experience type hint if needed - # (though StateType should match the TypedDict definition) return (mock_state_type, policy_target, value_target) -@pytest.fixture(scope="session") # Make session-scoped if appropriate +@pytest.fixture(scope="session") def mock_nn_interface( mock_model_config: ModelConfig, - mock_env_config: EnvConfig, + mock_env_config: EnvConfig, # Uses trianglengin.EnvConfig mock_train_config: TrainConfig, ) -> NeuralNetwork: """Provides a NeuralNetwork instance with a mock model for testing.""" - device = torch.device("cpu") # Use CPU for testing + device = torch.device("cpu") + # Pass trianglengin.EnvConfig here nn_interface = NeuralNetwork( mock_model_config, mock_env_config, mock_train_config, device ) - # Optionally replace internal model with a simpler mock if needed, - # but using the actual AlphaTriangleNet with simple config is often better. return nn_interface -@pytest.fixture(scope="session") # Make session-scoped if appropriate +@pytest.fixture(scope="session") def mock_trainer( mock_nn_interface: NeuralNetwork, mock_train_config: TrainConfig, - mock_env_config: EnvConfig, + mock_env_config: EnvConfig, # Uses trianglengin.EnvConfig ) -> Trainer: """Provides a Trainer instance.""" + # Pass trianglengin.EnvConfig here return Trainer(mock_nn_interface, mock_train_config, mock_env_config) -@pytest.fixture(scope="session") # Make session-scoped if appropriate +@pytest.fixture(scope="session") def mock_optimizer(mock_trainer: Trainer) -> optim.Optimizer: """Provides the optimizer from the mock_trainer.""" - # --- CHANGE: Explicitly cast return type --- return cast("optim.Optimizer", mock_trainer.optimizer) - # --- END CHANGE --- -@pytest.fixture # Buffer should likely be function-scoped unless state doesn't matter +@pytest.fixture def mock_experience_buffer(mock_train_config: TrainConfig) -> ExperienceBuffer: """Provides an ExperienceBuffer instance.""" return ExperienceBuffer(mock_train_config) -@pytest.fixture # Buffer should likely be function-scoped unless state doesn't matter +@pytest.fixture def filled_mock_buffer( mock_experience_buffer: ExperienceBuffer, mock_experience: Experience ) -> ExperienceBuffer: """Provides a buffer filled with some mock experiences.""" for _ in range(mock_experience_buffer.min_size_to_train + 5): - # Create slightly different experiences - # Correctly copy StateType dict and its NumPy arrays state_copy: StateType = { "grid": mock_experience[0]["grid"].copy(), "other_features": mock_experience[0]["other_features"].copy(), } - # Ensure grid is writeable before modifying (copy() already does this) state_copy["grid"] += ( rng.standard_normal(state_copy["grid"].shape, dtype=np.float32) * 0.1 ) - # Create the new experience tuple exp_copy: Experience = (state_copy, mock_experience[1], random.uniform(-1, 1)) mock_experience_buffer.add(exp_copy) return mock_experience_buffer diff --git a/tests/environment/__init__.py b/tests/environment/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/environment/conftest.py b/tests/environment/conftest.py deleted file mode 100644 index fa55f5b..0000000 --- a/tests/environment/conftest.py +++ /dev/null @@ -1,66 +0,0 @@ -import pytest - -# Use relative imports for alphatriangle components if running tests from project root -# or absolute imports if package is installed -try: - # Try absolute imports first (for installed package) - from alphatriangle.config import EnvConfig - from alphatriangle.environment import GameState - from alphatriangle.structs import Shape -except ImportError: - # Fallback to relative imports (for running tests directly) - from alphatriangle.config import EnvConfig - from alphatriangle.environment import GameState - from alphatriangle.structs import Shape - - -# Use session-scoped config from top-level conftest -@pytest.fixture(scope="session") -def default_env_config() -> EnvConfig: - """Provides the default EnvConfig used in the specification (session-scoped).""" - # Pydantic models with defaults can be instantiated without args - return EnvConfig() - - -@pytest.fixture -def game_state(default_env_config: EnvConfig) -> GameState: - """Provides a fresh GameState instance for testing.""" - # Use a fixed seed for reproducibility within tests if needed - return GameState(config=default_env_config, initial_seed=123) - - -@pytest.fixture -def game_state_with_fixed_shapes(default_env_config: EnvConfig) -> GameState: - """ - Provides a game state with predictable initial shapes. - Uses a modified EnvConfig with NUM_SHAPE_SLOTS=3 for this specific fixture. - """ - # Create a specific config for this fixture - config_3_slots = default_env_config.model_copy(update={"NUM_SHAPE_SLOTS": 3}) - gs = GameState(config=config_3_slots, initial_seed=456) - - # Override the random shapes with fixed ones for testing placement/refill - fixed_shapes = [ - Shape([(0, 0, False)], (255, 0, 0)), # Single down (matches grid at 0,0) - Shape([(0, 0, True)], (0, 255, 0)), # Single up (matches grid at 0,1) - Shape( - [(0, 0, False), (0, 1, True)], (0, 0, 255) - ), # Domino (matches grid at 0,0 and 0,1) - ] - # This fixture now guarantees NUM_SHAPE_SLOTS is 3 - assert len(fixed_shapes) == gs.env_config.NUM_SHAPE_SLOTS - - for i in range(len(fixed_shapes)): - gs.shapes[i] = fixed_shapes[i] - return gs - - -@pytest.fixture -def simple_shape() -> Shape: - """Provides a simple 3-triangle shape (Down, Up, Down).""" - # Example: L-shape (Down at 0,0; Up at 1,0; Down at 1,1 relative) - # Grid at (r,c) is Down if r+c is even, Up if odd. - # (0,0) is Down. (1,0) is Up. (1,1) is Down. This shape matches grid orientation. - triangles = [(0, 0, False), (1, 0, True), (1, 1, False)] - color = (255, 0, 0) - return Shape(triangles, color) diff --git a/tests/environment/test_actions.py b/tests/environment/test_actions.py deleted file mode 100644 index ff4f1c1..0000000 --- a/tests/environment/test_actions.py +++ /dev/null @@ -1,121 +0,0 @@ -# File: tests/environment/test_actions.py -import pytest - -from alphatriangle.config import EnvConfig -from alphatriangle.environment.core.action_codec import decode_action -from alphatriangle.environment.core.game_state import GameState - -# Import GridLogic correctly -from alphatriangle.environment.grid import logic as GridLogic -from alphatriangle.environment.logic import actions as ActionLogic -from alphatriangle.structs import Shape - -# Fixtures are now implicitly injected from tests/environment/conftest.py - - -@pytest.fixture -def game_state_almost_full(default_env_config: EnvConfig) -> GameState: - """ - Provides a game state where only a few placements are possible. - Grid is filled completely, then specific spots are made empty. - """ - # Use a fresh GameState to avoid side effects from other tests using the shared 'game_state' fixture - gs = GameState(config=default_env_config, initial_seed=987) - # Fill the entire playable grid first using NumPy arrays - playable_mask = ~gs.grid_data._death_np - gs.grid_data._occupied_np[playable_mask] = True - - # Explicitly make specific spots empty: (0,4) [Down] and (0,5) [Up] - empty_spots = [(0, 4), (0, 5)] - for r_empty, c_empty in empty_spots: - if gs.grid_data.valid(r_empty, c_empty): - gs.grid_data._occupied_np[r_empty, c_empty] = False - # Reset color ID as well - gs.grid_data._color_id_np[ - r_empty, c_empty - ] = -1 # Assuming -1 is NO_COLOR_ID - return gs - - -# --- Test Action Logic --- -def test_get_valid_actions_initial(game_state: GameState): - """Test valid actions in an initial empty state.""" - valid_actions = ActionLogic.get_valid_actions(game_state) - assert isinstance(valid_actions, list) - assert len(valid_actions) > 0 # Should be many valid actions initially - - # Check if decoded actions are valid placements - for action_index in valid_actions[:10]: # Check a sample - shape_idx, r, c = decode_action(action_index, game_state.env_config) - shape = game_state.shapes[shape_idx] - assert shape is not None - assert GridLogic.can_place(game_state.grid_data, shape, r, c) - - -def test_get_valid_actions_almost_full(game_state_almost_full: GameState): - """Test valid actions in a nearly full state with only (0,4) and (0,5) free.""" - gs = game_state_almost_full - # Ensure cells (0,4) and (0,5) are indeed empty using NumPy array - assert not gs.grid_data._occupied_np[0, 4], "Cell (0,4) should be empty" - assert not gs.grid_data._occupied_np[0, 5], "Cell (0,5) should be empty" - # Check orientation (calculated) - Apply Ruff suggestion - assert (0 + 4) % 2 == 0, "Cell (0,4) should be Down" # Changed from not (... != 0) - assert (0 + 5) % 2 != 0, "Cell (0,5) should be Up" - - # Single down triangle fits at (0,4) [which is Down] - gs.shapes[0] = Shape([(0, 0, False)], (255, 0, 0)) - # Single up triangle fits at (0,5) [which is Up] - gs.shapes[1] = Shape([(0, 0, True)], (0, 255, 0)) - # Make other slots empty or contain unfittable shapes - if len(gs.shapes) > 2: - gs.shapes[2] = Shape([(0, 0, False), (1, 0, False)], (0, 0, 255)) # Unfittable - if len(gs.shapes) > 3: - gs.shapes[3] = None - - valid_actions = ActionLogic.get_valid_actions(gs) - - # Expect fewer valid actions - assert isinstance(valid_actions, list) - # The setup should allow placing shape 0 at (0,4) and shape 1 at (0,5) - assert len(valid_actions) == 2, ( - f"Expected 2 valid actions, found {len(valid_actions)}. Actions: {valid_actions}" - ) - - expected_placements = {(0, 0, 4), (1, 0, 5)} # (shape_idx, r, c) - found_placements = set() - - # Check if decoded actions are valid placements in the few remaining spots - for action_index in valid_actions: - shape_idx, r, c = decode_action(action_index, gs.env_config) - shape = gs.shapes[shape_idx] - assert shape is not None, f"Shape at index {shape_idx} is None" - assert GridLogic.can_place(gs.grid_data, shape, r, c), ( - f"can_place returned False for action {action_index} -> shape_idx={shape_idx}, r={r}, c={c}" - ) - # Check if placement is in the expected empty area - is_expected_placement = (r == 0 and c == 4 and shape_idx == 0) or ( - r == 0 and c == 5 and shape_idx == 1 - ) - assert is_expected_placement, ( - f"Action {action_index} -> {(shape_idx, r, c)} is not one of the expected placements (0,0,4) or (1,0,5)" - ) - found_placements.add((shape_idx, r, c)) - - assert found_placements == expected_placements - - -def test_get_valid_actions_no_shapes(game_state: GameState): - """Test valid actions when no shapes are available.""" - game_state.shapes = [None] * game_state.env_config.NUM_SHAPE_SLOTS - valid_actions = ActionLogic.get_valid_actions(game_state) - assert valid_actions == [] - - -def test_get_valid_actions_no_space(game_state: GameState): - """Test valid actions when the grid is completely full (or no space for any shape).""" - # Fill the entire playable grid using NumPy arrays - playable_mask = ~game_state.grid_data._death_np - game_state.grid_data._occupied_np[playable_mask] = True - - valid_actions = ActionLogic.get_valid_actions(game_state) - assert valid_actions == [] diff --git a/tests/environment/test_grid_logic.py b/tests/environment/test_grid_logic.py deleted file mode 100644 index 4e8b05e..0000000 --- a/tests/environment/test_grid_logic.py +++ /dev/null @@ -1,147 +0,0 @@ -# File: tests/environment/test_grid_logic.py -import pytest - -# Use relative imports for alphatriangle components if running tests from project root -# or absolute imports if package is installed -try: - # Try absolute imports first (for installed package) - from alphatriangle.config import EnvConfig - from alphatriangle.environment.grid import GridData - from alphatriangle.environment.grid import logic as GridLogic - from alphatriangle.structs import Shape -except ImportError: - # Fallback to relative imports (for running tests directly) - from alphatriangle.config import EnvConfig - from alphatriangle.environment.grid import GridData - from alphatriangle.environment.grid import logic as GridLogic - from alphatriangle.structs import Shape - -# Use shared fixtures implicitly via pytest injection -# from .conftest import game_state, simple_shape # Import fixtures if needed - - -@pytest.fixture -def grid_data(default_env_config: EnvConfig) -> GridData: - """Provides a fresh GridData instance.""" - return GridData(config=default_env_config) - - -# --- Test can_place with NumPy GridData --- -def test_can_place_empty_grid(grid_data: GridData, simple_shape: Shape): - """Test placement on an empty grid.""" - # Place at (2,2). Grid(2,2) is Down (2+2=4, even). Shape(0,0) is Down. OK. - # Grid(3,2) is Up (3+2=5, odd). Shape(1,0) is Up. OK. - # Grid(3,3) is Down (3+3=6, even). Shape(1,1) is Down. OK. - assert GridLogic.can_place(grid_data, simple_shape, 2, 2) - - -def test_can_place_occupied(grid_data: GridData, simple_shape: Shape): - """Test placement fails if any target cell is occupied.""" - # Occupy one cell where the shape would go - target_r, target_c = 3, 2 - grid_data._occupied_np[target_r, target_c] = True - assert not GridLogic.can_place(grid_data, simple_shape, 2, 2) - - -# Remove unused simple_shape argument -def test_can_place_death_zone(grid_data: GridData): - """Test placement fails if any target cell is in a death zone.""" - # Find a death zone cell (e.g., top-left corner in default config) - death_r, death_c = 0, 0 - assert grid_data._death_np[death_r, death_c] - # Try placing a single triangle shape there - single_down_shape = Shape([(0, 0, False)], (255, 0, 0)) - assert not GridLogic.can_place(grid_data, single_down_shape, death_r, death_c) - - -def test_can_place_orientation_mismatch(grid_data: GridData): - """Test placement fails if triangle orientations don't match.""" - # Shape: Single UP triangle at its origin (0,0) - shape_up = Shape([(0, 0, True)], (0, 255, 0)) - # Target grid cell: (0,4), which is DOWN in default config (0+4=4, even) - target_r_down, target_c_down = 0, 4 - assert grid_data.valid(target_r_down, target_c_down) and not grid_data.is_death( - target_r_down, target_c_down - ) - assert not GridLogic.can_place(grid_data, shape_up, target_r_down, target_c_down) - - # Shape: Single DOWN triangle at its origin (0,0) - shape_down = Shape([(0, 0, False)], (255, 0, 0)) - # Target grid cell: (0,3), which is UP in default config (0+3=3, odd) - target_r_up, target_c_up = 0, 3 - assert grid_data.valid(target_r_up, target_c_up) and not grid_data.is_death( - target_r_up, target_c_up - ) - assert not GridLogic.can_place(grid_data, shape_down, target_r_up, target_c_up) - - # Test valid placement using playable coordinates - assert GridLogic.can_place(grid_data, shape_down, 0, 4) # Down on Down at (0,4) - assert GridLogic.can_place(grid_data, shape_up, 0, 3) # Up on Up at (0,3) - - -# --- Test check_and_clear_lines with NumPy GridData --- - - -def occupy_coords(grid_data: GridData, coords: set[tuple[int, int]]): - """Helper to occupy specific coordinates.""" - for r, c in coords: - if grid_data.valid(r, c) and not grid_data._death_np[r, c]: - grid_data._occupied_np[r, c] = True - - -def test_check_and_clear_lines_no_clear(grid_data: GridData): - """Test when newly occupied cells don't complete any lines.""" - newly_occupied = {(2, 2), (3, 2), (3, 3)} # Coords from simple_shape placement - occupy_coords(grid_data, newly_occupied) - lines_cleared, unique_cleared, cleared_lines_set = GridLogic.check_and_clear_lines( - grid_data, newly_occupied - ) - assert lines_cleared == 0 - assert not unique_cleared - assert not cleared_lines_set - # Check grid state unchanged (except for initial occupation) - assert grid_data._occupied_np[2, 2] - assert grid_data._occupied_np[3, 2] - assert grid_data._occupied_np[3, 3] - - -def test_check_and_clear_lines_single_line(grid_data: GridData): - """Test clearing a single horizontal line.""" - # Find a valid horizontal line from the precomputed set - # Example: Look for a line in row 1 (often has long lines) - expected_line_coords = None - for line_fs in grid_data.potential_lines: - coords = list(line_fs) - # Check if it's horizontal and in row 1 - if len(coords) >= grid_data.config.MIN_LINE_LENGTH and all( - r == 1 for r, c in coords - ): - expected_line_coords = frozenset(coords) - break - - assert expected_line_coords is not None, ( - "Could not find a suitable horizontal line in row 1 for testing" - ) - # line_len = len(expected_line_coords) # Removed unused variable - coords_list = list(expected_line_coords) - - # Occupy all but one cell in the line - occupy_coords(grid_data, set(coords_list[:-1])) - # Occupy the last cell - last_coord = coords_list[-1] - newly_occupied = {last_coord} - occupy_coords(grid_data, newly_occupied) - - lines_cleared, unique_cleared, cleared_lines_set = GridLogic.check_and_clear_lines( - grid_data, newly_occupied - ) - - assert lines_cleared == 1 - assert unique_cleared == set(expected_line_coords) # Expect set of coords - assert cleared_lines_set == { - expected_line_coords - } # Expect set of frozensets of coords - - # Verify the line is now empty in the NumPy array - for r, c in expected_line_coords: - assert not grid_data._occupied_np[r, c] diff --git a/tests/environment/test_shape_logic.py b/tests/environment/test_shape_logic.py deleted file mode 100644 index c640e2e..0000000 --- a/tests/environment/test_shape_logic.py +++ /dev/null @@ -1,123 +0,0 @@ -# File: tests/environment/test_shape_logic.py -import random - -import pytest - -from alphatriangle.environment import GameState -from alphatriangle.environment.shapes import logic as ShapeLogic -from alphatriangle.structs import Shape - -# Fixtures are now implicitly injected from tests/environment/conftest.py - - -@pytest.fixture -def fixed_rng() -> random.Random: - """Provides a Random instance with a fixed seed.""" - return random.Random(12345) - - -def test_generate_random_shape(fixed_rng: random.Random): - """Test generating a single random shape.""" - shape = ShapeLogic.generate_random_shape(fixed_rng) - assert isinstance(shape, Shape) - assert shape.triangles is not None - assert shape.color is not None - assert len(shape.triangles) > 0 - # Check connectivity (optional but good) - assert ShapeLogic.is_shape_connected(shape) - - -def test_generate_multiple_shapes(fixed_rng: random.Random): - """Test generating multiple shapes to ensure variety (or lack thereof with fixed seed).""" - shape1 = ShapeLogic.generate_random_shape(fixed_rng) - # Re-seed or use different rng instance if true randomness is needed per call - # For this test, using the same fixed_rng will likely produce the same shape again - shape2 = ShapeLogic.generate_random_shape(fixed_rng) - # --- REMOVED INCORRECT ASSERTION --- - # assert shape1 == shape2 # Expect same shape due to fixed seed - THIS IS INCORRECT - # --- END REMOVED --- - # Check that subsequent calls produce different results with the same RNG instance - assert shape1 != shape2, ( - "Two consecutive calls with the same RNG produced the exact same shape (template and color), which is highly unlikely." - ) - - # Use a different seed for variation - rng2 = random.Random(54321) - shape3 = ShapeLogic.generate_random_shape(rng2) - # Check that different RNGs produce different results (highly likely) - assert shape1 != shape3 or shape1.color != shape3.color - - -def test_refill_shape_slots_empty(game_state: GameState, fixed_rng: random.Random): - """Test refilling when all slots are initially empty.""" - game_state.shapes = [None] * game_state.env_config.NUM_SHAPE_SLOTS - ShapeLogic.refill_shape_slots(game_state, fixed_rng) - assert all(s is not None for s in game_state.shapes) - assert len(game_state.shapes) == game_state.env_config.NUM_SHAPE_SLOTS - - -def test_refill_shape_slots_partial(game_state: GameState, fixed_rng: random.Random): - """Test refilling when some slots are empty - SHOULD NOT REFILL.""" - num_slots = game_state.env_config.NUM_SHAPE_SLOTS - if num_slots < 2: - pytest.skip("Test requires at least 2 shape slots") - - # Start with full slots - ShapeLogic.refill_shape_slots(game_state, fixed_rng) - assert all(s is not None for s in game_state.shapes) - - # Empty one slot - game_state.shapes[0] = None - # Store original state (important: copy shapes if they are mutable) - original_shapes = [s.copy() if s else None for s in game_state.shapes] - - # Attempt refill - it should do nothing - ShapeLogic.refill_shape_slots(game_state, fixed_rng) - - # Check that shapes remain unchanged - assert game_state.shapes == original_shapes, "Refill happened unexpectedly" - - -def test_refill_shape_slots_full(game_state: GameState, fixed_rng: random.Random): - """Test refilling when all slots are already full - SHOULD NOT REFILL.""" - # Start with full slots - ShapeLogic.refill_shape_slots(game_state, fixed_rng) - assert all(s is not None for s in game_state.shapes) - original_shapes = [s.copy() if s else None for s in game_state.shapes] - - # Attempt refill - should do nothing - ShapeLogic.refill_shape_slots(game_state, fixed_rng) - - # Check shapes are unchanged - assert game_state.shapes == original_shapes, "Refill happened when slots were full" - - -def test_refill_shape_slots_batch_trigger(game_state: GameState) -> None: - """Test that refill only happens when ALL slots are empty.""" - num_slots = game_state.env_config.NUM_SHAPE_SLOTS - if num_slots < 2: - pytest.skip("Test requires at least 2 shape slots") - - # Fill all slots initially - ShapeLogic.refill_shape_slots(game_state, game_state._rng) - initial_shapes = [s.copy() if s else None for s in game_state.shapes] - assert all(s is not None for s in initial_shapes) - - # Empty one slot - refill should NOT happen - game_state.shapes[0] = None - shapes_after_one_empty = [s.copy() if s else None for s in game_state.shapes] - ShapeLogic.refill_shape_slots(game_state, game_state._rng) - assert game_state.shapes == shapes_after_one_empty, ( - "Refill happened when only one slot was empty" - ) - - # Empty all slots - refill SHOULD happen - game_state.shapes = [None] * num_slots - ShapeLogic.refill_shape_slots(game_state, game_state._rng) - assert all(s is not None for s in game_state.shapes), ( - "Refill did not happen when all slots were empty" - ) - # Check that the shapes are different from the initial ones (probabilistically) - assert game_state.shapes != initial_shapes, ( - "Shapes after refill are identical to initial shapes (unlikely)" - ) diff --git a/tests/environment/test_step.py b/tests/environment/test_step.py deleted file mode 100644 index d2bda7f..0000000 --- a/tests/environment/test_step.py +++ /dev/null @@ -1,353 +0,0 @@ -# File: tests/environment/test_step.py -import random -from time import sleep - -import pytest - -# Import mocker fixture from pytest-mock -from pytest_mock import MockerFixture - -from alphatriangle.config import EnvConfig -from alphatriangle.environment.core.game_state import GameState -from alphatriangle.environment.grid import ( - logic as GridLogic, -) -from alphatriangle.environment.grid.grid_data import GridData -from alphatriangle.environment.logic.step import calculate_reward, execute_placement -from alphatriangle.structs import Shape, Triangle - -# Fixtures are now implicitly injected from tests/environment/conftest.py - - -def occupy_line( - grid_data: GridData, line_indices: list[int], config: EnvConfig -) -> set[Triangle]: - """Helper to occupy triangles for a given line index list.""" - # occupied_tris: set[Triangle] = set() # Removed unused variable - for idx in line_indices: - r, c = divmod(idx, config.COLS) - # Combine nested if using 'and' - if grid_data.valid(r, c) and not grid_data.is_death(r, c): - grid_data._occupied_np[r, c] = True - # Cannot easily return Triangle objects anymore - # Return empty set as Triangle objects are not the primary state - return set() - - -def occupy_coords(grid_data: GridData, coords: set[tuple[int, int]]): - """Helper to occupy specific coordinates.""" - for r, c in coords: - if grid_data.valid(r, c) and not grid_data.is_death(r, c): - grid_data._occupied_np[r, c] = True - - -# --- New Reward Calculation Tests (v3) --- - - -def test_calculate_reward_v3_placement_only( - simple_shape: Shape, default_env_config: EnvConfig -): - """Test reward: only placement, game not over.""" - placed_count = len(simple_shape.triangles) - unique_coords_cleared: set[tuple[int, int]] = set() - is_game_over = False - reward = calculate_reward( - placed_count, unique_coords_cleared, is_game_over, default_env_config - ) - expected_reward = ( - placed_count * default_env_config.REWARD_PER_PLACED_TRIANGLE - + default_env_config.REWARD_PER_STEP_ALIVE - ) - assert reward == pytest.approx(expected_reward) - - -def test_calculate_reward_v3_single_line_clear( - simple_shape: Shape, default_env_config: EnvConfig -): - """Test reward: placement + line clear, game not over.""" - placed_count = len(simple_shape.triangles) - # Simulate a cleared line of 9 unique coordinates - unique_coords_cleared: set[tuple[int, int]] = {(0, i) for i in range(9)} - is_game_over = False - reward = calculate_reward( - placed_count, unique_coords_cleared, is_game_over, default_env_config - ) - expected_reward = ( - placed_count * default_env_config.REWARD_PER_PLACED_TRIANGLE - + len(unique_coords_cleared) * default_env_config.REWARD_PER_CLEARED_TRIANGLE - + default_env_config.REWARD_PER_STEP_ALIVE - ) - assert reward == pytest.approx(expected_reward) - - -def test_calculate_reward_v3_multi_line_clear( - simple_shape: Shape, default_env_config: EnvConfig -): - """Test reward: placement + multi-line clear (overlapping coords), game not over.""" - placed_count = len(simple_shape.triangles) - # Simulate two lines sharing coordinate (0,0) - line1_coords = {(0, i) for i in range(9)} - line2_coords = {(i, 0) for i in range(5)} - unique_coords_cleared = line1_coords.union(line2_coords) # Union handles uniqueness - is_game_over = False - reward = calculate_reward( - placed_count, unique_coords_cleared, is_game_over, default_env_config - ) - expected_reward = ( - placed_count * default_env_config.REWARD_PER_PLACED_TRIANGLE - + len(unique_coords_cleared) * default_env_config.REWARD_PER_CLEARED_TRIANGLE - + default_env_config.REWARD_PER_STEP_ALIVE - ) - assert reward == pytest.approx(expected_reward) - - -def test_calculate_reward_v3_game_over( - simple_shape: Shape, default_env_config: EnvConfig -): - """Test reward: placement, no line clear, game IS over.""" - placed_count = len(simple_shape.triangles) - unique_coords_cleared: set[tuple[int, int]] = set() - is_game_over = True - reward = calculate_reward( - placed_count, unique_coords_cleared, is_game_over, default_env_config - ) - expected_reward = ( - placed_count * default_env_config.REWARD_PER_PLACED_TRIANGLE - + default_env_config.PENALTY_GAME_OVER - ) - assert reward == pytest.approx(expected_reward) - - -def test_calculate_reward_v3_game_over_with_clear( - simple_shape: Shape, default_env_config: EnvConfig -): - """Test reward: placement + line clear, game IS over.""" - placed_count = len(simple_shape.triangles) - unique_coords_cleared: set[tuple[int, int]] = {(0, i) for i in range(9)} - is_game_over = True - reward = calculate_reward( - placed_count, unique_coords_cleared, is_game_over, default_env_config - ) - expected_reward = ( - placed_count * default_env_config.REWARD_PER_PLACED_TRIANGLE - + len(unique_coords_cleared) * default_env_config.REWARD_PER_CLEARED_TRIANGLE - + default_env_config.PENALTY_GAME_OVER - ) - assert reward == pytest.approx(expected_reward) - - -# --- Test execute_placement with new reward --- - - -def test_execute_placement_simple_no_refill_v3( - game_state_with_fixed_shapes: GameState, -): - """Test placing a shape without clearing lines, verify reward and NO immediate refill.""" - gs = game_state_with_fixed_shapes # Uses 3 slots, initially filled - config = gs.env_config - shape_idx = 0 - original_shape_in_slot_1 = gs.shapes[1] - original_shape_in_slot_2 = gs.shapes[2] - shape_to_place = gs.shapes[shape_idx] - assert shape_to_place is not None - placed_count = len(shape_to_place.triangles) - - r, c = 2, 2 - assert GridLogic.can_place(gs.grid_data, shape_to_place, r, c) - mock_rng = random.Random(42) - - reward = execute_placement(gs, shape_idx, r, c, mock_rng) - - # Verify reward (placement + survival) - expected_reward = ( - placed_count * config.REWARD_PER_PLACED_TRIANGLE + config.REWARD_PER_STEP_ALIVE - ) - assert reward == pytest.approx(expected_reward) - # Score is still tracked separately - assert gs.game_score == placed_count - - # Verify grid state using NumPy arrays - for dr, dc, _ in shape_to_place.triangles: - tri_r, tri_c = r + dr, c + dc - assert gs.grid_data._occupied_np[tri_r, tri_c] - # Cannot easily check color ID without map here, trust placement logic - - # Verify shape slot is now EMPTY - assert gs.shapes[shape_idx] is None - - # --- Verify NO REFILL --- - assert gs.shapes[1] is original_shape_in_slot_1 - assert gs.shapes[2] is original_shape_in_slot_2 - - assert gs.pieces_placed_this_episode == 1 - assert gs.triangles_cleared_this_episode == 0 - assert not gs.is_over() - - -def test_execute_placement_clear_line_no_refill_v3( - game_state_with_fixed_shapes: GameState, -): - """Test placing a shape that clears a line, verify reward and NO immediate refill.""" - gs = game_state_with_fixed_shapes - config = gs.env_config - shape_idx = 0 - shape_single_down = gs.shapes[shape_idx] - assert ( - shape_single_down is not None - and len(shape_single_down.triangles) == 1 - and not shape_single_down.triangles[0][2] - ) - placed_count = len(shape_single_down.triangles) - original_shape_in_slot_1 = gs.shapes[1] - original_shape_in_slot_2 = gs.shapes[2] - - # Pre-occupy line using coordinates - # Line indices [3..11] correspond to r=0, c=3 to c=11 - line_coords_to_occupy = {(0, i) for i in range(3, 12) if i != 4} - occupy_coords(gs.grid_data, line_coords_to_occupy) - cleared_line_coords = {(0, i) for i in range(3, 12)} # Coords (0,3) to (0,11) - - r, c = 0, 4 # Placement position - assert GridLogic.can_place(gs.grid_data, shape_single_down, r, c) - mock_rng = random.Random(42) - - reward = execute_placement(gs, shape_idx, r, c, mock_rng) - - # Verify reward (placement + line clear + survival) - expected_reward = ( - placed_count * config.REWARD_PER_PLACED_TRIANGLE - + len(cleared_line_coords) * config.REWARD_PER_CLEARED_TRIANGLE - + config.REWARD_PER_STEP_ALIVE - ) - assert reward == pytest.approx(expected_reward) - # Score is still tracked separately - assert gs.game_score == placed_count + len(cleared_line_coords) * 2 - - # Verify line is cleared using NumPy array - for row, col in cleared_line_coords: - assert not gs.grid_data._occupied_np[row, col] - - # Verify shape slot is now EMPTY - assert gs.shapes[shape_idx] is None - - # --- Verify NO REFILL --- - assert gs.shapes[1] is original_shape_in_slot_1 - assert gs.shapes[2] is original_shape_in_slot_2 - - assert gs.pieces_placed_this_episode == 1 - assert gs.triangles_cleared_this_episode == len(cleared_line_coords) - assert not gs.is_over() - - -def test_execute_placement_batch_refill_v3(game_state_with_fixed_shapes: GameState): - """Test that placing the last shape triggers a refill and correct reward.""" - gs = game_state_with_fixed_shapes - config = gs.env_config - mock_rng = random.Random(123) - - # Place first shape - shape_1_coords = (0, 4) - assert gs.shapes[0] is not None - placed_count_1 = len(gs.shapes[0].triangles) - assert GridLogic.can_place(gs.grid_data, gs.shapes[0], *shape_1_coords) - reward1 = execute_placement(gs, 0, 0, 4, mock_rng) - expected_reward1 = ( - placed_count_1 * config.REWARD_PER_PLACED_TRIANGLE - + config.REWARD_PER_STEP_ALIVE - ) - assert reward1 == pytest.approx(expected_reward1) - assert gs.shapes[0] is None - assert gs.shapes[1] is not None - assert gs.shapes[2] is not None - - # Place second shape - shape_2_coords = (0, 3) - assert gs.shapes[1] is not None - placed_count_2 = len(gs.shapes[1].triangles) - assert GridLogic.can_place(gs.grid_data, gs.shapes[1], *shape_2_coords) - reward2 = execute_placement(gs, 1, 0, 3, mock_rng) - expected_reward2 = ( - placed_count_2 * config.REWARD_PER_PLACED_TRIANGLE - + config.REWARD_PER_STEP_ALIVE - ) - assert reward2 == pytest.approx(expected_reward2) - assert gs.shapes[0] is None - assert gs.shapes[1] is None - assert gs.shapes[2] is not None - - # Place third shape (triggers refill) - shape_3_coords = (2, 2) - assert gs.shapes[2] is not None - placed_count_3 = len(gs.shapes[2].triangles) - assert GridLogic.can_place(gs.grid_data, gs.shapes[2], *shape_3_coords) - reward3 = execute_placement(gs, 2, 2, 2, mock_rng) - expected_reward3 = ( - placed_count_3 * config.REWARD_PER_PLACED_TRIANGLE - + config.REWARD_PER_STEP_ALIVE - ) # Game not over yet - assert reward3 == pytest.approx(expected_reward3) - sleep(0.01) # Allow time for refill to happen (though it should be synchronous) - - # --- Verify REFILL happened --- - assert all(s is not None for s in gs.shapes), "Not all slots were refilled" - assert gs.shapes[0] != Shape([(0, 0, False)], (255, 0, 0)) - assert gs.shapes[1] != Shape([(0, 0, True)], (0, 255, 0)) - assert gs.shapes[2] != Shape([(0, 0, False), (0, 1, True)], (0, 0, 255)) - - assert gs.pieces_placed_this_episode == 3 - assert not gs.is_over() - - -# Add mocker fixture to the test signature -def test_execute_placement_game_over_v3(game_state: GameState, mocker: MockerFixture): - """Test reward when placement leads to game over, mocking line clears.""" - config = game_state.env_config - # Fill grid almost completely using NumPy arrays - playable_mask = ~game_state.grid_data._death_np - game_state.grid_data._occupied_np[playable_mask] = True - - # Make one spot empty - empty_r, empty_c = 0, 4 - if not game_state.grid_data.is_death(empty_r, empty_c): # Ensure it's playable - game_state.grid_data._occupied_np[empty_r, empty_c] = False - - # Provide a shape that fits the empty spot - shape_to_place = Shape([(0, 0, False)], (255, 0, 0)) # Single down triangle - placed_count = len(shape_to_place.triangles) - - # --- Modify setup to prevent refill --- - unplaceable_shape = Shape([(0, 0, False), (1, 0, False), (2, 0, False)], (1, 1, 1)) - game_state.shapes = [None] * config.NUM_SHAPE_SLOTS - game_state.shapes[0] = shape_to_place - if config.NUM_SHAPE_SLOTS > 1: - game_state.shapes[1] = unplaceable_shape - # --- End modification --- - - assert GridLogic.can_place(game_state.grid_data, shape_to_place, empty_r, empty_c) - mock_rng = random.Random(999) - - # --- Mock check_and_clear_lines --- - # Patch the function within the logic module where execute_placement imports it from - mock_clear = mocker.patch( - "alphatriangle.environment.grid.logic.check_and_clear_lines", - return_value=(0, set(), set()), # Simulate no lines cleared - ) - # --- End Mock --- - - # Execute placement - this should fill the last spot and trigger game over - reward = execute_placement(game_state, 0, empty_r, empty_c, mock_rng) - - # Verify the mock was called (optional but good practice) - mock_clear.assert_called_once() - - # Verify game is over - assert game_state.is_over(), ( - "Game should be over after placing the final piece with no other valid moves" - ) - - # Verify reward (placement + game over penalty) - expected_reward = ( - placed_count * config.REWARD_PER_PLACED_TRIANGLE + config.PENALTY_GAME_OVER - ) - # Use a slightly larger tolerance if needed, but approx should work - assert reward == pytest.approx(expected_reward) diff --git a/tests/mcts/conftest.py b/tests/mcts/conftest.py index d22ce0c..0507d0d 100644 --- a/tests/mcts/conftest.py +++ b/tests/mcts/conftest.py @@ -1,28 +1,21 @@ +# File: tests/mcts/conftest.py import random from collections.abc import Mapping import numpy as np import pytest -# Use relative imports for alphatriangle components if running tests from project root -# or absolute imports if package is installed -try: - # Try absolute imports first (for installed package) - from alphatriangle.config import EnvConfig - from alphatriangle.mcts.core.node import Node - from alphatriangle.utils.types import ActionType, PolicyValueOutput -except ImportError: - # Fallback to relative imports (for running tests directly) - from alphatriangle.config import EnvConfig - from alphatriangle.mcts.core.node import Node - from alphatriangle.utils.types import ActionType, PolicyValueOutput - - -# Use default NumPy random number generator +# Import EnvConfig from trianglengin +from trianglengin.config import EnvConfig + +# Keep alphatriangle imports +from alphatriangle.mcts.core.node import Node +from alphatriangle.utils.types import ActionType, PolicyValueOutput + rng = np.random.default_rng() -# --- Mock GameState --- +# --- Mock GameState (using trianglengin.EnvConfig) --- class MockGameState: """A simplified mock GameState for testing MCTS logic.""" @@ -32,62 +25,58 @@ def __init__( is_terminal: bool = False, outcome: float = 0.0, valid_actions: list[ActionType] | None = None, - env_config: EnvConfig | None = None, + env_config: EnvConfig | None = None, # Expects trianglengin.EnvConfig ): self.current_step = current_step self._is_over = is_terminal self._outcome = outcome - # Use a default EnvConfig if none provided, needed for action dim + # Use trianglengin.EnvConfig self.env_config = env_config if env_config else EnvConfig() - # Cast ACTION_DIM to int action_dim_int = int(self.env_config.ACTION_DIM) self._valid_actions = ( valid_actions if valid_actions is not None else list(range(action_dim_int)) ) + self._game_over_reason: str | None = None # Add reason attribute def is_over(self) -> bool: return self._is_over def get_outcome(self) -> float: if not self._is_over: - raise ValueError("Cannot get outcome of non-terminal state.") + # MCTS expects 0 for non-terminal, not an error + return 0.0 return self._outcome - def valid_actions(self) -> list[ActionType]: - return self._valid_actions + def valid_actions(self) -> set[ActionType]: # Return set to match trianglengin + return set(self._valid_actions) def copy(self) -> "MockGameState": - # Simple copy for testing, doesn't need full state copy return MockGameState( self.current_step, self._is_over, self._outcome, list(self._valid_actions), - self.env_config, + self.env_config, # Pass trianglengin.EnvConfig ) def step(self, action: ActionType) -> tuple[float, bool]: - """ - Simulates taking a step. Returns (reward, done). - Matches the real GameState.step signature. - """ if action not in self.valid_actions(): raise ValueError( f"Invalid action {action} for mock state. Valid: {self.valid_actions()}" ) self.current_step += 1 - # Make terminal condition slightly more complex for testing self._is_over = self.current_step >= 10 or len(self._valid_actions) == 0 - self._outcome = 1.0 if self._is_over else 0.0 - # Simulate removing the taken action from valid actions + self._outcome = -1.0 if self._is_over else 0.0 # Match trianglengin outcome if action in self._valid_actions: self._valid_actions.remove(action) - # Simulate removing another random action sometimes elif self._valid_actions and random.random() < 0.5: self._valid_actions.pop(random.randrange(len(self._valid_actions))) + return 0.0, self._is_over # Return dummy reward - # Return dummy reward and the 'done' status - return 0.0, self._is_over + def force_game_over(self, reason: str): # Add method + self._is_over = True + self._game_over_reason = reason + self._valid_actions = [] def __hash__(self): return hash( @@ -105,7 +94,7 @@ def __eq__(self, other): ) -# --- Mock Network Evaluator --- +# ... (MockNetworkEvaluator remains the same, uses MockGameState) ... class MockNetworkEvaluator: """A mock network evaluator for testing MCTS.""" @@ -117,7 +106,7 @@ def __init__( ): self._default_policy = default_policy self._default_value = default_value - self._action_dim = action_dim # Store as int + self._action_dim = action_dim self.evaluation_history: list[MockGameState] = [] self.batch_evaluation_history: list[list[MockGameState]] = [] @@ -127,16 +116,14 @@ def _get_policy(self, state: MockGameState) -> Mapping[ActionType, float]: policy = { a: p for a, p in self._default_policy.items() if a in valid_actions } - # Normalize if specified policy doesn't sum to 1 over valid actions policy_sum = sum(policy.values()) if policy_sum > 1e-9 and abs(policy_sum - 1.0) > 1e-6: policy = {a: p / policy_sum for a, p in policy.items()} - elif not policy and valid_actions: # Handle empty policy for valid actions + elif not policy and valid_actions: prob = 1.0 / len(valid_actions) policy = dict.fromkeys(valid_actions, prob) return policy - # Default uniform policy valid_actions = state.valid_actions() if not valid_actions: return {} @@ -145,10 +132,8 @@ def _get_policy(self, state: MockGameState) -> Mapping[ActionType, float]: def evaluate(self, state: MockGameState) -> PolicyValueOutput: self.evaluation_history.append(state) - # Cast ACTION_DIM to int self._action_dim = int(state.env_config.ACTION_DIM) policy = self._get_policy(state) - # Create full policy map respecting action_dim full_policy = dict.fromkeys(range(self._action_dim), 0.0) full_policy.update(policy) return full_policy, self._default_value @@ -157,19 +142,13 @@ def evaluate_batch(self, states: list[MockGameState]) -> list[PolicyValueOutput] self.batch_evaluation_history.append(states) results = [] for state in states: - # Use single evaluate logic for consistency results.append(self.evaluate(state)) return results -# --- Pytest Fixtures --- -# Session-scoped fixtures moved to top-level tests/conftest.py - - @pytest.fixture def mock_evaluator(mock_env_config: EnvConfig) -> MockNetworkEvaluator: """Provides a MockNetworkEvaluator instance configured with the mock EnvConfig.""" - # Cast ACTION_DIM to int action_dim_int = int(mock_env_config.ACTION_DIM) return MockNetworkEvaluator(action_dim=action_dim_int) @@ -177,29 +156,25 @@ def mock_evaluator(mock_env_config: EnvConfig) -> MockNetworkEvaluator: @pytest.fixture def root_node_mock_state(mock_env_config: EnvConfig) -> Node: """Provides a root Node with a MockGameState using the mock EnvConfig.""" - # Cast ACTION_DIM to int action_dim_int = int(mock_env_config.ACTION_DIM) + # Pass trianglengin.EnvConfig to MockGameState state = MockGameState( valid_actions=list(range(action_dim_int)), env_config=mock_env_config, ) - # Cast MockGameState to Any temporarily to satisfy Node's type hint return Node(state=state) # type: ignore [arg-type] +# ... (expanded_node_mock_state remains the same, using MockGameState) ... @pytest.fixture def expanded_node_mock_state( root_node_mock_state: Node, mock_evaluator: MockNetworkEvaluator ) -> Node: """Provides an expanded root node with mock children using mock EnvConfig.""" root = root_node_mock_state - # Cast root.state back to MockGameState for the evaluator mock_state: MockGameState = root.state # type: ignore [assignment] - # Ensure evaluator action_dim is int - # Cast ACTION_DIM to int mock_evaluator._action_dim = int(mock_state.env_config.ACTION_DIM) policy, value = mock_evaluator.evaluate(mock_state) - # Ensure policy is not empty before expanding if not policy: policy = ( dict.fromkeys( @@ -221,7 +196,7 @@ def expanded_node_mock_state( prior_probability=prior, ) root.children[action] = child - root.visit_count = 1 # Simulate one visit to root after expansion + root.visit_count = 1 root.total_action_value = value return root @@ -237,38 +212,27 @@ def deep_expanded_node_mock_state( to encourage traversal down the path leading to action 0, then action 1. """ root = expanded_node_mock_state - # Ensure evaluator has correct action dim (as int) - # Cast ACTION_DIM to int mock_evaluator._action_dim = int(mock_env_config.ACTION_DIM) - # Ensure children exist if 0 not in root.children or 1 not in root.children: pytest.skip("Actions 0 or 1 not available in expanded node children") - # --- Configure Root Node to strongly prefer Action 0 --- - root.visit_count = 100 # Give root significant visits + root.visit_count = 100 child0 = root.children[0] - # child1 = root.children[1] # Unused variable - # Child 0: High visit count, good value, high prior (after potential noise) child0.visit_count = 80 - child0.total_action_value = 40 # Q = 0.5 + child0.total_action_value = 40 child0.prior_probability = 0.8 - # Other children: Low visits, low value, low prior for action, child in root.children.items(): if action != 0: child.visit_count = 2 - child.total_action_value = 0 # Q = 0.0 + child.total_action_value = 0 child.prior_probability = 0.01 - # --- Configure Child 0 to strongly prefer Action 1 --- - # Ensure Child 0 has children (expand it manually) - # Use evaluator to get a policy, then manually create children - # Cast child0.state back to MockGameState for the evaluator mock_child0_state: MockGameState = child0.state # type: ignore [assignment] policy_gc, value_gc = mock_evaluator.evaluate(mock_child0_state) - if not policy_gc: # Handle case where mock state has no valid actions + if not policy_gc: policy_gc = ( dict.fromkeys( mock_child0_state.valid_actions(), @@ -278,18 +242,17 @@ def deep_expanded_node_mock_state( else {} ) - valid_gc_actions = mock_child0_state.valid_actions() - if ( - 1 not in valid_gc_actions and valid_gc_actions - ): # If action 1 not valid, pick first valid one - preferred_gc_action = valid_gc_actions[0] - elif not valid_gc_actions: - pytest.skip("Child 0 has no valid actions to create grandchildren") - else: + # Convert set to list before checking/indexing + valid_gc_actions_list = list(mock_child0_state.valid_actions()) + if 1 in valid_gc_actions_list: preferred_gc_action = 1 + elif valid_gc_actions_list: + # If action 1 is not available, pick the first available action + preferred_gc_action = valid_gc_actions_list[0] + else: + pytest.skip("Child 0 has no valid actions to create grandchildren") - # Create grandchild nodes - for action_gc in valid_gc_actions: + for action_gc in valid_gc_actions_list: prior_gc = policy_gc.get(action_gc, 0.0) grandchild_state = MockGameState( current_step=2, valid_actions=[0], env_config=mock_child0_state.env_config @@ -302,19 +265,16 @@ def deep_expanded_node_mock_state( ) child0.children[action_gc] = grandchild - # Now configure grandchild stats preferred_grandchild = child0.children.get(preferred_gc_action) if preferred_grandchild: - # Preferred Grandchild: High visits, good value, high prior preferred_grandchild.visit_count = 60 - preferred_grandchild.total_action_value = 30 # Q = 0.5 + preferred_grandchild.total_action_value = 30 preferred_grandchild.prior_probability = 0.7 - # Other grandchildren: Low visits, low value, low prior for action_gc, grandchild in child0.children.items(): if action_gc != preferred_gc_action: grandchild.visit_count = 1 - grandchild.total_action_value = 0 # Q = 0.0 + grandchild.total_action_value = 0 grandchild.prior_probability = 0.05 return root diff --git a/tests/mcts/test_expansion.py b/tests/mcts/test_expansion.py index a6854b1..7bad074 100644 --- a/tests/mcts/test_expansion.py +++ b/tests/mcts/test_expansion.py @@ -1,27 +1,20 @@ -from typing import Any - import pytest +# Import EnvConfig from trianglengin +# Keep alphatriangle imports from alphatriangle.mcts.core.node import Node - -# Import necessary components and fixtures from alphatriangle.mcts.strategy import expansion -# Import session-scoped fixtures implicitly via pytest injection -# from alphatriangle.config import MCTSConfig # REMOVED - Provided by top-level conftest -from .conftest import ( # Import from conftest (local fixtures) - # mock_env_config, # REMOVED - Provided by top-level conftest - MockGameState, -) +# Import fixtures from local conftest +from .conftest import MockGameState +# ... (tests remain the same, using MockGameState which now uses trianglengin.EnvConfig) ... def test_expand_node_with_policy_basic(root_node_mock_state: Node): """Test basic node expansion with a valid policy.""" node = root_node_mock_state - # Cast node.state to Any temporarily to access mock method - mock_state: Any = node.state + mock_state: MockGameState = node.state # type: ignore [assignment] valid_actions = mock_state.valid_actions() - # Simple policy: uniform over valid actions policy = {action: 1.0 / len(valid_actions) for action in valid_actions} assert not node.is_expanded @@ -35,9 +28,7 @@ def test_expand_node_with_policy_basic(root_node_mock_state: Node): assert child.parent is node assert child.action_taken == action assert child.prior_probability == pytest.approx(1.0 / len(valid_actions)) - assert ( - child.state.current_step == node.state.current_step + 1 - ) # Check state stepped + assert child.state.current_step == node.state.current_step + 1 assert not child.is_expanded assert child.visit_count == 0 assert child.total_action_value == 0.0 @@ -46,52 +37,44 @@ def test_expand_node_with_policy_basic(root_node_mock_state: Node): def test_expand_node_with_policy_partial(root_node_mock_state: Node): """Test expansion when policy doesn't cover all valid actions (should assign 0 prior).""" node = root_node_mock_state - # Cast node.state to Any temporarily to access mock method - mock_state: Any = node.state - valid_actions = mock_state.valid_actions() # e.g., [0, 1, ..., 8] for 3x3 - # Policy only covers action 0 and 1 + mock_state: MockGameState = node.state # type: ignore [assignment] + valid_actions = mock_state.valid_actions() policy = {0: 0.6, 1: 0.4} expansion.expand_node_with_policy(node, policy) assert node.is_expanded - assert len(node.children) == len( - valid_actions - ) # Should still create nodes for all valid actions + assert len(node.children) == len(valid_actions) assert 0 in node.children assert node.children[0].prior_probability == pytest.approx(0.6) assert 1 in node.children assert node.children[1].prior_probability == pytest.approx(0.4) - # Check an action not in the policy but valid if 2 in valid_actions: assert 2 in node.children - assert node.children[2].prior_probability == 0.0 # Prior should default to 0 + assert node.children[2].prior_probability == 0.0 def test_expand_node_with_policy_empty_valid_actions(root_node_mock_state: Node): """Test expansion when the node's state has no valid actions (but isn't terminal yet).""" node = root_node_mock_state - # Cast node.state to Any temporarily to access mock attribute - mock_state: Any = node.state - mock_state._valid_actions = [] # No valid actions - policy = {0: 1.0} # Policy doesn't matter here + mock_state: MockGameState = node.state # type: ignore [assignment] + mock_state._valid_actions = [] + policy = {0: 1.0} expansion.expand_node_with_policy(node, policy) - assert not node.is_expanded # Should not expand + assert not node.is_expanded assert not node.children - # The function should log a warning in this case - # The node's state should be marked as terminal by the expansion function + # Check if the state was forced to game over assert node.state.is_over() + assert "Expansion found no valid actions" in node.state._game_over_reason # type: ignore def test_expand_node_with_policy_already_expanded(root_node_mock_state: Node): """Test that expanding an already expanded node does nothing.""" node = root_node_mock_state policy = {0: 1.0} - # Manually add a child to simulate expansion - # Pass the env_config from the root node's state node.children[0] = Node( state=MockGameState(current_step=1, env_config=node.state.env_config), # type: ignore [arg-type] parent=node, @@ -100,42 +83,34 @@ def test_expand_node_with_policy_already_expanded(root_node_mock_state: Node): assert node.is_expanded original_children = node.children.copy() - expansion.expand_node_with_policy(node, policy) - - assert node.children == original_children # Children should not change + assert node.children == original_children def test_expand_node_with_policy_terminal_node(root_node_mock_state: Node): """Test that expanding a terminal node does nothing.""" node = root_node_mock_state - # Cast node.state to Any temporarily to access mock attribute - mock_state: Any = node.state - mock_state._is_over = True # Mark as terminal + mock_state: MockGameState = node.state # type: ignore [assignment] + mock_state._is_over = True policy = {0: 1.0} assert not node.is_expanded expansion.expand_node_with_policy(node, policy) - assert not node.is_expanded # Should not expand + assert not node.is_expanded def test_expand_node_with_invalid_policy_content(root_node_mock_state: Node): """Test expansion handles policy with invalid content (e.g., negative priors).""" - # Note: The main search loop should validate policy *before* calling expand. - # This test checks if expand handles it defensively (it currently clamps). node = root_node_mock_state - # Cast node.state to Any temporarily to access mock method - mock_state: Any = node.state + mock_state: MockGameState = node.state # type: ignore [assignment] valid_actions = mock_state.valid_actions() - policy = {0: 1.5, 1: -0.5} # Invalid priors + policy = {0: 1.5, 1: -0.5} expansion.expand_node_with_policy(node, policy) assert node.is_expanded assert len(node.children) == len(valid_actions) - assert node.children[0].prior_probability == pytest.approx( - 1.5 - ) # Currently doesn't clamp > 1 - assert node.children[1].prior_probability == 0.0 # Clamps negative to 0 + assert node.children[0].prior_probability == pytest.approx(1.5) + assert node.children[1].prior_probability == 0.0 if 2 in valid_actions: assert node.children[2].prior_probability == 0.0 diff --git a/tests/mcts/test_selection.py b/tests/mcts/test_selection.py index c3ecdce..cf60c82 100644 --- a/tests/mcts/test_selection.py +++ b/tests/mcts/test_selection.py @@ -1,22 +1,21 @@ +# File: tests/mcts/test_selection.py import math -from typing import Any import pytest -# Import session-scoped fixtures implicitly via pytest injection -from alphatriangle.config import MCTSConfig # Keep MCTSConfig type hint if needed -from alphatriangle.mcts.core.node import Node +# Import EnvConfig from trianglengin +from trianglengin.config import EnvConfig -# Import necessary components and fixtures +# Keep alphatriangle imports +from alphatriangle.config import MCTSConfig +from alphatriangle.mcts.core.node import Node from alphatriangle.mcts.strategy import selection -from .conftest import ( # Import from conftest (local fixtures) - EnvConfig, # Keep EnvConfig type hint if needed - MockGameState, -) +# Import fixtures from local conftest +from .conftest import MockGameState -# --- Test PUCT Calculation --- +# ... (tests remain the same, using MockGameState which now uses trianglengin.EnvConfig) ... def test_puct_calculation_unvisited_child( mock_mcts_config: MCTSConfig, mock_env_config: EnvConfig ): @@ -36,7 +35,7 @@ def test_puct_calculation_unvisited_child( child, parent.visit_count, mock_mcts_config ) - assert q_value == 0.0, "Q-value should be 0 for unvisited node" + assert q_value == 0.0 expected_exploration = ( mock_mcts_config.puct_coefficient * 0.5 * (math.sqrt(10) / (1 + 0)) ) @@ -86,13 +85,11 @@ def test_puct_calculation_zero_parent_visits( child.visit_count = 0 child.total_action_value = 0.0 - # Calculate PUCT with parent_visit_count=0 score, q_value, exploration = selection.calculate_puct_score( child, 0, mock_mcts_config ) assert q_value == 0.0 - # The formula uses max(1, parent_visit_count) for the sqrt term numerator expected_exploration = ( mock_mcts_config.puct_coefficient * 0.6 * (math.sqrt(1) / (1 + 0)) ) @@ -100,7 +97,6 @@ def test_puct_calculation_zero_parent_visits( assert score == pytest.approx(q_value + exploration) -# --- Test Child Selection --- def test_select_child_node_basic( expanded_node_mock_state: Node, mock_mcts_config: MCTSConfig ): @@ -108,30 +104,23 @@ def test_select_child_node_basic( parent = expanded_node_mock_state parent.visit_count = 10 - # Ensure children 0, 1, 2 exist before accessing them if 0 not in parent.children or 1 not in parent.children or 2 not in parent.children: pytest.skip("Required children (0, 1, 2) not present in fixture") child0 = parent.children[0] child0.visit_count = 1 - child0.total_action_value = 0.8 # Q = 0.8 - child0.prior_probability = 0.1 # Lower prior, higher Q + child0.total_action_value = 0.8 + child0.prior_probability = 0.1 child1 = parent.children[1] child1.visit_count = 5 - child1.total_action_value = 0.5 # Low Q (0.1), higher visits - child1.prior_probability = 0.6 # High prior + child1.total_action_value = 0.5 + child1.prior_probability = 0.6 child2 = parent.children[2] child2.visit_count = 3 - child2.total_action_value = 1.5 # Mid Q (0.5), mid visits - child2.prior_probability = 0.3 # Mid prior - - # Calculate scores with C=1.5 (from config fixture now) - # Score0 = 0.8/1 + 1.5 * 0.1 * (sqrt(10) / (1 + 1)) ~ 0.8 + 0.237 = 1.037 - # Score1 = 0.5/5 + 1.5 * 0.6 * (sqrt(10) / (1 + 5)) ~ 0.1 + 0.474 = 0.574 - # Score2 = 1.5/3 + 1.5 * 0.3 * (sqrt(10) / (1 + 3)) ~ 0.5 + 0.355 = 0.855 - # Child 0 should have the highest score + child2.total_action_value = 1.5 + child2.prior_probability = 0.3 selected_child = selection.select_child_node(parent, mock_mcts_config) assert selected_child is child0 @@ -142,7 +131,7 @@ def test_select_child_node_no_children( ): """Test selection raises error if node has no children.""" parent = root_node_mock_state - assert not parent.children # Ensure it has no children + assert not parent.children with pytest.raises(selection.SelectionError): selection.select_child_node(parent, mock_mcts_config) @@ -154,42 +143,33 @@ def test_select_child_node_tie_breaking( parent = expanded_node_mock_state parent.visit_count = 10 - # Ensure children 0, 1, 2 exist if 0 not in parent.children or 1 not in parent.children or 2 not in parent.children: pytest.skip("Required children (0, 1, 2) not present in fixture") child0 = parent.children[0] child0.visit_count = 1 - child0.total_action_value = 0.9 # Q = 0.9 + child0.total_action_value = 0.9 child0.prior_probability = 0.4 child1 = parent.children[1] child1.visit_count = 1 - child1.total_action_value = 0.9 # Q = 0.9 + child1.total_action_value = 0.9 child1.prior_probability = 0.4 child2 = parent.children[2] child2.visit_count = 5 - child2.total_action_value = 0.1 # Q = 0.02 + child2.total_action_value = 0.1 child2.prior_probability = 0.1 - # Score0 = 0.9 + C * 0.4 * (sqrt(10)/2) - # Score1 = 0.9 + C * 0.4 * (sqrt(10)/2) - # Score2 = 0.02 + C * 0.1 * (sqrt(10)/6) - # Child 0 and 1 have equal highest score selected_child = selection.select_child_node(parent, mock_mcts_config) - assert ( - selected_child is child0 or selected_child is child1 - ) # Should select one of the tied best + assert selected_child is child0 or selected_child is child1 -# --- Test Dirichlet Noise --- def test_add_dirichlet_noise( expanded_node_mock_state: Node, mock_mcts_config: MCTSConfig ): """Test that Dirichlet noise modifies prior probabilities correctly.""" node = expanded_node_mock_state - # Create a copy of the config to modify locally for this test config_copy = mock_mcts_config.model_copy(deep=True) config_copy.dirichlet_alpha = 0.5 config_copy.dirichlet_epsilon = 0.25 @@ -198,12 +178,8 @@ def test_add_dirichlet_noise( if n_children == 0: pytest.skip("Node has no children to add noise to.") original_priors = {a: c.prior_probability for a, c in node.children.items()} - # original_sum = sum(original_priors.values()) # Unused variable - # Use default_rng for modern NumPy random generation - # rng = np.random.default_rng(42) # Removed unused variable selection.add_dirichlet_noise(node, config_copy) - # Resetting global seed is less ideal, rely on instance if needed elsewhere new_priors = {a: c.prior_probability for a, c in node.children.items()} mixed_sum = sum(new_priors.values()) @@ -212,14 +188,12 @@ def test_add_dirichlet_noise( priors_changed = False for action, new_p in new_priors.items(): assert action in original_priors - assert 0.0 <= new_p <= 1.0 # Check bounds + assert 0.0 <= new_p <= 1.0 if abs(new_p - original_priors[action]) > 1e-9: priors_changed = True - assert priors_changed, "Priors did not change after adding noise" - assert mixed_sum == pytest.approx(1.0, abs=1e-6), ( - f"Noisy priors sum to {mixed_sum}, not 1.0" - ) + assert priors_changed + assert mixed_sum == pytest.approx(1.0, abs=1e-6) def test_add_dirichlet_noise_disabled( @@ -231,7 +205,6 @@ def test_add_dirichlet_noise_disabled( pytest.skip("Node has no children.") original_priors = {a: c.prior_probability for a, c in node.children.items()} - # Create copies of the config to modify locally config_alpha_zero = mock_mcts_config.model_copy(deep=True) config_alpha_zero.dirichlet_alpha = 0.0 config_alpha_zero.dirichlet_epsilon = 0.25 @@ -242,22 +215,16 @@ def test_add_dirichlet_noise_disabled( selection.add_dirichlet_noise(node, config_alpha_zero) priors_after_alpha_zero = {a: c.prior_probability for a, c in node.children.items()} - assert priors_after_alpha_zero == original_priors, ( - "Priors changed when alpha was zero" - ) + assert priors_after_alpha_zero == original_priors - # Reset priors before next test for a, p in original_priors.items(): node.children[a].prior_probability = p selection.add_dirichlet_noise(node, config_eps_zero) priors_after_eps_zero = {a: c.prior_probability for a, c in node.children.items()} - assert priors_after_eps_zero == original_priors, ( - "Priors changed when epsilon was zero" - ) + assert priors_after_eps_zero == original_priors -# --- Test Traversal --- def test_traverse_to_leaf_unexpanded( root_node_mock_state: Node, mock_mcts_config: MCTSConfig ): @@ -273,9 +240,7 @@ def test_traverse_to_leaf_expanded( """Test traversal selects a child from an expanded node and stops (depth 1).""" root = expanded_node_mock_state for child in root.children.values(): - assert not child.is_expanded, ( - f"Child {child} should not be expanded in this fixture setup" - ) + assert not child.is_expanded leaf, depth = selection.traverse_to_leaf(root, mock_mcts_config) @@ -288,7 +253,6 @@ def test_traverse_to_leaf_max_depth( ): """Test traversal stops at max depth.""" root = expanded_node_mock_state - # Create a copy of the config to modify locally config_copy = mock_mcts_config.model_copy(deep=True) config_copy.max_search_depth = 0 @@ -297,13 +261,10 @@ def test_traverse_to_leaf_max_depth( assert leaf is root assert depth == 0 - # --- Test max depth = 1 --- config_copy.max_search_depth = 1 - # Ensure root has children if not root.children: pytest.skip("Root node has no children for max depth 1 test") - # Manually expand one child to test if traversal stops *before* selecting grandchild child0 = next(iter(root.children.values())) child0.children[0] = Node( state=MockGameState(current_step=2, env_config=root.state.env_config), # type: ignore [arg-type] @@ -313,10 +274,8 @@ def test_traverse_to_leaf_max_depth( leaf, depth = selection.traverse_to_leaf(root, config_copy) - assert leaf in root.children.values(), ( - "Leaf node should be a direct child of the root" - ) - assert depth == 1, "Depth should be 1 when max_search_depth is 1" + assert leaf in root.children.values() + assert depth == 1 def test_traverse_to_terminal_leaf( @@ -324,70 +283,58 @@ def test_traverse_to_terminal_leaf( ): """Test traversal stops at a terminal node.""" root = expanded_node_mock_state - # Ensure child 1 exists if 1 not in root.children: pytest.skip("Child 1 not present in fixture") child1 = root.children[1] - # Cast child1.state to Any temporarily to access mock attribute - mock_child1_state: Any = child1.state - mock_child1_state._is_over = True # Mark child 1 as terminal + mock_child1_state: MockGameState = child1.state # type: ignore [assignment] + mock_child1_state._is_over = True - # Make child 1 highly attractive to ensure it's selected root.visit_count = 10 for action, child in root.children.items(): if action == 1: child.visit_count = 5 - child.total_action_value = 4 # High Q = 0.8 - child.prior_probability = 0.8 # High P + child.total_action_value = 4 + child.prior_probability = 0.8 else: child.visit_count = 1 - child.total_action_value = 0 # Low Q = 0.0 - child.prior_probability = 0.1 # Low P + child.total_action_value = 0 + child.prior_probability = 0.1 leaf, depth = selection.traverse_to_leaf(root, mock_mcts_config) - assert leaf is child1, "Traversal should stop at the terminal child node" - assert depth == 1, "Depth should be 1 as traversal stops at the terminal child" + assert leaf is child1 + assert depth == 1 -# --- Added Test for Deeper Traversal --- def test_traverse_to_leaf_deeper( deep_expanded_node_mock_state: Node, mock_mcts_config: MCTSConfig ): """Test traversal goes deeper than 1 level using the specifically configured fixture.""" - root = deep_expanded_node_mock_state # This fixture is configured to prefer path 0 -> 1 (or first valid) - # Create a copy of the config to modify locally + root = deep_expanded_node_mock_state config_copy = mock_mcts_config.model_copy(deep=True) - config_copy.max_search_depth = 10 # Ensure max depth doesn't interfere + config_copy.max_search_depth = 10 - # --- Assert fixture setup is correct --- - assert 0 in root.children, "Fixture should have child 0" + assert 0 in root.children child0 = root.children[0] - assert child0.is_expanded, "Child 0 should be expanded in the fixture" - assert child0.children, "Child 0 should have grandchildren" + assert child0.is_expanded + assert child0.children - # Determine the action preferred by the fixture logic for child0 - # Cast child0.state to Any temporarily to access mock method - mock_child0_state: Any = child0.state + mock_child0_state: MockGameState = child0.state # type: ignore [assignment] valid_gc_actions = mock_child0_state.valid_actions() - if 1 in valid_gc_actions: + # Convert set to list before checking/indexing + valid_gc_actions_list = list(valid_gc_actions) + if 1 in valid_gc_actions_list: preferred_gc_action = 1 - elif valid_gc_actions: - preferred_gc_action = valid_gc_actions[0] + elif valid_gc_actions_list: + # Index the list, not the set + preferred_gc_action = valid_gc_actions_list[0] else: pytest.fail("Fixture error: Child 0 has no valid actions for grandchildren") expected_grandchild = child0.children.get(preferred_gc_action) - assert expected_grandchild is not None, ( - f"Expected grandchild for action {preferred_gc_action} not found" - ) + assert expected_grandchild is not None - # --- Run traversal --- leaf, depth = selection.traverse_to_leaf(root, config_copy) - # --- Assertions --- - # It should select child0, then the expected grandchild (which is a leaf in the fixture setup) - assert leaf is expected_grandchild, ( - f"Expected leaf {expected_grandchild}, but got {leaf}" - ) - assert depth == 2, f"Expected depth 2, but got {depth}" + assert leaf is expected_grandchild + assert depth == 2 diff --git a/tests/nn/test_network.py b/tests/nn/test_network.py index d99e4a4..db9d841 100644 --- a/tests/nn/test_network.py +++ b/tests/nn/test_network.py @@ -1,24 +1,19 @@ -# File: tests/nn/test_network.py from unittest.mock import MagicMock, patch import numpy as np import pytest import torch -# Use relative imports for alphatriangle components if running tests from project root -# or absolute imports if package is installed -try: - # Try absolute imports first (for installed package) - from alphatriangle.config import EnvConfig, ModelConfig, TrainConfig - from alphatriangle.environment import GameState - from alphatriangle.nn import AlphaTriangleNet, NeuralNetwork - from alphatriangle.utils.types import StateType -except ImportError: - # Fallback to relative imports (for running tests directly) - from alphatriangle.config import EnvConfig, ModelConfig, TrainConfig - from alphatriangle.environment import GameState - from alphatriangle.nn import AlphaTriangleNet, NeuralNetwork - from alphatriangle.utils.types import StateType +# Import EnvConfig from trianglengin +from trianglengin.config import EnvConfig + +# Import GameState from trianglengin +from trianglengin.core.environment import GameState + +# Keep alphatriangle imports +from alphatriangle.config import ModelConfig, TrainConfig +from alphatriangle.nn import AlphaTriangleNet, NeuralNetwork +from alphatriangle.utils.types import StateType # Use module-level rng from tests/conftest.py from tests.conftest import rng @@ -69,7 +64,7 @@ def nn_interface( @pytest.fixture def mock_game_state(env_config: EnvConfig) -> GameState: """Provides a real GameState object for testing NN interface.""" - # Use a real GameState instance + # Use a real GameState instance from trianglengin return GameState(config=env_config, initial_seed=123) diff --git a/tests/rl/test_trainer.py b/tests/rl/test_trainer.py index fc647b6..14537bf 100644 --- a/tests/rl/test_trainer.py +++ b/tests/rl/test_trainer.py @@ -1,24 +1,25 @@ -# File: tests/rl/test_trainer.py import numpy as np import pytest import torch -from alphatriangle.config import EnvConfig, ModelConfig, TrainConfig +# Import EnvConfig from trianglengin +from trianglengin.config import EnvConfig + +# Keep alphatriangle imports +from alphatriangle.config import ModelConfig, TrainConfig from alphatriangle.nn import NeuralNetwork from alphatriangle.rl import ExperienceBuffer, Trainer from alphatriangle.utils.types import Experience, PERBatchSample, StateType -# --- Fixtures --- - +# ... (fixtures remain mostly the same, but ensure EnvConfig is from trianglengin) ... @pytest.fixture -def env_config(mock_env_config: EnvConfig) -> EnvConfig: +def env_config(mock_env_config: EnvConfig) -> EnvConfig: # Uses trianglengin.EnvConfig return mock_env_config @pytest.fixture def model_config(mock_model_config: ModelConfig) -> ModelConfig: - # Ensure feature dim matches mock_state_type mock_model_config.OTHER_NN_INPUT_FEATURES_DIM = 10 return mock_model_config @@ -34,25 +35,24 @@ def train_config_uniform(mock_train_config: TrainConfig) -> TrainConfig: def train_config_per(mock_train_config: TrainConfig) -> TrainConfig: cfg = mock_train_config.model_copy(deep=True) cfg.USE_PER = True - cfg.PER_BETA_ANNEAL_STEPS = 100 # Set anneal steps + cfg.PER_BETA_ANNEAL_STEPS = 100 return cfg @pytest.fixture def nn_interface( mock_model_config: ModelConfig, - env_config: EnvConfig, + env_config: EnvConfig, # Uses trianglengin.EnvConfig train_config_uniform: TrainConfig, ) -> NeuralNetwork: """Provides a NeuralNetwork instance for testing, configured for uniform buffer.""" - # Use train_config_uniform here, or make it parameterizable if needed - device = torch.device("cpu") # Use CPU for testing + device = torch.device("cpu") + # Pass trianglengin.EnvConfig nn_interface_instance = NeuralNetwork( mock_model_config, env_config, train_config_uniform, device ) - # Ensure model is on CPU for testing consistency nn_interface_instance.model.to(device) - nn_interface_instance.model.eval() # Ensure it starts in eval mode if needed by tests + nn_interface_instance.model.eval() return nn_interface_instance @@ -60,23 +60,24 @@ def nn_interface( def trainer_uniform( nn_interface: NeuralNetwork, train_config_uniform: TrainConfig, - env_config: EnvConfig, + env_config: EnvConfig, # Uses trianglengin.EnvConfig ) -> Trainer: """Provides a Trainer instance configured for uniform sampling.""" + # Pass trianglengin.EnvConfig return Trainer(nn_interface, train_config_uniform, env_config) @pytest.fixture def trainer_per( - nn_interface: NeuralNetwork, train_config_per: TrainConfig, env_config: EnvConfig + nn_interface: NeuralNetwork, + train_config_per: TrainConfig, + env_config: EnvConfig, # Uses trianglengin.EnvConfig ) -> Trainer: """Provides a Trainer instance configured for PER.""" - # Need to re-create NN interface if its config depends on train_config - # For now, assume nn_interface created with uniform config is okay for PER trainer too + # Pass trianglengin.EnvConfig return Trainer(nn_interface, train_config_per, env_config) -# Use mock_experience implicitly from tests/conftest.py @pytest.fixture def buffer_uniform( train_config_uniform: TrainConfig, mock_experience: Experience @@ -84,7 +85,6 @@ def buffer_uniform( """Provides a filled uniform buffer.""" buffer = ExperienceBuffer(train_config_uniform) for i in range(buffer.min_size_to_train + 5): - # Correctly copy StateType dict and its NumPy arrays state_copy: StateType = { "grid": mock_experience[0]["grid"].copy() + i, "other_features": mock_experience[0]["other_features"].copy() + i, @@ -98,7 +98,6 @@ def buffer_uniform( return buffer -# Use mock_experience implicitly from tests/conftest.py @pytest.fixture def buffer_per( train_config_per: TrainConfig, mock_experience: Experience @@ -106,7 +105,6 @@ def buffer_per( """Provides a filled PER buffer.""" buffer = ExperienceBuffer(train_config_per) for i in range(buffer.min_size_to_train + 5): - # Correctly copy StateType dict and its NumPy arrays state_copy: StateType = { "grid": mock_experience[0]["grid"].copy() + i, "other_features": mock_experience[0]["other_features"].copy() + i, @@ -116,31 +114,25 @@ def buffer_per( mock_experience[1], mock_experience[2] + i * 0.1, ) - buffer.add(exp_copy) # Adds with max priority + buffer.add(exp_copy) return buffer -# --- Tests --- - - +# ... (tests remain the same, using updated fixtures) ... def test_trainer_initialization(trainer_uniform: Trainer): assert trainer_uniform.nn is not None assert trainer_uniform.model is not None assert trainer_uniform.optimizer is not None - # Scheduler might be None depending on config assert hasattr(trainer_uniform, "scheduler") -# Use mock_experience implicitly from tests/conftest.py def test_prepare_batch(trainer_uniform: Trainer, mock_experience: Experience): """Test the internal _prepare_batch method.""" batch_size = trainer_uniform.train_config.BATCH_SIZE batch = [mock_experience] * batch_size - # --- CHANGED: Variable name for clarity --- grid_t, other_t, policy_target_t, n_step_return_t = trainer_uniform._prepare_batch( batch ) - # --- END CHANGED --- assert grid_t.shape == ( batch_size, @@ -153,16 +145,12 @@ def test_prepare_batch(trainer_uniform: Trainer, mock_experience: Experience): trainer_uniform.model_config.OTHER_NN_INPUT_FEATURES_DIM, ) assert policy_target_t.shape == (batch_size, trainer_uniform.env_config.ACTION_DIM) - # --- CHANGED: Assert shape is (batch_size,) --- assert n_step_return_t.shape == (batch_size,) - # --- END CHANGED --- assert grid_t.device == trainer_uniform.device assert other_t.device == trainer_uniform.device assert policy_target_t.device == trainer_uniform.device - # --- CHANGED: Check n_step_return_t device --- assert n_step_return_t.device == trainer_uniform.device - # --- END CHANGED --- def test_train_step_uniform(trainer_uniform: Trainer, buffer_uniform: ExperienceBuffer): @@ -180,12 +168,11 @@ def test_train_step_uniform(trainer_uniform: Trainer, buffer_uniform: Experience assert "total_loss" in loss_info assert "policy_loss" in loss_info assert "value_loss" in loss_info - assert loss_info["total_loss"] > 0 # Loss should generally be positive + assert loss_info["total_loss"] > 0 assert isinstance(td_errors, np.ndarray) assert td_errors.shape == (trainer_uniform.train_config.BATCH_SIZE,) - # Check if model parameters changed params_changed = False for p_initial, p_final in zip( initial_params, trainer_uniform.model.parameters(), strict=True @@ -199,7 +186,6 @@ def test_train_step_uniform(trainer_uniform: Trainer, buffer_uniform: Experience def test_train_step_per(trainer_per: Trainer, buffer_per: ExperienceBuffer): """Test a single training step with PER.""" initial_params = [p.clone() for p in trainer_per.model.parameters()] - # Need current_step for PER beta calculation sample = buffer_per.sample( trainer_per.train_config.BATCH_SIZE, current_train_step=10 ) @@ -218,9 +204,8 @@ def test_train_step_per(trainer_per: Trainer, buffer_per: ExperienceBuffer): assert isinstance(td_errors, np.ndarray) assert td_errors.shape == (trainer_per.train_config.BATCH_SIZE,) - assert np.all(np.isfinite(td_errors)) # TD errors should be finite + assert np.all(np.isfinite(td_errors)) - # Check if model parameters changed params_changed = False for p_initial, p_final in zip( initial_params, trainer_per.model.parameters(), strict=True @@ -246,14 +231,9 @@ def test_get_current_lr(trainer_uniform: Trainer): """Test retrieving the current learning rate.""" lr = trainer_uniform.get_current_lr() assert isinstance(lr, float) - assert ( - lr == trainer_uniform.train_config.LEARNING_RATE - ) # Initially should be the base LR + assert lr == trainer_uniform.train_config.LEARNING_RATE - # Simulate scheduler step if scheduler exists if trainer_uniform.scheduler: trainer_uniform.scheduler.step() lr_after_step = trainer_uniform.get_current_lr() assert isinstance(lr_after_step, float) - # Cannot assert exact value without knowing scheduler type/params easily - # Just check it's still a float diff --git a/tests/stats/__init__.py b/tests/stats/__init__.py index e69de29..903ac41 100644 --- a/tests/stats/__init__.py +++ b/tests/stats/__init__.py @@ -0,0 +1,2 @@ +# File: tests/stats/__init__.py +# This file can be empty diff --git a/tests/stats/test_collector.py b/tests/stats/test_collector.py index 3afd7f5..e988917 100644 --- a/tests/stats/test_collector.py +++ b/tests/stats/test_collector.py @@ -1,4 +1,3 @@ -# File: tests/stats/test_collector.py import logging from collections import deque @@ -6,6 +5,8 @@ import pytest import ray +# Import GameState from trianglengin +# Keep alphatriangle imports from alphatriangle.stats import StatsCollectorActor from alphatriangle.utils.types import StepInfo # Import StepInfo @@ -253,11 +254,16 @@ def test_get_set_state(stats_actor): class MockGameStateForStats: def __init__(self, step: int, score: float): self.current_step = step - self.game_score = score + self._game_score = ( + score # Use internal attribute name if game_score is a method + ) # Add dummy attributes expected by the check in update_worker_game_state self.grid_data = True self.shapes = True + def game_score(self) -> float: # Add method if needed by tests + return self._game_score + def test_update_and_get_worker_state(stats_actor): """Test updating and retrieving worker game states.""" @@ -273,14 +279,14 @@ def test_update_and_get_worker_state(stats_actor): latest_states = ray.get(stats_actor.get_latest_worker_states.remote()) assert worker_id in latest_states assert latest_states[worker_id].current_step == 10 - assert latest_states[worker_id].game_score == 5.0 + assert latest_states[worker_id].game_score() == 5.0 # Call method # Update state again for worker 1 ray.get(stats_actor.update_worker_game_state.remote(worker_id, state2)) latest_states = ray.get(stats_actor.get_latest_worker_states.remote()) assert worker_id in latest_states assert latest_states[worker_id].current_step == 11 - assert latest_states[worker_id].game_score == 6.0 + assert latest_states[worker_id].game_score() == 6.0 # Call method # Update state for worker 2 worker_id_2 = 2