Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions deluca/_src/agents/bang_bang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023 The Deluca Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A bang-bang or on-off agent."""

import jax.numpy as jnp
from optax._src import base


BangBangState = base.EmptyAgentState


def bang_bang(
target: float | jnp.array,
min_action: float | jnp.array,
max_action: float | jnp.array,
) -> base.Agent:
"""An on-off agent.

NOTE: `target`, `min_action`, `max_action`, and `obs` must be the same shape.

Args:
target: The target value of the agent.
min_action: The minimum or "off" action.
max_action: The maximum or "on" action.

Returns:
A Bang Bang agent.
"""

def init_fn():
return BangBangState()

def action_fn(
state: BangBangState, obs: base.EnvironmentState
) -> tuple[BangBangState, float | jnp.array]:
return state, jnp.where(obs > target, min_action, max_action)

def update_fn(state: BangBangState) -> BangBangState:
return state

return base.Agent(init_fn, action_fn, update_fn)
45 changes: 45 additions & 0 deletions deluca/_src/agents/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2023 The Deluca Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A constant controller."""

import jax.numpy as jnp
from optax._src import base


def constant(
control: float | jnp.array,
) -> base.Agent:
"""A controller that gives a constant response.

Args:
control: The control to give.

Returns:
A constant agent.
"""

def init_fn():
return None

def control_fn(
state: base.AgentState, obs: base.EnvironmentState
) -> tuple[base.AgentState, float | jnp.array]:
del obs
return state, control

def update_fn(state: base.AgentState) -> base.AgentState:
return state

return base.Agent(init_fn, control_fn, update_fn)
65 changes: 65 additions & 0 deletions deluca/_src/agents/pid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2023 The Deluca Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Vanilla PID agent."""

import jax.numpy as jnp
from optax._src import base


class PIDState(base.AgentState):
P: float
I: float
D: float


def pid(
kp: float,
ki: float,
kd: float,
dt: float,
rc: float = 0.0,
) -> base.Agent:
"""A PID agent.

Args:
kp: P constant.
ki: I constant.
kd: D constant.
dt: Time increment.
rc: Delay constant.

Returns:
A constant agent.
"""

decay = dt / (dt + rc)

def init_fn():
return PIDState()

def action_fn(
state: PIDState, obs: base.EnvironmentState
) -> tuple[PIDState, float | jnp.array]:
p = obs
i = state.I + decay * (obs - state.I)
d = state.D + decay * (obs - state.P - state.D)

action = kp * p + ki * i + kd * d
return state.replace(P=p, I=i, D=d), action

def update_fn(state: PIDState) -> PIDState:
return state

return base.Agent(init_fn, action_fn, update_fn)
59 changes: 59 additions & 0 deletions deluca/_src/agents/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2023 The Deluca Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A random agent."""

from typing import Callable

import jax
import jax.numpy as jnp
from optax._src import base


class RandomState(base.AgentState):
key: jnp.ndarray


def random(
key: jnp.ndarray,
shape: tuple[int, ...],
func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] | None = None,
) -> base.Agent:
"""An agent that outputs random actions.

Args:
key: A `jax` key.
shape: The shape of the output actions.
func: A function that takes a key and a shape and outputs random actions.

Returns:
A Bang Bang agent.
"""

if func is None:
func = jax.random.uniform

def init_fn():
return RandomState(key=key)

def action_fn(
state: RandomState, obs: base.EnvironmentState
) -> tuple[RandomState, float | jnp.array]:
del obs
return state, func(state.key, shape)

def update_fn(state: RandomState) -> RandomState:
return state.replace(key=jax.random.fold_in(key, shape))

return base.Agent(init_fn, action_fn, update_fn)
Loading