Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/ssl_simulator_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
"metadata": {},
"outputs": [],
"source": [
"simulation_data = load_sim(SIMDATA_FILE1, debug=True)\n",
"simulation_data, simulation_settings = load_sim(SIMDATA_FILE1, debug=True)\n",
"plotter = PlotBasic(simulation_data)\n",
"plotter.plot()"
]
Expand Down Expand Up @@ -214,7 +214,7 @@
"metadata": {},
"outputs": [],
"source": [
"simulation_data = load_sim(SIMDATA_FILE2, debug=True)\n",
"simulation_data, simulation_settings = load_sim(SIMDATA_FILE2, debug=True)\n",
"plotter = PlotBasic(simulation_data)\n",
"plotter.plot()"
]
Expand Down Expand Up @@ -322,7 +322,7 @@
"metadata": {},
"outputs": [],
"source": [
"simulation_data = load_sim(SIMDATA_FILE3, debug=True)\n",
"simulation_data, simulation_settings = load_sim(SIMDATA_FILE3, debug=True)\n",
"plotter = PlotBasic(simulation_data)\n",
"plotter.plot()"
]
Expand Down Expand Up @@ -385,7 +385,7 @@
"metadata": {},
"outputs": [],
"source": [
"simulation_data = load_sim(SIMDATA_FILE4, debug=True)\n",
"simulation_data, simulation_settings = load_sim(SIMDATA_FILE4, debug=True)\n",
"plotter = PlotBasic(simulation_data)\n",
"plotter.plot()"
]
Expand Down
6 changes: 4 additions & 2 deletions ssl_simulator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
"""
# ssl_simulator/__init__.py

# Configuration
from ssl_simulator.config import CONFIG

# Utils
from ssl_simulator.utils.debug import *
Expand Down
18 changes: 18 additions & 0 deletions ssl_simulator/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# ssl_simulator/config.py

import os

class Config(dict):
def __setitem__(self, key, value):
print(f"SSL simulator configuration updated: {key} = {value}")
super().__setitem__(key, value)

def update(self, *args, **kwargs):
for key, value in dict(*args, **kwargs).items():
print(f"SSL simulator configuration updated: {key} = {value}")
super().update(*args, **kwargs)

# Initialize the configuration dictionary
CONFIG = Config({
"DEBUG": os.getenv("SSL_SIMULATOR_DEBUG", "False").lower() in ("true", "1", "yes"),
})
8 changes: 4 additions & 4 deletions ssl_simulator/core/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#######################################################################################

class EulerIntegrator:
def integrate(self, context, dt, debug=False):
def integrate(self, context, dt, test=False):
"""
Perform one step of Euler integration.

Expand All @@ -18,16 +18,16 @@ def integrate(self, context, dt, debug=False):
state (dict): Current state of the system.
dynamics_input (dict): Input to the dynamics function.
dt (float): Time step for integration.
debug (bool): If True, perform dimension checks during integration.
test (bool): If True, perform dimension checks during integration.

Returns:
dict: New state after integration.
"""
state = context.get_robot_state()
state_dot = context.get_robot_state_dot()

# Perform dimension checks if debug mode is enabled
if debug:
# Perform dimension checks if test mode is enabled
if test:
for key in state.keys():
if key + "_dot" in state_dot:
check_and_parse_dimensions(
Expand Down
2 changes: 1 addition & 1 deletion ssl_simulator/core/simulation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _set_time_step(self, time_step):
def _step_test(self):
self.context.compute_controls(self.time, self.time_step)
self.context.compute_robot_dynamics(self.time)
self.integrator.integrate(self.context, self.time_step, debug=True)
self.integrator.integrate(self.context, self.time_step, test=True)

def _log_data(self):
data = self.context.get_data()
Expand Down
16 changes: 16 additions & 0 deletions ssl_simulator/math/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
]

import numpy as np
from ssl_simulator.config import CONFIG

#######################################################################################

Expand Down Expand Up @@ -81,10 +82,25 @@ def check_and_parse_dimensions(array, expected_shape, name=None, fill_values=Non
array = np.asarray(array) # Ensure the input is a NumPy array

# Handle special cases for expected shapes (auto-add batch dimension)
# TODO: generalize this logic
orig_shape = array.shape
changed = False

if len(expected_shape) == 2 and expected_shape[0] is None and array.ndim == 1:
array = array[np.newaxis, :]
changed = True
elif len(expected_shape) == 3 and expected_shape[0] is None and array.ndim == 2:
array = array[np.newaxis, :, :]
changed = True
elif len(expected_shape) == 4 and expected_shape[0] is None and expected_shape[1] is None and array.ndim == 3:
array = array[:, np.newaxis, :, :]
changed = True
elif len(expected_shape) == 4 and expected_shape[0] is None and expected_shape[1] is None and array.ndim == 2:
array = array[np.newaxis, np.newaxis, :, :]
changed = True

if CONFIG["DEBUG"] and changed:
print(f"Shape changed: {orig_shape} -> {array.shape}")

# Replace None values in expected_shape with fill_values if provided
if fill_values is not None:
Expand Down
53 changes: 37 additions & 16 deletions ssl_simulator/math/lie.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@
"rot_3d_matrix",
"gen_random_rotations",
"orthonormal_vector_to",
"construct_attitude_basis",
"rotation_matrix_from_vector",
"rotation_angle_from_matrix",
"so3_hat",
"so3_vee",
"so3_exp_map",
"so3_log_map",
"so3_rotate_with_step"
"so3_rotate_with_step",
]

import numpy as np
import math

from ssl_simulator.math import check_and_parse_dimensions, unit_vec

def rot_3d_matrix(roll, pitch, yaw, dec=None):
"""
Generate R ∈ SO(3) from ROLL, PITCH, YAW.
Expand Down Expand Up @@ -104,6 +107,24 @@ def orthonormal_vector_to(v):

return n

def construct_attitude_basis(heading, gravity):
"""
Construct an orthonormal basis given heading and gravity vectors.
Handles both single vector (shape (3,) or (1,3)) and batch (shape (N,3)).
Returns: (3,3) or (N,3,3) basis matrix/matrices.
"""
heading = check_and_parse_dimensions(heading, (None,3), "heading")
gravity = check_and_parse_dimensions(gravity, (None,3), "gravity", fill_values=heading.shape[0])

v1 = unit_vec(heading)
gravity_proj = gravity - np.sum(gravity * v1, axis=1, keepdims=True) * v1
v3 = -unit_vec(gravity_proj)
if v3.any() < 1e-8:
raise ValueError("Gravity is parallel to heading; cannot construct basis.")
v2 = -np.cross(v1, v3)
basis = np.stack((v1, v2, v3), axis=-1)
return basis

def rotation_matrix_from_vector(v):
"""
- Given the input vector v, build an orthonormal basis and codify into a rotation matrix R ∈ SO(3) -
Expand All @@ -118,7 +139,7 @@ def rotation_matrix_from_vector(v):
md_y = np.cross(md_z, md)

# Build the rotation matrix
R = np.array([md, md_y, md_z])
R = np.array([md, md_y, md_z]).T
return R / np.linalg.det(R)

def rotation_angle_from_matrix(R):
Expand Down Expand Up @@ -170,21 +191,13 @@ def so3_hat(omega):
def so3_vee(omega_hat):
"""
- Generate \omega vector from \omega_\hat ∈ so(3) -
Supports single matrix (3,3) or batch (N,3,3).
Supports batch (...,3,3).
"""
omega_hat = np.asarray(omega_hat)
if omega_hat.ndim == 2:
wx = omega_hat[2,1]
wy = omega_hat[0,2]
wz = omega_hat[1,0]
return np.array([wx, wy, wz])
elif omega_hat.ndim == 3:
wx = omega_hat[:,2,1]
wy = omega_hat[:,0,2]
wz = omega_hat[:,1,0]
return np.stack([wx, wy, wz], axis=-1)
else:
raise ValueError("Input must be shape (3,3) or (N,3,3)")
wx = omega_hat[...,2,1]
wy = omega_hat[...,0,2]
wz = omega_hat[...,1,0]
return np.stack([wx, wy, wz], axis=-1)

###################################################################

Expand Down Expand Up @@ -323,7 +336,15 @@ def so3_rotate_with_step(R, omega_hat, step=np.pi/6):
"""
R = np.asarray(R)
omega_hat = np.asarray(omega_hat)


# Broadcast R to match omega_hat batch size
if R.shape[:-2] != omega_hat.shape[:-2]:
if R.shape[:-2] == (1,):
# Broadcast R to match omega_hat batch shape
R = np.broadcast_to(R, omega_hat.shape)
else:
raise ValueError(f"Incompatible batch shapes: R {R.shape[:-2]} vs omega_hat {omega_hat.shape[:-2]}")

# Flatten batch if necessary
batch_shape = R.shape[:-2]
R_flat = R.reshape(-1, 3, 3)
Expand Down
4 changes: 3 additions & 1 deletion ssl_simulator/utils/path_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import sys
from pathlib import Path

from ssl_simulator.config import CONFIG

#######################################################################################

def create_dir(directory: str, verbose: bool = True) -> None:
Expand All @@ -31,7 +33,7 @@ def create_dir(directory: str, verbose: bool = True) -> None:
if verbose:
print(f"The directory '{directory}' already exists!")

def add_src_to_path(file=None, relative_path="", deep=0, debug=False):
def add_src_to_path(file=None, relative_path="", deep=0, debug=CONFIG["DEBUG"]):
"""
Adds the "relative_path" folder to sys.path based on the notebook's location.
"""
Expand Down
6 changes: 4 additions & 2 deletions ssl_simulator/utils/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from ssl_simulator import check_file_size, json_to_dict, print_dict
from ssl_simulator.components.scalar_fields import ScalarField

from ssl_simulator.config import CONFIG

#######################################################################################

def load_sim(filename, debug=False, max_size_mb=100, verbose=False):
def load_sim(filename, debug=CONFIG["DEBUG"], max_size_mb=100, verbose=CONFIG["DEBUG"]):
check_file_size(filename, max_size_mb=max_size_mb)

settings, skiprows = _load_settings_line(filename)
Expand All @@ -30,7 +32,7 @@ def load_sim(filename, debug=False, max_size_mb=100, verbose=False):
if debug:
_debug_print(settings, data_dict, verbose)

return (data_dict, settings) if settings else data_dict
return data_dict, settings

def load_class(module_name: str, class_name: str, base_class=None, **init_kwargs):
"""
Expand Down