From 8fd790326bcd43be79e158e0dc3d824f5c9fcfd8 Mon Sep 17 00:00:00 2001 From: Daphne Cornelisse Date: Sat, 5 Apr 2025 09:49:00 -0400 Subject: [PATCH 1/5] Implement advantage filtering - zero-out transitions below threshold --- baselines/ppo/config/ppo_base_puffer.yaml | 4 + baselines/ppo/ppo_pufferlib.py | 7 ++ gpudrive/integrations/puffer/ppo.py | 130 +++++++++++++++++++--- 3 files changed, 126 insertions(+), 15 deletions(-) diff --git a/baselines/ppo/config/ppo_base_puffer.yaml b/baselines/ppo/config/ppo_base_puffer.yaml index 9f985667a..e1cc33080 100644 --- a/baselines/ppo/config/ppo_base_puffer.yaml +++ b/baselines/ppo/config/ppo_base_puffer.yaml @@ -78,6 +78,10 @@ train: max_grad_norm: 0.5 target_kl: null log_window: 1000 + # Advantage filtering + apply_advantage_filter: true + initial_th_factor: 0.01 + beta: 0.25 # # # Network # # # network: diff --git a/baselines/ppo/ppo_pufferlib.py b/baselines/ppo/ppo_pufferlib.py index 587fb96c9..3c2b5a3ea 100644 --- a/baselines/ppo/ppo_pufferlib.py +++ b/baselines/ppo/ppo_pufferlib.py @@ -189,6 +189,10 @@ def run( minibatch_size: Annotated[Optional[int], typer.Option(help="The minibatch size for training")] = None, gamma: Annotated[Optional[float], typer.Option(help="The discount factor for rewards")] = None, vf_coef: Annotated[Optional[float], typer.Option(help="Weight for vf_loss")] = None, + # Advantage filtering + apply_advantage_filter: Annotated[Optional[int], typer.Option(help="Whether to use advantage filter; 0 or 1")] = None, + initial_th_factor: Annotated[Optional[float], typer.Option(help="Initial threshold factor for training")] = None, + beta: Annotated[Optional[float], typer.Option(help="Beta parameter for training")] = None, # Wandb logging options project: Annotated[Optional[str], typer.Option(help="WandB project name")] = None, entity: Annotated[Optional[str], typer.Option(help="WandB entity name")] = None, @@ -242,6 +246,9 @@ def run( "render": None if render is None else bool(render), "gamma": gamma, "vf_coef": vf_coef, + "apply_advantage_filter": apply_advantage_filter, + "initial_th_factor": initial_th_factor, + "beta": beta, } config.train.update( {k: v for k, v in train_config.items() if v is not None} diff --git a/gpudrive/integrations/puffer/ppo.py b/gpudrive/integrations/puffer/ppo.py index bdc65ed45..9c14bfb8e 100644 --- a/gpudrive/integrations/puffer/ppo.py +++ b/gpudrive/integrations/puffer/ppo.py @@ -33,6 +33,59 @@ from gpudrive.integrations.puffer.logging import print_dashboard, abbreviate +class AdvantageFilter: + """ + Advantage filtering class to filter transitions based on advantage magnitude. + + This implementation is based on Algorithm 1 in "Robust Autonomy Emerges from Self-Play" + (https://arxiv.org/abs/2502.03349). The key idea is to discard transitions with + low-magnitude advantages to focus training on the most informative samples. + + The filtering threshold η is set to a percentage of the maximum advantage + magnitude observed so far, making it scale-invariant to reward magnitudes. + """ + + def __init__(self, beta=0.25, initial_th_factor=0.01): + """ + Args: + beta: EWMA decay factor for tracking maximum advantage + initial_th_factor: Filter threshold as a percentage of max advantage + """ + self.beta = beta + self.threshold_factor = initial_th_factor + self.max_advantage_ewma = None + + def filter(self, advantages_np): + """ + Filter transitions based on advantage magnitude. + + Args: + advantages_np: Numpy array of advantages + + Returns: + Boolean mask where True indicates transitions to keep + """ + # Get new max advantage + max_advantage = float(np.max(np.abs(advantages_np))) + + # Update the EWMA of max advantage + if self.max_advantage_ewma is None: + self.max_advantage_ewma = max_advantage + else: + self.max_advantage_ewma = ( + self.beta * max_advantage + + (1 - self.beta) * self.max_advantage_ewma + ) + + # Update filtering threshold + threshold = self.threshold_factor * self.max_advantage_ewma + + # Create mask of transitions to keep (where |advantage| >= threshold) + mask = np.abs(advantages_np) >= threshold + + return mask, threshold + + def create(config, vecenv, policy, optimizer=None, wandb=None): seed_everything(config.seed, config.torch_deterministic) profile = Profile() @@ -235,13 +288,50 @@ def train(data): losses = data.losses with profile.train_misc: + # Get the sorted indices for training data idxs = experience.sort_training_data() dones_np = experience.dones_np[idxs] values_np = experience.values_np[idxs] rewards_np = experience.rewards_np[idxs] + + # Compute GAE advantages advantages_np = compute_gae( dones_np, values_np, rewards_np, config.gamma, config.gae_lambda ) + + if config.apply_advantage_filter: + # Initialize the advantage filter if not already created + if not hasattr(data, "advantage_filter"): + data.advantage_filter = AdvantageFilter( + beta=config.beta, + initial_th_factor=config.initial_th_factor, + ) + + # Get mask of transitions to keep based on advantage magnitude + mask, threshold = data.advantage_filter.filter(advantages_np) + + # Apply weights to advantages - this zeroes out filtered transitions + # but keeps the array the same size and shape + advantages_np = advantages_np * mask.astype(np.float32) + + # Log filtering stats + num_total = len(advantages_np) + num_kept = mask.sum() + percent_kept = 100 * num_kept / num_total if num_total > 0 else 0 + data.msg = f"Advantage filtering: kept {num_kept}/{num_total} transitions ({percent_kept:.1f}%)" + + if not hasattr(data, "filtering_stats"): + data.filtering_stats = [] + + data.filtering_stats.append( + { + "threshold": threshold, + "percent_kept": percent_kept, + "max_advantage": float(np.max(np.abs(advantages_np))), + "global_step": data.global_step, + } + ) + experience.flatten_batch(advantages_np) # Optimizing the policy and value network @@ -381,22 +471,32 @@ def train(data): and data.global_step > 0 and time.perf_counter() - data.last_log_time > 3.0 ): - data.last_log_time = time.perf_counter() - data.wandb.log( - { - "performance/controlled_agent_sps": profile.controlled_agent_sps, - "performance/controlled_agent_sps_env": profile.controlled_agent_sps_env, - "performance/pad_agent_sps": profile.pad_agent_sps, - "performance/pad_agent_sps_env": profile.pad_agent_sps_env, - "global_step": data.global_step, - "performance/epoch": data.epoch, - "performance/uptime": profile.uptime, - "train/learning_rate": data.optimizer.param_groups[0]["lr"], - **{f"metrics/{k}": v for k, v in data.stats.items()}, - **{f"train/{k}": v for k, v in data.losses.items()}, - } - ) + + # Create log dictionary with existing metrics + log_dict = { + "performance/controlled_agent_sps": profile.controlled_agent_sps, + "performance/controlled_agent_sps_env": profile.controlled_agent_sps_env, + "performance/pad_agent_sps": profile.pad_agent_sps, + "performance/pad_agent_sps_env": profile.pad_agent_sps_env, + "global_step": data.global_step, + "performance/epoch": data.epoch, + "performance/uptime": profile.uptime, + "train/learning_rate": data.optimizer.param_groups[0]["lr"], + **{f"metrics/{k}": v for k, v in data.stats.items()}, + **{f"train/{k}": v for k, v in data.losses.items()}, + } + + # Add advantage filtering metrics if available + if hasattr(data, 'filtering_stats') and data.filtering_stats: + latest_stats = data.filtering_stats[-1] + log_dict.update({ + "advantage_filtering/threshold (η)": latest_stats['threshold'], + "advantage_filtering/percent_kept": latest_stats['percent_kept'], + "advantage_filtering/max_advantage": latest_stats['max_advantage'] + }) + + data.wandb.log(log_dict) if bool(data.stats): data.wandb.log({ From 773c6f3ed12da5d8ec02266809cfa9aa9e71ca63 Mon Sep 17 00:00:00 2001 From: Daphne Cornelisse Date: Sat, 5 Apr 2025 10:46:22 -0400 Subject: [PATCH 2/5] Store config with checkpoint --- gpudrive/integrations/puffer/ppo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gpudrive/integrations/puffer/ppo.py b/gpudrive/integrations/puffer/ppo.py index 9c14bfb8e..eed47775d 100644 --- a/gpudrive/integrations/puffer/ppo.py +++ b/gpudrive/integrations/puffer/ppo.py @@ -318,7 +318,9 @@ def train(data): num_total = len(advantages_np) num_kept = mask.sum() percent_kept = 100 * num_kept / num_total if num_total > 0 else 0 - data.msg = f"Advantage filtering: kept {num_kept}/{num_total} transitions ({percent_kept:.1f}%)" + data.msg = ( + f"Advantage filtering: kept {percent_kept:.1f}% of transitions" + ) if not hasattr(data, "filtering_stats"): data.filtering_stats = [] @@ -504,7 +506,6 @@ def train(data): }) # fmt: on - if data.epoch % config.checkpoint_interval == 0 or done_training: save_checkpoint(data) data.msg = f"Checkpoint saved at update {data.epoch}" @@ -816,6 +817,7 @@ def save_checkpoint(data, save_checkpoint_to_wandb=True): "action_dim": data.uncompiled_policy.action_dim, "exp_id": config.exp_id, "num_params": config.network["num_parameters"], + "config": data.config.to_dict(), } torch.save(state, model_path) From d2f67a52eef385f9115ae1d6d8e362697ac90d76 Mon Sep 17 00:00:00 2001 From: Daphne Cornelisse Date: Mon, 7 Apr 2025 08:47:45 -0400 Subject: [PATCH 3/5] Logging stats --- baselines/ppo/config/ppo_base_puffer.yaml | 7 ++-- gpudrive/integrations/puffer/ppo.py | 44 ++++++++++++++--------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/baselines/ppo/config/ppo_base_puffer.yaml b/baselines/ppo/config/ppo_base_puffer.yaml index e1cc33080..a9b6e9539 100644 --- a/baselines/ppo/config/ppo_base_puffer.yaml +++ b/baselines/ppo/config/ppo_base_puffer.yaml @@ -38,13 +38,13 @@ environment: # Overrides default environment configs (see pygpudrive/env/config. wandb: entity: "" - project: "gpudrive" - group: "test" + project: "adv_filter" + group: "testing" mode: "online" # Options: online, offline, disabled tags: ["ppo", "ff"] train: - exp_id: PPO # Set dynamically in the script if needed + exp_id: adv_filter # Set dynamically in the script if needed seed: 42 cpu_offload: false device: "cuda" # Dynamically set to cuda if available, else cpu @@ -63,6 +63,7 @@ train: torch_deterministic: false total_timesteps: 1_000_000_000 batch_size: 131_072 + num_minibatches: 16 minibatch_size: 8192 learning_rate: 3e-4 anneal_lr: false diff --git a/gpudrive/integrations/puffer/ppo.py b/gpudrive/integrations/puffer/ppo.py index eed47775d..794334d0d 100644 --- a/gpudrive/integrations/puffer/ppo.py +++ b/gpudrive/integrations/puffer/ppo.py @@ -63,7 +63,8 @@ def filter(self, advantages_np): advantages_np: Numpy array of advantages Returns: - Boolean mask where True indicates transitions to keep + mask: Boolean mask where True indicates transitions to keep + threshold: Current filtering threshold (η) """ # Get new max advantage max_advantage = float(np.max(np.abs(advantages_np))) @@ -299,7 +300,13 @@ def train(data): dones_np, values_np, rewards_np, config.gamma, config.gae_lambda ) + filter_mask = None if config.apply_advantage_filter: + if config.bptt_horizon > 1: + raise ValueError( + "Advantage filtering cannot be used with LSTM (bptt_horizon > 1)" + ) + # Initialize the advantage filter if not already created if not hasattr(data, "advantage_filter"): data.advantage_filter = AdvantageFilter( @@ -308,32 +315,35 @@ def train(data): ) # Get mask of transitions to keep based on advantage magnitude - mask, threshold = data.advantage_filter.filter(advantages_np) + filter_mask, threshold = data.advantage_filter.filter( + advantages_np + ) # Apply weights to advantages - this zeroes out filtered transitions # but keeps the array the same size and shape - advantages_np = advantages_np * mask.astype(np.float32) + advantages_np = advantages_np * filter_mask.astype(np.float32) - # Log filtering stats + # Log stats num_total = len(advantages_np) - num_kept = mask.sum() + num_kept = filter_mask.sum() percent_kept = 100 * num_kept / num_total if num_total > 0 else 0 - data.msg = ( - f"Advantage filtering: kept {percent_kept:.1f}% of transitions" - ) - if not hasattr(data, "filtering_stats"): - data.filtering_stats = [] + data.filtering_stats = { + "advantage_filtering/threshold (η)": threshold, + "advantage_filtering/percent_kept": percent_kept, + "advantage_filtering/max_advantage": float( + np.max(np.abs(advantages_np)) + ), + "advantage_filtering/num_kept": int(num_kept), + "advantage_filtering/total": int(num_total), + } - data.filtering_stats.append( - { - "threshold": threshold, - "percent_kept": percent_kept, - "max_advantage": float(np.max(np.abs(advantages_np))), - "global_step": data.global_step, - } + data.msg = ( + f"Advantage filtering: kept {num_kept}/{num_total} transitions " + f"({percent_kept:.1f}%, threshold (η)={threshold:.4f})" ) + # experience.flatten_batch(advantages_np, filter_mask) experience.flatten_batch(advantages_np) # Optimizing the policy and value network From c2c457bc7512f557e92cbb915c49a6669397da67 Mon Sep 17 00:00:00 2001 From: Daphne Cornelisse Date: Mon, 7 Apr 2025 09:02:08 -0400 Subject: [PATCH 4/5] Take in the number of minibatches instead of enforcing a fixed minibatch_size --- baselines/ppo/config/ppo_base_puffer.yaml | 1 - baselines/ppo/ppo_pufferlib.py | 6 +++--- gpudrive/integrations/puffer/ppo.py | 19 ++++++++----------- gpudrive/utils/generate_sbatch.py | 15 ++++++--------- 4 files changed, 17 insertions(+), 24 deletions(-) diff --git a/baselines/ppo/config/ppo_base_puffer.yaml b/baselines/ppo/config/ppo_base_puffer.yaml index a9b6e9539..8f8538847 100644 --- a/baselines/ppo/config/ppo_base_puffer.yaml +++ b/baselines/ppo/config/ppo_base_puffer.yaml @@ -64,7 +64,6 @@ train: total_timesteps: 1_000_000_000 batch_size: 131_072 num_minibatches: 16 - minibatch_size: 8192 learning_rate: 3e-4 anneal_lr: false gamma: 0.99 diff --git a/baselines/ppo/ppo_pufferlib.py b/baselines/ppo/ppo_pufferlib.py index 3c2b5a3ea..6f0e2a550 100644 --- a/baselines/ppo/ppo_pufferlib.py +++ b/baselines/ppo/ppo_pufferlib.py @@ -144,7 +144,7 @@ def sweep(args, project="PPO", sweep_name="my_sweep"): "max": 1e-1, }, "batch_size": {"values": [512, 1024, 2048]}, - "minibatch_size": {"values": [128, 256, 512]}, + "num_minibatches": {"values": [4, 8, 16]}, }, ), project=project, @@ -186,7 +186,7 @@ def run( ent_coef: Annotated[Optional[float], typer.Option(help="Entropy coefficient")] = None, update_epochs: Annotated[Optional[int], typer.Option(help="The number of epochs for updating the policy")] = None, batch_size: Annotated[Optional[int], typer.Option(help="The batch size for training")] = None, - minibatch_size: Annotated[Optional[int], typer.Option(help="The minibatch size for training")] = None, + num_minibatches: Annotated[Optional[int], typer.Option(help="The number of minibatches for training")] = None, gamma: Annotated[Optional[float], typer.Option(help="The discount factor for rewards")] = None, vf_coef: Annotated[Optional[float], typer.Option(help="Weight for vf_loss")] = None, # Advantage filtering @@ -242,7 +242,7 @@ def run( "ent_coef": ent_coef, "update_epochs": update_epochs, "batch_size": batch_size, - "minibatch_size": minibatch_size, + "num_minibatches": num_minibatches, "render": None if render is None else bool(render), "gamma": gamma, "vf_coef": vf_coef, diff --git a/gpudrive/integrations/puffer/ppo.py b/gpudrive/integrations/puffer/ppo.py index 794334d0d..2163f45a5 100644 --- a/gpudrive/integrations/puffer/ppo.py +++ b/gpudrive/integrations/puffer/ppo.py @@ -116,7 +116,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None): experience = Experience( config.batch_size, config.bptt_horizon, - config.minibatch_size, + config.num_minibatches, obs_shape, obs_dtype, atn_shape, @@ -501,12 +501,7 @@ def train(data): # Add advantage filtering metrics if available if hasattr(data, 'filtering_stats') and data.filtering_stats: - latest_stats = data.filtering_stats[-1] - log_dict.update({ - "advantage_filtering/threshold (η)": latest_stats['threshold'], - "advantage_filtering/percent_kept": latest_stats['percent_kept'], - "advantage_filtering/max_advantage": latest_stats['max_advantage'] - }) + log_dict.update(data.filtering_stats) data.wandb.log(log_dict) @@ -645,7 +640,7 @@ def __init__( self, batch_size, bptt_horizon, - minibatch_size, + num_minibatches, obs_shape, obs_dtype, atn_shape, @@ -654,8 +649,8 @@ def __init__( lstm=None, lstm_total_agents=0, ): - if minibatch_size is None: - minibatch_size = batch_size + if num_minibatches is None: + num_minibatches = 1 obs_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[obs_dtype] pin = device == "cuda" and cpu_offload @@ -690,8 +685,8 @@ def __init__( self.lstm_h = torch.zeros(shape).to(device) self.lstm_c = torch.zeros(shape).to(device) - num_minibatches = batch_size / minibatch_size self.num_minibatches = int(num_minibatches) + minibatch_size = batch_size // num_minibatches if self.num_minibatches != num_minibatches: raise ValueError("batch_size must be divisible by minibatch_size") @@ -712,9 +707,11 @@ def __init__( @property def full(self): + """Check if the buffer is full.""" return self.ptr >= self.batch_size def store(self, obs, value, action, logprob, reward, done, env_id, mask): + """Store a batch of transitions in the buffer.""" # Mask learner and Ensure indices do not exceed batch size ptr = self.ptr indices = torch.where(mask)[0].cpu().numpy()[: self.batch_size - ptr] diff --git a/gpudrive/utils/generate_sbatch.py b/gpudrive/utils/generate_sbatch.py index 95850cca3..0118e5359 100644 --- a/gpudrive/utils/generate_sbatch.py +++ b/gpudrive/utils/generate_sbatch.py @@ -252,21 +252,21 @@ def save_script(filename, file_path, fields, params, param_order=None): "memory": 70, "job_name": group, } - + hyperparams = { - "group": [group], # Group name + "group": [group], # Group name "num_worlds": [800], - "resample_scenes": [1], # Yes + "resample_scenes": [1], # Yes "k_unique_scenes": [800], "resample_interval": [5_000_000], "total_timesteps": [4_000_000_000], "resample_dataset_size": [10_000], "batch_size": [524288], - "minibatch_size": [16384], + "num_minibatches": [16], "update_epochs": [4], "ent_coef": [0.001, 0.003, 0.0001], "render": [0], - #"seed": [42, 3], + # "seed": [42, 3], } save_script( @@ -285,7 +285,7 @@ def save_script(filename, file_path, fields, params, param_order=None): # "total_timesteps": [3_000_000_000], # "resample_dataset_size": [1000], # "batch_size": [262_144, 524_288], - # "minibatch_size": [16_384], + # "num_minibatches": [16_384], # "update_epochs": [2, 4, 5], # "ent_coef": [0.0001, 0.001, 0.003], # "learning_rate": [1e-4, 3e-4], @@ -299,6 +299,3 @@ def save_script(filename, file_path, fields, params, param_order=None): # fields=fields, # params=hyperparams, # ) - - - From 5a93bcaeb3a6fecb0ccddd4946a2f91517de3b1b Mon Sep 17 00:00:00 2001 From: Daphne Cornelisse Date: Mon, 7 Apr 2025 10:12:09 -0400 Subject: [PATCH 5/5] Filter out transitions associated with low advantages by making the mb size dynamic --- gpudrive/integrations/puffer/ppo.py | 126 ++++++++++++++++++++++------ 1 file changed, 99 insertions(+), 27 deletions(-) diff --git a/gpudrive/integrations/puffer/ppo.py b/gpudrive/integrations/puffer/ppo.py index 2163f45a5..9c9bb8b4e 100644 --- a/gpudrive/integrations/puffer/ppo.py +++ b/gpudrive/integrations/puffer/ppo.py @@ -12,6 +12,7 @@ import random import psutil import time +import warnings from threading import Thread from collections import defaultdict, deque @@ -319,10 +320,6 @@ def train(data): advantages_np ) - # Apply weights to advantages - this zeroes out filtered transitions - # but keeps the array the same size and shape - advantages_np = advantages_np * filter_mask.astype(np.float32) - # Log stats num_total = len(advantages_np) num_kept = filter_mask.sum() @@ -343,8 +340,8 @@ def train(data): f"({percent_kept:.1f}%, threshold (η)={threshold:.4f})" ) - # experience.flatten_batch(advantages_np, filter_mask) - experience.flatten_batch(advantages_np) + # Prepare batch of transitions for model updating + experience.flatten_batch(advantages_np, filter_mask) # Optimizing the policy and value network num_update_iters = config.update_epochs * experience.num_minibatches @@ -632,7 +629,6 @@ def make_losses(): explained_variance=0, ) - class Experience: """Flat tensor storage (buffer) and array views for faster indexing.""" @@ -751,27 +747,103 @@ def sort_training_data(self): self.step = 0 return idxs - def flatten_batch(self, advantages_np): - advantages = torch.from_numpy(advantages_np).to(self.device) - b_idxs, b_flat = self.b_idxs, self.b_idxs_flat - self.b_actions = self.actions.to(self.device, non_blocking=True) - self.b_logprobs = self.logprobs.to(self.device, non_blocking=True) - self.b_dones = self.dones.to(self.device, non_blocking=True) - self.b_values = self.values.to(self.device, non_blocking=True) - self.b_advantages = ( - advantages.reshape( - self.minibatch_rows, self.num_minibatches, self.bptt_horizon + def flatten_batch(self, advantages_np, filter_mask=None): + """Prepare the batch of transitions for model updating.""" + + if filter_mask is not None: + + # Get the indices of transitions to keep + kept_indices = np.nonzero(filter_mask)[0] + total_kept = len(kept_indices) + + # Determine how many transitions per minibatch (floor division) + transitions_per_mb = total_kept // self.num_minibatches + + # We need at least one transition per minibatch + if transitions_per_mb < 32: + transitions_per_mb = 64 + + warnings.warn(f"Low adv. filtering retention rate: Only kept {len(kept_indices)} / {len(advantages_np)} transitions ({transitions_per_mb} per minibatch) \n Consider adjusting the advantage threshold factor or increase the batch_size.", UserWarning) + + # If we don't have enough, sample with replacement + if total_kept < self.num_minibatches: + kept_indices = np.random.choice(kept_indices, self.num_minibatches, replace=True) + total_kept = len(kept_indices) + + # Calculate total transitions to use (divisible by num_minibatches) + transitions_to_use = transitions_per_mb * self.num_minibatches + + np.random.shuffle(kept_indices) + kept_indices = kept_indices[:transitions_to_use] + filtered_idxs = kept_indices.copy() + + # Reshape to (minibatch_rows, num_minibatches, bptt_horizon) + minibatch_rows_filtered = transitions_to_use // (self.num_minibatches * self.bptt_horizon) + filtered_idxs = filtered_idxs.reshape( + minibatch_rows_filtered, + self.num_minibatches, + self.bptt_horizon ) - .transpose(0, 1) - .reshape(self.num_minibatches, self.minibatch_size) - ) - self.returns_np = advantages_np + self.values_np - self.b_obs = self.obs[self.b_idxs_obs] - self.b_actions = self.b_actions[b_idxs].contiguous() - self.b_logprobs = self.b_logprobs[b_idxs] - self.b_dones = self.b_dones[b_idxs] - self.b_values = self.b_values[b_flat] - self.b_returns = self.b_advantages + self.b_values + filtered_idxs = np.transpose(filtered_idxs, (1, 0, 2)) + + # Update minibatch indices + self.b_idxs_obs = torch.as_tensor(filtered_idxs).to(self.obs.device).long() + self.b_idxs = self.b_idxs_obs.to(self.device) + self.b_idxs_flat = self.b_idxs.reshape(self.num_minibatches, -1) + + # Get advantages for the filtered transitions + advantages = torch.from_numpy(advantages_np).to(self.device) + + # The rest of the processing is similar to the original code + b_idxs, b_flat = self.b_idxs, self.b_idxs_flat + self.b_actions = self.actions.to(self.device, non_blocking=True) + self.b_logprobs = self.logprobs.to(self.device, non_blocking=True) + self.b_dones = self.dones.to(self.device, non_blocking=True) + self.b_values = self.values.to(self.device, non_blocking=True) + + # Reshape advantages to match the filtered structure + filtered_advantages = advantages[kept_indices[:transitions_to_use]] + self.b_advantages = filtered_advantages.reshape( + self.num_minibatches, -1 + ) + + # Compute returns + self.returns_np = advantages_np + self.values_np + + # Get observations, actions, etc. based on filtered indices + self.b_obs = self.obs[self.b_idxs_obs] + self.b_actions = self.b_actions[b_idxs].contiguous() + self.b_logprobs = self.b_logprobs[b_idxs] + self.b_dones = self.b_dones[b_idxs] + self.b_values = self.b_values[b_flat] + self.b_returns = self.b_advantages + self.b_values + + else: + # Original implementation for when no filtering is applied + advantages = torch.from_numpy(advantages_np).to(self.device) + + # Get the respective indices for all minibatches + b_idxs, b_flat = self.b_idxs, self.b_idxs_flat + self.b_actions = self.actions.to(self.device, non_blocking=True) + self.b_logprobs = self.logprobs.to(self.device, non_blocking=True) + self.b_dones = self.dones.to(self.device, non_blocking=True) + self.b_values = self.values.to(self.device, non_blocking=True) + self.b_advantages = ( + advantages.reshape( + self.minibatch_rows, self.num_minibatches, self.bptt_horizon + ) + .transpose(0, 1) + .reshape(self.num_minibatches, self.minibatch_size) + ) + + # Re-order the transitions based on the sorted indices + self.returns_np = advantages_np + self.values_np + self.b_obs = self.obs[self.b_idxs_obs] + self.b_actions = self.b_actions[b_idxs].contiguous() + self.b_logprobs = self.b_logprobs[b_idxs] + self.b_dones = self.b_dones[b_idxs] + self.b_values = self.b_values[b_flat] + self.b_returns = self.b_advantages + self.b_values class Utilization(Thread):