Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
caa7de0
Update core.py
eladhazan Jun 2, 2025
60007ea
Fixed bugs in SFC, DSC; Changed LDS class; Added disturbance class
Jun 3, 2025
4442ea8
Fixed bugs in core.py
Jun 3, 2025
790edc3
Fixed bugs in LDS
Jun 3, 2025
1442411
Fixed SimpleRandom Agent
Jun 3, 2025
9088e41
Fixed LDS class
Jun 3, 2025
a2dd0f4
Fixed LDS class
Jun 3, 2025
2373e3c
Fixed a bud in DSC
Jun 3, 2025
b107bd6
remove depreciated stuff
eladhazan Jun 4, 2025
0cd3c4f
Added default arguments to LDS class
Jun 4, 2025
7250042
Added default arguments to LDS class
Jun 4, 2025
e0ba985
Fixed Simplerandom agent
Jun 4, 2025
48165cc
Added default arguments to LDS class
Jun 4, 2025
57cdcdd
Added default arguments to LDS class
Jun 4, 2025
20beb70
Added default arguments to LDS class
Jun 4, 2025
b25ddce
Added default arguments to LDS class
Jun 4, 2025
ef7d748
Merge pull request #88 from google/fix-lds-class
eladhazan Jun 4, 2025
969c591
Update _random.py
eladhazan Jun 4, 2025
d5be431
Update _zero.py
eladhazan Jun 4, 2025
987c0bc
added sinus noise
eladhazan Jun 4, 2025
8bfea8a
added the main agents to external
eladhazan Jun 4, 2025
ae6ea68
added default params to LQG and empty update function
eladhazan Jun 4, 2025
e970728
fixed random calls in DRC
eladhazan Jun 4, 2025
e352d4f
Add files via upload
eladhazan Jun 4, 2025
8742f3c
Replace deprecated jax.tree_* functions with jax.tree.*
danielsuo Jun 9, 2025
779ae92
Replace uses of `jnp.array` in types with `jnp.ndarray`.
danielsuo Jun 9, 2025
e5a31a7
Flax is changing the `RNNCellBase` API:
danielsuo Jun 9, 2025
81346b9
some fixes to zero agent
eladhazan Jun 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colabs/test_agents_on_lds.ipynb

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion deluca/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from deluca.agents._bang_bang import BangBang
from deluca.agents._pid import PID
from deluca.agents._random import Random
from deluca.agents._random import SimpleRandom
from deluca.agents._zero import Zero
from deluca.agents._grc import GRC
from deluca.agents._sfc import SFC
from deluca.agents._lqg import LQG

__all__ = ["BangBang", "PID", "Random", "Zero"]
__all__ = ["BangBang", "PID", "Random", "SimpleRandom", "Zero", "GRC", "SFC", "LQG"]
13 changes: 6 additions & 7 deletions deluca/agents/_drc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from jax import grad
from jax import jit

from deluca.agents.core import Agent
from deluca.utils import Random
from deluca.core import Agent
#from deluca.utils import Random


def quad_loss(x: jnp.ndarray, u: jnp.ndarray) -> Real:
Expand Down Expand Up @@ -75,8 +75,9 @@ def __init__(

cost_fn = cost_fn or quad_loss

self.random = Random(seed)

# self.random = Random(seed)
key = jax.random.PRNGKey(0)

d_state, d_action = B.shape # State & Action Dimensions

C = jnp.identity(d_state) if C is None else C
Expand All @@ -102,9 +103,7 @@ def __init__(
# initial linear policy / perturbation contributions / bias
self.K = K if K is not None else jnp.zeros((d_action, d_obs))

self.M = lr_scale * jax.random.normal(
self.random.generate_key(), shape=(m, d_action, d_obs)
)
self.M = lr_scale * jax.random.normal(key, shape=(m, d_action, d_obs))

# Past m nature y's such that y_nat[0] is the most recent
self.y_nat = jnp.zeros((m, d_obs, 1))
Expand Down
4 changes: 2 additions & 2 deletions deluca/agents/_dsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self.h = h

self.t = 0
subkey_1, subkey2, subkey3, subkey4 = jax.random.split(key, 4)
subkey1, subkey2, subkey3, subkey4 = jax.random.split(key, 4)
self.M_0 = init_scale*jax.random.normal(subkey1, shape=(self.n, self.p))
self.M_1 = init_scale*jax.random.normal(subkey2, shape=(self.h, self.n, self.p))
self.M_2 = init_scale*jax.random.normal(subkey3, shape=(self.h, self.n, self.p))
Expand Down Expand Up @@ -116,7 +116,7 @@ def last_ynat():
self.last_ynat = last_ynat

def slice_window(start):
return jax.lax.dynamic_slice(self.ynat_history, (start, 0, 0), (self.m, p, 1))
return jax.lax.dynamic_slice(self.ynat_history, (start, 0, 0), (self.m, self.p, 1))

self.slice_window = slice_window

Expand Down
23 changes: 18 additions & 5 deletions deluca/agents/_lqg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def __init__(
A: jnp.ndarray,
B: jnp.ndarray,
C: jnp.ndarray,
Q: jnp.ndarray,
R: jnp.ndarray,
W: jnp.ndarray,
V: jnp.ndarray
Q: jnp.ndarray = None,
R: jnp.ndarray = None,
W: jnp.ndarray = None,
V: jnp.ndarray = None
) -> None:
"""
Description: Initialize the dynamics of the model.
Expand All @@ -50,7 +50,16 @@ def __init__(

self.A, self.B, self.C = A, B, C # System Dynamics
self.d, self.n, self.p = self.A.shape[0], self.B.shape[1], self.C.shape[0] # State, Action, Observation Dimensions


if Q is None:
Q = jax.numpy.identity(self.d)
if R is None:
R = jax.numpy.identity(self.n)
if W is None:
W = jax.numpy.identity(self.d)
if V is None:
V = jax.numpy.identity(self.p)

self.P = jnp.array(solve_discrete_are(np.asarray(A), np.asarray(B), np.asarray(Q), np.asarray(R)))
self.K = jnp.linalg.inv(self.B.T @ self.P @ self.B + R) @ (self.B.T @ self.P @ self.A)
Sigma = jnp.array(solve_discrete_are(np.asarray(A).T, np.asarray(C).T, np.asarray(W), np.asarray(V)))
Expand All @@ -73,3 +82,7 @@ def __call__(self, y: jnp.ndarray) -> jnp.ndarray:
residual = y - self.C @ self.x_hat
self.x_hat = self.A @ self.x_hat + self.B @ u + self.L @ residual
return u


def update(self, y: jnp.ndarray, u: jnp.ndarray) -> None:
return None
2 changes: 1 addition & 1 deletion deluca/agents/_predestined.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class Predestined(Agent):
time: float = deluca.field(jaxed=False)
steps: int = deluca.field(jaxed=False)
u: jnp.array = deluca.field(jaxed=False)
u: jnp.ndarray = deluca.field(jaxed=False)

def __call__(self, state, obs, *args, **kwargs):
action = jax.lax.dynamic_slice(self.u, (state.steps.astype(int),), (1,))
Expand Down
23 changes: 12 additions & 11 deletions deluca/agents/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,19 @@ class SimpleRandom(Agent):
This agent return a normally distributed action.
"""
# d_action: int = field(1, jaxed=False)
# agent_state: jnp.array = field(default_factory=lambda: jnp.array([[1.0]]), jaxed=False)
# agent_state: jnp.ndarray = field(default_factory=lambda: jnp.ndarray([[1.0]]), jaxed=False)
# key: int = field(default_factory=lambda: jax.random.key(0), jaxed=False)

def __init__(self,d_action):
self.d_action = d_action
self.agent_state = None
self.key = jax.random.key(0)
return None
def __init__(self, n: int, key: jax.random.key = jax.random.PRNGKey(0)):
self.n = n
self.key = key

def __call__(self, obs):
self.key = jax.random.split(self.key)[0]
return jax.random.normal( self.key, shape = (self.d_action,1))
def __call__(self, obs: jnp.ndarray, key: jax.random.key = None):
if key is None:
self.key, subkey = jax.random.split(self.key, 2)
return jax.random.normal(subkey, shape = (self.n, 1))
else:
return jax.random.normal(key, shape = (self.n, 1))

def update(self, state: jnp.ndarray, u: jnp.ndarray) -> None:
return None
def update(self, obs: jnp.ndarray, action: jnp.ndarray) -> None:
return None
2 changes: 1 addition & 1 deletion deluca/agents/_sfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self.h = h

self.t = 0
subkey_1, subkey2 = jax.random.split(key)
subkey1, subkey2 = jax.random.split(key)
self.M = init_scale*jax.random.normal(subkey1, shape=(self.h, self.n, self.p))
self.M_0 = init_scale*jax.random.normal(subkey2, shape=(self.n, self.p))

Expand Down
10 changes: 5 additions & 5 deletions deluca/agents/_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@


class Zero(Agent):
d_action: int = field(1, jaxed=False)
agent_state: jnp.array = field(default_factory=lambda: jnp.array([[1.0]]), jaxed=False)
d_action: int = 1
# agent_state: jnp.ndarray = field(default_factory=lambda: jnp.ndarray([[1.0]]), jaxed=False)

def init(self,d_action):
self.d_action = d_action
self.agent_state = None
# self.agent_state = None
return None

def __call__(self, obs):
return jnp.zeros(self.d_action)
return jnp.zeros((self.d_action,1))

def update(self, state: jnp.ndarray, u: jnp.ndarray) -> None:
def update(self, y:jnp.array, u: jnp.ndarray) -> None:
return None
1 change: 0 additions & 1 deletion deluca/colabs/test_agents_on_lds.ipynb

This file was deleted.

16 changes: 14 additions & 2 deletions deluca/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ class Env(Obj):

@abstractmethod
def init(self):
"""Return an initialized state"""
"""Return an initial observation"""

@abstractmethod
def __call__(self, state, action, *args, **kwargs):
"""Return an updated state"""
"""Return an updated observation after taking input action"""


class AgentState(Obj):
Expand All @@ -134,9 +134,21 @@ def init(self):
return AgentState()


class Disturbance(Obj):

@abstractmethod
def init(self, *args):
"""Initializes disturbance class"""

@abstractmethod
def __call__(self, *args, **kwargs):
"""Returns the next disturbance"""


deluca.field = field
deluca.Obj = Obj
deluca.Env = Env
deluca.Agent = Agent
deluca.save = save
deluca.load = load
deluca.Disturbance = Disturbance
127 changes: 73 additions & 54 deletions deluca/envs/_lds.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,62 +14,76 @@

"""Linear dynamical system."""
from deluca.core import Env
from deluca.core import field
from deluca.core import Disturbance
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np


class LDS(Env):
"""LDS."""
key: int = field(default_factory=lambda: jax.random.key(0), jaxed=False)
A: jnp.array = field(default_factory=lambda: jnp.array([[1.0]]), jaxed=False)
B: jnp.array = field(default_factory=lambda: jnp.array([[1.0]]), jaxed=False)
C: jnp.array = field(default_factory=lambda: jnp.array([[1.0]]), jaxed=False)
state: jnp.array = field(default_factory=lambda: jnp.array([[1.0]]), jaxed=False)
d_hidden: int = field(1, jaxed=False)
d_in: int = field(1, jaxed=False)
d_out: int = field(1, jaxed=False)


def init(self,d_in = 1,d_hidden = 1, d_out = 1):
"""init.
initialize internal state to be random

Returns:
obs:
"""
self.d_in = d_in
self.d_hidden = d_hidden
self.d_out = d_out
A = jnp.diag( jnp.sign( jax.random.normal(self.key, shape=(self.d_hidden))) * 0.9 + jax.random.uniform(self.key, self.d_hidden ) * 0.04 )
print("the eigenvalues of our system:",jnp.diag(A))
B = jax.random.normal(self.key, shape=(self.d_hidden, self.d_in))
C = jax.random.normal(self.key, shape=(d_out, d_hidden)) # jax.numpy.identity(self.d_hidden) #
self.A = A
self.B = B
self.C = C
self.state = jax.random.normal(self.key, shape=(self.d_hidden, 1))
return self.C @ self.state
class ZeroDisturbance(Disturbance):
def init(self, d):
self.d = d

def __call__(self, t, key):
del t
del key
return jnp.zeros((self.d, 1))

def __call__(self, action):
"""__call__.

Args:
action:
class GaussianDisturbance(Disturbance):
def init(self, d, std=0.5):
self.d = d
self.std = std

Returns:
observation signal:
def __call__(self, t, key):
del t
return jax.random.normal(key, (self.d, 1)) * self.std

"""
assert action.shape[0] == self.d_in , "dimension of action is wrong"
self.state = self.A @ self.state + self.B @ action
return self.C @ self.state
class SinusDisturbance(Disturbance):
def init(self, d, amplitude=0.5):
self.d = d
self.amplitude = amplitude
key = jax.random.PRNGKey(0)
self.phases = jax.random.normal(key, (self.d, 1))

def __call__(self, t, key):
return (jax.numpy.sin( t * self.phases )) * self.amplitude


def generate_random_trajectory(self, trajectory_length = 1000):
class LDS(Env):
"""LDS."""

def init(self, d_in = 1, d_hidden = 25, d_out = 1, A = None, B = None, C = None, key = jax.random.PRNGKey(0), x0=None, disturbance=None):
key, subkey1, subkey2, subkey3 = jax.random.split(key, 4)
self.A = jnp.array(A) if A is not None else jnp.diag(jax.random.uniform(subkey1, (d_hidden,), minval=0.9, maxval=0.95))
self.B = jnp.array(B) if B is not None else jax.random.normal(subkey2, (d_hidden, d_in))
self.C = jnp.array(C) if C is not None else jax.random.normal(subkey3, (d_out, d_hidden))
self.d = self.A.shape[0]
self.n = self.B.shape[1]
self.p = self.C.shape[0]
self.x = jnp.zeros((self.d, 1)) if x0 is None else jnp.array(x0)
self.t = 0
if disturbance is None:
self.disturbance = GaussianDisturbance()
else:
self.disturbance = disturbance
self.disturbance.init(self.d)
self.key = key

def __call__(self, u, key = None):
if key is None:
self.key, subkey = jax.random.split(self.key, 2)
w_t = self.disturbance(self.t, subkey)
else:
w_t = self.disturbance(self.t, key)
self.x = self.A @ self.x + self.B @ u + w_t
y = self.C @ self.x
self.t += 1
return y


def generate_random_trajectory(self, trajectory_length = 1000, key=None):
"""generate_random_trajectory.
generates a trajectory of the environment with random actions

Expand All @@ -79,24 +93,29 @@ def generate_random_trajectory(self, trajectory_length = 1000):

"""
print("trajectory length =" , trajectory_length)
results = np.zeros(trajectory_length)

if key is None:
key = jax.random.PRNGKey(0)

results = jnp.zeros((trajectory_length,))
for i in range(trajectory_length):
# rand_action = jax.random.normal(self.key, shape = (self.d_in,1))
rand_action = np.random.normal(0,1,size=(self.d_in,1))
obs = self( rand_action )
results[i] = obs[0,0] # first coordinate of the observation
key, subkey1, subkey2 = jax.random.split(key, 3)
rand_action = jax.random.normal(subkey1, (self.n, 1))
obs = self(rand_action, subkey2)
results = results.at[i].set(obs[0, 0]) # first coordinate of the observation

return results

def show_me_the_signal(self,length = 1000):
results = self.generate_random_trajectory(length)
plt.plot(results)
def show_me_the_signal(self,length = 1000, key=None):
results = self.generate_random_trajectory(length, key)
plt.plot(np.array(results))
plt.show()

def diagnostics(self):
print("my parameters are")
print("d_in, d_hidden, d_out are:")
print(self.d_in, self.d_hidden, self.d_out)
print("d_in (n), d_hidden (d), d_out (p) are:")
print(self.n, self.d, self.p)
print("A, B, C are: ")
print(self.A, self.B, self.C)
print("state is:")
print(self.state)
print(self.x)
4 changes: 2 additions & 2 deletions deluca/lung/controllers/_deep_actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __call__(self, x):


class DeepACControllerState(deluca.Obj):
errs: jnp.array
errs: jnp.ndarray
key: int
time: float = float('inf')
steps: int = 0
Expand All @@ -92,7 +92,7 @@ class DeepAC(Controller):

params: list = deluca.field(jaxed=True)
model: nn.module = deluca.field(ActorCritic, jaxed=False)
featurizer: jnp.array = deluca.field(jaxed=False)
featurizer: jnp.ndarray = deluca.field(jaxed=False)
H: int = deluca.field(100, jaxed=False)
input_dim: int = deluca.field(1, jaxed=False)
history_len: int = deluca.field(10, jaxed=False)
Expand Down
Loading