Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,34 @@ install_training:
@command -v uv > /dev/null || (echo "uv not found. Please install from https://docs.astral.sh/uv/" && exit 1)
cd training && uv sync

# Run one iteration of selfplay and training, promoting the model if it improves
train: install_cargo install_training
bun run scripts/train.ts
# Generate selfplay training data (traditional MCTS, high quality for bootstrapping)
selfplay GAMES="100" PLAYOUTS="20000": install_cargo
mkdir -p training/artifacts
cargo run --release -p selfplay -- {{GAMES}} {{PLAYOUTS}} > training/artifacts/training_data.jsonl
@echo "Generated training data with auxiliary targets (ownership, score_diff)"

# Generate selfplay training data using neural network guidance (faster, more games)
selfplay_nn GAMES="200" PLAYOUTS="800": install_cargo
mkdir -p training/artifacts
cargo run --release -p selfplay -- {{GAMES}} {{PLAYOUTS}} --nn >> training/artifacts/training_data.jsonl
@echo "Generated NN-guided training data with auxiliary targets"

# Train the model on existing data
train_only EPOCHS="20": install_training
cd training && uv run train.py --epochs {{EPOCHS}}

# Run one iteration: selfplay + training (traditional MCTS for bootstrapping)
train GAMES="100" PLAYOUTS="20000" EPOCHS="20": install_cargo install_training
@echo "Running selfplay with {{GAMES}} games, {{PLAYOUTS}} playouts..."
just selfplay {{GAMES}} {{PLAYOUTS}}
@echo "Training for {{EPOCHS}} epochs..."
just train_only {{EPOCHS}}
@echo "Training iteration complete!"

# Run iterative training loop (AlphaZero-style with NN-guided selfplay)
iterate ITERATIONS="20" GAMES="200" PLAYOUTS="1000" EPOCHS="20": install_cargo install_training
cd training && uv run iterate.py --iterations {{ITERATIONS}} --games {{GAMES}} --playouts {{PLAYOUTS}} --epochs {{EPOCHS}}

# Create blank models for debugging
blank_models: install_training
cd training && uv run create_blank_models.py
5 changes: 1 addition & 4 deletions bots/src/bin/debug_modes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,7 @@ fn main() {
);

// Test 2: Try to load NN and compare
match NeuralNet::load(
"training/artifacts/model_drafting.onnx",
"training/artifacts/model_movement.onnx",
) {
match NeuralNet::load("training/artifacts/model.onnx") {
Ok(nn) => {
let nn = Arc::new(nn);

Expand Down
5 changes: 1 addition & 4 deletions bots/src/bin/nn_vs_mcts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ fn main() {
None
} else {
eprintln!("Loading neural network...");
match NeuralNet::load(
"training/artifacts/model_drafting.onnx",
"training/artifacts/model_movement.onnx",
) {
match NeuralNet::load("training/artifacts/model.onnx") {
Ok(model) => {
eprintln!("Neural network loaded successfully");
Some(Arc::new(model))
Expand Down
7 changes: 2 additions & 5 deletions bots/src/mctsbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -829,11 +829,8 @@ fn test_neural_network_guided_game() {

// Load neural network
let nn = Arc::new(
NeuralNet::load(
"../training/artifacts/model_drafting.onnx",
"../training/artifacts/model_movement.onnx",
)
.expect("Failed to load neural network"),
NeuralNet::load("../training/artifacts/model.onnx")
.expect("Failed to load neural network"),
);

let mut game = GameState::new_two_player::<StdRng>(&mut SeedableRng::seed_from_u64(42));
Expand Down
50 changes: 25 additions & 25 deletions bots/src/neuralnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ type Model = SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn Ty

/// Neural network wrapper for policy and value prediction
pub struct NeuralNet {
drafting_model: Model,
movement_model: Model,
model: Model,
}

/// Output from neural network inference
Expand All @@ -20,24 +19,28 @@ pub struct NeuralNetOutput {
}

impl NeuralNet {
/// Load ONNX models from the given paths
pub fn load(drafting_path: &str, movement_path: &str) -> TractResult<Self> {
let drafting_model = tract_onnx::onnx()
.model_for_path(drafting_path)?
/// Load ONNX model from the given path
///
/// The model should have:
/// - Input: features (1, 480)
/// - Outputs: drafting_policy (1, 60), movement_policy (1, 168), value (1, 1)
pub fn load(model_path: &str) -> TractResult<Self> {
let model = tract_onnx::onnx()
.model_for_path(model_path)?
.with_input_fact(0, f32::fact([1, NUM_FEATURES]).into())?
.into_optimized()?
.into_runnable()?;

let movement_model = tract_onnx::onnx()
.model_for_path(movement_path)?
.with_input_fact(0, f32::fact([1, NUM_FEATURES]).into())?
.into_optimized()?
.into_runnable()?;
Ok(Self { model })
}

Ok(Self {
drafting_model,
movement_model,
})
/// Legacy method for backward compatibility with old two-file approach
///
/// This loads the new single-file model but ignores the movement_path parameter.
/// Use `load()` instead for new code.
#[deprecated(note = "Use load() instead - only one model file is needed now")]
pub fn load_legacy(drafting_path: &str, _movement_path: &str) -> TractResult<Self> {
Self::load(drafting_path)
}

/// Run inference on the given game state
Expand All @@ -52,22 +55,19 @@ impl NeuralNet {
let input: Tensor = tract_ndarray::Array2::from_shape_vec((1, NUM_FEATURES), features)?
.into();

let model = if is_drafting {
&self.drafting_model
} else {
&self.movement_model
};

let outputs = model.run(tvec!(input.into()))?;
let outputs = self.model.run(tvec!(input.into()))?;

// Output 0: policy logits, Output 1: value (tanh output in [-1, 1])
let policy_logits: Vec<f32> = outputs[0]
// Model outputs: [0] drafting_policy, [1] movement_policy, [2] value
// Select the appropriate policy based on game phase
let policy_output_idx = if is_drafting { 0 } else { 1 };
let policy_logits: Vec<f32> = outputs[policy_output_idx]
.to_array_view::<f32>()?
.iter()
.copied()
.collect();

// Convert tanh output [-1, 1] to probability [0, 1]
let raw_value: f32 = outputs[1].to_array_view::<f32>()?[[0, 0]];
let raw_value: f32 = outputs[2].to_array_view::<f32>()?[[0, 0]];
let value = (raw_value + 1.0) / 2.0;

Ok(NeuralNetOutput {
Expand Down
Binary file added papers/katago.pdf
Binary file not shown.
48 changes: 38 additions & 10 deletions selfplay/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ pub struct TrainingSample {
pub player: usize,
/// Whether this is drafting phase
pub is_drafting: bool,
/// Ownership prediction target: which player owns each cell at game end
/// Array of 60 values, each in [0, 1, 2]:
/// 0 = player 0 claimed this cell
/// 1 = player 1 claimed this cell
/// 2 = neither player claimed this cell
pub ownership: Vec<u8>,
/// Score difference prediction target (from current player's perspective)
/// Stored as bin index in range [0, 184] where:
/// bin_index = (player_score - opponent_score) - (-92) = score_diff + 92
pub score_diff: u8,
}

fn main() {
Expand All @@ -73,10 +83,7 @@ fn main() {
// Load neural network if requested
let nn: Option<Arc<NeuralNet>> = if use_nn {
eprintln!("Loading neural network...");
match NeuralNet::load(
"training/artifacts/model_drafting.onnx",
"training/artifacts/model_movement.onnx",
) {
match NeuralNet::load("training/artifacts/model.onnx") {
Ok(model) => {
eprintln!("Neural network loaded successfully");
Some(Arc::new(model))
Expand Down Expand Up @@ -404,15 +411,36 @@ fn play_game(nplayouts: usize, nn: Option<Arc<NeuralNet>>) -> GameResult {
[0.0, 1.0]
};

// Compute ownership targets from final game state
// Each cell is owned by the player who claimed it (or 2 if unclaimed)
let mut ownership = vec![2u8; NUM_CELLS]; // Default: unclaimed
for cell in game.board.claimed[0].into_iter() {
ownership[cell as usize] = 0;
}
for cell in game.board.claimed[1].into_iter() {
ownership[cell as usize] = 1;
}

// Compute score difference targets (per-player perspective)
let score_diffs: [i32; 2] = [
scores[0] as i32 - scores[1] as i32,
scores[1] as i32 - scores[0] as i32,
];

// Convert pending samples to final training samples
let samples: Vec<TrainingSample> = pending_samples
.into_iter()
.map(|s| TrainingSample {
features: s.features,
policy: s.policy,
value: values[s.player],
player: s.player,
is_drafting: s.is_drafting,
.map(|s| {
let score_diff_bin = (score_diffs[s.player] + 92) as u8;
TrainingSample {
features: s.features,
policy: s.policy,
value: values[s.player],
player: s.player,
is_drafting: s.is_drafting,
ownership: ownership.clone(),
score_diff: score_diff_bin,
}
})
.collect();

Expand Down
69 changes: 40 additions & 29 deletions training/create_blank_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,37 @@ class BlankHTMFNet(nn.Module):
Blank neural network that outputs uniform policy and neutral value.

The model ignores the input and always outputs:
- Policy: all zeros (softmax → uniform distribution)
- Drafting policy: all zeros (softmax → uniform distribution)
- Movement policy: all zeros (softmax → uniform distribution)
- Value: 0 (tanh → neutral, converts to 0.5 in Rust)
- Ownership: uniform distribution over 3 classes
- Score diff: uniform distribution over 185 bins
"""

def __init__(self, policy_size: int):
def __init__(self):
super().__init__()
self.policy_size = policy_size
# These parameters exist to give ONNX export something to work with,
# but we override forward() to ignore them
self.dummy_param = nn.Parameter(torch.zeros(1))

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def forward(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = x.shape[0]
# Always output zeros for policy (softmax will make this uniform)
policy = torch.zeros(batch_size, self.policy_size)
# Always output zeros for policies (softmax will make this uniform)
drafting_policy = torch.zeros(batch_size, NUM_CELLS)
movement_policy = torch.zeros(batch_size, MOVEMENT_POLICY_SIZE)
# Always output zero for value (tanh=0 → neutral → 0.5 in Rust)
value = torch.zeros(batch_size, 1)
return policy, value
# Ownership: uniform over 3 classes for each cell
ownership = torch.zeros(batch_size, NUM_CELLS, 3)
# Score diff: uniform over 185 bins (range -92 to +92)
score_diff = torch.zeros(batch_size, 185)
return drafting_policy, movement_policy, value, ownership, score_diff


def export_to_onnx(model: nn.Module, path: Path):
"""Export model to ONNX format."""
"""Export model to ONNX format with all heads."""
model.eval()
dummy_input = torch.zeros(1, NUM_FEATURES)

Expand All @@ -78,11 +87,20 @@ def export_to_onnx(model: nn.Module, path: Path):
dummy_input,
path,
input_names=["features"],
output_names=["policy", "value"],
output_names=[
"drafting_policy",
"movement_policy",
"value",
"ownership",
"score_diff",
],
dynamic_axes={
"features": {0: "batch_size"},
"policy": {0: "batch_size"},
"drafting_policy": {0: "batch_size"},
"movement_policy": {0: "batch_size"},
"value": {0: "batch_size"},
"ownership": {0: "batch_size"},
"score_diff": {0: "batch_size"},
},
opset_version=17,
dynamo=False,
Expand All @@ -92,32 +110,25 @@ def export_to_onnx(model: nn.Module, path: Path):
def main():
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

# Create blank drafting model
drafting_model = BlankHTMFNet(NUM_CELLS)
drafting_path = ARTIFACTS_DIR / "blank_model_drafting.onnx"
print(f"Creating blank drafting model: {drafting_path}")
export_to_onnx(drafting_model, drafting_path)

# Create blank movement model
movement_model = BlankHTMFNet(MOVEMENT_POLICY_SIZE)
movement_path = ARTIFACTS_DIR / "blank_model_movement.onnx"
print(f"Creating blank movement model: {movement_path}")
export_to_onnx(movement_model, movement_path)
# Create blank model with both policy heads
model = BlankHTMFNet()
model_path = ARTIFACTS_DIR / "blank_model.onnx"
print(f"Creating blank model: {model_path}")
export_to_onnx(model, model_path)

print()
print("Blank models created!")
print("Blank model created!")
print()
print("To use them, copy to the standard names:")
print(" cp artifacts/blank_model_drafting.onnx artifacts/model_drafting.onnx")
print(" cp artifacts/blank_model_movement.onnx artifacts/model_movement.onnx")
print("To use it, copy to the standard name:")
print(" cp artifacts/blank_model.onnx artifacts/model.onnx")
print()
print("Or modify nn_vs_mcts.rs to load them from the blank paths.")
print("Or modify nn_vs_mcts.rs to load it from the blank path.")
print()
print("NOTE: These blank models output uniform policy (all zeros -> 1/n after softmax)")
print("and neutral value (0 tanh -> 0.5 probability).")
print("NOTE: This blank model outputs uniform policy (all zeros -> 1/n after softmax)")
print("and neutral value (0 tanh -> 0.5 probability) for both drafting and movement.")
print()
print("The PUCT mode uses random rollouts for leaf evaluation, so the value output")
print("from these models is NOT used. Only the policy priors are used to guide")
print("from this model is NOT used. Only the policy priors are used to guide")
print("which moves to explore first.")


Expand Down
7 changes: 3 additions & 4 deletions training/iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
ARTIFACTS_DIR = Path("./artifacts")
TRAINING_DATA = ARTIFACTS_DIR / "training_data.jsonl"
MODEL_FINAL = ARTIFACTS_DIR / "model_final.pt"
ONNX_DRAFTING = ARTIFACTS_DIR / "model_drafting.onnx"
ONNX_MOVEMENT = ARTIFACTS_DIR / "model_movement.onnx"
ONNX_MODEL = ARTIFACTS_DIR / "model.onnx"
ITERATIONS_DIR = ARTIFACTS_DIR / "iterations"


Expand Down Expand Up @@ -154,7 +153,7 @@ def save_iteration(iteration: int):
iter_dir = ITERATIONS_DIR / f"iter_{iteration:03d}"
iter_dir.mkdir(parents=True, exist_ok=True)

for src in [MODEL_FINAL, ONNX_DRAFTING, ONNX_MOVEMENT]:
for src in [MODEL_FINAL, ONNX_MODEL]:
if src.exists():
shutil.copy(src, iter_dir / src.name)

Expand Down Expand Up @@ -189,7 +188,7 @@ def main():
# Fresh start if requested
if args.fresh:
print("Starting fresh - removing existing model and data...")
for f in [MODEL_FINAL, ONNX_DRAFTING, ONNX_MOVEMENT, TRAINING_DATA]:
for f in [MODEL_FINAL, ONNX_MODEL, TRAINING_DATA]:
if f.exists():
f.unlink()
print(f" Removed {f}")
Expand Down
Loading