diff --git a/.gitignore b/.gitignore index f9082380e..5e1dbc6eb 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,4 @@ pufferlib/ocean/impulse_wars/*-release/ pufferlib/ocean/impulse_wars/debug-*/ pufferlib/ocean/impulse_wars/release-*/ pufferlib/ocean/impulse_wars/benchmark/ +pufferlib/ocean/dogfight/dogfight_test diff --git a/pufferlib/checkpoint_queue.py b/pufferlib/checkpoint_queue.py new file mode 100644 index 000000000..e910c1aa8 --- /dev/null +++ b/pufferlib/checkpoint_queue.py @@ -0,0 +1,191 @@ +"""Checkpoint Queue for Self-Play Training. + +Manages a queue of policy checkpoints where the opponent is always N checkpoints +behind the learner. This creates a stable skill gap and natural curriculum. + +Training Flow: + Stages 0-9: Autopilot opponent (curriculum) + Stage 10: Save checkpoint A (milestone) + Stages 10-19: Continue curriculum with autopilot + Stage 20: Save checkpoint B (milestone), START SELF-PLAY vs A + Dominate: Save checkpoint C, upgrade opponent to B + Dominate: Save checkpoint D, upgrade opponent to C + ...and so on (opponent always `lag` checkpoints behind) + +Lag Semantics: + lag=1 means "2nd newest" (skip 1 checkpoint): + - Queue: [A, B, C] with lag=1 -> opponent uses B (index -2) + - Queue: [A, B, C, D] with lag=1 -> opponent uses C (index -2) +""" +import os +import shutil +from dataclasses import dataclass, field +from typing import List, Optional +import torch + + +@dataclass +class QueueEntry: + """A checkpoint in the queue.""" + path: str # Checkpoint file path + step: int # Global step when saved + stage: float # Curriculum stage when saved + tag: str # "stage10", "stage20", or "selfplay_N" + + def is_milestone(self) -> bool: + """Return True if this is a milestone checkpoint (stage10/stage20).""" + return self.tag in ("stage10", "stage20") + + +class CheckpointQueue: + """Manages checkpoint queue for self-play training. + + Checkpoints are saved when the learner dominates the opponent (exceeds + perf_threshold). The opponent is loaded from an older checkpoint in the + queue (determined by lag parameter). + + Milestone checkpoints (stage10, stage20) are never pruned. + """ + + def __init__(self, save_dir: str, max_checkpoints: int = 20): + """Initialize checkpoint queue. + + Args: + save_dir: Directory to store checkpoint files + max_checkpoints: Maximum selfplay checkpoints to keep (milestones always kept) + """ + self.save_dir = save_dir + self.max_checkpoints = max_checkpoints + self.checkpoints: List[QueueEntry] = [] + + # Create save directory if needed + os.makedirs(save_dir, exist_ok=True) + + print(f'[CHECKPOINT-QUEUE] Initialized: save_dir={save_dir}, max={max_checkpoints}') + + def save(self, policy, step: int, stage: float, tag: str) -> str: + """Save checkpoint and add to queue. + + Args: + policy: PyTorch policy module to save + step: Current global step + stage: Current curriculum stage + tag: Checkpoint tag ("stage10", "stage20", or "selfplay_N") + + Returns: + Path to saved checkpoint file + """ + # Generate filename + filename = f"checkpoint_{tag}_step{step}.pt" + path = os.path.join(self.save_dir, filename) + + # Save checkpoint + torch.save({ + 'policy_state_dict': policy.state_dict(), + 'step': step, + 'stage': stage, + 'tag': tag, + }, path) + + # Add to queue + entry = QueueEntry(path=path, step=step, stage=stage, tag=tag) + self.checkpoints.append(entry) + + print(f'[CHECKPOINT-QUEUE] Saved {tag} at step {step}: {path}') + + # Prune old checkpoints if needed + self._prune_old_checkpoints() + + return path + + def get_opponent(self, lag: int = 1) -> Optional[str]: + """Get checkpoint path for opponent. + + Args: + lag: How many positions behind the latest (1=2nd newest, index -2) + + Returns: + Path to opponent checkpoint, or None if queue too small + """ + if len(self.checkpoints) < lag + 1: + return None + + # lag=1 means index -2 (2nd newest) + index = -(lag + 1) + return self.checkpoints[index].path + + def get_opponent_entry(self, lag: int = 1) -> Optional[QueueEntry]: + """Get full QueueEntry for opponent. + + Args: + lag: How many positions behind the latest (1=2nd newest, index -2) + + Returns: + QueueEntry for opponent, or None if queue too small + """ + if len(self.checkpoints) < lag + 1: + return None + + index = -(lag + 1) + return self.checkpoints[index] + + def should_upgrade(self, current_opponent_path: Optional[str], lag: int) -> Optional[str]: + """Check if opponent should be upgraded to newer checkpoint. + + Args: + current_opponent_path: Path to current opponent checkpoint + lag: Desired lag positions behind latest + + Returns: + New opponent path if upgrade needed, None otherwise + """ + new_path = self.get_opponent(lag) + + if new_path is None: + return None + + if new_path != current_opponent_path: + return new_path + + return None + + def _prune_old_checkpoints(self): + """Remove oldest selfplay checkpoints, keeping milestones forever.""" + # Count selfplay checkpoints (not milestones) + selfplay_checkpoints = [c for c in self.checkpoints if not c.is_milestone()] + + # Reserve 2 slots for milestones (stage10, stage20) + max_selfplay = self.max_checkpoints - 2 + + while len(selfplay_checkpoints) > max_selfplay: + # Find oldest selfplay checkpoint + oldest = selfplay_checkpoints.pop(0) + + # Remove file + if os.path.exists(oldest.path): + try: + os.remove(oldest.path) + print(f'[CHECKPOINT-QUEUE] Pruned old checkpoint: {oldest.path}') + except OSError as e: + print(f'[CHECKPOINT-QUEUE] Warning: Could not remove {oldest.path}: {e}') + + # Remove from main list + self.checkpoints.remove(oldest) + + def __len__(self) -> int: + """Return number of checkpoints in queue.""" + return len(self.checkpoints) + + def __repr__(self) -> str: + """Return string representation of queue.""" + tags = [c.tag for c in self.checkpoints] + return f"CheckpointQueue({tags})" + + def get_queue_state(self) -> dict: + """Get serializable state of the queue for logging/debugging.""" + return { + 'num_checkpoints': len(self.checkpoints), + 'tags': [c.tag for c in self.checkpoints], + 'steps': [c.step for c in self.checkpoints], + 'milestones': [c.tag for c in self.checkpoints if c.is_milestone()], + } diff --git a/pufferlib/config/default.ini b/pufferlib/config/default.ini index 6073c651e..ec75c61b4 100644 --- a/pufferlib/config/default.ini +++ b/pufferlib/config/default.ini @@ -18,7 +18,7 @@ seed = 42 [rnn] [train] -name = pufferai +name = pufferai project = ablations seed = 42 @@ -28,40 +28,40 @@ device = cuda optimizer = muon anneal_lr = True precision = float32 -total_timesteps = 10_000_000 -learning_rate = 0.015 -gamma = 0.995 -gae_lambda = 0.90 -update_epochs = 1 -clip_coef = 0.2 -vf_coef = 2.0 -vf_clip_coef = 0.2 -max_grad_norm = 1.5 -ent_coef = 0.001 -adam_beta1 = 0.95 -adam_beta2 = 0.999 -adam_eps = 1e-12 +total_timesteps = 400_000_000 +learning_rate = 0.0003812 +gamma = 0.9903 +gae_lambda = 0.9934 +update_epochs = 4 +clip_coef = 0.2576 +vf_coef = 4.034 +vf_clip_coef = 4.663 +max_grad_norm = 1.501 +ent_coef = 0.008355 +adam_beta1 = 0.8453 +adam_beta2 = 1 +adam_eps = 2.72e-05 data_dir = experiments checkpoint_interval = 200 batch_size = auto -minibatch_size = 8192 +minibatch_size = 32768 # Accumulate gradients above this size -max_minibatch_size = 32768 +max_minibatch_size = 65536 bptt_horizon = 64 compile = False compile_mode = max-autotune-no-cudagraphs compile_fullgraph = True -vtrace_rho_clip = 1.0 -vtrace_c_clip = 1.0 +vtrace_rho_clip = 2.91 +vtrace_c_clip = 3.085 -prio_alpha = 0.8 -prio_beta0 = 0.2 +prio_alpha = 0.9724 +prio_beta0 = 0.6139 [sweep] -method = Protein +method = Protein metric = score goal = maximize downsample = 5 @@ -75,26 +75,11 @@ prune_pareto = True #mean = 8 #scale = auto -# TODO: Elim from base -[sweep.train.total_timesteps] -distribution = log_normal -min = 3e7 -max = 1e10 -mean = 2e8 -scale = time - -[sweep.train.bptt_horizon] -distribution = uniform_pow2 -min = 16 -max = 64 -mean = 64 -scale = auto - [sweep.train.minibatch_size] distribution = uniform_pow2 -min = 8192 +min = 32768 max = 65536 -mean = 32768 +mean = 65536 scale = auto [sweep.train.learning_rate] @@ -115,7 +100,7 @@ scale = auto distribution = logit_normal min = 0.8 mean = 0.98 -max = 0.9999 +max = 0.995 scale = auto [sweep.train.gae_lambda] @@ -192,8 +177,8 @@ scale = auto [sweep.train.adam_eps] distribution = log_normal -min = 1e-14 -mean = 1e-8 +min = 1e-8 +mean = 1e-6 max = 1e-4 scale = auto diff --git a/pufferlib/config/ocean/dogfight.ini b/pufferlib/config/ocean/dogfight.ini new file mode 100644 index 000000000..dcc6c5ebe --- /dev/null +++ b/pufferlib/config/ocean/dogfight.ini @@ -0,0 +1,201 @@ +[base] +env_name = puffer_dogfight +package = ocean +policy_name = Policy +rnn_name = Recurrent + +[vec] +num_envs = 8 + +[env] +reward_aim_scale = 0.004452 +reward_closing_scale = 0.0001633 +penalty_neg_g = 0.02624 +speed_min = 50 +control_rate_penalty = 0.001 + +max_steps = 1429 +num_envs = 1024 +obs_scheme = 1 + +curriculum_enabled = 1 +curriculum_randomize = 0 +eval_spawn_mode = 0 +fixed_stage = -1 +max_stage = 19 +eval_interval = 415340 +warmup_steps = 1423645 +min_eval_episodes = 50 +total_timesteps = 200_000_000 +num_workers = 8 +finalize_margin = 20_000_000 + +# Self-play pool settings (Phase 2 - pool-based) +# policy_pool_dir = experiments/dogfight_pool # Set via training script +opponent_selection = skill_match +opponent_swap_interval = 500000 + +# Dual self-play settings (Phase 3 - train_dual_selfplay.py) +# selfplay_min_stage: Only enable dual self-play after reaching this stage +# checkpoint_lag: Opponent is N checkpoints behind latest (1=2nd newest) +# perf_threshold: Kill rate to trigger checkpoint save + opponent upgrade +# min_checkpoint_gap: Minimum steps before saving new checkpoint +# max_checkpoints: Max selfplay checkpoints to keep (milestones always kept) +# These are passed via command line to train_dual_selfplay.py +# Default: selfplay_min_stage = 20 +# Default: checkpoint_lag = 1 +# Default: perf_threshold = 0.65 +# Default: min_checkpoint_gap = 2_000_000 +# Default: max_checkpoints = 20 + +[train] +adam_beta1 = 0.8453 +adam_beta2 = 1 +adam_eps = 2.72e-05 +batch_size = auto +bptt_horizon = 64 +checkpoint_interval = 200 +clip_coef = 0.2576 +ent_coef = 0.008355 +gae_lambda = 0.9934 +gamma = 0.9903 +learning_rate = 0.0003812 +max_grad_norm = 1.501 +max_minibatch_size = 65536 +minibatch_size = 32768 +prio_alpha = 0.9724 +prio_beta0 = 0.6139 +seed = 42 +total_timesteps = 200_000_000 +update_epochs = 4 +vf_clip_coef = 4.663 +vf_coef = 4.034 +vtrace_c_clip = 3.085 +vtrace_rho_clip = 2.91 + +[sweep] +downsample = 1 +goal = maximize +method = Protein +metric = ultimate +prune_pareto = True +use_gpu = True + +[sweep.train.total_timesteps] +distribution = uniform +min = 200_000_000 +mean = 200_000_001 +max = 200_000_002 +scale = auto + +[sweep.env.reward_aim_scale] +distribution = uniform +min = 0.0001 +max = 0.01 +mean = 0.005 +scale = auto + +[sweep.env.reward_closing_scale] +distribution = uniform +min = 0.0001 +max = 0.003 +mean = 0.001 +scale = auto + +[sweep.env.penalty_neg_g] +distribution = uniform +min = 0.01 +max = 0.05 +mean = 0.02 +scale = auto + +[sweep.env.control_rate_penalty] +distribution = log_normal +min = 0.0001 +max = 0.01 +mean = 0.001 +scale = auto + +[sweep.env.obs_scheme] +distribution = int_uniform +min = 0 +max = 8 +mean = 8 +scale = 1.0 + +[sweep.env.max_steps] +distribution = int_uniform +min = 300 +max = 1500 +mean = 1200 +scale = 1.0 + +[sweep.env.eval_interval] +distribution = int_uniform +min = 75_000 +max = 1_000_000 +mean = 100_000 +scale = 1.0 + +[sweep.env.warmup_steps] +distribution = int_uniform +min = 1_000_000 +max = 3_000_000 +mean = 2_000_000 +scale = 1.0 + +[sweep.train.adam_eps] +distribution = log_normal +min = 1e-8 +mean = 1e-6 +max = 1e-4 +scale = auto + +[sweep.train.learning_rate] +distribution = log_normal +max = 0.0005 +mean = 0.00025 +min = 0.0001 +scale = 0.5 + +[sweep.train.vf_coef] +distribution = uniform +min = 1.0 +max = 5.0 +mean = 3.0 +scale = auto + +[sweep.train.clip_coef] +distribution = uniform +min = 0.15 +max = 1.0 +mean = 0.35 +scale = auto + +[sweep.train.ent_coef] +distribution = log_normal +min = 0.002 +max = 0.02 +mean = 0.008 +scale = 0.5 + +[sweep.train.max_grad_norm] +distribution = uniform +min = 0.5 +max = 2.0 +mean = 1.0 +scale = auto + +[sweep.train.gae_lambda] +distribution = logit_normal +min = 0.9 +max = 0.999 +mean = 0.97 +scale = auto + +[sweep.train.gamma] +distribution = logit_normal +min = 0.88 +max = 0.995 +mean = 0.94 +scale = auto diff --git a/pufferlib/environments/mani_skill/torch.py b/pufferlib/environments/mani_skill/torch.py index abb8eaa18..c2e5a795d 100644 --- a/pufferlib/environments/mani_skill/torch.py +++ b/pufferlib/environments/mani_skill/torch.py @@ -64,7 +64,7 @@ def decode_actions(self, hidden): '''Decodes a batch of hidden states into (multi)discrete actions. Assumes no time dimension (handled by LSTM wrappers).''' mean = self.decoder_mean(hidden) - logstd = self.decoder_logstd.expand_as(mean) + logstd = self.decoder_logstd.expand_as(mean).clamp(min=-20, max=2) std = torch.exp(logstd) logits = torch.distributions.Normal(mean, std) values = self.value(hidden) diff --git a/pufferlib/models.py b/pufferlib/models.py index fa43d7071..d81198343 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -88,7 +88,7 @@ def decode_actions(self, hidden): logits = self.decoder(hidden).split(self.action_nvec, dim=1) elif self.is_continuous: mean = self.decoder_mean(hidden) - logstd = self.decoder_logstd.expand_as(mean) + logstd = self.decoder_logstd.expand_as(mean).clamp(min=-20, max=2) std = torch.exp(logstd) logits = torch.distributions.Normal(mean, std) else: diff --git a/pufferlib/ocean/dogfight/arena.py b/pufferlib/ocean/dogfight/arena.py new file mode 100644 index 000000000..bcb740a05 --- /dev/null +++ b/pufferlib/ocean/dogfight/arena.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python +"""Arena: Head-to-head policy comparison for Dogfight. + +Pits two policies against each other over many episodes to evaluate relative strength. + +Usage: + # Compare two checkpoints + python pufferlib/ocean/dogfight/arena.py \\ + --policy-a experiments/model_a.pt \\ + --policy-b experiments/model_b.pt \\ + --episodes 100 + + # Compare policy against autopilot (no policy-b = use autopilot) + python pufferlib/ocean/dogfight/arena.py \\ + --policy-a experiments/model.pt \\ + --episodes 100 + + # With rendering (slow) + python pufferlib/ocean/dogfight/arena.py \\ + --policy-a model_a.pt \\ + --policy-b model_b.pt \\ + --render --fps 30 + + # At specific curriculum stage + python pufferlib/ocean/dogfight/arena.py \\ + --policy-a model_a.pt \\ + --policy-b model_b.pt \\ + --stage 15 +""" +import argparse +import time +from collections import defaultdict + +import numpy as np +import torch + +from pufferlib.ocean.dogfight.dogfight import Dogfight +from pufferlib.models import Default as Policy + + +def load_policy(path, env, device='cuda'): + """Load a policy checkpoint.""" + policy = Policy(env, hidden_size=128) + policy = policy.to(device) + + state_dict = torch.load(path, map_location=device, weights_only=True) + + # Handle different checkpoint formats + cleaned_state_dict = {} + for k, v in state_dict.items(): + # Skip LSTM keys + if k.startswith('lstm.') or k.startswith('cell.'): + continue + # Strip prefixes + new_k = k.replace('module.', '').replace('policy.', '') + cleaned_state_dict[new_k] = v + + policy.load_state_dict(cleaned_state_dict) + policy.eval() + + return policy + + +def run_episode(env, policy_a, policy_b, device='cuda', render=False, fps=30): + """Run a single episode, return winner ('a', 'b', or 'draw').""" + obs, _ = env.reset() + + from pufferlib.ocean.dogfight import binding + + done = False + tick = 0 + max_ticks = 6000 # 2 minutes at 50Hz + + while not done and tick < max_ticks: + # Get observations for both perspectives + obs_a = torch.as_tensor(obs, device=device).unsqueeze(0) + obs_b = binding.vec_get_opponent_observations(env.c_envs) + obs_b = torch.as_tensor(obs_b, device=device) + + # Policy A action (player) + with torch.no_grad(): + logits_a, _ = policy_a.forward_eval(obs_a, state=None) + action_a = logits_a.sample() + action_a = action_a.cpu().numpy().astype(np.float32) + action_a = np.clip(action_a, -1, 1) + + # Policy B action (opponent) + if policy_b is not None: + with torch.no_grad(): + logits_b, _ = policy_b.forward_eval(obs_b, state=None) + action_b = logits_b.sample() + action_b = action_b.cpu().numpy().astype(np.float32) + action_b = np.clip(action_b, -1, 1) + + # Set opponent actions in C code + binding.vec_set_opponent_actions(env.c_envs, action_b) + + # Step environment + obs, reward, terminal, truncation, info = env.step(action_a.reshape(1, -1)) + + if render: + env.render() + time.sleep(1.0 / fps) + + done = terminal[0] or truncation[0] + tick += 1 + + # Determine winner based on reward + final_reward = reward[0] + if final_reward > 0.5: + return 'a' # Policy A killed opponent + elif final_reward < -0.5: + return 'b' # Policy A was killed (opponent wins) + else: + return 'draw' # Timeout or other + + +def main(): + parser = argparse.ArgumentParser(description='Dogfight Arena: Head-to-head policy comparison') + parser.add_argument('--policy-a', type=str, required=True, + help='Path to first policy checkpoint') + parser.add_argument('--policy-b', type=str, default=None, + help='Path to second policy checkpoint (omit for autopilot)') + parser.add_argument('--episodes', type=int, default=100, + help='Number of episodes to run') + parser.add_argument('--stage', type=int, default=-1, + help='Curriculum stage (-1 for stage 20 AutoAce)') + parser.add_argument('--obs-scheme', type=int, default=1, + help='Observation scheme (must match training)') + parser.add_argument('--render', action='store_true', + help='Render episodes') + parser.add_argument('--fps', type=int, default=30, + help='Render FPS') + parser.add_argument('--device', type=str, default='cuda', + help='Device for policy inference') + args = parser.parse_args() + + device = args.device + if device == 'cuda' and not torch.cuda.is_available(): + device = 'cpu' + print('CUDA not available, using CPU') + + # Create environment + render_mode = 'human' if args.render else None + env = Dogfight( + num_envs=1, + render_mode=render_mode, + render_fps=args.fps if args.render else None, + obs_scheme=args.obs_scheme, + curriculum_enabled=1, + fixed_stage=args.stage if args.stage >= 0 else 20, # Default to AutoAce stage + max_steps=6000, # 2 minutes per episode + ) + + # Load policies + print(f'Loading policy A: {args.policy_a}') + policy_a = load_policy(args.policy_a, env, device) + + if args.policy_b: + print(f'Loading policy B: {args.policy_b}') + policy_b = load_policy(args.policy_b, env, device) + + # Enable opponent override for policy-controlled opponent + from pufferlib.ocean.dogfight import binding + binding.vec_enable_opponent_override(env.c_envs, 1) + else: + print('Policy B: Autopilot (C code)') + policy_b = None + + # Run matches + results = {'a': 0, 'b': 0, 'draw': 0} + episode_times = [] + + print(f'\nRunning {args.episodes} episodes...\n') + + for ep in range(args.episodes): + start_time = time.time() + winner = run_episode(env, policy_a, policy_b, device, args.render, args.fps) + elapsed = time.time() - start_time + episode_times.append(elapsed) + + results[winner] += 1 + + # Progress update + if (ep + 1) % 10 == 0 or (ep + 1) == args.episodes: + a_wins = results['a'] + b_wins = results['b'] + draws = results['draw'] + total = a_wins + b_wins + draws + a_pct = 100 * a_wins / total if total > 0 else 0 + b_pct = 100 * b_wins / total if total > 0 else 0 + draw_pct = 100 * draws / total if total > 0 else 0 + avg_time = np.mean(episode_times) + + print(f'Episode {ep+1}/{args.episodes}: ' + f'A={a_wins} ({a_pct:.1f}%) | ' + f'B={b_wins} ({b_pct:.1f}%) | ' + f'Draw={draws} ({draw_pct:.1f}%) | ' + f'Avg time: {avg_time:.2f}s') + + # Final summary + print('\n' + '='*60) + print('FINAL RESULTS') + print('='*60) + + total = args.episodes + a_wins = results['a'] + b_wins = results['b'] + draws = results['draw'] + + print(f'Policy A ({args.policy_a}):') + print(f' Wins: {a_wins} ({100*a_wins/total:.1f}%)') + + if args.policy_b: + print(f'Policy B ({args.policy_b}):') + else: + print('Autopilot:') + print(f' Wins: {b_wins} ({100*b_wins/total:.1f}%)') + + print(f'Draws: {draws} ({100*draws/total:.1f}%)') + + # Win rate comparison + if a_wins + b_wins > 0: + a_vs_b = a_wins / (a_wins + b_wins) + print(f'\nA vs B win rate: {100*a_vs_b:.1f}%') + + # Confidence interval (Wilson score) + n = a_wins + b_wins + p = a_vs_b + z = 1.96 # 95% confidence + denom = 1 + z*z/n + center = (p + z*z/(2*n)) / denom + spread = z * np.sqrt(p*(1-p)/n + z*z/(4*n*n)) / denom + ci_low = max(0, center - spread) + ci_high = min(1, center + spread) + print(f'95% CI: [{100*ci_low:.1f}%, {100*ci_high:.1f}%]') + + env.close() + + +if __name__ == '__main__': + main() diff --git a/pufferlib/ocean/dogfight/autoace.h b/pufferlib/ocean/dogfight/autoace.h new file mode 100644 index 000000000..1b5feb8c6 --- /dev/null +++ b/pufferlib/ocean/dogfight/autoace.h @@ -0,0 +1,664 @@ +#ifndef AUTOACE_H +#define AUTOACE_H + +#include +#include +#include + +typedef struct TacticalState { + float aspect_angle; // 0 = behind target, 180 = head-on (degrees) + float angle_off; // Track crossing angle - velocity alignment (degrees) + float antenna_train; // Target bearing from our nose, 0 = dead ahead (degrees) + float range; // Distance in meters + float closure_rate; // Positive = closing (m/s) + + float specific_energy; // Own Es = 0.5*v^2 + g*h (m^2/s^2) + float target_energy; // Target Es + float energy_delta; // Own Es - Target Es (positive = advantage) + float own_speed; // Current airspeed (m/s) + float target_speed; // Target airspeed (m/s) + + float time_to_intercept; // range / closure_rate (seconds, 999 if not closing) + bool in_gun_envelope; // range < 500m && antenna_train < 5 deg + bool target_in_front; // antenna_train < 90 deg + bool we_are_faster; // own_speed > target_speed + 5 m/s + bool closing; // closure_rate > 0 + + Vec3 lead_pos; // Predicted target position at bullet TOF +} TacticalState; + +typedef enum { + ENGAGE_OFFENSIVE, // Behind target, closing, have energy - ATTACK + ENGAGE_NEUTRAL, // Neither has clear advantage - MANEUVER FOR POSITION + ENGAGE_DEFENSIVE, // Target behind us, closing - SURVIVE + ENGAGE_WEAPONS, // In firing solution - TRACK AND SHOOT + ENGAGE_EXTEND, // Low energy, need to disengage and rebuild +} EngagementState; + +typedef struct AutoAceState { + // Current engagement assessment + TacticalState tactical; + EngagementState engagement; + + // Maneuver state machine + int mode_timer; // Ticks remaining in current mode (for persistence) + int maneuver_phase; // Phase within multi-phase maneuvers (0, 1, 2...) + float yoyo_apex_alt; // Target altitude for high yo-yo apex + + // Scissors maneuver state + int scissors_timer; // Ticks until next reversal + int scissors_direction; // +1 or -1 (current turn direction) + + // PID state for tracking + float prev_heading_error; + float prev_pitch_error; + float prev_bank_error; + + // Statistics + int shots_fired; + int hits; +} AutoAceState; + +#define AUTOACE_GUN_RANGE 500.0f // Gun effective range (m) +#define AUTOACE_BULLET_SPEED 850.0f // ~WW2 .50 cal muzzle velocity (m/s) +#define AUTOACE_GUN_CONE 5.0f // Firing cone half-angle (degrees) +#define AUTOACE_MIN_MODE_TIME 25 // Minimum ticks per mode (~0.5s at 50Hz) +#define AUTOACE_FIRE_COOLDOWN 10 // Ticks between shots + +// Energy thresholds (specific energy in m^2/s^2) +#define AUTOACE_ENERGY_LOW -5000.0f // Significant energy deficit +#define AUTOACE_ENERGY_ADVANTAGE 3000.0f // Clear energy advantage + +// Speed thresholds (m/s) +#define AUTOACE_SPEED_LOW 60.0f // Approaching stall +#define AUTOACE_SPEED_FAST_DIFF 5.0f // Speed difference considered significant + +// Closure rate thresholds (m/s) +#define AUTOACE_CLOSURE_FAST 50.0f // Closing too fast (overshoot risk) +#define AUTOACE_CLOSURE_SLOW -10.0f // Falling behind + +static inline void compute_tactical_state(Plane* self, Plane* target, TacticalState* ts) { + // === Geometry === + Vec3 to_target = sub3(target->pos, self->pos); + ts->range = norm3(to_target); + + if (ts->range < 1.0f) { + ts->range = 1.0f; + } + + Vec3 los = normalize3(to_target); // Line of sight + + // Aspect angle: angle between LOS (from us to target) and target's forward + // 0 = directly behind target (LOS aligns with target's fwd), 180 = head-on + Vec3 tgt_fwd = quat_rotate(target->ori, vec3(1, 0, 0)); + float aspect_cos = dot3(los, tgt_fwd); + aspect_cos = clampf(aspect_cos, -1.0f, 1.0f); + ts->aspect_angle = acosf(aspect_cos) * RAD_TO_DEG; + + // Antenna train angle: target bearing from our nose + // 0 = dead ahead, 90 = to our side, 180 = behind us + Vec3 self_fwd = quat_rotate(self->ori, vec3(1, 0, 0)); + float train_cos = dot3(los, self_fwd); + train_cos = clampf(train_cos, -1.0f, 1.0f); + ts->antenna_train = acosf(train_cos) * RAD_TO_DEG; + + // Angle-off: track crossing angle (how parallel are our velocities?) + // 0 = same direction, 180 = opposite directions + float self_speed = norm3(self->vel); + float tgt_speed = norm3(target->vel); + + if (self_speed > 1.0f && tgt_speed > 1.0f) { + Vec3 self_vel_n = normalize3(self->vel); + Vec3 tgt_vel_n = normalize3(target->vel); + float angle_off_cos = dot3(self_vel_n, tgt_vel_n); + angle_off_cos = clampf(angle_off_cos, -1.0f, 1.0f); + ts->angle_off = acosf(angle_off_cos) * RAD_TO_DEG; + } else { + ts->angle_off = 0.0f; + } + + // Closure rate: positive = closing + Vec3 rel_vel = sub3(self->vel, target->vel); + ts->closure_rate = dot3(rel_vel, los); + + ts->own_speed = self_speed; + ts->target_speed = tgt_speed; + ts->specific_energy = 0.5f * ts->own_speed * ts->own_speed + GRAVITY * self->pos.z; + ts->target_energy = 0.5f * ts->target_speed * ts->target_speed + GRAVITY * target->pos.z; + ts->energy_delta = ts->specific_energy - ts->target_energy; + + ts->time_to_intercept = (ts->closure_rate > 1.0f) ? + ts->range / ts->closure_rate : 999.0f; + ts->in_gun_envelope = (ts->range < AUTOACE_GUN_RANGE && + ts->antenna_train < AUTOACE_GUN_CONE); + ts->target_in_front = (ts->antenna_train < 90.0f); + ts->we_are_faster = (ts->own_speed > ts->target_speed + AUTOACE_SPEED_FAST_DIFF); + ts->closing = (ts->closure_rate > 0.0f); + + float bullet_tof = ts->range / AUTOACE_BULLET_SPEED; + ts->lead_pos = add3(target->pos, mul3(target->vel, bullet_tof)); +} + +static inline EngagementState classify_engagement(TacticalState* ts) { + if (ts->in_gun_envelope) { + return ENGAGE_WEAPONS; + } + + // DEFENSIVE: Target behind us (aspect > 135 from target's POV means + // they're behind us) and closing + // Actually, if OUR aspect < 45 means we're behind THEM + // If THEIR aspect (to us) > 135, they're behind us + // Easier: if antenna_train > 135 degrees, target is behind us + if (ts->antenna_train > 135.0f && ts->closure_rate > 10.0f) { + // Wait, antenna_train > 135 means target is behind us? No. + // antenna_train is target bearing FROM our nose + // If target is behind us, antenna_train > 90 + // Let's think about aspect_angle instead: + // aspect_angle = 0 means we're behind target + // What we need: is TARGET behind US? + // We can compute this from the reverse perspective: + // If target were computing aspect on us, what would it be? + // Simplified: if our antenna_train > 120, target has angular advantage + // OR if our aspect_angle > 120 (we're in front of target = bad) + return ENGAGE_DEFENSIVE; + } + + // More defensive check: target behind us means high antenna_train + if (ts->antenna_train > 120.0f && ts->closing) { + return ENGAGE_DEFENSIVE; + } + + // EXTEND: We're slow and/or low on energy + if (ts->own_speed < AUTOACE_SPEED_LOW || + ts->energy_delta < AUTOACE_ENERGY_LOW) { + return ENGAGE_EXTEND; + } + + // OFFENSIVE: Behind target (low aspect angle) with reasonable energy + // aspect_angle < 60 means we're in the rear quarter + if (ts->aspect_angle < 60.0f && + ts->energy_delta > AUTOACE_ENERGY_LOW && + ts->target_in_front) { + return ENGAGE_OFFENSIVE; + } + + // Default: NEUTRAL - need to maneuver for advantage + return ENGAGE_NEUTRAL; +} + +static inline float get_heading_to_point(Plane* self, Vec3 point) { + Vec3 to_point = sub3(point, self->pos); + return atan2f(to_point.y, to_point.x); +} + +static inline float get_pitch_to_point(Plane* self, Vec3 point) { + Vec3 to_point = sub3(point, self->pos); + float horiz_dist = sqrtf(to_point.x * to_point.x + to_point.y * to_point.y); + return atan2f(to_point.z, horiz_dist); +} + +static inline float get_current_heading(Plane* p) { + Vec3 fwd = quat_rotate(p->ori, vec3(1, 0, 0)); + return atan2f(fwd.y, fwd.x); +} + +static inline float get_current_pitch(Plane* p) { + Vec3 fwd = quat_rotate(p->ori, vec3(1, 0, 0)); + return asinf(clampf(fwd.z, -1.0f, 1.0f)); +} + +static inline float get_current_bank(Plane* p) { + Vec3 up = quat_rotate(p->ori, vec3(0, 0, 1)); + float bank = acosf(clampf(up.z, -1.0f, 1.0f)); + // Sign: positive when right wing down (up.y < 0) + return (up.y < 0) ? bank : -bank; +} + +static inline float normalize_angle(float angle) { + while (angle > PI) angle -= 2.0f * PI; + while (angle < -PI) angle += 2.0f * PI; + return angle; +} + +static inline float compute_bank_for_heading(Plane* self, float target_heading, float max_bank) { + float current_heading = get_current_heading(self); + float heading_error = normalize_angle(target_heading - current_heading); + + // Proportional bank: more error = more bank + float bank_command = heading_error * 1.5f; // Gain of 1.5 + return clampf(bank_command, -max_bank, max_bank); +} + +static inline void execute_gun_track(AutopilotState* ap, AutoAceState* ace, + Plane* self, Plane* target, float* actions) { + // Aim at lead point + Vec3 to_lead = sub3(ace->tactical.lead_pos, self->pos); + float desired_heading = atan2f(to_lead.y, to_lead.x); + float horiz_dist = sqrtf(to_lead.x * to_lead.x + to_lead.y * to_lead.y); + float desired_pitch = atan2f(to_lead.z, horiz_dist); + + float current_heading = get_current_heading(self); + float current_pitch = get_current_pitch(self); + + float heading_error = normalize_angle(desired_heading - current_heading); + float pitch_error = desired_pitch - current_pitch; + + // Bank to turn toward target + // Positive heading_error (target left) → negative target_bank (bank left) → turn left + float target_bank = clampf(heading_error * -2.0f, -1.2f, 1.2f); // ~70 deg max + float current_bank = get_current_bank(self); + float bank_error = target_bank - current_bank; + + // Aileron: roll to target bank (positive gain) + actions[2] = clampf(bank_error * 5.0f, -1.0f, 1.0f); + + // Elevator: pitch to track + // In a bank, we need to pull to change heading, not just pitch + float load_factor = 1.0f / fmaxf(cosf(fabsf(current_bank)), 0.3f); + float base_pull = -0.2f * load_factor; // Base pull to maintain altitude in turn + float pitch_correction = -pitch_error * 3.0f; + actions[1] = clampf(base_pull + pitch_correction, -1.0f, 1.0f); + + // Throttle: maintain energy + actions[0] = 0.8f * 2.0f - 1.0f; // 80% throttle -> action space + + // Rudder: coordinate turn + actions[3] = clampf(heading_error * 0.5f, -0.3f, 0.3f); + + // Fire when on target + if (ace->tactical.antenna_train < 3.0f && ace->tactical.range < AUTOACE_GUN_RANGE) { + if (self->fire_cooldown == 0) { + actions[4] = 1.0f; // FIRE! + ace->shots_fired++; + } + } else { + actions[4] = -1.0f; + } +} + +static inline void execute_pursuit_lag(AutopilotState* ap, AutoAceState* ace, + Plane* self, Plane* target, float* actions) { + // Aim at where target WAS (lag behind) + // Effectively aim at target position but don't lead + Vec3 to_target = sub3(target->pos, self->pos); + float desired_heading = atan2f(to_target.y, to_target.x); + float horiz_dist = sqrtf(to_target.x * to_target.x + to_target.y * to_target.y); + float desired_pitch = atan2f(to_target.z, horiz_dist); + + float current_heading = get_current_heading(self); + float current_pitch = get_current_pitch(self); + + float heading_error = normalize_angle(desired_heading - current_heading); + float pitch_error = desired_pitch - current_pitch; + + // Bank to turn toward target + // Positive heading_error (target left) → negative target_bank (bank left) → turn left + // Sign convention matches autopilot.h: negative bank = left wing down = turn left + float target_bank = clampf(heading_error * -1.5f, -1.0f, 1.0f); // ~60 deg max + float current_bank = get_current_bank(self); + float bank_error = target_bank - current_bank; + + // Positive gain: positive bank_error → positive aileron → roll right + // This matches autopilot.h roll_kp = -5.0 (but we changed target_bank sign) + actions[2] = clampf(bank_error * 5.0f, -1.0f, 1.0f); // Aileron + actions[1] = clampf(-pitch_error * 2.0f, -0.5f, 0.5f); // Elevator (gentle) + actions[0] = 0.9f * 2.0f - 1.0f; // High throttle to close + actions[3] = clampf(heading_error * 0.3f, -0.2f, 0.2f); // Rudder + actions[4] = -1.0f; // Don't fire in lag pursuit +} + +static inline void execute_pursuit_lead(AutopilotState* ap, AutoAceState* ace, + Plane* self, Plane* target, float* actions) { + // Aim at lead point + Vec3 to_lead = sub3(ace->tactical.lead_pos, self->pos); + float desired_heading = atan2f(to_lead.y, to_lead.x); + float horiz_dist = sqrtf(to_lead.x * to_lead.x + to_lead.y * to_lead.y); + float desired_pitch = atan2f(to_lead.z, horiz_dist); + + float current_heading = get_current_heading(self); + float current_pitch = get_current_pitch(self); + + float heading_error = normalize_angle(desired_heading - current_heading); + float pitch_error = desired_pitch - current_pitch; + + // Aggressive turn toward lead point + // Positive heading_error → negative target_bank → turn left + float target_bank = clampf(heading_error * -2.0f, -1.2f, 1.2f); + float current_bank = get_current_bank(self); + float bank_error = target_bank - current_bank; + + actions[2] = clampf(bank_error * 5.0f, -1.0f, 1.0f); // Aileron + actions[1] = clampf(-pitch_error * 3.0f, -0.7f, 0.7f); // Elevator (aggressive) + actions[0] = 0.7f * 2.0f - 1.0f; // Moderate throttle (manage closure) + actions[3] = clampf(heading_error * 0.4f, -0.3f, 0.3f); // Rudder + actions[4] = -1.0f; +} + +static inline void execute_break_turn(AutopilotState* ap, AutoAceState* ace, + Plane* self, Plane* target, float* actions) { + // Turn AWAY from target - determine which side target is on + Vec3 to_target = sub3(target->pos, self->pos); + Vec3 right = quat_rotate(self->ori, vec3(0, 1, 0)); + float dot_right = dot3(normalize3(to_target), right); + + // Turn away (opposite side from target) + // If target is to our right (dot_right > 0), bank left (negative) to turn away + // If target is to our left (dot_right < 0), bank right (positive) to turn away + float target_bank = (dot_right > 0) ? -1.3f : 1.3f; // Max bank ~75 deg + + float current_bank = get_current_bank(self); + float bank_error = target_bank - current_bank; + + actions[2] = clampf(bank_error * 6.0f, -1.0f, 1.0f); // Aggressive aileron + actions[1] = -0.7f; // Pull hard! + actions[0] = 1.0f; // Full throttle (max action = 1.0) + actions[3] = 0.0f; // No rudder in break + actions[4] = -1.0f; +} + +static inline void execute_high_yoyo(AutopilotState* ap, AutoAceState* ace, + Plane* self, Plane* target, float* actions) { + float current_bank = get_current_bank(self); + + if (ace->maneuver_phase == 0) { + // Phase 1: Reduce bank, pull up to climb + if (ace->yoyo_apex_alt == 0.0f) { + // Set apex altitude 150-200m above current + ace->yoyo_apex_alt = self->pos.z + 150.0f + rndf(0, 50); + } + + // Shallow bank, climb - reduce current bank toward zero + float target_bank = current_bank * 0.3f; + float bank_error = target_bank - current_bank; + + actions[2] = clampf(bank_error * 3.0f, -1.0f, 1.0f); + actions[1] = -0.5f; // Pull up moderately + actions[0] = 0.8f * 2.0f - 1.0f; + actions[3] = 0.0f; + actions[4] = -1.0f; + + // Transition when reaching apex + if (self->pos.z > ace->yoyo_apex_alt) { + ace->maneuver_phase = 1; + } + } else { + // Phase 2: Roll back in, dive toward target + Vec3 to_target = sub3(target->pos, self->pos); + float desired_heading = atan2f(to_target.y, to_target.x); + float current_heading = get_current_heading(self); + float heading_error = normalize_angle(desired_heading - current_heading); + + // Positive heading_error → negative target_bank → turn left + float target_bank = clampf(heading_error * -2.0f, -1.0f, 1.0f); + float bank_error = target_bank - current_bank; + + actions[2] = clampf(bank_error * 5.0f, -1.0f, 1.0f); + actions[1] = 0.3f; // Push over slightly to dive + actions[0] = 0.5f * 2.0f - 1.0f; // Reduced throttle + actions[3] = clampf(heading_error * 0.3f, -0.2f, 0.2f); + actions[4] = -1.0f; + } +} + +static inline void execute_scissors(AutopilotState* ap, AutoAceState* ace, + Plane* self, Plane* target, float* actions) { + // Initialize direction if needed + if (ace->scissors_direction == 0) { + ace->scissors_direction = (rndf(0, 1) > 0.5f) ? 1 : -1; + ace->scissors_timer = 40; // ~0.8 seconds per reversal + } + + // Check for reversal + ace->scissors_timer--; + if (ace->scissors_timer <= 0) { + ace->scissors_direction *= -1; // Reverse! + ace->scissors_timer = 35 + (int)rndf(0, 10); // Vary timing + } + + // target_bank: +1.4 = bank right = turn right, -1.4 = bank left = turn left + float target_bank = ace->scissors_direction * 1.4f; // ~80 deg + float current_bank = get_current_bank(self); + float bank_error = target_bank - current_bank; + + actions[2] = clampf(bank_error * 6.0f, -1.0f, 1.0f); // Aggressive roll + actions[1] = -0.5f; // Pull through each reversal + actions[0] = 0.3f * 2.0f - 1.0f; // Low throttle to slow down + actions[3] = 0.0f; + actions[4] = -1.0f; +} + +static inline void execute_extend(AutopilotState* ap, AutoAceState* ace, + Plane* self, Plane* target, float* actions) { + // Fly straight away from target + Vec3 from_target = sub3(self->pos, target->pos); + float away_heading = atan2f(from_target.y, from_target.x); + + float current_heading = get_current_heading(self); + float heading_error = normalize_angle(away_heading - current_heading); + + // Gentle turn to face away + // Positive heading_error → negative target_bank → turn left + float target_bank = clampf(heading_error * -1.0f, -0.5f, 0.5f); + float current_bank = get_current_bank(self); + float bank_error = target_bank - current_bank; + + actions[2] = clampf(bank_error * 4.0f, -1.0f, 1.0f); + actions[1] = -0.1f; // Slight climb to gain energy + actions[0] = 1.0f; // Full throttle! + actions[3] = 0.0f; + actions[4] = -1.0f; +} + +static inline void execute_pursuit_pure(AutopilotState* ap, AutoAceState* ace, + Plane* self, Plane* target, float* actions) { + Vec3 to_target = sub3(target->pos, self->pos); + float desired_heading = atan2f(to_target.y, to_target.x); + float horiz_dist = sqrtf(to_target.x * to_target.x + to_target.y * to_target.y); + float desired_pitch = atan2f(to_target.z, horiz_dist); + + float current_heading = get_current_heading(self); + float current_pitch = get_current_pitch(self); + + float heading_error = normalize_angle(desired_heading - current_heading); + float pitch_error = desired_pitch - current_pitch; + + // Positive heading_error → negative target_bank → turn left + float target_bank = clampf(heading_error * -2.0f, -1.0f, 1.0f); + float current_bank = get_current_bank(self); + float bank_error = target_bank - current_bank; + + actions[2] = clampf(bank_error * 5.0f, -1.0f, 1.0f); + actions[1] = clampf(-pitch_error * 2.5f, -0.6f, 0.6f); + actions[0] = 0.8f * 2.0f - 1.0f; + actions[3] = clampf(heading_error * 0.3f, -0.2f, 0.2f); + actions[4] = -1.0f; +} + +static inline void execute_hard_turn(AutopilotState* ap, AutoAceState* ace, + Plane* self, int direction, float* actions) { + // direction: +1 = right (positive bank), -1 = left (negative bank) + float target_bank = direction * 1.2f; // ~70 deg + float current_bank = get_current_bank(self); + float bank_error = target_bank - current_bank; + + actions[2] = clampf(bank_error * 5.0f, -1.0f, 1.0f); + actions[1] = -0.5f; // Pull to turn + actions[0] = 0.9f * 2.0f - 1.0f; + actions[3] = 0.0f; + actions[4] = -1.0f; +} + +static inline AutopilotMode select_tactical_mode(TacticalState* ts, AutoAceState* ace, Plane* self) { + EngagementState engage = classify_engagement(ts); + ace->engagement = engage; + + switch (engage) { + case ENGAGE_WEAPONS: + return AP_GUN_TRACK; + + case ENGAGE_OFFENSIVE: + // Behind target, closing + if (ts->closure_rate > AUTOACE_CLOSURE_FAST && ts->range < 400.0f) { + return AP_HIGH_YOYO; // Too fast, will overshoot + } + if (ts->closure_rate < AUTOACE_CLOSURE_SLOW) { + return AP_PURSUIT_LEAD; // Falling behind, cut inside + } + return AP_PURSUIT_LAG; // Default: controlled pursuit + + case ENGAGE_NEUTRAL: + // Turn fight for position + if (ts->energy_delta > AUTOACE_ENERGY_ADVANTAGE) { + return AP_HIGH_YOYO; // Convert energy to position + } + // Turn toward target + if (ts->antenna_train > 90.0f) { + // Target behind us, turn to face + return (rndf(0, 1) > 0.5f) ? AP_HARD_TURN_LEFT : AP_HARD_TURN_RIGHT; + } + return AP_PURSUIT_PURE; + + case ENGAGE_DEFENSIVE: + // Threat behind, need to survive + if (self->pos.z > 1500.0f && ts->closure_rate > 30.0f) { + // High altitude and fast closure - could split-s but we don't have that + return AP_SCISSORS; // Force overshoot + } + if (ts->range < 300.0f) { + return AP_SCISSORS; // Force overshoot when close + } + return AP_BREAK_TURN; // Hard turn away + + case ENGAGE_EXTEND: + return AP_EXTEND; // Run away, rebuild energy + } + + return AP_LEVEL; // Fallback +} + +static inline void autoace_init(AutoAceState* ace) { + memset(ace, 0, sizeof(AutoAceState)); + ace->scissors_direction = 0; + ace->mode_timer = 0; + ace->maneuver_phase = 0; + ace->yoyo_apex_alt = 0.0f; +} + +static inline void autoace_step(AutopilotState* ap, AutoAceState* ace, + Plane* self, Plane* target, float* actions, float dt) { + actions[0] = 0.0f; // throttle + actions[1] = 0.0f; // elevator + actions[2] = 0.0f; // ailerons + actions[3] = 0.0f; // rudder + actions[4] = -1.0f; // trigger (default: don't fire) + + compute_tactical_state(self, target, &ace->tactical); + + if (ace->mode_timer > 0) { + ace->mode_timer--; + } + + bool maneuver_done = false; + + switch (ap->mode) { + case AP_HIGH_YOYO: + if (ace->maneuver_phase == 1 && + ace->tactical.closure_rate < 30.0f && + ace->tactical.closure_rate > -10.0f) { + maneuver_done = true; + } + break; + case AP_BREAK_TURN: + if (ace->tactical.antenna_train < 100.0f) { + maneuver_done = true; + } + break; + case AP_EXTEND: + if (ace->tactical.energy_delta > 0.0f || + ace->tactical.range > 800.0f) { + maneuver_done = true; + } + break; + default: + break; + } + + if (ace->mode_timer <= 0 || maneuver_done) { + AutopilotMode new_mode = select_tactical_mode(&ace->tactical, ace, self); + if (new_mode != ap->mode) { + ap->mode = new_mode; + ace->mode_timer = AUTOACE_MIN_MODE_TIME; + ace->maneuver_phase = 0; + ace->yoyo_apex_alt = 0.0f; + ace->scissors_direction = 0; + } + } + + // Execute current maneuver + switch (ap->mode) { + case AP_GUN_TRACK: + execute_gun_track(ap, ace, self, target, actions); + break; + case AP_PURSUIT_LAG: + execute_pursuit_lag(ap, ace, self, target, actions); + break; + case AP_PURSUIT_LEAD: + execute_pursuit_lead(ap, ace, self, target, actions); + break; + case AP_PURSUIT_PURE: + execute_pursuit_pure(ap, ace, self, target, actions); + break; + case AP_HIGH_YOYO: + execute_high_yoyo(ap, ace, self, target, actions); + break; + case AP_SCISSORS: + execute_scissors(ap, ace, self, target, actions); + break; + case AP_BREAK_TURN: + execute_break_turn(ap, ace, self, target, actions); + break; + case AP_EXTEND: + execute_extend(ap, ace, self, target, actions); + break; + case AP_HARD_TURN_LEFT: + execute_hard_turn(ap, ace, self, -1, actions); + break; + case AP_HARD_TURN_RIGHT: + execute_hard_turn(ap, ace, self, +1, actions); + break; + + default: + autopilot_step(ap, self, actions, dt); + break; + } + + if (self->fire_cooldown > 0) { + self->fire_cooldown--; + } + if (actions[4] > 0.5f && self->fire_cooldown == 0) { + self->fire_cooldown = AUTOACE_FIRE_COOLDOWN; + } + + #if DEBUG >= 3 + static int debug_counter = 0; + if (debug_counter++ % 50 == 0) { // Every second + const char* mode_names[] = { + "STRAIGHT", "LEVEL", "TURN_L", "TURN_R", + "CLIMB", "DESCEND", "HARD_L", "HARD_R", + "WEAVE", "EVASIVE", "RANDOM", + "PURSUIT_LEAD", "PURSUIT_LAG", "PURSUIT_PURE", + "HIGH_YOYO", "LOW_YOYO", "SCISSORS", "BREAK", + "SPLIT_S", "EXTEND", "BARREL_ATK", "GUN_TRACK" + }; + const char* engage_names[] = { + "OFFENSIVE", "NEUTRAL", "DEFENSIVE", "WEAPONS", "EXTEND" + }; + printf("[AUTOACE] mode=%s engage=%s range=%.0f aspect=%.0f train=%.0f closure=%.0f\n", + mode_names[ap->mode], engage_names[ace->engagement], + ace->tactical.range, ace->tactical.aspect_angle, + ace->tactical.antenna_train, ace->tactical.closure_rate); + } + #endif +} + +#endif // AUTOACE_H diff --git a/pufferlib/ocean/dogfight/autopilot.h b/pufferlib/ocean/dogfight/autopilot.h new file mode 100644 index 000000000..9896d6435 --- /dev/null +++ b/pufferlib/ocean/dogfight/autopilot.h @@ -0,0 +1,428 @@ +/** + * autopilot.h - Target aircraft flight maneuvers + * + * Provides autopilot modes for opponent aircraft during training. + * Can be set randomly at reset or forced via API for curriculum learning. + */ + +#ifndef AUTOPILOT_H +#define AUTOPILOT_H + +// Note: autopilot.h requires flightlib.h to be included BEFORE this file, +// providing Vec3, Quat, Plane, and other physics types. +#include + +// Autopilot mode enumeration +typedef enum { + AP_STRAIGHT = 0, // Fly straight (current/default behavior) + AP_LEVEL, // Level flight with PD on vz + AP_TURN_LEFT, // Coordinated left turn + AP_TURN_RIGHT, // Coordinated right turn + AP_CLIMB, // Constant climb rate + AP_DESCEND, // Constant descent rate + AP_HARD_TURN_LEFT, // Aggressive 70° left turn + AP_HARD_TURN_RIGHT, // Aggressive 70° right turn + AP_WEAVE, // Sine wave jinking (S-turns) + AP_EVASIVE, // Break turn when threat behind + AP_RANDOM, // Random mode selection at reset + + // AutoAce tactical modes (used by autoace.h) + AP_PURSUIT_LEAD, // Nose ahead of target (gun attack) + AP_PURSUIT_LAG, // Nose behind target (position/close) + AP_PURSUIT_PURE, // Nose at target (missile/intercept) + AP_HIGH_YOYO, // Climb to bleed closure, dive back + AP_LOW_YOYO, // Dive to gain closure, pull up + AP_SCISSORS, // Reversing breaks to force overshoot + AP_BREAK_TURN, // Maximum rate defensive turn + AP_SPLIT_S, // Disengage downward (altitude permitting) + AP_EXTEND, // Straight away, full throttle, rebuild energy + AP_BARREL_ROLL_ATK, // Roll around target's flight path + AP_GUN_TRACK, // Lead pursuit with firing solution + + // Flight test modes + AP_MIN_RADIUS_TURN, // Full elevator, aileron keeps nose on horizon (tightest turn) + + AP_COUNT +} AutopilotMode; + +// ============================================================================ +// PID GAINS - Tuned for realistic 6DOF physics (RK4 integration) +// ============================================================================ + +// Level flight: vz tracking +// Tuned via pid_sweep.py: max_dev=7.95m over 8s +#define AP_LEVEL_KP 0.0005f +#define AP_LEVEL_KD 0.2f + +// Turn pitch-tracking: keeps nose level (pitch=0) during banked turns +// Tuned via pid_sweep.py: pitch_mean=-0.38°, pitch_std=0.36°, bank_error=0.03° +#define AP_TURN_PITCH_KP 8.0f +#define AP_TURN_PITCH_KD 0.5f +#define AP_TURN_ROLL_KP -5.0f +#define AP_TURN_ROLL_KD -0.2f + +// Default parameters +#define AP_DEFAULT_THROTTLE 1.0f +#define AP_DEFAULT_BANK_DEG 30.0f // Base gentle turns +#define AP_DEFAULT_CLIMB_RATE 5.0f + +// Stage-specific bank angles (curriculum progression) +#define AP_STAGE4_BANK_DEG 30.0f // MANEUVERING - gentle 30° turns +#define AP_STAGE5_BANK_DEG 45.0f // FULL_RANDOM - medium 45° turns +#define AP_STAGE6_BANK_DEG 60.0f // HARD_MANEUVERING - steep 60° turns +#define AP_HARD_BANK_DEG 70.0f // EVASIVE - aggressive 70° turns +#define AP_WEAVE_AMPLITUDE 0.6f // ~35° bank amplitude (radians) +#define AP_WEAVE_PERIOD 3.0f // 3 second full cycle + +// Autopilot state for a plane +typedef struct { + AutopilotMode mode; + int randomize_on_reset; // If true, pick random mode each reset + float throttle; // Target throttle [0,1] + float target_bank; // Target bank angle (radians) + float target_vz; // Target vertical velocity (m/s) + + // Curriculum: mode selection weights (sum to 1.0) + float mode_weights[AP_COUNT]; + + // Own RNG state (not affected by srand() calls) + unsigned int rng_state; + + // PID gains + float pitch_kp, pitch_kd; // Level flight: vz tracking + float turn_pitch_kp, turn_pitch_kd; // Turns: pitch tracking (keeps nose level) + float roll_kp, roll_kd; + + // PID state (for derivative terms) + float prev_vz; + float prev_pitch; + float prev_bank_error; + + // AP_WEAVE state + float phase; // Sine wave phase for weave oscillation + + // AP_EVASIVE state (set by caller each step) + Vec3 threat_pos; // Position of threat to evade +} AutopilotState; + +// Simple LCG random for autopilot (not affected by srand) +static inline float ap_rand(AutopilotState* ap) { + ap->rng_state = ap->rng_state * 1103515245 + 12345; + return (float)((ap->rng_state >> 16) & 0x7FFF) / 32767.0f; +} + +// Initialize autopilot with defaults +static inline void autopilot_init(AutopilotState* ap) { + ap->mode = AP_STRAIGHT; + ap->randomize_on_reset = 0; + ap->throttle = AP_DEFAULT_THROTTLE; + ap->target_bank = AP_DEFAULT_BANK_DEG * (PI / 180.0f); + ap->target_vz = AP_DEFAULT_CLIMB_RATE; + + // Default: uniform weights for modes 1-5 (skip STRAIGHT and RANDOM) + for (int i = 0; i < AP_COUNT; i++) { + ap->mode_weights[i] = 0.0f; + } + float uniform = 1.0f / 5.0f; // 5 modes: LEVEL, TURN_L, TURN_R, CLIMB, DESCEND + ap->mode_weights[AP_LEVEL] = uniform; + ap->mode_weights[AP_TURN_LEFT] = uniform; + ap->mode_weights[AP_TURN_RIGHT] = uniform; + ap->mode_weights[AP_CLIMB] = uniform; + ap->mode_weights[AP_DESCEND] = uniform; + + // Seed autopilot RNG from system rand (called once at init, not affected by later srand) + ap->rng_state = (unsigned int)rand(); + + ap->pitch_kp = AP_LEVEL_KP; + ap->pitch_kd = AP_LEVEL_KD; + ap->turn_pitch_kp = AP_TURN_PITCH_KP; + ap->turn_pitch_kd = AP_TURN_PITCH_KD; + ap->roll_kp = AP_TURN_ROLL_KP; + ap->roll_kd = AP_TURN_ROLL_KD; + + ap->prev_vz = 0.0f; + ap->prev_pitch = 0.0f; + ap->prev_bank_error = 0.0f; + + // New mode state + ap->phase = 0.0f; + ap->threat_pos = vec3(0, 0, 0); +} + +// Set autopilot mode with parameters +static inline void autopilot_set_mode(AutopilotState* ap, AutopilotMode mode, + float throttle, float bank_deg, float climb_rate) { + ap->mode = mode; + ap->randomize_on_reset = (mode == AP_RANDOM) ? 1 : 0; + ap->throttle = throttle; + ap->target_bank = bank_deg * (PI / 180.0f); + ap->target_vz = climb_rate; + + // Reset PID state on mode change + ap->prev_vz = 0.0f; + ap->prev_pitch = 0.0f; + ap->prev_bank_error = 0.0f; + + if (mode == AP_LEVEL || mode == AP_CLIMB || mode == AP_DESCEND) { + ap->pitch_kp = AP_LEVEL_KP; + ap->pitch_kd = AP_LEVEL_KD; + } else if (mode == AP_TURN_LEFT || mode == AP_TURN_RIGHT || + mode == AP_HARD_TURN_LEFT || mode == AP_HARD_TURN_RIGHT || + mode == AP_WEAVE || mode == AP_EVASIVE) { + ap->turn_pitch_kp = AP_TURN_PITCH_KP; + ap->turn_pitch_kd = AP_TURN_PITCH_KD; + ap->roll_kp = AP_TURN_ROLL_KP; + ap->roll_kd = AP_TURN_ROLL_KD; + } +} + +// Randomize autopilot mode using weighted selection (for AP_RANDOM at reset) +static inline void autopilot_randomize(AutopilotState* ap) { + float r = ap_rand(ap); // Use own RNG, not affected by srand() + float cumsum = 0.0f; + AutopilotMode selected = AP_LEVEL; // Default fallback + + for (int i = 1; i < AP_COUNT - 1; i++) { // Skip STRAIGHT(0) and RANDOM(10) + cumsum += ap->mode_weights[i]; + if (r <= cumsum) { + selected = (AutopilotMode)i; + break; + } + } + + // Save randomize flag (autopilot_set_mode would clear it) + int save_randomize = ap->randomize_on_reset; + autopilot_set_mode(ap, selected, AP_DEFAULT_THROTTLE, + AP_DEFAULT_BANK_DEG, AP_DEFAULT_CLIMB_RATE); + ap->randomize_on_reset = save_randomize; +} + +// Get bank angle from plane orientation +// Returns positive for right bank, negative for left bank +static inline float ap_get_bank_angle(Plane* p) { + Vec3 up = quat_rotate(p->ori, vec3(0, 0, 1)); + float bank = acosf(fminf(fmaxf(up.z, -1.0f), 1.0f)); + if (up.y < 0) bank = -bank; + return bank; +} + +// Get pitch angle from plane orientation +// Returns positive for nose up, negative for nose down +static inline float ap_get_pitch_angle(Plane* p) { + Vec3 fwd = quat_rotate(p->ori, vec3(1, 0, 0)); + return asinf(fminf(fmaxf(fwd.z, -1.0f), 1.0f)); +} + +// Get vertical velocity from plane +static inline float ap_get_vz(Plane* p) { + return p->vel.z; +} + +// Clamp value to range +static inline float ap_clamp(float v, float lo, float hi) { + return fminf(fmaxf(v, lo), hi); +} + +// Main autopilot step function +// Computes actions[5] = [throttle, elevator, ailerons, rudder, trigger] +static inline void autopilot_step(AutopilotState* ap, Plane* p, float* actions, float dt) { + // Initialize all actions to zero + actions[0] = 0.0f; // throttle (will be set below) + actions[1] = 0.0f; // elevator + actions[2] = 0.0f; // ailerons + actions[3] = 0.0f; // rudder + actions[4] = -1.0f; // trigger (never fire) + + // Set throttle (convert from [0,1] to [-1,1] action space) + actions[0] = ap->throttle * 2.0f - 1.0f; + + float vz = ap_get_vz(p); + float bank = ap_get_bank_angle(p); + + switch (ap->mode) { + case AP_STRAIGHT: + // Do nothing - just fly straight with throttle + break; + + case AP_LEVEL: { + // PD control on vz to maintain level flight + float vz_error = -vz; // Target vz = 0 + float vz_deriv = (vz - ap->prev_vz) / dt; + float elevator = ap->pitch_kp * vz_error + ap->pitch_kd * vz_deriv; + actions[1] = ap_clamp(elevator, -1.0f, 1.0f); + ap->prev_vz = vz; + break; + } + + case AP_TURN_LEFT: + case AP_TURN_RIGHT: { + // Dual PID: roll to target bank, pitch to keep nose level + float target_bank = ap->target_bank; + if (ap->mode == AP_TURN_LEFT) target_bank = -target_bank; + + // Elevator PID: track pitch=0 (level nose) instead of vz=0 + // This keeps the aircraft's nose on the horizon during turns + float pitch = ap_get_pitch_angle(p); + float pitch_error = 0.0f - pitch; // Target pitch = 0 (level) + float pitch_deriv = (pitch - ap->prev_pitch) / dt; + // Negative sign: positive error → negative elevator (pull back → nose up) + float elevator = -ap->turn_pitch_kp * pitch_error + ap->turn_pitch_kd * pitch_deriv; + actions[1] = ap_clamp(elevator, -1.0f, 1.0f); + ap->prev_pitch = pitch; + + // Aileron PID (achieve target bank) + float bank_error = target_bank - bank; + float bank_deriv = (bank_error - ap->prev_bank_error) / dt; + float aileron = ap->roll_kp * bank_error + ap->roll_kd * bank_deriv; + actions[2] = ap_clamp(aileron, -1.0f, 1.0f); + ap->prev_bank_error = bank_error; + break; + } + + case AP_CLIMB: { + // PD control to maintain target climb rate + float vz_error = ap->target_vz - vz; + float vz_deriv = (vz - ap->prev_vz) / dt; + // Negative because nose-up pitch (negative elevator) increases climb + float elevator = -ap->pitch_kp * vz_error + ap->pitch_kd * vz_deriv; + actions[1] = ap_clamp(elevator, -1.0f, 1.0f); + ap->prev_vz = vz; + break; + } + + case AP_DESCEND: { + // PD control to maintain target descent rate + float vz_error = -ap->target_vz - vz; // Target is negative vz + float vz_deriv = (vz - ap->prev_vz) / dt; + float elevator = -ap->pitch_kp * vz_error + ap->pitch_kd * vz_deriv; + actions[1] = ap_clamp(elevator, -1.0f, 1.0f); + ap->prev_vz = vz; + break; + } + + case AP_HARD_TURN_LEFT: + case AP_HARD_TURN_RIGHT: { + // Aggressive turn with high bank angle (70°) + float target_bank = AP_HARD_BANK_DEG * (PI / 180.0f); + if (ap->mode == AP_HARD_TURN_LEFT) target_bank = -target_bank; + + // Hard pull to maintain altitude in steep bank + float vz_error = -vz; + float elevator = -0.5f + ap->pitch_kp * vz_error; // Base pull + PD + actions[1] = ap_clamp(elevator, -1.0f, 1.0f); + ap->prev_vz = vz; + + // Aggressive aileron to achieve bank (50% more aggressive) + float bank_error = target_bank - bank; + float aileron = ap->roll_kp * bank_error * 1.5f; + actions[2] = ap_clamp(aileron, -1.0f, 1.0f); + break; + } + + case AP_WEAVE: { + // Sine wave banking - oscillates left/right, hard to lead + ap->phase += dt * (2.0f * PI / AP_WEAVE_PERIOD); + if (ap->phase > 2.0f * PI) ap->phase -= 2.0f * PI; + + float target_bank = AP_WEAVE_AMPLITUDE * sinf(ap->phase); + + // Elevator PID: track pitch=0 (level nose) + float pitch = ap_get_pitch_angle(p); + float pitch_error = 0.0f - pitch; + float pitch_deriv = (pitch - ap->prev_pitch) / dt; + float elevator = -ap->turn_pitch_kp * pitch_error + ap->turn_pitch_kd * pitch_deriv; + actions[1] = ap_clamp(elevator, -1.0f, 1.0f); + ap->prev_pitch = pitch; + + // Aileron PID to track oscillating bank + float bank_error = target_bank - bank; + float bank_deriv = (bank_error - ap->prev_bank_error) / dt; + float aileron = ap->roll_kp * bank_error + ap->roll_kd * bank_deriv; + actions[2] = ap_clamp(aileron, -1.0f, 1.0f); + ap->prev_bank_error = bank_error; + break; + } + + case AP_EVASIVE: { + // Break turn away from threat when close and behind + Vec3 to_threat = sub3(ap->threat_pos, p->pos); + float dist = norm3(to_threat); + Vec3 fwd = quat_rotate(p->ori, vec3(1, 0, 0)); + float dot_fwd = dot3(normalize3(to_threat), fwd); + + float target_bank = 0.0f; + float base_elevator = 0.0f; + + // Check if threat is close (<600m) and not in front (behind or side) + if (dist < 600.0f && dot_fwd < 0.3f) { + // Threat close and behind - BREAK TURN! + // Determine which side threat is on + Vec3 right = quat_rotate(p->ori, vec3(0, -1, 0)); + float dot_right = dot3(normalize3(to_threat), right); + + // Turn INTO threat (break turn toward attacker to force overshoot) + target_bank = (dot_right > 0) ? 1.2f : -1.2f; // ~70° break INTO threat + base_elevator = -0.6f; // Pull hard + } + + // Elevator: base pull + PD for altitude + float vz_error = -vz; + float elevator = base_elevator + ap->pitch_kp * vz_error; + actions[1] = ap_clamp(elevator, -1.0f, 1.0f); + ap->prev_vz = vz; + + // Aileron to achieve break bank (aggressive) + float bank_error = target_bank - bank; + float aileron = ap->roll_kp * bank_error * 1.5f; + actions[2] = ap_clamp(aileron, -1.0f, 1.0f); + break; + } + + case AP_MIN_RADIUS_TURN: { + // Bank-tracking turn test mode: + // - Moderate elevator pull (configurable via ap->target_vz as input, default -0.5) + // - Rudder locked at 0 + // - Aileron tracks target bank angle (set via ap->target_bank, default 60°) + // + // PID gains tuned via sweep (test_min_radius_turn.c --bank-sweep): + // kp=10.0, kd=3.0 gives tight bank tracking with low pitch rate + // oscillation across 80-160 m/s speed range. + + // Elevator: use target_vz as elevator input (repurposed, range -1 to 0) + float elev_input = (ap->target_vz < 0) ? ap->target_vz : -0.5f; + actions[1] = ap_clamp(elev_input, -1.0f, 0.0f); + + // Rudder locked + actions[3] = 0.0f; + + // Get current bank angle + float bank = ap_get_bank_angle(p); + + // Target bank (negative for right turn) + float target_bank = -fabsf(ap->target_bank); + + // Bank error: positive means we're too shallow, need more right bank + float bank_error = bank - target_bank; + float bank_deriv = (bank_error - ap->prev_bank_error) / dt; + + // PID gains (tuned via sweep) + float kp = 10.0f; + float kd = 3.0f; + + // Aileron: positive error -> positive aileron -> roll right + float aileron = kp * bank_error + kd * bank_deriv; + actions[2] = ap_clamp(aileron, -1.0f, 1.0f); + ap->prev_bank_error = bank_error; + break; + } + + case AP_RANDOM: + // Should have been randomized at reset, fall through to straight + break; + + default: + break; + } +} + +#endif // AUTOPILOT_H diff --git a/pufferlib/ocean/dogfight/autopilot.py b/pufferlib/ocean/dogfight/autopilot.py new file mode 100644 index 000000000..e975bd0f4 --- /dev/null +++ b/pufferlib/ocean/dogfight/autopilot.py @@ -0,0 +1,428 @@ +""" +Mode 1 Autopilot Helpers for Flight Tests + +Mode 1 (realistic 6DOF physics) has stability derivatives that create +nose-down moments at positive AOA. Tests need active control to hold +attitudes that Mode 0 (simplified) holds passively. + +PID Gains from pid_tune.py sweep (straight_level_mode1 scenario): + pitch_kp: 0.2, pitch_kd: 0.1 - controls vz/pitch via elevator + roll_kp: 1.0, roll_kd: 0.1 - controls bank via aileron + yaw_kp: 0.1, yaw_kd: 0.02 - damps yaw rate via rudder + +Key insight: Mode 1 physics uses angular velocities (omega) directly, +so we read omega_x/y/z from state for D terms instead of finite differences. +""" + +import numpy as np + + +# Default gains from pid_tune.py sweep +DEFAULT_GAINS = { + # Elevator (pitch/vz control) + 'pitch_kp': 0.2, + 'pitch_kd': 0.1, + # Aileron (bank control) + 'roll_kp': 1.0, + 'roll_kd': 0.1, + # Rudder (yaw damping) + 'yaw_kp': 0.1, + 'yaw_kd': 0.02, +} + + +def get_pitch_deg(state): + """Get pitch angle in degrees from state's forward vector.""" + return np.degrees(np.arcsin(np.clip(state['fwd_z'], -1.0, 1.0))) + + +def get_bank_deg(state): + """ + Get bank angle in degrees. + Positive = right bank, Negative = left bank. + """ + up_z, up_y = state['up_z'], state['up_y'] + bank = np.arccos(np.clip(up_z, -1.0, 1.0)) + # up_y < 0 means canopy tilted right = right bank (positive) + return np.degrees(bank if up_y < 0 else -bank) + + +def get_heading_deg(state): + """Get heading in degrees (0=+X, 90=+Y).""" + return np.degrees(np.arctan2(state['fwd_y'], state['fwd_x'])) + + +def hold_pitch(state, target_pitch_deg, gains=None): + """ + Hold a specific pitch angle using PD control. + + Args: + state: Dict from env.get_state() + target_pitch_deg: Desired pitch angle in degrees (positive = nose up) + gains: Dict with 'pitch_kp', 'pitch_kd' (uses defaults if None) + + Returns: + elevator: Control input [-1, 1] + """ + if gains is None: + gains = DEFAULT_GAINS + + pitch = get_pitch_deg(state) + omega_pitch = np.degrees(state['omega_y']) # Pitch rate from physics + + error = target_pitch_deg - pitch + + # Negative elevator = pull = nose UP + # So if pitch is below target (error > 0), we need negative elevator + # D term opposes pitch rate + elevator = -gains['pitch_kp'] * error - gains['pitch_kd'] * omega_pitch + + return np.clip(elevator, -1.0, 1.0) + + +def hold_vz(state, target_vz, gains=None): + """ + Hold a target vertical speed (vz) using PD control. + + Good for level flight (target_vz=0) or constant rate climb/descent. + + Args: + state: Dict from env.get_state() + target_vz: Desired vertical speed in m/s (positive = climbing) + gains: Dict with 'pitch_kp', 'pitch_kd' (uses defaults if None) + + Returns: + elevator: Control input [-1, 1] + """ + if gains is None: + gains = DEFAULT_GAINS + + vz = state['vz'] + omega_pitch = np.degrees(state['omega_y']) + + error = target_vz - vz + + # If descending (vz < target), error > 0, need nose UP (negative elevator) + # Scale error to match pitch-based control (rough conversion: 5 m/s ~ 3 deg pitch) + elevator = -gains['pitch_kp'] * 0.6 * error - gains['pitch_kd'] * omega_pitch + + return np.clip(elevator, -1.0, 1.0) + + +def hold_bank(state, target_bank_deg, gains=None): + """ + Hold a specific bank angle using PD control. + + Args: + state: Dict from env.get_state() + target_bank_deg: Desired bank angle (positive = right bank) + gains: Dict with 'roll_kp', 'roll_kd' (uses defaults if None) + + Returns: + aileron: Control input [-1, 1] + """ + if gains is None: + gains = DEFAULT_GAINS + + bank = get_bank_deg(state) + omega_roll = np.degrees(state['omega_x']) # Roll rate from physics + + error = target_bank_deg - bank + + # Positive aileron = roll right + # If bank is below target (error > 0), need positive aileron + # D term opposes roll rate + aileron = gains['roll_kp'] * error - gains['roll_kd'] * omega_roll + + return np.clip(aileron, -1.0, 1.0) + + +def damp_yaw(state, gains=None): + """ + Damp yaw rate to zero (straight flight). + + Args: + state: Dict from env.get_state() + gains: Dict with 'yaw_kp', 'yaw_kd' (uses defaults if None) + + Returns: + rudder: Control input [-1, 1] + """ + if gains is None: + gains = DEFAULT_GAINS + + omega_yaw = np.degrees(state['omega_z']) # Yaw rate from physics + + # Target yaw rate = 0, so error = -omega_yaw + # D term is just omega_yaw itself + rudder = -gains['yaw_kp'] * omega_yaw - gains['yaw_kd'] * omega_yaw + + return np.clip(rudder, -1.0, 1.0) + + +def hold_bank_and_level(state, target_bank_deg, gains=None): + """ + Coordinated turn: hold bank angle, keep nose level (vz ~ 0). + + In a banked turn, the lift vector is tilted, so some extra back pressure + is needed to maintain altitude. This function combines bank hold with + vz-based pitch control. + + Args: + state: Dict from env.get_state() + target_bank_deg: Desired bank angle (positive = right bank) + gains: Dict with all gains (uses defaults if None) + + Returns: + (elevator, aileron): Tuple of control inputs [-1, 1] + """ + if gains is None: + gains = DEFAULT_GAINS + + aileron = hold_bank(state, target_bank_deg, gains) + + # In a banked turn, need extra back pressure proportional to bank angle + # Load factor n = 1/cos(bank), so for 30 deg bank need ~1.15x lift + bank_rad = np.radians(abs(target_bank_deg)) + if bank_rad < np.radians(80): + # Extra pitch needed increases with bank angle + extra_pitch_bias = -0.05 * (1/np.cos(bank_rad) - 1) * 10 # Scaled pull + else: + extra_pitch_bias = -0.3 # Near knife-edge, just add pull + + # Base level flight + extra pull for turn + elevator = hold_vz(state, 0.0, gains) + extra_pitch_bias + elevator = np.clip(elevator, -1.0, 1.0) + + return elevator, aileron + + +def hold_pitch_and_bank(state, target_pitch_deg, target_bank_deg, gains=None): + """ + Hold both pitch angle and bank angle. + + Useful for setting up specific flight conditions (climb + turn, etc). + + Args: + state: Dict from env.get_state() + target_pitch_deg: Desired pitch angle (positive = nose up) + target_bank_deg: Desired bank angle (positive = right bank) + gains: Dict with all gains (uses defaults if None) + + Returns: + (elevator, aileron): Tuple of control inputs [-1, 1] + """ + if gains is None: + gains = DEFAULT_GAINS + + elevator = hold_pitch(state, target_pitch_deg, gains) + aileron = hold_bank(state, target_bank_deg, gains) + + return elevator, aileron + + +def full_autopilot(state, target_pitch_deg=0.0, target_bank_deg=0.0, + target_vz=None, damp_yaw_rate=True, gains=None): + """ + Full 3-axis autopilot for stable flight. + + Can operate in pitch-hold or vz-hold mode for elevator. + + Args: + state: Dict from env.get_state() + target_pitch_deg: Desired pitch angle (used if target_vz is None) + target_bank_deg: Desired bank angle (positive = right bank) + target_vz: If provided, holds vz instead of pitch + damp_yaw_rate: Whether to damp yaw oscillations + gains: Dict with all gains (uses defaults if None) + + Returns: + (elevator, aileron, rudder): Tuple of control inputs [-1, 1] + """ + if gains is None: + gains = DEFAULT_GAINS + + # Elevator: vz-hold or pitch-hold + if target_vz is not None: + elevator = hold_vz(state, target_vz, gains) + else: + elevator = hold_pitch(state, target_pitch_deg, gains) + + # Aileron: bank hold + aileron = hold_bank(state, target_bank_deg, gains) + + # Rudder: yaw damping + rudder = damp_yaw(state, gains) if damp_yaw_rate else 0.0 + + return elevator, aileron, rudder + + +# ============================================================================= +# VELOCITY-BASED AUTOPILOTS FOR DELTA CONTROL +# ============================================================================= +# These functions output velocity commands that work with delta action space. +# With delta control, actions are velocity commands that accumulate: +# ctrl_elevator += control_rate_coeff * action[1] +# +# The velocity autopilots: +# 1. Compute target position using existing position-based autopilot +# 2. Read current control position from state (ctrl_elevator, etc.) +# 3. Output velocity to drive toward target position + +# Default coefficient for delta control (matches dogfight.ini) +DEFAULT_CONTROL_RATE_COEFF = 0.25 + + +def hold_pitch_velocity(state, target_pitch_deg, gains=None, coeff=DEFAULT_CONTROL_RATE_COEFF): + """ + Velocity-based pitch hold for delta control. + + Args: + state: Dict from env.get_state() - must include ctrl_elevator + target_pitch_deg: Desired pitch angle in degrees + gains: PD gains (uses defaults if None) + coeff: Control rate coefficient (default 0.25) + + Returns: + velocity_cmd: Velocity command [-1, 1] for delta control + """ + current_elevator = state.get('ctrl_elevator', 0.0) + target_elevator = hold_pitch(state, target_pitch_deg, gains) + + # P-controller on position error, normalized for coeff + error = target_elevator - current_elevator + velocity_cmd = 2.0 * error / coeff + + return np.clip(velocity_cmd, -1.0, 1.0) + + +def hold_vz_velocity(state, target_vz, gains=None, coeff=DEFAULT_CONTROL_RATE_COEFF): + """ + Velocity-based vertical speed hold for delta control. + + Args: + state: Dict from env.get_state() - must include ctrl_elevator + target_vz: Desired vertical speed in m/s + gains: PD gains (uses defaults if None) + coeff: Control rate coefficient (default 0.25) + + Returns: + velocity_cmd: Velocity command [-1, 1] for delta control + """ + current_elevator = state.get('ctrl_elevator', 0.0) + target_elevator = hold_vz(state, target_vz, gains) + + error = target_elevator - current_elevator + velocity_cmd = 2.0 * error / coeff + + return np.clip(velocity_cmd, -1.0, 1.0) + + +def hold_bank_velocity(state, target_bank_deg, gains=None, coeff=DEFAULT_CONTROL_RATE_COEFF): + """ + Velocity-based bank hold for delta control. + + Args: + state: Dict from env.get_state() - must include ctrl_aileron + target_bank_deg: Desired bank angle in degrees + gains: PD gains (uses defaults if None) + coeff: Control rate coefficient (default 0.25) + + Returns: + velocity_cmd: Velocity command [-1, 1] for delta control + """ + current_aileron = state.get('ctrl_aileron', 0.0) + target_aileron = hold_bank(state, target_bank_deg, gains) + + error = target_aileron - current_aileron + velocity_cmd = 2.0 * error / coeff + + return np.clip(velocity_cmd, -1.0, 1.0) + + +def damp_yaw_velocity(state, gains=None, coeff=DEFAULT_CONTROL_RATE_COEFF): + """ + Velocity-based yaw damping for delta control. + + Args: + state: Dict from env.get_state() - must include ctrl_rudder + gains: PD gains (uses defaults if None) + coeff: Control rate coefficient (default 0.25) + + Returns: + velocity_cmd: Velocity command [-1, 1] for delta control + """ + current_rudder = state.get('ctrl_rudder', 0.0) + target_rudder = damp_yaw(state, gains) + + error = target_rudder - current_rudder + velocity_cmd = 2.0 * error / coeff + + return np.clip(velocity_cmd, -1.0, 1.0) + + +def hold_bank_and_level_velocity(state, target_bank_deg, gains=None, coeff=DEFAULT_CONTROL_RATE_COEFF): + """ + Velocity-based coordinated turn for delta control. + + Args: + state: Dict from env.get_state() + target_bank_deg: Desired bank angle in degrees + gains: PD gains (uses defaults if None) + coeff: Control rate coefficient (default 0.25) + + Returns: + (elevator_vel, aileron_vel): Tuple of velocity commands [-1, 1] + """ + current_elevator = state.get('ctrl_elevator', 0.0) + current_aileron = state.get('ctrl_aileron', 0.0) + + target_elevator, target_aileron = hold_bank_and_level(state, target_bank_deg, gains) + + elev_error = target_elevator - current_elevator + elev_vel = 2.0 * elev_error / coeff + + ail_error = target_aileron - current_aileron + ail_vel = 2.0 * ail_error / coeff + + return np.clip(elev_vel, -1.0, 1.0), np.clip(ail_vel, -1.0, 1.0) + + +def full_autopilot_velocity(state, target_pitch_deg=0.0, target_bank_deg=0.0, + target_vz=None, damp_yaw_rate=True, gains=None, + coeff=DEFAULT_CONTROL_RATE_COEFF): + """ + Full 3-axis velocity-based autopilot for delta control. + + Args: + state: Dict from env.get_state() + target_pitch_deg: Desired pitch angle (used if target_vz is None) + target_bank_deg: Desired bank angle (positive = right bank) + target_vz: If provided, holds vz instead of pitch + damp_yaw_rate: Whether to damp yaw oscillations + gains: PD gains (uses defaults if None) + coeff: Control rate coefficient (default 0.25) + + Returns: + (elevator_vel, aileron_vel, rudder_vel): Tuple of velocity commands [-1, 1] + """ + current_elevator = state.get('ctrl_elevator', 0.0) + current_aileron = state.get('ctrl_aileron', 0.0) + current_rudder = state.get('ctrl_rudder', 0.0) + + target_elevator, target_aileron, target_rudder = full_autopilot( + state, target_pitch_deg, target_bank_deg, target_vz, damp_yaw_rate, gains + ) + + elev_error = target_elevator - current_elevator + elev_vel = 2.0 * elev_error / coeff + + ail_error = target_aileron - current_aileron + ail_vel = 2.0 * ail_error / coeff + + rud_error = target_rudder - current_rudder + rud_vel = 2.0 * rud_error / coeff + + return (np.clip(elev_vel, -1.0, 1.0), + np.clip(ail_vel, -1.0, 1.0), + np.clip(rud_vel, -1.0, 1.0)) diff --git a/pufferlib/ocean/dogfight/binding.c b/pufferlib/ocean/dogfight/binding.c new file mode 100644 index 000000000..af48f078f --- /dev/null +++ b/pufferlib/ocean/dogfight/binding.c @@ -0,0 +1,653 @@ +#include "dogfight.h" + +#define Env Dogfight + +#include + +static PyObject* env_force_state(PyObject* self, PyObject* args, PyObject* kwargs); +static PyObject* env_set_autopilot(PyObject* self, PyObject* args, PyObject* kwargs); +static PyObject* vec_set_autopilot(PyObject* self, PyObject* args, PyObject* kwargs); +static PyObject* vec_set_mode_weights(PyObject* self, PyObject* args, PyObject* kwargs); +static PyObject* vec_set_curriculum_stage(PyObject* self, PyObject* args); +static PyObject* vec_set_curriculum_target(PyObject* self, PyObject* args); +static PyObject* env_get_autopilot_mode(PyObject* self, PyObject* args); +static PyObject* env_get_state(PyObject* self, PyObject* args); +static PyObject* env_set_obs_highlight(PyObject* self, PyObject* args); +static PyObject* env_get_autoace_state(PyObject* self, PyObject* args); +static PyObject* env_set_camera_follow(PyObject* self, PyObject* args); +static PyObject* vec_get_opponent_observations(PyObject* self, PyObject* args); +static PyObject* vec_set_opponent_actions(PyObject* self, PyObject* args); +static PyObject* vec_enable_opponent_override(PyObject* self, PyObject* args); +static PyObject* vec_set_opponent_buffers(PyObject* self, PyObject* args); +static PyObject* vec_set_eval_spawn_mode(PyObject* self, PyObject* args); + +#define MY_METHODS \ + {"env_force_state", (PyCFunction)env_force_state, METH_VARARGS | METH_KEYWORDS, "Force environment state"}, \ + {"env_set_autopilot", (PyCFunction)env_set_autopilot, METH_VARARGS | METH_KEYWORDS, "Set opponent autopilot mode"}, \ + {"vec_set_autopilot", (PyCFunction)vec_set_autopilot, METH_VARARGS | METH_KEYWORDS, "Set autopilot for all envs"}, \ + {"vec_set_mode_weights", (PyCFunction)vec_set_mode_weights, METH_VARARGS | METH_KEYWORDS, "Set mode weights for all envs"}, \ + {"vec_set_curriculum_stage", (PyCFunction)vec_set_curriculum_stage, METH_VARARGS, "Set curriculum stage for all envs"}, \ + {"vec_set_curriculum_target", (PyCFunction)vec_set_curriculum_target, METH_VARARGS, "Set curriculum target (float) for all envs"}, \ + {"env_get_autopilot_mode", (PyCFunction)env_get_autopilot_mode, METH_VARARGS, "Get current autopilot mode"}, \ + {"env_get_state", (PyCFunction)env_get_state, METH_VARARGS, "Get raw player state"}, \ + {"env_set_obs_highlight", (PyCFunction)env_set_obs_highlight, METH_VARARGS, "Set observation indices to highlight with red arrows"}, \ + {"env_get_autoace_state", (PyCFunction)env_get_autoace_state, METH_VARARGS, "Get AutoAce opponent state and tactical info"}, \ + {"env_set_camera_follow", (PyCFunction)env_set_camera_follow, METH_VARARGS, "Set camera to follow player (0) or opponent (1)"}, \ + {"vec_get_opponent_observations", (PyCFunction)vec_get_opponent_observations, METH_VARARGS, "Get observations from opponent perspective for self-play"}, \ + {"vec_set_opponent_actions", (PyCFunction)vec_set_opponent_actions, METH_VARARGS, "Set opponent actions from external policy (self-play)"}, \ + {"vec_enable_opponent_override", (PyCFunction)vec_enable_opponent_override, METH_VARARGS, "Enable/disable opponent action override (0=autopilot, 1=external)"}, \ + {"vec_set_opponent_buffers", (PyCFunction)vec_set_opponent_buffers, METH_VARARGS, "Set opponent observation/reward buffers for dual self-play"}, \ + {"vec_set_eval_spawn_mode", (PyCFunction)vec_set_eval_spawn_mode, METH_VARARGS, "Set eval spawn mode (0=random, 1=opponent_advantage)"} + +static float get_float(PyObject *kwargs, const char *key, float default_val) { + if (!kwargs) return default_val; + PyObject *val = PyDict_GetItemString(kwargs, key); + if (!val) return default_val; + if (PyFloat_Check(val)) return (float)PyFloat_AsDouble(val); + if (PyLong_Check(val)) return (float)PyLong_AsLong(val); + return default_val; +} + +static int get_int(PyObject *kwargs, const char *key, int default_val) { + if (!kwargs) return default_val; + PyObject *val = PyDict_GetItemString(kwargs, key); + if (!val) return default_val; + if (PyLong_Check(val)) return (int)PyLong_AsLong(val); + if (PyFloat_Check(val)) return (int)PyFloat_AsDouble(val); + return default_val; +} + +#include "../env_binding.h" + +static int my_init(Env *env, PyObject *args, PyObject *kwargs) { + env->max_steps = unpack(kwargs, "max_steps"); + int obs_scheme = get_int(kwargs, "obs_scheme", 0); + + RewardConfig rcfg = { + .aim_scale = get_float(kwargs, "reward_aim_scale", 0.05f), + .closing_scale = get_float(kwargs, "reward_closing_scale", 0.003f), + .neg_g = get_float(kwargs, "penalty_neg_g", 0.02f), + .control_rate_penalty = get_float(kwargs, "control_rate_penalty", 0.0f), + .speed_min = get_float(kwargs, "speed_min", 50.0f), + }; + + int curriculum_enabled = get_int(kwargs, "curriculum_enabled", 0); + int curriculum_randomize = get_int(kwargs, "curriculum_randomize", 0); + int eval_spawn_mode = get_int(kwargs, "eval_spawn_mode", 0); + + int env_num = get_int(kwargs, "env_num", 0); + + init(env, obs_scheme, &rcfg, curriculum_enabled, curriculum_randomize, env_num); + env->eval_spawn_mode = eval_spawn_mode; // Set after init (overrides default 0) + return 0; +} + +static int my_log(PyObject *dict, Log *log) { + assign_to_dict(dict, "episode_return", log->episode_return); + assign_to_dict(dict, "episode_length", log->episode_length); + assign_to_dict(dict, "score", log->score); + assign_to_dict(dict, "perf", log->perf); + assign_to_dict(dict, "sp_player_kills", log->sp_player_kills); + assign_to_dict(dict, "sp_opp_kills", log->sp_opp_kills); + assign_to_dict(dict, "shots_fired", log->shots_fired); + assign_to_dict(dict, "accuracy", log->accuracy); + assign_to_dict(dict, "stage", log->stage); + + assign_to_dict(dict, "avg_stage_weight", log->total_stage_weight); // Raw sum → correct avg + assign_to_dict(dict, "avg_abs_bias", log->total_abs_bias); // Raw sum → correct avg + assign_to_dict(dict, "avg_stage", log->stage_sum); // Raw sum → correct avg + assign_to_dict(dict, "avg_control_rate", log->total_control_rate); // Raw sum → correct avg + assign_to_dict(dict, "base_stage_kills", log->base_stage_kills); // Raw sum (not averaged) + assign_to_dict(dict, "base_stage_eps", log->base_stage_eps); // Raw sum (not averaged) + assign_to_dict(dict, "ultimate", log->ultimate); + assign_to_dict(dict, "n", log->n); + return 0; +} + +static PyObject* env_force_state(PyObject* self, PyObject* args, PyObject* kwargs) { + if (PyTuple_Size(args) != 1) { + PyErr_SetString(PyExc_TypeError, "env_force_state requires 1 positional arg (env handle)"); + return NULL; + } + + Env* env = unpack_env(args); + if (!env) return NULL; + + float p_px = get_float(kwargs, "p_px", 0.0f); + float p_py = get_float(kwargs, "p_py", 0.0f); + float p_pz = get_float(kwargs, "p_pz", 1000.0f); + + float p_vx = get_float(kwargs, "p_vx", 150.0f); + float p_vy = get_float(kwargs, "p_vy", 0.0f); + float p_vz = get_float(kwargs, "p_vz", 0.0f); + + float p_ow = get_float(kwargs, "p_ow", 1.0f); + float p_ox = get_float(kwargs, "p_ox", 0.0f); + float p_oy = get_float(kwargs, "p_oy", 0.0f); + float p_oz = get_float(kwargs, "p_oz", 0.0f); + + float p_throttle = get_float(kwargs, "p_throttle", 1.0f); + + float o_px = get_float(kwargs, "o_px", -9999.0f); + float o_py = get_float(kwargs, "o_py", -9999.0f); + float o_pz = get_float(kwargs, "o_pz", -9999.0f); + + float o_vx = get_float(kwargs, "o_vx", -9999.0f); + float o_vy = get_float(kwargs, "o_vy", -9999.0f); + float o_vz = get_float(kwargs, "o_vz", -9999.0f); + + float o_ow = get_float(kwargs, "o_ow", -9999.0f); + float o_ox = get_float(kwargs, "o_ox", -9999.0f); + float o_oy = get_float(kwargs, "o_oy", -9999.0f); + float o_oz = get_float(kwargs, "o_oz", -9999.0f); + + int tick = get_int(kwargs, "tick", 0); + + int p_cooldown = get_int(kwargs, "p_cooldown", -1); + int o_cooldown = get_int(kwargs, "o_cooldown", -1); + + force_state(env, + p_px, p_py, p_pz, + p_vx, p_vy, p_vz, + p_ow, p_ox, p_oy, p_oz, + p_throttle, + o_px, o_py, o_pz, + o_vx, o_vy, o_vz, + o_ow, o_ox, o_oy, o_oz, + tick, + p_cooldown, + o_cooldown + ); + + Py_RETURN_NONE; +} + +static PyObject* env_set_autopilot(PyObject* self, PyObject* args, PyObject* kwargs) { + if (PyTuple_Size(args) != 1) { + PyErr_SetString(PyExc_TypeError, "env_set_autopilot requires 1 positional arg (env handle)"); + return NULL; + } + + Env* env = unpack_env(args); + if (!env) return NULL; + + int mode = get_int(kwargs, "mode", AP_STRAIGHT); + if (mode < 0 || mode >= AP_COUNT) mode = AP_STRAIGHT; + float throttle = get_float(kwargs, "throttle", AP_DEFAULT_THROTTLE); + float bank_deg = get_float(kwargs, "bank_deg", AP_DEFAULT_BANK_DEG); + float climb_rate = get_float(kwargs, "climb_rate", AP_DEFAULT_CLIMB_RATE); + + autopilot_set_mode(&env->opponent_ap, (AutopilotMode)mode, throttle, bank_deg, climb_rate); + + Py_RETURN_NONE; +} + +static PyObject* vec_set_autopilot(PyObject* self, PyObject* args, PyObject* kwargs) { + if (PyTuple_Size(args) != 1) { + PyErr_SetString(PyExc_TypeError, "vec_set_autopilot requires 1 positional arg (vec handle)"); + return NULL; + } + + VecEnv* vec = unpack_vecenv(args); + if (!vec) return NULL; + + int mode = get_int(kwargs, "mode", AP_STRAIGHT); + if (mode < 0 || mode >= AP_COUNT) mode = AP_STRAIGHT; + float throttle = get_float(kwargs, "throttle", AP_DEFAULT_THROTTLE); + float bank_deg = get_float(kwargs, "bank_deg", AP_DEFAULT_BANK_DEG); + float climb_rate = get_float(kwargs, "climb_rate", AP_DEFAULT_CLIMB_RATE); + + for (int i = 0; i < vec->num_envs; i++) { + autopilot_set_mode(&vec->envs[i]->opponent_ap, (AutopilotMode)mode, + throttle, bank_deg, climb_rate); + } + + Py_RETURN_NONE; +} + +// Set mode weights for curriculum learning (vectorized) +static PyObject* vec_set_mode_weights(PyObject* self, PyObject* args, PyObject* kwargs) { + if (PyTuple_Size(args) != 1) { + PyErr_SetString(PyExc_TypeError, "vec_set_mode_weights requires 1 positional arg (vec handle)"); + return NULL; + } + + VecEnv* vec = unpack_vecenv(args); + if (!vec) return NULL; + + // Get weights for each mode (default 0.2 each for modes 1-5) + float w_level = get_float(kwargs, "level", 0.2f); + float w_turn_left = get_float(kwargs, "turn_left", 0.2f); + float w_turn_right = get_float(kwargs, "turn_right", 0.2f); + float w_climb = get_float(kwargs, "climb", 0.2f); + float w_descend = get_float(kwargs, "descend", 0.2f); + + // Set weights for all environments + for (int i = 0; i < vec->num_envs; i++) { + AutopilotState* ap = &vec->envs[i]->opponent_ap; + ap->mode_weights[AP_LEVEL] = w_level; + ap->mode_weights[AP_TURN_LEFT] = w_turn_left; + ap->mode_weights[AP_TURN_RIGHT] = w_turn_right; + ap->mode_weights[AP_CLIMB] = w_climb; + ap->mode_weights[AP_DESCEND] = w_descend; + } + + Py_RETURN_NONE; +} + +// Set curriculum stage for all environments (global curriculum) +static PyObject* vec_set_curriculum_stage(PyObject* self, PyObject* args) { + PyObject* vec_arg; + int stage; + + if (!PyArg_ParseTuple(args, "Oi", &vec_arg, &stage)) { + return NULL; + } + + VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(vec_arg); + if (!vec) { + PyErr_SetString(PyExc_TypeError, "Invalid vec handle"); + return NULL; + } + + // Set stage for all environments + for (int i = 0; i < vec->num_envs; i++) { + set_curriculum_stage(vec->envs[i], stage); + } + + Py_RETURN_NONE; +} + +// Set curriculum target (float 0.0-15.0) for all environments +static PyObject* vec_set_curriculum_target(PyObject* self, PyObject* args) { + PyObject* vec_arg; + float target; + + if (!PyArg_ParseTuple(args, "Of", &vec_arg, &target)) { + return NULL; + } + + VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(vec_arg); + if (!vec) { + PyErr_SetString(PyExc_TypeError, "Invalid vec handle"); + return NULL; + } + + for (int i = 0; i < vec->num_envs; i++) { + set_curriculum_target(vec->envs[i], target); + } + + Py_RETURN_NONE; +} + +static PyObject* env_get_autopilot_mode(PyObject* self, PyObject* args) { + Env* env = unpack_env(args); + if (!env) return NULL; + + return PyLong_FromLong((long)env->opponent_ap.mode); +} + +static PyObject* env_get_state(PyObject* self, PyObject* args) { + Env* env = unpack_env(args); + if (!env) return NULL; + + Plane* p = &env->player; + Vec3 up = quat_rotate(p->ori, vec3(0, 0, 1)); + Vec3 fwd = quat_rotate(p->ori, vec3(1, 0, 0)); + + PyObject* dict = PyDict_New(); + if (!dict) return NULL; + + PyDict_SetItemString(dict, "px", PyFloat_FromDouble(p->pos.x)); + PyDict_SetItemString(dict, "py", PyFloat_FromDouble(p->pos.y)); + PyDict_SetItemString(dict, "pz", PyFloat_FromDouble(p->pos.z)); + + PyDict_SetItemString(dict, "vx", PyFloat_FromDouble(p->vel.x)); + PyDict_SetItemString(dict, "vy", PyFloat_FromDouble(p->vel.y)); + PyDict_SetItemString(dict, "vz", PyFloat_FromDouble(p->vel.z)); + + PyDict_SetItemString(dict, "ow", PyFloat_FromDouble(p->ori.w)); + PyDict_SetItemString(dict, "ox", PyFloat_FromDouble(p->ori.x)); + PyDict_SetItemString(dict, "oy", PyFloat_FromDouble(p->ori.y)); + PyDict_SetItemString(dict, "oz", PyFloat_FromDouble(p->ori.z)); + + PyDict_SetItemString(dict, "up_x", PyFloat_FromDouble(up.x)); + PyDict_SetItemString(dict, "up_y", PyFloat_FromDouble(up.y)); + PyDict_SetItemString(dict, "up_z", PyFloat_FromDouble(up.z)); + + PyDict_SetItemString(dict, "fwd_x", PyFloat_FromDouble(fwd.x)); + PyDict_SetItemString(dict, "fwd_y", PyFloat_FromDouble(fwd.y)); + PyDict_SetItemString(dict, "fwd_z", PyFloat_FromDouble(fwd.z)); + + PyDict_SetItemString(dict, "throttle", PyFloat_FromDouble(p->throttle)); + + PyDict_SetItemString(dict, "g_force", PyFloat_FromDouble(p->g_force)); + + PyDict_SetItemString(dict, "omega_x", PyFloat_FromDouble(p->omega.x)); + PyDict_SetItemString(dict, "omega_y", PyFloat_FromDouble(p->omega.y)); + PyDict_SetItemString(dict, "omega_z", PyFloat_FromDouble(p->omega.z)); + + return dict; +} + +// Set which observation indices to highlight with red arrows +// Args: env_handle, list of indices (e.g., [4, 5, 6] for pitch, roll, yaw in scheme 0) +static PyObject* env_set_obs_highlight(PyObject* self, PyObject* args) { + PyObject* env_arg; + PyObject* indices_list; + + if (!PyArg_ParseTuple(args, "OO", &env_arg, &indices_list)) { + return NULL; + } + + // Get env from handle + Env* env = (Env*)PyLong_AsVoidPtr(env_arg); + if (!env) { + PyErr_SetString(PyExc_TypeError, "Invalid env handle"); + return NULL; + } + + // Clear existing highlights + memset(env->obs_highlight, 0, sizeof(env->obs_highlight)); + + // Parse list of indices + if (!PyList_Check(indices_list)) { + PyErr_SetString(PyExc_TypeError, "Second argument must be a list of indices"); + return NULL; + } + + Py_ssize_t n = PyList_Size(indices_list); + for (Py_ssize_t i = 0; i < n; i++) { + PyObject* item = PyList_GetItem(indices_list, i); + if (!PyLong_Check(item)) { + PyErr_SetString(PyExc_TypeError, "Indices must be integers"); + return NULL; + } + int idx = (int)PyLong_AsLong(item); + if (idx >= 0 && idx < 16) { + env->obs_highlight[idx] = 1; + } + } + + Py_RETURN_NONE; +} + +// Get AutoAce opponent state and tactical info for behavioral tests +static PyObject* env_get_autoace_state(PyObject* self, PyObject* args) { + Env* env = unpack_env(args); + if (!env) return NULL; + + Plane* opp = &env->opponent; + Vec3 opp_fwd = quat_rotate(opp->ori, vec3(1, 0, 0)); + Vec3 opp_up = quat_rotate(opp->ori, vec3(0, 0, 1)); + + // Compute bank angle (positive = right wing down) + // Bank = angle between plane's up and world up, signed by up.y + // Match sign convention from get_current_bank() in autoace.h: + // positive when right wing down (up.y < 0) + float opp_bank = acosf(fminf(fmaxf(opp_up.z, -1.0f), 1.0f)); + if (opp_up.y >= 0) opp_bank = -opp_bank; // Negative when left wing down + + PyObject* dict = PyDict_New(); + if (!dict) return NULL; + + // Opponent plane state + PyDict_SetItemString(dict, "opp_px", PyFloat_FromDouble(opp->pos.x)); + PyDict_SetItemString(dict, "opp_py", PyFloat_FromDouble(opp->pos.y)); + PyDict_SetItemString(dict, "opp_pz", PyFloat_FromDouble(opp->pos.z)); + PyDict_SetItemString(dict, "opp_vx", PyFloat_FromDouble(opp->vel.x)); + PyDict_SetItemString(dict, "opp_vy", PyFloat_FromDouble(opp->vel.y)); + PyDict_SetItemString(dict, "opp_vz", PyFloat_FromDouble(opp->vel.z)); + PyDict_SetItemString(dict, "opp_fwd_x", PyFloat_FromDouble(opp_fwd.x)); + PyDict_SetItemString(dict, "opp_fwd_y", PyFloat_FromDouble(opp_fwd.y)); + PyDict_SetItemString(dict, "opp_fwd_z", PyFloat_FromDouble(opp_fwd.z)); + PyDict_SetItemString(dict, "opp_bank", PyFloat_FromDouble(opp_bank)); + + // Opponent orientation quaternion + PyDict_SetItemString(dict, "opp_ow", PyFloat_FromDouble(opp->ori.w)); + PyDict_SetItemString(dict, "opp_ox", PyFloat_FromDouble(opp->ori.x)); + PyDict_SetItemString(dict, "opp_oy", PyFloat_FromDouble(opp->ori.y)); + PyDict_SetItemString(dict, "opp_oz", PyFloat_FromDouble(opp->ori.z)); + + // Last AutoAce actions (from most recent step) + PyDict_SetItemString(dict, "opp_throttle", PyFloat_FromDouble(env->last_opp_actions[0])); + PyDict_SetItemString(dict, "opp_elevator", PyFloat_FromDouble(env->last_opp_actions[1])); + PyDict_SetItemString(dict, "opp_aileron", PyFloat_FromDouble(env->last_opp_actions[2])); + PyDict_SetItemString(dict, "opp_rudder", PyFloat_FromDouble(env->last_opp_actions[3])); + PyDict_SetItemString(dict, "opp_trigger", PyFloat_FromDouble(env->last_opp_actions[4])); + + // Tactical state (from AutoAce) + TacticalState* ts = &env->opponent_ace.tactical; + PyDict_SetItemString(dict, "engagement", PyLong_FromLong(env->opponent_ace.engagement)); + PyDict_SetItemString(dict, "mode", PyLong_FromLong(env->opponent_ap.mode)); + PyDict_SetItemString(dict, "aspect_angle", PyFloat_FromDouble(ts->aspect_angle)); + PyDict_SetItemString(dict, "antenna_train", PyFloat_FromDouble(ts->antenna_train)); + PyDict_SetItemString(dict, "range", PyFloat_FromDouble(ts->range)); + PyDict_SetItemString(dict, "closure_rate", PyFloat_FromDouble(ts->closure_rate)); + PyDict_SetItemString(dict, "in_gun_envelope", PyBool_FromLong(ts->in_gun_envelope)); + + return dict; +} + +// Set camera to follow player (0) or opponent (1) +static PyObject* env_set_camera_follow(PyObject* self, PyObject* args) { + PyObject* env_arg; + int follow_opponent; + + if (!PyArg_ParseTuple(args, "Oi", &env_arg, &follow_opponent)) { + return NULL; + } + + Env* env = (Env*)PyLong_AsVoidPtr(env_arg); + if (!env) { + PyErr_SetString(PyExc_TypeError, "Invalid env handle"); + return NULL; + } + + env->camera_follow_opponent = follow_opponent; + Py_RETURN_NONE; +} + +// Get opponent observations for all environments (for self-play) +// Returns: numpy array of shape (num_envs, obs_size) with opponent's view of the world +// Currently only supports scheme 0 (OBS_MOMENTUM) - returns 16 obs per env +static PyObject* vec_get_opponent_observations(PyObject* self, PyObject* args) { + PyObject* vec_arg; + + if (!PyArg_ParseTuple(args, "O", &vec_arg)) { + return NULL; + } + + VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(vec_arg); + if (!vec) { + PyErr_SetString(PyExc_TypeError, "Invalid vec handle"); + return NULL; + } + + // Get obs_size from first environment (all envs have same scheme) + int obs_size = vec->envs[0]->obs_size; + + // Create numpy array of shape (num_envs, obs_size) + npy_intp dims[2] = {vec->num_envs, obs_size}; + PyObject* arr = PyArray_SimpleNew(2, dims, NPY_FLOAT32); + if (!arr) { + PyErr_SetString(PyExc_MemoryError, "Failed to allocate opponent observations array"); + return NULL; + } + + // Compute opponent observations for each environment + float* data = (float*)PyArray_DATA((PyArrayObject*)arr); + for (int i = 0; i < vec->num_envs; i++) { + compute_opponent_observations(vec->envs[i], data + i * obs_size); + } + + return arr; +} + +// Set opponent actions for all environments (for self-play) +// Args: vec_handle, actions_array (numpy float32 shape [num_envs, 5]) +// Sets opponent_actions_override for each env (used when use_opponent_override=1) +static PyObject* vec_set_opponent_actions(PyObject* self, PyObject* args) { + PyObject* vec_arg; + PyObject* actions_arr; + + if (!PyArg_ParseTuple(args, "OO", &vec_arg, &actions_arr)) { + return NULL; + } + + VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(vec_arg); + if (!vec) { + PyErr_SetString(PyExc_TypeError, "Invalid vec handle"); + return NULL; + } + + // Verify array shape and type + if (!PyArray_Check(actions_arr)) { + PyErr_SetString(PyExc_TypeError, "actions must be a numpy array"); + return NULL; + } + + PyArrayObject* arr = (PyArrayObject*)actions_arr; + if (PyArray_NDIM(arr) != 2) { + PyErr_SetString(PyExc_ValueError, "actions must be 2D array (num_envs, 5)"); + return NULL; + } + + npy_intp* dims = PyArray_DIMS(arr); + if (dims[0] != vec->num_envs || dims[1] != 5) { + PyErr_Format(PyExc_ValueError, + "actions shape must be (%d, 5), got (%ld, %ld)", + vec->num_envs, (long)dims[0], (long)dims[1]); + return NULL; + } + + if (PyArray_TYPE(arr) != NPY_FLOAT32) { + PyErr_SetString(PyExc_TypeError, "actions must be float32 dtype"); + return NULL; + } + + // Copy actions to each environment's override buffer + float* data = (float*)PyArray_DATA(arr); + for (int i = 0; i < vec->num_envs; i++) { + for (int j = 0; j < 5; j++) { + vec->envs[i]->opponent_actions_override[j] = data[i * 5 + j]; + } + } + + Py_RETURN_NONE; +} + +// Enable or disable opponent action override for all environments +// Args: vec_handle, enable (0=use autopilot, 1=use external actions) +static PyObject* vec_enable_opponent_override(PyObject* self, PyObject* args) { + PyObject* vec_arg; + int enable; + + if (!PyArg_ParseTuple(args, "Oi", &vec_arg, &enable)) { + return NULL; + } + + VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(vec_arg); + if (!vec) { + PyErr_SetString(PyExc_TypeError, "Invalid vec handle"); + return NULL; + } + + for (int i = 0; i < vec->num_envs; i++) { + vec->envs[i]->use_opponent_override = enable ? 1 : 0; + } + + Py_RETURN_NONE; +} + +// Set opponent observation/reward buffers for all environments (for dual self-play) +// Args: vec_handle, opponent_obs_array (numpy float32), opponent_rewards_array (numpy float32) +// These buffers will be written during c_step() enabling Multiprocessing backend +static PyObject* vec_set_opponent_buffers(PyObject* self, PyObject* args) { + PyObject* vec_arg; + PyObject* opp_obs_arr; + PyObject* opp_rew_arr; + + if (!PyArg_ParseTuple(args, "OOO", &vec_arg, &opp_obs_arr, &opp_rew_arr)) { + return NULL; + } + + VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(vec_arg); + if (!vec) { + PyErr_SetString(PyExc_TypeError, "Invalid vec handle"); + return NULL; + } + + // Get obs_size from first environment + int obs_size = vec->envs[0]->obs_size; + + // Set opponent observations buffer + float* opp_obs_data = NULL; + if (opp_obs_arr != Py_None) { + if (!PyArray_Check(opp_obs_arr)) { + PyErr_SetString(PyExc_TypeError, "opponent_obs must be a numpy array or None"); + return NULL; + } + PyArrayObject* arr = (PyArrayObject*)opp_obs_arr; + if (PyArray_TYPE(arr) != NPY_FLOAT32) { + PyErr_SetString(PyExc_TypeError, "opponent_obs must be float32 dtype"); + return NULL; + } + opp_obs_data = (float*)PyArray_DATA(arr); + } + + // Set opponent rewards buffer + float* opp_rew_data = NULL; + if (opp_rew_arr != Py_None) { + if (!PyArray_Check(opp_rew_arr)) { + PyErr_SetString(PyExc_TypeError, "opponent_rewards must be a numpy array or None"); + return NULL; + } + PyArrayObject* arr = (PyArrayObject*)opp_rew_arr; + if (PyArray_TYPE(arr) != NPY_FLOAT32) { + PyErr_SetString(PyExc_TypeError, "opponent_rewards must be float32 dtype"); + return NULL; + } + opp_rew_data = (float*)PyArray_DATA(arr); + } + + // Set buffers for each environment + // Each env gets a slice: env[i] -> opp_obs_data + i*obs_size, opp_rew_data + i + for (int i = 0; i < vec->num_envs; i++) { + if (opp_obs_data != NULL) { + vec->envs[i]->opponent_observations = opp_obs_data + i * obs_size; + } else { + vec->envs[i]->opponent_observations = NULL; + } + if (opp_rew_data != NULL) { + vec->envs[i]->opponent_rewards = opp_rew_data + i; + } else { + vec->envs[i]->opponent_rewards = NULL; + } + } + + Py_RETURN_NONE; +} + +// Set eval spawn mode for all environments +// Args: vec_handle, mode (0=random, 1=opponent_advantage) +static PyObject* vec_set_eval_spawn_mode(PyObject* self, PyObject* args) { + PyObject* vec_arg; + int mode; + + if (!PyArg_ParseTuple(args, "Oi", &vec_arg, &mode)) { + return NULL; + } + + VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(vec_arg); + if (!vec) { + PyErr_SetString(PyExc_TypeError, "Invalid vec handle"); + return NULL; + } + + for (int i = 0; i < vec->num_envs; i++) { + vec->envs[i]->eval_spawn_mode = mode; + } + + Py_RETURN_NONE; +} diff --git a/pufferlib/ocean/dogfight/dogfight.c b/pufferlib/ocean/dogfight/dogfight.c new file mode 100644 index 000000000..978e35432 --- /dev/null +++ b/pufferlib/ocean/dogfight/dogfight.c @@ -0,0 +1,312 @@ +// Standalone C demo for Dogfight environment +// Build: ./scripts/build_ocean.sh dogfight local +// Run: ./dogfight +// +// Hold LEFT_SHIFT for human control, release for AI autopilot +// +// Flight Stick (Logitech Extreme 3D or similar): +// Stick X - Roll (push right = roll right) +// Stick Y - Pitch (push forward = nose down) +// Twist - Rudder (twist right = yaw right) +// Throttle - Throttle (forward = more power) +// Trigger - Fire +// +// Keyboard (while holding SHIFT): +// W/S - Pitch down/up +// A/D - Roll left/right +// Q/E - Yaw left/right +// Up/Down - Throttle up/down +// Space - Fire +// +// Global Keys: +// R - Restart episode +// ESC - Quit + +#include +#include +#include "dogfight.h" +#include "puffernet.h" + +// Linux joystick API for flight sticks (bypasses GLFW gamepad abstraction) +#ifdef __linux__ +#include +#include +#include +#include +#include + +typedef struct { + int fd; + char name[80]; + int num_axes; + int num_buttons; + float axes[8]; // Up to 8 axes + int buttons[16]; // Up to 16 buttons +} LinuxJoystick; + +static LinuxJoystick* open_linux_joystick(const char* device) { + int fd = open(device, O_RDONLY | O_NONBLOCK); + if (fd < 0) return NULL; + + LinuxJoystick* js = calloc(1, sizeof(LinuxJoystick)); + js->fd = fd; + + ioctl(fd, JSIOCGNAME(80), js->name); + ioctl(fd, JSIOCGAXES, &js->num_axes); + ioctl(fd, JSIOCGBUTTONS, &js->num_buttons); + + printf("Joystick found: %s\n", js->name); + printf(" Axes: %d, Buttons: %d\n", js->num_axes, js->num_buttons); + + return js; +} + +static void poll_linux_joystick(LinuxJoystick* js) { + if (!js) return; + + struct js_event event; + while (read(js->fd, &event, sizeof(event)) > 0) { + // Mask off init flag + event.type &= ~JS_EVENT_INIT; + + if (event.type == JS_EVENT_AXIS) { + if (event.number < 8) { + js->axes[event.number] = event.value / 32767.0f; + } + } else if (event.type == JS_EVENT_BUTTON) { + if (event.number < 16) { + js->buttons[event.number] = event.value; + } + } + } +} + +static void close_linux_joystick(LinuxJoystick* js) { + if (js) { + close(js->fd); + free(js); + } +} +#endif // __linux__ + +#define DOGFIGHT_OBS_SIZE 17 +#define DOGFIGHT_ACTION_SIZE 5 +#define DOGFIGHT_HIDDEN_SIZE 128 +#define DOGFIGHT_NUM_WEIGHTS 135179 + +static float apply_deadzone(float value, float deadzone) { + if (fabsf(value) < deadzone) return 0.0f; + float sign = value > 0.0f ? 1.0f : -1.0f; + return sign * (fabsf(value) - deadzone) / (1.0f - deadzone); +} + +// Box-Muller transform for sampling from normal distribution +static double randn(double mean, double std) { + static int has_spare = 0; + static double spare; + + if (has_spare) { + has_spare = 0; + return mean + std * spare; + } + + has_spare = 1; + double u, v, s; + do { + u = 2.0 * rand() / RAND_MAX - 1.0; + v = 2.0 * rand() / RAND_MAX - 1.0; + s = u * u + v * v; + } while (s >= 1.0 || s == 0.0); + + s = sqrt(-2.0 * log(s) / s); + spare = v * s; + return mean + std * (u * s); +} + +typedef struct LinearContLSTM LinearContLSTM; +struct LinearContLSTM { + int num_agents; + float *obs; + float *log_std; + Linear *encoder; + GELU *gelu1; + LSTM *lstm; + Linear *actor; + Linear *value_fn; + int num_actions; +}; + +LinearContLSTM *make_linearcontlstm(Weights *weights, int num_agents, int input_dim, + int logit_sizes[], int num_actions) { + LinearContLSTM *net = calloc(1, sizeof(LinearContLSTM)); + net->num_agents = num_agents; + net->obs = calloc(num_agents * input_dim, sizeof(float)); + net->num_actions = logit_sizes[0]; + net->log_std = weights->data; + weights->idx += net->num_actions; + net->encoder = make_linear(weights, num_agents, input_dim, DOGFIGHT_HIDDEN_SIZE); + net->gelu1 = make_gelu(num_agents, DOGFIGHT_HIDDEN_SIZE); + int atn_sum = 0; + for (int i = 0; i < num_actions; i++) { + atn_sum += logit_sizes[i]; + } + net->actor = make_linear(weights, num_agents, DOGFIGHT_HIDDEN_SIZE, atn_sum); + net->value_fn = make_linear(weights, num_agents, DOGFIGHT_HIDDEN_SIZE, 1); + net->lstm = make_lstm(weights, num_agents, DOGFIGHT_HIDDEN_SIZE, DOGFIGHT_HIDDEN_SIZE); + return net; +} + +void free_linearcontlstm(LinearContLSTM *net) { + free(net->obs); + free(net->encoder); + free(net->gelu1); + free(net->actor); + free(net->value_fn); + free(net->lstm); + free(net); +} + +void forward_linearcontlstm(LinearContLSTM *net, float *observations, float *actions) { + linear(net->encoder, observations); + gelu(net->gelu1, net->encoder->output); + lstm(net->lstm, net->gelu1->output); + linear(net->actor, net->lstm->state_h); + linear(net->value_fn, net->lstm->state_h); + for (int i = 0; i < net->num_actions; i++) { + float std = expf(net->log_std[i]); + float mean = net->actor->output[i]; + actions[i] = randn(mean, std); + } +} + +void demo() { + srand(time(NULL)); + + Weights *weights = load_weights("resources/dogfight/puffer_dogfight_weights.bin", DOGFIGHT_NUM_WEIGHTS); + int logit_sizes[1] = {DOGFIGHT_ACTION_SIZE}; + LinearContLSTM *net = make_linearcontlstm(weights, 1, DOGFIGHT_OBS_SIZE, logit_sizes, 1); + + int obs_scheme = OBS_MOMENTUM_BETA; + int obs_size = OBS_SIZES[obs_scheme]; + + Dogfight env = { + .max_steps = 3000, + }; + + // Allocate buffers + env.observations = (float*)calloc(obs_size, sizeof(float)); + env.actions = (float*)calloc(5, sizeof(float)); // throttle, elevator, aileron, rudder, trigger + env.rewards = (float*)calloc(1, sizeof(float)); + env.terminals = (unsigned char*)calloc(1, sizeof(unsigned char)); + + RewardConfig rcfg = { + .aim_scale = 0.05f, + .closing_scale = 0.003f, + .neg_g = 0.02f, + .speed_min = 50.0f, + }; + + // curriculum_enabled=1, curriculum_randomize=1 for variety + init(&env, obs_scheme, &rcfg, 1, 1, 0); + c_reset(&env); + c_render(&env); + + SetTargetFPS(60); + +#ifdef __linux__ + LinuxJoystick* linux_js = open_linux_joystick("/dev/input/js0"); + if (!linux_js) linux_js = open_linux_joystick("/dev/input/js1"); + if (!linux_js) { + printf("No joystick found. Hold SHIFT for keyboard control.\n"); + } +#else + void* linux_js = NULL; +#endif + printf("Hold LEFT_SHIFT for human control, release for AI autopilot.\n"); + printf("Press R to restart, ESC to quit.\n"); + + while (!WindowShouldClose()) { + // Restart on R key + int key = GetKeyPressed(); + if (key == KEY_R || key == 'r' || key == 'R') { + c_reset(&env); + } + + // SHIFT = human control, otherwise AI flies + if (IsKeyDown(KEY_LEFT_SHIFT)) { + // ============================================ + // HUMAN CONTROL (hold SHIFT) + // ============================================ +#ifdef __linux__ + poll_linux_joystick(linux_js); +#endif + + env.actions[0] = 0.0f; // throttle (0 = 50% cruise) + env.actions[1] = 0.0f; // elevator + env.actions[2] = 0.0f; // ailerons + env.actions[3] = 0.0f; // rudder + env.actions[4] = -1.0f; // trigger (not firing) + +#ifdef __linux__ + if (linux_js) { + // Logitech Extreme 3D Pro mapping: + // Axis 0 = Stick X (roll) + // Axis 1 = Stick Y (pitch) + // Axis 2 = Twist (rudder) + // Axis 3 = Throttle slider (forward = -1, back = +1) + // Button 0 = Trigger + + LinuxJoystick* js = linux_js; + + // Pitch: push forward = nose down = positive (stick Y inverted) + env.actions[1] = -apply_deadzone(js->axes[1], 0.1f); + + // Roll: push right = roll right = positive + env.actions[2] = apply_deadzone(js->axes[0], 0.1f); + + // Rudder: twist right = yaw right = negative (action convention) + env.actions[3] = -apply_deadzone(js->axes[2], 0.1f); + + // Throttle: slider forward = more power = positive action + // Slider reports -1 at forward, +1 at back, so invert + env.actions[0] = -js->axes[3]; + + // Trigger (button 0) + if (js->buttons[0]) env.actions[4] = 1.0f; + } +#endif + + // Keyboard controls (always available when SHIFT held) + if (IsKeyDown(KEY_W)) env.actions[1] = 1.0f; // Nose down + if (IsKeyDown(KEY_S)) env.actions[1] = -1.0f; // Nose up + if (IsKeyDown(KEY_A)) env.actions[2] = -1.0f; // Roll left + if (IsKeyDown(KEY_D)) env.actions[2] = 1.0f; // Roll right + if (IsKeyDown(KEY_Q)) env.actions[3] = 1.0f; // Yaw left + if (IsKeyDown(KEY_E)) env.actions[3] = -1.0f; // Yaw right + if (IsKeyDown(KEY_UP)) env.actions[0] = 1.0f; // Full throttle + if (IsKeyDown(KEY_DOWN)) env.actions[0] = -1.0f; // Idle + if (IsKeyDown(KEY_SPACE)) env.actions[4] = 1.0f; // Fire + } else { + forward_linearcontlstm(net, env.observations, env.actions); + } + + c_step(&env); + c_render(&env); + } + +#ifdef __linux__ + close_linux_joystick(linux_js); +#endif + c_close(&env); + free_linearcontlstm(net); + free(weights); + free(env.observations); + free(env.actions); + free(env.rewards); + free(env.terminals); +} + +int main() { + demo(); + return 0; +} diff --git a/pufferlib/ocean/dogfight/dogfight.h b/pufferlib/ocean/dogfight/dogfight.h new file mode 100644 index 000000000..17480f225 --- /dev/null +++ b/pufferlib/ocean/dogfight/dogfight.h @@ -0,0 +1,1888 @@ +// dogfight.h - WW2 aerial combat environment +// Uses flightlib.h for flight physics + +#include +#include +#include +#include +#include + +#include "raylib.h" +#include "rlgl.h" // For rlSetClipPlanes() + +#define DEBUG 0 +#define EVAL_WINDOW 50 +#define PENALTY_STALL 0.002f +#define PENALTY_RUDDER 0.001f + +#include "flightlib.h" +#include "autopilot.h" +#include "autoace.h" + +typedef enum { + OBS_MOMENTUM = 0, // BASELINE: body-frame vel + omega + AoA + energy (15 obs) + OBS_MOMENTUM_BETA = 1, // + sideslip angle (16 obs) + OBS_MOMENTUM_GFORCE = 2, // + G-force (16 obs) + OBS_MOMENTUM_FULL = 3, // + sideslip + G + throttle + tgt rates (19 obs) + OBS_MINIMAL = 4, // stripped down essentials (11 obs) + OBS_CARTESIAN = 5, // cartesian target position (15 obs) + OBS_DRONE_STYLE = 6, // + quaternion + up vector (22 obs) + OBS_QBAR = 7, // + dynamic pressure (16 obs) + OBS_KITCHEN_SINK = 8, // everything (25 obs) + OBS_SCHEME_COUNT +} ObsScheme; + +static const int OBS_SIZES[OBS_SCHEME_COUNT] = {16, 17, 17, 20, 12, 16, 23, 17, 26}; + +typedef enum { + CURRICULUM_TAIL_CHASE = 0, // Stage 0: Easiest - opponent ahead, same heading + CURRICULUM_HEAD_ON, // Stage 1: Opponent coming toward us + CURRICULUM_VERTICAL, // Stage 2: Above or below player + CURRICULUM_GENTLE_TURNS, // Stage 3: Opponent does gentle 30° turns + CURRICULUM_OFFSET, // Stage 4: Large lateral/vertical offset, same heading + CURRICULUM_ANGLED, // Stage 5: Offset + different heading (±22°) + CURRICULUM_SIDE_NEAR, // Stage 6: 15-45° off axis (NEW - small side turn) + CURRICULUM_SIDE_MID, // Stage 7: 30-60° off axis (NEW - medium side turn) + CURRICULUM_SIDE_FAR, // Stage 8: 45-90° off axis (was SIDE_CHASE) + CURRICULUM_SIDE_MANEUVERING, // Stage 9: Side chase + 30° turns + CURRICULUM_DIVE_ATTACK, // Stage 10: 500m altitude advantage, 75° nose-down dive + CURRICULUM_ZOOM_ATTACK, // Stage 11: 500m below, 75° nose-up, near max speed + CURRICULUM_REAR_CHASE, // Stage 12: Target 90-150° off axis (rear quarters) + CURRICULUM_REAR_MANEUVERING, // Stage 13: Rear chase + 30° turns + CURRICULUM_FULL_PREDICTABLE, // Stage 14: 360° spawn, heading correlated (flying away) + CURRICULUM_FULL_RANDOM, // Stage 15: 360° spawn, random heading, 30° turns + CURRICULUM_MEDIUM_TURNS, // Stage 16: 360° spawn, random heading, 45° turns + CURRICULUM_HARD_MANEUVERING, // Stage 17: 60° turns + weave patterns + CURRICULUM_CROSSING, // Stage 18: 45 degree deflection shots + CURRICULUM_EVASIVE, // Stage 19: Reactive evasion (hardest) + CURRICULUM_AUTOACE, // Stage 20: Full AutoAce opponent (two-way combat) + CURRICULUM_COUNT // = 21 +} CurriculumStage; + +// Forward declarations for stage spawn functions +struct Dogfight; // Forward declare Dogfight struct +typedef void (*SpawnFn)(struct Dogfight*, Vec3, Vec3); + +// Stage configuration: consolidates all stage metadata in one place +typedef struct StageConfig { + int n; // Stage number (for easy lookup) + SpawnFn spawn; // Function pointer to spawn function + const char* description; // Human-readable description + float weight; // Difficulty weight (0.0-1.0) + int max_steps; // Episode length - fail to kill = terminal + score -1 + float angle_min_deg; // Min angle off axis (for documentation) + float angle_max_deg; // Max angle off axis + int bank; // Target bank angle in degrees (0=straight, 30, 45, 60) +} StageConfig; + +// Forward declarations of spawn functions (defined below, after Dogfight struct) +static void spawn_tail_chase(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_head_on(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_vertical(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_gentle_turns(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_offset(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_angled(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_side(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_dive_attack(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_zoom_attack(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_rear(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_full_predictable(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_full_random(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_medium_turns(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_hard_maneuvering(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_crossing(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_evasive(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); +static void spawn_autoace(struct Dogfight *env, Vec3 player_pos, Vec3 player_vel); + +// Stage configuration table - single source of truth for all stage metadata +// Updated 2026-01-25 to split SIDE_CHASE into 3 stages (SIDE_NEAR, SIDE_MID, SIDE_FAR) +// Updated 2026-01-26: Consolidated spawn_side/spawn_rear functions use angle fields +// Updated 2026-01-27: Added DIVE_ATTACK (10) and ZOOM_ATTACK (11) stages +// max_steps field is now for documentation only; episode length comes from Python config +static const StageConfig STAGES[CURRICULUM_COUNT] = { + // n spawn_fn description weight max_steps ang_min ang_max bank + {0, spawn_tail_chase, "Target ahead, same heading", 0.01f, 300, 0, 10, 0}, + {1, spawn_head_on, "Target coming toward us", 0.02f, 300, 170, 180, 0}, + {2, spawn_vertical, "Target above/below", 0.05f, 500, 0, 20, 0}, + {3, spawn_gentle_turns, "Target ahead, 30 deg turns", 0.10f, 1000, 0, 30, 30}, + {4, spawn_offset, "Large lateral offset, 30 deg turns", 0.15f, 1000, 0, 45, 30}, + {5, spawn_angled, "Offset + heading variance, 30 deg", 0.20f, 1200, 0, 22, 30}, + {6, spawn_side, "15-45 deg off axis", 0.25f, 1500, 15, 45, 0}, + {7, spawn_side, "30-60 deg off axis", 0.30f, 1800, 30, 60, 0}, + {8, spawn_side, "45-90 deg off axis", 0.35f, 2000, 45, 90, 0}, + {9, spawn_side, "45-90 deg + 30 deg turns", 0.40f, 3000, 45, 90, 30}, + {10, spawn_dive_attack, "Dive attack, 500m altitude adv", 0.45f, 2500, 120, 175, 0}, + {11, spawn_zoom_attack, "Zoom attack, 75 deg nose-up", 0.50f, 3000, 120, 175, 0}, + {12, spawn_rear, "90-150 deg off axis", 0.58f, 3500, 90, 150, 0}, + {13, spawn_rear, "90-150 deg + 30 deg turns", 0.62f, 3500, 90, 150, 30}, + {14, spawn_full_predictable, "360 deg, heading correlated", 0.68f, 4000, 0, 360, 0}, + {15, spawn_full_random, "360 deg random heading, 30 deg", 0.74f, 4000, 0, 360, 30}, + {16, spawn_medium_turns, "360 deg, 45 deg bank turns", 0.82f, 4000, 0, 360, 45}, + {17, spawn_hard_maneuvering, "360 deg, 60 deg banks + weave", 0.90f, 4000, 0, 360, 60}, + {18, spawn_crossing, "45 deg deflection shots", 0.95f, 4000, 45, 45, 0}, + {19, spawn_evasive, "Reactive break turns", 1.00f, 4000, 0, 360, 60}, + {20, spawn_autoace, "AutoAce intelligent opponent", 1.00f, 6000, 0, 360, 0}, +}; + +// Spawn randomization parameters - stage-dependent ranges for variety +typedef struct SpawnRandomization { + float speed_min, speed_max; // Initial airspeed range (m/s) + float pitch_max_deg; // Max pitch deviation (±degrees) + float bank_max_deg; // Max bank deviation (±degrees) + float throttle_min, throttle_max; // Initial throttle range +} SpawnRandomization; + +// Get spawn randomization parameters for a given stage +// Earlier stages = tighter ranges (easier), later stages = wider ranges (harder) +// Updated 2026-01-27: Stage boundaries adjusted for 20-stage curriculum (added DIVE_ATTACK, ZOOM_ATTACK) +static inline SpawnRandomization get_spawn_randomization(int stage) { + if (stage <= 3) return (SpawnRandomization){75, 85, 5, 10, 0.45f, 0.55f}; + if (stage <= 7) return (SpawnRandomization){70, 95, 10, 20, 0.35f, 0.65f}; + if (stage <= 13) return (SpawnRandomization){65, 105, 15, 30, 0.30f, 0.70f}; + return (SpawnRandomization){60, 110, 15, 45, 0.25f, 0.80f}; +} + +#define DT 0.02f + +#define WORLD_HALF_X 4000.0f +#define WORLD_HALF_Y 4000.0f +#define WORLD_MAX_Z 5000.0f +#define MAX_SPEED 250.0f + +#define INV_WORLD_HALF_X 0.00025f // 1/4000 +#define INV_WORLD_HALF_Y 0.00025f // 1/4000 +#define INV_WORLD_MAX_Z 0.0002f // 1/5000 +#define INV_MAX_SPEED 0.004f // 1/250 +#define INV_PI 0.31830988618f // 1/PI +#define INV_HALF_PI 0.63661977236f // 2/PI (i.e., 1/(PI*0.5)) +#define DEG_TO_RAD 0.01745329252f // PI/180 + +#define GUN_RANGE 500.0f // meters +#define INV_GUN_RANGE 0.002f // 1/500 +#define GUN_CONE_ANGLE 0.087f // ~5 degrees in radians +#define FIRE_COOLDOWN 10 // ticks (0.2 seconds at 50Hz) + +typedef struct Log { + float episode_return; + float episode_length; + float score; // 1.0 on kill, 0.0 on failure + float perf; // Raw kills (becomes kill_rate after vec_log divides by n) + float sp_player_kills; // Self-play only: player kills (TUI shows P:## O:##) + float sp_opp_kills; // Self-play only: opponent kills + float shots_fired; + float accuracy; + float stage; + + // RAW SUMS - exported to Python, become correct averages after vec_log divides by n + float total_stage_weight; // Sum of stage weights (exported as avg_stage_weight) + float total_abs_bias; // Sum of |aileron_bias| (exported as avg_abs_bias) + float stage_sum; // Sum of stages (exported as avg_stage) + float total_control_rate; // Sum of per-episode mean squared deltas (exported as avg_control_rate) + float base_stage_kills; // Kills at int(curriculum_target) - for per-stage gating + float base_stage_eps; // Episodes at int(curriculum_target) - for per-stage gating + + // PER-ENV RATIOS - for C debugging only, NOT exported (garbage after vec_log aggregation) + float avg_stage_weight; // = total_stage_weight / n (per-env only) + float avg_abs_bias; // = total_abs_bias / n (per-env only) + float avg_stage; // = stage_sum / n (per-env only) + float kill_rate; // = perf / n (per-env only - Python uses 'perf' instead) + float ultimate; // = kill_rate * avg_stage_weight (per-env only) + float n; +} Log; + +typedef enum DeathReason { + DEATH_NONE = 0, // Episode still running + DEATH_KILL = 1, // Player scored a kill (success) + DEATH_OOB = 2, // Out of bounds + DEATH_TIMEOUT = 3, // Max steps reached + DEATH_SUPERSONIC = 4 // Physics blowup +} DeathReason; + +typedef struct RewardConfig { + // Positive shaping + float aim_scale; // Continuous aiming reward (default 0.05) + float closing_scale; // +N per m/s closing (default 0.003) + // Penalties + float neg_g; // -N per unit G below 0.5 (default 0.02) - enforces "pull to turn" + float control_rate_penalty; // Penalty for (action - prev_action)^2 (default 0, sweepable) + // Thresholds + float speed_min; // Stall threshold (default 50.0) +} RewardConfig; + +typedef struct Client { + Camera3D camera; + float width; + float height; + + float cam_distance; + float cam_azimuth; + float cam_elevation; + int camera_mode; // 0 = follow target, 1 = midpoint view + bool is_dragging; + float last_mouse_x; + float last_mouse_y; + + Model plane_model; + Texture2D plane_texture; + bool model_loaded; + + float propeller_angle; // Current propeller rotation (radians) +} Client; + +typedef struct Dogfight { + float *observations; + float *actions; + float *rewards; + unsigned char *terminals; + + // Opponent perspective buffers (for dual self-play with Multiprocessing) + // Written during c_step() if non-NULL, same size as observations/rewards + float *opponent_observations; // Opponent's view of the world + float *opponent_rewards; // = -player_reward (zero-sum) + + Log log; + Client *client; + int tick; + int max_steps; + float episode_return; + Plane player; + Plane opponent; + // Per-episode precomputed values (for curriculum learning) + float gun_cone_angle; // Hit detection cone (radians) - FIXED at 5° + float cos_gun_cone; // cosf(gun_cone_angle) - for hit detection + // Opponent autopilot + AutopilotState opponent_ap; + // AutoAce intelligent opponent (stage 20+) + AutoAceState opponent_ace; + // Observation scheme + int obs_scheme; + int obs_size; + // Reward configuration (sweepable) + RewardConfig rcfg; + // Episode-level tracking (reset each episode) + int kill; // 1 if killed this episode, 0 otherwise + int opp_kill; // 1 if opponent killed player this episode (self-play) + float episode_shots_fired; // For accuracy tracking + // Curriculum learning + int curriculum_enabled; // 0 = off (legacy spawning), 1 = on + int curriculum_randomize; // 0 = progressive (training), 1 = random stage each episode (eval) + int total_episodes; // Cumulative episodes (persists across resets) + CurriculumStage stage; // Current difficulty stage (set globally by Python) + float curriculum_target; // Float 0.0-15.0 for probabilistic stage assignment + int is_initialized; // Flag to preserve curriculum state across re-init (for Multiprocessing) + // Anti-spinning + float total_aileron_usage; // Accumulated |aileron| input (for spin death) + float aileron_bias; // Cumulative signed aileron (for directional penalty) + float episode_control_rate; // Sum of squared control deltas this episode + // Episode reward accumulators (for DEBUG summaries) + float sum_r_closing; + float sum_r_speed; // Stall penalty + float sum_r_neg_g; + float sum_r_rudder; + float sum_r_aim; + float sum_r_rate; // Control rate penalty + // Aiming diagnostics (reset each episode, for DEBUG output) + float best_aim_angle; // Best (smallest) aim angle achieved (radians) + int ticks_in_cone; // Ticks where aim_dot > cos_gun_cone + float closest_dist; // Closest approach to target (meters) + // Flight envelope diagnostics (reset each episode, for DEBUG output) + float max_g, min_g; // Peak G-forces experienced + float max_bank; // Peak bank angle (abs, radians) + float max_pitch; // Peak pitch angle (abs, radians) + float min_speed, max_speed; // Speed envelope (m/s) + float min_alt, max_alt; // Altitude envelope (m) + float sum_throttle; // For computing mean throttle + int trigger_pulls; // Times trigger was pulled (>0.5) + int prev_trigger; // For edge detection + DeathReason death_reason; + DeathReason last_death_reason; // For rendering: what ended the previous episode + int last_winner; // For rendering: 1=player won, -1=opponent won, 0=draw/timeout + // Debug + int env_num; // Environment index (for filtering debug output) + // Observation highlighting (for visual debugging) + unsigned char obs_highlight[25]; // 1 = highlight this observation with red arrow (max scheme is 25 obs) + // Last opponent actions (for Python access in tests) + float last_opp_actions[5]; // throttle, elevator, aileron, rudder, trigger + // Camera control + int camera_follow_opponent; // 0 = follow player (default), 1 = follow opponent + // Self-play: external opponent actions override (Phase 1) + float opponent_actions_override[5]; // [throttle, elevator, aileron, rudder, trigger] + int use_opponent_override; // 0 = use autopilot, 1 = use override + // Head-on lockout: disable guns until planes pass each other (only for head-on spawns) + int head_on_lockout; // 1 = guns locked until pass-through detected + float prev_rel_dot; // Previous dot(rel_pos, rel_vel) for detecting pass + // Eval spawn mode: 0 = random (default), 1 = opponent_advantage (for testing opponent kill) + int eval_spawn_mode; + // Previous actions for control rate penalty + float prev_elevator; // Previous elevator for rate penalty + float prev_aileron; // Previous aileron for rate penalty + float prev_rudder; // Previous rudder for rate penalty +} Dogfight; + +#include "dogfight_observations.h" + +void init(Dogfight *env, int obs_scheme, RewardConfig *rcfg, int curriculum_enabled, int curriculum_randomize, int env_num) { + env->log = (Log){0}; + env->tick = 0; + env->env_num = env_num; + env->episode_return = 0.0f; + env->client = NULL; + // Observation scheme + env->obs_scheme = (obs_scheme >= 0 && obs_scheme < OBS_SCHEME_COUNT) ? obs_scheme : 0; + env->obs_size = OBS_SIZES[env->obs_scheme]; + // Gun cone for HIT DETECTION - fixed at 5° + env->gun_cone_angle = GUN_CONE_ANGLE; + env->cos_gun_cone = cosf(env->gun_cone_angle); + autopilot_init(&env->opponent_ap); + autoace_init(&env->opponent_ace); + // Reward configuration (copy from provided config) + env->rcfg = *rcfg; + // Episode tracking + env->kill = 0; + env->episode_shots_fired = 0.0f; + + env->curriculum_enabled = curriculum_enabled; + env->curriculum_randomize = curriculum_randomize; + if (!env->is_initialized) { + env->total_episodes = 0; + env->stage = CURRICULUM_TAIL_CHASE; // Stage managed globally by Python + env->curriculum_target = 0.0f; // Start at stage 0 + if (DEBUG >= 1) { + fprintf(stderr, "[INIT] FIRST init ptr=%p env_num=%d - setting total_episodes=0, stage=0\n", (void*)env, env_num); + } + } else { + if (DEBUG >= 1) { + fprintf(stderr, "[INIT] RE-init ptr=%p env_num=%d - preserving total_episodes=%d, stage=%d\n", + (void*)env, env_num, env->total_episodes, env->stage); + } + } + env->is_initialized = 1; + env->total_aileron_usage = 0.0f; + + // Initialize previous actions for control rate penalty + env->prev_elevator = 0.0f; + env->prev_aileron = 0.0f; + env->prev_rudder = 0.0f; + + memset(env->obs_highlight, 0, sizeof(env->obs_highlight)); + + // Self-play: default to autopilot-controlled opponent + env->use_opponent_override = 0; + memset(env->opponent_actions_override, 0, sizeof(env->opponent_actions_override)); + + // Opponent buffers: NULL by default, set by Python if dual self-play is enabled + env->opponent_observations = NULL; + env->opponent_rewards = NULL; + + // Eval spawn mode: 0 = random (default) + env->eval_spawn_mode = 0; +} + +void set_obs_highlight(Dogfight *env, int *indices, int count) { + memset(env->obs_highlight, 0, sizeof(env->obs_highlight)); + for (int i = 0; i < count && i < 25; i++) { + if (indices[i] >= 0 && indices[i] < 25) { + env->obs_highlight[indices[i]] = 1; + } + } +} + +// Helper: set opponent reward (only if buffer exists, for dual self-play) +static inline void set_opponent_reward(Dogfight *env, float reward) { + if (env->opponent_rewards != NULL) { + env->opponent_rewards[0] = reward; + } +} + +void add_log(Dogfight *env) { + // Level 1: Episode summary (one line, easy to grep) + if (DEBUG >= 1 && env->env_num == 0) { + const char* death_names[] = {"NONE", "KILL", "OOB", "TIMEOUT", "SUPERSONIC"}; + float mean_ail = env->total_aileron_usage / fmaxf((float)env->tick, 1.0f); + printf("EP tick=%d ret=%.2f death=%s kill=%d stage=%d total_eps=%d mean_ail=%.2f bias=%.1f\n", + env->tick, env->episode_return, death_names[env->death_reason], + env->kill, env->stage, env->total_episodes, mean_ail, env->aileron_bias); + } + + // Level 2: Reward breakdown (which components dominated?) + if (DEBUG >= 2 && env->env_num == 0) { + printf(" SHAPING: closing=%+.2f aim=%+.2f\n", env->sum_r_closing, env->sum_r_aim); + printf(" PENALTY: stall=%.2f neg_g=%.2f rudder=%.2f rate=%.2f\n", + env->sum_r_speed, env->sum_r_neg_g, env->sum_r_rudder, env->sum_r_rate); + printf(" AIM: best=%.1f° in_cone=%d/%d (%.0f%%) closest=%.0fm\n", + env->best_aim_angle * RAD_TO_DEG, + env->ticks_in_cone, env->tick, + 100.0f * env->ticks_in_cone / fmaxf((float)env->tick, 1.0f), + env->closest_dist); + } + + // Level 3: Flight envelope and control statistics + if (DEBUG >= 3 && env->env_num == 0) { + float mean_throttle = env->sum_throttle / fmaxf((float)env->tick, 1.0f); + printf(" FLIGHT: G=[%+.1f,%+.1f] bank=%.0f° pitch=%.0f° speed=[%.0f,%.0f] alt=[%.0f,%.0f]\n", + env->min_g, env->max_g, + env->max_bank * RAD_TO_DEG, env->max_pitch * RAD_TO_DEG, + env->min_speed, env->max_speed, + env->min_alt, env->max_alt); + printf(" CONTROL: mean_throttle=%.0f%% trigger_pulls=%d shots=%d\n", + mean_throttle * 100.0f, env->trigger_pulls, (int)env->episode_shots_fired); + } + + if (DEBUG >= 10) printf("=== ADD_LOG ===\n"); + if (DEBUG >= 10) printf(" kill=%d, episode_return=%.2f, tick=%d\n", env->kill, env->episode_return, env->tick); + if (DEBUG >= 10) printf(" episode_shots_fired=%.0f, reward=%.2f\n", env->episode_shots_fired, env->rewards[0]); + env->log.episode_return += env->episode_return; + env->log.episode_length += (float)env->tick; + env->log.perf += env->kill ? 1.0f : 0.0f; + // Self-play kill tracking: log at stage 20+ (AutoAce or self-play both have bidirectional combat) + if (env->stage >= CURRICULUM_AUTOACE) { + env->log.sp_player_kills += env->kill ? 1.0f : 0.0f; + env->log.sp_opp_kills += env->opp_kill ? 1.0f : 0.0f; + } + env->log.score += env->rewards[0]; + env->log.shots_fired += env->episode_shots_fired; + env->log.accuracy = (env->log.shots_fired > 0.0f) ? (env->log.perf / env->log.shots_fired * 100.0f) : 0.0f; + env->log.stage = (float)env->stage; + + env->log.total_stage_weight += STAGES[env->stage].weight; // coeffs to scale metrics based on difficulty + env->log.total_abs_bias += fabsf(env->aileron_bias); + env->log.stage_sum += (float)env->stage; // Accumulate for avg_stage + // Mean squared control delta per step this episode (lower = smoother control) + env->log.total_control_rate += env->episode_control_rate / fmaxf((float)env->tick, 1.0f); + + // Track performance at MAJORITY stage (the one we're trying to master) + // At target 0.9, majority is stage 1 (90% of episodes), not stage 0 + int mastery_stage = (int)(env->curriculum_target + 0.5f); // round, not floor + if (env->stage == mastery_stage) { + env->log.base_stage_kills += env->kill ? 1.0f : 0.0f; + env->log.base_stage_eps += 1.0f; + } + + env->log.n += 1.0f; + env->log.kill_rate = env->log.perf / fmaxf(env->log.n, 1.0f); + env->log.avg_stage = env->log.stage_sum / env->log.n; + env->log.avg_abs_bias = env->log.total_abs_bias / env->log.n; + env->log.avg_stage_weight = env->log.total_stage_weight / env->log.n; + + // Ultimate = kill_rate * difficulty (no bias penalty) + env->log.ultimate = env->log.kill_rate * env->log.avg_stage_weight; + + if (DEBUG >= 10) printf(" log.perf=%.2f, log.shots_fired=%.0f, log.n=%.0f\n", env->log.perf, env->log.shots_fired, env->log.n); +} + +// ============================================================================ +// Curriculum Learning: Stage-specific spawn functions +// ============================================================================ + +// Stage advancement handled in add_log() based on recent kill rate +CurriculumStage get_curriculum_stage(Dogfight *env) { + if (!env->curriculum_enabled) return CURRICULUM_FULL_RANDOM; + if (env->curriculum_randomize) { + // Random stage for eval mode - tests all difficulties + return (CurriculumStage)(rand() % CURRICULUM_COUNT); + } + + // Probabilistic selection based on curriculum_target + float target = env->curriculum_target; + int base = (int)target; + float frac = target - (float)base; + + if (base >= CURRICULUM_COUNT - 1) { + return (CurriculumStage)(CURRICULUM_COUNT - 1); + } + + // Probabilistic: if rand < frac, use base+1, else base + if (rndf(0, 1) < frac) { + return (CurriculumStage)(base + 1); + } + return (CurriculumStage)base; +} + +// Stage 0: TAIL_CHASE - Opponent ahead, same heading (easiest) +static void spawn_tail_chase(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Opponent 200-400m ahead with guaranteed minimum offset + // At 300m, 5° gun cone = ~26m radius for hits + // Minimum 26m y-offset guarantees ~5° at 300m (more at closer range) + // Signed offset with minimum magnitude: either [-50, -26] or [26, 50] + float y_sign = rndf(0, 1) > 0.5f ? 1.0f : -1.0f; + float y_offset = y_sign * rndf(26, 50); + + Vec3 opp_pos = vec3( + player_pos.x + rndf(200, 400), + player_pos.y + y_offset, // Min 26m = ~5° at 300m + player_pos.z + rndf(-38, 38) // z can still vary + ); + reset_plane(&env->opponent, opp_pos, player_vel); + env->opponent_ap.mode = AP_STRAIGHT; +} + +// Stage 1: HEAD_ON - Opponent coming toward us +static void spawn_head_on(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Opponent 400-600m ahead, facing us (opposite velocity) + Vec3 opp_pos = vec3( + player_pos.x + rndf(400, 600), + player_pos.y + rndf(-50, 50), + player_pos.z + rndf(-30, 30) + ); + Vec3 opp_vel = vec3(-player_vel.x, -player_vel.y, player_vel.z); + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent_ap.mode = AP_STRAIGHT; +} + +// Stage 18: CROSSING - 45 degree deflection shots (reduced from 90° - see CURRICULUM_PLANS.md) +// 90° deflection is historically nearly impossible; 45° is achievable with proper lead +static void spawn_crossing(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Opponent 300-500m to the side, flying at 45° angle (not perpendicular) + float side = rndf(0, 1) > 0.5f ? 1.0f : -1.0f; + Vec3 opp_pos = vec3( + player_pos.x + rndf(100, 200), + player_pos.y + side * rndf(300, 500), + player_pos.z + rndf(-50, 50) + ); + // 45° crossing velocity: opponent flies at 45° angle across player's path + // cos(45°) ≈ 0.707, sin(45°) ≈ 0.707 + float speed = norm3(player_vel); + float cos45 = 0.7071f; + float sin45 = 0.7071f; + // side=+1 (right): fly toward (-45°) = (cos, -sin) to cross leftward + // side=-1 (left): fly toward (+45°) = (cos, +sin) to cross rightward + Vec3 opp_vel = vec3(speed * cos45, -side * speed * sin45, 0); + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent_ap.mode = AP_STRAIGHT; +} + +// Stage 2: VERTICAL - Above or below player +static void spawn_vertical(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Opponent 200-400m ahead, 200-400m above OR below + float vert = rndf(0, 1) > 0.5f ? 1.0f : -1.0f; + float alt_offset = vert * rndf(200, 400); + Vec3 opp_pos = vec3( + player_pos.x + rndf(200, 400), + player_pos.y + rndf(-50, 50), + clampf(player_pos.z + alt_offset, 300, 4700) + ); + reset_plane(&env->opponent, opp_pos, player_vel); + env->opponent_ap.mode = AP_LEVEL; // Maintain altitude + + // Speed boost only when opponent is ABOVE us (climbing needs energy, diving doesn't) + if (opp_pos.z > player_pos.z) { + env->player.vel = mul3(env->player.vel, 1.15f); + } +} + +// Stage 3: GENTLE_TURNS - Opponent does gentle turns (30°) +static void spawn_gentle_turns(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Random spawn position (similar to original) + Vec3 opp_pos = vec3( + player_pos.x + rndf(200, 500), + player_pos.y + rndf(-100, 100), + player_pos.z + rndf(-50, 50) + ); + reset_plane(&env->opponent, opp_pos, player_vel); + // Randomly choose turn direction - gentle 30° bank + env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_TURN_LEFT : AP_TURN_RIGHT; + env->opponent_ap.target_bank = (float)STAGES[env->stage].bank * DEG_TO_RAD; +} + +// Stage 4: OFFSET - Large lateral/vertical offset, same heading +// Teaches: Finding and tracking targets not directly in front +static void spawn_offset(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Opponent 150-300m ahead with LARGE lateral/vertical offset + Vec3 opp_pos = vec3( + player_pos.x + rndf(150, 300), + player_pos.y + rndf(-200, 200), // Large lateral - can be way to the side + clampf(player_pos.z + rndf(-150, 150), 300, 4700) // Large vertical + ); + reset_plane(&env->opponent, opp_pos, player_vel); + env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_TURN_LEFT : AP_TURN_RIGHT; + env->opponent_ap.target_bank = (float)STAGES[env->stage].bank * DEG_TO_RAD; +} + +// Stage 5: ANGLED - Offset + different heading (±22°) +// Teaches: Pursuit geometry when target isn't flying your direction (small angle) +static void spawn_angled(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + Vec3 opp_pos = vec3( + player_pos.x + rndf(200, 400), + player_pos.y + rndf(-150, 150), + clampf(player_pos.z + rndf(-100, 100), 300, 4700) + ); + + // Heading offset: ±22° from player (reduced from ±45° for smoother progression) + float heading_offset = rndf(-0.385f, 0.385f); // ~22° in radians + float player_heading = atan2f(player_vel.y, player_vel.x); + float opp_heading = player_heading + heading_offset; + + float speed = norm3(player_vel); + Vec3 opp_vel = vec3(speed * cosf(opp_heading), speed * sinf(opp_heading), 0); + + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + + env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_TURN_LEFT : AP_TURN_RIGHT; + env->opponent_ap.target_bank = (float)STAGES[env->stage].bank * DEG_TO_RAD; +} + +// Stages 6-9: Unified side spawn - uses angle_min_deg, angle_max_deg, bank from STAGES +// Stages 6-8: Target off axis, flying away (no turns) +// Stage 9: Same geometry + 30° turns +static void spawn_side(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + const StageConfig* cfg = &STAGES[env->stage]; + + float side = rndf(0, 1) > 0.5f ? 1.0f : -1.0f; + float az_min = cfg->angle_min_deg * DEG_TO_RAD; + float az_max = cfg->angle_max_deg * DEG_TO_RAD; + float azimuth = side * rndf(az_min, az_max); + + float dist = rndf(300, 500); + float phi = rndf(-0.2f, 0.2f); // ±11° elevation + + Vec3 opp_pos = vec3( + player_pos.x + dist * cosf(azimuth), + player_pos.y + dist * sinf(azimuth), + clampf(player_pos.z + dist * sinf(phi), 300, 4700) + ); + + float away_heading = azimuth; + float opp_heading = away_heading + rndf(-0.35f, 0.35f); // ±20° variance + + float speed = norm3(player_vel); + Vec3 opp_vel = vec3(speed * cosf(opp_heading), speed * sinf(opp_heading), 0); + + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + + // AP mode based on bank field: 0 = straight, >0 = turning + if (cfg->bank > 0) { + env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_TURN_LEFT : AP_TURN_RIGHT; + env->opponent_ap.target_bank = (float)cfg->bank * DEG_TO_RAD; + } else { + env->opponent_ap.mode = AP_STRAIGHT; + } + + // Stages 8-9: Boost player speed 15% for pursuit advantage (wide angle chase) + if (env->stage >= CURRICULUM_SIDE_FAR) { + env->player.vel = mul3(env->player.vel, 1.15f); + } + + // Speed boost when opponent is above (climbing needs energy) + if (opp_pos.z > player_pos.z) { + env->player.vel = mul3(env->player.vel, 1.15f); + } +} + +// Stage 10: DIVE_ATTACK - Player starts 500m above, 75° nose-down for fast catch-up +// Same spawn geometry as spawn_rear (90-150° off axis), but player has massive altitude/energy advantage +static void spawn_dive_attack(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + const StageConfig* cfg = &STAGES[env->stage]; + + // Same azimuth geometry as spawn_rear (90-150° off axis) + float side = rndf(0, 1) > 0.5f ? 1.0f : -1.0f; + float az_min = cfg->angle_min_deg * DEG_TO_RAD; + float az_max = cfg->angle_max_deg * DEG_TO_RAD; + float azimuth = side * rndf(az_min, az_max); + + float dist = rndf(300, 500); + // Opponent spawns 500m BELOW player (big altitude advantage) + Vec3 opp_pos = vec3( + player_pos.x + dist * cosf(azimuth), + player_pos.y + dist * sinf(azimuth), + clampf(player_pos.z - 500 + rndf(-50, 50), 300, 4700) + ); + + float opp_heading = azimuth + rndf(-0.35f, 0.35f); // ±20° variance + float speed = norm3(player_vel); + Vec3 opp_vel = vec3(speed * cosf(opp_heading), speed * sinf(opp_heading), 0); + + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_STRAIGHT : AP_LEVEL; + + // Player starts 75° nose down (same heading, just pitched) + // Pitch rotation is around body Y-axis (right wing) + // Positive pitch around Y = nose down in this coordinate system + float pitch = 75.0f * DEG_TO_RAD; + Quat pitch_quat = quat_from_axis_angle(vec3(0, 1, 0), pitch); + env->player.ori = pitch_quat; + // Velocity matches pitch direction (diving toward target area) + env->player.vel = quat_rotate(pitch_quat, player_vel); + env->player.prev_vel = env->player.vel; +} + +// Stage 11: ZOOM_ATTACK - Player starts 500m below, 75° nose-up, near max speed +// Opposite of dive_attack: player zooms up toward target with high energy +static void spawn_zoom_attack(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + const StageConfig* cfg = &STAGES[env->stage]; + + // Same azimuth geometry as spawn_rear (90-150° off axis) + float side = rndf(0, 1) > 0.5f ? 1.0f : -1.0f; + float az_min = cfg->angle_min_deg * DEG_TO_RAD; + float az_max = cfg->angle_max_deg * DEG_TO_RAD; + float azimuth = side * rndf(az_min, az_max); + + float dist = rndf(300, 500); + // Opponent spawns 300 ABOVE player (player zooms up) + Vec3 opp_pos = vec3( + player_pos.x + dist * cosf(azimuth), + player_pos.y + dist * sinf(azimuth), + clampf(player_pos.z + 300 + rndf(-50, 50), 300, 4700) + ); + + float opp_heading = azimuth + rndf(-0.35f, 0.35f); // ±20° variance + float opp_speed = norm3(player_vel); + Vec3 opp_vel = vec3(opp_speed * cosf(opp_heading), opp_speed * sinf(opp_heading), 0); + + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_STRAIGHT : AP_LEVEL; + + // Player starts 75° nose UP with near-max speed (~145 m/s) + // Pitch rotation is around body Y-axis (right wing) + // Negative pitch around Y = nose up in this coordinate system + float pitch = -75.0f * DEG_TO_RAD; + Quat pitch_quat = quat_from_axis_angle(vec3(0, 1, 0), pitch); + env->player.ori = pitch_quat; + + // Set player to high speed (reduced from 140-150 due to instability at extreme pitch) + float zoom_speed = rndf(110, 120); + Vec3 base_vel = vec3(zoom_speed, 0, 0); + env->player.vel = quat_rotate(pitch_quat, base_vel); + env->player.prev_vel = env->player.vel; +} + +// Stages 12-13: Unified rear spawn - uses angle_min_deg, angle_max_deg, bank from STAGES +// Stage 12: Target 90-150° off axis (rear quarters), 50/50 straight/level +// Stage 13: Same geometry + 30° turns (unchanged, zoom_attack inserted before these) +static void spawn_rear(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + const StageConfig* cfg = &STAGES[env->stage]; + + float side = rndf(0, 1) > 0.5f ? 1.0f : -1.0f; + float az_min = cfg->angle_min_deg * DEG_TO_RAD; + float az_max = cfg->angle_max_deg * DEG_TO_RAD; + float azimuth = side * rndf(az_min, az_max); + + float dist = rndf(300, 500); + // Opponent spawns ~500m below player (large altitude advantage for rear chase) + Vec3 opp_pos = vec3( + player_pos.x + dist * cosf(azimuth), + player_pos.y + dist * sinf(azimuth), + clampf(player_pos.z - 500 + rndf(-50, 50), 300, 4700) + ); + + float opp_heading = azimuth + rndf(-0.35f, 0.35f); // ±20° variance + float speed = norm3(player_vel); + Vec3 opp_vel = vec3(speed * cosf(opp_heading), speed * sinf(opp_heading), 0); + + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + + // AP mode based on bank field: 0 = 50/50 straight/level, >0 = turning + if (cfg->bank > 0) { + env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_TURN_LEFT : AP_TURN_RIGHT; + env->opponent_ap.target_bank = (float)cfg->bank * DEG_TO_RAD; + } else { + env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_STRAIGHT : AP_LEVEL; + } + + // Speed boost for rear chase - player starts faster to close the gap + env->player.vel = mul3(env->player.vel, 1.25f); + env->player.prev_vel = env->player.vel; +} + +// Stage 14: FULL_PREDICTABLE - 360° spawn, heading correlated (flying away) +// Teaches: Full sphere awareness with predictable heading +static void spawn_full_predictable(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Full 360° spawn + float azimuth = rndf(-M_PI, M_PI); + float dist = rndf(300, 600); + float phi = rndf(-0.3f, 0.3f); // ±17° elevation + + Vec3 opp_pos = vec3( + player_pos.x + dist * cosf(azimuth) * cosf(phi), + player_pos.y + dist * sinf(azimuth) * cosf(phi), + clampf(player_pos.z + dist * sinf(phi), 300, 4700) + ); + + // KEY: Heading is CORRELATED - flying away from player + float away_heading = azimuth; // Same direction as spawn angle = flying away + float opp_heading = away_heading + rndf(-0.52f, 0.52f); // ±30° variance + + float speed = norm3(player_vel); + Vec3 opp_vel = vec3(speed * cosf(opp_heading), speed * sinf(opp_heading), 0); + + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_TURN_LEFT : AP_TURN_RIGHT; + env->opponent_ap.target_bank = (float)STAGES[env->stage].bank * DEG_TO_RAD; +} + +// Stage 15: FULL_RANDOM - 360° spawn, random heading, 30° turns +// Teaches: Random heading (key difficulty!) - must read observation to determine velocity +static void spawn_full_random(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Random direction in 3D sphere (300-600m from player) + float dist = rndf(300, 600); + float theta = rndf(0, 2.0f * M_PI); // Azimuth: 0-360° + float phi = rndf(-0.3f, 0.3f); // Elevation: ±17° (keep near level) + + Vec3 opp_pos = vec3( + player_pos.x + dist * cosf(theta) * cosf(phi), + player_pos.y + dist * sinf(theta) * cosf(phi), + clampf(player_pos.z + dist * sinf(phi), 300, 4700) + ); + + // Random velocity direction (not necessarily toward/away from player) + float vel_theta = rndf(0, 2.0f * M_PI); + float speed = norm3(player_vel); + Vec3 opp_vel = vec3(speed * cosf(vel_theta), speed * sinf(vel_theta), 0); + + reset_plane(&env->opponent, opp_pos, opp_vel); + + // Set orientation to match velocity direction (yaw rotation around Z) + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), vel_theta); + + // 3 modes: straight, level, turns (still 30° - steeper turns come in stage 16) + float r = rndf(0, 1); + if (r < 0.2f) env->opponent_ap.mode = AP_STRAIGHT; + else if (r < 0.4f) env->opponent_ap.mode = AP_LEVEL; + else env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_TURN_LEFT : AP_TURN_RIGHT; + + env->opponent_ap.target_bank = (float)STAGES[env->stage].bank * DEG_TO_RAD; +} + +// Stage 16: MEDIUM_TURNS - 360° spawn, random heading, 45° turns +// Teaches: Steeper 45° turns (first introduction of harder turns) +static void spawn_medium_turns(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Same geometry as FULL_RANDOM + float dist = rndf(300, 600); + float theta = rndf(0, 2.0f * M_PI); // Azimuth: 0-360° + float phi = rndf(-0.3f, 0.3f); // Elevation: ±17° (keep near level) + + Vec3 opp_pos = vec3( + player_pos.x + dist * cosf(theta) * cosf(phi), + player_pos.y + dist * sinf(theta) * cosf(phi), + clampf(player_pos.z + dist * sinf(phi), 300, 4700) + ); + + // Random velocity direction (uncorrelated with position) + float vel_theta = rndf(0, 2.0f * M_PI); + float speed = norm3(player_vel); + Vec3 opp_vel = vec3(speed * cosf(vel_theta), speed * sinf(vel_theta), 0); + + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), vel_theta); + + // 5 modes with 45° turns + float r = rndf(0, 1); + if (r < 0.2f) env->opponent_ap.mode = AP_STRAIGHT; + else if (r < 0.4f) env->opponent_ap.mode = AP_LEVEL; + else if (r < 0.6f) env->opponent_ap.mode = AP_TURN_LEFT; + else if (r < 0.8f) env->opponent_ap.mode = AP_TURN_RIGHT; + else env->opponent_ap.mode = AP_CLIMB; + + env->opponent_ap.target_bank = (float)STAGES[env->stage].bank * DEG_TO_RAD; +} + +// Stage 17: HARD_MANEUVERING - Hard turns (60°) and weave patterns +static void spawn_hard_maneuvering(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + Vec3 opp_pos = vec3( + player_pos.x + rndf(200, 400), + player_pos.y + rndf(-100, 100), + player_pos.z + rndf(-50, 50) + ); + reset_plane(&env->opponent, opp_pos, player_vel); + + // Pick from hard maneuver modes + float r = rndf(0, 1); + if (r < 0.3f) { + env->opponent_ap.mode = AP_HARD_TURN_LEFT; + } else if (r < 0.6f) { + env->opponent_ap.mode = AP_HARD_TURN_RIGHT; + } else { + env->opponent_ap.mode = AP_WEAVE; + env->opponent_ap.phase = rndf(0, 2.0f * M_PI); // Random start phase + } +} + +// Stage 19: EVASIVE - Opponent reacts to player position (hardest) +static void spawn_evasive(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Override player altitude to near max (3500-4500m) for high-altitude combat + env->player.pos.z = rndf(3500, 4500); + player_pos.z = env->player.pos.z; // Update local copy for opponent spawn + + // Spawn in various positions (like FULL_RANDOM) + float dist = rndf(300, 500); + float theta = rndf(0, 2.0f * M_PI); + float phi = rndf(-0.3f, 0.3f); + + Vec3 opp_pos = vec3( + player_pos.x + dist * cosf(theta) * cosf(phi), + player_pos.y + dist * sinf(theta) * cosf(phi), + clampf(player_pos.z + dist * sinf(phi), 2500, 4800) + ); + + float vel_theta = rndf(0, 2.0f * M_PI); + float speed = norm3(player_vel); + Vec3 opp_vel = vec3(speed * cosf(vel_theta), speed * sinf(vel_theta), 0); + + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), vel_theta); + + // Mix of hard modes with AP_EVASIVE dominant + float r = rndf(0, 1); + if (r < 0.4f) { + env->opponent_ap.mode = AP_EVASIVE; + } else if (r < 0.55f) { + env->opponent_ap.mode = AP_HARD_TURN_LEFT; + } else if (r < 0.7f) { + env->opponent_ap.mode = AP_HARD_TURN_RIGHT; + } else if (r < 0.85f) { + env->opponent_ap.mode = AP_WEAVE; + env->opponent_ap.phase = rndf(0, 2.0f * M_PI); + } else { + // 15% chance of regular turn modes (still steep 60°) + env->opponent_ap.mode = rndf(0,1) > 0.5f ? AP_TURN_LEFT : AP_TURN_RIGHT; + env->opponent_ap.target_bank = (float)STAGES[env->stage].bank * DEG_TO_RAD; + } +} + +// Stage 20: AUTOACE - Intelligent adversarial opponent (two-way combat) +static void spawn_autoace(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Similar to EVASIVE but with higher altitude range for energy maneuvers + // Override player altitude to mid-high (2500-4000m) + env->player.pos.z = rndf(2500, 4000); + player_pos.z = env->player.pos.z; + + // Spawn opponent in various positions (360 degree, varied distance) + float dist = rndf(400, 700); // Slightly further for AutoAce + float theta = rndf(0, 2.0f * M_PI); + float phi = rndf(-0.25f, 0.25f); // ±14 deg elevation + + Vec3 opp_pos = vec3( + player_pos.x + dist * cosf(theta) * cosf(phi), + player_pos.y + dist * sinf(theta) * cosf(phi), + clampf(player_pos.z + dist * sinf(phi), 2000, 4500) + ); + + float vel_theta = rndf(0, 2.0f * M_PI); + float speed = norm3(player_vel); + Vec3 opp_vel = vec3(speed * cosf(vel_theta), speed * sinf(vel_theta), 0); + + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), vel_theta); + + // AutoAce starts with lag pursuit (will adapt based on situation) + env->opponent_ap.mode = AP_PURSUIT_LAG; + + // Initialize AutoAce state + autoace_init(&env->opponent_ace); +} + +// EVAL spawn: True randomization with alternating advantages +// Used when curriculum_randomize=1 - creates varied, fair combat scenarios +static void spawn_eval_random(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // Always clear head-on lockout first (only set if we choose head-on spawn) + env->head_on_lockout = 0; + env->prev_rel_dot = 0.0f; + + // Alternate who gets advantage based on episode count + int player_advantage = (env->total_episodes % 2 == 0); + + // Random spawn type distribution: + // 40% - tactical (one behind/side of other) + // 30% - neutral (both at angles, neither clearly advantaged) + // 20% - energy (altitude/speed difference) + // 10% - head-on (with gun lockout until pass) + float spawn_roll = rndf(0, 1); + + // Base altitude for combat (mid-altitude) + float base_alt = rndf(2000, 3500); + env->player.pos.z = base_alt; + player_pos.z = base_alt; + float speed = norm3(player_vel); + + if (spawn_roll < 0.40f) { + // TACTICAL: One plane behind/side of other (clear advantage) + float dist = rndf(300, 600); + float angle_off = rndf(120, 180) * DEG_TO_RAD; // Behind (120-180° off nose) + float side = rndf(0, 1) > 0.5f ? 1.0f : -1.0f; + + if (player_advantage) { + // Player behind opponent - player has advantage + float opp_heading = rndf(0, 2.0f * M_PI); + Vec3 opp_pos = vec3( + player_pos.x + rndf(300, 500), + player_pos.y + side * rndf(50, 150), + clampf(player_pos.z + rndf(-100, 100), 500, 4500) + ); + Vec3 opp_vel = vec3(speed * cosf(opp_heading), speed * sinf(opp_heading), 0); + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + // Player heading toward opponent + Vec3 to_opp = sub3(opp_pos, player_pos); + float player_heading = atan2f(to_opp.y, to_opp.x); + env->player.vel = vec3(speed * cosf(player_heading), speed * sinf(player_heading), 0); + env->player.ori = quat_from_axis_angle(vec3(0, 0, 1), player_heading); + } else { + // Opponent behind player - opponent has advantage + float player_heading = rndf(0, 2.0f * M_PI); + env->player.vel = vec3(speed * cosf(player_heading), speed * sinf(player_heading), 0); + env->player.ori = quat_from_axis_angle(vec3(0, 0, 1), player_heading); + // Opponent behind + Vec3 opp_pos = vec3( + player_pos.x - cosf(player_heading) * dist + side * sinf(player_heading) * rndf(50, 150), + player_pos.y - sinf(player_heading) * dist - side * cosf(player_heading) * rndf(50, 150), + clampf(player_pos.z + rndf(-100, 100), 500, 4500) + ); + Vec3 to_player = sub3(player_pos, opp_pos); + float opp_heading = atan2f(to_player.y, to_player.x); + Vec3 opp_vel = vec3(speed * cosf(opp_heading), speed * sinf(opp_heading), 0); + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + } + env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_TURN_LEFT : AP_TURN_RIGHT; + env->opponent_ap.target_bank = rndf(30, 60) * DEG_TO_RAD; + + } else if (spawn_roll < 0.70f) { + // NEUTRAL: Both at angles, converging - fair fight + float dist = rndf(400, 700); + float theta = rndf(0, 2.0f * M_PI); + Vec3 opp_pos = vec3( + player_pos.x + dist * cosf(theta), + player_pos.y + dist * sinf(theta), + clampf(player_pos.z + rndf(-200, 200), 500, 4500) + ); + // Both heading toward a point between them (converging) + Vec3 midpoint = mul3(add3(player_pos, opp_pos), 0.5f); + Vec3 player_to_mid = sub3(midpoint, player_pos); + Vec3 opp_to_mid = sub3(midpoint, opp_pos); + // Add some angle offset so they're not perfectly converging + float player_heading = atan2f(player_to_mid.y, player_to_mid.x) + rndf(-0.5f, 0.5f); + float opp_heading = atan2f(opp_to_mid.y, opp_to_mid.x) + rndf(-0.5f, 0.5f); + + env->player.vel = vec3(speed * cosf(player_heading), speed * sinf(player_heading), 0); + env->player.ori = quat_from_axis_angle(vec3(0, 0, 1), player_heading); + Vec3 opp_vel = vec3(speed * cosf(opp_heading), speed * sinf(opp_heading), 0); + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + env->opponent_ap.mode = rndf(0, 1) > 0.5f ? AP_TURN_LEFT : AP_TURN_RIGHT; + env->opponent_ap.target_bank = rndf(30, 45) * DEG_TO_RAD; + + } else if (spawn_roll < 0.90f) { + // ENERGY: Altitude or speed advantage + float dist = rndf(400, 600); + float theta = rndf(0, 2.0f * M_PI); + float alt_diff = rndf(300, 600); // Significant altitude difference + + Vec3 opp_pos; + if (player_advantage) { + // Player higher (energy advantage) + env->player.pos.z = base_alt + alt_diff; + player_pos.z = env->player.pos.z; + opp_pos = vec3( + player_pos.x + dist * cosf(theta), + player_pos.y + dist * sinf(theta), + base_alt + ); + } else { + // Opponent higher (energy advantage) + opp_pos = vec3( + player_pos.x + dist * cosf(theta), + player_pos.y + dist * sinf(theta), + base_alt + alt_diff + ); + } + // Random headings + float player_heading = rndf(0, 2.0f * M_PI); + float opp_heading = rndf(0, 2.0f * M_PI); + env->player.vel = vec3(speed * cosf(player_heading), speed * sinf(player_heading), 0); + env->player.ori = quat_from_axis_angle(vec3(0, 0, 1), player_heading); + Vec3 opp_vel = vec3(speed * cosf(opp_heading), speed * sinf(opp_heading), 0); + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + env->opponent_ap.mode = AP_PURSUIT_LEAD; // Aggressive pursuit for energy fights + + } else { + // HEAD-ON: Facing each other (rare, 10%) - guns locked until they pass + float dist = rndf(600, 900); // Start further apart + float theta = rndf(0, 2.0f * M_PI); + + Vec3 opp_pos = vec3( + player_pos.x + dist * cosf(theta), + player_pos.y + dist * sinf(theta), + clampf(player_pos.z + rndf(-100, 100), 500, 4500) + ); + // Player faces opponent + Vec3 to_opp = sub3(opp_pos, player_pos); + float player_heading = atan2f(to_opp.y, to_opp.x); + // Opponent faces player (opposite direction) + float opp_heading = player_heading + M_PI; + + env->player.vel = vec3(speed * cosf(player_heading), speed * sinf(player_heading), 0); + env->player.ori = quat_from_axis_angle(vec3(0, 0, 1), player_heading); + Vec3 opp_vel = vec3(speed * cosf(opp_heading), speed * sinf(opp_heading), 0); + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + + // HEAD-ON LOCKOUT: Disable guns until they pass each other + env->head_on_lockout = 1; + // Initialize tracking for pass detection + Vec3 rel_pos = sub3(opp_pos, player_pos); + Vec3 rel_vel = sub3(opp_vel, env->player.vel); + env->prev_rel_dot = dot3(rel_pos, rel_vel); + + env->opponent_ap.mode = AP_STRAIGHT; // Fly straight initially + if (DEBUG >= 1) { + fprintf(stderr, "[EVAL-SPAWN] Head-on spawn - guns locked until pass\n"); + } + } + + // Reset autopilot PID state + env->opponent_ap.prev_vz = 0.0f; + env->opponent_ap.prev_bank_error = 0.0f; + + if (DEBUG >= 1) { + fprintf(stderr, "[EVAL-SPAWN] ep=%d advantage=%s spawn_type=%.0f%% dist=%.0fm\n", + env->total_episodes, player_advantage ? "PLAYER" : "OPPONENT", + spawn_roll * 100, norm3(sub3(env->opponent.pos, env->player.pos))); + } +} + +// Test spawn: Opponent behind player with advantage but not instant kill +// Player is 30° off opponent's nose - opponent must maneuver to get the shot +// Opponent is 400m behind, clear positional advantage +static void spawn_eval_opponent_advantage(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + env->head_on_lockout = 0; + env->prev_rel_dot = 0.0f; + + // Player at base altitude, flying straight along +X + float base_alt = 2500.0f; + float speed = norm3(player_vel); + if (speed < 70.0f) speed = 80.0f; + + env->player.pos = vec3(0, 0, base_alt); + env->player.vel = vec3(speed, 0, 0); + env->player.ori = quat_from_axis_angle(vec3(0, 0, 1), 0.0f); // Flying +X + env->player.throttle = 0.5f; + + // Opponent 400m behind player + float dist = 400.0f; + Vec3 opp_pos = vec3(-dist, 0, base_alt); // Directly behind player + + // Opponent heading: 30° off from pointing at player + // Player is at (0,0), opponent at (-400,0) + // Direct heading to player would be 0° (pointing +X) + // We offset 30° so player is 30° off opponent's nose + float angle_off_nose = 30.0f * DEG_TO_RAD; + float opp_heading = angle_off_nose; // Pointing 30° left of player + + Vec3 opp_vel = vec3(speed * cosf(opp_heading), speed * sinf(opp_heading), 0); + reset_plane(&env->opponent, opp_pos, opp_vel); + env->opponent.ori = quat_from_axis_angle(vec3(0, 0, 1), opp_heading); + env->opponent.throttle = 0.6f; + + // Autopilot: pursuit mode to track player + env->opponent_ap.mode = AP_PURSUIT_LEAD; + env->opponent_ap.prev_vz = 0.0f; + env->opponent_ap.prev_bank_error = 0.0f; + + if (DEBUG >= 1) { + Vec3 to_player = sub3(env->player.pos, opp_pos); + float actual_dist = norm3(to_player); + Vec3 opp_fwd = quat_rotate(env->opponent.ori, vec3(1, 0, 0)); + Vec3 to_player_norm = normalize3(to_player); + float aim_dot = dot3(opp_fwd, to_player_norm); + float aim_angle = acosf(clampf(aim_dot, -1.0f, 1.0f)) * RAD_TO_DEG; + fprintf(stderr, "[EVAL-OPP-ADV] dist=%.0fm aim_angle=%.1f° (cone=5°)\n", + actual_dist, aim_angle); + } +} + +// Master spawn function: dispatches to stage-specific spawner +void spawn_by_curriculum(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + // For eval mode (curriculum_randomize=1), use spawn based on eval_spawn_mode + if (env->curriculum_randomize) { + if (env->eval_spawn_mode == 1) { + // Mode 1: opponent advantage - for testing if opponent can kill player + spawn_eval_opponent_advantage(env, player_pos, player_vel); + } else { + // Mode 0 (default): random spawn + spawn_eval_random(env, player_pos, player_vel); + } + // Eval mode uses stage 20 (AutoAce) max_steps for fair combat duration + env->max_steps = STAGES[CURRICULUM_AUTOACE].max_steps; // 6000 + return; + } + + CurriculumStage new_stage = get_curriculum_stage(env); + + // Log stage transitions + if (new_stage != env->stage) { + if (DEBUG >= 1) { + fprintf(stderr, "[STAGE_CHANGE] ptr=%p env=%d eps=%d: stage %d -> %d\n", + (void*)env, env->env_num, env->total_episodes, env->stage, new_stage); + fflush(stderr); + } + env->stage = new_stage; + } + + // Use function pointer from STAGES table (replaces 18-case switch) + if (env->stage < CURRICULUM_COUNT) { + STAGES[env->stage].spawn(env, player_pos, player_vel); + + // Use per-stage max_steps for advanced stages (8+) where episode length matters + // Earlier stages use global max_steps from Python config for fast iteration + // The original "training regression" was from variable episode lengths during early training + // By stage 8+, agents are stable enough to handle longer episodes + if (env->stage >= CURRICULUM_SIDE_FAR) { // Stage 8+ + env->max_steps = STAGES[env->stage].max_steps; + } + // else: keep env->max_steps from Python init (already set) + } else { + spawn_evasive(env, player_pos, player_vel); // Fallback for invalid stage + } + + // Reset autopilot PID state after spawning + env->opponent_ap.prev_vz = 0.0f; + env->opponent_ap.prev_bank_error = 0.0f; +} + +// Legacy spawn (for curriculum_enabled=0) +void spawn_legacy(Dogfight *env, Vec3 player_pos, Vec3 player_vel) { + Vec3 opp_pos = vec3( + player_pos.x + rndf(200, 500), + player_pos.y + rndf(-100, 100), + player_pos.z + rndf(-50, 50) + ); + reset_plane(&env->opponent, opp_pos, player_vel); + + // Handle autopilot: randomize if configured, reset PID state + if (env->opponent_ap.randomize_on_reset) { + autopilot_randomize(&env->opponent_ap); + } + env->opponent_ap.prev_vz = 0.0f; + env->opponent_ap.prev_bank_error = 0.0f; +} + +// ============================================================================ +// Global curriculum control (called from Python based on aggregate kill_rate) +// ============================================================================ + +// Set curriculum stage for a single environment (used by vec version) +void set_curriculum_stage(Dogfight *env, int stage) { + if (stage >= 0 && stage < CURRICULUM_COUNT) { + env->stage = (CurriculumStage)stage; + env->curriculum_target = (float)stage; // Sync target for probabilistic selection + } +} + +// Set curriculum target (float 0.0-15.0) for probabilistic stage assignment +void set_curriculum_target(Dogfight *env, float target) { + env->curriculum_target = fminf(fmaxf(target, 0.0f), (float)(CURRICULUM_COUNT - 1)); +} + +// ============================================================================ + +void c_reset(Dogfight *env) { + // Save last episode result for rendering before reset + env->last_death_reason = env->death_reason; + if (env->death_reason == DEATH_KILL && env->kill) { + env->last_winner = 1; // Player won (got the kill) + } else if (env->death_reason == DEATH_KILL) { + env->last_winner = -1; // Opponent won (player was killed) + } else { + env->last_winner = 0; // Draw/timeout/OOB + } + + // Curriculum stage is now managed globally by Python based on aggregate kill_rate + // (see set_curriculum_stage() called from training loop) + + env->total_episodes++; + + env->tick = 0; + env->episode_return = 0.0f; + + // Clear episode tracking (safe to clear kill after curriculum used it) + env->kill = 0; + env->opp_kill = 0; + env->episode_shots_fired = 0.0f; + env->total_aileron_usage = 0.0f; + env->aileron_bias = 0.0f; + env->episode_control_rate = 0.0f; + + // Reset reward accumulators + env->sum_r_closing = 0.0f; + env->sum_r_speed = 0.0f; + env->sum_r_neg_g = 0.0f; + env->sum_r_rudder = 0.0f; + env->sum_r_aim = 0.0f; + env->sum_r_rate = 0.0f; + env->death_reason = DEATH_NONE; + + // Reset aiming diagnostics + env->best_aim_angle = M_PI; // Start at worst (180°) + env->ticks_in_cone = 0; + env->closest_dist = 10000.0f; // Start at max + + // Reset flight envelope diagnostics + env->max_g = 1.0f; // Start at 1G (level flight) + env->min_g = 1.0f; + env->max_bank = 0.0f; + env->max_pitch = 0.0f; + env->min_speed = 10000.0f; // Start at max + env->max_speed = 0.0f; + env->min_alt = 10000.0f; // Start at max + env->max_alt = 0.0f; + env->sum_throttle = 0.0f; + env->trigger_pulls = 0; + env->prev_trigger = 0; + + // Head-on lockout (only set by spawn_eval_random for head-on spawns) + env->head_on_lockout = 0; + env->prev_rel_dot = 0.0f; + + // Reset previous actions for control rate penalty + env->prev_elevator = 0.0f; + env->prev_aileron = 0.0f; + env->prev_rudder = 0.0f; + + // Gun cone for hit detection - stays fixed at 5° + env->cos_gun_cone = cosf(env->gun_cone_angle); + + // Spawn player at random position with base velocity + // Use most of the sky (800-4200m) but avoid very low altitudes + Vec3 pos = vec3(rndf(-500, 500), rndf(-500, 500), rndf(800, 4200)); + Vec3 vel = vec3(80, 0, 0); // Base speed, will be randomized below + reset_plane(&env->player, pos, vel); + + // Spawn opponent based on curriculum stage (or legacy if disabled) + if (env->curriculum_enabled) { + spawn_by_curriculum(env, pos, vel); + + // Phase 1: Apply stage-dependent speed randomization to both planes + SpawnRandomization r = get_spawn_randomization(env->stage); + float target_speed = rndf(r.speed_min, r.speed_max); + float speed_ratio = target_speed / 80.0f; // Scale from base speed + env->player.vel = mul3(env->player.vel, speed_ratio); + env->player.prev_vel = env->player.vel; // Keep in sync + env->opponent.vel = mul3(env->opponent.vel, speed_ratio); + env->opponent.prev_vel = env->opponent.vel; + + // Phase 2: Apply stage-dependent throttle randomization + env->player.throttle = rndf(r.throttle_min, r.throttle_max); + env->opponent_ap.throttle = rndf(r.throttle_min, r.throttle_max); // Autopilot throttle + } else { + spawn_legacy(env, pos, vel); + } + + if (DEBUG >= 10) printf("=== RESET ===\n"); + if (DEBUG >= 10) printf("kill=%d, episode_shots_fired=%.0f (now cleared)\n", env->kill, env->episode_shots_fired); + if (DEBUG >= 10) printf("player_pos=(%.1f, %.1f, %.1f)\n", pos.x, pos.y, pos.z); + if (DEBUG >= 10) printf("player_vel=(%.1f, %.1f, %.1f) speed=%.1f\n", vel.x, vel.y, vel.z, norm3(vel)); + if (DEBUG >= 10) printf("opponent_pos=(%.1f, %.1f, %.1f)\n", env->opponent.pos.x, env->opponent.pos.y, env->opponent.pos.z); + if (DEBUG >= 10) printf("initial_dist=%.1f m, stage=%d\n", norm3(sub3(env->opponent.pos, pos)), env->stage); + + compute_observations(env); +#if DEBUG >= 5 + print_observations(env); +#endif +} + +// Check if shooter hits target (cone-based hit detection) +bool check_hit(Plane *shooter, Plane *target, float cos_gun_cone) { + Vec3 to_target = sub3(target->pos, shooter->pos); + float dist = norm3(to_target); + if (dist > GUN_RANGE) return false; + if (dist < 1.0f) return false; // Too close (avoid division issues) + + Vec3 forward = quat_rotate(shooter->ori, vec3(1, 0, 0)); + Vec3 to_target_norm = normalize3(to_target); + float cos_angle = dot3(to_target_norm, forward); + return cos_angle > cos_gun_cone; +} +void c_step(Dogfight *env) { + env->tick++; + env->rewards[0] = 0.0f; + env->terminals[0] = 0; + + if (DEBUG >= 10) printf("\n========== TICK %d ==========\n", env->tick); + if (DEBUG >= 10) printf("=== ACTIONS ===\n"); + if (DEBUG >= 10) printf("throttle_raw=%.3f -> throttle=%.3f\n", env->actions[0], (env->actions[0] + 1.0f) * 0.5f); + if (DEBUG >= 10) printf("elevator=%.3f -> pitch_rate=%.3f rad/s\n", env->actions[1], env->actions[1] * MAX_PITCH_RATE); + if (DEBUG >= 10) printf("ailerons=%.3f -> roll_rate=%.3f rad/s\n", env->actions[2], env->actions[2] * MAX_ROLL_RATE); + if (DEBUG >= 10) printf("rudder=%.3f -> yaw_rate=%.3f rad/s\n", env->actions[3], -env->actions[3] * MAX_YAW_RATE); + if (DEBUG >= 10) printf("trigger=%.3f (fires if >0.5)\n", env->actions[4]); + + // Player uses full physics with actions + step_plane_with_physics(&env->player, env->actions, DT); + + // Opponent uses either external override (self-play) or autopilot + if (env->use_opponent_override) { + // Self-play mode: use externally provided actions from Python + float opp_actions[5]; + for (int i = 0; i < 5; i++) { + opp_actions[i] = env->opponent_actions_override[i]; + env->last_opp_actions[i] = opp_actions[i]; + } + + step_plane_with_physics(&env->opponent, opp_actions, DT); + + // Check if self-play opponent shot the player (two-way combat) + // Skip if in head-on lockout (guns disabled until pass) + if (opp_actions[4] > 0.5f && !env->head_on_lockout) { + // Set fire cooldown for visual tracer effect + if (env->opponent.fire_cooldown == 0) { + env->opponent.fire_cooldown = FIRE_COOLDOWN; + } + if (check_hit(&env->opponent, &env->player, env->cos_gun_cone)) { + // Player was shot down by self-play opponent! + if (DEBUG >= 1) { + printf("[SELF-PLAY] Player shot down by opponent policy!\n"); + } + env->opp_kill = 1; // Track opponent kill for self-play stats + env->death_reason = DEATH_KILL; + env->rewards[0] = -1.0f; + set_opponent_reward(env, 1.0f); // Opponent wins (zero-sum) + env->terminals[0] = 1; + add_log(env); + c_reset(env); + return; + } + } + } else if (env->opponent_ap.mode != AP_STRAIGHT) { + // Standard autopilot mode (curriculum stages) + float opp_actions[5]; + + // Use AutoAce for stage 20+ (intelligent adversarial opponent) + if (env->stage >= CURRICULUM_AUTOACE) { + autoace_step(&env->opponent_ap, &env->opponent_ace, + &env->opponent, &env->player, opp_actions, DT); + } else { + // Legacy autopilot for curriculum stages 0-19 + env->opponent_ap.threat_pos = env->player.pos; // For AP_EVASIVE mode + autopilot_step(&env->opponent_ap, &env->opponent, opp_actions, DT); + } + + // Store opponent actions for Python access (testing) + for (int i = 0; i < 5; i++) { + env->last_opp_actions[i] = opp_actions[i]; + } + + step_plane_with_physics(&env->opponent, opp_actions, DT); + + // Check if AutoAce shot the player (two-way combat at stage 20+) + if (env->stage >= CURRICULUM_AUTOACE && opp_actions[4] > 0.5f) { + if (check_hit(&env->opponent, &env->player, env->cos_gun_cone)) { + // Player was shot down by AutoAce! + if (DEBUG >= 1) { + printf("[AUTOACE] Player shot down by AutoAce!\n"); + } + env->opp_kill = 1; // Track opponent kill for self-play stats + env->death_reason = DEATH_KILL; // Reuse KILL (opponent's kill) + env->rewards[0] = -1.0f; // Penalty for dying + set_opponent_reward(env, 1.0f); // Opponent wins (zero-sum) + env->terminals[0] = 1; + add_log(env); + c_reset(env); + return; + } + } + } else { + step_plane(&env->opponent, DT); + } + + // Track aileron usage for monitoring (no death penalty - see BISECTION.md) + env->total_aileron_usage += fabsf(env->actions[2]); + env->aileron_bias += env->actions[2]; + +#if DEBUG >= 3 + // Track flight envelope diagnostics (only when debugging - expensive) + { + Plane *dbg_p = &env->player; + if (dbg_p->g_force > env->max_g) env->max_g = dbg_p->g_force; + if (dbg_p->g_force < env->min_g) env->min_g = dbg_p->g_force; + float speed = norm3(dbg_p->vel); + if (speed < env->min_speed) env->min_speed = speed; + if (speed > env->max_speed) env->max_speed = speed; + if (dbg_p->pos.z < env->min_alt) env->min_alt = dbg_p->pos.z; + if (dbg_p->pos.z > env->max_alt) env->max_alt = dbg_p->pos.z; + // Bank angle from quaternion + float bank = atan2f(2.0f * (dbg_p->ori.w * dbg_p->ori.x + dbg_p->ori.y * dbg_p->ori.z), + 1.0f - 2.0f * (dbg_p->ori.x * dbg_p->ori.x + dbg_p->ori.y * dbg_p->ori.y)); + if (fabsf(bank) > env->max_bank) env->max_bank = fabsf(bank); + // Pitch angle from quaternion + float pitch = asinf(clampf(2.0f * (dbg_p->ori.w * dbg_p->ori.y - dbg_p->ori.z * dbg_p->ori.x), -1.0f, 1.0f)); + if (fabsf(pitch) > env->max_pitch) env->max_pitch = fabsf(pitch); + // Throttle accumulator + env->sum_throttle += dbg_p->throttle; + // Trigger pull edge detection + int trigger_now = (env->actions[4] > 0.5f) ? 1 : 0; + if (trigger_now && !env->prev_trigger) env->trigger_pulls++; + env->prev_trigger = trigger_now; + } +#endif + + // === Head-on pass detection (for eval mode gun lockout) === + if (env->head_on_lockout) { + // Detect when planes pass each other: dot(rel_pos, rel_vel) flips sign + Vec3 rel_pos = sub3(env->opponent.pos, env->player.pos); + Vec3 rel_vel = sub3(env->opponent.vel, env->player.vel); + float rel_dot = dot3(rel_pos, rel_vel); + + // Sign flip from negative (approaching) to positive (separating) = passed + if (env->prev_rel_dot < 0 && rel_dot >= 0) { + env->head_on_lockout = 0; + if (DEBUG >= 1) { + fprintf(stderr, "[HEAD-ON] Planes passed - guns unlocked at tick %d\n", env->tick); + } + } + env->prev_rel_dot = rel_dot; + } + + // === Combat (Phase 5) === + Plane *p = &env->player; + Plane *o = &env->opponent; + float reward = 0.0f; + + // Decrement fire cooldowns + // Note: AutoAce (stage 20+) handles opponent cooldown internally in autoace.h + // Self-play mode also uses opponent cooldown for visual tracer + if (p->fire_cooldown > 0) p->fire_cooldown--; + if ((env->use_opponent_override || env->stage < CURRICULUM_AUTOACE) && o->fire_cooldown > 0) o->fire_cooldown--; + + // Player fires: action[4] > 0.5 and cooldown ready and not in head-on lockout + if (DEBUG >= 10) printf("trigger=%.3f, cooldown=%d, lockout=%d\n", env->actions[4], p->fire_cooldown, env->head_on_lockout); + if (env->actions[4] > 0.5f && p->fire_cooldown == 0 && !env->head_on_lockout) { + p->fire_cooldown = FIRE_COOLDOWN; + env->episode_shots_fired += 1.0f; + if (DEBUG >= 10) printf("=== FIRED! episode_shots_fired=%.0f ===\n", env->episode_shots_fired); + + // Check if hit = kill = SUCCESS = terminal + if (check_hit(p, o, env->cos_gun_cone)) { + if (DEBUG >= 10) printf("*** KILL! ***\n"); + env->kill = 1; + env->death_reason = DEATH_KILL; + env->rewards[0] = 1.0f; + set_opponent_reward(env, -1.0f); // Opponent loses (zero-sum) + env->episode_return += 1.0f; + env->terminals[0] = 1; + add_log(env); + c_reset(env); + return; + } else { + if (DEBUG >= 10) printf("MISS (dist=%.1f, in_cone=%d)\n", norm3(sub3(o->pos, p->pos)), + check_hit(p, o, env->cos_gun_cone)); + } + } + + // === Reward Shaping (all values from rcfg, sweepable) === + Vec3 rel_pos = sub3(o->pos, p->pos); + float dist = norm3(rel_pos); + + // === df11 Simplified Rewards (6 terms: 3 positive, 3 penalties) === + + // 1. Closing velocity: approaching = good + Vec3 rel_vel = sub3(p->vel, o->vel); + Vec3 rel_pos_norm = normalize3(rel_pos); + float closing_rate = dot3(rel_vel, rel_pos_norm); + float r_closing = clampf(closing_rate * env->rcfg.closing_scale, -0.05f, 0.05f); + reward += r_closing; + + // 2. Aim quality: continuous feedback for gun alignment + Vec3 player_fwd = quat_rotate(p->ori, vec3(1, 0, 0)); + float aim_dot = dot3(rel_pos_norm, player_fwd); // -1 to +1 + float aim_angle_deg = acosf(clampf(aim_dot, -1.0f, 1.0f)) * RAD_TO_DEG; + float r_aim = 0.0f; + if (dist < GUN_RANGE * 2.0f) { // Only in engagement envelope (~1000m) + float aim_quality = (aim_dot + 1.0f) * 0.5f; // Remap [-1,1] to [0,1] + r_aim = aim_quality * env->rcfg.aim_scale; + } + reward += r_aim; + + // 3. Negative G penalty: enforce "pull to turn" (realistic) + float g_threshold = 0.5f; + float g_deficit = fmaxf(0.0f, g_threshold - p->g_force); + float r_neg_g = -g_deficit * env->rcfg.neg_g; + reward += r_neg_g; + + // 4. Stall penalty: speed safety + float speed = norm3(p->vel); + float r_stall = 0.0f; + if (speed < env->rcfg.speed_min) { + r_stall = -(env->rcfg.speed_min - speed) * PENALTY_STALL; + } + reward += r_stall; + + // 5. Rudder penalty: prevent knife-edge climbing (small) + float r_rudder = -fabsf(env->actions[3]) * PENALTY_RUDDER; + reward += r_rudder; + + // 5b. Control rate penalty: penalize rapid control changes + // Sweepable coefficient - find max value that still allows good training + float d_e = env->actions[1] - env->prev_elevator; + float d_a = env->actions[2] - env->prev_aileron; + float d_r = env->actions[3] - env->prev_rudder; + float delta_sq = d_e*d_e + d_a*d_a + d_r*d_r; + env->episode_control_rate += delta_sq; // Always accumulate for logging + + float r_rate = 0.0f; + if (env->rcfg.control_rate_penalty > 0.0f) { + r_rate = -delta_sq * env->rcfg.control_rate_penalty; + reward += r_rate; + } + + // Update prev actions for next step + env->prev_elevator = env->actions[1]; + env->prev_aileron = env->actions[2]; + env->prev_rudder = env->actions[3]; + + // 6. Low altitude descent penalty: discourage descending rolling scissors + // If below 250m AND descending, penalty each tick + // Reduced from -0.25f to -0.025f to prevent gradient explosion + // (Opponent gets same penalty applied in opponent_rewards section at end) + float r_low_descent = 0.0f; + if (p->pos.z < 500.0f && p->vel.z < 0.0f) { + r_low_descent = -0.025f; + reward += r_low_descent; + } + + // 7. Tiny tick penalty: time preference for faster kills + float r_time = -0.00001f; + reward += r_time; + + // 8. Energy management reward: encourage maintaining/gaining energy + // Asymmetric: +0.001 for gaining energy, -0.0005 for losing (incentivize climbing) + float player_energy = calc_specific_energy(p); + float r_player_energy = (player_energy > p->prev_energy) ? 0.001f : -0.0005f; + reward += r_player_energy; + p->prev_energy = player_energy; + + // Opponent energy reward (applied to opponent_rewards at end) + float opp_energy = calc_specific_energy(o); + float r_opp_energy = (opp_energy > o->prev_energy) ? 0.001f : -0.0005f; + o->prev_energy = opp_energy; + +#if DEBUG >= 2 + // Track aiming diagnostics + { + float aim_angle_rad = acosf(clampf(aim_dot, -1.0f, 1.0f)); + if (aim_angle_rad < env->best_aim_angle) env->best_aim_angle = aim_angle_rad; + if (aim_dot > env->cos_gun_cone) env->ticks_in_cone++; + if (dist < env->closest_dist) env->closest_dist = dist; + } +#endif + + // Accumulate for episode summary + env->sum_r_closing += r_closing; + env->sum_r_aim += r_aim; + env->sum_r_neg_g += r_neg_g; + env->sum_r_speed += r_stall; + env->sum_r_rudder += r_rudder; + env->sum_r_rate += r_rate; + + if (DEBUG >= 4 && env->env_num == 0) printf("=== REWARD (df11) ===\n"); + if (DEBUG >= 4 && env->env_num == 0) printf("r_closing=%.4f (rate=%.1f m/s)\n", r_closing, closing_rate); + if (DEBUG >= 4 && env->env_num == 0) printf("r_aim=%.4f (aim_angle=%.1f deg, dist=%.1f)\n", r_aim, aim_angle_deg, dist); + if (DEBUG >= 4 && env->env_num == 0) printf("r_neg_g=%.5f (g=%.2f)\n", r_neg_g, p->g_force); + if (DEBUG >= 4 && env->env_num == 0) printf("r_stall=%.4f (speed=%.1f)\n", r_stall, speed); + if (DEBUG >= 4 && env->env_num == 0) printf("r_rudder=%.5f (rud=%.2f)\n", r_rudder, env->actions[3]); + if (DEBUG >= 4 && env->env_num == 0) printf("r_rate=%.5f (delta_sq=%.3f)\n", r_rate, delta_sq); + if (DEBUG >= 4 && env->env_num == 0) printf("reward_total=%.4f\n", reward); + + if (DEBUG >= 10) printf("=== COMBAT ===\n"); + if (DEBUG >= 10) printf("aim_angle=%.1f deg (cone=5 deg)\n", aim_angle_deg); + if (DEBUG >= 10) printf("dist_to_target=%.1f m (gun_range=500)\n", dist); + if (DEBUG >= 10) printf("in_cone=%d, in_range=%d\n", aim_dot > env->cos_gun_cone, dist < GUN_RANGE); + + // Global reward clamping to prevent gradient explosion (restored for df8) + reward = fmaxf(-1.0f, fminf(1.0f, reward)); + + env->rewards[0] = reward; + env->episode_return += reward; + + // Check opponent bounds FIRST (opponent crash/OOB = player wins) + // This handles the "both spiral to ground, one hits first" scenario + bool opp_oob = fabsf(o->pos.x) > WORLD_HALF_X || + fabsf(o->pos.y) > WORLD_HALF_Y || + o->pos.z < 0 || o->pos.z > WORLD_MAX_Z; + + if (opp_oob) { + if (DEBUG >= 1) { + printf("[TERMINAL] Opponent OOB/crashed: pos=(%.1f,%.1f,%.1f)\n", + o->pos.x, o->pos.y, o->pos.z); + } + env->death_reason = DEATH_OOB; // Opponent went OOB (not a player kill) + + // Descending rolling scissors fix: if player is below 200m when opponent crashes, + // both were in a death spiral - punish both equally + if (p->pos.z < 200.0f) { + // Both in death spiral - full punishment for both + env->rewards[0] = -1.0f; + set_opponent_reward(env, -1.0f); + } else { + // Player was at safe altitude - only partial penalty for not getting gun kill + env->rewards[0] = -0.5f; + set_opponent_reward(env, -1.0f); + } + env->terminals[0] = 1; + add_log(env); + c_reset(env); + return; + } + + // Check player bounds + bool oob = fabsf(p->pos.x) > WORLD_HALF_X || + fabsf(p->pos.y) > WORLD_HALF_Y || + p->pos.z < 0 || p->pos.z > WORLD_MAX_Z; + + // Check for supersonic (physics blowup) - 340 m/s = Mach 1 + float player_speed = norm3(p->vel); + float opp_speed = norm3(o->vel); + bool supersonic = player_speed > 340.0f || opp_speed > 340.0f; + if (DEBUG && supersonic) { + printf("=== SUPERSONIC BLOWUP ===\n"); + printf("player_speed=%.1f, opp_speed=%.1f\n", player_speed, opp_speed); + printf("player_vel=(%.1f, %.1f, %.1f)\n", p->vel.x, p->vel.y, p->vel.z); + printf("opp_vel=(%.1f, %.1f, %.1f)\n", o->vel.x, o->vel.y, o->vel.z); + printf("opp_ap_mode=%d\n", env->opponent_ap.mode); + } + + if (oob || env->tick >= env->max_steps || supersonic) { + if (DEBUG >= 10) printf("=== TERMINAL (FAILURE) ===\n"); + if (DEBUG >= 10) printf("oob=%d, supersonic=%d, tick=%d/%d\n", oob, supersonic, env->tick, env->max_steps); + // Track death reason (priority: supersonic > oob > timeout) + if (supersonic) { + env->death_reason = DEATH_SUPERSONIC; + // Physics blowup - both policies penalized equally + env->rewards[0] = -1.0f; + set_opponent_reward(env, -1.0f); + } else if (oob) { + env->death_reason = DEATH_OOB; + // Player crashed/OOB + env->rewards[0] = -1.0f; + + // Descending rolling scissors fix: if opponent is below 200m when player crashes, + // both were in a death spiral - punish both equally + if (o->pos.z < 200.0f) { + // Both in death spiral - full punishment for both + set_opponent_reward(env, -1.0f); + } else { + // Opponent was at safe altitude - only partial penalty for not getting gun kill + set_opponent_reward(env, -0.5f); + } + } else { + // Timeout - both failed to achieve kill + env->death_reason = DEATH_TIMEOUT; + env->rewards[0] = -0.5f; + set_opponent_reward(env, -0.5f); + } + env->terminals[0] = 1; + add_log(env); + c_reset(env); + return; + } + + compute_observations(env); +#if DEBUG >= 5 + print_observations(env); +#endif + + // Compute opponent observations and rewards (for dual self-play with Multiprocessing) + // Only if buffers are provided by Python (non-NULL) + if (env->opponent_observations != NULL) { + compute_opponent_observations(env, env->opponent_observations); + } + if (env->opponent_rewards != NULL) { + // Zero-sum game: opponent reward = negative of player reward + // PLUS independent penalties/rewards (not zero-sum) + float opp_reward = -env->rewards[0]; + // Low descent penalty + if (o->pos.z < 500.0f && o->vel.z < 0.0f) { + opp_reward += -0.025f; // Same penalty as player for low descent + } + // Energy management reward (calculated above in section 8) + opp_reward += r_opp_energy; + env->opponent_rewards[0] = opp_reward; + } +} + +void c_close(Dogfight *env); + +#include "dogfight_render.h" + +// Force exact game state for testing. Defaults shown in comments are applied in Python. +void force_state( + Dogfight *env, + float p_px, // = 0.0f, player pos X + float p_py, // = 0.0f, player pos Y + float p_pz, // = 1000.0f, player pos Z + float p_vx, // = 150.0f, player vel X (m/s) + float p_vy, // = 0.0f, player vel Y + float p_vz, // = 0.0f, player vel Z + float p_ow, // = 1.0f, player orientation quat W + float p_ox, // = 0.0f, player orientation quat X + float p_oy, // = 0.0f, player orientation quat Y + float p_oz, // = 0.0f, player orientation quat Z + float p_throttle, // = 1.0f, player throttle [0,1] + float o_px, // = -9999.0f (auto: 400m ahead), opponent pos X + float o_py, // = -9999.0f (auto), opponent pos Y + float o_pz, // = -9999.0f (auto), opponent pos Z + float o_vx, // = -9999.0f (auto: match player), opponent vel X + float o_vy, // = -9999.0f (auto), opponent vel Y + float o_vz, // = -9999.0f (auto), opponent vel Z + float o_ow, // = -9999.0f (auto: match player), opponent ori W + float o_ox, // = -9999.0f (auto), opponent ori X + float o_oy, // = -9999.0f (auto), opponent ori Y + float o_oz, // = -9999.0f (auto), opponent ori Z + int tick, // = 0, environment tick + int p_cooldown, // = -1 (no change), player fire cooldown ticks + int o_cooldown // = -1 (no change), opponent fire cooldown ticks +) { + env->player.pos = vec3(p_px, p_py, p_pz); + env->player.vel = vec3(p_vx, p_vy, p_vz); + env->player.prev_vel = vec3(p_vx, p_vy, p_vz); // Initialize to current (no accel) + env->player.omega = vec3(0, 0, 0); // No angular velocity + env->player.ori = quat(p_ow, p_ox, p_oy, p_oz); + quat_normalize(&env->player.ori); + env->player.throttle = p_throttle; + env->player.fire_cooldown = (p_cooldown >= 0) ? p_cooldown : 0; + env->player.yaw_from_rudder = 0.0f; + + // Opponent position: auto = 400m ahead of player + if (o_px < -9000.0f) { + Vec3 fwd = quat_rotate(env->player.ori, vec3(1, 0, 0)); + env->opponent.pos = add3(env->player.pos, mul3(fwd, 400.0f)); + } else { + env->opponent.pos = vec3(o_px, o_py, o_pz); + } + + // Opponent velocity: auto = match player + if (o_vx < -9000.0f) { + env->opponent.vel = env->player.vel; + } else { + env->opponent.vel = vec3(o_vx, o_vy, o_vz); + } + + // Opponent orientation: auto = match player + if (o_ow < -9000.0f) { + env->opponent.ori = env->player.ori; + } else { + env->opponent.ori = quat(o_ow, o_ox, o_oy, o_oz); + quat_normalize(&env->opponent.ori); + } + env->opponent.fire_cooldown = (o_cooldown >= 0) ? o_cooldown : 0; + env->opponent.yaw_from_rudder = 0.0f; + env->opponent.prev_vel = env->opponent.vel; // Initialize to current (no accel) + env->opponent.omega = vec3(0, 0, 0); // No angular velocity + + // Reset autopilot PID state to avoid derivative spikes + env->opponent_ap.prev_vz = env->opponent.vel.z; + env->opponent_ap.prev_bank_error = 0.0f; + + // Environment state + env->tick = tick; + env->episode_return = 0.0f; + + compute_observations(env); +#if DEBUG >= 5 + print_observations(env); +#endif +} diff --git a/pufferlib/ocean/dogfight/dogfight.py b/pufferlib/ocean/dogfight/dogfight.py new file mode 100644 index 000000000..9ee19f3f9 --- /dev/null +++ b/pufferlib/ocean/dogfight/dogfight.py @@ -0,0 +1,636 @@ +import time +import numpy as np +import gymnasium +import torch + +import pufferlib +from pufferlib.ocean.dogfight import binding +from pufferlib.models import Default as Policy + + +# Autopilot mode constants (must match autopilot.h enum) +class AutopilotMode: + STRAIGHT = 0 + LEVEL = 1 + TURN_LEFT = 2 + TURN_RIGHT = 3 + CLIMB = 4 + DESCEND = 5 + HARD_TURN_LEFT = 6 + HARD_TURN_RIGHT = 7 + WEAVE = 8 + EVASIVE = 9 + RANDOM = 10 + + +# Observation sizes by scheme (must match C OBS_SIZES in dogfight.h) +# All schemes include timer observation (tick/max_steps) at the end +OBS_SIZES = { + 0: 16, # MOMENTUM (baseline): body-frame vel + omega + AoA + energy + target + tactical + 1: 17, # MOMENTUM_BETA: + sideslip angle + 2: 17, # MOMENTUM_GFORCE: + G-force + 3: 20, # MOMENTUM_FULL: + sideslip + G + throttle + target rates + 4: 12, # MINIMAL: stripped down essentials + 5: 16, # CARTESIAN: cartesian target position + 6: 23, # DRONE_STYLE: + quaternion + up vector + 7: 17, # QBAR: + dynamic pressure + 8: 26, # KITCHEN_SINK: everything +} + + +class Dogfight(pufferlib.PufferEnv): + def __init__( + self, + num_envs=16, + render_mode=None, + render_fps=None, + report_interval=1, + buf=None, + seed=42, + max_steps=3000, + obs_scheme=0, + + curriculum_enabled=0, + curriculum_randomize=0, + eval_spawn_mode=0, # 0=random, 1=opponent_advantage (opponent behind player) + fixed_stage=-1, + max_stage=19, # Cap curriculum at this stage (19 = EVASIVE, no AutoAce/self-play) + eval_interval=2_500_000, # Steps between curriculum evaluations (2.5M = ~1s at 2.5M SPS) + warmup_steps=3_000_000, # Steps before curriculum starts evaluating (3M = ~1.2s at 2.5M SPS) + min_eval_episodes=50, # Minimum episodes in window before evaluating mastery + # Finalization: snap to mastered stage at end of training + total_timesteps=200_000_000, # Total training steps (for finalization calculation) + num_workers=8, # Number of parallel workers (for finalization calculation) + finalize_margin=50_000_000, # Start finalization this many steps before end + # df11: Simplified rewards (6 terms) + reward_aim_scale=0.05, # Continuous aiming reward + reward_closing_scale=0.003, # Per m/s closing + penalty_neg_g=0.02, # Enforce "pull to turn" + speed_min=50.0, # Stall threshold + control_rate_penalty=0.0, # Penalty for action rate changes (sweep to find optimal) + # Self-play: load frozen checkpoint as opponent + opponent_checkpoint=None, # Path to .pt checkpoint file + opponent_device='cpu', # Device for opponent policy inference + # Self-play: policy pool for skill-based opponent selection + policy_pool=None, # PolicyPool instance (optional) + opponent_selection='skill_match', # Selection strategy: skill_match, prioritized, random, latest + opponent_swap_interval=500_000, # Steps between opponent swaps + ): + # Observation size depends on scheme + obs_size = OBS_SIZES.get(obs_scheme, 19) + self.obs_scheme = obs_scheme + self.single_observation_space = gymnasium.spaces.Box( + low=-1, + high=1, + shape=(obs_size,), + dtype=np.float32, + ) + + # Action: Box(5) continuous [-1, 1] + # [0] throttle, [1] elevator, [2] ailerons, [3] rudder, [4] trigger + self.single_action_space = gymnasium.spaces.Box( + low=-1, high=1, shape=(5,), dtype=np.float32 + ) + + self.num_agents = num_envs + self.agents_per_batch = num_envs # For pufferl LSTM compatibility + self.render_mode = render_mode + self.render_fps = render_fps + self.report_interval = report_interval + self.tick = 0 + + # Global curriculum state (step-based window evaluation) + self._current_stage = 0 + self._target_stage = 0.9 # Start at 0.9 (90% stage 1, 10% stage 0) + self._warmup_steps = warmup_steps # Steps before curriculum starts evaluating + self._eval_interval = eval_interval # Steps between curriculum evaluations + self._last_eval_step = warmup_steps # First eval at warmup + eval_interval + self.curriculum_enabled = curriculum_enabled + self.fixed_stage = fixed_stage + self.max_stage = max_stage + self.min_eval_episodes = min_eval_episodes + + # Mastered stage tracking (pure mastery-gated progression) + self._mastered_stage = -1 # Highest stage with perf >= 0.90 AND >= 250 episodes (-1 = none) + + # Finalization: snap to mastered stage near end of training + # total_timesteps is global total, finalize_margin is global margin + # Per-worker finalize step = (global_total - global_margin) / num_workers + self._finalize_at_steps = (total_timesteps - finalize_margin) // num_workers + #print(f'[CURRICULUM] Initialized: finalize_at={self._finalize_at_steps}, mastered_stage={self._mastered_stage}') + + # Base stage tracking: performance at int(curriculum_target) only + self._base_stage_kills = 0.0 + self._base_stage_eps = 0.0 + + # If fixed_stage is set, lock to that stage + if fixed_stage >= 0: + self._target_stage = float(fixed_stage) + self._current_stage = fixed_stage + + super().__init__(buf) + self.actions = self.actions.astype(np.float32) # REQUIRED for continuous + + self._env_handles = [] + for env_num in range(num_envs): + handle = binding.env_init( + self.observations[env_num:(env_num+1)], + self.actions[env_num:(env_num+1)], + self.rewards[env_num:(env_num+1)], + self.terminals[env_num:(env_num+1)], + self.truncations[env_num:(env_num+1)], + env_num, + env_num=env_num, + report_interval=self.report_interval, + max_steps=max_steps, + obs_scheme=obs_scheme, + + curriculum_enabled=curriculum_enabled, + curriculum_randomize=curriculum_randomize, + eval_spawn_mode=eval_spawn_mode, + + reward_aim_scale=reward_aim_scale, + reward_closing_scale=reward_closing_scale, + penalty_neg_g=penalty_neg_g, + speed_min=speed_min, + control_rate_penalty=control_rate_penalty, + ) + self._env_handles.append(handle) + + self.c_envs = binding.vectorize(*self._env_handles) + + # Set opponent observation/reward/action buffers if provided (for dual self-play with Multiprocessing) + # These buffers come from shared memory in Multiprocessing backend + self._opponent_observations = None + self._opponent_rewards = None + self._opponent_actions = None + if buf is not None and 'opponent_observations' in buf: + self._opponent_observations = buf['opponent_observations'] + self._opponent_rewards = buf['opponent_rewards'] + # Flatten to match C expectations: shape (num_envs * obs_size,) and (num_envs,) + opp_obs_flat = self._opponent_observations.reshape(-1) + opp_rew_flat = self._opponent_rewards.reshape(-1) + binding.vec_set_opponent_buffers(self.c_envs, opp_obs_flat, opp_rew_flat) + # NOTE: Don't enable opponent override here - let DualPerspectiveTrainer + # control when to activate self-play mode. Otherwise sp_* stats get + # logged during curriculum training which is confusing. + if buf is not None and 'opponent_actions' in buf: + self._opponent_actions = buf['opponent_actions'] + # Shared flag indicating self-play mode is active (set by main process) + self._selfplay_active = None + self._opponent_override_enabled = False # Track if we've enabled C-side override + if buf is not None and 'selfplay_active' in buf: + self._selfplay_active = buf['selfplay_active'] + + # Self-play: opponent policy (loaded after c_envs created) + self.opponent_policy = None + self.opponent_device = opponent_device + self.opponent_lstm_state = None # For recurrent policies (future) + self._current_opponent_path = None # Track current opponent for swap detection + + # Policy pool: skill-based opponent selection + self.policy_pool = policy_pool + self.opponent_selection = opponent_selection + self.opponent_swap_interval = opponent_swap_interval + self._last_opponent_swap = 0 # Steps since last swap + self._save_to_pool_callback = None # Set by training script to save checkpoints + + if opponent_checkpoint: + self._load_opponent_policy(opponent_checkpoint) + self._current_opponent_path = opponent_checkpoint + + # Set fixed stage on C side if specified + if fixed_stage >= 0: + binding.vec_set_curriculum_target(self.c_envs, float(fixed_stage)) + + def _load_opponent_policy(self, path): + """Load a frozen checkpoint as the opponent policy.""" + # Create policy with same architecture as training + self.opponent_policy = Policy(self, hidden_size=128) + self.opponent_policy = self.opponent_policy.to(self.opponent_device) + + # Load checkpoint weights + state_dict = torch.load(path, map_location=self.opponent_device, weights_only=True) + + # Handle different checkpoint formats: + # 1. Strip 'module.' prefix from distributed training + # 2. Strip 'policy.' prefix from LSTMWrapper + # 3. Skip LSTM-specific keys (lstm.*, cell.*) + cleaned_state_dict = {} + for k, v in state_dict.items(): + # Skip LSTM keys - we only load the base policy + if k.startswith('lstm.') or k.startswith('cell.'): + continue + # Strip prefixes + new_k = k.replace('module.', '').replace('policy.', '') + cleaned_state_dict[new_k] = v + + self.opponent_policy.load_state_dict(cleaned_state_dict) + + # Freeze for inference only + self.opponent_policy.eval() + for p in self.opponent_policy.parameters(): + p.requires_grad = False + + # Enable C-side override mode (use external actions instead of autopilot) + binding.vec_enable_opponent_override(self.c_envs, 1) + + def reset(self, seed=None): + self.tick = 0 + binding.vec_reset(self.c_envs, seed if seed else 0) + # Reset opponent LSTM state on episode reset (for recurrent policies) + self.opponent_lstm_state = None + return self.observations, [] + + def step(self, actions): + self.actions[:] = actions + + # Check if main process has signaled self-play mode via shared memory flag + # This enables opponent override in workers (Multiprocessing) on first detection + if self._selfplay_active is not None and self._selfplay_active[0] == 1: + if not self._opponent_override_enabled: + binding.vec_enable_opponent_override(self.c_envs, 1) + self._opponent_override_enabled = True + + # Self-play: read opponent actions from shared memory buffer (dual self-play with Multiprocessing) + # or compute from local frozen policy (standard self-play with Serial) + if self._opponent_actions is not None and self._opponent_override_enabled: + # Multiprocessing dual self-play: read opponent actions from shared memory + # Main process writes actions to buf, workers read them here + opp_actions_flat = self._opponent_actions.reshape(-1, 5) # Shape: (num_envs, 5) + binding.vec_set_opponent_actions(self.c_envs, opp_actions_flat) + elif self.opponent_policy is not None: + # Serial self-play: compute opponent actions from frozen policy in-process + opp_obs = binding.vec_get_opponent_observations(self.c_envs) + opp_obs_t = torch.as_tensor(opp_obs, device=self.opponent_device) + + with torch.no_grad(): + # Policy returns Normal distribution for continuous actions + logits, _ = self.opponent_policy.forward_eval( + opp_obs_t, state=self.opponent_lstm_state + ) + # Sample actions from the distribution + opp_actions = logits.sample() + opp_actions = opp_actions.cpu().numpy().astype(np.float32) + + binding.vec_set_opponent_actions(self.c_envs, opp_actions) + + self.tick += 1 + binding.vec_step(self.c_envs) + + # Auto-render if render_mode is 'human' (Gymnasium convention) + if self.render_mode == 'human': + self.render() + if self.render_fps: + time.sleep(1.0 / self.render_fps) + + info = [] + if self.tick % self.report_interval == 0: + log_data = binding.vec_log(self.c_envs) + if log_data: + info.append(log_data) + + # Curriculum v4: Pure Mastery-Gated Progression + # Target is ALWAYS mastered_stage + 0.9 (except during finalization) + # No advancement logic - target changes ONLY when mastery is achieved + # Skip progression if fixed_stage is set (testing mode) + if self.curriculum_enabled and self.fixed_stage < 0: + n = log_data.get('n', 0) # episodes completed this tick + total_steps = self.tick * self.num_agents + + # Only accumulate AFTER warmup (avoid early kill bias) + if total_steps >= self._warmup_steps: + # Track base stage performance (for mastery gating) + if n > 0: + base_kills = log_data.get('base_stage_kills', 0) + base_eps = log_data.get('base_stage_eps', 0) + self._base_stage_kills += base_kills + self._base_stage_eps += base_eps + + # Evaluate at intervals + if total_steps - self._last_eval_step >= self._eval_interval: + self._last_eval_step = total_steps + + # Compute base stage performance (current stage mastery) + base_stage_perf = (self._base_stage_kills / self._base_stage_eps) if self._base_stage_eps > 0 else 0.0 + + # Check mastery at MAJORITY stage (round, not floor) + # At target 0.9, majority is stage 1 (90% of episodes) + mastery_stage = round(self._target_stage) + if base_stage_perf >= 0.90 and self._base_stage_eps >= self.min_eval_episodes: + if mastery_stage > self._mastered_stage: + #print(f'[CURRICULUM] MASTERED: stage {mastery_stage} (perf={base_stage_perf:.3f}, eps={self._base_stage_eps:.0f})') + self._mastered_stage = mastery_stage + # Reset base stage tracking for new level + self._base_stage_kills = 0.0 + self._base_stage_eps = 0.0 + + # Save milestone checkpoint to pool if callback is set + if self._save_to_pool_callback is not None: + self._save_to_pool_callback(mastery_stage, base_stage_perf) + + # Target is ALWAYS mastered + 0.9 (except finalization) + # Cap at max_stage to avoid AutoAce/self-play if desired + in_finalization = total_steps >= self._finalize_at_steps + if in_finalization: + new_target = float(self._mastered_stage) + 0.01 + else: + new_target = float(self._mastered_stage) + 0.9 + new_target = min(new_target, float(self.max_stage)) + + if abs(self._target_stage - new_target) > 0.01: + #print(f'[CURRICULUM] TARGET: {self._target_stage:.2f} → {new_target:.2f} (mastered={self._mastered_stage})') + self._target_stage = new_target + self._current_stage = int(self._target_stage) + binding.vec_set_curriculum_target(self.c_envs, self._target_stage) + + # Simple diagnostic print + #print(f'[CURRICULUM] step={total_steps} stage={self._target_stage:.2f} ' + # f'base={base_stage_perf:.3f}({self._base_stage_eps:.0f}eps) ' + # f'mastered={self._mastered_stage}') + + # Base stage: decay by 10% each interval (so recent perf matters more) + self._base_stage_kills *= 0.9 + self._base_stage_eps *= 0.9 + + # Policy pool: periodic opponent swapping + if self.policy_pool is not None and len(self.policy_pool) > 0: + total_steps = self.tick * self.num_agents + if total_steps - self._last_opponent_swap >= self.opponent_swap_interval: + self._last_opponent_swap = total_steps + new_opponent = self.policy_pool.select( + self._target_stage, + mode=self.opponent_selection + ) + if new_opponent and new_opponent != self._current_opponent_path: + self._load_opponent_policy(new_opponent) + self._current_opponent_path = new_opponent + + return (self.observations, self.rewards, self.terminals, self.truncations, info) + + def render(self): + binding.vec_render(self.c_envs, 0) + + def close(self): + binding.vec_close(self.c_envs) + + def force_state( + self, + env_idx=0, + player_pos=None, # (x, y, z) tuple, default (0, 0, 1000) + player_vel=None, # (vx, vy, vz) tuple, default (150, 0, 0) + player_ori=None, # (w, x, y, z) quaternion, default (1, 0, 0, 0) = wings level + player_throttle=1.0, # [0, 1], default full throttle + opponent_pos=None, # (x, y, z) or None for auto (400m ahead) + opponent_vel=None, # (vx, vy, vz) or None for auto (match player) + opponent_ori=None, # (w, x, y, z) or None for auto (match player) + tick=0, + player_cooldown=None, # Fire cooldown ticks for player (None = 0) + opponent_cooldown=None, # Fire cooldown ticks for opponent (None = 0) + ): + """ + Force exact game state for testing/debugging. + + Usage: + env.force_state(player_pos=(-1500, 0, 1000), player_vel=(150, 0, 0)) + env.force_state(player_vel=(80, 0, 0)) # Just change velocity + env.force_state(player_cooldown=100, opponent_cooldown=100) # Disable guns for 2 sec + """ + # Build kwargs for C binding + kwargs = {'tick': tick, 'p_throttle': player_throttle} + + # Player position + if player_pos is not None: + kwargs['p_px'], kwargs['p_py'], kwargs['p_pz'] = player_pos + + # Player velocity + if player_vel is not None: + kwargs['p_vx'], kwargs['p_vy'], kwargs['p_vz'] = player_vel + + # Player orientation + if player_ori is not None: + kwargs['p_ow'], kwargs['p_ox'], kwargs['p_oy'], kwargs['p_oz'] = player_ori + + # Opponent position (None = auto) + if opponent_pos is not None: + kwargs['o_px'], kwargs['o_py'], kwargs['o_pz'] = opponent_pos + + # Opponent velocity (None = auto) + if opponent_vel is not None: + kwargs['o_vx'], kwargs['o_vy'], kwargs['o_vz'] = opponent_vel + + # Opponent orientation (None = auto) + if opponent_ori is not None: + kwargs['o_ow'], kwargs['o_ox'], kwargs['o_oy'], kwargs['o_oz'] = opponent_ori + + # Fire cooldowns (None = 0, i.e., guns ready) + if player_cooldown is not None: + kwargs['p_cooldown'] = player_cooldown + if opponent_cooldown is not None: + kwargs['o_cooldown'] = opponent_cooldown + + # Call C binding with the specific env handle + binding.env_force_state(self._env_handles[env_idx], **kwargs) + + def get_state(self, env_idx=0): + """ + Get raw player state (independent of observation scheme). + + Returns dict with keys: + px, py, pz: Position + vx, vy, vz: Velocity + ow, ox, oy, oz: Orientation quaternion + up_x, up_y, up_z: Up vector (derived from quaternion) + fwd_x, fwd_y, fwd_z: Forward vector (derived from quaternion) + throttle: Current throttle + + Useful for physics tests that need exact state regardless of obs_scheme. + """ + return binding.env_get_state(self._env_handles[env_idx]) + + def set_autopilot( + self, + env_idx=0, + mode=AutopilotMode.STRAIGHT, + throttle=1.0, + bank_deg=30.0, + climb_rate=5.0, + ): + """ + Set autopilot mode for opponent aircraft. + + Args: + env_idx: Environment index, or None for all environments + mode: AutopilotMode constant (STRAIGHT, LEVEL, TURN_LEFT, etc.) + throttle: Target throttle [0, 1] + bank_deg: Bank angle for turn modes (degrees) + climb_rate: Target vertical velocity for climb/descend (m/s) + + Usage: + env.set_autopilot(mode=AutopilotMode.LEVEL) # Level flight, env 0 + env.set_autopilot(mode=AutopilotMode.TURN_RIGHT, bank_deg=45) # 45° right turn + env.set_autopilot(mode=AutopilotMode.RANDOM) # Randomize each episode + env.set_autopilot(env_idx=None, mode=AutopilotMode.RANDOM) # All envs + """ + if env_idx is None: + # Vectorized: set all envs at once + binding.vec_set_autopilot( + self.c_envs, + mode=mode, + throttle=throttle, + bank_deg=bank_deg, + climb_rate=climb_rate, + ) + else: + # Single env + binding.env_set_autopilot( + self._env_handles[env_idx], + mode=mode, + throttle=throttle, + bank_deg=bank_deg, + climb_rate=climb_rate, + ) + + def set_mode_weights(self, level=0.2, turn_left=0.2, turn_right=0.2, + climb=0.2, descend=0.2): + """ + Set probability weights for AP_RANDOM mode selection. + + Weights should sum to 1.0. Used for curriculum learning to bias + toward easier modes (e.g., LEVEL, STRAIGHT turns) early in training. + + Args: + level: Weight for AP_LEVEL (maintain altitude) + turn_left: Weight for AP_TURN_LEFT + turn_right: Weight for AP_TURN_RIGHT + climb: Weight for AP_CLIMB + descend: Weight for AP_DESCEND + """ + binding.vec_set_mode_weights( + self.c_envs, + level=level, turn_left=turn_left, turn_right=turn_right, + climb=climb, descend=descend, + ) + + def get_autopilot_mode(self, env_idx=0): + """Get current autopilot mode for an environment (for testing/debugging).""" + return binding.env_get_autopilot_mode(self._env_handles[env_idx]) + + def set_obs_highlight(self, indices, env_idx=0): + """ + Set which observations to highlight with red arrows in the visual display. + + Args: + indices: List of observation indices to highlight (e.g., [4, 5, 6] for pitch, roll, yaw) + env_idx: Environment index + + Usage: + env.set_obs_highlight([4, 5, 6]) # Highlight pitch, roll, yaw in scheme 0 + env.set_obs_highlight([]) # Clear highlights + """ + binding.env_set_obs_highlight(self._env_handles[env_idx], list(indices)) + + def set_curriculum_stage(self, stage: int): + """ + Set curriculum stage for all environments (global curriculum). + + Called by training loop based on aggregate kill_rate from log data. + All envs share the same stage for coherent metrics. + + Args: + stage: Curriculum stage (0=TAIL_CHASE, 1=HEAD_ON, 2=VERTICAL, + 3=MANEUVERING, 4=OFFSET_MANEUVERING, 5=ANGLED_MANEUVERING, + 6=FULL_RANDOM, 7=HARD_MANEUVERING, 8=CROSSING, 9=EVASIVE) + """ + binding.vec_set_curriculum_stage(self.c_envs, stage) + self._current_stage = stage + + def get_curriculum_stage(self) -> int: + """Get current global curriculum stage.""" + return self._current_stage + + def set_curriculum_target(self, target: float): + """ + Set curriculum target (0.0-15.0) for probabilistic stage assignment. + + At each episode reset, stage is assigned probabilistically: + - target=1.3 → 70% stage 1, 30% stage 2 + + Args: + target: Float target from 0.0 to 17.0 (18 curriculum stages) + """ + self._target_stage = max(0.0, min(target, 17.0)) + binding.vec_set_curriculum_target(self.c_envs, self._target_stage) + self._current_stage = int(self._target_stage) + + def get_curriculum_target(self) -> float: + """Get current curriculum target (float 0.0-9.0).""" + return self._target_stage + + def get_autoace_state(self, env_idx=0): + """ + Get AutoAce opponent state and tactical info for behavioral tests. + + Returns dict with keys: + # Opponent plane state + opp_px, opp_py, opp_pz: Position + opp_vx, opp_vy, opp_vz: Velocity + opp_fwd_x, opp_fwd_y, opp_fwd_z: Forward vector + opp_ow, opp_ox, opp_oy, opp_oz: Orientation quaternion + + # Last AutoAce actions + opp_throttle, opp_elevator, opp_aileron, opp_rudder, opp_trigger + + # Tactical state + engagement: 0=OFFENSIVE, 1=NEUTRAL, 2=DEFENSIVE, 3=WEAPONS, 4=EXTEND + mode: Autopilot mode enum value + aspect_angle: Degrees (0=behind target, 180=head-on) + antenna_train: Target bearing from nose (0=dead ahead) + range: Distance in meters + closure_rate: Positive = closing (m/s) + in_gun_envelope: Boolean + """ + return binding.env_get_autoace_state(self._env_handles[env_idx]) + + def set_camera_follow(self, follow_opponent=False, env_idx=0): + """ + Set which plane the camera follows during rendering. + + Args: + follow_opponent: True to follow opponent (AutoAce), False to follow player + env_idx: Environment index + """ + binding.env_set_camera_follow(self._env_handles[env_idx], 1 if follow_opponent else 0) + + def set_eval_spawn_mode(self, mode: int): + """ + Set eval spawn mode for all environments. + + Args: + mode: 0 = random (default), 1 = opponent_advantage (opponent behind player) + + Mode 1 (opponent_advantage) places opponent 400m behind player at 15° off tail, + giving opponent an easy kill opportunity. Useful for testing if opponent can kill. + """ + binding.vec_set_eval_spawn_mode(self.c_envs, mode) + + +def test_performance(timeout=10, atn_cache=1024): + env = Dogfight(num_envs=1000) + env.reset() + tick = 0 + + actions = [env.action_space.sample() for _ in range(atn_cache)] + + import time + start = time.time() + while time.time() - start < timeout: + atn = actions[tick % atn_cache] + env.step(atn) + tick += 1 + + print(f"SPS: {env.num_agents * tick / (time.time() - start)}") + + +if __name__ == "__main__": + test_performance() diff --git a/pufferlib/ocean/dogfight/dogfight_observations.h b/pufferlib/ocean/dogfight/dogfight_observations.h new file mode 100644 index 000000000..b4f226717 --- /dev/null +++ b/pufferlib/ocean/dogfight/dogfight_observations.h @@ -0,0 +1,1117 @@ +// dogfight_observations.h - Observation computation for dogfight environment +// Extracted from dogfight.h to reduce file size +// +// Observation Schemes (for realistic physics - physics mode 1): +// All schemes include timer observation at the end: tick/(max_steps+1) [0,~1) +// Scheme 0: OBS_MOMENTUM - Baseline (16 obs) +// Scheme 1: OBS_MOMENTUM_BETA - + sideslip angle (17 obs) +// Scheme 2: OBS_MOMENTUM_GFORCE - + G-force (17 obs) +// Scheme 3: OBS_MOMENTUM_FULL - + sideslip + G + throttle + tgt rates (20 obs) +// Scheme 4: OBS_MINIMAL - stripped down essentials (12 obs) +// Scheme 5: OBS_CARTESIAN - cartesian target position (16 obs) +// Scheme 6: OBS_DRONE_STYLE - + quaternion + up vector (23 obs) +// Scheme 7: OBS_QBAR - + dynamic pressure (17 obs) +// Scheme 8: OBS_KITCHEN_SINK - everything (26 obs) + +#ifndef DOGFIGHT_OBSERVATIONS_H +#define DOGFIGHT_OBSERVATIONS_H + +// Requires: flightlib.h (Vec3, Quat, math), Dogfight struct defined before include + +// Normalization constants +#define MAX_OMEGA 3.0f // ~172 deg/s, reasonable for aggressive maneuvering +#define INV_MAX_OMEGA (1.0f / MAX_OMEGA) +#define MAX_AOA 0.5f // ~28 deg, beyond this is deep stall +#define INV_MAX_AOA (1.0f / MAX_AOA) +#define MAX_SIDESLIP 0.5f // ~28 degrees +#define INV_MAX_SIDESLIP (1.0f / MAX_SIDESLIP) +#define MAX_QBAR 38281.0f // 0.5 * 1.225 * 250^2 at sea level, max speed +#define INV_MAX_QBAR (1.0f / MAX_QBAR) +#define MAX_RANGE 2000.0f // Normalization range for target distance +#define INV_MAX_RANGE (1.0f / MAX_RANGE) + +// ============================================================================ +// Generalized observation computation for self-play +// ============================================================================ +// Computes observations from 'self' plane's perspective looking at 'other' plane. +// Used for both player and opponent observations. +// +// Note: Timer observation uses env->tick/max_steps which is shared between both +// planes. This is correct for self-play where both see the same episode timer. + +void compute_obs_momentum_for_plane(Dogfight *env, Plane *self, Plane *other, float *obs_buffer) { + Quat q_inv = {self->ori.w, -self->ori.x, -self->ori.y, -self->ori.z}; + + // === OWN FLIGHT STATE === + // Body-frame velocity + Vec3 vel_body = quat_rotate(q_inv, self->vel); + float speed = norm3(self->vel); + + // Angle of attack + Vec3 forward = quat_rotate(self->ori, vec3(1, 0, 0)); + Vec3 up = quat_rotate(self->ori, vec3(0, 0, 1)); + float aoa = 0.0f; + if (speed > 1.0f) { + Vec3 vel_norm = normalize3(self->vel); + float cos_alpha = clampf(dot3(vel_norm, forward), -1.0f, 1.0f); + float alpha = acosf(cos_alpha); + float sign = (dot3(self->vel, up) < 0) ? 1.0f : -1.0f; + aoa = alpha * sign; + } + + // Energy state + float potential = self->pos.z * INV_WORLD_MAX_Z; + float kinetic = (speed * speed) / (MAX_SPEED * MAX_SPEED); + float own_energy = (potential + kinetic) * 0.5f; + + // === TARGET STATE === + Vec3 rel_pos = sub3(other->pos, self->pos); + Vec3 rel_pos_body = quat_rotate(q_inv, rel_pos); + float dist = norm3(rel_pos); + + float target_az = atan2f(rel_pos_body.y, rel_pos_body.x); + float r_horiz = sqrtf(rel_pos_body.x * rel_pos_body.x + rel_pos_body.y * rel_pos_body.y); + float target_el = atan2f(rel_pos_body.z, fmaxf(r_horiz, 1e-6f)); + + // Closure rate (positive = closing) + Vec3 rel_vel = sub3(self->vel, other->vel); + float closure = dot3(rel_vel, normalize3(rel_pos)); + + // === TACTICAL === + Vec3 other_fwd = quat_rotate(other->ori, vec3(1, 0, 0)); + Vec3 to_self = normalize3(sub3(self->pos, other->pos)); + float target_aspect = dot3(other_fwd, to_self); + + float other_speed = norm3(other->vel); + float other_potential = other->pos.z * INV_WORLD_MAX_Z; + float other_kinetic = (other_speed * other_speed) / (MAX_SPEED * MAX_SPEED); + float other_energy = (other_potential + other_kinetic) * 0.5f; + float energy_advantage = clampf(own_energy - other_energy, -1.0f, 1.0f); + + int i = 0; + // Own flight state (9 obs) + obs_buffer[i++] = clampf(vel_body.x * INV_MAX_SPEED, 0.0f, 1.0f); // Forward speed [0,1] + obs_buffer[i++] = clampf(vel_body.y * INV_MAX_SPEED, -1.0f, 1.0f); // Sideslip [-1,1] + obs_buffer[i++] = clampf(vel_body.z * INV_MAX_SPEED, -1.0f, 1.0f); // Climb rate [-1,1] + obs_buffer[i++] = clampf(self->omega.x * INV_MAX_OMEGA, -1.0f, 1.0f); // Roll rate [-1,1] + obs_buffer[i++] = clampf(self->omega.y * INV_MAX_OMEGA, -1.0f, 1.0f); // Pitch rate [-1,1] + obs_buffer[i++] = clampf(self->omega.z * INV_MAX_OMEGA, -1.0f, 1.0f); // Yaw rate [-1,1] + obs_buffer[i++] = clampf(aoa * INV_MAX_AOA, -1.0f, 1.0f); // AoA [-1,1] + obs_buffer[i++] = potential; // Altitude [0,1] + obs_buffer[i++] = own_energy; // Own energy [0,1] + + // Target state - spherical (4 obs) + obs_buffer[i++] = target_az * INV_PI; // Azimuth [-1,1] + obs_buffer[i++] = target_el * INV_HALF_PI; // Elevation [-1,1] + obs_buffer[i++] = clampf(dist * INV_MAX_RANGE, 0.0f, 1.0f); // Range [0,1] + obs_buffer[i++] = clampf(closure * INV_MAX_SPEED, -1.0f, 1.0f); // Closure [-1,1] + + // Tactical (2 obs) + obs_buffer[i++] = energy_advantage; // Energy advantage [-1,1] + obs_buffer[i++] = target_aspect; // Aspect [-1,1] + + // Timer (1 obs) - how much time left before episode ends + obs_buffer[i++] = (float)env->tick / (float)(env->max_steps + 1); // Timer [0,~1) + // OBS_SIZE = 16 +} + +// ============================================================================ +// Scheme 0: OBS_MOMENTUM - Baseline (15 obs) +// ============================================================================ +// Body-frame velocity + omega + AoA + energy + target spherical + tactical +// [0-2] Body-frame velocity (forward speed, sideslip, climb rate) +// [3-5] Angular velocity (roll rate, pitch rate, yaw rate) +// [6] Angle of attack +// [7-8] Altitude + own energy +// [9-12] Target spherical (azimuth, elevation, range, closure) +// [13-14] Tactical (energy advantage, target aspect) +void compute_obs_momentum(Dogfight *env) { + compute_obs_momentum_for_plane(env, &env->player, &env->opponent, env->observations); +} + +// ============================================================================ +// Scheme 1: OBS_MOMENTUM_BETA - + sideslip angle (16 obs) +// ============================================================================ +// Hypothesis: Explicit sideslip angle helps coordinated flight +void compute_obs_momentum_beta(Dogfight *env) { + Plane *p = &env->player; + Plane *o = &env->opponent; + + Quat q_inv = {p->ori.w, -p->ori.x, -p->ori.y, -p->ori.z}; + + // Body-frame velocity + Vec3 vel_body = quat_rotate(q_inv, p->vel); + float speed = norm3(p->vel); + + // Angle of attack + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + Vec3 up = quat_rotate(p->ori, vec3(0, 0, 1)); + float aoa = 0.0f; + if (speed > 1.0f) { + Vec3 vel_norm = normalize3(p->vel); + float cos_alpha = clampf(dot3(vel_norm, forward), -1.0f, 1.0f); + float alpha = acosf(cos_alpha); + float sign = (dot3(p->vel, up) < 0) ? 1.0f : -1.0f; + aoa = alpha * sign; + } + + // Sideslip angle (beta) + // beta = asin(vy / speed), positive = nose left of velocity + float beta = 0.0f; + if (speed > 1.0f) { + beta = asinf(clampf(vel_body.y / speed, -1.0f, 1.0f)); + } + + // Energy state + float potential = p->pos.z * INV_WORLD_MAX_Z; + float kinetic = (speed * speed) / (MAX_SPEED * MAX_SPEED); + float own_energy = (potential + kinetic) * 0.5f; + + // Target state + Vec3 rel_pos = sub3(o->pos, p->pos); + Vec3 rel_pos_body = quat_rotate(q_inv, rel_pos); + float dist = norm3(rel_pos); + + float target_az = atan2f(rel_pos_body.y, rel_pos_body.x); + float r_horiz = sqrtf(rel_pos_body.x * rel_pos_body.x + rel_pos_body.y * rel_pos_body.y); + float target_el = atan2f(rel_pos_body.z, fmaxf(r_horiz, 1e-6f)); + + Vec3 rel_vel = sub3(p->vel, o->vel); + float closure = dot3(rel_vel, normalize3(rel_pos)); + + // Tactical + Vec3 opp_fwd = quat_rotate(o->ori, vec3(1, 0, 0)); + Vec3 to_player = normalize3(sub3(p->pos, o->pos)); + float target_aspect = dot3(opp_fwd, to_player); + + float opp_speed = norm3(o->vel); + float opp_potential = o->pos.z * INV_WORLD_MAX_Z; + float opp_kinetic = (opp_speed * opp_speed) / (MAX_SPEED * MAX_SPEED); + float opp_energy = (opp_potential + opp_kinetic) * 0.5f; + float energy_advantage = clampf(own_energy - opp_energy, -1.0f, 1.0f); + + int i = 0; + // Own flight state (9 obs - same as MOMENTUM) + env->observations[i++] = clampf(vel_body.x * INV_MAX_SPEED, 0.0f, 1.0f); + env->observations[i++] = clampf(vel_body.y * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(vel_body.z * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.x * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.y * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.z * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(aoa * INV_MAX_AOA, -1.0f, 1.0f); + env->observations[i++] = potential; + env->observations[i++] = own_energy; + + // NEW: Sideslip angle + env->observations[i++] = clampf(beta * INV_MAX_SIDESLIP, -1.0f, 1.0f); // Beta [-1,1] + + // Target state (4 obs) + env->observations[i++] = target_az * INV_PI; + env->observations[i++] = target_el * INV_HALF_PI; + env->observations[i++] = clampf(dist * INV_MAX_RANGE, 0.0f, 1.0f); + env->observations[i++] = clampf(closure * INV_MAX_SPEED, -1.0f, 1.0f); + + // Tactical (2 obs) + env->observations[i++] = energy_advantage; + env->observations[i++] = target_aspect; + + // Timer (1 obs) + env->observations[i++] = (float)env->tick / (float)(env->max_steps + 1); + // OBS_SIZE = 17 +} + +// ============================================================================ +// Scheme 2: OBS_MOMENTUM_GFORCE - + G-force (16 obs) +// ============================================================================ +// Hypothesis: G-force awareness enables better high-G maneuvering +void compute_obs_momentum_gforce(Dogfight *env) { + Plane *p = &env->player; + Plane *o = &env->opponent; + + Quat q_inv = {p->ori.w, -p->ori.x, -p->ori.y, -p->ori.z}; + + // Body-frame velocity + Vec3 vel_body = quat_rotate(q_inv, p->vel); + float speed = norm3(p->vel); + + // Angle of attack + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + Vec3 up = quat_rotate(p->ori, vec3(0, 0, 1)); + float aoa = 0.0f; + if (speed > 1.0f) { + Vec3 vel_norm = normalize3(p->vel); + float cos_alpha = clampf(dot3(vel_norm, forward), -1.0f, 1.0f); + float alpha = acosf(cos_alpha); + float sign = (dot3(p->vel, up) < 0) ? 1.0f : -1.0f; + aoa = alpha * sign; + } + + // Energy state + float potential = p->pos.z * INV_WORLD_MAX_Z; + float kinetic = (speed * speed) / (MAX_SPEED * MAX_SPEED); + float own_energy = (potential + kinetic) * 0.5f; + + // G-force normalization: 0G=0, 1G=0.2, 5G=1.0, -2.5G=-0.5 + float g_norm = clampf(p->g_force / 5.0f, -0.5f, 1.0f); + + // Target state + Vec3 rel_pos = sub3(o->pos, p->pos); + Vec3 rel_pos_body = quat_rotate(q_inv, rel_pos); + float dist = norm3(rel_pos); + + float target_az = atan2f(rel_pos_body.y, rel_pos_body.x); + float r_horiz = sqrtf(rel_pos_body.x * rel_pos_body.x + rel_pos_body.y * rel_pos_body.y); + float target_el = atan2f(rel_pos_body.z, fmaxf(r_horiz, 1e-6f)); + + Vec3 rel_vel = sub3(p->vel, o->vel); + float closure = dot3(rel_vel, normalize3(rel_pos)); + + // Tactical + Vec3 opp_fwd = quat_rotate(o->ori, vec3(1, 0, 0)); + Vec3 to_player = normalize3(sub3(p->pos, o->pos)); + float target_aspect = dot3(opp_fwd, to_player); + + float opp_speed = norm3(o->vel); + float opp_potential = o->pos.z * INV_WORLD_MAX_Z; + float opp_kinetic = (opp_speed * opp_speed) / (MAX_SPEED * MAX_SPEED); + float opp_energy = (opp_potential + opp_kinetic) * 0.5f; + float energy_advantage = clampf(own_energy - opp_energy, -1.0f, 1.0f); + + int i = 0; + // Own flight state (9 obs) + env->observations[i++] = clampf(vel_body.x * INV_MAX_SPEED, 0.0f, 1.0f); + env->observations[i++] = clampf(vel_body.y * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(vel_body.z * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.x * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.y * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.z * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(aoa * INV_MAX_AOA, -1.0f, 1.0f); + env->observations[i++] = potential; + env->observations[i++] = own_energy; + + // NEW: G-force + env->observations[i++] = g_norm; // G-force [-0.5,1] + + // Target state (4 obs) + env->observations[i++] = target_az * INV_PI; + env->observations[i++] = target_el * INV_HALF_PI; + env->observations[i++] = clampf(dist * INV_MAX_RANGE, 0.0f, 1.0f); + env->observations[i++] = clampf(closure * INV_MAX_SPEED, -1.0f, 1.0f); + + // Tactical (2 obs) + env->observations[i++] = energy_advantage; + env->observations[i++] = target_aspect; + + // Timer (1 obs) + env->observations[i++] = (float)env->tick / (float)(env->max_steps + 1); + // OBS_SIZE = 17 +} + +// ============================================================================ +// Scheme 3: OBS_MOMENTUM_FULL - + sideslip + G + throttle + target rates (19 obs) +// ============================================================================ +// Hypothesis: Maximum relevant information is optimal +void compute_obs_momentum_full(Dogfight *env) { + Plane *p = &env->player; + Plane *o = &env->opponent; + + Quat q_inv = {p->ori.w, -p->ori.x, -p->ori.y, -p->ori.z}; + + // Body-frame velocity + Vec3 vel_body = quat_rotate(q_inv, p->vel); + float speed = norm3(p->vel); + + // Angle of attack + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + Vec3 up = quat_rotate(p->ori, vec3(0, 0, 1)); + float aoa = 0.0f; + if (speed > 1.0f) { + Vec3 vel_norm = normalize3(p->vel); + float cos_alpha = clampf(dot3(vel_norm, forward), -1.0f, 1.0f); + float alpha = acosf(cos_alpha); + float sign = (dot3(p->vel, up) < 0) ? 1.0f : -1.0f; + aoa = alpha * sign; + } + + // Sideslip angle + float beta = 0.0f; + if (speed > 1.0f) { + beta = asinf(clampf(vel_body.y / speed, -1.0f, 1.0f)); + } + + // Energy state + float potential = p->pos.z * INV_WORLD_MAX_Z; + float kinetic = (speed * speed) / (MAX_SPEED * MAX_SPEED); + float own_energy = (potential + kinetic) * 0.5f; + + // G-force + float g_norm = clampf(p->g_force / 5.0f, -0.5f, 1.0f); + + // Target state + Vec3 rel_pos = sub3(o->pos, p->pos); + Vec3 rel_pos_body = quat_rotate(q_inv, rel_pos); + float dist = norm3(rel_pos); + + float target_az = atan2f(rel_pos_body.y, rel_pos_body.x); + float r_horiz = sqrtf(rel_pos_body.x * rel_pos_body.x + rel_pos_body.y * rel_pos_body.y); + float target_el = atan2f(rel_pos_body.z, fmaxf(r_horiz, 1e-6f)); + + Vec3 rel_vel = sub3(p->vel, o->vel); + float closure = dot3(rel_vel, normalize3(rel_pos)); + + // Tactical + float opp_speed = norm3(o->vel); + float opp_potential = o->pos.z * INV_WORLD_MAX_Z; + float opp_kinetic = (opp_speed * opp_speed) / (MAX_SPEED * MAX_SPEED); + float opp_energy = (opp_potential + opp_kinetic) * 0.5f; + float energy_advantage = clampf(own_energy - opp_energy, -1.0f, 1.0f); + + int i = 0; + // Own flight state (9 obs) + env->observations[i++] = clampf(vel_body.x * INV_MAX_SPEED, 0.0f, 1.0f); + env->observations[i++] = clampf(vel_body.y * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(vel_body.z * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.x * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.y * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.z * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(aoa * INV_MAX_AOA, -1.0f, 1.0f); + env->observations[i++] = potential; + env->observations[i++] = own_energy; + + // Extended own state (3 obs) + env->observations[i++] = clampf(beta * INV_MAX_SIDESLIP, -1.0f, 1.0f); // Beta + env->observations[i++] = g_norm; // G-force + env->observations[i++] = p->throttle; // Throttle [0,1] + + // Target state (4 obs) + env->observations[i++] = target_az * INV_PI; + env->observations[i++] = target_el * INV_HALF_PI; + env->observations[i++] = clampf(dist * INV_MAX_RANGE, 0.0f, 1.0f); + env->observations[i++] = clampf(closure * INV_MAX_SPEED, -1.0f, 1.0f); + + // Target angular rates (2 obs) - for predicting opponent maneuvers + env->observations[i++] = clampf(o->omega.y * INV_MAX_OMEGA, -1.0f, 1.0f); // Target pitch rate + env->observations[i++] = clampf(o->omega.x * INV_MAX_OMEGA, -1.0f, 1.0f); // Target roll rate + + // Energy advantage (1 obs) + env->observations[i++] = energy_advantage; + + // Timer (1 obs) + env->observations[i++] = (float)env->tick / (float)(env->max_steps + 1); + // OBS_SIZE = 20 +} + +// ============================================================================ +// Scheme 4: OBS_MINIMAL - stripped down essentials (11 obs) +// ============================================================================ +// Hypothesis: Simpler observations learn faster and generalize better +void compute_obs_minimal(Dogfight *env) { + Plane *p = &env->player; + Plane *o = &env->opponent; + + Quat q_inv = {p->ori.w, -p->ori.x, -p->ori.y, -p->ori.z}; + + // Body-frame velocity + Vec3 vel_body = quat_rotate(q_inv, p->vel); + float speed = norm3(p->vel); + + // Angle of attack + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + Vec3 up = quat_rotate(p->ori, vec3(0, 0, 1)); + float aoa = 0.0f; + if (speed > 1.0f) { + Vec3 vel_norm = normalize3(p->vel); + float cos_alpha = clampf(dot3(vel_norm, forward), -1.0f, 1.0f); + float alpha = acosf(cos_alpha); + float sign = (dot3(p->vel, up) < 0) ? 1.0f : -1.0f; + aoa = alpha * sign; + } + + // Altitude + float potential = p->pos.z * INV_WORLD_MAX_Z; + + // Target state + Vec3 rel_pos = sub3(o->pos, p->pos); + Vec3 rel_pos_body = quat_rotate(q_inv, rel_pos); + float dist = norm3(rel_pos); + + float target_az = atan2f(rel_pos_body.y, rel_pos_body.x); + float r_horiz = sqrtf(rel_pos_body.x * rel_pos_body.x + rel_pos_body.y * rel_pos_body.y); + float target_el = atan2f(rel_pos_body.z, fmaxf(r_horiz, 1e-6f)); + + Vec3 rel_vel = sub3(p->vel, o->vel); + float closure = dot3(rel_vel, normalize3(rel_pos)); + + // Energy advantage + float kinetic = (speed * speed) / (MAX_SPEED * MAX_SPEED); + float own_energy = (potential + kinetic) * 0.5f; + + float opp_speed = norm3(o->vel); + float opp_potential = o->pos.z * INV_WORLD_MAX_Z; + float opp_kinetic = (opp_speed * opp_speed) / (MAX_SPEED * MAX_SPEED); + float opp_energy = (opp_potential + opp_kinetic) * 0.5f; + float energy_advantage = clampf(own_energy - opp_energy, -1.0f, 1.0f); + + int i = 0; + // Minimal own state (6 obs) + env->observations[i++] = clampf(vel_body.x * INV_MAX_SPEED, 0.0f, 1.0f); // Forward speed + env->observations[i++] = clampf(aoa * INV_MAX_AOA, -1.0f, 1.0f); // AoA + env->observations[i++] = clampf(p->omega.x * INV_MAX_OMEGA, -1.0f, 1.0f); // Roll rate + env->observations[i++] = clampf(p->omega.y * INV_MAX_OMEGA, -1.0f, 1.0f); // Pitch rate + env->observations[i++] = clampf(p->omega.z * INV_MAX_OMEGA, -1.0f, 1.0f); // Yaw rate + env->observations[i++] = potential; // Altitude + + // Target (4 obs) + env->observations[i++] = target_az * INV_PI; // Azimuth + env->observations[i++] = target_el * INV_HALF_PI; // Elevation + env->observations[i++] = clampf(dist * INV_MAX_RANGE, 0.0f, 1.0f); // Range + env->observations[i++] = clampf(closure * INV_MAX_SPEED, -1.0f, 1.0f); // Closure + + // Tactical (1 obs) + env->observations[i++] = energy_advantage; + + // Timer (1 obs) + env->observations[i++] = (float)env->tick / (float)(env->max_steps + 1); + // OBS_SIZE = 12 +} + +// ============================================================================ +// Scheme 5: OBS_CARTESIAN - cartesian target position (15 obs) +// ============================================================================ +// Hypothesis: Cartesian target coords better for lead computing +void compute_obs_cartesian(Dogfight *env) { + Plane *p = &env->player; + Plane *o = &env->opponent; + + Quat q_inv = {p->ori.w, -p->ori.x, -p->ori.y, -p->ori.z}; + + // Body-frame velocity + Vec3 vel_body = quat_rotate(q_inv, p->vel); + float speed = norm3(p->vel); + + // Angle of attack + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + Vec3 up = quat_rotate(p->ori, vec3(0, 0, 1)); + float aoa = 0.0f; + if (speed > 1.0f) { + Vec3 vel_norm = normalize3(p->vel); + float cos_alpha = clampf(dot3(vel_norm, forward), -1.0f, 1.0f); + float alpha = acosf(cos_alpha); + float sign = (dot3(p->vel, up) < 0) ? 1.0f : -1.0f; + aoa = alpha * sign; + } + + // Energy state + float potential = p->pos.z * INV_WORLD_MAX_Z; + float kinetic = (speed * speed) / (MAX_SPEED * MAX_SPEED); + float own_energy = (potential + kinetic) * 0.5f; + + // Target in body frame - CARTESIAN instead of spherical + Vec3 rel_pos = sub3(o->pos, p->pos); + Vec3 rel_pos_body = quat_rotate(q_inv, rel_pos); + + Vec3 rel_vel = sub3(p->vel, o->vel); + float closure = dot3(rel_vel, normalize3(rel_pos)); + + // Tactical + Vec3 opp_fwd = quat_rotate(o->ori, vec3(1, 0, 0)); + Vec3 to_player = normalize3(sub3(p->pos, o->pos)); + float target_aspect = dot3(opp_fwd, to_player); + + float opp_speed = norm3(o->vel); + float opp_potential = o->pos.z * INV_WORLD_MAX_Z; + float opp_kinetic = (opp_speed * opp_speed) / (MAX_SPEED * MAX_SPEED); + float opp_energy = (opp_potential + opp_kinetic) * 0.5f; + float energy_advantage = clampf(own_energy - opp_energy, -1.0f, 1.0f); + + int i = 0; + // Own flight state (9 obs) + env->observations[i++] = clampf(vel_body.x * INV_MAX_SPEED, 0.0f, 1.0f); + env->observations[i++] = clampf(vel_body.y * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(vel_body.z * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.x * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.y * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.z * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(aoa * INV_MAX_AOA, -1.0f, 1.0f); + env->observations[i++] = potential; + env->observations[i++] = own_energy; + + // Target state - CARTESIAN (4 obs) + env->observations[i++] = clampf(rel_pos_body.x * INV_MAX_RANGE, -1.0f, 1.0f); // Target X (forward) + env->observations[i++] = clampf(rel_pos_body.y * INV_MAX_RANGE, -1.0f, 1.0f); // Target Y (right) + env->observations[i++] = clampf(rel_pos_body.z * INV_MAX_RANGE, -1.0f, 1.0f); // Target Z (up) + env->observations[i++] = clampf(closure * INV_MAX_SPEED, -1.0f, 1.0f); + + // Tactical (2 obs) + env->observations[i++] = energy_advantage; + env->observations[i++] = target_aspect; + + // Timer (1 obs) + env->observations[i++] = (float)env->tick / (float)(env->max_steps + 1); + // OBS_SIZE = 16 +} + +// ============================================================================ +// Scheme 6: OBS_DRONE_STYLE - + quaternion + up vector (22 obs) +// ============================================================================ +// Hypothesis: Quaternion + up vector (drone_race style) helps 3D maneuvers +void compute_obs_drone_style(Dogfight *env) { + Plane *p = &env->player; + Plane *o = &env->opponent; + + Quat q_inv = {p->ori.w, -p->ori.x, -p->ori.y, -p->ori.z}; + + // Body-frame velocity + Vec3 vel_body = quat_rotate(q_inv, p->vel); + float speed = norm3(p->vel); + + // Angle of attack + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + Vec3 up = quat_rotate(p->ori, vec3(0, 0, 1)); + float aoa = 0.0f; + if (speed > 1.0f) { + Vec3 vel_norm = normalize3(p->vel); + float cos_alpha = clampf(dot3(vel_norm, forward), -1.0f, 1.0f); + float alpha = acosf(cos_alpha); + float sign = (dot3(p->vel, up) < 0) ? 1.0f : -1.0f; + aoa = alpha * sign; + } + + // Energy state + float potential = p->pos.z * INV_WORLD_MAX_Z; + float kinetic = (speed * speed) / (MAX_SPEED * MAX_SPEED); + float own_energy = (potential + kinetic) * 0.5f; + + // Up vector in world frame (derived from quaternion) + Vec3 world_up = quat_rotate(p->ori, vec3(0, 0, 1)); + + // Target state + Vec3 rel_pos = sub3(o->pos, p->pos); + Vec3 rel_pos_body = quat_rotate(q_inv, rel_pos); + float dist = norm3(rel_pos); + + float target_az = atan2f(rel_pos_body.y, rel_pos_body.x); + float r_horiz = sqrtf(rel_pos_body.x * rel_pos_body.x + rel_pos_body.y * rel_pos_body.y); + float target_el = atan2f(rel_pos_body.z, fmaxf(r_horiz, 1e-6f)); + + Vec3 rel_vel = sub3(p->vel, o->vel); + float closure = dot3(rel_vel, normalize3(rel_pos)); + + // Tactical + Vec3 opp_fwd = quat_rotate(o->ori, vec3(1, 0, 0)); + Vec3 to_player = normalize3(sub3(p->pos, o->pos)); + float target_aspect = dot3(opp_fwd, to_player); + + float opp_speed = norm3(o->vel); + float opp_potential = o->pos.z * INV_WORLD_MAX_Z; + float opp_kinetic = (opp_speed * opp_speed) / (MAX_SPEED * MAX_SPEED); + float opp_energy = (opp_potential + opp_kinetic) * 0.5f; + float energy_advantage = clampf(own_energy - opp_energy, -1.0f, 1.0f); + + int i = 0; + // Own flight state (9 obs) + env->observations[i++] = clampf(vel_body.x * INV_MAX_SPEED, 0.0f, 1.0f); + env->observations[i++] = clampf(vel_body.y * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(vel_body.z * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.x * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.y * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.z * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(aoa * INV_MAX_AOA, -1.0f, 1.0f); + env->observations[i++] = potential; + env->observations[i++] = own_energy; + + // Quaternion (4 obs) - raw orientation for NN to reason about 3D + env->observations[i++] = p->ori.w; + env->observations[i++] = p->ori.x; + env->observations[i++] = p->ori.y; + env->observations[i++] = p->ori.z; + + // Up vector in world frame (3 obs) - gravity-relative maneuvers + env->observations[i++] = world_up.x; + env->observations[i++] = world_up.y; + env->observations[i++] = world_up.z; + + // Target state (4 obs) + env->observations[i++] = target_az * INV_PI; + env->observations[i++] = target_el * INV_HALF_PI; + env->observations[i++] = clampf(dist * INV_MAX_RANGE, 0.0f, 1.0f); + env->observations[i++] = clampf(closure * INV_MAX_SPEED, -1.0f, 1.0f); + + // Tactical (2 obs) + env->observations[i++] = energy_advantage; + env->observations[i++] = target_aspect; + + // Timer (1 obs) + env->observations[i++] = (float)env->tick / (float)(env->max_steps + 1); + // OBS_SIZE = 23 +} + +// ============================================================================ +// Scheme 7: OBS_QBAR - + dynamic pressure (16 obs) +// ============================================================================ +// Hypothesis: Dynamic pressure helps understand control authority +void compute_obs_qbar(Dogfight *env) { + Plane *p = &env->player; + Plane *o = &env->opponent; + + Quat q_inv = {p->ori.w, -p->ori.x, -p->ori.y, -p->ori.z}; + + // Body-frame velocity + Vec3 vel_body = quat_rotate(q_inv, p->vel); + float speed = norm3(p->vel); + + // Angle of attack + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + Vec3 up = quat_rotate(p->ori, vec3(0, 0, 1)); + float aoa = 0.0f; + if (speed > 1.0f) { + Vec3 vel_norm = normalize3(p->vel); + float cos_alpha = clampf(dot3(vel_norm, forward), -1.0f, 1.0f); + float alpha = acosf(cos_alpha); + float sign = (dot3(p->vel, up) < 0) ? 1.0f : -1.0f; + aoa = alpha * sign; + } + + // Energy state + float potential = p->pos.z * INV_WORLD_MAX_Z; + float kinetic = (speed * speed) / (MAX_SPEED * MAX_SPEED); + float own_energy = (potential + kinetic) * 0.5f; + + // Dynamic pressure q_bar = 0.5 * rho * V^2 + // At sea level rho ≈ 1.225 kg/m³ + float rho = 1.225f; + float q_bar = 0.5f * rho * speed * speed; + float q_bar_norm = clampf(q_bar * INV_MAX_QBAR, 0.0f, 1.0f); + + // Target state + Vec3 rel_pos = sub3(o->pos, p->pos); + Vec3 rel_pos_body = quat_rotate(q_inv, rel_pos); + float dist = norm3(rel_pos); + + float target_az = atan2f(rel_pos_body.y, rel_pos_body.x); + float r_horiz = sqrtf(rel_pos_body.x * rel_pos_body.x + rel_pos_body.y * rel_pos_body.y); + float target_el = atan2f(rel_pos_body.z, fmaxf(r_horiz, 1e-6f)); + + Vec3 rel_vel = sub3(p->vel, o->vel); + float closure = dot3(rel_vel, normalize3(rel_pos)); + + // Tactical + Vec3 opp_fwd = quat_rotate(o->ori, vec3(1, 0, 0)); + Vec3 to_player = normalize3(sub3(p->pos, o->pos)); + float target_aspect = dot3(opp_fwd, to_player); + + float opp_speed = norm3(o->vel); + float opp_potential = o->pos.z * INV_WORLD_MAX_Z; + float opp_kinetic = (opp_speed * opp_speed) / (MAX_SPEED * MAX_SPEED); + float opp_energy = (opp_potential + opp_kinetic) * 0.5f; + float energy_advantage = clampf(own_energy - opp_energy, -1.0f, 1.0f); + + int i = 0; + // Own flight state (9 obs) + env->observations[i++] = clampf(vel_body.x * INV_MAX_SPEED, 0.0f, 1.0f); + env->observations[i++] = clampf(vel_body.y * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(vel_body.z * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.x * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.y * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.z * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(aoa * INV_MAX_AOA, -1.0f, 1.0f); + env->observations[i++] = potential; + env->observations[i++] = own_energy; + + // Dynamic pressure (1 obs) + env->observations[i++] = q_bar_norm; // q_bar [0,1] + + // Target state (4 obs) + env->observations[i++] = target_az * INV_PI; + env->observations[i++] = target_el * INV_HALF_PI; + env->observations[i++] = clampf(dist * INV_MAX_RANGE, 0.0f, 1.0f); + env->observations[i++] = clampf(closure * INV_MAX_SPEED, -1.0f, 1.0f); + + // Tactical (2 obs) + env->observations[i++] = energy_advantage; + env->observations[i++] = target_aspect; + + // Timer (1 obs) + env->observations[i++] = (float)env->tick / (float)(env->max_steps + 1); + // OBS_SIZE = 17 +} + +// ============================================================================ +// Scheme 8: OBS_KITCHEN_SINK - everything (25 obs) +// ============================================================================ +// Hypothesis: Maximum information with everything is optimal +void compute_obs_kitchen_sink(Dogfight *env) { + Plane *p = &env->player; + Plane *o = &env->opponent; + + Quat q_inv = {p->ori.w, -p->ori.x, -p->ori.y, -p->ori.z}; + + // Body-frame velocity + Vec3 vel_body = quat_rotate(q_inv, p->vel); + float speed = norm3(p->vel); + + // Angle of attack + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + Vec3 up_body = quat_rotate(p->ori, vec3(0, 0, 1)); + float aoa = 0.0f; + if (speed > 1.0f) { + Vec3 vel_norm = normalize3(p->vel); + float cos_alpha = clampf(dot3(vel_norm, forward), -1.0f, 1.0f); + float alpha = acosf(cos_alpha); + float sign = (dot3(p->vel, up_body) < 0) ? 1.0f : -1.0f; + aoa = alpha * sign; + } + + // Sideslip angle + float beta = 0.0f; + if (speed > 1.0f) { + beta = asinf(clampf(vel_body.y / speed, -1.0f, 1.0f)); + } + + // G-force + float g_norm = clampf(p->g_force / 5.0f, -0.5f, 1.0f); + + // Dynamic pressure + float rho = 1.225f; + float q_bar = 0.5f * rho * speed * speed; + float q_bar_norm = clampf(q_bar * INV_MAX_QBAR, 0.0f, 1.0f); + + // Energy state + float potential = p->pos.z * INV_WORLD_MAX_Z; + float kinetic = (speed * speed) / (MAX_SPEED * MAX_SPEED); + float own_energy = (potential + kinetic) * 0.5f; + + // Up vector in world frame + Vec3 world_up = quat_rotate(p->ori, vec3(0, 0, 1)); + + // Target state + Vec3 rel_pos = sub3(o->pos, p->pos); + Vec3 rel_pos_body = quat_rotate(q_inv, rel_pos); + float dist = norm3(rel_pos); + + float target_az = atan2f(rel_pos_body.y, rel_pos_body.x); + float r_horiz = sqrtf(rel_pos_body.x * rel_pos_body.x + rel_pos_body.y * rel_pos_body.y); + float target_el = atan2f(rel_pos_body.z, fmaxf(r_horiz, 1e-6f)); + + Vec3 rel_vel = sub3(p->vel, o->vel); + float closure = dot3(rel_vel, normalize3(rel_pos)); + + // Energy advantage + float opp_speed = norm3(o->vel); + float opp_potential = o->pos.z * INV_WORLD_MAX_Z; + float opp_kinetic = (opp_speed * opp_speed) / (MAX_SPEED * MAX_SPEED); + float opp_energy = (opp_potential + opp_kinetic) * 0.5f; + float energy_advantage = clampf(own_energy - opp_energy, -1.0f, 1.0f); + + int i = 0; + // Body-frame velocity (3 obs) + env->observations[i++] = clampf(vel_body.x * INV_MAX_SPEED, 0.0f, 1.0f); + env->observations[i++] = clampf(vel_body.y * INV_MAX_SPEED, -1.0f, 1.0f); + env->observations[i++] = clampf(vel_body.z * INV_MAX_SPEED, -1.0f, 1.0f); + + // Angular velocity (3 obs) + env->observations[i++] = clampf(p->omega.x * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.y * INV_MAX_OMEGA, -1.0f, 1.0f); + env->observations[i++] = clampf(p->omega.z * INV_MAX_OMEGA, -1.0f, 1.0f); + + // Flight angles (2 obs) + env->observations[i++] = clampf(aoa * INV_MAX_AOA, -1.0f, 1.0f); + env->observations[i++] = clampf(beta * INV_MAX_SIDESLIP, -1.0f, 1.0f); + + // Flight state (4 obs) + env->observations[i++] = g_norm; + env->observations[i++] = q_bar_norm; + env->observations[i++] = potential; + env->observations[i++] = own_energy; + + // Controls (1 obs) + env->observations[i++] = p->throttle; + + // Quaternion (4 obs) + env->observations[i++] = p->ori.w; + env->observations[i++] = p->ori.x; + env->observations[i++] = p->ori.y; + env->observations[i++] = p->ori.z; + + // Up vector in world frame (3 obs) + env->observations[i++] = world_up.x; + env->observations[i++] = world_up.y; + env->observations[i++] = world_up.z; + + // Target spherical (4 obs) + env->observations[i++] = target_az * INV_PI; + env->observations[i++] = target_el * INV_HALF_PI; + env->observations[i++] = clampf(dist * INV_MAX_RANGE, 0.0f, 1.0f); + env->observations[i++] = clampf(closure * INV_MAX_SPEED, -1.0f, 1.0f); + + // Energy advantage (1 obs) + env->observations[i++] = energy_advantage; + + // Timer (1 obs) + env->observations[i++] = (float)env->tick / (float)(env->max_steps + 1); + // OBS_SIZE = 26 +} + +// ============================================================================ +// Dispatcher function +// ============================================================================ +void compute_observations(Dogfight *env) { + switch (env->obs_scheme) { + case OBS_MOMENTUM: compute_obs_momentum(env); break; + case OBS_MOMENTUM_BETA: compute_obs_momentum_beta(env); break; + case OBS_MOMENTUM_GFORCE: compute_obs_momentum_gforce(env); break; + case OBS_MOMENTUM_FULL: compute_obs_momentum_full(env); break; + case OBS_MINIMAL: compute_obs_minimal(env); break; + case OBS_CARTESIAN: compute_obs_cartesian(env); break; + case OBS_DRONE_STYLE: compute_obs_drone_style(env); break; + case OBS_QBAR: compute_obs_qbar(env); break; + case OBS_KITCHEN_SINK: compute_obs_kitchen_sink(env); break; + default: compute_obs_momentum(env); break; + } +} + +// ============================================================================ +// Opponent observations (for self-play) +// ============================================================================ +// Scheme 1 generalized for self-play (opponent perspective) +// ============================================================================ +void compute_obs_momentum_beta_for_plane(Dogfight *env, Plane *self, Plane *other, float *obs_buffer) { + Quat q_inv = {self->ori.w, -self->ori.x, -self->ori.y, -self->ori.z}; + + // Body-frame velocity + Vec3 vel_body = quat_rotate(q_inv, self->vel); + float speed = norm3(self->vel); + + // Angle of attack + Vec3 forward = quat_rotate(self->ori, vec3(1, 0, 0)); + Vec3 up = quat_rotate(self->ori, vec3(0, 0, 1)); + float aoa = 0.0f; + if (speed > 1.0f) { + Vec3 vel_norm = normalize3(self->vel); + float cos_alpha = clampf(dot3(vel_norm, forward), -1.0f, 1.0f); + float alpha = acosf(cos_alpha); + float sign = (dot3(self->vel, up) < 0) ? 1.0f : -1.0f; + aoa = alpha * sign; + } + + // Sideslip angle (beta) + float beta = 0.0f; + if (speed > 1.0f) { + beta = asinf(clampf(vel_body.y / speed, -1.0f, 1.0f)); + } + + // Energy state + float potential = self->pos.z * INV_WORLD_MAX_Z; + float kinetic = (speed * speed) / (MAX_SPEED * MAX_SPEED); + float own_energy = (potential + kinetic) * 0.5f; + + // Target state + Vec3 rel_pos = sub3(other->pos, self->pos); + Vec3 rel_pos_body = quat_rotate(q_inv, rel_pos); + float dist = norm3(rel_pos); + + float target_az = atan2f(rel_pos_body.y, rel_pos_body.x); + float r_horiz = sqrtf(rel_pos_body.x * rel_pos_body.x + rel_pos_body.y * rel_pos_body.y); + float target_el = atan2f(rel_pos_body.z, fmaxf(r_horiz, 1e-6f)); + + Vec3 rel_vel = sub3(self->vel, other->vel); + float closure = dot3(rel_vel, normalize3(rel_pos)); + + // Tactical + Vec3 other_fwd = quat_rotate(other->ori, vec3(1, 0, 0)); + Vec3 to_self = normalize3(sub3(self->pos, other->pos)); + float target_aspect = dot3(other_fwd, to_self); + + float other_speed = norm3(other->vel); + float other_potential = other->pos.z * INV_WORLD_MAX_Z; + float other_kinetic = (other_speed * other_speed) / (MAX_SPEED * MAX_SPEED); + float other_energy = (other_potential + other_kinetic) * 0.5f; + float energy_advantage = clampf(own_energy - other_energy, -1.0f, 1.0f); + + int i = 0; + // Own flight state (9 obs) + obs_buffer[i++] = clampf(vel_body.x * INV_MAX_SPEED, 0.0f, 1.0f); + obs_buffer[i++] = clampf(vel_body.y * INV_MAX_SPEED, -1.0f, 1.0f); + obs_buffer[i++] = clampf(vel_body.z * INV_MAX_SPEED, -1.0f, 1.0f); + obs_buffer[i++] = clampf(self->omega.x * INV_MAX_OMEGA, -1.0f, 1.0f); + obs_buffer[i++] = clampf(self->omega.y * INV_MAX_OMEGA, -1.0f, 1.0f); + obs_buffer[i++] = clampf(self->omega.z * INV_MAX_OMEGA, -1.0f, 1.0f); + obs_buffer[i++] = clampf(aoa * INV_MAX_AOA, -1.0f, 1.0f); + obs_buffer[i++] = potential; + obs_buffer[i++] = own_energy; + + // Sideslip angle (scheme 1 addition) + obs_buffer[i++] = clampf(beta * INV_MAX_SIDESLIP, -1.0f, 1.0f); + + // Target state (4 obs) + obs_buffer[i++] = target_az * INV_PI; + obs_buffer[i++] = target_el * INV_HALF_PI; + obs_buffer[i++] = clampf(dist * INV_MAX_RANGE, 0.0f, 1.0f); + obs_buffer[i++] = clampf(closure * INV_MAX_SPEED, -1.0f, 1.0f); + + // Tactical (2 obs) + obs_buffer[i++] = energy_advantage; + obs_buffer[i++] = target_aspect; + + // Timer (1 obs) + obs_buffer[i++] = (float)env->tick / (float)(env->max_steps + 1); + // OBS_SIZE = 17 +} + +// ============================================================================ +// Computes observations from opponent's perspective looking at player. +// Supports scheme 0 (MOMENTUM) and scheme 1 (MOMENTUM_BETA). +void compute_opponent_observations(Dogfight *env, float *opp_obs_buffer) { + if (env->obs_scheme == 1) { + compute_obs_momentum_beta_for_plane(env, &env->opponent, &env->player, opp_obs_buffer); + } else { + // Default to scheme 0 for other schemes (may need to add more _for_plane variants) + compute_obs_momentum_for_plane(env, &env->opponent, &env->player, opp_obs_buffer); + } +} + +// ============================================================================ +// Debug labels for print_observations +// ============================================================================ +#if DEBUG >= 5 + +// Scheme 0: OBS_MOMENTUM (16 obs) +static const char* DEBUG_OBS_LABELS_MOMENTUM[16] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", + "tgt_az", "tgt_el", "range", "closure", + "E_adv", "aspect", "timer" +}; + +// Scheme 1: OBS_MOMENTUM_BETA (17 obs) +static const char* DEBUG_OBS_LABELS_MOMENTUM_BETA[17] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", "beta", + "tgt_az", "tgt_el", "range", "closure", + "E_adv", "aspect", "timer" +}; + +// Scheme 2: OBS_MOMENTUM_GFORCE (17 obs) +static const char* DEBUG_OBS_LABELS_MOMENTUM_GFORCE[17] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", "g_force", + "tgt_az", "tgt_el", "range", "closure", + "E_adv", "aspect", "timer" +}; + +// Scheme 3: OBS_MOMENTUM_FULL (20 obs) +static const char* DEBUG_OBS_LABELS_MOMENTUM_FULL[20] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", "beta", "g_force", "throttle", + "tgt_az", "tgt_el", "range", "closure", + "tgt_pitch_r", "tgt_roll_r", "E_adv", "timer" +}; + +// Scheme 4: OBS_MINIMAL (12 obs) +static const char* DEBUG_OBS_LABELS_MINIMAL[12] = { + "fwd_spd", "aoa", "roll_r", "pitch_r", "yaw_r", "altitude", + "tgt_az", "tgt_el", "range", "closure", "E_adv", "timer" +}; + +// Scheme 5: OBS_CARTESIAN (16 obs) +static const char* DEBUG_OBS_LABELS_CARTESIAN[16] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", + "tgt_x", "tgt_y", "tgt_z", "closure", + "E_adv", "aspect", "timer" +}; + +// Scheme 6: OBS_DRONE_STYLE (23 obs) +static const char* DEBUG_OBS_LABELS_DRONE_STYLE[23] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", + "quat_w", "quat_x", "quat_y", "quat_z", + "up_x", "up_y", "up_z", + "tgt_az", "tgt_el", "range", "closure", + "E_adv", "aspect", "timer" +}; + +// Scheme 7: OBS_QBAR (17 obs) +static const char* DEBUG_OBS_LABELS_QBAR[17] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", "q_bar", + "tgt_az", "tgt_el", "range", "closure", + "E_adv", "aspect", "timer" +}; + +// Scheme 8: OBS_KITCHEN_SINK (26 obs) +static const char* DEBUG_OBS_LABELS_KITCHEN_SINK[26] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "beta", "g_force", "q_bar", "altitude", "energy", "throttle", + "quat_w", "quat_x", "quat_y", "quat_z", + "up_x", "up_y", "up_z", + "tgt_az", "tgt_el", "range", "closure", "E_adv", "timer" +}; + +void print_observations(Dogfight *env) { + const char** labels = NULL; + int num_obs = env->obs_size; + + // Select labels based on scheme + switch (env->obs_scheme) { + case OBS_MOMENTUM: labels = DEBUG_OBS_LABELS_MOMENTUM; break; + case OBS_MOMENTUM_BETA: labels = DEBUG_OBS_LABELS_MOMENTUM_BETA; break; + case OBS_MOMENTUM_GFORCE: labels = DEBUG_OBS_LABELS_MOMENTUM_GFORCE; break; + case OBS_MOMENTUM_FULL: labels = DEBUG_OBS_LABELS_MOMENTUM_FULL; break; + case OBS_MINIMAL: labels = DEBUG_OBS_LABELS_MINIMAL; break; + case OBS_CARTESIAN: labels = DEBUG_OBS_LABELS_CARTESIAN; break; + case OBS_DRONE_STYLE: labels = DEBUG_OBS_LABELS_DRONE_STYLE; break; + case OBS_QBAR: labels = DEBUG_OBS_LABELS_QBAR; break; + case OBS_KITCHEN_SINK: labels = DEBUG_OBS_LABELS_KITCHEN_SINK; break; + default: labels = DEBUG_OBS_LABELS_MOMENTUM; break; + } + + printf("=== OBS (scheme %d, %d obs) ===\n", env->obs_scheme, num_obs); + + for (int i = 0; i < num_obs; i++) { + float val = env->observations[i]; + + // Determine range based on scheme and index + bool is_01 = false; + switch (env->obs_scheme) { + case OBS_MOMENTUM: + // fwd_spd(0), altitude(7), energy(8), range(11), timer(15) are [0,1] + is_01 = (i == 0 || i == 7 || i == 8 || i == 11 || i == 15); + break; + case OBS_MOMENTUM_BETA: + case OBS_MOMENTUM_GFORCE: + case OBS_QBAR: + // fwd_spd(0), altitude(7), energy(8), range(12), timer(16) are [0,1] + is_01 = (i == 0 || i == 7 || i == 8 || i == 12 || i == 16); + break; + case OBS_MOMENTUM_FULL: + // fwd_spd(0), altitude(7), energy(8), throttle(11), range(14), timer(19) are [0,1] + is_01 = (i == 0 || i == 7 || i == 8 || i == 11 || i == 14 || i == 19); + break; + case OBS_MINIMAL: + // fwd_spd(0), altitude(5), range(8), timer(11) are [0,1] + is_01 = (i == 0 || i == 5 || i == 8 || i == 11); + break; + case OBS_CARTESIAN: + // fwd_spd(0), altitude(7), energy(8), timer(15) are [0,1] + is_01 = (i == 0 || i == 7 || i == 8 || i == 15); + break; + case OBS_DRONE_STYLE: + // fwd_spd(0), altitude(7), energy(8), range(18), timer(22) are [0,1] + is_01 = (i == 0 || i == 7 || i == 8 || i == 18 || i == 22); + break; + case OBS_KITCHEN_SINK: + // fwd_spd(0), q_bar(9), altitude(10), energy(11), throttle(12), range(22), timer(25) are [0,1] + is_01 = (i == 0 || i == 9 || i == 10 || i == 11 || i == 12 || i == 22 || i == 25); + break; + default: + break; + } + + const char* range_str = is_01 ? "[0,1]" : "[-1,1]"; + printf("[%2d] %-12s = %+.3f %s\n", i, labels[i], val, range_str); + } +} +#endif // DEBUG >= 5 + +#endif // DOGFIGHT_OBSERVATIONS_H diff --git a/pufferlib/ocean/dogfight/dogfight_render.h b/pufferlib/ocean/dogfight/dogfight_render.h new file mode 100644 index 000000000..0b80b24ba --- /dev/null +++ b/pufferlib/ocean/dogfight/dogfight_render.h @@ -0,0 +1,509 @@ +#ifndef DOGFIGHT_RENDER_H +#define DOGFIGHT_RENDER_H + + +#include "raymath.h" + +// Convert our Quat (w,x,y,z) to Raylib Quaternion (x,y,z,w) +static inline Quaternion quat_to_raylib(Quat q) { + return (Quaternion){q.x, q.y, q.z, q.w}; +} + +// Scheme 0: OBS_MOMENTUM (15 obs) - baseline +static const char* OBS_LABELS_MOMENTUM[15] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", + "tgt_az", "tgt_el", "range", "closure", + "E_adv", "aspect" +}; + +// Scheme 1: OBS_MOMENTUM_BETA (17 obs) +static const char* OBS_LABELS_MOMENTUM_BETA[17] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", "beta", + "tgt_az", "tgt_el", "range", "closure", + "E_adv", "aspect", "timer" +}; + +// Scheme 2: OBS_MOMENTUM_GFORCE (17 obs) +static const char* OBS_LABELS_MOMENTUM_GFORCE[17] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", "g_force", + "tgt_az", "tgt_el", "range", "closure", + "E_adv", "aspect", "timer" +}; + +// Scheme 3: OBS_MOMENTUM_FULL (19 obs) +static const char* OBS_LABELS_MOMENTUM_FULL[19] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", "beta", "g_force", "throttle", + "tgt_az", "tgt_el", "range", "closure", + "tgt_pitch_r", "tgt_roll_r", "E_adv" +}; + +// Scheme 4: OBS_MINIMAL (11 obs) +static const char* OBS_LABELS_MINIMAL[11] = { + "fwd_spd", "aoa", "roll_r", "pitch_r", "yaw_r", "altitude", + "tgt_az", "tgt_el", "range", "closure", "E_adv" +}; + +// Scheme 5: OBS_CARTESIAN (15 obs) +static const char* OBS_LABELS_CARTESIAN[15] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", + "tgt_x", "tgt_y", "tgt_z", "closure", + "E_adv", "aspect" +}; + +// Scheme 6: OBS_DRONE_STYLE (22 obs) +static const char* OBS_LABELS_DRONE_STYLE[22] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", + "quat_w", "quat_x", "quat_y", "quat_z", + "up_x", "up_y", "up_z", + "tgt_az", "tgt_el", "range", "closure", + "E_adv", "aspect" +}; + +// Scheme 7: OBS_QBAR (16 obs) +static const char* OBS_LABELS_QBAR[16] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "altitude", "energy", "q_bar", + "tgt_az", "tgt_el", "range", "closure", + "E_adv", "aspect" +}; + +// Scheme 8: OBS_KITCHEN_SINK (25 obs) +static const char* OBS_LABELS_KITCHEN_SINK[25] = { + "fwd_spd", "sideslip", "climb", "roll_r", "pitch_r", "yaw_r", + "aoa", "beta", "g_force", "q_bar", "altitude", "energy", "throttle", + "quat_w", "quat_x", "quat_y", "quat_z", + "up_x", "up_y", "up_z", + "tgt_az", "tgt_el", "range", "closure", "E_adv" +}; + +void draw_plane_model(Client *client, Vec3 pos, Quat ori, Color tint, float scale_factor, float prop_angle) { + Vector3 position = {pos.x, pos.y, pos.z}; + + // Convert our quaternion (w,x,y,z) to Raylib (x,y,z,w) + Quaternion model_rot = quat_to_raylib(ori); + + // GLB model is Y-up, we use Z-up + // Rotate 90 deg around X to convert Y-up to Z-up + // Then rotate to align nose with +X (model nose might point +Z or -Z) + Vector3 x_axis = {1, 0, 0}; + Vector3 z_axis = {0, 0, 1}; + Quaternion coord_fix, nose_fix, full_fix, final_rot; + + coord_fix = QuaternionFromAxisAngle(x_axis, PI / 2); // Y-up to Z-up + nose_fix = QuaternionFromAxisAngle(z_axis, PI / 2); // Rotate nose to +X + full_fix = QuaternionMultiply(nose_fix, coord_fix); + + final_rot = QuaternionMultiply(model_rot, full_fix); + + Matrix scale_mat = MatrixScale(scale_factor, scale_factor, scale_factor); + Matrix rot_mat = QuaternionToMatrix(final_rot); + Matrix trans_mat = MatrixTranslate(position.x, position.y, position.z); + Matrix base_transform = MatrixMultiply(MatrixMultiply(scale_mat, rot_mat), trans_mat); + + // Propeller rotation around model's forward axis (Y in model space before coord fix) + // After coord_fix: model Y becomes world Z, but we want rotation around nose axis + // Propeller spins around model's local Z axis (forward in GLB space) + Quaternion prop_rot = QuaternionFromAxisAngle((Vector3){0, 0, 1}, prop_angle); + Quaternion prop_final = QuaternionMultiply(final_rot, prop_rot); + Matrix prop_rot_mat = QuaternionToMatrix(prop_final); + Matrix prop_transform = MatrixMultiply(MatrixMultiply(scale_mat, prop_rot_mat), trans_mat); + + rlDisableColorBlend(); + Model model = client->plane_model; + for (int i = 0; i < model.meshCount; i++) { + if (i >= 3) continue; + if (i == 2 && client->camera_mode == 3) continue; // Skip blur prop in cockpit view + Matrix transform = (i == 2) ? prop_transform : base_transform; + int mat_idx = model.meshMaterial[i]; + Color original = model.materials[mat_idx].maps[MATERIAL_MAP_DIFFUSE].color; + model.materials[mat_idx].maps[MATERIAL_MAP_DIFFUSE].color = (Color){ + (unsigned char)(original.r * tint.r / 255), + (unsigned char)(original.g * tint.g / 255), + (unsigned char)(original.b * tint.b / 255), + original.a + }; + DrawMesh(model.meshes[i], model.materials[mat_idx], transform); + model.materials[mat_idx].maps[MATERIAL_MAP_DIFFUSE].color = original; + } + rlEnableColorBlend(); +} + +void handle_camera_controls(Client *c) { + Vector2 mouse = GetMousePosition(); + + if (IsMouseButtonPressed(MOUSE_BUTTON_LEFT)) { + c->is_dragging = true; + c->last_mouse_x = mouse.x; + c->last_mouse_y = mouse.y; + } + if (IsMouseButtonReleased(MOUSE_BUTTON_LEFT)) { + c->is_dragging = false; + } + + if (c->is_dragging) { + float sensitivity = 0.005f; + c->cam_azimuth -= (mouse.x - c->last_mouse_x) * sensitivity; + c->cam_elevation += (mouse.y - c->last_mouse_y) * sensitivity; + c->cam_elevation = clampf(c->cam_elevation, -1.4f, 1.4f); // prevent gimbal lock + c->last_mouse_x = mouse.x; + c->last_mouse_y = mouse.y; + } + + float wheel = GetMouseWheelMove(); + if (wheel != 0) { + c->cam_distance = clampf(c->cam_distance - wheel * 10.0f, 30.0f, 300.0f); + } +} + +// Draw a single observation bar +// x, y: top-left position +// label: observation name +// value: the observation value +// is_01_range: true for [0,1] range, false for [-1,1] range +void draw_obs_bar(int x, int y, const char* label, float value, bool is_01_range) { + DrawText(label, x, y, 14, WHITE); + + int bar_x = x + 80; + int bar_w = 150; + int bar_h = 14; + + DrawRectangle(bar_x, y, bar_w, bar_h, DARKGRAY); + + float norm_val; + int fill_x, fill_w; + + if (is_01_range) { + norm_val = clampf(value, 0.0f, 1.0f); + fill_x = bar_x; + fill_w = (int)(norm_val * bar_w); + } else { + norm_val = clampf(value, -1.0f, 1.0f); + int center = bar_x + bar_w / 2; + if (norm_val >= 0) { + fill_x = center; + fill_w = (int)(norm_val * bar_w / 2); + } else { + fill_w = (int)(-norm_val * bar_w / 2); + fill_x = center - fill_w; + } + } + + Color fill_color = GREEN; + if (fabsf(value) > 0.9f) fill_color = YELLOW; + if (fabsf(value) > 1.0f) fill_color = RED; + + DrawRectangle(fill_x, y, fill_w, bar_h, fill_color); + + if (!is_01_range) { + int center = bar_x + bar_w / 2; + DrawLine(center, y, center, y + bar_h, WHITE); + } + + DrawText(TextFormat("%+.2f", value), bar_x + bar_w + 5, y, 14, WHITE); +} + +void draw_obs_monitor(Dogfight *env) { + int start_x = 1540; + int start_y = 10; + int row_height = 18; + + const char** labels = NULL; + int num_obs = env->obs_size; + + switch (env->obs_scheme) { + case OBS_MOMENTUM: + labels = OBS_LABELS_MOMENTUM; + break; + case OBS_MOMENTUM_BETA: + labels = OBS_LABELS_MOMENTUM_BETA; + break; + case OBS_MOMENTUM_GFORCE: + labels = OBS_LABELS_MOMENTUM_GFORCE; + break; + case OBS_MOMENTUM_FULL: + labels = OBS_LABELS_MOMENTUM_FULL; + break; + case OBS_MINIMAL: + labels = OBS_LABELS_MINIMAL; + break; + case OBS_CARTESIAN: + labels = OBS_LABELS_CARTESIAN; + break; + case OBS_DRONE_STYLE: + labels = OBS_LABELS_DRONE_STYLE; + break; + case OBS_QBAR: + labels = OBS_LABELS_QBAR; + break; + case OBS_KITCHEN_SINK: + labels = OBS_LABELS_KITCHEN_SINK; + break; + default: + labels = OBS_LABELS_MOMENTUM; + break; + } + + DrawText(TextFormat("OBS (scheme %d)", env->obs_scheme), + start_x, start_y, 16, YELLOW); + start_y += 22; + + for (int i = 0; i < num_obs; i++) { + float val = env->observations[i]; + bool is_01 = false; + switch (env->obs_scheme) { + case OBS_MOMENTUM: + // fwd_spd(0), altitude(7), energy(8), range(11) are [0,1] + is_01 = (i == 0 || i == 7 || i == 8 || i == 11); + break; + case OBS_MOMENTUM_BETA: + case OBS_MOMENTUM_GFORCE: + case OBS_QBAR: + // fwd_spd(0), altitude(7), energy(8), range(12) are [0,1] + is_01 = (i == 0 || i == 7 || i == 8 || i == 12); + break; + case OBS_MOMENTUM_FULL: + // fwd_spd(0), altitude(7), energy(8), throttle(11), range(14) are [0,1] + is_01 = (i == 0 || i == 7 || i == 8 || i == 11 || i == 14); + break; + case OBS_MINIMAL: + // fwd_spd(0), altitude(5), range(8) are [0,1] + is_01 = (i == 0 || i == 5 || i == 8); + break; + case OBS_CARTESIAN: + // fwd_spd(0), altitude(7), energy(8) are [0,1] + is_01 = (i == 0 || i == 7 || i == 8); + break; + case OBS_DRONE_STYLE: + // fwd_spd(0), altitude(7), energy(8), range(18) are [0,1] + is_01 = (i == 0 || i == 7 || i == 8 || i == 18); + break; + case OBS_KITCHEN_SINK: + // fwd_spd(0), q_bar(9), altitude(10), energy(11), throttle(12), range(22) are [0,1] + is_01 = (i == 0 || i == 9 || i == 10 || i == 11 || i == 12 || i == 22); + break; + default: + break; + } + int y = start_y + i * row_height; + draw_obs_bar(start_x, y, labels[i], val, is_01); + + if (env->obs_highlight[i]) { + int arrow_x = start_x - 20; + int arrow_y = y + 7; + DrawTriangle( + (Vector2){arrow_x, arrow_y - 5}, + (Vector2){arrow_x, arrow_y + 5}, + (Vector2){arrow_x + 12, arrow_y}, + RED + ); + } + } +} + +void c_render(Dogfight *env) { + if (env->client == NULL) { + env->client = (Client *)calloc(1, sizeof(Client)); + env->client->width = 1280; + env->client->height = 720; + env->client->cam_distance = 80.0f; + env->client->cam_azimuth = 0.0f; + env->client->cam_elevation = 0.3f; + env->client->camera_mode = 0; // 0 = follow target, 1 = midpoint view, 2 = chase, 3 = cockpit + env->client->is_dragging = false; + + InitWindow(1920, 1080, "Dogfight"); + SetTargetFPS(60); + + env->client->camera.up = (Vector3){0.0f, 0.0f, 1.0f}; + env->client->camera.fovy = 45.0f; + env->client->camera.projection = CAMERA_PERSPECTIVE; + + env->client->plane_model = LoadModel("pufferlib/ocean/dogfight/p40.glb"); + env->client->model_loaded = (env->client->plane_model.meshCount > 0); + } + + if (WindowShouldClose() || IsKeyDown(KEY_ESCAPE)) { + c_close(env); + exit(0); + } + + if (IsKeyPressed(KEY_C)) { + env->client->camera_mode = (env->client->camera_mode + 1) % 4; + } + + // R key: force reset to new episode + if (IsKeyPressed(KEY_R)) { + c_reset(env); + } + + handle_camera_controls(env->client); + + Plane *p = &env->player; + Plane *o = &env->opponent; + Plane *cam_target = env->camera_follow_opponent ? o : p; + Vec3 fwd = quat_rotate(cam_target->ori, vec3(1, 0, 0)); + float dist = env->client->cam_distance; + + float az = env->client->cam_azimuth; + float el = env->client->cam_elevation; + + if (env->client->camera_mode == 2) { + // Mode 2: Direct chase - fully body-aligned (roll + pitch + yaw) + Vec3 up = quat_rotate(cam_target->ori, vec3(0, 0, 1)); + // Position: behind and slightly above in body frame (0.25 up, was 0.5) + float chase_dist = dist * 0.7f; + float cam_x = cam_target->pos.x - fwd.x * chase_dist + up.x * chase_dist * 0.25f; + float cam_y = cam_target->pos.y - fwd.y * chase_dist + up.y * chase_dist * 0.25f; + float cam_z = cam_target->pos.z - fwd.z * chase_dist + up.z * chase_dist * 0.25f; + env->client->camera.position = (Vector3){cam_x, cam_y, cam_z}; + // Target ahead and above plane so plane appears lower on screen + float tgt_x = cam_target->pos.x + fwd.x * 20.0f + up.x * 20.0f; + float tgt_y = cam_target->pos.y + fwd.y * 20.0f + up.y * 20.0f; + float tgt_z = cam_target->pos.z + fwd.z * 20.0f + up.z * 20.0f; + env->client->camera.target = (Vector3){tgt_x, tgt_y, tgt_z}; + env->client->camera.up = (Vector3){up.x, up.y, up.z}; + } else if (env->client->camera_mode == 3) { + // Mode 3: Cockpit POV - at plane center, 1.25m up in body frame + Vec3 up = quat_rotate(cam_target->ori, vec3(0, 0, 1)); + float cam_x = cam_target->pos.x + up.x * 1.11f; + float cam_y = cam_target->pos.y + up.y * 1.11f; + float cam_z = cam_target->pos.z + up.z * 1.11f; + env->client->camera.position = (Vector3){cam_x, cam_y, cam_z}; + // Look forward and ~15 degrees up (tan(15°) ≈ 0.268, so 27 up per 100 forward) + float tgt_x = cam_target->pos.x + fwd.x * 100.0f + up.x * 27.0f; + float tgt_y = cam_target->pos.y + fwd.y * 100.0f + up.y * 27.0f; + float tgt_z = cam_target->pos.z + fwd.z * 100.0f + up.z * 27.0f; + env->client->camera.target = (Vector3){tgt_x, tgt_y, tgt_z}; + env->client->camera.up = (Vector3){up.x, up.y, up.z}; + } else { + // Reset to world up for other modes + env->client->camera.up = (Vector3){0.0f, 0.0f, 1.0f}; + // Modes 0 and 1: Orbit camera position + float cam_x = cam_target->pos.x - fwd.x * dist * cosf(el) * cosf(az) + fwd.y * dist * sinf(az); + float cam_y = cam_target->pos.y - fwd.y * dist * cosf(el) * cosf(az) - fwd.x * dist * sinf(az); + float cam_z = cam_target->pos.z + dist * sinf(el) + 20.0f; + env->client->camera.position = (Vector3){cam_x, cam_y, cam_z}; + + if (env->client->camera_mode == 0) { + // Mode 0: Orbit - look at cam_target + env->client->camera.target = (Vector3){cam_target->pos.x, cam_target->pos.y, cam_target->pos.z}; + } else { + // Mode 1: Midpoint - look at midpoint between both planes + float mid_x = (p->pos.x + o->pos.x) / 2.0f; + float mid_y = (p->pos.y + o->pos.y) / 2.0f; + float mid_z = (p->pos.z + o->pos.z) / 2.0f; + env->client->camera.target = (Vector3){mid_x, mid_y, mid_z}; + } + } + + WaitTime(0.02); // Sync with sim DT (50 FPS) for proper GIF frame timing + BeginDrawing(); + ClearBackground((Color){0, 100, 120, 255}); + + rlSetClipPlanes(1.0, 15000.0); + BeginMode3D(env->client->camera); + + // DrawPlane uses raylib's Y-up convention (XZ plane), so we draw triangles instead + Vector3 g1 = {-4000, -4000, 0}; + Vector3 g2 = {4000, -4000, 0}; + Vector3 g3 = {4000, 4000, 0}; + Vector3 g4 = {-4000, 4000, 0}; + Color ground_color = (Color){20, 60, 20, 255}; + DrawTriangle3D(g1, g2, g3, ground_color); + DrawTriangle3D(g1, g3, g4, ground_color); + + DrawCubeWires((Vector3){0, 0, 2500}, 8000, 8000, 5000, (Color){100, 100, 100, 255}); + + float player_scale = env->camera_follow_opponent ? 4.0f : 1.0f; + float opponent_scale = env->camera_follow_opponent ? 1.0f : 4.0f; + + float prop_speed = 150.0f + 130.0f * p->throttle; + env->client->propeller_angle += prop_speed * 0.0167f; + if (env->client->propeller_angle > 2.0f * PI) { + env->client->propeller_angle -= 2.0f * PI; + } + + draw_plane_model(env->client, p->pos, p->ori, WHITE, player_scale, env->client->propeller_angle); + + draw_plane_model(env->client, o->pos, o->ori, RED, opponent_scale, env->client->propeller_angle); + + // Player tracer (yellow) + if (p->fire_cooldown >= FIRE_COOLDOWN - 2) { + Vec3 nose = add3(p->pos, quat_rotate(p->ori, vec3(15, 0, 0))); + Vec3 tracer_end = add3(p->pos, quat_rotate(p->ori, vec3(GUN_RANGE, 0, 0))); + Vector3 nose_r = {nose.x, nose.y, nose.z}; + Vector3 end_r = {tracer_end.x, tracer_end.y, tracer_end.z}; + DrawLine3D(nose_r, end_r, YELLOW); + } + + // Opponent tracer (orange) - for self-play or AutoAce + if (o->fire_cooldown >= FIRE_COOLDOWN - 2) { + Vec3 o_nose = add3(o->pos, quat_rotate(o->ori, vec3(15, 0, 0))); + Vec3 o_tracer_end = add3(o->pos, quat_rotate(o->ori, vec3(GUN_RANGE, 0, 0))); + Vector3 o_nose_r = {o_nose.x, o_nose.y, o_nose.z}; + Vector3 o_end_r = {o_tracer_end.x, o_tracer_end.y, o_tracer_end.z}; + DrawLine3D(o_nose_r, o_end_r, ORANGE); + } + + EndMode3D(); + + float speed = norm3(p->vel); + float dist_to_opp = norm3(sub3(o->pos, p->pos)); + + DrawText(TextFormat("Speed: %.0f m/s", speed), 10, 10, 20, WHITE); + DrawText(TextFormat("Altitude: %.0f m", p->pos.z), 10, 40, 20, WHITE); + DrawText(TextFormat("Throttle: %.0f%%", p->throttle * 100.0f), 10, 70, 20, WHITE); + DrawText(TextFormat("Distance: %.0f m", dist_to_opp), 10, 100, 20, WHITE); + DrawText(TextFormat("Tick: %d / %d", env->tick, env->max_steps), 10, 130, 20, WHITE); + DrawText(TextFormat("Return: %.2f", env->episode_return), 10, 160, 20, WHITE); + DrawText(TextFormat("Perf: %.1f%% | Shots: %.0f", env->log.perf / fmaxf(env->log.n, 1.0f) * 100.0f, env->log.shots_fired), 10, 190, 20, YELLOW); + DrawText(TextFormat("Stage: %d", env->stage), 10, 220, 20, LIME); + + // Show last round result prominently for first 100 ticks + if (env->tick < 100 && env->last_death_reason != DEATH_NONE) { + const char* result_text = NULL; + Color result_color = WHITE; + if (env->last_winner == 1) { + result_text = "PLAYER WINS"; + result_color = GREEN; + } else if (env->last_winner == -1) { + result_text = "OPPONENT WINS"; + result_color = RED; + } else if (env->last_death_reason == DEATH_OOB) { + result_text = "OUT OF BOUNDS"; + result_color = ORANGE; + } else if (env->last_death_reason == DEATH_TIMEOUT) { + result_text = "TIMEOUT"; + result_color = YELLOW; + } + if (result_text) { + int text_width = MeasureText(result_text, 50); + DrawText(result_text, (1920 - text_width) / 2, 450, 50, result_color); + } + } + + draw_obs_monitor(env); + + DrawText("Mouse drag: Orbit | Scroll: Zoom | R: Reset | ESC: Exit", 10, (int)env->client->height - 30, 16, GRAY); + + EndDrawing(); +} + +void c_close(Dogfight *env) { + if (env->client != NULL) { + if (env->client->model_loaded) { + UnloadModel(env->client->plane_model); + } + CloseWindow(); + free(env->client); + env->client = NULL; + } +} + +#endif // DOGFIGHT_RENDER_H diff --git a/pufferlib/ocean/dogfight/flightlib.h b/pufferlib/ocean/dogfight/flightlib.h new file mode 100644 index 000000000..f24c8f926 --- /dev/null +++ b/pufferlib/ocean/dogfight/flightlib.h @@ -0,0 +1,1139 @@ +// flightlib.h - Realistic RK4 flight physics for dogfight environment +// +// Full 6-DOF flight model with: +// - Angular momentum as state variable (omega integrated, not commanded) +// - RK4 integration (4th-order Runge-Kutta) +// - Aerodynamic moments from stability derivatives +// - Control surface effectiveness (elevator, aileron, rudder) +// - Euler's equations for rotational dynamics + +#ifndef FLIGHTLIB_H +#define FLIGHTLIB_H + +#include +#include +#include +#include + +#ifndef DEBUG +#define DEBUG 0 +#endif + +#ifndef PI +#define PI 3.14159265358979f +#endif + +// Debug control (0=off, 1+=increasingly verbose) +#ifndef DEBUG_REALISTIC +#define DEBUG_REALISTIC 0 +#endif + +static int _realistic_step_count = 0; +static int _realistic_rk4_stage = 0; + +typedef struct { float x, y, z; } Vec3; +typedef struct { float w, x, y, z; } Quat; + +static inline float clampf(float v, float lo, float hi) { + return v < lo ? lo : (v > hi ? hi : v); +} + +static inline float rndf(float a, float b) { + return a + ((float)rand() / (float)RAND_MAX) * (b - a); +} + +static inline Vec3 vec3(float x, float y, float z) { return (Vec3){x, y, z}; } +static inline Vec3 add3(Vec3 a, Vec3 b) { return (Vec3){a.x + b.x, a.y + b.y, a.z + b.z}; } +static inline Vec3 sub3(Vec3 a, Vec3 b) { return (Vec3){a.x - b.x, a.y - b.y, a.z - b.z}; } +static inline Vec3 mul3(Vec3 a, float s) { return (Vec3){a.x * s, a.y * s, a.z * s}; } +static inline float dot3(Vec3 a, Vec3 b) { return a.x * b.x + a.y * b.y + a.z * b.z; } +static inline float norm3(Vec3 a) { return sqrtf(dot3(a, a)); } + +static inline Vec3 normalize3(Vec3 v) { + float n = norm3(v); + if (n < 1e-8f) return vec3(0, 0, 0); + return mul3(v, 1.0f / n); +} + +static inline Vec3 cross3(Vec3 a, Vec3 b) { + return vec3( + a.y * b.z - a.z * b.y, + a.z * b.x - a.x * b.z, + a.x * b.y - a.y * b.x + ); +} + +static inline Quat quat(float w, float x, float y, float z) { return (Quat){w, x, y, z}; } + +static inline Quat quat_mul(Quat a, Quat b) { + return (Quat){ + a.w*b.w - a.x*b.x - a.y*b.y - a.z*b.z, + a.w*b.x + a.x*b.w + a.y*b.z - a.z*b.y, + a.w*b.y - a.x*b.z + a.y*b.w + a.z*b.x, + a.w*b.z + a.x*b.y - a.y*b.x + a.z*b.w + }; +} + +static inline Quat quat_add(Quat a, Quat b) { + return (Quat){a.w + b.w, a.x + b.x, a.y + b.y, a.z + b.z}; +} + +static inline Quat quat_scale(Quat q, float s) { + return (Quat){q.w * s, q.x * s, q.y * s, q.z * s}; +} + +static inline void quat_normalize(Quat* q) { + float n = sqrtf(q->w*q->w + q->x*q->x + q->y*q->y + q->z*q->z); + if (n > 1e-8f) { + float inv = 1.0f / n; + q->w *= inv; q->x *= inv; q->y *= inv; q->z *= inv; + } +} + +static inline Vec3 quat_rotate(Quat q, Vec3 v) { + Quat qv = {0.0f, v.x, v.y, v.z}; + Quat q_conj = {q.w, -q.x, -q.y, -q.z}; + Quat tmp = quat_mul(q, qv); + Quat res = quat_mul(tmp, q_conj); + return (Vec3){res.x, res.y, res.z}; +} + +static inline Quat quat_from_axis_angle(Vec3 axis, float angle) { + float half = angle * 0.5f; + float s = sinf(half); + return (Quat){cosf(half), axis.x * s, axis.y * s, axis.z * s}; +} + +// Aircraft parameters - P-51D Mustang (see P51d_REFERENCE_DATA.md) + +#define MASS 4082.0f // kg +#define WING_AREA 21.65f // m^2 +#define WINGSPAN 11.28f // m +#define CHORD 2.02f // m + +#define IXX 6500.0f // Roll inertia +#define IYY 22000.0f // Pitch inertia +#define IZZ 27000.0f // Yaw inertia + +#define C_D0 0.0163f +#define K 0.072f +#define K_SIDESLIP 0.7f +#define C_L_MAX 1.48f +#define C_L_ALPHA 5.56f +#define ALPHA_ZERO -0.021f +#define WING_INCIDENCE 0.026f + +#define ENGINE_POWER 1112000.0f // watts +#define ETA_PROP 0.80f + +#define GRAVITY 9.81f // m/s^2 +#define RHO 1.225f // kg/m^3 + +#define G_LIMIT_POS 6.0f +#define G_LIMIT_NEG 1.5f + +#define INV_MASS 0.000245f // 1/4082 +#define INV_GRAVITY 0.10197f // 1/9.81 +#define RAD_TO_DEG 57.2957795f // 180/PI + +#define MAX_PITCH_RATE 2.5f // rad/s +#define MAX_ROLL_RATE 3.0f // rad/s +#define MAX_YAW_RATE 0.50f // rad/s + +typedef struct { + Vec3 pos; + Vec3 vel; + Vec3 prev_vel; + Vec3 omega; + Quat ori; + float throttle; + float g_force; + float yaw_from_rudder; + int fire_cooldown; + float prev_energy; // Previous specific energy for energy management reward +} Plane; + +static inline void step_plane(Plane *p, float dt) { + p->prev_vel = p->vel; + + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + float speed = norm3(p->vel); + if (speed < 1.0f) speed = 80.0f; + p->vel = mul3(forward, speed); + p->pos = add3(p->pos, mul3(p->vel, dt)); + + if (DEBUG >= 10) printf("=== TARGET ===\n"); + if (DEBUG >= 10) printf("target_speed=%.1f m/s (expected=80)\n", speed); + if (DEBUG >= 10) printf("target_pos=(%.1f, %.1f, %.1f)\n", p->pos.x, p->pos.y, p->pos.z); + if (DEBUG >= 10) printf("target_fwd=(%.2f, %.2f, %.2f)\n", forward.x, forward.y, forward.z); +} + +#define CM_0 -0.005f // Pitch trim offset (fine-tuned for ~1.0G level flight) +#define CM_ALPHA -1.2f // Pitch stability (negative = stable, nose-up creates nose-down moment) +#define CL_BETA -0.08f // Dihedral effect (negative = stable, sideslip creates restoring roll) +#define CN_BETA 0.12f // Weathervane stability (positive = stable, sideslip creates restoring yaw) + +// Damping derivatives (dimensionless, multiplied by q*c/2V or p*b/2V) +#define CM_Q -10.0f // Pitch damping (matches JSBSim P-51D) +#define CL_P -0.4f // Roll damping (opposes roll rate) +#define CN_R -0.15f // Yaw damping (opposes yaw rate) + +// Control derivatives (per radian deflection) +#define CM_DELTA_E -0.5f // Elevator: negative = nose UP with positive (back stick) deflection +#define CL_DELTA_A 0.20f // Aileron: positive = roll RIGHT with positive deflection + // Tuning: 0.04f->19°, 0.15f->70°, need 90°, try 0.20f +#define CN_DELTA_R 0.015f // Rudder: positive = nose RIGHT with positive (right pedal) deflection + // Tuning: 0.015f should give 2-20° heading change with full rudder + +// Cross-coupling derivatives +#define CN_DELTA_A -0.007f // Adverse yaw from aileron (negative = right aileron causes left yaw) +#define CL_DELTA_R -0.003f // Roll from rudder (negative = right rudder causes left roll, rudder is above roll axis) + +// Control surface deflection limits (radians) +#define MAX_ELEVATOR_DEFLECTION 0.35f // ±20° +#define MAX_AILERON_DEFLECTION 0.35f // ±20° +#define MAX_RUDDER_DEFLECTION 0.35f // ±20° + +// High-speed control authority scaling (prevents oscillations at high speed) +// At high speeds, control moments scale with V² while damping scales with V, +// causing under-damped behavior. Scale down control authority to compensate. +// Values derived from per-speed optimal scale discovery (--find-optimal mode) +#define CONTROL_V_REF 100.0f // Only scale authority above 100 m/s (cruise speed) +#define CONTROL_SCALE_SLOPE 0.000833f // Authority reduction per m/s above ref +#define CONTROL_SCALE_MIN 0.05f // Minimum authority (never below 5%) + +// Runtime-configurable physics parameters for parameter sweeps +typedef struct { + float control_v_ref; // Reference speed for full authority + float control_scale_slope; // How fast authority drops with speed + float control_scale_min; // Floor for control authority + float damping_scale_slope; // Extra damping scale per m/s above ref (0 = off) +} FlightParams; + +// Default parameters (matches current compile-time #defines) +static inline FlightParams default_flight_params(void) { + return (FlightParams){ + .control_v_ref = CONTROL_V_REF, + .control_scale_slope = CONTROL_SCALE_SLOPE, + .control_scale_min = CONTROL_SCALE_MIN, + .damping_scale_slope = 0.0f + }; +} + +typedef struct { + Vec3 vel; + Vec3 v_dot; + Quat q_dot; + Vec3 w_dot; +} StateDerivative; + +static inline float compute_aoa(Plane* p) { + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + Vec3 up = quat_rotate(p->ori, vec3(0, 0, 1)); + + float V = norm3(p->vel); + if (V < 1.0f) return 0.0f; + + Vec3 vel_norm = normalize3(p->vel); + float cos_alpha = dot3(vel_norm, forward); + cos_alpha = clampf(cos_alpha, -1.0f, 1.0f); + float alpha = acosf(cos_alpha); // Always positive [0, pi] + + // Sign: positive when nose is ABOVE velocity vector + // If vel dot up < 0, velocity is "below" the body frame -> nose above -> alpha > 0 + float vel_dot_up = dot3(p->vel, up); + float sign = (vel_dot_up < 0) ? 1.0f : -1.0f; + + if (DEBUG_REALISTIC >= 3 && _realistic_rk4_stage == 0) { + printf(" [AOA] forward=(%.3f,%.3f,%.3f) up=(%.3f,%.3f,%.3f)\n", + forward.x, forward.y, forward.z, up.x, up.y, up.z); + printf(" [AOA] vel=(%.1f,%.1f,%.1f) |vel|=%.1f\n", + p->vel.x, p->vel.y, p->vel.z, V); + printf(" [AOA] vel_norm=(%.4f,%.4f,%.4f)\n", + vel_norm.x, vel_norm.y, vel_norm.z); + printf(" [AOA] cos_alpha=%.4f (vel_norm·forward)\n", cos_alpha); + printf(" [AOA] acos(cos_alpha)=%.4f rad = %.2f deg\n", alpha, alpha * RAD_TO_DEG); + printf(" [AOA] vel·up=%.4f -> sign=%.0f\n", vel_dot_up, sign); + printf(" [AOA] FINAL alpha=%.4f rad = %.2f deg\n", alpha * sign, alpha * sign * RAD_TO_DEG); + } + + return alpha * sign; +} + +static inline float compute_sideslip(Plane* p) { + Vec3 right = quat_rotate(p->ori, vec3(0, 1, 0)); + + float V = norm3(p->vel); + if (V < 1.0f) return 0.0f; + + Vec3 vel_norm = normalize3(p->vel); + + // beta = arcsin(v · right / |v|) - positive when velocity has component to the right + float sin_beta = dot3(vel_norm, right); + float beta = asinf(clampf(sin_beta, -1.0f, 1.0f)); + + if (DEBUG_REALISTIC >= 3 && _realistic_rk4_stage == 0) { + printf(" [BETA] right=(%.3f,%.3f,%.3f)\n", right.x, right.y, right.z); + printf(" [BETA] sin_beta=%.4f (vel_norm·right)\n", sin_beta); + printf(" [BETA] FINAL beta=%.4f rad = %.2f deg\n", beta, beta * RAD_TO_DEG); + } + + return beta; +} + +static inline Vec3 compute_lift_direction(Vec3 vel_norm, Vec3 right, Vec3 body_up) { + Vec3 lift_dir = cross3(vel_norm, right); + float mag = norm3(lift_dir); + + if (DEBUG_REALISTIC >= 3 && _realistic_rk4_stage == 0) { + printf(" [LIFT_DIR] vel_norm×right=(%.3f,%.3f,%.3f) |mag|=%.4f\n", + lift_dir.x, lift_dir.y, lift_dir.z, mag); + } + + if (mag > 0.01f) { + Vec3 result = mul3(lift_dir, 1.0f / mag); + if (DEBUG_REALISTIC >= 3 && _realistic_rk4_stage == 0) { + printf(" [LIFT_DIR] normalized=(%.3f,%.3f,%.3f)\n", result.x, result.y, result.z); + } + return result; + } + if (DEBUG_REALISTIC >= 3 && _realistic_rk4_stage == 0) { + printf(" [LIFT_DIR] FALLBACK to world_up=(0,0,1)\n"); + } + return (Vec3){0, 0, 1}; // Fallback to world-frame up (lift perpendicular to ground) +} + +static inline float compute_thrust(float throttle, float V) { + float P_avail = ENGINE_POWER * throttle; + float T_dynamic = (P_avail * ETA_PROP) / V; // Thrust from power equation + float T_static = 0.3f * P_avail; // Static thrust limit + float T = fminf(T_static, T_dynamic); // Can't exceed either limit + + if (DEBUG_REALISTIC >= 3 && _realistic_rk4_stage == 0) { + printf(" [THRUST] throttle=%.2f P_avail=%.0f W\n", throttle, P_avail); + printf(" [THRUST] T_dynamic=%.0f N, T_static=%.0f N -> T=%.0f N\n", + T_dynamic, T_static, T); + } + + return T; +} + +// Helper: apply derivative to state (for RK4 intermediate stages) +static inline void step_temp(Plane* state, StateDerivative* d, float dt, Plane* out) { + out->pos = add3(state->pos, mul3(d->vel, dt)); + out->vel = add3(state->vel, mul3(d->v_dot, dt)); + out->ori = quat_add(state->ori, quat_scale(d->q_dot, dt)); + quat_normalize(&out->ori); + out->omega = add3(state->omega, mul3(d->w_dot, dt)); + out->throttle = state->throttle; + out->g_force = state->g_force; + out->yaw_from_rudder = state->yaw_from_rudder; + out->fire_cooldown = state->fire_cooldown; + out->prev_vel = state->prev_vel; + + if (DEBUG_REALISTIC >= 5) { + printf(" [STEP_TEMP] dt=%.4f\n", dt); + printf(" [STEP_TEMP] d->vel=(%.2f,%.2f,%.2f) d->v_dot=(%.2f,%.2f,%.2f)\n", + d->vel.x, d->vel.y, d->vel.z, d->v_dot.x, d->v_dot.y, d->v_dot.z); + printf(" [STEP_TEMP] d->w_dot=(%.4f,%.4f,%.4f)\n", + d->w_dot.x, d->w_dot.y, d->w_dot.z); + printf(" [STEP_TEMP] out->vel=(%.2f,%.2f,%.2f)\n", out->vel.x, out->vel.y, out->vel.z); + printf(" [STEP_TEMP] out->omega=(%.4f,%.4f,%.4f)\n", + out->omega.x, out->omega.y, out->omega.z); + printf(" [STEP_TEMP] out->ori=(%.4f,%.4f,%.4f,%.4f)\n", + out->ori.w, out->ori.x, out->ori.y, out->ori.z); + } +} + +static inline void compute_derivatives(Plane* state, float* actions, float dt, StateDerivative* deriv) { + + if (DEBUG_REALISTIC >= 5) { + const char* stage_names[] = {"k1", "k2", "k3", "k4"}; + printf("\n === COMPUTE_DERIVATIVES (RK4 stage %s) ===\n", stage_names[_realistic_rk4_stage]); + } + + float V = norm3(state->vel); + if (V < 1.0f) V = 1.0f; // Prevent div-by-zero + + Vec3 vel_norm = normalize3(state->vel); + Vec3 forward = quat_rotate(state->ori, vec3(1, 0, 0)); // Body X-axis + Vec3 right = quat_rotate(state->ori, vec3(0, 1, 0)); // Body Y-axis + Vec3 body_up = quat_rotate(state->ori, vec3(0, 0, 1)); // Body Z-axis + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- STATE ---\n"); + printf(" pos=(%.1f, %.1f, %.1f)\n", state->pos.x, state->pos.y, state->pos.z); + printf(" vel=(%.2f, %.2f, %.2f) |V|=%.2f m/s\n", + state->vel.x, state->vel.y, state->vel.z, V); + printf(" vel_norm=(%.4f, %.4f, %.4f)\n", vel_norm.x, vel_norm.y, vel_norm.z); + printf(" ori=(w=%.4f, x=%.4f, y=%.4f, z=%.4f) |ori|=%.6f\n", + state->ori.w, state->ori.x, state->ori.y, state->ori.z, + sqrtf(state->ori.w*state->ori.w + state->ori.x*state->ori.x + + state->ori.y*state->ori.y + state->ori.z*state->ori.z)); + printf(" omega=(%.4f, %.4f, %.4f) rad/s = (%.2f, %.2f, %.2f) deg/s\n", + state->omega.x, state->omega.y, state->omega.z, + state->omega.x * RAD_TO_DEG, state->omega.y * RAD_TO_DEG, state->omega.z * RAD_TO_DEG); + printf(" forward=(%.4f, %.4f, %.4f)\n", forward.x, forward.y, forward.z); + printf(" right=(%.4f, %.4f, %.4f)\n", right.x, right.y, right.z); + printf(" body_up=(%.4f, %.4f, %.4f)\n", body_up.x, body_up.y, body_up.z); + + // Compute pitch angle from forward vector + float pitch_from_forward = asinf(-forward.z) * RAD_TO_DEG; // nose up = positive + printf(" pitch_from_forward=%.2f deg (nose %s)\n", + pitch_from_forward, pitch_from_forward > 0 ? "UP" : "DOWN"); + + // Velocity direction + float vel_pitch = asinf(vel_norm.z) * RAD_TO_DEG; // climbing = positive + printf(" vel_pitch=%.2f deg (%s)\n", vel_pitch, vel_pitch > 0 ? "CLIMBING" : "DESCENDING"); + } + + if (DEBUG_REALISTIC >= 3 && _realistic_rk4_stage == 0) { + printf("\n --- AERODYNAMIC ANGLES ---\n"); + } + float alpha = compute_aoa(state); + float beta = compute_sideslip(state); + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf(" alpha=%.4f rad = %.2f deg (%s)\n", alpha, alpha * RAD_TO_DEG, + alpha > 0 ? "nose ABOVE vel" : "nose BELOW vel"); + printf(" beta=%.4f rad = %.2f deg\n", beta, beta * RAD_TO_DEG); + } + + float q_bar = 0.5f * RHO * V * V; + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- DYNAMIC PRESSURE ---\n"); + printf(" q_bar = 0.5 * %.4f * %.1f^2 = %.1f Pa\n", RHO, V, q_bar); + } + + // ======================================================================== + // 4. Map actions to control surface deflections + // ======================================================================== + // Actions are [-1, 1], mapped to deflection in radians + // Sign conventions (M_moment is negated later for Z-up frame): + // - Elevator: actions[1] > 0 (push forward) → nose DOWN + // - Aileron: actions[2] > 0 → roll RIGHT + // - Rudder: actions[3] > 0 → yaw LEFT + float throttle = clampf((actions[0] + 1.0f) * 0.5f, 0.0f, 1.0f); // [0, 1] + + // Scale control authority at high speed to prevent over-controlling + // At high speed, control moments scale with V² while damping scales with V, + // causing under-damped oscillations. Reduce authority to compensate. + float control_scale = 1.0f - fmaxf(0.0f, V - CONTROL_V_REF) * CONTROL_SCALE_SLOPE; + control_scale = fmaxf(control_scale, CONTROL_SCALE_MIN); + + float delta_e = clampf(actions[1], -1.0f, 1.0f) * MAX_ELEVATOR_DEFLECTION * control_scale; + float delta_a = clampf(actions[2], -1.0f, 1.0f) * MAX_AILERON_DEFLECTION * control_scale; + float delta_r = clampf(actions[3], -1.0f, 1.0f) * MAX_RUDDER_DEFLECTION * control_scale; + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- CONTROLS ---\n"); + printf(" actions=[%.3f, %.3f, %.3f, %.3f]\n", + actions[0], actions[1], actions[2], actions[3]); + printf(" throttle=%.3f (%.0f%%)\n", throttle, throttle * 100); + printf(" control_scale=%.3f (V=%.1f, ref=%.1f)\n", control_scale, V, CONTROL_V_REF); + printf(" delta_e=%.4f rad = %.2f deg (elevator, %s)\n", + delta_e, delta_e * RAD_TO_DEG, + delta_e > 0 ? "push=nose DOWN" : delta_e < 0 ? "pull=nose UP" : "neutral"); + printf(" delta_a=%.4f rad = %.2f deg (aileron)\n", delta_a, delta_a * RAD_TO_DEG); + printf(" delta_r=%.4f rad = %.2f deg (rudder)\n", delta_r, delta_r * RAD_TO_DEG); + } + + float alpha_effective = alpha + WING_INCIDENCE - ALPHA_ZERO; + float C_L_raw = C_L_ALPHA * alpha_effective; + float C_L = clampf(C_L_raw, -C_L_MAX, C_L_MAX); // Stall limiting + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- LIFT COEFFICIENT ---\n"); + printf(" alpha=%.4f + WING_INCIDENCE=%.4f - ALPHA_ZERO=%.4f = alpha_eff=%.4f rad\n", + alpha, WING_INCIDENCE, ALPHA_ZERO, alpha_effective); + printf(" C_L_raw = C_L_ALPHA(%.2f) * alpha_eff(%.4f) = %.4f\n", + C_L_ALPHA, alpha_effective, C_L_raw); + printf(" C_L = clamp(%.4f, -%.2f, %.2f) = %.4f%s\n", + C_L_raw, C_L_MAX, C_L_MAX, C_L, + (C_L != C_L_raw) ? " (STALL CLAMPED!)" : ""); + } + + float C_D0_term = C_D0; + float induced_term = K * C_L * C_L; + float sideslip_term = K_SIDESLIP * beta * beta; + float C_D = C_D0_term + induced_term + sideslip_term; + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- DRAG COEFFICIENT ---\n"); + printf(" C_D0=%.4f + K*C_L^2=%.4f + K_sideslip*beta^2=%.4f = C_D=%.4f\n", + C_D0_term, induced_term, sideslip_term, C_D); + printf(" L/D ratio = %.2f\n", (C_D > 0.0001f) ? C_L / C_D : 0.0f); + } + + float L_mag = C_L * q_bar * WING_AREA; + float D_mag = C_D * q_bar * WING_AREA; + + if (DEBUG_REALISTIC >= 3 && _realistic_rk4_stage == 0) { + printf("\n --- LIFT DIRECTION ---\n"); + } + Vec3 lift_dir = compute_lift_direction(vel_norm, right, body_up); + Vec3 F_lift = mul3(lift_dir, L_mag); + + Vec3 F_drag = mul3(vel_norm, -D_mag); + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- AERODYNAMIC FORCES ---\n"); + printf(" L_mag = C_L(%.4f) * q_bar(%.1f) * S(%.1f) = %.1f N\n", + C_L, q_bar, WING_AREA, L_mag); + printf(" D_mag = C_D(%.4f) * q_bar(%.1f) * S(%.1f) = %.1f N\n", + C_D, q_bar, WING_AREA, D_mag); + printf(" lift_dir=(%.4f, %.4f, %.4f)\n", lift_dir.x, lift_dir.y, lift_dir.z); + printf(" F_lift=(%.1f, %.1f, %.1f) N\n", F_lift.x, F_lift.y, F_lift.z); + printf(" F_drag=(%.1f, %.1f, %.1f) N (opposite to vel)\n", F_drag.x, F_drag.y, F_drag.z); + } + + if (DEBUG_REALISTIC >= 3 && _realistic_rk4_stage == 0) { + printf("\n --- THRUST ---\n"); + } + float T_mag = compute_thrust(throttle, V); + Vec3 F_thrust = mul3(forward, T_mag); + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf(" F_thrust=(%.1f, %.1f, %.1f) N (along forward)\n", + F_thrust.x, F_thrust.y, F_thrust.z); + } + + Vec3 F_gravity = vec3(0, 0, -MASS * GRAVITY); + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- GRAVITY ---\n"); + printf(" F_gravity=(%.1f, %.1f, %.1f) N\n", F_gravity.x, F_gravity.y, F_gravity.z); + } + + Vec3 F_aero = add3(F_lift, F_drag); + Vec3 F_aero_thrust = add3(F_aero, F_thrust); + Vec3 F_total = add3(F_aero_thrust, F_gravity); + deriv->v_dot = mul3(F_total, INV_MASS); + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- TOTAL FORCE & ACCELERATION ---\n"); + printf(" F_aero (lift+drag)=(%.1f, %.1f, %.1f) N\n", F_aero.x, F_aero.y, F_aero.z); + printf(" F_aero+thrust=(%.1f, %.1f, %.1f) N\n", F_aero_thrust.x, F_aero_thrust.y, F_aero_thrust.z); + printf(" F_total=(%.1f, %.1f, %.1f) N\n", F_total.x, F_total.y, F_total.z); + printf(" |F_total|=%.1f N\n", norm3(F_total)); + printf(" v_dot = F/m = (%.3f, %.3f, %.3f) m/s^2\n", deriv->v_dot.x, deriv->v_dot.y, deriv->v_dot.z); + printf(" |v_dot|=%.3f m/s^2 = %.3f g\n", norm3(deriv->v_dot), norm3(deriv->v_dot) / GRAVITY); + + // Break down vertical component + printf(" v_dot.z=%.3f m/s^2 (%s)\n", deriv->v_dot.z, + deriv->v_dot.z > 0 ? "accelerating UP" : "accelerating DOWN"); + + // What's contributing to vertical acceleration? + printf(" Vertical breakdown: lift_z=%.1f + drag_z=%.1f + thrust_z=%.1f + grav_z=%.1f = %.1f N\n", + F_lift.z, F_drag.z, F_thrust.z, F_gravity.z, F_total.z); + } + + float p = state->omega.x; // roll rate + float q = state->omega.y; // pitch rate + float r = state->omega.z; // yaw rate + + // Non-dimensional rates for damping derivatives + float p_hat = p * WINGSPAN / (2.0f * V); + float q_hat = q * CHORD / (2.0f * V); + float r_hat = r * WINGSPAN / (2.0f * V); + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- ANGULAR RATES ---\n"); + printf(" p=%.4f, q=%.4f, r=%.4f rad/s (body: roll, pitch, yaw)\n", p, q, r); + printf(" p_hat=%.6f, q_hat=%.6f, r_hat=%.6f (non-dimensional)\n", p_hat, q_hat, r_hat); + } + + // Rolling moment coefficient (Cl) + // Components: dihedral effect + roll damping + aileron control + rudder coupling + float Cl_beta = CL_BETA * beta; + float Cl_p = CL_P * p_hat; + float Cl_da = CL_DELTA_A * delta_a; + float Cl_dr = CL_DELTA_R * delta_r; + float Cl = Cl_beta + Cl_p + Cl_da + Cl_dr; + + // Pitching moment coefficient (Cm) + // Components: static stability + pitch damping + elevator control + float Cm_0 = CM_0; // Trim offset + float Cm_alpha = CM_ALPHA * alpha; + float Cm_q = CM_Q * q_hat; + float Cm_de = CM_DELTA_E * delta_e; + float Cm = Cm_0 + Cm_alpha + Cm_q + Cm_de; + + // Yawing moment coefficient (Cn) + // Components: weathervane stability + yaw damping + rudder control + adverse yaw + float Cn_beta = CN_BETA * beta; + float Cn_r = CN_R * r_hat; + float Cn_dr = CN_DELTA_R * delta_r; + float Cn_da = CN_DELTA_A * delta_a; + float Cn = Cn_beta + Cn_r + Cn_dr + Cn_da; + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- MOMENT COEFFICIENTS ---\n"); + printf(" Cl = CL_BETA*beta(%.6f) + CL_P*p_hat(%.6f) + CL_DELTA_A*da(%.6f) + CL_DELTA_R*dr(%.6f) = %.6f\n", + Cl_beta, Cl_p, Cl_da, Cl_dr, Cl); + printf(" Cm = CM_0(%.6f) + CM_ALPHA*alpha(%.6f) + CM_Q*q_hat(%.6f) + CM_DELTA_E*de(%.6f) = %.6f\n", + Cm_0, Cm_alpha, Cm_q, Cm_de, Cm); + printf(" CM_0=%.4f (trim), CM_ALPHA=%.2f, alpha=%.4f rad -> Cm_alpha=%.6f\n", CM_0, CM_ALPHA, alpha, Cm_alpha); + printf(" (alpha>0 means nose ABOVE vel, CM_ALPHA<0 means nose-down restoring moment)\n"); + printf(" (Cm_alpha %.6f is %s)\n", Cm_alpha, + Cm_alpha > 0 ? "nose-UP moment" : Cm_alpha < 0 ? "nose-DOWN moment" : "zero"); + printf(" Cn = CN_BETA*beta(%.6f) + CN_R*r_hat(%.6f) + CN_DELTA_R*dr(%.6f) + CN_DELTA_A*da(%.6f) = %.6f\n", + Cn_beta, Cn_r, Cn_dr, Cn_da, Cn); + } + + // Convert to dimensional moments (N⋅m) + // Note: Cm sign convention is for aircraft Z-down frame (positive Cm = nose up) + // In our Z-up frame, positive omega.y = nose DOWN, so we negate Cm + float L_moment = Cl * q_bar * WING_AREA * WINGSPAN; // Roll moment + float M_moment = -Cm * q_bar * WING_AREA * CHORD; // Pitch moment (negated for Z-up frame) + float N_moment = Cn * q_bar * WING_AREA * WINGSPAN; // Yaw moment + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- DIMENSIONAL MOMENTS ---\n"); + printf(" L_moment (roll) = Cl(%.6f) * q_bar(%.1f) * S(%.1f) * b(%.1f) = %.1f N⋅m\n", + Cl, q_bar, WING_AREA, WINGSPAN, L_moment); + printf(" M_moment (pitch) = -Cm(%.6f) * q_bar(%.1f) * S(%.1f) * c(%.2f) = %.1f N⋅m\n", + Cm, q_bar, WING_AREA, CHORD, M_moment); + printf(" Note: M_moment negated because our Z is up (positive omega.y = nose DOWN)\n"); + printf(" Cm=%.6f -> -Cm=%.6f -> M_moment=%.1f (will cause omega.y to %s)\n", + Cm, -Cm, M_moment, M_moment > 0 ? "INCREASE (nose DOWN)" : "DECREASE (nose UP)"); + printf(" N_moment (yaw) = Cn(%.6f) * q_bar(%.1f) * S(%.1f) * b(%.1f) = %.1f N⋅m\n", + Cn, q_bar, WING_AREA, WINGSPAN, N_moment); + } + + // ======================================================================== + // Angular acceleration (Euler's equations) + // ======================================================================== + // τ = I⋅α + ω × (I⋅ω) → α = I⁻¹(τ - ω × (I⋅ω)) + // For diagonal inertia tensor, the gyroscopic coupling terms are: + // (I_yy - I_zz) * q * r for roll + // (I_zz - I_xx) * r * p for pitch + // (I_xx - I_yy) * p * q for yaw + + float gyro_roll = (IYY - IZZ) * q * r; + float gyro_pitch = (IZZ - IXX) * r * p; + float gyro_yaw = (IXX - IYY) * p * q; + + deriv->w_dot.x = (L_moment + gyro_roll) / IXX; + deriv->w_dot.y = (M_moment + gyro_pitch) / IYY; + deriv->w_dot.z = (N_moment + gyro_yaw) / IZZ; + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- ANGULAR ACCELERATION (Euler's equations) ---\n"); + printf(" Gyroscopic: roll=%.3f, pitch=%.3f, yaw=%.3f N⋅m\n", gyro_roll, gyro_pitch, gyro_yaw); + printf(" I = (Ixx=%.0f, Iyy=%.0f, Izz=%.0f) kg⋅m^2\n", IXX, IYY, IZZ); + printf(" w_dot.x (roll) = (L=%.1f + gyro=%.3f) / Ixx = %.6f rad/s^2 = %.3f deg/s^2\n", + L_moment, gyro_roll, deriv->w_dot.x, deriv->w_dot.x * RAD_TO_DEG); + printf(" w_dot.y (pitch) = (M=%.1f + gyro=%.3f) / Iyy = %.6f rad/s^2 = %.3f deg/s^2\n", + M_moment, gyro_pitch, deriv->w_dot.y, deriv->w_dot.y * RAD_TO_DEG); + printf(" w_dot.z (yaw) = (N=%.1f + gyro=%.3f) / Izz = %.6f rad/s^2 = %.3f deg/s^2\n", + N_moment, gyro_yaw, deriv->w_dot.z, deriv->w_dot.z * RAD_TO_DEG); + printf(" w_dot.y=%.6f means omega.y will %s -> nose will pitch %s\n", + deriv->w_dot.y, + deriv->w_dot.y > 0 ? "INCREASE" : "DECREASE", + deriv->w_dot.y > 0 ? "DOWN" : "UP"); + } + + // q_dot = 0.5 * q * [0, ω] where ω is angular velocity in body frame + Quat omega_q = {0.0f, state->omega.x, state->omega.y, state->omega.z}; + Quat q_dot = quat_mul(state->ori, omega_q); + deriv->q_dot.w = 0.5f * q_dot.w; + deriv->q_dot.x = 0.5f * q_dot.x; + deriv->q_dot.y = 0.5f * q_dot.y; + deriv->q_dot.z = 0.5f * q_dot.z; + + if (DEBUG_REALISTIC >= 3 && _realistic_rk4_stage == 0) { + printf("\n --- QUATERNION KINEMATICS ---\n"); + printf(" omega_q=(%.4f, %.4f, %.4f, %.4f)\n", omega_q.w, omega_q.x, omega_q.y, omega_q.z); + printf(" q_dot (before 0.5)=(%.6f, %.6f, %.6f, %.6f)\n", q_dot.w, q_dot.x, q_dot.y, q_dot.z); + printf(" q_dot (final)=(%.6f, %.6f, %.6f, %.6f)\n", + deriv->q_dot.w, deriv->q_dot.x, deriv->q_dot.y, deriv->q_dot.z); + } + + deriv->vel = state->vel; + + if (DEBUG_REALISTIC >= 2 && _realistic_rk4_stage == 0) { + printf("\n --- DERIVATIVE SUMMARY ---\n"); + printf(" vel = (%.2f, %.2f, %.2f) m/s\n", deriv->vel.x, deriv->vel.y, deriv->vel.z); + printf(" v_dot = (%.3f, %.3f, %.3f) m/s^2\n", deriv->v_dot.x, deriv->v_dot.y, deriv->v_dot.z); + printf(" q_dot = (%.6f, %.6f, %.6f, %.6f)\n", + deriv->q_dot.w, deriv->q_dot.x, deriv->q_dot.y, deriv->q_dot.z); + printf(" w_dot = (%.6f, %.6f, %.6f) rad/s^2\n", deriv->w_dot.x, deriv->w_dot.y, deriv->w_dot.z); + } +} + +// Version with runtime-configurable parameters for sweeps +static inline void compute_derivatives_with_params( + Plane* state, float* actions, float dt, + StateDerivative* deriv, FlightParams* params) +{ + float V = norm3(state->vel); + if (V < 1.0f) V = 1.0f; + + Vec3 vel_norm = normalize3(state->vel); + Vec3 forward = quat_rotate(state->ori, vec3(1, 0, 0)); + Vec3 right = quat_rotate(state->ori, vec3(0, 1, 0)); + Vec3 body_up = quat_rotate(state->ori, vec3(0, 0, 1)); + + float alpha = compute_aoa(state); + float beta = compute_sideslip(state); + float q_bar = 0.5f * RHO * V * V; + + // Controls with runtime parameters + float throttle = clampf((actions[0] + 1.0f) * 0.5f, 0.0f, 1.0f); + + float control_scale = 1.0f - fmaxf(0.0f, V - params->control_v_ref) * params->control_scale_slope; + control_scale = fmaxf(control_scale, params->control_scale_min); + + float delta_e = clampf(actions[1], -1.0f, 1.0f) * MAX_ELEVATOR_DEFLECTION * control_scale; + float delta_a = clampf(actions[2], -1.0f, 1.0f) * MAX_AILERON_DEFLECTION * control_scale; + float delta_r = clampf(actions[3], -1.0f, 1.0f) * MAX_RUDDER_DEFLECTION * control_scale; + + // Lift and drag + float alpha_effective = alpha + WING_INCIDENCE - ALPHA_ZERO; + float C_L_raw = C_L_ALPHA * alpha_effective; + float C_L = clampf(C_L_raw, -C_L_MAX, C_L_MAX); + + float C_D = C_D0 + K * C_L * C_L + K_SIDESLIP * beta * beta; + + float L_mag = C_L * q_bar * WING_AREA; + float D_mag = C_D * q_bar * WING_AREA; + + Vec3 lift_dir = compute_lift_direction(vel_norm, right, body_up); + Vec3 F_lift = mul3(lift_dir, L_mag); + Vec3 F_drag = mul3(vel_norm, -D_mag); + + float T_mag = compute_thrust(throttle, V); + Vec3 F_thrust = mul3(forward, T_mag); + Vec3 F_gravity = vec3(0, 0, -MASS * GRAVITY); + + Vec3 F_total = add3(add3(add3(F_lift, F_drag), F_thrust), F_gravity); + deriv->v_dot = mul3(F_total, INV_MASS); + + // Angular rates and damping + float p = state->omega.x; + float q = state->omega.y; + float r = state->omega.z; + + float p_hat = p * WINGSPAN / (2.0f * V); + float q_hat = q * CHORD / (2.0f * V); + float r_hat = r * WINGSPAN / (2.0f * V); + + // Damping scaling - can boost damping at high speed + float damping_scale = 1.0f + fmaxf(0.0f, V - params->control_v_ref) * params->damping_scale_slope; + + // Moment coefficients with scaled damping + float Cl = CL_BETA * beta + (CL_P * p_hat * damping_scale) + CL_DELTA_A * delta_a + CL_DELTA_R * delta_r; + float Cm = CM_0 + CM_ALPHA * alpha + (CM_Q * q_hat * damping_scale) + CM_DELTA_E * delta_e; + float Cn = CN_BETA * beta + (CN_R * r_hat * damping_scale) + CN_DELTA_R * delta_r + CN_DELTA_A * delta_a; + + // Dimensional moments + float L_moment = Cl * q_bar * WING_AREA * WINGSPAN; + float M_moment = -Cm * q_bar * WING_AREA * CHORD; + float N_moment = Cn * q_bar * WING_AREA * WINGSPAN; + + // Angular acceleration (Euler's equations) + float gyro_roll = (IYY - IZZ) * q * r; + float gyro_pitch = (IZZ - IXX) * r * p; + float gyro_yaw = (IXX - IYY) * p * q; + + deriv->w_dot.x = (L_moment + gyro_roll) / IXX; + deriv->w_dot.y = (M_moment + gyro_pitch) / IYY; + deriv->w_dot.z = (N_moment + gyro_yaw) / IZZ; + + // Quaternion kinematics + Quat omega_q = {0.0f, state->omega.x, state->omega.y, state->omega.z}; + Quat q_dot = quat_mul(state->ori, omega_q); + deriv->q_dot.w = 0.5f * q_dot.w; + deriv->q_dot.x = 0.5f * q_dot.x; + deriv->q_dot.y = 0.5f * q_dot.y; + deriv->q_dot.z = 0.5f * q_dot.z; + + deriv->vel = state->vel; +} + +// RK4 step with runtime parameters +static inline void rk4_step_with_params(Plane* state, float* actions, float dt, FlightParams* params) { + StateDerivative k1, k2, k3, k4; + Plane temp; + + _realistic_rk4_stage = 0; + compute_derivatives_with_params(state, actions, dt, &k1, params); + + _realistic_rk4_stage = 1; + step_temp(state, &k1, dt * 0.5f, &temp); + compute_derivatives_with_params(&temp, actions, dt, &k2, params); + + _realistic_rk4_stage = 2; + step_temp(state, &k2, dt * 0.5f, &temp); + compute_derivatives_with_params(&temp, actions, dt, &k3, params); + + _realistic_rk4_stage = 3; + step_temp(state, &k3, dt, &temp); + compute_derivatives_with_params(&temp, actions, dt, &k4, params); + + _realistic_rk4_stage = 0; + + float dt_6 = dt / 6.0f; + + state->pos.x += (k1.vel.x + 2.0f * k2.vel.x + 2.0f * k3.vel.x + k4.vel.x) * dt_6; + state->pos.y += (k1.vel.y + 2.0f * k2.vel.y + 2.0f * k3.vel.y + k4.vel.y) * dt_6; + state->pos.z += (k1.vel.z + 2.0f * k2.vel.z + 2.0f * k3.vel.z + k4.vel.z) * dt_6; + + state->vel.x += (k1.v_dot.x + 2.0f * k2.v_dot.x + 2.0f * k3.v_dot.x + k4.v_dot.x) * dt_6; + state->vel.y += (k1.v_dot.y + 2.0f * k2.v_dot.y + 2.0f * k3.v_dot.y + k4.v_dot.y) * dt_6; + state->vel.z += (k1.v_dot.z + 2.0f * k2.v_dot.z + 2.0f * k3.v_dot.z + k4.v_dot.z) * dt_6; + + state->ori.w += (k1.q_dot.w + 2.0f * k2.q_dot.w + 2.0f * k3.q_dot.w + k4.q_dot.w) * dt_6; + state->ori.x += (k1.q_dot.x + 2.0f * k2.q_dot.x + 2.0f * k3.q_dot.x + k4.q_dot.x) * dt_6; + state->ori.y += (k1.q_dot.y + 2.0f * k2.q_dot.y + 2.0f * k3.q_dot.y + k4.q_dot.y) * dt_6; + state->ori.z += (k1.q_dot.z + 2.0f * k2.q_dot.z + 2.0f * k3.q_dot.z + k4.q_dot.z) * dt_6; + + state->omega.x += (k1.w_dot.x + 2.0f * k2.w_dot.x + 2.0f * k3.w_dot.x + k4.w_dot.x) * dt_6; + state->omega.y += (k1.w_dot.y + 2.0f * k2.w_dot.y + 2.0f * k3.w_dot.y + k4.w_dot.y) * dt_6; + state->omega.z += (k1.w_dot.z + 2.0f * k2.w_dot.z + 2.0f * k3.w_dot.z + k4.w_dot.z) * dt_6; + + quat_normalize(&state->ori); +} + +// Step plane with runtime parameters +static inline void step_plane_with_params(Plane *p, float *actions, float dt, FlightParams* params) { + p->prev_vel = p->vel; + + float clamped_actions[4]; + for (int i = 0; i < 4; i++) { + clamped_actions[i] = clampf(actions[i], -1.0f, 1.0f); + } + + rk4_step_with_params(p, clamped_actions, dt, params); + + p->throttle = (clamped_actions[0] + 1.0f) * 0.5f; + + p->omega.x = clampf(p->omega.x, -5.0f, 5.0f); + p->omega.y = clampf(p->omega.y, -5.0f, 5.0f); + p->omega.z = clampf(p->omega.z, -2.0f, 2.0f); + + // G-force calculation + Vec3 dv = sub3(p->vel, p->prev_vel); + Vec3 accel = mul3(dv, 1.0f / dt); + Vec3 body_up = quat_rotate(p->ori, vec3(0, 0, 1)); + float accel_up = dot3(accel, body_up); + p->g_force = accel_up * INV_GRAVITY + 1.0f; + + // G-limit enforcement (same as step_plane_with_physics) + float speed_before = norm3(p->vel); + if (p->g_force > G_LIMIT_POS) { + float excess_g = p->g_force - G_LIMIT_POS; + float excess_accel = excess_g * GRAVITY; + Vec3 correction = mul3(body_up, excess_accel * dt); + Vec3 vel_norm = normalize3(p->vel); + float correction_along_vel = dot3(correction, vel_norm); + Vec3 correction_perp = sub3(correction, mul3(vel_norm, correction_along_vel)); + p->vel = sub3(p->vel, correction_perp); + p->g_force = G_LIMIT_POS; + } else if (p->g_force < -G_LIMIT_NEG) { + float deficit_g = -G_LIMIT_NEG - p->g_force; + float deficit_accel = deficit_g * GRAVITY; + Vec3 correction = mul3(body_up, deficit_accel * dt); + Vec3 vel_norm = normalize3(p->vel); + float correction_along_vel = dot3(correction, vel_norm); + Vec3 correction_perp = sub3(correction, mul3(vel_norm, correction_along_vel)); + p->vel = add3(p->vel, correction_perp); + p->g_force = -G_LIMIT_NEG; + } + + p->yaw_from_rudder = compute_sideslip(p); +} + +static inline void rk4_step(Plane* state, float* actions, float dt) { + StateDerivative k1, k2, k3, k4; + Plane temp; + + if (DEBUG_REALISTIC >= 5) { + printf("\n========== RK4 STEP (dt=%.4f) ==========\n", dt); + } + + // k1: derivative at current state + _realistic_rk4_stage = 0; + compute_derivatives(state, actions, dt, &k1); + + if (DEBUG_REALISTIC >= 5) { + printf("\n k1: v_dot=(%.3f,%.3f,%.3f) w_dot=(%.6f,%.6f,%.6f)\n", + k1.v_dot.x, k1.v_dot.y, k1.v_dot.z, k1.w_dot.x, k1.w_dot.y, k1.w_dot.z); + } + + // k2: derivative at state + k1*dt/2 + _realistic_rk4_stage = 1; + step_temp(state, &k1, dt * 0.5f, &temp); + compute_derivatives(&temp, actions, dt, &k2); + + if (DEBUG_REALISTIC >= 5) { + printf(" k2: v_dot=(%.3f,%.3f,%.3f) w_dot=(%.6f,%.6f,%.6f)\n", + k2.v_dot.x, k2.v_dot.y, k2.v_dot.z, k2.w_dot.x, k2.w_dot.y, k2.w_dot.z); + } + + // k3: derivative at state + k2*dt/2 + _realistic_rk4_stage = 2; + step_temp(state, &k2, dt * 0.5f, &temp); + compute_derivatives(&temp, actions, dt, &k3); + + if (DEBUG_REALISTIC >= 5) { + printf(" k3: v_dot=(%.3f,%.3f,%.3f) w_dot=(%.6f,%.6f,%.6f)\n", + k3.v_dot.x, k3.v_dot.y, k3.v_dot.z, k3.w_dot.x, k3.w_dot.y, k3.w_dot.z); + } + + // k4: derivative at state + k3*dt + _realistic_rk4_stage = 3; + step_temp(state, &k3, dt, &temp); + compute_derivatives(&temp, actions, dt, &k4); + + if (DEBUG_REALISTIC >= 5) { + printf(" k4: v_dot=(%.3f,%.3f,%.3f) w_dot=(%.6f,%.6f,%.6f)\n", + k4.v_dot.x, k4.v_dot.y, k4.v_dot.z, k4.w_dot.x, k4.w_dot.y, k4.w_dot.z); + } + + _realistic_rk4_stage = 0; // Reset for next step + + float dt_6 = dt / 6.0f; + + Vec3 old_vel = state->vel; + Vec3 old_omega = state->omega; + Quat old_ori = state->ori; + + state->pos.x += (k1.vel.x + 2.0f * k2.vel.x + 2.0f * k3.vel.x + k4.vel.x) * dt_6; + state->pos.y += (k1.vel.y + 2.0f * k2.vel.y + 2.0f * k3.vel.y + k4.vel.y) * dt_6; + state->pos.z += (k1.vel.z + 2.0f * k2.vel.z + 2.0f * k3.vel.z + k4.vel.z) * dt_6; + + state->vel.x += (k1.v_dot.x + 2.0f * k2.v_dot.x + 2.0f * k3.v_dot.x + k4.v_dot.x) * dt_6; + state->vel.y += (k1.v_dot.y + 2.0f * k2.v_dot.y + 2.0f * k3.v_dot.y + k4.v_dot.y) * dt_6; + state->vel.z += (k1.v_dot.z + 2.0f * k2.v_dot.z + 2.0f * k3.v_dot.z + k4.v_dot.z) * dt_6; + + state->ori.w += (k1.q_dot.w + 2.0f * k2.q_dot.w + 2.0f * k3.q_dot.w + k4.q_dot.w) * dt_6; + state->ori.x += (k1.q_dot.x + 2.0f * k2.q_dot.x + 2.0f * k3.q_dot.x + k4.q_dot.x) * dt_6; + state->ori.y += (k1.q_dot.y + 2.0f * k2.q_dot.y + 2.0f * k3.q_dot.y + k4.q_dot.y) * dt_6; + state->ori.z += (k1.q_dot.z + 2.0f * k2.q_dot.z + 2.0f * k3.q_dot.z + k4.q_dot.z) * dt_6; + + state->omega.x += (k1.w_dot.x + 2.0f * k2.w_dot.x + 2.0f * k3.w_dot.x + k4.w_dot.x) * dt_6; + state->omega.y += (k1.w_dot.y + 2.0f * k2.w_dot.y + 2.0f * k3.w_dot.y + k4.w_dot.y) * dt_6; + state->omega.z += (k1.w_dot.z + 2.0f * k2.w_dot.z + 2.0f * k3.w_dot.z + k4.w_dot.z) * dt_6; + + quat_normalize(&state->ori); + + if (DEBUG_REALISTIC >= 5) { + printf("\n --- RK4 WEIGHTED AVERAGE ---\n"); + printf(" vel: (%.2f,%.2f,%.2f) -> (%.2f,%.2f,%.2f) delta=(%.3f,%.3f,%.3f)\n", + old_vel.x, old_vel.y, old_vel.z, + state->vel.x, state->vel.y, state->vel.z, + state->vel.x - old_vel.x, state->vel.y - old_vel.y, state->vel.z - old_vel.z); + printf(" omega: (%.4f,%.4f,%.4f) -> (%.4f,%.4f,%.4f) delta=(%.6f,%.6f,%.6f)\n", + old_omega.x, old_omega.y, old_omega.z, + state->omega.x, state->omega.y, state->omega.z, + state->omega.x - old_omega.x, state->omega.y - old_omega.y, state->omega.z - old_omega.z); + printf(" ori: (%.4f,%.4f,%.4f,%.4f) -> (%.4f,%.4f,%.4f,%.4f)\n", + old_ori.w, old_ori.x, old_ori.y, old_ori.z, + state->ori.w, state->ori.x, state->ori.y, state->ori.z); + } +} + + +static inline void step_plane_with_physics(Plane *p, float *actions, float dt) { + _realistic_step_count++; + + if (DEBUG_REALISTIC >= 1) { + printf("\n"); + printf("╔══════════════════════════════════════════════════════════════════════════════╗\n"); + printf("║ REALISTIC PHYSICS STEP %d (dt=%.4f) \n", _realistic_step_count, dt); + printf("╚══════════════════════════════════════════════════════════════════════════════╝\n"); + } + + p->prev_vel = p->vel; + + if (DEBUG_REALISTIC >= 1) { + printf("\n=== BEFORE RK4 ===\n"); + printf("pos=(%.1f, %.1f, %.1f) alt=%.1f m\n", p->pos.x, p->pos.y, p->pos.z, p->pos.z); + printf("vel=(%.2f, %.2f, %.2f) |V|=%.2f m/s\n", p->vel.x, p->vel.y, p->vel.z, norm3(p->vel)); + printf("ori=(w=%.4f, x=%.4f, y=%.4f, z=%.4f)\n", p->ori.w, p->ori.x, p->ori.y, p->ori.z); + printf("omega=(%.4f, %.4f, %.4f) rad/s\n", p->omega.x, p->omega.y, p->omega.z); + + // Compute pitch angle + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + float pitch = asinf(-forward.z) * RAD_TO_DEG; + Vec3 vel_norm = normalize3(p->vel); + float vel_pitch = asinf(vel_norm.z) * RAD_TO_DEG; + float alpha = compute_aoa(p) * RAD_TO_DEG; + + printf("pitch=%.2f deg (nose %s), vel_pitch=%.2f deg (%s), alpha=%.2f deg\n", + pitch, pitch > 0 ? "UP" : "DOWN", + vel_pitch, vel_pitch > 0 ? "CLIMBING" : "DESCENDING", + alpha); + printf("actions=[thr=%.2f, elev=%.2f, ail=%.2f, rud=%.2f]\n", + actions[0], actions[1], actions[2], actions[3]); + } + + float clamped_actions[4]; + for (int i = 0; i < 4; i++) { + clamped_actions[i] = clampf(actions[i], -1.0f, 1.0f); + } + + rk4_step(p, clamped_actions, dt); + + p->throttle = (clamped_actions[0] + 1.0f) * 0.5f; + + float old_omega_y = p->omega.y; + p->omega.x = clampf(p->omega.x, -5.0f, 5.0f); // ~286 deg/s max roll + p->omega.y = clampf(p->omega.y, -5.0f, 5.0f); // ~286 deg/s max pitch + p->omega.z = clampf(p->omega.z, -2.0f, 2.0f); // ~115 deg/s max yaw (less authority) + + if (DEBUG_REALISTIC >= 1 && old_omega_y != p->omega.y) { + printf(" WARNING: omega.y clamped from %.4f to %.4f\n", old_omega_y, p->omega.y); + } + + Vec3 dv = sub3(p->vel, p->prev_vel); + Vec3 accel = mul3(dv, 1.0f / dt); + Vec3 body_up = quat_rotate(p->ori, vec3(0, 0, 1)); + + // Total acceleration in body-up direction, converted to G + // Add 1G because we're measuring from inertial frame (gravity already in accel) + float accel_up = dot3(accel, body_up); + p->g_force = accel_up * INV_GRAVITY + 1.0f; + + if (DEBUG_REALISTIC >= 1) { + printf("\n=== G-FORCE CALCULATION ===\n"); + printf("dv=(%.3f, %.3f, %.3f) over dt=%.4f\n", dv.x, dv.y, dv.z, dt); + printf("accel=(%.3f, %.3f, %.3f) m/s^2\n", accel.x, accel.y, accel.z); + printf("body_up=(%.4f, %.4f, %.4f)\n", body_up.x, body_up.y, body_up.z); + printf("accel·body_up=%.3f m/s^2 / g=%.3f + 1.0 = %.3f G\n", + accel_up, accel_up * INV_GRAVITY, p->g_force); + } + + float speed_before_glimit = norm3(p->vel); + + if (p->g_force > G_LIMIT_POS) { + // Positive G exceeded - reduce upward acceleration + float excess_g = p->g_force - G_LIMIT_POS; + float excess_accel = excess_g * GRAVITY; + + if (DEBUG_REALISTIC >= 1) { + printf("G-LIMIT: +%.2f G exceeded limit +%.1f by %.2f G, reducing vel\n", + p->g_force, G_LIMIT_POS, excess_g); + } + + Vec3 correction = mul3(body_up, excess_accel * dt); + + // Project out the component along velocity to preserve speed (energy) + Vec3 vel_norm = normalize3(p->vel); + float correction_along_vel = dot3(correction, vel_norm); + Vec3 correction_perp = sub3(correction, mul3(vel_norm, correction_along_vel)); + + p->vel = sub3(p->vel, correction_perp); + p->g_force = G_LIMIT_POS; + + } else if (p->g_force < -G_LIMIT_NEG) { + // Negative G exceeded - reduce downward acceleration + float deficit_g = -G_LIMIT_NEG - p->g_force; + float deficit_accel = deficit_g * GRAVITY; + + if (DEBUG_REALISTIC >= 1) { + printf("G-LIMIT: %.2f G exceeded limit -%.1f by %.2f G, reducing vel\n", + p->g_force, G_LIMIT_NEG, -deficit_g); + } + + Vec3 correction = mul3(body_up, deficit_accel * dt); + + // Project out the component along velocity to preserve speed (energy) + Vec3 vel_norm = normalize3(p->vel); + float correction_along_vel = dot3(correction, vel_norm); + Vec3 correction_perp = sub3(correction, mul3(vel_norm, correction_along_vel)); + + p->vel = add3(p->vel, correction_perp); + p->g_force = -G_LIMIT_NEG; + } + + // Verify energy was preserved (speed should not have changed) + if (DEBUG_REALISTIC >= 1) { + float speed_after_glimit = norm3(p->vel); + if (fabsf(speed_after_glimit - speed_before_glimit) > 0.01f) { + printf("WARNING: G-limit changed speed from %.2f to %.2f!\n", + speed_before_glimit, speed_after_glimit); + } + } + + p->yaw_from_rudder = compute_sideslip(p); + + if (DEBUG_REALISTIC >= 1) { + printf("\n=== AFTER RK4 ===\n"); + printf("pos=(%.1f, %.1f, %.1f) alt=%.1f m (Δalt=%.2f m)\n", + p->pos.x, p->pos.y, p->pos.z, p->pos.z, p->pos.z - (p->pos.z - p->vel.z * dt)); + printf("vel=(%.2f, %.2f, %.2f) |V|=%.2f m/s\n", p->vel.x, p->vel.y, p->vel.z, norm3(p->vel)); + printf("ori=(w=%.4f, x=%.4f, y=%.4f, z=%.4f)\n", p->ori.w, p->ori.x, p->ori.y, p->ori.z); + printf("omega=(%.4f, %.4f, %.4f) rad/s = (%.2f, %.2f, %.2f) deg/s\n", + p->omega.x, p->omega.y, p->omega.z, + p->omega.x * RAD_TO_DEG, p->omega.y * RAD_TO_DEG, p->omega.z * RAD_TO_DEG); + printf("g_force=%.2f G (limits: +%.1f/-%.1f)\n", p->g_force, G_LIMIT_POS, G_LIMIT_NEG); + + // Compute final pitch and alpha + Vec3 forward = quat_rotate(p->ori, vec3(1, 0, 0)); + float pitch = asinf(-forward.z) * RAD_TO_DEG; + float alpha = compute_aoa(p) * RAD_TO_DEG; + Vec3 vel_norm = normalize3(p->vel); + float vel_pitch = asinf(vel_norm.z) * RAD_TO_DEG; + + printf("final: pitch=%.2f deg, vel_pitch=%.2f deg, alpha=%.2f deg\n", + pitch, vel_pitch, alpha); + + // Key insight: what's happening to orientation vs velocity? + printf("\n=== STEP SUMMARY ===\n"); + printf("vel.z changed: %.3f -> %.3f (Δ=%.3f m/s, %s)\n", + p->prev_vel.z, p->vel.z, p->vel.z - p->prev_vel.z, + p->vel.z > p->prev_vel.z ? "CLIMBING MORE" : "DIVING MORE"); + printf("omega.y = %.4f rad/s = %.2f deg/s (nose pitching %s)\n", + p->omega.y, p->omega.y * RAD_TO_DEG, + p->omega.y > 0 ? "DOWN" : "UP"); + } + + if (DEBUG >= 10) { + float V = norm3(p->vel); + float alpha = compute_aoa(p) * RAD_TO_DEG; + float beta = compute_sideslip(p) * RAD_TO_DEG; + printf("=== REALISTIC PHYSICS ===\n"); + printf("speed=%.1f m/s\n", V); + printf("throttle=%.2f\n", p->throttle); + printf("alpha=%.2f deg, beta=%.2f deg\n", alpha, beta); + printf("omega=(%.3f, %.3f, %.3f) rad/s\n", p->omega.x, p->omega.y, p->omega.z); + printf("g_force=%.2f g (limit=+%.1f/-%.1f)\n", p->g_force, G_LIMIT_POS, G_LIMIT_NEG); + } +} + +// Calculate specific energy: Es = altitude + speed²/(2*g) +static inline float calc_specific_energy(Plane *p) { + float speed = norm3(p->vel); + return p->pos.z + (speed * speed) / (2.0f * GRAVITY); +} + +static inline void reset_plane(Plane *p, Vec3 pos, Vec3 vel) { + p->pos = pos; + p->vel = vel; + p->prev_vel = vel; + p->omega = vec3(0, 0, 0); + p->ori = quat(1, 0, 0, 0); + p->throttle = 0.5f; + p->g_force = 1.0f; + p->yaw_from_rudder = 0.0f; + p->fire_cooldown = 0; + // Initialize specific energy for energy management reward + float speed = norm3(vel); + p->prev_energy = pos.z + (speed * speed) / (2.0f * GRAVITY); + + _realistic_step_count = 0; + + if (DEBUG_REALISTIC >= 1) { + printf("\n=== RESET_PLANE ===\n"); + printf("pos=(%.1f, %.1f, %.1f)\n", pos.x, pos.y, pos.z); + printf("vel=(%.2f, %.2f, %.2f) |V|=%.2f m/s\n", vel.x, vel.y, vel.z, norm3(vel)); + printf("ori=(1, 0, 0, 0) (identity)\n"); + printf("omega=(0, 0, 0)\n"); + } +} + +#endif // FLIGHTLIB_H diff --git a/pufferlib/ocean/dogfight/p40.glb b/pufferlib/ocean/dogfight/p40.glb new file mode 100644 index 000000000..c21c170a3 Binary files /dev/null and b/pufferlib/ocean/dogfight/p40.glb differ diff --git a/pufferlib/ocean/dogfight/train_dual_selfplay.py b/pufferlib/ocean/dogfight/train_dual_selfplay.py new file mode 100644 index 000000000..e32072c02 --- /dev/null +++ b/pufferlib/ocean/dogfight/train_dual_selfplay.py @@ -0,0 +1,1314 @@ +#!/usr/bin/env python +"""Dual-Perspective Self-Play Training for Dogfight with Checkpoint Queue. + +This script trains a single policy on BOTH perspectives of each dogfight episode: +- Player perspective: (obs, actions, rewards) +- Opponent perspective: (obs, actions, -rewards) + +The opponent is loaded from a checkpoint queue, always N checkpoints behind the learner. +This creates a stable skill gap and natural curriculum within self-play. + +Architecture: +1. Phase 1 (Curriculum): Stages 0-19 with autopilot opponent, single-perspective training +2. Milestone (Stage 10): Save first checkpoint +3. Phase 2 (Curriculum): Stages 10-19 with autopilot opponent +4. Milestone (Stage 20): Save second checkpoint, START SELF-PLAY vs stage 10 checkpoint +5. Domination: When perf >= threshold, save new checkpoint and upgrade opponent + +Key insight: We train the SAME policy from BOTH perspectives of the zero-sum game. +Each episode provides: +- "Here's what the winner did" → positive reward signal +- "Here's what the loser did" → negative reward signal + +This doubles the learning signal and teaches the agent to both attack AND defend. + +Usage: + # Basic dual self-play training (starts self-play at stage 20) + python pufferlib/ocean/dogfight/train_dual_selfplay.py + + # With wandb logging + python pufferlib/ocean/dogfight/train_dual_selfplay.py --wandb --wandb-project df-dual + + # Start self-play immediately at stage 0 (for testing) + python pufferlib/ocean/dogfight/train_dual_selfplay.py --selfplay-min-stage 0 + + # Custom checkpoint queue settings + python pufferlib/ocean/dogfight/train_dual_selfplay.py --checkpoint-lag 2 --perf-threshold 0.70 + + # Verbose debug output + python pufferlib/ocean/dogfight/train_dual_selfplay.py --debug + + # Evaluate player policy against opponent checkpoint with rendering + python pufferlib/ocean/dogfight/train_dual_selfplay.py eval \ + --load-model-path experiments/model.pt \ + --opponent-checkpoint checkpoints/selfplay_xxx/checkpoint_stage20_step20000000.pt \ + --render-mode raylib + + # Or with wandb run ID for player model + python pufferlib/ocean/dogfight/train_dual_selfplay.py eval \ + --load-id abc123 \ + --opponent-checkpoint path/to/opponent.pt +""" +import os +import sys +import copy +import time +import argparse +import uuid +from collections import defaultdict + +import numpy as np +import torch +import torch.nn as nn + +import pufferlib +import pufferlib.vector +import pufferlib.pytorch +from pufferlib import pufferl +from pufferlib.checkpoint_queue import CheckpointQueue + +# Debug level: 0=off, 1=key events, 2=loop progress, 3=detailed +DEBUG_LEVEL = 0 + +def debug(level, msg): + """Print debug message if level is enabled.""" + if DEBUG_LEVEL >= level: + print(f'[DUAL-DEBUG-{level}] {msg}', flush=True) + + +# Configuration defaults +DEFAULT_OPPONENT_UPDATE_INTERVAL = 1_000_000 # Update opponent every 1M steps (legacy, unused with queue) +DEFAULT_SELFPLAY_MIN_STAGE = 20 # Only enable self-play after stage 20 +DEFAULT_CHECKPOINT_LAG = 1 # Opponent is N checkpoints behind (1=2nd newest) +DEFAULT_PERF_THRESHOLD = 0.65 # Kill rate to trigger checkpoint save + opponent upgrade +DEFAULT_MIN_STEPS_BETWEEN_CHECKPOINTS = 2_000_000 # Minimum steps before saving new checkpoint +DEFAULT_MAX_CHECKPOINTS = 20 # Max selfplay checkpoints (milestones always kept) + + +class DualPerspectiveTrainer: + """Trainer that collects experience from both player and opponent perspectives. + + During curriculum (stages 0-19): + - Standard single-perspective PPO + - Opponent is autopilot (handled by C code) + + During dual self-play (stage 20+): + - Opponent is frozen copy of learner + - Collect experience from BOTH perspectives + - Train on combined experience (2x training signal) + """ + + def __init__(self, config, vecenv, learner_policy, logger=None, + opponent_update_interval=DEFAULT_OPPONENT_UPDATE_INTERVAL, + selfplay_min_stage=DEFAULT_SELFPLAY_MIN_STAGE, + checkpoint_lag=DEFAULT_CHECKPOINT_LAG, + perf_threshold=DEFAULT_PERF_THRESHOLD, + min_steps_between_checkpoints=DEFAULT_MIN_STEPS_BETWEEN_CHECKPOINTS, + max_checkpoints=DEFAULT_MAX_CHECKPOINTS, + checkpoint_dir=None, + run_id=None): + # Store custom config + self.opponent_update_interval = opponent_update_interval + self.selfplay_min_stage = selfplay_min_stage + self.checkpoint_lag = checkpoint_lag + self.perf_threshold = perf_threshold + self.min_steps_between_checkpoints = min_steps_between_checkpoints + self.use_dual_selfplay = False + self.last_opponent_update = 0 + + # Create the standard PuffeRL trainer for the learner + self.trainer = pufferl.PuffeRL(config, vecenv, learner_policy, logger) + self.config = config + self.vecenv = vecenv + self.learner_policy = learner_policy + + # Generate run ID if not provided + if run_id is None: + run_id = str(uuid.uuid4())[:8] + + # Initialize checkpoint queue + if checkpoint_dir is None: + checkpoint_dir = f'checkpoints/selfplay_{run_id}' + self.checkpoint_queue = CheckpointQueue( + save_dir=checkpoint_dir, + max_checkpoints=max_checkpoints + ) + + # Track milestone saves and domination state + self._saved_stage10 = False + self._saved_stage20 = False + self._current_opponent_path = None + self.last_checkpoint_step = 0 + self._current_stage = 0 + + # Create frozen opponent policy (copy of learner) + self.opponent_policy = None + self._init_opponent_policy() + + # Get driver env for direct C access + self.driver_env = vecenv.driver_env + + # Dual experience buffers (allocated lazily) + self.opponent_obs = None + self.opponent_actions = None + self.opponent_logprobs = None + self.opponent_values = None + self.opponent_rewards = None + self.opponent_terminals = None + + # Track opponent LSTM state if using RNN + self.opponent_lstm_h = None + self.opponent_lstm_c = None + + print(f'[DUAL-SELFPLAY] Initialized: min_stage={selfplay_min_stage}, ' + f'checkpoint_lag={checkpoint_lag}, perf_threshold={perf_threshold}') + print(f'[DUAL-SELFPLAY] Checkpoint dir: {checkpoint_dir}') + + def _init_opponent_policy(self): + """Create frozen opponent policy as copy of learner.""" + device = self.config['device'] + + # Deep copy the learner policy architecture + self.opponent_policy = copy.deepcopy(self.learner_policy) + self.opponent_policy = self.opponent_policy.to(device) + + # Copy current learner weights + self.opponent_policy.load_state_dict(self.learner_policy.state_dict()) + + # Freeze for inference only + self.opponent_policy.eval() + for p in self.opponent_policy.parameters(): + p.requires_grad = False + + print(f'[DUAL-SELFPLAY] Opponent policy initialized from learner') + + def _allocate_opponent_buffers(self): + """Allocate experience buffers for opponent perspective.""" + if self.opponent_obs is not None: + return # Already allocated + + # Match learner buffer shapes + device = self.config['device'] + segments = self.trainer.segments + horizon = self.config['bptt_horizon'] + obs_space = self.vecenv.single_observation_space + atn_space = self.vecenv.single_action_space + + self.opponent_obs = torch.zeros(segments, horizon, *obs_space.shape, + dtype=pufferlib.pytorch.numpy_to_torch_dtype_dict[obs_space.dtype], + pin_memory=device == 'cuda' and self.config.get('cpu_offload', False), + device='cpu' if self.config.get('cpu_offload', False) else device) + self.opponent_actions = torch.zeros(segments, horizon, *atn_space.shape, device=device, + dtype=pufferlib.pytorch.numpy_to_torch_dtype_dict[atn_space.dtype]) + self.opponent_values = torch.zeros(segments, horizon, device=device) + self.opponent_logprobs = torch.zeros(segments, horizon, device=device) + self.opponent_rewards = torch.zeros(segments, horizon, device=device) + self.opponent_terminals = torch.zeros(segments, horizon, device=device) + + print(f'[DUAL-SELFPLAY] Allocated opponent buffers: ' + f'obs={self.opponent_obs.shape}, actions={self.opponent_actions.shape}') + # Verify allocation is zeros + obs_range = (self.opponent_obs.min().item(), self.opponent_obs.max().item()) + debug(1, f'Opponent obs buffer after allocation: range={obs_range}') + + def _update_opponent(self): + """Load opponent from checkpoint queue instead of copying weights.""" + opponent_path = self.checkpoint_queue.get_opponent(lag=self.checkpoint_lag) + + if opponent_path is None: + # Queue too small, fall back to copying learner weights + self.opponent_policy.load_state_dict(self.learner_policy.state_dict()) + self.last_opponent_update = self.trainer.global_step + print(f'[DUAL-SELFPLAY] Queue too small, copied learner weights at step {self.trainer.global_step}') + return + + if opponent_path != self._current_opponent_path: + self._load_opponent_from_checkpoint(opponent_path) + self._current_opponent_path = opponent_path + self.last_opponent_update = self.trainer.global_step + + def _load_opponent_from_checkpoint(self, checkpoint_path: str): + """Load opponent policy from a checkpoint file.""" + checkpoint = torch.load(checkpoint_path, map_location=self.config['device']) + self.opponent_policy.load_state_dict(checkpoint['policy_state_dict']) + + # Get checkpoint info for logging + tag = checkpoint.get('tag', 'unknown') + step = checkpoint.get('step', 0) + print(f'[CHECKPOINT-QUEUE] Loaded opponent from {tag} (step {step}): {checkpoint_path}') + + def _check_milestone_save(self, current_stage: int): + """Save checkpoints at stage 10 and stage 20 milestones.""" + if current_stage >= 10 and not self._saved_stage10: + self.checkpoint_queue.save( + self.learner_policy, + self.trainer.global_step, + current_stage, + "stage10" + ) + self._saved_stage10 = True + self.last_checkpoint_step = self.trainer.global_step + print(f'[CHECKPOINT-QUEUE] Saved milestone: stage10 at step {self.trainer.global_step}') + + if current_stage >= 20 and not self._saved_stage20: + self.checkpoint_queue.save( + self.learner_policy, + self.trainer.global_step, + current_stage, + "stage20" + ) + self._saved_stage20 = True + self.last_checkpoint_step = self.trainer.global_step + print(f'[CHECKPOINT-QUEUE] Saved milestone: stage20 at step {self.trainer.global_step}') + + def _check_domination(self, logs): + """Check if learner dominates opponent using perf metric. + + When the learner's kill rate (perf) exceeds the threshold: + 1. Save a new checkpoint + 2. Upgrade opponent to older checkpoint (lag positions behind) + """ + if not self.use_dual_selfplay: + return + + # Check minimum steps since last checkpoint + steps_since_last = self.trainer.global_step - self.last_checkpoint_step + if steps_since_last < self.min_steps_between_checkpoints: + return + + # Get perf from logs (already computed kill rate) + # Note: stats are prefixed with 'environment/' in mean_and_log() + perf = logs.get('environment/perf', 0) if logs else 0 + if perf >= self.perf_threshold: + print(f'[CHECKPOINT-QUEUE] Learner dominating (perf={perf:.2f} >= {self.perf_threshold}), saving checkpoint') + + # Save new checkpoint + checkpoint_num = len([c for c in self.checkpoint_queue.checkpoints if not c.is_milestone()]) + tag = f"selfplay_{checkpoint_num}" + self.checkpoint_queue.save( + self.learner_policy, + self.trainer.global_step, + self._current_stage, + tag + ) + self.last_checkpoint_step = self.trainer.global_step + + # Log queue state + queue_state = self.checkpoint_queue.get_queue_state() + print(f'[CHECKPOINT-QUEUE] Queue: {queue_state["tags"]}') + + # Upgrade opponent to older checkpoint (lag positions behind) + self._update_opponent() + + def _check_selfplay_transition(self, stats=None): + """Check if we should transition to dual self-play mode and save milestones. + + Args: + stats: Stats dict from trainer.stats (contains 'stage' from C logs) + """ + # Get current stage from stats (populated by C code during evaluate) + # Use avg_stage which is more reliable than individual episode stages + # (when target=20.9, 90% episodes are stage 20, 10% are stage 19) + if stats and 'avg_stage' in stats and len(stats['avg_stage']) > 0: + # avg_stage is the mean stage across episodes - use max to catch when we hit 20 + current_stage = max(stats['avg_stage']) + elif stats and 'stage' in stats and len(stats['stage']) > 0: + # Fallback to raw stage values + current_stage = max(stats['stage']) + else: + # Last resort fallback + current_stage = getattr(self.driver_env, '_current_stage', 0) + if self.trainer.epoch % 100 == 0: + print(f'[DUAL-SELFPLAY] WARNING: No stage in stats, using fallback stage={current_stage}') + + self._current_stage = current_stage + + # Check for milestone saves (stage 10, stage 20) + self._check_milestone_save(int(current_stage)) + + debug(1, f'_check_selfplay_transition: use_dual_selfplay={self.use_dual_selfplay}') + if self.use_dual_selfplay: + return # Already in self-play mode + + # Trigger at 19.9+ to catch stage 20 reliably (avg_stage=20.0 when all episodes are stage 20) + trigger_threshold = self.selfplay_min_stage - 0.1 # 20 - 0.1 = 19.9 + + if current_stage >= trigger_threshold: + print(f'[DUAL-SELFPLAY] Transitioning to dual self-play at stage {current_stage}', flush=True) + self.use_dual_selfplay = True + + # Allocate opponent buffers + self._allocate_opponent_buffers() + + # Enable opponent override in C code (activates self-play mode) + # This must be done when transitioning to self-play, not at env init + from pufferlib.ocean.dogfight import binding + binding.vec_enable_opponent_override(self.driver_env.c_envs, 1) + + # Signal workers to enable opponent override via shared memory flag + # (workers check this flag in step() before using opponent actions) + if hasattr(self.vecenv, 'buf') and 'selfplay_active' in self.vecenv.buf: + self.vecenv.buf['selfplay_active'][0] = 1 + print(f'[DUAL-SELFPLAY] Set selfplay_active flag in shared memory') + + # Initialize opponent from checkpoint queue + self._update_opponent() + + def evaluate(self): + """Evaluate with dual experience collection in self-play mode.""" + # Note: selfplay transition check is done in train() BEFORE stats are cleared + + if not self.use_dual_selfplay: + # Standard single-perspective evaluation + return self.trainer.evaluate() + + # Dual self-play evaluation + return self._evaluate_dual() + + def _evaluate_dual(self): + """Collect experience from both player AND opponent perspectives.""" + debug(1, f'_evaluate_dual starting: segments={self.trainer.segments}') + + profile = self.trainer.profile + epoch = self.trainer.epoch + profile('eval', epoch) + profile('eval_misc', epoch, nest=True) + + config = self.config + device = config['device'] + + # Import binding for C-level access + from pufferlib.ocean.dogfight import binding + + # Reset LSTM states if using RNN + if config['use_rnn']: + for k in self.trainer.lstm_h: + self.trainer.lstm_h[k].zero_() + self.trainer.lstm_c[k].zero_() + # Initialize opponent LSTM if needed + if self.opponent_lstm_h is None: + n = self.vecenv.agents_per_batch + h = self.learner_policy.hidden_size + total_agents = self.trainer.total_agents + self.opponent_lstm_h = {i*n: torch.zeros(n, h, device=device) for i in range(total_agents//n)} + self.opponent_lstm_c = {i*n: torch.zeros(n, h, device=device) for i in range(total_agents//n)} + for k in self.opponent_lstm_h: + self.opponent_lstm_h[k].zero_() + self.opponent_lstm_c[k].zero_() + + self.trainer.full_rows = 0 + loop_count = 0 + while self.trainer.full_rows < self.trainer.segments: + loop_count += 1 + if loop_count % 100 == 1: + debug(2, f'eval loop {loop_count}: full_rows={self.trainer.full_rows}/{self.trainer.segments}') + + profile('env', epoch) + o, r, d, t, info, env_id, mask = self.vecenv.recv() + + profile('eval_misc', epoch) + env_id = slice(env_id[0], env_id[-1] + 1) + debug(3, f'recv: o.shape={o.shape}, env_id={env_id}') + + done_mask = d + t + self.trainer.global_step += int(mask.sum()) + + profile('eval_copy', epoch) + o = torch.as_tensor(o) + o_device = o.to(device) + r = torch.as_tensor(r).to(device) + d = torch.as_tensor(d).to(device) + + # Get opponent observations from shared memory buffers (Multiprocessing) + # or via C binding (Serial). C code writes to buffers during c_step(). + if hasattr(self.vecenv, 'buf') and 'opponent_observations' in self.vecenv.buf: + # Multiprocessing: read from shared memory buffer + o_opponent_all = self.vecenv.buf['opponent_observations'] + debug(3, f'opponent obs from buf: shape={o_opponent_all.shape}') + # buf shape is (num_workers, agents_per_worker, *obs_shape) + # w_slice from recv() gives us the right worker indices + o_opponent = torch.as_tensor(o_opponent_all[self.vecenv.w_slice].reshape(-1, *self.vecenv.single_observation_space.shape)).to(device) + else: + # Serial: use C binding directly + o_opponent_all = binding.vec_get_opponent_observations(self.driver_env.c_envs) + debug(3, f'opponent obs from binding: all.shape={o_opponent_all.shape}, slicing with env_id={env_id}') + o_opponent = torch.as_tensor(o_opponent_all[env_id]).to(device) + + # Handle NaN observations (can occur at episode boundaries) + # Replace NaN with zeros - these will get masked out anyway + nan_count = np.isnan(o_opponent_all[env_id]).sum() if isinstance(o_opponent_all, np.ndarray) else 0 + if nan_count > 0: + debug(2, f'NaN in opponent obs: {nan_count} values') + if torch.isnan(o_opponent).any(): + o_opponent = torch.nan_to_num(o_opponent, nan=0.0) + # Verify cleaning worked + if torch.isnan(o_opponent).any(): + debug(1, f'ERROR: NaN still in o_opponent after nan_to_num!') + else: + debug(3, f'NaN cleaned successfully') + + profile('eval_forward', epoch) + with torch.no_grad(), self.trainer.amp_context: + # Learner forward pass + state_p = dict( + reward=r, + done=d, + env_id=env_id, + mask=mask, + ) + + if config['use_rnn']: + state_p['lstm_h'] = self.trainer.lstm_h[env_id.start] + state_p['lstm_c'] = self.trainer.lstm_c[env_id.start] + + logits_p, value_p = self.trainer.policy.forward_eval(o_device, state_p) + action_p, logprob_p, _ = pufferlib.pytorch.sample_logits(logits_p) + r_clamped = torch.clamp(r, -1, 1) + + # Opponent forward pass (no gradients, frozen) + state_o = dict( + reward=-r, # Opponent gets negative reward + done=d, + env_id=env_id, + mask=mask, + lstm_h=None, + lstm_c=None, + ) + + if config['use_rnn']: + state_o['lstm_h'] = self.opponent_lstm_h[env_id.start] + state_o['lstm_c'] = self.opponent_lstm_c[env_id.start] + + logits_o, value_o = self.opponent_policy.forward_eval(o_opponent, state_o) + action_o, logprob_o, _ = pufferlib.pytorch.sample_logits(logits_o) + + debug(3, f'actions: player={action_p.shape}, opponent={action_o.shape}') + + profile('eval_copy', epoch) + with torch.no_grad(): + # Update LSTM states + if config['use_rnn']: + self.trainer.lstm_h[env_id.start] = state_p['lstm_h'] + self.trainer.lstm_c[env_id.start] = state_p['lstm_c'] + self.opponent_lstm_h[env_id.start] = state_o['lstm_h'] + self.opponent_lstm_c[env_id.start] = state_o['lstm_c'] + + # Fast path for fully vectorized envs + l = self.trainer.ep_lengths[env_id.start].item() + batch_rows = slice(self.trainer.ep_indices[env_id.start].item(), + 1+self.trainer.ep_indices[env_id.stop - 1].item()) + + # Store PLAYER experience + if config.get('cpu_offload', False): + self.trainer.observations[batch_rows, l] = o + else: + self.trainer.observations[batch_rows, l] = o_device + + self.trainer.actions[batch_rows, l] = action_p + self.trainer.logprobs[batch_rows, l] = logprob_p + self.trainer.rewards[batch_rows, l] = r_clamped + self.trainer.terminals[batch_rows, l] = d.float() + self.trainer.values[batch_rows, l] = value_p.flatten() + + # Store OPPONENT experience (rewards are NEGATIVE of player) + if config.get('cpu_offload', False): + self.opponent_obs[batch_rows, l] = o_opponent.cpu() + else: + self.opponent_obs[batch_rows, l] = o_opponent + + self.opponent_actions[batch_rows, l] = action_o + self.opponent_logprobs[batch_rows, l] = logprob_o + self.opponent_rewards[batch_rows, l] = -r_clamped # ZERO-SUM + self.opponent_terminals[batch_rows, l] = d.float() + self.opponent_values[batch_rows, l] = value_o.flatten() + + # Handle episode boundaries + self.trainer.ep_lengths[env_id] += 1 + if l+1 >= config['bptt_horizon']: + num_full = env_id.stop - env_id.start + self.trainer.ep_indices[env_id] = self.trainer.free_idx + torch.arange( + num_full, device=config['device']).int() + self.trainer.ep_lengths[env_id] = 0 + self.trainer.free_idx += num_full + self.trainer.full_rows += num_full + + # Prepare actions for env + action_p_np = action_p.cpu().numpy() + action_o_np = action_o.cpu().numpy() + + if isinstance(logits_p, torch.distributions.Normal): + action_p_np = np.clip(action_p_np, + self.vecenv.action_space.low, + self.vecenv.action_space.high) + action_o_np = np.clip(action_o_np, + self.vecenv.action_space.low, + self.vecenv.action_space.high) + + profile('eval_misc', epoch) + # Process info + for i in info: + for k, v in pufferlib.unroll_nested_dict(i): + if isinstance(v, np.ndarray): + v = v.tolist() + elif isinstance(v, (list, tuple)): + self.trainer.stats[k].extend(v) + else: + self.trainer.stats[k].append(v) + + # Set opponent actions: write to shared memory (Multiprocessing) or C binding (Serial) + profile('env', epoch) + if hasattr(self.vecenv, 'buf') and 'opponent_actions' in self.vecenv.buf: + # Multiprocessing: write to shared memory buffer + # Workers will read this during their step() call + opp_act_buf = self.vecenv.buf['opponent_actions'] + opp_act_buf[self.vecenv.w_slice] = action_o_np.reshape(opp_act_buf[self.vecenv.w_slice].shape) + else: + # Serial: set directly via C binding + binding.vec_set_opponent_actions(self.driver_env.c_envs, action_o_np) + + # Send player actions + self.vecenv.send(action_p_np) + + profile('eval_misc', epoch) + self.trainer.free_idx = self.trainer.total_agents + self.trainer.ep_indices = torch.arange(self.trainer.total_agents, + device=device, dtype=torch.int32) + self.trainer.ep_lengths.zero_() + profile.end() + # Verify opponent obs after evaluate + obs_range = (self.opponent_obs.min().item(), self.opponent_obs.max().item()) + debug(1, f'Opponent obs after _evaluate_dual: range={obs_range}') + + debug(1, f'_evaluate_dual done: loops={loop_count}, global_step={self.trainer.global_step}') + return self.trainer.stats + + def train(self): + """Train on combined experience in self-play mode.""" + if not self.use_dual_selfplay: + # Standard single-perspective training + # Check for selfplay transition BEFORE train() clears stats + self._check_selfplay_transition(self.trainer.stats) + logs = self.trainer.train() + return logs + + # Dual self-play training + logs = self._train_dual() + + # Check if learner dominates opponent -> save checkpoint and upgrade + self._check_domination(logs) + + return logs + + def _train_dual(self): + """Train on combined player + opponent experience.""" + debug(1, f'_train_dual starting: epoch={self.trainer.epoch}') + + profile = self.trainer.profile + epoch = self.trainer.epoch + profile('train', epoch) + profile('train_misc', epoch, nest=True) + losses = defaultdict(float) + config = self.config + device = config['device'] + + # Combine player and opponent experience + # Shape: [2*segments, horizon, ...] + combined_obs = torch.cat([self.trainer.observations, self.opponent_obs], dim=0) + combined_actions = torch.cat([self.trainer.actions, self.opponent_actions], dim=0) + combined_logprobs = torch.cat([self.trainer.logprobs, self.opponent_logprobs], dim=0) + combined_values = torch.cat([self.trainer.values, self.opponent_values], dim=0) + combined_rewards = torch.cat([self.trainer.rewards, self.opponent_rewards], dim=0) + combined_terminals = torch.cat([self.trainer.terminals, self.opponent_terminals], dim=0) + combined_ratio = torch.cat([self.trainer.ratio, torch.ones_like(self.trainer.ratio)], dim=0) + + # Handle NaN and extreme values in observations + if torch.isnan(combined_obs).any(): + nan_before = torch.isnan(combined_obs).sum().item() + combined_obs = torch.nan_to_num(combined_obs, nan=0.0) + debug(1, f'Cleaned NaN in combined_obs: {nan_before} values') + + # Check for extreme values (observations should be in [-1, 1] range) + obs_min, obs_max = combined_obs.min().item(), combined_obs.max().item() + if obs_min < -100 or obs_max > 100: + debug(1, f'WARNING: Extreme obs values! range=[{obs_min:.2e}, {obs_max:.2e}]') + # Check player vs opponent buffers + p_min, p_max = self.trainer.observations.min().item(), self.trainer.observations.max().item() + o_min, o_max = self.opponent_obs.min().item(), self.opponent_obs.max().item() + debug(1, f' player_obs range: [{p_min:.2e}, {p_max:.2e}]') + debug(1, f' opponent_obs range: [{o_min:.2e}, {o_max:.2e}]') + # Clamp to sane range + combined_obs = torch.clamp(combined_obs, -10, 10) + debug(1, f' Clamped to [-10, 10]') + + total_segments = combined_values.shape[0] + + b0 = config['prio_beta0'] + a = config['prio_alpha'] + clip_coef = config['clip_coef'] + vf_clip = config['vf_clip_coef'] + anneal_beta = b0 + (1 - b0)*a*epoch/self.trainer.total_epochs + + # Reset ratio for combined experience + self.trainer.ratio[:] = 1 + + debug(2, f'train: total_minibatches={self.trainer.total_minibatches}, combined shape={combined_values.shape}') + + # Check for NaN in combined buffers + obs_nan = torch.isnan(combined_obs).sum().item() + if obs_nan > 0: + debug(1, f'WARNING: NaN in combined_obs: {obs_nan} values') + # Find which buffer has NaN + p_nan = torch.isnan(self.trainer.observations).sum().item() + o_nan = torch.isnan(self.opponent_obs).sum().item() + debug(1, f' player_obs NaN: {p_nan}, opponent_obs NaN: {o_nan}') + + # Check for NaN in rewards (opposite signs could cause issues) + reward_nan = torch.isnan(combined_rewards).sum().item() + if reward_nan > 0: + debug(1, f'WARNING: NaN in combined_rewards: {reward_nan} values') + + for mb in range(self.trainer.total_minibatches): + if mb % 10 == 0: + debug(3, f'train minibatch {mb}/{self.trainer.total_minibatches}') + profile('train_misc', epoch) + self.trainer.amp_context.__enter__() + + shape = combined_values.shape + advantages = torch.zeros(shape, device=device) + advantages = pufferl.compute_puff_advantage(combined_values, combined_rewards, + combined_terminals, combined_ratio, advantages, config['gamma'], + config['gae_lambda'], config['vtrace_rho_clip'], config['vtrace_c_clip']) + + # Prioritize experience by advantage magnitude + adv = advantages.abs().sum(axis=1) + prio_weights = torch.nan_to_num(adv**a, 0, 0, 0) + prio_probs = (prio_weights + 1e-6)/(prio_weights.sum() + 1e-6) + + # Sample from combined segments (2x as many) + idx = torch.multinomial(prio_probs, self.trainer.minibatch_segments) + mb_prio = (total_segments*prio_probs[idx, None])**-anneal_beta + + profile('train_copy', epoch) + mb_obs = combined_obs[idx] + mb_actions = combined_actions[idx] + mb_logprobs = combined_logprobs[idx] + mb_rewards = combined_rewards[idx] + mb_terminals = combined_terminals[idx] + mb_ratio = combined_ratio[idx] + mb_values = combined_values[idx] + mb_returns = advantages[idx] + mb_values + mb_advantages = advantages[idx] + + profile('train_forward', epoch) + if not config['use_rnn']: + mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) + + state = dict( + action=mb_actions, + lstm_h=None, + lstm_c=None, + ) + + # Forward pass through LEARNER policy (gets gradients from BOTH perspectives) + if torch.isnan(mb_obs).any(): + debug(1, f'ERROR: NaN in mb_obs before forward pass! mb={mb}') + debug(1, f' mb_obs shape: {mb_obs.shape}, NaN count: {torch.isnan(mb_obs).sum().item()}') + + # Check policy weights for NaN before forward + for name, p in self.trainer.policy.named_parameters(): + if torch.isnan(p).any(): + debug(1, f'ERROR: NaN in policy param {name} BEFORE forward!') + break + + try: + logits, newvalue = self.trainer.policy(mb_obs, state) + except ValueError as e: + debug(1, f'ERROR in forward pass: {e}') + # Check all the inputs + debug(1, f' mb_obs: shape={mb_obs.shape}, nan={torch.isnan(mb_obs).sum().item()}, inf={torch.isinf(mb_obs).sum().item()}') + debug(1, f' mb_obs range: [{mb_obs.min().item():.4f}, {mb_obs.max().item():.4f}]') + # Check hidden state from LSTM + if 'hidden' in state: + h = state['hidden'] + debug(1, f' hidden: nan={torch.isnan(h).sum().item()}, inf={torch.isinf(h).sum().item()}') + # Re-raise to stop training + raise + + # Check if logits contain NaN + if isinstance(logits, torch.distributions.Normal): + if torch.isnan(logits.loc).any() or torch.isnan(logits.scale).any(): + debug(1, f'ERROR: NaN in logits! loc_nan={torch.isnan(logits.loc).sum().item()}, scale_nan={torch.isnan(logits.scale).sum().item()}') + elif torch.isnan(logits).any(): + debug(1, f'ERROR: NaN in logits tensor!') + actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=mb_actions) + + profile('train_misc', epoch) + newlogprob = newlogprob.reshape(mb_logprobs.shape) + logratio = newlogprob - mb_logprobs + ratio = logratio.exp() + combined_ratio[idx] = ratio.detach() + + with torch.no_grad(): + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfrac = ((ratio - 1.0).abs() > config['clip_coef']).float().mean() + + # Weight advantages by priority and normalize + adv = mb_advantages + adv = mb_prio * (adv - adv.mean()) / (adv.std() + 1e-8) + + # PPO losses + pg_loss1 = -adv * ratio + pg_loss2 = -adv * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + newvalue = newvalue.view(mb_returns.shape) + v_clipped = mb_values + torch.clamp(newvalue - mb_values, -vf_clip, vf_clip) + v_loss_unclipped = (newvalue - mb_returns) ** 2 + v_loss_clipped = (v_clipped - mb_returns) ** 2 + v_loss = 0.5*torch.max(v_loss_unclipped, v_loss_clipped).mean() + + entropy_loss = entropy.mean() + + loss = pg_loss + config['vf_coef']*v_loss - config['ent_coef']*entropy_loss + self.trainer.amp_context.__enter__() # TODO: AMP needs debugging + + # Update values for combined buffer (only player portion used for priority) + combined_values[idx] = newvalue.detach().float() + + # Logging + profile('train_misc', epoch) + losses['policy_loss'] += pg_loss.item() / self.trainer.total_minibatches + losses['value_loss'] += v_loss.item() / self.trainer.total_minibatches + losses['entropy'] += entropy_loss.item() / self.trainer.total_minibatches + losses['old_approx_kl'] += old_approx_kl.item() / self.trainer.total_minibatches + losses['approx_kl'] += approx_kl.item() / self.trainer.total_minibatches + losses['clipfrac'] += clipfrac.item() / self.trainer.total_minibatches + losses['importance'] += ratio.mean().item() / self.trainer.total_minibatches + + # Learn on accumulated minibatches + profile('learn', epoch) + loss.backward() + if (mb + 1) % self.trainer.accumulate_minibatches == 0: + # Check for NaN in gradients before stepping + grad_nan = False + for name, p in self.trainer.policy.named_parameters(): + if p.grad is not None and torch.isnan(p.grad).any(): + debug(1, f'WARNING: NaN grad in {name}') + grad_nan = True + if grad_nan: + debug(1, f'NaN gradient detected at mb {mb}, skipping optimizer step') + self.trainer.optimizer.zero_grad() + continue + + torch.nn.utils.clip_grad_norm_(self.trainer.policy.parameters(), config['max_grad_norm']) + self.trainer.optimizer.step() + self.trainer.optimizer.zero_grad() + + # Check for NaN in weights after stepping + for name, p in self.trainer.policy.named_parameters(): + if torch.isnan(p).any(): + debug(1, f'WARNING: NaN weight in {name} after step') + break + + # Update learning rate scheduler + profile('train_misc', epoch) + if config['anneal_lr']: + self.trainer.scheduler.step() + + y_pred = combined_values.flatten() + y_true = advantages.flatten() + combined_values.flatten() + var_y = y_true.var() + explained_var = torch.nan if var_y == 0 else (1 - (y_true - y_pred).var() / var_y).item() + losses['explained_variance'] = explained_var + + # Add dual self-play specific metrics + losses['dual_selfplay'] = 1.0 # Flag for logging + + profile.end() + logs = None + self.trainer.epoch += 1 + done_training = self.trainer.global_step >= config['total_timesteps'] + if done_training or self.trainer.global_step == 0 or time.time() > self.trainer.last_log_time + 0.25: + logs = self.trainer.mean_and_log() + self.trainer.losses = losses + self.trainer.print_dashboard() + self.trainer.stats = defaultdict(list) + self.trainer.last_log_time = time.time() + self.trainer.last_log_step = self.trainer.global_step + profile.clear() + + if self.trainer.epoch % config['checkpoint_interval'] == 0 or done_training: + self.trainer.save_checkpoint() + self.trainer.msg = f'Checkpoint saved at update {self.trainer.epoch}' + + return logs + + @property + def global_step(self): + return self.trainer.global_step + + @property + def epoch(self): + return self.trainer.epoch + + def close(self): + return self.trainer.close() + + +def eval_selfplay(env_name, args, player_path, opponent_path, load_id=None): + """Evaluate player policy against opponent checkpoint with rendering. + + This is like pufferl.eval() but with dual policy inference: + - Player uses the loaded model (from player_path or load_id) + - Opponent uses a checkpoint from opponent_path + + Args: + env_name: Environment name ('puffer_dogfight') + args: Config args dict + player_path: Path to player model weights (or None if using load_id) + opponent_path: Path to opponent checkpoint file + load_id: Optional wandb/neptune run ID to load player model from + """ + from pufferlib.ocean.dogfight import binding + + # Force Serial backend with single env for eval + backend = args['vec'].get('backend', 'Serial') + if backend != 'PufferEnv': + backend = 'Serial' + args['vec'] = dict(backend=backend, num_envs=1) + args['env']['num_envs'] = 1 # Also set internal agent count to 1 + + # Enable eval spawn mode: truly random positions, angles, alternating advantages + args['env']['curriculum_randomize'] = 1 + print(f'[EVAL-SELFPLAY] Enabled curriculum_randomize for varied spawn positions') + + # Create environment + vecenv = pufferl.load_env(env_name, args) + + # Load player policy + # Set load_model_path/load_id so load_policy picks it up + if player_path: + args['load_model_path'] = player_path + if load_id: + args['load_id'] = load_id + + player_policy = pufferl.load_policy(args, vecenv, env_name) + player_policy.eval() + + # Create opponent policy (same architecture, different weights) + device = args['train']['device'] + opponent_policy = copy.deepcopy(player_policy) + + # Load opponent weights from checkpoint + checkpoint = torch.load(opponent_path, map_location=device) + if 'policy_state_dict' in checkpoint: + # Our checkpoint format + opponent_policy.load_state_dict(checkpoint['policy_state_dict']) + tag = checkpoint.get('tag', 'unknown') + step = checkpoint.get('step', 0) + print(f'[EVAL-SELFPLAY] Loaded opponent from {tag} (step {step}): {opponent_path}') + else: + # Raw state dict format + state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()} + opponent_policy.load_state_dict(state_dict) + print(f'[EVAL-SELFPLAY] Loaded opponent from: {opponent_path}') + + opponent_policy.eval() + for p in opponent_policy.parameters(): + p.requires_grad = False + + # Get driver env for rendering and C-level access + driver = vecenv.driver_env + num_agents = vecenv.num_envs # Actual batch size, not observation dimension + + # Enable opponent override in C code (so autopilot doesn't control opponent) + binding.vec_enable_opponent_override(driver.c_envs, 1) + print(f'[EVAL-SELFPLAY] Enabled opponent override (neural network opponent)') + + # Initialize LSTM states if using RNN + state_p = {} + state_o = {} + if args['train']['use_rnn']: + state_p = dict( + lstm_h=torch.zeros(num_agents, player_policy.hidden_size, device=device), + lstm_c=torch.zeros(num_agents, player_policy.hidden_size, device=device), + ) + state_o = dict( + lstm_h=torch.zeros(num_agents, opponent_policy.hidden_size, device=device), + lstm_c=torch.zeros(num_agents, opponent_policy.hidden_size, device=device), + ) + + # Reset environment with time-based seed for variety + seed = int(time.time_ns() % 2**31) + ob, info = vecenv.reset(seed=seed) + + frames = [] + episode_count = 0 + step_count = 0 + + # Get config values with sensible defaults + save_frames = args.get('save_frames', 0) + fps = args.get('fps', 15) + gif_path = args.get('gif_path', 'selfplay_eval.gif') + + print(f'[EVAL-SELFPLAY] Starting evaluation loop (press ESC to exit)') + print(f'[EVAL-SELFPLAY] Player: {player_path or load_id}') + print(f'[EVAL-SELFPLAY] Opponent: {opponent_path}') + + while True: + # Render + render = driver.render() + if len(frames) < save_frames: + frames.append(render) + + # Handle different render modes + if driver.render_mode == 'ansi': + print('\033[0;0H' + render + '\n') + time.sleep(1 / fps) + elif driver.render_mode == 'rgb_array': + # raylib handles its own display, but we need to throttle + time.sleep(1 / fps) + + # Get player observation + ob_tensor = torch.as_tensor(ob).to(device) + + # Get opponent observation from C - slice to match actual num_agents + # (C returns observations for all env slots, but we only use first num_agents) + ob_opponent_np = binding.vec_get_opponent_observations(driver.c_envs) + ob_opponent_np = ob_opponent_np[:num_agents] # Only take first num_agents rows + ob_opponent = torch.as_tensor(ob_opponent_np).to(device) + + # Handle NaN in opponent observations (can occur at boundaries) + if torch.isnan(ob_opponent).any(): + ob_opponent = torch.nan_to_num(ob_opponent, nan=0.0) + + with torch.no_grad(): + # Player forward pass + logits_p, value_p = player_policy.forward_eval(ob_tensor, state_p) + action_p, logprob_p, _ = pufferlib.pytorch.sample_logits(logits_p) + action_p_np = action_p.cpu().numpy().reshape(vecenv.action_space.shape) + + # Opponent forward pass + logits_o, value_o = opponent_policy.forward_eval(ob_opponent, state_o) + action_o, logprob_o, _ = pufferlib.pytorch.sample_logits(logits_o) + # Keep 2D shape (num_agents, action_dim) - don't reshape to lose batch dimension! + action_o_np = action_o.cpu().numpy().astype(np.float32) + + # Clip actions for continuous action space + if isinstance(logits_p, torch.distributions.Normal): + action_p_np = np.clip(action_p_np, vecenv.action_space.low, vecenv.action_space.high) + action_o_np = np.clip(action_o_np, vecenv.action_space.low, vecenv.action_space.high) + + # Debug: print BOTH observations and actions every 60 steps + if step_count % 60 == 0: + obs_p = ob.flatten() if hasattr(ob, 'flatten') else ob[0] + obs_o = ob_opponent_np[0] if len(ob_opponent_np.shape) > 1 else ob_opponent_np + act_p = action_p_np.flatten() + act_o = action_o_np.flatten() + # Obs scheme 0: [fwd_spd, sideslip, climb, roll_r, pitch_r, yaw_r, + # aoa, altitude, energy, tgt_az, tgt_el, range, closure, + # E_adv, aspect, timer] + print(f'[DEBUG] step={step_count}') + print(f' PLAYER OBS: tgt_az={obs_p[9]:.2f} tgt_el={obs_p[10]:.2f} range={obs_p[11]:.2f} closure={obs_p[12]:.2f} aspect={obs_p[14]:.2f}') + print(f' PLAYER ACT: throttle={act_p[0]:.2f} elev={act_p[1]:.2f} ail={act_p[2]:.2f} rud={act_p[3]:.2f} trig={act_p[4]:.2f}') + print(f' OPPON OBS: tgt_az={obs_o[9]:.2f} tgt_el={obs_o[10]:.2f} range={obs_o[11]:.2f} closure={obs_o[12]:.2f} aspect={obs_o[14]:.2f}') + print(f' OPPON ACT: throttle={act_o[0]:.2f} elev={act_o[1]:.2f} ail={act_o[2]:.2f} rud={act_o[3]:.2f} trig={act_o[4]:.2f}') + + # Set opponent actions in C (before stepping) - already correct 2D shape (num_agents, 5) + binding.vec_set_opponent_actions(driver.c_envs, action_o_np) + + # Step environment with player action + ob, reward, terminated, truncated, info = vecenv.step(action_p_np) + step_count += 1 + + # Check for episode end + if terminated.any() or truncated.any(): + episode_count += 1 + if episode_count % 10 == 0: + print(f'[EVAL-SELFPLAY] Episode {episode_count} completed (step {step_count})') + + # Reset with time-based seed for variety in next episode + seed = int(time.time_ns() % 2**31) + ob, info = vecenv.reset(seed=seed) + + # Reset LSTM states for fresh episode + if args['train']['use_rnn']: + state_p = dict( + lstm_h=torch.zeros(num_agents, player_policy.hidden_size, device=device), + lstm_c=torch.zeros(num_agents, player_policy.hidden_size, device=device), + ) + state_o = dict( + lstm_h=torch.zeros(num_agents, opponent_policy.hidden_size, device=device), + lstm_c=torch.zeros(num_agents, opponent_policy.hidden_size, device=device), + ) + + # Save frames to gif if requested + if len(frames) > 0 and len(frames) == save_frames: + import imageio + imageio.mimsave(gif_path, frames, fps=fps, loop=0) + print(f'[EVAL-SELFPLAY] Saved {len(frames)} frames to {gif_path}') + frames = [] # Reset to allow more recording + + +def main(): + global DEBUG_LEVEL + env_name = 'puffer_dogfight' + + # Check for 'eval' subcommand + if len(sys.argv) > 1 and sys.argv[1] == 'eval': + sys.argv.pop(1) # Remove 'eval' from args + + # Parse eval-specific args + opponent_checkpoint = None + player_path = None + load_id = None + no_rnn = False + + new_argv = [sys.argv[0]] + i = 1 + while i < len(sys.argv): + arg = sys.argv[i] + + # --no-rnn flag + if arg == '--no-rnn': + no_rnn = True + i += 1 + continue + + # --opponent-checkpoint + if arg == '--opponent-checkpoint': + if i + 1 < len(sys.argv): + opponent_checkpoint = sys.argv[i + 1] + i += 2 + continue + elif arg.startswith('--opponent-checkpoint='): + opponent_checkpoint = arg.split('=', 1)[1] + i += 1 + continue + + # --load-model-path (for player) + if arg == '--load-model-path': + if i + 1 < len(sys.argv): + player_path = sys.argv[i + 1] + i += 2 + continue + elif arg.startswith('--load-model-path='): + player_path = arg.split('=', 1)[1] + i += 1 + continue + + # --load-id (for player, from wandb/neptune) + if arg == '--load-id': + if i + 1 < len(sys.argv): + load_id = sys.argv[i + 1] + i += 2 + continue + elif arg.startswith('--load-id='): + load_id = arg.split('=', 1)[1] + i += 1 + continue + + new_argv.append(arg) + i += 1 + + sys.argv = new_argv + + # Validate required args + if opponent_checkpoint is None: + print('Error: --opponent-checkpoint is required for eval mode') + print('Usage: python train_dual_selfplay.py eval --opponent-checkpoint [--load-model-path | --load-id ]') + sys.exit(1) + + if player_path is None and load_id is None: + print('Error: Either --load-model-path or --load-id is required for eval mode') + print('Usage: python train_dual_selfplay.py eval --opponent-checkpoint [--load-model-path | --load-id ]') + sys.exit(1) + + # Load config + args = pufferl.load_config(env_name) + + # Override RNN if requested + if no_rnn: + args['rnn_name'] = None + args['train']['use_rnn'] = False + print(f'[EVAL-SELFPLAY] RNN disabled (--no-rnn flag)') + + print(f'[EVAL-SELFPLAY] Starting eval mode') + print(f'[EVAL-SELFPLAY] Player model: {player_path or load_id}') + print(f'[EVAL-SELFPLAY] Opponent checkpoint: {opponent_checkpoint}') + + # Run eval + eval_selfplay(env_name, args, player_path, opponent_checkpoint, load_id) + return + + # Extract custom args before pufferl parses + opponent_update_interval = DEFAULT_OPPONENT_UPDATE_INTERVAL + selfplay_min_stage = DEFAULT_SELFPLAY_MIN_STAGE + checkpoint_lag = DEFAULT_CHECKPOINT_LAG + perf_threshold = DEFAULT_PERF_THRESHOLD + min_steps_between_checkpoints = DEFAULT_MIN_STEPS_BETWEEN_CHECKPOINTS + max_checkpoints = DEFAULT_MAX_CHECKPOINTS + checkpoint_dir = None + + def parse_int_arg(args, i, arg_name): + """Parse --arg value or --arg=value format, return (value, new_i) or None.""" + if args[i] == f'--{arg_name}': + if i + 1 < len(args): + return int(args[i + 1]), i + 2 + elif args[i].startswith(f'--{arg_name}='): + return int(args[i].split('=', 1)[1]), i + 1 + return None + + def parse_float_arg(args, i, arg_name): + """Parse --arg value or --arg=value format, return (value, new_i) or None.""" + if args[i] == f'--{arg_name}': + if i + 1 < len(args): + return float(args[i + 1]), i + 2 + elif args[i].startswith(f'--{arg_name}='): + return float(args[i].split('=', 1)[1]), i + 1 + return None + + def parse_str_arg(args, i, arg_name): + """Parse --arg value or --arg=value format, return (value, new_i) or None.""" + if args[i] == f'--{arg_name}': + if i + 1 < len(args): + return args[i + 1], i + 2 + elif args[i].startswith(f'--{arg_name}='): + return args[i].split('=', 1)[1], i + 1 + return None + + new_argv = [] + i = 0 + while i < len(sys.argv): + # Legacy: opponent-update-interval (kept for compatibility) + result = parse_int_arg(sys.argv, i, 'opponent-update-interval') + if result: + opponent_update_interval = result[0] + i = result[1] + continue + + result = parse_int_arg(sys.argv, i, 'selfplay-min-stage') + if result: + selfplay_min_stage = result[0] + i = result[1] + continue + + # New checkpoint queue args + result = parse_int_arg(sys.argv, i, 'checkpoint-lag') + if result: + checkpoint_lag = result[0] + i = result[1] + continue + + result = parse_float_arg(sys.argv, i, 'perf-threshold') + if result: + perf_threshold = result[0] + i = result[1] + continue + + result = parse_int_arg(sys.argv, i, 'min-checkpoint-gap') + if result: + min_steps_between_checkpoints = result[0] + i = result[1] + continue + + result = parse_int_arg(sys.argv, i, 'max-checkpoints') + if result: + max_checkpoints = result[0] + i = result[1] + continue + + result = parse_str_arg(sys.argv, i, 'checkpoint-dir') + if result: + checkpoint_dir = result[0] + i = result[1] + continue + + if sys.argv[i] == '--debug': + DEBUG_LEVEL = 2 + i += 1 + continue + elif sys.argv[i] == '--debug-verbose': + DEBUG_LEVEL = 3 + i += 1 + continue + + new_argv.append(sys.argv[i]) + i += 1 + sys.argv = new_argv + + # Load standard dogfight config + args = pufferl.load_config(env_name) + + # NOTE: Dual self-play now works with Multiprocessing backend! + # Opponent observations and rewards are computed in C during c_step() + # and written to shared memory buffers, enabling parallel workers. + backend = args['vec'].get('backend', 'Multiprocessing') + print(f'[DUAL-SELFPLAY] Using {backend} backend with C-level opponent buffers') + + # Create environment using standard pufferl flow + vecenv = pufferl.load_env(env_name, args) + + # Create policy + policy = pufferl.load_policy(args, vecenv, env_name) + + # Create logger if requested + logger = None + run_id = None + if args['neptune']: + logger = pufferl.NeptuneLogger(args) + elif args['wandb']: + logger = pufferl.WandbLogger(args) + # Use wandb run ID for checkpoint directory + if hasattr(logger, 'run') and logger.run: + run_id = logger.run.id + + # Create dual-perspective trainer with checkpoint queue + train_config = {**args['train'], 'env': env_name} + trainer = DualPerspectiveTrainer( + train_config, vecenv, policy, logger, + opponent_update_interval=opponent_update_interval, + selfplay_min_stage=selfplay_min_stage, + checkpoint_lag=checkpoint_lag, + perf_threshold=perf_threshold, + min_steps_between_checkpoints=min_steps_between_checkpoints, + max_checkpoints=max_checkpoints, + checkpoint_dir=checkpoint_dir, + run_id=run_id + ) + + print(f'[DUAL-SELFPLAY] Starting training with checkpoint queue') + print(f'[DUAL-SELFPLAY] Min stage for self-play: {selfplay_min_stage}') + print(f'[DUAL-SELFPLAY] Checkpoint lag: {checkpoint_lag} (opponent is {checkpoint_lag} checkpoint(s) behind)') + print(f'[DUAL-SELFPLAY] Perf threshold: {perf_threshold} (save checkpoint when perf >= this)') + print(f'[DUAL-SELFPLAY] Min steps between checkpoints: {min_steps_between_checkpoints}') + + # Training loop + while trainer.global_step < train_config['total_timesteps']: + if train_config['device'] == 'cuda': + torch.compiler.cudagraph_mark_step_begin() + trainer.evaluate() + if train_config['device'] == 'cuda': + torch.compiler.cudagraph_mark_step_begin() + logs = trainer.train() + + # Log dual self-play status periodically + if trainer.epoch % 100 == 0 and trainer.epoch > 0: + mode = "DUAL" if trainer.use_dual_selfplay else "CURRICULUM" + queue_len = len(trainer.checkpoint_queue) + opponent_entry = trainer.checkpoint_queue.get_opponent_entry(trainer.checkpoint_lag) + opponent_tag = opponent_entry.tag if opponent_entry else "none" + print(f'[DUAL-SELFPLAY] Mode: {mode}, Steps: {trainer.global_step}, ' + f'Queue: {queue_len} checkpoints, Opponent: {opponent_tag}') + + # Cleanup + model_path = trainer.close() + if logger: + logger.close(model_path) + + print(f'[DUAL-SELFPLAY] Training complete') + + +if __name__ == '__main__': + main() diff --git a/pufferlib/ocean/dogfight/train_selfplay.py b/pufferlib/ocean/dogfight/train_selfplay.py new file mode 100644 index 000000000..8f7d44f4f --- /dev/null +++ b/pufferlib/ocean/dogfight/train_selfplay.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +"""Self-play training with PolicyPool for Dogfight. + +This script wires up the PolicyPool for self-play opponent selection. +When the curriculum system achieves mastery of a stage, checkpoints are +automatically saved to the pool for future opponent selection. + +Usage: + # Basic self-play training + python pufferlib/ocean/dogfight/train_selfplay.py + + # With wandb logging + python pufferlib/ocean/dogfight/train_selfplay.py --wandb --wandb-project dogfight-selfplay + + # Resume from pool with existing checkpoints + python pufferlib/ocean/dogfight/train_selfplay.py --pool-dir experiments/dogfight_pool + + # Start with empty pool (default) + python pufferlib/ocean/dogfight/train_selfplay.py +""" +import os +import sys +import argparse + +import pufferlib +import pufferlib.vector +from pufferlib import pufferl +from pufferlib.policy_pool import PolicyPool, setup_pool_callback +from pufferlib.ocean.dogfight.dogfight import Dogfight + + +# Default pool directory +DEFAULT_POOL_DIR = 'experiments/dogfight_pool' + + +def main(): + env_name = 'puffer_dogfight' + + # Extract --pool-dir from sys.argv before pufferl parses + pool_dir = DEFAULT_POOL_DIR + new_argv = [] + i = 0 + while i < len(sys.argv): + if sys.argv[i] == '--pool-dir': + if i + 1 < len(sys.argv): + pool_dir = sys.argv[i + 1] + i += 2 + continue + elif sys.argv[i].startswith('--pool-dir='): + pool_dir = sys.argv[i].split('=', 1)[1] + i += 1 + continue + new_argv.append(sys.argv[i]) + i += 1 + sys.argv = new_argv + + # Load standard dogfight config (now without --pool-dir) + args = pufferl.load_config(env_name) + + # Extract obs_scheme from env config + obs_scheme = args['env'].get('obs_scheme', 1) + + # Create pool directory if needed + os.makedirs(pool_dir, exist_ok=True) + + # Create PolicyPool with matching obs_scheme + pool = PolicyPool(pool_dir, obs_scheme=obs_scheme) + print(f'[SELFPLAY] PolicyPool created: {pool_dir} (obs_scheme={obs_scheme}, {len(pool)} entries)') + + # Add pool to env kwargs + args['env']['policy_pool'] = pool + + # Create environment using standard pufferl flow + vecenv = pufferl.load_env(env_name, args) + + # Create policy + policy = pufferl.load_policy(args, vecenv, env_name) + + # Create logger if requested + logger = None + if args['neptune']: + logger = pufferl.NeptuneLogger(args) + elif args['wandb']: + logger = pufferl.WandbLogger(args) + + # Create trainer + train_config = {**args['train'], 'env': env_name} + trainer = pufferl.PuffeRL(train_config, vecenv, policy, logger) + + # Wire up pool callback (checkpoints saved on mastery) + # vecenv.driver_env is the unwrapped Dogfight instance + setup_pool_callback(vecenv.driver_env, trainer, pool) + print(f'[SELFPLAY] Pool callback wired to trainer') + + # Standard training loop + while trainer.global_step < train_config['total_timesteps']: + if train_config['device'] == 'cuda': + import torch + torch.compiler.cudagraph_mark_step_begin() + trainer.evaluate() + if train_config['device'] == 'cuda': + import torch + torch.compiler.cudagraph_mark_step_begin() + trainer.train() + + # Log pool status periodically + if trainer.epoch % 100 == 0 and trainer.epoch > 0: + pool_size = len(pool) + if pool_size > 0: + print(f'[SELFPLAY] Pool size: {pool_size} checkpoints') + + # Cleanup + model_path = trainer.close() + if logger: + logger.close(model_path) + + print(f'[SELFPLAY] Training complete. Final pool size: {len(pool)}') + print(f'[SELFPLAY] Pool directory: {pool_dir}') + + +if __name__ == '__main__': + main() diff --git a/pufferlib/ocean/environment.py b/pufferlib/ocean/environment.py index 93df76506..70cd2d988 100644 --- a/pufferlib/ocean/environment.py +++ b/pufferlib/ocean/environment.py @@ -122,6 +122,7 @@ def make_multiagent(buf=None, **kwargs): 'blastar': 'Blastar', 'convert': 'Convert', 'convert_circle': 'ConvertCircle', + 'dogfight': 'Dogfight', 'pong': 'Pong', 'freeway': 'Freeway', 'enduro': 'Enduro', diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 8cf4ffe7d..5d9d0e4d9 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -299,7 +299,7 @@ def decode_actions(self, flat_hidden, state=None): value = self.value_fn(flat_hidden) if self.is_continuous: mean = self.decoder_mean(flat_hidden) - logstd = self.decoder_logstd.expand_as(mean) + logstd = self.decoder_logstd.expand_as(mean).clamp(min=-20, max=2) std = torch.exp(logstd) probs = torch.distributions.Normal(mean, std) batch = flat_hidden.shape[0] @@ -433,7 +433,7 @@ def decode_actions(self, flat_hidden): value = self.value_fn(flat_hidden) if self.is_continuous: mean = self.decoder_mean(flat_hidden) - logstd = self.decoder_logstd.expand_as(mean) + logstd = self.decoder_logstd.expand_as(mean).clamp(min=-20, max=2) std = torch.exp(logstd) probs = torch.distributions.Normal(mean, std) batch = flat_hidden.shape[0] @@ -893,7 +893,7 @@ def decode_actions(self, hidden): logits = self.decoder(hidden).split(self.action_nvec, dim=1) elif self.is_continuous: mean = self.decoder_mean(hidden) - logstd = self.decoder_logstd.expand_as(mean) + logstd = self.decoder_logstd.expand_as(mean).clamp(min=-20, max=2) std = torch.exp(logstd) logits = torch.distributions.Normal(mean, std) else: diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 3972e722f..f1f0bf536 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -939,6 +939,7 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, should_sto logger = WandbLogger(args) train_config = { **args['train'], 'env': env_name } + pufferl = PuffeRL(train_config, vecenv, policy, logger) all_logs = [] @@ -1041,6 +1042,11 @@ def sweep(args=None, env_name=None): raise pufferlib.APIUsageError('Sweeps require either wandb or neptune') method = args['sweep'].pop('method') + + project = args.get('wandb_project', args.get('neptune_project', 'sweep')) + args['sweep'].setdefault('state_file', f'{project}_sweep.json') + args['sweep'].setdefault('override_file', f'{project}_override.json') + try: sweep_cls = getattr(pufferlib.sweep, method) except: diff --git a/pufferlib/sweep.py b/pufferlib/sweep.py index 41401b6f6..5767e7718 100644 --- a/pufferlib/sweep.py +++ b/pufferlib/sweep.py @@ -1,6 +1,8 @@ import random import math import warnings +import os +import json from copy import deepcopy from contextlib import contextmanager @@ -129,7 +131,8 @@ def unnormalize(self, value): def _params_from_puffer_sweep(sweep_config): param_spaces = {} for name, param in sweep_config.items(): - if name in ('method', 'metric', 'goal', 'downsample', 'use_gpu', 'prune_pareto'): + if name in ('method', 'metric', 'goal', 'downsample', 'use_gpu', 'prune_pareto', + 'state_file', 'override_file'): continue assert isinstance(param, dict) @@ -467,6 +470,9 @@ def __init__(self, self.gp_max_obs = gp_max_obs # train time bumps after 800? self.infer_batch_size = infer_batch_size + self.state_file = sweep_config.get('state_file', 'sweep_state.json') + self.override_file = sweep_config.get('override_file', 'override.json') + # Use 64 bit for GP regression with default_tensor_dtype(torch.float64): # Params taken from HEBO: https://arxiv.org/abs/2012.03826 @@ -493,6 +499,117 @@ def __init__(self, self.gp_cost_buffer = torch.empty(self.gp_max_obs, device=self.device) self.infer_batch_buffer = torch.empty(self.infer_batch_size, self.hyperparameters.num, device=self.device) + self._load_state_if_exists() + + @staticmethod + def _json_default(obj): + """JSON serializer for numpy types.""" + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, (np.floating, np.integer)): + return obj.item() + raise TypeError(f'Not JSON serializable: {type(obj)}') + + def _save_state(self): + """Save sweep state to JSON for crash recovery.""" + state = { + 'suggestion_idx': self.suggestion_idx, + 'success_observations': self.success_observations, + 'failure_observations': self.failure_observations, + 'min_score': self.min_score if self.min_score != math.inf else None, + 'max_score': self.max_score if self.max_score != -math.inf else None, + 'log_c_min': self.log_c_min if self.log_c_min != math.inf else None, + 'log_c_max': self.log_c_max if self.log_c_max != -math.inf else None, + } + tmp = f'{self.state_file}.tmp' + try: + with open(tmp, 'w') as f: + json.dump(state, f, indent=2, default=self._json_default) + os.replace(tmp, self.state_file) + except OSError as e: + print(f'[Protein] Failed to save state: {e}') + if os.path.exists(tmp): + os.remove(tmp) + + def _load_state_if_exists(self): + """Load state from previous run if exists (crash recovery).""" + tmp = f'{self.state_file}.tmp' + if os.path.exists(tmp): + os.remove(tmp) + if not os.path.exists(self.state_file): + return + try: + with open(self.state_file) as f: + state = json.load(f) + self.suggestion_idx = state.get('suggestion_idx', 0) + self.success_observations = state.get('success_observations', []) + self.failure_observations = state.get('failure_observations', []) + if state.get('min_score') is not None: + self.min_score = state['min_score'] + if state.get('max_score') is not None: + self.max_score = state['max_score'] + if state.get('log_c_min') is not None: + self.log_c_min = state['log_c_min'] + if state.get('log_c_max') is not None: + self.log_c_max = state['log_c_max'] + for obs in self.success_observations + self.failure_observations: + if isinstance(obs['input'], list): + obs['input'] = np.array(obs['input']) + print(f'[Protein] Resumed from {self.state_file}: {len(self.success_observations)} obs, idx={self.suggestion_idx}') + except (json.JSONDecodeError, KeyError, FileNotFoundError, OSError) as e: + print(f'[Protein] Failed to load state: {e}') + + def _check_override(self): + """Check for override. Returns None, 'skip', or params dict.""" + if not os.path.exists(self.override_file): + return None + tmp = f'{self.override_file}.tmp' + try: + with open(self.override_file) as f: + data = json.load(f) + if 'suggestions' not in data or not data['suggestions']: + os.remove(self.override_file) + return None + suggestion = data['suggestions'].pop(0) + if data['suggestions']: + with open(tmp, 'w') as f: + json.dump(data, f, indent=2) + os.replace(tmp, self.override_file) + else: + os.remove(self.override_file) + + # Check for skip flag (spacer runs) + if suggestion.get('skip', False): + reason = suggestion.get('reason', 'Spacer - skip override') + print(f'[Protein] SKIP: {reason}') + return 'skip' + + reason = suggestion.get('reason', 'No reason provided') + print(f'[Protein] OVERRIDE: {reason}') + return suggestion.get('params', suggestion) + except (json.JSONDecodeError, KeyError) as e: + print(f'[Protein] Invalid override file: {e}') + if os.path.exists(self.override_file): + os.remove(self.override_file) + return None + except OSError as e: + print(f'[Protein] Failed to update override file: {e}') + if os.path.exists(tmp): + os.remove(tmp) + return None + + def _apply_params_to_fill(self, fill, params): + """Apply param dict to fill in place. Modifies fill directly.""" + for key, value in pufferlib.unroll_nested_dict(params): + parts = key.split('/') + try: + target = fill + for part in parts[:-1]: + target = target[part] + target[parts[-1]] = value + except KeyError: + print(f'[Protein] Override key not found: {key}') + def _filter_near_duplicates(self, inputs, duplicate_threshold=EPSILON): if len(inputs) < 2: return np.arange(len(inputs)) @@ -571,27 +688,9 @@ def _train_gp_models(self): return score_loss, cost_loss - def suggest(self, fill): + def _gp_suggest(self, fill): + """Generate suggestion using Gaussian Process optimization.""" info = {} - self.suggestion_idx += 1 - if len(self.success_observations) == 0 and self.seed_with_search_center: - suggestion = self.hyperparameters.search_centers - return self.hyperparameters.to_dict(suggestion, fill), info - - elif len(self.success_observations) < self.num_random_samples: - # Suggest the next point in the Sobol sequence - zero_one = self.sobol.random(1)[0] - suggestion = 2*zero_one - 1 # Scale from [0, 1) to [-1, 1) - cost_suggestion = self.cost_random_suggestion + 0.1 * np.random.randn() - suggestion[self.cost_param_idx] = np.clip(cost_suggestion, -1, 1) # limit the cost - return self.hyperparameters.to_dict(suggestion, fill), info - - elif self.resample_frequency and self.suggestion_idx % self.resample_frequency == 0: - candidates, _ = pareto_points(self.success_observations) - suggestions = np.stack([e['input'] for e in candidates]) - best_idx = np.random.randint(0, len(candidates)) - best = suggestions[best_idx] - return self.hyperparameters.to_dict(best, fill), info score_loss, cost_loss = self._train_gp_models() @@ -599,7 +698,7 @@ def suggest(self, fill): print(f'Resetting GP optimizers at suggestion {self.suggestion_idx}') self.score_opt = torch.optim.Adam(self.gp_score.parameters(), lr=self.gp_learning_rate, amsgrad=True) self.cost_opt = torch.optim.Adam(self.gp_cost.parameters(), lr=self.gp_learning_rate, amsgrad=True) - + candidates, pareto_idxs = pareto_points(self.success_observations) if self.prune_pareto: @@ -614,13 +713,15 @@ def suggest(self, fill): suggestions = suggestions[dedup_indices] if len(suggestions) == 0: - return self.suggest(fill) # Fallback to random if all suggestions are filtered + # Fallback to search center if all suggestions are filtered + suggestion = self.hyperparameters.search_centers + return self.hyperparameters.to_dict(suggestion, fill), info ### Predict scores and costs # Batch predictions to avoid GPU OOM for large number of suggestions gp_y_norm_list, gp_log_c_norm_list = [], [] - with torch.no_grad(), gpytorch.settings.fast_pred_var(), warnings.catch_warnings(): + with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.cholesky_jitter(1e-4), warnings.catch_warnings(): warnings.simplefilter("ignore", gpytorch.utils.warnings.NumericalWarning) # Create a reusable buffer on the device to avoid allocating a huge tensor @@ -639,7 +740,8 @@ def suggest(self, fill): except RuntimeError: # Handle numerical errors during GP prediction - pred_y_mean, pred_c_mean = torch.zeros(current_batch_size) + pred_y_mean = torch.zeros(current_batch_size) + pred_c_mean = torch.zeros(current_batch_size) gp_y_norm_list.append(pred_y_mean.cpu()) gp_log_c_norm_list.append(pred_c_mean.cpu()) @@ -685,6 +787,45 @@ def suggest(self, fill): best = suggestions[best_idx] return self.hyperparameters.to_dict(best, fill), info + def suggest(self, fill): + info = {} + self.suggestion_idx += 1 + override = self._check_override() + + # Always generate a suggestion (GP, Sobol, or search center) + if len(self.success_observations) == 0 and self.seed_with_search_center: + suggestion = self.hyperparameters.search_centers + result = self.hyperparameters.to_dict(suggestion, fill) + + elif len(self.success_observations) < self.num_random_samples: + # Sobol sequence for early exploration + zero_one = self.sobol.random(1)[0] + suggestion = 2*zero_one - 1 # Scale from [0, 1) to [-1, 1) + cost_suggestion = self.cost_random_suggestion + 0.1 * np.random.randn() + suggestion[self.cost_param_idx] = np.clip(cost_suggestion, -1, 1) # limit the cost + result = self.hyperparameters.to_dict(suggestion, fill) + + elif self.resample_frequency and self.suggestion_idx % self.resample_frequency == 0: + # Resample from pareto front + candidates, _ = pareto_points(self.success_observations) + suggestions = np.stack([e['input'] for e in candidates]) + best_idx = np.random.randint(0, len(candidates)) + best = suggestions[best_idx] + result = self.hyperparameters.to_dict(best, fill) + + else: + # Full GP suggestion + result, info = self._gp_suggest(fill) + + # Apply override ON TOP of generated suggestion + if override == 'skip': + info['skip'] = True + elif override: + self._apply_params_to_fill(result, override) + info['override'] = True + + return result, info + def observe(self, hypers, score, cost, is_failure=False): params = self.hyperparameters.from_dict(hypers) new_observation = dict( @@ -697,6 +838,7 @@ def observe(self, hypers, score, cost, is_failure=False): if is_failure or not np.isfinite(score) or np.isnan(score): new_observation['is_failure'] = True self.failure_observations.append(new_observation) + self._save_state() return if self.success_observations: @@ -705,6 +847,7 @@ def observe(self, hypers, score, cost, is_failure=False): same = np.where(dist < EPSILON)[0] if len(same) > 0: self.success_observations[same[0]] = new_observation + self._save_state() return # Ignore obs that are below the minimum cost @@ -712,3 +855,83 @@ def observe(self, hypers, score, cost, is_failure=False): return self.success_observations.append(new_observation) + self._save_state() + + +def read_sweep_results(state_file, sweep_config, sort_by='score'): + """ + Load sweep results as user/agent-readable dicts. + + Args: + state_file: Path to {project}_sweep.json + sweep_config: The 'sweep' section from load_config() + sort_by: 'score' (descending), 'cost' (ascending), or None + + Returns: + List of dicts: [{'params': {...}, 'score': float, 'cost': float}, ...] + """ + with open(state_file) as f: + state = json.load(f) + + hyperparams = Hyperparameters(sweep_config, verbose=False) + + results = [] + for obs in state.get('success_observations', []): + input_vec = np.array(obs['input']) + + if len(input_vec) != hyperparams.num: + raise ValueError( + f"State file has {len(input_vec)} dimensions but config has {hyperparams.num}. " + f"Config may have changed since sweep started." + ) + + params = hyperparams.to_dict(input_vec) + flat_params = dict(pufferlib.unroll_nested_dict(params)) + + results.append({ + 'params': flat_params, + 'score': obs['output'], + 'cost': obs['cost'], + }) + + if sort_by == 'score': + results.sort(key=lambda x: x['score'], reverse=True) + elif sort_by == 'cost': + results.sort(key=lambda x: x['cost']) + + return results + + +def create_override(override_file, suggestions, reason=None): + """ + Inject hyperparams into next sweep run. Use real values, not normalized. + + Args: + override_file: Path to write + suggestions: List of dicts, e.g. [{'train/learning_rate': 0.001}] + reason: List of strings (same length as suggestions), or None + """ + if reason is None: + reason = [None] * len(suggestions) + if len(reason) != len(suggestions): + raise ValueError(f"Got {len(suggestions)} suggestions but {len(reason)} reasons") + + data = { + 'suggestions': [ + { + 'params': s, + 'reason': r or 'Programmatic override' + } + for s, r in zip(suggestions, reason) + ] + } + + tmp = f'{override_file}.tmp' + try: + with open(tmp, 'w') as f: + json.dump(data, f, indent=2) + os.replace(tmp, override_file) + except OSError: + if os.path.exists(tmp): + os.remove(tmp) + raise diff --git a/pufferlib/vector.py b/pufferlib/vector.py index 78614f4d6..f545dbc30 100644 --- a/pufferlib/vector.py +++ b/pufferlib/vector.py @@ -82,6 +82,11 @@ def __init__(self, env_creators, env_args, env_kwargs, num_envs, buf=None, seed= masks=self.masks[ptr:end], actions=self.actions[ptr:end] ) + # Include opponent buffers if they exist (for dual self-play) + if buf is not None and 'opponent_observations' in buf: + buf_i['opponent_observations'] = buf['opponent_observations'][ptr:end] + buf_i['opponent_rewards'] = buf['opponent_rewards'][ptr:end] + buf_i['opponent_actions'] = buf['opponent_actions'][ptr:end] ptr = end seed_i = seed + i if seed is not None else None env = env_creators[i](*env_args[i], buf=buf_i, seed=seed_i, **env_kwargs[i]) @@ -187,6 +192,18 @@ def _worker_process(env_creators, env_args, env_kwargs, obs_shape, obs_dtype, at ) buf['masks'][:] = True + # Opponent perspective buffers (for dual self-play) + # These are optional - envs that support opponent buffers will use them + if 'opponent_observations' in shm: + buf['opponent_observations'] = np.ndarray((*shape, *obs_shape), + dtype=obs_dtype, buffer=shm['opponent_observations'])[worker_idx] + buf['opponent_rewards'] = np.ndarray(shape, dtype=np.float32, + buffer=shm['opponent_rewards'])[worker_idx] + buf['opponent_actions'] = np.ndarray((*shape, *atn_shape), + dtype=atn_dtype, buffer=shm['opponent_actions'])[worker_idx] + buf['selfplay_active'] = np.ndarray(1, dtype=np.int8, + buffer=shm['selfplay_active']) # Shared flag, not per-worker + if is_native and num_envs == 1: envs = env_creators[0](*env_args[0], **env_kwargs[0], buf=buf, seed=seed) else: @@ -306,6 +323,14 @@ def __init__(self, env_creators, env_args, env_kwargs, masks=RawArray('b', num_agents), semaphores=RawArray('c', num_workers), notify=RawArray('b', num_workers), + # Opponent perspective buffers (for dual self-play) + # opponent_observations/rewards: Written by C code during step if env supports opponent buffers + # opponent_actions: Written by main process, read by workers before step + # selfplay_active: Flag set by main process when self-play mode is enabled + opponent_observations=RawArray(obs_ctype, num_agents * int(np.prod(obs_shape))), + opponent_rewards=RawArray('f', num_agents), + opponent_actions=RawArray(atn_ctype, num_agents * int(np.prod(atn_shape))), + selfplay_active=RawArray('b', 1), # Single flag: 0=curriculum, 1=selfplay ) shape = (num_workers, agents_per_worker) self.obs_batch_shape = (self.agents_per_batch, *obs_shape) @@ -321,6 +346,13 @@ def __init__(self, env_creators, env_args, env_kwargs, masks=np.ndarray(shape, dtype=bool, buffer=self.shm['masks']), semaphores=np.ndarray(num_workers, dtype=np.uint8, buffer=self.shm['semaphores']), notify=np.ndarray(num_workers, dtype=bool, buffer=self.shm['notify']), + # Opponent perspective buffers (for dual self-play) + opponent_observations=np.ndarray((*shape, *obs_shape), + dtype=obs_dtype, buffer=self.shm['opponent_observations']), + opponent_rewards=np.ndarray(shape, dtype=np.float32, buffer=self.shm['opponent_rewards']), + opponent_actions=np.ndarray((*shape, *atn_shape), + dtype=atn_dtype, buffer=self.shm['opponent_actions']), + selfplay_active=np.ndarray(1, dtype=np.int8, buffer=self.shm['selfplay_active']), ) self.buf['semaphores'][:] = MAIN diff --git a/setup.py b/setup.py index 552cb00e8..9fd3bf9fc 100644 --- a/setup.py +++ b/setup.py @@ -189,14 +189,15 @@ def run(self): # Find C extensions c_extensions = [] if not NO_OCEAN: - c_extension_paths = glob.glob('pufferlib/ocean/**/binding.c', recursive=True) + #c_extension_paths = glob.glob('pufferlib/ocean/**/binding.c', recursive=True) + c_extension_paths = ['pufferlib/ocean/dogfight/binding.c'] c_extensions = [ Extension( path.rstrip('.c').replace('/', '.'), sources=[path], **extension_kwargs, ) - for path in c_extension_paths if 'matsci' not in path + for path in c_extension_paths# if 'matsci' not in path ] c_extension_paths = [os.path.join(*path.split('/')[:-1]) for path in c_extension_paths]