[Feature, Example] A3C Atari Implementation for TorchRL#3001
[Feature, Example] A3C Atari Implementation for TorchRL#3001simeetnayan81 wants to merge 17 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3001
Note: Links to docs will display an error until the docs builds have been completed. ❌ 17 Awaiting Approval, 1 New FailureAs of commit 8e7f96c with merge base c764978 ( AWAITING APPROVAL - The following workflows need approval before CI can run:
NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
vmoens
left a comment
There was a problem hiding this comment.
This all looks pretty good!
Could you share a (couple of) learning curve?
Another thing to do before landing is to add it to the sota-implementations CI run:
https://github.com/pytorch/rl/blob/main/.github/unittest/linux_sota/scripts/test_sota.py
Make sure the config passed there is as much barebone as we can - we just want to run the script for a couple of collection / optim iters and make sure it runs without error (not that it properly trains).
We also need to add it to the sota-check runs
|
Thanks @vmoens . I'll add the required changes as well as some training curves. |
|
@vmoens, I have added the required scripts as well. Not getting enough resources and time for hyperparam tuning to generate a proper training curve. |
vmoens
left a comment
There was a problem hiding this comment.
LGTM, just a minor comment on the logger!
| logger = get_logger( | ||
| cfg.logger.backend, | ||
| logger_name="a3c", | ||
| experiment_name=exp_name, | ||
| wandb_kwargs={ | ||
| "config": dict(cfg), | ||
| "project": cfg.logger.project_name, | ||
| "group": cfg.logger.group_name, | ||
| }, | ||
| ) |
There was a problem hiding this comment.
What I usually see is that the logger is only passed to the first worker.
Another thing is that you may want to assume that the logger isn't serializable and should be instantiated locally within the worker.
There was a problem hiding this comment.
Oh yea, I did that because I thought logging any single worker should be a good representative of the global model since anyway the weights are being copied. Logging all the worker might not be really useful but that can be done as well.
| num_workers = cfg.multiprocessing.num_workers | ||
|
|
||
| if num_workers is None: | ||
| num_workers = mp.cpu_count() |
There was a problem hiding this comment.
we should have way fewer workers - I think we need users to tell us how many.
There was a problem hiding this comment.
That can be configured in the config_atari. You want me to explicitly set it to some constant here?
| data_reshape = data.reshape(-1) | ||
| losses = [] | ||
|
|
||
| mini_batches = data_reshape.split(self.mini_batch_size) |
There was a problem hiding this comment.
To shuffle things a bit I usually rely on a replay buffer instance rather than just splitting the data
| for local_param, global_param in zip( | ||
| self.local_actor.parameters(), self.global_actor.parameters() | ||
| ): | ||
| global_param._grad = local_param.grad | ||
|
|
||
| for local_param, global_param in zip( | ||
| self.local_critic.parameters(), self.global_critic.parameters() | ||
| ): | ||
| global_param._grad = local_param.grad | ||
|
|
||
| gn = torch.nn.utils.clip_grad_norm_( | ||
| self.loss_module.parameters(), max_norm=max_grad_norm | ||
| ) |
There was a problem hiding this comment.
can you explain what we do here? What do we use the _grad for?
There was a problem hiding this comment.
_grad is used to store the gradients for each parameter.
We copy local gradients to the global model so the global model can be updated with the optimizer.
This is a key step in A3C, where multiple workers asynchronously update a shared global model.
| torch.set_float32_matmul_precision("high") | ||
|
|
||
|
|
||
| class SharedAdam(torch.optim.Adam): |
There was a problem hiding this comment.
shouldn't we move this to the utils file?
There was a problem hiding this comment.
Sure, will do it
d95de87 to
a6eb18d
Compare
vmoens
left a comment
There was a problem hiding this comment.
I made a few edits.
Can you explain the way the params are shared and updated? I'm not sure I see the logic
There is a global model (shared across all workers) and a local model (each worker has its own copy). |
| self.loss_module.parameters(), max_norm=max_grad_norm | ||
| ) | ||
|
|
||
| self.optimizer.step() |
There was a problem hiding this comment.
After this line, we should sync the weights from global.
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
[Quality] |
Quality | [Quality] Fix typos and add codespell |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
[Quality] |
Quality | [Quality] Fix typos and add codespell |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
Description
Describe your changes in detail.
This PR adds an implementation of the Asynchronous Advantage Actor-Critic (A3C) algorithm for Atari environments in the torchrl/sota-implementations directory. The main files added are:
a3c_atari.py: Contains the A3C worker class, shared optimizer, and main training loop using multiprocessing.
utils_atari.py: Provides utility functions for environment creation, model construction, and evaluation, adapted for Atari tasks.
config_atari.yaml: Configuration file for hyperparameters, environment settings, and logging.
The implementation leverages TorchRL's collectors, objectives, and logging utilities, and is designed to be modular and extensible for research and benchmarking. Some of the utils functions are also borrowed from a2c_atari.
Motivation and Context
This change is required to provide a strong, reproducible baseline for A3C on Atari environments using TorchRL. It enables researchers and practitioners to benchmark and compare reinforcement learning algorithms within the TorchRL ecosystem. The implementation follows best practices for distributed RL and is compatible with TorchRL's API.
This PR solves the issue: #1755
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!