Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion genesis_forge/managed_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def step(
self.managers["action"]
):
(start, end) = self._action_ranges[i]
action_manager.step(actions[:, start:end])
processed_actions = action_manager.step(actions[:, start:end])
action_manager.send_actions_to_simulation(processed_actions)
self.scene.step()

# Update entity managers
Expand Down
59 changes: 48 additions & 11 deletions genesis_forge/managers/action/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from torch._tensor import Tensor
import torch
import numpy as np
from gymnasium import spaces
Expand Down Expand Up @@ -150,7 +151,7 @@ def actions(self) -> torch.Tensor:
@property
def raw_actions(self) -> torch.Tensor:
"""
The actions received from the policy, before being converted.
The actions received from the policy, before being processed.
"""
if self._raw_actions is None:
return torch.zeros((self.env.num_envs, self.num_actions))
Expand Down Expand Up @@ -213,6 +214,47 @@ def get_dofs_force(self, clip_to_max_force: bool = False):
clip_to_max_force=clip_to_max_force, dofs_idx=self.dofs_idx
)

def get_actions(self) -> torch.Tensor:
"""
Get the current actions for the environments.
"""
if self._actions is None:
return torch.zeros((self.env.num_envs, self.num_actions))
return self._actions

def get_actions_dict(self, env_idx: int = 0) -> dict[str, float]:
"""
Get the latest actions for an environment as a dictionary of DOF names and values.
"""
return {
name: value.item()
for name, value in zip(
self.dofs.keys(), self._actions[env_idx, :]
)
}

def process_actions(self, actions: torch.Tensor) -> torch.Tensor:
"""
Process the actions and convert them to actuator commands.
Override this function if you want to change the action processing logic.

Args:
actions: The incoming step actions to handle.

Returns:
The processed and converted actions.
"""
return actions

def send_actions_to_simulation(self) -> torch.Tensor:
"""
Send the latest processed actions to the actuators in the simulation.
Override this function to define how the actions are sent to the simulation.
"""
raise NotImplementedError(
"handle_actions is not implemented for this action manager."
)

"""
Lifecycle Operations
"""
Expand All @@ -237,7 +279,7 @@ def build(self):

def step(self, actions: torch.Tensor) -> None:
"""
Handle the received actions.
Handle actions received in this step.
"""
# Action delay buffer
if self._delay_step > 0:
Expand All @@ -248,7 +290,10 @@ def step(self, actions: torch.Tensor) -> None:
self._raw_actions = actions
if self._actions is None:
self._actions = torch.empty_like(actions, device=gs.device)
self._actions[:] = self._raw_actions[:]

# Process the actions
self._actions[:] = self.process_actions(self._raw_actions[:])

return self._actions

def reset(self, envs_idx: list[int] | None):
Expand All @@ -262,11 +307,3 @@ def reset(self, envs_idx: list[int] | None):
self._action_delay_buffer.append(
torch.zeros((self.env.num_envs, self.num_actions), device=gs.device)
)

def get_actions(self) -> torch.Tensor:
"""
Get the current actions for the environments.
"""
if self._actions is None:
return torch.zeros((self.env.num_envs, self.num_actions))
return self._actions
40 changes: 19 additions & 21 deletions genesis_forge/managers/action/position_action_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class PositionActionManager(BaseActionManager):
offset: Offset factor for the action.
use_default_offset: Whether to use default joint positions configured in the articulation asset as offset. Defaults to True.
clip: Clip the action values to the range. If omitted, the action values will automatically be clipped to the joint limits.
soft_limit_scale_factor: Scales the clip range of all limits by this factor around the midpoint
of each joint's limits to establish a safety region within the limits.
Defaults to 1.0.
quiet_action_errors: Whether to quiet action errors.
delay_step: The number of steps to delay the actions for.
This is an easy way to emulate the latency in the system.
Expand Down Expand Up @@ -110,6 +113,7 @@ def __init__(
scale: float | dict[str, float] = 1.0,
offset: float | dict[str, float] = 0.0,
clip: tuple[float, float] | dict[str, tuple[float, float]] = None,
soft_limit_scale_factor: float = 1.0,
use_default_offset: bool = True,
action_handler: Callable[[torch.Tensor], None] = None,
quiet_action_errors: bool = False,
Expand All @@ -126,6 +130,7 @@ def __init__(
self._offset_cfg = ensure_dof_pattern(offset)
self._scale_cfg = ensure_dof_pattern(scale)
self._clip_cfg = ensure_dof_pattern(clip)
self._soft_limit_scale_factor = soft_limit_scale_factor
self._quiet_action_errors = quiet_action_errors
self._enabled_dof = None
self._use_default_offset = use_default_offset
Expand Down Expand Up @@ -161,6 +166,11 @@ def build(self):
self._clip_values = torch.stack([lower_limit, upper_limit], dim=1)
if self._clip_cfg is not None:
self._get_dof_value_tensor(self._clip_cfg, output=self._clip_values)
if self._soft_limit_scale_factor != 1.0:
midpoint = (self._clip_values[:, 0] + self._clip_values[:, 1]) * 0.5
half_range = (self._clip_values[:, 1] - self._clip_values[:, 0]) * 0.5 * self._soft_limit_scale_factor
self._clip_values[:, 0] = midpoint - half_range
self._clip_values[:, 1] = midpoint + half_range

# Scale
self._scale_values = None
Expand All @@ -175,31 +185,16 @@ def build(self):
offset = self._offset_cfg if self._offset_cfg is not None else 0.0
self._offset_values = self._get_dof_value_tensor(offset)

def step(self, actions: torch.Tensor) -> torch.Tensor:
def process_actions(self, actions: torch.Tensor) -> torch.Tensor:
"""
Take the incoming actions for this step and handle them.

Args:
actions: The incoming step actions to handle.
"""
if not self.enabled:
return
actions = super().step(actions)
self._actions = self.handle_actions(actions)
return self._actions

def handle_actions(self, actions: torch.Tensor) -> torch.Tensor:
"""
Converts the actions to position commands, and send them to the DOF actuators.
Override this function if you want to change the action handling logic.
Convert the actions to position commands, and clamp them to the limits.

Args:
actions: The incoming step actions to handle.

Returns:
The processed and handled actions.
The actions as position commands.
"""

# Validate actions
if not self._quiet_action_errors:
if torch.isnan(actions).any():
Expand All @@ -214,12 +209,15 @@ def handle_actions(self, actions: torch.Tensor) -> torch.Tensor:
min=self._clip_values[:, 0],
max=self._clip_values[:, 1],
)
return actions

# Set target positions
def send_actions_to_simulation(self, actions: torch.Tensor) -> torch.Tensor:
"""
Sends the actions as position commands to the actuators in the simulation.
"""
actions = self.get_actions()
self.actuator_manager.control_dofs_position(actions, self.dofs_idx)

return actions

"""
Internal methods
"""
Expand Down
20 changes: 9 additions & 11 deletions genesis_forge/managers/action/position_within_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ class PositionWithinLimitsActionManager(PositionActionManager):
actuator_manager: The actuator manager which is used to setup and control the DOF joints.
actuator_joints: Which joints of the actuator manager that this action manager will control.
These can be full names or regular expressions.
limit: (optional) A dictionary of DOF name patterns and their position limits.
limit: A dictionary of DOF name patterns and their position limits.
If omitted, the limits will be set to the limits of the actuators defined in the model.
soft_limit_scale_factor: Scales the range of all limits by this factor to establish a safety region within the limits. Defaults to 1.0.
quiet_action_errors: Whether to quiet action errors.
delay_step: The number of steps to delay the actions for.
This is an easy way to emulate the latency in the system.
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
actuator_joints: list[str] | str = ".*",
quiet_action_errors: bool = False,
limit: tuple[float, float] | dict[str, tuple[float, float]] = {},
soft_limit_scale_factor: float = 1.0,
delay_step: int = 0,
**kwargs,
):
Expand All @@ -88,6 +90,7 @@ def __init__(
**kwargs,
)
self._limit_cfg = ensure_dof_pattern(limit)
self._soft_limit_scale_factor = soft_limit_scale_factor

"""
Lifecycle Operations
Expand All @@ -103,26 +106,21 @@ def build(self):
lower = lower.unsqueeze(0).expand(self.env.num_envs, -1)
upper = upper.unsqueeze(0).expand(self.env.num_envs, -1)
self._offset = (upper + lower) * 0.5
self._scale = (upper - lower) * 0.5
self._scale = (upper - lower) * 0.5 * self._soft_limit_scale_factor

def handle_actions(self, actions: torch.Tensor) -> torch.Tensor:
def process_actions(self, actions: torch.Tensor) -> torch.Tensor:
"""
Converts the actions to position commands, and send them to the DOF actuators.
Override this function if you want to change the action handling logic.
Convert the actions to position commands within the limits.

Args:
actions: The incoming step actions to handle.

Returns:
The processed and handled actions.
The actions as position commands.
"""
# Convert the action from -1 to 1, to absolute position within the actuator limits
actions.clamp_(-1.0, 1.0)
actions = actions.clamp(-1.0, 1.0)
actions = actions * self._scale + self._offset

# Set target positions
self.actuator_manager.control_dofs_position(actions, self.dofs_idx)

return actions

"""
Expand Down
28 changes: 28 additions & 0 deletions genesis_forge/managers/actuator/actuator_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,34 @@ def get_dofs_force(
[lower, upper] = self._robot.get_dofs_force_range(dofs_idx or self.dofs_idx)
force = force.clamp(lower, upper)
return force

def get_dofs_control_force(
self,
noise: float = 0.0,
clip_to_max_force: bool = False,
dofs_idx: list[int] | None = None,
) -> torch.Tensor:
"""
Return the force output by the configured DOFs.
This is a wrapper for `RigidEntity.get_dofs_control_force`.

Args:
noise: The maximum amount of random noise to add to the force values returned.
clip_to_max_force: Clip the force returned to the maximum force of the actuators.
dofs_idx: The indices of the DOFs to get the force for. If None, all the DOFs of this actuator manager are used.

Returns:
force: torch.Tensor, shape (n_envs, n_dofs)
The force experienced by the enabled DOFs.
"""
dofs_idx = dofs_idx if dofs_idx is not None else self.dofs_idx
force = self._robot.get_dofs_control_force(dofs_idx)
if noise > 0.0:
force = self._add_random_noise(force, noise)
if clip_to_max_force:
[lower, upper] = self._robot.get_dofs_force_range(dofs_idx or self.dofs_idx)
force = force.clamp(lower, upper)
return force

def get_dofs_limits(
self, dofs_idx: list[int] | None = None
Expand Down
9 changes: 9 additions & 0 deletions genesis_forge/managers/config/config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ def build(self, **kwargs):
return
self._init_fn_class()

def reset(self, envs_idx: list[int]):
"""
Reset the function class for the given environments.
No-op if the function is not a class instance.
"""
if not self._is_class:
return
self._fn.reset(envs_idx)

def execute(self, envs_idx: list[int]):
"""
Call the function for the given environment ids.
Expand Down
4 changes: 4 additions & 0 deletions genesis_forge/managers/config/mdp_fn_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def build(self):
"""Called during the environment build phase and when MDP params are changed."""
pass

def reset(self, envs_idx):
"""Called when environments are reset. Override to clear per-env state."""
pass

def __call__(
self,
env: GenesisEnv,
Expand Down
8 changes: 4 additions & 4 deletions genesis_forge/managers/contact/kernel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import gstaichi as ti
from genesis.utils.geom import ti_inv_transform_by_quat
import quadrants as ti
from genesis.utils.geom import qd_inv_transform_by_quat


@ti.kernel
Expand Down Expand Up @@ -73,9 +73,9 @@ def kernel_get_contact_forces(

# Transform force to local frame of target link
if is_target_b:
force_vec = ti_inv_transform_by_quat(force_vec, quat_b)
force_vec = qd_inv_transform_by_quat(force_vec, quat_b)
else:
force_vec = ti_inv_transform_by_quat(-force_vec, quat_a)
force_vec = qd_inv_transform_by_quat(-force_vec, quat_a)

# Accumulate force and position
for j in ti.static(range(3)):
Expand Down
Loading