diff --git a/deluca/agents/__init__.py b/deluca/agents/__init__.py index 227b5b7..1ffd17d 100644 --- a/deluca/agents/__init__.py +++ b/deluca/agents/__init__.py @@ -16,5 +16,9 @@ from deluca.agents._pid import PID from deluca.agents._random import Random from deluca.agents._zero import Zero +from deluca.agents._adaptive import Adaptive +from deluca.agents._gpc import GPC +from deluca.agents._lqr import LQR -__all__ = ["BangBang", "PID", "Random", "Zero"] + +__all__ = ["BangBang", "PID", "Random", "Zero", "Adaptive", "GPC", "LQR"] diff --git a/deluca/agents/_adaptive.py b/deluca/agents/_adaptive.py index f09951f..14daa40 100644 --- a/deluca/agents/_adaptive.py +++ b/deluca/agents/_adaptive.py @@ -20,7 +20,7 @@ import jax.numpy as jnp import numpy as np -from deluca.agents.core import Agent +from deluca.core import Agent def quad_loss(x: jnp.ndarray, u: jnp.ndarray) -> Real: @@ -66,13 +66,13 @@ def __init__( self.n, self.m = B.shape cost_fn = cost_fn or quad_loss + self.cost_fn = cost_fn self.base_controller = base_controller # Start From Uniform Distribution self.T = T self.weights = np.zeros(T) - self.weights[0] = 1.0 # Track current timestep self.t, self.expert_density = 0, expert_density @@ -111,10 +111,14 @@ def evolve(x, h): def __call__(self, x, A, B): play_i = np.argmax(self.weights) + if not self.alive[play_i]: + play_i = 0 + self.u = self.learners[play_i].get_action(x) # Update alive models for i in jnp.nonzero(self.alive)[0]: + i = int(i) loss_i = self.policy_loss(self.learners[i], A, B, x, self.w) self.weights[i] *= np.exp(-self.eta * loss_i) self.weights[i] = min(max(self.weights[i], self.eps), self.inf) @@ -138,13 +142,17 @@ def __call__(self, x, A, B): del self.learners[kill] self.weights[kill] = 0 - # Rescale - max_w = np.max(self.weights) - if max_w < 1: - self.weights /= max_w + # # Rescale + # max_w = np.max(self.weights) + # if max_w < 1: + # self.weights /= max_w + max_weight_i = np.argmax(self.weights) + max_w = self.weights[max_weight_i] + if self.alive[max_weight_i] and max_w < 1: + self.weights /= max_w # Get new noise (will be located at w[-1]) - self.w = self.w.at[0].set(x - self.A @ self.x + self.B @ self.u) + self.w = self.w.at[0].set(x - self.A @ self.x - self.B @ self.u) self.w = jnp.roll(self.w, -1, axis=0) # Update System diff --git a/deluca/agents/_bpc.py b/deluca/agents/_bpc.py index 3109a2e..3c1f7c5 100644 --- a/deluca/agents/_bpc.py +++ b/deluca/agents/_bpc.py @@ -20,7 +20,7 @@ import numpy.random as random from deluca.agents._lqr import LQR -from deluca.agents.core import Agent +from deluca.core import Agent def generate_uniform(shape, norm=1.00): diff --git a/deluca/agents/_deep.py b/deluca/agents/_deep.py index 1527677..08781f5 100644 --- a/deluca/agents/_deep.py +++ b/deluca/agents/_deep.py @@ -18,7 +18,7 @@ import jax import jax.numpy as jnp -from deluca.agents.core import Agent +from deluca.core import Agent from deluca.utils import Random diff --git a/deluca/agents/_drc.py b/deluca/agents/_drc.py index 0d0b5a7..93d7074 100644 --- a/deluca/agents/_drc.py +++ b/deluca/agents/_drc.py @@ -21,7 +21,7 @@ from jax import grad from jax import jit -from deluca.agents.core import Agent +from deluca.core import Agent from deluca.utils import Random diff --git a/deluca/agents/_gpc.py b/deluca/agents/_gpc.py index fc28139..c8419b3 100644 --- a/deluca/agents/_gpc.py +++ b/deluca/agents/_gpc.py @@ -23,7 +23,7 @@ from jax import jit from deluca.agents._lqr import LQR -from deluca.agents.core import Agent +from deluca.core import Agent def quad_loss(x: jnp.ndarray, u: jnp.ndarray) -> Real: diff --git a/deluca/agents/_hinf.py b/deluca/agents/_hinf.py index 1456145..6fb90fd 100644 --- a/deluca/agents/_hinf.py +++ b/deluca/agents/_hinf.py @@ -19,7 +19,7 @@ import jax.numpy as jnp from jax.numpy.linalg import inv -from deluca.agents.core import Agent +from deluca.core import Agent class Hinf(Agent): diff --git a/deluca/agents/_igpc.py b/deluca/agents/_igpc.py index 2541f0e..e82ae9d 100644 --- a/deluca/agents/_igpc.py +++ b/deluca/agents/_igpc.py @@ -15,7 +15,7 @@ import jax import jax.numpy as jnp -from deluca.agents.core import Agent +from deluca.core import Agent from deluca.utils.planning import f_at_x from deluca.utils.planning import LQR from deluca.utils.planning import rollout diff --git a/deluca/agents/_ilc.py b/deluca/agents/_ilc.py index 870e4e9..9a78fa6 100644 --- a/deluca/agents/_ilc.py +++ b/deluca/agents/_ilc.py @@ -14,7 +14,7 @@ import jax.numpy as jnp -from deluca.agents.core import Agent +from deluca.core import Agent from deluca.utils.planning import f_at_x from deluca.utils.planning import LQR from deluca.utils.planning import rollout diff --git a/deluca/agents/_ilqr.py b/deluca/agents/_ilqr.py index f970577..5b0216a 100644 --- a/deluca/agents/_ilqr.py +++ b/deluca/agents/_ilqr.py @@ -22,7 +22,7 @@ import jax.numpy as jnp -from deluca.agents.core import Agent +from deluca.core import Agent def iLQR_loop(env, U_initial, T, alpha=1.0, log=None): diff --git a/deluca/agents/_lqr.py b/deluca/agents/_lqr.py index 29eb516..3bde666 100644 --- a/deluca/agents/_lqr.py +++ b/deluca/agents/_lqr.py @@ -16,7 +16,7 @@ import jax.numpy as jnp from scipy.linalg import solve_discrete_are as dare -from deluca.agents.core import Agent +from deluca.core import Agent class LQR(Agent):