diff --git a/batch_run_analytical.py b/batch_run_analytical.py index 9b2dced..942de6d 100644 --- a/batch_run_analytical.py +++ b/batch_run_analytical.py @@ -64,7 +64,6 @@ def get_args(): parser.add_argument('--pi_steps', type=int, default=10000, help='For memory iteration, how many steps of policy improvement do we do per iteration?') - parser.add_argument('--policy_optim_alg', type=str, default='policy_grad', help='policy improvement algorithm to use. "policy_iter" - policy iteration, "policy_grad" - policy gradient, ' '"discrep_max" - discrepancy maximization, "discrep_min" - discrepancy minimization') @@ -75,7 +74,9 @@ def get_args(): parser.add_argument('--random_policies', default=100, type=int, help='How many random policies do we use for random kitchen sinks??') parser.add_argument('--leave_out_optimal', action='store_true', - help="Do we include the optimal policy when we select the initial policy") + help="Do we include the optimal policy when we select the initial policy?") + parser.add_argument('--mem_aug_before_init_pi', action='store_true', + help="Do we augment our memory before selecting the highest LD initial policy?") parser.add_argument('--n_mem_states', default=2, type=int, help='for memory_id = 0, how many memory states do we have?') @@ -121,6 +122,13 @@ def get_kitchen_sink_policy(policies: jnp.ndarray, pomdp: POMDP, measure: Callab all_policy_measures, _, _ = batch_measures(policies, pomdp) return policies[jnp.argmax(all_policy_measures)] +def get_mem_kitchen_sink_policy(policies: jnp.ndarray, + mem_params: jnp.ndarray, + pomdp: POMDP): + mem_policies = policies.repeat(mem_params.shape[-1], axis=1) + batch_measures = jax.vmap(mem_discrep_loss, in_axes=(None, 0, None)) + all_policy_measures = batch_measures(mem_params, mem_policies, pomdp) + return policies[jnp.argmax(all_policy_measures)] def make_experiment(args): @@ -193,17 +201,20 @@ def update_pg_step(inps, i): if args.leave_out_optimal: pis_with_memoryless_optimal = pi_paramses[:-1] + # We initialize mem params here + mem_shape = (1, pomdp.action_space.n, pomdp.observation_space.n, args.n_mem_states, args.n_mem_states) + mem_params = random.normal(mem_rng, shape=mem_shape) * 0.5 + # now we get our kitchen sink policies kitchen_sinks_info = {} - ld_pi_params = get_kitchen_sink_policy(pis_with_memoryless_optimal, pomdp, discrep_loss) + if args.mem_aug_before_init_pi: + ld_pi_params = get_kitchen_sink_policy(pis_with_memoryless_optimal, pomdp, discrep_loss) + else: + ld_pi_params = get_mem_kitchen_sink_policy(pis_with_memoryless_optimal, mem_params, pomdp) pis_to_learn_mem = jnp.stack([ld_pi_params]) kitchen_sinks_info['ld'] = ld_pi_params.copy() - # We initialize 3 mem params: 1 for LD - mem_shape = (pis_to_learn_mem.shape[0], pomdp.action_space.n, pomdp.observation_space.n, args.n_mem_states, args.n_mem_states) - mem_params = random.normal(mem_rng, shape=mem_shape) * 0.5 - mem_tx_params = jax.vmap(optim.init, in_axes=0)(mem_params) info['beginning']['init_mem_params'] = mem_params.copy() diff --git a/lamb/models.py b/lamb/models.py index b56b1f8..202509b 100644 --- a/lamb/models.py +++ b/lamb/models.py @@ -539,6 +539,32 @@ def __call__(self, hidden, x): return hidden, pi, jnp.squeeze(v, axis=-1) + +class PelletPredictorNN(nn.Module): + hidden_size: int + n_outs: int + n_hidden_layers: int = 1 + + @nn.compact + def __call__(self, x): + out = nn.Dense(self.hidden_size, kernel_init=orthogonal(2), bias_init=constant(0.0))( + x + ) + out = nn.relu(out) + + for i in range(self.n_hidden_layers): + out = nn.Dense( + self.hidden_size, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(x) + out = nn.relu(out) + + logits = nn.Dense( + self.n_outs, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(out) + predictions = nn.sigmoid(logits) + return predictions, logits + + def get_network_fn(env: environment.Environment, env_params: environment.EnvParams, memoryless: bool = False): if isinstance(env, Battleship) or (hasattr(env, '_unwrapped') and isinstance(env._unwrapped, Battleship)): diff --git a/lamb/utils/data.py b/lamb/utils/data.py index ec9bfb9..73b9c65 100644 --- a/lamb/utils/data.py +++ b/lamb/utils/data.py @@ -1,5 +1,49 @@ +import functools +from typing import Callable, Optional + +import jax +from jax import numpy as jnp import numpy as np +def add_dim_to_args( + func: Callable, + axis: int = 1, + starting_arg_index: Optional[int] = 1, + ending_arg_index: Optional[int] = None, + kwargs_on_device_keys: Optional[list] = None, +): + """Adds a dimension to the specified arguments of a function. + + Args: + func (Callable): The function to wrap. + axis (int, optional): The axis to add the dimension to. Defaults to 1. + starting_arg_index (Optional[int], optional): The index of the first argument to + add the dimension to. Defaults to 1. + ending_arg_index (Optional[int], optional): The index of the last argument to + add the dimension to. Defaults to None. + kwargs_on_device_keys (Optional[list], optional): The keys of the kwargs that should + be added to. Defaults to None. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if ending_arg_index is None: + end_index = len(args) + else: + end_index = ending_arg_index + + args = list(args) + args[starting_arg_index:end_index] = [ + jax.tree.map(lambda x: jnp.expand_dims(x, axis=axis), a) + for a in args[starting_arg_index:end_index] + ] + for k, v in kwargs.items(): + if kwargs_on_device_keys is None or k in kwargs_on_device_keys: + kwargs[k] = jax.tree.map(lambda x: jnp.expand_dims(x, axis=1), v) + return func(*args, **kwargs) + + return wrapper + def one_hot(x, n): return np.eye(n)[x] diff --git a/lamb/utils/file_system.py b/lamb/utils/file_system.py index 051dc6b..4829923 100644 --- a/lamb/utils/file_system.py +++ b/lamb/utils/file_system.py @@ -92,16 +92,40 @@ def load_info(results_path: Path) -> dict: return np.load(results_path, allow_pickle=True).item() -def load_train_state(key: jax.random.PRNGKey, fpath: Path): +def load_train_state(key: jax.random.PRNGKey, fpath: Path, + update_idx_to_take: int = None, + best_over_rng: bool = False): # load our params orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() restored = orbax_checkpointer.restore(fpath) args = restored['args'] unpacked_ts = restored['out']['runner_state'][0] - + if update_idx_to_take is None: + best_idx = 0 + if best_over_rng: + # we take the max here since we just want episode returns over all seeds + # and we take the mean over axis=-1 since we do n episodes of eval. + perf_across_seeds = restored['out']['final_eval_metric']['returned_discounted_episode_returns'].max(axis=-2).mean(axis=-1) + best_idx = np.squeeze(np.argmax(perf_across_seeds, axis=-1)) + + params = jax.tree_map(lambda x: x[0, 0, 0, 0, 0, 0, best_idx], unpacked_ts['params']) + else: + perf_across_seeds_expanded = restored['out']['metric']['returned_discounted_episode_returns'].squeeze().mean(axis=-1).mean(axis=-1) + all_ckpt_params = jax.tree.map(lambda x: x[0, 0, 0, 0, 0, 0], restored['out']['checkpoint']) + n_ckpt_steps = jax.tree.flatten(all_ckpt_params)[0][0].shape[1] + perf_interval = perf_across_seeds_expanded.shape[1] // n_ckpt_steps + perf_across_seeds = perf_across_seeds_expanded[:, ::perf_interval] + timestep_perf = perf_across_seeds[:, update_idx_to_take] + best_idx = np.argmax(timestep_perf) + params = jax.tree_map(lambda x: x[best_idx, update_idx_to_take], all_ckpt_params) + + + gamma = args['gamma'] + if 'config' in restored: + gamma = restored['config']['GAMMA'] env, env_params = get_gymnax_env(args['env'], key, - restored['config']['GAMMA'], + gamma=gamma, action_concat=args['action_concat']) network_fn, action_size = get_network_fn(env, env_params, memoryless=args['memoryless']) @@ -110,8 +134,9 @@ def load_train_state(key: jax.random.PRNGKey, fpath: Path): double_critic=args['double_critic'], hidden_size=args['hidden_size']) tx = optax.adam(args['lr'][0]) + ts = TrainState.create(apply_fn=network.apply, - params=jax.tree_map(lambda x: x[0, 0, 0, 0, 0, 0], unpacked_ts['params']), + params=params, tx=tx) return env, env_params, args, network, ts diff --git a/lamb/utils/replay/__init__.py b/lamb/utils/replay/__init__.py new file mode 100644 index 0000000..8e3da3f --- /dev/null +++ b/lamb/utils/replay/__init__.py @@ -0,0 +1,2 @@ +from .flat import make_flat_buffer, TransitionSample +from .trajectory import make_trajectory_buffer, TrajectoryBufferSample \ No newline at end of file diff --git a/lamb/utils/replay/flat.py b/lamb/utils/replay/flat.py new file mode 100644 index 0000000..ed78933 --- /dev/null +++ b/lamb/utils/replay/flat.py @@ -0,0 +1,182 @@ +""" +Taken from https://github.com/instadeepai/flashbax/blob/main/flashbax/buffers/flat_buffer.py +""" +import warnings +from typing import TYPE_CHECKING, Generic, Optional + +from chex import PRNGKey +from typing_extensions import NamedTuple + +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + +import jax + +from lamb.utils.data import add_dim_to_args + +from .trajectory import ( + Experience, + TrajectoryBuffer, + TrajectoryBufferState, + make_trajectory_buffer +) + +class ExperiencePair(NamedTuple, Generic[Experience]): + first: Experience + second: Experience + + +@dataclass(frozen=True) +class TransitionSample(Generic[Experience]): + experience: ExperiencePair[Experience] + + +def validate_sample_batch_size(sample_batch_size: int, max_length: int): + if sample_batch_size > max_length: + raise ValueError("sample_batch_size must be less than or equal to max_length") + + +def validate_min_length(min_length: int, add_batch_size: int, max_length: int): + used_min_length = min_length // add_batch_size + 1 + if used_min_length > max_length: + raise ValueError("min_length used is too large for the buffer size.") + + +def validate_max_length_add_batch_size(max_length: int, add_batch_size: int): + if max_length // add_batch_size < 2: + raise ValueError( + f"""max_length//add_batch_size must be greater than 2. It is currently + {max_length}//{add_batch_size} = {max_length//add_batch_size}""" + ) + + +def validate_flat_buffer_args( + max_length: int, + min_length: int, + sample_batch_size: int, + add_batch_size: int, +): + """Validates the arguments for the flat buffer.""" + + validate_sample_batch_size(sample_batch_size, max_length) + validate_min_length(min_length, add_batch_size, max_length) + validate_max_length_add_batch_size(max_length, add_batch_size) + + +def create_flat_buffer( + max_length: int, + min_length: int, + sample_batch_size: int, + add_sequences: bool, + add_batch_size: Optional[int], +) -> TrajectoryBuffer: + """Creates a trajectory buffer that acts as a flat buffer. + + Args: + max_length (int): The maximum length of the buffer. + min_length (int): The minimum length of the buffer. + sample_batch_size (int): The batch size of the samples. + add_sequences (Optional[bool], optional): Whether data is being added in sequences + to the buffer. If False, single transitions are being added each time add + is called. Defaults to False. + add_batch_size (Optional[int], optional): If adding data in batches, what is the + batch size that is being added each time. If None, single transitions or single + sequences are being added each time add is called. Defaults to None. + + Returns: + The buffer.""" + + if add_batch_size is None: + # add_batch_size being None implies that we are adding single transitions + add_batch_size = 1 + add_batches = False + else: + add_batches = True + + validate_flat_buffer_args( + max_length=max_length, + min_length=min_length, + sample_batch_size=sample_batch_size, + add_batch_size=add_batch_size, + ) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Setting max_size dynamically sets the `max_length_time_axis` to " + f"be `max_size`//`add_batch_size = {max_length // add_batch_size}`." + "This allows one to control exactly how many transitions are stored in the buffer." + "Note that this overrides the `max_length_time_axis` argument.", + ) + + buffer = make_trajectory_buffer( + max_length_time_axis=None, # Unused because max_size is specified + min_length_time_axis=min_length // add_batch_size + 1, + add_batch_size=add_batch_size, + sample_batch_size=sample_batch_size, + sample_sequence_length=2, + period=1, + max_size=max_length, + ) + + add_fn = buffer.add + + if not add_batches: + add_fn = add_dim_to_args( + add_fn, axis=0, starting_arg_index=1, ending_arg_index=2 + ) + + if not add_sequences: + axis = 1 - int(not add_batches) # 1 if add_batches else 0 + add_fn = add_dim_to_args( + add_fn, axis=axis, starting_arg_index=1, ending_arg_index=2 + ) + + def sample_fn(state: TrajectoryBufferState, rng_key: PRNGKey) -> TransitionSample: + """Samples a batch of transitions from the buffer.""" + sampled_batch = buffer.sample(state, rng_key).experience + first = jax.tree.map(lambda x: x[:, 0], sampled_batch) + second = jax.tree.map(lambda x: x[:, 1], sampled_batch) + return TransitionSample(experience=ExperiencePair(first=first, second=second)) + + def all_fn(state: TrajectoryBufferState) -> TransitionSample: + """Returns all transitions.""" + first = jax.tree.map(lambda x: x[0, :-1], state.experience) + second = jax.tree.map(lambda x: x[0:, 1:], state.experience) + return TransitionSample(experience=ExperiencePair(first=first, second=second)) + + return buffer.replace(add=add_fn, sample=sample_fn, all=all_fn) # type: ignore + + +def make_flat_buffer( + max_length: int, + min_length: int, + sample_batch_size: int, + add_sequences: bool = False, + add_batch_size: Optional[int] = None, +) -> TrajectoryBuffer: + """Makes a trajectory buffer act as a flat buffer. + + Args: + max_length (int): The maximum length of the buffer. + min_length (int): The minimum length of the buffer. + sample_batch_size (int): The batch size of the samples. + add_sequences (Optional[bool], optional): Whether data is being added in sequences + to the buffer. If False, single transitions are being added each time add + is called. Defaults to False. + add_batch_size (Optional[int], optional): If adding data in batches, what is the + batch size that is being added each time. If None, single transitions or single + sequences are being added each time add is called. Defaults to None. + + Returns: + The buffer.""" + + return create_flat_buffer( + max_length=max_length, + min_length=min_length, + sample_batch_size=sample_batch_size, + add_sequences=add_sequences, + add_batch_size=add_batch_size, + ) diff --git a/lamb/utils/replay/trajectory.py b/lamb/utils/replay/trajectory.py new file mode 100644 index 0000000..8a9e170 --- /dev/null +++ b/lamb/utils/replay/trajectory.py @@ -0,0 +1,618 @@ +""" +Taken from https://github.com/instadeepai/flashbax/blob/main/flashbax/buffers/trajectory_buffer.py +""" + + +""""Pure functions defining the trajectory buffer. The trajectory buffer takes batches of n-step +experience data, where n is the number of time steps within a trajectory. The trajectory buffer +concatenates consecutive batches of experience data along the time axis, retaining their ordering. +This allows for random sampling of the trajectories within the buffer. +""" + +import functools +import warnings +from typing import TYPE_CHECKING, Callable, Generic, Optional, TypeVar + +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + +import chex +import jax +import jax.numpy as jnp +from jax import Array + +Experience = TypeVar("Experience", bound=chex.ArrayTree) + + +def get_tree_shape_prefix(tree: chex.ArrayTree, n_axes: int = 1) -> chex.Shape: + """Get the shape of the leading axes (up to n_axes) of a pytree. This assumes all + leaves have a common leading axes size (e.g. a common batch size).""" + flat_tree, tree_def = jax.tree_util.tree_flatten(tree) + leaf = flat_tree[0] + leading_axis_shape = leaf.shape[0:n_axes] + chex.assert_tree_shape_prefix(tree, leading_axis_shape) + return leading_axis_shape + + +@dataclass(frozen=True) +class TrajectoryBufferState(Generic[Experience]): + """State of the trajectory replay buffer. + + Attributes: + experience: Arbitrary pytree containing the experience data, for example a single + timestep (s,a,r). These are stacked along the first axis. + current_index: Index where the next batch of experience data will be added to. + is_full: Whether the buffer state is completely full with experience (otherwise it will + have some empty padded values). + """ + + experience: Experience + current_index: Array + is_full: Array + + +@dataclass(frozen=True) +class TrajectoryBufferSample(Generic[Experience]): + """Container for samples from the buffer + + Attributes: + experience: Arbitrary pytree containing a batch of experience data. + """ + + experience: Experience + + +def init( + experience: Experience, + add_batch_size: int, + max_length_time_axis: int, +) -> TrajectoryBufferState[Experience]: + """ + Initialise the buffer state. + + Args: + experience: A single timestep (e.g. (s,a,r)) used for inferring + the structure of the experience data that will be saved in the buffer state. + add_batch_size: Batch size of experience added to the buffer's state using the `add` + function. I.e. the leading batch size of added experience should have size + `add_batch_size`. + max_length_time_axis: Maximum length of the buffer along the time axis (second axis of the + experience data). + + Returns: + state: Initial state of the replay buffer. All values are empty as no experience has + been added yet. + """ + # Set experience value to be empty. + experience = jax.tree.map(jnp.empty_like, experience) + + # Broadcast to [add_batch_size, max_length_time_axis] + experience = jax.tree.map( + lambda x: jnp.broadcast_to( + x[None, None, ...], (add_batch_size, max_length_time_axis, *x.shape) + ), + experience, + ) + + state = TrajectoryBufferState( + experience=experience, + is_full=jnp.array(False, dtype=bool), + current_index=jnp.array(0), + ) + return state + + +def add( + state: TrajectoryBufferState[Experience], + batch: Experience, +) -> TrajectoryBufferState[Experience]: + """ + Add a batch of experience to the buffer state. Assumes that this carries on from the episode + where the previous added batch of experience ended. For example, if we consider a single + trajectory within the batch; if the last timestep of the previous added trajectory's was at + time `t` then the first timestep of the current trajectory will be at time `t + 1`. + + Args: + state: The buffer state. + batch: A batch of experience. The leading axis of the pytree is the batch dimension. + This must match `add_batch_size` and the structure of the experience used + during initialisation of the buffer state. This batch is added along the time axis of + the buffer state. + + + Returns: + A new buffer state with the batch of experience added. + """ + # Check that the batch has the correct shape. + chex.assert_tree_shape_prefix(batch, get_tree_shape_prefix(state.experience)) + # Check that the batch has the correct dtypes. + chex.assert_trees_all_equal_dtypes(batch, state.experience) + + # Get the length of the time axis of the buffer state. + max_length_time_axis = get_tree_shape_prefix(state.experience, n_axes=2)[1] + # Check that the sequence length is less than or equal the maximum length of the time axis. + chex.assert_axis_dimension_lteq( + jax.tree_util.tree_leaves(batch)[0], 1, max_length_time_axis + ) + + # Get the length of the sequence of the batch. + seq_len = get_tree_shape_prefix(batch, n_axes=2)[1] + + # Calculate index location in the state where we will assign the batch of experience. + indices = (jnp.arange(seq_len) + state.current_index) % max_length_time_axis + + # Update the buffer state. + experience = jax.tree.map( + lambda experience_field, batch_field: experience_field.at[:, indices].set( + batch_field + ), + state.experience, + batch, + ) + + new_index = state.current_index + seq_len + is_full = state.is_full | (new_index >= max_length_time_axis) + new_index = new_index % max_length_time_axis + + state = state.replace( # type: ignore + experience=experience, + current_index=new_index, + is_full=is_full, + ) + + return state + + +def get_invalid_indices( + state: TrajectoryBufferState[Experience], + sample_sequence_length: int, + period: int, + add_batch_size: int, + max_length_time_axis: int, +) -> Array: + """ + Get the indices of the items that will be invalid when sampling from the buffer state. This + is used to mask out the invalid items when sampling. The indices are in the format of a + flattened array and refer to items, not the actual data. To convert item indices into data + indices, we would perform the following: + + indices = item_indices * period + row_indices = indices // max_length_time_axis + time_indices = indices % max_length_time_axis + + Item indices essentially refer to a flattened array picture of the + items (i.e. subsequences that can be sampled) in the buffer state. + + + Args: + state: The buffer state. + sample_sequence_length: The length of the sequence that will be sampled from the buffer + state. + period: The period refers to the interval between sampled sequences. It serves to regulate + how much overlap there is between the trajectories that are sampled. To understand the + degree of overlap, you can calculate it as the difference between the + sample_sequence_length and the period. For instance, if you set period=1, it means that + trajectories will be sampled uniformly with the potential for any degree of overlap. On + the other hand, if period is equal to sample_sequence_length - 1, then trajectories can + be sampled in a way where only the first and last timesteps overlap with each other. + This helps you control the extent of overlap between consecutive sequences in your + sampling process. + add_batch_size: The number of trajectories that will be added to the buffer state. + max_length_time_axis: The maximum length of the time axis of the buffer state. + + Returns: + The indices of the items (with shape : [add_batch_size, num_items]) that will be invalid + when sampling from the buffer state. + """ + # We get the max subsequence data index as done in the add function. + max_divisible_length = max_length_time_axis - (max_length_time_axis % period) + max_subsequence_data_index = max_divisible_length - 1 + # We get the data index that is at least sample_sequence_length away from the + # current index. + previous_valid_data_index = ( + state.current_index - sample_sequence_length + ) % max_length_time_axis + # We ensure that this index is not above the maximum mappable data index of the buffer. + previous_valid_data_index = jnp.minimum( + previous_valid_data_index, max_subsequence_data_index + ) + # We then convert the data index into the item index and add one to get the index + # of the item that is broken apart. + invalid_item_starting_index = (previous_valid_data_index // period) + 1 + # We then take the modulo of the invalid item index to ensure that it is within the + # bounds of the priority array. max_length_time_axis // period is the maximum number + # of items/subsequences that can be sampled from the buffer state. + invalid_item_starting_index = invalid_item_starting_index % ( + max_length_time_axis // period + ) + + # Calculate the maximum number of items/subsequences that can start within a + # sample length of data. We add one to account for situations where the max + # number of items has been broken. Often, this will unfortunately mask an item + # that is valid however this should not be a severe issue as it would be only + # one additional item. + max_num_invalid_items = (sample_sequence_length // period) + 1 + # Get the actual indices of the items we cannot sample from. + invalid_item_indices = ( + jnp.arange(max_num_invalid_items) + invalid_item_starting_index + ) % (max_length_time_axis // period) + # Since items that are broken are broken in the same place in each row, we + # broadcast and add the total number of items to each index to reference + # the invalid items in each add_batch row. + invalid_item_indices = invalid_item_indices + jnp.arange(add_batch_size)[ + :, None + ] * (max_length_time_axis // period) + + return invalid_item_indices + + +def calculate_uniform_item_indices( + state: TrajectoryBufferState[Experience], + rng_key: chex.PRNGKey, + batch_size: int, + sample_sequence_length: int, + period: int, + add_batch_size: int, + max_length_time_axis: int, +) -> Array: + """Randomly sample a batch of item indices from the buffer state. This is done uniformly. + + Args: + state: The buffer's state. + rng_key: Random key. + batch_size: Batch size of sampled experience. + sample_sequence_length: Length of trajectory to sample. + period: The period refers to the interval between sampled sequences. It serves to regulate + how much overlap there is between the trajectories that are sampled. To understand the + degree of overlap, you can calculate it as the difference between the + sample_sequence_length and the period. For instance, if you set period=1, it means that + trajectories will be sampled uniformly with the potential for any degree of overlap. On + the other hand, if period is equal to sample_sequence_length - 1, then trajectories can + be sampled in a way where only the first and last timesteps overlap with each other. + This helps you control the extent of overlap between consecutive sequences in your + sampling process. + add_batch_size: The number of trajectories that will be added to the buffer state. + max_length_time_axis: The maximum length of the time axis of the buffer state. + + Returns: + The indices of the items that will be sampled from the buffer state. + + """ + # Get the max subsequence data index to ensure we dont sample items + # that should not ever be sampled i.e. a subsequence beyond the period + # boundary. + max_divisible_length = max_length_time_axis - (max_length_time_axis % period) + max_subsequence_data_index = max_divisible_length - 1 + # Get the maximum valid time index of the data buffer based on + # whether it is full or not. + max_data_time_index = jnp.where( + state.is_full, + max_subsequence_data_index, + state.current_index - sample_sequence_length, + ) + # Convert the max time index to the maximum non-valid item index. This is the item + # index that we can sample up to (excluding). We add 1 since the max time index is the last + # valid time index that we can sample from and we want the exclusive upper bound + # or in the case of a full buffer, the size of one row of the item array. + max_item_time_index = (max_data_time_index // period) + 1 + + # Get the indices of the items that will be invalid when sampling. + invalid_item_indices = get_invalid_indices( + state=state, + sample_sequence_length=sample_sequence_length, + period=period, + add_batch_size=add_batch_size, + max_length_time_axis=max_length_time_axis, + ) + # Since all the invalid indices are repeated albeit with a batch offset, + # we can just take the first row of the invalid indices for calculation. + invalid_item_indices = invalid_item_indices[0] + + # We then get the upper bound of the item indices that we can sample from. + # When being initially populated with data, the max time index will already account + # for the items that cannot be sampled meaning that invalid indices are not needed. + # Additionally, there is separate logic that needs to be performed when the buffer is not full. + # When the buffer is full, the max time index will not account for the items that cannot be + # sampled meaning that we need to subtract the number of invalid items from the + # max item index. + num_invalid_items = jnp.where(state.is_full, invalid_item_indices.shape[0], 0) + upper_bound = max_item_time_index - num_invalid_items + + # Since the invalid item indices are always consecutive (in a circular manner), + # we can get the offset by taking the last item index and adding one. + time_offset = invalid_item_indices[-1] + 1 + + # We then sample a batch of item indices over the time axis. + sampled_item_time_indices = jax.random.randint( + rng_key, (batch_size,), 0, upper_bound + ) + # We then add the offset and modulo the indices to ensure that they are within + # the bounds of the item array (which doesnt actually exist). We modulo by the + # max item index to ensure that we loop back to the start of the item array. + sampled_item_time_indices = ( + sampled_item_time_indices + time_offset + ) % max_item_time_index + + # We then get the batch indices by sampling a batch of indices over the batch axis. + sampled_item_batch_indices = jax.random.randint( + rng_key, (batch_size,), 0, add_batch_size + ) + + # We then calculate the item indices by multiplying the batch indices by the + # number of items in each batch and adding the time indices. This gives us + # a flattened array picture of the items we will sample from. + item_indices = ( + sampled_item_batch_indices * (max_length_time_axis // period) + ) + sampled_item_time_indices + + return item_indices + + +def sample( + state: TrajectoryBufferState[Experience], + rng_key: chex.PRNGKey, + batch_size: int, + sequence_length: int, + period: int, +) -> TrajectoryBufferSample[Experience]: + """ + Sample a batch of trajectories from the buffer. + + Args: + state: The buffer's state. + rng_key: Random key. + batch_size: Batch size of sampled experience. + sequence_length: Length of trajectory to sample. + period: The period refers to the interval between sampled sequences. It serves to regulate + how much overlap there is between the trajectories that are sampled. To understand the + degree of overlap, you can calculate it as the difference between the + sample_sequence_length and the period. For instance, if you set period=1, it means that + trajectories will be sampled uniformly with the potential for any degree of overlap. On + the other hand, if period is equal to sample_sequence_length - 1, then trajectories can + be sampled in a way where only the first and last timesteps overlap with each other. + This helps you control the extent of overlap between consecutive sequences in your + sampling process. + + Returns: + A batch of experience. + """ + add_batch_size, max_length_time_axis = get_tree_shape_prefix( + state.experience, n_axes=2 + ) + # Calculate the indices of the items that will be sampled. + item_indices = calculate_uniform_item_indices( + state, + rng_key, + batch_size, + sequence_length, + period, + add_batch_size, + max_length_time_axis, + ) + + # Convert the item indices to the indices of the data buffer + flat_data_indices = item_indices * period + # Get the batch index and time index of the sampled items. + batch_data_indices = flat_data_indices // max_length_time_axis + time_data_indices = flat_data_indices % max_length_time_axis + + # The buffer is circular, so we can loop back to the start (`% max_length_time_axis`) + # if the time index is greater than the length. We then add the sequence length to get + # the end index of the sequence. + time_data_indices = ( + jnp.arange(sequence_length) + time_data_indices[:, jnp.newaxis] + ) % max_length_time_axis + + # Slice the experience in the buffer to get a batch of trajectories of length sequence_length + batch_trajectory = jax.tree.map( + lambda x: x[batch_data_indices[:, jnp.newaxis], time_data_indices], + state.experience, + ) + + return TrajectoryBufferSample(experience=batch_trajectory) + + +def can_sample( + state: TrajectoryBufferState[Experience], min_length_time_axis: int +) -> Array: + """Indicates whether the buffer has been filled above the minimum length, such that it + may be sampled from.""" + return state is not None and (state.is_full | (state.current_index >= min_length_time_axis)) + + +def all_fn(state: TrajectoryBufferState[Experience]) -> TrajectoryBufferSample[Experience]: + sampled = state.experience + if not state.is_full: + sampled = jax.tree.map(lambda x: x[:, :state.current_index], state.experience) + return TrajectoryBufferSample(experience=sampled) + + +BufferState = TypeVar("BufferState", bound=TrajectoryBufferState) +BufferSample = TypeVar("BufferSample", bound=TrajectoryBufferSample) + + +@dataclass(frozen=True) +class TrajectoryBuffer(Generic[Experience, BufferState, BufferSample]): + """Pure functions defining the trajectory buffer. This buffer assumes batches added to the + buffer are a pytree with a shape prefix of (batch_size, trajectory_length). Consecutive batches + are then concatenated along the second axis (i.e. the time axis). During sampling this allows + for trajectories to be sampled - by slicing consecutive sequences along the time axis. + + Attributes: + init: A pure function which may be used to initialise the buffer state using a single + timestep (e.g. (s,a,r)). + add: A pure function for adding a new batch of experience to the buffer state. + sample: A pure function for sampling a batch of data from the replay buffer, with a leading + axis of size (`sample_batch_size`, `sample_sequence_length`). Note `sample_batch_size` + and `sample_sequence_length` may be different to the batch size and sequence length of + data added to the state using the `add` function. + can_sample: Whether the buffer can be sampled from, which is determined by if the + number of trajectories added to the buffer state is greater than or equal to the + `min_length`. + + See `make_trajectory_buffer` for how this container is instantiated. + """ + + init: Callable[[Experience], BufferState] + add: Callable[ + [BufferState, Experience], + BufferState, + ] + sample: Callable[ + [BufferState, chex.PRNGKey], + BufferSample, + ] + can_sample: Callable[[BufferState], Array] + all: Callable[[BufferState], BufferSample] + + +def validate_size( + max_length_time_axis: Optional[int], max_size: Optional[int], add_batch_size: int +) -> None: + if max_size is not None and max_length_time_axis is not None: + raise ValueError( + "Cannot specify both `max_size` and `max_length_time_axis` arguments." + ) + if max_size is not None: + warnings.warn( + "Setting max_size dynamically sets the `max_length_time_axis` to " + f"be `max_size`//`add_batch_size = {max_size // add_batch_size}`." + "This allows one to control exactly how many timesteps are stored in the buffer." + "Note that this overrides the `max_length_time_axis` argument.", + stacklevel=1, + ) + + +def validate_trajectory_buffer_args( + max_length_time_axis: Optional[int], + min_length_time_axis: int, + add_batch_size: int, + sample_sequence_length: int, + period: int, + max_size: Optional[int], +) -> None: + """Validate the arguments of the trajectory buffer.""" + + validate_size(max_length_time_axis, max_size, add_batch_size) + + if max_size is not None: + max_length_time_axis = max_size // add_batch_size + + if sample_sequence_length > min_length_time_axis: + warnings.warn( + "`sample_sequence_length` greater than `min_length_time_axis`, therefore " + "overriding `min_length_time_axis`" + "to be set to `sample_sequence_length`, as we need at least `sample_sequence_length` " + "timesteps added to the buffer before we can sample.", + stacklevel=1, + ) + min_length_time_axis = sample_sequence_length + + if period > sample_sequence_length: + warnings.warn( + "Setting period greater than sample_sequence_length will result in no overlap between" + f"trajectories, however, {period-sample_sequence_length} transitions will " + "never be sampled. Setting period to be equal to sample_sequence_length will " + "also result in no overlap between trajectories, however, all transitions will " + "be sampled. Setting period to be `sample_sequence_length - 1` is generally " + "desired to ensure that only starting and ending transitions are shared " + "between trajectories allowing for utilising last transitions for bootstrapping.", + stacklevel=1, + ) + + if max_length_time_axis is not None: + if sample_sequence_length > max_length_time_axis: + raise ValueError( + "`sample_sequence_length` must be less than or equal to `max_length_time_axis`." + ) + + if min_length_time_axis > max_length_time_axis: + raise ValueError( + "`min_length_time_axis` must be less than or equal to `max_length_time_axis`." + ) + + +def make_trajectory_buffer( + add_batch_size: int, + sample_batch_size: int, + sample_sequence_length: int, + period: int, + min_length_time_axis: int, + max_size: Optional[int] = None, + max_length_time_axis: Optional[int] = None, +) -> TrajectoryBuffer: + """Makes a trajectory buffer. + + Args: + add_batch_size: Batch size of experience added to the buffer. Used to initialise the leading + axis of the buffer state's experience. + sample_batch_size: Batch size of experience returned from the `sample` method of the + buffer. + sample_sequence_length: Trajectory length of experience of sampled batches. Note that this + may differ from the trajectory length of experience added to the buffer. + period: The period refers to the interval between sampled sequences. It serves to regulate + how much overlap there is between the trajectories that are sampled. To understand the + degree of overlap, you can calculate it as the difference between the + sample_sequence_length and the period. For instance, if you set period=1, it means that + trajectories will be sampled uniformly with the potential for any degree of overlap. On + the other hand, if period is equal to sample_sequence_length - 1, then trajectories can + be sampled in a way where only the first and last timesteps overlap with each other. + This helps you control the extent of overlap between consecutive sequences in your + sampling process. + min_length_time_axis: Minimum length of the buffer (along the time axis) before sampling is + allowed. + max_size: Optional argument to specify the size of the buffer based on timesteps. + This sets the maximum number of timesteps that can be stored in the buffer and sets + the `max_length_time_axis` to be `max_size`//`add_batch_size`. This allows one to + control exactly how many timesteps are stored in the buffer. Note that this + overrides the `max_length_time_axis` argument. + max_length_time_axis: Optional Argument to specify the maximum length of the buffer in terms + of time steps within the 'time axis'. The second axis (the time axis) of the buffer + state's experience field will be of size `max_length_time_axis`. + + + Returns: + A trajectory buffer. + """ + validate_trajectory_buffer_args( + max_length_time_axis=max_length_time_axis, + min_length_time_axis=min_length_time_axis, + add_batch_size=add_batch_size, + sample_sequence_length=sample_sequence_length, + period=period, + max_size=max_size, + ) + + if sample_sequence_length > min_length_time_axis: + min_length_time_axis = sample_sequence_length + + if max_size is not None: + max_length_time_axis = max_size // add_batch_size + + init_fn = functools.partial( + init, + add_batch_size=add_batch_size, + max_length_time_axis=max_length_time_axis, + ) + add_fn = functools.partial( + add, + ) + sample_fn = functools.partial( + sample, + batch_size=sample_batch_size, + sequence_length=sample_sequence_length, + period=period, + ) + can_sample_fn = functools.partial( + can_sample, min_length_time_axis=min_length_time_axis + ) + + return TrajectoryBuffer( + init=init_fn, + add=add_fn, + sample=sample_fn, + can_sample=can_sample_fn, + all=all_fn, + ) diff --git a/results/pocman_pellet_probe_trajectory.zip b/results/pocman_pellet_probe_trajectory.zip new file mode 100644 index 0000000..9513a4b Binary files /dev/null and b/results/pocman_pellet_probe_trajectory.zip differ diff --git a/scripts/additional_experiments.md b/scripts/additional_experiments.md new file mode 100644 index 0000000..b79928e --- /dev/null +++ b/scripts/additional_experiments.md @@ -0,0 +1,44 @@ +# Additional experiments + +This `scripts` directory includes scripts for plotting all experimental +results in our work, as well as scripts for a few additional experiments +in the paper. + +## Parity Check experiments +The parity check closed-form optimization experiments were done with the +`batch_run_analytical.py` script, except with the option `--mem_aug_before_init_pi`. +This option augments our POMDP with a random memory function before +choosing the initial policy that maximizes the λ-discrepancy. See the +`parity_check*_30seeds.py` hyperparameter files to run these experiments. + + +## P.O. PacMan Memory Probe +To train our memory probe, we need to first collect checkpoints from +a P.O. PacMan run. We can do so with the `pocman_*ppo_best_ckpt.py` scripts +in `scripts/hyperparams`. This script will train P.O. PacMan agents with the +best swept hyperparams. + +After training, we need to run the `scripts/collect_rnn_trajectories.py` script +to collect 1M samples from each behavior policy (LD and vanilla PPO). This script +will collect RNN hidden states from two RNNs (`--rnn_path_0` and `--rnn_path_1`), +while following the `--behavior_path` RNN as the behavior policy. We collect +1M time steps with each variant as the behavior policy, for a combined dataset of +2M samples. We use the `scripts/combine_probe_datasets.py` script to combine these +datasets, resulting in a `results/combined_probe_datasets` data buffer. + +Now we train our probe with the `scripts/train_probe.py` script. Pass in the +PATH to the combined dataset above as the argument to `--dataset_path`. Use the +`--features_idx` argument to select which RNN hidden states to use for training (0 or 1). +The index and ordering of these hidden states will depend on which RNN paths were +used in `--rnn_path_0` and `--rnn_path_1`. + +Once our probe has been trained, we can collect trajectories with each trained probe +with `scripts/collect_probe_trajectories.py`. We can visualize these collected +probe trajectories with `scripts/visualization/viz_pocman_probe.py`. +We provide the collected probe trajectories in `results/pocman_pellet_probe_trajectory.zip`. +To generate this visualization, simply uncompress this file and pass each file in as +the argument to `scripts/visualization/viz_pocman_probe.py`. + + + + diff --git a/scripts/batch_run_ppo_epoch.py b/scripts/batch_run_ppo_epoch.py new file mode 100644 index 0000000..164c078 --- /dev/null +++ b/scripts/batch_run_ppo_epoch.py @@ -0,0 +1,486 @@ +from collections import deque +from dataclasses import replace +from functools import partial +import inspect +from typing import Literal + +from flax.training.train_state import TrainState +from flax.training import orbax_utils +import jax +import jax.numpy as jnp +import numpy as np +import optax +import orbax.checkpoint +from tap import Tap + +from lamb.agents.ppo import Transition, env_step +from lamb.envs import get_gymnax_env +from lamb.envs.jax_wrappers import LogEnvState +from lamb.models import get_network_fn, ScannedRNN +from lamb.utils.file_system import get_results_path + + +class BatchPPOHyperparams(Tap): + env: str = 'tmaze_5' + num_envs: int = 4 + default_max_steps_in_episode: int = 1000 + gamma: float = 0.99 # will be replaced if env has gamma property. + + num_steps: int = 128 + num_epochs: int = 50 + update_epochs: int = 4 + num_minibatches: int = 4 + + memoryless: bool = False + double_critic: bool = False + action_concat: bool = False + + lr: list[float] = [2.5e-4] + lambda0: list[float] = [0.95] # GAE lambda_0 + lambda1: list[float] = [0.5] # GAE lambda_1 + alpha: list[float] = [1.] # adv = alpha * adv_lambda_0 + (1 - alpha) * adv_lambda_1 + ld_weight: list[float] = [0.0] # how much to we weight the LD loss vs. value loss? only applies when optimize LD is True. + vf_coeff: list[float] = [0.5] + + hidden_size: int = 128 + total_steps: int = int(1.5e6) + entropy_coeff: float = 0.01 + clip_eps: float = 0.2 + max_grad_norm: float = 0.5 + anneal_lr: bool = True + + num_eval_envs: int = 10 + steps_log_freq: int = 1 + update_log_freq: int = 1 + save_checkpoints: bool = False # Do we save train_state along with our per timestep outputs? + save_runner_state: bool = False # Do we save the checkpoint in the end? + seed: int = 2020 + n_seeds: int = 5 + platform: Literal['cpu', 'gpu'] = 'cpu' + debug: bool = False + + study_name: str = 'batch_ppo_test' + + def process_args(self) -> None: + self.vf_coeff = jnp.array(self.vf_coeff) + self.lr = jnp.array(self.lr) + self.lambda0 = jnp.array(self.lambda0) + self.lambda1 = jnp.array(self.lambda1) + self.alpha = jnp.array(self.alpha) + self.ld_weight = jnp.array(self.ld_weight) + +def filter_period_first_dim(x, n: int): + if isinstance(x, jnp.ndarray) or isinstance(x, np.ndarray): + return x[::n] + + +def make_train(args: BatchPPOHyperparams, rand_key: jax.random.PRNGKey): + num_updates = ( + args.total_steps // args.num_steps // args.num_envs + ) + args.minibatch_size = ( + args.num_envs * args.num_steps // args.num_minibatches + ) + env_key, rand_key = jax.random.split(rand_key) + env, env_params = get_gymnax_env(args.env, env_key, + gamma=args.gamma, + action_concat=args.action_concat) + + if hasattr(env, 'gamma'): + args.gamma = env.gamma + + assert hasattr(env_params, 'max_steps_in_episode') + + double_critic = args.double_critic + memoryless = args.memoryless + + network_fn, action_size = get_network_fn(env, env_params, memoryless=memoryless) + + network = network_fn(action_size, + double_critic=double_critic, + hidden_size=args.hidden_size) + + steps_filter = partial(filter_period_first_dim, n=args.steps_log_freq) + update_filter = partial(filter_period_first_dim, n=args.update_log_freq) + + # Used for vmapping over our double critic. + transition_axes_map = Transition( + None, None, 2, None, None, None, None + ) + + _env_step = partial(env_step, network=network, env=env, env_params=env_params) + + def train(vf_coeff, ld_weight, alpha, lambda1, lambda0, lr, rng): + def linear_schedule(count): + frac = ( + 1.0 + - (count // (args.num_minibatches * args.update_epochs)) + / num_updates + ) + return lr * frac + + + # INIT NETWORK + rng, _rng = jax.random.split(rng) + init_x = ( + jnp.zeros( + (1, args.num_envs, *env.observation_space(env_params).shape) + ), + jnp.zeros((1, args.num_envs)), + ) + init_hstate = ScannedRNN.initialize_carry(args.num_envs, args.hidden_size) + network_params = network.init(_rng, init_hstate, init_x) + if args.anneal_lr: + tx = optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + optax.adam(learning_rate=linear_schedule, eps=1e-5), + ) + else: + tx = optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + optax.adam(lr, eps=1e-5), + ) + train_state = TrainState.create( + apply_fn=network.apply, + params=network_params, + tx=tx, + ) + + # INIT ENV + rng, _rng = jax.random.split(rng) + reset_rng = jax.random.split(_rng, args.num_envs) + obsv, env_state = env.reset(reset_rng, env_params) + init_hstate = ScannedRNN.initialize_carry(args.num_envs, args.hidden_size) + + # We first need to populate our LogEnvState stats. + rng, _rng = jax.random.split(rng) + init_rng = jax.random.split(_rng, args.num_envs) + init_obsv, init_env_state = env.reset(init_rng, env_params) + init_init_hstate = ScannedRNN.initialize_carry(args.num_envs, args.hidden_size) + + init_runner_state = ( + train_state, + env_state, + init_obsv, + jnp.zeros(args.num_envs, dtype=bool), + init_init_hstate, + _rng, + ) + + starting_runner_state, _ = jax.lax.scan( + _env_step, init_runner_state, None, env_params.max_steps_in_episode + ) + + def recursive_replace(env_state, new_env_state, names): + if not isinstance(env_state, LogEnvState): + return replace(env_state, env_state=recursive_replace(env_state.env_state, new_env_state.env_state, names)) + new_log_vals = {name: getattr(new_env_state, name) for name in names} + return replace(env_state, **new_log_vals) + + replace_field_names = ['returned_episode_returns', 'returned_discounted_episode_returns', 'returned_episode_lengths'] + env_state = recursive_replace(env_state, starting_runner_state[1], replace_field_names) + + # TRAIN LOOP + def _update_step(runner_state, i): + # COLLECT TRAJECTORIES + initial_hstate = runner_state[-2] + runner_state, traj_batch = jax.lax.scan( + _env_step, runner_state, jnp.arange(args.num_steps), args.num_steps + ) + + # CALCULATE ADVANTAGE + train_state, env_state, last_obs, last_done, hstate, rng = runner_state + ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :]) + _, _, last_val = network.apply(train_state.params, hstate, ac_in) + last_val = last_val.squeeze(0) + def _calculate_gae(traj_batch, last_val, last_done, gae_lambda): + def _get_advantages(carry, transition): + gae, next_value, next_done, gae_lambda = carry + done, value, reward = transition.done, transition.value, transition.reward + delta = reward + args.gamma * next_value * (1 - next_done) - value + gae = delta + args.gamma * gae_lambda * (1 - next_done) * gae + return (gae, value, done, gae_lambda), gae + _, advantages = jax.lax.scan(_get_advantages, + (jnp.zeros_like(last_val), last_val, last_done, gae_lambda), + traj_batch, reverse=True, unroll=16) + return advantages, advantages + traj_batch.value + + gae_lambda = jnp.array(lambda0) + if double_critic: + # last_val is index 1 here b/c we squeezed earlier. + _calculate_gae = jax.vmap(_calculate_gae, + in_axes=[transition_axes_map, 1, None, 0], + out_axes=2) + gae_lambda = jnp.array([lambda0, lambda1]) + advantages, targets = _calculate_gae(traj_batch, last_val, last_done, gae_lambda) + + # UPDATE NETWORK + def _update_epoch(update_state, unused): + def _update_minbatch(train_state, batch_info): + init_hstate, traj_batch, advantages, targets = batch_info + + def _loss_fn(params, init_hstate, traj_batch, gae, targets): + # RERUN NETWORK + _, pi, value = network.apply( + params, init_hstate[0], (traj_batch.obs, traj_batch.done) + ) + log_prob = pi.log_prob(traj_batch.action) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + ( + value - traj_batch.value + ).clip(-args.clip_eps, args.clip_eps) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = ( + jnp.maximum(value_losses, value_losses_clipped).mean() + ) + # Lambda discrepancy loss + if double_critic: + value_loss = ld_weight * (jnp.square(value[..., 0] - value[..., 1])).mean() + \ + (1 - ld_weight) * value_loss + + # CALCULATE ACTOR LOSS + ratio = jnp.exp(log_prob - traj_batch.log_prob) + + # which advantage do we use to update our policy? + if double_critic: + gae = (alpha * gae[..., 0] + + (1 - alpha) * gae[..., 1]) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - args.clip_eps, + 1.0 + args.clip_eps, + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + entropy = pi.entropy().mean() + + total_loss = ( + loss_actor + + vf_coeff * value_loss + - args.entropy_coeff * entropy + ) + return total_loss, (value_loss, loss_actor, entropy) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + total_loss, grads = grad_fn( + train_state.params, init_hstate, traj_batch, advantages, targets + ) + train_state = train_state.apply_gradients(grads=grads) + return train_state, total_loss + + ( + train_state, + init_hstate, + traj_batch, + advantages, + targets, + rng, + ) = update_state + + rng, _rng = jax.random.split(rng) + permutation = jax.random.permutation(_rng, args.num_envs) + batch = (init_hstate, traj_batch, advantages, targets) + + shuffled_batch = jax.tree.map( + lambda x: jnp.take(x, permutation, axis=1), batch + ) + + minibatches = jax.tree.map( + lambda x: jnp.swapaxes( + jnp.reshape( + x, + [x.shape[0], args.num_minibatches, -1] + + list(x.shape[2:]), + ), + 1, + 0, + ), + shuffled_batch, + ) + + train_state, total_loss = jax.lax.scan( + _update_minbatch, train_state, minibatches + ) + update_state = ( + train_state, + init_hstate, + traj_batch, + advantages, + targets, + rng, + ) + return update_state, total_loss + + init_hstate = initial_hstate[None, :] # TBH + update_state = ( + train_state, + init_hstate, + traj_batch, + advantages, + targets, + rng, + ) + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, args.update_epochs + ) + train_state = update_state[0] + + # save metrics only every steps_log_freq + metric = traj_batch.info + metric = jax.tree.map(steps_filter, metric) + + rng = update_state[-1] + if args.debug: + + def callback(info): + timesteps = ( + info["timestep"][info["returned_episode"]] * args.num_envs + ) + avg_return_values = jnp.mean(info["returned_episode_returns"][info["returned_episode"]]) + if len(timesteps) > 0: + print( + f"timesteps={timesteps[0]} - {timesteps[-1]}, avg episodic return={avg_return_values:.2f}" + ) + + jax.debug.callback(callback, metric) + + runner_state = (train_state, env_state, last_obs, last_done, hstate, rng) + + return runner_state, metric + + def _epoch(runner_state, _): + runner_state, metric = jax.lax.scan( + _update_step, runner_state, jnp.arange(round(num_updates / args.num_epochs)), round(num_updates / args.num_epochs) + ) + # save metrics only every update_log_freq + metric = jax.tree.map(update_filter, metric) + + res = {'metric': metric} + if args.save_checkpoints: + res['checkpoint'] = runner_state[0].params + + return runner_state, res + + rng, _rng = jax.random.split(rng) + runner_state = ( + train_state, + env_state, + obsv, + jnp.zeros((args.num_envs), dtype=bool), + init_hstate, + _rng, + ) + + runner_state, metric = jax.lax.scan( + _epoch, runner_state, jnp.arange(args.num_epochs), args.num_epochs + ) + # combine epochs with time steps + metric['metric'] = jax.tree.map(lambda x: x.reshape(-1, *x.shape[2:]), metric['metric']) + + # returned metric has an extra dimension. + # runner_state, metric = jax.lax.scan( + # _update_step, runner_state, jnp.arange(num_updates), num_updates + # ) + # + # # save metrics only every update_log_freq + # metric = jax.tree.map(update_filter, metric) + + # TODO: offline eval here. + final_train_state = runner_state[0] + + reset_rng = jax.random.split(_rng, args.num_eval_envs) + eval_obsv, eval_env_state = env.reset(reset_rng, env_params) + + eval_init_hstate = ScannedRNN.initialize_carry(args.num_eval_envs, args.hidden_size) + + eval_runner_state = ( + final_train_state, + eval_env_state, + eval_obsv, + jnp.zeros((args.num_eval_envs), dtype=bool), + eval_init_hstate, + _rng, + ) + + # COLLECT EVAL TRAJECTORIES + eval_runner_state, eval_traj_batch = jax.lax.scan( + _env_step, eval_runner_state, None, env_params.max_steps_in_episode + ) + + res = {"runner_state": runner_state, "metric": metric['metric'], 'final_eval_metric': eval_traj_batch.info} + + if args.save_checkpoints: + res['checkpoint'] = metric['checkpoint'] + return res + + return train + + +if __name__ == "__main__": + # jax.disable_jit(True) + # okay some weirdness here. NUM_ENVS needs to match with NUM_MINIBATCHES + args = BatchPPOHyperparams().parse_args() + jax.config.update('jax_platform_name', args.platform) + + rng = jax.random.PRNGKey(args.seed) + make_train_rng, rng = jax.random.split(rng) + rngs = jax.random.split(rng, args.n_seeds) + train_fn = make_train(args, make_train_rng) + + train_args = list(inspect.signature(train_fn).parameters.keys()) + + vmaps_train = train_fn + swept_args = deque() + + # we need to go backwards, since JAX returns indices + # in the order in which they're vmapped. + for i, arg in reversed(list(enumerate(train_args))): + dims = [None] * len(train_args) + dims[i] = 0 + vmaps_train = jax.vmap(vmaps_train, in_axes=dims) + if arg == 'rng': + swept_args.appendleft(rngs) + else: + assert hasattr(args, arg) + swept_args.appendleft(getattr(args, arg)) + + train_jit = jax.jit(vmaps_train) + out = train_jit(*swept_args) + + # our final_eval_metric returns max_num_steps. + # we can filter that down by the max episode length amongst the runs. + final_eval = out['final_eval_metric'] + + # the +1 at the end is to include the done step + largest_episode = final_eval['returned_episode'].argmax(axis=-2).max() + 1 + + def get_first_n_filter(x): + return x[..., :largest_episode, :] + out['final_eval_metric'] = jax.tree.map(get_first_n_filter, final_eval) + + if not args.save_runner_state: + del out['runner_state'] + + results_path = get_results_path(args, return_npy=False) # returns a results directory + + all_results = { + 'argument_order': train_args, + 'out': out, + 'args': args.as_dict() + } + + # Save all results with Orbax + orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() + save_args = orbax_utils.save_args_from_target(all_results) + + print(f"Saving results to {results_path}") + orbax_checkpointer.save(results_path, all_results, save_args=save_args) + + print("Done.") diff --git a/scripts/collect_probe_trajectories.py b/scripts/collect_probe_trajectories.py new file mode 100644 index 0000000..8a1eccd --- /dev/null +++ b/scripts/collect_probe_trajectories.py @@ -0,0 +1,150 @@ +from functools import partial +from pathlib import Path +from typing import Union, Literal + +from chex import dataclass +from jumanji.environments.routing.pac_man import State +from jumanji.environments.routing.pac_man.types import Position +import jax +import jax.numpy as jnp +import numpy as np +import orbax.checkpoint +from tap import Tap + +from porl.agents.ppo import env_step +from porl.envs.pocman import PocMan +from porl.models.actor_critic import ScannedRNN, PelletPredictorNN +from porl.utils.file_system import load_train_state, numpyify_and_save + +from definitions import ROOT_DIR + + +class PocmanProbeCollectHyperparams(Tap): + probe_path_0: Union[str, Path] + probe_path_1: Union[str, Path] + rnn_path_0: Union[str, Path] + rnn_path_1: Union[str, Path] + + behavior_policy_idx: Literal[0, 1] = 1 + seed: int = 2024 + + def configure(self) -> None: + self.add_argument('--probe_path_0', type=Path) + self.add_argument('--probe_path_1', type=Path) + self.add_argument('--rnn_path_0', type=Path) + self.add_argument('--rnn_path_1', type=Path) + + +def state_to_dict(state: State): + state_dict = {} + for k, v in state.items(): + if isinstance(v, Position): + state_dict[k] = {'x': v.x, 'y': v.y} + else: + state_dict[k] = v + return state_dict + + +def unpack_and_flatten_state(state) -> dict: + while (not isinstance(state, State)): + state = state.env_state + + flattened_state = jax.tree.map(lambda x: x[0], state) + return state_to_dict(flattened_state) + + +def load_probe(fpath: Path): + orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() + restored = orbax_checkpointer.restore(fpath) + args = restored['args'] + unpacked_ts = restored['final_train_state'] + + # TODO: refactor this + n_pellet_predictions = unpacked_ts['params']['params']['Dense_3']['bias'].shape[0] + + network = PelletPredictorNN(hidden_size=args['hidden_size'], + n_outs=n_pellet_predictions, + n_hidden_layers=args['n_hidden_layers']) + return network, unpacked_ts + +def predictions_to_map(predictions: jnp.ndarray, env: PocMan): + predictions = predictions.squeeze() + env_generator = env._unwrapped.generator + + # we first subtract by 1, so that all walls are -1, and + # empty spaces are 0. + preds_map = env_generator.numpy_maze - 1 + + preds_map = preds_map.at[env_generator.pellet_spaces[:, 1], env_generator.pellet_spaces[:, 0]].set(predictions) + return preds_map + + +if __name__ == "__main__": + # jax.disable_jit(True) + args = PocmanProbeCollectHyperparams().parse_args() + + key = jax.random.PRNGKey(args.seed) + load_key_0, load_key_1, key = jax.random.split(key, 3) + + probe_network_0, probe_ts_0 = load_probe(args.probe_path_0) + probe_network_1, probe_ts_1 = load_probe(args.probe_path_1) + + # TODO: This is kind of sketch. We've refactored this now, so change this when done retraining. + env, env_params, rnn_args0, rnn_network0, rnn_ts0 = load_train_state(load_key_0, args.rnn_path_0, + update_idx_to_take=2, + best_over_rng=True) + _, _, rnn_args1, rnn_network1, rnn_ts1 = load_train_state(load_key_1, args.rnn_path_1, + update_idx_to_take=2, + best_over_rng=True) + + predictions_to_map = jax.jit(partial(predictions_to_map, env=env)) + + networks = [rnn_network0, rnn_network1] + tses = [rnn_ts0, rnn_ts1] + ts = tses[args.behavior_policy_idx] + _env_step = jax.jit(partial(env_step, network=networks[args.behavior_policy_idx], env=env, env_params=env_params)) + + @jax.jit + def predict_probes(obs, done, hs0, hs1): + ac_in = (obs[jnp.newaxis, :], done[jnp.newaxis, :]) + hs0, _, _ = rnn_network0.apply(rnn_ts0.params, hs0, ac_in) + hs1, _, _ = rnn_network1.apply(rnn_ts1.params, hs1, ac_in) + + predictions0, _ = probe_network_0.apply(probe_ts_0['params'], hs0) + predictions1, _ = probe_network_1.apply(probe_ts_1['params'], hs1) + return (predictions0, predictions1), (hs0, hs1) + + key, reset_key = jax.random.split(key) + reset_key = reset_key[None, ...] + obsv, state = env.reset(reset_key, env_params) + states = [unpack_and_flatten_state(state)] + predictions, pred_maps = [], [] + + assert rnn_args1['hidden_size'] == rnn_args0['hidden_size'] + hstate = ScannedRNN.initialize_carry(1, rnn_args1['hidden_size']) + hs0 = ScannedRNN.initialize_carry(1, rnn_args0['hidden_size']) + hs1 = ScannedRNN.initialize_carry(1, rnn_args1['hidden_size']) + + done = jnp.array([False]) + rs = (ts, state, obsv, done, hstate, key) + while not jnp.any(done): + preds, (hs0, hs1) = predict_probes(obsv, done, hs0, hs1) + predictions.append(preds) + pred_maps.append((predictions_to_map(preds[0]), predictions_to_map(preds[1]))) + rs, transition = _env_step(rs, None) + ts, state, obsv, done, hstate, key = rs + states.append(unpack_and_flatten_state(state)) + + preds, (hs0, hs1) = predict_probes(obsv, done, hs0, hs1) + pred_maps.append((predictions_to_map(preds[0]), predictions_to_map(preds[1]))) + + res = { + 'states': states, + 'predictions': pred_maps + } + + res_path = Path(ROOT_DIR, 'results', f'pocman_pellet_probe_trajectory_bidx_{args.behavior_policy_idx}.npy') + + print(f"Saving Pocman probe trajectory to {res_path}") + numpyify_and_save(res_path, res) + print("Done.") diff --git a/scripts/collect_rnn_trajectories.py b/scripts/collect_rnn_trajectories.py new file mode 100644 index 0000000..681b8e1 --- /dev/null +++ b/scripts/collect_rnn_trajectories.py @@ -0,0 +1,197 @@ +from functools import partial +from pathlib import Path +from typing import Union, NamedTuple + +import chex +import jax +import jax.numpy as jnp +from jax_tqdm import scan_tqdm +import numpy as np +from tap import Tap +from flax.training import orbax_utils +import orbax.checkpoint + +from lamb.envs.pocman import State +from lamb.models import ScannedRNN +from lamb.utils.file_system import load_train_state, make_hash_md5 + + +class CollectHyperparams(Tap): + rnn_path_0: Union[str, Path] + rnn_path_1: Union[str, Path] + behavior_path: Union[str, Path] + + update_idx_to_take: int = None + + num_envs: int = 4 + n_samples: int = int(1e6) + + seed: int = 2024 + platform: str = 'cpu' + + def configure(self) -> None: + self.add_argument('--rnn_path_0', type=Path) + self.add_argument('--rnn_path_1', type=Path) + self.add_argument('--behavior_path', type=Path) + + +def ppo_pocman_step(runner_state, unused, + behavior_network, rnn_network_0, rnn_network_1, + env, env_params): + def get_pocman_state(s) -> State: + if isinstance(s, State): + return s + if hasattr(s, 'env_state'): + return get_pocman_state(s.env_state) + else: + raise TypeError('No Pocman env_state found.') + + (behavior_ts, rnn_ts_0, rnn_ts_1, env_state, last_obs, last_done, + behavior_hstate, rnn_hstate_0, rnn_hstate_1, rng) = runner_state + rng, _rng = jax.random.split(rng) + + # SELECT ACTION + ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :]) + next_behavior_hstate, pi, value = behavior_network.apply(behavior_ts.params, behavior_hstate, ac_in) + action = pi.sample(seed=_rng) + log_prob = pi.log_prob(action) + value, action, log_prob = ( + value.squeeze(0), + action.squeeze(0), + log_prob.squeeze(0), + ) + + # get our RNN hidden states that we're sampling + next_rnn_hstate_0, _, _ = rnn_network_0.apply(rnn_ts_0.params, rnn_hstate_0, ac_in) + next_rnn_hstate_1, _, _ = rnn_network_1.apply(rnn_ts_1.params, rnn_hstate_1, ac_in) + + # STEP ENV + rng, _rng = jax.random.split(rng) + rng_step = jax.random.split(_rng, next_behavior_hstate.shape[0]) + obsv, next_env_state, reward, done, info = env.step(rng_step, env_state, action, env_params) + + # transition = Transition( + # last_done, action, value, reward, log_prob, last_obs, info + # ) + pocman_state = get_pocman_state(env_state) + possible_locations = jnp.array(env._unwrapped.generator.reachable_spaces) + + def get_single_occupancy(loc: jnp.ndarray): + return jnp.all(possible_locations == loc[None, ...], axis=-1) + + # We vmap twice, once for the batch dimension in VecEnv, + # the second time for the 4 ghosts + ghost_occupancy = jax.vmap(jax.vmap(get_single_occupancy))(pocman_state.ghost_locations).sum(axis=-2) + datum = { + 'x_0': rnn_hstate_0, + 'x_1': rnn_hstate_1, + 'pellet_occupancy': jnp.all(pocman_state.pellet_locations != 0, axis=-1), + 'ghost_occupancy': jnp.clip(ghost_occupancy, a_max=1), + # 'state': pocman_state + } + runner_state = (behavior_ts, rnn_ts_0, rnn_ts_1, next_env_state, obsv, done, + next_behavior_hstate, next_rnn_hstate_0, next_rnn_hstate_1, rng) + return runner_state, datum + + +def make_collect(args: CollectHyperparams, key: chex.PRNGKey): + steps_to_collect = args.n_samples // args.num_envs + + behavior_key, rnn_key_0, rnn_key_1, key = jax.random.split(key, 4) + + env, env_params, behavior_args, behavior_network, behavior_ts = load_train_state(behavior_key, args.behavior_path, + update_idx_to_take=args.update_idx_to_take, + best_over_rng=True) + _, _, rnn_args_0, rnn_network_0, rnn_ts_0 = load_train_state(rnn_key_0, args.rnn_path_0, + update_idx_to_take=args.update_idx_to_take, + best_over_rng=True) + _, _, rnn_args_1, rnn_network_1, rnn_ts_1 = load_train_state(rnn_key_1, args.rnn_path_1, + update_idx_to_take=args.update_idx_to_take, + best_over_rng=True) + + _env_step = partial(ppo_pocman_step, behavior_network=behavior_network, + rnn_network_0=rnn_network_0, rnn_network_1=rnn_network_1, + env=env, env_params=env_params) + _env_step = scan_tqdm(steps_to_collect)(_env_step) + + ckpts = { + 'behavior': {'args': behavior_args, 'ts': behavior_ts, 'path': args.behavior_path}, + 'rnn_0': {'args': rnn_args_0, 'ts': rnn_ts_0, 'path': args.rnn_path_0}, + 'rnn_1': {'args': rnn_args_1, 'ts': rnn_ts_1, 'path': args.rnn_path_1} + } + + def collect(rng): + # INIT ENV + rng, _rng = jax.random.split(rng) + reset_rng = jax.random.split(_rng, args.num_envs) + obsv, env_state = env.reset(reset_rng, env_params) + + # init hidden state + init_behavior_hstate = ScannedRNN.initialize_carry(args.num_envs, behavior_args['hidden_size']) + init_rnn_hstate_0 = ScannedRNN.initialize_carry(args.num_envs, rnn_args_0['hidden_size']) + init_rnn_hstate_1 = ScannedRNN.initialize_carry(args.num_envs, rnn_args_1['hidden_size']) + init_runner_state = ( + behavior_ts, + rnn_ts_0, + rnn_ts_1, + env_state, + obsv, + jnp.zeros(args.num_envs, dtype=bool), + init_behavior_hstate, + init_rnn_hstate_0, + init_rnn_hstate_1, + _rng, + ) + + runner_state, dataset = jax.lax.scan( + _env_step, init_runner_state, jnp.arange(steps_to_collect), steps_to_collect + ) + + # Now we flatten back down + flat_dataset = jax.tree.map(lambda x: x.reshape(-1, *x.shape[2:]), dataset) + + return flat_dataset + + return collect, ckpts + + +if __name__ == "__main__": + # jax.disable_jit(True) + args = CollectHyperparams().parse_args() + jax.config.update('jax_platform_name', args.platform) + + key = jax.random.PRNGKey(args.seed) + make_key, collect_key, key = jax.random.split(key, 3) + + collect_fn, ckpt_info = make_collect(args, make_key) + collect_fn = jax.jit(collect_fn) + + dataset = collect_fn(collect_key) + + def path_to_str(d: dict): + for k, v in d.items(): + if isinstance(v, Path): + d[k] = str(v) + elif isinstance(v, dict): + path_to_str(v) + + + to_save = { + 'dataset': dataset, + 'args': args.as_dict(), + 'ckpt': ckpt_info, + } + path_to_str(to_save) + + save_path = args.behavior_path.parent / \ + f'buffer_{args.n_samples}_timestep_{args.update_idx_to_take}_seed_{args.seed}_{make_hash_md5(args.as_dict())}' + + # Save all results with Orbax + orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() + save_args = orbax_utils.save_args_from_target(to_save) + + print(f"Saving results to {save_path}") + orbax_checkpointer.save(save_path, to_save, save_args=save_args) + + print("Done.") + diff --git a/scripts/combine_probe_datasets.py b/scripts/combine_probe_datasets.py new file mode 100644 index 0000000..42a1185 --- /dev/null +++ b/scripts/combine_probe_datasets.py @@ -0,0 +1,43 @@ +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np +import orbax.checkpoint + +from lamb.utils.file_system import numpyify_and_save +from definitions import ROOT_DIR + + +if __name__ == "__main__": + + jax.config.update('jax_platform_name', 'cpu') + d0_path = Path("../results/pocman_ppo_best_ckpt/buffer_1000000_timestep_2_seed_2024_390f282614e5b7398cf10e565ea811e7") + d1_path = Path("../results/pocman_LD_ppo_best_ckpt/buffer_1000000_timestep_2_seed_2024_e232a0c2fa52cb6e68f01be9dac6b7b9") + + orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() + restored = orbax_checkpointer.restore(d0_path) + + args_0, ckpt_0, dataset_0 = restored['args'], restored['ckpt'], restored['dataset'] + dataset_0['ghost_occupancy'] = dataset_0['ghost_occupancy'].astype(np.int8) + + restored = orbax_checkpointer.restore(d1_path) + + args_1, ckpt_1, dataset_1 = restored['args'], restored['ckpt'], restored['dataset'] + dataset_1['ghost_occupancy'] = dataset_1['ghost_occupancy'].astype(np.int8) + + combined_dataset = jax.tree.map(lambda x, y: jnp.concatenate((x, y), axis=0), dataset_0, dataset_1) + + save_dir = Path(ROOT_DIR, 'results', 'combined_probe_datasets') + save_dir.mkdir(exist_ok=True) + save_path = save_dir / f'combined_{d0_path.stem.split("_")[-1]}.npy' + to_save = { + 'args': [args_0, args_1], + 'ckpt': [ckpt_0, ckpt_1], + 'dataset': combined_dataset + } + + print(f"Saving results to {save_path}") + numpyify_and_save(save_path, to_save) + + print("Done.") diff --git a/scripts/hyperparams/analytical_30seeds.py b/scripts/hyperparams/analytical_30seeds.py index 7b8cf1e..70661a4 100644 --- a/scripts/hyperparams/analytical_30seeds.py +++ b/scripts/hyperparams/analytical_30seeds.py @@ -5,7 +5,7 @@ hparams = { 'file_name': f'runs_{exp_name}.txt', - 'entry': '-m batch_run_kitchen_sinks_ld_only', + 'entry': '-m batch_run_analytical', 'args': [{ 'spec': [ 'tiger-alt-start', 'tmaze_5_two_thirds_up', '4x3.95', diff --git a/scripts/hyperparams/analytical_8_30seeds.py b/scripts/hyperparams/analytical_8_30seeds.py index 912db9b..20a99be 100644 --- a/scripts/hyperparams/analytical_8_30seeds.py +++ b/scripts/hyperparams/analytical_8_30seeds.py @@ -5,7 +5,7 @@ hparams = { 'file_name': f'runs_{exp_name}.txt', - 'entry': '-m batch_run_kitchen_sinks_ld_only', + 'entry': '-m batch_run_analytical', 'args': [{ 'spec': [ 'tiger-alt-start', 'tmaze_5_two_thirds_up', '4x3.95', diff --git a/scripts/hyperparams/parity_check_30seeds.py b/scripts/hyperparams/parity_check_30seeds.py new file mode 100644 index 0000000..08177b3 --- /dev/null +++ b/scripts/hyperparams/parity_check_30seeds.py @@ -0,0 +1,31 @@ +from pathlib import Path + +exp_name = Path(__file__).stem + +hparams = { + 'file_name': + f'runs_{exp_name}.txt', + 'entry': '-m batch_run_analytical', + 'args': [{ + 'spec': [ + 'parity_check' + ], + 'policy_optim_alg': 'policy_grad', + 'leave_out_optimal': True, + 'mem_aug_before_init_pi': True, + 'value_type': 'q', + 'error_type': 'l2', + 'alpha': 1., + 'mi_steps': 10000, + 'pi_steps': 10000, + 'optimizer': 'adam', + 'lr': 0.75, + 'n_mem_states': [2, 4], + 'mi_iterations': 1, + 'random_policies': 100, + 'n_seeds': 30, + 'platform': 'gpu' + }, + + ] +} diff --git a/scripts/hyperparams/parity_check_8_30seeds.py b/scripts/hyperparams/parity_check_8_30seeds.py new file mode 100644 index 0000000..b6b377f --- /dev/null +++ b/scripts/hyperparams/parity_check_8_30seeds.py @@ -0,0 +1,32 @@ +from pathlib import Path + +exp_name = Path(__file__).stem + +hparams = { + 'file_name': + f'runs_{exp_name}.txt', + 'entry': '-m batch_run_analytical', + 'args': [{ + 'spec': [ + 'parity_check' + ], + 'policy_optim_alg': 'policy_grad', + 'leave_out_optimal': True, + 'mem_aug_before_init_pi': True, + 'value_type': 'q', + 'error_type': 'l2', + 'alpha': 1., + 'mi_steps': 10000, + 'pi_steps': 10000, + 'optimizer': 'adam', + 'lr': 0.75, + 'n_mem_states': 8, + 'mi_iterations': 1, + 'random_policies': 100, + 'seed': [2024 + s for s in range(6)], + 'n_seeds': 5, + 'platform': 'gpu' + }, + + ] +} diff --git a/scripts/hyperparams/pocman_LD_ppo_best_ckpt.py b/scripts/hyperparams/pocman_LD_ppo_best_ckpt.py new file mode 100644 index 0000000..75b4d55 --- /dev/null +++ b/scripts/hyperparams/pocman_LD_ppo_best_ckpt.py @@ -0,0 +1,39 @@ +from pathlib import Path + +exp_name = Path(__file__).stem + +lrs = [2.5e-4] +lambda0s = [0.1] +lambda1s = [0.5] +alphas = [1] +ld_weights = [0.25] + +hparams = { + 'file_name': + f'runs_{exp_name}.txt', + 'entry': '-m scripts.batch_run_ppo_epoch', + 'args': [ + { + 'env': 'pocman', + 'double_critic': True, + 'action_concat': True, + 'lr': lrs, + 'lambda0': ' '.join(map(str, lambda0s)), + 'lambda1': lambda1s, + 'alpha': ' '.join(map(str, alphas)), + 'ld_weight': ld_weights, + 'hidden_size': 512, + 'entropy_coeff': 0.05, + 'num_epochs': 25, + 'steps_log_freq': 4, + 'update_log_freq': 200, + 'total_steps': int(1e7), + 'save_checkpoints': True, + 'save_runner_state': True, + 'seed': 2036, + 'n_seeds': 5, + 'platform': 'gpu', + 'study_name': exp_name + } + ] +} diff --git a/scripts/hyperparams/pocman_ppo_best_ckpt.py b/scripts/hyperparams/pocman_ppo_best_ckpt.py new file mode 100644 index 0000000..1ba98c4 --- /dev/null +++ b/scripts/hyperparams/pocman_ppo_best_ckpt.py @@ -0,0 +1,39 @@ +from pathlib import Path + +exp_name = Path(__file__).stem + +lrs = [2.5e-5] +lambda0s = [0.5] +lambda1s = [0.95] +alphas = [1] +ld_weights = [0] + +hparams = { + 'file_name': + f'runs_{exp_name}.txt', + 'entry': '-m scripts.batch_run_ppo_epoch', + 'args': [ + { + 'env': 'pocman', + 'double_critic': False, + 'action_concat': True, + 'lr': lrs, + 'lambda0': ' '.join(map(str, lambda0s)), + 'lambda1': ' '.join(map(str, lambda1s)), + 'alpha': ' '.join(map(str, alphas)), + 'ld_weight': ' '.join(map(str, ld_weights)), + 'hidden_size': 512, + 'entropy_coeff': 0.05, + 'num_epochs': 25, + 'steps_log_freq': 4, + 'update_log_freq': 200, + 'total_steps': int(1e7), + 'save_checkpoints': True, + 'save_runner_state': True, + 'seed': 2036, + 'n_seeds': 5, + 'platform': 'gpu', + 'study_name': exp_name + } + ] +} diff --git a/scripts/train_probe.py b/scripts/train_probe.py new file mode 100644 index 0000000..6d27590 --- /dev/null +++ b/scripts/train_probe.py @@ -0,0 +1,185 @@ +from collections import namedtuple +from pathlib import Path +from typing import Union + +from flax.training import orbax_utils +from flax.training.train_state import TrainState +import jax +import jax.numpy as jnp +import numpy as np +from tap import Tap +import optax +import orbax.checkpoint + +from lamb.models import PelletPredictorNN +from lamb.utils.replay import make_flat_buffer +from lamb.utils.replay.trajectory import TrajectoryBufferState +from lamb.utils.file_system import make_hash_md5, load_info + +from definitions import ROOT_DIR + + +class PocmanProbeHyperparams(Tap): + dataset_path: Union[str, Path] + features_idx: int = 0 + prediction_target_key: str = 'pellet_occupancy' # Key in dataset that we set as target + + hidden_size: int = 512 + n_hidden_layers: int = 2 + lr: float = 1e-4 + + epochs: int = 100 + train_steps: int = int(1e6) + batch_size: int = 32 + + study_name: str = 'test' + debug: bool = False + seed: int = 2024 + platform: str = 'cpu' + + def configure(self) -> None: + self.add_argument('--dataset_path', type=Path) + + +def filter_period_first_dim(x, n: int): + if isinstance(x, jnp.ndarray) or isinstance(x, np.ndarray): + return x[::n] + + +def make_train(args: PocmanProbeHyperparams): + args.steps_per_epoch = args.train_steps // args.epochs + + if args.dataset_path.suffix == '.npy': + restored = load_info(args.dataset_path) + else: + orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() + restored = orbax_checkpointer.restore(args.dataset_path) + + dataset_args, dataset = restored['args'], restored['dataset'] + + experience = jax.tree.map(lambda x: jnp.array(x), dataset) + if 'x' not in experience: + experience['x'] = experience[f'x_{args.features_idx}'] + + for key in list(experience.keys()): + if key.startswith('x_'): + del experience[key] + + Experience = namedtuple('Experience', list(experience.keys())) + experience = Experience(**experience) + + n_pellet_predictions = getattr(experience, args.prediction_target_key).shape[-1] + + network = PelletPredictorNN(hidden_size=args.hidden_size, + n_outs=n_pellet_predictions, + n_hidden_layers=args.n_hidden_layers) + + + experience_size = experience.x.shape[0] + buffer = make_flat_buffer( + max_length=experience_size, + min_length=args.batch_size, + sample_batch_size=args.batch_size, + # add_batch_size=experience_size + ) + buffer = buffer.replace( + init=jax.jit(buffer.init), + add=jax.jit(buffer.add, donate_argnums=0), + sample=jax.jit(buffer.sample), + can_sample=jax.jit(buffer.can_sample), + ) + + buffer_state = TrajectoryBufferState( + current_index=jnp.array(0, dtype=int), + is_full=jnp.array(True), + experience=jax.tree_util.tree_map(lambda x: x[None, ...], experience) + ) + + def train(rng): + params_rng, rng = jax.random.split(rng) + params = network.init(params_rng, experience.x[:1]) + tx = optax.adam(args.lr, eps=1e-5) + + train_state = TrainState.create( + apply_fn=network.apply, + params=params, + tx=tx, + ) + + def _epoch_step(runner_state, i): + # @scan_tqdm(args.train_steps, print_rate=100) + @jax.jit + def _update_step(runner_state, i): + ts, rng = runner_state + + sample_key, rng = jax.random.split(rng) + batch = buffer.sample(buffer_state, sample_key) + + target = getattr(batch.experience.first, args.prediction_target_key).astype(float) + + def _loss_fn(params: dict): + _, logits = network.apply(params, batch.experience.first.x) + loss = optax.losses.sigmoid_binary_cross_entropy(logits, target).sum(axis=-1) + return loss.mean() + + grad_fn = jax.jit(jax.value_and_grad(_loss_fn)) + loss, grads = grad_fn(ts.params) + new_ts = ts.apply_gradients(grads=grads) + return (new_ts, rng), loss + + runner_state, losses = jax.lax.scan( + _update_step, runner_state, jnp.arange(args.steps_per_epoch), args.steps_per_epoch + ) + if args.debug: + jax.debug.print("Step {step} average loss: {loss}", step=(i * args.steps_per_epoch), loss=losses.mean()) + + return runner_state, losses.mean() + + runner_state = (train_state, rng) + runner_state, epoch_losses = jax.lax.scan( + _epoch_step, runner_state, jnp.arange(args.epochs), args.epochs + ) + return { + 'final_train_state': runner_state[0], 'epoch_losses': epoch_losses, + 'ckpt': restored['ckpt'] + } + + return train + + +if __name__ == '__main__': + jax.disable_jit(True) + args = PocmanProbeHyperparams().parse_args() + jax.config.update('jax_platform_name', args.platform) + + key = jax.random.PRNGKey(args.seed) + train_key, key = jax.random.split(key) + + train_fn = make_train(args) + # train_fn = jax.jit(train_fn) + + out = train_fn(train_key) + + out['args'] = args.as_dict() + + def path_to_str(d: dict): + for k, v in d.items(): + if isinstance(v, Path): + d[k] = str(v) + elif isinstance(v, dict): + path_to_str(v) + path_to_str(out) + + # Save all results with Orbax + orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() + save_args = orbax_utils.save_args_from_target(out) + + results_dir = Path(ROOT_DIR, 'results', args.study_name) + results_dir.mkdir(exist_ok=True) + results_path = results_dir / f"{args.prediction_target_key}_seed_{args.seed}_features_idx_{args.features_idx}_{make_hash_md5(out['args'])}" + + print(f"Saving results to {results_path}") + orbax_checkpointer.save(results_path, out, save_args=save_args) + + print("Done.") + diff --git a/scripts/visualization/viz_pocman_probes.py b/scripts/visualization/viz_pocman_probes.py new file mode 100644 index 0000000..e863c29 --- /dev/null +++ b/scripts/visualization/viz_pocman_probes.py @@ -0,0 +1,371 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import NamedTuple, Tuple + +import numpy as np +import matplotlib.animation +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt + +class Position(NamedTuple): + x: np.int32 + y: np.int32 + + def __eq__(self, other: object) -> np.ndarray: + if not isinstance(other, Position): + return NotImplemented + return (self.x == other.x) & (self.y == other.y) + + +@dataclass +class State: + """The state of the environment. + + key: random key used for auto-reset. + grid: jax array (int) of the ingame maze with walls. + pellets: int tracking the number of pellets. + frightened_state_time: jax array (int) of shape () + tracks number of steps for the scatter state. + pellet_locations: jax array (int) of pellet locations. + power_up_locations: jax array (int) of power-up locations + player_locations: current 2D position of agent. + ghost_locations: jax array (int) of current ghost positions. + initial_player_locations: starting 2D position of agent. + initial_ghost_positions: jax array (int) of initial ghost positions. + ghost_init_targets: jax array (int) of initial ghost targets. + used to direct ghosts on respawn. + old_ghost_locations: jax array (int) of shape ghost positions from last step. + used to prevent ghost backtracking. + ghost_init_steps: jax array (int) of number of initial ghost steps. + used to determine per ghost initialisation. + ghost_actions: jax array (int) of ghost action at current step. + last_direction: (int) tracking the last direction of the player. + dead: (bool) used to track player death. + visited_index: jax array (int) of visited locations. + used to prevent repeated pellet points. + ghost_starts: jax array (int) of reset positions for ghosts + used to reset ghost positions if eaten + scatter_targets: jax array (int) of scatter targets. + target locations for ghosts when scatter behavior is active. + step_count: (int32) of total steps taken from reset till current timestep. + ghost_eaten: jax array (bool) tracking if ghost has been eaten before. + score: (int32) of total points aquired. + """ + + key: np.ndarray # (2,) + grid: np.ndarray # (31,28) + pellets: np.int32 # () + frightened_state_time: np.int32 # () + pellet_locations: np.ndarray # (316,2) + power_up_locations: np.ndarray # (4,2) + player_locations: Position # Position(row, col) each of shape () + ghost_locations: np.ndarray # (4,2) + initial_player_locations: Position # Position(row, col) each of shape () + initial_ghost_positions: np.ndarray # (4,2) + ghost_init_targets: np.ndarray # (4,2) + old_ghost_locations: np.ndarray # (4,2) + ghost_init_steps: np.ndarray # (4,) + ghost_actions: np.ndarray # (4,) + last_direction: np.int32 # () + dead: bool # () + visited_index: np.ndarray # (320,2) + ghost_starts: np.ndarray # (4,2) + scatter_targets: np.ndarray # (4,2) + step_count: np.int32 # () + ghost_eaten: np.ndarray # (4,) + score: np.int32 # () + +def create_grid_image(observation: State) -> np.ndarray: + """ + Generate the observation of the current state. + + Args: + state: 'State` object corresponding to the new state of the environment. + + Returns: + rgb: A 3-dimensional array representing the RGB observation of the current state. + """ + + # Make walls blue and passages black + layer_1 = (1 - observation.grid) * 0.0 + layer_2 = (1 - observation.grid) * 0.0 + layer_3 = (1 - observation.grid) * 0.6 + + player_loc = observation.player_locations + ghost_pos = observation.ghost_locations + pellets_loc = observation.power_up_locations + is_scared = observation.frightened_state_time + idx = observation.pellet_locations + n = 3 + + # Power pellet are pink + for i in range(len(pellets_loc)): + p = pellets_loc[i] + layer_1[p[1], p[0]] = 1.0 + layer_2[p[1], p[0]] = 0.8 + layer_3[p[1], p[0]] = 0.6 + + # Set player is yellow + layer_1[player_loc.x, player_loc.y] = 1 + layer_2[player_loc.x, player_loc.y] = 1 + layer_3[player_loc.x, player_loc.y] = 0 + + cr = np.array([1, 1, 0, 1]) + cg = np.array([0, 0.7, 1, 0.5]) + cb = np.array([0, 1, 1, 0.0]) + # Set ghost locations + + layers = (layer_1, layer_2, layer_3) + + def set_ghost_colours( + layers: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + layer_1, layer_2, layer_3 = layers + for i in range(4): + y = ghost_pos[i][0] + x = ghost_pos[i][1] + + layer_1[x, y] = cr[i] + layer_2[x, y] = cg[i] + layer_3[x, y] = cb[i] + return layer_1, layer_2, layer_3 + + def set_ghost_colours_scared( + layers: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + layer_1, layer_2, layer_3 = layers + for i in range(4): + y = ghost_pos[i][0] + x = ghost_pos[i][1] + layer_1[x, y] = 0 + layer_2[x, y] = 0 + layer_3[x, y] = 1 + return layer_1, layer_2, layer_3 + + if is_scared > 0: + layers = set_ghost_colours_scared(layers) + else: + layers = set_ghost_colours(layers) + + layer_1, layer_2, layer_3 = layers + + layer_1[0, 0] = 0 + layer_2[0, 0] = 0 + layer_3[0, 0] = 0.6 + + obs = [layer_1, layer_2, layer_3] + rgb = np.stack(obs, axis=-1) + + expand_rgb = np.kron(rgb, np.ones((n, n, 1))) + layer_1 = expand_rgb[:, :, 0] + layer_2 = expand_rgb[:, :, 1] + layer_3 = expand_rgb[:, :, 2] + + # place normal pellets + for i in range(len(idx)): + if np.array(idx[i]).sum != 0: + loc = idx[i] + c = loc[1] * n + 1 + r = loc[0] * n + 1 + layer_1[c, r] = 1.0 + layer_2[c, r] = 0.8 + layer_3[c, r] = 0.6 + + layers = (layer_1, layer_2, layer_3) + + # Draw details + def set_ghost_colours_details( + layers: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + layer_1, layer_2, layer_3 = layers + for i in range(4): + y = ghost_pos[i][0] + x = ghost_pos[i][1] + c = x * n + 1 + r = y * n + 1 + + layer_1[c, r] = cr[i] + layer_2[c, r] = cg[i] + layer_3[c, r] = cb[i] + + # Make notch in top + layer_1[c - 1, r - 1] = 0.0 + layer_2[c - 1, r - 1] = 0.0 + layer_3[c - 1, r - 1] = 0.0 + + # Make notch in top + layer_1[c - 1, r + 1] = 0.0 + layer_2[c - 1, r + 1] = 0.0 + layer_3[c - 1, r + 1] = 0.0 + + # Eyes + layer_1[c, r + 1] = 1 + layer_2[c, r + 1] = 1 + layer_3[c, r + 1] = 1 + + layer_1[c, r - 1] = 1 + layer_2[c, r - 1] = 1 + layer_3[c, r - 1] = 1 + + return layer_1, layer_2, layer_3 + + def set_ghost_colours_scared_details( + layers: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + layer_1, layer_2, layer_3 = layers + for i in range(4): + y = ghost_pos[i][0] + x = ghost_pos[i][1] + + c = x * n + 1 + r = y * n + 1 + + layer_1[x * n + 1, y * n + 1] = 0 + layer_2[x * n + 1, y * n + 1] = 0 + layer_3[x * n + 1, y * n + 1] = 1 + + # Make notch in top + layer_1[c - 1, r - 1] = 0.0 + layer_2[c - 1, r - 1] = 0.0 + layer_3[c - 1, r - 1] = 0.0 + + # Make notch in top + layer_1[c - 1, r + 1] = 0.0 + layer_2[c - 1, r + 1] = 0.0 + layer_3[c - 1, r + 1] = 0.0 + + # Eyes + layer_1[c, r + 1] = 1 + layer_2[c, r + 1] = 0.6 + layer_3[c, r + 1] = 0.2 + + layer_1[c, r - 1] = 1 + layer_2[c, r - 1] = 0.6 + layer_3[c, r - 1] = 0.2 + + return layer_1, layer_2, layer_3 + + if is_scared > 0: + layers = set_ghost_colours_scared_details(layers) + else: + layers = set_ghost_colours_details(layers) + + layer_1, layer_2, layer_3 = layers + + # Power pellet is pink + for i in range(len(pellets_loc)): + p = pellets_loc[i] + layer_1[p[1] * n + 2, p[0] * n + 1] = 1 + layer_2[p[1] * n + 1, p[0] * n + 1] = 0.8 + layer_3[p[1] * n + 1, p[0] * n + 1] = 0.6 + + # Set player is yellow + layer_1[player_loc.x * n + 1, player_loc.y * n + 1] = 1 + layer_2[player_loc.x * n + 1, player_loc.y * n + 1] = 1 + layer_3[player_loc.x * n + 1, player_loc.y * n + 1] = 0 + + obs = [layer_1, layer_2, layer_3] + rgb = np.stack(obs, axis=-1) + expand_rgb + + return rgb + + +def visualize_grid(grid, fig, ax, add_colorbar=False): + # Check if the grid has the correct dimensions + if grid.shape != (21, 19): + raise ValueError("Grid must be 21x19 in size") + + # Create a custom colormap for values between 0 and 1 + cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap', ['white', '#f39c12']) # Using a brighter blue + + # Create a masked array for the grid + masked_grid = np.ma.masked_where(grid == -1, grid) + + # Plot the grid + if fig is None: + fig, ax = plt.subplots() + cax = ax.imshow(masked_grid, cmap=cmap, vmin=0, vmax=1) + + # Add color bar + if add_colorbar: + cbar = fig.colorbar(cax, ax=ax, fraction=0.046, pad=0.04) + cbar.set_label('Value') + + # Plot the -1 values as black + for (i, j), value in np.ndenumerate(grid): + if value == -1: + ax.add_patch(plt.Rectangle((j - 0.5, i - 0.5), 1, 1, color='black')) + + # Set grid lines + ax.set_xticks(np.arange(-.5, 19, 1), minor=True) + ax.set_yticks(np.arange(-.5, 21, 1), minor=True) + ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5) + ax.tick_params(which='minor', size=0) + + if fig is None: + # Display the grid + plt.show() + + +def load_info(results_path: Path) -> dict: + return np.load(results_path, allow_pickle=True).item() + + +def pack_states(states: list[dict]) -> list[State]: + packed_states = [] + for s in states: + dict_s = {} + for k, v in list(s.items()): + if isinstance(v, dict): + assert 'x' in v and 'y' in v + dict_s[k] = Position(x=v['x'], y=v['y']) + else: + assert isinstance(v, np.ndarray) or isinstance(v, list) + dict_s[k] = v + packed_states.append(State(**dict_s)) + return packed_states + + +if __name__ == "__main__": + traj_path = Path('../../results/pocman_pellet_probe_trajectory_bidx_1.npy') + save_vod_path = Path('../../results/pocman_pellet_probe_trajectory_bidx_1.mp4') + + dataset = load_info(traj_path) + states, predictions = pack_states(dataset['states']), dataset['predictions'] + + # now we make our animation + fig, axes = plt.subplots(3, num=f"PocmanPredictionAnimation", figsize=(4, 12)) + + def make_frame(idx: int) -> None: + state_ax, p0_ax, p1_ax = axes + for a in axes: + a.clear() + + # First we make our state image + state = states[idx] + state_img = create_grid_image(state) + state_ax.set_axis_off() + state_ax.imshow(state_img) + + # Now we make our prediction images + prediction0, prediction1 = predictions[idx] + visualize_grid(prediction0, fig, p0_ax, add_colorbar=(idx==0)) + visualize_grid(prediction1, fig, p1_ax, add_colorbar=(idx==0)) + + fig.suptitle(f"PacMan Score: {int(state.score)}", size=10) + + + animation = matplotlib.animation.FuncAnimation( + fig, + make_frame, + frames=len(states), + interval=400, + ) + + + # plt.show() + # Save the animation as a gif. + animation.save(save_vod_path) + + print(f"Saved animation to {save_vod_path}")