Skip to content

Conversation

@simeetnayan81
Copy link
Contributor

Description

Describe your changes in detail.
This PR adds an implementation of the Asynchronous Advantage Actor-Critic (A3C) algorithm for Atari environments in the torchrl/sota-implementations directory. The main files added are:

a3c_atari.py: Contains the A3C worker class, shared optimizer, and main training loop using multiprocessing.
utils_atari.py: Provides utility functions for environment creation, model construction, and evaluation, adapted for Atari tasks.
config_atari.yaml: Configuration file for hyperparameters, environment settings, and logging.

The implementation leverages TorchRL's collectors, objectives, and logging utilities, and is designed to be modular and extensible for research and benchmarking. Some of the utils functions are also borrowed from a2c_atari.

Motivation and Context

This change is required to provide a strong, reproducible baseline for A3C on Atari environments using TorchRL. It enables researchers and practitioners to benchmark and compare reinforcement learning algorithms within the TorchRL ecosystem. The implementation follows best practices for distributed RL and is compatible with TorchRL's API.

This PR solves the issue: #1755

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3001

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 2 Cancelled Jobs, 1 Unrelated Failure

As of commit 3ebad77 with merge base fbfd8a8 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 15, 2025
@vmoens vmoens added the new algo New algorithm request or PR label Jun 16, 2025
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all looks pretty good!
Could you share a (couple of) learning curve?
Another thing to do before landing is to add it to the sota-implementations CI run:
https://github.com/pytorch/rl/blob/main/.github/unittest/linux_sota/scripts/test_sota.py
Make sure the config passed there is as much barebone as we can - we just want to run the script for a couple of collection / optim iters and make sure it runs without error (not that it properly trains).
We also need to add it to the sota-check runs

@simeetnayan81
Copy link
Contributor Author

Thanks @vmoens . I'll add the required changes as well as some training curves.

@simeetnayan81
Copy link
Contributor Author

@vmoens, I have added the required scripts as well.

Not getting enough resources and time for hyperparam tuning to generate a proper training curve.

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a minor comment on the logger!

Comment on lines 210 to 219
logger = get_logger(
cfg.logger.backend,
logger_name="a3c",
experiment_name=exp_name,
wandb_kwargs={
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I usually see is that the logger is only passed to the first worker.
Another thing is that you may want to assume that the logger isn't serializable and should be instantiated locally within the worker.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yea, I did that because I thought logging any single worker should be a good representative of the global model since anyway the weights are being copied. Logging all the worker might not be really useful but that can be done as well.

num_workers = cfg.multiprocessing.num_workers

if num_workers is None:
num_workers = mp.cpu_count()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should have way fewer workers - I think we need users to tell us how many.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can be configured in the config_atari. You want me to explicitly set it to some constant here?

data_reshape = data.reshape(-1)
losses = []

mini_batches = data_reshape.split(self.mini_batch_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To shuffle things a bit I usually rely on a replay buffer instance rather than just splitting the data

Comment on lines +90 to +135
for local_param, global_param in zip(
self.local_actor.parameters(), self.global_actor.parameters()
):
global_param._grad = local_param.grad

for local_param, global_param in zip(
self.local_critic.parameters(), self.global_critic.parameters()
):
global_param._grad = local_param.grad

gn = torch.nn.utils.clip_grad_norm_(
self.loss_module.parameters(), max_norm=max_grad_norm
)
Copy link
Collaborator

@vmoens vmoens Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain what we do here? What do we use the _grad for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_grad is used to store the gradients for each parameter.
We copy local gradients to the global model so the global model can be updated with the optimizer.
This is a key step in A3C, where multiple workers asynchronously update a shared global model.

torch.set_float32_matmul_precision("high")


class SharedAdam(torch.optim.Adam):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we move this to the utils file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do it

@vmoens vmoens force-pushed the a3c-implementation branch from d95de87 to a6eb18d Compare July 10, 2025 04:32
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a few edits.
Can you explain the way the params are shared and updated? I'm not sure I see the logic

@simeetnayan81
Copy link
Contributor Author

I made a few edits. Can you explain the way the params are shared and updated? I'm not sure I see the logic

There is a global model (shared across all workers) and a local model (each worker has its own copy).
The global model’s parameters are placed in shared memory so all workers can access and update them.
Each worker interacts with its own environment using its local model.
After collecting experience, the worker computes gradients (via backpropagation) on its local model.
The gradients from the local model are copied to the corresponding parameters of the global model (global_param._grad = local_param.grad).
The optimizer (e.g., Adam) is called on the global model, using the gradients just copied from the local model.
This updates the global model’s parameters.
The local model can be periodically updated (synced) from the global model to keep it up-to-date.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. new algo New algorithm request or PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants