Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions src/comodo/mujocoSimulator/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from abc import ABC, abstractmethod
from typing import List
from comodo.mujocoSimulator.mjcontactinfo import MjContactInfo
import types
import logging
import mujoco
import copy

class Callback(ABC):
def __init__(self, *args, **kwargs) -> None:
self.args = args
self.kwargs = kwargs
self.simulator = None

def set_simulator(self, simulator):
self.simulator = simulator

@abstractmethod
def on_simulation_start(self) -> None:
pass

@abstractmethod
def on_simulation_step(self, t: float, data: mujoco.MjData, opts: dict = None) -> None:
pass

@abstractmethod
def on_simulation_end(self) -> None:
pass


class EarlyStoppingCallback(Callback):
def __init__(self, condition, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.score = 0
self.condition = condition

def on_simulation_start(self) -> None:
pass

def on_simulation_step(self, t: float, iter: int, data: mujoco.MjData, opts: dict = None) -> None:
if self.condition(t, iter, data, opts, *self.args, **self.kwargs):
if self.simulator is not None:
self.simulator.should_stop = True

def on_simulation_end(self) -> None:
pass


class ScoreCallback(Callback):
def __init__(self, score_function, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.score = 0
self.history = []
self.score_function = score_function

def on_simulation_start(self) -> None:
self.score = 0
self.history = []

def on_simulation_step(self, t: float, iter: int, data: mujoco.MjData, opts: dict = None) -> None:
score = self.score_function(t, iter, data, opts, *self.args, **self.kwargs)
self.score += score
self.history.append(score)

def on_simulation_end(self) -> None:
pass


class TrackerCallback(Callback):
def __init__(self, tracked_variables: list, print_values: bool | list = False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.tracked_variables = tracked_variables

if isinstance(print_values, list):
self.print_values = print_values
elif isinstance(print_values, bool):
self.print_values = [print_values for _ in range(len(tracked_variables))]
else:
raise ValueError(f"print_values should be a boolean or a list of booleans for masking, not {type(print_values)}")

self.t = []
self.vals = {}

def on_simulation_start(self) -> None:
pass

def on_simulation_step(self, t: float, iter: int, data: mujoco.MjData, opts: dict = None) -> None:
self.t.append(t)
for var in self.tracked_variables:
if isinstance(var, str):
try:
v = eval(f"data.{var}")
if not var in self.vals:
self.vals[var] = []
self.vals[var].append(copy.deepcopy(v))
if self.print_values[self.tracked_variables.index(var)]:
print(f"{var}: {v}")
except:
print(f"Error: {self.tracked_variables} not found in data")
elif isinstance(var, types.FunctionType):
v = var(t, iter, data, opts, *self.args, **self.kwargs)
if not var.__name__ in self.vals:
self.vals[var.__name__] = []
if not isinstance(v, (int, float)):
return
self.vals[var.__name__].append(v)
if self.print_values[self.tracked_variables.index(var)]:
print(f"{var.__name__}: {v}")

def on_simulation_end(self) -> None:
pass

def get_tracked_values(self):
return self.t, self.vals


class ContactCallback(Callback):
def __init__(self, tracked_bodies: List[str] = [], logger: logging.Logger = None, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.tracked_bodies = tracked_bodies
self.logger = logger
self.last_contact = None

def on_simulation_start(self) -> None:
pass

def on_simulation_step(self, t: float, iter: int, data: mujoco.MjData, opts: dict = None) -> None:
if opts.get("contact", None) is not None:
contact_info = opts["contact"]
assert isinstance(contact_info, MjContactInfo), "Contact info is not an instance of MjContactInfo"
if contact_info.is_none():
return
self.last_contact = contact_info
if self.logger is not None:
self.logger.debug(f"Contact detected at t={t}: {contact_info}")
else:
pass
#print(f"Contact detected at t={t}: {contact_info}")


def on_simulation_end(self) -> None:
pass
39 changes: 39 additions & 0 deletions src/comodo/mujocoSimulator/mjcontactinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from dataclasses import dataclass
from typing import Sequence, List

class MjContactInfo:
"""Wrapper class for mjContact struct of MuJoCo.
Accepts the struct instance individually or by unpacking the struct.
"""
def __init__(self, t: float, iter: int, *args):
self.t = t
self.iter = iter
if len(args) == 1:
self.mj_struct = args[0]
self.dist: float = self.mj_struct.dist[0] if self.mj_struct.dist.size > 0 else None
self.pos: List[float] = self.mj_struct.pos.flatten().tolist() if self.mj_struct.pos.size > 0 else None
self.frame: List[float] = self.mj_struct.frame.flatten().tolist() if self.mj_struct.frame.size > 0 else None
self.dim: int = self.mj_struct.dim[0] if self.mj_struct.dim.size > 0 else None
self.geom: List[int] = self.mj_struct.geom.flatten().tolist() if self.mj_struct.geom.size > 0 else None
self.flex: List[int] = self.mj_struct.flex.flatten().tolist() if self.mj_struct.flex.size > 0 else None
self.elem: List[int] = self.mj_struct.elem.flatten().tolist() if self.mj_struct.elem.size > 0 else None
self.vert: List[int] = self.mj_struct.vert.flatten().tolist() if self.mj_struct.vert.size > 0 else None
self.mu: float = float(self.mj_struct.mu[0]) if self.mj_struct.mu.size > 0 else None
self.H: List[float] = self.mj_struct.H.flatten().tolist() if self.mj_struct.H.size > 0 else None

self._is_none = self.pos is None
else:
raise NotImplementedError("Unpacking the struct is not implemented yet.")

def __str__(self):
return f"ContactInfo(t={self.t}, iter={self.iter}, dist={self.dist}, pos={self.pos}, frame={self.frame}, dim={self.dim}, geom={self.geom}, flex={self.flex}, elem={self.elem}, vert={self.vert}, mu={self.mu}, H={self.H})"

def __repr__(self):
return f"ContactInfo(t={self.t}, iter={self.iter}, dist={self.dist}, pos={self.pos}, frame={self.frame}, dim={self.dim}, geom={self.geom}, flex={self.flex}, elem={self.elem}, vert={self.vert}, mu={self.mu}, H={self.H})"

def is_none(self) -> bool:
return self._is_none

def get_time(self) -> float:
return self.t

Loading