diff --git a/Justfile b/Justfile index 831579d..f8f6009 100644 --- a/Justfile +++ b/Justfile @@ -70,6 +70,34 @@ install_training: @command -v uv > /dev/null || (echo "uv not found. Please install from https://docs.astral.sh/uv/" && exit 1) cd training && uv sync -# Run one iteration of selfplay and training, promoting the model if it improves -train: install_cargo install_training - bun run scripts/train.ts +# Generate selfplay training data (traditional MCTS, high quality for bootstrapping) +selfplay GAMES="100" PLAYOUTS="20000": install_cargo + mkdir -p training/artifacts + cargo run --release -p selfplay -- {{GAMES}} {{PLAYOUTS}} > training/artifacts/training_data.jsonl + @echo "Generated training data with auxiliary targets (ownership, score_diff)" + +# Generate selfplay training data using neural network guidance (faster, more games) +selfplay_nn GAMES="200" PLAYOUTS="800": install_cargo + mkdir -p training/artifacts + cargo run --release -p selfplay -- {{GAMES}} {{PLAYOUTS}} --nn >> training/artifacts/training_data.jsonl + @echo "Generated NN-guided training data with auxiliary targets" + +# Train the model on existing data +train_only EPOCHS="20": install_training + cd training && uv run train.py --epochs {{EPOCHS}} + +# Run one iteration: selfplay + training (traditional MCTS for bootstrapping) +train GAMES="100" PLAYOUTS="20000" EPOCHS="20": install_cargo install_training + @echo "Running selfplay with {{GAMES}} games, {{PLAYOUTS}} playouts..." + just selfplay {{GAMES}} {{PLAYOUTS}} + @echo "Training for {{EPOCHS}} epochs..." + just train_only {{EPOCHS}} + @echo "Training iteration complete!" + +# Run iterative training loop (AlphaZero-style with NN-guided selfplay) +iterate ITERATIONS="20" GAMES="200" PLAYOUTS="1000" EPOCHS="20": install_cargo install_training + cd training && uv run iterate.py --iterations {{ITERATIONS}} --games {{GAMES}} --playouts {{PLAYOUTS}} --epochs {{EPOCHS}} + +# Create blank models for debugging +blank_models: install_training + cd training && uv run create_blank_models.py diff --git a/bots/src/bin/debug_modes.rs b/bots/src/bin/debug_modes.rs index 35ea8e6..fc18e04 100644 --- a/bots/src/bin/debug_modes.rs +++ b/bots/src/bin/debug_modes.rs @@ -146,10 +146,7 @@ fn main() { ); // Test 2: Try to load NN and compare - match NeuralNet::load( - "training/artifacts/model_drafting.onnx", - "training/artifacts/model_movement.onnx", - ) { + match NeuralNet::load("training/artifacts/model.onnx") { Ok(nn) => { let nn = Arc::new(nn); diff --git a/bots/src/bin/nn_vs_mcts.rs b/bots/src/bin/nn_vs_mcts.rs index d134d78..e38d423 100644 --- a/bots/src/bin/nn_vs_mcts.rs +++ b/bots/src/bin/nn_vs_mcts.rs @@ -41,10 +41,7 @@ fn main() { None } else { eprintln!("Loading neural network..."); - match NeuralNet::load( - "training/artifacts/model_drafting.onnx", - "training/artifacts/model_movement.onnx", - ) { + match NeuralNet::load("training/artifacts/model.onnx") { Ok(model) => { eprintln!("Neural network loaded successfully"); Some(Arc::new(model)) diff --git a/bots/src/mctsbot.rs b/bots/src/mctsbot.rs index 2332aa7..e60b590 100644 --- a/bots/src/mctsbot.rs +++ b/bots/src/mctsbot.rs @@ -829,11 +829,8 @@ fn test_neural_network_guided_game() { // Load neural network let nn = Arc::new( - NeuralNet::load( - "../training/artifacts/model_drafting.onnx", - "../training/artifacts/model_movement.onnx", - ) - .expect("Failed to load neural network"), + NeuralNet::load("../training/artifacts/model.onnx") + .expect("Failed to load neural network"), ); let mut game = GameState::new_two_player::(&mut SeedableRng::seed_from_u64(42)); diff --git a/bots/src/neuralnet.rs b/bots/src/neuralnet.rs index c14e159..170d4ae 100644 --- a/bots/src/neuralnet.rs +++ b/bots/src/neuralnet.rs @@ -7,8 +7,7 @@ type Model = SimplePlan, Graph TractResult { - let drafting_model = tract_onnx::onnx() - .model_for_path(drafting_path)? + /// Load ONNX model from the given path + /// + /// The model should have: + /// - Input: features (1, 480) + /// - Outputs: drafting_policy (1, 60), movement_policy (1, 168), value (1, 1) + pub fn load(model_path: &str) -> TractResult { + let model = tract_onnx::onnx() + .model_for_path(model_path)? .with_input_fact(0, f32::fact([1, NUM_FEATURES]).into())? .into_optimized()? .into_runnable()?; - let movement_model = tract_onnx::onnx() - .model_for_path(movement_path)? - .with_input_fact(0, f32::fact([1, NUM_FEATURES]).into())? - .into_optimized()? - .into_runnable()?; + Ok(Self { model }) + } - Ok(Self { - drafting_model, - movement_model, - }) + /// Legacy method for backward compatibility with old two-file approach + /// + /// This loads the new single-file model but ignores the movement_path parameter. + /// Use `load()` instead for new code. + #[deprecated(note = "Use load() instead - only one model file is needed now")] + pub fn load_legacy(drafting_path: &str, _movement_path: &str) -> TractResult { + Self::load(drafting_path) } /// Run inference on the given game state @@ -52,22 +55,19 @@ impl NeuralNet { let input: Tensor = tract_ndarray::Array2::from_shape_vec((1, NUM_FEATURES), features)? .into(); - let model = if is_drafting { - &self.drafting_model - } else { - &self.movement_model - }; - - let outputs = model.run(tvec!(input.into()))?; + let outputs = self.model.run(tvec!(input.into()))?; - // Output 0: policy logits, Output 1: value (tanh output in [-1, 1]) - let policy_logits: Vec = outputs[0] + // Model outputs: [0] drafting_policy, [1] movement_policy, [2] value + // Select the appropriate policy based on game phase + let policy_output_idx = if is_drafting { 0 } else { 1 }; + let policy_logits: Vec = outputs[policy_output_idx] .to_array_view::()? .iter() .copied() .collect(); + // Convert tanh output [-1, 1] to probability [0, 1] - let raw_value: f32 = outputs[1].to_array_view::()?[[0, 0]]; + let raw_value: f32 = outputs[2].to_array_view::()?[[0, 0]]; let value = (raw_value + 1.0) / 2.0; Ok(NeuralNetOutput { diff --git a/papers/katago.pdf b/papers/katago.pdf new file mode 100644 index 0000000..1e7e1df Binary files /dev/null and b/papers/katago.pdf differ diff --git a/selfplay/src/main.rs b/selfplay/src/main.rs index ecef1f5..b5637c5 100644 --- a/selfplay/src/main.rs +++ b/selfplay/src/main.rs @@ -52,6 +52,16 @@ pub struct TrainingSample { pub player: usize, /// Whether this is drafting phase pub is_drafting: bool, + /// Ownership prediction target: which player owns each cell at game end + /// Array of 60 values, each in [0, 1, 2]: + /// 0 = player 0 claimed this cell + /// 1 = player 1 claimed this cell + /// 2 = neither player claimed this cell + pub ownership: Vec, + /// Score difference prediction target (from current player's perspective) + /// Stored as bin index in range [0, 184] where: + /// bin_index = (player_score - opponent_score) - (-92) = score_diff + 92 + pub score_diff: u8, } fn main() { @@ -73,10 +83,7 @@ fn main() { // Load neural network if requested let nn: Option> = if use_nn { eprintln!("Loading neural network..."); - match NeuralNet::load( - "training/artifacts/model_drafting.onnx", - "training/artifacts/model_movement.onnx", - ) { + match NeuralNet::load("training/artifacts/model.onnx") { Ok(model) => { eprintln!("Neural network loaded successfully"); Some(Arc::new(model)) @@ -404,15 +411,36 @@ fn play_game(nplayouts: usize, nn: Option>) -> GameResult { [0.0, 1.0] }; + // Compute ownership targets from final game state + // Each cell is owned by the player who claimed it (or 2 if unclaimed) + let mut ownership = vec![2u8; NUM_CELLS]; // Default: unclaimed + for cell in game.board.claimed[0].into_iter() { + ownership[cell as usize] = 0; + } + for cell in game.board.claimed[1].into_iter() { + ownership[cell as usize] = 1; + } + + // Compute score difference targets (per-player perspective) + let score_diffs: [i32; 2] = [ + scores[0] as i32 - scores[1] as i32, + scores[1] as i32 - scores[0] as i32, + ]; + // Convert pending samples to final training samples let samples: Vec = pending_samples .into_iter() - .map(|s| TrainingSample { - features: s.features, - policy: s.policy, - value: values[s.player], - player: s.player, - is_drafting: s.is_drafting, + .map(|s| { + let score_diff_bin = (score_diffs[s.player] + 92) as u8; + TrainingSample { + features: s.features, + policy: s.policy, + value: values[s.player], + player: s.player, + is_drafting: s.is_drafting, + ownership: ownership.clone(), + score_diff: score_diff_bin, + } }) .collect(); diff --git a/training/create_blank_models.py b/training/create_blank_models.py index f921276..b0e1831 100644 --- a/training/create_blank_models.py +++ b/training/create_blank_models.py @@ -48,28 +48,37 @@ class BlankHTMFNet(nn.Module): Blank neural network that outputs uniform policy and neutral value. The model ignores the input and always outputs: - - Policy: all zeros (softmax → uniform distribution) + - Drafting policy: all zeros (softmax → uniform distribution) + - Movement policy: all zeros (softmax → uniform distribution) - Value: 0 (tanh → neutral, converts to 0.5 in Rust) + - Ownership: uniform distribution over 3 classes + - Score diff: uniform distribution over 185 bins """ - def __init__(self, policy_size: int): + def __init__(self): super().__init__() - self.policy_size = policy_size # These parameters exist to give ONNX export something to work with, # but we override forward() to ignore them self.dummy_param = nn.Parameter(torch.zeros(1)) - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: batch_size = x.shape[0] - # Always output zeros for policy (softmax will make this uniform) - policy = torch.zeros(batch_size, self.policy_size) + # Always output zeros for policies (softmax will make this uniform) + drafting_policy = torch.zeros(batch_size, NUM_CELLS) + movement_policy = torch.zeros(batch_size, MOVEMENT_POLICY_SIZE) # Always output zero for value (tanh=0 → neutral → 0.5 in Rust) value = torch.zeros(batch_size, 1) - return policy, value + # Ownership: uniform over 3 classes for each cell + ownership = torch.zeros(batch_size, NUM_CELLS, 3) + # Score diff: uniform over 185 bins (range -92 to +92) + score_diff = torch.zeros(batch_size, 185) + return drafting_policy, movement_policy, value, ownership, score_diff def export_to_onnx(model: nn.Module, path: Path): - """Export model to ONNX format.""" + """Export model to ONNX format with all heads.""" model.eval() dummy_input = torch.zeros(1, NUM_FEATURES) @@ -78,11 +87,20 @@ def export_to_onnx(model: nn.Module, path: Path): dummy_input, path, input_names=["features"], - output_names=["policy", "value"], + output_names=[ + "drafting_policy", + "movement_policy", + "value", + "ownership", + "score_diff", + ], dynamic_axes={ "features": {0: "batch_size"}, - "policy": {0: "batch_size"}, + "drafting_policy": {0: "batch_size"}, + "movement_policy": {0: "batch_size"}, "value": {0: "batch_size"}, + "ownership": {0: "batch_size"}, + "score_diff": {0: "batch_size"}, }, opset_version=17, dynamo=False, @@ -92,32 +110,25 @@ def export_to_onnx(model: nn.Module, path: Path): def main(): ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True) - # Create blank drafting model - drafting_model = BlankHTMFNet(NUM_CELLS) - drafting_path = ARTIFACTS_DIR / "blank_model_drafting.onnx" - print(f"Creating blank drafting model: {drafting_path}") - export_to_onnx(drafting_model, drafting_path) - - # Create blank movement model - movement_model = BlankHTMFNet(MOVEMENT_POLICY_SIZE) - movement_path = ARTIFACTS_DIR / "blank_model_movement.onnx" - print(f"Creating blank movement model: {movement_path}") - export_to_onnx(movement_model, movement_path) + # Create blank model with both policy heads + model = BlankHTMFNet() + model_path = ARTIFACTS_DIR / "blank_model.onnx" + print(f"Creating blank model: {model_path}") + export_to_onnx(model, model_path) print() - print("Blank models created!") + print("Blank model created!") print() - print("To use them, copy to the standard names:") - print(" cp artifacts/blank_model_drafting.onnx artifacts/model_drafting.onnx") - print(" cp artifacts/blank_model_movement.onnx artifacts/model_movement.onnx") + print("To use it, copy to the standard name:") + print(" cp artifacts/blank_model.onnx artifacts/model.onnx") print() - print("Or modify nn_vs_mcts.rs to load them from the blank paths.") + print("Or modify nn_vs_mcts.rs to load it from the blank path.") print() - print("NOTE: These blank models output uniform policy (all zeros -> 1/n after softmax)") - print("and neutral value (0 tanh -> 0.5 probability).") + print("NOTE: This blank model outputs uniform policy (all zeros -> 1/n after softmax)") + print("and neutral value (0 tanh -> 0.5 probability) for both drafting and movement.") print() print("The PUCT mode uses random rollouts for leaf evaluation, so the value output") - print("from these models is NOT used. Only the policy priors are used to guide") + print("from this model is NOT used. Only the policy priors are used to guide") print("which moves to explore first.") diff --git a/training/iterate.py b/training/iterate.py index 23f50d8..c0f72a7 100644 --- a/training/iterate.py +++ b/training/iterate.py @@ -24,8 +24,7 @@ ARTIFACTS_DIR = Path("./artifacts") TRAINING_DATA = ARTIFACTS_DIR / "training_data.jsonl" MODEL_FINAL = ARTIFACTS_DIR / "model_final.pt" -ONNX_DRAFTING = ARTIFACTS_DIR / "model_drafting.onnx" -ONNX_MOVEMENT = ARTIFACTS_DIR / "model_movement.onnx" +ONNX_MODEL = ARTIFACTS_DIR / "model.onnx" ITERATIONS_DIR = ARTIFACTS_DIR / "iterations" @@ -154,7 +153,7 @@ def save_iteration(iteration: int): iter_dir = ITERATIONS_DIR / f"iter_{iteration:03d}" iter_dir.mkdir(parents=True, exist_ok=True) - for src in [MODEL_FINAL, ONNX_DRAFTING, ONNX_MOVEMENT]: + for src in [MODEL_FINAL, ONNX_MODEL]: if src.exists(): shutil.copy(src, iter_dir / src.name) @@ -189,7 +188,7 @@ def main(): # Fresh start if requested if args.fresh: print("Starting fresh - removing existing model and data...") - for f in [MODEL_FINAL, ONNX_DRAFTING, ONNX_MOVEMENT, TRAINING_DATA]: + for f in [MODEL_FINAL, ONNX_MODEL, TRAINING_DATA]: if f.exists(): f.unlink() print(f" Removed {f}") diff --git a/training/train.py b/training/train.py index 34c2adc..fee13e7 100644 --- a/training/train.py +++ b/training/train.py @@ -5,7 +5,7 @@ The network has two heads: - Policy head: probability distribution over moves - Drafting: 60 values (one per cell) - - Movement: 168 values (4 penguins × 6 directions × 7 distances) + - Movement: 168 values (4 penguins x 6 directions x 7 distances) - Value head: predicted win probability for current player Usage: @@ -68,6 +68,8 @@ def _build_valid_mask(): ARTIFACTS_DIR = Path("./artifacts") TRAINING_DATA = ARTIFACTS_DIR / "training_data.jsonl" MODEL_CHECKPOINT = ARTIFACTS_DIR / "model_final.pt" +ONNX_MODEL = ARTIFACTS_DIR / "model.onnx" +# Legacy paths for backward compatibility ONNX_DRAFTING = ARTIFACTS_DIR / "model_drafting.onnx" ONNX_MOVEMENT = ARTIFACTS_DIR / "model_movement.onnx" @@ -93,13 +95,24 @@ def __init__(self, data_path: Path): def __len__(self): return len(self.drafting_samples) + len(self.movement_samples) - def get_drafting_data(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Return all drafting data as tensors.""" + def get_drafting_data( + self, + ) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None + ]: + """Return all drafting data as tensors. + + Returns: + (features, policies, values, ownerships, score_diffs) + Ownership and score_diff may be None if not present in training data. + """ if not self.drafting_samples: return ( torch.zeros(0, NUM_FEATURES), torch.zeros(0, NUM_CELLS), torch.zeros(0, 1), + None, + None, ) features = torch.tensor( @@ -111,15 +124,43 @@ def get_drafting_data(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: values = torch.tensor( [[s["value"]] for s in self.drafting_samples], dtype=torch.float32 ) - return features, policies, values - def get_movement_data(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Return all movement data as tensors.""" + # Check if auxiliary targets are available + has_ownership = "ownership" in self.drafting_samples[0] + has_score_diff = "score_diff" in self.drafting_samples[0] + + ownerships = None + if has_ownership: + ownerships = torch.tensor( + [s["ownership"] for s in self.drafting_samples], dtype=torch.long + ) + + score_diffs = None + if has_score_diff: + score_diffs = torch.tensor( + [s["score_diff"] for s in self.drafting_samples], dtype=torch.long + ) + + return features, policies, values, ownerships, score_diffs + + def get_movement_data( + self, + ) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None + ]: + """Return all movement data as tensors. + + Returns: + (features, policies, values, ownerships, score_diffs) + Ownership and score_diff may be None if not present in training data. + """ if not self.movement_samples: return ( torch.zeros(0, NUM_FEATURES), torch.zeros(0, MOVEMENT_POLICY_SIZE), torch.zeros(0, 1), + None, + None, ) features = torch.tensor( @@ -131,7 +172,24 @@ def get_movement_data(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: values = torch.tensor( [[s["value"]] for s in self.movement_samples], dtype=torch.float32 ) - return features, policies, values + + # Check if auxiliary targets are available + has_ownership = "ownership" in self.movement_samples[0] + has_score_diff = "score_diff" in self.movement_samples[0] + + ownerships = None + if has_ownership: + ownerships = torch.tensor( + [s["ownership"] for s in self.movement_samples], dtype=torch.long + ) + + score_diffs = None + if has_score_diff: + score_diffs = torch.tensor( + [s["score_diff"] for s in self.movement_samples], dtype=torch.long + ) + + return features, policies, values, ownerships, score_diffs def features_to_grid(features: torch.Tensor) -> torch.Tensor: @@ -197,20 +255,11 @@ def forward(self, x): return out -class HTMFNet(nn.Module): - """ - Neural network for HTMF with policy and value heads. - - Architecture matches OpenSpiel's AlphaZero ResNet: - - Input: 8 channels × 8×8 grid (60 valid cells embedded in 64) - - Shared trunk: Initial conv + residual blocks - - Policy head: 1×1 conv → BN → ReLU → flatten → FC - - Value head: 1×1 conv → BN → ReLU → flatten → FC → ReLU → FC → tanh - """ +class SharedTrunk(nn.Module): + """Shared convolutional trunk for processing board state.""" - def __init__(self, policy_size: int, num_filters: int = 64, num_blocks: int = 4): + def __init__(self, num_filters: int = 64, num_blocks: int = 4): super().__init__() - self.policy_size = policy_size # Initial convolution self.input_conv = nn.Conv2d(NUM_CHANNELS, num_filters, kernel_size=3, padding=1) @@ -221,17 +270,6 @@ def __init__(self, policy_size: int, num_filters: int = 64, num_blocks: int = 4) [ConvResidualBlock(num_filters) for _ in range(num_blocks)] ) - # Policy head: 1×1 conv to reduce channels, then flatten and FC - self.policy_conv = nn.Conv2d(num_filters, 2, kernel_size=1) - self.policy_bn = nn.BatchNorm2d(2) - self.policy_fc = nn.Linear(2 * GRID_SIZE, policy_size) - - # Value head: 1×1 conv, flatten, FC layers - self.value_conv = nn.Conv2d(num_filters, 1, kernel_size=1) - self.value_bn = nn.BatchNorm2d(1) - self.value_fc1 = nn.Linear(GRID_SIZE, num_filters) - self.value_fc2 = nn.Linear(num_filters, 1) - def forward(self, x): # Convert flat features to grid: (batch, 480) -> (batch, 8, 8, 8) x = features_to_grid(x) @@ -241,18 +279,182 @@ def forward(self, x): for block in self.blocks: x = block(x) - # Policy head + return x + + +class PolicyHead(nn.Module): + """Policy head for outputting move probabilities.""" + + def __init__(self, num_filters: int, policy_size: int): + super().__init__() + self.policy_conv = nn.Conv2d(num_filters, 2, kernel_size=1) + self.policy_bn = nn.BatchNorm2d(2) + self.policy_fc = nn.Linear(2 * GRID_SIZE, policy_size) + + def forward(self, x): policy = F.relu(self.policy_bn(self.policy_conv(x))) policy = policy.view(policy.size(0), -1) # Flatten policy = self.policy_fc(policy) + return policy + + +class ValueHead(nn.Module): + """Value head for predicting win probability.""" + + def __init__(self, num_filters: int): + super().__init__() + self.value_conv = nn.Conv2d(num_filters, 1, kernel_size=1) + self.value_bn = nn.BatchNorm2d(1) + self.value_fc1 = nn.Linear(GRID_SIZE, num_filters) + self.value_fc2 = nn.Linear(num_filters, 1) - # Value head + def forward(self, x): value = F.relu(self.value_bn(self.value_conv(x))) value = value.view(value.size(0), -1) # Flatten value = F.relu(self.value_fc1(value)) value = torch.tanh(self.value_fc2(value)) + return value + + +class OwnershipHead(nn.Module): + """Ownership head for predicting which player owns each cell at game end. + + Following KataGo, this predicts per-cell ownership to provide localized + gradient feedback for credit assignment. + + Output: (batch, 60, 3) where the 3 classes are [player_0, player_1, neither] + """ + + def __init__(self, num_filters: int): + super().__init__() + # Use convolutions to preserve spatial structure + self.ownership_conv = nn.Conv2d(num_filters, 3, kernel_size=1) + self.ownership_bn = nn.BatchNorm2d(3) + + def forward(self, x): + # x is (batch, num_filters, 8, 8) + ownership = self.ownership_bn(self.ownership_conv(x)) # (batch, 3, 8, 8) + + # Extract the 60 valid cells for each of the 3 classes + batch_size = ownership.shape[0] + ownership_cells = torch.zeros( + batch_size, 3, NUM_CELLS, device=ownership.device, dtype=ownership.dtype + ) + + for cell_idx, (row, col) in enumerate(CELL_TO_GRID): + ownership_cells[:, :, cell_idx] = ownership[:, :, row, col] + + # Reshape to (batch, 60, 3) for cross-entropy loss + ownership_cells = ownership_cells.transpose(1, 2) # (batch, 60, 3) + + return ownership_cells + + +class ScoreDifferenceHead(nn.Module): + """Score difference head for predicting final score difference. + + Following KataGo, this predicts a distribution over possible score differences + to provide finer-grained learning signal than binary win/loss. + + Score range: [-92, +92] → 185 possible values (inclusive) + Output: (batch, 185) logits over score distribution + """ + + # Score difference range constants + MIN_SCORE_DIFF = -92 + MAX_SCORE_DIFF = 92 + NUM_SCORE_BINS = MAX_SCORE_DIFF - MIN_SCORE_DIFF + 1 # 185 + + def __init__(self, num_filters: int): + super().__init__() + self.score_conv = nn.Conv2d(num_filters, 1, kernel_size=1) + self.score_bn = nn.BatchNorm2d(1) + self.score_fc1 = nn.Linear(GRID_SIZE, num_filters) + self.score_fc2 = nn.Linear(num_filters, self.NUM_SCORE_BINS) + + def forward(self, x): + score = F.relu(self.score_bn(self.score_conv(x))) + score = score.view(score.size(0), -1) # Flatten + score = F.relu(self.score_fc1(score)) + score = self.score_fc2(score) # (batch, 185) logits + return score + + +class HTMFNet(nn.Module): + """ + Neural network for HTMF with shared trunk and multiple heads. + + Architecture: + - Input: 8 channels x 8x8 grid (60 valid cells embedded in 64) + - Shared trunk: Initial conv + residual blocks (shared between all heads) + - Drafting policy head: outputs 60 cell probabilities + - Movement policy head: outputs 168 move probabilities (4 penguins × 6 dirs × 7 dists) + - Value head: predicts win probability + - Ownership head: predicts per-cell ownership at game end (60 cells × 3 classes) + - Score difference head: predicts final score difference distribution (185 bins) + """ + + def __init__(self, num_filters: int = 64, num_blocks: int = 4): + super().__init__() + + # Shared convolutional trunk + self.trunk = SharedTrunk(num_filters, num_blocks) + + # Policy heads for drafting and movement + self.drafting_policy = PolicyHead(num_filters, NUM_CELLS) + self.movement_policy = PolicyHead(num_filters, MOVEMENT_POLICY_SIZE) + + # Main value head + self.value = ValueHead(num_filters) - return policy, value + # Auxiliary heads (KataGo-style) + self.ownership = OwnershipHead(num_filters) + self.score_diff = ScoreDifferenceHead(num_filters) + + def forward(self, x, is_drafting: bool | None = None): + """Forward pass. + + Args: + x: Input features (batch, 480) + is_drafting: If True, return only drafting policy; if False, return only movement policy; + if None, return both policies (for ONNX export) + + Returns: + - If is_drafting is True/False: (policy, value, ownership, score_diff) tuple + - If is_drafting is None: (drafting_policy, movement_policy, value, ownership, score_diff) tuple + """ + # Shared trunk + trunk_out = self.trunk(x) + + # Auxiliary heads (always computed) + value = self.value(trunk_out) + ownership = self.ownership(trunk_out) + score_diff = self.score_diff(trunk_out) + + # Select appropriate policy head(s) + if is_drafting is None: + # Return both policies (for ONNX export) + drafting_policy = self.drafting_policy(trunk_out) + movement_policy = self.movement_policy(trunk_out) + return drafting_policy, movement_policy, value, ownership, score_diff + elif is_drafting: + policy = self.drafting_policy(trunk_out) + else: + policy = self.movement_policy(trunk_out) + + return policy, value, ownership, score_diff + + +def score_diff_to_index(score_diff: int) -> int: + """Convert score difference to bin index. + + Args: + score_diff: Score difference in range [-92, 92] + + Returns: + Bin index in range [0, 184] + """ + return score_diff - ScoreDifferenceHead.MIN_SCORE_DIFF def train_model( @@ -260,15 +462,22 @@ def train_model( features: torch.Tensor, policies: torch.Tensor, values: torch.Tensor, + ownerships: torch.Tensor | None, + score_diffs: torch.Tensor | None, optimizer: torch.optim.Optimizer, device: torch.device, + is_drafting: bool, batch_size: int = 256, -) -> tuple[float, float]: - """Train the model for one epoch and return (policy_loss, value_loss).""" +) -> tuple[float, float, float, float]: + """Train the model for one epoch. + + Returns: + (policy_loss, value_loss, ownership_loss, score_diff_loss) + """ model.train() if len(features) == 0: - return 0.0, 0.0 + return 0.0, 0.0, 0.0, 0.0 # Shuffle data perm = torch.randperm(len(features)) @@ -277,8 +486,16 @@ def train_model( # Convert values from [0, 1] to [-1, 1] for tanh output values = (values[perm] * 2 - 1).to(device) + # Auxiliary targets (may be None if not available) + if ownerships is not None: + ownerships = ownerships[perm].to(device) + if score_diffs is not None: + score_diffs = score_diffs[perm].to(device) + total_policy_loss = 0.0 total_value_loss = 0.0 + total_ownership_loss = 0.0 + total_score_diff_loss = 0.0 num_batches = 0 for i in range(0, len(features), batch_size): @@ -288,7 +505,9 @@ def train_model( optimizer.zero_grad() - pred_policy, pred_value = model(batch_features) + pred_policy, pred_value, pred_ownership, pred_score_diff = model( + batch_features, is_drafting + ) # Policy loss: cross-entropy with target distribution # Target policies are probability distributions from MCTS @@ -299,33 +518,110 @@ def train_model( # Value loss: MSE value_loss = F.mse_loss(pred_value, batch_values) - # Combined loss - loss = policy_loss + value_loss + # Initialize auxiliary losses + ownership_loss = torch.tensor(0.0, device=device) + score_diff_loss = torch.tensor(0.0, device=device) + + # Ownership loss: per-cell cross-entropy + if ownerships is not None: + batch_ownerships = ownerships[i : i + batch_size] + # pred_ownership: (batch, 60, 3) + # batch_ownerships: (batch, 60) with class indices [0, 1, 2] + ownership_loss = F.cross_entropy( + pred_ownership.reshape(-1, 3), # (batch * 60, 3) + batch_ownerships.reshape(-1).long(), # (batch * 60,) + ) + + # Score difference loss: PDF + CDF loss (KataGo-style) + if score_diffs is not None: + batch_score_diffs = score_diffs[i : i + batch_size] + # pred_score_diff: (batch, 185) logits + # batch_score_diffs: (batch,) with bin indices [0, 184] + + # PDF loss: standard cross-entropy with one-hot target + pdf_loss = F.cross_entropy(pred_score_diff, batch_score_diffs.long()) + + # CDF loss: penalize cumulative distribution error + # Compute predicted and target CDFs + pred_probs = F.softmax(pred_score_diff, dim=1) + pred_cdf = torch.cumsum(pred_probs, dim=1) + + # Create target CDF (step function at true score) + target_cdf = torch.zeros_like(pred_cdf) + for j, score_idx in enumerate(batch_score_diffs): + target_cdf[j, int(score_idx) :] = 1.0 + + # MSE between CDFs + cdf_loss = F.mse_loss(pred_cdf, target_cdf) + + # Combine PDF and CDF losses (equal weighting as in KataGo) + score_diff_loss = pdf_loss + cdf_loss + + # Combined loss with auxiliary targets + # Weight auxiliary losses lower to avoid overwhelming main objectives + loss = ( + policy_loss + + value_loss + + 0.5 * ownership_loss + + 0.5 * score_diff_loss + ) loss.backward() optimizer.step() total_policy_loss += policy_loss.item() total_value_loss += value_loss.item() + total_ownership_loss += ownership_loss.item() + total_score_diff_loss += score_diff_loss.item() num_batches += 1 - return total_policy_loss / max(1, num_batches), total_value_loss / max( - 1, num_batches + return ( + total_policy_loss / max(1, num_batches), + total_value_loss / max(1, num_batches), + total_ownership_loss / max(1, num_batches), + total_score_diff_loss / max(1, num_batches), ) def export_to_onnx(model: HTMFNet, path: Path): - """Export model to ONNX format.""" + """Export model to ONNX format with all heads. + + The exported model has: + - Input: features (batch, 480) + - Outputs: + - drafting_policy (batch, 60) + - movement_policy (batch, 168) + - value (batch, 1) + - ownership (batch, 60, 3) - per-cell ownership prediction + - score_diff (batch, 185) - score difference distribution + """ # Move model to CPU for ONNX export (required for compatibility) model = model.cpu() model.eval() dummy_input = torch.zeros(1, NUM_FEATURES) + # Create a wrapper that always outputs all heads + class ONNXWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + return self.model(x, is_drafting=None) + + wrapper = ONNXWrapper(model) + torch.onnx.export( - model, + wrapper, dummy_input, path, input_names=["features"], - output_names=["policy", "value"], + output_names=[ + "drafting_policy", + "movement_policy", + "value", + "ownership", + "score_diff", + ], dynamo=False, # Use legacy exporter for Python 3.14 compatibility ) @@ -368,20 +664,27 @@ def main(): print(f"Loading training data from {TRAINING_DATA}...") dataset = HTMFDataset(TRAINING_DATA) - drafting_features, drafting_policies, drafting_values = dataset.get_drafting_data() - movement_features, movement_policies, movement_values = dataset.get_movement_data() + ( + drafting_features, + drafting_policies, + drafting_values, + drafting_ownerships, + drafting_score_diffs, + ) = dataset.get_drafting_data() + ( + movement_features, + movement_policies, + movement_values, + movement_ownerships, + movement_score_diffs, + ) = dataset.get_movement_data() print(f"Total samples: {len(dataset)}") + if drafting_ownerships is not None or movement_ownerships is not None: + print("Auxiliary targets detected: ownership, score_diff") - # Create models - drafting_model = HTMFNet( - policy_size=NUM_CELLS, num_filters=args.num_filters, num_blocks=args.num_blocks - ).to(device) - movement_model = HTMFNet( - policy_size=MOVEMENT_POLICY_SIZE, - num_filters=args.num_filters, - num_blocks=args.num_blocks, - ).to(device) + # Create single shared model + model = HTMFNet(num_filters=args.num_filters, num_blocks=args.num_blocks).to(device) # Load existing weights if available if MODEL_CHECKPOINT.exists(): @@ -389,13 +692,19 @@ def main(): checkpoint = torch.load( MODEL_CHECKPOINT, map_location=device, weights_only=True ) - drafting_model.load_state_dict(checkpoint["drafting_model"]) - movement_model.load_state_dict(checkpoint["movement_model"]) - print("Loaded existing model weights") + # Handle both old format (direct state_dict) and new format (dict with "model" key) + if "model" in checkpoint: + state_dict = checkpoint["model"] + else: + # Old format: checkpoint is the state_dict directly + state_dict = checkpoint - # Optimizers - drafting_optimizer = torch.optim.Adam(drafting_model.parameters(), lr=args.lr) - movement_optimizer = torch.optim.Adam(movement_model.parameters(), lr=args.lr) + # Load weights, allowing for missing keys (e.g., new auxiliary heads) + model.load_state_dict(state_dict, strict=False) + print("Loaded existing model weights (auxiliary heads will be randomly initialized if not present)") + + # Single optimizer for the entire model + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) print(f"\nTraining for {args.epochs} epochs...") print(f"Learning rate: {args.lr}") @@ -403,52 +712,59 @@ def main(): print() for epoch in range(1, args.epochs + 1): - # Train drafting model - d_policy_loss, d_value_loss = train_model( - drafting_model, + # Train on drafting data + d_policy_loss, d_value_loss, d_ownership_loss, d_score_diff_loss = train_model( + model, drafting_features, drafting_policies, drafting_values, - drafting_optimizer, + drafting_ownerships, + drafting_score_diffs, + optimizer, device, - args.batch_size, + is_drafting=True, + batch_size=args.batch_size, ) - # Train movement model - m_policy_loss, m_value_loss = train_model( - movement_model, + # Train on movement data + m_policy_loss, m_value_loss, m_ownership_loss, m_score_diff_loss = train_model( + model, movement_features, movement_policies, movement_values, - movement_optimizer, + movement_ownerships, + movement_score_diffs, + optimizer, device, - args.batch_size, + is_drafting=False, + batch_size=args.batch_size, ) - print( - f"Epoch {epoch:3d}/{args.epochs}: " - f"Drafting [P: {d_policy_loss:.4f}, V: {d_value_loss:.4f}] | " - f"Movement [P: {m_policy_loss:.4f}, V: {m_value_loss:.4f}]" - ) + # Print losses (only show auxiliary losses if they're non-zero) + status = f"Epoch {epoch:3d}/{args.epochs}: " + status += f"Drafting [P: {d_policy_loss:.4f}, V: {d_value_loss:.4f}" + if d_ownership_loss > 0: + status += f", O: {d_ownership_loss:.4f}, S: {d_score_diff_loss:.4f}" + status += "] | Movement [P: {m_policy_loss:.4f}, V: {m_value_loss:.4f}" + if m_ownership_loss > 0: + status += f", O: {m_ownership_loss:.4f}, S: {m_score_diff_loss:.4f}" + status += "]" + print(status) # Save PyTorch checkpoint print(f"\nSaving model to {MODEL_CHECKPOINT}...") torch.save( { - "drafting_model": drafting_model.state_dict(), - "movement_model": movement_model.state_dict(), + "model": model.state_dict(), "num_filters": args.num_filters, "num_blocks": args.num_blocks, }, MODEL_CHECKPOINT, ) - # Export to ONNX - print(f"Exporting drafting model to {ONNX_DRAFTING}...") - export_to_onnx(drafting_model, ONNX_DRAFTING) - - print(f"Exporting movement model to {ONNX_MOVEMENT}...") - export_to_onnx(movement_model, ONNX_MOVEMENT) + # Export to ONNX (single file with both policy heads) + print(f"Exporting model to {ONNX_MODEL}...") + export_to_onnx(model, ONNX_MODEL) print("\nTraining complete!") return 0