From 8fb9c20cb4acc00391a35e43a4e564c0e37a4424 Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Mon, 16 Feb 2026 15:03:56 -0500 Subject: [PATCH 1/5] working atlas eval --- examples/Atlas/build_atlas_problem.jl | 19 +- examples/Atlas/mwe_get_next_state_failure.jl | 394 +++++++++++++++++++ examples/Atlas/train_dr_atlas_det_eq.jl | 137 ++++++- examples/Atlas/visualize_atlas_policy.jl | 8 +- 4 files changed, 529 insertions(+), 29 deletions(-) create mode 100644 examples/Atlas/mwe_get_next_state_failure.jl diff --git a/examples/Atlas/build_atlas_problem.jl b/examples/Atlas/build_atlas_problem.jl index 82e8e8d..748942e 100644 --- a/examples/Atlas/build_atlas_problem.jl +++ b/examples/Atlas/build_atlas_problem.jl @@ -73,11 +73,20 @@ function build_atlas_subproblems(; # Default optimizer if isnothing(optimizer) - optimizer = () -> DiffOpt.diff_optimizer(optimizer_with_attributes(Ipopt.Optimizer, - "print_level" => 0, - "hsllib" => HSL_jll.libhsl_path, - "linear_solver" => "ma27" - )) + optimizer = () -> begin + m = DiffOpt.diff_optimizer( + optimizer_with_attributes( + Ipopt.Optimizer, + "print_level" => 0, + "hsllib" => HSL_jll.libhsl_path, + "linear_solver" => "ma27", + ), + ) + # Atlas dynamics are encoded with a VectorNonlinearOracle constraint. + # Force the nonlinear DiffOpt backend so reverse differentiation works. + MOI.set(m, DiffOpt.ModelConstructor(), DiffOpt.NonLinearProgram.Model) + return m + end end if perturbation_frequency < 1 diff --git a/examples/Atlas/mwe_get_next_state_failure.jl b/examples/Atlas/mwe_get_next_state_failure.jl new file mode 100644 index 0000000..72245da --- /dev/null +++ b/examples/Atlas/mwe_get_next_state_failure.jl @@ -0,0 +1,394 @@ +using JuMP +using DiffOpt +import Ipopt, HSL_jll +import MathOptInterface as MOI +using Statistics +using LinearAlgebra +using Random + +include(joinpath(@__DIR__, "build_atlas_problem.jl")) + +function sample_uncertainty(uncertainty_stage::Vector{Tuple{VariableRef, Vector{Float64}}}) + return [(p, vals[1]) for (p, vals) in uncertainty_stage] +end + +function set_stage_data!( + subproblem::JuMP.Model, + state_param_in::Vector{Any}, + state_param_out::Vector{Tuple{Any, VariableRef}}, + uncertainty::Vector{Tuple{VariableRef, Float64}}, + state_in::Vector{Float64}, + state_out_target::Vector{Float64}, +) + for (i, state_var) in enumerate(state_param_in) + set_parameter_value(state_var, state_in[i]) + end + for (uncertainty_param, uncertainty_value) in uncertainty + set_parameter_value(uncertainty_param, uncertainty_value) + end + for i in eachindex(state_param_out) + set_parameter_value(state_param_out[i][1], state_out_target[i]) + end + return +end + +function solve_and_get_next_state!( + subproblem::JuMP.Model, + state_param_in::Vector{Any}, + state_param_out::Vector{Tuple{Any, VariableRef}}, + uncertainty::Vector{Tuple{VariableRef, Float64}}, + state_in::Vector{Float64}, + state_out_target::Vector{Float64}, +) + set_stage_data!( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + state_out_target, + ) + optimize!(subproblem) + term = termination_status(subproblem) + if !(term in (MOI.OPTIMAL, MOI.LOCALLY_SOLVED)) + error("Unexpected solver status: $term") + end + return [value(state_param_out[i][2]) for i in eachindex(state_param_out)] +end + +function scalar_output( + subproblem::JuMP.Model, + state_param_in::Vector{Any}, + state_param_out::Vector{Tuple{Any, VariableRef}}, + uncertainty::Vector{Tuple{VariableRef, Float64}}, + state_in::Vector{Float64}, + state_out_target::Vector{Float64}, + Δy::Vector{Float64}, +) + y = solve_and_get_next_state!( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + state_out_target, + ) + return dot(Δy, y), y +end + +function reverse_vjp( + subproblem::JuMP.Model, + state_param_in::Vector{Any}, + state_param_out::Vector{Tuple{Any, VariableRef}}, + uncertainty::Vector{Tuple{VariableRef, Float64}}, + state_in::Vector{Float64}, + state_out_target::Vector{Float64}, + Δy::Vector{Float64}, +) + ϕ, y = scalar_output( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + state_out_target, + Δy, + ) + + DiffOpt.empty_input_sensitivities!(subproblem) + for i in eachindex(state_param_out) + DiffOpt.set_reverse_variable(subproblem, state_param_out[i][2], Δy[i]) + end + DiffOpt.reverse_differentiate!(subproblem) + + g_state_in = [DiffOpt.get_reverse_parameter(subproblem, p) for p in state_param_in] + g_state_target = [DiffOpt.get_reverse_parameter(subproblem, p) for (p, _) in state_param_out] + return ϕ, y, g_state_in, g_state_target +end + +function fd_direction_state_in( + subproblem::JuMP.Model, + state_param_in::Vector{Any}, + state_param_out::Vector{Tuple{Any, VariableRef}}, + uncertainty::Vector{Tuple{VariableRef, Float64}}, + state_in::Vector{Float64}, + state_out_target::Vector{Float64}, + Δy::Vector{Float64}, + direction::Vector{Float64}; + eps::Float64 = 1e-5, +) + state_plus = state_in .+ eps .* direction + state_minus = state_in .- eps .* direction + ϕ_plus, _ = scalar_output( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_plus, + state_out_target, + Δy, + ) + ϕ_minus, _ = scalar_output( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_minus, + state_out_target, + Δy, + ) + return (ϕ_plus - ϕ_minus) / (2eps) +end + +function fd_direction_state_target( + subproblem::JuMP.Model, + state_param_in::Vector{Any}, + state_param_out::Vector{Tuple{Any, VariableRef}}, + uncertainty::Vector{Tuple{VariableRef, Float64}}, + state_in::Vector{Float64}, + state_out_target::Vector{Float64}, + Δy::Vector{Float64}, + direction::Vector{Float64}; + eps::Float64 = 1e-5, +) + target_plus = state_out_target .+ eps .* direction + target_minus = state_out_target .- eps .* direction + ϕ_plus, _ = scalar_output( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + target_plus, + Δy, + ) + ϕ_minus, _ = scalar_output( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + target_minus, + Δy, + ) + return (ϕ_plus - ϕ_minus) / (2eps) +end + +function fd_component_state_in( + subproblem::JuMP.Model, + state_param_in::Vector{Any}, + state_param_out::Vector{Tuple{Any, VariableRef}}, + uncertainty::Vector{Tuple{VariableRef, Float64}}, + state_in::Vector{Float64}, + state_out_target::Vector{Float64}, + Δy::Vector{Float64}, + idx::Int; + eps::Float64 = 1e-5, +) + e = zeros(length(state_in)) + e[idx] = 1.0 + return fd_direction_state_in( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + state_out_target, + Δy, + e; + eps = eps, + ) +end + +function fd_component_state_target( + subproblem::JuMP.Model, + state_param_in::Vector{Any}, + state_param_out::Vector{Tuple{Any, VariableRef}}, + uncertainty::Vector{Tuple{VariableRef, Float64}}, + state_in::Vector{Float64}, + state_out_target::Vector{Float64}, + Δy::Vector{Float64}, + idx::Int; + eps::Float64 = 1e-5, +) + e = zeros(length(state_out_target)) + e[idx] = 1.0 + return fd_direction_state_target( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + state_out_target, + Δy, + e; + eps = eps, + ) +end + +function spread_indices(n::Int) + return unique(sort([1, cld(n, 4), cld(n, 2), cld(3n, 4), n])) +end + +function report_subset_quality(label::String, g_rev::Vector{Float64}, g_fd::Vector{Float64}, idx::Vector{Int}) + abs_err = abs.(g_rev .- g_fd) + rel_err = abs_err ./ max.(abs.(g_fd), 1e-8) + i_max = argmax(abs_err) + println("$label coordinate-check quality:") + println(" checked_indices = $idx") + println(" max_abs_error = $(maximum(abs_err))") + println(" mean_abs_error = $(mean(abs_err))") + println(" max_rel_error = $(maximum(rel_err))") + println(" mean_rel_error = $(mean(rel_err))") + println(" worst_pair(rev,fd)= ($(g_rev[i_max]), $(g_fd[i_max])) at checked offset $i_max") +end + +function main() + println("Building Atlas with N=2 (one subproblem)...") + quick_optimizer = () -> begin + m = DiffOpt.diff_optimizer( + optimizer_with_attributes( + Ipopt.Optimizer, + "print_level" => 0, + "hsllib" => HSL_jll.libhsl_path, + "linear_solver" => "ma27", + "hessian_approximation" => "limited-memory", + "max_iter" => 50, + "tol" => 1e-3, + ), + ) + MOI.set(m, DiffOpt.ModelConstructor(), DiffOpt.NonLinearProgram.Model) + return m + end + + subproblems, state_params_in, state_params_out, initial_state, uncertainty_samples, + _, _, _, _, _ = build_atlas_subproblems(; + N = 2, + h = 0.01, + perturbation_scale = 0.5, + num_scenarios = 2, + penalty = 10.0, + perturbation_frequency = 51, + optimizer = quick_optimizer, + ) + + subproblem = subproblems[1] + state_param_in = state_params_in[1] + state_param_out = state_params_out[1] + uncertainty = sample_uncertainty(uncertainty_samples[1]) + state_in = copy(initial_state) + state_out_target = copy(initial_state) + + println("1) Solving baseline one-subproblem forward pass...") + ϕ0, y0 = scalar_output( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + state_out_target, + ones(length(state_param_out)), + ) + println(" phi(ones' * y) baseline = $ϕ0") + println(" output dimension = $(length(y0))") + + # Use a sparse cotangent so the reverse pass corresponds to one output component. + Δy = zeros(length(y0)) + Δy[1] = 1.0 + + println("2) Reverse-mode VJP with DiffOpt...") + _, _, g_rev_state_in, g_rev_state_target = reverse_vjp( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + state_out_target, + Δy, + ) + println(" reverse pass done") + + println("3) Finite-difference gradient checks (directional + coordinate subset)...") + eps = 1e-5 + + Random.seed!(42) + d_state_in = randn(length(g_rev_state_in)) + d_state_in ./= max(norm(d_state_in), 1e-16) + d_state_target = randn(length(g_rev_state_target)) + d_state_target ./= max(norm(d_state_target), 1e-16) + + fd_dir_state_in = fd_direction_state_in( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + state_out_target, + Δy, + d_state_in; + eps = eps, + ) + fd_dir_state_target = fd_direction_state_target( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + state_out_target, + Δy, + d_state_target; + eps = eps, + ) + + println("4) Gradient quality report") + rev_dir_state_in = dot(g_rev_state_in, d_state_in) + rev_dir_state_target = dot(g_rev_state_target, d_state_target) + println("state_in directional check:") + println(" reverse directional derivative = $rev_dir_state_in") + println(" finite-diff directional deriv = $fd_dir_state_in") + println(" abs_error = $(abs(rev_dir_state_in - fd_dir_state_in))") + println(" rel_error = $(abs(rev_dir_state_in - fd_dir_state_in) / max(abs(fd_dir_state_in), 1e-8))") + println("state_out_target directional check:") + println(" reverse directional derivative = $rev_dir_state_target") + println(" finite-diff directional deriv = $fd_dir_state_target") + println(" abs_error = $(abs(rev_dir_state_target - fd_dir_state_target))") + println(" rel_error = $(abs(rev_dir_state_target - fd_dir_state_target) / max(abs(fd_dir_state_target), 1e-8))") + + idx_in = spread_indices(length(g_rev_state_in)) + fd_subset_in = [ + fd_component_state_in( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + state_out_target, + Δy, + i; + eps = eps, + ) + for i in idx_in + ] + rev_subset_in = [g_rev_state_in[i] for i in idx_in] + report_subset_quality("state_in", rev_subset_in, fd_subset_in, idx_in) + + idx_target = spread_indices(length(g_rev_state_target)) + fd_subset_target = [ + fd_component_state_target( + subproblem, + state_param_in, + state_param_out, + uncertainty, + state_in, + state_out_target, + Δy, + i; + eps = eps, + ) + for i in idx_target + ] + rev_subset_target = [g_rev_state_target[i] for i in idx_target] + report_subset_quality("state_out_target", rev_subset_target, fd_subset_target, idx_target) +end + +main() diff --git a/examples/Atlas/train_dr_atlas_det_eq.jl b/examples/Atlas/train_dr_atlas_det_eq.jl index 23bbc51..0bfcc6e 100644 --- a/examples/Atlas/train_dr_atlas_det_eq.jl +++ b/examples/Atlas/train_dr_atlas_det_eq.jl @@ -8,6 +8,7 @@ using Flux using DecisionRules using Random using Statistics +using LinearAlgebra using JuMP import Ipopt, HSL_jll using Wandb, Dates, Logging @@ -22,21 +23,35 @@ include(joinpath(Atlas_dir, "build_atlas_problem.jl")) # ============================================================================ # Problem parameters -N = 50 # Number of time steps +N = 10 # Number of time steps h = 0.01 # Time step -perturbation_scale = 1.5 # Scale of random perturbations +perturbation_scale = 0.5 # Scale of random perturbations num_scenarios = 10 # Number of uncertainty samples per stage penalty = 10.0 # Penalty for state deviation -perturbation_frequency = 5 # Frequency of perturbations (every k stages) +perturbation_frequency = 301 # Frequency of perturbations (every k stages) # Training parameters -num_epochs = 1 -num_batches = 100 +num_epochs = 10 +num_batches = 10 num_train_per_batch = 1 layers = Int64[64, 64] activation = sigmoid optimizers = [Flux.Adam(0.001)] +# Initial-state augmentation via short policy rollouts +# If enabled, each eligible epoch starts training from a reachable state obtained +# by rolling out the current policy for a small random number of stages. +enable_rollout_initial_state_augmentation = true +rollout_start_epoch = 1 +rollout_every_epochs = 1 +rollout_max_horizon_fraction = 10.0 + +if enable_rollout_initial_state_augmentation + rollout_every_epochs < 1 && error("rollout_every_epochs must be >= 1") + (rollout_max_horizon_fraction <= 0) && + error("rollout_max_horizon_fraction must be in (0, 1]") +end + # Save paths model_dir = joinpath(Atlas_dir, "models") mkpath(model_dir) @@ -46,7 +61,7 @@ save_file = "atlas-balancing-deteq-N$(N)-$(now())" # - set to "latest" to use the most recent deterministic-equivalent model in `model_dir` # - set to a run name (with or without `.jld2`) or a full path # e.g. "atlas-balancing-deteq-N50-2026-02-02T21:16:37.554" -warmstart_model = "atlas-balancing-deteq-N50-2026-02-08T14:25:13.534" +warmstart_model = "atlas-balancing-deteq-N5-2026-02-14T23:37:16.303" # CLI override: # julia --project=. examples/Atlas/train_dr_atlas_det_eq.jl if !isempty(ARGS) @@ -59,8 +74,9 @@ end println("Building Atlas deterministic equivalent problem...") -# First build subproblems to get the structure -@time subproblems, state_params_in_sub, state_params_out_sub, initial_state, uncertainty_samples, +# Build one subproblem set dedicated to deterministic-equivalent construction. +# `deterministic_equivalent!` mutates this data. +@time subproblems_det, state_params_in_det, state_params_out_det, initial_state, uncertainty_samples_det_builder, _, _, x_ref, u_ref, atlas = build_atlas_subproblems(; N = N, h = h, @@ -70,6 +86,17 @@ println("Building Atlas deterministic equivalent problem...") perturbation_frequency = perturbation_frequency, ) +# Build a second, independent subproblem set used only for rollout-based initial-state generation. +@time rollout_subproblems, rollout_state_params_in, rollout_state_params_out, rollout_initial_state, rollout_uncertainty_samples, + _, _, _, _, _ = build_atlas_subproblems(; + N = N, + h = h, + perturbation_scale = perturbation_scale, + num_scenarios = num_scenarios, + penalty = penalty, + perturbation_frequency = perturbation_frequency, +) + # Build deterministic equivalent det_equivalent = DiffOpt.nonlinear_diff_model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0, @@ -81,16 +108,16 @@ det_equivalent = DiffOpt.nonlinear_diff_model(optimizer_with_attributes(Ipopt.Op # Convert subproblems to deterministic equivalent using DecisionRules det_equivalent, uncertainty_samples_det = DecisionRules.deterministic_equivalent!( det_equivalent, - subproblems, - state_params_in_sub, - state_params_out_sub, + subproblems_det, + state_params_in_det, + state_params_out_det, initial_state, - uncertainty_samples + uncertainty_samples_det_builder ) nx = atlas.nx nu = atlas.nu -n_perturb = length(uncertainty_samples[1]) # Number of perturbation parameters +n_perturb = length(rollout_uncertainty_samples[1]) # Number of perturbation parameters println("Atlas state dimension: $nx") println("Atlas control dimension: $nu") @@ -117,6 +144,10 @@ lg = WandbLogger( "nu" => nu, "formulation" => "deterministic_equivalent", "warmstart_model" => isnothing(warmstart_model) ? "none" : string(warmstart_model), + "enable_rollout_initial_state_augmentation" => enable_rollout_initial_state_augmentation, + "rollout_start_epoch" => rollout_start_epoch, + "rollout_every_epochs" => rollout_every_epochs, + "rollout_max_horizon_fraction" => rollout_max_horizon_fraction, ) ) @@ -131,7 +162,7 @@ end # Policy architecture: LSTM processes perturbations, Dense combines with previous state # This design is memory-efficient and allows the LSTM to focus on temporal patterns -n_uncertainties = length(uncertainty_samples[1]) +n_uncertainties = length(rollout_uncertainty_samples[1]) models = state_conditioned_policy(n_uncertainties, nx, nx, layers; activation=activation, encoder_type=Flux.LSTM) @@ -187,7 +218,7 @@ println(" Combiner (Dense): $(layers[end]) + $nx -> $nx") println("\nEvaluating initial policy...") Random.seed!(8788) objective_values = [simulate_multistage( - det_equivalent, state_params_in_sub, state_params_out_sub, + det_equivalent, state_params_in_det, state_params_out_det, initial_state, DecisionRules.sample(uncertainty_samples_det), models; ) for _ in 1:2] @@ -195,7 +226,7 @@ best_obj = mean(objective_values) println("Initial objective: $best_obj") # for testing visualization. fill x with visited states -# X[2:end] = [value.([var[2] for var in stage]) for stage in state_params_out_sub] +# X[2:end] = [value.([var[2] for var in stage]) for stage in state_params_out_det] # calculate distance from reference # dist = sum((X[t][i] - x_ref[i])^2 for i in 1:length(x_ref) for t in 1:length(X)) @@ -213,6 +244,44 @@ adjust_hyperparameters = (iter, opt_state, num_train_per_batch) -> begin return num_train_per_batch end +function rollout_reachable_initial_state( + model, + nominal_initial_state::Vector{Float64}, + subproblems::Vector{JuMP.Model}, + state_params_in, + state_params_out, + uncertainty_samples, + rollout_steps::Int, +) + uncertainty_sample = DecisionRules.sample(uncertainty_samples) + max_stages = min(rollout_steps, length(subproblems), length(uncertainty_sample)) + state_in = copy(nominal_initial_state) + Flux.reset!(model) + + for stage in 1:max_stages + uncertainty_stage = uncertainty_sample[stage] + uncertainty_vec = [u[2] for u in uncertainty_stage] + state_out_target = Float64.(model(vcat(uncertainty_vec, state_in))) + DecisionRules.simulate_stage( + subproblems[stage], + state_params_in[stage], + state_params_out[stage], + uncertainty_stage, + state_in, + state_out_target, + ) + state_in = Float64.(DecisionRules.get_next_state( + subproblems[stage], + state_params_in[stage], + state_params_out[stage], + state_in, + state_out_target, + )) + end + + return state_in +end + # ============================================================================ # Training # ============================================================================ @@ -223,13 +292,41 @@ println("Epochs: $num_epochs, Batches per epoch: $num_batches") for epoch in 1:num_epochs println("\n=== Epoch $epoch ===") _num_train_per_batch = num_train_per_batch + epoch_initial_state = copy(initial_state) + + if enable_rollout_initial_state_augmentation && + epoch >= rollout_start_epoch && + ((epoch - rollout_start_epoch) % rollout_every_epochs == 0) + max_rollout_steps = max(1, floor(Int, rollout_max_horizon_fraction * (N - 1))) + rollout_steps = rand(1:max_rollout_steps) + try + epoch_initial_state = rollout_reachable_initial_state( + models, + rollout_initial_state, + rollout_subproblems, + rollout_state_params_in, + rollout_state_params_out, + rollout_uncertainty_samples, + rollout_steps, + ) + rollout_state_shift = norm(epoch_initial_state .- initial_state) + println(" Rollout augmentation active: steps=$rollout_steps, ||x0_epoch - x0_nominal||₂=$(round(rollout_state_shift, digits=4))") + Wandb.log(lg, Dict( + "metrics/rollout_steps" => rollout_steps, + "metrics/rollout_state_shift_l2" => rollout_state_shift, + )) + catch err + @warn "Rollout initial-state update failed; falling back to nominal initial_state." exception=(err, catch_backtrace()) + epoch_initial_state = copy(initial_state) + end + end train_multistage( models, - initial_state, + epoch_initial_state, det_equivalent, - state_params_in_sub, - state_params_out_sub, + state_params_in_det, + state_params_out_det, uncertainty_samples_det; num_batches = num_batches, num_train_per_batch = _num_train_per_batch, @@ -251,7 +348,7 @@ end println("\n=== Final Evaluation ===") Random.seed!(8788) objective_values = [simulate_multistage( - det_equivalent, state_params_in_sub, state_params_out_sub, + det_equivalent, state_params_in_det, state_params_out_det, initial_state, DecisionRules.sample(uncertainty_samples_det), models; ) for _ in 1:10] diff --git a/examples/Atlas/visualize_atlas_policy.jl b/examples/Atlas/visualize_atlas_policy.jl index 7c5532b..ccdb41c 100644 --- a/examples/Atlas/visualize_atlas_policy.jl +++ b/examples/Atlas/visualize_atlas_policy.jl @@ -23,14 +23,14 @@ include(joinpath(Atlas_dir, "atlas_visualization.jl")) # ============================================================================ # Model to load (modify this path to your trained model) -model_path = "./models/atlas-balancing-deteq-N50-2026-02-09T18:53:00.104.jld2" # Set to path of trained model, or nothing to use latest +model_path = "./models/atlas-balancing-deteq-N10-2026-02-15T19:49:47.739.jld2" # Set to path of trained model, or nothing to use latest # Problem parameters (should match training) -N = 50 # Number of time steps +N = 300 # Number of time steps h = 0.01 # Time step -perturbation_scale = 1.5 # Scale of random perturbations +perturbation_scale = 0.5 # Scale of random perturbations num_scenarios = 1 # Number of scenarios to simulate -perturbation_frequency = 5 # Frequency of perturbations (every k stages) +perturbation_frequency = 1000 # Frequency of perturbations (every k stages) # Visualization options animate_robot = true # Whether to animate in MeshCat From d966c24cdfcc863c1a4fc86c32f73670ad39235d Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Mon, 16 Feb 2026 16:19:00 -0500 Subject: [PATCH 2/5] showing pertubation --- Project.toml | 2 +- examples/Atlas/Project.toml | 3 + examples/Atlas/atlas_visualization.jl | 187 +++++++++++++++++++++++ examples/Atlas/train_dr_atlas_det_eq.jl | 4 +- examples/Atlas/visualize_atlas_policy.jl | 35 ++++- 5 files changed, 224 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 0fb43b8..e843595 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,7 @@ JuMP = "1.29.4" MathOptInterface = "1.48.0" ParametricOptInterface = "0.14.1" Zygote = "0.6.77" -julia = "~1.9, ~1.10" +julia = "~1.9, 1.10" [extras] CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" diff --git a/examples/Atlas/Project.toml b/examples/Atlas/Project.toml index 610b8a4..a1e0bd5 100644 --- a/examples/Atlas/Project.toml +++ b/examples/Atlas/Project.toml @@ -1,8 +1,11 @@ [deps] +Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" +CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298" DecisionRules = "47937410-f832-486f-8300-12c95b225dfc" DiffOpt = "930fe3bc-9c6b-11ea-2d94-6184641e85e7" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" HSL_jll = "017b0a0e-03f4-516a-9b91-836bbd1904dd" Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" diff --git a/examples/Atlas/atlas_visualization.jl b/examples/Atlas/atlas_visualization.jl index e216c57..ef4022f 100644 --- a/examples/Atlas/atlas_visualization.jl +++ b/examples/Atlas/atlas_visualization.jl @@ -4,6 +4,11 @@ using MeshCat using MeshCatMechanisms +using GeometryBasics: Point, Vec, HyperSphere +using RigidBodyDynamics: findbody, findjoint, successor, frame_after, default_frame, MechanismState, relative_transform, translation +using CoordinateTransformations: Translation +using Colors +using LinearAlgebra: norm, cross const URDFPATH = joinpath(@__DIR__, "urdf", "atlas_all.urdf") @@ -29,3 +34,185 @@ function animate!(model::Atlas, mvis::MechanismVisualizer, qs; Δt=0.001) return anim end + +""" + animate_with_perturbation_cause!(model, mvis, qs, perturbations; kwargs...) + +Animate Atlas and overlay an illustrative perturbation-cause arrow in MeshCat. +The arrow points in perturbation direction and its length scales with magnitude. + +Arguments: +- `qs`: state trajectory (length `T`) +- `perturbations`: per-stage perturbations (typically length `T-1`) + +Keyword arguments: +- `Δt`: frame step in seconds +- `arrow_scale`: converts perturbation magnitude to arrow length +- `min_arrow_length`: minimum visible arrow length when active +- `show_threshold`: hide arrow when `abs(perturbation) <= show_threshold` +- `linger_seconds`: how long the cause arrow remains visible after each perturbation +- `perturbation_state_index`: state index (in `x`) where perturbation is injected; + if it maps to a velocity state, the arrow is attached to that joint/body. +- `arrow_base`: local-frame base offset from the selected anchor frame +- `impact_distance`: local x-distance from anchor to initial contact point (keeps marker outside robot) +- `retreat_distance`: local x-distance the marker retreats after impact +- `shaft_radius`: thickness of the arrow shaft +""" +function animate_with_perturbation_cause!( + model::Atlas, + mvis::MechanismVisualizer, + qs, + perturbations; + Δt=0.001, + arrow_scale=1.0, + min_arrow_length=0.12, + show_threshold=1e-6, + linger_seconds=0.35, + perturbation_state_index=nothing, + arrow_base=Point(0.0, 0.0, 0.12), + impact_distance=0.18, + retreat_distance=0.35, + shaft_radius=0.03, +) + vis = mvis.visualizer + if isnothing(perturbation_state_index) + perturbation_state_index = model.nq + 5 + end + + velocity_idx = perturbation_state_index - model.nq + anchor_body = nothing + anchor_origin = Point(0.0, 0.0, 0.0) + perturbation_dir_local = Vec(0.0, 1.0, 0.0) + anchor_description = "" + + if 1 <= velocity_idx <= length(model.joint_names) + joint_name = model.joint_names[velocity_idx] + joint = findjoint(model.mech, joint_name) + anchor_body = successor(joint, model.mech) + state0 = MechanismState(model.mech) + joint_in_body = translation(relative_transform(state0, default_frame(anchor_body), frame_after(joint))) + anchor_origin = Point(joint_in_body[1], joint_in_body[2], joint_in_body[3]) + joint_type = getfield(joint, :joint_type) + if hasfield(typeof(joint_type), :axis) + axis = collect(getfield(joint_type, :axis)) + axis_norm = norm(axis) + if axis_norm > 1e-8 + axis ./= axis_norm + # Build a direction orthogonal to the joint axis so the effect reads as a lateral collision. + dir = cross(axis, [0.0, 0.0, 1.0]) + if norm(dir) < 1e-8 + dir = cross(axis, [1.0, 0.0, 0.0]) + end + if norm(dir) > 1e-8 + dir ./= norm(dir) + perturbation_dir_local = Vec(dir[1], dir[2], dir[3]) + end + end + end + anchor_description = + "joint=$(joint_name), body=$(getfield(anchor_body, :name)), dir_local=$(collect(perturbation_dir_local))" + else + anchor_body = findbody(model.mech, "pelvis") + anchor_description = + "fallback=pelvis (state index $perturbation_state_index not mapped to velocity DOF), dir_local=$(collect(perturbation_dir_local))" + end + + cause_arrow_parent = mvis[anchor_body] + cause_arrow = ArrowVisualizer(cause_arrow_parent[:perturbation_cause_arrow]) + setobject!( + cause_arrow; + shaft_material=MeshLambertMaterial(color=colorant"red"), + head_material=MeshLambertMaterial(color=colorant"yellow"), + ) + cause_impactor = cause_arrow_parent[:perturbation_cause_impactor] + setobject!( + cause_impactor, + HyperSphere(Point(0.0, 0.0, 0.0), 0.055), + MeshLambertMaterial(color=colorant"orange"), + ) + linger_frames = max(1, round(Int, linger_seconds / Δt)) + head_radius = 2.2 * shaft_radius + head_length = 2.8 * shaft_radius + + anim = MeshCat.Animation(vis; fps=convert(Int, floor(1.0 / Δt))) + last_event_frame = 0 + last_event_value = 0.0 + last_event_sign = 1.0 + event_count = count(p -> abs(p) > show_threshold, perturbations) + max_abs_pert = isempty(perturbations) ? 0.0 : maximum(abs.(perturbations)) + println( + "Perturbation-cause overlay: events=$event_count, max_abs=$(round(max_abs_pert, digits=6)), " * + "perturb_state_idx=$perturbation_state_index, anchor={$anchor_description}, " * + "impact_distance=$impact_distance, retreat_distance=$retreat_distance" + ) + for (frame, q) in enumerate(qs) + MeshCat.atframe(anim, frame) do + set_configuration!(mvis, q[1:model.nq]) + + p = frame <= length(perturbations) ? perturbations[frame] : 0.0 + if abs(p) > show_threshold + last_event_frame = frame + last_event_value = p + last_event_sign = sign(p) == 0 ? 1.0 : sign(p) + end + + frames_since_event = frame - last_event_frame + if last_event_frame > 0 && frames_since_event <= linger_frames + progress = frames_since_event / linger_frames + decay = 1.0 - progress + outward_dir = Vec( + last_event_sign * perturbation_dir_local[1], + last_event_sign * perturbation_dir_local[2], + last_event_sign * perturbation_dir_local[3], + ) + + # Contact happens just outside the body, then marker backs away. + contact_point = Point( + anchor_origin[1] + arrow_base[1] + outward_dir[1] * impact_distance, + anchor_origin[2] + arrow_base[2] + outward_dir[2] * impact_distance, + anchor_origin[3] + arrow_base[3] + outward_dir[3] * impact_distance, + ) + impactor_point = Point( + contact_point[1] + outward_dir[1] * retreat_distance * progress, + contact_point[2] + outward_dir[2] * retreat_distance * progress, + contact_point[3] + outward_dir[3] * retreat_distance * progress, + ) + settransform!( + cause_impactor, + Translation(impactor_point[1], impactor_point[2], impactor_point[3]), + ) + + effective_p = last_event_value * decay + arrow_length = max(min_arrow_length * decay, abs(effective_p) * arrow_scale) + # Arrow points from impactor toward robot (collision cause direction). + direction = Vec( + -outward_dir[1] * arrow_length, + -outward_dir[2] * arrow_length, + -outward_dir[3] * arrow_length, + ) + settransform!( + cause_arrow, + impactor_point, + direction; + shaft_radius=shaft_radius, + max_head_radius=head_radius, + max_head_length=head_length, + ) + else + # "Hide" by shrinking to zero length (more robust than animating visibility). + settransform!( + cause_arrow, + anchor_origin, + Vec(0.0, 0.0, 0.0); + shaft_radius=shaft_radius, + max_head_radius=head_radius, + max_head_length=head_length, + ) + # Keep impactor out of view when there is no active perturbation event. + settransform!(cause_impactor, Translation(1000.0, 1000.0, 1000.0)) + end + end + end + MeshCat.setanimation!(mvis, anim) + return anim +end diff --git a/examples/Atlas/train_dr_atlas_det_eq.jl b/examples/Atlas/train_dr_atlas_det_eq.jl index 0bfcc6e..a640b5d 100644 --- a/examples/Atlas/train_dr_atlas_det_eq.jl +++ b/examples/Atlas/train_dr_atlas_det_eq.jl @@ -28,7 +28,7 @@ h = 0.01 # Time step perturbation_scale = 0.5 # Scale of random perturbations num_scenarios = 10 # Number of uncertainty samples per stage penalty = 10.0 # Penalty for state deviation -perturbation_frequency = 301 # Frequency of perturbations (every k stages) +perturbation_frequency = 5 # Frequency of perturbations (every k stages) # Training parameters num_epochs = 10 @@ -61,7 +61,7 @@ save_file = "atlas-balancing-deteq-N$(N)-$(now())" # - set to "latest" to use the most recent deterministic-equivalent model in `model_dir` # - set to a run name (with or without `.jld2`) or a full path # e.g. "atlas-balancing-deteq-N50-2026-02-02T21:16:37.554" -warmstart_model = "atlas-balancing-deteq-N5-2026-02-14T23:37:16.303" +warmstart_model = "atlas-balancing-deteq-N10-2026-02-15T19:49:47.739" # CLI override: # julia --project=. examples/Atlas/train_dr_atlas_det_eq.jl if !isempty(ARGS) diff --git a/examples/Atlas/visualize_atlas_policy.jl b/examples/Atlas/visualize_atlas_policy.jl index ccdb41c..859a0bb 100644 --- a/examples/Atlas/visualize_atlas_policy.jl +++ b/examples/Atlas/visualize_atlas_policy.jl @@ -35,6 +35,14 @@ perturbation_frequency = 1000 # Frequency of perturbations (every k stages) # Visualization options animate_robot = true # Whether to animate in MeshCat save_plots = true # Whether to save plots to file +show_perturbation_cause_in_meshcat = true +meshcat_cause_arrow_scale = 2.0 +meshcat_cause_show_threshold = 1e-6 +meshcat_cause_linger_seconds = 2.0 +meshcat_cause_min_arrow_length = 0.40 +meshcat_cause_shaft_radius = 0.08 +meshcat_cause_impact_distance = 0.18 +meshcat_cause_retreat_distance = 0.35 # ============================================================================ # Load Model @@ -74,6 +82,7 @@ atlas = Atlas() nx = atlas.nx nu = atlas.nu +perturbation_idx = atlas.nq + 5 println("State dimension: $nx") println("Control dimension: $nu") @@ -87,6 +96,7 @@ println("Control dimension: $nu") N = N, h = h, perturbation_scale = perturbation_scale, + perturbation_indices = [perturbation_idx], num_scenarios = num_scenarios, perturbation_frequency = perturbation_frequency, ) @@ -301,8 +311,27 @@ if animate_robot # Convert to format expected by animate! X_animate = all_states[best_scenario] - - animate!(atlas, mvis, X_animate, Δt=h) + + if show_perturbation_cause_in_meshcat + animate_with_perturbation_cause!( + atlas, + mvis, + X_animate, + all_perturbations[best_scenario]; + Δt = h, + arrow_scale = meshcat_cause_arrow_scale, + show_threshold = meshcat_cause_show_threshold, + linger_seconds = meshcat_cause_linger_seconds, + min_arrow_length = meshcat_cause_min_arrow_length, + shaft_radius = meshcat_cause_shaft_radius, + perturbation_state_index = perturbation_idx, + impact_distance = meshcat_cause_impact_distance, + retreat_distance = meshcat_cause_retreat_distance, + ) + println("MeshCat overlay: collision-style perturbation cause enabled (impactor appears at contact and retreats over linger window).") + else + animate!(atlas, mvis, X_animate, Δt=h) + end println("\nAnimation ready! Open MeshCat visualizer to view.") println("Best scenario objective: $(all_objectives[best_scenario])") @@ -321,8 +350,6 @@ println("="^60) openloop_objectives = Float64[] openloop_final_deviations = Float64[] -perturbation_idx = atlas.nq + 5 - for s in 1:num_scenarios Random.seed!(s * 100 + 42) From e3dab90a1313df1f856005996729963a7ec4dddf Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Mon, 16 Feb 2026 18:15:13 -0500 Subject: [PATCH 3/5] update --- examples/Atlas/Project.toml | 1 + examples/Atlas/visualize_atlas_policy.jl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/Atlas/Project.toml b/examples/Atlas/Project.toml index a1e0bd5..6405884 100644 --- a/examples/Atlas/Project.toml +++ b/examples/Atlas/Project.toml @@ -20,4 +20,5 @@ RigidBodyDynamics = "366cf18f-59d5-5db9-a4de-86a9f6786172" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Wandb = "ad70616a-06c9-5745-b1f1-6a5f42545108" diff --git a/examples/Atlas/visualize_atlas_policy.jl b/examples/Atlas/visualize_atlas_policy.jl index 859a0bb..07cba0d 100644 --- a/examples/Atlas/visualize_atlas_policy.jl +++ b/examples/Atlas/visualize_atlas_policy.jl @@ -30,7 +30,7 @@ N = 300 # Number of time steps h = 0.01 # Time step perturbation_scale = 0.5 # Scale of random perturbations num_scenarios = 1 # Number of scenarios to simulate -perturbation_frequency = 1000 # Frequency of perturbations (every k stages) +perturbation_frequency = 50 # Frequency of perturbations (every k stages) # Visualization options animate_robot = true # Whether to animate in MeshCat From 9696086503a883b8bb74e2c7bd126b75cb06a630 Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Mon, 16 Feb 2026 21:36:36 -0500 Subject: [PATCH 4/5] add first try transient plots --- .../Atlas/evaluate_transient_atlas_policy.jl | 224 +++++++++++++++++ examples/Atlas/plot_transient_eval_results.jl | 237 ++++++++++++++++++ 2 files changed, 461 insertions(+) create mode 100644 examples/Atlas/evaluate_transient_atlas_policy.jl create mode 100644 examples/Atlas/plot_transient_eval_results.jl diff --git a/examples/Atlas/evaluate_transient_atlas_policy.jl b/examples/Atlas/evaluate_transient_atlas_policy.jl new file mode 100644 index 0000000..ae254e2 --- /dev/null +++ b/examples/Atlas/evaluate_transient_atlas_policy.jl @@ -0,0 +1,224 @@ +# Evaluate a trained Atlas policy under single-shot transient perturbations. +# +# Designed for SLURM array execution: +# - each task runs one rollout of length N +# - perturbation is applied only at stage 1 +# - perturbation sign/magnitude are mapped from array task id +# +# Results are saved to: +# transient_eval_results//sim_XXX__mag_.jld2 +# +# Usage examples: +# julia --project=. evaluate_transient_atlas_policy.jl +# ATLAS_POLICY_PATH=./models/atlas-balancing-deteq-...jld2 \ +# SLURM_ARRAY_TASK_ID=4 julia --project=. evaluate_transient_atlas_policy.jl + +using Flux +using DecisionRules +using LinearAlgebra +using JuMP +import Ipopt, HSL_jll +using JLD2 +using Dates +using Printf + +Atlas_dir = dirname(@__FILE__) +include(joinpath(Atlas_dir, "build_atlas_problem.jl")) + +function parse_env(T::Type, key::String, default) + if haskey(ENV, key) && !isempty(ENV[key]) + return parse(T, ENV[key]) + end + return default +end + +function resolve_policy_path(policy_hint::Union{Nothing, String}, model_dir::String) + if !isnothing(policy_hint) && !isempty(policy_hint) + candidates = String[policy_hint] + push!(candidates, joinpath(model_dir, policy_hint)) + if !endswith(policy_hint, ".jld2") + push!(candidates, policy_hint * ".jld2") + push!(candidates, joinpath(model_dir, policy_hint * ".jld2")) + end + for path in candidates + isfile(path) && return path + end + error("Could not resolve ATLAS_POLICY_PATH: $policy_hint") + end + + model_files = filter(f -> endswith(f, ".jld2") && startswith(f, "atlas-balancing"), readdir(model_dir)) + isempty(model_files) && error("No .jld2 models found in $model_dir and no ATLAS_POLICY_PATH provided.") + model_files_full = [joinpath(model_dir, f) for f in model_files] + return model_files_full[argmax([mtime(f) for f in model_files_full])] +end + +function perturbation_from_task(task_id::Int, n_levels::Int, min_mag::Float64, max_mag::Float64) + n_levels < 1 && error("ATLAS_TRANSIENT_N_LEVELS must be >= 1.") + (min_mag <= 0 || max_mag < min_mag) && + error("Expected 0 < min_mag <= max_mag, got min_mag=$min_mag, max_mag=$max_mag.") + + magnitudes = collect(range(min_mag, max_mag, length=n_levels)) + n_tasks = 2 * n_levels + (task_id < 1 || task_id > n_tasks) && + error("Task id $task_id is out of range 1:$n_tasks for n_levels=$n_levels.") + + mag_idx = fld(task_id - 1, 2) + 1 + sign = isodd(task_id) ? -1.0 : 1.0 + mag = magnitudes[mag_idx] + return sign * mag, mag, sign, n_tasks +end + +function build_fixed_perturbation_sample(uncertainty_samples, perturbation_value::Float64) + # Start from DecisionRules.sample so tuple types match simulate_multistage expectations. + perturbation_sample = DecisionRules.sample(uncertainty_samples) + for t in eachindex(perturbation_sample) + for i in eachindex(perturbation_sample[t]) + var = perturbation_sample[t][i][1] + perturbation_sample[t][i] = (var, 0.0) + end + end + + if !isempty(perturbation_sample) && !isempty(perturbation_sample[1]) + first_var = perturbation_sample[1][1][1] + perturbation_sample[1][1] = (first_var, perturbation_value) + end + + return perturbation_sample +end + +function main() + model_dir = joinpath(Atlas_dir, "models") + output_root = joinpath(Atlas_dir, "transient_eval_results") + mkpath(output_root) + + # CLI arg takes precedence; otherwise use env var. + policy_hint = !isempty(ARGS) ? ARGS[1] : get(ENV, "ATLAS_POLICY_PATH", "") + policy_path = resolve_policy_path(policy_hint, model_dir) + policy_name = splitext(basename(policy_path))[1] + result_dir = joinpath(output_root, policy_name) + mkpath(result_dir) + + N = parse_env(Int, "ATLAS_TRANSIENT_HORIZON", 300) + h = parse_env(Float64, "ATLAS_TRANSIENT_TIMESTEP", 0.01) + n_levels = parse_env(Int, "ATLAS_TRANSIENT_N_LEVELS", 10) + max_mag = parse_env(Float64, "ATLAS_TRANSIENT_MAX_MAG", 1.0) + min_mag_default = max_mag / max(n_levels, 1) + min_mag = parse_env(Float64, "ATLAS_TRANSIENT_MIN_MAG", min_mag_default) + task_id = parse_env(Int, "SLURM_ARRAY_TASK_ID", parse_env(Int, "TASK_ID", 1)) + + atlas = Atlas() + perturbation_idx_default = atlas.nq + 5 + perturbation_idx = parse_env(Int, "ATLAS_TRANSIENT_PERTURBATION_INDEX", perturbation_idx_default) + (perturbation_idx < 1 || perturbation_idx > atlas.nx) && + error("ATLAS_TRANSIENT_PERTURBATION_INDEX=$perturbation_idx is outside 1:$(atlas.nx).") + + perturbation_value, perturbation_mag, perturbation_sign, n_tasks = + perturbation_from_task(task_id, n_levels, min_mag, max_mag) + + println("Transient eval task:") + println(" policy: $policy_name") + println(" task: $task_id / $n_tasks") + println(" perturbation value: $perturbation_value") + println(" perturbation index: $perturbation_idx") + println(" horizon: N=$N, h=$h") + + subproblems, state_params_in, state_params_out, initial_state, uncertainty_samples, + X_vars, U_vars, x_ref, u_ref, _ = build_atlas_subproblems(; + atlas = atlas, + N = N, + h = h, + perturbation_scale = 0.0, + perturbation_frequency = N, # stage 1 only for N-step rollout + perturbation_indices = [perturbation_idx], + num_scenarios = 1, + ) + + nx = atlas.nx + nu = atlas.nu + n_uncertainties = length(uncertainty_samples[1]) + layers = Int64[64, 64] + activation = sigmoid + models = state_conditioned_policy(n_uncertainties, nx, nx, layers; + activation = activation, encoder_type = Flux.LSTM) + + model_data = JLD2.load(policy_path) + haskey(model_data, "model_state") || error("Model file does not contain `model_state`: $policy_path") + Flux.loadmodel!(models, normalize_recur_state(model_data["model_state"])) + Flux.reset!(models) + + perturbation_sample = build_fixed_perturbation_sample(uncertainty_samples, perturbation_value) + perturbation_series = zeros(Float64, N - 1) + perturbation_series[1] = perturbation_value + + status = "success" + error_message = "" + objective_value = NaN + states = zeros(Float64, nx, N) + actions = zeros(Float64, nu, N - 1) + rollout_state_shift_l2 = fill(NaN, N) + state_change_l2 = fill(NaN, N - 1) + time = collect(0:N-1) .* h + + try + objective_value = simulate_multistage( + subproblems, + state_params_in, + state_params_out, + initial_state, + perturbation_sample, + models, + ) + + states[:, 1] .= initial_state + for t in 1:N-1 + states[:, t + 1] .= value.(X_vars[t]) + actions[:, t] .= value.(U_vars[t]) + end + rollout_state_shift_l2 .= [norm(states[:, t] .- x_ref) for t in 1:N] + state_change_l2 .= [norm(states[:, t + 1] .- states[:, t]) for t in 1:N-1] + catch err + status = "failure" + error_message = sprint(showerror, err, catch_backtrace()) + @warn "Transient evaluation failed." exception=(err, catch_backtrace()) + end + + mag_token = replace(@sprintf("%.4f", perturbation_mag), "." => "p") + sign_token = perturbation_sign > 0 ? "pos" : "neg" + result_file = joinpath( + result_dir, + @sprintf("sim_%03d_%s_mag_%s.jld2", task_id, sign_token, mag_token), + ) + + jldsave( + result_file; + policy_name, + policy_path, + status, + error_message, + task_id, + n_tasks, + N, + h, + perturbation_idx, + perturbation_value, + perturbation_mag, + perturbation_sign, + objective_value, + states, + actions, + rollout_state_shift_l2, + state_change_l2, + perturbation_series, + time, + x_ref, + u_ref, + ) + + println("Saved transient result: $result_file") + println("Status: $status") + if status != "success" + error(error_message) + end +end + +main() diff --git a/examples/Atlas/plot_transient_eval_results.jl b/examples/Atlas/plot_transient_eval_results.jl new file mode 100644 index 0000000..14a9368 --- /dev/null +++ b/examples/Atlas/plot_transient_eval_results.jl @@ -0,0 +1,237 @@ +# Plot transient evaluation study for Atlas policy rollouts. +# +# Expected input directory structure: +# transient_eval_results//sim_*.jld2 +# +# Produces: +# 1) perturbation vs time-to-equilibrium scatter +# 2) rollout state-shift L2 vs time for all simulations, color-coded by perturbation +# +# Usage: +# julia --project=. plot_transient_eval_results.jl +# julia --project=. plot_transient_eval_results.jl + +using JLD2 +using Statistics +using LinearAlgebra +using Plots +using Printf + +Atlas_dir = dirname(@__FILE__) +results_root = joinpath(Atlas_dir, "transient_eval_results") + +function parse_env(T::Type, key::String, default) + if haskey(ENV, key) && !isempty(ENV[key]) + return parse(T, ENV[key]) + end + return default +end + +function resolve_policy_name(policy_hint::Union{Nothing, String}) + if !isnothing(policy_hint) && !isempty(policy_hint) + # If a full model path is given, use stem as policy folder name. + if endswith(policy_hint, ".jld2") + return splitext(basename(policy_hint))[1] + end + return basename(policy_hint) + end + + # Fallback to most recently modified policy result folder. + isdir(results_root) || error("No transient_eval_results directory found at $results_root") + dirs = filter(d -> isdir(joinpath(results_root, d)), readdir(results_root)) + isempty(dirs) && error("No policy subfolders found in $results_root") + full = [joinpath(results_root, d) for d in dirs] + return dirs[argmax([mtime(d) for d in full])] +end + +function time_to_equilibrium(state_change_l2::Vector{Float64}, h::Float64; stable_steps::Int=50, delta_tol::Float64=1e-3) + if length(state_change_l2) < stable_steps + return NaN, nothing + end + for start_step in 1:(length(state_change_l2) - stable_steps + 1) + if all(@view(state_change_l2[start_step:start_step + stable_steps - 1]) .<= delta_tol) + # Declare equilibrium at the end of the stable window. + eq_step = start_step + stable_steps + return (eq_step - 1) * h, eq_step + end + end + return NaN, nothing +end + +function color_for_perturbation(value::Float64, max_abs::Float64, gradient) + alpha = max_abs > 0 ? clamp((value + max_abs) / (2 * max_abs), 0.0, 1.0) : 0.5 + return get(gradient, alpha) +end + +function main() + policy_hint = !isempty(ARGS) ? ARGS[1] : get(ENV, "ATLAS_POLICY_PATH", "") + policy_name = resolve_policy_name(policy_hint) + result_dir = joinpath(results_root, policy_name) + isdir(result_dir) || error("Result directory does not exist: $result_dir") + + sim_files = sort(filter( + f -> endswith(f, ".jld2") && startswith(basename(f), "sim_"), + readdir(result_dir; join=true), + )) + isempty(sim_files) && error("No sim_*.jld2 files found in $result_dir") + + stable_steps = parse_env(Int, "ATLAS_EQUILIBRIUM_STEPS", 50) + delta_tol = parse_env(Float64, "ATLAS_EQUILIBRIUM_DELTA_TOL", 1e-3) + + perturbations = Float64[] + equilibrium_times = Float64[] + equilibrium_steps = Union{Nothing, Int}[] + rollout_series = Vector{Vector{Float64}}() + time_series = Vector{Vector{Float64}}() + task_ids = Int[] + failed_files = String[] + h_ref = NaN + N_ref = 0 + + for file in sim_files + data = JLD2.load(file) + status = get(data, "status", "success") + if status != "success" + push!(failed_files, file) + continue + end + + perturbation_value = Float64(data["perturbation_value"]) + h = Float64(data["h"]) + N = Int(data["N"]) + state_change_l2 = Vector{Float64}(data["state_change_l2"]) + rollout_state_shift_l2 = Vector{Float64}(data["rollout_state_shift_l2"]) + time = haskey(data, "time") ? Vector{Float64}(data["time"]) : collect(0:N-1) .* h + task_id = Int(data["task_id"]) + + eq_time, eq_step = time_to_equilibrium(state_change_l2, h; stable_steps=stable_steps, delta_tol=delta_tol) + + push!(perturbations, perturbation_value) + push!(equilibrium_times, eq_time) + push!(equilibrium_steps, eq_step) + push!(rollout_series, rollout_state_shift_l2) + push!(time_series, time) + push!(task_ids, task_id) + h_ref = h + N_ref = N + end + + isempty(perturbations) && error("No successful simulation files found in $result_dir") + + # Sort by perturbation value for coherent plotting and summary. + order = sortperm(perturbations) + perturbations = perturbations[order] + equilibrium_times = equilibrium_times[order] + equilibrium_steps = equilibrium_steps[order] + rollout_series = rollout_series[order] + time_series = time_series[order] + task_ids = task_ids[order] + + max_abs_pert = maximum(abs.(perturbations)) + color_grad = cgrad([:blue, :white, :red]) + + reached = findall(!isnan, equilibrium_times) + unreached = findall(isnan, equilibrium_times) + fallback_time = N_ref > 0 ? (N_ref - 1) * h_ref : maximum(vcat(time_series...)) + + p1 = plot( + title = "Perturbation vs Time To Equilibrium", + xlabel = "Signed Perturbation Magnitude", + ylabel = "Time To Equilibrium (s)", + legend = :top, + grid = true, + ) + if !isempty(reached) + scatter!( + p1, + perturbations[reached], + equilibrium_times[reached]; + marker_z = perturbations[reached], + color = color_grad, + clims = (-max_abs_pert, max_abs_pert), + ms = 8, + label = "Reached equilibrium", + colorbar_title = "Perturbation", + ) + end + if !isempty(unreached) + scatter!( + p1, + perturbations[unreached], + fill(fallback_time, length(unreached)); + marker = :x, + ms = 8, + color = :black, + label = "Not reached by horizon", + ) + end + + p2 = plot( + title = "Rollout State Shift L2 vs Time", + xlabel = "Time (s)", + ylabel = "||x_t - x_ref||\u2082", + legend = false, + grid = true, + ) + for i in eachindex(perturbations) + perturb = perturbations[i] + t = time_series[i] + shift = rollout_series[i] + c = color_for_perturbation(perturb, max_abs_pert, color_grad) + plot!( + p2, + t, + shift; + color = c, + lw = 2, + alpha = 0.9, + ) + end + + # Add a compact color scale legend. + p_scale = scatter( + [0.0, 1.0], + [0.0, 0.0]; + marker_z = [-max_abs_pert, max_abs_pert], + color = color_grad, + clims = (-max_abs_pert, max_abs_pert), + markersize = 0.001, + markerstrokewidth = 0, + legend = false, + colorbar_title = "Perturbation", + framestyle = :none, + xshowaxis = false, + yshowaxis = false, + ) + + combined = plot(p1, p2, p_scale; layout = @layout([a; b; c{0.06h}]), size = (1200, 1100)) + + plots_dir = joinpath(result_dir, "plots") + mkpath(plots_dir) + plot_path = joinpath(plots_dir, "transient_student_study_plots.png") + savefig(combined, plot_path) + + summary_path = joinpath(result_dir, "transient_equilibrium_summary.tsv") + open(summary_path, "w") do io + println(io, "task_id\tperturbation\tequilibrium_time_s\tequilibrium_step") + for i in eachindex(perturbations) + eq_step = isnothing(equilibrium_steps[i]) ? "NA" : string(equilibrium_steps[i]) + eq_time = isnan(equilibrium_times[i]) ? "NA" : @sprintf("%.6f", equilibrium_times[i]) + println(io, "$(task_ids[i])\t$(perturbations[i])\t$(eq_time)\t$(eq_step)") + end + end + + println("Saved plots to: $plot_path") + println("Saved equilibrium summary to: $summary_path") + println("Successful simulations: $(length(perturbations))") + println("Failed simulations skipped: $(length(failed_files))") + if !isempty(failed_files) + println("Failed files:") + for f in failed_files + println(" - $f") + end + end + println("Equilibrium definition: ||x_{t+1} - x_t||₂ <= $delta_tol for $stable_steps consecutive steps.") +end + +main() From e6c06ce1590a0b62f857752ee213b33eb491ea73 Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Tue, 17 Feb 2026 21:53:32 -0500 Subject: [PATCH 5/5] update --- .gitignore | 1 + examples/Atlas/train_dr_atlas_det_eq.jl | 4 ++-- examples/Atlas/visualize_atlas_policy.jl | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 2f17921..f345598 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,4 @@ Manifest.toml examples/**/.CondaPkg/* *.bson *.err +*.tsv diff --git a/examples/Atlas/train_dr_atlas_det_eq.jl b/examples/Atlas/train_dr_atlas_det_eq.jl index a640b5d..16c97af 100644 --- a/examples/Atlas/train_dr_atlas_det_eq.jl +++ b/examples/Atlas/train_dr_atlas_det_eq.jl @@ -28,7 +28,7 @@ h = 0.01 # Time step perturbation_scale = 0.5 # Scale of random perturbations num_scenarios = 10 # Number of uncertainty samples per stage penalty = 10.0 # Penalty for state deviation -perturbation_frequency = 5 # Frequency of perturbations (every k stages) +perturbation_frequency = 1000 # Frequency of perturbations (every k stages) # Training parameters num_epochs = 10 @@ -44,7 +44,7 @@ optimizers = [Flux.Adam(0.001)] enable_rollout_initial_state_augmentation = true rollout_start_epoch = 1 rollout_every_epochs = 1 -rollout_max_horizon_fraction = 10.0 +rollout_max_horizon_fraction = 20.0 if enable_rollout_initial_state_augmentation rollout_every_epochs < 1 && error("rollout_every_epochs must be >= 1") diff --git a/examples/Atlas/visualize_atlas_policy.jl b/examples/Atlas/visualize_atlas_policy.jl index 07cba0d..e987351 100644 --- a/examples/Atlas/visualize_atlas_policy.jl +++ b/examples/Atlas/visualize_atlas_policy.jl @@ -28,9 +28,9 @@ model_path = "./models/atlas-balancing-deteq-N10-2026-02-15T19:49:47.739.jld2" # Problem parameters (should match training) N = 300 # Number of time steps h = 0.01 # Time step -perturbation_scale = 0.5 # Scale of random perturbations +perturbation_scale = 1.5 # Scale of random perturbations num_scenarios = 1 # Number of scenarios to simulate -perturbation_frequency = 50 # Frequency of perturbations (every k stages) +perturbation_frequency = 1000 # Frequency of perturbations (every k stages) # Visualization options animate_robot = true # Whether to animate in MeshCat