BYOL Single GPU implementation #1
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
) Summary: Pull Request resolved: facebookresearch#343 Some basic changes to make this script work within FBinfra. 1. Register Manifold in PathManager. 1. In order to do #1, create fb/extra_scripts/convert_sharded_checkpoint.y and add necessary dependencies in TARGETS 1. Replace some torch.loads using PathManager. Reviewed By: prigoyal Differential Revision: D29166520 fbshipit-source-id: a61b4eb80d74526b0a7e2d38f973eb688b311a94
|
As per the commit fc1217d addressed the following:
|
| def __init__(self, base_momentum: float, shuffle_batch: bool = True): | ||
| super().__init__() | ||
| self.base_momentum = base_momentum | ||
| self.is_distributed = False |
|
|
||
| class BYOLHook(ClassyHook): | ||
| """ | ||
| TODO: Update description |
|
|
||
| @staticmethod | ||
| def cosine_decay(training_iter, max_iters, initial_value): | ||
| # TODO: Why do we need this min statement? |
There was a problem hiding this comment.
Comment this method.
Put types in the method.
| return initial_value * cosine_decay_value | ||
|
|
||
| @staticmethod | ||
| def target_ema(training_iter, base_ema, max_iters): |
There was a problem hiding this comment.
Let's comment every method in the byol_hooks.py and byol_losses.py and make sure they all have type hints.
|
|
||
| def _build_byol_target_network(self, task: tasks.ClassyTask) -> None: | ||
| """ | ||
| Create the model replica called the target. This will slowly track |
There was a problem hiding this comment.
Improve comment. Something like: "Target network is exponential moving average of online network, ... "
| @torch.no_grad() | ||
| def on_forward(self, task: tasks.ClassyTask) -> None: | ||
| """ | ||
| - Update the target model. |
There was a problem hiding this comment.
Update comment for BYOL (this was copy/pasted from moco).
| @register_loss("byol_loss") | ||
| class BYOLLoss(ClassyLoss): | ||
| """ | ||
| TODO: change description |
| and https://github.com/facebookresearch/moco for a reference implementation, reused here | ||
|
|
||
| Config params: | ||
| embedding_dim (int): head output output dimension |
| "_BYOLLossConfig", ["embedding_dim", "momentum"] | ||
| ) | ||
|
|
||
| def regression_loss(x, y): |
There was a problem hiding this comment.
Type-hints + comments for all these functions.
| @classmethod | ||
| def from_config(cls, config: BYOLLossConfig): | ||
| """ | ||
| Instantiates BYOLLoss from configuration. |
There was a problem hiding this comment.
Put in the config options in the docstring here.
|
|
||
| def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> torch.Tensor: | ||
| """ | ||
| Given the encoder queries, the key and the queue of the previous queries, |
There was a problem hiding this comment.
This comment I think is copy/pasted.
| self.is_distributed = False | ||
|
|
||
| self.momentum = None | ||
| self.inv_momentum = None |
|
|
||
| self.momentum = None | ||
| self.inv_momentum = None | ||
| self.total_iters = None |
Implementation of BYOL: https://arxiv.org/abs/2006.07733 on Single GPU
Issue #190