Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions batch_run_analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def get_args():
parser.add_argument('--pi_steps', type=int, default=10000,
help='For memory iteration, how many steps of policy improvement do we do per iteration?')


parser.add_argument('--policy_optim_alg', type=str, default='policy_grad',
help='policy improvement algorithm to use. "policy_iter" - policy iteration, "policy_grad" - policy gradient, '
'"discrep_max" - discrepancy maximization, "discrep_min" - discrepancy minimization')
Expand All @@ -75,7 +74,9 @@ def get_args():
parser.add_argument('--random_policies', default=100, type=int,
help='How many random policies do we use for random kitchen sinks??')
parser.add_argument('--leave_out_optimal', action='store_true',
help="Do we include the optimal policy when we select the initial policy")
help="Do we include the optimal policy when we select the initial policy?")
parser.add_argument('--mem_aug_before_init_pi', action='store_true',
help="Do we augment our memory before selecting the highest LD initial policy?")
parser.add_argument('--n_mem_states', default=2, type=int,
help='for memory_id = 0, how many memory states do we have?')

Expand Down Expand Up @@ -121,6 +122,13 @@ def get_kitchen_sink_policy(policies: jnp.ndarray, pomdp: POMDP, measure: Callab
all_policy_measures, _, _ = batch_measures(policies, pomdp)
return policies[jnp.argmax(all_policy_measures)]

def get_mem_kitchen_sink_policy(policies: jnp.ndarray,
mem_params: jnp.ndarray,
pomdp: POMDP):
mem_policies = policies.repeat(mem_params.shape[-1], axis=1)
batch_measures = jax.vmap(mem_discrep_loss, in_axes=(None, 0, None))
all_policy_measures = batch_measures(mem_params, mem_policies, pomdp)
return policies[jnp.argmax(all_policy_measures)]

def make_experiment(args):

Expand Down Expand Up @@ -193,17 +201,20 @@ def update_pg_step(inps, i):
if args.leave_out_optimal:
pis_with_memoryless_optimal = pi_paramses[:-1]

# We initialize mem params here
mem_shape = (1, pomdp.action_space.n, pomdp.observation_space.n, args.n_mem_states, args.n_mem_states)
mem_params = random.normal(mem_rng, shape=mem_shape) * 0.5

# now we get our kitchen sink policies
kitchen_sinks_info = {}
ld_pi_params = get_kitchen_sink_policy(pis_with_memoryless_optimal, pomdp, discrep_loss)
if args.mem_aug_before_init_pi:
ld_pi_params = get_kitchen_sink_policy(pis_with_memoryless_optimal, pomdp, discrep_loss)
else:
ld_pi_params = get_mem_kitchen_sink_policy(pis_with_memoryless_optimal, mem_params, pomdp)
pis_to_learn_mem = jnp.stack([ld_pi_params])

kitchen_sinks_info['ld'] = ld_pi_params.copy()

# We initialize 3 mem params: 1 for LD
mem_shape = (pis_to_learn_mem.shape[0], pomdp.action_space.n, pomdp.observation_space.n, args.n_mem_states, args.n_mem_states)
mem_params = random.normal(mem_rng, shape=mem_shape) * 0.5

mem_tx_params = jax.vmap(optim.init, in_axes=0)(mem_params)

info['beginning']['init_mem_params'] = mem_params.copy()
Expand Down
26 changes: 26 additions & 0 deletions lamb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,32 @@ def __call__(self, hidden, x):

return hidden, pi, jnp.squeeze(v, axis=-1)


class PelletPredictorNN(nn.Module):
hidden_size: int
n_outs: int
n_hidden_layers: int = 1

@nn.compact
def __call__(self, x):
out = nn.Dense(self.hidden_size, kernel_init=orthogonal(2), bias_init=constant(0.0))(
x
)
out = nn.relu(out)

for i in range(self.n_hidden_layers):
out = nn.Dense(
self.hidden_size, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
)(x)
out = nn.relu(out)

logits = nn.Dense(
self.n_outs, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
)(out)
predictions = nn.sigmoid(logits)
return predictions, logits


def get_network_fn(env: environment.Environment, env_params: environment.EnvParams,
memoryless: bool = False):
if isinstance(env, Battleship) or (hasattr(env, '_unwrapped') and isinstance(env._unwrapped, Battleship)):
Expand Down
44 changes: 44 additions & 0 deletions lamb/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,49 @@
import functools
from typing import Callable, Optional

import jax
from jax import numpy as jnp
import numpy as np

def add_dim_to_args(
func: Callable,
axis: int = 1,
starting_arg_index: Optional[int] = 1,
ending_arg_index: Optional[int] = None,
kwargs_on_device_keys: Optional[list] = None,
):
"""Adds a dimension to the specified arguments of a function.

Args:
func (Callable): The function to wrap.
axis (int, optional): The axis to add the dimension to. Defaults to 1.
starting_arg_index (Optional[int], optional): The index of the first argument to
add the dimension to. Defaults to 1.
ending_arg_index (Optional[int], optional): The index of the last argument to
add the dimension to. Defaults to None.
kwargs_on_device_keys (Optional[list], optional): The keys of the kwargs that should
be added to. Defaults to None.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
if ending_arg_index is None:
end_index = len(args)
else:
end_index = ending_arg_index

args = list(args)
args[starting_arg_index:end_index] = [
jax.tree.map(lambda x: jnp.expand_dims(x, axis=axis), a)
for a in args[starting_arg_index:end_index]
]
for k, v in kwargs.items():
if kwargs_on_device_keys is None or k in kwargs_on_device_keys:
kwargs[k] = jax.tree.map(lambda x: jnp.expand_dims(x, axis=1), v)
return func(*args, **kwargs)

return wrapper


def one_hot(x, n):
return np.eye(n)[x]
33 changes: 29 additions & 4 deletions lamb/utils/file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,40 @@ def load_info(results_path: Path) -> dict:
return np.load(results_path, allow_pickle=True).item()


def load_train_state(key: jax.random.PRNGKey, fpath: Path):
def load_train_state(key: jax.random.PRNGKey, fpath: Path,
update_idx_to_take: int = None,
best_over_rng: bool = False):
# load our params
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
restored = orbax_checkpointer.restore(fpath)
args = restored['args']
unpacked_ts = restored['out']['runner_state'][0]


if update_idx_to_take is None:
best_idx = 0
if best_over_rng:
# we take the max here since we just want episode returns over all seeds
# and we take the mean over axis=-1 since we do n episodes of eval.
perf_across_seeds = restored['out']['final_eval_metric']['returned_discounted_episode_returns'].max(axis=-2).mean(axis=-1)
best_idx = np.squeeze(np.argmax(perf_across_seeds, axis=-1))

params = jax.tree_map(lambda x: x[0, 0, 0, 0, 0, 0, best_idx], unpacked_ts['params'])
else:
perf_across_seeds_expanded = restored['out']['metric']['returned_discounted_episode_returns'].squeeze().mean(axis=-1).mean(axis=-1)
all_ckpt_params = jax.tree.map(lambda x: x[0, 0, 0, 0, 0, 0], restored['out']['checkpoint'])
n_ckpt_steps = jax.tree.flatten(all_ckpt_params)[0][0].shape[1]
perf_interval = perf_across_seeds_expanded.shape[1] // n_ckpt_steps
perf_across_seeds = perf_across_seeds_expanded[:, ::perf_interval]
timestep_perf = perf_across_seeds[:, update_idx_to_take]
best_idx = np.argmax(timestep_perf)
params = jax.tree_map(lambda x: x[best_idx, update_idx_to_take], all_ckpt_params)


gamma = args['gamma']
if 'config' in restored:
gamma = restored['config']['GAMMA']
env, env_params = get_gymnax_env(args['env'], key,
restored['config']['GAMMA'],
gamma=gamma,
action_concat=args['action_concat'])

network_fn, action_size = get_network_fn(env, env_params, memoryless=args['memoryless'])
Expand All @@ -110,8 +134,9 @@ def load_train_state(key: jax.random.PRNGKey, fpath: Path):
double_critic=args['double_critic'],
hidden_size=args['hidden_size'])
tx = optax.adam(args['lr'][0])

ts = TrainState.create(apply_fn=network.apply,
params=jax.tree_map(lambda x: x[0, 0, 0, 0, 0, 0], unpacked_ts['params']),
params=params,
tx=tx)

return env, env_params, args, network, ts
Expand Down
2 changes: 2 additions & 0 deletions lamb/utils/replay/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .flat import make_flat_buffer, TransitionSample
from .trajectory import make_trajectory_buffer, TrajectoryBufferSample
182 changes: 182 additions & 0 deletions lamb/utils/replay/flat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""
Taken from https://github.com/instadeepai/flashbax/blob/main/flashbax/buffers/flat_buffer.py
"""
import warnings
from typing import TYPE_CHECKING, Generic, Optional

from chex import PRNGKey
from typing_extensions import NamedTuple

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
from chex import dataclass

import jax

from lamb.utils.data import add_dim_to_args

from .trajectory import (
Experience,
TrajectoryBuffer,
TrajectoryBufferState,
make_trajectory_buffer
)

class ExperiencePair(NamedTuple, Generic[Experience]):
first: Experience
second: Experience


@dataclass(frozen=True)
class TransitionSample(Generic[Experience]):
experience: ExperiencePair[Experience]


def validate_sample_batch_size(sample_batch_size: int, max_length: int):
if sample_batch_size > max_length:
raise ValueError("sample_batch_size must be less than or equal to max_length")


def validate_min_length(min_length: int, add_batch_size: int, max_length: int):
used_min_length = min_length // add_batch_size + 1
if used_min_length > max_length:
raise ValueError("min_length used is too large for the buffer size.")


def validate_max_length_add_batch_size(max_length: int, add_batch_size: int):
if max_length // add_batch_size < 2:
raise ValueError(
f"""max_length//add_batch_size must be greater than 2. It is currently
{max_length}//{add_batch_size} = {max_length//add_batch_size}"""
)


def validate_flat_buffer_args(
max_length: int,
min_length: int,
sample_batch_size: int,
add_batch_size: int,
):
"""Validates the arguments for the flat buffer."""

validate_sample_batch_size(sample_batch_size, max_length)
validate_min_length(min_length, add_batch_size, max_length)
validate_max_length_add_batch_size(max_length, add_batch_size)


def create_flat_buffer(
max_length: int,
min_length: int,
sample_batch_size: int,
add_sequences: bool,
add_batch_size: Optional[int],
) -> TrajectoryBuffer:
"""Creates a trajectory buffer that acts as a flat buffer.

Args:
max_length (int): The maximum length of the buffer.
min_length (int): The minimum length of the buffer.
sample_batch_size (int): The batch size of the samples.
add_sequences (Optional[bool], optional): Whether data is being added in sequences
to the buffer. If False, single transitions are being added each time add
is called. Defaults to False.
add_batch_size (Optional[int], optional): If adding data in batches, what is the
batch size that is being added each time. If None, single transitions or single
sequences are being added each time add is called. Defaults to None.

Returns:
The buffer."""

if add_batch_size is None:
# add_batch_size being None implies that we are adding single transitions
add_batch_size = 1
add_batches = False
else:
add_batches = True

validate_flat_buffer_args(
max_length=max_length,
min_length=min_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Setting max_size dynamically sets the `max_length_time_axis` to "
f"be `max_size`//`add_batch_size = {max_length // add_batch_size}`."
"This allows one to control exactly how many transitions are stored in the buffer."
"Note that this overrides the `max_length_time_axis` argument.",
)

buffer = make_trajectory_buffer(
max_length_time_axis=None, # Unused because max_size is specified
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
period=1,
max_size=max_length,
)

add_fn = buffer.add

if not add_batches:
add_fn = add_dim_to_args(
add_fn, axis=0, starting_arg_index=1, ending_arg_index=2
)

if not add_sequences:
axis = 1 - int(not add_batches) # 1 if add_batches else 0
add_fn = add_dim_to_args(
add_fn, axis=axis, starting_arg_index=1, ending_arg_index=2
)

def sample_fn(state: TrajectoryBufferState, rng_key: PRNGKey) -> TransitionSample:
"""Samples a batch of transitions from the buffer."""
sampled_batch = buffer.sample(state, rng_key).experience
first = jax.tree.map(lambda x: x[:, 0], sampled_batch)
second = jax.tree.map(lambda x: x[:, 1], sampled_batch)
return TransitionSample(experience=ExperiencePair(first=first, second=second))

def all_fn(state: TrajectoryBufferState) -> TransitionSample:
"""Returns all transitions."""
first = jax.tree.map(lambda x: x[0, :-1], state.experience)
second = jax.tree.map(lambda x: x[0:, 1:], state.experience)
return TransitionSample(experience=ExperiencePair(first=first, second=second))

return buffer.replace(add=add_fn, sample=sample_fn, all=all_fn) # type: ignore


def make_flat_buffer(
max_length: int,
min_length: int,
sample_batch_size: int,
add_sequences: bool = False,
add_batch_size: Optional[int] = None,
) -> TrajectoryBuffer:
"""Makes a trajectory buffer act as a flat buffer.

Args:
max_length (int): The maximum length of the buffer.
min_length (int): The minimum length of the buffer.
sample_batch_size (int): The batch size of the samples.
add_sequences (Optional[bool], optional): Whether data is being added in sequences
to the buffer. If False, single transitions are being added each time add
is called. Defaults to False.
add_batch_size (Optional[int], optional): If adding data in batches, what is the
batch size that is being added each time. If None, single transitions or single
sequences are being added each time add is called. Defaults to None.

Returns:
The buffer."""

return create_flat_buffer(
max_length=max_length,
min_length=min_length,
sample_batch_size=sample_batch_size,
add_sequences=add_sequences,
add_batch_size=add_batch_size,
)
Loading