Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deluca/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@
from deluca.agents._pid import PID
from deluca.agents._random import Random
from deluca.agents._zero import Zero
from deluca.agents._adaptive import Adaptive
from deluca.agents._gpc import GPC
from deluca.agents._lqr import LQR

__all__ = ["BangBang", "PID", "Random", "Zero"]

__all__ = ["BangBang", "PID", "Random", "Zero", "Adaptive", "GPC", "LQR"]
22 changes: 15 additions & 7 deletions deluca/agents/_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax.numpy as jnp
import numpy as np

from deluca.agents.core import Agent
from deluca.core import Agent


def quad_loss(x: jnp.ndarray, u: jnp.ndarray) -> Real:
Expand Down Expand Up @@ -66,13 +66,13 @@ def __init__(
self.n, self.m = B.shape

cost_fn = cost_fn or quad_loss
self.cost_fn = cost_fn

self.base_controller = base_controller

# Start From Uniform Distribution
self.T = T
self.weights = np.zeros(T)
self.weights[0] = 1.0

# Track current timestep
self.t, self.expert_density = 0, expert_density
Expand Down Expand Up @@ -111,10 +111,14 @@ def evolve(x, h):
def __call__(self, x, A, B):

play_i = np.argmax(self.weights)
if not self.alive[play_i]:
play_i = 0

self.u = self.learners[play_i].get_action(x)

# Update alive models
for i in jnp.nonzero(self.alive)[0]:
i = int(i)
loss_i = self.policy_loss(self.learners[i], A, B, x, self.w)
self.weights[i] *= np.exp(-self.eta * loss_i)
self.weights[i] = min(max(self.weights[i], self.eps), self.inf)
Expand All @@ -138,13 +142,17 @@ def __call__(self, x, A, B):
del self.learners[kill]
self.weights[kill] = 0

# Rescale
max_w = np.max(self.weights)
if max_w < 1:
self.weights /= max_w
# # Rescale
# max_w = np.max(self.weights)
# if max_w < 1:
# self.weights /= max_w
max_weight_i = np.argmax(self.weights)
max_w = self.weights[max_weight_i]
if self.alive[max_weight_i] and max_w < 1:
self.weights /= max_w

# Get new noise (will be located at w[-1])
self.w = self.w.at[0].set(x - self.A @ self.x + self.B @ self.u)
self.w = self.w.at[0].set(x - self.A @ self.x - self.B @ self.u)
self.w = jnp.roll(self.w, -1, axis=0)

# Update System
Expand Down
2 changes: 1 addition & 1 deletion deluca/agents/_bpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy.random as random

from deluca.agents._lqr import LQR
from deluca.agents.core import Agent
from deluca.core import Agent


def generate_uniform(shape, norm=1.00):
Expand Down
2 changes: 1 addition & 1 deletion deluca/agents/_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax
import jax.numpy as jnp

from deluca.agents.core import Agent
from deluca.core import Agent
from deluca.utils import Random


Expand Down
2 changes: 1 addition & 1 deletion deluca/agents/_drc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from jax import grad
from jax import jit

from deluca.agents.core import Agent
from deluca.core import Agent
from deluca.utils import Random


Expand Down
2 changes: 1 addition & 1 deletion deluca/agents/_gpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax import jit

from deluca.agents._lqr import LQR
from deluca.agents.core import Agent
from deluca.core import Agent


def quad_loss(x: jnp.ndarray, u: jnp.ndarray) -> Real:
Expand Down
2 changes: 1 addition & 1 deletion deluca/agents/_hinf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import jax.numpy as jnp
from jax.numpy.linalg import inv

from deluca.agents.core import Agent
from deluca.core import Agent


class Hinf(Agent):
Expand Down
2 changes: 1 addition & 1 deletion deluca/agents/_igpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import jax
import jax.numpy as jnp

from deluca.agents.core import Agent
from deluca.core import Agent
from deluca.utils.planning import f_at_x
from deluca.utils.planning import LQR
from deluca.utils.planning import rollout
Expand Down
2 changes: 1 addition & 1 deletion deluca/agents/_ilc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import jax.numpy as jnp

from deluca.agents.core import Agent
from deluca.core import Agent
from deluca.utils.planning import f_at_x
from deluca.utils.planning import LQR
from deluca.utils.planning import rollout
Expand Down
2 changes: 1 addition & 1 deletion deluca/agents/_ilqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import jax.numpy as jnp

from deluca.agents.core import Agent
from deluca.core import Agent


def iLQR_loop(env, U_initial, T, alpha=1.0, log=None):
Expand Down
2 changes: 1 addition & 1 deletion deluca/agents/_lqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import jax.numpy as jnp
from scipy.linalg import solve_discrete_are as dare

from deluca.agents.core import Agent
from deluca.core import Agent


class LQR(Agent):
Expand Down