diff --git a/pufferlib/ocean/benchmark/scripts/run_jerk_ablation.py b/pufferlib/ocean/benchmark/scripts/run_jerk_ablation.py new file mode 100755 index 000000000..55f91c886 --- /dev/null +++ b/pufferlib/ocean/benchmark/scripts/run_jerk_ablation.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python3 +""" +Ablation study for JERK dynamics discretization. + +Tests how trajectory approximation quality improves with finer action discretization. +Compares baseline (12 actions) vs expanded (40 actions) configurations. +""" + +import argparse +import os +import json +import csv +from tqdm import tqdm +import numpy as np +import matplotlib.pyplot as plt + +from pufferlib.ocean.benchmark.trajectory_approximation import ( + ClassicDynamics, + JerkDynamics, + load_trajectories_from_directory, + find_valid_segments, + extract_segment, + initialize_classic_state, + initialize_jerk_state, + approximate_trajectory, + compute_all_metrics, + aggregate_metrics, + format_metrics_summary, + compute_kinematic_features, + aggregate_features_across_segments, + compute_kinematic_realism_score, + format_kinematic_summary, + plot_jerk_ablation_trajectory, +) + + +# Define JERK configurations to test +JERK_CONFIGS = { + 'baseline': { + 'jerk_long': [-15.0, -4.0, 0.0, 4.0], + 'jerk_lat': [-4.0, 0.0, 4.0], + 'description': 'Baseline (4x3=12 actions)', + }, + 'expanded': { + 'jerk_long': [-15.0, -6.0, -4.0, -2.0, 0.0, 2.0, 4.0, 6.0], + 'jerk_lat': [-4.0, -2.0, 0.0, 2.0, 4.0], + 'description': 'Expanded (8x5=40 actions)', + }, +} + + +def parse_weights(weight_string): + """Parse comma-separated weights string.""" + try: + pos, vel, heading = map(float, weight_string.split(',')) + return {'pos': pos, 'vel': vel, 'heading': heading} + except: + raise ValueError(f"Invalid weights format: {weight_string}. Expected: pos,vel,heading") + + +def run_configuration(config_name, config, trajectories, args, weights, store_trajectories=False): + """ + Run trajectory approximation for a single JERK configuration. + + Args: + store_trajectories: If True, store full trajectory data for visualization + + Returns: + dict with metrics_list, segments_list, kinematic scores, and optionally trajectory data + """ + print(f"\n{'='*60}") + print(f"Running: {config['description']}") + print(f" Longitudinal jerk: {config['jerk_long']}") + print(f" Lateral jerk: {config['jerk_lat']}") + print(f"{'='*60}") + + # Initialize dynamics model + jerk_dynamics = JerkDynamics( + dt=args.dt, + jerk_long=config['jerk_long'], + jerk_lat=config['jerk_lat'] + ) + print(f" Total actions: {jerk_dynamics.num_actions}") + + # Process trajectories + metrics_list = [] + segments_list = [] + trajectory_data = [] if store_trajectories else None + skipped_stationary = 0 + processed_segments = 0 + + for traj_idx, traj in enumerate(tqdm(trajectories, desc=f"{config_name}")): + # Find valid segments + segments = find_valid_segments(traj['valid'], min_length=args.min_segment_length) + + if len(segments) == 0: + continue + + for seg_idx, (start, end) in enumerate(segments): + # Extract segment + segment = extract_segment(traj, start, end) + + # Filter stationary vehicles + position_change = np.sqrt( + (segment['x'][-1] - segment['x'][0])**2 + + (segment['y'][-1] - segment['y'][0])**2 + ) + speed = np.sqrt(segment['vx']**2 + segment['vy']**2) + max_speed = np.max(speed) + mean_speed = np.mean(speed) + + if position_change < 0.5 and max_speed < 0.5: + skipped_stationary += 1 + continue + + if mean_speed < args.min_mean_speed: + skipped_stationary += 1 + continue + + processed_segments += 1 + + # Initialize state + jerk_state = initialize_jerk_state(segment, dt=args.dt) + + # Approximate trajectory + jerk_traj, jerk_actions, jerk_errors = approximate_trajectory( + jerk_dynamics, segment, jerk_state, weights=weights + ) + jerk_traj['valid'] = segment['valid'].copy() + + # Compute metrics + metrics = compute_all_metrics(jerk_traj, segment, segment['valid']) + metrics_list.append(metrics) + + # Store for kinematic realism + segments_list.append({ + 'ground_truth': segment, + 'predicted': jerk_traj, + }) + + # Store for trajectory visualization + if store_trajectories: + trajectory_data.append({ + 'scenario': traj['scenario'], + 'object_id': traj['object_id'], + 'traj_idx': traj_idx, + 'seg_idx': seg_idx, + 'segment_range': (start, end), + 'ground_truth': segment, + 'predicted': jerk_traj, + }) + + print(f"\n Processed: {processed_segments} segments") + print(f" Skipped: {skipped_stationary} stationary/slow segments") + + if len(metrics_list) == 0: + return None + + # Aggregate metrics + aggregated = aggregate_metrics(metrics_list, model_name=config['description']) + + # Compute kinematic realism + gt_features_list = [] + pred_features_list = [] + + for seg in segments_list: + gt_feat = compute_kinematic_features( + seg['ground_truth']['x'], + seg['ground_truth']['y'], + seg['ground_truth']['heading'], + seg['ground_truth']['valid'], + dt=args.dt + ) + gt_features_list.append(gt_feat) + + pred_feat = compute_kinematic_features( + seg['predicted']['x'], + seg['predicted']['y'], + seg['predicted']['heading'], + seg['predicted']['valid'], + dt=args.dt + ) + pred_features_list.append(pred_feat) + + gt_population = aggregate_features_across_segments(gt_features_list) + pred_population = aggregate_features_across_segments(pred_features_list) + kinematic_score = compute_kinematic_realism_score(gt_population, pred_population) + + # Print summary + print(format_metrics_summary(aggregated)) + print(format_kinematic_summary(kinematic_score, config['description'])) + + result = { + 'metrics_list': metrics_list, + 'aggregated': aggregated, + 'kinematic_score': kinematic_score, + 'num_segments': len(metrics_list), + } + + if store_trajectories: + result['trajectory_data'] = trajectory_data + + return result + + +def plot_ablation_comparison(results, output_path): + """ + Create histogram comparison for ablation study. + Shows baseline vs expanded error distributions for position, speed, and heading. + + Args: + results: dict mapping config_name -> result dict + output_path: path to save figure + """ + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + + # Extract error distributions from metrics_list + baseline_metrics = results['baseline']['metrics_list'] + expanded_metrics = results['expanded']['metrics_list'] + + # Extract position errors (ADE) + baseline_ade = [m['position']['ade'] for m in baseline_metrics] + expanded_ade = [m['position']['ade'] for m in expanded_metrics] + + # Extract speed errors (MAE) + baseline_speed = [m['velocity']['speed_mae'] for m in baseline_metrics] + expanded_speed = [m['velocity']['speed_mae'] for m in expanded_metrics] + + # Extract heading errors (MAE) + baseline_heading = [m['heading']['mae'] for m in baseline_metrics] + expanded_heading = [m['heading']['mae'] for m in expanded_metrics] + + # Panel 1: Position error histogram + axes[0].hist(baseline_ade, bins=30, alpha=0.5, color='red', label='Baseline (12 actions)') + axes[0].hist(expanded_ade, bins=30, alpha=0.5, color='blue', label='Expanded (40 actions)') + axes[0].set_xlabel('ADE (m)', fontweight='bold') + axes[0].set_ylabel('Frequency', fontweight='bold') + axes[0].set_title('Position Error Distribution', fontweight='bold') + axes[0].legend() + axes[0].grid(True, alpha=0.3, axis='y') + + # Panel 2: Speed error histogram + axes[1].hist(baseline_speed, bins=30, alpha=0.5, color='red', label='Baseline (12 actions)') + axes[1].hist(expanded_speed, bins=30, alpha=0.5, color='blue', label='Expanded (40 actions)') + axes[1].set_xlabel('Speed MAE (m/s)', fontweight='bold') + axes[1].set_ylabel('Frequency', fontweight='bold') + axes[1].set_title('Speed Error Distribution', fontweight='bold') + axes[1].legend() + axes[1].grid(True, alpha=0.3, axis='y') + + # Panel 3: Heading error histogram + axes[2].hist(baseline_heading, bins=30, alpha=0.5, color='red', label='Baseline (12 actions)') + axes[2].hist(expanded_heading, bins=30, alpha=0.5, color='blue', label='Expanded (40 actions)') + axes[2].set_xlabel('Heading MAE (deg)', fontweight='bold') + axes[2].set_ylabel('Frequency', fontweight='bold') + axes[2].set_title('Heading Error Distribution', fontweight='bold') + axes[2].legend() + axes[2].grid(True, alpha=0.3, axis='y') + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches='tight') + print(f"\nAblation comparison plot saved to: {output_path}") + + +def save_ablation_summary(results, output_path): + """Save ablation study summary to JSON.""" + summary = {} + + for config_name, result in results.items(): + summary[config_name] = { + 'description': JERK_CONFIGS[config_name]['description'], + 'jerk_long': JERK_CONFIGS[config_name]['jerk_long'], + 'jerk_lat': JERK_CONFIGS[config_name]['jerk_lat'], + 'num_actions': len(JERK_CONFIGS[config_name]['jerk_long']) * + len(JERK_CONFIGS[config_name]['jerk_lat']), + 'num_segments': result['num_segments'], + 'aggregated_metrics': result['aggregated'], + 'kinematic_realism': result['kinematic_score'], + } + + with open(output_path, 'w') as f: + json.dump(summary, f, indent=2) + + print(f"Ablation summary saved to: {output_path}") + + +def print_comparison_table(results): + """Print comparison table of key metrics.""" + print(f"\n{'='*80}") + print("ABLATION STUDY COMPARISON") + print(f"{'='*80}") + + # Header + print(f"\n{'Metric':<30} {'Baseline (12)':<20} {'Expanded (40)':<20} {'Improvement':<15}") + print(f"{'-'*30} {'-'*20} {'-'*20} {'-'*15}") + + # Get values + def get_val(config, path): + val = results[config]['aggregated'] + for key in path.split('.'): + val = val[key] + return val + + metrics = [ + ('Position ADE (m)', 'position.ade_mean'), + ('Position FDE (m)', 'position.fde_mean'), + ('Position P90 (m)', 'position.ade_p90'), + ('Position Max (m)', 'position.max_mean'), + ('Speed MAE (m/s)', 'velocity.speed_mae_mean'), + ('Speed P90 (m/s)', 'velocity.speed_max_p90'), + ('Heading MAE (deg)', 'heading.mae_mean'), + ('Heading P90 (deg)', 'heading.max_p90'), + ] + + for label, path in metrics: + baseline = get_val('baseline', path) + expanded = get_val('expanded', path) + improvement = ((baseline - expanded) / baseline) * 100 + + print(f"{label:<30} {baseline:<20.4f} {expanded:<20.4f} {improvement:>+14.2f}%") + + # Kinematic realism (higher is better) + baseline_kr = results['baseline']['kinematic_score']['kinematic_realism_score'] + expanded_kr = results['expanded']['kinematic_score']['kinematic_realism_score'] + kr_improvement = ((expanded_kr - baseline_kr) / baseline_kr) * 100 + + print(f"\n{'Kinematic Realism Score':<30} {baseline_kr:<20.4f} {expanded_kr:<20.4f} {kr_improvement:>+14.2f}%") + print(f"{'='*80}\n") + + +def main(): + parser = argparse.ArgumentParser( + description='Ablation study for JERK dynamics discretization' + ) + parser.add_argument( + '--data_dir', + type=str, + default='data/processed/training', + help='Directory containing JSON trajectory files' + ) + parser.add_argument( + '--num_scenarios', + type=int, + default=100, + help='Number of scenarios to process (default: 100)' + ) + parser.add_argument( + '--output_dir', + type=str, + default='output/jerk_ablation', + help='Output directory for results' + ) + parser.add_argument( + '--min_segment_length', + type=int, + default=10, + help='Minimum valid segment length (default: 10 timesteps)' + ) + parser.add_argument( + '--error_weights', + type=str, + default='1.0,1.0,10.0', + help='Weights for pos,vel,heading errors (default: 1.0,1.0,10.0)' + ) + parser.add_argument( + '--dt', + type=float, + default=0.1, + help='Time step in seconds (default: 0.1)' + ) + parser.add_argument( + '--min_mean_speed', + type=float, + default=0.5, + help='Minimum mean speed (m/s) to process segment (default: 0.5)' + ) + parser.add_argument( + '--configs', + type=str, + nargs='+', + default=['baseline', 'expanded'], + choices=['baseline', 'expanded'], + help='Which configurations to run (default: both)' + ) + parser.add_argument( + '--max_visualizations', + type=int, + default=10, + help='Maximum number of individual trajectory visualizations (default: 10)' + ) + + args = parser.parse_args() + + # Parse weights + weights = parse_weights(args.error_weights) + print(f"Using error weights: pos={weights['pos']}, vel={weights['vel']}, heading={weights['heading']}") + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + viz_dir = os.path.join(args.output_dir, 'visualizations') + os.makedirs(viz_dir, exist_ok=True) + + # Load trajectories (shared across all configs) + print(f"\nLoading trajectories from {args.data_dir}...") + trajectories = load_trajectories_from_directory( + args.data_dir, + num_scenarios=args.num_scenarios + ) + print(f"Loaded {len(trajectories)} vehicle trajectories") + + # Run each configuration + # Store trajectories only if we need visualizations and we have both configs + store_trajs = args.max_visualizations > 0 and len(args.configs) >= 2 + results = {} + for config_name in args.configs: + config = JERK_CONFIGS[config_name] + result = run_configuration(config_name, config, trajectories, args, weights, + store_trajectories=store_trajs) + + if result is None: + print(f"WARNING: No valid segments for {config_name}") + continue + + results[config_name] = result + + if len(results) == 0: + print("ERROR: No configurations produced valid results") + return + + # Generate individual trajectory visualizations + if store_trajs and 'baseline' in results and 'expanded' in results: + print("\nGenerating individual trajectory visualizations...") + baseline_trajs = results['baseline']['trajectory_data'] + expanded_trajs = results['expanded']['trajectory_data'] + + # Match trajectories from both configs (they should be in same order) + num_viz = min(args.max_visualizations, len(baseline_trajs)) + for i in range(num_viz): + baseline_data = baseline_trajs[i] + expanded_data = expanded_trajs[i] + + # Verify they're the same segment + assert baseline_data['scenario'] == expanded_data['scenario'] + assert baseline_data['object_id'] == expanded_data['object_id'] + assert baseline_data['segment_range'] == expanded_data['segment_range'] + + # Create visualization + viz_path = os.path.join( + viz_dir, + f'trajectory_{baseline_data["traj_idx"]:04d}_seg{baseline_data["seg_idx"]:02d}.png' + ) + + plot_jerk_ablation_trajectory( + baseline_data['ground_truth'], + baseline_data['predicted'], + expanded_data['predicted'], + baseline_data['segment_range'], + viz_path, + config_names=('Baseline (12 actions)', 'Expanded (40 actions)') + ) + + print(f" Generated {num_viz} trajectory visualizations in {viz_dir}") + + # Generate comparison visualizations + print("\nGenerating aggregate comparison plots...") + comparison_plot_path = os.path.join(args.output_dir, 'ablation_comparison.png') + plot_ablation_comparison(results, comparison_plot_path) + + # Save summary + summary_path = os.path.join(args.output_dir, 'ablation_summary.json') + save_ablation_summary(results, summary_path) + + # Print comparison table + if len(results) >= 2: + print_comparison_table(results) + + print(f"\nAblation study complete! Results saved to {args.output_dir}") + + +if __name__ == '__main__': + main() diff --git a/pufferlib/ocean/benchmark/scripts/run_trajectory_approximation.py b/pufferlib/ocean/benchmark/scripts/run_trajectory_approximation.py new file mode 100755 index 000000000..054e38075 --- /dev/null +++ b/pufferlib/ocean/benchmark/scripts/run_trajectory_approximation.py @@ -0,0 +1,398 @@ +#!/usr/bin/env python3 +""" +Command-line interface for trajectory approximation using CLASSIC and JERK dynamics models. + +This script loads Waymo trajectories from JSON files and approximates them using +greedy search over discrete action spaces for both dynamics models. +""" + +import argparse +import os +import json +import csv +from tqdm import tqdm +import numpy as np + +from pufferlib.ocean.benchmark.trajectory_approximation import ( + ClassicDynamics, + JerkDynamics, + load_trajectories_from_directory, + find_valid_segments, + extract_segment, + initialize_classic_state, + initialize_jerk_state, + approximate_trajectory, + compute_all_metrics, + aggregate_metrics, + format_metrics_summary, + plot_trajectory_comparison, + plot_aggregate_metrics, + plot_error_histogram, + compute_kinematic_features, + aggregate_features_across_segments, + compute_kinematic_realism_score, + compute_leave_one_out_baseline, + format_kinematic_summary, +) + + +def parse_weights(weight_string): + """Parse comma-separated weights string.""" + try: + pos, vel, heading = map(float, weight_string.split(',')) + return {'pos': pos, 'vel': vel, 'heading': heading} + except: + raise ValueError(f"Invalid weights format: {weight_string}. Expected: pos,vel,heading") + + +def save_metrics_to_csv(all_results, output_path): + """Save detailed metrics to CSV file.""" + with open(output_path, 'w', newline='') as f: + writer = csv.writer(f) + + # Header + writer.writerow([ + 'scenario', 'object_id', 'segment_start', 'segment_end', 'segment_length', + 'model', + 'ade', 'fde', 'pos_p90', 'pos_max', + 'speed_mae', 'speed_max', 'speed_p90', + 'heading_mae', 'heading_max', 'heading_p90' + ]) + + # Write rows + for result in all_results: + scenario = result['scenario'] + object_id = result['object_id'] + start, end = result['segment'] + segment_length = end - start + + for model_name in ['classic', 'jerk']: + metrics = result[f'{model_name}_metrics'] + + writer.writerow([ + scenario, object_id, start, end, segment_length, + model_name.upper(), + metrics['position']['ade'], + metrics['position']['fde'], + metrics['position']['p90'], + metrics['position']['max'], + metrics['velocity']['speed_mae'], + metrics['velocity']['speed_max'], + metrics['velocity']['speed_p90'], + metrics['heading']['mae'], + metrics['heading']['max'], + metrics['heading']['p90'], + ]) + + +def save_summary_to_json(aggregated_classic, aggregated_jerk, kinematic_scores, output_path): + """Save aggregated metrics to JSON file.""" + summary = { + 'classic': aggregated_classic, + 'jerk': aggregated_jerk, + 'kinematic_realism': kinematic_scores, + } + + with open(output_path, 'w') as f: + json.dump(summary, f, indent=2) + + +def main(): + parser = argparse.ArgumentParser( + description='Approximate Waymo trajectories using CLASSIC and JERK dynamics models' + ) + parser.add_argument( + '--data_dir', + type=str, + default='data/processed/training', + help='Directory containing JSON trajectory files' + ) + parser.add_argument( + '--num_scenarios', + type=int, + default=100, + help='Number of scenarios to process (default: 100)' + ) + parser.add_argument( + '--output_dir', + type=str, + default='output/trajectory_approximation', + help='Output directory for results' + ) + parser.add_argument( + '--min_segment_length', + type=int, + default=10, + help='Minimum valid segment length to process (default: 10 timesteps)' + ) + parser.add_argument( + '--error_weights', + type=str, + default='1.0,1.0,10.0', + help='Weights for pos,vel,heading errors (default: 1.0,1.0,10.0)' + ) + parser.add_argument( + '--max_visualizations', + type=int, + default=10, + help='Maximum number of individual trajectory visualizations (default: 10)' + ) + parser.add_argument( + '--dt', + type=float, + default=0.1, + help='Time step in seconds (default: 0.1)' + ) + parser.add_argument( + '--min_mean_speed', + type=float, + default=0.5, + help='Minimum mean speed (m/s) to process segment (default: 0.5, filters stationary vehicles)' + ) + + args = parser.parse_args() + + # Parse weights + weights = parse_weights(args.error_weights) + print(f"Using error weights: pos={weights['pos']}, vel={weights['vel']}, heading={weights['heading']}") + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + viz_dir = os.path.join(args.output_dir, 'visualizations') + os.makedirs(viz_dir, exist_ok=True) + + # Initialize dynamics models + print(f"\nInitializing dynamics models (dt={args.dt}s)...") + classic_dynamics = ClassicDynamics(dt=args.dt) + jerk_dynamics = JerkDynamics(dt=args.dt) + print(f" CLASSIC: {classic_dynamics.num_actions} actions") + print(f" JERK: {jerk_dynamics.num_actions} actions") + + # Load trajectories + print(f"\nLoading trajectories from {args.data_dir}...") + trajectories = load_trajectories_from_directory( + args.data_dir, + num_scenarios=args.num_scenarios + ) + print(f"Loaded {len(trajectories)} vehicle trajectories") + + # Process each trajectory + print(f"\nProcessing trajectories (min segment length: {args.min_segment_length}, min mean speed: {args.min_mean_speed} m/s)...") + all_results = [] + all_gt_segments = [] # For kinematic realism + all_classic_segments = [] # For kinematic realism + all_jerk_segments = [] # For kinematic realism + viz_count = 0 + skipped_stationary = 0 + processed_segments = 0 + + for traj_idx, traj in enumerate(tqdm(trajectories, desc="Trajectories")): + # Find valid segments + segments = find_valid_segments(traj['valid'], min_length=args.min_segment_length) + + if len(segments) == 0: + continue + + for seg_idx, (start, end) in enumerate(segments): + # Extract segment + segment = extract_segment(traj, start, end) + + # Skip stationary vehicles (parked or stopped in traffic) + # Calculate position change and speed statistics + position_change = np.sqrt( + (segment['x'][-1] - segment['x'][0])**2 + + (segment['y'][-1] - segment['y'][0])**2 + ) + speed = np.sqrt(segment['vx']**2 + segment['vy']**2) + max_speed = np.max(speed) + mean_speed = np.mean(speed) + + # Filter out stationary vehicles (moving less than 0.5m total and max speed < 0.5 m/s) + if position_change < 0.5 and max_speed < 0.5: + skipped_stationary += 1 + continue # Skip this segment - stationary vehicle + + # Also skip very slow moving vehicles that won't test the dynamics well + if mean_speed < args.min_mean_speed: + skipped_stationary += 1 + continue + + processed_segments += 1 + + # Initialize states + classic_state = initialize_classic_state(segment) + jerk_state = initialize_jerk_state(segment, dt=args.dt) + + # Approximate with CLASSIC model + classic_traj, classic_actions, classic_errors = approximate_trajectory( + classic_dynamics, segment, classic_state, weights=weights + ) + # Add valid mask from ground truth + classic_traj['valid'] = segment['valid'].copy() + + # Approximate with JERK model + jerk_traj, jerk_actions, jerk_errors = approximate_trajectory( + jerk_dynamics, segment, jerk_state, weights=weights + ) + # Add valid mask from ground truth + jerk_traj['valid'] = segment['valid'].copy() + + # Compute metrics + classic_metrics = compute_all_metrics(classic_traj, segment, segment['valid']) + jerk_metrics = compute_all_metrics(jerk_traj, segment, segment['valid']) + + # Store results + result = { + 'scenario': traj['scenario'], + 'object_id': traj['object_id'], + 'segment': (start, end), + 'classic_metrics': classic_metrics, + 'jerk_metrics': jerk_metrics, + } + all_results.append(result) + + # Store trajectories for kinematic realism computation + all_gt_segments.append(segment) + all_classic_segments.append(classic_traj) + all_jerk_segments.append(jerk_traj) + + # Create visualization for first few segments + if viz_count < args.max_visualizations: + viz_path = os.path.join( + viz_dir, + f'comparison_{traj_idx:04d}_seg{seg_idx:02d}.png' + ) + plot_trajectory_comparison( + segment, classic_traj, jerk_traj, + (start, end), viz_path + ) + viz_count += 1 + + print(f"\nProcessing complete:") + print(f" Processed: {processed_segments} moving trajectory segments") + print(f" Skipped: {skipped_stationary} stationary/slow segments") + print(f" Total results: {len(all_results)} segments") + + if len(all_results) == 0: + print("No valid trajectory segments found. Exiting.") + return + + # Aggregate metrics + print("\nAggregating metrics...") + classic_metrics_list = [r['classic_metrics'] for r in all_results] + jerk_metrics_list = [r['jerk_metrics'] for r in all_results] + + aggregated_classic = aggregate_metrics(classic_metrics_list, model_name='CLASSIC') + aggregated_jerk = aggregate_metrics(jerk_metrics_list, model_name='JERK') + + # Print summary + print(format_metrics_summary(aggregated_classic)) + print(format_metrics_summary(aggregated_jerk)) + + # Compute kinematic realism scores + print("\nComputing kinematic realism scores...") + + # Compute features for all segments + print(" Computing kinematic features...") + gt_features_list = [] + classic_features_list = [] + jerk_features_list = [] + + for i in range(len(all_gt_segments)): + gt_feat = compute_kinematic_features( + all_gt_segments[i]['x'], + all_gt_segments[i]['y'], + all_gt_segments[i]['heading'], + all_gt_segments[i]['valid'], + dt=args.dt + ) + gt_features_list.append(gt_feat) + + classic_feat = compute_kinematic_features( + all_classic_segments[i]['x'], + all_classic_segments[i]['y'], + all_classic_segments[i]['heading'], + all_classic_segments[i]['valid'], + dt=args.dt + ) + classic_features_list.append(classic_feat) + + jerk_feat = compute_kinematic_features( + all_jerk_segments[i]['x'], + all_jerk_segments[i]['y'], + all_jerk_segments[i]['heading'], + all_jerk_segments[i]['valid'], + dt=args.dt + ) + jerk_features_list.append(jerk_feat) + + # Aggregate features across population + print(" Aggregating population-level features...") + gt_population = aggregate_features_across_segments(gt_features_list) + classic_population = aggregate_features_across_segments(classic_features_list) + jerk_population = aggregate_features_across_segments(jerk_features_list) + + # Compute leave-one-out baseline for ground truth + print(" Computing ground truth baseline (leave-one-out)...") + gt_baseline_metrics = compute_leave_one_out_baseline(all_gt_segments, dt=args.dt) + + # Compute kinematic realism scores for inverse dynamics + print(" Evaluating CLASSIC kinematic realism...") + classic_kinematic = compute_kinematic_realism_score(gt_population, classic_population) + + print(" Evaluating JERK kinematic realism...") + jerk_kinematic = compute_kinematic_realism_score(gt_population, jerk_population) + + # Print kinematic realism summaries + print(format_kinematic_summary(gt_baseline_metrics, 'Ground Truth (Leave-One-Out)')) + print(format_kinematic_summary(classic_kinematic, 'CLASSIC')) + print(format_kinematic_summary(jerk_kinematic, 'JERK')) + + # Compute realism loss + print(f"\n{'='*60}") + print("Kinematic Realism Loss (GT Baseline - Inverse Dynamics)") + print(f"{'='*60}") + classic_loss = gt_baseline_metrics['kinematic_realism_score'] - classic_kinematic['kinematic_realism_score'] + jerk_loss = gt_baseline_metrics['kinematic_realism_score'] - jerk_kinematic['kinematic_realism_score'] + print(f" CLASSIC loss: {classic_loss:.4f}") + print(f" JERK loss: {jerk_loss:.4f}") + print(f"{'='*60}") + + # Save results + print("\nSaving results...") + + # Save detailed metrics to CSV + csv_path = os.path.join(args.output_dir, 'metrics_detailed.csv') + save_metrics_to_csv(all_results, csv_path) + print(f" Detailed metrics saved to: {csv_path}") + + # Save aggregated metrics to JSON + json_path = os.path.join(args.output_dir, 'metrics_summary.json') + kinematic_scores = { + 'ground_truth_baseline': gt_baseline_metrics, + 'classic': classic_kinematic, + 'jerk': jerk_kinematic, + 'classic_realism_loss': classic_loss, + 'jerk_realism_loss': jerk_loss, + } + save_summary_to_json(aggregated_classic, aggregated_jerk, kinematic_scores, json_path) + print(f" Summary metrics saved to: {json_path}") + + # Generate aggregate visualizations + print("\nGenerating visualizations...") + + # Box plots + boxplot_path = os.path.join(args.output_dir, 'comparison_boxplots.png') + plot_aggregate_metrics(classic_metrics_list, jerk_metrics_list, boxplot_path) + print(f" Box plots saved to: {boxplot_path}") + + # Histograms + histogram_path = os.path.join(args.output_dir, 'error_histograms.png') + plot_error_histogram(classic_metrics_list, jerk_metrics_list, histogram_path) + print(f" Histograms saved to: {histogram_path}") + + print(f"\nDone! Results saved to {args.output_dir}") + + +if __name__ == '__main__': + main() diff --git a/pufferlib/ocean/benchmark/trajectory_approximation/__init__.py b/pufferlib/ocean/benchmark/trajectory_approximation/__init__.py new file mode 100644 index 000000000..353fa5dbc --- /dev/null +++ b/pufferlib/ocean/benchmark/trajectory_approximation/__init__.py @@ -0,0 +1,79 @@ +"""Trajectory approximation using CLASSIC and JERK dynamics models.""" + +from .dynamics import ClassicDynamics, JerkDynamics +from .trajectory_loader import ( + load_trajectory_from_json, + load_trajectories_from_directory, + get_trajectory_statistics, +) +from .trajectory_utils import ( + wrap_angle, + find_valid_segments, + extract_segment, + initialize_classic_state, + initialize_jerk_state, +) +from .greedy_search import ( + GreedyActionSelector, + approximate_trajectory, + decode_action, +) +from .error_metrics import ( + compute_position_error, + compute_velocity_error, + compute_heading_error, + compute_all_metrics, + aggregate_metrics, + format_metrics_summary, +) +from .visualizer import ( + plot_trajectory_comparison, + plot_jerk_ablation_trajectory, + plot_aggregate_metrics, + plot_error_histogram, +) +from .kinematic_realism import ( + compute_kinematic_features, + aggregate_features_across_segments, + compute_kinematic_realism_score, + compute_leave_one_out_baseline, + format_kinematic_summary, +) + +__all__ = [ + # Dynamics models + 'ClassicDynamics', + 'JerkDynamics', + # Trajectory loading + 'load_trajectory_from_json', + 'load_trajectories_from_directory', + 'get_trajectory_statistics', + # Utilities + 'wrap_angle', + 'find_valid_segments', + 'extract_segment', + 'initialize_classic_state', + 'initialize_jerk_state', + # Greedy search + 'GreedyActionSelector', + 'approximate_trajectory', + 'decode_action', + # Metrics + 'compute_position_error', + 'compute_velocity_error', + 'compute_heading_error', + 'compute_all_metrics', + 'aggregate_metrics', + 'format_metrics_summary', + # Visualization + 'plot_trajectory_comparison', + 'plot_jerk_ablation_trajectory', + 'plot_aggregate_metrics', + 'plot_error_histogram', + # Kinematic realism + 'compute_kinematic_features', + 'aggregate_features_across_segments', + 'compute_kinematic_realism_score', + 'compute_leave_one_out_baseline', + 'format_kinematic_summary', +] diff --git a/pufferlib/ocean/benchmark/trajectory_approximation/dynamics.py b/pufferlib/ocean/benchmark/trajectory_approximation/dynamics.py new file mode 100644 index 000000000..00920accc --- /dev/null +++ b/pufferlib/ocean/benchmark/trajectory_approximation/dynamics.py @@ -0,0 +1,253 @@ +"""Pure Python/NumPy implementation of CLASSIC and JERK dynamics models. + +These implementations follow the C++ code in pufferlib/ocean/drive/drive.h +to ensure consistency with the simulation environment. +""" + +import numpy as np +from .trajectory_utils import wrap_angle + + +class ClassicDynamics: + """ + Kinematic bicycle model with direct acceleration and steering control. + + Action space: 91 discrete actions (7 acceleration × 13 steering) + State: {x, y, heading, vx, vy, length} + + Based on drive.h lines 1558-1618 + """ + + # Action space constants (from drive.h lines 99-100) + ACCELERATION_VALUES = np.array([-4.0, -2.667, -1.333, 0.0, 1.333, 2.667, 4.0]) + STEERING_VALUES = np.array([-1.0, -0.833, -0.667, -0.5, -0.333, -0.167, 0.0, + 0.167, 0.333, 0.5, 0.667, 0.833, 1.0]) + + def __init__(self, dt=0.1): + """ + Initialize CLASSIC dynamics model. + + Args: + dt: Time step in seconds (default 0.1s = 10 Hz) + """ + self.dt = dt + self.num_accel = len(self.ACCELERATION_VALUES) + self.num_steer = len(self.STEERING_VALUES) + self.num_actions = self.num_accel * self.num_steer # 7 * 13 = 91 + + def step(self, state, action_idx): + """ + Step the dynamics forward by one timestep. + + Args: + state: dict with {x, y, heading, vx, vy, length} + action_idx: int in [0, 90] (discrete action index) + Returns: + next_state: dict with updated {x, y, heading, vx, vy, length} + """ + # Decode action (drive.h lines 1576-1579) + acceleration_index = action_idx // self.num_steer + steering_index = action_idx % self.num_steer + acceleration = self.ACCELERATION_VALUES[acceleration_index] + steering = self.STEERING_VALUES[steering_index] + + # Current state (drive.h lines 1583-1587) + x = state['x'] + y = state['y'] + heading = state['heading'] + vx = state['vx'] + vy = state['vy'] + length = state['length'] + + # Calculate current speed (drive.h line 1590) + speed = np.sqrt(vx**2 + vy**2) + + # Update speed with acceleration (drive.h lines 1593-1594) + speed_new = speed + acceleration * self.dt + speed_new = np.clip(speed_new, -100.0, 100.0) # clipSpeed + + # Compute beta and yaw rate (drive.h lines 1597-1600) + beta = np.tanh(0.5 * np.tan(steering)) + yaw_rate = (speed * np.cos(beta) * np.tan(steering)) / length + + # New velocity components (drive.h lines 1603-1604) + vx_new = speed_new * np.cos(heading + beta) + vy_new = speed_new * np.sin(heading + beta) + + # Update position and heading (drive.h lines 1607-1609) + x_new = x + vx_new * self.dt + y_new = y + vy_new * self.dt + heading_new = heading + yaw_rate * self.dt + + # Normalize heading to [-pi, pi] + heading_new = wrap_angle(heading_new) + + return { + 'x': x_new, + 'y': y_new, + 'heading': heading_new, + 'vx': vx_new, + 'vy': vy_new, + 'length': length, + } + + +class JerkDynamics: + """ + Stateful jerk bicycle model with persistent acceleration state. + + Action space: Configurable discrete actions (jerk_long × jerk_lat) + State: {x, y, heading, vx, vy, a_long, a_lat, steering_angle, wheelbase} + + Based on drive.h lines 1620-1714 + """ + + # Default action space constants (from drive.h lines 95-96) + DEFAULT_JERK_LONG = np.array([-15.0, -4.0, 0.0, 4.0]) + DEFAULT_JERK_LAT = np.array([-4.0, 0.0, 4.0]) + + def __init__(self, dt=0.1, jerk_long=None, jerk_lat=None): + """ + Initialize JERK dynamics model. + + Args: + dt: Time step in seconds (default 0.1s = 10 Hz) + jerk_long: Longitudinal jerk discretization values (default: [-15, -4, 0, 4]) + jerk_lat: Lateral jerk discretization values (default: [-4, 0, 4]) + """ + self.dt = dt + self.JERK_LONG = np.array(jerk_long) if jerk_long is not None else self.DEFAULT_JERK_LONG + self.JERK_LAT = np.array(jerk_lat) if jerk_lat is not None else self.DEFAULT_JERK_LAT + self.num_jerk_long = len(self.JERK_LONG) + self.num_jerk_lat = len(self.JERK_LAT) + self.num_actions = self.num_jerk_long * self.num_jerk_lat + + def step(self, state, action_idx): + """ + Step the dynamics forward by one timestep. + + Args: + state: dict with {x, y, heading, vx, vy, a_long, a_lat, + steering_angle, wheelbase} + action_idx: int in [0, 11] (discrete action index) + Returns: + next_state: dict with updated state including accelerations + """ + # Decode action (drive.h lines 1638-1642) + # Note: In the C++ code, actions are stored as [long_idx, lat_idx] pairs + # But for discrete actions encoded as single int: + jerk_long_idx = action_idx // self.num_jerk_lat + jerk_lat_idx = action_idx % self.num_jerk_lat + a_long_jerk = self.JERK_LONG[jerk_long_idx] + a_lat_jerk = self.JERK_LAT[jerk_lat_idx] + + # Current state + x = state['x'] + y = state['y'] + heading = state['heading'] + vx = state['vx'] + vy = state['vy'] + a_long = state['a_long'] + a_lat = state['a_lat'] + steering_angle = state['steering_angle'] + wheelbase = state['wheelbase'] + + # Heading unit vector + heading_x = np.cos(heading) + heading_y = np.sin(heading) + + # Calculate new acceleration (drive.h lines 1646-1660) + a_long_new = a_long + a_long_jerk * self.dt + a_lat_new = a_lat + a_lat_jerk * self.dt + + # Make it easy to stop with 0 accel (sign-change detection) + if a_long * a_long_new < 0: + a_long_new = 0.0 + else: + a_long_new = np.clip(a_long_new, -5.0, 2.5) + + if a_lat * a_lat_new < 0: + a_lat_new = 0.0 + else: + a_lat_new = np.clip(a_lat_new, -4.0, 4.0) + + # Calculate new velocity (drive.h lines 1663-1672) + v_dot_heading = vx * heading_x + vy * heading_y + speed = np.sqrt(vx**2 + vy**2) + signed_v = np.copysign(speed, v_dot_heading) + + # Trapezoidal integration + v_new = signed_v + 0.5 * (a_long_new + a_long) * self.dt + + # Make it easy to stop with 0 vel (sign-change detection) + if signed_v * v_new < 0: + v_new = 0.0 + else: + v_new = np.clip(v_new, -2.0, 20.0) + + # Calculate new steering angle (drive.h lines 1675-1683) + signed_curvature = a_lat_new / max(v_new * v_new, 1e-5) + # Ensure minimum curvature magnitude (drive.h line 1676) + signed_curvature = np.copysign(max(abs(signed_curvature), 1e-5), signed_curvature) + + steering_angle_target = np.arctan(signed_curvature * wheelbase) + + # Rate-limited steering (max 0.6 rad/s) + delta_steer = np.clip( + steering_angle_target - steering_angle, + -0.6 * self.dt, + 0.6 * self.dt + ) + steering_angle_new = np.clip( + steering_angle + delta_steer, + -0.55, + 0.55 + ) + + # Update curvature and accel to account for limited steering + signed_curvature = np.tan(steering_angle_new) / wheelbase + a_lat_new = v_new * v_new * signed_curvature + + # Calculate resulting movement using bicycle dynamics (drive.h lines 1686-1700) + d = 0.5 * (v_new + signed_v) * self.dt # Average displacement + theta = d * signed_curvature # Rotation angle + + if abs(signed_curvature) < 1e-5 or abs(theta) < 1e-5: + # Straight line motion + dx_local = d + dy_local = 0.0 + else: + # Curved motion + dx_local = np.sin(theta) / signed_curvature + dy_local = (1.0 - np.cos(theta)) / signed_curvature + + # Transform to global coordinates + dx = dx_local * heading_x - dy_local * heading_y + dy = dx_local * heading_y + dy_local * heading_x + + # Update everything (drive.h lines 1702-1713) + x_new = x + dx + y_new = y + dy + heading_new = wrap_angle(heading + theta) + + # Update velocity components + vx_new = v_new * np.cos(heading_new) + vy_new = v_new * np.sin(heading_new) + + # Compute jerk (for diagnostics) + jerk_long = (a_long_new - a_long) / self.dt + jerk_lat = (a_lat_new - a_lat) / self.dt + + return { + 'x': x_new, + 'y': y_new, + 'heading': heading_new, + 'vx': vx_new, + 'vy': vy_new, + 'a_long': a_long_new, + 'a_lat': a_lat_new, + 'steering_angle': steering_angle_new, + 'wheelbase': wheelbase, + 'jerk_long': jerk_long, + 'jerk_lat': jerk_lat, + } diff --git a/pufferlib/ocean/benchmark/trajectory_approximation/error_metrics.py b/pufferlib/ocean/benchmark/trajectory_approximation/error_metrics.py new file mode 100644 index 000000000..cc3982c11 --- /dev/null +++ b/pufferlib/ocean/benchmark/trajectory_approximation/error_metrics.py @@ -0,0 +1,230 @@ +"""Error metrics for trajectory approximation evaluation.""" + +import numpy as np +from .trajectory_utils import wrap_angle + + +def compute_position_error(predicted_traj, ground_truth_traj, valid_mask): + """ + Compute position errors (ADE, FDE, percentiles). + + Args: + predicted_traj: Dict with {x, y, ...} + ground_truth_traj: Dict with {x, y, ...} + valid_mask: Boolean array indicating valid timesteps + Returns: + dict with {'ade', 'fde', 'p50', 'p90', 'p95', 'max'} + """ + dx = predicted_traj['x'] - ground_truth_traj['x'] + dy = predicted_traj['y'] - ground_truth_traj['y'] + errors = np.sqrt(dx**2 + dy**2)[valid_mask] + + if len(errors) == 0: + return { + 'ade': 0.0, + 'fde': 0.0, + 'p50': 0.0, + 'p90': 0.0, + 'p95': 0.0, + 'max': 0.0, + } + + return { + 'ade': float(np.mean(errors)), # Average Displacement Error + 'fde': float(errors[-1]), # Final Displacement Error + 'p50': float(np.percentile(errors, 50)), + 'p90': float(np.percentile(errors, 90)), + 'p95': float(np.percentile(errors, 95)), + 'max': float(np.max(errors)), + } + + +def compute_velocity_error(predicted_traj, ground_truth_traj, valid_mask): + """ + Compute velocity magnitude and direction errors. + + Args: + predicted_traj: Dict with {vx, vy, ...} + ground_truth_traj: Dict with {vx, vy, ...} + valid_mask: Boolean array indicating valid timesteps + Returns: + dict with speed and direction error statistics + """ + # Speed (magnitude) error + pred_speed = np.sqrt(predicted_traj['vx']**2 + predicted_traj['vy']**2) + gt_speed = np.sqrt(ground_truth_traj['vx']**2 + ground_truth_traj['vy']**2) + speed_error = np.abs(pred_speed - gt_speed)[valid_mask] + + # Velocity direction error (angle between velocity vectors) + pred_vel_angle = np.arctan2(predicted_traj['vy'], predicted_traj['vx']) + gt_vel_angle = np.arctan2(ground_truth_traj['vy'], ground_truth_traj['vx']) + direction_error = np.abs(wrap_angle(pred_vel_angle - gt_vel_angle))[valid_mask] + + if len(speed_error) == 0: + return { + 'speed_mae': 0.0, + 'speed_max': 0.0, + 'speed_p90': 0.0, + 'direction_mae': 0.0, + 'direction_max': 0.0, + 'direction_p90': 0.0, + } + + return { + 'speed_mae': float(np.mean(speed_error)), + 'speed_max': float(np.max(speed_error)), + 'speed_p90': float(np.percentile(speed_error, 90)), + 'direction_mae': float(np.mean(direction_error) * 180 / np.pi), # degrees + 'direction_max': float(np.max(direction_error) * 180 / np.pi), + 'direction_p90': float(np.percentile(direction_error, 90) * 180 / np.pi), + } + + +def compute_heading_error(predicted_traj, ground_truth_traj, valid_mask): + """ + Compute heading errors in degrees. + + Args: + predicted_traj: Dict with {heading, ...} + ground_truth_traj: Dict with {heading, ...} + valid_mask: Boolean array indicating valid timesteps + Returns: + dict with heading error statistics + """ + heading_diff = wrap_angle( + predicted_traj['heading'] - ground_truth_traj['heading'] + ) + errors = np.abs(heading_diff)[valid_mask] * 180 / np.pi # Convert to degrees + + if len(errors) == 0: + return { + 'mae': 0.0, + 'p50': 0.0, + 'p90': 0.0, + 'max': 0.0, + } + + return { + 'mae': float(np.mean(errors)), + 'p50': float(np.percentile(errors, 50)), + 'p90': float(np.percentile(errors, 90)), + 'max': float(np.max(errors)), + } + + +def compute_all_metrics(predicted_traj, ground_truth_traj, valid_mask): + """ + Compute all metrics and return unified dict. + + Args: + predicted_traj: Dict with {x, y, heading, vx, vy, ...} + ground_truth_traj: Dict with {x, y, heading, vx, vy, valid, ...} + valid_mask: Boolean array indicating valid timesteps + Returns: + dict with {'position': {...}, 'velocity': {...}, 'heading': {...}} + """ + return { + 'position': compute_position_error(predicted_traj, ground_truth_traj, valid_mask), + 'velocity': compute_velocity_error(predicted_traj, ground_truth_traj, valid_mask), + 'heading': compute_heading_error(predicted_traj, ground_truth_traj, valid_mask), + } + + +def aggregate_metrics(all_metrics, model_name=''): + """ + Aggregate metrics across multiple trajectory segments. + + Args: + all_metrics: List of metric dicts from compute_all_metrics + model_name: Optional name for the model (for display) + Returns: + dict with aggregated statistics + """ + if len(all_metrics) == 0: + return {} + + # Extract position metrics + position_ade = [m['position']['ade'] for m in all_metrics] + position_fde = [m['position']['fde'] for m in all_metrics] + position_max = [m['position']['max'] for m in all_metrics] + + # Extract velocity metrics + speed_mae = [m['velocity']['speed_mae'] for m in all_metrics] + speed_max = [m['velocity']['speed_max'] for m in all_metrics] + + # Extract heading metrics + heading_mae = [m['heading']['mae'] for m in all_metrics] + heading_max = [m['heading']['max'] for m in all_metrics] + + aggregated = { + 'model': model_name, + 'num_trajectories': len(all_metrics), + 'position': { + 'ade_mean': float(np.mean(position_ade)), + 'ade_std': float(np.std(position_ade)), + 'ade_median': float(np.median(position_ade)), + 'ade_p90': float(np.percentile(position_ade, 90)), + 'fde_mean': float(np.mean(position_fde)), + 'fde_std': float(np.std(position_fde)), + 'max_mean': float(np.mean(position_max)), + 'max_p90': float(np.percentile(position_max, 90)), + }, + 'velocity': { + 'speed_mae_mean': float(np.mean(speed_mae)), + 'speed_mae_std': float(np.std(speed_mae)), + 'speed_mae_median': float(np.median(speed_mae)), + 'speed_max_mean': float(np.mean(speed_max)), + 'speed_max_p90': float(np.percentile(speed_max, 90)), + }, + 'heading': { + 'mae_mean': float(np.mean(heading_mae)), + 'mae_std': float(np.std(heading_mae)), + 'mae_median': float(np.median(heading_mae)), + 'max_mean': float(np.mean(heading_max)), + 'max_p90': float(np.percentile(heading_max, 90)), + }, + } + + return aggregated + + +def format_metrics_summary(aggregated_metrics): + """ + Format aggregated metrics as a readable string. + + Args: + aggregated_metrics: Dict from aggregate_metrics + Returns: + Formatted string + """ + if not aggregated_metrics: + return "No metrics available" + + lines = [] + model_name = aggregated_metrics.get('model', 'Unknown') + num_traj = aggregated_metrics.get('num_trajectories', 0) + + lines.append(f"\n{'='*60}") + lines.append(f"Model: {model_name} ({num_traj} trajectory segments)") + lines.append(f"{'='*60}") + + # Position errors + pos = aggregated_metrics['position'] + lines.append(f"\nPosition Errors (meters):") + lines.append(f" ADE: {pos['ade_mean']:.3f} ± {pos['ade_std']:.3f} (median: {pos['ade_median']:.3f}, p90: {pos['ade_p90']:.3f})") + lines.append(f" FDE: {pos['fde_mean']:.3f} ± {pos['fde_std']:.3f}") + lines.append(f" Max: {pos['max_mean']:.3f} (p90: {pos['max_p90']:.3f})") + + # Velocity errors + vel = aggregated_metrics['velocity'] + lines.append(f"\nVelocity Errors:") + lines.append(f" Speed MAE: {vel['speed_mae_mean']:.3f} ± {vel['speed_mae_std']:.3f} m/s (median: {vel['speed_mae_median']:.3f})") + lines.append(f" Speed Max: {vel['speed_max_mean']:.3f} m/s (p90: {vel['speed_max_p90']:.3f})") + + # Heading errors + hdg = aggregated_metrics['heading'] + lines.append(f"\nHeading Errors (degrees):") + lines.append(f" MAE: {hdg['mae_mean']:.2f} ± {hdg['mae_std']:.2f}° (median: {hdg['mae_median']:.2f}°)") + lines.append(f" Max: {hdg['max_mean']:.2f}° (p90: {hdg['max_p90']:.2f}°)") + + return '\n'.join(lines) diff --git a/pufferlib/ocean/benchmark/trajectory_approximation/greedy_search.py b/pufferlib/ocean/benchmark/trajectory_approximation/greedy_search.py new file mode 100644 index 000000000..3ccbcd364 --- /dev/null +++ b/pufferlib/ocean/benchmark/trajectory_approximation/greedy_search.py @@ -0,0 +1,215 @@ +"""Greedy action selection for trajectory approximation.""" + +import numpy as np +from .trajectory_utils import wrap_angle +from .dynamics import ClassicDynamics, JerkDynamics + + +class GreedyActionSelector: + """ + Greedy action selection for trajectory approximation. + + At each timestep, evaluates all possible actions and selects the one + that minimizes the weighted error to the next ground truth state. + """ + + def __init__(self, dynamics_model, weights=None): + """ + Initialize greedy action selector. + + Args: + dynamics_model: ClassicDynamics or JerkDynamics instance + weights: Dict with weights for pos, vel, heading errors + Default: {'pos': 1.0, 'vel': 1.0, 'heading': 10.0} + """ + self.dynamics = dynamics_model + self.num_actions = dynamics_model.num_actions + + if weights is None: + # Default weights - heading gets higher weight because it's in radians + self.weights = {'pos': 1.0, 'vel': 1.0, 'heading': 10.0} + else: + self.weights = weights + + def compute_error(self, predicted_state, target_state): + """ + Compute weighted error between predicted and target states. + + Args: + predicted_state: Dict with {x, y, heading, vx, vy, ...} + target_state: Dict with {x, y, heading, vx, vy, ...} + Returns: + float: Weighted total error + """ + # Position error (Euclidean distance) + pos_error = np.sqrt( + (predicted_state['x'] - target_state['x'])**2 + + (predicted_state['y'] - target_state['y'])**2 + ) + + # Velocity error (Euclidean distance in velocity space) + vel_error = np.sqrt( + (predicted_state['vx'] - target_state['vx'])**2 + + (predicted_state['vy'] - target_state['vy'])**2 + ) + + # Heading error (angular distance, wrapped to [-pi, pi]) + heading_error = abs(wrap_angle( + predicted_state['heading'] - target_state['heading'] + )) + + # Weighted sum + total_error = ( + self.weights['pos'] * pos_error + + self.weights['vel'] * vel_error + + self.weights['heading'] * heading_error + ) + + return total_error + + def select_best_action(self, current_state, target_state): + """ + Test all possible actions and return the one with minimum error. + + Args: + current_state: Current state dict + target_state: Desired next state from ground truth + Returns: + tuple: (best_action_idx, best_state, min_error) + """ + min_error = float('inf') + best_action_idx = 0 + best_state = None + + for action_idx in range(self.num_actions): + # Simulate this action + predicted_state = self.dynamics.step(current_state, action_idx) + + # Compute error + error = self.compute_error(predicted_state, target_state) + + # Update best if this is better + if error < min_error: + min_error = error + best_action_idx = action_idx + best_state = predicted_state + + return best_action_idx, best_state, min_error + + +def approximate_trajectory(dynamics_model, ground_truth, initial_state, weights=None): + """ + Approximate a trajectory using greedy search. + + Args: + dynamics_model: ClassicDynamics or JerkDynamics instance + ground_truth: Dict with arrays {x, y, heading, vx, vy, valid, ...} + initial_state: Initial state dict + weights: Optional weights for error computation + Returns: + tuple: (approximated_trajectory, action_sequence, per_step_errors) + """ + selector = GreedyActionSelector(dynamics_model, weights=weights) + num_steps = len(ground_truth['x']) + + # Initialize storage + traj = { + 'x': np.zeros(num_steps, dtype=np.float32), + 'y': np.zeros(num_steps, dtype=np.float32), + 'heading': np.zeros(num_steps, dtype=np.float32), + 'vx': np.zeros(num_steps, dtype=np.float32), + 'vy': np.zeros(num_steps, dtype=np.float32), + } + + # Additional fields for JERK model + if isinstance(dynamics_model, JerkDynamics): + traj['a_long'] = np.zeros(num_steps, dtype=np.float32) + traj['a_lat'] = np.zeros(num_steps, dtype=np.float32) + traj['steering_angle'] = np.zeros(num_steps, dtype=np.float32) + + actions = np.zeros(num_steps - 1, dtype=np.int32) + errors = np.zeros(num_steps - 1, dtype=np.float32) + + # Set initial state + current_state = initial_state.copy() + traj['x'][0] = current_state['x'] + traj['y'][0] = current_state['y'] + traj['heading'][0] = current_state['heading'] + traj['vx'][0] = current_state['vx'] + traj['vy'][0] = current_state['vy'] + + if isinstance(dynamics_model, JerkDynamics): + traj['a_long'][0] = current_state['a_long'] + traj['a_lat'][0] = current_state['a_lat'] + traj['steering_angle'][0] = current_state['steering_angle'] + + # Greedy search at each timestep + for t in range(num_steps - 1): + # Skip if next timestep is invalid + if not ground_truth['valid'][t + 1]: + break + + # Target state from ground truth + target_state = { + 'x': ground_truth['x'][t + 1], + 'y': ground_truth['y'][t + 1], + 'heading': ground_truth['heading'][t + 1], + 'vx': ground_truth['vx'][t + 1], + 'vy': ground_truth['vy'][t + 1], + } + + # Select best action + action_idx, next_state, error = selector.select_best_action( + current_state, target_state + ) + + # Record action and error + actions[t] = action_idx + errors[t] = error + + # Update trajectory + traj['x'][t + 1] = next_state['x'] + traj['y'][t + 1] = next_state['y'] + traj['heading'][t + 1] = next_state['heading'] + traj['vx'][t + 1] = next_state['vx'] + traj['vy'][t + 1] = next_state['vy'] + + if isinstance(dynamics_model, JerkDynamics): + traj['a_long'][t + 1] = next_state['a_long'] + traj['a_lat'][t + 1] = next_state['a_lat'] + traj['steering_angle'][t + 1] = next_state['steering_angle'] + + # Update current state + current_state = next_state + + return traj, actions, errors + + +def decode_action(dynamics_model, action_idx): + """ + Decode action index into human-readable components. + + Args: + dynamics_model: ClassicDynamics or JerkDynamics instance + action_idx: Action index to decode + Returns: + Dict with action components + """ + if isinstance(dynamics_model, ClassicDynamics): + accel_idx = action_idx // dynamics_model.num_steer + steer_idx = action_idx % dynamics_model.num_steer + return { + 'type': 'classic', + 'acceleration': dynamics_model.ACCELERATION_VALUES[accel_idx], + 'steering': dynamics_model.STEERING_VALUES[steer_idx], + } + elif isinstance(dynamics_model, JerkDynamics): + jerk_long_idx = action_idx // dynamics_model.num_jerk_lat + jerk_lat_idx = action_idx % dynamics_model.num_jerk_lat + return { + 'type': 'jerk', + 'jerk_long': dynamics_model.JERK_LONG[jerk_long_idx], + 'jerk_lat': dynamics_model.JERK_LAT[jerk_lat_idx], + } + else: + return {'type': 'unknown'} diff --git a/pufferlib/ocean/benchmark/trajectory_approximation/kinematic_realism.py b/pufferlib/ocean/benchmark/trajectory_approximation/kinematic_realism.py new file mode 100644 index 000000000..0c43dc683 --- /dev/null +++ b/pufferlib/ocean/benchmark/trajectory_approximation/kinematic_realism.py @@ -0,0 +1,371 @@ +"""Population-level kinematic realism metrics for trajectory approximation. + +This module computes kinematic distribution similarity between ground truth trajectories +and inverse dynamics approximations, adapted from WOSAC methodology for deterministic systems. +""" + +import numpy as np +import configparser +import os +from typing import Dict, List, Tuple + + +def compute_kinematic_features( + x: np.ndarray, + y: np.ndarray, + heading: np.ndarray, + valid: np.ndarray, + dt: float = 0.1, +) -> Dict[str, np.ndarray]: + """Compute kinematic features (speeds and accelerations) from trajectory. + + Uses central difference for derivatives, matching WOSAC implementation. + + Args: + x: X positions, shape (n_steps,) + y: Y positions, shape (n_steps,) + heading: Heading angles, shape (n_steps,) + valid: Valid mask, shape (n_steps,) + dt: Time step in seconds (default 0.1s) + + Returns: + Dict with: + - linear_speed: Shape (n_steps,) + - linear_acceleration: Shape (n_steps,) + - angular_speed: Shape (n_steps,) + - angular_acceleration: Shape (n_steps,) + - speed_valid: Validity mask for speeds + - accel_valid: Validity mask for accelerations + """ + n_steps = len(x) + + # Compute position changes using central difference + # Central diff: f'(x) = [f(x+h) - f(x-h)] / 2h + dx = np.full(n_steps, np.nan) + dy = np.full(n_steps, np.nan) + dx[1:-1] = (x[2:] - x[:-2]) / 2 + dy[1:-1] = (y[2:] - y[:-2]) / 2 + + # Linear speed + linear_speed = np.sqrt(dx**2 + dy**2) / dt + + # Linear acceleration (central diff of speed) + linear_accel = np.full(n_steps, np.nan) + linear_accel[1:-1] = (linear_speed[2:] - linear_speed[:-2]) / (2 * dt) + + # Angular speed (central diff of heading, with wrapping) + dheading = np.full(n_steps, np.nan) + dheading[1:-1] = _wrap_angle((heading[2:] - heading[:-2]) / 2) * 2 # Wrap half-differences + angular_speed = dheading / dt + + # Angular acceleration (central diff of angular speed) + angular_accel = np.full(n_steps, np.nan) + dheading_step = dheading.copy() + angular_accel[1:-1] = _wrap_angle((dheading_step[2:] - dheading_step[:-2]) / 2) * 2 / (dt**2) + + # Compute validity masks using central logical_and + # Speed is valid if neighbors are valid + speed_valid = np.zeros(n_steps, dtype=bool) + speed_valid[1:-1] = valid[2:] & valid[:-2] + + # Acceleration is valid if speed neighbors are valid + accel_valid = np.zeros(n_steps, dtype=bool) + accel_valid[1:-1] = speed_valid[2:] & speed_valid[:-2] + + return { + 'linear_speed': linear_speed, + 'linear_acceleration': linear_accel, + 'angular_speed': angular_speed, + 'angular_acceleration': angular_accel, + 'speed_valid': speed_valid, + 'accel_valid': accel_valid, + } + + +def _wrap_angle(angle: np.ndarray) -> np.ndarray: + """Wrap angles to [-pi, pi] range.""" + return (angle + np.pi) % (2 * np.pi) - np.pi + + +def aggregate_features_across_segments( + all_segment_features: List[Dict[str, np.ndarray]] +) -> Dict[str, np.ndarray]: + """Flatten and concatenate features from multiple segments into population arrays. + + Args: + all_segment_features: List of feature dicts from compute_kinematic_features() + + Returns: + Dict with flattened population-level features: + - linear_speed: (N_total_valid_steps,) + - linear_acceleration: (N_total_valid_steps,) + - angular_speed: (N_total_valid_steps,) + - angular_acceleration: (N_total_valid_steps,) + """ + # Collect all valid samples across segments + linear_speeds = [] + linear_accels = [] + angular_speeds = [] + angular_accels = [] + + for features in all_segment_features: + # Filter by validity and exclude NaNs + speed_mask = features['speed_valid'] & ~np.isnan(features['linear_speed']) & ~np.isnan(features['angular_speed']) + accel_mask = features['accel_valid'] & ~np.isnan(features['linear_acceleration']) & ~np.isnan(features['angular_acceleration']) + + linear_speeds.append(features['linear_speed'][speed_mask]) + angular_speeds.append(features['angular_speed'][speed_mask]) + linear_accels.append(features['linear_acceleration'][accel_mask]) + angular_accels.append(features['angular_acceleration'][accel_mask]) + + return { + 'linear_speed': np.concatenate(linear_speeds) if linear_speeds else np.array([]), + 'linear_acceleration': np.concatenate(linear_accels) if linear_accels else np.array([]), + 'angular_speed': np.concatenate(angular_speeds) if angular_speeds else np.array([]), + 'angular_acceleration': np.concatenate(angular_accels) if angular_accels else np.array([]), + } + + +def compute_histogram_log_likelihood( + log_samples: np.ndarray, + sim_samples: np.ndarray, + min_val: float, + max_val: float, + num_bins: int, + additive_smoothing: float = 0.1, +) -> float: + """Compute mean log-likelihood of log samples under sim distribution. + + Adapted from WOSAC's histogram_estimate function for population-level evaluation. + + Args: + log_samples: Samples to evaluate (e.g., ground truth features) + sim_samples: Samples to build distribution from (e.g., simulated features) + min_val: Minimum value for histogram bins + max_val: Maximum value for histogram bins + num_bins: Number of histogram bins + additive_smoothing: Laplace smoothing pseudocount + + Returns: + Mean log-likelihood (scalar) + """ + if len(log_samples) == 0 or len(sim_samples) == 0: + return np.nan + + # Clip samples to valid range + log_samples_clipped = np.clip(log_samples, min_val, max_val) + sim_samples_clipped = np.clip(sim_samples, min_val, max_val) + + # Create bin edges + edges = np.linspace(min_val, max_val, num_bins + 1) + + # Build histogram from sim samples + sim_counts, _ = np.histogram(sim_samples_clipped, bins=edges) + + # Apply smoothing and normalize to probabilities + sim_counts = sim_counts.astype(float) + additive_smoothing + sim_probs = sim_counts / sim_counts.sum() + + # Find which bin each log sample belongs to + log_bins = np.digitize(log_samples_clipped, edges, right=False) - 1 + log_bins = np.clip(log_bins, 0, num_bins - 1) + + # Get log probabilities for each sample + log_probs = np.log(sim_probs[log_bins]) + + # Return mean log-likelihood + return float(np.mean(log_probs)) + + +def load_wosac_config() -> configparser.ConfigParser: + """Load WOSAC metric configuration.""" + config_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + 'wosac.ini' + ) + config = configparser.ConfigParser() + config.read(config_path) + return config + + +def compute_kinematic_realism_score( + gt_features: Dict[str, np.ndarray], + sim_features: Dict[str, np.ndarray], + config: configparser.ConfigParser = None, +) -> Dict[str, float]: + """Compute population-level kinematic realism scores. + + Evaluates how likely ground truth features are under simulated distribution, + using WOSAC histogram-based log-likelihood estimation. + + Args: + gt_features: Ground truth population features from aggregate_features_across_segments() + sim_features: Simulated population features from aggregate_features_across_segments() + config: WOSAC config (if None, loads default wosac.ini) + + Returns: + Dict with: + - linear_speed_likelihood: exp(mean_log_likelihood) + - linear_acceleration_likelihood: exp(mean_log_likelihood) + - angular_speed_likelihood: exp(mean_log_likelihood) + - angular_acceleration_likelihood: exp(mean_log_likelihood) + - kinematic_realism_score: Weighted average of above + """ + if config is None: + config = load_wosac_config() + + # Compute log-likelihoods for each metric + metrics = {} + + # Linear speed + min_val = config.getfloat('linear_speed', 'histogram.min_val') + max_val = config.getfloat('linear_speed', 'histogram.max_val') + num_bins = config.getint('linear_speed', 'histogram.num_bins') + smoothing = config.getfloat('linear_speed', 'histogram.additive_smoothing_pseudocount') + + log_likelihood = compute_histogram_log_likelihood( + gt_features['linear_speed'], + sim_features['linear_speed'], + min_val, max_val, num_bins, smoothing + ) + metrics['linear_speed_log_likelihood'] = log_likelihood + metrics['linear_speed_likelihood'] = np.exp(log_likelihood) + + # Linear acceleration + min_val = config.getfloat('linear_acceleration', 'histogram.min_val') + max_val = config.getfloat('linear_acceleration', 'histogram.max_val') + num_bins = config.getint('linear_acceleration', 'histogram.num_bins') + smoothing = config.getfloat('linear_acceleration', 'histogram.additive_smoothing_pseudocount') + + log_likelihood = compute_histogram_log_likelihood( + gt_features['linear_acceleration'], + sim_features['linear_acceleration'], + min_val, max_val, num_bins, smoothing + ) + metrics['linear_acceleration_log_likelihood'] = log_likelihood + metrics['linear_acceleration_likelihood'] = np.exp(log_likelihood) + + # Angular speed + min_val = config.getfloat('angular_speed', 'histogram.min_val') + max_val = config.getfloat('angular_speed', 'histogram.max_val') + num_bins = config.getint('angular_speed', 'histogram.num_bins') + smoothing = config.getfloat('angular_speed', 'histogram.additive_smoothing_pseudocount') + + log_likelihood = compute_histogram_log_likelihood( + gt_features['angular_speed'], + sim_features['angular_speed'], + min_val, max_val, num_bins, smoothing + ) + metrics['angular_speed_log_likelihood'] = log_likelihood + metrics['angular_speed_likelihood'] = np.exp(log_likelihood) + + # Angular acceleration + min_val = config.getfloat('angular_acceleration', 'histogram.min_val') + max_val = config.getfloat('angular_acceleration', 'histogram.max_val') + num_bins = config.getint('angular_acceleration', 'histogram.num_bins') + smoothing = config.getfloat('angular_acceleration', 'histogram.additive_smoothing_pseudocount') + + log_likelihood = compute_histogram_log_likelihood( + gt_features['angular_acceleration'], + sim_features['angular_acceleration'], + min_val, max_val, num_bins, smoothing + ) + metrics['angular_acceleration_log_likelihood'] = log_likelihood + metrics['angular_acceleration_likelihood'] = np.exp(log_likelihood) + + # Compute weighted kinematic realism score (metametric) + weights = { + 'linear_speed': config.getfloat('linear_speed', 'metametric_weight'), + 'linear_acceleration': config.getfloat('linear_acceleration', 'metametric_weight'), + 'angular_speed': config.getfloat('angular_speed', 'metametric_weight'), + 'angular_acceleration': config.getfloat('angular_acceleration', 'metametric_weight'), + } + + kinematic_score = ( + weights['linear_speed'] * metrics['linear_speed_likelihood'] + + weights['linear_acceleration'] * metrics['linear_acceleration_likelihood'] + + weights['angular_speed'] * metrics['angular_speed_likelihood'] + + weights['angular_acceleration'] * metrics['angular_acceleration_likelihood'] + ) / sum(weights.values()) + + metrics['kinematic_realism_score'] = kinematic_score + + return metrics + + +def compute_leave_one_out_baseline( + all_segment_trajectories: List[Dict[str, np.ndarray]], + dt: float = 0.1, +) -> Dict[str, float]: + """Compute leave-one-out kinematic realism baseline for ground truth. + + Each segment is evaluated against the distribution built from all other segments. + + Args: + all_segment_trajectories: List of trajectory dicts with {x, y, heading, valid, ...} + dt: Time step in seconds + + Returns: + Dict with baseline kinematic realism metrics + """ + config = load_wosac_config() + n_segments = len(all_segment_trajectories) + + # Compute features for all segments + all_features = [] + for traj in all_segment_trajectories: + features = compute_kinematic_features( + traj['x'], traj['y'], traj['heading'], traj['valid'], dt + ) + all_features.append(features) + + # Leave-one-out evaluation + segment_scores = [] + + for i in range(n_segments): + # Build distribution from all segments except i + other_features = [all_features[j] for j in range(n_segments) if j != i] + population_features = aggregate_features_across_segments(other_features) + + # Evaluate segment i against population + test_features = aggregate_features_across_segments([all_features[i]]) + + score = compute_kinematic_realism_score( + test_features, + population_features, + config + ) + segment_scores.append(score) + + # Average across all segments + avg_metrics = {} + for key in segment_scores[0].keys(): + values = [s[key] for s in segment_scores if not np.isnan(s[key])] + avg_metrics[key] = float(np.mean(values)) if values else np.nan + + return avg_metrics + + +def format_kinematic_summary(metrics: Dict[str, float], model_name: str = '') -> str: + """Format kinematic realism metrics as readable string. + + Args: + metrics: Dict from compute_kinematic_realism_score() + model_name: Optional model name for display + + Returns: + Formatted string + """ + lines = [] + lines.append(f"\n{'='*60}") + lines.append(f"Kinematic Realism Metrics: {model_name}") + lines.append(f"{'='*60}") + lines.append(f" Linear speed: {metrics['linear_speed_likelihood']:.4f}") + lines.append(f" Linear acceleration: {metrics['linear_acceleration_likelihood']:.4f}") + lines.append(f" Angular speed: {metrics['angular_speed_likelihood']:.4f}") + lines.append(f" Angular acceleration: {metrics['angular_acceleration_likelihood']:.4f}") + lines.append(f"{'─'*60}") + lines.append(f" Kinematic Realism Score: {metrics['kinematic_realism_score']:.4f}") + lines.append(f"{'='*60}") + + return '\n'.join(lines) diff --git a/pufferlib/ocean/benchmark/trajectory_approximation/trajectory_loader.py b/pufferlib/ocean/benchmark/trajectory_approximation/trajectory_loader.py new file mode 100644 index 000000000..7325486db --- /dev/null +++ b/pufferlib/ocean/benchmark/trajectory_approximation/trajectory_loader.py @@ -0,0 +1,119 @@ +"""Load Waymo trajectory data from JSON files.""" + +import json +import numpy as np +import glob +import os + + +def load_trajectory_from_json(json_path): + """ + Load Waymo trajectory data from JSON file. + + Args: + json_path: Path to JSON file + Returns: + List of trajectory dicts, one per vehicle object + """ + with open(json_path, 'r') as f: + scenario = json.load(f) + + trajectories = [] + scenario_name = os.path.basename(json_path) + + for obj in scenario['objects']: + # Only process vehicles (skip pedestrians) + if obj['type'] != 'vehicle': + continue + + # Extract trajectory data + traj = { + 'x': np.array([p['x'] for p in obj['position']], dtype=np.float32), + 'y': np.array([p['y'] for p in obj['position']], dtype=np.float32), + 'z': np.array([p['z'] for p in obj['position']], dtype=np.float32), + 'heading': np.array(obj['heading'], dtype=np.float32), + 'vx': np.array([v['x'] for v in obj['velocity']], dtype=np.float32), + 'vy': np.array([v['y'] for v in obj['velocity']], dtype=np.float32), + 'valid': np.array(obj['valid'], dtype=bool), + 'width': obj['width'], + 'length': obj['length'], + 'height': obj['height'], + 'object_id': obj['id'], + 'scenario': scenario_name, + } + + # Filter out invalid velocity sentinels (-10000 indicates missing data) + invalid_mask = (traj['vx'] == -10000) | (traj['vy'] == -10000) + traj['valid'] = traj['valid'] & ~invalid_mask + + trajectories.append(traj) + + return trajectories + + +def load_trajectories_from_directory(data_dir, num_scenarios=None, pattern='*.json'): + """ + Load trajectories from all JSON files in a directory. + + Args: + data_dir: Directory containing JSON files + num_scenarios: Maximum number of scenarios to load (None = all) + pattern: Glob pattern for files (default '*.json') + Returns: + List of all trajectory dicts from all scenarios + """ + json_files = sorted(glob.glob(os.path.join(data_dir, pattern))) + + if num_scenarios is not None: + json_files = json_files[:num_scenarios] + + all_trajectories = [] + for json_path in json_files: + try: + trajectories = load_trajectory_from_json(json_path) + all_trajectories.extend(trajectories) + except Exception as e: + print(f"Warning: Failed to load {json_path}: {e}") + continue + + return all_trajectories + + +def get_trajectory_statistics(trajectories): + """ + Compute statistics across a list of trajectories. + + Args: + trajectories: List of trajectory dicts + Returns: + Dict with statistics + """ + num_trajectories = len(trajectories) + + # Count valid timesteps + total_valid_timesteps = sum(np.sum(traj['valid']) for traj in trajectories) + avg_valid_timesteps = total_valid_timesteps / num_trajectories if num_trajectories > 0 else 0 + + # Speed statistics + speeds = [] + for traj in trajectories: + speed = np.sqrt(traj['vx'][traj['valid']]**2 + traj['vy'][traj['valid']]**2) + speeds.extend(speed.tolist()) + speeds = np.array(speeds) + + # Length statistics + lengths = np.array([traj['length'] for traj in trajectories]) + widths = np.array([traj['width'] for traj in trajectories]) + + return { + 'num_trajectories': num_trajectories, + 'avg_valid_timesteps': avg_valid_timesteps, + 'total_valid_timesteps': total_valid_timesteps, + 'speed_mean': np.mean(speeds) if len(speeds) > 0 else 0, + 'speed_std': np.std(speeds) if len(speeds) > 0 else 0, + 'speed_max': np.max(speeds) if len(speeds) > 0 else 0, + 'length_mean': np.mean(lengths), + 'length_std': np.std(lengths), + 'width_mean': np.mean(widths), + 'width_std': np.std(widths), + } diff --git a/pufferlib/ocean/benchmark/trajectory_approximation/trajectory_utils.py b/pufferlib/ocean/benchmark/trajectory_approximation/trajectory_utils.py new file mode 100644 index 000000000..1d0ddb745 --- /dev/null +++ b/pufferlib/ocean/benchmark/trajectory_approximation/trajectory_utils.py @@ -0,0 +1,145 @@ +"""Utility functions for trajectory processing and state initialization.""" + +import numpy as np + + +def wrap_angle(angle): + """ + Wrap angle to [-pi, pi] range. + + Args: + angle: Angle or array of angles in radians + Returns: + Wrapped angle in [-pi, pi] + """ + return (angle + np.pi) % (2 * np.pi) - np.pi + + +def find_valid_segments(valid_mask, min_length=10): + """ + Extract continuous valid segments from trajectory. + + Args: + valid_mask: Boolean array indicating valid timesteps + min_length: Minimum segment length to consider + Returns: + List of (start_idx, end_idx) tuples + """ + segments = [] + in_segment = False + start = 0 + + for i, is_valid in enumerate(valid_mask): + if is_valid and not in_segment: + start = i + in_segment = True + elif not is_valid and in_segment: + if i - start >= min_length: + segments.append((start, i)) + in_segment = False + + # Handle segment at end + if in_segment and len(valid_mask) - start >= min_length: + segments.append((start, len(valid_mask))) + + return segments + + +def extract_segment(trajectory, start, end): + """ + Extract a segment from a trajectory dict. + + Args: + trajectory: Dict with arrays {x, y, z, heading, vx, vy, valid, ...} + start: Start index (inclusive) + end: End index (exclusive) + Returns: + Trajectory dict with sliced arrays + """ + segment = {} + for key in ['x', 'y', 'z', 'heading', 'vx', 'vy', 'valid']: + if key in trajectory: + segment[key] = trajectory[key][start:end].copy() + + # Copy scalar values + for key in ['width', 'length', 'height', 'object_id']: + if key in trajectory: + segment[key] = trajectory[key] + + return segment + + +def initialize_classic_state(trajectory): + """ + Initialize CLASSIC dynamics state from trajectory. + + Args: + trajectory: Dict with {x, y, heading, vx, vy, length, ...} + Returns: + Initial state dict for CLASSIC model + """ + return { + 'x': trajectory['x'][0], + 'y': trajectory['y'][0], + 'heading': trajectory['heading'][0], + 'vx': trajectory['vx'][0], + 'vy': trajectory['vy'][0], + 'length': trajectory['length'], + } + + +def initialize_jerk_state(trajectory, dt=0.1): + """ + Initialize JERK dynamics state from trajectory. + Estimates initial acceleration and steering angle from first few timesteps. + + Args: + trajectory: Dict with {x, y, heading, vx, vy, length, ...} + dt: Time delta between steps (default 0.1s) + Returns: + Initial state dict for JERK model + """ + # Compute initial speed + v0 = np.sqrt(trajectory['vx'][0]**2 + trajectory['vy'][0]**2) + + # Estimate longitudinal acceleration from speed change + if len(trajectory['vx']) > 1: + v1 = np.sqrt(trajectory['vx'][1]**2 + trajectory['vy'][1]**2) + a_long_init = (v1 - v0) / dt + a_long_init = np.clip(a_long_init, -5.0, 2.5) + else: + a_long_init = 0.0 + + # Estimate curvature from heading change + if len(trajectory['heading']) > 1: + delta_heading = wrap_angle(trajectory['heading'][1] - trajectory['heading'][0]) + distance = np.sqrt( + (trajectory['x'][1] - trajectory['x'][0])**2 + + (trajectory['y'][1] - trajectory['y'][0])**2 + ) + # Avoid division by zero + curvature = delta_heading / max(distance, 1e-5) if distance > 1e-5 else 0.0 + else: + curvature = 0.0 + + # Estimate steering angle from curvature + wheelbase = trajectory['length'] # Approximate wheelbase as vehicle length + steering_angle_init = np.arctan(curvature * wheelbase) + steering_angle_init = np.clip(steering_angle_init, -0.55, 0.55) + + # Estimate lateral acceleration from curvature + # a_lat = v^2 * curvature + a_lat_init = v0**2 * curvature + a_lat_init = np.clip(a_lat_init, -4.0, 4.0) + + return { + 'x': trajectory['x'][0], + 'y': trajectory['y'][0], + 'heading': trajectory['heading'][0], + 'vx': trajectory['vx'][0], + 'vy': trajectory['vy'][0], + 'a_long': a_long_init, + 'a_lat': a_lat_init, + 'steering_angle': steering_angle_init, + 'wheelbase': wheelbase, + } diff --git a/pufferlib/ocean/benchmark/trajectory_approximation/visualizer.py b/pufferlib/ocean/benchmark/trajectory_approximation/visualizer.py new file mode 100644 index 000000000..4019e247e --- /dev/null +++ b/pufferlib/ocean/benchmark/trajectory_approximation/visualizer.py @@ -0,0 +1,391 @@ +"""Visualization functions for trajectory approximation results.""" + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use('Agg') # Use non-interactive backend + + +def plot_trajectory_comparison(ground_truth, classic_approx, jerk_approx, + segment_range, output_path): + """ + Create 6-panel comparison plot. + + Panels: + - Row 1: XY trajectory, Speed over time, Heading over time + - Row 2: Position error, Velocity error, Heading error + + Args: + ground_truth: Ground truth trajectory dict + classic_approx: CLASSIC model approximation dict + jerk_approx: JERK model approximation dict + segment_range: Tuple (start, end) for segment indices + output_path: Path to save the plot + """ + fig, axes = plt.subplots(2, 3, figsize=(18, 10)) + start, end = segment_range + + # Time axis + t = np.arange(end - start) * 0.1 # 0.1s timestep + + # Panel 1: XY trajectory + axes[0, 0].plot(ground_truth['x'], ground_truth['y'], + 'k-', label='Ground Truth', linewidth=2) + axes[0, 0].plot(classic_approx['x'], classic_approx['y'], + 'r--', label='CLASSIC', alpha=0.7, linewidth=1.5) + axes[0, 0].plot(jerk_approx['x'], jerk_approx['y'], + 'b--', label='JERK', alpha=0.7, linewidth=1.5) + axes[0, 0].set_xlabel('X (m)') + axes[0, 0].set_ylabel('Y (m)') + axes[0, 0].legend() + axes[0, 0].set_title('Trajectory Comparison') + axes[0, 0].axis('equal') + axes[0, 0].grid(True, alpha=0.3) + + # Panel 2: Speed over time + gt_speed = np.sqrt(ground_truth['vx']**2 + ground_truth['vy']**2) + classic_speed = np.sqrt(classic_approx['vx']**2 + classic_approx['vy']**2) + jerk_speed = np.sqrt(jerk_approx['vx']**2 + jerk_approx['vy']**2) + + axes[0, 1].plot(t, gt_speed, 'k-', label='Ground Truth', linewidth=2) + axes[0, 1].plot(t, classic_speed, 'r--', label='CLASSIC', alpha=0.7, linewidth=1.5) + axes[0, 1].plot(t, jerk_speed, 'b--', label='JERK', alpha=0.7, linewidth=1.5) + axes[0, 1].set_xlabel('Time (s)') + axes[0, 1].set_ylabel('Speed (m/s)') + axes[0, 1].legend() + axes[0, 1].set_title('Speed Profile') + axes[0, 1].grid(True, alpha=0.3) + + # Panel 3: Heading over time + axes[0, 2].plot(t, ground_truth['heading'] * 180/np.pi, + 'k-', label='Ground Truth', linewidth=2) + axes[0, 2].plot(t, classic_approx['heading'] * 180/np.pi, + 'r--', label='CLASSIC', alpha=0.7, linewidth=1.5) + axes[0, 2].plot(t, jerk_approx['heading'] * 180/np.pi, + 'b--', label='JERK', alpha=0.7, linewidth=1.5) + axes[0, 2].set_xlabel('Time (s)') + axes[0, 2].set_ylabel('Heading (deg)') + axes[0, 2].legend() + axes[0, 2].set_title('Heading Profile') + axes[0, 2].grid(True, alpha=0.3) + + # Panel 4: Position error over time + classic_pos_error = np.sqrt( + (classic_approx['x'] - ground_truth['x'])**2 + + (classic_approx['y'] - ground_truth['y'])**2 + ) + jerk_pos_error = np.sqrt( + (jerk_approx['x'] - ground_truth['x'])**2 + + (jerk_approx['y'] - ground_truth['y'])**2 + ) + + axes[1, 0].plot(t, classic_pos_error, 'r-', label='CLASSIC', linewidth=1.5) + axes[1, 0].plot(t, jerk_pos_error, 'b-', label='JERK', linewidth=1.5) + axes[1, 0].set_xlabel('Time (s)') + axes[1, 0].set_ylabel('Position Error (m)') + axes[1, 0].legend() + axes[1, 0].set_title('Position Error Over Time') + axes[1, 0].grid(True, alpha=0.3) + + # Panel 5: Speed error over time + classic_speed_error = np.abs(classic_speed - gt_speed) + jerk_speed_error = np.abs(jerk_speed - gt_speed) + + axes[1, 1].plot(t, classic_speed_error, 'r-', label='CLASSIC', linewidth=1.5) + axes[1, 1].plot(t, jerk_speed_error, 'b-', label='JERK', linewidth=1.5) + axes[1, 1].set_xlabel('Time (s)') + axes[1, 1].set_ylabel('Speed Error (m/s)') + axes[1, 1].legend() + axes[1, 1].set_title('Speed Error Over Time') + axes[1, 1].grid(True, alpha=0.3) + + # Panel 6: Heading error over time + from .trajectory_utils import wrap_angle + classic_heading_error = np.abs(wrap_angle( + classic_approx['heading'] - ground_truth['heading'] + )) * 180/np.pi + jerk_heading_error = np.abs(wrap_angle( + jerk_approx['heading'] - ground_truth['heading'] + )) * 180/np.pi + + axes[1, 2].plot(t, classic_heading_error, 'r-', label='CLASSIC', linewidth=1.5) + axes[1, 2].plot(t, jerk_heading_error, 'b-', label='JERK', linewidth=1.5) + axes[1, 2].set_xlabel('Time (s)') + axes[1, 2].set_ylabel('Heading Error (deg)') + axes[1, 2].legend() + axes[1, 2].set_title('Heading Error Over Time') + axes[1, 2].grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches='tight') + plt.close() + + +def plot_jerk_ablation_trajectory(ground_truth, baseline_approx, expanded_approx, + segment_range, output_path, config_names=None): + """ + Create 6-panel comparison plot for JERK ablation study. + + Panels: + - Row 1: XY trajectory, Speed over time, Heading over time + - Row 2: Position error, Velocity error, Heading error + + Args: + ground_truth: Ground truth trajectory dict + baseline_approx: Baseline JERK approximation dict + expanded_approx: Expanded JERK approximation dict + segment_range: Tuple (start, end) for segment indices + output_path: Path to save the plot + config_names: Optional tuple of (baseline_name, expanded_name) + """ + if config_names is None: + config_names = ('JERK Baseline', 'JERK Expanded') + + baseline_name, expanded_name = config_names + + fig, axes = plt.subplots(2, 3, figsize=(18, 10)) + start, end = segment_range + + # Time axis + t = np.arange(end - start) * 0.1 # 0.1s timestep + + # Panel 1: XY trajectory + axes[0, 0].plot(ground_truth['x'], ground_truth['y'], + 'k-', label='Ground Truth', linewidth=2) + axes[0, 0].plot(baseline_approx['x'], baseline_approx['y'], + 'r--', label=baseline_name, alpha=0.7, linewidth=1.5) + axes[0, 0].plot(expanded_approx['x'], expanded_approx['y'], + 'b--', label=expanded_name, alpha=0.7, linewidth=1.5) + axes[0, 0].set_xlabel('X (m)') + axes[0, 0].set_ylabel('Y (m)') + axes[0, 0].legend() + axes[0, 0].set_title('Trajectory Comparison') + axes[0, 0].axis('equal') + axes[0, 0].grid(True, alpha=0.3) + + # Panel 2: Speed over time + gt_speed = np.sqrt(ground_truth['vx']**2 + ground_truth['vy']**2) + baseline_speed = np.sqrt(baseline_approx['vx']**2 + baseline_approx['vy']**2) + expanded_speed = np.sqrt(expanded_approx['vx']**2 + expanded_approx['vy']**2) + + axes[0, 1].plot(t, gt_speed, 'k-', label='Ground Truth', linewidth=2) + axes[0, 1].plot(t, baseline_speed, 'r--', label=baseline_name, alpha=0.7, linewidth=1.5) + axes[0, 1].plot(t, expanded_speed, 'b--', label=expanded_name, alpha=0.7, linewidth=1.5) + axes[0, 1].set_xlabel('Time (s)') + axes[0, 1].set_ylabel('Speed (m/s)') + axes[0, 1].legend() + axes[0, 1].set_title('Speed Profile') + axes[0, 1].grid(True, alpha=0.3) + + # Panel 3: Heading over time + axes[0, 2].plot(t, ground_truth['heading'] * 180/np.pi, + 'k-', label='Ground Truth', linewidth=2) + axes[0, 2].plot(t, baseline_approx['heading'] * 180/np.pi, + 'r--', label=baseline_name, alpha=0.7, linewidth=1.5) + axes[0, 2].plot(t, expanded_approx['heading'] * 180/np.pi, + 'b--', label=expanded_name, alpha=0.7, linewidth=1.5) + axes[0, 2].set_xlabel('Time (s)') + axes[0, 2].set_ylabel('Heading (deg)') + axes[0, 2].legend() + axes[0, 2].set_title('Heading Profile') + axes[0, 2].grid(True, alpha=0.3) + + # Panel 4: Position error over time + baseline_pos_error = np.sqrt( + (baseline_approx['x'] - ground_truth['x'])**2 + + (baseline_approx['y'] - ground_truth['y'])**2 + ) + expanded_pos_error = np.sqrt( + (expanded_approx['x'] - ground_truth['x'])**2 + + (expanded_approx['y'] - ground_truth['y'])**2 + ) + + axes[1, 0].plot(t, baseline_pos_error, 'r-', label=baseline_name, linewidth=1.5) + axes[1, 0].plot(t, expanded_pos_error, 'b-', label=expanded_name, linewidth=1.5) + axes[1, 0].set_xlabel('Time (s)') + axes[1, 0].set_ylabel('Position Error (m)') + axes[1, 0].legend() + axes[1, 0].set_title('Position Error Over Time') + axes[1, 0].grid(True, alpha=0.3) + + # Panel 5: Speed error over time + baseline_speed_error = np.abs(baseline_speed - gt_speed) + expanded_speed_error = np.abs(expanded_speed - gt_speed) + + axes[1, 1].plot(t, baseline_speed_error, 'r-', label=baseline_name, linewidth=1.5) + axes[1, 1].plot(t, expanded_speed_error, 'b-', label=expanded_name, linewidth=1.5) + axes[1, 1].set_xlabel('Time (s)') + axes[1, 1].set_ylabel('Speed Error (m/s)') + axes[1, 1].legend() + axes[1, 1].set_title('Speed Error Over Time') + axes[1, 1].grid(True, alpha=0.3) + + # Panel 6: Heading error over time + from .trajectory_utils import wrap_angle + baseline_heading_error = np.abs(wrap_angle( + baseline_approx['heading'] - ground_truth['heading'] + )) * 180/np.pi + expanded_heading_error = np.abs(wrap_angle( + expanded_approx['heading'] - ground_truth['heading'] + )) * 180/np.pi + + axes[1, 2].plot(t, baseline_heading_error, 'r-', label=baseline_name, linewidth=1.5) + axes[1, 2].plot(t, expanded_heading_error, 'b-', label=expanded_name, linewidth=1.5) + axes[1, 2].set_xlabel('Time (s)') + axes[1, 2].set_ylabel('Heading Error (deg)') + axes[1, 2].legend() + axes[1, 2].set_title('Heading Error Over Time') + axes[1, 2].grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches='tight') + plt.close() + + +def plot_aggregate_metrics(all_metrics_classic, all_metrics_jerk, output_path): + """ + Create box plots comparing CLASSIC vs JERK across all trajectories. + + Args: + all_metrics_classic: List of metric dicts for CLASSIC model + all_metrics_jerk: List of metric dicts for JERK model + output_path: Path to save the plot + """ + fig, axes = plt.subplots(2, 3, figsize=(18, 10)) + + # Extract metrics for CLASSIC + classic_ade = [m['position']['ade'] for m in all_metrics_classic] + classic_speed = [m['velocity']['speed_mae'] for m in all_metrics_classic] + classic_heading = [m['heading']['mae'] for m in all_metrics_classic] + classic_fde = [m['position']['fde'] for m in all_metrics_classic] + classic_speed_max = [m['velocity']['speed_max'] for m in all_metrics_classic] + classic_heading_max = [m['heading']['max'] for m in all_metrics_classic] + + # Extract metrics for JERK + jerk_ade = [m['position']['ade'] for m in all_metrics_jerk] + jerk_speed = [m['velocity']['speed_mae'] for m in all_metrics_jerk] + jerk_heading = [m['heading']['mae'] for m in all_metrics_jerk] + jerk_fde = [m['position']['fde'] for m in all_metrics_jerk] + jerk_speed_max = [m['velocity']['speed_max'] for m in all_metrics_jerk] + jerk_heading_max = [m['heading']['max'] for m in all_metrics_jerk] + + # Panel 1: Average Displacement Error + bp1 = axes[0, 0].boxplot([classic_ade, jerk_ade], + labels=['CLASSIC', 'JERK'], + patch_artist=True) + bp1['boxes'][0].set_facecolor('red') + bp1['boxes'][1].set_facecolor('blue') + axes[0, 0].set_ylabel('ADE (m)') + axes[0, 0].set_title('Average Displacement Error') + axes[0, 0].grid(True, alpha=0.3, axis='y') + + # Panel 2: Speed MAE + bp2 = axes[0, 1].boxplot([classic_speed, jerk_speed], + labels=['CLASSIC', 'JERK'], + patch_artist=True) + bp2['boxes'][0].set_facecolor('red') + bp2['boxes'][1].set_facecolor('blue') + axes[0, 1].set_ylabel('Speed MAE (m/s)') + axes[0, 1].set_title('Speed Error') + axes[0, 1].grid(True, alpha=0.3, axis='y') + + # Panel 3: Heading MAE + bp3 = axes[0, 2].boxplot([classic_heading, jerk_heading], + labels=['CLASSIC', 'JERK'], + patch_artist=True) + bp3['boxes'][0].set_facecolor('red') + bp3['boxes'][1].set_facecolor('blue') + axes[0, 2].set_ylabel('Heading MAE (deg)') + axes[0, 2].set_title('Heading Error') + axes[0, 2].grid(True, alpha=0.3, axis='y') + + # Panel 4: Final Displacement Error + bp4 = axes[1, 0].boxplot([classic_fde, jerk_fde], + labels=['CLASSIC', 'JERK'], + patch_artist=True) + bp4['boxes'][0].set_facecolor('red') + bp4['boxes'][1].set_facecolor('blue') + axes[1, 0].set_ylabel('FDE (m)') + axes[1, 0].set_title('Final Displacement Error') + axes[1, 0].grid(True, alpha=0.3, axis='y') + + # Panel 5: Max Speed Error + bp5 = axes[1, 1].boxplot([classic_speed_max, jerk_speed_max], + labels=['CLASSIC', 'JERK'], + patch_artist=True) + bp5['boxes'][0].set_facecolor('red') + bp5['boxes'][1].set_facecolor('blue') + axes[1, 1].set_ylabel('Max Speed Error (m/s)') + axes[1, 1].set_title('Maximum Speed Error') + axes[1, 1].grid(True, alpha=0.3, axis='y') + + # Panel 6: Max Heading Error + bp6 = axes[1, 2].boxplot([classic_heading_max, jerk_heading_max], + labels=['CLASSIC', 'JERK'], + patch_artist=True) + bp6['boxes'][0].set_facecolor('red') + bp6['boxes'][1].set_facecolor('blue') + axes[1, 2].set_ylabel('Max Heading Error (deg)') + axes[1, 2].set_title('Maximum Heading Error') + axes[1, 2].grid(True, alpha=0.3, axis='y') + + # Add overall title + fig.suptitle(f'Trajectory Approximation Comparison ({len(all_metrics_classic)} segments)', + fontsize=16, fontweight='bold') + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches='tight') + plt.close() + + +def plot_error_histogram(all_metrics_classic, all_metrics_jerk, output_path): + """ + Create histograms of error distributions. + + Args: + all_metrics_classic: List of metric dicts for CLASSIC model + all_metrics_jerk: List of metric dicts for JERK model + output_path: Path to save the plot + """ + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + + # Extract metrics + classic_ade = [m['position']['ade'] for m in all_metrics_classic] + jerk_ade = [m['position']['ade'] for m in all_metrics_jerk] + + classic_speed = [m['velocity']['speed_mae'] for m in all_metrics_classic] + jerk_speed = [m['velocity']['speed_mae'] for m in all_metrics_jerk] + + classic_heading = [m['heading']['mae'] for m in all_metrics_classic] + jerk_heading = [m['heading']['mae'] for m in all_metrics_jerk] + + # Panel 1: Position error histogram + axes[0].hist(classic_ade, bins=30, alpha=0.5, color='red', label='CLASSIC') + axes[0].hist(jerk_ade, bins=30, alpha=0.5, color='blue', label='JERK') + axes[0].set_xlabel('ADE (m)') + axes[0].set_ylabel('Frequency') + axes[0].set_title('Position Error Distribution') + axes[0].legend() + axes[0].grid(True, alpha=0.3, axis='y') + + # Panel 2: Speed error histogram + axes[1].hist(classic_speed, bins=30, alpha=0.5, color='red', label='CLASSIC') + axes[1].hist(jerk_speed, bins=30, alpha=0.5, color='blue', label='JERK') + axes[1].set_xlabel('Speed MAE (m/s)') + axes[1].set_ylabel('Frequency') + axes[1].set_title('Speed Error Distribution') + axes[1].legend() + axes[1].grid(True, alpha=0.3, axis='y') + + # Panel 3: Heading error histogram + axes[2].hist(classic_heading, bins=30, alpha=0.5, color='red', label='CLASSIC') + axes[2].hist(jerk_heading, bins=30, alpha=0.5, color='blue', label='JERK') + axes[2].set_xlabel('Heading MAE (deg)') + axes[2].set_ylabel('Frequency') + axes[2].set_title('Heading Error Distribution') + axes[2].legend() + axes[2].grid(True, alpha=0.3, axis='y') + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches='tight') + plt.close()