From 4fed0beca64a52aa718371dc3b2cf1fd979197a4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 10 Feb 2021 14:03:24 -0800 Subject: [PATCH 01/82] Fix padding mask for new architectures (#3228) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3227 All models that do **not** make use of group norm, such as - Wav2Vec 2.0 Large (LV-60)* - Wav2Vec 2.0 Large (LV-60) + Self Training * do need this fix IMO to able to correctly run batches through the model. Before this PR, the following code snippet failed: ```python import fairseq import torch # get model wav2vec_path = "data/wav2vec2_vox_960h_new.pt" model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( [wav2vec_path], arg_overrides={"data": "./data"} ) model = model[0] model.eval() # create single input input_wav_0 = torch.randn((1, 2000)) input_wav_1 = torch.randn((1, 3000)) # create batched input batch_input_wav = torch.zeros((2, 3000)) batch_input_wav[0, :input_wav_0.shape[-1]] = input_wav_0 batch_input_wav[1, :input_wav_1.shape[-1]] = input_wav_1 # create padding mask padding_mask = torch.zeros((2, 3000), dtype=torch.bool) padding_mask[0, input_wav_0.shape[-1]:] = True # run batch & single output = model(source=input_wav_0, padding_mask=None)["encoder_out"] batch_output = model(source=batch_input_wav, padding_mask=padding_mask)["encoder_out"] # is equal? print("Is batched forward and simple forward equal?", torch.allclose(output[:,0], batch_output[:output.shape[0], 0], atol=1e-3)) ``` Note: It is assumed that both https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt and https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt were downloaded and stored in the folder data. Also, see [this](https://colab.research.google.com/drive/1ASZ4lVZbKkj-dvRHDl1lo0mCcsaOERlG?usp=sharing) notebook for reproducibility. This PR should fix the behavior and make the above code snippet / notebook run succesfully. ## PR review Gently pinging alexeib for Wav2Vec2 Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3228 Reviewed By: aconneau Differential Revision: D26373721 Pulled By: alexeib fbshipit-source-id: 3d5aca2f8136d1a8c4b5b4bc9c03cd05a69a3b52 --- fairseq/models/wav2vec/wav2vec2.py | 32 +++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 783ebcfe6b..644add7b17 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -443,6 +443,21 @@ def compute_preds(self, x, y, negatives): return logits + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + conv_cfg_list = eval(self.cfg.conv_feature_layers) + + for i in range(len(conv_cfg_list)): + input_lengths = _conv_out_length(input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]) + + return input_lengths.to(torch.long) + def forward(self, source, padding_mask=None, mask=True, features_only=False): if self.feature_grad_mult > 0: @@ -460,11 +475,18 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): unmasked_features = features.clone() if padding_mask is not None: - extra = padding_mask.size(1) % features.size(1) - if extra > 0: - padding_mask = padding_mask[:, :-extra] - padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) - padding_mask = padding_mask.all(-1) + input_lengths = (1 - padding_mask.long()).sum(-1) + # apply conv formula to get real output_lengths + output_lengths = self._get_feat_extract_output_lengths(input_lengths) + + padding_mask = torch.zeros( + features.shape[:2], dtype=features.dtype, device=features.device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + padding_mask[(torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1)] = 1 + padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() if self.post_extract_proj is not None: features = self.post_extract_proj(features) From ac90cb3085439d15af1a33cf0a0b1a6703f07413 Mon Sep 17 00:00:00 2001 From: Ruslan Mavlyutov Date: Wed, 10 Feb 2021 14:57:17 -0800 Subject: [PATCH 02/82] Extra logging to confirm OOM source Reviewed By: myleott, chtran Differential Revision: D26348808 fbshipit-source-id: 010ef00024e02c09ec35b624f0713ce5f1f387b4 --- fairseq/trainer.py | 1 + fairseq_cli/train.py | 1 + 2 files changed, 2 insertions(+) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 49129a7fb0..24f72e2f9a 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -277,6 +277,7 @@ def consolidate_optimizer(self): def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if self.is_data_parallel_master: # only save one checkpoint + logger.info(f"Saving checkpoint to {filename}") extra_state["metrics"] = metrics.state_dict() extra_state["previous_training_time"] = self.cumulative_training_time() checkpoint_utils.save_state( diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 9af7568a77..ec4890b9e6 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -236,6 +236,7 @@ def train( valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() + logger.info("Start iterating over samples") for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i From 7061a0ff83872ac491ba5963eb7fc04cb10d57c4 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Wed, 10 Feb 2021 16:25:25 -0800 Subject: [PATCH 03/82] better error handling for expired handles Summary: At the start of the half there were some expired handles and it was annoying to track down which datasets were responsible when sampling data among multiple datasets and which flows were running them. Lets improve the error message to address several pain points 1. Explicitly tell the user which dataset has expired handles 2. Link to a scuba query to enable the user to find all flows that have expired handles 3. Fail job if 10k handles have expired, rather than if 10k handles in a row have expired. This can detect failures from datasets that have for example 50% expired handles 4. add logging when handles fail Reviewed By: cruvadom Differential Revision: D26187820 fbshipit-source-id: 771a359ea01de80b38932921346e98cff812f2f7 --- fairseq/data/multi_corpus_dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index 9c7f1cb976..7207174bf3 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -126,7 +126,11 @@ def __len__(self): def __getitem__(self, index): index, key = self._map_index(index) - return self.datasets[key][index] + try: + return self.datasets[key][index] + except Exception as e: + e.args = (f"Error from {key} dataset", *e.args) + raise def collater(self, samples): """ From ee48d1b95835a0e5fa2129219d205f8d9e748b76 Mon Sep 17 00:00:00 2001 From: pritam Date: Thu, 11 Feb 2021 09:41:54 -0800 Subject: [PATCH 04/82] Use torch pipe if available in fairseq. (#3149) Summary: fairscale.nn.Pipe has been ported to PyTorch: https://github.com/pytorch/pytorch/blob/master/torch/distributed/pipeline/sync/pipe.py#L138. As a result, modifying the pipeline transformer to use PyTorch pipe if available. This change depends on https://github.com/pytorch/pytorch/pull/50860. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3149 Test Plan: ``` python train.py ru_en_bin/ --arch transformer_iwslt_de_en_pipeline_parallel --share-decoder-input-output-embed --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 --dropout 0.3 --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --max-tokens 4096 --eval-bleu --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' --eval-bleu-detok moses --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --maximize-best-checkpoint-metric --pipeline-model-parallel --pipeline-balance '[1,3,5,3,3,1]' --pipeline-devices '[0,1,0,2,3,0]' --pipeline-chunks 16 --distributed-world-size 1 --distributed-no-spawn --disable-validation --max-epoch 1 ``` Output with torch pipe: ``` 2021-01-20 16:13:35 | INFO | train | epoch 001 | loss 12.676 | nll_loss 12.331 | ppl 5151.97 | wps 5108 | ups 1.66 | wpb 3081.6 | bsz 131.6 | num_updates 380 | lr 4.75e-05 | gnorm 2.08 | train_wall 229 | wall 233 2021-01-20 16:13:36 | INFO | fairseq_cli.train | done training in 233.1 seconds ``` Output with fairscale pipe: ``` 2021-01-20 14:13:59 | INFO | train | epoch 001 | loss 12.677 | nll_loss 12.331 | ppl 5152.07 | wps 5198.9 | ups 1.69 | wpb 3081.6 | bsz 131.6 | num_updates 380 | lr 4.75e-05 | gnorm 2.08 | train_wall 224 | wall 228 2021-01-20 14:13:59 | INFO | fairseq_cli.train | done training in 228.0 seconds ``` Reviewed By: myleott Differential Revision: D26204633 Pulled By: shruti-bh fbshipit-source-id: 535f816e8d149b47fc6ba8385981accf67257257 --- .../pipeline_parallel_transformer/model.py | 128 +++++++++++++----- 1 file changed, 92 insertions(+), 36 deletions(-) diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py index 7873611214..7f30dd98bb 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -39,15 +39,47 @@ DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 +TORCH_PIPE = False +RPC_INIT = False + +def import_pipe(): + global TORCH_PIPE + global RPC_INIT + try: + from torch.distributed.pipeline.sync import Pipe # noqa + global Pipe + from torch.distributed.pipeline.sync.utils import partition_model + global partition_model + from torch.distributed import rpc + import tempfile + TORCH_PIPE = True + # Initialize single process RPC agent since TORCH_PIPE requires + # RRef. RRef depends on RPC being initialized and as a result we initialize + # RPC with a single node. + tmpfile = tempfile.NamedTemporaryFile() + if not RPC_INIT: + rpc.init_rpc( + name="worker", + rank=0, + world_size=1, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + init_method="file://{}".format(tmpfile.name), + ) + ) + RPC_INIT = True + logger.info('Using torch pipe') + except ImportError: + try: + from fairscale.nn import Pipe # noqa + logger.info('Using fairscale pipe') + except ImportError: + raise ImportError("Please install fairscale with: pip install fairscale") @register_model("pipeline_parallel_transformer") class PipelineParallelTransformerModel(BaseFairseqModel): def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): - try: - from fairscale.nn import Pipe - except ImportError: - raise ImportError("Please install fairscale with: pip install fairscale") + import_pipe() super().__init__() assert isinstance(encoder, FairseqEncoder) assert isinstance(decoder, FairseqDecoder) @@ -65,13 +97,20 @@ def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): self.num_decoder_modules = len(decoder_module_list) module_list = encoder_module_list + decoder_module_list self.devices = devices - self.model = Pipe( - nn.Sequential(*module_list), - balance=balance, - devices=devices, - chunks=chunks, - checkpoint=checkpoint, - ) + if TORCH_PIPE: + self.model = Pipe( + partition_model(nn.Sequential(*module_list), balance, devices), + chunks=chunks, + checkpoint=checkpoint, + ) + else: + self.model = Pipe( + nn.Sequential(*module_list), + balance=balance, + devices=devices, + chunks=chunks, + checkpoint=checkpoint, + ) self.encoder_max_positions = self.max_positions_helper( encoder.embedding_layer, "max_source_positions" ) @@ -87,7 +126,10 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens): if self.training: input_lst = [src_tokens, src_lengths, prev_output_tokens] input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) - return self.model(input) + if TORCH_PIPE: + return self.model(input).local_value() + else: + return self.model(input) else: assert self.encoder is not None and self.decoder is not None, ( "encoder and decoder need to be initialized by " @@ -425,10 +467,7 @@ class TransformerEncoder(FairseqEncoder): def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) - try: - from fairscale.nn import Pipe - except ImportError: - raise ImportError("Please install fairscale with: pip install fairscale") + import_pipe() self.use_pipeline = encoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) @@ -449,13 +488,20 @@ def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): f"Sum of encoder_balance={encoder_balance} is not equal " + f"to num_encoder_modules={len(encoder_module_list)}" ) - self.model = Pipe( - module=nn.Sequential(*encoder_module_list), - balance=encoder_balance, - devices=encoder_devices, - chunks=args.pipeline_chunks, - checkpoint=args.pipeline_checkpoint, - ) + if TORCH_PIPE: + self.model = Pipe( + module=partition_model(nn.Sequential(*encoder_module_list), encoder_balance, encoder_devices), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + else: + self.model = Pipe( + module=nn.Sequential(*encoder_module_list), + balance=encoder_balance, + devices=encoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) def forward(self, src_tokens, src_lengths): """ @@ -485,7 +531,10 @@ def forward(self, src_tokens, src_lengths): input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens) if self.use_pipeline: input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) - encoder_out = self.model(input_tuple) + if TORCH_PIPE: + encoder_out = self.model(input_tuple).local_value() + else: + encoder_out = self.model(input_tuple) else: encoder_embed_output_tuple = self.embedding_layer(input_tuple) encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple) @@ -561,10 +610,7 @@ def __init__( ): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) - try: - from fairscale.nn import Pipe - except ImportError: - raise ImportError("Please install fairscale with: pip install fairscale") + import_pipe() self.use_pipeline = decoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) @@ -586,13 +632,20 @@ def __init__( f"Sum of decoder_balance={decoder_balance} is not equal " + f"to num_decoder_modules={len(decoder_module_list)}" ) - self.model = Pipe( - module=nn.Sequential(*decoder_module_list), - balance=decoder_balance, - devices=decoder_devices, - chunks=args.pipeline_chunks, - checkpoint=args.pipeline_checkpoint, - ) + if TORCH_PIPE: + self.model = Pipe( + module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + else: + self.model = Pipe( + module=nn.Sequential(*decoder_module_list), + balance=decoder_balance, + devices=decoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) def forward( self, @@ -622,7 +675,10 @@ def forward( ) if self.use_pipeline: input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) - return (self.model(input_tuple),) + if TORCH_PIPE: + return (self.model(input_tuple).local_value(),) + else: + return (self.model(input_tuple),) else: embed_layer_output = self.embedding_layer(input_tuple) state = self.decoder_layers(embed_layer_output) From fd7c2a8b371c2abf645f558282221eba6833f35f Mon Sep 17 00:00:00 2001 From: Mary Williamson Date: Thu, 11 Feb 2021 13:53:33 -0800 Subject: [PATCH 05/82] More informative exception when numpy version changes (#3231) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: More informative exception when numpy version changes to ask the user to recompile Cython files # Before submitting - [With myleott ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [N/A ] Did you make sure to update the docs? - [N/A ] Did you write any new necessary tests? ## What does this PR do? Raises a more informative error to tell the user to recompile Cython files after an update to the numpy version. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3231 Reviewed By: myleott Differential Revision: D26375174 Pulled By: mwillwork fbshipit-source-id: f0a93e162bc4cf84619581110d21bea907baf7fc --- fairseq/data/data_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 1a83063542..47d8492ec9 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -307,6 +307,11 @@ def batch_by_size( "Please build Cython components with: `pip install --editable .` " "or `python setup.py build_ext --inplace`" ) + except ValueError: + raise ValueError( + "Please build (or rebuild) Cython components with: `pip install " + " --editable .` or `python setup.py build_ext --inplace`." + ) max_tokens = max_tokens if max_tokens is not None else -1 max_sentences = max_sentences if max_sentences is not None else -1 From 66e1803c60272602c719a5ba75acef1c530066ef Mon Sep 17 00:00:00 2001 From: alexeib Date: Thu, 11 Feb 2021 13:59:08 -0800 Subject: [PATCH 06/82] save task state in the checkpoint (#1562) Summary: this allows tasks to declare some properties they'd like to save in the checkpoint (such as a dictionary), which are loaded when checkpoint is restored. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1562 Test Plan: tested by training a new wav2vec model, then finetuning it, then decoding it and making sure the dict only loaded once, during fine tuning process (and was obtained from checkpoint for decoding) Reviewed By: myleott, gwenzek Differential Revision: D25937974 Pulled By: alexeib fbshipit-source-id: b9908042f76ec8cda943f33885eb9b1f121662ae --- examples/speech_recognition/hydra/infer.py | 4 +- examples/speech_recognition/infer.py | 13 +++--- fairseq/checkpoint_utils.py | 5 +++ fairseq/distributed/utils.py | 3 ++ fairseq/tasks/audio_pretraining.py | 28 ++++++------ fairseq/tasks/fairseq_task.py | 50 ++++++++++++++++++++-- fairseq/trainer.py | 5 ++- 7 files changed, 77 insertions(+), 31 deletions(-) diff --git a/examples/speech_recognition/hydra/infer.py b/examples/speech_recognition/hydra/infer.py index 6afa066f25..b1c985bc0d 100644 --- a/examples/speech_recognition/hydra/infer.py +++ b/examples/speech_recognition/hydra/infer.py @@ -10,10 +10,9 @@ import os import shutil import sys -from argparse import Namespace from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import editdistance import torch @@ -26,7 +25,6 @@ CommonEvalConfig, DatasetConfig, DistributedTrainingConfig, FairseqDataclass, GenerationConfig) -from fairseq.dataclass.initialize import hydra_init from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.logging.progress_bar import BaseProgressBar from fairseq.models.fairseq_model import FairseqModel diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index 5a582c54af..f4efbf39c8 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -144,11 +144,11 @@ def process_predictions( print( "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"] ) - # only score top hypothesis - if not args.quiet: - logger.debug("HYPO:" + hyp_words) - logger.debug("TARGET:" + tgt_words) - logger.debug("___________________") + + if not args.quiet: + logger.info("HYPO:" + hyp_words) + logger.info("TARGET:" + tgt_words) + logger.info("___________________") hyp_words = hyp_words.split() tgt_words = tgt_words.split() @@ -216,7 +216,6 @@ def main(args, task=None, model_state=None): use_cuda = torch.cuda.is_available() and not args.cpu - logger.info("| decoding with criterion {}".format(args.criterion)) task = tasks.setup_task(args) @@ -227,7 +226,7 @@ def main(args, task=None, model_state=None): task.load_dataset(args.gen_subset) else: logger.info("| loading model(s) from {}".format(args.path)) - models, saved_cfg = checkpoint_utils.load_model_ensemble( + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( utils.split_paths(args.path), arg_overrides=ast.literal_eval(args.model_overrides), task=task, diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 2f209b6b39..55a546356e 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -349,6 +349,9 @@ def load_model_ensemble_and_task( if task is None: task = tasks.setup_task(cfg.task) + if "task_state" in state: + task.load_state_dict(state["task_state"]) + # build model for ensemble model = task.build_model(cfg.model) @@ -403,6 +406,7 @@ def save_state( num_updates, optim_history=None, extra_state=None, + task=None, **kwargs, ): from fairseq import utils @@ -425,6 +429,7 @@ def save_state( } ], "extra_state": extra_state, + "task_state": task.state_dict() if task is not None else {} } if utils.has_parameters(criterion): state_dict["criterion"] = criterion.state_dict() diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index e3d8e1e0d3..710ca18628 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -325,6 +325,9 @@ def distributed_main(i, main, cfg: FairseqConfig, kwargs): main(cfg, **kwargs) + if torch.distributed.is_initialized(): + torch.distributed.barrier(get_global_group()) + def call_main(cfg: FairseqConfig, main, **kwargs): if cfg.distributed_training.distributed_init_method is None: diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 7c82777331..92685160d4 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -15,7 +15,6 @@ from omegaconf import MISSING from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset, encoders -from fairseq.data.data_utils import post_process from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.configs import GenerationConfig @@ -98,16 +97,14 @@ class AudioPretrainingTask(FairseqTask): def __init__( self, cfg: AudioPretrainingConfig, - source_dictionary=None, - target_dictionary=None, ): super().__init__(cfg) - self._target_dictionary = target_dictionary - self._source_dictionary = source_dictionary if cfg.eval_wer: assert cfg.labels is not None, "eval_wer can only be set during fine-tuning" self.blank_symbol = "" + self.state.add_factory("target_dictionary", self.load_target_dictionary) + @classmethod def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): """Setup the task (e.g., load dictionaries). @@ -116,13 +113,13 @@ def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): cfg (AudioPretrainingConfig): configuration of this task """ - if cfg.labels: - dict_path = os.path.join(cfg.data, f"dict.{cfg.labels}.txt") - target_dictionary = Dictionary.load(dict_path) - else: - target_dictionary = None + return cls(cfg) - return cls(cfg, target_dictionary=target_dictionary) + def load_target_dictionary(self): + if self.cfg.labels: + dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt") + return Dictionary.load(dict_path) + return None def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): data_path = self.cfg.data @@ -136,7 +133,7 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): manifest = os.path.join(data_path, "{}.tsv".format(split)) self.datasets[split] = FileAudioDataset( manifest, - sample_rate=task_cfg.sample_rate, + sample_rate=task_cfg.get('sample_rate', self.cfg.sample_rate), max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.max_sample_size, min_length=self.cfg.min_sample_size, @@ -146,7 +143,6 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") - labels = [] with open(label_path, "r") as f: labels = [ line for i, line in enumerate(f) @@ -166,18 +162,18 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, - add_to_input=task_cfg.autoregressive, + add_to_input=task_cfg.get('autoregressive', False), ) @property def source_dictionary(self): - return self._source_dictionary + return None @property def target_dictionary(self): """Return the :class:`~fairseq.data.Dictionary` for the language model.""" - return self._target_dictionary + return self.state.target_dictionary def max_positions(self): """Maximum input length supported by the encoder.""" diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 34264bdc01..04025023fa 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -7,7 +7,7 @@ import os import warnings from argparse import Namespace -from typing import List +from typing import Any, Callable, Dict, List import torch from fairseq import metrics, search, tokenizer, utils @@ -20,10 +20,45 @@ logger = logging.getLogger(__name__) +class StatefulContainer(object): + + _state: Dict[str, Any] = dict() + _factories: Dict[str, Callable[[], Any]] = dict() + + def add_factory(self, name, factory: Callable[[], Any]): + self._factories[name] = factory + + def merge_state_dict(self, state_dict: Dict[str, Any]): + self._state.update(state_dict) + + @property + def state_dict(self) -> Dict[str, Any]: + return self._state + + def __getattr__(self, name): + if name not in self._state and name in self._factories: + self._state[name] = self._factories[name]() + + if name in self._state: + return self._state[name] + + raise AttributeError(f"Task state has no factory for attribute {name}") + + class FairseqTask(object): """ Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss. + + Tasks have limited statefulness. In particular, state that needs to be + saved to/loaded from checkpoints needs to be stored in the `self.state` + :class:`StatefulContainer` object. For example:: + + self.state.add_factory("dictionary", self.load_dictionary) + print(self.state.dictionary) # calls self.load_dictionary() + + This is necessary so that when loading checkpoints, we can properly + recreate the task state after initializing the task instance. """ @classmethod @@ -42,10 +77,13 @@ def logging_outputs_can_be_summed(criterion) -> bool: """ return criterion.logging_outputs_can_be_summed() + cfg: FairseqDataclass + datasets: Dict[str, FairseqDataset] = dict() + dataset_to_epoch_iter: Dict[FairseqDataset, Any] = dict() + state: StatefulContainer = StatefulContainer() + def __init__(self, cfg: FairseqDataclass, **kwargs): self.cfg = cfg - self.datasets = {} - self.dataset_to_epoch_iter = {} @classmethod def load_dictionary(cls, filename): @@ -514,6 +552,12 @@ def reduce_metrics(self, logging_outputs, criterion): criterion.__class__.reduce_metrics(logging_outputs) + def state_dict(self): + return self.state.state_dict + + def load_state_dict(self, state_dict: Dict[str, Any]): + self.state.merge_state_dict(state_dict) + def max_positions(self): """Return the max input length allowed by the task.""" return None diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 24f72e2f9a..e860fb1832 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -288,8 +288,9 @@ def save_checkpoint(self, filename, extra_state): self.optimizer, self.lr_scheduler, self.get_num_updates(), - self._optim_history, - extra_state, + optim_history=self._optim_history, + extra_state=extra_state, + task=self.task, ) logger.info(f"Finished saving checkpoint to {filename}") From 138265ce15d198e6baceae334effed8fb384a286 Mon Sep 17 00:00:00 2001 From: Kritika Singh Date: Thu, 11 Feb 2021 15:39:32 -0800 Subject: [PATCH 07/82] Make wav2vec_asr encoder compatible with pyspeech fst decoder Summary: - I don't think there is a convention for the shapes of `encoder_out` and `encoder_padding_mask` in fairseq but `fst_external_decoder.py` expects `encoder_padding_mask` to be of shape T x B. `encoder_padding_mask` also seems unused in the fairseq [CTC criterion and w2l decoder integration](https://fburl.com/diffusion/ms1zi2px) so taking the easy way out and changing its shape. - Also checking in some changes to the pyspeech audio_pretraining task required to make decoding work Reviewed By: alexeib Differential Revision: D26382442 fbshipit-source-id: 87c8f9433026c0e011847f4e2e094beb2cd2182c --- fairseq/models/wav2vec/wav2vec2_asr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index bbd2ab9ec5..9cd17b635c 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -158,7 +158,7 @@ def get_normalized_probs(self, net_output, log_probs): def get_logits(self, net_output): logits = net_output["encoder_out"] - padding = net_output["encoder_padding_mask"] + padding = net_output["padding_mask"] if padding is not None and padding.any(): padding = padding.T logits[padding][...,0] = 0 @@ -359,7 +359,7 @@ def forward(self, source, padding_mask, tbc=True, **kwargs): return { "encoder_out": x, # T x B x C - "encoder_padding_mask": padding_mask, # B x T + "encoder_padding_mask": padding_mask.transpose(0, 1), # T x B "padding_mask": padding_mask, } @@ -539,7 +539,7 @@ def extract_features( x, attn, _ = layer( x, encoder_out["encoder_out"] if encoder_out is not None else None, - encoder_out["encoder_padding_mask"] + encoder_out["padding_mask"] if encoder_out is not None else None, incremental_state, From 1d5b075e3f30fd3f28af4c8851e8659285ded230 Mon Sep 17 00:00:00 2001 From: alexeib Date: Thu, 11 Feb 2021 18:12:11 -0800 Subject: [PATCH 08/82] fix fairseqlm decoder with flashlight chnages (#1617) Summary: fixes fairseqlm integration with flashlight (formerly wav2letter) decoder Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1617 Reviewed By: xuqiantong Differential Revision: D26415650 Pulled By: alexeib fbshipit-source-id: 813684ba55047e92378f508101ff1eec55754420 --- examples/speech_recognition/w2l_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index 706d9f1433..8b158293a0 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -431,7 +431,7 @@ def __init__(self, args, tgt_dict): self.silence, self.blank, self.unk_word, - self.asg_transitions, + [], self.unit_lm, ) else: From 506a8e0f45c1206b1306276fed9cec92c7061dd0 Mon Sep 17 00:00:00 2001 From: alexeib Date: Thu, 11 Feb 2021 21:22:42 -0800 Subject: [PATCH 09/82] seq2seq autoregressive flag check (#1618) Summary: raise an exception if trying to use wav2vec seq2seq finetuning without autoregressive flag Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1618 Reviewed By: xuqiantong Differential Revision: D26417249 Pulled By: alexeib fbshipit-source-id: 777b6d170b0f8196746e03b399e4d7c21ac0b837 --- fairseq/models/wav2vec/wav2vec2_asr.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index 9cd17b635c..afa51299b6 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -220,6 +220,7 @@ class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig): share_decoder_input_output_embed: bool = field( default=False, metadata={"help": "share decoder input and output embeddings"} ) + autoregressive: bool = II("task.autoregressive") @register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig) @@ -231,6 +232,8 @@ def __init__(self, encoder, decoder): def build_model(cls, cfg: Wav2Vec2Seq2SeqConfig, task: FairseqTask): """Build a new model instance.""" + assert cfg.autoregressive, "Please set task.autoregressive=true for seq2seq asr models" + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim): From 7ffb40d9c8e33b272e85604892a0935d8e57bb0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Pedro=20Megid=20Carrilho?= Date: Fri, 12 Feb 2021 00:26:03 -0800 Subject: [PATCH 10/82] Fix typo Wav2Vec2 README.md (#3240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3240 Reviewed By: aconneau Differential Revision: D26420073 Pulled By: alexeib fbshipit-source-id: 5939535b945a64e61d655cd36dc955ae46410bfb --- examples/wav2vec/README.md | 294 ------------------------------------- 1 file changed, 294 deletions(-) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 663adf97dc..e69de29bb2 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -1,294 +0,0 @@ -# wav2vec 2.0 - -wav2vec 2.0 learns speech representations on unlabeled data as described in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](https://arxiv.org/abs/2006.11477). - -We learned speech representations in multiple languages as well in [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979). - -We also combined wav2vec 2.0 with self-training in [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430). - -## Pre-trained models - -Model | Finetuning split | Dataset | Model -|---|---|---|--- -Wav2Vec 2.0 Base | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) -Wav2Vec 2.0 Base | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_10m.pt) -Wav2Vec 2.0 Base | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_100h.pt) -Wav2Vec 2.0 Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_960h.pt) -Wav2Vec 2.0 Large | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/libri960_big.pt) -Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_10m.pt) -Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt) -Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt) -Wav2Vec 2.0 Large (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt) -Wav2Vec 2.0 Large (LV-60)* | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_new.pt) -Wav2Vec 2.0 Large (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_new.pt) -Wav2Vec 2.0 Large (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt) -Wav2Vec 2.0 Large (LV-60) + Self Training * | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_pl.pt) -Wav2Vec 2.0 Large (LV-60) + Self Training * | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_pl.pt) -Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt) - -\* updated (Oct. 24, 2020) - -We also release multilingual pre-trained wav2vec 2.0 (XLSR) models: - -Model | Architecture | Hours | Languages | Datasets | Model -|---|---|---|---|---|--- -XLSR-53 | Large | 56k | 53 | MLS, CommonVoice, BABEL | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt) - -The XLSR model uses the following datasets for multilingual pretraining: - -* **[MLS: Multilingual LibriSpeech](https://indico2.conference4me.psnc.pl/event/35/contributions/3585/attachments/1060/1101/Wed-2-6-10.pdf)** (8 languages, 50.7k hours): *Dutch, English, French, German, Italian, Polish, Portuguese, Spanish* - -* **[CommonVoice](https://commonvoice.mozilla.org/en/languages)** (36 languages, 3.6k hours): *Arabic, Basque, Breton, Chinese (CN), Chinese (HK), Chinese (TW), Chuvash, Dhivehi, Dutch, English, Esperanto, Estonian, French, German, Hakh-Chin, Indonesian, Interlingua, Irish, Italian, Japanese, Kabyle, Kinyarwanda, Kyrgyz, Latvian, Mongolian, Persian, Portuguese, Russian, Sakha, Slovenian, Spanish, Swedish, Tamil, Tatar, Turkish, Welsh* (see also [finetuning splits]([https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz]) from [this paper](https://arxiv.org/abs/2002.02848)). - -* **[Babel](https://catalog.ldc.upenn.edu/byyear)** (17 languages, 1.7k hours): *Assamese, Bengali, Cantonese, Cebuano, Georgian, Haitian, Kazakh, Kurmanji, Lao, Pashto, Swahili, Tagalog, Tamil, Tok, Turkish, Vietnamese, Zulu* - - -## Training a new model with the CLI tools - -Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) - -### Prepare training data manifest: - -First, install the `soundfile` library: -```shell script -pip install soundfile -``` - -Next, run: - -```shell script -$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid -``` - -$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read. - -$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. -To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a -separately pre-processed manifest file. - -### Train a wav2vec 2.0 base model: - -This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper - -Note that the input is expected to be single channel, sampled at 16 kHz - -```shell script -$ fairseq-hydra-train \ - task.data=/path/to/data \ - --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ - --config-name wav2vec2_base_librispeech -``` - -Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) -`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 64/k - -### Train a wav2vec 2.0 large model: - -This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper - -```shell script -$ fairseq-hydra-train \ - task.data=/path/to/data \ - --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ - --config-name wav2vec2_large_librivox -``` - -Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) -`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 128/k - -### Fine-tune a pre-trained model with CTC: - -Fine-tuning a model requires parallel audio and labels file, as well as a vocabulary file in fairseq format. -A letter vocabulary can be downloaded [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). -An example [script](libri_labels.py) that generates labels for the Librispeech dataset from the tsv file produced by wav2vec_manifest.py can be used as follows: - -```shell script -split=train -$ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $split -``` - -Fine-tuning on 100h of Librispeech with letter targets: -```shell script -$ fairseq-hydra-train \ - distributed_training.distributed_port=$PORT \ - task.data=/path/to/data \ - model.w2v_path=/path/to/model.pt \ - --config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \ - --config-name base_100h -``` - -There are other config files in the config/finetuning directory that can be used to fine-tune on other splits. -You can specify the right config via the `--config-name` parameter. - -Note: you can simulate 24 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) -`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 24/k - -Decoding with a language model during training requires flashlight [python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter). -If you want to use a language model, add `+criterion.wer_args='[/path/to/kenlm, /path/to/lexicon, 2, -1]'` to the command line. - -### Evaluating a CTC model: - -Evaluating a CTC model with a language model requires [flashlight python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter) to be installed. - -Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the [wav2letter model repository](https://github.com/facebookresearch/wav2letter/tree/master/recipes/sota/2019). -Be sure to upper-case the language model vocab after downloading it. - -Letter dictionary for pre-trained models can be found [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). - -Next, run the evaluation command: - -```shell script -$subset=dev_other -python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_pretraining \ ---nbest 1 --path /path/to/model --gen-subset $subset --results-path /path/to/save/results/for/sclite --w2l-decoder kenlm \ ---lm-model /path/to/kenlm.bin --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 \ ---post-process letter -``` - -To get raw numbers, use --w2l-decoder viterbi and omit the lexicon. To use the transformer language model, use --w2l-decoder fairseqlm. - -## Use wav2vec 2.0 with 🤗Transformers: - -Wav2Vec2 is also available in the [🤗Transformers library](https://github.com/huggingface/transformers) since vesion 4.3. - -Pretrained Models can be found on the [hub](https://huggingface.co/models?filter=wav2vec2) -and documentation can be found [here](https://huggingface.co/transformers/master/model_doc/wav2vec2.html). - -Usage example: - -```python -# !pip install transformers -import soundfile as sf -import torch -from transformers import Wav2Vec2ForMaskedLM, Wav2Vec2Tokenizer - -# load pretrained model -tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") -model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h") - -# load audio -audio_input, _ = sf.read("path/to/audio/file") - -# transcribe -input_values = tokenizer(audio_input, return_tensors="pt").input_values -logits = model(input_values).logits -predicted_ids = torch.argmax(logits, dim=-1) -transcription = tokenizer.batch_decode(predicted_ids)[0] -``` - -# wav2vec - -Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862). - -## Pre-trained models - -Description | Dataset | Model ----|---|--- -Wav2Vec large | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt) - -#### Example usage: -```python -import torch -import fairseq - -cp_path = '/path/to/wav2vec.pt' -model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) -model = model[0] -model.eval() - -wav_input_16khz = torch.randn(1,10000) -z = model.feature_extractor(wav_input_16khz) -c = model.feature_aggregator(z) -``` - -## Training a new model with the CLI tools - -Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate files 10 to 30 seconds in length) - -### Prepare training data manifest: - -``` -$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav -``` - -### Train a wav2vec model: - -``` -$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ ---arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ ---conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ ---conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ ---skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ ---max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test -``` - -### Extract embeddings from the downstream task data: - -``` -$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/wav2vec_featurize.py --input /path/to/task/waves --output /path/to/output \ ---model /model/path/checkpoint_best.pt --split train valid test -``` - -# vq-wav2vec - -Example to train a vq-wav2vec model as described in [vq-wav2vec: Self-Supervised Learning of Discrete Speech Representations (Baevski et al., 2019)](https://arxiv.org/abs/1910.05453). - -These models are also used in [Effectiveness of self-supervised pre-training for speech recognition (Baevski et al., 2019)](https://arxiv.org/abs/1911.03912). - -## Pre-trained models - -Description | Dataset | Model ----|---|--- -vq-wav2vec Gumbel | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt) -vq-wav2vec K-means | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt) -Roberta on K-means codes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/bert_kmeans.tar) - -#### Example usage: -```python -import torch -import fairseq - -cp = torch.load('/path/to/vq-wav2vec.pt') -model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp]) -model = model[0] -model.eval() - -wav_input_16khz = torch.randn(1,10000) -z = model.feature_extractor(wav_input_16khz) -_, idxs = model.vector_quantizer.forward_idx(z) -print(idxs.shape) # output: torch.Size([1, 60, 2]), 60 timesteps with 2 indexes corresponding to 2 groups in the model -``` - -## Training a new model with the CLI tools - -Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) - -### Prepare training data manifest: - -``` -$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav -``` - -### Train a gumbel vq-wav2vec model: - -``` -$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 \ ---save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 \ ---optimizer adam --lr 1e-05 --lr-scheduler cosine \ ---conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)] \ ---conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ ---activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \ ---log-keys ["prob_perplexity","code_perplexity","temp"] --vq-type gumbel --vq-groups 2 --vq-depth 2 \ ---combine-groups --vq-vars 320 --vq-temp (2,0.5,0.999995) --prediction-steps 12 --warmup-updates 1000 \ ---warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 --max-sample-size 150000 \ ---max-tokens 300000 --cross-sample-negatives 0 --update-freq 1 --seed 2 --skip-invalid-size-inputs-valid-test -``` - -for k-means training, set vq-type with "kmeans" and add --loss-weights [1] argument. Pre-trained models were trained on 16 GPUs. - -### Tokenize audio data (e.g. for BERT training): - -``` -$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/vq-wav2vec_featurize.py --data-dir /manifest/path --output-dir /path/to/output \ ---checkpoint /model/path/checkpoint_best.pt --split train valid test --extension tsv -``` From f3b6f5817fbee59057ae2506f01502ea3c301b4b Mon Sep 17 00:00:00 2001 From: alexeib Date: Fri, 12 Feb 2021 11:32:52 -0800 Subject: [PATCH 11/82] Fix w2v readme (#1621) Summary: somehow merging previous pull request deleted the readme Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1621 Reviewed By: michaelauli Differential Revision: D26429893 Pulled By: alexeib fbshipit-source-id: 3e6ed1e4698e67e56e0b88d304f42907a4f6cf41 --- examples/wav2vec/README.md | 294 +++++++++++++++++++++++++++++++++++++ 1 file changed, 294 insertions(+) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index e69de29bb2..e95f292b51 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -0,0 +1,294 @@ +# wav2vec 2.0 + +wav2vec 2.0 learns speech representations on unlabeled data as described in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](https://arxiv.org/abs/2006.11477). + +We learned speech representations in multiple languages as well in [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979). + +We also combined wav2vec 2.0 with self-training in [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430). + +## Pre-trained models + +Model | Finetuning split | Dataset | Model +|---|---|---|--- +Wav2Vec 2.0 Base | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) +Wav2Vec 2.0 Base | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_10m.pt) +Wav2Vec 2.0 Base | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_100h.pt) +Wav2Vec 2.0 Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_960h.pt) +Wav2Vec 2.0 Large | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/libri960_big.pt) +Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_10m.pt) +Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt) +Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt) +Wav2Vec 2.0 Large (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt) +Wav2Vec 2.0 Large (LV-60)* | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_new.pt) +Wav2Vec 2.0 Large (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_new.pt) +Wav2Vec 2.0 Large (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_pl.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_pl.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt) + +\* updated (Oct. 24, 2020) + +We also release multilingual pre-trained wav2vec 2.0 (XLSR) models: + +Model | Architecture | Hours | Languages | Datasets | Model +|---|---|---|---|---|--- +XLSR-53 | Large | 56k | 53 | MLS, CommonVoice, BABEL | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt) + +The XLSR model uses the following datasets for multilingual pretraining: + +* **[MLS: Multilingual LibriSpeech](https://indico2.conference4me.psnc.pl/event/35/contributions/3585/attachments/1060/1101/Wed-2-6-10.pdf)** (8 languages, 50.7k hours): *Dutch, English, French, German, Italian, Polish, Portuguese, Spanish* + +* **[CommonVoice](https://commonvoice.mozilla.org/en/languages)** (36 languages, 3.6k hours): *Arabic, Basque, Breton, Chinese (CN), Chinese (HK), Chinese (TW), Chuvash, Dhivehi, Dutch, English, Esperanto, Estonian, French, German, Hakh-Chin, Indonesian, Interlingua, Irish, Italian, Japanese, Kabyle, Kinyarwanda, Kyrgyz, Latvian, Mongolian, Persian, Portuguese, Russian, Sakha, Slovenian, Spanish, Swedish, Tamil, Tatar, Turkish, Welsh* (see also [finetuning splits]([https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz]) from [this paper](https://arxiv.org/abs/2002.02848)). + +* **[Babel](https://catalog.ldc.upenn.edu/byyear)** (17 languages, 1.7k hours): *Assamese, Bengali, Cantonese, Cebuano, Georgian, Haitian, Kazakh, Kurmanji, Lao, Pashto, Swahili, Tagalog, Tamil, Tok, Turkish, Vietnamese, Zulu* + + +## Training a new model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) + +### Prepare training data manifest: + +First, install the `soundfile` library: +```shell script +pip install soundfile +``` + +Next, run: + +```shell script +$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid +``` + +$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read. + +$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. +To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a +separately pre-processed manifest file. + +### Train a wav2vec 2.0 base model: + +This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper + +Note that the input is expected to be single channel, sampled at 16 kHz + +```shell script +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_base_librispeech +``` + +Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 64/k + +### Train a wav2vec 2.0 large model: + +This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper + +```shell script +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox +``` + +Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 128/k + +### Fine-tune a pre-trained model with CTC: + +Fine-tuning a model requires parallel audio and labels file, as well as a vocabulary file in fairseq format. +A letter vocabulary can be downloaded [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). +An example [script](libri_labels.py) that generates labels for the Librispeech dataset from the tsv file produced by wav2vec_manifest.py can be used as follows: + +```shell script +split=train +$ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $split +``` + +Fine-tuning on 100h of Librispeech with letter targets: +```shell script +$ fairseq-hydra-train \ + distributed_training.distributed_port=$PORT \ + task.data=/path/to/data \ + model.w2v_path=/path/to/model.pt \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \ + --config-name base_100h +``` + +There are other config files in the config/finetuning directory that can be used to fine-tune on other splits. +You can specify the right config via the `--config-name` parameter. + +Note: you can simulate 24 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 24/k + +Decoding with a language model during training requires flashlight [python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter). +If you want to use a language model, add `+criterion.wer_args='[/path/to/kenlm, /path/to/lexicon, 2, -1]'` to the command line. + +### Evaluating a CTC model: + +Evaluating a CTC model with a language model requires [flashlight python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter) to be installed. + +Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the [wav2letter model repository](https://github.com/facebookresearch/wav2letter/tree/master/recipes/sota/2019). +Be sure to upper-case the language model vocab after downloading it. + +Letter dictionary for pre-trained models can be found [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). + +Next, run the evaluation command: + +```shell script +$subset=dev_other +python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_pretraining \ +--nbest 1 --path /path/to/model --gen-subset $subset --results-path /path/to/save/results/for/sclite --w2l-decoder kenlm \ +--lm-model /path/to/kenlm.bin --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 \ +--post-process letter +``` + +To get raw numbers, use --w2l-decoder viterbi and omit the lexicon. To use the transformer language model, use --w2l-decoder fairseqlm. + +## Use wav2vec 2.0 with 🤗Transformers: + +Wav2Vec2 is also available in the [🤗Transformers library](https://github.com/huggingface/transformers) since version 4.3. + +Pretrained Models can be found on the [hub](https://huggingface.co/models?filter=wav2vec2) +and documentation can be found [here](https://huggingface.co/transformers/master/model_doc/wav2vec2.html). + +Usage example: + +```python +# !pip install transformers +import soundfile as sf +import torch +from transformers import Wav2Vec2ForMaskedLM, Wav2Vec2Tokenizer + +# load pretrained model +tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") +model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h") + +# load audio +audio_input, _ = sf.read("path/to/audio/file") + +# transcribe +input_values = tokenizer(audio_input, return_tensors="pt").input_values +logits = model(input_values).logits +predicted_ids = torch.argmax(logits, dim=-1) +transcription = tokenizer.batch_decode(predicted_ids)[0] +``` + +# wav2vec + +Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862). + +## Pre-trained models + +Description | Dataset | Model +---|---|--- +Wav2Vec large | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt) + +#### Example usage: +```python +import torch +import fairseq + +cp_path = '/path/to/wav2vec.pt' +model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) +model = model[0] +model.eval() + +wav_input_16khz = torch.randn(1,10000) +z = model.feature_extractor(wav_input_16khz) +c = model.feature_aggregator(z) +``` + +## Training a new model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate files 10 to 30 seconds in length) + +### Prepare training data manifest: + +``` +$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav +``` + +### Train a wav2vec model: + +``` +$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test +``` + +### Extract embeddings from the downstream task data: + +``` +$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/wav2vec_featurize.py --input /path/to/task/waves --output /path/to/output \ +--model /model/path/checkpoint_best.pt --split train valid test +``` + +# vq-wav2vec + +Example to train a vq-wav2vec model as described in [vq-wav2vec: Self-Supervised Learning of Discrete Speech Representations (Baevski et al., 2019)](https://arxiv.org/abs/1910.05453). + +These models are also used in [Effectiveness of self-supervised pre-training for speech recognition (Baevski et al., 2019)](https://arxiv.org/abs/1911.03912). + +## Pre-trained models + +Description | Dataset | Model +---|---|--- +vq-wav2vec Gumbel | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt) +vq-wav2vec K-means | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt) +Roberta on K-means codes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/bert_kmeans.tar) + +#### Example usage: +```python +import torch +import fairseq + +cp = torch.load('/path/to/vq-wav2vec.pt') +model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp]) +model = model[0] +model.eval() + +wav_input_16khz = torch.randn(1,10000) +z = model.feature_extractor(wav_input_16khz) +_, idxs = model.vector_quantizer.forward_idx(z) +print(idxs.shape) # output: torch.Size([1, 60, 2]), 60 timesteps with 2 indexes corresponding to 2 groups in the model +``` + +## Training a new model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) + +### Prepare training data manifest: + +``` +$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav +``` + +### Train a gumbel vq-wav2vec model: + +``` +$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 \ +--save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 \ +--optimizer adam --lr 1e-05 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \ +--log-keys ["prob_perplexity","code_perplexity","temp"] --vq-type gumbel --vq-groups 2 --vq-depth 2 \ +--combine-groups --vq-vars 320 --vq-temp (2,0.5,0.999995) --prediction-steps 12 --warmup-updates 1000 \ +--warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 --max-sample-size 150000 \ +--max-tokens 300000 --cross-sample-negatives 0 --update-freq 1 --seed 2 --skip-invalid-size-inputs-valid-test +``` + +for k-means training, set vq-type with "kmeans" and add --loss-weights [1] argument. Pre-trained models were trained on 16 GPUs. + +### Tokenize audio data (e.g. for BERT training): + +``` +$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/vq-wav2vec_featurize.py --data-dir /manifest/path --output-dir /path/to/output \ +--checkpoint /model/path/checkpoint_best.pt --split train valid test --extension tsv +``` From 02803a1be45642b4c2f9c2970a4f4ae645a2dccf Mon Sep 17 00:00:00 2001 From: Weiyi Zheng Date: Fri, 12 Feb 2021 14:04:21 -0800 Subject: [PATCH 12/82] broadcast the whole optimizer state to each rank Summary: OSS removed the 'partition' key in their state dict to accommodate for changing partition size. This requires an update on the fairseq side to not look into the parameter partition, just broadcast everything, and let the optimizer on each rank decides which parameters are relevant. This diff also needs D26419095 to function completely, and blefaudeux has made fixes upstream in https://github.com/facebookresearch/fairscale/pull/383 Reviewed By: myleott Differential Revision: D26382917 fbshipit-source-id: 95af1022be59e88814748acaee36a1a350f7dc5b --- fairseq/optim/shard.py | 58 ++++++++---------------------------------- 1 file changed, 10 insertions(+), 48 deletions(-) diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py index 3d025a23ca..3c1b34ae60 100644 --- a/fairseq/optim/shard.py +++ b/fairseq/optim/shard.py @@ -5,11 +5,11 @@ from typing import Any, Dict -import torch +from fairseq.distributed import utils try: - from fairscale.optim import OSS, utils + from fairscale.optim import OSS _has_fairscale = True except ImportError: @@ -38,53 +38,15 @@ def broadcast_global_state_dict( self, state_dict: Dict[str, Any] ) -> Dict[str, Any]: """ - Broadcasts the relevant parts of a global state dict from rank 0 to - all other ranks. + Broadcasts the entire state_dict to all other ranks + each rank is responsible to load their own partition of data """ - if self.rank == 0: - - # Create template state dict for all other keys not related to sharding - template_state_dict = { - key: state_dict[key] - for key in state_dict - if key not in ("param_groups", "state") - } - template_state_dict["local_state_dict"] = True - - for dst_rank in range(self.world_size): - # Get the dst_rank's param_groups shard - send_state = { - "param_groups": state_dict["param_groups"][ - state_dict["partition"][dst_rank][0] : state_dict[ - "partition" - ][dst_rank][1] - ], - "state": state_dict["state"][dst_rank], - } - send_state.update(template_state_dict) - - if dst_rank == 0: - recv_state = send_state - else: - utils.broadcast_object( - send_state, - src_rank=0, - group=self.group, - dist_device=self._device, - ) - else: - empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device) - for dst_rank in range(1, self.world_size): - state = utils.broadcast_object( - empty_buffer, - src_rank=0, - group=self.group, - dist_device=self._device, - ) - if dst_rank == self.rank: - recv_state = state - - return recv_state + return utils.broadcast_object( + state_dict, + src_rank=0, + group=self.group, + dist_device=self._device, + ) torch_optimizer = optimizer.optimizer optim_cls = type(torch_optimizer) From 09945b45d4e2608563b1b18c3bbe289bf9351529 Mon Sep 17 00:00:00 2001 From: cordercorder <2205722269@qq.com> Date: Fri, 12 Feb 2021 14:35:55 -0800 Subject: [PATCH 13/82] Fixes bugs of evaluation with BLEU score when training with multi-gpus. (#3237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: …ith BLEU scores # Before submitting - [no] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [yes] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [no need] Did you make sure to update the docs? - [no need] Did you write any new necessary tests? ## What does this PR do? Fixes bugs of evaluation with BLEU score when training with multi-gpus. But no error will happend if there is no distributed training. when --eval-bleu is set to be `True` (default it is `False` and the best checkpoint is selected according to loss) and training with multi-gpus (when the number of gpu which participate in distributed training is greater than 1), following error will happend. ```bash Traceback (most recent call last): Traceback (most recent call last): File "/data/cordercorder/anaconda3/envs/nmt/bin/fairseq-train", line 33, in File "/data/cordercorder/anaconda3/envs/nmt/bin/fairseq-train", line 33, in Traceback (most recent call last): File "/data/cordercorder/anaconda3/envs/nmt/bin/fairseq-train", line 33, in sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')())sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')()) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 450, in cli_main File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 450, in cli_main sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')()) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 450, in cli_main distributed_utils.call_main(cfg, main)distributed_utils.call_main(cfg, main) File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 349, in call_main File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 349, in call_main distributed_utils.call_main(cfg, main) File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 349, in call_main distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 326, in distributed_main File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 326, in distributed_main distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 326, in distributed_main main(cfg, **kwargs) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 143, in main main(cfg, **kwargs) main(cfg, **kwargs)rder/fairseq/fairseq_cli/train.py", line 143, in main File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 143, in main valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner return func(*args, **kwds) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 259, in train Traceback (most recent call last): File "/data/cordercorder/anaconda3/envs/nmt/bin/fairseq-train", line 33, in return func(*args, **kwds) return func(*args, **kwds) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 259, in train File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 259, in train cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 345, in validate_and_save cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 345, in validate_and_save cfg, trainer, task, epoch_itr, valid_subsets, end_of_epochsys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')()) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 345, in validate_and_save File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 450, in cli_main valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 413, in validate valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 413, in validate valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 413, in validate trainer.valid_step(sample) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner distributed_utils.call_main(cfg, main) File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 349, in call_main trainer.valid_step(sample) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner return func(*args, **kwds) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 834, in valid_step trainer.valid_step(sample) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner return func(*args, **kwds) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 834, in valid_step return func(*args, **kwds)distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 834, in valid_step File "/data1/cordercorder/fairseq/fairseq/distributed/utils.py", line 326, in distributed_main main(cfg, **kwargs) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 143, in main logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 1157, in _reduce_and_log_stats logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 1157, in _reduce_and_log_stats valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 1157, in _reduce_and_log_stats return func(*args, **kwds) File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 259, in train cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 345, in validate_and_save self.task.reduce_metrics(logging_outputs, self.get_criterion()) File "/data1/cordercorder/fairseq/fairseq/tasks/translation.py", line 410, in reduce_metrics self.task.reduce_metrics(logging_outputs, self.get_criterion())valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) File "/data1/cordercorder/fairseq/fairseq/tasks/translation.py", line 410, in reduce_metrics File "/data1/cordercorder/fairseq/fairseq_cli/train.py", line 413, in validate self.task.reduce_metrics(logging_outputs, self.get_criterion()) File "/data1/cordercorder/fairseq/fairseq/tasks/translation.py", line 410, in reduce_metrics metrics.log_scalar("_bleu_counts", np.array(counts)) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/tensor.py", line 480, in __array__ trainer.valid_step(sample) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/contextlib.py", line 74, in inner metrics.log_scalar("_bleu_counts", np.array(counts)) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/tensor.py", line 480, in __array__ return func(*args, **kwds)metrics.log_scalar("_bleu_counts", np.array(counts)) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 834, in valid_step File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/tensor.py", line 480, in __array__ return self.numpy() TypeError: can't convert cuda:2 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. return self.numpy() TypeError: can't convert cuda:3 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. return self.numpy() TypeError: can't convert cuda:1 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) File "/data1/cordercorder/fairseq/fairseq/trainer.py", line 1157, in _reduce_and_log_stats self.task.reduce_metrics(logging_outputs, self.get_criterion()) File "/data1/cordercorder/fairseq/fairseq/tasks/translation.py", line 410, in reduce_metrics metrics.log_scalar("_bleu_counts", np.array(counts)) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/tensor.py", line 480, in __array__ return self.numpy() TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. Traceback (most recent call last): File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/runpy.py", line 193, in _run_module_as_main "__main__", mod_spec) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/runpy.py", line 85, in _run_code exec(code, run_globals) File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/distributed/launch.py", line 261, in main() File "/data/cordercorder/anaconda3/envs/nmt/lib/python3.7/site-packages/torch/distributed/launch.py", line 257, in main cmd=cmd) subprocess.CalledProcessError: Command '['/data/cordercorder/anaconda3/envs/nmt/bin/python', '-u', '/data/cordercorder/anaconda3/envs/nmt/bin/fairseq-train', '--local_rank=3', 'tiny_data_bin', '--distributed-world-size', '4', '--arch', 'transformer', '--share-decoder-input-output-embed', '--optimizer', 'adam', '--adam-betas', '(0.9, 0.98)', '--clip-norm', '0.0', '--lr-scheduler', 'inverse_sqrt', '--warmup-init-lr', '1e-07', '--warmup-updates', '3000', '--lr', '0.0005', '--stop-min-lr', '1e-09', '--dropout', '0.25', '--weight-decay', '0.0001', '--criterion', 'label_smoothed_cross_entropy', '--label-smoothing', '0.1', '--max-tokens', '5000', '--batch-size', '64', '--update-freq', '4', '--max-epoch', '30', '--save-dir', 'checkpoint', '--skip-invalid-size-inputs-valid-test', '--eval-bleu', '--eval-bleu-args', '{"beam": 5}', '--eval-bleu-remove-bpe', 'sentencepiece', '--eval-bleu-print-samples', '--eval-tokenized-bleu', '--best-checkpoint-metric', 'bleu', '--maximize-best-checkpoint-metric', '--validate-interval-updates', '1']' returned non-zero exit status 1. ``` The error is cased by the fact that the numpy of version 1.20.1 does't support codes like following: ```python import torch import numpy as np a = torch.tensor(0, device="cuda:0") b = np.array([a]) ``` The above codes will lead to error: "TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.", but the codes run well if the numpy version is 1.18.1 or 1.17.0 (when the numpy version is below 1.20.0, it is ok, I guess). However, it seems like that the latest version of fairseq need a numpy package of version 1.20.0 or higher (issue https://github.com/pytorch/fairseq/issues/3203 ). ### Reproduce the error Download the source code of fairseq (commit ID: 7061a0ff83872ac491ba5963eb7fc04cb10d57c4) and run following code: ```bash export CUDA_VISIBLE_DEVICES=0,1,2,3 data_bin_dir=tiny_data_bin python -m torch.distributed.launch --nproc_per_node=4 \ --master_addr="127.0.0.1" \ --master_port=12345 \ $(which fairseq-train) ${data_bin_dir} \ --distributed-world-size 4 \ --arch transformer \ --share-decoder-input-output-embed \ --optimizer adam \ --adam-betas '(0.9, 0.98)' \ --clip-norm 0.0 \ --lr-scheduler inverse_sqrt \ --warmup-init-lr 1e-07 \ --warmup-updates 3000 \ --lr 0.0005 \ --stop-min-lr 1e-09 \ --dropout 0.25 \ --weight-decay 0.0001 \ --criterion label_smoothed_cross_entropy \ --label-smoothing 0.1 \ --max-tokens 5000 \ --batch-size 64 \ --update-freq 4 \ --max-epoch 30 \ --save-dir checkpoint \ --skip-invalid-size-inputs-valid-test \ --eval-bleu \ --eval-bleu-args '{"beam": 5}' \ --eval-bleu-remove-bpe sentencepiece \ --eval-bleu-print-samples \ --eval-tokenized-bleu \ --best-checkpoint-metric bleu \ --maximize-best-checkpoint-metric \ --validate-interval-updates 1 ``` ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3237 Reviewed By: myleott Differential Revision: D26429732 Pulled By: alexeib fbshipit-source-id: bc887ce952d28541cb07dbbdc7e80e99428a6b34 --- fairseq/tasks/translation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 90635d882f..331f685495 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -394,7 +394,11 @@ def reduce_metrics(self, logging_outputs, criterion): if self.cfg.eval_bleu: def sum_logs(key): - return sum(log.get(key, 0) for log in logging_outputs) + import torch + result = sum(log.get(key, 0) for log in logging_outputs) + if torch.is_tensor(result): + result = result.cpu() + return result counts, totals = [], [] for i in range(EVAL_BLEU_ORDER): From 5ac5e8a20a7a914698f9970c2a384f14015ece3d Mon Sep 17 00:00:00 2001 From: alexeib Date: Fri, 12 Feb 2021 21:18:23 -0800 Subject: [PATCH 14/82] fix sharing objects between tasks (#1623) Summary: fixes previous change that changes state/dataset/etc to class variables instead of instance variables Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1623 Reviewed By: michaelauli Differential Revision: D26439560 Pulled By: alexeib fbshipit-source-id: ab9e75a425a47ac7ace006419259e254770e560e --- fairseq/tasks/fairseq_task.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 04025023fa..375b5277b9 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -78,12 +78,15 @@ def logging_outputs_can_be_summed(criterion) -> bool: return criterion.logging_outputs_can_be_summed() cfg: FairseqDataclass - datasets: Dict[str, FairseqDataset] = dict() - dataset_to_epoch_iter: Dict[FairseqDataset, Any] = dict() - state: StatefulContainer = StatefulContainer() + datasets: Dict[str, FairseqDataset] + dataset_to_epoch_iter: Dict[FairseqDataset, Any] + state: StatefulContainer = None def __init__(self, cfg: FairseqDataclass, **kwargs): self.cfg = cfg + self.datasets = dict() + self.dataset_to_epoch_iter = dict() + self.state = StatefulContainer() @classmethod def load_dictionary(cls, filename): @@ -553,10 +556,13 @@ def reduce_metrics(self, logging_outputs, criterion): criterion.__class__.reduce_metrics(logging_outputs) def state_dict(self): - return self.state.state_dict + if self.state is not None: + return self.state.state_dict + return {} def load_state_dict(self, state_dict: Dict[str, Any]): - self.state.merge_state_dict(state_dict) + if self.state is not None: + self.state.merge_state_dict(state_dict) def max_positions(self): """Return the max input length allowed by the task.""" From 43415b44781af6ac9c10adce0ae2a7d26d611bd1 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 16 Feb 2021 15:50:46 -0800 Subject: [PATCH 15/82] Prepend embedding layer when return_all_hiddens=True in TransformerEncoder (#1559) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1559 This matches the behavior of RobertaEncoder. Test Plan: Imported from OSS Reviewed By: gwenzek Differential Revision: D25936937 Pulled By: myleott fbshipit-source-id: 795ec8d50298a41d9e9638101436faa01cdf1586 --- fairseq/models/transformer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 362d9b28d6..78762ef924 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -435,14 +435,21 @@ def forward( """ x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) - # B x T x C -> T x B x C - x = x.transpose(0, 1) - # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) + # account for padding while computing the representation + if encoder_padding_mask is not None: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + encoder_states = [] + if return_all_hiddens: + encoder_states.append(x) + # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) @@ -454,7 +461,7 @@ def forward( x = self.layer_norm(x) # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in - # `foward` so we use a dictionary instead. + # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. return { From 54423d3b22a3e7f536e02e9e5445cef9becbd60d Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 16 Feb 2021 15:50:46 -0800 Subject: [PATCH 16/82] refactor RobertaEncoder (#1560) Summary: This is long overdue, but finally deprecating the RobertaEncoder components and just using TransformerEncoder directly. This will make it easier for some upcoming online backtranslation changes, and will eventually make migrating it to dataclasses/Hydra easier too. It also fixes some longstanding inconsistencies in layernorm placement in the model parallel roberta code. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1560 Test Plan: - confirmed that training gives identical losses as before: https://gist.github.com/myleott/9a4d213fb88a02b00094ea074f5a2e2d - confirmed that old roberta models can be loaded and produce identical results - confirmed that old linformer models can be loaded and produce identical results (reran commands from D25938236 (https://github.com/pytorch/fairseq/commit/bf54551cafa13678c0254d2c20354cc026cc0bac)) - confirmed that old model parallel models can be loaded and produce identical results: ``` python -m fairseq_cli.validate --path checkpoint.mp1/checkpoint_last.pt --task dummy_masked_lm --criterion masked_lm --max-sentences 8 --dataset-size 100 --model-parallel-size 2 --distributed-world-size 2 before: 2021-01-19 19:04:14 | INFO | valid | | valid on 'valid' subset | loss 14.62 | ppl 25174.3 | wps 0 | wpb 53248 | bsz 104 after: 2021-01-19 19:06:59 | INFO | valid | | valid on 'valid' subset | loss 14.62 | ppl 25174.3 | wps 0 | wpb 53248 | bsz 104 ``` Reviewed By: gwenzek, ngoyal2707 Differential Revision: D25937145 Pulled By: myleott fbshipit-source-id: 1ce0bc93e28e03fb926534ea4134684a49232599 --- .../linformer_src/models/linformer_roberta.py | 71 ++------- .../modules/linformer_sentence_encoder.py | 137 ++-------------- .../linformer_sentence_encoder_layer.py | 83 ++-------- .../model_parallel/models/roberta/model.py | 148 +++++------------- fairseq/model_parallel/models/transformer.py | 7 +- fairseq/model_parallel/modules/__init__.py | 6 - .../modules/transformer_sentence_encoder.py | 59 ------- .../transformer_sentence_encoder_layer.py | 77 --------- fairseq/models/roberta/model.py | 100 ++++++++---- fairseq/models/transformer.py | 1 + 10 files changed, 161 insertions(+), 528 deletions(-) delete mode 100644 fairseq/model_parallel/modules/transformer_sentence_encoder.py delete mode 100644 fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py diff --git a/examples/linformer/linformer_src/models/linformer_roberta.py b/examples/linformer/linformer_src/models/linformer_roberta.py index be5d8e85ec..18ad44f079 100644 --- a/examples/linformer/linformer_src/models/linformer_roberta.py +++ b/examples/linformer/linformer_src/models/linformer_roberta.py @@ -11,9 +11,15 @@ import torch from fairseq import utils from fairseq.models import register_model, register_model_architecture -from fairseq.models.roberta import RobertaEncoder, RobertaModel +from fairseq.models.roberta import ( + init_bert_params, + roberta_base_architecture, + roberta_large_architecture, + RobertaEncoder, + RobertaModel, +) -from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder +from ..modules.linformer_sentence_encoder import LinformerTransformerEncoder logger = logging.getLogger(__name__) @@ -66,30 +72,10 @@ def __init__(self, args, dictionary): super().__init__(args, dictionary) self.register_buffer("version", torch.tensor(2)) - def build_encoder(self, args, dictionary): - return LinformerSentenceEncoder( - padding_idx=dictionary.pad(), - vocab_size=len(dictionary), - num_encoder_layers=args.encoder_layers, - embedding_dim=args.encoder_embed_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - layerdrop=args.encoder_layerdrop, - max_seq_len=args.max_positions, - num_segments=0, - encoder_normalize_before=True, - apply_bert_init=True, - activation_fn=args.activation_fn, - q_noise=args.quant_noise_pq, - qn_block_size=args.quant_noise_pq_block_size, - compressed=args.compressed, - shared_kv_compressed=args.shared_kv_compressed, - shared_layer_kv_compressed=args.shared_layer_kv_compressed, - freeze_compress=args.freeze_compress, - ) + def build_encoder(self, args, dictionary, embed_tokens): + encoder = LinformerTransformerEncoder(args, dictionary, embed_tokens) + encoder.apply(init_bert_params) + return encoder def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) @@ -115,25 +101,11 @@ def upgrade_state_dict_named(self, state_dict, name): @register_model_architecture("linformer_roberta", "linformer_roberta") def base_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 12) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) - args.compressed = getattr(args, "compressed", 4) args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0) args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0) args.freeze_compress = getattr(args, "freeze_compress", 0) + roberta_base_architecture(args) @register_model_architecture("linformer_roberta", "linformer_roberta_base") @@ -143,18 +115,5 @@ def linformer_roberta_base_architecture(args): @register_model_architecture("linformer_roberta", "linformer_roberta_large") def linformer_roberta_large_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 24) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.compressed = getattr(args, "compressed", 4) - args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0) - args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0) + roberta_large_architecture(args) + base_architecture(args) diff --git a/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py b/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py index 3cdca01235..44f7989bd8 100644 --- a/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py +++ b/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py @@ -6,12 +6,12 @@ import math import torch.nn as nn -from fairseq.modules import TransformerSentenceEncoder +from fairseq.models.transformer import TransformerEncoder -from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer +from .linformer_sentence_encoder_layer import LinformerTransformerEncoderLayer -class LinformerSentenceEncoder(TransformerSentenceEncoder): +class LinformerTransformerEncoder(TransformerEncoder): """ Implementation for a Bi-directional Linformer based Sentence Encoder used in BERT/XLM style pre-trained models. @@ -35,135 +35,20 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder): in format B x C. """ - def __init__( - self, - padding_idx: int, - vocab_size: int, - num_encoder_layers: int = 6, - embedding_dim: int = 768, - ffn_embedding_dim: int = 3072, - num_attention_heads: int = 8, - dropout: float = 0.1, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - layerdrop: float = 0.0, - max_seq_len: int = 256, - num_segments: int = 2, - use_position_embeddings: bool = True, - offset_positions_by_padding: bool = True, - encoder_normalize_before: bool = False, - apply_bert_init: bool = False, - activation_fn: str = "relu", - learned_pos_embedding: bool = True, - embed_scale: float = None, - freeze_embeddings: bool = False, - n_trans_layers_to_freeze: int = 0, - export: bool = False, - traceable: bool = False, - q_noise: float = 0.0, - qn_block_size: int = 8, - compressed: int = 4, - shared_kv_compressed: int = 0, - shared_layer_kv_compressed: int = 0, - freeze_compress: int = 0, - ) -> None: - - # Initialize linformer parameters - self.compressed = compressed - self.shared_kv_compressed = shared_kv_compressed - self.shared_layer_kv_compressed = shared_layer_kv_compressed + def __init__(self, args, dictionary, embed_tokens): self.compress_layer = None - self.freeze_compress = freeze_compress - - super().__init__( - padding_idx=padding_idx, - vocab_size=vocab_size, - num_encoder_layers=num_encoder_layers, - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - layerdrop=layerdrop, - max_seq_len=max_seq_len, - num_segments=num_segments, - use_position_embeddings=use_position_embeddings, - offset_positions_by_padding=offset_positions_by_padding, - encoder_normalize_before=encoder_normalize_before, - apply_bert_init=apply_bert_init, - activation_fn=activation_fn, - learned_pos_embedding=learned_pos_embedding, - embed_scale=embed_scale, - freeze_embeddings=freeze_embeddings, - n_trans_layers_to_freeze=n_trans_layers_to_freeze, - export=export, - traceable=traceable, - q_noise=q_noise, - qn_block_size=qn_block_size, - ) + super().__init__(args, dictionary, embed_tokens) - def build_transformer_sentence_encoder_layer( - self, - embedding_dim, - ffn_embedding_dim, - num_attention_heads, - dropout, - attention_dropout, - activation_dropout, - activation_fn, - export, - q_noise, - qn_block_size, - ): - if self.shared_layer_kv_compressed == 1 and self.compress_layer is None: + def build_encoder_layer(self, args): + if self.args.shared_layer_kv_compressed == 1 and self.compress_layer is None: compress_layer = nn.Linear( - self.max_seq_len, self.max_seq_len // self.compressed + self.args.max_positions, + self.args.max_positions // self.args.compressed, ) # intialize parameters for compressed layer nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2)) - if self.freeze_compress == 1: + if self.args.freeze_compress == 1: compress_layer.weight.requires_grad = False self.compress_layer = compress_layer - return LinformerSentenceEncoderLayer( - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - activation_fn=activation_fn, - export=export, - q_noise=q_noise, - qn_block_size=qn_block_size, - compressed=self.compressed, - max_seq_len=self.max_seq_len, - shared_kv_compressed=self.shared_kv_compressed, - shared_compress_layer=( - None if self.shared_layer_kv_compressed == 0 else self.compress_layer - ), - freeze_compress=self.freeze_compress, - ) - - def upgrade_state_dict_named(self, state_dict, name): - prefix = name + "." if name != "" else "" - items_to_add = {} - keys_to_remove = [] - - # update key name for shared layer in new version of code - for k in state_dict.keys(): - if k.startswith(prefix + "compress_layer"): - if self.shared_layer_kv_compressed: - for layer_idx in range(len(self.layers)): - new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format( - layer_idx, - k[len(prefix + "compress_layer.") :], - ) - items_to_add[new_k] = state_dict[k] - - for k in keys_to_remove: - del state_dict[k] - - for key, value in items_to_add.items(): - state_dict[key] = value + return LinformerTransformerEncoderLayer(args, self.compress_layer) diff --git a/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py b/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py index 0b80fabefe..7e2caa0340 100644 --- a/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py +++ b/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py @@ -3,88 +3,44 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable - import torch from fairseq import utils -from fairseq.modules import TransformerSentenceEncoderLayer +from fairseq.modules import TransformerEncoderLayer from .multihead_linear_attention import MultiheadLinearAttention -class LinformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): +class LinformerTransformerEncoderLayer(TransformerEncoderLayer): """ Implements a Linformer Encoder Layer used in BERT/XLM style pre-trained models. """ - def __init__( - self, - embedding_dim: int = 768, - ffn_embedding_dim: int = 3072, - num_attention_heads: int = 8, - dropout: float = 0.1, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - activation_fn: str = "relu", - export: bool = False, - q_noise: float = 0.0, - qn_block_size: int = 8, - init_fn: Callable = None, - compressed: int = 1, - max_seq_len: int = 256, - shared_kv_compressed: int = 0, - shared_compress_layer: any = None, - freeze_compress: int = 0, - ) -> None: - - # Initialize linformer parameters - self.compressed = compressed - self.max_seq_len = max_seq_len - self.shared_kv_compressed = shared_kv_compressed - self.freeze_compress = freeze_compress - + def __init__(self, args, shared_compress_layer): # wrap in a list so it's not automatically registered by PyTorch self.shared_compress_layer = [shared_compress_layer] - super().__init__( - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - activation_fn=activation_fn, - export=export, - q_noise=q_noise, - qn_block_size=qn_block_size, - ) + super().__init__(args) + self.register_buffer("version", torch.tensor(2)) - def build_self_attention( - self, - embed_dim, - num_attention_heads, - dropout, - self_attention, - q_noise, - qn_block_size, - ): + def build_self_attention(self, embed_dim, args): return MultiheadLinearAttention( embed_dim, - num_attention_heads, - dropout=dropout, + args.encoder_attention_heads, + dropout=args.dropout, self_attention=True, - q_noise=q_noise, - qn_block_size=qn_block_size, - compressed=self.compressed, - max_seq_len=self.max_seq_len, - shared_kv_compressed=self.shared_kv_compressed, + q_noise=args.quant_noise_pq, + qn_block_size=args.quant_noise_pq_block_size, + compressed=args.compressed, + max_seq_len=args.max_positions, + shared_kv_compressed=args.shared_kv_compressed, shared_compress_layer=self.shared_compress_layer[0], - freeze_compress=self.freeze_compress, + freeze_compress=args.freeze_compress, ) def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) prefix = name + "." if name != "" else "" # some old checkpoints had weight sharing implemented incorrectly @@ -101,14 +57,7 @@ def upgrade_state_dict_named(self, state_dict, name): self.shared_compress_layer[0].weight.size(0), ) ] - self.self_attn = self.build_self_attention( - self.embedding_dim, - self.num_attention_heads, - dropout=self.attention_dropout, - self_attention=True, - q_noise=self.q_noise, - qn_block_size=self.qn_block_size, - ) + self.self_attn = self.build_self_attention(self.embed_dim, self.args) # delete shared_compress_layer, since it's already copied to # self_attn.compress_k.weight del state_dict[f"{prefix}shared_compress_layer.weight"] diff --git a/fairseq/model_parallel/models/roberta/model.py b/fairseq/model_parallel/models/roberta/model.py index 68ad88d2a5..77a80ef720 100644 --- a/fairseq/model_parallel/models/roberta/model.py +++ b/fairseq/model_parallel/models/roberta/model.py @@ -12,16 +12,15 @@ import torch.nn as nn import torch.nn.functional as F from fairseq import utils -from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoder -from fairseq.models import FairseqEncoder, register_model, register_model_architecture +from fairseq.model_parallel.models.transformer import ModelParallelTransformerEncoder +from fairseq.models import register_model, register_model_architecture from fairseq.models.roberta import ( - RobertaClassificationHead, + roberta_base_architecture, + roberta_prenorm_architecture, RobertaEncoder, - RobertaLMHead, RobertaModel, ) -from fairseq.modules import LayerNorm, TransformerSentenceEncoder -from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.modules import LayerNorm try: @@ -29,7 +28,7 @@ copy_to_model_parallel_region, gather_from_model_parallel_region, ColumnParallelLinear, - RowParallelLinear, + VocabParallelEmbedding, ) has_megatron_submodule = True @@ -48,7 +47,15 @@ def __init__(self, args, encoder): @staticmethod def add_args(parser): - super(ModelParallelRobertaModel, ModelParallelRobertaModel).add_args(parser) + RobertaModel.add_args(parser) + parser.add_argument( + "--no-final-layer-norm", + action="store_true", + help=( + "don't add final layernorm (only applicable when " + "--encoder-normalize-before=True" + ), + ) @classmethod def build_model(cls, args, task): @@ -165,121 +172,52 @@ def forward(self, features, **kwargs): return x -class ModelParallelRobertaEncoder(FairseqEncoder): - """RoBERTa encoder. - - Implements the :class:`~fairseq.models.FairseqDecoder` interface required - by :class:`~fairseq.models.FairseqLanguageModel`. - """ +class ModelParallelRobertaEncoder(RobertaEncoder): + """RoBERTa encoder.""" def __init__(self, args, dictionary): - super().__init__(dictionary) - self.args = args - - # RoBERTa is a sentence encoder model, so users will intuitively trim - # encoder layers. However, the implementation uses the fairseq decoder, - # so we fix here. - if args.encoder_layers_to_keep: - args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) - args.decoder_layers_to_keep = args.encoder_layers_to_keep - args.encoder_layers_to_keep = None - - self.sentence_encoder = ModelParallelTransformerSentenceEncoder( - padding_idx=dictionary.pad(), - vocab_size=len(dictionary), - num_encoder_layers=args.encoder_layers, - embedding_dim=args.encoder_embed_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - layerdrop=args.encoder_layerdrop, - max_seq_len=args.max_positions, - num_segments=0, - encoder_normalize_before=False, - apply_bert_init=False, - activation_fn=args.activation_fn, - ) - self.lm_head = ModelParallelRobertaLMHead( - embed_dim=args.encoder_embed_dim, - output_dim=len(dictionary), - activation_fn=args.activation_fn, - weight=self.sentence_encoder.embed_tokens.weight, - ) - - def forward( - self, - src_tokens, - features_only=False, - return_all_hiddens=False, - masked_tokens=None, - **unused - ): - """ - Args: - src_tokens (LongTensor): input tokens of shape `(batch, src_len)` - features_only (bool, optional): skip LM head and just return - features. If True, the output will be of shape - `(batch, src_len, embed_dim)`. - return_all_hiddens (bool, optional): also return all of the - intermediate hidden states (default: False). - - Returns: - tuple: - - the LM output of shape `(batch, src_len, vocab)` - - a dictionary of additional data, where 'inner_states' - is a list of hidden states. Note that the hidden - states have shape `(src_len, batch, vocab)`. - """ - x, extra = self.extract_features( - src_tokens, return_all_hiddens=return_all_hiddens - ) - if not features_only: - x = self.output_layer(x, masked_tokens=masked_tokens) - return x, extra + super().__init__(args, dictionary) + assert not self.args.untie_weights_roberta - def extract_features(self, src_tokens, return_all_hiddens=False, **unused): - inner_states, _ = self.sentence_encoder( - src_tokens, - last_state_only=not return_all_hiddens, - ) - features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C - return features, {"inner_states": inner_states if return_all_hiddens else None} + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) - def output_layer(self, features, masked_tokens=None, **unused): - return self.lm_head(features, masked_tokens) + def build_encoder(self, args, dictionary, embed_tokens): + return ModelParallelTransformerEncoder(args, dictionary, embed_tokens) - def max_positions(self): - """Maximum output length supported by the encoder.""" - return self.args.max_positions + def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): + return ModelParallelRobertaLMHead(embed_dim, output_dim, activation_fn, weight) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta") def base_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 12) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + args.no_final_layer_norm = getattr(args, "no_final_layer_norm", False) + # model parallel RoBERTa defaults to "Pre-LN" formulation + roberta_prenorm_architecture(args) - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) +# earlier versions of model parallel RoBERTa removed the final layer norm +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_v1") +def model_parallel_roberta_v1_architecture(args): + args.no_final_layer_norm = getattr(args, "no_final_layer_norm", True) + base_architecture(args) + + +@register_model_architecture( + "model_parallel_roberta", "model_parallel_roberta_postnorm" +) +def model_parallel_roberta_postnorm_architecture(args): + # the original BERT/RoBERTa uses the "Post-LN" formulation + roberta_base_architecture(args) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base") -def roberta_base_architecture(args): +def model_parallel_roberta_base_architecture(args): base_architecture(args) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large") -def roberta_large_architecture(args): +def model_parallel_roberta_large_architecture(args): args.encoder_layers = getattr(args, "encoder_layers", 24) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) diff --git a/fairseq/model_parallel/models/transformer.py b/fairseq/model_parallel/models/transformer.py index 4f34645226..6b330ef1b7 100644 --- a/fairseq/model_parallel/models/transformer.py +++ b/fairseq/model_parallel/models/transformer.py @@ -6,7 +6,6 @@ import logging import torch.nn as nn -import torch.nn.functional as F from fairseq.model_parallel.modules import ( ModelParallelTransformerDecoderLayer, ModelParallelTransformerEncoderLayer, @@ -86,6 +85,12 @@ class ModelParallelTransformerEncoder(TransformerEncoder): is a :class:`ModelParallelTransformerEncoderLayer`. """ + def __init__(self, args, dictionary, embed_tokens): + super().__init__(args, dictionary, embed_tokens) + + if args.no_final_layer_norm: + self.layer_norm = None + def build_encoder_layer(self, args): return ModelParallelTransformerEncoderLayer(args) diff --git a/fairseq/model_parallel/modules/__init__.py b/fairseq/model_parallel/modules/__init__.py index fb45b3c9e0..11603217a1 100644 --- a/fairseq/model_parallel/modules/__init__.py +++ b/fairseq/model_parallel/modules/__init__.py @@ -9,15 +9,9 @@ ModelParallelTransformerEncoderLayer, ModelParallelTransformerDecoderLayer, ) -from .transformer_sentence_encoder_layer import ( - ModelParallelTransformerSentenceEncoderLayer, -) -from .transformer_sentence_encoder import ModelParallelTransformerSentenceEncoder __all__ = [ "ModelParallelMultiheadAttention", "ModelParallelTransformerEncoderLayer", "ModelParallelTransformerDecoderLayer", - "ModelParallelTransformerSentenceEncoder", - "ModelParallelTransformerSentenceEncoderLayer", ] diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder.py b/fairseq/model_parallel/modules/transformer_sentence_encoder.py deleted file mode 100644 index a5d50a33c6..0000000000 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import random -from typing import Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoderLayer -from fairseq.modules import ( - LayerNorm, - MultiheadAttention, - PositionalEmbedding, - TransformerSentenceEncoder, -) - - -try: - from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding - - has_megatron_submodule = True -except (ImportError, ModuleNotFoundError): - has_megatron_submodule = False - - -class ModelParallelTransformerSentenceEncoder(TransformerSentenceEncoder): - """ - Implementation for a Model Parallel Bi-directional Transformer based - Sentence Encoder used in BERT/XLM style pre-trained models. - """ - - def build_embedding(self, vocab_size, embedding_dim, padding_idx): - return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) - - def build_transformer_sentence_encoder_layer( - self, - embedding_dim, - ffn_embedding_dim, - num_attention_heads, - dropout, - attention_dropout, - activation_dropout, - activation_fn, - export, - **unused, - ): - return ModelParallelTransformerSentenceEncoderLayer( - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - activation_fn=activation_fn, - export=export, - ) diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py deleted file mode 100644 index e10bf52332..0000000000 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn.functional as F -from fairseq import utils -from fairseq.model_parallel.modules import ModelParallelMultiheadAttention -from fairseq.modules import TransformerSentenceEncoderLayer - - -try: - from fairseq.model_parallel.megatron.mpu import ( - ColumnParallelLinear, - RowParallelLinear, - ) - - has_megatron_submodule = True -except (ImportError, ModuleNotFoundError): - has_megatron_submodule = False - - -class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): - """ - Implements a Model Parallel Transformer Encoder Layer used in - BERT/XLM style pre-trained models. - """ - - def build_fc1(self, input_dim, output_dim, **unused): - return ColumnParallelLinear(input_dim, output_dim, gather_output=False) - - def build_fc2(self, input_dim, output_dim, **unused): - return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) - - def build_self_attention( - self, - embed_dim, - num_attention_heads, - dropout, - **kwargs, - ): - return ModelParallelMultiheadAttention( - embed_dim, num_attention_heads, dropout=dropout, self_attention=True - ) - - def forward( - self, - x: torch.Tensor, - self_attn_mask: torch.Tensor = None, - self_attn_padding_mask: torch.Tensor = None, - ): - """ - LayerNorm is applied either before or after the self-attention/ffn - modules similar to the original Transformer imlementation. - """ - residual = x - x = self.self_attn_layer_norm(x) - x, attn = self.self_attn( - query=x, - key=x, - value=x, - key_padding_mask=self_attn_padding_mask, - need_weights=False, - attn_mask=self_attn_mask, - ) - x = self.dropout_module(x) - x = residual + x - - residual = x - x = self.final_layer_norm(x) - x = self.activation_fn(self.fc1(x)) - x = self.activation_dropout_module(x) - x = self.fc2(x) - x = self.dropout_module(x) - x = residual + x - return x, None diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 00a5a5485f..a2a40ba6e2 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -18,7 +18,8 @@ register_model, register_model_architecture, ) -from fairseq.modules import LayerNorm, TransformerSentenceEncoder +from fairseq.models.transformer import TransformerEncoder +from fairseq.modules import LayerNorm from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from fairseq.modules.transformer_sentence_encoder import init_bert_params @@ -87,6 +88,11 @@ def add_args(parser): action="store_true", help="apply layernorm before each encoder block", ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) parser.add_argument( "--dropout", type=float, metavar="D", help="dropout probability" ) @@ -264,6 +270,13 @@ def upgrade_state_dict_named(self, state_dict, name): state_dict[new_k] = state_dict[k] del state_dict[k] + # rename emb_layer_norm -> layernorm_embedding + for k in list(state_dict.keys()): + if ".emb_layer_norm." in k: + new_k = k.replace(".emb_layer_norm.", ".layernorm_embedding.") + state_dict[new_k] = state_dict[k] + del state_dict[k] + # upgrade children modules super().upgrade_state_dict_named(state_dict, name) @@ -401,7 +414,11 @@ def __init__(self, args, dictionary): if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) - self.sentence_encoder = self.build_encoder(args, dictionary) + embed_tokens = self.build_embedding( + len(dictionary), args.encoder_embed_dim, dictionary.pad() + ) + + self.sentence_encoder = self.build_encoder(args, dictionary, embed_tokens) self.lm_head = self.build_lm_head( embed_dim=args.encoder_embed_dim, @@ -414,26 +431,16 @@ def __init__(self, args, dictionary): ), ) - def build_encoder(self, args, dictionary): - return TransformerSentenceEncoder( - padding_idx=dictionary.pad(), - vocab_size=len(dictionary), - num_encoder_layers=args.encoder_layers, - embedding_dim=args.encoder_embed_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - layerdrop=args.encoder_layerdrop, - max_seq_len=args.max_positions, - num_segments=0, - encoder_normalize_before=True, - apply_bert_init=True, - activation_fn=args.activation_fn, - q_noise=args.quant_noise_pq, - qn_block_size=args.quant_noise_pq_block_size, - ) + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return nn.Embedding(vocab_size, embedding_dim, padding_idx) + + def build_encoder(self, args, dictionary, embed_tokens): + encoder = TransformerEncoder(args, dictionary, embed_tokens) + encoder.apply(init_bert_params) + return encoder + + def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): + return RobertaLMHead(embed_dim, output_dim, activation_fn, weight) def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): return RobertaLMHead(embed_dim, output_dim, activation_fn, weight) @@ -470,13 +477,15 @@ def forward( return x, extra def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs): - inner_states, _ = self.sentence_encoder( + encoder_out = self.sentence_encoder( src_tokens, - last_state_only=not return_all_hiddens, + return_all_hiddens=return_all_hiddens, token_embeddings=kwargs.get("token_embeddings", None), ) - features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C - return features, {"inner_states": inner_states if return_all_hiddens else None} + # T x B x C -> B x T x C + features = encoder_out["encoder_out"][0].transpose(0, 1) + inner_states = encoder_out["encoder_states"] if return_all_hiddens else None + return features, {"inner_states": inner_states} def output_layer(self, features, masked_tokens=None, **unused): return self.lm_head(features, masked_tokens) @@ -493,21 +502,50 @@ def base_architecture(args): args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_dropout = getattr(args, "activation_dropout", 0.0) args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + + args.max_source_positions = getattr(args, "max_positions", 512) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + + # BERT has a few structural differences compared to the original Transformer + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) + args.layernorm_embedding = getattr(args, "layernorm_embedding", True) + args.no_scale_embedding = getattr(args, "no_scale_embedding", True) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False) + + # Adaptive input config + args.adaptive_input = getattr(args, "adaptive_input", False) + + # LayerDrop config + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + + # Quantization noise config + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + + # R4F config args.spectral_norm_classification_head = getattr( args, "spectral_norm_classification_head", False ) +@register_model_architecture("roberta", "roberta_prenorm") +def roberta_prenorm_architecture(args): + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + base_architecture(args) + + @register_model_architecture("roberta", "roberta_base") def roberta_base_architecture(args): base_architecture(args) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 78762ef924..4960fd143d 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -325,6 +325,7 @@ class TransformerEncoder(FairseqEncoder): """ def __init__(self, args, dictionary, embed_tokens): + self.args = args super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) From 7096ac35870aa24735bd0cc850beefa07784a668 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 16 Feb 2021 15:50:46 -0800 Subject: [PATCH 17/82] Make validate.py work with model parallel (#1570) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1570 Test Plan: Imported from OSS Reviewed By: gwenzek, ngoyal2707 Differential Revision: D25967675 Pulled By: myleott fbshipit-source-id: 7c7f8d25b87ef9b4f0a85331548bb3a2886a1e92 --- fairseq/logging/progress_bar.py | 2 +- fairseq_cli/validate.py | 22 ++++++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index dc061a1821..0ae2bc006d 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -123,7 +123,7 @@ def __init__(self, iterable, epoch=None, prefix=None): if epoch is not None: self.prefix += "epoch {:03d}".format(epoch) if prefix is not None: - self.prefix += " | {}".format(prefix) + self.prefix += (" | " if self.prefix != "" else "") + prefix def __len__(self): return len(self.iterable) diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index c69bb94142..90d7e4c6a9 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -u -# !/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the @@ -43,6 +42,13 @@ def main(cfg: DictConfig, override_args=None): if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) + if cfg.distributed_training.distributed_world_size > 1: + data_parallel_world_size = distributed_utils.get_data_parallel_world_size() + data_parallel_rank = distributed_utils.get_data_parallel_rank() + else: + data_parallel_world_size = 1 + data_parallel_rank = 0 + if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) @@ -91,8 +97,8 @@ def main(cfg: DictConfig, override_args=None): ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, seed=cfg.common.seed, - num_shards=cfg.distributed_training.distributed_world_size, - shard_id=cfg.distributed_training.distributed_rank, + num_shards=data_parallel_world_size, + shard_id=data_parallel_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) @@ -111,7 +117,7 @@ def main(cfg: DictConfig, override_args=None): progress.log(log_output, step=i) log_outputs.append(log_output) - if cfg.distributed_training.distributed_world_size > 1: + if data_parallel_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=cfg.common.all_gather_list_size, @@ -132,9 +138,13 @@ def cli_main(): # only override args that are explicitly given on the command line override_parser = options.get_validation_parser() - override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) + override_args = options.parse_args_and_arch( + override_parser, suppress_defaults=True + ) - distributed_utils.call_main(convert_namespace_to_omegaconf(args), main, override_args=override_args) + distributed_utils.call_main( + convert_namespace_to_omegaconf(args), main, override_args=override_args + ) if __name__ == "__main__": From e0788f7007a8473a76db573985031f3c94201e79 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 17 Feb 2021 10:54:25 -0800 Subject: [PATCH 18/82] fix bart generation bug (#1629) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1629 Reviewed By: myleott Differential Revision: D26484942 Pulled By: sshleifer fbshipit-source-id: 9dcbab5c404c14d8f35628d823102ad9ce59dffd --- fairseq/models/bart/hub_interface.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 1ff170a782..2ddeb763a3 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -92,22 +92,27 @@ def generate( tokenized_sentences: List[torch.LongTensor], *args, inference_step_args=None, + skip_invalid_size_inputs=False, **kwargs ) -> List[List[Dict[str, torch.Tensor]]]: inference_step_args = inference_step_args or {} if "prefix_tokens" in inference_step_args: raise NotImplementedError("prefix generation not implemented for BART") - else: - bsz = len(tokenized_sentences) - inference_step_args["prefix_tokens"] = tokenized_sentences[0].new_full( - (bsz, 1), fill_value=self.task.source_dictionary.bos() + res = [] + for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): + src_tokens = batch['net_input']['src_tokens'] + inference_step_args["prefix_tokens"] =src_tokens.new_full( + (src_tokens.size(0), 1), fill_value=self.task.source_dictionary.bos() ).to(device=self.device) - return super().generate( - tokenized_sentences, - *args, - inference_step_args=inference_step_args, - **kwargs - ) + results = super().generate( + src_tokens, + *args, + inference_step_args=inference_step_args, + skip_invalid_size_inputs=skip_invalid_size_inputs, + **kwargs + ) + res.extend(results) + return res def extract_features( self, tokens: torch.LongTensor, return_all_hiddens: bool = False From 7040ce71f3e0e84730adc267df764f48dc483dac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Onur=20=C3=87elebi?= Date: Thu, 18 Feb 2021 03:09:14 -0800 Subject: [PATCH 19/82] LASER training code (#1207) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Integrating LASER (Language-Agnostic SEntence Representations) training code - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ Y] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ N/A] Did you make sure to update the docs? - [ Y] Did you write any new necessary tests? => an additional test in `test_iterators.py` ## What does this PR do? This diff introduces the training code for LASER. It includes a specific `laser` task in `laser_task.py` which reads a json configuration file describing the binarized datasets of language pairs. `multitask_data_utils.py` defines dataset wrappers and iterators used by `laser` task. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Yes. � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1207 Reviewed By: myleott Differential Revision: D26454296 Pulled By: Celebio fbshipit-source-id: c987672aa66abf31b039ee11867b06912d3486e5 --- examples/laser/README.md | 144 +++++ examples/laser/laser_src/__init__.py | 8 + examples/laser/laser_src/laser_lstm.py | 585 ++++++++++++++++++ examples/laser/laser_src/laser_task.py | 326 ++++++++++ examples/laser/laser_src/laser_transformer.py | 354 +++++++++++ .../laser/laser_src/multitask_data_utils.py | 143 +++++ tests/test_binaries.py | 60 ++ tests/utils.py | 38 ++ 8 files changed, 1658 insertions(+) create mode 100644 examples/laser/README.md create mode 100644 examples/laser/laser_src/__init__.py create mode 100644 examples/laser/laser_src/laser_lstm.py create mode 100644 examples/laser/laser_src/laser_task.py create mode 100644 examples/laser/laser_src/laser_transformer.py create mode 100644 examples/laser/laser_src/multitask_data_utils.py diff --git a/examples/laser/README.md b/examples/laser/README.md new file mode 100644 index 0000000000..66acada04f --- /dev/null +++ b/examples/laser/README.md @@ -0,0 +1,144 @@ +# LASER Language-Agnostic SEntence Representations + +LASER is a library to calculate and use multilingual sentence embeddings. + +You can find more information about LASER and how to use it on the official [LASER repository](https://github.com/facebookresearch/LASER). + +This folder contains source code for training LASER embeddings. + + +## Prepare data and configuration file + +Binarize your data with fairseq, as described [here](https://fairseq.readthedocs.io/en/latest/getting_started.html#data-pre-processing). + +Create a json config file with this format: +``` +{ + "src_vocab": "/path/to/spm.src.cvocab", + "tgt_vocab": "/path/to/spm.tgt.cvocab", + "train": [ + { + "type": "translation", + "id": 0, + "src": "/path/to/srclang1-tgtlang0/train.srclang1", + "tgt": "/path/to/srclang1-tgtlang0/train.tgtlang0" + }, + { + "type": "translation", + "id": 1, + "src": "/path/to/srclang1-tgtlang1/train.srclang1", + "tgt": "/path/to/srclang1-tgtlang1/train.tgtlang1" + }, + { + "type": "translation", + "id": 0, + "src": "/path/to/srclang2-tgtlang0/train.srclang2", + "tgt": "/path/to/srclang2-tgtlang0/train.tgtlang0" + }, + { + "type": "translation", + "id": 1, + "src": "/path/to/srclang2-tgtlang1/train.srclang2", + "tgt": "/path/to/srclang2-tgtlang1/train.tgtlang1" + }, + ... + ], + "valid": [ + { + "type": "translation", + "id": 0, + "src": "/unused", + "tgt": "/unused" + } + ] +} +``` +where paths are paths to binarized indexed fairseq dataset files. +`id` represents the target language id. + + +## Training Command Line Example + +``` +fairseq-train \ + /path/to/configfile_described_above.json \ + --user-dir examples/laser/laser_src \ + --log-interval 100 --log-format simple \ + --task laser --arch laser_lstm \ + --save-dir . \ + --optimizer adam \ + --lr 0.001 \ + --lr-scheduler inverse_sqrt \ + --clip-norm 5 \ + --warmup-updates 90000 \ + --update-freq 2 \ + --dropout 0.0 \ + --encoder-dropout-out 0.1 \ + --max-tokens 2000 \ + --max-epoch 50 \ + --encoder-bidirectional \ + --encoder-layers 5 \ + --encoder-hidden-size 512 \ + --decoder-layers 1 \ + --decoder-hidden-size 2048 \ + --encoder-embed-dim 320 \ + --decoder-embed-dim 320 \ + --decoder-lang-embed-dim 32 \ + --warmup-init-lr 0.001 \ + --disable-validation +``` + + +## Applications + +We showcase several applications of multilingual sentence embeddings +with code to reproduce our results (in the directory "tasks"). + +* [**Cross-lingual document classification**](https://github.com/facebookresearch/LASER/tree/master/tasks/mldoc) using the + [*MLDoc*](https://github.com/facebookresearch/MLDoc) corpus [2,6] +* [**WikiMatrix**](https://github.com/facebookresearch/LASER/tree/master/tasks/WikiMatrix) + Mining 135M Parallel Sentences in 1620 Language Pairs from Wikipedia [7] +* [**Bitext mining**](https://github.com/facebookresearch/LASER/tree/master/tasks/bucc) using the + [*BUCC*](https://comparable.limsi.fr/bucc2018/bucc2018-task.html) corpus [3,5] +* [**Cross-lingual NLI**](https://github.com/facebookresearch/LASER/tree/master/tasks/xnli) + using the [*XNLI*](https://www.nyu.edu/projects/bowman/xnli/) corpus [4,5,6] +* [**Multilingual similarity search**](https://github.com/facebookresearch/LASER/tree/master/tasks/similarity) [1,6] +* [**Sentence embedding of text files**](https://github.com/facebookresearch/LASER/tree/master/tasks/embed) + example how to calculate sentence embeddings for arbitrary text files in any of the supported language. + +**For all tasks, we use exactly the same multilingual encoder, without any task specific optimization or fine-tuning.** + + + +## References + +[1] Holger Schwenk and Matthijs Douze, + [*Learning Joint Multilingual Sentence Representations with Neural Machine Translation*](https://aclanthology.info/papers/W17-2619/w17-2619), + ACL workshop on Representation Learning for NLP, 2017 + +[2] Holger Schwenk and Xian Li, + [*A Corpus for Multilingual Document Classification in Eight Languages*](http://www.lrec-conf.org/proceedings/lrec2018/pdf/658.pdf), + LREC, pages 3548-3551, 2018. + +[3] Holger Schwenk, + [*Filtering and Mining Parallel Data in a Joint Multilingual Space*](http://aclweb.org/anthology/P18-2037) + ACL, July 2018 + +[4] Alexis Conneau, Guillaume Lample, Ruty Rinott, Adina Williams, Samuel R. Bowman, Holger Schwenk and Veselin Stoyanov, + [*XNLI: Cross-lingual Sentence Understanding through Inference*](https://aclweb.org/anthology/D18-1269), + EMNLP, 2018. + +[5] Mikel Artetxe and Holger Schwenk, + [*Margin-based Parallel Corpus Mining with Multilingual Sentence Embeddings*](https://arxiv.org/abs/1811.01136) + arXiv, Nov 3 2018. + +[6] Mikel Artetxe and Holger Schwenk, + [*Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond*](https://arxiv.org/abs/1812.10464) + arXiv, Dec 26 2018. + +[7] Holger Schwenk, Vishrav Chaudhary, Shuo Sun, Hongyu Gong and Paco Guzman, + [*WikiMatrix: Mining 135M Parallel Sentences in 1620 Language Pairs from Wikipedia*](https://arxiv.org/abs/1907.05791) + arXiv, July 11 2019. + +[8] Holger Schwenk, Guillaume Wenzek, Sergey Edunov, Edouard Grave and Armand Joulin + [*CCMatrix: Mining Billions of High-Quality Parallel Sentences on the WEB*](https://arxiv.org/abs/1911.04944) diff --git a/examples/laser/laser_src/__init__.py b/examples/laser/laser_src/__init__.py new file mode 100644 index 0000000000..9ffbd656d8 --- /dev/null +++ b/examples/laser/laser_src/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .laser_task import * # noqa +from .laser_lstm import * # noqa +from .laser_transformer import * # noqa diff --git a/examples/laser/laser_src/laser_lstm.py b/examples/laser/laser_src/laser_lstm.py new file mode 100644 index 0000000000..10df90e002 --- /dev/null +++ b/examples/laser/laser_src/laser_lstm.py @@ -0,0 +1,585 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import options, utils + +from fairseq.models import ( + FairseqEncoder, + FairseqIncrementalDecoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) + + +@register_model("laser_lstm") +class LSTMModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens=None, + tgt_tokens=None, + tgt_lengths=None, + target_language_id=None, + dataset_name="", + ): + assert target_language_id is not None + + src_encoder_out = self.encoder(src_tokens, src_lengths, dataset_name) + return self.decoder( + prev_output_tokens, src_encoder_out, lang_id=target_language_id + ) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--dropout", + default=0.1, + type=float, + metavar="D", + help="dropout probability", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-embed-path", + default=None, + type=str, + metavar="STR", + help="path to pre-trained encoder embedding", + ) + parser.add_argument( + "--encoder-hidden-size", type=int, metavar="N", help="encoder hidden size" + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="number of encoder layers" + ) + parser.add_argument( + "--encoder-bidirectional", + action="store_true", + help="make all layers of encoder bidirectional", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-embed-path", + default=None, + type=str, + metavar="STR", + help="path to pre-trained decoder embedding", + ) + parser.add_argument( + "--decoder-hidden-size", type=int, metavar="N", help="decoder hidden size" + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="number of decoder layers" + ) + parser.add_argument( + "--decoder-out-embed-dim", + type=int, + metavar="N", + help="decoder output embedding dimension", + ) + parser.add_argument( + "--decoder-zero-init", + type=str, + metavar="BOOL", + help="initialize the decoder hidden/cell state to zero", + ) + parser.add_argument( + "--decoder-lang-embed-dim", + type=int, + metavar="N", + help="decoder language embedding dimension", + ) + parser.add_argument( + "--fixed-embeddings", + action="store_true", + help="keep embeddings fixed (ENCODER ONLY)", + ) # TODO Also apply to decoder embeddings? + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument( + "--encoder-dropout-in", + type=float, + metavar="D", + help="dropout probability for encoder input embedding", + ) + parser.add_argument( + "--encoder-dropout-out", + type=float, + metavar="D", + help="dropout probability for encoder output", + ) + parser.add_argument( + "--decoder-dropout-in", + type=float, + metavar="D", + help="dropout probability for decoder input embedding", + ) + parser.add_argument( + "--decoder-dropout-out", + type=float, + metavar="D", + help="dropout probability for decoder output", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted (in case there are any new ones) + base_architecture(args) + + def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(embed_path) + utils.print_embed_overlap(embed_dict, dictionary) + return utils.load_embedding(embed_dict, dictionary, embed_tokens) + + pretrained_encoder_embed = None + if args.encoder_embed_path: + pretrained_encoder_embed = load_pretrained_embedding_from_file( + args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim + ) + pretrained_decoder_embed = None + if args.decoder_embed_path: + pretrained_decoder_embed = load_pretrained_embedding_from_file( + args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim + ) + + num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0 + + encoder = LSTMEncoder( + dictionary=task.source_dictionary, + embed_dim=args.encoder_embed_dim, + hidden_size=args.encoder_hidden_size, + num_layers=args.encoder_layers, + dropout_in=args.encoder_dropout_in, + dropout_out=args.encoder_dropout_out, + bidirectional=args.encoder_bidirectional, + pretrained_embed=pretrained_encoder_embed, + fixed_embeddings=args.fixed_embeddings, + ) + decoder = LSTMDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + hidden_size=args.decoder_hidden_size, + out_embed_dim=args.decoder_out_embed_dim, + num_layers=args.decoder_layers, + dropout_in=args.decoder_dropout_in, + dropout_out=args.decoder_dropout_out, + zero_init=options.eval_bool(args.decoder_zero_init), + encoder_embed_dim=args.encoder_embed_dim, + encoder_output_units=encoder.output_units, + pretrained_embed=pretrained_decoder_embed, + num_langs=num_langs, + lang_embed_dim=args.decoder_lang_embed_dim, + ) + return cls(encoder, decoder) + + +class LSTMEncoder(FairseqEncoder): + """LSTM encoder.""" + + def __init__( + self, + dictionary, + embed_dim=512, + hidden_size=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + bidirectional=False, + left_pad=True, + pretrained_embed=None, + padding_value=0.0, + fixed_embeddings=False, + ): + super().__init__(dictionary) + self.num_layers = num_layers + self.dropout_in = dropout_in + self.dropout_out = dropout_out + self.bidirectional = bidirectional + self.hidden_size = hidden_size + + num_embeddings = len(dictionary) + self.padding_idx = dictionary.pad() + if pretrained_embed is None: + self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) + else: + self.embed_tokens = pretrained_embed + if fixed_embeddings: + self.embed_tokens.weight.requires_grad = False + + self.lstm = LSTM( + input_size=embed_dim, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=self.dropout_out if num_layers > 1 else 0.0, + bidirectional=bidirectional, + ) + self.left_pad = left_pad + self.padding_value = padding_value + + self.output_units = hidden_size + if bidirectional: + self.output_units *= 2 + + def forward(self, src_tokens, src_lengths, dataset_name): + if self.left_pad: + # convert left-padding to right-padding + src_tokens = utils.convert_padding_direction( + src_tokens, + self.padding_idx, + left_to_right=True, + ) + + bsz, seqlen = src_tokens.size() + + # embed tokens + x = self.embed_tokens(src_tokens) + x = F.dropout(x, p=self.dropout_in, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # pack embedded source tokens into a PackedSequence + try: + packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) + except BaseException: + raise Exception(f"Packing failed in dataset {dataset_name}") + + # apply LSTM + if self.bidirectional: + state_size = 2 * self.num_layers, bsz, self.hidden_size + else: + state_size = self.num_layers, bsz, self.hidden_size + h0 = x.data.new(*state_size).zero_() + c0 = x.data.new(*state_size).zero_() + packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) + + # unpack outputs and apply dropout + x, _ = nn.utils.rnn.pad_packed_sequence( + packed_outs, padding_value=self.padding_value + ) + x = F.dropout(x, p=self.dropout_out, training=self.training) + assert list(x.size()) == [seqlen, bsz, self.output_units] + + if self.bidirectional: + + def combine_bidir(outs): + return torch.cat( + [ + torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view( + 1, bsz, self.output_units + ) + for i in range(self.num_layers) + ], + dim=0, + ) + + final_hiddens = combine_bidir(final_hiddens) + final_cells = combine_bidir(final_cells) + + encoder_padding_mask = src_tokens.eq(self.padding_idx).t() + + # Set padded outputs to -inf so they are not selected by max-pooling + padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) + + # Build the sentence embedding by max-pooling over the encoder outputs + sentemb = x.max(dim=0)[0] + + return { + "sentemb": sentemb, + "encoder_out": (x, final_hiddens, final_cells), + "encoder_padding_mask": encoder_padding_mask + if encoder_padding_mask.any() + else None, + } + + def reorder_encoder_out(self, encoder_out_dict, new_order): + encoder_out_dict["sentemb"] = encoder_out_dict["sentemb"].index_select( + 0, new_order + ) + encoder_out_dict["encoder_out"] = tuple( + eo.index_select(1, new_order) for eo in encoder_out_dict["encoder_out"] + ) + if encoder_out_dict["encoder_padding_mask"] is not None: + encoder_out_dict["encoder_padding_mask"] = encoder_out_dict[ + "encoder_padding_mask" + ].index_select(1, new_order) + return encoder_out_dict + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return int(1e5) # an arbitrary large number + + +class LSTMDecoder(FairseqIncrementalDecoder): + """LSTM decoder.""" + + def __init__( + self, + dictionary, + embed_dim=512, + hidden_size=512, + out_embed_dim=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + zero_init=False, + encoder_embed_dim=512, + encoder_output_units=512, + pretrained_embed=None, + num_langs=1, + lang_embed_dim=0, + ): + super().__init__(dictionary) + self.dropout_in = dropout_in + self.dropout_out = dropout_out + self.hidden_size = hidden_size + + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + if pretrained_embed is None: + self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + else: + self.embed_tokens = pretrained_embed + + self.layers = nn.ModuleList( + [ + LSTMCell( + input_size=encoder_output_units + embed_dim + lang_embed_dim + if layer == 0 + else hidden_size, + hidden_size=hidden_size, + ) + for layer in range(num_layers) + ] + ) + if hidden_size != out_embed_dim: + self.additional_fc = Linear(hidden_size, out_embed_dim) + self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) + + if zero_init: + self.sentemb2init = None + else: + self.sentemb2init = Linear( + encoder_output_units, 2 * num_layers * hidden_size + ) + + if lang_embed_dim == 0: + self.embed_lang = None + else: + self.embed_lang = nn.Embedding(num_langs, lang_embed_dim) + nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1) + + def forward( + self, prev_output_tokens, encoder_out_dict, incremental_state=None, lang_id=0 + ): + sentemb = encoder_out_dict["sentemb"] + encoder_out = encoder_out_dict["encoder_out"] + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + bsz, seqlen = prev_output_tokens.size() + + # get outputs from encoder + encoder_outs, _, _ = encoder_out[:3] + srclen = encoder_outs.size(0) + + # embed tokens + x = self.embed_tokens(prev_output_tokens) + x = F.dropout(x, p=self.dropout_in, training=self.training) + + # embed language identifier + if self.embed_lang is not None: + lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id) + langemb = self.embed_lang(lang_ids) + # TODO Should we dropout here??? + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # initialize previous states (or get from cache during incremental generation) + cached_state = utils.get_incremental_state( + self, incremental_state, "cached_state" + ) + if cached_state is not None: + prev_hiddens, prev_cells, input_feed = cached_state + else: + num_layers = len(self.layers) + if self.sentemb2init is None: + prev_hiddens = [ + x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers) + ] + prev_cells = [ + x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers) + ] + else: + init = self.sentemb2init(sentemb) + prev_hiddens = [ + init[:, (2 * i) * self.hidden_size : (2 * i + 1) * self.hidden_size] + for i in range(num_layers) + ] + prev_cells = [ + init[ + :, + (2 * i + 1) * self.hidden_size : (2 * i + 2) * self.hidden_size, + ] + for i in range(num_layers) + ] + input_feed = x.data.new(bsz, self.hidden_size).zero_() + + attn_scores = x.data.new(srclen, seqlen, bsz).zero_() + outs = [] + for j in range(seqlen): + if self.embed_lang is None: + input = torch.cat((x[j, :, :], sentemb), dim=1) + else: + input = torch.cat((x[j, :, :], sentemb, langemb), dim=1) + + for i, rnn in enumerate(self.layers): + # recurrent cell + hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) + + # hidden state becomes the input to the next layer + input = F.dropout(hidden, p=self.dropout_out, training=self.training) + + # save state for next time step + prev_hiddens[i] = hidden + prev_cells[i] = cell + + out = hidden + out = F.dropout(out, p=self.dropout_out, training=self.training) + + # input feeding + input_feed = out + + # save final output + outs.append(out) + + # cache previous states (no-op except during incremental generation) + utils.set_incremental_state( + self, + incremental_state, + "cached_state", + (prev_hiddens, prev_cells, input_feed), + ) + + # collect outputs across time steps + x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) + + # T x B x C -> B x T x C + x = x.transpose(1, 0) + + # srclen x tgtlen x bsz -> bsz x tgtlen x srclen + attn_scores = attn_scores.transpose(0, 2) + + # project back to size of vocabulary + if hasattr(self, "additional_fc"): + x = self.additional_fc(x) + x = F.dropout(x, p=self.dropout_out, training=self.training) + x = self.fc_out(x) + + return x, attn_scores + + def reorder_incremental_state(self, incremental_state, new_order): + super().reorder_incremental_state(incremental_state, new_order) + cached_state = utils.get_incremental_state( + self, incremental_state, "cached_state" + ) + if cached_state is None: + return + + def reorder_state(state): + if isinstance(state, list): + return [reorder_state(state_i) for state_i in state] + return state.index_select(0, new_order) + + new_state = tuple(map(reorder_state, cached_state)) + utils.set_incremental_state(self, incremental_state, "cached_state", new_state) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + return int(1e5) # an arbitrary large number + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.uniform_(m.weight, -0.1, 0.1) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def LSTM(input_size, hidden_size, **kwargs): + m = nn.LSTM(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if "weight" in name or "bias" in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def LSTMCell(input_size, hidden_size, **kwargs): + m = nn.LSTMCell(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if "weight" in name or "bias" in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def Linear(in_features, out_features, bias=True, dropout=0): + """Weight-normalized Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features, bias=bias) + m.weight.data.uniform_(-0.1, 0.1) + if bias: + m.bias.data.uniform_(-0.1, 0.1) + return m + + +@register_model_architecture("laser_lstm", "laser_lstm") +def base_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_hidden_size = getattr( + args, "encoder_hidden_size", args.encoder_embed_dim + ) + args.encoder_layers = getattr(args, "encoder_layers", 1) + args.encoder_bidirectional = getattr(args, "encoder_bidirectional", False) + args.encoder_dropout_in = getattr(args, "encoder_dropout_in", args.dropout) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", args.dropout) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_hidden_size = getattr( + args, "decoder_hidden_size", args.decoder_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 1) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + args.decoder_zero_init = getattr(args, "decoder_zero_init", "0") + args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0) + args.fixed_embeddings = getattr(args, "fixed_embeddings", False) diff --git a/examples/laser/laser_src/laser_task.py b/examples/laser/laser_src/laser_task.py new file mode 100644 index 0000000000..c8ac805f54 --- /dev/null +++ b/examples/laser/laser_src/laser_task.py @@ -0,0 +1,326 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from collections import OrderedDict, defaultdict +import json +import os +import logging + +from fairseq import options, models +from fairseq.data import ( + data_utils, + Dictionary, + LanguagePairDataset, + IndexedDataset, + FairseqDataset, +) +from .multitask_data_utils import ( + MultitaskDatasetWrapper, + MultidatasetEpochBatchIterator, +) + + +from fairseq.tasks import LegacyFairseqTask, register_task + +logger = logging.getLogger(__name__) + + +@register_task("laser") +class LaserTask(LegacyFairseqTask): + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "configfile", metavar="PATH", help="dataset configuration file in json" + ) + parser.add_argument( + "--weighting-alpha", + type=float, + default=None, + help="alpha for automatic weighting", + ) + parser.add_argument( + "--raw-text", action="store_true", help="load raw text dataset" + ) + parser.add_argument( + "--left-pad-source", + default="True", + type=str, + metavar="BOOL", + help="pad the source on the left (default: True)", + ) + parser.add_argument( + "--left-pad-target", + default="False", + type=str, + metavar="BOOL", + help="pad the target on the left (default: False)", + ) + parser.add_argument( + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + + def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks): + super().__init__(args) + self.config = config + self.src_dictionary = src_dictionary + self.tgt_dictionary = tgt_dictionary + self.num_tasks = num_tasks + + @classmethod + def setup_task(cls, args, **kwargs): + with open(args.configfile, "r") as f: + config = json.load(f) + num_tasks = max(dataset["id"] for dataset in config["train"]) + 1 + + args.left_pad_source = options.eval_bool(args.left_pad_source) + args.left_pad_target = options.eval_bool(args.left_pad_target) + + src_dictionary = Dictionary.load(config["src_vocab"]) + tgt_dictionary = Dictionary.load(config["tgt_vocab"]) + + logger.info( + "| src Dictionary {} : {} types".format( + config["src_vocab"], len(src_dictionary) + ) + ) + logger.info( + "| tgt Dictionary {} : {} types".format( + config["tgt_vocab"], len(tgt_dictionary) + ) + ) + + return cls(args, config, src_dictionary, tgt_dictionary, num_tasks) + + # Experimental overriding for backtranslation + def build_model(self, args): + model = models.build_model(args, self) + return model + + def dataset(self, split): + if split not in self.datasets: + raise KeyError("Dataset not loaded: " + split) + return self.datasets[split] + + def load_dataset(self, split, epoch=1, **kwargs): + """Load a dataset split.""" + + def indexed_dataset(path, dictionary): + if self.args.raw_text: + raise Exception("Unable to handle raw text.") + dataset = IndexedDataset(path, fix_lua_indexing=True) + + return dataset + + pair_datasets = OrderedDict() + + if split == "valid": + self.datasets[split] = pair_datasets + return + + if split not in self.config: + raise FileNotFoundError( + "Dataset not found in config file: {}".format(split) + ) + + size_by_corpus = defaultdict(int) + size_sum = 0 + size_sum_with_subsampling = 0 + init_pair_datasets = {} + + for dataset_config in self.config[split]: + src_path = os.path.dirname(dataset_config["src"]) + corpus_name = src_path.split("/")[-2] + language_pair_name = src_path.split("/")[-1] + pair_datasets_key = corpus_name + "-" + language_pair_name + + logger.info(f"loading... {pair_datasets_key}") + if "src" in dataset_config: + src_dataset = indexed_dataset( + dataset_config["src"], self.src_dictionary + ) + else: + src_dataset = None + + if "tgt" in dataset_config: + tgt_dataset = indexed_dataset( + dataset_config["tgt"], self.tgt_dictionary + ) + else: + tgt_dataset = None + + dataset = LanguagePairDataset( + src_dataset, + src_dataset.sizes, + self.src_dictionary, + tgt_dataset, + tgt_dataset.sizes, + self.tgt_dictionary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + ) + + if pair_datasets_key in init_pair_datasets: + logger.warning( + f"Ignoring already added {pair_datasets_key}. " + f"Consider using `sample` key in order to upsample." + ) + else: + init_pair_datasets[pair_datasets_key] = { + "dataset": dataset, + "sample": dataset_config.get("sample", None), + "id": dataset_config.get("id", None), + "len": len(dataset), + } + + length_sum = 0 + weighted_freqs_sum = 0 + freq_per_dataset = {} + vmax = 0 + vmin = 1 + weighted_freq_per_dataset = {} + + if self.args.weighting_alpha: + for key in init_pair_datasets: + if init_pair_datasets[key]["sample"] is None: + length_sum += len(init_pair_datasets[key]["dataset"]) + + for key in init_pair_datasets: + if init_pair_datasets[key]["sample"] is None: + val = float(init_pair_datasets[key]["len"]) / length_sum + freq_per_dataset[key] = val + weighted_freqs_sum += val ** self.args.weighting_alpha + + for key in freq_per_dataset: + val = ( + freq_per_dataset[key] ** self.args.weighting_alpha + / weighted_freqs_sum + ) + vmin = min(vmin, val) + vmax = max(vmax, val) + weighted_freq_per_dataset[key] = val + + for pair_datasets_key in init_pair_datasets: + dataset_config = init_pair_datasets[pair_datasets_key] + dataset = dataset_config["dataset"] + sample = dataset_config["sample"] + if sample is None: + sample = 1.0 + + if pair_datasets_key in weighted_freq_per_dataset: + w = vmax / weighted_freq_per_dataset[pair_datasets_key] + sample = w + + sample = round(sample) + + initial_sample = sample + initial_pair_datasets_key = pair_datasets_key + + while sample >= 1.0: + assert ( + pair_datasets_key not in pair_datasets + ), f"{pair_datasets_key} already in" + size_sum_with_subsampling += len(dataset) + pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper( + dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key + ) + size_sum += len(dataset) + sample -= 1.0 + pair_datasets_key += "-up" + + assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}" + + logger.info( + f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}" + ) + size_by_corpus[corpus_name] += len(dataset) + + self.datasets[split] = pair_datasets + logger.info( + f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}" + ) + + @property + def source_dictionary(self): + return self.src_dictionary + + @property + def target_dictionary(self): + return self.tgt_dictionary + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + ): + + assert isinstance(dataset, OrderedDict) + assert len(dataset) + assert isinstance(dataset[next(iter(dataset))], FairseqDataset) + + # initialize the dataset with the correct starting epoch + for _, dt in dataset.items(): + dt.set_epoch(epoch) + + indices = OrderedDict() + batch_sampler = OrderedDict() + + with data_utils.numpy_seed(seed + epoch): + for key, dt in dataset.items(): + logger.info(f"\t ordered_indices {key}") + indices[key] = dt.ordered_indices() + + # filter examples that are too large + if max_positions is not None: + for key, dt in dataset.items(): + logger.info(f"\t filter_by_size {key}") + indices[key], ignored = dt.filter_indices_by_size( + indices[key], max_positions + ) + + for key, dt in dataset.items(): + logger.info(f"\t batch_by_size {key}") + batch_sampler[key] = data_utils.batch_by_size( + indices[key], + dt.num_tokens, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + + epoch_iter = MultidatasetEpochBatchIterator( + dataset=dataset, + batch_sampler=batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + ) + + return epoch_iter diff --git a/examples/laser/laser_src/laser_transformer.py b/examples/laser/laser_src/laser_transformer.py new file mode 100644 index 0000000000..0be030994f --- /dev/null +++ b/examples/laser/laser_src/laser_transformer.py @@ -0,0 +1,354 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import Any, Dict, List, Optional +from torch import Tensor + +import torch +import torch.nn as nn + +from fairseq.models import ( + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import ( + base_architecture, + Embedding, + TransformerModel, + TransformerEncoder, + TransformerDecoder, +) +from fairseq.modules import ( + TransformerDecoderLayer, +) + +logger = logging.getLogger(__name__) + + +@register_model("laser_transformer") +class LaserTransformerModel(FairseqEncoderDecoderModel): + """Train Transformer for LASER task + + Requires --task laser + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens=None, + tgt_tokens=None, + tgt_lengths=None, + target_language_id=-1, + dataset_name="", + ): + laser_encoder_out = self.encoder(src_tokens, src_lengths) + return self.decoder( + prev_output_tokens, laser_encoder_out, lang_id=target_language_id + ) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + TransformerModel.add_args(parser) + parser.add_argument( + "--decoder-lang-embed-dim", + type=int, + metavar="N", + help="decoder language embedding dimension", + ) + + @classmethod + def build_model(cls, args, task): + base_laser_transformer_architecture(args) + + num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0 + + def load_embed_tokens(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + + return Embedding(num_embeddings, embed_dim, padding_idx) + + encoder_embed_tokens = load_embed_tokens( + task.source_dictionary, args.encoder_embed_dim + ) + decoder_embed_tokens = load_embed_tokens( + task.target_dictionary, args.decoder_embed_dim + ) + num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0 + + encoder = LaserTransformerEncoder( + args, task.source_dictionary, encoder_embed_tokens + ) + + decoder = LaserTransformerDecoder( + args, + task.target_dictionary, + decoder_embed_tokens, + num_langs=num_langs, + lang_embed_dim=args.decoder_lang_embed_dim, + ) + + return cls(encoder, decoder) + + +class LaserTransformerEncoder(TransformerEncoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, src_tokens, *args, **kwargs): + encoder_out = super().forward(src_tokens, *args, **kwargs) + + x = encoder_out["encoder_out"][0] # T x B x C + padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) + + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) + + # Build the sentence embedding by max-pooling over the encoder outputs + sentemb = x.max(dim=0)[0] + + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `foward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return {"sentemb": [sentemb]} # B x C + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + """ + Same as the one in transformer.py, with new_sentemb + """ + if len(encoder_out["sentemb"]) == 0: + new_sentemb = [] + else: + new_sentemb = [encoder_out["sentemb"][0].index_select(0, new_order)] + + return { + "sentemb": new_sentemb, # B x C + } + + +class LaserTransformerDecoder(TransformerDecoder): + def __init__(self, args, dictionary, *kargs, **kwargs): + self.num_langs = kwargs.get("num_langs", 1) + self.lang_embed_dim = kwargs.get("lang_embed_dim", 0) + kwargs.pop("num_langs", None) + kwargs.pop("lang_embed_dim", None) + + super().__init__(args, dictionary, *kargs, **kwargs, no_encoder_attn=True) + + if self.lang_embed_dim == 0: + self.embed_lang = None + else: + self.embed_lang = nn.Embedding(self.num_langs, self.lang_embed_dim) + nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1) + + if self.output_projection is not None: + laser_output_embed_dim = ( + self.output_embed_dim + self.lang_embed_dim + args.encoder_embed_dim + ) + self.output_projection = nn.Linear( + laser_output_embed_dim, len(dictionary), bias=False + ) + nn.init.normal_( + self.output_projection.weight, + mean=0, + std=laser_output_embed_dim ** -0.5, + ) + + def build_decoder_layer(self, args, no_encoder_attn=False): + decoder_embed_dim = args.decoder_embed_dim + args.decoder_embed_dim = ( + decoder_embed_dim + self.lang_embed_dim + args.encoder_embed_dim + ) + res = TransformerDecoderLayer(args, no_encoder_attn=True) + args.decoder_embed_dim = decoder_embed_dim + + return res + + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + lang_id: Optional[int] = None, + ): + """ + Similar to *forward* but only return features. + + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + bsz, seqlen = prev_output_tokens.size() + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + if self.embed_lang is not None: + lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id) + langemb = self.embed_lang(lang_ids) + langemb = langemb.unsqueeze(0) + repeat_vals = [x.shape[0] // langemb.shape[0]] + [-1] * ( + len(langemb.shape) - 1 + ) + x = torch.cat((x, langemb.expand(*repeat_vals)), dim=-1) + + sentemb = encoder_out["sentemb"][0] + sentemb = sentemb.unsqueeze(0) + + repeat_vals = [x.shape[0] // sentemb.shape[0]] + [-1] * (len(sentemb.shape) - 1) + x = torch.cat((x, sentemb.expand(*repeat_vals)), dim=-1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + + # decoder layers + attn: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, _ = layer( + x, + None, + None, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": [attn], "inner_states": inner_states} + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + lang_id: Optional[int] = None, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + assert lang_id is not None + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + lang_id=lang_id, + ) + if not features_only: + x = self.output_layer(x) + return x, extra + + +@register_model_architecture("laser_transformer", "laser_transformer") +def base_laser_transformer_architecture(args): + base_architecture(args) + args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0) diff --git a/examples/laser/laser_src/multitask_data_utils.py b/examples/laser/laser_src/multitask_data_utils.py new file mode 100644 index 0000000000..b05caea267 --- /dev/null +++ b/examples/laser/laser_src/multitask_data_utils.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import numpy as np + +from fairseq.data import BaseWrapperDataset, FairseqDataset, iterators + + +class MultiItr(object): + def __init__(self, itr): + self.itr = itr + self._counts = [0 for x in itr] + + def __len__(self): + return sum(len(itr) for itr in self.itr) + + def __iter__(self): + return self + + def __next__(self): + ratios = [count / len(itr) for count, itr in zip(self._counts, self.itr)] + idx = ratios.index(min(ratios)) + self._counts[idx] += 1 + return next(self.itr[idx]) + + +class MultidatasetEpochBatchIterator(iterators.EpochBatchIterating): + """A wrapper around multiple epoch batch iterators.""" + + def __init__( + self, + dataset, + batch_sampler, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + ): + + assert isinstance(dataset, OrderedDict) + assert len(dataset) + assert isinstance(dataset[next(iter(dataset))], FairseqDataset) + + self.iterators = [] + + self.epoch = epoch + for key, dt in dataset.items(): + epoch_iter = iterators.EpochBatchIterator( + dataset=dt, + collate_fn=dt.collater, + batch_sampler=batch_sampler[key], + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=0, + epoch=epoch, + ) + self.iterators.append(epoch_iter) + + def __len__(self): + return sum(len(itr) for itr in self.iterators) + + def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): + # `self.epoch += 1` should be handled by underlying `EpochBatchIterator`s. + return MultiItr( + [ + itr.next_epoch_itr( + shuffle=shuffle, fix_batches_to_gpus=fix_batches_to_gpus + ) + for itr in self.iterators + ] + ) + + def end_of_epoch(self): + return all(itr.end_of_epoch() for itr in self.iterators) + + @property + def next_epoch_idx(self): + """Return the epoch index after *next_epoch_itr* is called.""" + + epochs = [itr.next_epoch_idx for itr in self.iterators] + self.epoch = epochs[0] + assert all(epoch == self.epoch for epoch in epochs) + + return self.epoch + + @property + def iterations_in_epoch(self): + return sum(itr.iterations_in_epoch for itr in self.iterators) + + def state_dict(self): + return { + "iterators": [it.state_dict() for it in self.iterators], + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict["epoch"] + for it, d in zip(self.iterators, state_dict["iterators"]): + it.load_state_dict(d) + + +class MultitaskDatasetWrapper(BaseWrapperDataset): + """A wrapper for a multitask dataset.""" + + def __init__(self, dataset, target_language_id, sample=1.0, name=""): + super().__init__(dataset) + self.target_language_id = target_language_id + self.sample = sample + self.name = name + + def collater(self, *args, **kwargs): + ans = self.dataset.collater(*args, **kwargs) + if "net_input" in ans: + ans["net_input"]["target_language_id"] = self.target_language_id + ans["net_input"]["dataset_name"] = self.name + return ans + + def num_tokens(self, *args, **kwargs): + return self.dataset.num_tokens(*args, **kwargs) + + def ordered_indices(self, *args, **kwargs): + indices = self.dataset.ordered_indices(*args, **kwargs) + # Hacky solution for sampling + size = int(self.sample * indices.shape[0]) + + return indices.take(np.sort(np.random.permutation(indices.shape[0])[:size])) + + def size(self, index: int): + return self.dataset.size(index) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 981ffd49cd..3cb98897bf 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -22,6 +22,7 @@ preprocess_lm_data, preprocess_summarization_data, preprocess_translation_data, + create_laser_data_and_config_json, train_translation_model, ) @@ -935,6 +936,65 @@ def test_alignment(self): ) generate_main(data_dir) + def test_laser_lstm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_laser_lstm") as data_dir: + laser_config_file = create_laser_data_and_config_json(data_dir) + train_translation_model( + laser_config_file.name, + "laser_lstm", + [ + "--user-dir", + "examples/laser/laser_src", + "--weighting-alpha", + "0.3", + "--encoder-bidirectional", + "--encoder-hidden-size", + "512", + "--encoder-layers", + "5", + "--decoder-layers", + "1", + "--encoder-embed-dim", + "320", + "--decoder-embed-dim", + "320", + "--decoder-lang-embed-dim", + "32", + "--save-dir", + data_dir, + "--disable-validation", + ], + task="laser", + lang_flags=[], + ) + + def test_laser_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_laser_transformer") as data_dir: + laser_config_file = create_laser_data_and_config_json(data_dir) + train_translation_model( + laser_config_file.name, + "laser_transformer", + [ + "--user-dir", + "examples/laser/laser_src", + "--weighting-alpha", + "0.3", + "--encoder-embed-dim", + "320", + "--decoder-embed-dim", + "320", + "--decoder-lang-embed-dim", + "32", + "--save-dir", + data_dir, + "--disable-validation", + ], + task="laser", + lang_flags=[], + ) + def test_alignment_full_context(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_alignment") as data_dir: diff --git a/tests/utils.py b/tests/utils.py index 178df5763e..1bf6f8d7f3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import json import os import random import sys @@ -274,6 +275,43 @@ def preprocess_summarization_data(data_dir, extra_flags=None): preprocess.main(preprocess_args) +def create_laser_data_and_config_json(data_dir): + src_langs = ["de", "fr", "ru", "tr", "zh"] + tgt_langs = ["en", "es"] + config_json = {} + config_train_json = [] + src_vocab = None + tgt_vocab = None + + for src_lang in src_langs: + for tgt_lang in tgt_langs: + langpair_folder = f"{src_lang}-{tgt_lang}" + + langpair_path = os.path.join(data_dir, langpair_folder) + os.mkdir(langpair_path) + create_dummy_data(langpair_path) + preprocess_translation_data(langpair_path, ["--dataset-impl", "cached"]) + + src_vocab = os.path.join(langpair_path, "dict.in.txt") + tgt_vocab = os.path.join(langpair_path, "dict.out.txt") + config_train_json.append( + { + "id": 0 if tgt_lang == "en" else 1, + "src": os.path.join(langpair_path, "train.in-out.in"), + "tgt": os.path.join(langpair_path, "train.in-out.out"), + } + ) + + config_json["src_vocab"] = src_vocab + config_json["tgt_vocab"] = tgt_vocab + config_json["train"] = config_train_json + + with open(os.path.join(data_dir, "laserconfig.json"), "w") as config_file: + json.dump(config_json, config_file) + + return config_file + + def train_translation_model( data_dir, arch, From 3bc43c17d14c4b9f6b052a915f9589cd538bc8b6 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 18 Feb 2021 13:10:02 -0800 Subject: [PATCH 20/82] Fix speed regression after RobertaEncoder refactor (#1626) Summary: Add back a couple speed optimizations in the original roberta code that got lost in the refactor Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1626 Reviewed By: gwenzek Differential Revision: D26478534 Pulled By: myleott fbshipit-source-id: b945de5e9bffd51cd63630cc3aa1f0078a41cca8 --- fairseq/models/transformer.py | 9 ++++++--- fairseq/modules/transformer_layer.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 4960fd143d..605cfa65e8 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -434,10 +434,11 @@ def forward( hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ - x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) - # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) + has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any()) + + x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) # account for padding while computing the representation if encoder_padding_mask is not None: @@ -453,7 +454,9 @@ def forward( # encoder layers for layer in self.layers: - x = layer(x, encoder_padding_mask) + x = layer( + x, encoder_padding_mask=encoder_padding_mask if has_pads else None + ) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 03e70f4279..f9ada37bde 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -103,7 +103,7 @@ def upgrade_state_dict_named(self, state_dict, name): state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] del state_dict[k] - def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): + def forward(self, x, encoder_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor] = None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` @@ -135,6 +135,7 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): key=x, value=x, key_padding_mask=encoder_padding_mask, + need_weights=False, attn_mask=attn_mask, ) x = self.dropout_module(x) From da9eaba12d82b9bfc1442f0e2c6fc1b895f4d35d Mon Sep 17 00:00:00 2001 From: Elizabeth Salesky Date: Thu, 18 Feb 2021 13:58:56 -0800 Subject: [PATCH 21/82] Add support for multi-channel audio and example for mTEDx data (#3253) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? - updates audio_utils to handle multi-channel audio as well as mono, with no change needed for existing recipes - adds speech-to-text example for Multilingual TEDx (http://openslr.org/100) data ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3253 Reviewed By: yuntang Differential Revision: D26514419 Pulled By: kahne fbshipit-source-id: 699e428affda5b1347f96a8310691ab152dd6769 --- examples/speech_to_text/README.md | 2 + examples/speech_to_text/docs/mtedx_example.md | 200 +++++++++++++++ examples/speech_to_text/prep_mtedx_data.py | 235 ++++++++++++++++++ fairseq/data/audio/audio_utils.py | 10 +- .../models/speech_to_text/s2t_transformer.py | 9 + 5 files changed, 455 insertions(+), 1 deletion(-) create mode 100644 examples/speech_to_text/docs/mtedx_example.md create mode 100644 examples/speech_to_text/prep_mtedx_data.py diff --git a/examples/speech_to_text/README.md b/examples/speech_to_text/README.md index 4b6f89d105..988ed83d77 100644 --- a/examples/speech_to_text/README.md +++ b/examples/speech_to_text/README.md @@ -36,6 +36,8 @@ audio paths (one per line) as inputs. - [Speech-to-Text Translation (ST) on CoVoST 2](docs/covost_example.md) +- [Speech-to-Text Translation (ST) on Multilingual TEDx](docs/mtedx_example.md) + ## Updates - 02/04/2021: Added interactive decoding (`fairseq-interactive`) support. Examples: [ASR (LibriSpeech)](docs/librispeech_example.md#interactive-decoding) diff --git a/examples/speech_to_text/docs/mtedx_example.md b/examples/speech_to_text/docs/mtedx_example.md new file mode 100644 index 0000000000..c0e17db9a2 --- /dev/null +++ b/examples/speech_to_text/docs/mtedx_example.md @@ -0,0 +1,200 @@ +[[Back]](..) + +# S2T Example: Speech Translation (ST) on Multilingual TEDx + +[Multilingual TEDx](https://arxiv.org/abs/2102.01757) is multilingual corpus for speech recognition and +speech translation. The data is derived from TEDx talks in 8 source languages +with translations to a subset of 5 target languages. + +## Data Preparation +[Download](http://openslr.org/100/) and unpack Multilingual TEDx data to a path +`${MTEDX_ROOT}/${LANG_PAIR}`, then preprocess it with +```bash +# additional Python packages for S2T data processing/model training +pip install pandas torchaudio sentencepiece + +# Generate TSV manifests, features, vocabulary +# and configuration for each language +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task asr \ + --vocab-type unigram --vocab-size 1000 +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task st \ + --vocab-type unigram --vocab-size 1000 + +# Add vocabulary and configuration for joint data +# (based on the manifests and features generated above) +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task asr --joint \ + --vocab-type unigram --vocab-size 8000 +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task st --joint \ + --vocab-type unigram --vocab-size 8000 +``` +The generated files (manifest, features, vocabulary and data configuration) will be added to +`${MTEDX_ROOT}/${LANG_PAIR}` (per-language data) and `MTEDX_ROOT` (joint data). + + +## ASR +#### Training +Spanish as example: +```bash +fairseq-train ${MTEDX_ROOT}/es-es \ + --config-yaml config_asr.yaml --train-subset train_asr --valid-subset valid_asr \ + --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_xs --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --load-pretrained-encoder-from ${PRETRAINED_ENCODER} \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 +``` +For joint model (using ASR data from all 8 languages): +```bash +fairseq-train ${MTEDX_ROOT} \ + --config-yaml config_asr.yaml \ + --train-subset train_es-es_asr,train_fr-fr_asr,train_pt-pt_asr,train_it-it_asr,train_ru-ru_asr,train_el-el_asr,train_ar-ar_asr,train_de-de_asr \ + --valid-subset valid_es-es_asr,valid_fr-fr_asr,valid_pt-pt_asr,valid_it-it_asr,valid_ru-ru_asr,valid_el-el_asr,valid_ar-ar_asr,valid_de-de_asr \ + --save-dir ${MULTILINGUAL_ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 \ + --ignore-prefix-size 1 +``` +where `MULTILINGUAL_ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs +with 1 GPU. You may want to update it accordingly when using more than 1 GPU. +For multilingual models, we prepend target language ID token as target BOS, which should be excluded from +the training loss via `--ignore-prefix-size 1`. + +#### Inference & Evaluation +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +fairseq-generate ${MTEDX_ROOT}/es-es \ + --config-yaml config_asr.yaml --gen-subset test --task speech_to_text \ + --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ + --skip-invalid-size-inputs-valid-test \ + --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct --remove-bpe + +# For models trained on joint data +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${MULTILINGUAL_ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${MULTILINGUAL_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +for LANG in es fr pt it ru el ar de; do + fairseq-generate ${MTEDX_ROOT} \ + --config-yaml config_asr.yaml --gen-subset test_${LANG}-${LANG}_asr --task speech_to_text \ + --prefix-size 1 --path ${MULTILINGUAL_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 40000 --beam 5 \ + --skip-invalid-size-inputs-valid-test \ + --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct --remove-bpe +done +``` +#### Results +| Data | --arch | Params | Es | Fr | Pt | It | Ru | El | Ar | De | +|--------------|--------------------|--------|------|------|------|------|------|-------|-------|-------| +| Monolingual | s2t_transformer_xs | 10M | 46.4 | 45.6 | 54.8 | 48.0 | 74.7 | 109.5 | 104.4 | 111.1 | + + +## ST +#### Training +Es-En as example: +```bash +fairseq-train ${MTEDX_ROOT}/es-en \ + --config-yaml config_st.yaml --train-subset train_st --valid-subset valid_st \ + --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_xs --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --load-pretrained-encoder-from ${PRETRAINED_ENCODER} \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 +``` +For multilingual model (all 12 directions): +```bash +fairseq-train ${MTEDX_ROOT} \ + --config-yaml config_st.yaml \ + --train-subset train_el-en_st,train_es-en_st,train_es-fr_st,train_es-it_st,train_es-pt_st,train_fr-en_st,train_fr-es_st,train_fr-pt_st,train_it-en_st,train_it-es_st,train_pt-en_st,train_pt-es_st,train_ru-en_st \ + --valid-subset valid_el-en_st,valid_es-en_st,valid_es-fr_st,valid_es-it_st,valid_es-pt_st,valid_fr-en_st,valid_fr-es_st,valid_fr-pt_st,valid_it-en_st,valid_it-es_st,valid_pt-en_st,valid_pt-es_st,valid_ru-en_st \ + --save-dir ${MULTILINGUAL_ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 \ + --ignore-prefix-size 1 \ + --load-pretrained-encoder-from ${PRETRAINED_ENCODER} +``` +where `ST_SAVE_DIR` (`MULTILINGUAL_ST_SAVE_DIR`) is the checkpoint root path. The ST encoder is pre-trained by ASR +for faster training and better performance: `--load-pretrained-encoder-from <(JOINT_)ASR checkpoint path>`. We set +`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU. +For multilingual models, we prepend target language ID token as target BOS, which should be excluded from +the training loss via `--ignore-prefix-size 1`. + +#### Inference & Evaluation +Average the last 10 checkpoints and evaluate on the `test` split: +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +fairseq-generate ${MTEDX_ROOT}/es-en \ + --config-yaml config_st.yaml --gen-subset test --task speech_to_text \ + --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 50000 --beam 5 --scoring sacrebleu --remove-bpe + +# For multilingual models +python scripts/average_checkpoints.py \ + --inputs ${MULTILINGUAL_ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +for LANGPAIR in es-en es-fr es-pt fr-en fr-es fr-pt pt-en pt-es it-en it-es ru-en el-en; do + fairseq-generate ${MTEDX_ROOT} \ + --config-yaml config_st.yaml --gen-subset test_${LANGPAIR}_st --task speech_to_text \ + --prefix-size 1 --path ${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 40000 --beam 5 \ + --skip-invalid-size-inputs-valid-test \ + --scoring sacrebleu --remove-bpe +done +``` +For multilingual models, we force decoding from the target language ID token (as BOS) via `--prefix-size 1`. + +#### Results +| Data | --arch | Params | Es-En | Es-Pt | Es-Fr | Fr-En | Fr-Es | Fr-Pt | Pt-En | Pt-Es | It-En | It-Es | Ru-En | El-En | +|--------------|--------------------|-----|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------| +| Bilingual | s2t_transformer_xs | 10M | 7.0 | 12.2 | 1.7 | 8.9 | 10.6 | 7.9 | 8.1 | 8.7 | 6.4 | 1.0 | 0.7 | 0.6 | +| Multilingual | s2t_transformer_s | 31M | 12.3 | 17.4 | 6.1 | 12.0 | 13.6 | 13.2 | 12.0 | 13.7 | 10.7 | 13.1 | 0.6 | 0.8 | + + +## Citation +Please cite as: +``` +@misc{salesky2021mtedx, + title={Multilingual TEDx Corpus for Speech Recognition and Translation}, + author={Elizabeth Salesky and Matthew Wiesner and Jacob Bremerman and Roldano Cattoni and Matteo Negri and Marco Turchi and Douglas W. Oard and Matt Post}, + year={2021}, +} + +@inproceedings{wang2020fairseqs2t, + title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq}, + author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino}, + booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations}, + year = {2020}, +} + +@inproceedings{ott2019fairseq, + title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, + author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, + booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, + year = {2019}, +} +``` + +[[Back]](..) diff --git a/examples/speech_to_text/prep_mtedx_data.py b/examples/speech_to_text/prep_mtedx_data.py new file mode 100644 index 0000000000..6c37398fcc --- /dev/null +++ b/examples/speech_to_text/prep_mtedx_data.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +from pathlib import Path +import shutil +from itertools import groupby +from tempfile import NamedTemporaryFile +from typing import Tuple + +import pandas as pd +import torchaudio +from examples.speech_to_text.data_utils import ( + create_zip, + extract_fbank_features, + filter_manifest_df, + gen_config_yaml, + gen_vocab, + get_zip_manifest, + load_df_from_tsv, + save_df_to_tsv, +) +from torch import Tensor +from torch.utils.data import Dataset +from tqdm import tqdm + + +log = logging.getLogger(__name__) + + +MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker", "tgt_lang"] + + +class mTEDx(Dataset): + """ + Create a Dataset for Multilingual TEDx. + Each item is a tuple of the form: waveform, sample_rate, source utterance, + target utterance, speaker_id, utterance_id + """ + + SPLITS = ["train", "valid", "test"] + LANGPAIRS = ["es-es", "fr-fr", "pt-pt", "it-it", "ru-ru", "el-el", "ar-ar", "de-de", + "es-en", "es-fr", "es-pt", "es-it", "fr-en", "fr-es", "fr-pt", + "pt-en", "pt-es", "it-en", "it-es", "ru-en", "el-en"] + + def __init__(self, root: str, lang: str, split: str) -> None: + assert split in self.SPLITS and lang in self.LANGPAIRS + _root = Path(root) / f"{lang}" / "data" / split + wav_root, txt_root = _root / "wav", _root / "txt" + assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir() + # Load audio segments + try: + import yaml + except ImportError: + print("Please install PyYAML to load the Multilingual TEDx YAML files") + with open(txt_root / f"{split}.yaml") as f: + segments = yaml.load(f, Loader=yaml.BaseLoader) + # Load source and target utterances + src, tgt = lang.split("-") + for _lang in [src, tgt]: + with open(txt_root / f"{split}.{_lang}") as f: + utterances = [r.strip() for r in f] + assert len(segments) == len(utterances) + for i, u in enumerate(utterances): + segments[i][_lang] = u + # Gather info + self.data = [] + for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): + wav_filename = wav_filename.replace(".wav", ".flac") + wav_path = wav_root / wav_filename + sample_rate = torchaudio.info(wav_path.as_posix())[0].rate + seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) + for i, segment in enumerate(seg_group): + offset = int(float(segment["offset"]) * sample_rate) + n_frames = int(float(segment["duration"]) * sample_rate) + _id = f"{wav_path.stem}_{i}" + self.data.append( + ( + wav_path.as_posix(), + offset, + n_frames, + sample_rate, + segment[src], + segment[tgt], + segment["speaker_id"], + tgt, + _id, + ) + ) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str, str]: + wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id = self.data[n] + waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames) + return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id + + def __len__(self) -> int: + return len(self.data) + + +def process(args): + root = Path(args.data_root).absolute() + for lang in mTEDx.LANGPAIRS: + cur_root = root / f"{lang}" + if not cur_root.is_dir(): + print(f"{cur_root.as_posix()} does not exist. Skipped.") + continue + # Extract features + feature_root = cur_root / "fbank80" + feature_root.mkdir(exist_ok=True) + for split in mTEDx.SPLITS: + print(f"Fetching split {split}...") + dataset = mTEDx(root.as_posix(), lang, split) + print("Extracting log mel filter bank features...") + for waveform, sample_rate, _, _, _, _, utt_id in tqdm(dataset): + extract_fbank_features( + waveform, sample_rate, feature_root / f"{utt_id}.npy" + ) + # Pack features into ZIP + zip_path = cur_root / "fbank80.zip" + print("ZIPing features...") + create_zip(feature_root, zip_path) + print("Fetching ZIP manifest...") + zip_manifest = get_zip_manifest(zip_path) + # Generate TSV manifest + print("Generating manifest...") + train_text = [] + for split in mTEDx.SPLITS: + is_train_split = split.startswith("train") + manifest = {c: [] for c in MANIFEST_COLUMNS} + dataset = mTEDx(args.data_root, lang, split) + for wav, sr, src_utt, tgt_utt, speaker_id, tgt_lang, utt_id in tqdm(dataset): + manifest["id"].append(utt_id) + manifest["audio"].append(zip_manifest[utt_id]) + duration_ms = int(wav.size(1) / sr * 1000) + manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) + manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt) + manifest["speaker"].append(speaker_id) + manifest["tgt_lang"].append(tgt_lang) + if is_train_split: + train_text.extend(manifest["tgt_text"]) + df = pd.DataFrame.from_dict(manifest) + df = filter_manifest_df(df, is_train_split=is_train_split) + save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv") + # Generate vocab + v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) + spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}" + with NamedTemporaryFile(mode="w") as f: + for t in train_text: + f.write(t + "\n") + gen_vocab( + Path(f.name), + cur_root / spm_filename_prefix, + args.vocab_type, + args.vocab_size, + ) + # Generate config YAML + gen_config_yaml( + cur_root, + spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy="lb", + ) + # Clean up + shutil.rmtree(feature_root) + + +def process_joint(args): + cur_root = Path(args.data_root) + assert all((cur_root / f"{lang}").is_dir() for lang in mTEDx.LANGPAIRS), \ + "do not have downloaded data available for all languages" + # Generate vocab + vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) + spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}" + with NamedTemporaryFile(mode="w") as f: + for lang in mTEDx.LANGPAIRS: + tsv_path = cur_root / f"{lang}" / f"train_{args.task}.tsv" + df = load_df_from_tsv(tsv_path) + for t in df["tgt_text"]: + f.write(t + "\n") + special_symbols = None + if args.joint: + # Add tgt_lang tags to dict + special_symbols = list({f'' for lang in mTEDx.LANGPAIRS}) + gen_vocab( + Path(f.name), + cur_root / spm_filename_prefix, + args.vocab_type, + args.vocab_size, + special_symbols=special_symbols + ) + # Generate config YAML + gen_config_yaml( + cur_root, + spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy="ld", + prepend_tgt_lang_tag=(args.joint), + ) + # Make symbolic links to manifests + for lang in mTEDx.LANGPAIRS: + for split in mTEDx.SPLITS: + src_path = cur_root / f"{lang}" / f"{split}_{args.task}.tsv" + desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv" + if not desc_path.is_symlink(): + os.symlink(src_path, desc_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-root", "-d", required=True, type=str) + parser.add_argument( + "--vocab-type", + default="unigram", + required=True, + type=str, + choices=["bpe", "unigram", "char"], + ), + parser.add_argument("--vocab-size", default=8000, type=int) + parser.add_argument("--task", type=str, choices=["asr", "st"]) + parser.add_argument("--joint", action="store_true", help="") + args = parser.parse_args() + + if args.joint: + process_joint(args) + else: + process(args) + + +if __name__ == "__main__": + main() diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index de08669851..f0e75b1d65 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -56,8 +56,16 @@ def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarr try: import torch import torchaudio.compliance.kaldi as ta_kaldi + import torchaudio.sox_effects as ta_sox + + waveform = torch.from_numpy(waveform) + if len(waveform.shape) == 1: + # Mono channel: D -> 1 x D + waveform = waveform.unsqueeze(0) + else: + # Merge multiple channels to one: C x D -> 1 x D + waveform, _ = ta_sox.apply_effects_tensor(waveform, sample_rate, ['channels', '1']) - waveform = torch.from_numpy(waveform).unsqueeze(0) features = ta_kaldi.fbank( waveform, num_mel_bins=n_bins, sample_frequency=sample_rate ) diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index 1f556107a2..814924ec97 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -422,6 +422,15 @@ def s2t_transformer_s(args): base_architecture(args) +@register_model_architecture("s2t_transformer", "s2t_transformer_xs") +def s2t_transformer_xs(args): + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4) + args.dropout = getattr(args, "dropout", 0.3) + s2t_transformer_s(args) + + @register_model_architecture("s2t_transformer", "s2t_transformer_sp") def s2t_transformer_sp(args): args.encoder_layers = getattr(args, "encoder_layers", 16) From 284a86a49a054dcace1e66ee4c65dfb4adb5a39f Mon Sep 17 00:00:00 2001 From: Weiyi Zheng Date: Thu, 18 Feb 2021 16:35:02 -0800 Subject: [PATCH 22/82] remove the missing _device property Summary: after D26382917 (https://github.com/pytorch/fairseq/commit/02803a1be45642b4c2f9c2970a4f4ae645a2dccf) shipped somehow the self._device was removed in optimizer, (or maybe I didn't test it the right way in the previous diff?) fortunately OSS doesn't need it any way. Reviewed By: myleott Differential Revision: D26523538 fbshipit-source-id: 637c1e344670340ae40b32635ef51f5501966b0c --- fairseq/optim/shard.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py index 3c1b34ae60..9d7f2eb9e5 100644 --- a/fairseq/optim/shard.py +++ b/fairseq/optim/shard.py @@ -45,7 +45,6 @@ def broadcast_global_state_dict( state_dict, src_rank=0, group=self.group, - dist_device=self._device, ) torch_optimizer = optimizer.optimizer From d2ee5883e774700c41b1eaddd0326e9afa6d3cd2 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Thu, 18 Feb 2021 22:41:32 -0800 Subject: [PATCH 23/82] Simultaneous Speech Translation Model (#1607) Summary: This is the pull request for the code for the paper [SimulMT to SimulST: Adapting Simultaneous Text Translation to End-to-End Simultaneous Speech Translation](https://www.aclweb.org/anthology/2020.aacl-main.58/) The model will also be used for [IWSLT 2021 shared task on simultaneous translation ](https://iwslt.org/2021/simultaneous) This pull request includes - Convtransformer offline model - Convtransformer simultaneous translation model with fixed pre-decision module - The agent files for inference for the convtransformer simultaneous translation model jmp84 The README is still missing. Just curious where should I place it? Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1607 Test Plan: Imported from GitHub, without a `Test Plan:` line. ********** One of the failing landing integration tests ``` buck test mode/dev //multimo/fb/models/test:multimo_fb_model_test https://fburl.com/testinfra/oxq2cn5n ``` Reviewed By: jmp84 Differential Revision: D26439663 Pulled By: sravyapopuri388 fbshipit-source-id: b127cb4962756af221b65e3ccb6598a42fc75f7f --- .../models/transformer_monotonic_attention.py | 28 +- .../modules/fixed_pre_decision.py | 170 +++++++ .../modules/monotonic_multihead_attention.py | 343 ++++++------- .../modules/monotonic_transformer_layer.py | 8 + .../utils/data_utils.py | 100 ++++ examples/speech_to_text/README.md | 1 + .../docs/simulst_mustc_example.md | 52 ++ .../agents/fairseq_simul_st_agent.py | 331 +++++++++++++ .../agents/simul_trans_agent.py | 200 ++++++++ fairseq/models/speech_to_text/__init__.py | 2 + .../models/speech_to_text/convtransformer.py | 452 ++++++++++++++++++ .../convtransformer_simul_trans.py | 49 ++ 12 files changed, 1560 insertions(+), 176 deletions(-) create mode 100644 examples/simultaneous_translation/modules/fixed_pre_decision.py create mode 100644 examples/simultaneous_translation/utils/data_utils.py create mode 100644 examples/speech_to_text/docs/simulst_mustc_example.md create mode 100644 examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py create mode 100644 examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py create mode 100644 fairseq/models/speech_to_text/convtransformer.py create mode 100644 fairseq/models/speech_to_text/convtransformer_simul_trans.py diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index ab8adf3aab..dd3895f0eb 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -10,17 +10,20 @@ TransformerMonotonicDecoderLayer, TransformerMonotonicEncoderLayer, ) -from fairseq.models import register_model, register_model_architecture +from fairseq.models import ( + register_model, + register_model_architecture, +) from fairseq.models.transformer import ( - TransformerDecoder, - TransformerEncoder, TransformerModel, + TransformerEncoder, + TransformerDecoder, base_architecture, transformer_iwslt_de_en, transformer_vaswani_wmt_en_de_big, + transformer_vaswani_wmt_en_fr_big, ) - DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -33,7 +36,7 @@ def build_encoder(cls, args, src_dict, embed_tokens): @register_model("transformer_monotonic") -class TransformerMonotonicModel(TransformerModel): +class TransformerModelSimulTrans(TransformerModel): @classmethod def build_encoder(cls, args, src_dict, embed_tokens): return TransformerMonotonicEncoder(args, src_dict, embed_tokens) @@ -178,13 +181,18 @@ def pre_attention( if positions is not None: x += positions + x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) - encoder_out = encoder_out_dict.encoder_out - encoder_padding_mask = encoder_out_dict.encoder_padding_mask + encoder_out = encoder_out_dict["encoder_out"][0] + encoder_padding_mask = ( + encoder_out_dict["encoder_padding_mask"][0] + if len(encoder_out_dict["encoder_padding_mask"]) > 0 + else None + ) return x, encoder_out, encoder_padding_mask @@ -236,7 +244,7 @@ def extract_features( attn_list.append(attn) if incremental_state is not None: - curr_steps = layer.get_steps(incremental_state) + curr_steps = layer.get_head_steps(incremental_state) step_list.append(curr_steps) if incremental_state.get("online", False): @@ -287,7 +295,7 @@ def reorder_incremental_state(self, incremental_state, new_order): @register_model_architecture("transformer_monotonic", "transformer_monotonic") -def base_monotonic_rchitecture(args): +def base_monotonic_architecture(args): base_architecture(args) args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False) @@ -297,7 +305,7 @@ def base_monotonic_rchitecture(args): ) def transformer_monotonic_iwslt_de_en(args): transformer_iwslt_de_en(args) - base_monotonic_rchitecture(args) + base_monotonic_architecture(args) # parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) diff --git a/examples/simultaneous_translation/modules/fixed_pre_decision.py b/examples/simultaneous_translation/modules/fixed_pre_decision.py new file mode 100644 index 0000000000..2cde55b35e --- /dev/null +++ b/examples/simultaneous_translation/modules/fixed_pre_decision.py @@ -0,0 +1,170 @@ +from functools import partial + +import torch +import torch.nn.functional as F + +from . import register_monotonic_attention +from .monotonic_multihead_attention import ( + MonotonicMultiheadAttentionWaitK, + MonotonicMultiheadAttentionHardAligned, + MonotonicMultiheadAttentionInfiniteLookback, +) + + +def fixed_pooling_monotonic_attention(monotonic_attention): + def create_model(monotonic_attention, klass): + class FixedStrideMonotonicAttention(monotonic_attention): + def __init__(self, args): + super().__init__(args) + self.pre_decision_type = args.fixed_pre_decision_type + self.pre_decision_ratio = args.fixed_pre_decision_ratio + self.pre_decision_pad_threshold = args.fixed_pre_decision_pad_threshold + if self.pre_decision_ratio == 1: + return + + if args.fixed_pre_decision_type == "average": + self.pooling_layer = torch.nn.AvgPool1d( + kernel_size=self.pre_decision_ratio, + stride=self.pre_decision_ratio, + ceil_mode=True, + ) + elif args.fixed_pre_decision_type == "last": + + def last(key): + if key.size(2) < self.pre_decision_ratio: + return key + else: + k = key[ + :, + :, + self.pre_decision_ratio - 1 :: self.pre_decision_ratio, + ].contiguous() + if key.size(-1) % self.pre_decision_ratio != 0: + k = torch.cat([k, key[:, :, -1:]], dim=-1).contiguous() + return k + + self.pooling_layer = last + else: + raise NotImplementedError + + @staticmethod + def add_args(parser): + super( + FixedStrideMonotonicAttention, FixedStrideMonotonicAttention + ).add_args(parser) + parser.add_argument( + "--fixed-pre-decision-ratio", + type=int, + required=True, + help=( + "Ratio for the fixed pre-decision," + "indicating how many encoder steps will start" + "simultaneous decision making process." + ), + ) + parser.add_argument( + "--fixed-pre-decision-type", + default="average", + choices=["average", "last"], + help="Pooling type", + ) + parser.add_argument( + "--fixed-pre-decision-pad-threshold", + type=float, + default=0.3, + help="If a part of the sequence has pad" + ",the threshold the pooled part is a pad.", + ) + + def insert_zeros(self, x): + bsz_num_heads, tgt_len, src_len = x.size() + stride = self.pre_decision_ratio + weight = F.pad(x.new_ones(1, 1, 1), (stride - 1, 0)) + x_upsample = F.conv_transpose1d( + x.view(-1, src_len).unsqueeze(1), + weight, + stride=stride, + padding=0, + ) + return x_upsample.squeeze(1).view(bsz_num_heads, tgt_len, -1) + + def p_choose( + self, + query, + key, + key_padding_mask=None, + incremental_state=None, + **extra_args + ): + + if self.pre_decision_ratio == 1: + return super().p_choose( + self, + query, + key, + key_padding_mask=None, + incremental_state=None, + **extra_args + ) + + key_pool = self.pooling_layer(key.transpose(0, 2)).transpose(0, 2) + + if key_padding_mask is not None: + key_padding_mask_pool = ( + self.pooling_layer(key_padding_mask.unsqueeze(0).float()) + .squeeze(0) + .gt(self.pre_decision_pad_threshold) + ) + # Make sure at least one element is not pad + key_padding_mask_pool[:, 0] = 0 + else: + key_padding_mask_pool = None + + p_choose_pooled = super().p_choose( + query, + key_pool, + key_padding_mask_pool, + incremental_state=incremental_state, + ) + + # Upsample, interpolate zeros + p_choose = self.insert_zeros(p_choose_pooled) + + # can be larger than src_len because we used ceil before + src_len = key.size(0) + p_choose = p_choose[:, :, :src_len] + p_choose[:, :, -1] = p_choose_pooled[:, :, -1] + + tgt_len = query.size(0) + batch_size = query.size(1) + + assert list(p_choose.size()) == [ + batch_size * self.num_heads, + tgt_len, + src_len, + ] + + return p_choose + + FixedStrideMonotonicAttention.__name__ = klass.__name__ + return FixedStrideMonotonicAttention + + return partial(create_model, monotonic_attention) + + +@register_monotonic_attention("waitk_fixed_pre_decision") +@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionWaitK) +class MonotonicMultiheadAttentionWaitkFixedStride: + pass + + +@register_monotonic_attention("hard_aligned_fixed_pre_decision") +@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionHardAligned) +class MonotonicMultiheadAttentionHardFixedStride: + pass + + +@register_monotonic_attention("infinite_lookback_fixed_pre_decision") +@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionInfiniteLookback) +class MonotonicMultiheadAttentionInfiniteLookbackFixedStride: + pass diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index c09725ac9a..5423f26c34 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn + import torch.nn.functional as F from examples.simultaneous_translation.utils.functions import ( exclusive_cumprod, @@ -30,6 +31,7 @@ def __init__(self, args): self.eps = args.attention_eps self.mass_preservation = args.mass_preservation + self.noise_type = args.noise_type self.noise_mean = args.noise_mean self.noise_var = args.noise_var @@ -43,23 +45,26 @@ def __init__(self, args): @staticmethod def add_args(parser): # fmt: off - parser.add_argument('--no-mass-preservation', action="store_false", dest="mass_preservation", + parser.add_argument('--no-mass-preservation', action="store_false", + dest="mass_preservation", help='Do not stay on the last token when decoding') - parser.add_argument('--mass-preservation', action="store_true", dest="mass_preservation", + parser.add_argument('--mass-preservation', action="store_true", + dest="mass_preservation", help='Stay on the last token when decoding') parser.set_defaults(mass_preservation=True) - parser.add_argument('--noise-var', type=float, default=1.0, help='Variance of discretness noise') parser.add_argument('--noise-mean', type=float, default=0.0, help='Mean of discretness noise') - parser.add_argument('--energy-bias', action="store_true", default=False, + parser.add_argument('--noise-type', type=str, default="flat", + help='Type of discretness noise') + parser.add_argument('--energy-bias', action="store_true", + default=False, help='Bias for energy') parser.add_argument('--energy-bias-init', type=float, default=-2.0, help='Initial value of the bias for energy') parser.add_argument('--attention-eps', type=float, default=1e-6, help='Epsilon when calculating expected attention') - # fmt: on def p_choose(self, *args): raise NotImplementedError @@ -67,7 +72,9 @@ def p_choose(self, *args): def input_projections(self, *args): raise NotImplementedError - def attn_energy(self, q_proj, k_proj, key_padding_mask=None): + def attn_energy( + self, q_proj, k_proj, key_padding_mask=None, attn_mask=None + ): """ Calculating monotonic energies @@ -82,7 +89,13 @@ def attn_energy(self, q_proj, k_proj, key_padding_mask=None): bsz = bsz // self.num_heads src_len = k_proj.size(1) - attn_energy = torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias + attn_energy = ( + torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias + ) + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_energy += attn_mask attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len) @@ -102,7 +115,7 @@ def expected_alignment_train(self, p_choose, key_padding_mask): q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} a_ij = p_ij q_ij - parellel solution: + Parallel solution: ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) ============================================================ @@ -139,21 +152,40 @@ def expected_alignment_train(self, p_choose, key_padding_mask): if self.mass_preservation: # Last token has the residual probabilities - alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) - - assert not torch.isnan(alpha).any(), "NaN detected in alpha." + if key_padding_mask is not None and key_padding_mask[:, -1].any(): + # right padding + batch_size = key_padding_mask.size(0) + residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0) + src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True) + src_lens = src_lens.expand( + batch_size, self.num_heads + ).contiguous().view(-1, 1) + src_lens = src_lens.expand(-1, tgt_len).contiguous() + # add back the last value + residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1) + alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals) + else: + residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) + alpha[:, :, -1] = residuals + + if torch.isnan(alpha).any(): + # Something is wrong + raise RuntimeError("NaN in alpha.") return alpha - def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state): + def expected_alignment_infer( + self, p_choose, encoder_padding_mask, incremental_state + ): + # TODO modify this function """ Calculating mo alignment for MMA during inference time ============================================================ Expected input size p_choose: bsz * num_heads, tgt_len, src_len - key_padding_mask: bsz * src_len incremental_state: dict + encodencoder_padding_mask: bsz * src_len """ # p_choose: bsz * self.num_heads, src_len bsz_num_heads, tgt_len, src_len = p_choose.size() @@ -166,7 +198,8 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state # prev_monotonic_step: bsz, num_heads bsz = bsz_num_heads // self.num_heads prev_monotonic_step = monotonic_cache.get( - "step", p_choose.new_zeros([bsz, self.num_heads]).long() + "head_step", + p_choose.new_zeros([bsz, self.num_heads]).long() ) bsz, num_heads = prev_monotonic_step.size() assert num_heads == self.num_heads @@ -175,8 +208,9 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state # p_choose: bsz, num_heads, src_len p_choose = p_choose.view(bsz, num_heads, src_len) - if key_padding_mask is not None: - src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long() + if encoder_padding_mask is not None: + src_lengths = src_len - \ + encoder_padding_mask.sum(dim=1, keepdim=True).long() else: src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len @@ -186,16 +220,16 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state new_monotonic_step = prev_monotonic_step step_offset = 0 - if key_padding_mask is not None: - if key_padding_mask[:, 0].any(): + if encoder_padding_mask is not None: + if encoder_padding_mask[:, 0].any(): # left_pad_source = True: - step_offset = key_padding_mask.sum(dim=-1, keepdim=True) + step_offset = encoder_padding_mask.sum(dim=-1, keepdim=True) max_steps = src_lengths - 1 if self.mass_preservation else src_lengths # finish_read: bsz, num_heads finish_read = new_monotonic_step.eq(max_steps) - + p_choose_i = 1 while finish_read.sum().item() < bsz * self.num_heads: # p_choose: bsz * self.num_heads, src_len # only choose the p at monotonic steps @@ -224,23 +258,34 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state new_monotonic_step += action finish_read = new_monotonic_step.eq(max_steps) | (action == 0) - # finish_read = (~ (finish_read.sum(dim=1, keepdim=True) < self.num_heads / 2)) | finish_read - monotonic_cache["step"] = new_monotonic_step + if p_choose_i is None: + import pdb;pdb.set_trace() + + monotonic_cache["head_step"] = new_monotonic_step + # Whether a head is looking for new input + monotonic_cache["head_read"] = ( + new_monotonic_step.eq(max_steps) & (p_choose_i < 0.5) + ) # alpha: bsz * num_heads, 1, src_len # new_monotonic_step: bsz, num_heads - alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter( - 1, - (step_offset + new_monotonic_step) - .view(bsz * self.num_heads, 1) - .clamp(0, src_len - 1), - 1, + alpha = ( + p_choose + .new_zeros([bsz * self.num_heads, src_len]) + .scatter( + 1, + (step_offset + new_monotonic_step) + .view(bsz * self.num_heads, 1).clamp(0, src_len - 1), + 1 + ) ) if not self.mass_preservation: alpha = alpha.masked_fill( - (new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0 + (new_monotonic_step == max_steps) + .view(bsz * self.num_heads, 1), + 0 ) alpha = alpha.unsqueeze(1) @@ -249,18 +294,28 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state return alpha + def _get_monotonic_buffer(self, incremental_state): + return utils.get_incremental_state( + self, + incremental_state, + 'monotonic', + ) or {} + + def _set_monotonic_buffer(self, incremental_state, buffer): + utils.set_incremental_state( + self, + incremental_state, + 'monotonic', + buffer, + ) + def v_proj_output(self, value): raise NotImplementedError def forward( - self, - query, - key, - value, - key_padding_mask=None, - incremental_state=None, - *args, - **kwargs, + self, query, key, value, + key_padding_mask=None, attn_mask=None, incremental_state=None, + need_weights=True, static_kv=False, *args, **kwargs ): tgt_len, bsz, embed_dim = query.size() @@ -268,26 +323,31 @@ def forward( # stepwise prob # p_choose: bsz * self.num_heads, tgt_len, src_len - p_choose = self.p_choose(query, key, key_padding_mask) + p_choose = self.p_choose( + query, key, key_padding_mask, incremental_state, + ) # expected alignment alpha # bsz * self.num_heads, tgt_len, src_len if incremental_state is not None: alpha = self.expected_alignment_infer( - p_choose, key_padding_mask, incremental_state - ) + p_choose, key_padding_mask, incremental_state) else: - alpha = self.expected_alignment_train(p_choose, key_padding_mask) + alpha = self.expected_alignment_train( + p_choose, key_padding_mask) # expected attention beta # bsz * self.num_heads, tgt_len, src_len beta = self.expected_attention( - alpha, query, key, value, key_padding_mask, incremental_state + alpha, query, key, value, + key_padding_mask, attn_mask, + incremental_state ) attn_weights = beta v_proj = self.v_proj_output(value) + attn = torch.bmm(attn_weights.type_as(v_proj), v_proj) attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) @@ -298,67 +358,17 @@ def forward( alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len) p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len) - return attn, {"alpha": alpha, "beta": beta, "p_choose": p_choose} - - def reorder_incremental_state(self, incremental_state, new_order): - """Reorder buffered internal state (for incremental generation).""" - super().reorder_incremental_state(incremental_state, new_order) - input_buffer = self._get_monotonic_buffer(incremental_state) - if input_buffer is not None: - for k in input_buffer.keys(): - input_buffer[k] = input_buffer[k].index_select(0, new_order) - self._set_monotonic_buffer(incremental_state, input_buffer) - - def _get_monotonic_buffer(self, incremental_state): - return ( - utils.get_incremental_state( - self, - incremental_state, - "monotonic", - ) - or {} - ) - - def _set_monotonic_buffer(self, incremental_state, buffer): - utils.set_incremental_state( - self, - incremental_state, - "monotonic", - buffer, - ) - - def get_pointer(self, incremental_state): - return ( - utils.get_incremental_state( - self, - incremental_state, - "monotonic", - ) - or {} - ) - - def get_fastest_pointer(self, incremental_state): - return self.get_pointer(incremental_state)["step"].max(0)[0] - - def set_pointer(self, incremental_state, p_choose): - curr_pointer = self.get_pointer(incremental_state) - if len(curr_pointer) == 0: - buffer = torch.zeros_like(p_choose) - else: - buffer = self.get_pointer(incremental_state)["step"] - - buffer += (p_choose < 0.5).type_as(buffer) - - utils.set_incremental_state( - self, - incremental_state, - "monotonic", - {"step": buffer}, - ) + return attn, { + "alpha": alpha, + "beta": beta, + "p_choose": p_choose, + } @register_monotonic_attention("hard_aligned") -class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention): +class MonotonicMultiheadAttentionHardAligned( + MonotonicAttention, MultiheadAttention +): def __init__(self, args): MultiheadAttention.__init__( self, @@ -392,39 +402,36 @@ def input_projections(self, query, key, value, name): bsz = query.size(1) q = self.q_in_proj[name](query) q *= self.scaling - q = ( - q.contiguous() - .view(-1, bsz * self.num_heads, self.head_dim) - .transpose(0, 1) - ) + q = q.contiguous().view( + -1, bsz * self.num_heads, self.head_dim + ).transpose(0, 1) else: q = None if key is not None: bsz = key.size(1) k = self.k_in_proj[name](key) - k = ( - k.contiguous() - .view(-1, bsz * self.num_heads, self.head_dim) - .transpose(0, 1) - ) + k = k.contiguous().view( + -1, bsz * self.num_heads, self.head_dim + ).transpose(0, 1) else: k = None if value is not None: bsz = value.size(1) v = self.v_in_proj[name](value) - v = ( - v.contiguous() - .view(-1, bsz * self.num_heads, self.head_dim) - .transpose(0, 1) - ) + v = v.contiguous().view( + -1, bsz * self.num_heads, self.head_dim + ).transpose(0, 1) else: v = None return q, k, v - def p_choose(self, query, key, key_padding_mask=None): + def p_choose( + self, query, key, key_padding_mask=None, + incremental_state=None, *extra_args + ): """ Calculating step wise prob for reading and writing 1 to read, 0 to write @@ -440,7 +447,9 @@ def p_choose(self, query, key, key_padding_mask=None): """ # prepare inputs - q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic") + q_proj, k_proj, _ = self.input_projections( + query, key, None, "monotonic" + ) # attention energy attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask) @@ -473,7 +482,9 @@ def v_proj_output(self, value): @register_monotonic_attention("infinite_lookback") -class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHard): +class MonotonicMultiheadAttentionInfiniteLookback( + MonotonicMultiheadAttentionHardAligned +): def __init__(self, args): super().__init__(args) self.init_soft_attention() @@ -498,30 +509,33 @@ def init_soft_attention(self): nn.init.xavier_uniform_(self.q_in_proj["soft"].weight) def expected_attention( - self, alpha, query, key, value, key_padding_mask, incremental_state + self, alpha, query, key, value, + key_padding_mask, attn_mask, incremental_state ): # monotonic attention, we will calculate milk here bsz_x_num_heads, tgt_len, src_len = alpha.size() bsz = int(bsz_x_num_heads / self.num_heads) q, k, _ = self.input_projections(query, key, None, "soft") - soft_energy = self.attn_energy(q, k, key_padding_mask) + soft_energy = self.attn_energy(q, k, key_padding_mask, attn_mask) - assert list(soft_energy.size()) == [bsz, self.num_heads, tgt_len, src_len] + assert list(soft_energy.size()) == \ + [bsz, self.num_heads, tgt_len, src_len] soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len) if incremental_state is not None: monotonic_cache = self._get_monotonic_buffer(incremental_state) - monotonic_step = monotonic_cache["step"] + 1 + monotonic_length = monotonic_cache["head_step"] + 1 step_offset = 0 if key_padding_mask is not None: if key_padding_mask[:, 0].any(): # left_pad_source = True: step_offset = key_padding_mask.sum(dim=-1, keepdim=True) - monotonic_step += step_offset + monotonic_length += step_offset mask = lengths_to_mask( - monotonic_step.view(-1), soft_energy.size(2), 1 + monotonic_length.view(-1), + soft_energy.size(2), 1 ).unsqueeze(1) soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf")) @@ -531,84 +545,81 @@ def expected_attention( beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2) else: - # bsz * num_heads, tgt_len, src_len soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] - exp_soft_energy = torch.exp(soft_energy) - exp_soft_energy_cumsum = torch.cumsum(exp_soft_energy, dim=2) + exp_soft_energy = torch.exp(soft_energy) + self.eps + inner_items = alpha / (torch.cumsum(exp_soft_energy, dim=2)) + + beta = ( + exp_soft_energy + * torch.cumsum(inner_items.flip(dims=[2]), dim=2) + .flip(dims=[2]) + ) + + beta = beta.view(bsz, self.num_heads, tgt_len, src_len) if key_padding_mask is not None: - if key_padding_mask.any(): - exp_soft_energy_cumsum = ( - exp_soft_energy_cumsum.view( - -1, self.num_heads, tgt_len, src_len - ) - .masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps - ) - .view(-1, tgt_len, src_len) - ) - - inner_items = alpha / exp_soft_energy_cumsum - - beta = exp_soft_energy * torch.cumsum( - inner_items.flip(dims=[2]), dim=2 - ).flip(dims=[2]) + beta = beta.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).bool(), 0) + beta = beta / beta.sum(dim=3, keepdim=True) + beta = beta.view(bsz * self.num_heads, tgt_len, src_len) beta = self.dropout_module(beta) - assert not torch.isnan(beta).any(), "NaN detected in beta." + if torch.isnan(beta).any(): + # Something is wrong + raise RuntimeError("NaN in beta.") return beta @register_monotonic_attention("waitk") -class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookback): +class MonotonicMultiheadAttentionWaitK( + MonotonicMultiheadAttentionInfiniteLookback +): def __init__(self, args): super().__init__(args) self.q_in_proj["soft"] = self.q_in_proj["monotonic"] self.k_in_proj["soft"] = self.k_in_proj["monotonic"] self.waitk_lagging = args.waitk_lagging - assert ( - self.waitk_lagging > 0 - ), f"Lagging has to been larger than 0, get {self.waitk_lagging}." + assert self.waitk_lagging > 0, ( + f"Lagging has to been larger than 0, get {self.waitk_lagging}." + ) @staticmethod def add_args(parser): super( - MonotonicMultiheadAttentionWaitk, - MonotonicMultiheadAttentionWaitk, + MonotonicMultiheadAttentionWaitK, + MonotonicMultiheadAttentionWaitK, ).add_args(parser) parser.add_argument( - "--waitk-lagging", type=int, required=True, help="Wait k lagging" + "--waitk-lagging", type=int, required=True, help="Wait K lagging" ) def p_choose( - self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None + self, query, key, key_padding_mask=None, + incremental_state=None, *extra_args ): """ query: bsz, tgt_len key: bsz, src_len key_padding_mask: bsz, src_len """ - src_len, bsz, _ = key.size() - tgt_len, bsz, _ = query.size() + if incremental_state is not None: + tgt_len = int(incremental_state["steps"]["tgt"]) + src_len = int(incremental_state["steps"]["src"]) + bsz = 1 + else: + src_len, bsz, _ = key.size() + tgt_len, bsz, _ = query.size() + p_choose = query.new_ones(bsz, tgt_len, src_len) p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1) p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1) - if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any(): - # Left pad source - # add -1 to the end - p_choose = p_choose.masked_fill( - key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1 - ) - p_choose = convert_padding_direction( - p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True - ) - p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query) - # remove -1 - p_choose[p_choose.eq(-1)] = 0 + if incremental_state is not None: + p_choose = p_choose[:, -1:] + tgt_len = 1 # Extend to each head p_choose = ( diff --git a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py index 442b7d487d..e6e1850a18 100644 --- a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py +++ b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py @@ -26,11 +26,19 @@ def __init__( add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) + + assert args.simul_type is not None, "A --simul-type is needed." + self.encoder_attn = build_monotonic_attention(args) self.encoder_attn_layer_norm = LayerNorm( self.embed_dim, export=getattr(args, "char_inputs", False) ) + def get_head_steps(self, incremental_state): + return self.encoder_attn._get_monotonic_buffer(incremental_state).get( + "head_step" + ) + def prune_incremental_state(self, incremental_state): def prune(module): input_buffer = module._get_input_buffer(incremental_state) diff --git a/examples/simultaneous_translation/utils/data_utils.py b/examples/simultaneous_translation/utils/data_utils.py new file mode 100644 index 0000000000..cc4729e63c --- /dev/null +++ b/examples/simultaneous_translation/utils/data_utils.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def calc_mean_invstddev(feature): + if len(feature.size()) != 2: + raise ValueError("We expect the input feature to be 2-D tensor") + mean = feature.mean(0) + var = feature.var(0) + # avoid division by ~zero + eps = 1e-8 + if (var < eps).any(): + return mean, 1.0 / (torch.sqrt(var) + eps) + return mean, 1.0 / torch.sqrt(var) + + +def apply_mv_norm(features): + # If there is less than 2 spectrograms, the variance cannot be computed (is NaN) + # and normalization is not possible, so return the item as it is + if features.size(0) < 2: + return features + mean, invstddev = calc_mean_invstddev(features) + res = (features - mean) * invstddev + return res + + +def lengths_to_encoder_padding_mask(lengths, batch_first=False): + """ + convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor + + Args: + lengths: a (B, )-shaped tensor + + Return: + max_length: maximum length of B sequences + encoder_padding_mask: a (max_length, B) binary mask, where + [t, b] = 0 for t < lengths[b] and 1 otherwise + + TODO: + kernelize this function if benchmarking shows this function is slow + """ + max_lengths = torch.max(lengths).item() + bsz = lengths.size(0) + encoder_padding_mask = torch.arange( + max_lengths + ).to( # a (T, ) tensor with [0, ..., T-1] + lengths.device + ).view( # move to the right device + 1, max_lengths + ).expand( # reshape to (1, T)-shaped tensor + bsz, -1 + ) >= lengths.view( # expand to (B, T)-shaped tensor + bsz, 1 + ).expand( + -1, max_lengths + ) + if not batch_first: + return encoder_padding_mask.t(), max_lengths + else: + return encoder_padding_mask, max_lengths + + +def encoder_padding_mask_to_lengths( + encoder_padding_mask, max_lengths, batch_size, device +): + """ + convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor + + Conventionally, encoder output contains a encoder_padding_mask, which is + a 2-D mask in a shape (T, B), whose (t, b) element indicate whether + encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we + need to convert this mask tensor to a 1-D tensor in shape (B, ), where + [b] denotes the valid length of b-th sequence + + Args: + encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None, + indicating all are valid + Return: + seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the + number of valid elements of b-th sequence + + max_lengths: maximum length of all sequence, if encoder_padding_mask is + not None, max_lengths must equal to encoder_padding_mask.size(0) + + batch_size: batch size; if encoder_padding_mask is + not None, max_lengths must equal to encoder_padding_mask.size(1) + + device: which device to put the result on + """ + if encoder_padding_mask is None: + return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device) + + assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match" + assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match" + + return max_lengths - torch.sum(encoder_padding_mask, dim=0) diff --git a/examples/speech_to_text/README.md b/examples/speech_to_text/README.md index 988ed83d77..f639d300d3 100644 --- a/examples/speech_to_text/README.md +++ b/examples/speech_to_text/README.md @@ -37,6 +37,7 @@ audio paths (one per line) as inputs. - [Speech-to-Text Translation (ST) on CoVoST 2](docs/covost_example.md) - [Speech-to-Text Translation (ST) on Multilingual TEDx](docs/mtedx_example.md) +- [Simultaneous Speech-to-Text Translation (SimulST) on MuST-C](docs/simulst_mustc_example.md) ## Updates - 02/04/2021: Added interactive decoding (`fairseq-interactive`) support. Examples: diff --git a/examples/speech_to_text/docs/simulst_mustc_example.md b/examples/speech_to_text/docs/simulst_mustc_example.md new file mode 100644 index 0000000000..5dea0d8475 --- /dev/null +++ b/examples/speech_to_text/docs/simulst_mustc_example.md @@ -0,0 +1,52 @@ +# Simultaneous Speech Translation (SimulST) on MuST-C + +[MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with 8-language translations on English TED talks. + +## Data Preparation & ASR +Please follow the steps in offline [speech-to-text](../mustc_example.md) translation for data preparation and ASR pretraining. + +## Training + +#### Wait-K(K=3) with fixed pre-decision module +``` + fairseq-train ${MUSTC_ROOT}/en-de \ + --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ + --save-dir ${ST_SAVE_DIR} --num-workers 8 \ + --optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \ + --criterion label_smoothed_cross_entropy \ + --warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --task speech_to_text \ + --arch convtransformer_simul_trans_espnet \ + --simul-type waitk_fixed_pre_decision \ + --waitk-lagging 3 \ + --fixed-pre-decision-ratio 7 +``` +#### Monotonic multihead attention with fixed pre-decision module +``` + fairseq-train ${MUSTC_ROOT}/en-de \ + --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ + --save-dir ${ST_SAVE_DIR} --num-workers 8 \ + --optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \ + --warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --task speech_to_text \ + --criterion latency_augmented_label_smoothed_cross_entropy \ + --latency-weight-avg 0.1 \ + --arch convtransformer_simul_trans_espnet \ + --simul-type infinite_lookback_fixed_pre_decision \ + --fixed-pre-decision-ratio 7 +``` +## Inference & Evaluation +[SimulEval](https://github.com/facebookresearch/SimulEval) is used for evaluation. +``` +simuleval \ + --agent ${FAIRSEQ}/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py + --src-file ${SRC_LIST_OF_AUDIO} + --tgt-file ${TGT_FILE} + --data-bin ${MUSTC_ROOT}/en-de \ + --model-path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --tgt-splitter-type SentencePieceModel \ + --tgt-splitter-path ${MUSTC_ROOT}/en-de/spm.model \ + --scores +``` diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py new file mode 100644 index 0000000000..cbe8bc4322 --- /dev/null +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -0,0 +1,331 @@ +import math +import os + +import numpy as np +import torch +import torchaudio.compliance.kaldi as kaldi +import yaml +from fairseq import checkpoint_utils, tasks + +try: + from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS + from simuleval.agents import SpeechAgent + from simuleval.states import ListEntry +except ImportError: + print("Please install simuleval 'pip install simuleval'") + + +SHIFT_SIZE = 10 +WINDOW_SIZE = 25 +SAMPLE_RATE = 16000 +FEATURE_DIM = 80 +BOW_PREFIX = "\u2581" + + +class OnlineFeatureExtractor: + """ + Extract speech feature on the fly. + """ + + def __init__( + self, + shift_size=SHIFT_SIZE, + window_size=WINDOW_SIZE, + sample_rate=SAMPLE_RATE, + feature_dim=FEATURE_DIM, + global_cmvn=None, + ): + self.shift_size = shift_size + self.window_size = window_size + assert self.window_size >= self.shift_size + + self.sample_rate = sample_rate + self.feature_dim = feature_dim + self.num_samples_per_shift = int(SHIFT_SIZE * SAMPLE_RATE / 1000) + self.num_samples_per_window = int(WINDOW_SIZE * SAMPLE_RATE / 1000) + self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000 + self.previous_residual_samples = [] + self.global_cmvn = global_cmvn + + def clear_cache(self): + self.previous_residual_samples = [] + + def __call__(self, new_samples): + samples = self.previous_residual_samples + new_samples + if len(samples) < self.num_samples_per_window: + self.previous_residual_samples = samples + return + + # num_frames is the number of frames from the new segment + num_frames = math.floor( + (len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size)) + / self.num_samples_per_shift + ) + + # the number of frames used for feature extraction + # including some part of thte previous segment + effective_num_samples = int( + num_frames * self.len_ms_to_samples(self.shift_size) + + self.len_ms_to_samples(self.window_size - self.shift_size) + ) + + input_samples = samples[:effective_num_samples] + self.previous_residual_samples = samples[ + num_frames * self.num_samples_per_shift : + ] + + torch.manual_seed(1) + output = kaldi.fbank( + torch.FloatTensor(input_samples).unsqueeze(0), + num_mel_bins=self.feature_dim, + frame_length=self.window_size, + frame_shift=self.shift_size, + ).numpy() + + output = self.transform(output) + + return torch.from_numpy(output) + + def transform(self, input): + if self.global_cmvn is None: + return input + + mean = self.global_cmvn["mean"] + std = self.global_cmvn["std"] + + x = np.subtract(input, mean) + x = np.divide(x, std) + return x + + +class TensorListEntry(ListEntry): + """ + Data structure to store a list of tensor. + """ + + def append(self, value): + + if len(self.value) == 0: + self.value = value + return + + self.value = torch.cat([self.value] + [value], dim=0) + + def info(self): + return { + "type": str(self.new_value_type), + "length": self.__len__(), + "value": "" if type(self.value) is list else self.value.size(), + } + + +class FairseqSimulSTAgent(SpeechAgent): + + speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size + + def __init__(self, args): + super().__init__(args) + + self.eos = DEFAULT_EOS + + self.gpu = getattr(args, "gpu", False) + + self.args = args + + self.load_model_vocab(args) + + config_yaml = os.path.join(args.data_bin, "config.yaml") + with open(config_yaml, "r") as f: + config = yaml.load(f) + + if "global_cmvn" in config: + global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) + else: + global_cmvn = None + + self.feature_extractor = OnlineFeatureExtractor(global_cmvn=global_cmvn) + + self.max_len = args.max_len + + self.force_finish = args.force_finish + + torch.set_grad_enabled(False) + + def to_device(self, tensor): + if self.gpu: + return tensor.cuda() + else: + return tensor.cpu() + + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--model-path', type=str, required=True, + help='path to your pretrained model.') + parser.add_argument("--data-bin", type=str, required=True, + help="Path of data binary") + parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", + help="Subword splitter type for target text") + parser.add_argument("--tgt-splitter-path", type=str, default=None, + help="Subword splitter model path for target text") + parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation", + help="User directory for simultaneous translation") + parser.add_argument("--max-len", type=int, default=200, + help="Max length of translation") + parser.add_argument("--force-finish", default=False, action="store_true", + help="") + # fmt: on + return parser + + def load_model_vocab(self, args): + + filename = args.model_path + if not os.path.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + + state = checkpoint_utils.load_checkpoint_to_cpu(filename) + + task_args = state["cfg"]["task"] + task_args.data = args.data_bin + + task = tasks.setup_task(task_args) + + # build model for ensemble + self.model = task.build_model(state["cfg"]["model"]) + self.model.load_state_dict(state["model"], strict=True) + self.model.eval() + self.model.share_memory() + + if self.gpu: + self.model.cuda() + + # Set dictionary + self.dict = {} + self.dict["tgt"] = task.target_dictionary + + def initialize_states(self, states): + self.feature_extractor.clear_cache() + states.units.source = TensorListEntry() + states.units.target = ListEntry() + states.incremental_states = dict() + + def segment_to_units(self, segment, states): + # Convert speech samples to features + features = self.feature_extractor(segment) + if features is not None: + return [features] + else: + return [] + + def units_to_segment(self, units, states): + # Merge sub word to full word. + if self.model.decoder.dictionary.eos() == units[0]: + return DEFAULT_EOS + + segment = [] + if None in units.value: + units.value.remove(None) + + for index in units: + if index is None: + units.pop() + token = self.model.decoder.dictionary.string([index]) + if token.startswith(BOW_PREFIX): + if len(segment) == 0: + segment += [token.replace(BOW_PREFIX, "")] + else: + for j in range(len(segment)): + units.pop() + + string_to_return = ["".join(segment)] + + if self.model.decoder.dictionary.eos() == units[0]: + string_to_return += [DEFAULT_EOS] + + return string_to_return + else: + segment += [token.replace(BOW_PREFIX, "")] + + if ( + len(units) > 0 + and self.model.decoder.dictionary.eos() == units[-1] + or len(states.units.target) > self.max_len + ): + tokens = [self.model.decoder.dictionary.string([unit]) for unit in units] + return ["".join(tokens).replace(BOW_PREFIX, "")] + [DEFAULT_EOS] + + return None + + def update_model_encoder(self, states): + if len(states.units.source) == 0: + return + src_indices = self.to_device(states.units.source.value.unsqueeze(0)) + src_lengths = self.to_device( + torch.LongTensor([states.units.source.value.size(0)]) + ) + print(src_lengths) + + states.encoder_states = self.model.encoder(src_indices, src_lengths) + torch.cuda.empty_cache() + + def update_states_read(self, states): + # Happens after a read action. + self.update_model_encoder(states) + + def policy(self, states): + if not getattr(states, "encoder_states", None): + return READ_ACTION + + tgt_indices = self.to_device( + torch.LongTensor( + [self.model.decoder.dictionary.eos()] + + [x for x in states.units.target.value if x is not None] + ).unsqueeze(0) + ) + + states.incremental_states["steps"] = { + "src": states.encoder_states["encoder_out"][0].size(0), + "tgt": 1 + len(states.units.target), + } + + states.incremental_states["online"] = True + + x, outputs = self.model.decoder.forward( + prev_output_tokens=tgt_indices, + encoder_out=states.encoder_states, + incremental_state=states.incremental_states, + # features_only=True, + ) + + states.decoder_out = x + + states.decoder_out_extra = outputs + + torch.cuda.empty_cache() + + if outputs["action"] == 0: + return READ_ACTION + else: + return WRITE_ACTION + + def predict(self, states): + decoder_states = states.decoder_out + + lprobs = self.model.get_normalized_probs( + [decoder_states[:, -1:]], log_probs=True + ) + + index = lprobs.argmax(dim=-1) + + torch.cuda.empty_cache() + + index = index[0, 0].item() + + if ( + self.force_finish + and index == self.model.decoder.dictionary.eos() + and not states.finish_read() + ): + index = None + + return index diff --git a/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py b/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py new file mode 100644 index 0000000000..45df5fa227 --- /dev/null +++ b/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os + +from fairseq import checkpoint_utils, utils, tasks + +from . import DEFAULT_EOS, GET, SEND +from .agent import Agent + + +class SimulTransAgent(Agent): + def __init__(self, args): + # Load Model + self.load_model(args) + + # build word spliter + self.build_word_splitter(args) + + self.max_len = args.max_len + + self.eos = DEFAULT_EOS + + @staticmethod + def add_args(parser): + parser.add_argument( + "--model-path", + type=str, + required=True, + help="path to your pretrained model.", + ) + parser.add_argument( + "--data-bin", type=str, required=True, help="Path of data binary" + ) + parser.add_argument( + "--user-dir", + type=str, + default="example/simultaneous_translation", + help="User directory for simultaneous translation", + ) + parser.add_argument( + "--src-splitter-type", + type=str, + default=None, + help="Subword splitter type for source text", + ) + parser.add_argument( + "--tgt-splitter-type", + type=str, + default=None, + help="Subword splitter type for target text", + ) + parser.add_argument( + "--src-splitter-path", + type=str, + default=None, + help="Subword splitter model path for source text", + ) + parser.add_argument( + "--tgt-splitter-path", + type=str, + default=None, + help="Subword splitter model path for target text", + ) + parser.add_argument( + "--max-len", + type=int, + default=150, + help="Maximum length difference between source and target prediction", + ) + parser.add_argument( + "--model-overrides", + default="{}", + type=str, + metavar="DICT", + help="A dictionary used to override model args at generation " + "that were used during model training", + ) + # fmt: on + return parser + + def load_dictionary(self, task): + raise NotImplementedError + + def load_model(self, args): + args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..") + utils.import_user_module(args) + filename = args.model_path + if not os.path.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + + state = checkpoint_utils.load_checkpoint_to_cpu( + filename, json.loads(args.model_overrides) + ) + + saved_args = state["args"] + saved_args.data = args.data_bin + + task = tasks.setup_task(saved_args) + + # build model for ensemble + self.model = task.build_model(saved_args) + self.model.load_state_dict(state["model"], strict=True) + + # Set dictionary + self.load_dictionary(task) + + def init_states(self): + return { + "indices": {"src": [], "tgt": []}, + "tokens": {"src": [], "tgt": []}, + "segments": {"src": [], "tgt": []}, + "steps": {"src": 0, "tgt": 0}, + "finished": False, + "finish_read": False, + "model_states": {}, + } + + def update_states(self, states, new_state): + raise NotImplementedError + + def policy(self, states): + # Read and Write policy + action = None + + while action is None: + if states["finished"]: + # Finish the hypo by sending eos to server + return self.finish_action() + + # Model make decision given current states + decision = self.model.decision_from_states(states) + + if decision == 0 and not self.finish_read(states): + # READ + action = self.read_action(states) + else: + # WRITE + action = self.write_action(states) + + # None means we make decision again but not sending server anything + # This happened when read a buffered token + # Or predict a subword + return action + + def finish_read(self, states): + raise NotImplementedError + + def write_action(self, states): + token, index = self.model.predict_from_states(states) + + if ( + index == self.dict["tgt"].eos() + or len(states["tokens"]["tgt"]) > self.max_len + ): + # Finish this sentence is predict EOS + states["finished"] = True + end_idx_last_full_word = self._target_length(states) + + else: + states["tokens"]["tgt"] += [token] + end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word( + states["tokens"]["tgt"] + ) + self._append_indices(states, [index], "tgt") + + if end_idx_last_full_word > states["steps"]["tgt"]: + # Only sent detokenized full words to the server + word = self.word_splitter["tgt"].merge( + states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word] + ) + states["steps"]["tgt"] = end_idx_last_full_word + states["segments"]["tgt"] += [word] + + return {"key": SEND, "value": word} + else: + return None + + def read_action(self, states): + return {"key": GET, "value": None} + + def finish_action(self): + return {"key": SEND, "value": DEFAULT_EOS} + + def reset(self): + pass + + def finish_eval(self, states, new_state): + if len(new_state) == 0 and len(states["indices"]["src"]) == 0: + return True + return False + + def _append_indices(self, states, new_indices, key): + states["indices"][key] += new_indices + + def _target_length(self, states): + return len(states["tokens"]["tgt"]) diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py index 5d7f59b3a6..28e3bb720f 100644 --- a/fairseq/models/speech_to_text/__init__.py +++ b/fairseq/models/speech_to_text/__init__.py @@ -4,4 +4,6 @@ # LICENSE file in the root directory of this source tree. from .berard import * # noqa +from .convtransformer import * # noqa +from .convtransformer_simul_trans import * # noqa from .s2t_transformer import * # noqa diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py new file mode 100644 index 0000000000..512ee78be0 --- /dev/null +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python3 + +import logging +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from examples.simultaneous_translation.utils.data_utils import ( + lengths_to_encoder_padding_mask, +) +from fairseq import checkpoint_utils, utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerEncoderLayer +from torch import Tensor + +logger = logging.getLogger(__name__) + + +@register_model("convtransformer") +class ConvTransformerModel(FairseqEncoderDecoderModel): + """ + Transformer-based Speech translation model from ESPNet-ST + https://arxiv.org/abs/2004.10234 + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--input-feat-per-channel", + type=int, + metavar="N", + help="encoder input dimension per input channel", + ) + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--decoder-output-dim", + type=int, + metavar="N", + help="decoder output dimension (extra linear layer if different from decoder embed dim)", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + parser.add_argument( + "--load-pretrained-encoder-from", + type=str, + metavar="STR", + help="model to take encoder weights from (for initialization)", + ) + parser.add_argument( + "--load-pretrained-decoder-from", + type=str, + metavar="STR", + help="model to take decoder weights from (for initialization)", + ) + parser.add_argument( + "--conv-out-channels", + type=int, + metavar="INT", + help="the number of output channels of conv layer", + ) + + @classmethod + def build_encoder(cls, args): + encoder = ConvTransformerEncoder(args) + if getattr(args, "load_pretrained_encoder_from", None): + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + return encoder + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + decoder = TransformerDecoderNoExtra(args, task.target_dictionary, embed_tokens) + if getattr(args, "load_pretrained_decoder_from", None): + decoder = checkpoint_utils.load_pretrained_component_from_model( + component=decoder, checkpoint=args.load_pretrained_decoder_from + ) + return decoder + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + return Embedding(num_embeddings, embed_dim, padding_idx) + + decoder_embed_tokens = build_embedding( + task.target_dictionary, args.decoder_embed_dim + ) + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args, task, decoder_embed_tokens) + return cls(encoder, decoder) + + @staticmethod + @torch.jit.unused + def set_batch_first(lprobs): + lprobs.batch_first = True + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) + if self.training: + self.set_batch_first(lprobs) + return lprobs + + def output_layout(self): + return "BTD" + + """ + The forward method inherited from the base class has a **kwargs argument in + its input, which is not supported in torchscript. This method overrites the forward + method definition without **kwargs. + """ + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths) + decoder_out = self.decoder( + prev_output_tokens=prev_output_tokens, encoder_out=encoder_out + ) + return decoder_out + + +class ConvTransformerEncoder(FairseqEncoder): + """Conv + Transformer encoder""" + + def __init__(self, args): + """Construct an Encoder object.""" + super().__init__(None) + + self.dropout = args.dropout + self.embed_scale = ( + 1.0 if args.no_scale_embedding else math.sqrt(args.encoder_embed_dim) + ) + self.padding_idx = 1 + self.in_channels = 1 + self.input_dim = args.input_feat_per_channel + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, args.conv_out_channels, 3, stride=2, padding=3 // 2), + torch.nn.ReLU(), + torch.nn.Conv2d( + args.conv_out_channels, + args.conv_out_channels, + 3, + stride=2, + padding=3 // 2, + ), + torch.nn.ReLU(), + ) + transformer_input_dim = self.infer_conv_output_dim( + self.in_channels, self.input_dim, args.conv_out_channels + ) + self.out = torch.nn.Linear(transformer_input_dim, args.encoder_embed_dim) + self.embed_positions = PositionalEmbedding( + args.max_source_positions, + args.encoder_embed_dim, + self.padding_idx, + learned=False, + ) + + self.transformer_layers = nn.ModuleList([]) + self.transformer_layers.extend( + [TransformerEncoderLayer(args) for i in range(args.encoder_layers)] + ) + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + def pooling_ratio(self): + return 4 + + def infer_conv_output_dim(self, in_channels, input_dim, out_channels): + sample_seq_len = 200 + sample_bsz = 10 + x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) + x = torch.nn.Conv2d(1, out_channels, 3, stride=2, padding=3 // 2)(x) + x = torch.nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=3 // 2)(x) + x = x.transpose(1, 2) + mb, seq = x.size()[:2] + return x.contiguous().view(mb, seq, -1).size(-1) + + def forward(self, src_tokens, src_lengths): + """Encode input sequence. + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + bsz, max_seq_len, _ = src_tokens.size() + x = ( + src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + .transpose(1, 2) + .contiguous() + ) + x = self.conv(x) + bsz, _, output_seq_len, _ = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1) + x = self.out(x) + x = self.embed_scale * x + + subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) + + input_lengths = min( + (src_lengths.float() / subsampling_factor).ceil().long(), + x.size(0) * src_lengths.new_ones([1]), + ) + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + input_lengths, batch_first=True + ) + + positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + for layer in self.transformer_layers: + x = layer(x, encoder_padding_mask) + + if not encoder_padding_mask.any(): + maybe_encoder_padding_mask = None + else: + maybe_encoder_padding_mask = encoder_padding_mask + + return { + "encoder_out": [x], + "encoder_padding_mask": [maybe_encoder_padding_mask] + if maybe_encoder_padding_mask is not None + else [], + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + (encoder_out["encoder_padding_mask"][0]).index_select(0, new_order) + ] + if len(encoder_out["encoder_embedding"]) == 0: + new_encoder_embedding = [] + else: + new_encoder_embedding = [ + (encoder_out["encoder_embedding"][0]).index_select(0, new_order) + ] + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, + "encoder_padding_mask": new_encoder_padding_mask, + "encoder_embedding": new_encoder_embedding, + "encoder_states": encoder_states, + "src_tokens": [], + "src_lengths": [], + } + + +class TransformerDecoderNoExtra(TransformerDecoder): + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + # call scriptable method from parent class + x, _ = self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + return x, None + + +@register_model_architecture(model_name="convtransformer", arch_name="convtransformer") +def base_architecture(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.max_source_positions = getattr(args, "max_source_positions", 3000) + args.max_target_positions = getattr(args, "max_target_positions", 1024) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.conv_out_channels = getattr(args, "conv_out_channels", args.encoder_embed_dim) + + +@register_model_architecture("convtransformer", "convtransformer_espnet") +def convtransformer_espnet(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) diff --git a/fairseq/models/speech_to_text/convtransformer_simul_trans.py b/fairseq/models/speech_to_text/convtransformer_simul_trans.py new file mode 100644 index 0000000000..e5dd771e03 --- /dev/null +++ b/fairseq/models/speech_to_text/convtransformer_simul_trans.py @@ -0,0 +1,49 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from examples.simultaneous_translation.models.transformer_monotonic_attention import ( + TransformerMonotonicDecoder, +) +from fairseq import checkpoint_utils +from fairseq.models import ( + register_model, + register_model_architecture, +) + +from .convtransformer import ConvTransformerModel, convtransformer_espnet + + +@register_model("convtransformer_simul_trans") +class SimulConvTransformerModel(ConvTransformerModel): + @staticmethod + def add_args(parser): + super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser) + parser.add_argument( + "--train-monotonic-only", + action="store_true", + default=False, + help="Only train monotonic attention", + ) + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + tgt_dict = task.tgt_dict + + decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) + + if getattr(args, "load_pretrained_decoder_from", None): + decoder = checkpoint_utils.load_pretrained_component_from_model( + component=decoder, checkpoint=args.load_pretrained_decoder_from + ) + return decoder + + +@register_model_architecture( + "convtransformer_simul_trans", "convtransformer_simul_trans_espnet" +) +def convtransformer_simul_trans_espnet(args): + convtransformer_espnet(args) From 523fe83828e6374439a6203330ed0e8c13e86b62 Mon Sep 17 00:00:00 2001 From: Sravya Popuri Date: Thu, 18 Feb 2021 22:41:32 -0800 Subject: [PATCH 24/82] Integrate Simul ST model into pyspeech Summary: This diff integrates simul ST training into pyspeech with very minor modifications to the open sourced code. Specific changes made are - In fixed_pre_decision.py remove self as argument to p_choose function as it is already called with super in line 101 - In monotonic_multihead_attention.py remove pdb.set_trace() - Move label_smoothed_cross_entropy_latency_augmented.py to fairseq/criterions folder and add missing arguments to parser - In fairseq/data/data_utils.py type cast max_tokens to int to avoid type error. - Update fairseq/convtransformer.py to pyspeech/convtransformer.py # Next steps: - Verify decoding using the model trained - Support everstore handle based decoding in simuleval and integrate it into pyspeech. Reviewed By: jmp84 Differential Revision: D26478861 fbshipit-source-id: 3b02b2aee757e5464b71dbdd7ebdba42659faee5 --- .../modules/fixed_pre_decision.py | 1 - .../modules/monotonic_multihead_attention.py | 2 -- ...moothed_cross_entropy_latency_augmented.py | 22 ++++++++++++++++++- fairseq/data/data_utils.py | 5 ++++- .../models/speech_to_text/convtransformer.py | 6 +---- 5 files changed, 26 insertions(+), 10 deletions(-) rename {examples/simultaneous_translation => fairseq}/criterions/label_smoothed_cross_entropy_latency_augmented.py (86%) diff --git a/examples/simultaneous_translation/modules/fixed_pre_decision.py b/examples/simultaneous_translation/modules/fixed_pre_decision.py index 2cde55b35e..725be1a983 100644 --- a/examples/simultaneous_translation/modules/fixed_pre_decision.py +++ b/examples/simultaneous_translation/modules/fixed_pre_decision.py @@ -99,7 +99,6 @@ def p_choose( if self.pre_decision_ratio == 1: return super().p_choose( - self, query, key, key_padding_mask=None, diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index 5423f26c34..3e25957cd6 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -259,8 +259,6 @@ def expected_alignment_infer( finish_read = new_monotonic_step.eq(max_steps) | (action == 0) - if p_choose_i is None: - import pdb;pdb.set_trace() monotonic_cache["head_step"] = new_monotonic_step # Whether a head is looking for new input diff --git a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py similarity index 86% rename from examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py rename to fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py index 761cfe61a1..aa3dba31e2 100644 --- a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -53,8 +53,28 @@ def add_args(parser): LatencyAugmentedLabelSmoothedCrossEntropyCriterion, LatencyAugmentedLabelSmoothedCrossEntropyCriterion, ).add_args(parser) - """Add criterion-specific arguments to the parser.""" # fmt: off + + """Add criterion-specific arguments to the parser.""" + parser.add_argument( + "--label-smoothing", + default=0.0, + type=float, + metavar="D", + help="epsilon for label smoothing, 0 means no label smoothing", + ) + parser.add_argument( + "--ignore_prefix_size", + default=0, + type=int, + help="ignore first N tokens", + ) + parser.add_argument( + "--report-accuracy", + default=False, + type=bool, + help="report accuracy metric", + ) parser.add_argument("--latency-weight-avg", default=0., type=float, metavar='D', help="Average loss weight") parser.add_argument("--latency-weight-var", default=0., type=float, metavar='D', diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 47d8492ec9..3042358f2f 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -313,7 +313,10 @@ def batch_by_size( " --editable .` or `python setup.py build_ext --inplace`." ) - max_tokens = max_tokens if max_tokens is not None else -1 + # added int() to avoid TypeError: an integer is required + max_tokens = ( + int(max_tokens) if max_tokens is not None else -1 + ) max_sentences = max_sentences if max_sentences is not None else -1 bsz_mult = required_batch_size_multiple diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py index 512ee78be0..06276e636a 100644 --- a/fairseq/models/speech_to_text/convtransformer.py +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -307,11 +307,7 @@ def forward(self, src_tokens, src_lengths): subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) - input_lengths = min( - (src_lengths.float() / subsampling_factor).ceil().long(), - x.size(0) * src_lengths.new_ones([1]), - ) - + input_lengths = (src_lengths.float() / subsampling_factor).ceil().long() encoder_padding_mask, _ = lengths_to_encoder_padding_mask( input_lengths, batch_first=True ) From 675f608915a216ac32777928a0b1e8210cb66df6 Mon Sep 17 00:00:00 2001 From: Changhan Wang Date: Fri, 19 Feb 2021 08:59:37 -0800 Subject: [PATCH 25/82] Fix LibriSpeech data prep script Summary: Fix LibriSpeech data prep script * Lowercasing transcript to be consistent with the pre-trained models Reviewed By: jmp84 Differential Revision: D26538845 fbshipit-source-id: 0885f99e2c85f0e722a24f3cb83f2635ce9429bc --- examples/speech_to_text/prep_librispeech_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech_to_text/prep_librispeech_data.py b/examples/speech_to_text/prep_librispeech_data.py index 6a6f55ded4..7b08447190 100644 --- a/examples/speech_to_text/prep_librispeech_data.py +++ b/examples/speech_to_text/prep_librispeech_data.py @@ -71,7 +71,7 @@ def process(args): manifest["audio"].append(zip_manifest[sample_id]) duration_ms = int(wav.size(1) / sample_rate * 1000) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) - manifest["tgt_text"].append(utt) + manifest["tgt_text"].append(utt.lower()) manifest["speaker"].append(spk_id) save_df_to_tsv( pd.DataFrame.from_dict(manifest), out_root / f"{split}.tsv" From 2909ee1852cdae7ad4115a1a04520b0522265dd2 Mon Sep 17 00:00:00 2001 From: "joseph.suh" Date: Fri, 19 Feb 2021 10:07:43 -0800 Subject: [PATCH 26/82] Fix bug for issue (#3211) (#3212) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes KeyError mentioned in # (3211). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3212 Reviewed By: alexeib Differential Revision: D26513255 Pulled By: myleott fbshipit-source-id: 5a11cb369c9d4202fab6998d269e7da5f3d3e534 --- fairseq/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index e860fb1832..f66dc25e40 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -401,7 +401,7 @@ def load_checkpoint( self.lr_step(epoch) - if itr_state["version"] >= 2 and itr_state["iterations_in_epoch"] == 0: + if itr_state.get("version", 1) >= 2 and itr_state["iterations_in_epoch"] == 0: # reset meters at start of epoch reset_meters = True From 3ef18886d0a802a8c8d90b57d858df3da7e75202 Mon Sep 17 00:00:00 2001 From: Alex Gaziev Date: Fri, 19 Feb 2021 10:10:13 -0800 Subject: [PATCH 27/82] Remove extra arg min_length and fix min_sample_size behavior (#3249) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3178 (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � (I did ;) Pull Request resolved: https://github.com/pytorch/fairseq/pull/3249 Reviewed By: alexeib Differential Revision: D26513275 Pulled By: myleott fbshipit-source-id: 2785098a945404c07eb72c079177654b1739a7a2 --- fairseq/data/audio/raw_audio_dataset.py | 10 +++------- fairseq/tasks/audio_pretraining.py | 5 ++--- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index ac5acd03bb..1d92e4966b 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -23,9 +23,8 @@ def __init__( self, sample_rate, max_sample_size=None, - min_sample_size=None, + min_sample_size=0, shuffle=True, - min_length=0, pad=False, normalize=False, ): @@ -37,7 +36,6 @@ def __init__( max_sample_size if max_sample_size is not None else sys.maxsize ) self.min_sample_size = min_sample_size - self.min_length = min_length self.pad = pad self.shuffle = shuffle self.normalize = normalize @@ -136,9 +134,8 @@ def __init__( manifest_path, sample_rate, max_sample_size=None, - min_sample_size=None, + min_sample_size=0, shuffle=True, - min_length=0, pad=False, normalize=False, ): @@ -147,7 +144,6 @@ def __init__( max_sample_size=max_sample_size, min_sample_size=min_sample_size, shuffle=shuffle, - min_length=min_length, pad=pad, normalize=normalize, ) @@ -162,7 +158,7 @@ def __init__( items = line.strip().split("\t") assert len(items) == 2, line sz = int(items[1]) - if min_length is not None and sz < min_length: + if min_sample_size is not None and sz < min_sample_size: skipped += 1 continue self.fnames.append(items[0]) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 92685160d4..b7b5429819 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -57,7 +57,7 @@ class AudioPretrainingConfig(FairseqDataclass): default=None, metadata={"help": "max sample size to crop to for batching"} ) min_sample_size: Optional[int] = field( - default=None, metadata={"help": "min sample size to crop to for batching"} + default=None, metadata={"help": "min sample size to skip small examples"} ) # Options for reporting WER metrics during validation. Only applicable to @@ -135,8 +135,7 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): manifest, sample_rate=task_cfg.get('sample_rate', self.cfg.sample_rate), max_sample_size=self.cfg.max_sample_size, - min_sample_size=self.cfg.max_sample_size, - min_length=self.cfg.min_sample_size, + min_sample_size=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, ) From c6b5c00312dc23f473c66ba3016cc9e3decfd317 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Fri, 19 Feb 2021 10:31:08 -0800 Subject: [PATCH 28/82] fix criterion name check when resuming from checkpoint Summary: I tried resuming a run from a checkpoint in f250883864, but ran into: AssertionError: Criterion does not match; please reset the optimizer (--reset-optimizer). DistributedTimeoutWrapper vs ContrastiveLabelsCriterion Based on this, I believe since D25836853 (https://github.com/pytorch/fairseq/commit/d68a3530dda7f8275e490864b28974ef30fe854b) we are no longer saving the actual criterion's name, but DistributedTimeoutWrapper in the checkpoint. This is kind of weird though, as I would expect more people to run into this issue. Not sure if I am doing something wrong, let me know if so, thanks! Reviewed By: myleott Differential Revision: D26478656 fbshipit-source-id: bc3c7c925f5505140d9df4438af3a73d65d4f531 --- fairseq/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index f66dc25e40..891155f162 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -284,7 +284,7 @@ def save_checkpoint(self, filename, extra_state): filename, self.cfg, self.model.state_dict(), - self.criterion, + self.get_criterion(), self.optimizer, self.lr_scheduler, self.get_num_updates(), @@ -375,10 +375,10 @@ def load_checkpoint( last_optim = self._optim_history[-1] assert ( last_optim["criterion_name"] == self.get_criterion().__class__.__name__ - ), "Criterion does not match; please reset the optimizer (--reset-optimizer)." + ), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}" assert ( last_optim["optimizer_name"] == self.optimizer.__class__.__name__ - ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)." + ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}" if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) From ae22da652d63bd6e05a9a035f6a9dcabb1a39c73 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Fri, 19 Feb 2021 21:52:31 -0800 Subject: [PATCH 29/82] Correct the estimation of cnn output lengths in convtransformer (#1636) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1636 Reviewed By: xutaima Differential Revision: D26562816 Pulled By: jmp84 fbshipit-source-id: 4e6efd0b4236d7187bd365d790f260bd5297aed5 --- fairseq/models/speech_to_text/convtransformer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py index 06276e636a..622b5e6df8 100644 --- a/fairseq/models/speech_to_text/convtransformer.py +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -30,7 +30,6 @@ class ConvTransformerModel(FairseqEncoderDecoderModel): Transformer-based Speech translation model from ESPNet-ST https://arxiv.org/abs/2004.10234 """ - def __init__(self, encoder, decoder): super().__init__(encoder, decoder) @@ -307,7 +306,11 @@ def forward(self, src_tokens, src_lengths): subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) - input_lengths = (src_lengths.float() / subsampling_factor).ceil().long() + input_lengths = torch.min( + (src_lengths.float() / subsampling_factor).ceil().long(), + x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long() + ) + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( input_lengths, batch_first=True ) From 61e46bb99758e05bc990e3687c69b507a8ebf185 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 20 Feb 2021 06:21:45 -0800 Subject: [PATCH 30/82] Fix attempt to unlink directory copied into source package (Python 3.9) (#3235) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [N/A] Did you make sure to update the docs? - [N/A] Did you write any new necessary tests? ## What does this PR do? Currently when installing the newest source package from PyPI I get an error like so: ``` Collecting fairseq Using cached fairseq-0.10.2.tar.gz (938 kB) Installing build dependencies ... done Getting requirements to build wheel ... error ERROR: Command errored out with exit status 1: command: /home/frankier/sources/datasets/.venv/bin/python3 /tmp/tmp_ujftsgi_in_process.py get_requires_for_build_wheel /tmp/tmpmn0eumq2 cwd: /tmp/pip-install-dg5d6q9y/fairseq Complete output (31 lines): Traceback (most recent call last): File "setup.py", line 214, in do_setup(package_data) File "setup.py", line 136, in do_setup setup( File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/__init__.py", line 152, in setup _install_setup_requires(attrs) File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/__init__.py", line 147, in _install_setup_requires dist.fetch_build_eggs(dist.setup_requires) File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 60, in fetch_build_eggs raise SetupRequirementsError(specifier_list) setuptools.build_meta.SetupRequirementsError: ['cython', 'numpy', 'setuptools>=18.0'] During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmp/tmp_ujftsgi_in_process.py", line 280, in main() File "/tmp/tmp_ujftsgi_in_process.py", line 263, in main json_out['return_val'] = hook(**hook_input['kwargs']) File "/tmp/tmp_ujftsgi_in_process.py", line 114, in get_requires_for_build_wheel return hook(config_settings) File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 149, in get_requires_for_build_wheel return self._get_build_requires( File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 130, in _get_build_requires self.run_setup() File "/tmp/pip-build-env-hag0sxvp/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 145, in run_setup exec(compile(code, __file__, 'exec'), locals()) File "setup.py", line 217, in os.unlink(fairseq_examples) IsADirectoryError: [Errno 21] Is a directory: 'fairseq/examples' ---------------------------------------- ERROR: Command errored out with exit status 1: /home/frankier/sources/datasets/.venv/bin/python3 /tmp/tmp_ujftsgi_in_process.py get_requires_for_build_wheel /tmp/tmpmn0eumq2 Check the logs for full command output. ``` I believe the reason for this is that the source package contains the examples directory because it was put there during package creation (it seems the symlink because a directory). Now, when setup.py is run again, it seems the setup.py attempts to unlink the directory, which is not possible because only symlinks can be unlinked. This PR therefore only attempts to unlink it if it is a symlink. I have not thoroughly tested whether my proposed cause is the true cause, but this should fix it in any case. Note that the source package is fetched because there is no wheel for Python 3.9, so most users will not see this because they will use the wheel. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3235 Reviewed By: alexeib Differential Revision: D26513259 Pulled By: myleott fbshipit-source-id: 775d6c636a5867b9983bb6419829f13ee414e2fd --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d1a976104e..3670ff3cfc 100644 --- a/setup.py +++ b/setup.py @@ -256,5 +256,5 @@ def get_files(path, relative_to="fairseq"): } do_setup(package_data) finally: - if "build_ext" not in sys.argv[1:] and os.path.exists(fairseq_examples): + if "build_ext" not in sys.argv[1:] and os.path.islink(fairseq_examples): os.unlink(fairseq_examples) From 4cf7d76114d50008cdd98a7fde250d4ef99b66fe Mon Sep 17 00:00:00 2001 From: Pierre Andrews Date: Sat, 20 Feb 2021 06:23:41 -0800 Subject: [PATCH 31/82] Hydra Integration doc should refer to non legacy task (#1619) Summary: # Before submitting - [NO] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [YES] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [YES] Did you make sure to update the docs? - [NO] Did you write any new necessary tests? ## What does this PR do? This is a typo fix to the Hydra Integration doc where the example with dataclass config should user `FairseqTask` and not `LegacyFairseqTask`. Didn't make an issue for this as it's a trivial doc change for the example to match the actual doc. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1619 Reviewed By: huihuifan Differential Revision: D26448855 Pulled By: Mortimerp9 fbshipit-source-id: 467323101b8425370f6bd7c0532e70abb319b337 --- docs/hydra_integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md index 04c797fe50..6a15298382 100644 --- a/docs/hydra_integration.md +++ b/docs/hydra_integration.md @@ -120,7 +120,7 @@ class LanguageModelingConfig(FairseqDataclass): ... @register_task("language_modeling", dataclass=LanguageModelingConfig) -class LanguageModelingTask(LegacyFairseqTask): +class LanguageModelingTask(FairseqTask): ... @classmethod def setup_task(cls, cfg: LanguageModelingConfig): From 38258a79a42f3ccfa596cc51bbf269cf13c3d799 Mon Sep 17 00:00:00 2001 From: Sravya Popuri Date: Mon, 22 Feb 2021 13:55:06 -0800 Subject: [PATCH 32/82] Update FairseqSimulSTAgent to make it generic and reusable internally Summary: This diff 1. Updates FairseqSimulSTAgent to make it generic and reusable internally [Touches OSS] 2. Adds FBFairseqSimulSTAgent inheriting FairseqSimulSTAgent 3. Add TARGETS file in examples/speech_to_text 4. Update simuleval TARGETS and add a bento kernel for easy testing Reviewed By: jmp84 Differential Revision: D26573214 fbshipit-source-id: f4b71f90693cc878cc771b46a006bcbc83a50124 --- .../agents/fairseq_simul_st_agent.py | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index cbe8bc4322..32cd0a1f61 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -27,25 +27,18 @@ class OnlineFeatureExtractor: Extract speech feature on the fly. """ - def __init__( - self, - shift_size=SHIFT_SIZE, - window_size=WINDOW_SIZE, - sample_rate=SAMPLE_RATE, - feature_dim=FEATURE_DIM, - global_cmvn=None, - ): - self.shift_size = shift_size - self.window_size = window_size + def __init__(self, args): + self.shift_size = args.shift_size + self.window_size = args.window_size assert self.window_size >= self.shift_size - self.sample_rate = sample_rate - self.feature_dim = feature_dim - self.num_samples_per_shift = int(SHIFT_SIZE * SAMPLE_RATE / 1000) - self.num_samples_per_window = int(WINDOW_SIZE * SAMPLE_RATE / 1000) + self.sample_rate = args.sample_rate + self.feature_dim = args.feature_dim + self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000) + self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000) self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000 self.previous_residual_samples = [] - self.global_cmvn = global_cmvn + self.global_cmvn = args.global_cmvn def clear_cache(self): self.previous_residual_samples = [] @@ -134,16 +127,15 @@ def __init__(self, args): self.load_model_vocab(args) - config_yaml = os.path.join(args.data_bin, "config.yaml") - with open(config_yaml, "r") as f: + with open(args.config, "r") as f: config = yaml.load(f) if "global_cmvn" in config: - global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) + args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) else: - global_cmvn = None + args.global_cmvn = None - self.feature_extractor = OnlineFeatureExtractor(global_cmvn=global_cmvn) + self.feature_extractor = OnlineFeatureExtractor(args) self.max_len = args.max_len @@ -164,6 +156,8 @@ def add_args(parser): help='path to your pretrained model.') parser.add_argument("--data-bin", type=str, required=True, help="Path of data binary") + parser.add_argument("--config", type=str, required=True, + help="Path to config yaml file") parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", help="Subword splitter type for target text") parser.add_argument("--tgt-splitter-path", type=str, default=None, @@ -174,9 +168,21 @@ def add_args(parser): help="Max length of translation") parser.add_argument("--force-finish", default=False, action="store_true", help="") + parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE, + help="") + parser.add_argument("--window-size", type=int, default=WINDOW_SIZE, + help="") + parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE, + help="") + parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM, + help="") + # fmt: on return parser + def set_up_task(self, task_args): + return tasks.setup_task(task_args) + def load_model_vocab(self, args): filename = args.model_path @@ -188,7 +194,7 @@ def load_model_vocab(self, args): task_args = state["cfg"]["task"] task_args.data = args.data_bin - task = tasks.setup_task(task_args) + task = self.set_up_task(task_args) # build model for ensemble self.model = task.build_model(state["cfg"]["model"]) From 808b751597d85c098990080d21fd450877dcb242 Mon Sep 17 00:00:00 2001 From: Miguel Del-Agua Date: Mon, 22 Feb 2021 14:21:36 -0800 Subject: [PATCH 33/82] Improve torchscript compatibility of transfomer and transformer pg (#3247) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3246 Fixes https://github.com/pytorch/fairseq/issues/3248 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3247 Reviewed By: myleott Differential Revision: D26513267 Pulled By: lematt1991 fbshipit-source-id: 958de0b3a58a0dd2a56bd6c6d7fb2644a89f6746 --- .../pointer_generator_src/transformer_pg.py | 80 +++++++++++++++---- fairseq/models/fairseq_decoder.py | 13 +++ fairseq/models/transformer.py | 47 +++++++++-- tests/test_export.py | 13 +++ 4 files changed, 133 insertions(+), 20 deletions(-) diff --git a/examples/pointer_generator/pointer_generator_src/transformer_pg.py b/examples/pointer_generator/pointer_generator_src/transformer_pg.py index fb40a80836..e109a8e269 100644 --- a/examples/pointer_generator/pointer_generator_src/transformer_pg.py +++ b/examples/pointer_generator/pointer_generator_src/transformer_pg.py @@ -4,13 +4,12 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List, Tuple import torch import torch.nn as nn from fairseq import metrics, utils from fairseq.models import register_model, register_model_architecture -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import ( DEFAULT_MAX_SOURCE_POSITIONS, DEFAULT_MAX_TARGET_POSITIONS, @@ -155,7 +154,13 @@ class TransformerPointerGeneratorEncoder(TransformerEncoder): to the decoder. """ - def forward(self, src_tokens, src_lengths, **kwargs): + def forward( + self, + src_tokens, + src_lengths: Optional[Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[Tensor] = None + ): """ Runs the `forward()` method of the parent Transformer class. Then adds the source tokens into the encoder output tuple. @@ -169,6 +174,10 @@ def forward(self, src_tokens, src_lengths, **kwargs): shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings Returns: namedtuple: @@ -184,7 +193,15 @@ def forward(self, src_tokens, src_lengths, **kwargs): - **src_tokens** (Tensor): input token ids of shape `(batch, src_len)` """ - encoder_out = super().forward(src_tokens, src_lengths, **kwargs) + encoder_out = self.forward_scriptable(src_tokens, + src_lengths, + return_all_hiddens, + token_embeddings) + + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `forward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. return { "encoder_out": encoder_out["encoder_out"], # T x B x C "encoder_padding_mask": encoder_out["encoder_padding_mask"], # B x T @@ -236,7 +253,7 @@ def __init__(self, args, dictionary, embed_tokens): def forward( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, alignment_layer: Optional[int] = 0, @@ -248,8 +265,8 @@ def forward( Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing - encoder_out (EncoderOut, optional): output from the encoder, used - for encoder-side attention + encoder_out (optional): output from the encoder, used for + encoder-side attention incremental_state (dict, optional): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without @@ -284,10 +301,21 @@ def forward( predictors = torch.cat((prev_output_embed, x), 2) p_gens = self.project_p_gens(predictors) p_gens = torch.sigmoid(p_gens) - x = self.output_layer(x, extra["attn"][0], encoder_out["src_tokens"][0], p_gens) + # Torchscript complains if encoder_out or attn are None because + # `output_layer()` signature expects tensors instead + attn: Optional[Tensor] = extra["attn"][0] + assert encoder_out is not None + assert attn is not None + x = self.output_layer(x, attn, encoder_out["src_tokens"][0], p_gens) return x, extra - def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): + def output_layer( + self, + features: Tensor, + attn: Tensor, + src_tokens: Tensor, + p_gens: Tensor + ) -> Tensor: """ Project features to the vocabulary size and mix with the attention distributions. @@ -296,7 +324,10 @@ def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): p_gens = self.force_p_gen # project back to size of vocabulary - logits = super().output_layer(features, **kwargs) + if self.adaptive_softmax is None: + logits = self.output_projection(features) + else: + logits = features batch_size = logits.shape[0] output_length = logits.shape[1] @@ -306,7 +337,7 @@ def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): # The final output distribution will be a mixture of the normal output # distribution (softmax of logits) and attention weights. - gen_dists = super().get_normalized_probs( + gen_dists = self.get_normalized_probs_scriptable( (logits, None), log_probs=False, sample=None ) gen_dists = torch.mul(gen_dists, p_gens) @@ -330,7 +361,12 @@ def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): # Final distributions, [batch_size, output_length, num_types]. return gen_dists + attn_dists - def get_normalized_probs(self, net_output, log_probs, sample): + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): """ Get normalized probabilities (or log probs) from a net's output. Pointer-generator network output is already normalized. @@ -375,8 +411,19 @@ class Embedding(nn.Embedding): """ __constants__ = ["unk_idx"] - def __init__(self, num_embeddings, embedding_dim, padding_idx, unk_idx): - super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) + # Torchscript: Inheriting from Embedding class produces an error when exporting to Torchscript + # -> RuntimeError: Unable to cast Python instance to C++ type (compile in debug mode for details + # It's happening because max_norm attribute from nn.Embedding is None by default and it cannot be + # cast to a C++ type + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int], + unk_idx: int, + max_norm: Optional[float] = float("inf"), + ): + super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, max_norm=max_norm) self.unk_idx = unk_idx nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5) nn.init.constant_(self.weight[padding_idx], 0) @@ -385,7 +432,10 @@ def forward(self, input): input = torch.where( input >= self.num_embeddings, torch.ones_like(input) * self.unk_idx, input ) - return super().forward(input) + return nn.functional.embedding( + input, self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse + ) @register_model_architecture( diff --git a/fairseq/models/fairseq_decoder.py b/fairseq/models/fairseq_decoder.py index 7eeb5c652f..4f1e8b52a2 100644 --- a/fairseq/models/fairseq_decoder.py +++ b/fairseq/models/fairseq_decoder.py @@ -64,6 +64,19 @@ def get_normalized_probs( sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 605cfa65e8..f2f36baf3e 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -422,6 +422,45 @@ def forward( token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + return self.forward_scriptable(src_tokens, + src_lengths, + return_all_hiddens, + token_embeddings) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def forward_scriptable( + self, + src_tokens, + src_lengths: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of @@ -787,13 +826,11 @@ def extract_features_scriptable( alignment_layer = self.num_layers - 1 # embed positions - positions = ( - self.embed_positions( + positions = None + if self.embed_positions is not None: + positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state ) - if self.embed_positions is not None - else None - ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] diff --git a/tests/test_export.py b/tests/test_export.py index 87e52bd7c1..b380697b9a 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -103,6 +103,19 @@ def test_export_transformer(self): scripted = torch.jit.script(model) _test_save_and_load(scripted) + @unittest.skipIf( + torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" + ) + def test_export_transformer_no_token_pos_emb(self): + task, parser = get_dummy_task_and_parser() + TransformerModel.add_args(parser) + args = parser.parse_args([]) + args.no_token_positional_embeddings = True + model = TransformerModel.build_model(args, task) + scripted = torch.jit.script(model) + _test_save_and_load(scripted) + + if __name__ == "__main__": unittest.main() From 89cd70c0f0c096bdbfcfb2ab339a9c8f23540bc0 Mon Sep 17 00:00:00 2001 From: m_fomicheva Date: Mon, 22 Feb 2021 14:55:33 -0800 Subject: [PATCH 34/82] Fixed scripts and instructions for reproducing the results. (#3264) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [N] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [Y] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [Y] Did you make sure to update the docs? - [N] Did you write any new necessary tests? ## What does this PR do? Small fixes in the script and documentation for correctly reproducing the results in the corresponding paper. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3264 Reviewed By: lematt1991 Differential Revision: D26587397 Pulled By: myleott fbshipit-source-id: 3675ec4d4388cafa224d395e08b53667f142cb27 --- examples/unsupervised_quality_estimation/README.md | 6 +++--- examples/unsupervised_quality_estimation/meteor.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/unsupervised_quality_estimation/README.md b/examples/unsupervised_quality_estimation/README.md index aeb96a14b1..e86a0d13b8 100644 --- a/examples/unsupervised_quality_estimation/README.md +++ b/examples/unsupervised_quality_estimation/README.md @@ -55,7 +55,7 @@ Translate ``` CUDA_VISIBLE_DEVICES=$GPU fairseq-generate $TMP/bin --path ${MODEL_DIR}/${SRC_LANG}-${TGT_LANG}.pt --beam 5 --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 > $TMP/fairseq.out -grep ^H $TMP/fairseq.out | cut -f3- > $TMP/mt.out +grep ^H $TMP/fairseq.out | cut -d- -f2- | sort -n | cut -f3- > $TMP/mt.out ``` Post-process @@ -88,7 +88,7 @@ CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_ --retain-dropout-modules '["TransformerModel","TransformerEncoder","TransformerDecoder","TransformerEncoderLayer"]' TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out -grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores +grep ^H $TMP/dropout.scoring.out | cut -d- -f2- | sort -n | cut -f2 > $TMP/dropout.scores ``` @@ -112,7 +112,7 @@ CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_ --unkpen 5 --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer TransformerDecoderLayer --seed 46 > $TMP/dropout.generation.out -grep ^H $TMP/dropout.generation.out | cut -f3- > $TMP/dropout.hypotheses_ +grep ^H $TMP/dropout.generation.out | cut -d- -f2- | sort -n | cut -f3- > $TMP/dropout.hypotheses_ sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/dropout.hypotheses_ | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl -l $TGT_LANG > $TMP/dropout.hypotheses diff --git a/examples/unsupervised_quality_estimation/meteor.py b/examples/unsupervised_quality_estimation/meteor.py index 4a214e794d..2ee0448cf1 100644 --- a/examples/unsupervised_quality_estimation/meteor.py +++ b/examples/unsupervised_quality_estimation/meteor.py @@ -85,19 +85,19 @@ def read_output(meteor_output_path, n_repeats): def main(): parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input") + parser.add_argument("-i", "--infile") parser.add_argument("-n", "--repeat_times", type=int) parser.add_argument("-m", "--meteor") parser.add_argument("-o", "--output") args = parser.parse_args() - translations = read_translations(args.infile, args.repetitions) + translations = read_translations(args.infile, args.repeat_times) sys.stderr.write("\nGenerating input for Meteor...") - ref_path, mt_path = generate_input(translations, args.repetitions) + ref_path, mt_path = generate_input(translations, args.repeat_times) sys.stderr.write("\nRunning Meteor...") out_path = run_meteor(ref_path, mt_path, args.meteor) sys.stderr.write("\nReading output...") - scores = read_output(out_path, args.repetitions) + scores = read_output(out_path, args.repeat_times) sys.stderr.write("\nWriting results...") with open(args.output, "w") as o: for scr in scores: From b9778da42643f5b20fa0a555834d49537ce165c0 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 22 Feb 2021 15:00:15 -0800 Subject: [PATCH 35/82] Small fixes for flow-cli usage Summary: - Use `PathManager.ls` instead of `os.listdir` - Add version.txt to fairseq TARGETS Reviewed By: vishrav Differential Revision: D26579091 fbshipit-source-id: 20d57dc19335a3006cd5fa6d1a3d5e878b105874 --- fairseq/data/data_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 3042358f2f..6f7561afbe 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -17,6 +17,8 @@ import numpy as np import torch +from fairseq.file_io import PathManager + logger = logging.getLogger(__name__) @@ -24,7 +26,7 @@ def infer_language_pair(path): """Infer language pair from filename: .-.(...).idx""" src, dst = None, None - for filename in os.listdir(path): + for filename in PathManager.ls(path): parts = filename.split(".") if len(parts) >= 3 and len(parts[1].split("-")) == 2: return parts[1].split("-") From ab560669cd9baaa4009e1fd01c970f8ffccd1ee0 Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 22 Feb 2021 15:36:56 -0800 Subject: [PATCH 36/82] Fixes circular import as complained by python (#3257) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? fixes circular import as complained by python ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3257 Reviewed By: jmp84 Differential Revision: D26587382 Pulled By: myleott fbshipit-source-id: a8a6e7bee4dcfa6baf934c257958b7d7592205c8 --- .../models/speech_to_text/convtransformer_simul_trans.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fairseq/models/speech_to_text/convtransformer_simul_trans.py b/fairseq/models/speech_to_text/convtransformer_simul_trans.py index e5dd771e03..7e77330a0c 100644 --- a/fairseq/models/speech_to_text/convtransformer_simul_trans.py +++ b/fairseq/models/speech_to_text/convtransformer_simul_trans.py @@ -5,9 +5,6 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -from examples.simultaneous_translation.models.transformer_monotonic_attention import ( - TransformerMonotonicDecoder, -) from fairseq import checkpoint_utils from fairseq.models import ( register_model, @@ -33,6 +30,10 @@ def add_args(parser): def build_decoder(cls, args, task, embed_tokens): tgt_dict = task.tgt_dict + from examples.simultaneous_translation.models.transformer_monotonic_attention import ( + TransformerMonotonicDecoder, + ) + decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) if getattr(args, "load_pretrained_decoder_from", None): From c3d2beec96bd609f87d8da14cc2dffdbbd843b54 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Tue, 23 Feb 2021 23:32:40 -0800 Subject: [PATCH 37/82] efficient batch level sampling Summary: Batch level sampling (each batch comes from a dataset sampled from some distribution) is useful in cases where we have a criterion that makes this assumption or a unique collator per dataset. However, the current implementation in fairseq `MultiCorpusSampledDataset` is inefficient, because it packs batches by assuming the size of item i is `max(dataset.size(i % len(dataset)) for dataset in datasets)`, which often significantly overestimates the actual sampled item's size, especially with many datasets. We can make this more efficient by modifying `MultiCorpusDataset`, which can do efficient batch sampling by: 1. Every epoch, sampling the indices/dataset to train on. 2. When creating batches, create per-dataset batches and merge them together Reviewed By: jay-mahadeokar Differential Revision: D26601515 fbshipit-source-id: a3273f88d86d7922f9ba004e7324e909ecc6ecf7 --- fairseq/data/multi_corpus_dataset.py | 49 +++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index 7207174bf3..6563713489 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -35,6 +35,7 @@ class MultiCorpusDataset(FairseqDataset): corresponding dataset seed: random seed for sampling the datsets sort_indices: if true, will sort the ordered indices by size + batch_sample: if true, will ensure each batch is from a single dataset """ def __init__( @@ -43,6 +44,7 @@ def __init__( distribution: List[float], seed: int, sort_indices: bool = False, + batch_sample: bool = False, ): super().__init__() assert isinstance(datasets, OrderedDict) @@ -51,6 +53,7 @@ def __init__( self.distribution = distribution self.seed = seed self.sort_indices = sort_indices + self.batch_sample = batch_sample # Avoid repeated conversions to list later self.dataset_list = list(datasets.values()) @@ -80,6 +83,7 @@ def ordered_indices(self): ] if self.sort_indices: sampled_indices.sort(key=lambda i: self.num_tokens(i)) + return np.array(sampled_indices, dtype=np.int64) def _sample(self, indices, counters): @@ -125,22 +129,26 @@ def __len__(self): return self.total_num_instances def __getitem__(self, index): - index, key = self._map_index(index) + new_index, key = self._map_index(index) try: - return self.datasets[key][index] + item = self.datasets[key][new_index] + item["full_id"] = index + return item except Exception as e: e.args = (f"Error from {key} dataset", *e.args) raise def collater(self, samples): """ - Since we enforce all datsets to be the same, collating is just - picking the first one and doing collate. + If we are doing batch sampling, then pick the right collater to use. + + Otherwise we assume all collaters are the same. """ if len(samples) == 0: return None + _, key = self._map_index(samples[0]["full_id"]) - return list(self.datasets.values())[0].collater(samples) + return self.datasets[key].collater(samples) def num_tokens(self, index: int): index, key = self._map_index(index) @@ -168,3 +176,34 @@ def supports_fetch_outside_dataloader(self): self.datasets[key].supports_fetch_outside_dataloader for key in self.datasets ) + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + if not self.batch_sample: + return super().batch_by_size( + indices, max_tokens, max_sentences, required_batch_size_multiple + ) + + dataset_indices = {key: [] for key in self.datasets} + for i in indices: + _, key = self._map_index(i) + dataset_indices[key].append(i) + + batches = [] + for key in dataset_indices: + cur_batches = super().batch_by_size( + np.array(dataset_indices[key], dtype=np.int64), + max_tokens, + max_sentences, + required_batch_size_multiple, + ) + logger.info(f"Created {len(cur_batches)} batches for dataset {key}") + batches += cur_batches + + # Assume shuffling is handled in fairseq/data/iterators.py + return batches From 55e48f18fee765fc4d528650570b8af0133ac074 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 24 Feb 2021 11:22:27 -0800 Subject: [PATCH 38/82] downcast indices in TokenBlockDataset (#1647) Summary: ### Measurements TLDR: This saves ~8% CPU RAM for training tiny model on medium sized dataset (11GB on disk) Command below: ``` +---------------------+----------------+---------+--------+ | fname | cpu_mem_used | wps | ppl | +=====================+================+=========+========+ +---------------------+----------------+---------+--------+ | branch_nw8_2gpu.log | 25.41 | 54721 | 429.1 | +---------------------+----------------+---------+--------+ +---------------------+----------------+---------+--------+ | master_nw8_2gpu.log | 27.53 | 52833.1 | 429.1 | +---------------------+----------------+---------+--------+ ``` ### Command ``` base_cmd () { dd=$1 shift fairseq-train --fp16 $dd \ --task language_modeling \ --arch transformer_lm_gpt2_tiny \ --sample-break-mode complete --tokens-per-sample 512 \ --optimizer adam --clip-norm 0.0 --lr 0.0005 \ --batch-size 1 \ --max-update 200 --max-epoch 1 \ --log-format simple --log-interval 100 \ --restore-file x.pt --no-save \ --skip-invalid-size-inputs-valid-test --disable-validation $@ } CUDA_VISIBLE_DEVICES=0,1 base_cmd /private/home/sshleifer/data-bin/stories_mmap --num-workers 8 ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1647 Reviewed By: myleott Differential Revision: D26628861 Pulled By: sshleifer fbshipit-source-id: 142afe0358d1c4cae448828ba811b211406509d7 --- fairseq/data/indexed_dataset.py | 37 +++++++++++++++++++---------- fairseq/data/token_block_dataset.py | 11 +++++---- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index a821417321..066f4dcd4f 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -15,12 +15,21 @@ from . import FairseqDataset +from typing import Union -def __best_fitting_dtype(vocab_size=None): - if vocab_size is not None and vocab_size < 65500: + +def best_fitting_uint_dtype( + max_int_to_represent, +) -> Union[np.uint16, np.uint32, np.uint64]: + + if max_int_to_represent is None: + return np.uint32 # Safe guess + elif max_int_to_represent < 65500: return np.uint16 + elif max_int_to_represent < 4294967295: + return np.uint32 else: - return np.int32 + return np.uint64 def get_available_dataset_impl(): @@ -48,7 +57,7 @@ def infer_dataset_impl(path): def make_builder(out_file, impl, vocab_size=None): if impl == "mmap": return MMapIndexedDatasetBuilder( - out_file, dtype=__best_fitting_dtype(vocab_size) + out_file, dtype=best_fitting_uint_dtype(vocab_size) ) elif impl == "fasta": raise NotImplementedError @@ -92,7 +101,7 @@ def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) -dtypes = { +_code_to_dtype = { 1: np.uint8, 2: np.int8, 3: np.int16, @@ -101,12 +110,14 @@ def write_longs(f, a): 6: np.float, 7: np.double, 8: np.uint16, + 9: np.uint32, + 10: np.uint64, } -def code(dtype): - for k in dtypes.keys(): - if dtypes[k] == dtype: +def _dtype_header_code(dtype) -> int: + for k in _code_to_dtype.keys(): + if _code_to_dtype[k] == dtype: return k raise ValueError(dtype) @@ -141,7 +152,7 @@ def read_index(self, path): version = f.read(8) assert struct.unpack(" Date: Wed, 24 Feb 2021 11:25:41 -0800 Subject: [PATCH 39/82] make LanguageModelingTask 1% simpler (#1641) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1641 Reviewed By: myleott Differential Revision: D26607648 Pulled By: sshleifer fbshipit-source-id: 9d7f9d7a0825e3124c181b651a126842e5de6109 --- fairseq/tasks/language_modeling.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 4a44d967b3..579bf69785 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -184,7 +184,9 @@ def build_model(self, args): return model - def load_dataset(self, split, epoch=1, combine=False, **kwargs): + def load_dataset( + self, split: str, epoch=1, combine=False, **kwargs + ) -> MonolingualDataset: """Load a given dataset split. Args: @@ -228,7 +230,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): and self.args.sample_break_mode != "none" ) - self.datasets[split] = self._initialize_dataset( + self.datasets[split] = MonolingualDataset( dataset=dataset, sizes=dataset.sizes, src_vocab=self.dictionary, @@ -239,9 +241,6 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): add_bos_token=self.args.add_bos_token, ) - def _initialize_dataset(self, **kwargs): - return MonolingualDataset(**kwargs) - def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): """ Generate batches for inference. We prepend an eos token to src_tokens From 52daa1b29b35c93ffb950e56507c9c1d17aa2369 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 24 Feb 2021 14:21:24 -0800 Subject: [PATCH 40/82] move code to .py files, document usage (#1637) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1637 Test Plan: ```bash python examples/bart/summarize.py --model-dir pytorch/fairseq --model-file bart.large.cnn --src $HOME/data-bin/cnn_dm/test.source --n 12 --out hub_hypo.txt python examples/bart/summarize.py \ --model-dir pytorch/fairseq \ --model-file bart.large.cnn \ --src cnn_dm/test.source \ --out cnn_dm/test.hypo --xsum-kwargs ``` Reviewed By: ngoyal2707 Differential Revision: D26581703 Pulled By: sshleifer fbshipit-source-id: 80eb28012f7770eee01ed50a1163c5a2c5cc6d37 --- examples/bart/README.md | 47 +++++------- examples/bart/README.summarization.md | 55 +++++--------- examples/bart/summarize.py | 100 ++++++++++++++++++++++++++ fairseq/sequence_generator.py | 6 +- 4 files changed, 136 insertions(+), 72 deletions(-) create mode 100644 examples/bart/summarize.py diff --git a/examples/bart/README.md b/examples/bart/README.md index 013a809be6..4050a724ee 100644 --- a/examples/bart/README.md +++ b/examples/bart/README.md @@ -179,38 +179,23 @@ with open('glue_data/MNLI/dev_matched.tsv') as fin: ``` #### Evaluating the `bart.large.cnn` model: -Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample. +- Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample. +- For simpler preprocessing, you can also `wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz`, although there is no guarantee of identical scores +- `huggingface/transformers` has a simpler interface that supports [single-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_eval.py) and [multi-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_distributed_eval.py) beam search. + In `huggingface/transformers`, the BART models' paths are `facebook/bart-large-cnn` and `facebook/bart-large-xsum`. -```python -bart = torch.hub.load('pytorch/fairseq', 'bart.large.cnn') -bart.cuda() -bart.eval() -bart.half() -count = 1 -bsz = 32 -with open('test.source') as source, open('test.hypo', 'w') as fout: - sline = source.readline().strip() - slines = [sline] - for sline in source: - if count % bsz == 0: - with torch.no_grad(): - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() - slines = [] - - slines.append(sline.strip()) - count += 1 - if slines != []: - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() -``` - -Install `files2rouge` from [here](https://github.com/pltrdy/files2rouge). +In `fairseq`, summaries can be generated using: + +```bash +cp data-bin/cnn_dm/dict.source.txt checkpoints/ +python examples/bart/summarize.py \ + --model-dir pytorch/fairseq \ + --model-file bart.large.cnn \ + --src cnn_dm/test.source \ + --out cnn_dm/test.hypo +``` + +For calculating rouge, install `files2rouge` from [here](https://github.com/pltrdy/files2rouge). ```bash export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar diff --git a/examples/bart/README.summarization.md b/examples/bart/README.summarization.md index d7fecc9ce6..8727584f2b 100644 --- a/examples/bart/README.summarization.md +++ b/examples/bart/README.summarization.md @@ -80,42 +80,23 @@ Expected training time is about `5 hours`. Training time can be reduced with dis Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task ### Inference for CNN-DM test data using above trained checkpoint. -After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet: +After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using `eval_cnn.py`, for example -```python -import torch -from fairseq.models.bart import BARTModel - -bart = BARTModel.from_pretrained( - 'checkpoints/', - checkpoint_file='checkpoint_best.pt', - data_name_or_path='cnn_dm-bin' -) - -bart.cuda() -bart.eval() -bart.half() -count = 1 -bsz = 32 -with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout: - sline = source.readline().strip() - slines = [sline] - for sline in source: - if count % bsz == 0: - with torch.no_grad(): - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() - slines = [] - - slines.append(sline.strip()) - count += 1 - if slines != []: - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() +```bash +cp data-bin/cnn_dm/dict.source.txt checkpoints/ +python examples/bart/summarize.py \ + --model-dir checkpoints \ + --model-file checkpoint_best.pt \ + --src cnn_dm/test.source \ + --out cnn_dm/test.hypo +``` +For XSUM, which uses beam=6, lenpen=1.0, max_len_b=60, min_len=10: +```bash +cp data-bin/cnn_dm/dict.source.txt checkpoints/ +python examples/bart/summarize.py \ + --model-dir checkpoints \ + --model-file checkpoint_best.pt \ + --src cnn_dm/test.source \ + --out cnn_dm/test.hypo \ + --xsum-kwargs ``` -Use beam=6, lenpen=1.0, max_len_b=60, min_len=10 for Xsum Generation diff --git a/examples/bart/summarize.py b/examples/bart/summarize.py new file mode 100644 index 0000000000..04435f80e3 --- /dev/null +++ b/examples/bart/summarize.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.models.bart import BARTModel +import argparse + +XSUM_KWARGS = dict(beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3) +CNN_KWARGS = dict(beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) + + +@torch.no_grad() +def generate(bart, infile, outfile="bart_hypo.txt", bsz=32, n_obs=None, **eval_kwargs): + count = 1 + + # if n_obs is not None: bsz = min(bsz, n_obs) + + with open(infile) as source, open(outfile, "w") as fout: + sline = source.readline().strip() + slines = [sline] + for sline in source: + if n_obs is not None and count > n_obs: + break + if count % bsz == 0: + hypotheses_batch = bart.sample(slines, **eval_kwargs) + for hypothesis in hypotheses_batch: + fout.write(hypothesis + "\n") + fout.flush() + slines = [] + + slines.append(sline.strip()) + count += 1 + + if slines != []: + hypotheses_batch = bart.sample(slines, **eval_kwargs) + for hypothesis in hypotheses_batch: + fout.write(hypothesis + "\n") + fout.flush() + + +def main(): + """ + Usage:: + + python examples/bart/summarize.py \ + --model-dir $HOME/bart.large.cnn \ + --model-file model.pt \ + --src $HOME/data-bin/cnn_dm/test.source + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-dir", + required=True, + type=str, + default="bart.large.cnn/", + help="path containing model file and src_dict.txt", + ) + parser.add_argument( + "--model-file", + default="checkpoint_best.pt", + help="where in model_dir are weights saved", + ) + parser.add_argument( + "--src", default="test.source", help="text to summarize", type=str + ) + parser.add_argument( + "--out", default="test.hypo", help="where to save summaries", type=str + ) + parser.add_argument("--bsz", default=32, help="where to save summaries", type=int) + parser.add_argument( + "--n", default=None, help="how many examples to summarize", type=int + ) + parser.add_argument( + "--xsum-kwargs", + action="store_true", + default=False, + help="if true use XSUM_KWARGS else CNN_KWARGS", + ) + args = parser.parse_args() + eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS + if args.model_dir == "pytorch/fairseq": + bart = torch.hub.load("pytorch/fairseq", args.model_file) + else: + bart = BARTModel.from_pretrained( + args.model_dir, + checkpoint_file=args.model_file, + data_name_or_path=args.model_dir, + ) + bart = bart.eval() + if torch.cuda.is_available(): + bart = bart.cuda().half() + generate( + bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs + ) + + +if __name__ == "__main__": + main() diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 117c6116fb..2574ab13f0 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -214,7 +214,7 @@ def _generate( raise Exception("expected src_tokens or source in net input") # bsz: total number of sentences in beam - # Note that src_tokens may have more than 2 dimenions (i.e. audio features) + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) bsz, src_len = src_tokens.size()[:2] beam_size = self.beam_size @@ -376,9 +376,7 @@ def _generate( self.search.set_src_lengths(src_lengths) if self.repeat_ngram_blocker is not None: - lprobs = self.repeat_ngram_blocker( - tokens, lprobs, bsz, beam_size, step - ) + lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) # Shape: (batch, cand_size) cand_scores, cand_indices, cand_beams = self.search.step( From fb3fadbb159d8af6d83a5680674d20f7b7635766 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 24 Feb 2021 15:41:02 -0800 Subject: [PATCH 41/82] Set DynamicLossScaler class defaults to match CLI defaults (#1649) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1649 Reviewed By: stephenroller Differential Revision: D26639303 Pulled By: myleott fbshipit-source-id: 7def925cd7885cfe85d542464316cbc0f2ba6d2c --- fairseq/optim/dynamic_loss_scaler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/optim/dynamic_loss_scaler.py b/fairseq/optim/dynamic_loss_scaler.py index c5da604220..43f9be37b9 100644 --- a/fairseq/optim/dynamic_loss_scaler.py +++ b/fairseq/optim/dynamic_loss_scaler.py @@ -10,7 +10,7 @@ def __init__( init_scale=2.0 ** 15, scale_factor=2.0, scale_window=2000, - tolerance=0.05, + tolerance=0.0, threshold=None, min_loss_scale=1e-4, ): From b8651bc984413e7e45f44294dffcc85692ba89c1 Mon Sep 17 00:00:00 2001 From: Weiyi Zheng Date: Wed, 24 Feb 2021 15:48:38 -0800 Subject: [PATCH 42/82] actually checking gradnorm consistency Summary: D24849271 (https://github.com/pytorch/fairseq/commit/3c5647cebf454c07b52a0fb899c920789381ebda) fixed finite check, but the 'or' condition means as long as all gradients are finite, the check will pass. This diff adds back the consistency check, the norm can't differ from each other much. Reviewed By: myleott Differential Revision: D26640459 fbshipit-source-id: 3e23e13841372aa04461dcde245b893715480c3c --- fairseq/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 891155f162..680a7ee953 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -1113,7 +1113,7 @@ def is_consistent(tensor): max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) return ( torch.isfinite(tensor).all() - or (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() ) if not is_consistent(self._grad_norm_buf): From d3890e593398c485f6593ab8512ac51d37dedc9c Mon Sep 17 00:00:00 2001 From: Sravya Popuri Date: Wed, 24 Feb 2021 22:55:37 -0800 Subject: [PATCH 43/82] Add HiveScorer to read data from hive and EverstoreAudioInstance to load audio from everstore Summary: This diff - Refactors utils/agent_finder.py to reduce the complexity of find_agent_cls function - Refactors cli.py and server.py to remove unnecessary argument parser function calls - Adds fb_hive_scorer.py with HiveScorer to read data from hive and process everstore handles - Adds fb_options.py to add necessary arguments for HiveScorer - Updates other parts of the code to include the new scorer Reviewed By: jmp84 Differential Revision: D26575148 fbshipit-source-id: ae6e12d2adf5f393f807d5238f0d78a2f64a77a3 --- .../simultaneous_translation/agents/fairseq_simul_st_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index 32cd0a1f61..5793609095 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -128,7 +128,7 @@ def __init__(self, args): self.load_model_vocab(args) with open(args.config, "r") as f: - config = yaml.load(f) + config = yaml.load(f, Loader=yaml.BaseLoader) if "global_cmvn" in config: args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) From f569c024ae6ee3e8c37c3b9dca975a3df50f7a03 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Thu, 25 Feb 2021 22:33:48 -0800 Subject: [PATCH 44/82] Relocate simultaneous translation code (#1639) Summary: Relocate simultaneous translation code from example/simultaneous_translation to fairseq/model/simultaneous_translation, only keep the documents Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1639 Reviewed By: jmp84 Differential Revision: D26599346 Pulled By: xutaima fbshipit-source-id: 4f708d172696a430bd4e7b14871f5c8862a20489 --- examples/simultaneous_translation/__init__.py | 2 +- .../criterions/__init__.py | 15 --------------- .../models}/convtransformer_simul_trans.py | 10 +++++++++- fairseq/models/speech_to_text/__init__.py | 1 - fairseq/models/speech_to_text/convtransformer.py | 8 ++------ 5 files changed, 12 insertions(+), 24 deletions(-) delete mode 100644 examples/simultaneous_translation/criterions/__init__.py rename {fairseq/models/speech_to_text => examples/simultaneous_translation/models}/convtransformer_simul_trans.py (83%) diff --git a/examples/simultaneous_translation/__init__.py b/examples/simultaneous_translation/__init__.py index 446fc86c8a..5835316ba9 100644 --- a/examples/simultaneous_translation/__init__.py +++ b/examples/simultaneous_translation/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import criterions, eval, models # noqa +from . import models # noqa diff --git a/examples/simultaneous_translation/criterions/__init__.py b/examples/simultaneous_translation/criterions/__init__.py deleted file mode 100644 index 08791bfff3..0000000000 --- a/examples/simultaneous_translation/criterions/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import importlib -import os - - -for file in os.listdir(os.path.dirname(__file__)): - if file.endswith(".py") and not file.startswith("_"): - criterion_name = file[: file.find(".py")] - importlib.import_module( - "examples.simultaneous_translation.criterions." + criterion_name - ) diff --git a/fairseq/models/speech_to_text/convtransformer_simul_trans.py b/examples/simultaneous_translation/models/convtransformer_simul_trans.py similarity index 83% rename from fairseq/models/speech_to_text/convtransformer_simul_trans.py rename to examples/simultaneous_translation/models/convtransformer_simul_trans.py index 7e77330a0c..84ba4d0d3f 100644 --- a/fairseq/models/speech_to_text/convtransformer_simul_trans.py +++ b/examples/simultaneous_translation/models/convtransformer_simul_trans.py @@ -11,11 +11,19 @@ register_model_architecture, ) -from .convtransformer import ConvTransformerModel, convtransformer_espnet +from fairseq.models.speech_to_text import ConvTransformerModel, convtransformer_espnet @register_model("convtransformer_simul_trans") class SimulConvTransformerModel(ConvTransformerModel): + """ + Implementation of the paper: + + SimulMT to SimulST: Adapting Simultaneous Text Translation to + End-to-End Simultaneous Speech Translation + + https://www.aclweb.org/anthology/2020.aacl-main.58.pdf + """ @staticmethod def add_args(parser): super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser) diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py index 28e3bb720f..c6ae9b17ba 100644 --- a/fairseq/models/speech_to_text/__init__.py +++ b/fairseq/models/speech_to_text/__init__.py @@ -5,5 +5,4 @@ from .berard import * # noqa from .convtransformer import * # noqa -from .convtransformer_simul_trans import * # noqa from .s2t_transformer import * # noqa diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py index 622b5e6df8..a4cbbcdeeb 100644 --- a/fairseq/models/speech_to_text/convtransformer.py +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -7,9 +7,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from examples.simultaneous_translation.utils.data_utils import ( - lengths_to_encoder_padding_mask, -) +from fairseq.data.data_utils import lengths_to_padding_mask from fairseq import checkpoint_utils, utils from fairseq.models import ( FairseqEncoder, @@ -311,9 +309,7 @@ def forward(self, src_tokens, src_lengths): x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long() ) - encoder_padding_mask, _ = lengths_to_encoder_padding_mask( - input_lengths, batch_first=True - ) + encoder_padding_mask = lengths_to_padding_mask(input_lengths) positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) x += positions From 4f881a760e1cd7e11ecce2332b6ee9a435f233a5 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 26 Feb 2021 20:59:22 -0800 Subject: [PATCH 45/82] TokenBlockDataset np type promotion issue (#1658) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1658 Reviewed By: jxmsML Differential Revision: D26701840 Pulled By: sshleifer fbshipit-source-id: 90d631c3cd775ab847366fe7a05136c29d90cd63 --- fairseq/data/indexed_dataset.py | 10 ++++++---- fairseq/data/token_block_dataset.py | 16 ++++++++++------ fairseq/models/transformer_lm.py | 12 ++++++++++++ tests/test_token_block_dataset.py | 13 +++++++++++++ 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 066f4dcd4f..802e37a7ff 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -18,9 +18,9 @@ from typing import Union -def best_fitting_uint_dtype( +def best_fitting_int_dtype( max_int_to_represent, -) -> Union[np.uint16, np.uint32, np.uint64]: +) -> Union[np.uint16, np.uint32, np.int64]: if max_int_to_represent is None: return np.uint32 # Safe guess @@ -29,7 +29,9 @@ def best_fitting_uint_dtype( elif max_int_to_represent < 4294967295: return np.uint32 else: - return np.uint64 + return np.int64 + # we avoid np.uint64 because it doesn't save space and its type promotion behaves unexpectedly + # https://github.com/numpy/numpy/issues/5745 def get_available_dataset_impl(): @@ -57,7 +59,7 @@ def infer_dataset_impl(path): def make_builder(out_file, impl, vocab_size=None): if impl == "mmap": return MMapIndexedDatasetBuilder( - out_file, dtype=best_fitting_uint_dtype(vocab_size) + out_file, dtype=best_fitting_int_dtype(vocab_size) ) elif impl == "fasta": raise NotImplementedError diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index 038f1c81d7..4617466234 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -6,7 +6,8 @@ import numpy as np import torch from fairseq.data import FairseqDataset, plasma_utils -from fairseq.data.indexed_dataset import best_fitting_uint_dtype +from fairseq.data.indexed_dataset import best_fitting_int_dtype + class TokenBlockDataset(FairseqDataset): """Break a Dataset of tokens into blocks. @@ -95,15 +96,18 @@ def __init__( ) else: block_to_dataset_index = _get_block_to_dataset_index_fast( - sizes, - slice_indices, + sizes, slice_indices, ) size_dtype = np.uint16 if block_size < 65535 else np.uint32 - slice_indices_dtype = best_fitting_uint_dtype(slice_indices[-1].max()) + slice_indices_dtype = best_fitting_int_dtype(slice_indices[-1].max()) - self._slice_indices = plasma_utils.PlasmaArray(slice_indices.astype(slice_indices_dtype)) + self._slice_indices = plasma_utils.PlasmaArray( + slice_indices.astype(slice_indices_dtype) + ) self._sizes = plasma_utils.PlasmaArray(self._sizes.astype(size_dtype)) - self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index.astype(slice_indices_dtype)) + self._block_to_dataset_index = plasma_utils.PlasmaArray( + block_to_dataset_index.astype(slice_indices_dtype) + ) @property def slice_indices(self): diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index edf62b12b3..f12470d033 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -394,6 +394,18 @@ def transformer_lm_gpt2_small(args): base_lm_architecture(args) +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_tiny") +def transformer_lm_gpt2_tiny(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 64) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 64) + args.decoder_layers = getattr(args, "decoder_layers", 2) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 1) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + @register_model_architecture("transformer_lm", "transformer_lm_gpt2_medium") def transformer_lm_gpt2_medium(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280) diff --git a/tests/test_token_block_dataset.py b/tests/test_token_block_dataset.py index ea315b4e67..c4d7b76dcd 100644 --- a/tests/test_token_block_dataset.py +++ b/tests/test_token_block_dataset.py @@ -74,6 +74,19 @@ def test_complete_break_mode(self): self.assertEqual(ds[1].tolist(), [5, 1, 1]) self.assertEqual(ds[2].tolist(), [6, 1]) + def test_4billion_tokens(self): + """Regression test for numpy type promotion issue https://github.com/numpy/numpy/issues/5745""" + data = [torch.tensor(list(range(10000)), dtype=torch.long)] * 430000 + ds = self._build_dataset( + data, block_size=6, pad=0, eos=1, break_mode="complete" + ) + ds[-1] # __getitem__ works + start, end = ds.slice_indices[-1] + assert end > 4294967295 # data must be sufficiently large to overflow uint32 + assert not isinstance( + end + 1, float + ) # this would also raise, since np.uint64(1) + 1 => 2.0 + if __name__ == "__main__": unittest.main() From 5354aa3a6ec80092cc7bb9aecfad7077bb50b47e Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 28 Feb 2021 12:44:23 -0800 Subject: [PATCH 46/82] github CI install pyarrow Reviewed By: myleott Differential Revision: D26643358 fbshipit-source-id: 8d7e1082c6e11f9bbab4b34de078cf05197297a5 --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 29e5254d33..0af8bad95d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -39,7 +39,7 @@ jobs: - name: Install optional test requirements run: | - python -m pip install fairscale iopath transformers + python -m pip install fairscale iopath transformers pyarrow - name: Lint with flake8 run: | From e5e8b3fee1e57a7abf35ad1a3ff223a2b7190c65 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 28 Feb 2021 12:49:20 -0800 Subject: [PATCH 47/82] Fix nearly all unit-test warnings (#1652) Summary: 2 types of warnings fixed: ``` `np.long` is a deprecated alias for `np.compat.long`. Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.9 it will stop working ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1652 Reviewed By: myleott Differential Revision: D26643344 Pulled By: sshleifer fbshipit-source-id: 960bccc94f299bd8a8c58a87acd80694e9d5c363 --- fairseq/data/language_pair_dataset.py | 12 +++--------- fairseq/data/token_block_dataset.py | 2 +- fairseq/optim/lr_scheduler/cosine_lr_scheduler.py | 2 +- .../lr_scheduler/inverse_square_root_schedule.py | 2 +- 4 files changed, 6 insertions(+), 12 deletions(-) diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 8858cec84e..9d36cbd4ce 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -114,10 +114,7 @@ def compute_alignment_weights(alignments): "id": id, "nsentences": len(samples), "ntokens": ntokens, - "net_input": { - "src_tokens": src_tokens, - "src_lengths": src_lengths, - }, + "net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths,}, "target": target, } if prev_output_tokens is not None: @@ -289,7 +286,7 @@ def __init__( # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to BucketPadLengthDataset) - num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) + num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) @@ -470,8 +467,5 @@ def filter_indices_by_size(self, indices, max_sizes): list: list of removed indices """ return data_utils.filter_paired_dataset_indices_by_size( - self.src_sizes, - self.tgt_sizes, - indices, - max_sizes, + self.src_sizes, self.tgt_sizes, indices, max_sizes, ) diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index 4617466234..ce0a0d1114 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -88,7 +88,7 @@ def __init__( [ np.arange(len(sizes)), # starting index in dataset np.zeros( - len(sizes), dtype=np.long + len(sizes), dtype=np.compat.long ), # starting offset within starting index np.arange(len(sizes)), # ending index in dataset ], diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index 38b57fe54c..51f58359ed 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import math -from collections import Collection +from collections.abc import Collection from dataclasses import dataclass, field from typing import List diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index d9321577bb..0f87bb5d7e 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import Collection +from collections.abc import Collection from dataclasses import dataclass, field from typing import List From 39e55139ea05da36e9ab9837c4943f660b79dcbe Mon Sep 17 00:00:00 2001 From: Hiromu Yakura Date: Mon, 1 Mar 2021 12:36:18 -0800 Subject: [PATCH 48/82] Fix the order of constraints in LanguagePairDataset (#3280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/3279. This change modifies the output of `echo -e "Ja, wer hat, wenn du willst, Götter gebildet, uns zu ihnen erhoben, sie zu uns herniedergebracht, als der Dichter?\tbard\nZu vollenden ist nicht die Sache des Schülers, es ist genug, wenn er sich übt\tstudent" | python normalize.py | python tok.py | fairseq-interactive --constraints -s de -t en --beam 10 --batch-size 2 --buffer-size 2 --bpe fastbpe --bpe-codes ../../../models/ende30k.fastbpe.code --path ../../../models/wmt19.de-en.ffn8192.pt ../../../models/` as follows. Before: ``` S-0 Ja , wer hat , wenn du will@@ st , Gö@@ tter gebildet , uns zu ihnen erhoben , sie zu uns her@@ nieder@@ gebracht , als der Dich@@ ter ? W-0 1.755 seconds C-0 student H-0 -1.1425577402114868 Yes , who , if you will , has formed go@@ ds , raised us up to them , brought them down to us , but the po@@ et student ? D-0 -1.1425577402114868 Yes , who , if you will , has formed gods , raised us up to them , brought them down to us , but the poet student ? P-0 -1.8768 -0.2214 -0.4671 -1.2521 -0.2101 -0.3053 -1.2077 -0.1496 -1.8780 -1.4195 -0.4071 -0.1347 -0.3726 -1.1306 -0.1665 -1.4588 -0.2837 -0.1722 -0.2330 -0.2840 -0.1806 -0.1432 -0.2263 -0.1395 -0.7261 -1.4593 -0.3639 -0.4030 -0.1083 -18.7577 -0.2396 -0.1837 S-1 Zu voll@@ enden ist nicht die Sache des Sch@@ ül@@ ers , es ist genug , wenn er sich übt W-1 1.755 seconds C-1 b@@ ard H-1 -1.9625756740570068 It is not up to the b@@ ard to complete , it is enough if he practi@@ ses D-1 -1.9625756740570068 It is not up to the bard to complete , it is enough if he practises P-1 -1.2630 -0.3364 -0.1634 -2.7070 -0.1734 -0.2815 -17.3978 -6.0238 -0.4888 -1.7563 -0.8708 -0.6773 -0.2027 -0.2456 -1.6366 -0.2911 -2.0235 -0.1961 -0.5538 ``` After: ``` S-0 Ja , wer hat , wenn du will@@ st , Gö@@ tter gebildet , uns zu ihnen erhoben , sie zu uns her@@ nieder@@ gebracht , als der Dich@@ ter ? W-0 1.740 seconds C-0 b@@ ard H-0 -1.2060465812683105 Yes , who , if you will , formed go@@ ds , raised us up to them , brought them down to us , but the b@@ ard ? D-0 -1.2060465812683105 Yes , who , if you will , formed gods , raised us up to them , brought them down to us , but the bard ? P-0 -1.8768 -0.2214 -0.4671 -1.2521 -0.2101 -0.3053 -1.2077 -0.1496 -2.2551 -0.5702 -0.1331 -0.3940 -1.0268 -0.1750 -1.4635 -0.2821 -0.1725 -0.2404 -0.3575 -0.1833 -0.1441 -0.2250 -0.1419 -0.7020 -1.5215 -0.3700 -16.8578 -2.7290 -0.3405 -0.2060 S-1 Zu voll@@ enden ist nicht die Sache des Sch@@ ül@@ ers , es ist genug , wenn er sich übt W-1 1.740 seconds C-1 student H-1 -0.8064212203025818 It is not up to the student to complete , it is enough if he practi@@ ses D-1 -0.8064212203025818 It is not up to the student to complete , it is enough if he practises P-1 -1.2630 -0.3364 -0.1634 -2.7070 -0.1734 -0.2815 -1.5556 -0.2831 -1.3885 -0.7310 -0.6367 -0.1824 -0.2386 -1.5320 -0.2728 -2.0003 -0.2163 -0.5536 ``` ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3280 Reviewed By: myleott Differential Revision: D26725013 Pulled By: lematt1991 fbshipit-source-id: 2275fcf146cb8cd9ca21f847e10a4dacdee996f9 --- fairseq/data/language_pair_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 9d36cbd4ce..ff3e14bf14 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -157,7 +157,7 @@ def compute_alignment_weights(alignments): constraints = torch.zeros((len(samples), max(lens))).long() for i, sample in enumerate(samples): constraints[i, 0 : lens[i]] = samples[i].get("constraints") - batch["constraints"] = constraints + batch["constraints"] = constraints.index_select(0, sort_order) return batch From 1c0439b7dabe62d39c6e7f1c8ebc86311e042b5a Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 1 Mar 2021 16:21:05 -0800 Subject: [PATCH 49/82] fixes circular imports incurred by a recent commit (#3286) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? Fixes circular imports incurred by a recent commit ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3286 Reviewed By: lematt1991 Differential Revision: D26725255 Pulled By: myleott fbshipit-source-id: 5572f733b83bdfadcce3188c0789fc6d70a3bad3 --- fairseq/models/fairseq_model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 244cbc0c66..186f3d2464 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -14,7 +14,6 @@ import torch.nn as nn import torch.nn.functional as F from fairseq import utils -from fairseq.checkpoint_utils import prune_state_dict from fairseq.data import Dictionary from fairseq.dataclass.utils import ( convert_namespace_to_omegaconf, @@ -111,6 +110,9 @@ def load_state_dict( model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) + + from fairseq.checkpoint_utils import prune_state_dict + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) @@ -450,6 +452,9 @@ def load_state_dict( model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) + + from fairseq.checkpoint_utils import prune_state_dict + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) From 3100d0b8e5bb5e61b4d73b9c058389aa2c06784a Mon Sep 17 00:00:00 2001 From: Eric Lou Date: Tue, 2 Mar 2021 09:24:03 -0800 Subject: [PATCH 50/82] ioPath async - opt-in Fairseq integration (#1635) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1635 **Summary:** Integrate ioPath's async writes feature into Fairseq checkpoint writing. **Details:** - Created new checkpoint config param `--write-checkpoints-asynchronously` with default value `False`. Aliased to `--save-async`. - Added to `PathManager` class in `file_io.py` to include `PathManager.opena(...)` and `PathManager.async_close()`. These new methods use ioPath's async `PathManager`. **Usage:** ``` python train.py --save-async ``` --------- NOTE: **QUESTIONS** 1. In the current implementation, we don't save `checkpoint_best` and `checkpoint_latest` since ioPath doesn't yet have a "wait until a file is written and then copy/move it to another path" feature. Is this okay for now? 2. Should I mimic the atomic vs non-atomic save structure that synchronous Fairseq checkpoint writes have? **Note to Eric:** Keep this integration in check with D26375501. Reviewed By: myleott Differential Revision: D26467815 fbshipit-source-id: 50068ef7bf9a6d5cea4d5e0d13d672604dc4a6b0 --- fairseq/checkpoint_utils.py | 42 ++++++++++++++++++++++---------- fairseq/dataclass/configs.py | 10 ++++++++ fairseq/file_io.py | 46 ++++++++++++++++++++++++++++++++++++ fairseq_cli/train.py | 20 ++++++++++++++++ 4 files changed, 105 insertions(+), 13 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 55a546356e..d6618fbb62 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -93,9 +93,17 @@ def is_better(a, b): if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: - assert PathManager.copy( - checkpoints[0], cp, overwrite=True - ), f"Failed to copy {checkpoints[0]} to {cp}" + if cfg.write_checkpoints_asynchronously: + # TODO[ioPath]: Need to implement a delayed asynchronous + # file copying/moving feature. + logger.warning( + f"ioPath is not copying {checkpoints[0]} to {cp} " + "since async write mode is on." + ) + else: + assert PathManager.copy( + checkpoints[0], cp, overwrite=True + ), f"Failed to copy {checkpoints[0]} to {cp}" write_timer.stop() logger.info( @@ -383,7 +391,23 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] -def torch_persistent_save(obj, f): +def torch_persistent_save(cfg: CheckpointConfig, obj, filename): + if cfg.write_checkpoints_asynchronously: + with PathManager.opena(filename, "wb") as f: + _torch_persistent_save(obj, f) + else: + if PathManager.supports_rename(filename): + # do atomic save + with PathManager.open(filename + ".tmp", "wb") as f: + _torch_persistent_save(obj, f) + PathManager.rename(filename + ".tmp", filename) + else: + # fallback to non-atomic save + with PathManager.open(filename, "wb") as f: + _torch_persistent_save(obj, f) + + +def _torch_persistent_save(obj, f): if isinstance(f, str): with PathManager.open(f, "wb") as h: torch_persistent_save(obj, h) @@ -448,15 +472,7 @@ def save_state( # keep everything on CPU state_dict = utils.move_to_cpu(state_dict) - if PathManager.supports_rename(filename): - # do atomic save - with PathManager.open(filename + ".tmp", "wb") as f: - torch_persistent_save(state_dict, f) - PathManager.rename(filename + ".tmp", filename) - else: - # fallback to non-atomic save - with PathManager.open(filename, "wb") as f: - torch_persistent_save(state_dict, f) + torch_persistent_save(cfg.checkpoint, state_dict, filename) def _upgrade_state_dict(state): diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index f66e98fe83..39355b1caf 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -607,6 +607,16 @@ class CheckpointConfig(FairseqDataclass): "(default: only load on rank 0 and broadcast to other devices)" }, ) + write_checkpoints_asynchronously: bool = field( + default=False, + metadata={ + "help": ( + "Write checkpoints asynchronously in a separate " + "thread. NOTE: This feature is currently being tested." + ), + "argparse_alias": "--save-async", + }, + ) model_parallel_size: int = II("common.model_parallel_size") distributed_rank: int = II("distributed_training.distributed_rank") diff --git a/fairseq/file_io.py b/fairseq/file_io.py index 7d6c28dccd..731fef3570 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -32,6 +32,8 @@ except ImportError: FVCorePathManager = None +IOPathPathManager = None + class PathManager: """ @@ -148,3 +150,47 @@ def supports_rename(path: str) -> bool: @staticmethod def rename(src: str, dst: str): os.rename(src, dst) + + """ + ioPath async PathManager methods: + """ + @staticmethod + def opena( + path: str, + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + ): + """ + Return file descriptor with asynchronous write operations. + """ + global IOPathPathManager + if not IOPathPathManager: + logging.info("ioPath is initializing PathManager.") + try: + from iopath import PathManager + IOPathPathManager = PathManager() + except Exception: + logging.exception("Failed to initialize ioPath PathManager object.") + return IOPathPathManager.opena( + path=path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + + @staticmethod + def async_close() -> bool: + """ + Wait for files to be written and clean up asynchronous PathManager. + NOTE: `PathManager.async_close()` must be called at the end of any + script that uses `PathManager.opena(...)`. + """ + global IOPathPathManager + if IOPathPathManager: + return IOPathPathManager.async_close() + return False diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index ec4890b9e6..80ad57acd1 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -28,6 +28,7 @@ from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.distributed_utils import is_master +from fairseq.file_io import PathManager from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer from fairseq.trainer import Trainer @@ -67,6 +68,16 @@ def main(cfg: FairseqConfig) -> None: # Print args logger.info(cfg) + if cfg.checkpoint.write_checkpoints_asynchronously: + try: + import iopath # noqa: F401 + except ImportError: + logging.exception( + "Asynchronous checkpoint writing is specified but iopath is " + "not installed: `pip install iopath`" + ) + return + # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) # Load valid dataset (we load training data below, based on the latest checkpoint) @@ -157,6 +168,15 @@ def main(cfg: FairseqConfig) -> None: train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) + # ioPath implementation to wait for all asynchronous file writes to complete. + if cfg.checkpoint.write_checkpoints_asynchronously: + logger.info( + "ioPath PathManager waiting for all asynchronous checkpoint " + "writes to finish." + ) + PathManager.async_close() + logger.info("ioPath PathManager finished waiting.") + def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: # skip check if no validation was done in the current epoch From 12e21b9a6e7262fa1af2090e22c301bc0b5d1399 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Tue, 2 Mar 2021 13:28:53 -0800 Subject: [PATCH 51/82] Add global cmvn for mustc data preparation (#1660) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1660 Reviewed By: jmp84, kahne Differential Revision: D26708521 Pulled By: xutaima fbshipit-source-id: c53e9052298c559706ceffeb359dadfede2f1a09 --- examples/speech_to_text/data_utils.py | 32 +++++++++++++++++-- examples/speech_to_text/prep_mustc_data.py | 37 ++++++++++++++++++++-- 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py index 0d7c034419..fa0d459611 100644 --- a/examples/speech_to_text/data_utils.py +++ b/examples/speech_to_text/data_utils.py @@ -126,7 +126,9 @@ def gen_config_yaml( specaugment_policy: str = "lb", prepend_tgt_lang_tag: bool = False, sampling_alpha: float = 1.0, - audio_root: str = "" + audio_root: str = "", + cmvn_type: str = "utterance", + gcmvn_path: Optional[Path] = None, ): manifest_root = manifest_root.absolute() writer = S2TDataConfigWriter(manifest_root / yaml_filename) @@ -151,8 +153,19 @@ def gen_config_yaml( if prepend_tgt_lang_tag: writer.set_prepend_tgt_lang_tag(True) writer.set_sampling_alpha(sampling_alpha) - writer.set_feature_transforms("_train", ["utterance_cmvn", "specaugment"]) - writer.set_feature_transforms("*", ["utterance_cmvn"]) + + if cmvn_type not in ["global", "utterance"]: + raise NotImplementedError + + writer.set_feature_transforms("_train", [f"{cmvn_type}_cmvn", "specaugment"]) + writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"]) + + if cmvn_type == "global": + assert gcmvn_path is not None, ( + 'Please provide path of global cmvn file.' + ) + writer.set_global_cmvn(gcmvn_path) + if len(audio_root) > 0: writer.set_audio_root(audio_root) writer.flush() @@ -206,6 +219,16 @@ def filter_manifest_df( return df[valid] +def cal_gcmvn_stats(features_list): + features = np.concatenate(features_list) + square_sums = (features ** 2).sum(axis=0) + mean = features.mean(axis=0) + features = np.subtract(features, mean) + var = square_sums / features.shape[0] - mean ** 2 + std = np.sqrt(np.maximum(var, 1e-8)) + return {"mean": mean.astype("float32"), "std": std.astype("float32")} + + class S2TDataConfigWriter(object): DEFAULT_VOCAB_FILENAME = "dict.txt" DEFAULT_INPUT_FEAT_PER_CHANNEL = 80 @@ -297,6 +320,9 @@ def set_input_feat_per_channel(self, input_feat_per_channel: int = 80): def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]): self.config["bpe_tokenizer"] = bpe_tokenizer + def set_global_cmvn(self, stats_npz_path: str): + self.config["stats_npz_path"] = stats_npz_path + def set_feature_transforms(self, split: str, transforms: List[str]): if "transforms" not in self.config: self.config["transforms"] = {} diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 520968401c..4e410bcb18 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -13,6 +13,7 @@ from tempfile import NamedTemporaryFile from typing import Tuple +import numpy as np import pandas as pd import torchaudio from examples.speech_to_text.data_utils import ( @@ -24,6 +25,7 @@ get_zip_manifest, load_df_from_tsv, save_df_to_tsv, + cal_gcmvn_stats, ) from torch import Tensor from torch.utils.data import Dataset @@ -111,10 +113,28 @@ def process(args): print(f"Fetching split {split}...") dataset = MUSTC(root.as_posix(), lang, split) print("Extracting log mel filter bank features...") + if split == 'train' and args.cmvn_type == "global": + print("And estimating cepstral mean and variance stats...") + gcmvn_feature_list = [] + for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): - extract_fbank_features( - waveform, sample_rate, feature_root / f"{utt_id}.npy" + features = extract_fbank_features(waveform, sample_rate) + + np.save( + (feature_root / f"{utt_id}.npy").as_posix(), + features ) + + if split == 'train' and args.cmvn_type == "global": + if len(gcmvn_feature_list) < args.gcmvn_max_num: + gcmvn_feature_list.append(features) + + if split == 'train' and args.cmvn_type == "global": + # Estimate and save cmv + stats = cal_gcmvn_stats(gcmvn_feature_list) + with open(cur_root / "gcmvn.npz", "wb") as f: + np.savez(f, mean=stats["mean"], std=stats["std"]) + # Pack features into ZIP zip_path = cur_root / "fbank80.zip" print("ZIPing features...") @@ -158,6 +178,11 @@ def process(args): spm_filename_prefix + ".model", yaml_filename=f"config_{args.task}.yaml", specaugment_policy="lb", + cmvn_type=args.cmvn_type, + gcmvn_cmvn_path=( + cur_root / "gcmvn.npz" if args.cmvn_type == "global" + else None + ), ) # Clean up shutil.rmtree(feature_root) @@ -216,6 +241,14 @@ def main(): parser.add_argument("--vocab-size", default=8000, type=int) parser.add_argument("--task", type=str, choices=["asr", "st"]) parser.add_argument("--joint", action="store_true", help="") + parser.add_argument("--cmvn-type", default="utterance", + choices=["global", "utterance"], + help="The type of cepstral mean and variance normalization") + parser.add_argument("--gcmvn-max-num", default=150000, type=int, + help=( + "Maximum number of sentences to use to estimate" + "global mean and variance" + )) args = parser.parse_args() if args.joint: From c58af189957eb15b47e507473b4da3e83dfbdf2e Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Tue, 2 Mar 2021 17:08:45 -0800 Subject: [PATCH 52/82] Several update on simultaneous translation inference. (#1655) Summary: Fix some issues in some corner cases. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1655 Reviewed By: jmp84 Differential Revision: D26651362 Pulled By: sravyapopuri388 fbshipit-source-id: 160d75be8d49f8263c14af225c90fe7997171a43 --- .../models/convtransformer_simul_trans.py | 2 +- .../models/transformer_monotonic_attention.py | 90 ++++--------------- .../modules/fixed_pre_decision.py | 38 ++++++-- .../modules/monotonic_multihead_attention.py | 7 +- .../agents/fairseq_simul_st_agent.py | 32 ++++--- 5 files changed, 76 insertions(+), 93 deletions(-) diff --git a/examples/simultaneous_translation/models/convtransformer_simul_trans.py b/examples/simultaneous_translation/models/convtransformer_simul_trans.py index 84ba4d0d3f..760a48168d 100644 --- a/examples/simultaneous_translation/models/convtransformer_simul_trans.py +++ b/examples/simultaneous_translation/models/convtransformer_simul_trans.py @@ -10,7 +10,6 @@ register_model, register_model_architecture, ) - from fairseq.models.speech_to_text import ConvTransformerModel, convtransformer_espnet @@ -24,6 +23,7 @@ class SimulConvTransformerModel(ConvTransformerModel): https://www.aclweb.org/anthology/2020.aacl-main.58.pdf """ + @staticmethod def add_args(parser): super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser) diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index dd3895f0eb..65c12c6f5b 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -65,60 +65,6 @@ def _indices_from_states(self, states): return src_indices, None, tgt_indices - def predict_from_states(self, states): - decoder_states = self.decoder.output_layer(states["decoder_features"]) - lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True) - - index = lprobs.argmax(dim=-1) - - token = self.decoder.dictionary.string(index) - - return token, index[0, 0].item() - - def decision_from_states(self, states): - """ - This funcion take states dictionary as input, and gives the agent - a decision of whether read a token from server. Moreover, the decoder - states are also calculated here so we can directly generate a target - token without recompute every thing - """ - - self.eval() - - if len(states["tokens"]["src"]) == 0: - return 0 - - src_indices, src_lengths, tgt_indices = self._indices_from_states(states) - - # Update encoder states if needed - if ( - "encoder_states" not in states - or states["encoder_states"][0].size(1) <= states["steps"]["src"] - ): - encoder_out_dict = self.encoder(src_indices, src_lengths) - states["encoder_states"] = encoder_out_dict - else: - encoder_out_dict = states["encoder_states"] - - # online means we still need tokens to feed the model - states["model_states"]["online"] = not ( - states["finish_read"] - and len(states["tokens"]["src"]) == states["steps"]["src"] - ) - - states["model_states"]["steps"] = states["steps"] - - x, outputs = self.decoder.forward( - prev_output_tokens=tgt_indices, - encoder_out=encoder_out_dict, - incremental_state=states["model_states"], - features_only=True, - ) - - states["decoder_features"] = x - - return outputs["action"] - class TransformerMonotonicEncoder(TransformerEncoder): def __init__(self, args, dictionary, embed_tokens): @@ -208,6 +154,18 @@ def post_attention(self, x): return x + def clear_cache(self, incremental_state, end_id=None): + """ + Clear cache in the monotonic layers. + The cache is generated because of a forward pass of decode but no prediction. + end_id is the last idx of the layers + """ + if end_id is None: + end_id = len(self.layers) + + for j in range(end_id): + self.layers[j].prune_incremental_state(incremental_state) + def extract_features( self, prev_output_tokens, encoder_out, incremental_state=None, **unused ): @@ -247,9 +205,13 @@ def extract_features( curr_steps = layer.get_head_steps(incremental_state) step_list.append(curr_steps) - if incremental_state.get("online", False): + if incremental_state.get("online", True): + # Online indicates that the encoder states are still changing p_choose = ( - attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t()) + attn["p_choose"] + .squeeze(0) + .squeeze(1) + .gather(1, curr_steps.t()) ) new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps) @@ -258,24 +220,10 @@ def extract_features( # We need to prune the last self_attn saved_state # if model decide not to read # otherwise there will be duplicated saved_state - for j in range(i + 1): - self.layers[j].prune_incremental_state(incremental_state) + self.clear_cache(incremental_state, i + 1) return x, {"action": 0} - if incremental_state is not None and not incremental_state.get("online", False): - # Here is for fast evaluation - fastest_step = ( - torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1 - ) - - if "fastest_step" in incremental_state: - incremental_state["fastest_step"] = torch.cat( - [incremental_state["fastest_step"], fastest_step], dim=1 - ) - else: - incremental_state["fastest_step"] = fastest_step - x = self.post_attention(x) return x, { diff --git a/examples/simultaneous_translation/modules/fixed_pre_decision.py b/examples/simultaneous_translation/modules/fixed_pre_decision.py index 725be1a983..cc5e7ad532 100644 --- a/examples/simultaneous_translation/modules/fixed_pre_decision.py +++ b/examples/simultaneous_translation/modules/fixed_pre_decision.py @@ -1,6 +1,7 @@ from functools import partial import torch +import math import torch.nn.functional as F from . import register_monotonic_attention @@ -96,6 +97,9 @@ def p_choose( incremental_state=None, **extra_args ): + src_len = key.size(0) + tgt_len = query.size(0) + batch_size = query.size(1) if self.pre_decision_ratio == 1: return super().p_choose( @@ -119,6 +123,16 @@ def p_choose( else: key_padding_mask_pool = None + if incremental_state is not None: + # The floor instead of ceil is used for inference + # But make sure the length key_pool at least 1 + if ( + max(1, math.floor(key.size(0) / self.pre_decision_ratio)) + ) < key_pool.size(0): + key_pool = key_pool[:-1] + if key_padding_mask_pool is not None: + key_padding_mask_pool = key_padding_mask_pool[:-1] + p_choose_pooled = super().p_choose( query, key_pool, @@ -129,13 +143,23 @@ def p_choose( # Upsample, interpolate zeros p_choose = self.insert_zeros(p_choose_pooled) - # can be larger than src_len because we used ceil before - src_len = key.size(0) - p_choose = p_choose[:, :, :src_len] - p_choose[:, :, -1] = p_choose_pooled[:, :, -1] - - tgt_len = query.size(0) - batch_size = query.size(1) + if p_choose.size(-1) < src_len: + # Append zeros if the upsampled p_choose is shorter than src_len + p_choose = torch.cat( + [ + p_choose, + p_choose.new_zeros( + p_choose.size(0), + tgt_len, + src_len - p_choose.size(-1) + ) + ], + dim=2 + ) + else: + # can be larger than src_len because we used ceil before + p_choose = p_choose[:, :, :src_len] + p_choose[:, :, -1] = p_choose_pooled[:, :, -1] assert list(p_choose.size()) == [ batch_size * self.num_heads, diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index 3e25957cd6..49882afcd8 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -604,13 +604,14 @@ def p_choose( key_padding_mask: bsz, src_len """ if incremental_state is not None: + # Retrieve target length from incremental states + # For inference the length of query is always 1 tgt_len = int(incremental_state["steps"]["tgt"]) - src_len = int(incremental_state["steps"]["src"]) - bsz = 1 else: - src_len, bsz, _ = key.size() tgt_len, bsz, _ = query.size() + src_len, bsz, _ = key.size() + p_choose = query.new_ones(bsz, tgt_len, src_len) p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1) p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1) diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index 5793609095..f944203785 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -127,6 +127,15 @@ def __init__(self, args): self.load_model_vocab(args) + if getattr( + self.model.decoder.layers[0].encoder_attn, + 'pre_decision_ratio', + None + ) is not None: + self.speech_segment_size *= ( + self.model.decoder.layers[0].encoder_attn.pre_decision_ratio + ) + with open(args.config, "r") as f: config = yaml.load(f, Loader=yaml.BaseLoader) @@ -167,15 +176,15 @@ def add_args(parser): parser.add_argument("--max-len", type=int, default=200, help="Max length of translation") parser.add_argument("--force-finish", default=False, action="store_true", - help="") + help="Force the model to finish the hypothsis if the source is not finished") parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE, - help="") + help="Shift size of feature extraction window.") parser.add_argument("--window-size", type=int, default=WINDOW_SIZE, - help="") + help="Window size of feature extraction window.") parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE, - help="") + help="Sample rate") parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM, - help="") + help="Acoustic feature dimension.") # fmt: on return parser @@ -265,11 +274,12 @@ def units_to_segment(self, units, states): def update_model_encoder(self, states): if len(states.units.source) == 0: return - src_indices = self.to_device(states.units.source.value.unsqueeze(0)) + src_indices = self.to_device( + states.units.source.value.unsqueeze(0) + ) src_lengths = self.to_device( torch.LongTensor([states.units.source.value.size(0)]) ) - print(src_lengths) states.encoder_states = self.model.encoder(src_indices, src_lengths) torch.cuda.empty_cache() @@ -294,13 +304,12 @@ def policy(self, states): "tgt": 1 + len(states.units.target), } - states.incremental_states["online"] = True + states.incremental_states["online"] = not states.finish_read() x, outputs = self.model.decoder.forward( prev_output_tokens=tgt_indices, encoder_out=states.encoder_states, incremental_state=states.incremental_states, - # features_only=True, ) states.decoder_out = x @@ -323,8 +332,6 @@ def predict(self, states): index = lprobs.argmax(dim=-1) - torch.cuda.empty_cache() - index = index[0, 0].item() if ( @@ -332,6 +339,9 @@ def predict(self, states): and index == self.model.decoder.dictionary.eos() and not states.finish_read() ): + # If we want to force finish the translation + # (don't stop before finish reading), return a None + # self.model.decoder.clear_cache(states.incremental_states) index = None return index From ddc483ff3d3a70f3abc33fc4d10bb29871c73d73 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Tue, 2 Mar 2021 17:08:45 -0800 Subject: [PATCH 53/82] Streaming models for simul ST (#1552) Summary: `fairseq/models/speech_to_text/modules/emformer.py` mostly contains the code from Yangyang. I did a little modification to make it run on fairseq. `fairseq/models/speech_to_text/modules/augmented_memory_attention.py` contains code for the old streaming models `fairseq/models/speech_to_text/modules/convtransformer_simul_trans.py` contaons three convtransformer based simultaneous translation models. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1552 Reviewed By: jmp84 Differential Revision: D26563864 Pulled By: sravyapopuri388 fbshipit-source-id: a91a6247559861977cbc22db00ba9511f6b21c69 --- .../modules/augmented_memory_attention.py | 486 +++++ .../models/speech_to_text/modules/emformer.py | 1838 +++++++++++++++++ fairseq/models/speech_to_text/utils.py | 564 +++++ 3 files changed, 2888 insertions(+) create mode 100644 fairseq/models/speech_to_text/modules/augmented_memory_attention.py create mode 100644 fairseq/models/speech_to_text/modules/emformer.py create mode 100644 fairseq/models/speech_to_text/utils.py diff --git a/fairseq/models/speech_to_text/modules/augmented_memory_attention.py b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py new file mode 100644 index 0000000000..5d31524b76 --- /dev/null +++ b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py @@ -0,0 +1,486 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple, List + +import torch +import torch.nn.functional as F +from fairseq.models import FairseqEncoder +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.speech_to_text import ( + ConvTransformerEncoder, +) +from fairseq.models.speech_to_text.utils import attention_suppression +from fairseq.models.speech_to_text.utils import ( + lengths_to_encoder_padding_mask, + segments_to_sequence, + sequence_to_segments, +) +from fairseq.modules import MultiheadAttention, TransformerEncoderLayer +from torch import nn, Tensor + +# ------------------------------------------------------------------------------ +# AugmentedMemoryConvTransformerEncoder +# ------------------------------------------------------------------------------ + + +class AugmentedMemoryConvTransformerEncoder(ConvTransformerEncoder): + def __init__(self, args): + super().__init__(args) + + args.encoder_stride = self.stride() + + self.left_context = args.left_context // args.encoder_stride + + self.right_context = args.right_context // args.encoder_stride + + self.left_context_after_stride = args.left_context // args.encoder_stride + self.right_context_after_stride = args.right_context // args.encoder_stride + + self.transformer_layers = nn.ModuleList([]) + self.transformer_layers.extend( + [ + AugmentedMemoryTransformerEncoderLayer(args) + for i in range(args.encoder_layers) + ] + ) + + def stride(self): + # Hard coded here. Should infer from convs in future + stride = 4 + return stride + + def forward(self, src_tokens, src_lengths, states=None): + """Encode input sequence. + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + bsz, max_seq_len, _ = src_tokens.size() + x = ( + src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + .transpose(1, 2) + .contiguous() + ) + x = self.conv(x) + bsz, _, output_seq_len, _ = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1) + x = self.out(x) + x = self.embed_scale * x + + subsampling_factor = 1.0 * max_seq_len / output_seq_len + input_lengths = (src_lengths.float() / subsampling_factor).round().long() + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + input_lengths, batch_first=True + ) + + # TODO: fix positional embedding + positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) + + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # State to store memory banks etc. + if states is None: + states = [ + {"memory_banks": None, "encoder_states": None} + for i in range(len(self.transformer_layers)) + ] + + for i, layer in enumerate(self.transformer_layers): + # x size: + # (self.left_size + self.segment_size + self.right_size) + # / self.stride, num_heads, dim + # TODO: Consider mask here + x = layer(x, states[i]) + states[i]["encoder_states"] = x[ + self.left_context_after_stride : -self.right_context_after_stride + ] + + lengths = ( + ( + ~encoder_padding_mask[ + :, self.left_context_after_stride : -self.right_context_after_stride + ] + ) + .sum(dim=1, keepdim=True) + .long() + ) + + return states[-1]["encoder_states"], lengths, states + + +# ------------------------------------------------------------------------------ +# AugmentedMemoryTransformerEncoderLayer +# ------------------------------------------------------------------------------ +class AugmentedMemoryTransformerEncoderLayer(TransformerEncoderLayer): + def __init__(self, args): + super().__init__(args) + + self.left_context = args.left_context // args.encoder_stride + self.right_context = args.right_context // args.encoder_stride + + def forward(self, x, state): + + length, batch_size, x_dim = x.size() + + residual = x + + if self.normalize_before: + x = self.self_attn_layer_norm(x) + + # init_state + if state.get("memory_banks", None) is None: + state["memory_banks"] = [] + + # TODO reseach new sum_query method + seg_start = self.left_context + seg_end = length - self.right_context + if seg_start < seg_end: + summarization_query = torch.mean(x[seg_start:seg_end], keepdim=True, dim=0) + else: + summarization_query = x.new_zeros(1, batch_size, x_dim) + + x = torch.cat([x, summarization_query], dim=0) + + x = self.self_attn(input_and_summary=x, state=state) + + x = self.dropout_module(x) + x = residual + x + + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = residual + x + if not self.normalize_before: + x = self.final_layer_norm(x) + + return x + + def build_self_attention(self, embed_dim, args): + return AugmentedMemoryMultiheadAttention( + embed_dim=embed_dim, + num_heads=args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + tanh_on_mem=True, + max_memory_size=args.max_memory_size, + ) + + +# ------------------------------------------------------------------------------ +# AugmentedMemoryMultiheadAttention +# ------------------------------------------------------------------------------ +class AugmentedMemoryMultiheadAttention(MultiheadAttention): + """ + Augmented Memory Attention from + Streaming Transformer-based Acoustic Models + Using Self-attention with Augmented Memory + https://arxiv.org/abs/2005.08042 + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + tanh_on_mem=False, + memory_dim=None, + std_scale=0.5, # 0.5 based on https://arxiv.org/abs/2005.09137 + max_memory_size=-1, + disable_mem_on_mem_attn=True, + ): + super().__init__( + embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn, + self_attention, + encoder_decoder_attention, + q_noise, + qn_block_size, + ) + + self.memory_dim = memory_dim if memory_dim is not None else embed_dim + self.std_scale = std_scale + self.disable_mem_on_mem_attn = disable_mem_on_mem_attn + + # This Operator was used for factorization in PySpeech + self.v2e = lambda x: x + + if tanh_on_mem: + self.squash_mem = torch.tanh + self.nonlinear_squash_mem = True + else: + self.squash_mem = lambda x: x + self.nonlinear_squash_mem = False + + self.max_memory_size = max_memory_size + + def forward(self, input_and_summary, state): + """ + input: Encoder states of current segment with left or right context, + plus one summarization query + + """ + + length, batch_size, _ = input_and_summary.shape + length = length - 1 # not include sum_query, last index + + memory = state["memory_banks"] + # TODO: positional embedding on memory + + if self.max_memory_size > -1 and len(memory) > self.max_memory_size: + # TODO: need to fix here + if self.max_memory_size == 0: + memory = memory.new_zeros(1, memory.size(1), self.memory_dim) + else: + memory = memory[-self.max_memory_size :] + + memory_and_input = torch.cat(memory + [input_and_summary[:-1]], dim=0) + input_and_sum_query = input_and_summary + + q = self.q_proj(self.v2e(input_and_sum_query)) + k = self.k_proj(self.v2e(memory_and_input)) + v = self.v_proj(self.v2e(memory_and_input)) + + q = ( + q.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + * self.scaling + ) + k = ( + k.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + v = ( + v.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + attention_weights = torch.bmm(q, k.transpose(1, 2)) + + if self.disable_mem_on_mem_attn: + attention_weights = self.suppress_mem_on_mem_attention( + batch_size, self.num_heads, len(memory), attention_weights + ) + + if self.std_scale is not None: + attention_weights = attention_suppression(attention_weights, self.std_scale) + + assert list(attention_weights.shape) == [ + batch_size * self.num_heads, + length + 1, + length + len(memory), + ] + + attention_weights = torch.nn.functional.softmax( + attention_weights.float(), dim=-1 + ).type_as(attention_weights) + + attention_probs = self.dropout_module(attention_weights) + + # [T, T, B, n_head] + [T, B, n_head, d_head] -> [T, B, n_head, d_head] + attention = torch.bmm(attention_probs, v) + + assert list(attention.shape) == [ + batch_size * self.num_heads, + length + 1, + self.head_dim, + ] + + attention = ( + attention.transpose(0, 1) + .contiguous() + .view(length + 1, batch_size, self.embed_dim) + ) + + output_and_memory = self.out_proj(attention) + + next_m = output_and_memory[-1:] + next_m = self.squash_mem(next_m) + output = output_and_memory[:-1] + + state["memory_banks"].append(next_m) + + return output + + def suppress_mem_on_mem_attention( + self, B: int, num_heads: int, mem_size: int, attention_weight: Tensor + ): + """ + Arguments: + - B: batch size + - num_heads: number of attention heads + - mem_size: size of memory bank + - attention_weight: a [B*num_heads, T + 1, T + mem_size] vector + + Return: + modified attention_weight with [B*num_heads, -1, :mem_size] = -inf + """ + attention_weight[:, -1, :mem_size] = float("-inf") + return attention_weight + + +# ------------------------------------------------------------------------------ +# SequenceEncoder +# ------------------------------------------------------------------------------ +class SequenceEncoder(FairseqEncoder): + """ + SequenceEncoder encodes sequences. + + More specifically, `src_tokens` and `src_lengths` in `forward()` should + describe a batch of "complete" sequences rather than segments. + + Segment-by-segment inference can be triggered by `segment_size`: + 1) `segment_size` is None: + SequenceEncoder treats the input sequence as one single segment. + 2) `segment_size` is not None (some int instead): + SequenceEncoder does the following: + 1. breaks the input sequence into several segments + 2. inference on each segment and collect the outputs + 3. concatanete segment outputs into the output sequence. + Note that `segment_size` here shouldn't include additional left/right + contexts needed, for example if we wish to infer with LC-BLSTM where the + middle chunk size is 100 and right context is 20, `segment_size` should be + 100. + """ + + def __init__(self, args, module): + super().__init__(None) + + self.module = module + self.input_time_axis = 1 + self.output_time_axis = 0 + self.segment_size = args.segment_size + self.left_context = args.left_context + self.right_context = args.right_context + + def forward( + self, + src_tokens: Tensor, + src_lengths: Tensor, + states=None, + ): + + seg_src_tokens_lengths = sequence_to_segments( + sequence=src_tokens, + time_axis=self.input_time_axis, + lengths=src_lengths, + segment_size=self.segment_size, + extra_left_context=self.left_context, + extra_right_context=self.right_context, + ) + + seg_encoder_states_lengths: List[Tuple[Tensor, Tensor]] = [] + + for seg_src_tokens, seg_src_lengths in seg_src_tokens_lengths: + (seg_encoder_states, seg_enc_lengths, states) = self.module( + seg_src_tokens, + seg_src_lengths, + states=states, + ) + + seg_encoder_states_lengths.append((seg_encoder_states, seg_enc_lengths)) + + encoder_out, enc_lengths = segments_to_sequence( + segments=seg_encoder_states_lengths, time_axis=self.output_time_axis + ) + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + enc_lengths, batch_first=True + ) + + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + return EncoderOut( + encoder_out=encoder_out, + encoder_padding_mask=encoder_padding_mask, + encoder_embedding=None, + encoder_states=states, + src_tokens=None, + src_lengths=None, + ) + + def incremental_encode( + self, + seg_src_tokens: Tensor, + seg_src_lengths: Tensor, + states=None, + ): + """ + Different from forward function, this function takes segmented speech + as input, and append encoder states to previous states + """ + (seg_encoder_states, seg_enc_lengths, states) = self.module( + seg_src_tokens, + seg_src_lengths, + states=states, + ) + return seg_encoder_states, seg_enc_lengths, states + + +# ------------------------------------------------------------------------------ +# Augmented memory model decorator +# ------------------------------------------------------------------------------ +def augmented_memory(klass): + class StreamSeq2SeqModel(klass): + @staticmethod + def add_args(parser): + super(StreamSeq2SeqModel, StreamSeq2SeqModel).add_args(parser) + parser.add_argument( + "--segment-size", type=int, required=True, help="Length of the segment." + ) + parser.add_argument( + "--left-context", + type=int, + default=0, + help="Left context for the segment.", + ) + parser.add_argument( + "--right-context", + type=int, + default=0, + help="Right context for the segment.", + ) + parser.add_argument( + "--max-memory-size", + type=int, + default=-1, + help="Right context for the segment.", + ) + + StreamSeq2SeqModel.__name__ = klass.__name__ + return StreamSeq2SeqModel diff --git a/fairseq/models/speech_to_text/modules/emformer.py b/fairseq/models/speech_to_text/modules/emformer.py new file mode 100644 index 0000000000..42b157b766 --- /dev/null +++ b/fairseq/models/speech_to_text/modules/emformer.py @@ -0,0 +1,1838 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +import math +import re +from functools import partial +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from fairseq.models import ( + FairseqEncoder, +) +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.speech_to_text.utils import ( + NoOp, + lengths_to_padding_mask, + segments_to_sequence, +) +from fairseq.models.speech_to_text.utils import ( + attention_suppression, + layer_norm_backward_hook, +) +from torch import Tensor, device as Device +from torch.quantization.qconfig import ( + default_dynamic_qconfig, + per_channel_dynamic_qconfig, +) + + +class RelativePositionEmbedding(nn.Module): + """ + Implementation according to https://arxiv.org/abs/1803.02155 + """ + + def __init__(self, head_dim, max_position, norm_init=True): + super().__init__() + self.head_dim = head_dim + self.max_position = max_position + self.embeddings = nn.Parameter(torch.Tensor(max_position * 2 + 1, head_dim)) + if norm_init: + nn.init.xavier_normal_(self.embeddings) + else: + nn.init.xavier_uniform_(self.embeddings) + + def forward(self, input: Tensor): + output = nn.functional.embedding(input.long(), self.embeddings) + return output + + +class Fp32LayerNorm(nn.Module): + def __init__( + self, + input_dim, + clamp_grad=True, + max_grad_value=256, + eps=1e-5, + elementwise_affine=True, + ): + super().__init__() + self.torch_module = torch.nn.LayerNorm( + input_dim, eps=eps, elementwise_affine=elementwise_affine + ) + if clamp_grad: + hook = partial(layer_norm_backward_hook, clamp_value=max_grad_value) + self.torch_module.register_backward_hook(hook) + + def forward(self, input): + output = torch.nn.functional.layer_norm( + input.float(), + self.torch_module.normalized_shape, + self.torch_module.weight.float() + if self.torch_module.weight is not None + else None, + self.torch_module.bias.float() + if self.torch_module.bias is not None + else None, + self.torch_module.eps, + ).type_as(input) + return output + + +# ------------------------------------------------------------------------------ +# PositionwiseFF +# ------------------------------------------------------------------------------ + + +class PositionwiseFF(nn.Module): + """ + FFN layer in transformer. + + Args: + input_dim: input embedding dimension + ffn_dim: FFN layer inner dimension + dropout_on_fc1: dropout for first linear layer + dropout_on_fc2: dropout fr second linear layer + activation_fn: activation function used after first linear layer. \ + Only relu or gelu is supported. + + """ + + def __init__( + self, input_dim, ffn_dim, dropout_on_fc1, dropout_on_fc2, activation_fn + ): + super(PositionwiseFF, self).__init__() + + self.input_dim = input_dim + self.ffn_dim = ffn_dim + if activation_fn == "relu": + ac = nn.ReLU() + elif activation_fn == "gelu": + ac = nn.GELU() + else: + raise ValueError("Unsupported activation_fn = ({})".format(activation_fn)) + + # fc1 -> ac -> dropout -> fc2 -> dropout + self.module = nn.Sequential( + nn.Linear(input_dim, ffn_dim), + ac, + nn.Dropout(dropout_on_fc1), + nn.Linear(ffn_dim, input_dim), + nn.Dropout(dropout_on_fc2), + ) + + self.layer_norm = Fp32LayerNorm(input_dim) + + def forward(self, input): + module_out = self.module(self.layer_norm(input)) + output = module_out + input + + return output + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + torch.quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +# ------------------------------------------------------------------------------ +# SummarizationLayer +# ------------------------------------------------------------------------------ + + +class SummarizationLayer(nn.Module): + def __init__(self, method, segment_size, embedding_dim): + super(SummarizationLayer, self).__init__() + self.segment_size = segment_size + self.embedding_dim = embedding_dim + nonlin_match = re.match(r"nonlinear\((?P[a-z]+),(?P[0-9]+)\)", method) + self.method = method + if method == "mean": + self.module = nn.AvgPool1d( + kernel_size=segment_size, + stride=segment_size, + ceil_mode=True, + ) + elif method == "max": + self.module = nn.MaxPool1d( + kernel_size=segment_size, + stride=segment_size, + ceil_mode=True, + ) + elif method == "linear": + self.module = nn.Linear(segment_size, 1) + elif nonlin_match: + nonlin_args = nonlin_match.groupdict() + act_type = nonlin_args["act"] + hid_dim = int(nonlin_args["dim"]) + if act_type == "relu": + act = nn.ReLU() + elif act_type == "gelu": + act = nn.GELU() + else: + raise ValueError("Unsupported activation_fn = ({})".format(act_type)) + self.module = nn.Sequential( + nn.Linear(segment_size, hid_dim), + act, + nn.Linear(hid_dim, 1), + ) + else: + raise ValueError("Unsupported summarization method = ({})".format(method)) + + def forward(self, input): + # T, B, D -> B, D, T + input = input.permute(1, 2, 0) + + if self.method == "mean" or self.method == "max": + output = self.module(input) + output = output.permute(2, 0, 1) + return output + + full_seg_length = input.size(2) // self.segment_size * self.segment_size + if full_seg_length > 0: + # at least one seg is full + B = input.size(0) + D = input.size(1) + input_todo = ( + input[:, :, :full_seg_length] + .contiguous() + .view(B, -1, self.segment_size) + ) + output = self.module(input_todo) + output = output.view(B, D, -1) + else: + output = input.new_zeros(input.size(0), input.size(1), 0) + left = input.size(2) - full_seg_length + if left > 0: + # when last seg is not full, use zeros as last memory placeholder + zeros = input.new_zeros(input.size(0), input.size(1), 1) + output = torch.cat([output, zeros], dim=2) + output = output.permute(2, 0, 1) + return output + + +# ------------------------------------------------------------------------------ +# NoSegAugmentedMemoryMultiheadAttentionBmm +# ------------------------------------------------------------------------------ + + +class NoSegAugmentedMemoryMultiheadAttentionBmm(nn.Module): + """ + Whole utterance augmented memory multihead attention using BMM. + + Different with previous augmented memory multihead attention where + the utterance is chunked into segments. Here we use attention mask + achieve so. The input embedding [right_context, utterance, summary] + is a concatenation of right context, utterance and summary. + + Right context block is the concatenation of all the right context for + each segments. [right_context_0, right_context_1, ..., right_context_n] + For example, if we have utterance = [v0, v1, v2, ...., v20]. segment + size 8, right_context size 4. Then the right context blocks = + [v8, v9, v10, v11, v16, v17, v18, v19, 0, 0, 0, 0], where v8, v9, v10, + and v11 are the right context for first segment. v16, v17, v18 and v19 + are the right context for second segment. 0, 0, 0 and 0 are right context + for the last segment. + + utterance is corresponding to input embedding sequence + + summary is concatenation of average of each segments. [summary_0, + summary_1, ..., ]. + + In augmented memory multihead attention, the query is [right_context, + utterance, summary], key is [memory, right_context, utterance]. Different + with AugmentedMemoryMultiheadAttentionBmm, memory here is passed from + previous attention layer. For the first attention layer, memory is average + of each segment. + + Memory is a concatenation of memory from each segments in previous attention + layer. For example, current layer is i, then memory is [m_0, m_1, ..., m_n]. + Each m_k is the output from seg_k in layer i-1. + + args: + input_dim: input embedding dimension + num_heads: number of heads in multihead self-attention + dropout: attention dropout + std_scale: if std_scale is not None. The weak attention suppression is + turned on. For std_scale = 0.5, all the attention smaller than + mean + 0.5 * std will be suppressed. + scaled_init: whether to use scaled init for linear weight + tanh_on_mem: whether to use tanh on memory output + use_mem: whether to use memory or not. When max_memory_size is 0, then + we don't have memory anymore. + layer_index: current self-attention layer index that is used in depth + initialization + max_relative_position: max relative position used in relative position + embedding + rpe_old_option: To be compatible with previous model. The previous model + was trained with attention += attention + rpe. The correct equation + should be attention = attention + rpe + + """ + + def __init__( + self, + input_dim, + num_heads, + dropout=0.0, + std_scale=None, + scaled_init=False, + tanh_on_mem=False, + use_mem=True, + mini_batches=False, + negative_inf="-inf", + layer_index=-1, + max_relative_position=0, + rpe_old_option=True, + ): + if input_dim % num_heads: + raise ValueError( + "input_dim ({}) must be divisible by num_heads ({})".format( + input_dim, num_heads + ) + ) + + super().__init__() + + embed_dim = input_dim + self.e2h_kv = torch.nn.Linear(input_dim, 2 * input_dim, bias=True) + self.e2h_q = torch.nn.Linear(input_dim, input_dim, bias=True) + self.rpe_old_option = rpe_old_option + if max_relative_position > 0: + self.use_rpe = True + self.rpe_k = RelativePositionEmbedding( + head_dim=input_dim // num_heads, + max_position=max_relative_position, + ) + self.rpe_v = RelativePositionEmbedding( + head_dim=input_dim // num_heads, + max_position=max_relative_position, + ) + else: + self.use_rpe = False + self.rpe_k = None + self.rpe_v = None + if scaled_init: + if layer_index == -1: + gain = 1.0 / math.sqrt(2) + else: + # https://arxiv.org/abs/2005.09684 depthwise initialization + # stablize the training greatly. Use depthwise initialization to + # replace incremental loss. + gain = 1.0 / math.sqrt(layer_index + 1) + torch.nn.init.xavier_uniform_(self.e2h_kv.weight, gain=gain) + torch.nn.init.xavier_uniform_(self.e2h_q.weight, gain=gain) + + self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True) + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim ** -0.5 + + self.std_scale = std_scale + self.use_mem = use_mem + self.mini_batches = mini_batches + self.negative_inf = negative_inf + + if tanh_on_mem: + self.squash_mem = torch.tanh + self.nonlinear_squash_mem = True + else: + self.squash_mem = NoOp() + self.nonlinear_squash_mem = False + + def prepare_qkv( + self, + input: Tensor, + mems: Tensor, + lengths: Tensor, + summary_length: int, + lc_length: int, + ): + # T: right_context length + utterance_length + summary_length + T, B, D = input.shape + mem_length = mems.size(0) + utterance_length = torch.max(lengths) + + right_context_blocks_length = T - utterance_length - summary_length + rc_block = input[:right_context_blocks_length, :, :] + utterance_block = input[right_context_blocks_length : T - summary_length, :, :] + + if B == 1: + padding_mask = None + else: + klengths = lengths + mem_length + right_context_blocks_length + lc_length + padding_mask = lengths_to_padding_mask(lengths=klengths) + + mem_rc_input = torch.cat([mems, rc_block, utterance_block], dim=0) + + # In training lc_length = 0 + key_length = mem_rc_input.size(0) + lc_length + rc_input_sum = input + q = self.e2h_q(rc_input_sum) + kv = self.e2h_kv(mem_rc_input) + k, v = kv.chunk(chunks=2, dim=2) + result_qkv = (q, k, v) + input_shape = (T, B, D) + result_lengths_info = ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) + if padding_mask is not None: + assert padding_mask.size(0) == B + assert padding_mask.size(1) == key_length + + return result_qkv, input_shape, result_lengths_info, padding_mask + + def prepare_attention_weights( + self, + q: Tensor, + new_k: Tensor, + new_v: Tensor, + input_shape: Tuple[int, int, int], + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor]: + T, B, D = input_shape + q = ( + q.contiguous().view(-1, B * self.num_heads, self.head_dim).transpose(0, 1) + * self.scaling + ) + + k = ( + new_k.contiguous() + .view(-1, B * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + v = ( + new_v.contiguous() + .view(-1, B * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + attention_weights = torch.bmm(q, k.transpose(1, 2)) + if self.use_rpe and rpe is not None and self.rpe_v is not None: + r_k = self.rpe_k(rpe) + # [q, B*h, d] * [q, k, d] -> [B*h, q, k] + attention_weights_rpe = torch.matmul( + q.transpose(0, 1), r_k.transpose(1, 2) + ).transpose(0, 1) + attention_weights = attention_weights + attention_weights_rpe + attention_weights_float = attention_weights.float() + + return attention_weights, attention_weights_float, v + + def prepare_attention_output( + self, + attention_weights: Tensor, + attention_weights_float: Tensor, + v: Tensor, + input_shape: Tuple[int, int, int], + key_length: int, + padding_mask: Optional[Tensor], + rpe: Optional[Tensor], + ) -> Tensor: + T, B, D = input_shape + if padding_mask is not None: + attention_weights_float = attention_weights_float.view( + B, self.num_heads, T, key_length + ) + attention_weights_float = attention_weights_float.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + ) + attention_weights_float = attention_weights_float.view( + B * self.num_heads, T, key_length + ) + + if self.std_scale is not None: + attention_weights_float = attention_suppression( + attention_weights_float, self.std_scale + ) + + attention_weights_float = torch.nn.functional.softmax( + attention_weights_float, dim=-1 + ) + attention_weights = attention_weights_float.type_as(attention_weights) + + attention_probs = torch.nn.functional.dropout( + attention_weights, p=self.dropout, training=self.training + ) + + # [T, key_length, B, n_head]+ [key_length, B, n_head, d_head] + # -> [T, B, n_head, d_head] + attention = torch.bmm(attention_probs, v) + if self.use_rpe and rpe is not None and self.rpe_v is not None: + r_v = self.rpe_v(rpe) + attention_rpe = torch.matmul( + attention_probs.transpose(0, 1), r_v + ).transpose(0, 1) + + if self.rpe_old_option: + attention += attention + attention_rpe + else: + attention = attention + attention_rpe + + assert list(attention.shape) == [B * self.num_heads, T, self.head_dim] + + attention = attention.transpose(0, 1).contiguous().view(T, B, self.embed_dim) + + rc_output_memory = self.out_proj(attention) + return rc_output_memory + + @torch.jit.unused + def forward( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + attention_mask: Tensor, + pre_mems: Optional[Tensor] = None, + left_context_key: Optional[Tensor] = None, + left_context_val: Optional[Tensor] = None, + rpe: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + forward function for NoSegAugmentedMemoryMultiheadAttentionBmm in training. + + args: + input: formed in the following way + [right_context_0, right_contex_1, ..., seg_0, seg_1, + ..., summary_0, summary_1,..] + lengths: the length of query which is [seg_0, seg_1, ....] + mems: [mem_0, mem_1, ...]. + attention_mask: attention mask for query = [right_context, query, summary] + key = [mem, right_context, query]. This is only used for traing. + + """ + if self.use_mem: + mem_length = mems.size(0) + summary_length = mem_length + 1 + if pre_mems is not None: + mems = torch.cat([pre_mems, mems], dim=0) + else: + mem_length = 0 + summary_length = 0 + + # In training, lc_length = 0 + if left_context_key is not None: + lc_length = left_context_key.size(0) + else: + lc_length = 0 + results = self.prepare_qkv( + input=input, + mems=mems, + lengths=lengths, + summary_length=summary_length, + lc_length=lc_length, + ) + result_qkv, input_shape, result_lengths_info, padding_mask = results + q, k, v = result_qkv + ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) = result_lengths_info + + if left_context_key is not None: + # add the cache key and value + new_k = torch.cat( + [ + k[: mem_length + right_context_blocks_length, :, :], + left_context_key, + k[-utterance_length:, :, :], + ], + dim=0, + ) + new_v = torch.cat( + [ + v[: mem_length + right_context_blocks_length, :, :], + left_context_val, + v[-utterance_length:, :, :], + ], + dim=0, + ) + next_k = new_k[mem_length + right_context_blocks_length :, :, :] + next_v = new_v[mem_length + right_context_blocks_length :, :, :] + else: + new_k = k + new_v = v + next_k = None + next_v = None + + attention_weights, attention_weights_float, v = self.prepare_attention_weights( + q=q, + new_k=new_k, + new_v=new_v, + input_shape=input_shape, + rpe=rpe, + ) + + # mask attention + attention_mask = attention_mask.unsqueeze(0) + attention_weights_float = attention_weights_float.masked_fill( + attention_mask, float(self.negative_inf) + ) + + rc_output_memory = self.prepare_attention_output( + attention_weights=attention_weights, + attention_weights_float=attention_weights_float, + v=v, + input_shape=input_shape, + key_length=key_length, + padding_mask=padding_mask, + rpe=rpe, + ) + + if self.use_mem: + # next_m length equals to summary length - 1 + # last memory is ignored + if self.mini_batches: + next_m = rc_output_memory[-summary_length:] + else: + next_m = rc_output_memory[-summary_length:-1] + + next_m = self.squash_mem(next_m) + # rc and output + rc_output = rc_output_memory[:-summary_length] + if not self.nonlinear_squash_mem: + next_m = torch.clamp(next_m, min=-10, max=10) + else: + next_m = mems + rc_output = rc_output_memory + + return rc_output, next_m, next_k, next_v + + @torch.jit.export + def forward_jit( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + left_context_key: Tensor, + left_context_val: Tensor, + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + forward function for NoSegAugmentedMemoryMultiheadAttentionBmm in decoding. + + args: + input: formed in the following way + [right_context_0, right_contex_1, ..., seg_0, seg_1, + ..., summary_0, summary_1,..] + lengths: the length of query which is [seg_0, seg_1, ....] + mems: [mem_0, mem_1, ...]. + left_context_key: left_context for key part. This is only used for online + decoding. In training, this is empty tensor + left_context_val: left_context for value part. This is only used for online + decoding. In training, this is empty tensor + + """ + lc_length = left_context_key.size(0) + + # In decoding, summary_length = 1 or 0 + if self.use_mem: + summary_length = 1 + else: + summary_length = 0 + + results = self.prepare_qkv( + input=input, + mems=mems, + lengths=lengths, + summary_length=summary_length, + lc_length=lc_length, + ) + result_qkv, input_shape, result_lengths_info, padding_mask = results + q, k, v = result_qkv + ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) = result_lengths_info + + # add the cache key and value + new_k = torch.cat( + [ + k[: mem_length + right_context_blocks_length, :, :], + left_context_key, + k[-utterance_length:, :, :], + ], + dim=0, + ) + new_v = torch.cat( + [ + v[: mem_length + right_context_blocks_length, :, :], + left_context_val, + v[-utterance_length:, :, :], + ], + dim=0, + ) + next_k = new_k[mem_length + right_context_blocks_length :, :, :] + next_v = new_v[mem_length + right_context_blocks_length :, :, :] + + attention_weights, attention_weights_float, v = self.prepare_attention_weights( + q=q, + new_k=new_k, + new_v=new_v, + input_shape=input_shape, + rpe=rpe, + ) + # In online decoding, we don't have attention mask. But we still need + # to disable the attention from summary query to memory + attention_weights_float[:, -1, :mem_length] = float(self.negative_inf) + rc_output_memory = self.prepare_attention_output( + attention_weights=attention_weights, + attention_weights_float=attention_weights_float, + v=v, + input_shape=input_shape, + key_length=key_length, + padding_mask=padding_mask, + rpe=rpe, + ) + + # In decoding, summary length is 1 + if self.use_mem: + next_m = rc_output_memory[-1:] + next_m = self.squash_mem(next_m) + # rc and output + rc_output = rc_output_memory[:-1] + if not self.nonlinear_squash_mem: + next_m = torch.clamp(next_m, min=-10, max=10) + else: + rc_output = rc_output_memory + # empty tensor as input mems + next_m = mems + + return rc_output, next_m, next_k, next_v + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + torch.quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +class NoSegAugmentedMemoryTransformer(nn.Module): + """ + Whole utterance augmented memory transformer. + + This is not pyspeech nn layer. It is used as a module in a master layer where + multiple transformers is used. + """ + + def __init__( + self, + input_dim, + num_heads, + ffn_dim, + dropout_in_attn=0.0, + dropout_on_attn=None, + dropout_on_fc1=None, + dropout_on_fc2=None, + activation_fn="relu", + tanh_on_mem=False, + std_scale=None, + scaled_init=False, + segment_size=128, + use_mem=True, + mini_batches=False, + negative_inf="-inf", + layer_index=-1, + summarization_method="mean", + max_relative_position=0, + rpe_old_option=True, + ): + super(NoSegAugmentedMemoryTransformer, self).__init__() + + self.attention = NoSegAugmentedMemoryMultiheadAttentionBmm( + input_dim=input_dim, + num_heads=num_heads, + dropout=dropout_in_attn, + scaled_init=scaled_init, + tanh_on_mem=tanh_on_mem, + std_scale=std_scale, + use_mem=use_mem, + mini_batches=mini_batches, + negative_inf=negative_inf, + layer_index=layer_index, + max_relative_position=max_relative_position, + ) + self.dropout = nn.Dropout(dropout_on_attn) + self.pos_ff = PositionwiseFF( + input_dim=input_dim, + ffn_dim=ffn_dim, + dropout_on_fc1=dropout_on_fc1, + dropout_on_fc2=dropout_on_fc2, + activation_fn=activation_fn, + ) + self.layer_norm_pre = Fp32LayerNorm(input_dim) + self.layer_norm = Fp32LayerNorm(input_dim) + self.segment_size = segment_size + self.use_mem = use_mem + + self.memory_op = SummarizationLayer( + summarization_method, segment_size, input_dim + ) + + def set_mini_batches(self, mini_batches): + self.attention.mini_batches = mini_batches + + def gen_summary_queries(self, input): + sum_input = self.memory_op(input) + return sum_input + + def pre_attention_ops(self, input, right_context_blocks): + rc_length = right_context_blocks.size(0) + input_length = input.size(0) + + rc_and_input = torch.cat([right_context_blocks, input], dim=0) + residual_input = rc_and_input + rc_and_input = self.layer_norm_pre(rc_and_input) + + query_input = rc_and_input[-input_length:, :, :] + return rc_length, input_length, residual_input, query_input, rc_and_input + + def after_attention_ops(self, attention_output, residual_input): + output = self.dropout(attention_output) + output = output + residual_input + output = self.pos_ff(output) + output = self.layer_norm(output) + return output + + @torch.jit.export + def forward_jit( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + left_context_key: Tensor, + left_context_val: Tensor, + right_context_blocks: Tensor, + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + + results = self.pre_attention_ops(input, right_context_blocks) + rc_length, input_length, residual_input, query_input, rc_and_input = results + + # In online decoding, the summary query size is always 1 or 0 + if self.use_mem: + summary_query = self.gen_summary_queries(query_input) + summary_query = summary_query[0:1, :, :] + rc_qu_su = torch.cat([rc_and_input, summary_query], dim=0) + else: + rc_qu_su = rc_and_input + + rc_output, next_m, next_k, next_v = self.attention.forward_jit( + input=rc_qu_su, + lengths=lengths, + mems=mems, + left_context_key=left_context_key, + left_context_val=left_context_val, + rpe=rpe, + ) + rc_output = self.after_attention_ops(rc_output, residual_input) + results = ( + rc_output[-input_length:, :, :], + next_m, + rc_output[0:rc_length, :, :], + next_k, + next_v, + ) + return results + + @torch.jit.unused + def forward( + self, + input, + lengths, + mems, + right_context_blocks, + attention_mask, + pre_mems, + left_context_key, + left_context_val, + rpe, + ): + + results = self.pre_attention_ops(input, right_context_blocks) + rc_length, input_length, residual_input, query_input, rc_and_input = results + if self.use_mem: + summary_query = self.gen_summary_queries(query_input) + rc_qu_su = torch.cat([rc_and_input, summary_query], dim=0) + else: + rc_qu_su = rc_and_input + + rc_output, next_m, next_k, next_v = self.attention( + input=rc_qu_su, + lengths=lengths, + mems=mems, + attention_mask=attention_mask, + pre_mems=pre_mems, + left_context_key=left_context_key, + left_context_val=left_context_val, + rpe=rpe, + ) + + # [TODO] Note memory did not go through pos_ff. What happen if we pass + # memory through the pos_ff as well? + rc_output = self.after_attention_ops(rc_output, residual_input) + results = ( + rc_output[-input_length:, :, :], + next_m, + rc_output[0:rc_length, :, :], + next_k, + next_v, + ) + + return results + + +class NoSegAugmentedMemoryTransformerEncoderLayer(FairseqEncoder): + """ + Whole utterance augmented memory transformer encoder layer. This is a master layer + where we can define multiple augmented memory transformers. There are two reasons + to setup the master layer. + 1. We only need to define once about the attention mask. All the layers in the master + layer share the same mask. + 2. pyspeech nn layer has special input and output format. Defining one master layer is + easier to passing memory between different layes inside the master layer + + args: + input_dim: input embedding dimension + num_heads: number of heads in multihead self-attention + ffn_dim: ffn dimension in FFN layer + num_layers: number of augmented memory transformer layers + dropout_in_attn: dropout used in multi-head self-attention + dropout_on_attn: dropout used for output from te multihead self-attention + dropout_on_fc1: dropout used in FFN layer for the first linear layer + dropout_on_fc2: dropout used in FFN layer for the second linear layer + segment_size: segment size for each segment + context_config: (left_context_size, right_context_size) defines the surround context size + for each segment + max_memory_size: maximum memory size used for each segment + scaled_init: whether use scaled init for weight initialization in attention layer + std_scale: if std_scale is not None. The weak attention suppression is + turned on. For std_scale = 0.5, all the attention smaller than + mean + 0.5 * std will be suppressed. + activation_fn: activation function used in FFN layer. [ReLU, GELU] supported + tanh_on_mem: whether use tanh on memory + mini_batches: use mini-btach training + negative_inf: the negative infinity value used in attention masking. default is "-inf". + For some situation, e.g. LM. it is better to use "-1e8" to avoid nan issue. + summarization_method: method to generate segment summrization embedding + max_relative_position: max relatie position for relative position embedding + rpe_old_option: To be compatible with previous model. The previous model + was trained with attention += attention + rpe. The correct equation + should be attention = attention + rpe + [TODO]: remove the rpe_old_option by the end of 2021 Q1. + + """ + + def __init__( + self, + input_dim, + num_heads, + ffn_dim, + num_layers=1, + dropout_in_attn=0.0, + dropout_on_attn=0.0, + dropout_on_fc1=0.0, + dropout_on_fc2=0.0, + segment_size=128, + context_config=(0, 0), + max_memory_size=0, + scaled_init=True, + std_scale=None, + activation_fn="relu", + tanh_on_mem=False, + mini_batches=False, + negative_inf="-inf", + deep_init=True, + summarization_method="mean", + max_relative_position=0, + rpe_old_option=True, + ): + super().__init__(None) + if input_dim % num_heads: + raise ValueError( + "input_dim ({}) must be divisible by num_heads ({})".format( + input_dim, num_heads + ) + ) + + # we used to support growing memory size. However, it will cause + # cross stream batching failure. Now we need to have exact max memory size + if max_memory_size < 0: + raise ValueError("max_memory_size must be >= 0") + + # Only assign right_context. In decoding, left context will be cached. + # No need to let the online decoder to re-assign the left context + self.left_context, self.right_context = context_config + self.segment_size = segment_size + self.memory_dim = input_dim + self.max_memory_size = max_memory_size + self.mini_batches = mini_batches + if self.max_memory_size != 0: + self.use_mem = True + else: + self.use_mem = False + + self.memory_op = SummarizationLayer( + summarization_method, segment_size, input_dim + ) + + self.layers = torch.nn.ModuleList() + self.num_layers = num_layers + self.max_relative_position = max_relative_position + if self.max_relative_position > 0: + self.use_rpe = True + else: + self.use_rpe = False + for i in range(self.num_layers): + if deep_init: + layer_index = i + else: + layer_index = -1 + + self.layers.append( + NoSegAugmentedMemoryTransformer( + num_heads=num_heads, + input_dim=input_dim, + ffn_dim=ffn_dim, + dropout_in_attn=dropout_in_attn, + dropout_on_attn=dropout_on_attn, + dropout_on_fc1=dropout_on_fc1, + dropout_on_fc2=dropout_on_fc2, + segment_size=segment_size, + std_scale=std_scale, + activation_fn=activation_fn, + tanh_on_mem=tanh_on_mem, + scaled_init=scaled_init, + use_mem=self.use_mem, + mini_batches=mini_batches, + negative_inf=negative_inf, + layer_index=layer_index, + summarization_method=summarization_method, + max_relative_position=max_relative_position, + rpe_old_option=rpe_old_option, + ) + ) + + def set_mini_batches(self, mini_batches): + # handy function only used for unit test + self.mini_batches = mini_batches + for layer in self.layers: + layer.set_mini_batches(mini_batches) + + def _get_relative_position( + self, + input: Tensor, + max_relative_position: int, + left_context_length: int, + past_length: int, + is_decoding: bool, + ): + # For training, we copy the right context to the start of the utterance + # First dimension in distance is corresponding to query. + # [right context, utterance, summary vector] + # Second dimension in distance is corresponding to key. + # [Memory bank, right context, utterance] + # For summary vector in query part, the distance with + # all other position is 2*max_position. For memory bank in key, + # the distance with all other positions is 0. + + T, B, D = input.shape + num_segs = math.ceil((T - self.right_context) / self.segment_size) + + # utterance + u_st = past_length * self.segment_size + u_ed = u_st + T + utterance_ranges = torch.arange(u_st, u_ed - self.right_context) + + # left context. Only in minibatch or decoding + left_context_ranges = torch.arange(u_st - left_context_length, u_st) + + # Right context block + # right context + utterance + right_context_blocks = [] + for i in range(0, num_segs - 1): + st = (i + 1) * self.segment_size + u_st + ed = st + self.right_context + assert ed < u_ed + temp = torch.arange(st, ed) + right_context_blocks.append(temp) + right_context_blocks.append(torch.arange(u_ed - self.right_context, u_ed)) + right_context_ranges = torch.cat(right_context_blocks) + + if self.use_mem: + # Memory bank + # The position for memory -n, .., -1 + if is_decoding: + memory_size = min(past_length, self.max_memory_size) + else: + memory_size = num_segs + past_length - 1 + memory_bank_ranges = torch.arange( + -max_relative_position - 1, -max_relative_position - 1 - memory_size, -1 + ) + + # summary vector + # The position for summary vector as the T+max_relative_position+1. + # After the clamping, the relative position is max_relative_position + summary_pos_st = u_ed + max_relative_position + 1 + summary_vector_ranges = torch.arange( + summary_pos_st, summary_pos_st + num_segs + ) + + key_ranges = torch.cat( + [ + memory_bank_ranges, + right_context_ranges, + left_context_ranges, + utterance_ranges, + ] + ) + + query_ranges = torch.cat( + [right_context_ranges, utterance_ranges, summary_vector_ranges] + ) + else: + key_ranges = torch.cat( + [right_context_ranges, left_context_ranges, utterance_ranges] + ) + + query_ranges = torch.cat([right_context_ranges, utterance_ranges]) + + distance = key_ranges[None, :] - query_ranges[:, None] + distance_clamp = ( + torch.clamp(distance, -max_relative_position, max_relative_position) + + max_relative_position + ) + distance_clamp = distance_clamp.to(input.device).long().detach() + return distance_clamp + + def _get_attention_mask(self, input, past_length=0, left_context_cache=0): + # attention mask for each query contains three parts: + # 1. memory part + # 2. left_context + segment + # 3. right_context_block + # so for each segment and its correspoinding right context block, + # the attention matrix is formed by 9 parts: + # [0, m, 0, 0, right_context, 0, 0, seg, 0] + # [before memory, memory, after memory, before right context, right_context, + # after right context, before seg, seg, after seg] + # + # Query is formed in the way as [right_context_blocks, utterance, summary] + # + # Note: put m and right_context before segment is convenient + # for padding_mask operation. + # Key lengths = m_length + right_context_block_length + lengths + utterance_length, batch_size, _ = input.shape + summary_length = math.ceil(utterance_length / self.segment_size) + num_segs = summary_length + rc_length = self.right_context * num_segs + rc = self.right_context + lc = self.left_context + + # using mini-batches, there is left context cache available for current + # sequence. + lcc = left_context_cache + + # max_memory_size is 0 then we don't have memory and summary + # past_length is the memory carry from previous sequence + if self.use_mem: + mem_length = num_segs - 1 + past_length + else: + mem_length = 0 + rc_mask = [] + query_mask = [] + summary_mask = [] + for j in range(0, num_segs): + ssize = min(self.segment_size, utterance_length - j * self.segment_size) + + rc_size = rc + rc_mat = [] + q_mat = [] + s_mat = [] + m_start = max(j + past_length - self.max_memory_size, 0) + + # max_memory_size is 0, then we don't use memory + if self.use_mem: + # part 0: before memory + rc_mat.append(input.new_zeros(rc_size, m_start)) + q_mat.append(input.new_zeros(ssize, m_start)) + s_mat.append(input.new_zeros(1, m_start)) + + # part 1: memory + col_1 = j + past_length - m_start + rc_mat.append(torch.ones(rc_size, col_1, device=input.device)) + q_mat.append(torch.ones(ssize, col_1, device=input.device)) + # based on D22875746, disable summary query attention + # on memeory is better for long form utterance + s_mat.append(input.new_zeros(1, col_1)) + + # part 2: after memory + col_2 = mem_length - (j + past_length) + rc_mat.append(input.new_zeros(rc_size, col_2)) + q_mat.append(input.new_zeros(ssize, col_2)) + s_mat.append(input.new_zeros(1, col_2)) + + # part 3: before right context + rc_start = j * rc + rc_mat.append(input.new_zeros(rc_size, rc_start)) + q_mat.append(input.new_zeros(ssize, rc_start)) + s_mat.append(input.new_zeros(1, rc_start)) + + # part 4: right context + rc_end = rc_start + rc + col_4 = rc + rc_mat.append(torch.ones(rc_size, col_4, device=input.device)) + q_mat.append(torch.ones(ssize, col_4, device=input.device)) + s_mat.append(torch.ones(1, col_4, device=input.device)) + + # part 5: after right context + col_5 = rc_length - rc_end + rc_mat.append(input.new_zeros(rc_size, col_5)) + q_mat.append(input.new_zeros(ssize, col_5)) + s_mat.append(input.new_zeros(1, col_5)) + + # part 6: before query segment + seg_start = max(j * self.segment_size + lcc - lc, 0) + rc_mat.append(input.new_zeros(rc_size, seg_start)) + q_mat.append(input.new_zeros(ssize, seg_start)) + s_mat.append(input.new_zeros(1, seg_start)) + + # part 7: query segment + # note: right context is put in right context block + # here we only need to consider about left context + seg_end = min((j + 1) * self.segment_size + lcc, utterance_length + lcc) + col_7 = seg_end - seg_start + rc_mat.append(torch.ones(rc_size, col_7, device=input.device)) + q_mat.append(torch.ones(ssize, col_7, device=input.device)) + s_mat.append(torch.ones(1, col_7, device=input.device)) + + # part 8: after query segment + col_8 = utterance_length + lcc - seg_end + rc_mat.append(input.new_zeros(rc_size, col_8)) + q_mat.append(input.new_zeros(ssize, col_8)) + s_mat.append(input.new_zeros(1, col_8)) + + rc_mask.append(torch.cat(rc_mat, dim=1)) + query_mask.append(torch.cat(q_mat, dim=1)) + summary_mask.append(torch.cat(s_mat, dim=1)) + + # no memory, then we don't need summary either + if self.use_mem: + attention_mask = ( + 1 + - torch.cat( + [ + torch.cat(rc_mask, dim=0), + torch.cat(query_mask, dim=0), + torch.cat(summary_mask, dim=0), + ], + dim=0, + ) + ).to(torch.bool) + else: + attention_mask = ( + 1 + - torch.cat( + [torch.cat(rc_mask, dim=0), torch.cat(query_mask, dim=0)], dim=0 + ) + ).to(torch.bool) + + return attention_mask + + @torch.jit.export + def init_state( + self, batch_size: int, device: Optional[Device] = None + ) -> List[Tensor]: + empty_memory = torch.zeros( + self.num_layers, + self.max_memory_size, + batch_size, + self.memory_dim, + device=device, + ) + left_context_key = torch.zeros( + self.num_layers, + self.left_context, + batch_size, + self.memory_dim, + device=device, + ) + left_context_val = torch.zeros( + self.num_layers, + self.left_context, + batch_size, + self.memory_dim, + device=device, + ) + past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device) + + return [empty_memory, left_context_key, left_context_val, past_length] + + @torch.jit.export + def batch_state(self, states: List[List[Tensor]]) -> List[Tensor]: + if len(states) == 0: + return [] + batched_m = [] + batched_lc_key = [] + batched_lc_val = [] + batched_past_length = [] + for state in states: + if len(state) == 0: + continue + m, lc_key, lc_val, past_length = state + batched_m.append(m) + batched_lc_key.append(lc_key) + batched_lc_val.append(lc_val) + batched_past_length.append(past_length) + + if ( + (len(batched_m) == 0) + or (len(batched_lc_key) == 0) + or (len(batched_lc_val) == 0) + or (len(batched_past_length) == 0) + ): + return [ + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + ] + + batched_m = torch.cat(batched_m, dim=2) + batched_lc_key = torch.cat(batched_lc_key, dim=2) + batched_lc_val = torch.cat(batched_lc_val, dim=2) + batched_past_length = torch.cat(batched_past_length, dim=1) + return [batched_m, batched_lc_key, batched_lc_val, batched_past_length] + + @torch.jit.export + def reorder_state(self, state: List[Tensor], indices: Tensor) -> List[Tensor]: + if len(state) == 0: + return [] + m, lc_key, lc_val, past_length = state + indices = indices.to(device=m.device) + reord_m = torch.index_select(m, 2, indices) + reord_lc_key = torch.index_select(lc_key, 2, indices) + reord_lc_val = torch.index_select(lc_val, 2, indices) + reord_past_length = torch.index_select(past_length, 1, indices) + return [reord_m, reord_lc_key, reord_lc_val, reord_past_length] + + @torch.jit.export + def reset_state(self, state: List[Tensor], indices: Tensor) -> List[Tensor]: + m, lc_key, lc_val, past_length = state + m = m.index_fill(dim=2, index=indices, value=0.0) + lc_key = lc_key.index_fill(dim=2, index=indices, value=0.0) + lc_val = lc_val.index_fill(dim=2, index=indices, value=0.0) + past_length = past_length.index_fill(dim=1, index=indices, value=0) + + return [m, lc_key, lc_val, past_length] + + @torch.jit.export + def state_size(self) -> int: + return 4 + + @torch.jit.export + def batch_size_in_state( + self, state: Optional[List[Tensor]], sloppy: bool = True + ) -> Optional[int]: + if state is None: + return None + return state[0].size(2) + + def gen_summary_queries(self, input): + sum_input = self.memory_op(input) + return sum_input + + def _gen_right_context_padded_input(self, input): + # This function deals with input that is already + # padded with right context (e.g. minibatch training) + right_context_blocks = [] + T, B, D = input.shape + num_segs = math.ceil((T - self.right_context) / self.segment_size) + for i in range(0, num_segs - 1): + st = (i + 1) * self.segment_size + ed = st + self.right_context + assert ed < T + temp = input[st:ed, :, :] + right_context_blocks.append(temp) + + # last segment right context is already available + right_context_blocks.append(input[T - self.right_context :, :, :]) + return torch.cat(right_context_blocks, dim=0) + + def _gen_segs_right_context(self, input, lengths): + segments = [] + T, B, D = input.size() + nT = T - self.right_context + + # assume input is right context padded + num_segs = math.ceil(nT / self.segment_size) + # pad zeros to the utterance to make sure each + # segment has the same right context. For the + for i in range(0, num_segs - 1): + st = i * self.segment_size + ed = min(T, st + self.segment_size + self.right_context) + temp = input[st:ed, :, :] + rest_lengths = torch.clamp( + lengths - self.segment_size, min=0, max=nT - (i + 1) * self.segment_size + ) + segments.append((temp, lengths - rest_lengths + self.right_context)) + lengths = rest_lengths + + last_seg = input[st + self.segment_size :, :, :] + segments.append((last_seg, rest_lengths + self.right_context)) + + return segments + + @torch.jit.unused + def forward( + self, input: Tensor, padding_masks: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]: + # Xutai: originally the second argument is lengths. + lengths = (~padding_masks).sum(dim=1).long() + # mini batch training. + if self.mini_batches: + return self.forward_mini_batches(input, lengths, state) + + # regular full sequence training. Note, assume the right context in provided + # in the input. + T, B, D = input.size() + right_context_blocks = self._gen_right_context_padded_input(input) + + # generate the relative positional embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=0, + past_length=0, + is_decoding=False, + ) + else: + rpe = None + input = input[: T - self.right_context, :, :] + + attention_mask = self._get_attention_mask(input) + + # firt layer use each segment mean as memory + # ignore the last one seg average + if self.use_mem: + mems = self.gen_summary_queries(input)[:-1, :, :] + else: + mems = torch.zeros(0, input.size(1), input.size(2), device=input.device) + mems = mems.type_as(input) + + output = input + all_outputs = [] + + for layer in self.layers: + output, mems, right_context_blocks, _, _ = layer( + input=output, + lengths=lengths, + attention_mask=attention_mask, + mems=mems, + right_context_blocks=right_context_blocks, + pre_mems=None, + left_context_key=None, + left_context_val=None, + rpe=rpe, + ) + all_outputs.append(output) + return output, padding_masks, [], all_outputs + + def forward_jit_mini_batch_init( + self, + seg: Tensor, + state: Optional[List[Tensor]] = None, + is_decoding: bool = False, + ): + # Prepare state. In whole sequence training, state is ignored. + # For minibatch training, we need to prepare state + if state is None: + state = self.init_state(batch_size=seg.size(1), device=seg.device) + if seg.dtype == torch.half: + state = [state[0].half(), state[1].half(), state[2].half(), state[3]] + + if self.use_mem: + # note input average only on seg, not on right context + # first layer use each segmetn mean as memory. the last + # one segment average is used in state + full_mems = self.gen_summary_queries(seg) + if is_decoding: + mems = full_mems[0:1, :, :] + state_mems = torch.cat([state[0][0], mems], dim=0) + else: + mems = full_mems[:-1, :, :] + state_mems = torch.cat([state[0][0], full_mems], dim=0) + else: + mems = state[0][0] + state_mems = mems + + # track processed segment number or memory number + # the same batch as the same bumber of past length + past_length = state[3][0][0].item() + past_left_context = min(past_length * self.segment_size, self.left_context) + past_length = min(self.max_memory_size, past_length) + + return state, mems, state_mems, past_length, past_left_context + + def state_update_before( + self, layer: int, state: List[Tensor], past_length: int, past_left_context: int + ): + pre_mems = state[0][layer][self.max_memory_size - past_length :, :, :] + lc_key = state[1][layer][self.left_context - past_left_context :, :, :] + lc_val = state[2][layer][self.left_context - past_left_context :, :, :] + return pre_mems, lc_key, lc_val + + def state_update_after( + self, + layer: int, + state: List[Tensor], + mems: Tensor, + next_key: Tensor, + next_val: Tensor, + mems_list: List[Tensor], + lc_key_list: List[Tensor], + lc_val_list: List[Tensor], + ): + # mems is used for next layer + if layer < self.num_layers - 1: + state_mems = torch.cat([state[0][layer + 1], mems], dim=0) + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + + # when mems pass to next sequence, we need the last memory. when mems + # use for the next layer, we can ignore the last memory + mems = mems[:-1, :, :] + + # note state[1][i] and state[2][i] original length equals to self.left_context + new_k = torch.cat([state[1][layer], next_key], dim=0) + new_v = torch.cat([state[2][layer], next_val], dim=0) + lc_key_list.append(new_k[-self.left_context :, :, :]) + lc_val_list.append(new_v[-self.left_context :, :, :]) + return mems_list, lc_key_list, lc_val_list, mems + + def state_update_after_loop( + self, + state: List[Tensor], + mems_list: List[Tensor], + lc_key_list: List[Tensor], + lc_val_list: List[Tensor], + update_length: int, + ): + state[0] = torch.stack(mems_list, dim=0) + state[1] = torch.stack(lc_key_list, dim=0) + state[2] = torch.stack(lc_val_list, dim=0) + state[3] = state[3] + update_length + return state + + @torch.jit.unused + def forward_mini_batches( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]: + T, B, D = input.size() + + # input without right context + seg = input[: T - self.right_context, :, :] + + # get right context blocks + right_context_blocks = self._gen_right_context_padded_input(input) + + mems_list = [] + lc_key_list = [] + lc_val_list = [] + results = self.forward_jit_mini_batch_init(seg, state, False) + state, mems, state_mems, past_length, past_left_context = results + + # relative position embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=past_left_context, + past_length=past_length, + is_decoding=False, + ) + else: + rpe = None + + # get attention mask based on seg (not include right context) and available + # left context + attention_mask = self._get_attention_mask(seg, past_length, past_left_context) + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + output = seg + i = 0 + all_outputs = [] + for layer in self.layers: + # In order to make cross stream batching work, mem, left context key + # and left context value in the state should always be the same shape. + # We use the past length to track the processed segment number. In this + # way, we take out the essential memory, left context key and left + # context val from the state. After finish the forward for current segment + # we add the new memory, left context key and left context value into the + # staate and trim out the oldest part to keep the shape consistent. + pre_mems, lc_key, lc_val = self.state_update_before( + i, state, past_length, past_left_context + ) + + output, mems, right_context_blocks, next_key, next_val = layer.forward( + input=output, + lengths=lengths, + attention_mask=attention_mask, + mems=mems, + right_context_blocks=right_context_blocks, + pre_mems=pre_mems, + left_context_key=lc_key, + left_context_val=lc_val, + rpe=rpe, + ) + all_outputs.append(output) + mems_list, lc_key_list, lc_val_list, mems = self.state_update_after( + layer=i, + state=state, + mems=mems, + next_key=next_key, + next_val=next_val, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + ) + + i += 1 + + # update state + update_length = math.ceil((T - self.right_context) / self.segment_size) + state = self.state_update_after_loop( + state=state, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + update_length=update_length, + ) + + return output, lengths, state, all_outputs + + def forward_jit_test( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + This one simulate sequence encoder forward jit. This is for unit test purpose. + It is not used in training or decoding. Note, extra_right_context is set in + the model. In unit test, input = [utterance, right_context], lengths = + [utterance_length]. + args: + input: input utterance + lengths: utterance input length + state: None here. input is whole utterance + """ + # [TODO] sequence_to_segment has bug in lengths. + seg_src_tokens_lengths = self._gen_segs_right_context(input, lengths) + + seg_enc_tokens_lengths: List[Tuple[Tensor, Tensor]] = [] + state: Optional[List[Tensor]] = None + for seg_src_tokens, seg_src_lengths in seg_src_tokens_lengths: + seg_enc_tokens, seg_enc_lengths, state = self.forward_jit( + input=seg_src_tokens, lengths=seg_src_lengths, state=state + ) + seg_enc_tokens_lengths.append((seg_enc_tokens, seg_enc_lengths)) + + enc_tokens, enc_lengths = segments_to_sequence( + segments=seg_enc_tokens_lengths, time_axis=0 + ) + + state = [] # returns trivial state + + return enc_tokens, enc_lengths, state + + @torch.jit.export + def forward_jit( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Forward helper for online decoding. + + args: + input: [seg, right_context]. We assume in online we + always padding the right context to the preset right context size. + For the last segment, we may have short segment size, but right + context size is the same as other segments + lengths: utterance input length is the utterance segment length and + right context size + state: [memory, left_context_key, left_context_val]. To improve throughput, + in addition to memory, we also cache key and value for left_context in + multihead self-attention + """ + # In online decoding, input = [segment, right_context] + # Lengths = [segment_length, right_context_length] + # so we need strip right context in output + T, B, D = input.size() + rc_str = T - self.right_context + rc_end = T + right_context_blocks = input[rc_str:rc_end, :, :] + seg = input[:rc_str, :, :] + lengths = torch.clamp(lengths - self.right_context, min=0) + mems_list = [] + lc_key_list = [] + lc_val_list = [] + + results = self.forward_jit_mini_batch_init(seg, state, True) + state, mems, state_mems, past_length, past_left_context = results + + # relative position embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=past_left_context, + past_length=past_length, + is_decoding=True, + ) + else: + rpe = None + + # memory for first layer. + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + output = seg + i = 0 + for layer in self.layers: + # In order to make cross stream batching work, mem, left context key + # and left context value in the state should always be the same shape. + # We use the past length to track the processed segment number. In this + # way, we take out the essential memory, left context key and left + # context val from the state. After finish the forward for current segment + # we add the new memory, left context key and left context value into the + # staate and trim out the oldest part to keep the shape consistent. + true_mems, lc_key, lc_val = self.state_update_before( + layer=i, + state=state, + past_length=past_length, + past_left_context=past_left_context, + ) + + output, mems, right_context_blocks, next_key, next_val = layer.forward_jit( + input=output, + lengths=lengths, + mems=true_mems, + right_context_blocks=right_context_blocks, + left_context_key=lc_key, + left_context_val=lc_val, + rpe=rpe, + ) + # mems is used for next layer + mems_list, lc_key_list, lc_val_list, _ = self.state_update_after( + layer=i, + state=state, + mems_list=mems_list, + mems=mems, + next_key=next_key, + next_val=next_val, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + ) + i += 1 + + # update state + state = self.state_update_after_loop( + state=state, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + update_length=1, + ) + + return output, lengths, state + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + torch.quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +# ------------------------------------------------------------------------------ +# Emformer encoder for seq2seq model +# This is a wrapper over the original emformer +# ------------------------------------------------------------------------------ +def emformer_encoder(klass): + class SpeechEncoder(klass): + def __init__(self, args): + super().__init__(args) + stride = SpeechEncoder.conv_layer_stride(args) + trf_left_context = args.segment_left_context // stride + trf_right_context = args.segment_right_context // stride + context_config = [trf_left_context, trf_right_context] + self.transformer_layers = nn.ModuleList( + [ + NoSegAugmentedMemoryTransformerEncoderLayer( + input_dim=args.encoder_embed_dim, + num_heads=args.encoder_attention_heads, + ffn_dim=args.encoder_ffn_embed_dim, + num_layers=args.encoder_layers, + dropout_in_attn=args.dropout, + dropout_on_attn=args.dropout, + dropout_on_fc1=args.dropout, + dropout_on_fc2=args.dropout, + activation_fn=args.activation_fn, + context_config=context_config, + segment_size=args.segment_length, + max_memory_size=args.max_memory_size, + scaled_init=True, # TODO: use constant for now. + tanh_on_mem=args.amtrf_tanh_on_mem, + ) + ] + ) + + def forward(self, *args, **kwargs): + encoder_out = super().forward(*args, **kwargs) + (output, encoder_padding_masks, [], all_outputs) = encoder_out.encoder_out + + # This is because that in the original implementation + # the output didn't consider the last segment as right context. + encoder_padding_masks = encoder_padding_masks[:, : output.size(0)] + # import pdb;pdb.set_trace() + + return EncoderOut( + encoder_out=output, + encoder_padding_mask=encoder_padding_masks, + encoder_embedding=None, + encoder_states=None, + src_tokens=None, + src_lengths=None, + ) + + @staticmethod + def conv_layer_stride(args): + # TODO: make it configurable from the args + return 4 + + SpeechEncoder.__name__ = klass.__name__ + return SpeechEncoder diff --git a/fairseq/models/speech_to_text/utils.py b/fairseq/models/speech_to_text/utils.py new file mode 100644 index 0000000000..573f8537c9 --- /dev/null +++ b/fairseq/models/speech_to_text/utils.py @@ -0,0 +1,564 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +import logging +from collections.abc import Iterable +from itertools import repeat +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + + +# ------------------------------------------------------------------------------ +# assert_equal() +# ------------------------------------------------------------------------------ + + +def assert_equal(value1, value2, name1=None, name2=None): + """Asserts two values are equal otherwise raise an error.""" + + str_name1 = "" if name1 is None else "{} ".format(name1) + str_name2 = "" if name2 is None else "{} ".format(name2) + if value1 != value2: + str_value1 = "{}" if name1 is None else "({})" + str_value1 = str_value1.format(value1) + str_value2 = "{}" if name2 is None else "({})" + str_value2 = str_value2.format(value2) + raise ValueError( + "Expected {}{} == {}{}".format(str_name1, str_value1, str_name2, str_value2) + ) + + +def fill_config(config, key, value): + if value is not None: + if key not in config or config[key] is None: + config[key] = value + assert_equal(value, config[key], "value", f'config["{key}"]') + + +# ------------------------------------------------------------------------------ +# check_and_return_expected() +# ------------------------------------------------------------------------------ + + +def check_and_return_expected(value, undefined_value, expected_value, name=None): + """ + Return the expected value while checking if the given value is undefined or + equal to the expected value. + """ + if (undefined_value is None and value is None) or (undefined_value == value): + return expected_value + if value != expected_value: + str_name = "" if name is None else "{} ".format(name) + str_value = "{}" if name is None else "({})" + str_value = str_value.format(value) + raise ValueError( + "Expected {}{} == {}".format(str_name, str_value, expected_value) + ) + return expected_value + + +# ------------------------------------------------------------------------------ +# get_time_axis() +# ------------------------------------------------------------------------------ + + +def get_time_axis(layout): + """ + Extract the time axis from the layout, for example for breaking sequence into + segments. + """ + if layout in ["TB", "TBD"]: + return 0 + if layout in ["BT", "BTD"]: + return 1 + if layout in ["BCTD"]: + return 2 + raise ValueError("Unsupported layout = {}".format(layout)) + + +# ------------------------------------------------------------------------------ +# get_batch_axis() +# ------------------------------------------------------------------------------ + + +def get_batch_axis(layout): + """ + Extract the batch axis from the layout + """ + if layout in ["TB", "TBD"]: + return 1 + if layout in ["BT", "BTD", "BCTD"]: + return 0 + raise ValueError("Unsupported layout = {}".format(layout)) + + +# ------------------------------------------------------------------------------ +# monotonically_increasing_and_bounded() +# ------------------------------------------------------------------------------ + + +def monotonically_increasing_and_bounded(iterable, min=None, max=None): + """ + Check if the elements in the given iterable are monotonically increasing and + bounded by upper/lower bounds. + """ + if not isinstance(iterable, Iterable): + raise TypeError( + "Expected iterable to be of type Iterable, got ({})".format( + iterable.__class__.__name__ + ) + ) + for i in range(len(iterable)): + if min is not None and iterable[i] < min: + return False + if max is not None and iterable[i] > max: + return False + if i > 0 and iterable[i] <= iterable[i - 1]: + return False + return True + + +# ------------------------------------------------------------------------------ +# to_pair() +# ------------------------------------------------------------------------------ + + +def to_pair(value, name): + """Make a pair (of type tuple) of given value.""" + if isinstance(value, Iterable): + if len(value) != 2: + raise ValueError( + "Expected `{}` to have exactly 2 elements, got: ({})".format( + name, value + ) + ) + return value + return tuple(repeat(value, 2)) + + +# ------------------------------------------------------------------------------ +# infer_conv_output_attrs() +# ------------------------------------------------------------------------------ + + +# TODO(cfyeh): figure out if we can get `output_dim` without calling the module. +def infer_conv_output_attrs( + module, input_channels, input_dim, batch_size=1, max_length=8 +): + """Get output attributes of a module with input.""" + input = torch.randn(batch_size, input_channels, max_length, input_dim) + output = module(input) + output_channels = output.shape[1] + output_dim = output.shape[-1] + return output_channels, output_dim + + +# ------------------------------------------------------------------------------ +# NoOp +# ------------------------------------------------------------------------------ + + +class NoOp(torch.nn.Module): + """ + NoOp simply passes the input as the output. + """ + + def __init__(self): + super().__init__() + + def forward(self, input: Tensor) -> Tensor: + return input + + +# ------------------------------------------------------------------------------ +# Permute: a torch.nn.Module applies permutation on the input tensor. +# ------------------------------------------------------------------------------ + + +class Permute(torch.nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, input: Tensor) -> Tensor: + return input.permute(self.dims).contiguous() + + +# ------------------------------------------------------------------------------ +# lengths_to_padding_mask() +# ------------------------------------------------------------------------------ + + +def lengths_to_padding_mask(lengths: Tensor) -> Tensor: + """Convert lengths of shape (B, ) to padding mask.""" + batch_size = lengths.shape[0] + max_length = int(torch.max(lengths).item()) + padding_mask = torch.arange( # [0, ..., T-1] + max_length, device=lengths.device, dtype=lengths.dtype + ).expand(batch_size, max_length) >= lengths.unsqueeze(1) + + return padding_mask + + +# ------------------------------------------------------------------------------ +# lengths_to_attention_mask() +# ------------------------------------------------------------------------------ + + +def lengths_to_attention_mask( + lengths: Tensor, + left_context: Optional[int] = None, + right_context: Optional[int] = None, +) -> Optional[Tensor]: + """ + Generate attention mask based on (lengths, left_context, right_context). + left_context is None means unlimited left context. + right_context is None means unlimited right context. + """ + + if left_context is None and right_context is None: + return None + + max_length = int(torch.max(lengths).item()) + + # For example, with `max_length` == 5, + # indices = tensor([ + # [ 0, 1, 2, 3, 4, 5], + # [-1, 0, 1, 2, 3, 4], + # [-2, -1, 0, 1, 2, 3], + # [-3, -2, -1, 0, 1, 2], + # [-4, -3, -2, -1, 0, 1], + # [-5, -4, -3, -2, -1, 0], + # ]) + + # In some cases the second torch.arange is created on cpu which causes a + # failure. Adding the device option to guard against it. + indices = torch.arange( + max_length, device=lengths.device, dtype=lengths.dtype + ).expand(max_length, max_length) - torch.arange( + max_length, device=lengths.device + ).view( + max_length, -1 + ) + + # For example, with `max_length` == 5, + # bool_mask = tensor([ + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # ]) + bool_mask = ( + torch.tensor([True]).to(device=lengths.device).expand(max_length, max_length) + ) + + # For example, with `max_length` == 5, left_context == 2 + # left_mask = tensor([ + # [ True, True, True, True, True], + # [ True, True, True, True, True], + # [ True, True, True, True, True], + # [False, True, True, True, True], + # [False, False, True, True, True], + # ]) + if left_context is not None: + left_mask = indices >= -left_context + bool_mask = bool_mask & left_mask + + # For example, with `max_length` == 5, right_context == 1 + # right_mask = tensor([ + # [True, True, False, False, False], + # [True, True, True, False, False], + # [True, True, True, True, False], + # [True, True, True, True, True], + # [True, True, True, True, True], + # ]) + if right_context is not None: + right_mask = indices <= right_context + bool_mask = bool_mask & right_mask + + bool_mask = (~bool_mask).to(device=lengths.device) + return bool_mask + + +# ------------------------------------------------------------------------------ +# infer_output_norm() +# ------------------------------------------------------------------------------ + + +def infer_output_norm(module, output_norm=None): + """ + Infer the output norm (string and module) needed on the module gvien desired + output normalization. + """ + if output_norm == module.output_norm(): + # output_norm already matches module.output_norm(). + return (None, NoOp()) + + if output_norm is None and module.output_norm() is not None: + logger = logging.getLogger("infer_output_norm()") + logger.warning( + "trying to set output_norm ({}) ".format(output_norm) + + "but got module.output_norm() ({}), ".format(module.output_norm()) + + "the combined output_norm() will be ({})".format(module.output_norm()) + ) + return (None, NoOp()) + + if output_norm == "log_softmax": + if module.output_norm() is not None: + raise ValueError( + "incompatible output_norm ({}) ".format(output_norm) + + "and module.output_norm() ({})".format(module.output_norm()) + ) + else: + return ("log_softmax", torch.nn.LogSoftmax(dim=-1)) + + if output_norm == "softmax": + if module.output_norm() is not None: + raise ValueError( + "incompatible output_norm ({}) ".format(output_norm) + + "and module.output_norm() ({})".format(module.output_norm()) + ) + else: + return ("softmax", torch.nn.Softmax(dim=-1)) + + raise ValueError( + "output_norm ({}) not in ".format(output_norm) + + "supported list = [None, softmax, log_softmax]" + ) + + +# ------------------------------------------------------------------------------ +# infer_channels_from_layout() +# ------------------------------------------------------------------------------ + + +def infer_channels_from_layout(layout, channels): + """Extract the number of channels from the layout.""" + if layout in ("TBD", "BTD"): + if channels is not None and channels != 1: + raise ValueError( + "Expected channels ({}) to be 1 for layout = {}".format( + channels, layout + ) + ) + if channels is None: + return 1 + return channels + + +# ------------------------------------------------------------------------------ +# pad_sequence() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def pad_sequence( + sequence: Tensor, + time_axis: int, + extra_left_context: int = 0, + extra_right_context: int = 0, +) -> Tensor: + """Pad extra left/right contexts to the sequence.""" + + if extra_left_context == 0 and extra_right_context == 0: + return sequence + + tensors_to_concat = [] + + if extra_left_context: + size = (extra_left_context,) + fill_value = 0 + indices = torch.full( + size=size, + fill_value=fill_value, + dtype=torch.long, + device=sequence.device, + ) + left_padding = torch.index_select(sequence, time_axis, indices) + tensors_to_concat.append(left_padding) + + tensors_to_concat.append(sequence) + + # NOTE(cfyeh): for efficiency reason we pad 0 instead of the last frame for + # extra right contexts. + if extra_right_context: + size = list(sequence.shape) + size[time_axis] = extra_right_context + right_padding = torch.zeros(size, dtype=sequence.dtype, device=sequence.device) + tensors_to_concat.append(right_padding) + + padded_sequence = torch.cat(tensors_to_concat, dim=time_axis) + return padded_sequence + + +# ------------------------------------------------------------------------------ +# sequence_to_segments() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def sequence_to_segments( + sequence: Tensor, + time_axis: int, + lengths: Tensor, + segment_size: Optional[int] = None, + extra_left_context: int = 0, + extra_right_context: int = 0, +) -> List[Tuple[Tensor, Tensor]]: + """Breaks sequence into segments.""" + + sequence = pad_sequence( + sequence=sequence, + time_axis=time_axis, + extra_left_context=extra_left_context, + extra_right_context=extra_right_context, + ) + + lengths = lengths + extra_left_context + extra_right_context + + segments: List[Tuple[Tensor, Tensor]] = [] + + if segment_size is None: + segments.append((sequence, lengths)) + return segments + + offset = 0 + end = sequence.shape[time_axis] + step = segment_size + size = extra_left_context + segment_size + extra_right_context + + while offset + extra_left_context + extra_right_context < end: + clamped_size = min(size, end - offset) + segment_lengths = torch.clamp(lengths - offset, min=0, max=clamped_size) + indices = torch.arange( + start=offset, + end=(offset + clamped_size), + step=1, + dtype=torch.long, + device=sequence.device, + ) + segment_tensor = torch.index_select(sequence, time_axis, indices) + segments.append((segment_tensor, segment_lengths)) + offset = offset + step + + return segments + + +# ------------------------------------------------------------------------------ +# segments_to_sequence() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def segments_to_sequence( + segments: List[Tuple[Tensor, Tensor]], time_axis: int +) -> Tuple[Tensor, Tensor]: + """Concatenate segments into a full sequence.""" + if len(segments) == 1: + return segments[0] + + tensors_to_concat: List[Tensor] = [] + lengths_to_stack: List[Tensor] = [] + + for tensor, lengths in segments: + tensors_to_concat.append(tensor) + lengths_to_stack.append(lengths) + + sequence = torch.cat(tensors_to_concat, dim=time_axis) + lengths = torch.stack(lengths_to_stack, dim=0) + lengths = torch.sum(lengths, dim=0) + + return sequence, lengths + + +def lengths_to_encoder_padding_mask(lengths, batch_first: bool = False): + """ + convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor + + Args: + lengths: a (B, )-shaped tensor + batch_first: whether to return a (B, T) tensor + + Return: + max_length: maximum length of B sequences + encoder_padding_mask: a (max_length, B) binary mask, where + [t, b] = False for t < lengths[b] and True otherwise + + TODO: + kernelize this function if benchmarking shows this function is slow + """ + max_lengths = torch.max(lengths).item() + bsz = lengths.size(0) + encoder_padding_mask = torch.arange( + max_lengths + ).to( # a (T, ) tensor with [0, ..., T-1] + lengths.device + ).view( # move to the right device + 1, max_lengths + ).expand( # reshape to (1, T)-shaped tensor + bsz, -1 + ) > lengths.view( # expand to (B, T)-shaped tensor + bsz, 1 + ).expand( + -1, max_lengths + ) + if not batch_first: + return encoder_padding_mask.t(), max_lengths + else: + return encoder_padding_mask, max_lengths + + +# ------------------------------------------------------------------------------ +# attention suppression +# ------------------------------------------------------------------------------ + + +def attention_suppression(attention_weights: Tensor, scale: float): + # B, H, qlen, klen -> B, H, qlen, 1 + attention_prob = torch.nn.functional.softmax(attention_weights.float(), dim=-1) + attention_nozeros = attention_prob.to(torch.bool) + nozeros_sum = torch.sum(attention_nozeros.to(torch.float), dim=-1, keepdim=True) + + # For very sparse situation, we need get round about 0s + key_sum = torch.sum(attention_prob, dim=-1, keepdim=True) + + # nozeros_sum should > 1 + key_mean = key_sum / (nozeros_sum + 1e-8) + + # std calculation + dis = (attention_prob - key_mean) * (attention_prob - key_mean) + + # if attention_prob[i] < threshold, then dis_masked[i] = 0; for all i + dis_masked = torch.where( + attention_nozeros, dis, attention_prob.new_zeros(attention_prob.size()) + ) + + key_var = torch.sum(dis_masked, dim=-1, keepdim=True) + key_var = key_var / (nozeros_sum - 1.0 + 1e-8) + key_std = torch.sqrt(key_var) + key_thread = key_mean - scale * key_std + + # if attention_prob[i] >= key_thread, then attention_prob[i] + # , otherwise "-inf" + inf_tensor = attention_prob.new_zeros(attention_prob.size()).detach() + inf_tensor[:] = float("-inf") + attention_weights_float = torch.where( + attention_prob < key_thread, + inf_tensor, + attention_weights.float(), + ) + + return attention_weights_float.type_as(attention_weights) + + +def layer_norm_backward_hook(module, grad_input, grad_output, clamp_value): + return tuple(torch.clamp(v, min=-clamp_value, max=clamp_value) for v in grad_input) From b8786dc2aadb56bb549f92ed542875096868bdd5 Mon Sep 17 00:00:00 2001 From: Sravya Popuri Date: Tue, 2 Mar 2021 17:08:45 -0800 Subject: [PATCH 54/82] Integrate Augmented memory transformer and emformer based augmented memory transformer into fbcode Summary: Integrate Augmented memory transformer and emformer based augmented memory transformer into fbcode. This diff - Modifies the way encoder_out_dict variable is accessed in transformer_monotonic_attention.py - Fix dimension issues in augmented_memory_attention.py - Modifies the way encoder_out is accessed in emformer.py Reviewed By: jmp84 Differential Revision: D26567899 fbshipit-source-id: 9b298ad0bdf78de00b1182001813b0513d32a119 --- .../models/convtransformer_simul_trans.py | 99 ++++++++++++++++++- .../modules/augmented_memory_attention.py | 22 +++-- .../models/speech_to_text/modules/emformer.py | 26 +++-- 3 files changed, 122 insertions(+), 25 deletions(-) diff --git a/examples/simultaneous_translation/models/convtransformer_simul_trans.py b/examples/simultaneous_translation/models/convtransformer_simul_trans.py index 760a48168d..0b15e93fea 100644 --- a/examples/simultaneous_translation/models/convtransformer_simul_trans.py +++ b/examples/simultaneous_translation/models/convtransformer_simul_trans.py @@ -10,7 +10,17 @@ register_model, register_model_architecture, ) -from fairseq.models.speech_to_text import ConvTransformerModel, convtransformer_espnet +from fairseq.models.speech_to_text import ( + ConvTransformerModel, + convtransformer_espnet, + ConvTransformerEncoder, +) +from fairseq.models.speech_to_text.modules.augmented_memory_attention import ( + augmented_memory, + SequenceEncoder, + AugmentedMemoryConvTransformerEncoder, +) +from fairseq.models.speech_to_text.modules.emformer import emformer_encoder @register_model("convtransformer_simul_trans") @@ -56,3 +66,90 @@ def build_decoder(cls, args, task, embed_tokens): ) def convtransformer_simul_trans_espnet(args): convtransformer_espnet(args) + + +@register_model("convtransformer_augmented_memory") +@augmented_memory +class AugmentedMemoryConvTransformerModel(SimulConvTransformerModel): + @classmethod + def build_encoder(cls, args): + encoder = SequenceEncoder(args, AugmentedMemoryConvTransformerEncoder(args)) + + if getattr(args, "load_pretrained_encoder_from", None) is not None: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + + return encoder + + +@register_model_architecture( + "convtransformer_augmented_memory", "convtransformer_augmented_memory" +) +def augmented_memory_convtransformer_espnet(args): + convtransformer_espnet(args) + + +# ============================================================================ # +# Convtransformer +# with monotonic attention decoder +# with emformer encoder +# ============================================================================ # + + +@emformer_encoder +class ConvTransformerEmformerEncoder(ConvTransformerEncoder): + pass + + +@register_model("convtransformer_emformer") +class ConvtransformerEmformer(SimulConvTransformerModel): + @staticmethod + def add_args(parser): + super(ConvtransformerEmformer, ConvtransformerEmformer).add_args(parser) + + parser.add_argument( + "--segment-length", + type=int, + metavar="N", + help="length of each segment (not including left context / right context)", + ) + parser.add_argument( + "--segment-left-context", + type=int, + help="length of left context in a segment", + ) + parser.add_argument( + "--segment-right-context", + type=int, + help="length of right context in a segment", + ) + parser.add_argument( + "--max-memory-size", + type=int, + default=-1, + help="Right context for the segment.", + ) + parser.add_argument( + "--amtrf-tanh-on-mem", + default=False, + action="store_true", + help="whether to use tanh on memory vector", + ) + + @classmethod + def build_encoder(cls, args): + encoder = ConvTransformerEmformerEncoder(args) + if getattr(args, "load_pretrained_encoder_from", None): + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + return encoder + + +@register_model_architecture( + "convtransformer_emformer", + "convtransformer_emformer", +) +def convtransformer_emformer_base(args): + convtransformer_espnet(args) diff --git a/fairseq/models/speech_to_text/modules/augmented_memory_attention.py b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py index 5d31524b76..e7465bc889 100644 --- a/fairseq/models/speech_to_text/modules/augmented_memory_attention.py +++ b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py @@ -8,7 +8,6 @@ import torch import torch.nn.functional as F from fairseq.models import FairseqEncoder -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.speech_to_text import ( ConvTransformerEncoder, ) @@ -72,7 +71,10 @@ def forward(self, src_tokens, src_lengths, states=None): x = self.embed_scale * x subsampling_factor = 1.0 * max_seq_len / output_seq_len - input_lengths = (src_lengths.float() / subsampling_factor).round().long() + input_lengths = torch.max( + (src_lengths.float() / subsampling_factor).ceil().long(), + x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long(), + ) encoder_padding_mask, _ = lengths_to_encoder_padding_mask( input_lengths, batch_first=True @@ -425,14 +427,14 @@ def forward( if not encoder_padding_mask.any(): encoder_padding_mask = None - return EncoderOut( - encoder_out=encoder_out, - encoder_padding_mask=encoder_padding_mask, - encoder_embedding=None, - encoder_states=states, - src_tokens=None, - src_lengths=None, - ) + return { + "encoder_out": [encoder_out], + "encoder_padding_mask": [encoder_padding_mask], + "encoder_embedding": [], + "encoder_states": [states], + "src_tokens": [], + "src_lengths": [], + } def incremental_encode( self, diff --git a/fairseq/models/speech_to_text/modules/emformer.py b/fairseq/models/speech_to_text/modules/emformer.py index 42b157b766..e026b86847 100644 --- a/fairseq/models/speech_to_text/modules/emformer.py +++ b/fairseq/models/speech_to_text/modules/emformer.py @@ -17,7 +17,6 @@ from fairseq.models import ( FairseqEncoder, ) -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.speech_to_text.utils import ( NoOp, lengths_to_padding_mask, @@ -1811,23 +1810,22 @@ def __init__(self, args): ] ) - def forward(self, *args, **kwargs): - encoder_out = super().forward(*args, **kwargs) - (output, encoder_padding_masks, [], all_outputs) = encoder_out.encoder_out + def forward(self, src_tokens, src_lengths): + encoder_out = super().forward(src_tokens, src_lengths) + (output, encoder_padding_masks, [], _) = encoder_out["encoder_out"][0] # This is because that in the original implementation # the output didn't consider the last segment as right context. encoder_padding_masks = encoder_padding_masks[:, : output.size(0)] - # import pdb;pdb.set_trace() - - return EncoderOut( - encoder_out=output, - encoder_padding_mask=encoder_padding_masks, - encoder_embedding=None, - encoder_states=None, - src_tokens=None, - src_lengths=None, - ) + + return { + "encoder_out": [output], + "encoder_padding_mask": [encoder_padding_masks], + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } @staticmethod def conv_layer_stride(args): From 0c32e251e29dc6f10755addd37c5f9d963693df9 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Wed, 3 Mar 2021 09:59:23 -0800 Subject: [PATCH 55/82] Update Simultaneous Translation doc (#1659) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1659 Reviewed By: jmp84, kahne Differential Revision: D26708524 Pulled By: xutaima fbshipit-source-id: 0f34e5e9e3bec2360e098c9c272105c793bfa7b7 --- .../simultaneous_translation/docs/baseline.md | 178 ------------------ .../docs/evaluation.md | 115 ----------- .../docs/simulst_mustc_example.md | 60 +++++- 3 files changed, 54 insertions(+), 299 deletions(-) delete mode 100644 examples/simultaneous_translation/docs/baseline.md delete mode 100644 examples/simultaneous_translation/docs/evaluation.md diff --git a/examples/simultaneous_translation/docs/baseline.md b/examples/simultaneous_translation/docs/baseline.md deleted file mode 100644 index d9bf1a1117..0000000000 --- a/examples/simultaneous_translation/docs/baseline.md +++ /dev/null @@ -1,178 +0,0 @@ -# **Baseline Simultaneous Translation** ---- - -This is an instruction of training and evaluating a *wait-k* simultanoes LSTM model on MUST-C English-Gernam Dataset. - -[STACL: Simultaneous Translation with Implicit Anticipation and Controllable Latency using Prefix-to-Prefix Framework](https://https://www.aclweb.org/anthology/P19-1289/) - - -## **Requirements** -Install fairseq (make sure to use the correct branch): -``` -git clone --branch simulastsharedtask git@github.com:pytorch/fairseq.git -cd fairseq -pip install -e . -``` - -Assuming that fairseq is installed in a directory called `FAIRSEQ`. - -Install SentencePiece. One easy way is to use anaconda: - -``` -conda install -c powerai sentencepiece -``` - -Download the MuST-C data for English-German available at https://ict.fbk.eu/must-c/. -We will assume that the data is downloaded in a directory called `DATA_ROOT`. - - -## **Text-to-text Model** ---- -### Data Preparation -Train a SentencePiece model: -```shell -for lang in en de; do - python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \ - --data-path $DATA_ROOT/data \ - --vocab-size 10000 \ - --max-frame 3000 \ - --model-type unigram \ - --lang $lang \ - --out-path . -``` - -Process the data with the SentencePiece model: -```shell -proc_dir=proc -mkdir -p $proc_dir -for split in train dev tst-COMMON tst-HE; do - for lang in en de; do - spm_encode \ - --model unigram-$lang-10000-3000/spm.model \ - < $DATA_ROOT/data/$split/txt/$split.$lang \ - > $proc_dir/$split.spm.$lang - done -done -``` - -Binarize the data: - -```shell -proc_dir=proc -fairseq-preprocess \ - --source-lang en --target-lang de \ - --trainpref $proc_dir/train.spm \ - --validpref $proc_dir/dev.spm \ - --testpref $proc_dir/tst-COMMON.spm \ - --thresholdtgt 0 \ - --thresholdsrc 0 \ - --workers 20 \ - --destdir ./data-bin/mustc_en_de \ -``` - -### Training - - -```shell -mkdir -p checkpoints -CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \ - --save-dir checkpoints \ - --arch berard_simul_text_iwslt \ - --simul-type waitk \ - --waitk-lagging 2 \ - --optimizer adam \ - --max-epoch 100 \ - --lr 0.001 \ - --clip-norm 5.0 \ - --batch-size 128 \ - --log-format json \ - --log-interval 10 \ - --criterion cross_entropy_acc \ - --user-dir $FAIRSEQ/examples/simultaneous_translation -``` - -## **Speech-to-text Model** ---- -### Data Preparation -First, segment wav files. -```shell -python $FAIRSEQ/examples/simultaneous_translation/data/segment_wav.py \ - --datapath $DATA_ROOT -``` -Similar to text-to-text model, train a Sentencepiecemodel, but only train on German -```Shell -python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \ - --data-path $DATA_ROOT/data \ - --vocab-size 10000 \ - --max-frame 3000 \ - --model-type unigram \ - --lang $lang \ - --out-path . -``` -## Training -```shell -mkdir -p checkpoints -CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \ - --save-dir checkpoints \ - --arch berard_simul_text_iwslt \ - --waitk-lagging 2 \ - --waitk-stride 10 \ - --input-feat-per-channel 40 \ - --encoder-hidden-size 512 \ - --output-layer-dim 128 \ - --decoder-num-layers 3 \ - --task speech_translation \ - --user-dir $FAIRSEQ/examples/simultaneous_translation - --optimizer adam \ - --max-epoch 100 \ - --lr 0.001 \ - --clip-norm 5.0 \ - --batch-size 128 \ - --log-format json \ - --log-interval 10 \ - --criterion cross_entropy_acc \ - --user-dir $FAIRSEQ/examples/simultaneous_translation -``` - -## Evaluation ---- -### Evaluation Server -For text translation models, the server is set up as follow give input file and reference file. - -``` shell -python ./eval/server.py \ - --hostname localhost \ - --port 12321 \ - --src-file $DATA_ROOT/data/dev/txt/dev.en \ - --ref-file $DATA_ROOT/data/dev/txt/dev.de -``` -For speech translation models, the input is the data direcrory. -``` shell -python ./eval/server.py \ - --hostname localhost \ - --port 12321 \ - --ref-file $DATA_ROOT \ - --data-type speech -``` - -### Decode and Evaluate with Client -Once the server is set up, run client to evaluate translation quality and latency. -```shell -# TEXT -python $fairseq_dir/examples/simultaneous_translation/evaluate.py \ - data-bin/mustc_en_de \ - --user-dir $FAIRSEQ/examples/simultaneous_translation \ - --src-spm unigram-en-10000-3000/spm.model\ - --tgt-spm unigram-de-10000-3000/spm.model\ - -s en -t de \ - --path checkpoints/checkpoint_best.pt - -# SPEECH -python $fairseq_dir/examples/simultaneous_translation/evaluate.py \ - data-bin/mustc_en_de \ - --user-dir $FAIRSEQ/examples/simultaneous_translation \ - --data-type speech \ - --tgt-spm unigram-de-10000-3000/spm.model\ - -s en -t de \ - --path checkpoints/checkpoint_best.pt -``` diff --git a/examples/simultaneous_translation/docs/evaluation.md b/examples/simultaneous_translation/docs/evaluation.md deleted file mode 100644 index c53407354e..0000000000 --- a/examples/simultaneous_translation/docs/evaluation.md +++ /dev/null @@ -1,115 +0,0 @@ -# Introduction to evaluation interface -The simultaneous translation models from sharedtask participents are evaluated under a server-client protocol. The participents are requisted to plug in their own model API in the protocol, and submit a docker file. - -## Server-Client Protocol -An server-client protocol that will be used in evaluation. For example, when a *wait-k* model (k=3) translate the English sentence "Alice and Bob are good friends" to Genman sentence "Alice und Bob sind gute Freunde." , the evaluation process is shown as following figure. - -While every time client needs to read a new state (word or speech utterence), a "GET" request is supposed to sent over to server. Whenever a new token is generated, a "SEND" request with the word predicted (untokenized word) will be sent to server immediately. The server can hence calculate both latency and BLEU score of the sentence. - -### Server -The server code is provided and can be set up directly locally for development purpose. For example, to evaluate a text simultaneous test set, - -```shell - - python fairseq/examples/simultaneous_translation/eval/server.py \ - --hostname local_host \ - --port 1234 \ - --src-file SRC_FILE \ - --ref-file REF_FILE \ - --data-type text \ -``` -The state that server sent to client is has the following format -```json -{ - 'sent_id': Int, - 'segment_id': Int, - 'segment': String -} -``` - -### Client -The client will handle the evaluation process mentioned above. It should be out-of-box as well. The client's protocol is as following table - -|Action|Content| -|:---:|:---:| -|Request new word / utterence| ```{key: "Get", value: None}```| -|Predict word "W"| ```{key: "SEND", value: "W"}```| - - - -The core of the client module is the agent, which needs to be modified to different models accordingly. The abstract class of agent is as follow, the evaluation process happens in the `decode()` function. -```python -class Agent(object): - "an agent needs to follow this pattern" - def __init__(self, *args, **kwargs): - ... - - def init_states(self): - # Initializing states - ... - - def update_states(self, states, new_state): - # Update states with given new state from server - # TODO (describe the states) - ... - - def finish_eval(self, states, new_state): - # Check if evaluation is finished - ... - - def policy(self, state: list) -> dict: - # Provide a action given current states - # The action can only be either - # {key: "GET", value: NONE} - # or - # {key: "SEND", value: W} - ... - - def reset(self): - # Reset agent - ... - - def decode(self, session): - - states = self.init_states() - self.reset() - - # Evaluataion protocol happens here - while True: - # Get action from the current states according to self.policy() - action = self.policy(states) - - if action['key'] == GET: - # Read a new state from server - new_state = session.get_src() - states = self.update_states(states, new_state) - - if self.finish_eval(states, new_state): - # End of document - break - - elif action['key'] == SEND: - # Send a new prediction to server - session.send_hypo(action['value']) - - # Clean the history, wait for next sentence - if action['value'] == DEFAULT_EOS: - states = self.init_states() - self.reset() - else: - raise NotImplementedError - - -``` -Here an implementation of agent of text [*wait-k* model](somelink). Notice that the tokenization is not considered. - -## Quality -The quality is measured by detokenized BLEU. So make sure that the predicted words sent to server are detokenized. An implementation is can be find [here](some link) - -## Latency -The latency metrics are -* Average Proportion -* Average Lagging -* Differentiable Average Lagging -Again Thery will also be evaluated on detokenized text. - diff --git a/examples/speech_to_text/docs/simulst_mustc_example.md b/examples/speech_to_text/docs/simulst_mustc_example.md index 5dea0d8475..0144fcb766 100644 --- a/examples/speech_to_text/docs/simulst_mustc_example.md +++ b/examples/speech_to_text/docs/simulst_mustc_example.md @@ -1,13 +1,46 @@ # Simultaneous Speech Translation (SimulST) on MuST-C +This is an instruction of training and evaluating a transformer *wait-k* simultaneous model on MUST-C English-Germen Dataset, from [SimulMT to SimulST: Adapting Simultaneous Text Translation to End-to-End Simultaneous Speech Translation](https://www.aclweb.org/anthology/2020.aacl-main.58.pdf). + [MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with 8-language translations on English TED talks. -## Data Preparation & ASR -Please follow the steps in offline [speech-to-text](../mustc_example.md) translation for data preparation and ASR pretraining. +## Data Preparation +[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path +`${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with +```bash +# Additional Python packages for S2T data processing/model training +pip install pandas torchaudio sentencepiece + +# Generate TSV manifests, features, vocabulary, +# global cepstral and mean estimation, +# and configuration for each language +python examples/speech_to_text/prep_mustc_data.py \ + --data-root ${MUSTC_ROOT} --task asr \ + --vocab-type unigram --vocab-size 10000 \ + --cmvn-type global +python examples/speech_to_text/prep_mustc_data.py \ + --data-root ${MUSTC_ROOT} --task st \ + --vocab-type unigram --vocab-size 10000 + --cmvn-type global +``` + +## ASR Pretraining +We just need a pretrained offline ASR model +``` +fairseq-train ${MUSTC_ROOT}/en-de \ + --config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \ + --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch convtransformer_espnet --optimizer adam --lr 0.0005 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 +``` -## Training +## Simultaneous Speech Translation Training -#### Wait-K(K=3) with fixed pre-decision module +### Wait-K with fixed pre-decision module +Fixed pre-decision indicates that the model operate simultaneous policy on the boundaries of fixed chunks. +Here is a example of fixed pre-decision ratio 7 (the simultaneous decision is made every 7 encoder states) and +a wait-3 policy model ``` fairseq-train ${MUSTC_ROOT}/en-de \ --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ @@ -21,8 +54,9 @@ Please follow the steps in offline [speech-to-text](../mustc_example.md) transla --simul-type waitk_fixed_pre_decision \ --waitk-lagging 3 \ --fixed-pre-decision-ratio 7 + ``` -#### Monotonic multihead attention with fixed pre-decision module +### Monotonic multihead attention with fixed pre-decision module ``` fairseq-train ${MUSTC_ROOT}/en-de \ --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ @@ -39,9 +73,13 @@ Please follow the steps in offline [speech-to-text](../mustc_example.md) transla ``` ## Inference & Evaluation [SimulEval](https://github.com/facebookresearch/SimulEval) is used for evaluation. +The source file is a list of paths of audio files, +while target file is the corresponding translations. ``` +pip install simuleval + simuleval \ - --agent ${FAIRSEQ}/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py + --agent examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py --src-file ${SRC_LIST_OF_AUDIO} --tgt-file ${TGT_FILE} --data-bin ${MUSTC_ROOT}/en-de \ @@ -50,3 +88,13 @@ simuleval \ --tgt-splitter-path ${MUSTC_ROOT}/en-de/spm.model \ --scores ``` + +A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms. The databin (containing dictionary, gcmvn file and sentencepiece model) can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin). + +The quality is measured by detokenized BLEU. So make sure that the predicted words sent to the server are detokenized. + +The latency metrics are +* Average Proportion +* Average Lagging +* Differentiable Average Lagging +Again they will also be evaluated on detokenized text. From 7d2394b56f1cbdcdede9c7a8cf6de1df022e0a17 Mon Sep 17 00:00:00 2001 From: Eric Lou Date: Wed, 3 Mar 2021 10:48:42 -0800 Subject: [PATCH 56/82] ioPath async - Fairseq unittests (#1669) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1669 Unit tests for async writes integration done in D26467815 (https://github.com/pytorch/fairseq/commit/3100d0b8e5bb5e61b4d73b9c058389aa2c06784a). Ongoing performance tests: https://fb.quip.com/kjM7Atb1kKbO Reviewed By: myleott Differential Revision: D26732660 fbshipit-source-id: faf8cac67b9167af4195358c1a2592804c13562c --- fairseq/file_io.py | 2 +- tests/test_checkpoint_utils.py | 15 +++++++++++++++ tests/test_file_io.py | 15 +++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/fairseq/file_io.py b/fairseq/file_io.py index 731fef3570..9a78ab505d 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -170,7 +170,7 @@ def opena( if not IOPathPathManager: logging.info("ioPath is initializing PathManager.") try: - from iopath import PathManager + from iopath.common.file_io import PathManager IOPathPathManager = PathManager() except Exception: logging.exception("Failed to initialize ioPath PathManager object.") diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index 617a5f7c84..3278de6b9f 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -9,8 +9,10 @@ import tempfile import unittest from io import StringIO +from unittest.mock import patch from fairseq import checkpoint_utils +from omegaconf import OmegaConf from tests.utils import ( create_dummy_data, @@ -87,6 +89,19 @@ def test_prune_state_dict(self): self.assertEqual(len(ensemble[0].encoder.layers), 2) self.assertEqual(len(ensemble[0].decoder.layers), 1) + def test_torch_persistent_save_async(self): + cfg = OmegaConf.create() + cfg.dataset = OmegaConf.create() + cfg.dataset.write_checkpoints_asynchronously = True + state_dict = {} + filename = "async_checkpoint.pt" + + with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena: + with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save: + checkpoint_utils.torch_persistent_save(cfg.dataset, state_dict, filename) + mock_opena.assert_called_with(filename, "wb") + mock_save.assert_called() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_file_io.py b/tests/test_file_io.py index aef5b80d18..8ebbba4a2e 100644 --- a/tests/test_file_io.py +++ b/tests/test_file_io.py @@ -45,3 +45,18 @@ def test_file_io_oss(self): with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: s = f.read() self.assertEqual(s, self._tmpfile_contents) + + def test_file_io_async(self): + # ioPath `PathManager` is initialized after the first `opena` call. + try: + from fairseq.file_io import IOPathPathManager, PathManager + + self.assertIsNone(IOPathPathManager) + _asyncfile = os.path.join(self._tmpdir, "async.txt") + f = PathManager.opena(_asyncfile, "wb") + f.close() + + from fairseq.file_io import IOPathPathManager + self.assertIsNotNone(IOPathPathManager) + finally: + self.assertTrue(PathManager.async_close()) From 1fed7a8426e8c548196add0d65d77857ab224705 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Wed, 3 Mar 2021 19:29:55 -0800 Subject: [PATCH 57/82] add unit test for multi_corpus_dataset Reviewed By: vimalmanohar Differential Revision: D26220694 fbshipit-source-id: ed13f8527a1b203e1a9d004fa8a86e1ad6423d60 --- tests/test_multi_corpus_dataset.py | 69 ++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tests/test_multi_corpus_dataset.py diff --git a/tests/test_multi_corpus_dataset.py b/tests/test_multi_corpus_dataset.py new file mode 100644 index 0000000000..a1fafe489b --- /dev/null +++ b/tests/test_multi_corpus_dataset.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from collections import OrderedDict + +import torch +from fairseq.data import LanguagePairDataset, TokenBlockDataset +from fairseq.data.multi_corpus_dataset import MultiCorpusDataset +from tests.test_train import mock_dict + + +class TestMultiCorpusDataset(unittest.TestCase): + def setUp(self): + d = mock_dict() + tokens_1 = torch.LongTensor([i for i in range(1, 5000, 2)]).view(1, -1) + tokens_ds1 = TokenBlockDataset( + tokens_1, + sizes=[tokens_1.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, + ) + self.dataset_1 = LanguagePairDataset( + tokens_ds1, tokens_ds1.sizes, d, shuffle=False + ) + tokens_2 = torch.LongTensor([i for i in range(2, 5000, 2)]).view(1, -1) + tokens_ds2 = TokenBlockDataset( + tokens_2, + sizes=[tokens_2.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, + ) + self.dataset_2 = LanguagePairDataset( + tokens_ds2, tokens_ds2.sizes, d, shuffle=False + ) + + def _test_sample_helper( + self, + distribution, + ): + m = MultiCorpusDataset( + OrderedDict({0: self.dataset_1, 1: self.dataset_2}), + distribution=distribution, + seed=0, + sort_indices=True, + ) + m.set_epoch(1) + indices = m.ordered_indices() + count_sample_from_first_dataset = 0 + for i in indices: + if m[i]["source"].item() % 2 == 1: + count_sample_from_first_dataset += 1 + sample_from_first_ds_percentage = ( + 1.0 * count_sample_from_first_dataset / len(indices) + ) + self.assertLess( + abs(sample_from_first_ds_percentage - distribution[0]), + 0.01, + ) + + def test_multi_corpus_dataset(self): + for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1]]: + self._test_sample_helper(distribution=distribution) From fc2840de58b06f381626332153203fb32588c23d Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Wed, 3 Mar 2021 19:29:55 -0800 Subject: [PATCH 58/82] optimize sampling process of multi_corpus_dataset Summary: The sampling process in multi_corpus_dataset is very inefficient. Turns out we can signficantly optimize it by sampling in batches rather than one by one. this allows: 1. fast local development and iteration with corpus sampling, as the turnaround time was long before 2. makes it take less time for our jobs can start training, enabling earlier signal if for example there is a configuration issue Reviewed By: zhengwy888 Differential Revision: D26187821 fbshipit-source-id: b4f7f6b7c187b3785499308226e2af671a6c354f --- fairseq/data/multi_corpus_dataset.py | 85 +++++++++++++++++----------- tests/test_multi_corpus_dataset.py | 14 ++++- 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index 6563713489..00e464ed31 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging +import time from collections import OrderedDict from typing import Dict, List @@ -12,7 +13,6 @@ from . import FairseqDataset - logger = logging.getLogger(__name__) @@ -49,6 +49,7 @@ def __init__( super().__init__() assert isinstance(datasets, OrderedDict) assert len(datasets) == len(distribution) + assert sum(distribution) == 1 self.datasets = datasets self.distribution = distribution self.seed = seed @@ -69,43 +70,61 @@ def __init__( self.total_num_instances += len(dataset) def ordered_indices(self): + start = time.time() with data_utils.numpy_seed(self.seed, self.epoch): - # Used to store the order of indices of each dataset to use - indices = [ - np.random.permutation(len(dataset)) - for dataset in self.datasets.values() - ] - # Keep track of which samples we've used for each dataset - counters = [0 for _ in self.datasets] - - sampled_indices = [ - self._sample(indices, counters) for _ in range(self.total_num_instances) - ] + sampled_indices = [] + num_selected_instances = 0 + + # For each dataset i, sample self.distribution[i] * self.total_num_instances + for i, key in enumerate(self.datasets): + + if i < len(self.datasets) - 1: + num_instances = int(self.distribution[i] * self.total_num_instances) + high = self.dataset_offsets[i + 1] + else: + num_instances = self.total_num_instances - num_selected_instances + high = self.total_num_instances + + logger.info(f"sampling {num_instances} from {key} dataset") + num_selected_instances += num_instances + + # First, add k copies of the dataset where k = num_instances // len(dataset). + # This ensures an equal distribution of the data points as much as possible. + # For the remaining entries randomly sample them + dataset_size = len(self.datasets[key]) + num_copies = num_instances // dataset_size + dataset_indices = ( + np.random.permutation(high - self.dataset_offsets[i]) + + self.dataset_offsets[i] + )[: num_instances - num_copies * dataset_size] + if num_copies > 0: + sampled_indices += list( + np.concatenate( + ( + np.repeat( + np.arange(self.dataset_offsets[i], high), num_copies + ), + dataset_indices, + ) + ) + ) + else: + sampled_indices += list(dataset_indices) + + assert ( + len(sampled_indices) == self.total_num_instances + ), f"{len(sampled_indices)} vs {self.total_num_instances}" + + np.random.shuffle(sampled_indices) if self.sort_indices: sampled_indices.sort(key=lambda i: self.num_tokens(i)) - return np.array(sampled_indices, dtype=np.int64) - - def _sample(self, indices, counters): - # First pick dataset - dataset_idx = np.random.choice(len(self.distribution), p=self.distribution) - - # Then get dataset internal index - idx = indices[dataset_idx][counters[dataset_idx]] - - # Convert to multi-datasets index - idx += self.dataset_offsets[dataset_idx] - - counters[dataset_idx] += 1 - - # Reset if we reach end - if counters[dataset_idx] == len(self.dataset_list[dataset_idx]): - counters[dataset_idx] = 0 - indices[dataset_idx] = np.random.permutation( - len(self.dataset_list[dataset_idx]) + logger.info( + "multi_corpus_dataset ordered_indices took {}s".format( + time.time() - start + ) ) - - return idx + return np.array(sampled_indices, dtype=np.int64) def _map_index(self, index: int): """ diff --git a/tests/test_multi_corpus_dataset.py b/tests/test_multi_corpus_dataset.py index a1fafe489b..5a79f4b680 100644 --- a/tests/test_multi_corpus_dataset.py +++ b/tests/test_multi_corpus_dataset.py @@ -27,7 +27,7 @@ def setUp(self): self.dataset_1 = LanguagePairDataset( tokens_ds1, tokens_ds1.sizes, d, shuffle=False ) - tokens_2 = torch.LongTensor([i for i in range(2, 5000, 2)]).view(1, -1) + tokens_2 = torch.LongTensor([i for i in range(0, 5000, 2)]).view(1, -1) tokens_ds2 = TokenBlockDataset( tokens_2, sizes=[tokens_2.size(-1)], @@ -53,9 +53,13 @@ def _test_sample_helper( m.set_epoch(1) indices = m.ordered_indices() count_sample_from_first_dataset = 0 + items = set() for i in indices: - if m[i]["source"].item() % 2 == 1: + item = m[i]["source"].item() + if item % 2 == 1: count_sample_from_first_dataset += 1 + + items.add(item) sample_from_first_ds_percentage = ( 1.0 * count_sample_from_first_dataset / len(indices) ) @@ -63,6 +67,12 @@ def _test_sample_helper( abs(sample_from_first_ds_percentage - distribution[0]), 0.01, ) + self.assertEqual( + len(items), + int(min(len(self.dataset_1), len(indices) * distribution[0]) + + min(len(self.dataset_1), len(indices) * distribution[1])) + ) + print(distribution) def test_multi_corpus_dataset(self): for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1]]: From f6d60e2fee9fe8982e3c9de1e6bb77680978e749 Mon Sep 17 00:00:00 2001 From: alexeib Date: Wed, 3 Mar 2021 21:15:01 -0800 Subject: [PATCH 59/82] minor fixes and improvements (#1671) Summary: there are a few changes here: - convert config persisted in checkpoints into a plain dict when saving and back to omegaconf config when loading: this helps avoid compatibility issues between different versions of python, omegaconf, etc - update checkpoints that have old print_alignment saved - add lr_float to composite optimizer to enable sweeping on lr with auto sweepers like ax - fixing some edge cases for config loading Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1671 Reviewed By: myleott Differential Revision: D26791583 Pulled By: alexeib fbshipit-source-id: 124dec74932052925c43b6a93130f4428803cb46 --- fairseq/checkpoint_utils.py | 54 +++++++++++++++++++++++++++++++------ fairseq/dataclass/utils.py | 16 ++++++----- fairseq/optim/composite.py | 13 ++++++--- 3 files changed, 65 insertions(+), 18 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index d6618fbb62..97f22041bc 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -21,7 +21,7 @@ ) from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder -from omegaconf import DictConfig, open_dict +from omegaconf import Container, DictConfig, open_dict, OmegaConf logger = logging.getLogger(__name__) @@ -275,8 +275,22 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) - if "cfg" in state and state["cfg"] is not None and arg_overrides is not None: - overwrite_args_by_name(state["cfg"], arg_overrides) + if "cfg" in state and state["cfg"] is not None: + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import _utils + + old_primitive = _utils.is_primitive_type + _utils.is_primitive_type = lambda _: True + + state["cfg"] = OmegaConf.create(state["cfg"]) + + _utils.is_primitive_type = old_primitive + OmegaConf.set_struct(state["cfg"], True) + + if arg_overrides is not None: + overwrite_args_by_name(state["cfg"], arg_overrides) state = _upgrade_state_dict(state) return state @@ -440,7 +454,7 @@ def save_state( if extra_state is None: extra_state = {} state_dict = { - "cfg": cfg, + "cfg": OmegaConf.to_container(cfg) if OmegaConf.is_config(cfg) else cfg, "args": kwargs.get("args", None), "model": model_state_dict or {}, "optimizer_history": optim_history @@ -453,7 +467,7 @@ def save_state( } ], "extra_state": extra_state, - "task_state": task.state_dict() if task is not None else {} + "task_state": task.state_dict() if task is not None else {}, } if utils.has_parameters(criterion): state_dict["criterion"] = criterion.state_dict() @@ -568,15 +582,39 @@ def _upgrade_state_dict(state): if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): state["args"].lr = [state["args"].lr] # convert task data arg to a string instead of List[string] - if hasattr(state["args"], "data") and isinstance(state["args"].data, list) and len(state["args"].data) > 0: + if ( + hasattr(state["args"], "data") + and isinstance(state["args"].data, list) + and len(state["args"].data) > 0 + ): state["args"].data = state["args"].data[0] state["cfg"] = convert_namespace_to_omegaconf(state["args"]) if "cfg" in state and state["cfg"] is not None: - with open_dict(state["cfg"]): + cfg = state["cfg"] + with open_dict(cfg): # any upgrades for Hydra-based configs - pass + if ( + "task" in cfg + and "eval_wer_config" in cfg.task + and isinstance(cfg.task.eval_wer_config.print_alignment, bool) + ): + cfg.task.eval_wer_config.print_alignment = "hard" + if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): + cfg.generation.print_alignment = "hard" + if ( + "model" in cfg + and "w2v_args" in cfg.model + and cfg.model.w2v_args is not None + and ( + hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args + ) + and isinstance( + cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool + ) + ): + cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard" return state diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index a4d4a412dd..27c9006fdb 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -43,7 +43,9 @@ def interpret_dc_type(field_type): return str typestring = str(field_type) - if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring) or typestring.startswith("typing.Optional"): + if re.match( + r"(typing.|^)Union\[(.*), NoneType\]$", typestring + ) or typestring.startswith("typing.Optional"): return field_type.__args__[0] return field_type @@ -235,15 +237,17 @@ def get_default(f): and not (isinstance(val, str) and val.startswith("${")) ): # if type is int but val is float, then we will crash later - try to convert here - if hasattr(v.type, '__args__'): + if hasattr(v.type, "__args__"): t_args = v.type.__args__ - if len(t_args) == 1: + if len(t_args) == 1 and (t_args[0] is float or t_args[0] is int): val = list(map(t_args[0], val)) - elif val is not None and (field_type is int or field_type is bool or field_type is float): + elif val is not None and ( + field_type is int or field_type is bool or field_type is float + ): try: val = field_type(val) except: - pass # ignore errors here, they are often from interpolation args + pass # ignore errors here, they are often from interpolation args if val is None: overrides.append("{}.{}=null".format(sub_node, k)) @@ -430,7 +434,7 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): if k in cfg and isinstance(cfg[k], DictConfig): if k in overrides and isinstance(overrides[k], dict): for ok, ov in overrides[k].items(): - if isinstance(ov, dict): + if isinstance(ov, dict) and cfg[k][ok] is not None: overwrite_args_by_name(cfg[k][ok], ov) else: cfg[k][ok] = ov diff --git a/fairseq/optim/composite.py b/fairseq/optim/composite.py index 1a581bc010..a5366d6243 100644 --- a/fairseq/optim/composite.py +++ b/fairseq/optim/composite.py @@ -22,12 +22,13 @@ class OptimizerAndSchedulerConfig(FairseqDataclass): optimizer: Any = None lr_scheduler: Optional[Any] = None - lr: List[float] = II("optimization.lr") + lr: List = II("optimization.lr") + lr_float: Optional[float] = None # this makes it easier to sweep on learning rate with auto sweepers @dataclass class CompositeOptimizerConfig(FairseqDataclass): - groups: Dict[str, OptimizerAndSchedulerConfig] = field( + groups: Dict[str, Any] = field( default_factory=lambda: {}, metadata={ "help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. " @@ -64,8 +65,12 @@ def __init__(self, cfg: CompositeOptimizerConfig, params): for group, group_params in groupped_params.items(): group_cfg = cfg.groups[group] with open_dict(group_cfg): - group_cfg.optimizer.lr = group_cfg.lr - group_cfg.lr_scheduler.lr = group_cfg.lr + if group_cfg.lr_float is not None: + group_cfg.optimizer.lr = [group_cfg.lr_float] + group_cfg.lr_scheduler.lr = [group_cfg.lr_float] + else: + group_cfg.optimizer.lr = group_cfg.lr + group_cfg.lr_scheduler.lr = group_cfg.lr self.optimizers[group] = _build_optimizer(group_cfg.optimizer, group_params) if group_cfg.lr_scheduler is not None: self.lr_schedulers[group] = build_lr_scheduler( From f1c595beb8acd2a6dc8c9fa9f7fb60ca23c61899 Mon Sep 17 00:00:00 2001 From: Kaushik Rangadurai Date: Thu, 4 Mar 2021 11:48:27 -0800 Subject: [PATCH 60/82] Ability to pass attn_mask to TransformerSentenceEncoder Summary: Provide an ability to pass attn_mask to TransformerSentenceEncoder. The default is None and hence this is backwards compatible. The attention mask can either be a 2D tensor (of shape [tgt_seq_len, src_seq_len]) or a 3D tensor of shape (bcz * num_heads, tgt_seq_len, src_seq_len). In case of self attention, tgt_seq_len = src_seq_len. Reviewed By: myleott Differential Revision: D26790767 fbshipit-source-id: 937d6c6cf08790c7d43d33fda97a30425f31ea06 --- fairseq/modules/transformer_sentence_encoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 6e9c32f467..a7fb198779 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -226,6 +226,7 @@ def forward( last_state_only: bool = False, positions: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: is_tpu = tokens.device.type == "xla" @@ -268,7 +269,7 @@ def forward( inner_states.append(x) for layer in self.layers: - x, _ = layer(x, self_attn_padding_mask=padding_mask) + x, _ = layer(x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask) if not last_state_only: inner_states.append(x) From 6d23cc7e7c32d1a6aa1d2d4a4c94abe50c980126 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 4 Mar 2021 13:31:02 -0800 Subject: [PATCH 61/82] Move checkpoint state_dict creation into Trainer (#1666) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1666 Context: the checkpoint saving call stack has become a bit convoluted: ``` train.py + checkpoint_utils.save_checkpoint + trainer.save_checkpoint + checkpoint_utils.save_state + checkpoint_utils.torch_persistent_save ``` This diff slightly simplifies the checkpoint saving logic by exposing a `state_dict` method inside the Trainer. This simplifies the call stack to: ``` train.py + checkpoint_utils.save_checkpoint + trainer.save_checkpoint + checkpoint_utils.torch_persistent_save ``` This new structure is important for the FullyShardedDataParallel diff (next diff in the stack), since it enables the Trainer to save multiple checkpoints for the different optimizer state shards. Test Plan: - unit tests - trained WMT En-De models; confirmed checkpoints save/load properly, resuming from a checkpoint gives identical results - `buck test fblearner/flow/projects/langtech/translation:tests` (2 failures are in trunk too): https://www.internalfb.com/intern/testinfra/testconsole/testrun/2533274840914654/ Reviewed By: zhengwy888 Differential Revision: D26771146 Pulled By: myleott fbshipit-source-id: 10f91979cd42205c1d8abcaa9ab56f63eba31e93 --- fairseq/checkpoint_utils.py | 71 ++++------------------------------ fairseq/dataclass/configs.py | 1 - fairseq/trainer.py | 67 +++++++++++++++++++++++++------- tests/test_checkpoint_utils.py | 7 ++-- tests/test_train.py | 1 + 5 files changed, 64 insertions(+), 83 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 97f22041bc..5a98dad2aa 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -31,7 +31,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): from fairseq import meters # only one worker should attempt to create the required dir - if cfg.distributed_rank == 0: + if trainer.data_parallel_rank == 0: os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) @@ -44,7 +44,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): trainer.consolidate_optimizer() - if not trainer.is_data_parallel_master: + if not trainer.should_save_checkpoint_on_current_rank: return write_timer = meters.StopwatchMeter() @@ -59,7 +59,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): def is_better(a, b): return a >= b if cfg.maximize_best_checkpoint_metric else a <= b - suffix = cfg.checkpoint_suffix or "" + suffix = trainer.checkpoint_suffix checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 @@ -165,7 +165,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): " or reset_lr_scheduler or reset_meters or reset_dataloader" ) - suffix = cfg.checkpoint_suffix + suffix = trainer.checkpoint_suffix if ( cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' @@ -190,7 +190,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): raise ValueError( f"--funetune-from-model {cfg.finetune_from_model} does not exist" ) - elif cfg.model_parallel_size > 1: + elif suffix is not None: checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = cfg.restore_file @@ -405,8 +405,8 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] -def torch_persistent_save(cfg: CheckpointConfig, obj, filename): - if cfg.write_checkpoints_asynchronously: +def torch_persistent_save(obj, filename, async_write: bool = False): + if async_write: with PathManager.opena(filename, "wb") as f: _torch_persistent_save(obj, f) else: @@ -434,61 +434,6 @@ def _torch_persistent_save(obj, f): logger.error(traceback.format_exc()) -def save_state( - filename, - cfg: FairseqConfig, - model_state_dict, - criterion, - optimizer, - lr_scheduler, - num_updates, - optim_history=None, - extra_state=None, - task=None, - **kwargs, -): - from fairseq import utils - - if optim_history is None: - optim_history = [] - if extra_state is None: - extra_state = {} - state_dict = { - "cfg": OmegaConf.to_container(cfg) if OmegaConf.is_config(cfg) else cfg, - "args": kwargs.get("args", None), - "model": model_state_dict or {}, - "optimizer_history": optim_history - + [ - { - "criterion_name": criterion.__class__.__name__, - "optimizer_name": optimizer.__class__.__name__, - "lr_scheduler_state": lr_scheduler.state_dict(), - "num_updates": num_updates, - } - ], - "extra_state": extra_state, - "task_state": task.state_dict() if task is not None else {}, - } - if utils.has_parameters(criterion): - state_dict["criterion"] = criterion.state_dict() - - if cfg is None: - cfg = state_dict["args"] - assert cfg is not None, "must provide cfg or args" - - if isinstance(cfg, DictConfig): - no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state - else: - no_save_optimizer_state = cfg.no_save_optimizer_state - if not no_save_optimizer_state: - state_dict["last_optimizer_state"] = optimizer.state_dict() - - # keep everything on CPU - state_dict = utils.move_to_cpu(state_dict) - - torch_persistent_save(cfg.checkpoint, state_dict, filename) - - def _upgrade_state_dict(state): """Helper for upgrading old model checkpoints.""" from fairseq import models, registry, tasks @@ -529,7 +474,7 @@ def _upgrade_state_dict(state): if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 # old model checkpoints may not have separate source/target positions - if hasattr(state["args"], "max_positions") and not hasattr( + if "args" in state and hasattr(state["args"], "max_positions") and not hasattr( state["args"], "max_source_positions" ): state["args"].max_source_positions = state["args"].max_positions diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 39355b1caf..4d3c60bfd6 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -618,7 +618,6 @@ class CheckpointConfig(FairseqDataclass): }, ) model_parallel_size: int = II("common.model_parallel_size") - distributed_rank: int = II("distributed_training.distributed_rank") @dataclass diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 680a7ee953..45d9591d7c 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -25,6 +25,8 @@ from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler +from omegaconf import OmegaConf + logger = logging.getLogger(__name__) @@ -171,6 +173,16 @@ def use_distributed_wrapper(self) -> bool: and not self.cfg.optimization.use_bmuf ) + @property + def should_save_checkpoint_on_current_rank(self) -> bool: + """Indicates whether to save checkpoints on the current DDP rank.""" + return self.is_data_parallel_master + + @property + def checkpoint_suffix(self) -> str: + """Suffix to add to the checkpoint file name.""" + return self.cfg.checkpoint.checkpoint_suffix or "" + @property def criterion(self): if self._wrapped_criterion is None: @@ -274,25 +286,50 @@ def consolidate_optimizer(self): if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): self.optimizer.optimizer.consolidate_state_dict() + def state_dict(self): + state_dict = { + "args": None, # legacy + "cfg": ( + OmegaConf.to_container(self.cfg) + if OmegaConf.is_config(self.cfg) else self.cfg + ), + "model": self.model.state_dict(), + "criterion": ( + self.criterion.state_dict() + if utils.has_parameters(self.criterion) else None + ), + "optimizer_history": (self._optim_history or []) + + [ + { + "criterion_name": self.get_criterion().__class__.__name__, + "optimizer_name": self.optimizer.__class__.__name__, + "lr_scheduler_state": self.lr_scheduler.state_dict(), + "num_updates": self.get_num_updates(), + } + ], + "task_state": self.task.state_dict() if self.task is not None else {}, + "extra_state": { + "metrics": metrics.state_dict(), + "previous_training_time": self.cumulative_training_time(), + } + } + if not self.cfg.checkpoint.no_save_optimizer_state: + state_dict["last_optimizer_state"] = self.optimizer.state_dict() + return state_dict + def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" - if self.is_data_parallel_master: # only save one checkpoint - logger.info(f"Saving checkpoint to {filename}") - extra_state["metrics"] = metrics.state_dict() - extra_state["previous_training_time"] = self.cumulative_training_time() - checkpoint_utils.save_state( + logger.info(f"Saving checkpoint to {filename}") + # call state_dict on all ranks in case it needs internal communication + state_dict = utils.move_to_cpu(self.state_dict()) + state_dict["extra_state"].update(extra_state) + if self.should_save_checkpoint_on_current_rank: + checkpoint_utils.torch_persistent_save( + state_dict, filename, - self.cfg, - self.model.state_dict(), - self.get_criterion(), - self.optimizer, - self.lr_scheduler, - self.get_num_updates(), - optim_history=self._optim_history, - extra_state=extra_state, - task=self.task, + async_write=self.cfg.checkpoint.write_checkpoints_asynchronously, ) - logger.info(f"Finished saving checkpoint to {filename}") + logger.info(f"Finished saving checkpoint to {filename}") def load_checkpoint( self, diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index 3278de6b9f..0f28222633 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -90,15 +90,14 @@ def test_prune_state_dict(self): self.assertEqual(len(ensemble[0].decoder.layers), 1) def test_torch_persistent_save_async(self): - cfg = OmegaConf.create() - cfg.dataset = OmegaConf.create() - cfg.dataset.write_checkpoints_asynchronously = True state_dict = {} filename = "async_checkpoint.pt" with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena: with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save: - checkpoint_utils.torch_persistent_save(cfg.dataset, state_dict, filename) + checkpoint_utils.torch_persistent_save( + state_dict, filename, async_write=True + ) mock_opena.assert_called_with(filename, "wb") mock_save.assert_called() diff --git a/tests/test_train.py b/tests/test_train.py index 57daa194b2..65f4683bc6 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -68,6 +68,7 @@ def get_mock_cfg(finetune_from_model): "reset_lr_scheduler": False, "finetune_from_model": finetune_from_model, "model_parallel_size": 1, + "restore_file": "checkpoint_last.pt", }, "common": { "model_parallel_size": 1, From 656d7e5779a9ec4ccf0ad45d86a4ce589c597588 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 4 Mar 2021 13:31:02 -0800 Subject: [PATCH 62/82] Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) (#1667) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1667 Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) This enables fully parameter + optimizer state sharding by using FullyShardedDataParallel (FSDP) from fairscale. The user just needs to provide `--ddp-backend=fully_sharded` to enable. Other common options work out-of-the-box (e.g., `--fp16`, `--memory-efficient-fp16`, `--update-freq`, etc.). This should be a drop-in replacement for the "c10d" backend. This yields pretty big speedups for small models and enables training ~13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs, without model parallelism. This also adds a new option `--cpu-offload` that offloads the optimizer state and FP32 model copy to CPU, which is particularly useful when combined with `--optimizer=cpu_adam`. Note: after enabling this, each GPU will save a checkpoint file, since the optimizer state is sharded. Each checkpoint will contain a single shard of the optimizer state and the rank 0 checkpoint will contain the full model weights. Note: a known limitation of the current implementation is that you cannot resume training on a different world_size. This constraint will be relaxed in future iterations. Test Plan: Imported from OSS Reviewed By: sshleifer Differential Revision: D26771144 Pulled By: myleott fbshipit-source-id: 74c2f46f57719e24e2dcfc9d9ee7c2fc0aeedb46 --- fairseq/dataclass/configs.py | 15 +++ fairseq/dataclass/constants.py | 1 + fairseq/distributed/__init__.py | 4 + .../fully_sharded_data_parallel.py | 122 ++++++++++++++++++ fairseq/models/distributed_fairseq_model.py | 21 ++- fairseq/models/fairseq_model.py | 20 ++- fairseq/models/transformer.py | 6 + fairseq/optim/cpu_adam.py | 4 + fairseq/optim/fp16_optimizer.py | 14 +- fairseq/trainer.py | 84 ++++++++++-- fairseq_cli/train.py | 15 ++- tests/test_binaries.py | 10 +- tests/test_dataset.py | 7 + 13 files changed, 292 insertions(+), 31 deletions(-) create mode 100644 fairseq/distributed/fully_sharded_data_parallel.py diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 4d3c60bfd6..5d6aee157a 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -355,7 +355,22 @@ class DistributedTrainingConfig(FairseqDataclass): zero_sharding: ZERO_SHARDING_CHOICES = field( default="none", metadata={"help": "ZeRO sharding"} ) + fp16: bool = II("common.fp16") + memory_efficient_fp16: bool = II("common.memory_efficient_fp16") tpu: bool = II("common.tpu") + # configuration for --ddp-backend=fully_sharded + no_reshard_after_forward: bool = field( + default=False, + metadata={"help": "don't reshard parameters after forward pass"}, + ) + fp32_reduce_scatter: bool = field( + default=False, + metadata={"help": "reduce-scatter grads in FP32"}, + ) + cpu_offload: bool = field( + default=False, + metadata={"help": "offload FP32 params to CPU"} + ) @dataclass diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 93bc6d03cb..faba0862fa 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -37,6 +37,7 @@ def ChoiceEnum(choices: List[str]): LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) DDP_BACKEND_CHOICES = ChoiceEnum([ "c10d", # alias for pytorch_ddp + "fully_sharded", # FullyShardedDataParallel from fairscale "legacy_ddp", "no_c10d", # alias for legacy_ddp "pytorch_ddp", diff --git a/fairseq/distributed/__init__.py b/fairseq/distributed/__init__.py index 7f4016e38c..d0b96b734c 100644 --- a/fairseq/distributed/__init__.py +++ b/fairseq/distributed/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from .distributed_timeout_wrapper import DistributedTimeoutWrapper +from .fully_sharded_data_parallel import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel from .legacy_distributed_data_parallel import LegacyDistributedDataParallel from .module_proxy_wrapper import ModuleProxyWrapper from .tpu_distributed_data_parallel import TPUDistributedDataParallel @@ -11,6 +12,9 @@ __all__ = [ "DistributedTimeoutWrapper", + "fsdp_enable_wrap", + "fsdp_wrap", + "FullyShardedDataParallel", "LegacyDistributedDataParallel", "ModuleProxyWrapper", "TPUDistributedDataParallel", diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py new file mode 100644 index 0000000000..9d74398325 --- /dev/null +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from typing import Optional + +import torch + +from fairseq.dataclass.configs import DistributedTrainingConfig +from fairseq.distributed import utils as dist_utils + + +try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + has_FSDP = True +except ImportError: + FSDP = torch.nn.Module + has_FSDP = False + + +class FullyShardedDataParallel(FSDP): + """ + A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some + fairseq-specific checkpoint saving/loading logic. + + Args: + use_sharded_state (bool): if True, then ``state_dict`` will return + ``FSDP.local_state_dict`` and ``load_state_dict`` will call + ``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will + return the full model weights on data parallel rank 0 (empty on + other ranks) and ``load_state_dict`` will broadcast model weights + from rank 0 to other ranks. + """ + + def __init__(self, *args, use_sharded_state: bool = False, **kwargs): + if not has_FSDP: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + super().__init__(*args, **kwargs) + self.use_sharded_state = use_sharded_state + + def state_dict(self, destination=None, prefix='', keep_vars=False): + if self.use_sharded_state: + return super().local_state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + if self.rank == 0: + return super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + # We must call state_dict() due to use of communication + # primitives. But we don't use the result. + super().state_dict() + return destination or {} + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + if self.use_sharded_state: + return super().load_local_state_dict(state_dict, strict=strict) + else: + state_dict = dist_utils.broadcast_object( + state_dict, src_rank=0, group=self.process_group + ) + return super().load_state_dict(state_dict, strict=strict) + + +@contextlib.contextmanager +def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False): + try: + from fairscale.nn import enable_wrap + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + if cfg.memory_efficient_fp16: + assert cfg.fp16 # memory_efficient_fp16 should imply fp16 + group = dist_utils.get_data_parallel_group() + if group is None and cfg.distributed_world_size == 1: + from fairscale.utils.testing import DummyProcessGroup + group = DummyProcessGroup(rank=0, size=1) + fsdp_config = { + "process_group": group, + "reshard_after_forward": not cfg.no_reshard_after_forward, + "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16, + "fp32_reduce_scatter": cfg.fp32_reduce_scatter, + "flatten_parameters": True, + "cpu_offload": cfg.cpu_offload, + "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, + "bucket_cap_mb": cfg.bucket_cap_mb, + } + with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config): + yield + + +def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): + """ + Helper to wrap layers/modules in FSDP. This falls back to a no-op if + fairscale is not available. + + Args: + module (nn.Module): module to (maybe) wrap + min_num_params (int, Optional): minimum number of layer params to wrap + """ + try: + from fairscale.nn import wrap + cls = FullyShardedDataParallel + if min_num_params is not None: + num_params = sum(p.numel() for p in module.parameters()) + if num_params >= min_num_params: + return wrap(module, cls=cls, **kwargs) + else: + return module + else: + return wrap(module, cls=cls, **kwargs) + except ImportError: + return module diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index ca157f06e9..3422faea74 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -105,12 +105,27 @@ def DistributedFairseqModel(args, model, process_group, device): ) # forward missing getattr and state_dict/load_state_dict to orig model wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend == "fully_sharded": + try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP" + wrapped_model = model + if args.memory_efficient_fp16: + wrapped_model = wrapped_model.half() + if not args.cpu_offload: + wrapped_model = wrapped_model.to(device=device) else: raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) # kill hung distributed jobs after a timeout - wrapped_model = DistributedTimeoutWrapper( - wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) - ) + if getattr(args, "heartbeat_timeout", -1) > 0: + wrapped_model = DistributedTimeoutWrapper( + wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) + ) return wrapped_model diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 186f3d2464..d393c02ae6 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -27,6 +27,13 @@ logger = logging.getLogger(__name__) +def check_type(module, expected_type): + if hasattr(module, "unwrapped_module"): + assert isinstance(module.unwrapped_module, expected_type) + else: + assert isinstance(module, expected_type) + + class BaseFairseqModel(nn.Module): """Base class for fairseq models.""" @@ -284,8 +291,9 @@ def __init__(self, encoder, decoder): self.encoder = encoder self.decoder = decoder - assert isinstance(self.encoder, FairseqEncoder) - assert isinstance(self.decoder, FairseqDecoder) + + check_type(self.encoder, FairseqEncoder) + check_type(self.decoder, FairseqDecoder) def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): """ @@ -365,8 +373,8 @@ def __init__(self, encoders, decoders): assert encoders.keys() == decoders.keys() self.keys = list(encoders.keys()) for key in self.keys: - assert isinstance(encoders[key], FairseqEncoder) - assert isinstance(decoders[key], FairseqDecoder) + check_type(encoders[key], FairseqEncoder) + check_type(decoders[key], FairseqDecoder) self.models = nn.ModuleDict( { @@ -469,7 +477,7 @@ class FairseqLanguageModel(BaseFairseqModel): def __init__(self, decoder): super().__init__() self.decoder = decoder - assert isinstance(self.decoder, FairseqDecoder) + check_type(self.decoder, FairseqDecoder) def forward(self, src_tokens, **kwargs): """ @@ -530,7 +538,7 @@ class FairseqEncoderModel(BaseFairseqModel): def __init__(self, encoder): super().__init__() self.encoder = encoder - assert isinstance(self.encoder, FairseqEncoder) + check_type(self.encoder, FairseqEncoder) def forward(self, src_tokens, src_lengths, **kwargs): """ diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index f2f36baf3e..a0a0b8dcd5 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from fairseq import utils +from fairseq.distributed import fsdp_wrap from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, @@ -240,6 +241,9 @@ def build_model(cls, args, task): args.checkpoint_activations = True # offloading implies checkpointing encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + if not args.share_all_embeddings: + encoder = fsdp_wrap(encoder, min_num_params=1e8) + decoder = fsdp_wrap(decoder, min_num_params=1e8) return cls(args, encoder, decoder) @classmethod @@ -386,6 +390,7 @@ def build_encoder_layer(self, args): if getattr(args, "checkpoint_activations", False): offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + layer = fsdp_wrap(layer, min_num_params=1e8) return layer def forward_embedding( @@ -726,6 +731,7 @@ def build_decoder_layer(self, args, no_encoder_attn=False): if getattr(args, "checkpoint_activations", False): offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + layer = fsdp_wrap(layer, min_num_params=1e8) return layer def forward( diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py index fad5a64ecb..5e935df1a5 100644 --- a/fairseq/optim/cpu_adam.py +++ b/fairseq/optim/cpu_adam.py @@ -107,6 +107,10 @@ def __init__( self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode ) + @property + def supports_flat_params(self): + return True + @torch.no_grad() def step(self, closure=None): loss = None diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index e0b069f172..00ea1bbb76 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -322,6 +322,10 @@ def set_lr(self, lr): def all_reduce_grads(self, module): self.fp32_optimizer.all_reduce_grads(module) + @property + def supports_flat_params(self): + return self.fp32_optimizer.supports_flat_params + class _MemoryEfficientFP16OptimizerMixin(object): def __init__(self, *args, **kwargs): @@ -442,6 +446,10 @@ def zero_grad(self): else: self._multiply_factor = 1.0 + @property + def supports_flat_params(self): + return self.wrapped_optimizer.supports_flat_params + class MemoryEfficientFP16Optimizer( _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer @@ -461,8 +469,10 @@ class MemoryEfficientFP16Optimizer( *supports_memory_efficient_fp16* property. """ - def __init__(self, cfg: DictConfig, params, optimizer, **kwargs): - if not optimizer.supports_memory_efficient_fp16: + def __init__( + self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs + ): + if not allow_unsupported and not optimizer.supports_memory_efficient_fp16: raise ValueError( "Unsupported optimizer: {}".format(optimizer.__class__.__name__) ) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 45d9591d7c..4d47d39897 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -63,15 +63,31 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): else: self.device = torch.device("cpu") + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if self.cfg.common.bf16: + raise ValueError( + "FullyShardedDataParallel is not compatible with --bf16 or " + "--memory-efficient-bf16" + ) + if self.cfg.distributed_training.zero_sharding != "none": + raise ValueError( + "FullyShardedDataParallel is not compatible with --zero-sharding " + "option (it's already built in)" + ) + else: + if self.cfg.distributed_training.cpu_offload: + raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded") + # copy model and criterion to current device/dtype self._criterion = criterion self._model = model - if cfg.common.fp16: - self._criterion = self._criterion.half() - self._model = self._model.half() - elif cfg.common.bf16: - self._criterion = self._criterion.to(dtype=torch.bfloat16) - self._model = self._model.to(dtype=torch.bfloat16) + if cfg.distributed_training.ddp_backend != "fully_sharded": + if cfg.common.fp16: + self._criterion = self._criterion.half() + self._model = self._model.half() + elif cfg.common.bf16: + self._criterion = self._criterion.to(dtype=torch.bfloat16) + self._model = self._model.to(dtype=torch.bfloat16) if ( not cfg.distributed_training.pipeline_model_parallel # the DistributedFairseqModel wrapper will handle moving to device, @@ -171,17 +187,26 @@ def use_distributed_wrapper(self) -> bool: return ( self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf + ) or ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and self.cfg.distributed_training.cpu_offload ) @property def should_save_checkpoint_on_current_rank(self) -> bool: """Indicates whether to save checkpoints on the current DDP rank.""" - return self.is_data_parallel_master + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + return True + else: + return self.is_data_parallel_master @property def checkpoint_suffix(self) -> str: """Suffix to add to the checkpoint file name.""" - return self.cfg.checkpoint.checkpoint_suffix or "" + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(self.data_parallel_rank) + else: + return self.cfg.checkpoint.checkpoint_suffix or "" @property def criterion(self): @@ -234,7 +259,20 @@ def _build_optimizer(self): ) ) - if self.cfg.common.fp16 or self.cfg.common.bf16: + if ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and self.cfg.common.fp16 + ): + # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper, + # mostly for the grad scaling. But if we don't have the + # --memory-efficient-fp16 flag set, then we're effectively doing + # regular --fp16 and can allow the use of optimizers that would + # otherwise be unsupported by MemoryEfficientFP16Optimizer. + allow_unsupported = not self.cfg.common.memory_efficient_fp16 + self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( + self.cfg, params, allow_unsupported=allow_unsupported + ) + elif self.cfg.common.fp16 or self.cfg.common.bf16: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: logger.info( "NOTE: your device does NOT support faster training with --fp16, " @@ -254,6 +292,16 @@ def _build_optimizer(self): logger.info("NOTE: your device may support faster training with --fp16") self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + assert not self.cfg.optimization.use_bmuf, \ + "--ddp-backend=fully_sharded is not compatible with BMUF" + assert self._optimizer.supports_flat_params, ( + "--ddp-backend=fully_sharded is only compatible with pointwise " + "optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). " + "However, the sharding will result in slightly different results when " + "using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)" + ) + if self.cfg.optimization.use_bmuf: self._optimizer = optim.FairseqBMUF( self.cfg.bmuf, @@ -355,6 +403,8 @@ def load_checkpoint( # TPUs don't support broadcast yet, so load checkpoints # on every worker for now or self.tpu + # FSDP requires loading checkpoint shards on all ranks + or self.cfg.distributed_training.ddp_backend == "fully_sharded" ) if load_on_all_ranks or self.data_parallel_rank == 0: @@ -965,7 +1015,21 @@ def set_num_updates(self, num_updates): metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) def clip_grad_norm(self, clip_norm): - return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None) + + def agg_norm_fn(total_norm): + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + total_norm = total_norm ** 2 + if ( + self.data_parallel_process_group is not None + or torch.distributed.is_initialized() + ): + total_norm = distributed_utils.all_reduce( + total_norm.cuda(), group=self.data_parallel_process_group + ) + total_norm = total_norm ** 0.5 + return total_norm + + return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=agg_norm_fn) def cumulative_training_time(self): if self._cumulative_training_time is None: diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 80ad57acd1..d770e4e4ec 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -18,7 +18,6 @@ import torch from fairseq import ( checkpoint_utils, - distributed_utils, options, quantization_utils, tasks, @@ -27,7 +26,7 @@ from fairseq.data import iterators from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf -from fairseq.distributed_utils import is_master +from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils from fairseq.file_io import PathManager from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer @@ -50,7 +49,7 @@ def main(cfg: FairseqConfig) -> None: utils.import_user_module(cfg.common) - if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: + if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) @@ -87,7 +86,11 @@ def main(cfg: FairseqConfig) -> None: assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion - model = task.build_model(cfg.model) + if cfg.distributed_training.ddp_backend == "fully_sharded": + with fsdp_enable_wrap(cfg.distributed_training): + model = fsdp_wrap(task.build_model(cfg.model)) + else: + model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) @@ -95,8 +98,8 @@ def main(cfg: FairseqConfig) -> None: logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( "num. model params: {:,} (num. trained: {:,})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()), + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad), ) ) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 3cb98897bf..e10cc767b8 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -1697,8 +1697,9 @@ def test_activation_offloading_does_not_change_metrics(self): """Neither ----checkpoint-activations nor --offload-activations should change loss""" with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: - create_dummy_data(data_dir, num_examples=20) - preprocess_translation_data(data_dir) + with self.assertLogs(): + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) offload_logs = self._train(data_dir, ["--offload-activations"]) baseline_logs = self._train(data_dir, []) @@ -1720,8 +1721,9 @@ def test_activation_checkpointing_does_not_change_metrics(self): """--checkpoint-activations should not change loss""" with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: - create_dummy_data(data_dir, num_examples=20) - preprocess_translation_data(data_dir) + with self.assertLogs(): + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) ckpt_logs = self._train(data_dir, ["--checkpoint-activations"]) baseline_logs = self._train(data_dir, []) assert len(baseline_logs) == len(ckpt_logs) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9fb69a5f77..a3e3970028 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import unittest from typing import Sequence @@ -20,6 +21,12 @@ def sample(id: int, length: int): class TestDataset(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + def test_round_robin_zip_datasets(self): long_dataset = lang_pair_dataset([10, 9, 8, 11]) short_dataset = lang_pair_dataset([11, 9]) From 73886ac228f8f0368871237f7498ec8b07444322 Mon Sep 17 00:00:00 2001 From: Ning Dong Date: Thu, 4 Mar 2021 14:20:00 -0800 Subject: [PATCH 63/82] Refactor FairseqSimulSTAgent Summary: 1. In fblearner flow we are dumping cmvn stats into json file (e.g. f253830726) Previously there's only --config option taking .npz path from a yaml file, and it's the only usage for the config. This diff adds an option --global-stats to import from json. 2. Inherit FairseqSimulSTAgent from nn.Module instead of SpeechAgent whose root class is object to prepare for scripting methods. Copy over / simplify all the necessary methods from SpeechAgent/Agent. Reviewed By: jmp84 Differential Revision: D26800957 fbshipit-source-id: 74be527f8473c13405a60bb16ce6da5a7dc0b888 --- .../agents/fairseq_simul_st_agent.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index f944203785..2b5fdc2d3f 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -1,19 +1,20 @@ import math import os - +import json import numpy as np import torch import torchaudio.compliance.kaldi as kaldi import yaml from fairseq import checkpoint_utils, tasks +from fairseq.file_io import PathManager try: from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS - from simuleval.agents import SpeechAgent - from simuleval.states import ListEntry + from simuleval.states import ListEntry, SpeechStates except ImportError: print("Please install simuleval 'pip install simuleval'") +from torch import nn SHIFT_SIZE = 10 WINDOW_SIZE = 25 @@ -112,12 +113,12 @@ def info(self): } -class FairseqSimulSTAgent(SpeechAgent): +class FairseqSimulSTAgent(nn.Module): speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size def __init__(self, args): - super().__init__(args) + super().__init__() self.eos = DEFAULT_EOS @@ -136,13 +137,18 @@ def __init__(self, args): self.model.decoder.layers[0].encoder_attn.pre_decision_ratio ) - with open(args.config, "r") as f: - config = yaml.load(f, Loader=yaml.BaseLoader) + args.global_cmvn = None + if args.config: + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.BaseLoader) - if "global_cmvn" in config: - args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) - else: - args.global_cmvn = None + if "global_cmvn" in config: + args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) + + if args.global_stats: + with PathManager.open(args.global_stats, "r") as f: + global_cmvn = json.loads(f.read()) + self.global_cmvn = {"mean": global_cmvn["mean"], "std": global_cmvn["stddev"]} self.feature_extractor = OnlineFeatureExtractor(args) @@ -152,6 +158,13 @@ def __init__(self, args): torch.set_grad_enabled(False) + def build_states(self, args, client, sentence_id): + # Initialize states here, for example add customized entry to states + # This function will be called at beginning of every new sentence + states = SpeechStates(args, client, sentence_id, self) + self.initialize_states(states) + return states + def to_device(self, tensor): if self.gpu: return tensor.cuda() @@ -165,8 +178,10 @@ def add_args(parser): help='path to your pretrained model.') parser.add_argument("--data-bin", type=str, required=True, help="Path of data binary") - parser.add_argument("--config", type=str, required=True, + parser.add_argument("--config", type=str, default=None, help="Path to config yaml file") + parser.add_argument("--global-stats", type=str, default=None, + help="Path to json file containing cmvn stats") parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", help="Subword splitter type for target text") parser.add_argument("--tgt-splitter-path", type=str, default=None, From 7c95746a7e5e4a087399d186590815e45ae775c8 Mon Sep 17 00:00:00 2001 From: Changhan Wang Date: Thu, 4 Mar 2021 17:17:11 -0800 Subject: [PATCH 64/82] fix bug on converting stereo audio in audio_utils.py Summary: Fix bug on converting stereo audio in audio_utils.py - Github issue: https://github.com/pytorch/fairseq/issues/3303 Reviewed By: jmp84 Differential Revision: D26825964 fbshipit-source-id: 26905e71540bc52e98d76996b199ac0fbe78357b --- fairseq/data/audio/audio_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index f0e75b1d65..f8cc80f5e2 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -63,8 +63,8 @@ def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarr # Mono channel: D -> 1 x D waveform = waveform.unsqueeze(0) else: - # Merge multiple channels to one: C x D -> 1 x D - waveform, _ = ta_sox.apply_effects_tensor(waveform, sample_rate, ['channels', '1']) + # Merge multiple channels to one: D x C -> 1 x D + waveform, _ = ta_sox.apply_effects_tensor(waveform.T, sample_rate, [['channels', '1']]) features = ta_kaldi.fbank( waveform, num_mel_bins=n_bins, sample_frequency=sample_rate From 16c1a200f87a2adb6395e353345c19bbe990d1dd Mon Sep 17 00:00:00 2001 From: sarapapi <57095209+sarapapi@users.noreply.github.com> Date: Mon, 8 Mar 2021 14:10:29 -0800 Subject: [PATCH 65/82] Fix Global CMVN path of MustC data preprocessing (#3307) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fix a typo in gcmv_path given for config yaml generation (actual: gcvmn_cvmn_path, correct: gcmvn_path) ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3307 Reviewed By: jmp84 Differential Revision: D26826231 Pulled By: kahne fbshipit-source-id: 6b60f2a8a8b4ba1c0c088299a08ef04fdfe870a8 --- examples/speech_to_text/prep_mustc_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 4e410bcb18..45fd43533d 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -179,7 +179,7 @@ def process(args): yaml_filename=f"config_{args.task}.yaml", specaugment_policy="lb", cmvn_type=args.cmvn_type, - gcmvn_cmvn_path=( + gcmvn_path=( cur_root / "gcmvn.npz" if args.cmvn_type == "global" else None ), From 00d5b7adbeaf64e02c53a591d637efe4c8cad923 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 9 Mar 2021 06:28:23 -0800 Subject: [PATCH 66/82] Add README/tutorial for Fully Sharded Data Parallel (#3327) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3327 Reviewed By: sshleifer Differential Revision: D26899416 Pulled By: myleott fbshipit-source-id: bbb493a5c4e0a51f3b26fe8f94e3962b6206d6f6 --- .github/workflows/build.yml | 3 +- README.md | 11 +- .../fully_sharded_data_parallel/README.md | 164 ++++++++++++++++++ .../fully_sharded_data_parallel.py | 18 +- fairseq/models/fairseq_model.py | 5 +- fairseq/models/roberta/model.py | 26 ++- fairseq/models/transformer.py | 40 ++++- fairseq/models/transformer_lm.py | 121 +++++++++++-- fairseq/trainer.py | 16 +- 9 files changed, 363 insertions(+), 41 deletions(-) create mode 100644 examples/fully_sharded_data_parallel/README.md diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0af8bad95d..105c42a503 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -39,7 +39,8 @@ jobs: - name: Install optional test requirements run: | - python -m pip install fairscale iopath transformers pyarrow + python -m pip install iopath transformers pyarrow + python -m pip install git+https://github.com/facebookresearch/fairscale.git@master - name: Lint with flake8 run: | diff --git a/README.md b/README.md index 5fedac7eec..839dd8e1de 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,9 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md) +* February 2021 [Added LASER training code](examples/laser/README.md) +* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md) * December 2020: [GottBERT model and code released](examples/gottbert/README.md) * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) @@ -68,14 +71,14 @@ We provide reference implementations of various sequence modeling papers: * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) * October 2020: [Added CRISS models and code](examples/criss/README.md) + +
Previous updates

+ * September 2020: [Added Linformer code](examples/linformer/README.md) * September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) - -

Previous updates

- * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) * April 2020: [Quant-Noise code released](examples/quant_noise/README.md) @@ -108,6 +111,8 @@ We provide reference implementations of various sequence modeling papers: * [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) * [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration +* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md) +* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md) We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) with a convenient `torch.hub` interface: diff --git a/examples/fully_sharded_data_parallel/README.md b/examples/fully_sharded_data_parallel/README.md new file mode 100644 index 0000000000..bc98670968 --- /dev/null +++ b/examples/fully_sharded_data_parallel/README.md @@ -0,0 +1,164 @@ +# Fully Sharded Data Parallel (FSDP) + +## Overview +Recent work by [Microsoft](https://arxiv.org/abs/1910.02054) and +[Google](https://arxiv.org/abs/2004.13336) has shown that data parallel +training can be made significantly more efficient by sharding the model +parameters and optimizer state across data parallel workers. These ideas are +encapsulated in the new **`FullyShardedDataParallel` (FSDP)** wrapper provided +by [fairscale](https://github.com/facebookresearch/fairscale/). + +Compared to PyTorch DDP: +* FSDP produces identical results as PyTorch DDP (it's still synchronous data parallel training) +* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs +* FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the communication can be overlapped with the forward pass +* FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs + +FSDP is fully supported in fairseq via the following new arguments: +* `--ddp-backend=fully_sharded`: enables full sharding via FSDP +* `--cpu-offload`: offloads the optimizer state and FP32 model copy to CPU (combine with `--optimizer=cpu_adam`) +* `--no-reshard-after-forward`: increases training speed for some models and is similar to ZeRO stage 2 +* other popular options (`--fp16`, `--update-freq`, `--checkpoint-activations`, `--offload-activations`, etc.) continue to work as normal + +

Limitations

+ +FSDP currently has several limitations compared to fairseq's default DDP backend (PyTorch DDP): +* while FSDP is full compatible with pointwise Optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.), it is not currently compatible with non-pointwise Optimizers (e.g., Adagrad, Adafactor, LAMB, etc.) +* FSDP depends on flattening the parameters, so models that currently require `--fp16-no-flatten-grads` may not be supported + +See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed +explanation of these and other limitations. + +

+ +
How it works

+ +Fully Sharded Data Parallel + +See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed +explanation of how FSDP works. + +

+ +## Example usage + +The following examples illustrate how to train a very large language model with +13 billion parameters on 1 GPU by offloading parameters and optimizer states to +CPU, or on 8 GPUs by fully sharding the params and optimizer states across GPUs. + +These examples use the WikiText-103 dataset for demonstration purposes, but +in practice a much larger dataset will be needed to achieve good results. +Follow the [instructions here](https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.pretraining.md#1-preprocess-the-data) +to preprocess the WikiText-103 dataset using the GPT-2/RoBERTa vocabulary. + +### 13B params on 1 V100 GPU (with CPU offloading) + +The following command trains a 13B parameter GPT-3 model on a single V100 GPU +using the `--cpu-offload` feature to offload parameters and optimizer states to +CPU. In this setting, the optimizer step (Adam) happens on CPU. We also use the +`--checkpoint-activations` feature (sometimes called [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html)), +which further saves memory in exchange for a small increase in computation. + +Requirements: +- You'll need 32GB of GPU memory and 256GB of system memory. +- We use the CPU Adam optimizer from [DeepSpeed](https://github.com/microsoft/DeepSpeed), so you'll need to `pip install deepspeed` before running the command. + +Some notes: +- The command will take ~5 minutes to start training, during which time it will appear to be hung, since randomly initializing 13B weights can be slow. +- The `--cpu-offload` feature requires training in mixed precision (`--fp16`). +- Tune the `OMP_NUM_THREADS` env variable for best performance with CPU offloading. +- The example command below stops training after 10 steps (`--max-update 10`) and does not save checkpoints (`--no-save`). + +```bash +OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0 \ + fairseq-train data-bin/wikitext-103-roberta-bpe-bin \ + --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ + --cpu-offload --checkpoint-activations \ + --task language_modeling --tokens-per-sample 2048 --batch-size 8 \ + --arch transformer_lm_gpt3_13 \ + --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ + --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ + --max-update 10 --no-save --log-format json --log-interval 1 + +# Example output: +# (...) +# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +# (...) +# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs) +# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +# (...) +# Adam Optimizer #0 is created with AVX2 arithmetic capability. +# Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +# (...) +# 2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"} +# 2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"} +# 2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +# 2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +# 2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"} +# 2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"} +# 2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"} +# 2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"} +# 2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"} +# 2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"} +# 2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"} +# 2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"} +# 2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +# 2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset +# 2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"} +# 2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +# 2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"} +# 2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds +``` + +### 13B params on 8 V100 GPUs (with full parameter + optimizer state sharding) + +FSDP can also shard the parameters and optimizer states across multiple GPUs, +reducing memory requirements significantly. On 8 GPUs, sharding enables +training the same 13B parameter model *without offloading the parameters to +CPU*. However, without CPU offloading we'd only be able to fit a batch size of +1 per GPU, which would cause training speed to suffer. + +We obtain the best performance on 8 GPUs by combining full sharding and CPU +offloading. The following command trains the same 13B parameter GPT-3 model as +before on 8 GPUs; training speed increases from ~310 -> ~3200 words per second. + +```bash +OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + fairseq-train data-bin/wikitext-103-roberta-bpe-bin \ + --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ + --cpu-offload --checkpoint-activations \ + --task language_modeling --tokens-per-sample 2048 --batch-size 8 \ + --arch transformer_lm_gpt3_13 \ + --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ + --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ + --max-update 10 --no-save --log-format json --log-interval 1 + +# Example output: +# (...) +# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +# (...) +# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) +# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +# (...) +# Adam Optimizer #0 is created with AVX2 arithmetic capability. +# Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +# (...) +# 2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"} +# 2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"} +# 2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +# 2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +# 2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"} +# 2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"} +# 2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"} +# 2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"} +# 2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"} +# 2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"} +# 2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"} +# 2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"} +# 2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +# 2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset +# 2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"} +# 2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +# 2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"} +# 2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds +``` diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py index 9d74398325..9c290b3fda 100644 --- a/fairseq/distributed/fully_sharded_data_parallel.py +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -43,6 +43,13 @@ def __init__(self, *args, use_sharded_state: bool = False, **kwargs): super().__init__(*args, **kwargs) self.use_sharded_state = use_sharded_state + @property + def unwrapped_module(self) -> torch.nn.Module: + if self.flatten_parameters: + return self.module.module + else: + return self.module + def state_dict(self, destination=None, prefix='', keep_vars=False): if self.use_sharded_state: return super().local_state_dict( @@ -94,7 +101,11 @@ def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = F "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, "bucket_cap_mb": cfg.bucket_cap_mb, } - with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config): + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + use_sharded_state=use_sharded_state, + **fsdp_config, + ): yield @@ -109,14 +120,13 @@ def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): """ try: from fairscale.nn import wrap - cls = FullyShardedDataParallel if min_num_params is not None: num_params = sum(p.numel() for p in module.parameters()) if num_params >= min_num_params: - return wrap(module, cls=cls, **kwargs) + return wrap(module, **kwargs) else: return module else: - return wrap(module, cls=cls, **kwargs) + return wrap(module, **kwargs) except ImportError: return module diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index d393c02ae6..171a8a40f1 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -29,9 +29,10 @@ def check_type(module, expected_type): if hasattr(module, "unwrapped_module"): - assert isinstance(module.unwrapped_module, expected_type) + assert isinstance(module.unwrapped_module, expected_type), \ + f"{type(module.unwrapped_module)} != {expected_type}" else: - assert isinstance(module, expected_type) + assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" class BaseFairseqModel(nn.Module): diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index a2a40ba6e2..c79d4faf79 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -18,7 +18,7 @@ register_model, register_model_architecture, ) -from fairseq.models.transformer import TransformerEncoder +from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, TransformerEncoder from fairseq.modules import LayerNorm from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from fairseq.modules.transformer_sentence_encoder import init_bert_params @@ -122,6 +122,11 @@ def add_args(parser): action="store_true", help="(re-)register and load heads when loading checkpoints", ) + parser.add_argument( + "--untie-weights-roberta", + action="store_true", + help="Untie weights between embeddings and classifiers in RoBERTa", + ) # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) parser.add_argument( "--encoder-layerdrop", @@ -157,17 +162,26 @@ def add_args(parser): default=0, help="scalar quantization noise and scalar quantization at training time", ) - parser.add_argument( - "--untie-weights-roberta", - action="store_true", - help="Untie weights between embeddings and classifiers in RoBERTa", - ) + # args for "Better Fine-Tuning by Reducing Representational Collapse" (Aghajanyan et al. 2020) parser.add_argument( "--spectral-norm-classification-head", action="store_true", default=False, help="Apply spectral normalization on the classification head", ) + # args for Fully Sharded Data Parallel (FSDP) training + parser.add_argument( + "--min-params-to-wrap", + type=int, + metavar="D", + default=DEFAULT_MIN_PARAMS_TO_WRAP, + help=( + "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes." + ) + ) @classmethod def build_model(cls, args, task): diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index a0a0b8dcd5..d39e9ec7ed 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -36,6 +36,9 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024 +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + + @register_model("transformer") class TransformerModel(FairseqEncoderDecoderModel): """ @@ -191,6 +194,16 @@ def add_args(parser): help='block size of quantization noise at training time') parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, help='scalar quantization noise and scalar quantization at training time') + # args for Fully Sharded Data Parallel (FSDP) training + parser.add_argument( + '--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP, + help=( + 'minimum number of params for a layer to be wrapped with FSDP() when ' + 'training with --ddp-backend=fully_sharded. Smaller values will ' + 'improve memory efficiency, but may make torch.distributed ' + 'communication less efficient due to smaller input sizes.' + ) + ) # fmt: on @classmethod @@ -242,8 +255,11 @@ def build_model(cls, args, task): encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) if not args.share_all_embeddings: - encoder = fsdp_wrap(encoder, min_num_params=1e8) - decoder = fsdp_wrap(decoder, min_num_params=1e8) + min_params_to_wrap = getattr( + args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP + ) + encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) + decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) return cls(args, encoder, decoder) @classmethod @@ -387,10 +403,16 @@ def __init__(self, args, dictionary, embed_tokens): def build_encoder_layer(self, args): layer = TransformerEncoderLayer(args) - if getattr(args, "checkpoint_activations", False): + checkpoint = getattr(args, "checkpoint_activations", False) + if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - layer = fsdp_wrap(layer, min_num_params=1e8) + # checkpointing requires alignment to FSDP wrap boundaries + min_params_to_wrap = ( + getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) + if not checkpoint else 0 + ) + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward_embedding( @@ -728,10 +750,16 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): def build_decoder_layer(self, args, no_encoder_attn=False): layer = TransformerDecoderLayer(args, no_encoder_attn) - if getattr(args, "checkpoint_activations", False): + checkpoint = getattr(args, "checkpoint_activations", False) + if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - layer = fsdp_wrap(layer, min_num_params=1e8) + # checkpointing requires alignment to FSDP wrap boundaries + min_params_to_wrap = ( + getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) + if not checkpoint else 0 + ) + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward( diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index f12470d033..09c99b96f6 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -14,7 +14,9 @@ register_model, register_model_architecture, ) -from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.models.transformer import ( + DEFAULT_MIN_PARAMS_TO_WRAP, Embedding, TransformerDecoder +) from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder from omegaconf import II @@ -126,15 +128,6 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=False, metadata={"help": "use learned positional embeddings in the decoder"}, ) - decoder_layerdrop: float = field( - default=0.0, metadata={"help": "LayerDrop probability for decoder"} - ) - decoder_layers_to_keep: Optional[str] = field( - default=None, - metadata={ - "help": "which layers to *keep* when pruning as a comma-separated list" - }, - ) layernorm_embedding: bool = field( default=False, metadata={"help": "add layernorm to embedding"} ) @@ -148,6 +141,17 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=False, metadata={"help": "move checkpointed activations to CPU after they are used."}, ) + # config for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "LayerDrop probability for decoder"} + ) + decoder_layers_to_keep: Optional[str] = field( + default=None, + metadata={ + "help": "which layers to *keep* when pruning as a comma-separated list" + }, + ) + # config for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) quant_noise_pq: float = field( default=0.0, metadata={"help": "iterative PQ quantization noise at training time"}, @@ -156,13 +160,25 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=8, metadata={"help": "block size of quantization noise at training time"}, ) - # TODO common var add to parent quant_noise_scalar: float = field( default=0.0, metadata={ "help": "scalar quantization noise and scalar quantization at training time" }, ) + # config for Fully Sharded Data Parallel (FSDP) training + min_params_to_wrap: int = field( + default=DEFAULT_MIN_PARAMS_TO_WRAP, + metadata={ + "help": ( + "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes." + ) + } + ) + # options from other parts of the config add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") @@ -289,7 +305,7 @@ def base_lm_architecture(args): args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) - args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) args.activation_fn = getattr(args, "activation_fn", "relu") args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) @@ -428,3 +444,84 @@ def transformer_lm_gpt2_big(args): args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) + + +def base_gpt3_architecture(args): + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) + args.dropout = getattr(args, "dropout", 0.0) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_small") +def transformer_lm_gpt3_small(args): + # 125M params + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_medium") +def transformer_lm_gpt3_medium(args): + # 350M params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_large") +def transformer_lm_gpt3_large(args): + # 760M params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1536) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_xl") +def transformer_lm_gpt3_xl(args): + # 1.3B params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 24) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_2_7") +def transformer_lm_gpt3_2_7(args): + # 2.7B params + args.decoder_layers = getattr(args, "decoder_layers", 32) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_6_7") +def transformer_lm_gpt3_6_7(args): + # 6.7B params + args.decoder_layers = getattr(args, "decoder_layers", 32) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_13") +def transformer_lm_gpt3_13(args): + # 13B params + args.decoder_layers = getattr(args, "decoder_layers", 40) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 40) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_175") +def transformer_lm_gpt3_175(args): + # 175B params + args.decoder_layers = getattr(args, "decoder_layers", 96) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 12288) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 96) + base_gpt3_architecture(args) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 4d47d39897..9435558157 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -1017,15 +1017,17 @@ def set_num_updates(self, num_updates): def clip_grad_norm(self, clip_norm): def agg_norm_fn(total_norm): - if self.cfg.distributed_training.ddp_backend == "fully_sharded": - total_norm = total_norm ** 2 - if ( + if ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and ( self.data_parallel_process_group is not None or torch.distributed.is_initialized() - ): - total_norm = distributed_utils.all_reduce( - total_norm.cuda(), group=self.data_parallel_process_group - ) + ) + ): + total_norm = total_norm.cuda().float() ** 2 + total_norm = distributed_utils.all_reduce( + total_norm, group=self.data_parallel_process_group + ) total_norm = total_norm ** 0.5 return total_norm From c6006678261bf5d52e2c744508b5ddd306cafebd Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 9 Mar 2021 09:38:01 -0800 Subject: [PATCH 67/82] Update README for Fully Sharded Data Parallel (#3331) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3331 Reviewed By: sshleifer Differential Revision: D26912554 Pulled By: myleott fbshipit-source-id: b45a161fbd52a12da13d7e011d562d35a5b5a1a7 --- .../fully_sharded_data_parallel/README.md | 137 ++++++++++-------- fairseq/models/roberta/model.py | 4 +- fairseq/models/transformer.py | 11 +- fairseq/models/transformer_lm.py | 4 +- fairseq/trainer.py | 29 ++-- 5 files changed, 104 insertions(+), 81 deletions(-) diff --git a/examples/fully_sharded_data_parallel/README.md b/examples/fully_sharded_data_parallel/README.md index bc98670968..d620f0e4f1 100644 --- a/examples/fully_sharded_data_parallel/README.md +++ b/examples/fully_sharded_data_parallel/README.md @@ -17,7 +17,7 @@ Compared to PyTorch DDP: FSDP is fully supported in fairseq via the following new arguments: * `--ddp-backend=fully_sharded`: enables full sharding via FSDP * `--cpu-offload`: offloads the optimizer state and FP32 model copy to CPU (combine with `--optimizer=cpu_adam`) -* `--no-reshard-after-forward`: increases training speed for some models and is similar to ZeRO stage 2 +* `--no-reshard-after-forward`: increases training speed for large models (1B+ params) and is similar to ZeRO stage 2 * other popular options (`--fp16`, `--update-freq`, `--checkpoint-activations`, `--offload-activations`, etc.) continue to work as normal
Limitations

@@ -59,11 +59,13 @@ CPU. In this setting, the optimizer step (Adam) happens on CPU. We also use the `--checkpoint-activations` feature (sometimes called [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html)), which further saves memory in exchange for a small increase in computation. -Requirements: -- You'll need 32GB of GPU memory and 256GB of system memory. +**Requirements:** +- Install the latest master version of fairscale: `pip install git+https://github.com/facebookresearch/fairscale.git@master` +- You'll need 32GB of GPU memory and ~256GB of system memory to train the 13B param model. +- If you have less system memory, the 6.7B param model can be trained with ~128GB of system memory, just set `--arch transformer_lm_gpt3_6_7` - We use the CPU Adam optimizer from [DeepSpeed](https://github.com/microsoft/DeepSpeed), so you'll need to `pip install deepspeed` before running the command. -Some notes: +**Notes:** - The command will take ~5 minutes to start training, during which time it will appear to be hung, since randomly initializing 13B weights can be slow. - The `--cpu-offload` feature requires training in mixed precision (`--fp16`). - Tune the `OMP_NUM_THREADS` env variable for best performance with CPU offloading. @@ -79,48 +81,54 @@ OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0 \ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ --max-update 10 --no-save --log-format json --log-interval 1 +``` + +

Example output

-# Example output: -# (...) -# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) -# (...) -# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs) -# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 -# (...) -# Adam Optimizer #0 is created with AVX2 arithmetic capability. -# Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 -# (...) -# 2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"} -# 2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"} -# 2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 -# 2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 -# 2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"} -# 2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"} -# 2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"} -# 2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"} -# 2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"} -# 2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"} -# 2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"} -# 2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"} -# 2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 -# 2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset -# 2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"} -# 2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) -# 2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"} -# 2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds ``` +(...) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +(...) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +(...) +Adam Optimizer #0 is created with AVX2 arithmetic capability. +Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +(...) +2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"} +2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"} +2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"} +2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"} +2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"} +2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"} +2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"} +2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"} +2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"} +2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"} +2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset +2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"} +2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"} +2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds +``` + +

### 13B params on 8 V100 GPUs (with full parameter + optimizer state sharding) FSDP can also shard the parameters and optimizer states across multiple GPUs, -reducing memory requirements significantly. On 8 GPUs, sharding enables +reducing memory requirements significantly. On 8 x 32GB GPUs, sharding enables training the same 13B parameter model *without offloading the parameters to CPU*. However, without CPU offloading we'd only be able to fit a batch size of 1 per GPU, which would cause training speed to suffer. We obtain the best performance on 8 GPUs by combining full sharding and CPU offloading. The following command trains the same 13B parameter GPT-3 model as -before on 8 GPUs; training speed increases from ~310 -> ~3200 words per second. +before on 8 x 32GB V100 GPUs; training speed increases superlinearly from ~310 +words per second to ~3200 words per second. ```bash OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ @@ -132,33 +140,38 @@ OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ --max-update 10 --no-save --log-format json --log-interval 1 +``` + +
Example output

-# Example output: -# (...) -# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) -# (...) -# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) -# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 -# (...) -# Adam Optimizer #0 is created with AVX2 arithmetic capability. -# Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 -# (...) -# 2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"} -# 2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"} -# 2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 -# 2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 -# 2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"} -# 2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"} -# 2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"} -# 2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"} -# 2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"} -# 2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"} -# 2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"} -# 2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"} -# 2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 -# 2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset -# 2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"} -# 2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) -# 2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"} -# 2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds ``` +(...) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +(...) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +(...) +Adam Optimizer #0 is created with AVX2 arithmetic capability. +Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +(...) +2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"} +2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"} +2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"} +2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"} +2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"} +2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"} +2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"} +2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"} +2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"} +2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"} +2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset +2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"} +2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"} +2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds +``` + +

diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index c79d4faf79..5d2ed4902d 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -179,7 +179,9 @@ def add_args(parser): "minimum number of params for a layer to be wrapped with FSDP() when " "training with --ddp-backend=fully_sharded. Smaller values will " "improve memory efficiency, but may make torch.distributed " - "communication less efficient due to smaller input sizes." + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." ) ) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index d39e9ec7ed..297807c31a 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -201,7 +201,9 @@ def add_args(parser): 'minimum number of params for a layer to be wrapped with FSDP() when ' 'training with --ddp-backend=fully_sharded. Smaller values will ' 'improve memory efficiency, but may make torch.distributed ' - 'communication less efficient due to smaller input sizes.' + 'communication less efficient due to smaller input sizes. This option ' + 'is set to 0 (i.e., always wrap) when --checkpoint-activations or ' + '--offload-activations are passed.' ) ) # fmt: on @@ -258,6 +260,7 @@ def build_model(cls, args, task): min_params_to_wrap = getattr( args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP ) + # fsdp_wrap is a no-op when --ddp-backend != fully_sharded encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) return cls(args, encoder, decoder) @@ -407,7 +410,8 @@ def build_encoder_layer(self, args): if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - # checkpointing requires alignment to FSDP wrap boundaries + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size min_params_to_wrap = ( getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) if not checkpoint else 0 @@ -754,7 +758,8 @@ def build_decoder_layer(self, args, no_encoder_attn=False): if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - # checkpointing requires alignment to FSDP wrap boundaries + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size min_params_to_wrap = ( getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) if not checkpoint else 0 diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 09c99b96f6..fca9470e5e 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -174,7 +174,9 @@ class TransformerLanguageModelConfig(FairseqDataclass): "minimum number of params for a layer to be wrapped with FSDP() when " "training with --ddp-backend=fully_sharded. Smaller values will " "improve memory efficiency, but may make torch.distributed " - "communication less efficient due to smaller input sizes." + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." ) } ) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 9435558157..dcf5305455 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -1017,21 +1017,22 @@ def set_num_updates(self, num_updates): def clip_grad_norm(self, clip_norm): def agg_norm_fn(total_norm): - if ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" - and ( - self.data_parallel_process_group is not None - or torch.distributed.is_initialized() - ) - ): - total_norm = total_norm.cuda().float() ** 2 - total_norm = distributed_utils.all_reduce( - total_norm, group=self.data_parallel_process_group - ) - total_norm = total_norm ** 0.5 - return total_norm + total_norm = total_norm.cuda().float() ** 2 + total_norm = distributed_utils.all_reduce( + total_norm, group=self.data_parallel_process_group + ) + return total_norm ** 0.5 - return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=agg_norm_fn) + should_agg_norm = ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and ( + self.data_parallel_process_group is not None + or torch.distributed.is_initialized() + ) + ) + return self.optimizer.clip_grad_norm( + clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None + ) def cumulative_training_time(self): if self._cumulative_training_time is None: From 05255f96410e5b1eaf3bf59b767d5b4b7e2c3a35 Mon Sep 17 00:00:00 2001 From: Changhan Wang Date: Tue, 9 Mar 2021 16:26:05 -0800 Subject: [PATCH 68/82] update audio_utils and fix mTEDx example Summary: update audio_utils and fix mTEDx example - Updated `audio_utils` - Added support for OGG Vorbis (the only supported lossy compressed format) - Added a separate `convert_to_mono()` helper function - Updated `get_waveform()` - added new arguments `frames` and `start` for reading part of audios - added new argument `mono` for auto conversion to mono-channel audio - unified returned waveform shape to channels x length (same as torchaudio default) - Updated mTEDx and MUST-C data prep scripts - Replaced `torchaudio.info()` with `soundfile.info()` (the latter is faster and the former has incompatible interface between <0.8 and the latest 0.8) - Replaced `torchaudio.load()` with `get_waveform` for auto conversion to mono channel Reviewed By: jmp84 Differential Revision: D26901114 fbshipit-source-id: fa9560c9714d51a91157d5141564574d4eee454d --- examples/speech_to_text/data_utils.py | 12 ++- examples/speech_to_text/docs/mtedx_example.md | 2 +- examples/speech_to_text/docs/mustc_example.md | 2 +- examples/speech_to_text/prep_mtedx_data.py | 13 +-- examples/speech_to_text/prep_mustc_data.py | 13 +-- fairseq/data/audio/audio_utils.py | 83 ++++++++++++++----- fairseq/data/audio/speech_to_text_dataset.py | 5 +- 7 files changed, 89 insertions(+), 41 deletions(-) diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py index fa0d459611..3f96ffc427 100644 --- a/examples/speech_to_text/data_utils.py +++ b/examples/speech_to_text/data_utils.py @@ -14,7 +14,10 @@ import numpy as np import pandas as pd import sentencepiece as sp -from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank +from fairseq.data.audio.audio_utils import ( + _convert_to_mono, _get_kaldi_fbank, _get_torchaudio_fbank +) +import torch from tqdm import tqdm @@ -66,7 +69,7 @@ def gen_vocab( def extract_fbank_features( - waveform, + waveform: torch.FloatTensor, sample_rate: int, output_path: Optional[Path] = None, n_mel_bins: int = 80, @@ -75,8 +78,9 @@ def extract_fbank_features( if output_path is not None and output_path.is_file() and not overwrite: return - _waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers - _waveform = _waveform.squeeze().numpy() + _waveform = _convert_to_mono(waveform, sample_rate) + _waveform = _waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers + _waveform = _waveform.numpy() features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins) if features is None: diff --git a/examples/speech_to_text/docs/mtedx_example.md b/examples/speech_to_text/docs/mtedx_example.md index c0e17db9a2..25b4556aff 100644 --- a/examples/speech_to_text/docs/mtedx_example.md +++ b/examples/speech_to_text/docs/mtedx_example.md @@ -11,7 +11,7 @@ with translations to a subset of 5 target languages. `${MTEDX_ROOT}/${LANG_PAIR}`, then preprocess it with ```bash # additional Python packages for S2T data processing/model training -pip install pandas torchaudio sentencepiece +pip install pandas torchaudio soundfile sentencepiece # Generate TSV manifests, features, vocabulary # and configuration for each language diff --git a/examples/speech_to_text/docs/mustc_example.md b/examples/speech_to_text/docs/mustc_example.md index 7628dc77ef..79df0aafdc 100644 --- a/examples/speech_to_text/docs/mustc_example.md +++ b/examples/speech_to_text/docs/mustc_example.md @@ -11,7 +11,7 @@ `${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with ```bash # additional Python packages for S2T data processing/model training -pip install pandas torchaudio sentencepiece +pip install pandas torchaudio soundfile sentencepiece # Generate TSV manifests, features, vocabulary # and configuration for each language diff --git a/examples/speech_to_text/prep_mtedx_data.py b/examples/speech_to_text/prep_mtedx_data.py index 6c37398fcc..34b1c398c8 100644 --- a/examples/speech_to_text/prep_mtedx_data.py +++ b/examples/speech_to_text/prep_mtedx_data.py @@ -14,7 +14,7 @@ from typing import Tuple import pandas as pd -import torchaudio +import soundfile as sf from examples.speech_to_text.data_utils import ( create_zip, extract_fbank_features, @@ -25,10 +25,12 @@ load_df_from_tsv, save_df_to_tsv, ) -from torch import Tensor +import torch from torch.utils.data import Dataset from tqdm import tqdm +from fairseq.data.audio.audio_utils import get_waveform + log = logging.getLogger(__name__) @@ -73,7 +75,7 @@ def __init__(self, root: str, lang: str, split: str) -> None: for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): wav_filename = wav_filename.replace(".wav", ".flac") wav_path = wav_root / wav_filename - sample_rate = torchaudio.info(wav_path.as_posix())[0].rate + sample_rate = sf.info(wav_path.as_posix()).samplerate seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) for i, segment in enumerate(seg_group): offset = int(float(segment["offset"]) * sample_rate) @@ -93,9 +95,10 @@ def __init__(self, root: str, lang: str, split: str) -> None: ) ) - def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str, str]: + def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str, str, str]: wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id = self.data[n] - waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames) + waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset) + waveform = torch.from_numpy(waveform) return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id def __len__(self) -> int: diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 45fd43533d..0ee204e651 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -15,7 +15,7 @@ import numpy as np import pandas as pd -import torchaudio +import soundfile as sf from examples.speech_to_text.data_utils import ( create_zip, extract_fbank_features, @@ -27,10 +27,12 @@ save_df_to_tsv, cal_gcmvn_stats, ) -from torch import Tensor +import torch from torch.utils.data import Dataset from tqdm import tqdm +from fairseq.data.audio.audio_utils import get_waveform + log = logging.getLogger(__name__) @@ -71,7 +73,7 @@ def __init__(self, root: str, lang: str, split: str) -> None: self.data = [] for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): wav_path = wav_root / wav_filename - sample_rate = torchaudio.info(wav_path.as_posix())[0].rate + sample_rate = sf.info(wav_path.as_posix()).samplerate seg_group = sorted(_seg_group, key=lambda x: x["offset"]) for i, segment in enumerate(seg_group): offset = int(float(segment["offset"]) * sample_rate) @@ -90,9 +92,10 @@ def __init__(self, root: str, lang: str, split: str) -> None: ) ) - def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]: + def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str, str]: wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] - waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames) + waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset) + waveform = torch.from_numpy(waveform) return waveform, sr, src_utt, tgt_utt, spk_id, utt_id def __len__(self) -> int: diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index f8cc80f5e2..ddd5642c7e 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -1,35 +1,80 @@ -import os.path as op +from pathlib import Path from typing import BinaryIO, Optional, Tuple, Union import numpy as np +import torch + + +SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"} + + +def _convert_to_mono( + waveform: torch.FloatTensor, sample_rate: int +) -> torch.FloatTensor: + if waveform.shape[0] > 1: + try: + import torchaudio.sox_effects as ta_sox + except ImportError: + raise ImportError( + "Please install torchaudio to convert multi-channel audios" + ) + effects = [['channels', '1']] + return ta_sox.apply_effects_tensor(waveform, sample_rate, effects)[0] + return waveform + + +def convert_to_mono(waveform: np.ndarray, sample_rate: int) -> np.ndarray: + if waveform.shape[0] > 1: + _waveform = torch.from_numpy(waveform) + return _convert_to_mono(_waveform, sample_rate).numpy() + return waveform def get_waveform( - path_or_fp: Union[str, BinaryIO], normalization=True + path_or_fp: Union[str, BinaryIO], normalization=True, mono=True, + frames=-1, start=0, always_2d=True ) -> Tuple[np.ndarray, int]: - """Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC. + """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio. Args: path_or_fp (str or BinaryIO): the path or file-like object normalization (bool): Normalize values to [-1, 1] (Default: True) + mono (bool): convert multi-channel audio to mono-channel one + frames (int): the number of frames to read. (-1 for reading all) + start (int): Where to start reading. A negative value counts from the end. + always_2d (bool): always return 2D array even for mono-channel audios + Returns: + waveform (numpy.ndarray): 1D or 2D waveform (channels x length) + sample_rate (float): sample rate """ if isinstance(path_or_fp, str): - ext = op.splitext(op.basename(path_or_fp))[1] - if ext not in {".flac", ".wav"}: + ext = Path(path_or_fp).suffix + if ext not in SF_AUDIO_FILE_EXTENSIONS: raise ValueError(f"Unsupported audio format: {ext}") try: import soundfile as sf except ImportError: - raise ImportError("Please install soundfile to load WAV/FLAC file") + raise ImportError( + "Please install soundfile to load WAV/FLAC/OGG Vorbis audios" + ) - waveform, sample_rate = sf.read(path_or_fp, dtype="float32") + waveform, sample_rate = sf.read( + path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start + ) + waveform = waveform.T # T x C -> C x T + if mono and waveform.shape[0] > 1: + waveform = convert_to_mono(waveform, sample_rate) if not normalization: waveform *= 2 ** 15 # denormalized to 16-bit signed integers + if not always_2d: + waveform = waveform.squeeze(axis=0) return waveform, sample_rate -def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: +def _get_kaldi_fbank( + waveform: np.ndarray, sample_rate: int, n_bins=80 +) -> Optional[np.ndarray]: """Get mel-filter bank features via PyKaldi.""" try: from kaldi.feat.mel import MelBanksOptions @@ -45,27 +90,19 @@ def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: opts.mel_opts = mel_opts opts.frame_opts = frame_opts fbank = Fbank(opts=opts) - features = fbank.compute(Vector(waveform), 1.0).numpy() + features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy() return features except ImportError: return None -def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: +def _get_torchaudio_fbank( + waveform: np.ndarray, sample_rate, n_bins=80 +) -> Optional[np.ndarray]: """Get mel-filter bank features via TorchAudio.""" try: - import torch import torchaudio.compliance.kaldi as ta_kaldi - import torchaudio.sox_effects as ta_sox - waveform = torch.from_numpy(waveform) - if len(waveform.shape) == 1: - # Mono channel: D -> 1 x D - waveform = waveform.unsqueeze(0) - else: - # Merge multiple channels to one: D x C -> 1 x D - waveform, _ = ta_sox.apply_effects_tensor(waveform.T, sample_rate, [['channels', '1']]) - features = ta_kaldi.fbank( waveform, num_mel_bins=n_bins, sample_frequency=sample_rate ) @@ -79,11 +116,11 @@ def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray: (faster CPP implementation) to TorchAudio (Python implementation). Note that Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the waveform should not be normalized.""" - sound, sample_rate = get_waveform(path_or_fp, normalization=False) + waveform, sample_rate = get_waveform(path_or_fp, normalization=False) - features = _get_kaldi_fbank(sound, sample_rate, n_bins) + features = _get_kaldi_fbank(waveform, sample_rate, n_bins) if features is None: - features = _get_torchaudio_fbank(sound, sample_rate, n_bins) + features = _get_torchaudio_fbank(waveform, sample_rate, n_bins) if features is None: raise ImportError( "Please install pyKaldi or torchaudio to enable " diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index 39d22c7a5e..c6c64db084 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -153,7 +153,8 @@ def get_features_or_waveform_from_uncompressed_zip( if is_npy_data(data): features_or_waveform = np.load(f) elif is_flac_or_wav_data(data): - features_or_waveform = get_waveform(f)[0] if need_waveform else get_fbank(f) + features_or_waveform = \ + get_waveform(f, always_2d=False)[0] if need_waveform else get_fbank(f) else: raise ValueError(f'Unknown file format for "{path}"') return features_or_waveform @@ -178,7 +179,7 @@ def get_features_or_waveform(path: str, need_waveform=False): if len(extra) == 0: if need_waveform: - return get_waveform(_path) + return get_waveform(_path, always_2d=False) return get_features_from_npy_or_audio(_path) elif len(extra) == 2: extra = [int(i) for i in extra] From d031611ce49cb231653cf9246667ac237cbbdaff Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Wed, 10 Mar 2021 20:32:49 -0800 Subject: [PATCH 69/82] Update simul trans doc (#1683) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1683 Reviewed By: jmp84 Differential Revision: D26914869 Pulled By: xutaima fbshipit-source-id: a5d2efdcff1852e56304e77838840b3aad5124b0 --- .../docs/simulst_mustc_example.md | 39 +++++++++++++++---- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/examples/speech_to_text/docs/simulst_mustc_example.md b/examples/speech_to_text/docs/simulst_mustc_example.md index 0144fcb766..3452806a1c 100644 --- a/examples/speech_to_text/docs/simulst_mustc_example.md +++ b/examples/speech_to_text/docs/simulst_mustc_example.md @@ -1,6 +1,6 @@ # Simultaneous Speech Translation (SimulST) on MuST-C -This is an instruction of training and evaluating a transformer *wait-k* simultaneous model on MUST-C English-Germen Dataset, from [SimulMT to SimulST: Adapting Simultaneous Text Translation to End-to-End Simultaneous Speech Translation](https://www.aclweb.org/anthology/2020.aacl-main.58.pdf). +This is a tutorial of training and evaluating a transformer *wait-k* simultaneous model on MUST-C English-Germen Dataset, from [SimulMT to SimulST: Adapting Simultaneous Text Translation to End-to-End Simultaneous Speech Translation](https://www.aclweb.org/anthology/2020.aacl-main.58.pdf). [MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with 8-language translations on English TED talks. @@ -14,18 +14,21 @@ pip install pandas torchaudio sentencepiece # Generate TSV manifests, features, vocabulary, # global cepstral and mean estimation, # and configuration for each language +cd fairseq + python examples/speech_to_text/prep_mustc_data.py \ --data-root ${MUSTC_ROOT} --task asr \ --vocab-type unigram --vocab-size 10000 \ --cmvn-type global + python examples/speech_to_text/prep_mustc_data.py \ --data-root ${MUSTC_ROOT} --task st \ - --vocab-type unigram --vocab-size 10000 + --vocab-type unigram --vocab-size 10000 \ --cmvn-type global ``` ## ASR Pretraining -We just need a pretrained offline ASR model +We need a pretrained offline ASR model. Assuming the save directory of the ASR model is `${ASR_SAVE_DIR}` ``` fairseq-train ${MUSTC_ROOT}/en-de \ --config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \ @@ -34,21 +37,22 @@ fairseq-train ${MUSTC_ROOT}/en-de \ --arch convtransformer_espnet --optimizer adam --lr 0.0005 --lr-scheduler inverse_sqrt \ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 ``` +A pretrained ASR checkpoint can be downloaded [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1_en_de_pretrained_asr) ## Simultaneous Speech Translation Training ### Wait-K with fixed pre-decision module Fixed pre-decision indicates that the model operate simultaneous policy on the boundaries of fixed chunks. Here is a example of fixed pre-decision ratio 7 (the simultaneous decision is made every 7 encoder states) and -a wait-3 policy model -``` +a wait-3 policy model. Assuming the save directory is `${ST_SAVE_DIR}` +```bash fairseq-train ${MUSTC_ROOT}/en-de \ --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ --save-dir ${ST_SAVE_DIR} --num-workers 8 \ --optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \ --criterion label_smoothed_cross_entropy \ --warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \ - --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/checkpoint_best.pt \ --task speech_to_text \ --arch convtransformer_simul_trans_espnet \ --simul-type waitk_fixed_pre_decision \ @@ -76,7 +80,9 @@ a wait-3 policy model The source file is a list of paths of audio files, while target file is the corresponding translations. ``` -pip install simuleval +git clone https://github.com/facebookresearch/SimulEval.git +cd SimulEval +pip install -e . simuleval \ --agent examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -89,7 +95,24 @@ simuleval \ --scores ``` -A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms. The databin (containing dictionary, gcmvn file and sentencepiece model) can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin). +A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms. The databin (containing dictionary, gcmvn file and sentencepiece model) can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin.tgz). + +The output should be similar as follow: +```bash +{ + "Quality": { + "BLEU": 12.79214535384013 + }, + "Latency": { + "AL": 1669.5778120018108, + "AL_CA": 2077.9027656104813, + "AP": 0.7652936521983029, + "AP_CA": 0.8891561507382866, + "DAL": 2028.1566141735727, + "DAL_CA": 2497.336430059716 + } +} +``` The quality is measured by detokenized BLEU. So make sure that the predicted words sent to the server are detokenized. From 2235f86b40da5915cd801c4f2f29de4c17c9804b Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 12 Mar 2021 12:29:40 -0800 Subject: [PATCH 70/82] PlasmaView: don't materialize array in memory (#1645) Summary: ### Changes: - `PlasmaArray` saves the underlying data to `self.array`, `PlasmaView` never does that, instead it fetches the data from `plasma_store` shared memory when it is needed. - `PlasmaArray` starts a new, ephemeral plasma_store and puts a new array in it when it is pickled. If `--use-plasma-view`, there is one server started before `spawn` and arrays are only put into it once, in `PlasmaArray.__init__` to accommodate this. - user can now pass `--plasma-path` to explicitly control where server is started. - We now make plasma keys based on `(split_path, (block_size, document_sep_len, str(break_mode), len(dataset)))`, so two jobs sharing plasma server but with different datasets, or same dataset but different clargs, will read each the other's array. ### Results [pre March 1] This saves some CPU memory (5-15%), according to both `psutil` and `psrecord`: here we run base_cmd (below) with num_workers=0,2,8, 2 GPUS and collect the logs. `branch` refers to `--use-plasma-view`, `master` uses `PlasmaArray` ``` +-------------------------+----------------+---------+-------+ | setting | cpu_mem_used | wps | ppl | +=========================+================+=========+=======+ | branch_nw0_gpu2_ddm.log | 12 | 55143.2 | 429.1 | +-------------------------+----------------+---------+-------+ | branch_nw2_gpu2_ddm.log | 13.67 | 43377.6 | 429.1 | +-------------------------+----------------+---------+-------+ | branch_nw8_gpu2_ddm.log | 18.36 | 53019.9 | 429.1 | +-------------------------+----------------+---------+-------+ | master_nw0_gpu2_ddm.log | 12.26 | 56733 | 429.1 | +-------------------------+----------------+---------+-------+ | master_nw2_gpu2_ddm.log | 14.58 | 53337.9 | 429.1 | +-------------------------+----------------+---------+-------+ | master_nw8_gpu2_ddm.log | 21.1 | 53217.2 | 429.1 | +-------------------------+----------------+---------+-------+ ``` ### Replication 1) get this branch ```bash git fetch && git checkout share-plasma-server ``` 2) Train tiny model and save logs ```bash base_cmd () { fairseq-train --fp16 /private/home/sshleifer/data-bin/stories_mmap \ --task language_modeling \ --arch transformer_lm_gpt2_tiny \ --sample-break-mode complete --tokens-per-sample 512 \ --optimizer adam --clip-norm 0.0 --lr 0.0005 \ --batch-size 1 \ --max-update 200 --max-epoch 1 \ --log-format simple --log-interval 100 \ --restore-file x.pt --no-save \ --skip-invalid-size-inputs-valid-test --disable-validation $@ } USE_LOCK=1 CUDA_VISIBLE_DEVICES=0,1 base_cmd --num-workers 0 --use-plasma-view | tee branch_nw0_gpu2_ddm.log ``` ### TODO: - [x] test larger dataset - [x] make it optional, cleanup - [x] 1 GPU - [x] unit-tests - [x] ask hashing Q on stackoverflow https://stackoverflow.com/questions/66354598/deterministic-method-to-hash-np-array-int - [ ] measure whether `PlasmaArray` disable for small array's logic helps - [ x] test with fb_sweep - [ x] measure 4 GPU savings Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1645 Test Plan: Read github PR description: https://github.com/fairinternal/fairseq-py/pull/1645 Reviewed By: myleott Differential Revision: D26630365 Pulled By: sshleifer fbshipit-source-id: b0c4163fbc97a7aefb116de70265fba11f6d7b42 --- fairseq/data/plasma_utils.py | 134 +++++++++++++++++++++++++--- fairseq/data/token_block_dataset.py | 67 +++++++++----- fairseq/dataclass/configs.py | 83 +++++++---------- fairseq/tasks/language_modeling.py | 11 ++- fairseq/trainer.py | 5 +- fairseq_cli/train.py | 9 +- tests/test_plasma_utils.py | 127 ++++++++++++++++++++++++++ 7 files changed, 343 insertions(+), 93 deletions(-) create mode 100644 tests/test_plasma_utils.py diff --git a/fairseq/data/plasma_utils.py b/fairseq/data/plasma_utils.py index f4bb6472d7..b9fab3b739 100644 --- a/fairseq/data/plasma_utils.py +++ b/fairseq/data/plasma_utils.py @@ -3,11 +3,23 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + import subprocess +import json import tempfile +import hashlib +from typing import Hashable + +try: + import pyarrow.plasma as plasma + + PYARROW_AVAILABLE = True +except ImportError: + plasma = None + PYARROW_AVAILABLE = False -class PlasmaArray(object): +class PlasmaArray: """ Wrapper around numpy arrays that automatically moves the data to shared memory upon serialization. This is particularly helpful when passing numpy @@ -31,12 +43,7 @@ def __init__(self, array): @property def plasma(self): if self._plasma is None and not self.disable: - try: - import pyarrow.plasma as plasma - - self._plasma = plasma - except ImportError: - self._plasma = None + self._plasma = plasma return self._plasma def start_server(self): @@ -47,13 +54,7 @@ def start_server(self): self._server_tmp = tempfile.NamedTemporaryFile() self.path = self._server_tmp.name self._server = subprocess.Popen( - [ - "plasma_store", - "-m", - str(int(1.05 * self.array.nbytes)), - "-s", - self.path, - ] + ["plasma_store", "-m", str(int(1.05 * self.array.nbytes)), "-s", self.path] ) @property @@ -64,6 +65,7 @@ def client(self): return self._client def __getstate__(self): + """Called on pickle load""" if self.plasma is None: return self.__dict__ if self.object_id is None: @@ -78,6 +80,7 @@ def __getstate__(self): return state def __setstate__(self, state): + """Called on pickle save""" self.__dict__.update(state) if self.plasma is None: return @@ -89,3 +92,106 @@ def __del__(self): self._server = None self._server_tmp.close() self._server_tmp = None + + +DEFAULT_PLASMA_PATH = "/tmp/plasma" + + +class PlasmaView: + """Interface to write and read from shared memory. Whereas PlasmaArray writes to plasma on serialization, + PlasmaView writes to shared memory on instantiation.""" + + def __init__(self, array, split_path: str, hash_data: Hashable, plasma_path=None): + """ + Args: + array: numpy array to store. This can be read with ``PlasmaView().array`` + split_path: the path whence the data was read, used for hashing + hash_data: other metadata about the array that can be used to create a unique key. + as of writing, the 3 callers in ``TokenBlockDataset`` use:: + + hash_data = ((block_size, document_sep_len, str(break_mode), len(dataset)), 0|1|2) + + + """ + assert PYARROW_AVAILABLE + assert split_path is not None + if plasma_path is None: + plasma_path = DEFAULT_PLASMA_PATH + + self.path = plasma_path + self.split_path = split_path + self._client = None # Initialize lazily for pickle. plasma clients should not be deep copied or serialized. + self._n = None + + self.object_id = self.get_object_id(self.split_path, hash_data) + try: + self.client.put(array, object_id=self.object_id) + except plasma.PlasmaObjectExists: + pass + + @property + def client(self): + if self._client is None: + self._client = plasma.connect(self.path, num_retries=200) + return self._client + + @property + def array(self): + """Fetch a read only view of an np.array, stored in plasma.""" + ret = self.client.get(self.object_id) + return ret + + @staticmethod + def get_object_id(split_path: str, hash_data: Hashable): + """Returns plasma.ObjectID from hashing split_path and object_num.""" + hash = hashlib.blake2b(bytes(split_path, "utf-8"), digest_size=20) + harg = json.dumps(hash_data).encode("utf-8") + hash.update(harg) + return plasma.ObjectID(hash.digest()) + + def __getstate__(self): + """Called on pickle save""" + self.disconnect() + state = self.__dict__.copy() + assert state["_client"] is None + assert "object_id" in state + return state + + def __setstate__(self, state): + """Called on pickle load""" + self.__dict__.update(state) + + def __del__(self): + self.disconnect() + + def disconnect(self): + if self._client is not None: + self._client.disconnect() + self._client = None + + def __len__(self): + """Save reads by caching len""" + if self._n is None: + self._n = len(self.array) + return self._n + + +GB100 = (1024 ** 3) * 100 + + +class PlasmaStore: + def __init__(self, path=DEFAULT_PLASMA_PATH, nbytes: int = GB100): + + self.server = self.start(path, nbytes) + + def __del__(self): + self.server.kill() + + @staticmethod + def start(path=DEFAULT_PLASMA_PATH, nbytes: int = GB100) -> subprocess.Popen: + if not PYARROW_AVAILABLE: + raise ImportError("please run pip install pyarrow to use --use_plasma_view") + # best practice is to allocate more space than we need. The limitation seems to be the size of /dev/shm + _server = subprocess.Popen(["plasma_store", "-m", str(nbytes), "-s", path]) + plasma.connect(path, num_retries=200) # If we can't connect we fail immediately + return _server diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index ce0a0d1114..d2c65fd7e0 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -7,6 +7,7 @@ import torch from fairseq.data import FairseqDataset, plasma_utils from fairseq.data.indexed_dataset import best_fitting_int_dtype +from typing import Tuple class TokenBlockDataset(FairseqDataset): @@ -42,7 +43,46 @@ def __init__( break_mode=None, include_targets=False, document_sep_len=1, + use_plasma_view=False, + split_path=None, + plasma_path=None, ): + + super().__init__() + self.dataset = dataset + self.pad = pad + self.eos = eos + self.include_targets = include_targets + + assert len(dataset) > 0 + + assert len(dataset) == len(sizes) + _sizes, block_to_dataset_index, slice_indices = self._build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) + if use_plasma_view: + plasma_id = (block_size, document_sep_len, str(break_mode), len(dataset)) + self._slice_indices = plasma_utils.PlasmaView( + slice_indices, split_path, (plasma_id, 0), plasma_path=plasma_path + ) + self._sizes = plasma_utils.PlasmaView( + _sizes, split_path, (plasma_id, 1), plasma_path=plasma_path + ) + self._block_to_dataset_index = plasma_utils.PlasmaView( + block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path, + ) + else: + self._slice_indices = plasma_utils.PlasmaArray(slice_indices) + self._sizes = plasma_utils.PlasmaArray(_sizes) + self._block_to_dataset_index = plasma_utils.PlasmaArray( + block_to_dataset_index + ) + + @staticmethod + def _build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) -> Tuple[np.ndarray]: + """Use token_block_utils_fast to build arrays for indexing into self.dataset""" try: from fairseq.data.token_block_utils_fast import ( _get_slice_indices_fast, @@ -54,15 +94,6 @@ def __init__( "or `python setup.py build_ext --inplace`" ) - super().__init__() - self.dataset = dataset - self.pad = pad - self.eos = eos - self.include_targets = include_targets - - assert len(dataset) == len(sizes) - assert len(dataset) > 0 - if isinstance(sizes, list): sizes = np.array(sizes, dtype=np.int64) else: @@ -79,7 +110,7 @@ def __init__( slice_indices = _get_slice_indices_fast( sizes, str(break_mode), block_size, document_sep_len ) - self._sizes = slice_indices[:, 1] - slice_indices[:, 0] + _sizes = slice_indices[:, 1] - slice_indices[:, 0] # build index mapping block indices to the underlying dataset indices if break_mode == "eos": @@ -99,15 +130,12 @@ def __init__( sizes, slice_indices, ) size_dtype = np.uint16 if block_size < 65535 else np.uint32 - slice_indices_dtype = best_fitting_int_dtype(slice_indices[-1].max()) - - self._slice_indices = plasma_utils.PlasmaArray( - slice_indices.astype(slice_indices_dtype) - ) - self._sizes = plasma_utils.PlasmaArray(self._sizes.astype(size_dtype)) - self._block_to_dataset_index = plasma_utils.PlasmaArray( - block_to_dataset_index.astype(slice_indices_dtype) - ) + num_tokens = slice_indices[-1].max() + slice_indices_dtype = best_fitting_int_dtype(num_tokens) + slice_indices = slice_indices.astype(slice_indices_dtype) + _sizes = _sizes.astype(size_dtype) + block_to_dataset_index = block_to_dataset_index.astype(slice_indices_dtype) + return _sizes, block_to_dataset_index, slice_indices @property def slice_indices(self): @@ -131,7 +159,6 @@ def __getitem__(self, index): buffer = torch.cat( [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] ) - slice_s, slice_e = self.slice_indices[index] length = slice_e - slice_s s, e = start_offset, start_offset + length diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 5d6aee157a..3c29be9197 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -104,15 +104,10 @@ class CommonConfig(FairseqDataclass): ) wandb_project: Optional[str] = field( default=None, - metadata={ - "help": "Weights and Biases project name to use for logging" - }, + metadata={"help": "Weights and Biases project name to use for logging"}, ) azureml_logging: Optional[bool] = field( - default=False, - metadata={ - "help": "Log scalars to AzureML context" - }, + default=False, metadata={"help": "Log scalars to AzureML context"}, ) seed: int = field( default=1, metadata={"help": "pseudo random number generator seed"} @@ -192,6 +187,15 @@ class CommonConfig(FairseqDataclass): "main method can return a value (useful for sweeps)" }, ) + use_plasma_view: bool = field( + default=False, metadata={"help": "Store indices and sizes in shared memory"} + ) + plasma_path: Optional[str] = field( + default="/tmp/plasma", + metadata={ + "help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail." + }, + ) @dataclass @@ -263,7 +267,7 @@ class DistributedTrainingConfig(FairseqDataclass): metadata={ "help": "kill the job if no progress is made in N seconds; " "set to -1 to disable" - } + }, ) broadcast_buffers: bool = field( default=False, @@ -360,16 +364,13 @@ class DistributedTrainingConfig(FairseqDataclass): tpu: bool = II("common.tpu") # configuration for --ddp-backend=fully_sharded no_reshard_after_forward: bool = field( - default=False, - metadata={"help": "don't reshard parameters after forward pass"}, + default=False, metadata={"help": "don't reshard parameters after forward pass"}, ) fp32_reduce_scatter: bool = field( - default=False, - metadata={"help": "reduce-scatter grads in FP32"}, + default=False, metadata={"help": "reduce-scatter grads in FP32"}, ) cpu_offload: bool = field( - default=False, - metadata={"help": "offload FP32 params to CPU"} + default=False, metadata={"help": "offload FP32 params to CPU"} ) @@ -665,12 +666,10 @@ class FairseqBMUFConfig(FairseqDataclass): @dataclass class GenerationConfig(FairseqDataclass): beam: int = field( - default=5, - metadata={"help": "beam size"}, + default=5, metadata={"help": "beam size"}, ) nbest: int = field( - default=1, - metadata={"help": "number of hypotheses to output"}, + default=1, metadata={"help": "number of hypotheses to output"}, ) max_len_a: float = field( default=0, @@ -685,24 +684,19 @@ class GenerationConfig(FairseqDataclass): }, ) min_len: int = field( - default=1, - metadata={"help": "minimum generation length"}, + default=1, metadata={"help": "minimum generation length"}, ) match_source_len: bool = field( - default=False, - metadata={"help": "generations should match the source length"}, + default=False, metadata={"help": "generations should match the source length"}, ) unnormalized: bool = field( - default=False, - metadata={"help": "compare unnormalized hypothesis scores"}, + default=False, metadata={"help": "compare unnormalized hypothesis scores"}, ) no_early_stop: bool = field( - default=False, - metadata={"help": "deprecated"}, + default=False, metadata={"help": "deprecated"}, ) no_beamable_mm: bool = field( - default=False, - metadata={"help": "don't use BeamableMM in attention layers"}, + default=False, metadata={"help": "don't use BeamableMM in attention layers"}, ) lenpen: float = field( default=1, @@ -724,12 +718,10 @@ class GenerationConfig(FairseqDataclass): }, ) sacrebleu: bool = field( - default=False, - metadata={"help": "score with sacrebleu"}, + default=False, metadata={"help": "score with sacrebleu"}, ) score_reference: bool = field( - default=False, - metadata={"help": "just score the reference translation"}, + default=False, metadata={"help": "just score the reference translation"}, ) prefix_size: int = field( default=0, @@ -763,12 +755,10 @@ class GenerationConfig(FairseqDataclass): }, ) temperature: float = field( - default=1.0, - metadata={"help": "temperature for generation"}, + default=1.0, metadata={"help": "temperature for generation"}, ) diverse_beam_groups: int = field( - default=-1, - metadata={"help": "number of groups for Diverse Beam Search"}, + default=-1, metadata={"help": "number of groups for Diverse Beam Search"}, ) diverse_beam_strength: float = field( default=0.5, @@ -787,16 +777,13 @@ class GenerationConfig(FairseqDataclass): }, ) print_step: bool = field( - default=False, - metadata={"help": "print steps"}, + default=False, metadata={"help": "print steps"}, ) lm_path: Optional[str] = field( - default=None, - metadata={"help": "path to lm checkpoint for lm fusion"}, + default=None, metadata={"help": "path to lm checkpoint for lm fusion"}, ) lm_weight: float = field( - default=0.0, - metadata={"help": "weight for lm probs for lm fusion"}, + default=0.0, metadata={"help": "weight for lm probs for lm fusion"}, ) # arguments for iterative refinement generator @@ -805,8 +792,7 @@ class GenerationConfig(FairseqDataclass): metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, ) iter_decode_max_iter: int = field( - default=10, - metadata={"help": "maximum iterations for iterative refinement."}, + default=10, metadata={"help": "maximum iterations for iterative refinement."}, ) iter_decode_force_max_iter: bool = field( default=False, @@ -833,8 +819,7 @@ class GenerationConfig(FairseqDataclass): }, ) retain_dropout: bool = field( - default=False, - metadata={"help": "Use dropout at inference time"}, + default=False, metadata={"help": "Use dropout at inference time"}, ) # temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed # retain_dropout_modules: Optional[List[str]] = field( @@ -859,8 +844,7 @@ class GenerationConfig(FairseqDataclass): @dataclass class CommonEvalConfig(FairseqDataclass): path: Optional[str] = field( - default=None, - metadata={"help": "path(s) to model file(s), colon separated"}, + default=None, metadata={"help": "path(s) to model file(s), colon separated"}, ) post_process: Optional[str] = field( default=None, @@ -922,8 +906,7 @@ class InteractiveConfig(FairseqDataclass): }, ) input: str = field( - default="-", - metadata={"help": "file to read from; use - for stdin"}, + default="-", metadata={"help": "file to read from; use - for stdin"}, ) diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 579bf69785..a3847733a1 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -91,6 +91,8 @@ class LanguageModelingConfig(FairseqDataclass): ) data_buffer_size: int = II("dataset.data_buffer_size") tpu: bool = II("common.tpu") + use_plasma_view: bool = II("common.use_plasma_view") + plasma_path: str = II("common.plasma_path") @register_task("language_modeling", dataclass=LanguageModelingConfig) @@ -198,13 +200,12 @@ def load_dataset( data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) + # each process has its own copy of the raw data (likely to be an np.memmap) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine ) if dataset is None: - raise FileNotFoundError( - "Dataset not found: {} ({})".format(split, split_path) - ) + raise FileNotFoundError(f"Dataset not found: {split} ({split_path})") dataset = maybe_shorten_dataset( dataset, @@ -214,7 +215,6 @@ def load_dataset( self.args.tokens_per_sample, self.args.seed, ) - dataset = TokenBlockDataset( dataset, dataset.sizes, @@ -223,6 +223,9 @@ def load_dataset( eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, + use_plasma_view=self.args.use_plasma_view, + split_path=split_path, + plasma_path=self.args.plasma_path, ) add_eos_for_other_targets = ( diff --git a/fairseq/trainer.py b/fairseq/trainer.py index dcf5305455..ee29ed65a8 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -1165,10 +1165,7 @@ def _all_gather_list_sync( return logging_outputs, extra_stats_to_sum def _fast_stat_sync_sum( - self, - logging_outputs: List[Dict[str, Any]], - *extra_stats_to_sum, - ignore=False, + self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, ): """ Sync logging outputs across workers. fast_stat_sync_sum is diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index d770e4e4ec..d618817e46 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -24,6 +24,7 @@ utils, ) from fairseq.data import iterators +from fairseq.data.plasma_utils import PlasmaStore from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils @@ -118,7 +119,6 @@ def main(cfg: FairseqConfig) -> None: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info( "training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size @@ -465,6 +465,10 @@ def cli_main( cfg = convert_namespace_to_omegaconf(args) + if cfg.common.use_plasma_view: + server = PlasmaStore(path=cfg.common.plasma_path) + logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}") + if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): @@ -472,6 +476,9 @@ def cli_main( else: distributed_utils.call_main(cfg, main) + # if cfg.common.use_plasma_view: + # server.server.kill() + if __name__ == "__main__": cli_main() diff --git a/tests/test_plasma_utils.py b/tests/test_plasma_utils.py new file mode 100644 index 0000000000..5737530e3d --- /dev/null +++ b/tests/test_plasma_utils.py @@ -0,0 +1,127 @@ +import contextlib +import unittest +import tempfile +from io import StringIO + +import numpy as np + +from tests.test_binaries import train_language_model +from tests.utils import create_dummy_data, preprocess_lm_data + +try: + from pyarrow import plasma + from fairseq.data.plasma_utils import PlasmaView, PlasmaStore + + PYARROW_AVAILABLE = True +except ImportError: + PYARROW_AVAILABLE = False + +dummy_path = 'dummy' + + +@unittest.skipUnless(PYARROW_AVAILABLE, "") +class TestPlasmaView(unittest.TestCase): + def setUp(self) -> None: + self.tmp_file = tempfile.NamedTemporaryFile() # noqa: P201 + self.path = self.tmp_file.name + self.server = PlasmaStore.start(path=self.path) + self.client = plasma.connect(self.path, num_retries=10) + + def tearDown(self) -> None: + self.client.disconnect() + self.tmp_file.close() + self.server.kill() + + def test_two_servers_do_not_share_object_id_space(self): + data_server_1 = np.array([0, 1]) + data_server_2 = np.array([2, 3]) + server_2_path = self.path + with tempfile.NamedTemporaryFile() as server_1_path: + server = PlasmaStore.start(path=server_1_path.name, nbytes=10000) + arr1 = PlasmaView( + data_server_1, dummy_path, 1, plasma_path=server_1_path.name + ) + assert len(arr1.client.list()) == 1 + assert (arr1.array == data_server_1).all() + arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path) + assert (arr2.array == data_server_2).all() + assert (arr1.array == data_server_1).all() + server.kill() + + def test_hash_collision(self): + data_server_1 = np.array([0, 1]) + data_server_2 = np.array([2, 3]) + arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path) + assert len(arr1.client.list()) == 1 + arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path) + assert len(arr1.client.list()) == 1 + assert len(arr2.client.list()) == 1 + assert (arr2.array == data_server_1).all() + # New hash key based on tuples + arr3 = PlasmaView( + data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path + ) + assert ( + len(arr2.client.list()) == 2 + ), "No new object was created by using a novel hash key" + assert ( + arr3.object_id in arr2.client.list() + ), "No new object was created by using a novel hash key" + assert ( + arr3.object_id in arr3.client.list() + ), "No new object was created by using a novel hash key" + del arr3, arr2, arr1 + + @staticmethod + def _assert_view_equal(pv1, pv2): + np.testing.assert_array_equal(pv1.array, pv2.array) + + def test_putting_same_array_twice(self): + data = np.array([4, 4, 4]) + arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path) + assert len(self.client.list()) == 1 + arr1b = PlasmaView( + data, dummy_path, 1, plasma_path=self.path + ) # should not change contents of store + arr1c = PlasmaView( + None, dummy_path, 1, plasma_path=self.path + ) # should not change contents of store + + assert len(self.client.list()) == 1 + self._assert_view_equal(arr1, arr1b) + self._assert_view_equal(arr1, arr1c) + PlasmaView( + data, dummy_path, 2, plasma_path=self.path + ) # new object id, adds new entry + assert len(self.client.list()) == 2 + + new_client = plasma.connect(self.path) + assert len(new_client.list()) == 2 # new client can access same objects + assert isinstance(arr1.object_id, plasma.ObjectID) + del arr1b + del arr1c + + def test_plasma_store_full_raises(self): + with tempfile.NamedTemporaryFile() as new_path: + server = PlasmaStore.start(path=new_path.name, nbytes=10000) + with self.assertRaises(plasma.PlasmaStoreFull): + # 2000 floats is more than 2000 bytes + PlasmaView( + np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name + ) + server.kill() + + def test_object_id_overflow(self): + PlasmaView.get_object_id("", 2 ** 21) + + def test_training_lm_plasma(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "transformer_lm", + ["--use-plasma-view", "--plasma-path", self.path], + run_validation=True, + ) From 252d5a9ae93e68254cfb1896fb5624cf11cda15e Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Fri, 12 Mar 2021 16:45:40 -0800 Subject: [PATCH 71/82] Fix a bug that FairseqSimulSTAgent is not an agent (#1690) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1690 Reviewed By: jmp84 Differential Revision: D27025669 Pulled By: xutaima fbshipit-source-id: 8125365adedfdc938813d08e911e1f6ebe4f584b --- .../agents/fairseq_simul_st_agent.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py index 2b5fdc2d3f..8b8003e1d5 100644 --- a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -10,12 +10,11 @@ try: from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS + from simuleval.agents import SpeechAgent from simuleval.states import ListEntry, SpeechStates except ImportError: print("Please install simuleval 'pip install simuleval'") -from torch import nn - SHIFT_SIZE = 10 WINDOW_SIZE = 25 SAMPLE_RATE = 16000 @@ -65,7 +64,7 @@ def __call__(self, new_samples): input_samples = samples[:effective_num_samples] self.previous_residual_samples = samples[ - num_frames * self.num_samples_per_shift : + num_frames * self.num_samples_per_shift: ] torch.manual_seed(1) @@ -113,12 +112,12 @@ def info(self): } -class FairseqSimulSTAgent(nn.Module): +class FairseqSimulSTAgent(SpeechAgent): speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size def __init__(self, args): - super().__init__() + super().__init__(args) self.eos = DEFAULT_EOS @@ -218,6 +217,9 @@ def load_model_vocab(self, args): task_args = state["cfg"]["task"] task_args.data = args.data_bin + if args.config is not None: + task_args.config_yaml = args.config + task = self.set_up_task(task_args) # build model for ensemble From 965240c784910895b05e66d7ef7e15321050b414 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Sun, 14 Mar 2021 20:55:46 -0700 Subject: [PATCH 72/82] optimize memory when loading large checkpoints by deleting state dict early Summary: I had some issues with loading checkpoints from 5B parameter models (60 GB checkpoint files) due to OOM. Reviewed By: myleott Differential Revision: D27027616 fbshipit-source-id: 2b816e8e46ec80f0ec721aa7a6702cee531b94eb --- fairseq/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index ee29ed65a8..1c4c532dd0 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -442,10 +442,14 @@ def load_checkpoint( self.model.load_state_dict( state["model"], strict=True, model_cfg=self.cfg.model ) + # save memory for later steps + del state["model"] if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict( state["criterion"], strict=True ) + del state["criterion"] + except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " From 4f833342bf8b6b51033b4d5db61faf677b05b57c Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Fri, 12 Feb 2021 19:23:47 +0000 Subject: [PATCH 73/82] Improve tpu related utils. - xmp.spawn 8 or 1 processes instead of always 8. - util function to get the xla metrics report.o - util functions to move stuff to/from tpu. - make utils.item a no-op for xla. This is not critical on xla, and causes big performance hit. - util function to check if a tensor is on xla device. - util function to do torch.index_put efficiently on xla. dd --- fairseq/distributed/utils.py | 6 ++++-- fairseq/logging/metrics.py | 8 ++++++++ fairseq/utils.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index 710ca18628..970b784915 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -281,7 +281,6 @@ def distributed_init(cfg: FairseqConfig): cfg.distributed_training.device_id = xm.get_local_ordinal() cfg.distributed_training.distributed_rank = xm.get_ordinal() xm.rendezvous("distributed_init") # wait for all workers - xm.mark_step() if is_master(cfg.distributed_training): logging.getLogger().setLevel(logging.INFO) @@ -357,7 +356,10 @@ def call_main(cfg: FairseqConfig, main, **kwargs): xmp.spawn( fn=distributed_main, args=(main, cfg, kwargs), - nprocs=8, # use all 8 TPU cores + # tpu-comment: + # 8 devices in one TPU VM, is the max processes to be spawned. + # The rest is driven by xm.distributed.xla_dist + nprocs=min(cfg.distributed_training.distributed_world_size, 8), ) else: # single GPU main diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index 7b56e31592..2bb1da086f 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -286,3 +286,11 @@ def load_state_dict(state_dict): for name, agg_state in state_dict.items(): _aggregators[name] = MetersDict() _aggregators[name].load_state_dict(agg_state) + + +def xla_metrics_report(): + try: + import torch_xla.debug.metrics as met + print(met.metrics_report()) + except ImportError: + return diff --git a/fairseq/utils.py b/fairseq/utils.py index d4bf73648b..90bb8369f2 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -110,6 +110,7 @@ def _move_to_cuda(tensor): def move_to_cpu(sample): + def _move_to_cpu(tensor): # PyTorch has poor support for half tensors (float16) on CPU. # Move any such tensors to float32. @@ -120,6 +121,17 @@ def _move_to_cpu(tensor): return apply_to_sample(_move_to_cpu, sample) +def move_to_tpu(sample): + + import torch_xla.core.xla_model as xm + device = xm.xla_device() + + def _move_to_tpu(tensor): + return tensor.to(device) + + return apply_to_sample(_move_to_tpu, sample) + + def get_incremental_state( module: MultiheadAttention, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], @@ -289,6 +301,9 @@ def convert_padding_direction( def item(tensor): + # tpu-comment: making this a no-op for xla devices. + if torch.is_tensor(tensor) and tensor.device.type == 'xla': + return tensor.detach() if hasattr(tensor, "item"): return tensor.item() if hasattr(tensor, "__getitem__"): @@ -679,6 +694,27 @@ def tpu_data_loader(itr): ) +def is_xla_tensor(tensor): + return torch.is_tensor(tensor) and tensor.device.type == 'xla' + + +def index_put(tensor, indices, value): + if is_xla_tensor(tensor): + for _ in range(indices.dim(), tensor.dim()): + indices = indices.unsqueeze(-1) + if indices.size(-1) < tensor.size(-1): + indices = indices.expand_as(tensor) + tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) + else: + tensor[indices] = value + return tensor + + +def xla_device_to_cpu(dat): + import torch_xla.core.xla_model as xm + return xm._maybe_convert_to_cpu(dat) + + class CudaEnvironment(object): def __init__(self): cur_device = torch.cuda.current_device() From 46773af0f2081843ef02b98e21cac2d2781b3d5f Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Fri, 12 Feb 2021 19:31:11 +0000 Subject: [PATCH 74/82] Improve train.py and trainer.py's tpu capabilities. - add util function to mark step and send a given tensor/container to cpu. - instead of 1 transfer per tensor (N total) in `logging_outputs`, we can do 1 transfer total. - remove redundant mark_step's - remove redundant compilation check on each device. XLA metrics are global even if they come from one device. - s/GPU/device/g --- fairseq/trainer.py | 27 ++++++++++++++++----------- fairseq_cli/train.py | 5 ++++- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 1c4c532dd0..6c5ac3654a 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -532,6 +532,7 @@ def get_train_iterator( epoch=epoch, combine=combine, data_selector=data_selector, + tpu=self.tpu, ) batch_iterator = self.task.get_batch_iterator( dataset=self.task.dataset(self.cfg.dataset.train_subset), @@ -684,9 +685,7 @@ def maybe_no_sync(): # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass - import torch_xla.core.xla_model as xm - - xm.mark_step() + self._xla_markstep_and_send_to_cpu() if is_dummy_batch: if torch.is_tensor(sample_size): @@ -808,7 +807,6 @@ def maybe_no_sync(): if self.tpu: # mark step on TPUs import torch_xla.core.xla_model as xm - xm.mark_step() # only log stats every log_interval steps @@ -825,7 +823,7 @@ def maybe_no_sync(): metrics.log_scalar( "gb_total", gb_total, priority=1600, round=1, weight=0 ) - + logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm ) @@ -878,9 +876,7 @@ def valid_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" if self.tpu: import torch_xla.core.xla_model as xm - xm.rendezvous("valid_step") # wait for all workers - xm.mark_step() with torch.no_grad(): self.model.eval() @@ -923,6 +919,8 @@ def valid_step(self, sample, raise_oom=False): ) # log validation stats + if self.tpu: + logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) return logging_output @@ -1285,6 +1283,8 @@ def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): return logging_output def _check_xla_compilation(self): + if self.cfg.distributed_training.distributed_rank: + return import torch_xla.debug.metrics as met compile_stats = met.metric_data("CompileTime") @@ -1293,13 +1293,18 @@ def _check_xla_compilation(self): num_xla_compiles = compile_stats[0] if num_xla_compiles > self._num_xla_compiles: logger.warning( - "XLA compilation detected on device #{}; too many of these can lead " - "to slow training, but we expect a few in the beginning".format( - self.cfg.distributed_training.distributed_rank - ) + "XLA compilation detected; too many of these can lead " + "to slow training, but we expect a few in the beginning" ) self._num_xla_compiles = num_xla_compiles + def _xla_markstep_and_send_to_cpu(self, data=None): + import torch_xla.core.xla_model as xm + xm.mark_step() + if data is not None: + from fairseq.utils import xla_device_to_cpu + return xla_device_to_cpu(data) + def _catalog_shared_params(module, memo=None, prefix=""): if memo is None: diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index d618817e46..d48da7dda5 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -125,7 +125,7 @@ def main(cfg: FairseqConfig) -> None: ) ) logger.info( - "max tokens per GPU = {} and batch size per GPU = {}".format( + "max tokens per device = {} and max sentences per device = {}".format( cfg.dataset.max_tokens, cfg.dataset.batch_size, ) @@ -139,6 +139,9 @@ def main(cfg: FairseqConfig) -> None: # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) + if cfg.common.tpu: + import torch_xla.core.xla_model as xm + xm.rendezvous("load_checkpoint") # wait for all workers max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() From dbddbf71d3b7fffc8252620532e31b2c35031e47 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Fri, 12 Feb 2021 19:36:22 +0000 Subject: [PATCH 75/82] Adapt necessary fairseq_dataset's to support XLA. - XLA compiles every time it sees a new graph, this includes dynamic input shapes. - This commit introduces bucketing to raw_audio_dataset. - Tweaks bucket_pad_length_dataset and data_utils.py to enable this. - Computation of mask indices in wav2vec2's `forward` is costly on XLA. - Moving it to the data preparation phase, optionally for gpus, forced for tpus. --- fairseq/data/audio/raw_audio_dataset.py | 123 +++++++++++++++++++++- fairseq/data/bucket_pad_length_dataset.py | 46 ++++---- fairseq/data/data_utils.py | 22 ++++ 3 files changed, 167 insertions(+), 24 deletions(-) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 1d92e4966b..d0ff604e2b 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -12,7 +12,8 @@ import torch import torch.nn.functional as F -from .. import FairseqDataset +from .. import FairseqDataset, BaseWrapperDataset +from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes logger = logging.getLogger(__name__) @@ -27,6 +28,8 @@ def __init__( shuffle=True, pad=False, normalize=False, + compute_mask_indices=False, + **mask_compute_kwargs, ): super().__init__() @@ -39,6 +42,14 @@ def __init__( self.pad = pad self.shuffle = shuffle self.normalize = normalize + self.compute_mask_indices = compute_mask_indices + if self.compute_mask_indices: + self.mask_compute_kwargs = mask_compute_kwargs + self._features_size_map = {} + self._C = mask_compute_kwargs['encoder_embed_dim'] + self._conv_feature_layers = eval( + mask_compute_kwargs['conv_feature_layers'] + ) def __getitem__(self, index): raise NotImplementedError() @@ -70,6 +81,45 @@ def crop_to_max_size(self, wav, target_size): end = size - diff + start return wav[start:end] + def _compute_mask_indices(self, dims, padding_mask): + B, T, C = dims + mask_indices, mask_channel_indices = None, None + if self.mask_compute_kwargs['mask_prob'] > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_compute_kwargs['mask_prob'], + self.mask_compute_kwargs['mask_length'], + self.mask_compute_kwargs['mask_selection'], + self.mask_compute_kwargs['mask_other'], + min_masks=2, + no_overlap=self.mask_compute_kwargs['no_mask_overlap'], + min_space=self.mask_compute_kwargs['mask_min_space'], + ) + mask_indices = torch.from_numpy(mask_indices) + if self.mask_compute_kwargs['mask_channel_prob'] > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_compute_kwargs['mask_channel_prob'], + self.mask_compute_kwargs['mask_channel_length'], + self.mask_compute_kwargs['mask_channel_selection'], + self.mask_compute_kwargs['mask_channel_other'], + no_overlap=self.mask_compute_kwargs['no_mask_channel_overlap'], + min_space=self.mask_compute_kwargs['mask_channel_min_space'], + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .unsqueeze(1) + .expand(-1, T, -1) + ) + + return mask_indices, mask_channel_indices + + @staticmethod + def _bucket_tensor(tensor, num_pad, value): + return F.pad(tensor, (0, num_pad), value=value) + def collater(self, samples): samples = [s for s in samples if s["source"] is not None] if len(samples) == 0: @@ -101,9 +151,55 @@ def collater(self, samples): collated_sources[i] = self.crop_to_max_size(source, target_size) input = {"source": collated_sources} + out = {"id": torch.LongTensor([s["id"] for s in samples])} if self.pad: input["padding_mask"] = padding_mask - return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input} + + if hasattr(self, 'num_buckets') and self.num_buckets > 0: + assert self.pad, "Cannot bucket without padding first." + bucket = max(self._bucketed_sizes[s['id']] for s in samples) + num_pad = bucket - collated_sources.size(-1) + if num_pad: + input['source'] = self._bucket_tensor( + collated_sources, num_pad, 0 + ) + input['padding_mask'] = self._bucket_tensor( + padding_mask, num_pad, True + ) + + if self.compute_mask_indices: + B = input['source'].size(0) + T = self._get_mask_indices_dims(input['source'].size(-1)) + padding_mask_reshaped = input['padding_mask'].clone() + extra = padding_mask_reshaped.size(1) % T + if extra > 0: + padding_mask_reshaped = padding_mask_reshaped[:, :-extra] + padding_mask_reshaped = padding_mask_reshaped.view( + padding_mask_reshaped.size(0), T, -1 + ) + padding_mask_reshaped = padding_mask_reshaped.all(-1) + input['padding_count'] = ( + padding_mask_reshaped.sum(-1).max().item() + ) + mask_indices, mask_channel_indices = self._compute_mask_indices( + (B, T, self._C), padding_mask_reshaped, + ) + input["mask_indices"] = mask_indices + input["mask_channel_indices"] = mask_channel_indices + out['sample_size'] = mask_indices.sum().item() + + out["net_input"] = input + return out + + def _get_mask_indices_dims(self, size, padding=0, dilation=1): + if size not in self._features_size_map: + L_in = size + for (_, kernel_size, stride) in self._conv_feature_layers: + L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1 + L_out = 1 + L_out // stride + L_in = L_out + self._features_size_map[size] = L_out + return self._features_size_map[size] def num_tokens(self, index): return self.size(index) @@ -138,6 +234,9 @@ def __init__( shuffle=True, pad=False, normalize=False, + num_buckets=0, + compute_mask_indices=False, + **mask_compute_kwargs, ): super().__init__( sample_rate=sample_rate, @@ -146,6 +245,8 @@ def __init__( shuffle=shuffle, pad=pad, normalize=normalize, + compute_mask_indices=compute_mask_indices, + **mask_compute_kwargs, ) self.fnames = [] @@ -164,8 +265,26 @@ def __init__( self.fnames.append(items[0]) self.line_inds.add(i) self.sizes.append(sz) + self.set_bucket_info(num_buckets) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") + def set_bucket_info(self, num_buckets): + self.num_buckets = num_buckets + if self.num_buckets > 0: + self._collated_sizes = np.minimum( + np.array(self.sizes), self.max_sample_size, + ) + self.buckets = get_buckets( + self._collated_sizes, self.num_buckets, + ) + self._bucketed_sizes = get_bucketed_sizes( + self._collated_sizes, self.buckets + ) + logger.info( + f"{len(self.buckets)} bucket(s) for the audio dataset: " + f"{self.buckets}" + ) + def __getitem__(self, index): import soundfile as sf diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py index cda8834ac8..0f94100148 100644 --- a/fairseq/data/bucket_pad_length_dataset.py +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -6,6 +6,7 @@ import numpy as np import torch.nn.functional as F from fairseq.data import BaseWrapperDataset +from fairseq.data.data_utils import get_buckets, get_bucketed_sizes class BucketPadLengthDataset(BaseWrapperDataset): @@ -29,42 +30,43 @@ def __init__( num_buckets, pad_idx, left_pad, + tensor_key=None, ): super().__init__(dataset) self.pad_idx = pad_idx self.left_pad = left_pad assert num_buckets > 0 - self.buckets = np.unique( - np.percentile( - sizes, - np.linspace(0, 100, num_buckets + 1), - interpolation="lower", - )[1:] - ) + self.buckets = get_buckets(sizes, num_buckets) + self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + self._tensor_key = tensor_key - def get_bucketed_sizes(orig_sizes, buckets): - sizes = np.copy(orig_sizes) - assert np.min(sizes) >= 0 - start_val = -1 - for end_val in buckets: - mask = (sizes > start_val) & (sizes <= end_val) - sizes[mask] = end_val - start_val = end_val - return sizes + def _set_tensor(self, item, val): + if self._tensor_key is None: + return val + item[self._tensor_key] = val + return item - self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + def _get_tensor(self, item): + if self._tensor_key is None: + return item + return item[self._tensor_key] - def __getitem__(self, index): - item = self.dataset[index] - bucket_size = self._bucketed_sizes[index] - num_pad = bucket_size - item.size(-1) + def _pad(self, tensor, bucket_size, dim=-1): + num_pad = bucket_size - tensor.size(dim) return F.pad( - item, + tensor, (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), value=self.pad_idx, ) + def __getitem__(self, index): + item = self.dataset[index] + bucket_size = self._bucketed_sizes[index] + tensor = self._get_tensor(item) + padded = self._pad(tensor, bucket_size) + return self._set_tensor(item, padded) + @property def sizes(self): return self._bucketed_sizes diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 6f7561afbe..01c743c3e8 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -524,3 +524,25 @@ def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor: def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor: return ~lengths_to_padding_mask(lens) + + +def get_buckets(sizes, num_buckets): + buckets = np.unique( + np.percentile( + sizes, + np.linspace(0, 100, num_buckets + 1), + interpolation='lower', + )[1:] + ) + return buckets + + +def get_bucketed_sizes(orig_sizes, buckets): + sizes = np.copy(orig_sizes) + assert np.min(sizes) >= 0 + start_val = -1 + for end_val in buckets: + mask = (sizes > start_val) & (sizes <= end_val) + sizes[mask] = end_val + start_val = end_val + return sizes From 10f8605b218240df5512cad6c6ef59024500731c Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Fri, 12 Feb 2021 20:36:39 +0000 Subject: [PATCH 76/82] Make Wav2vec2 Criterion/Task/Model work well with XLA. - Use the util functions from previous commits in order to route the XLA codepath better. - In model - Compute mask_indices only if it's not pre-computed in data prep phase. - Remove the dynamicity in model's forward caused by mask_indices. - adjust loss computation in criterion accordingly. - Adjust sampling of negatives, by integrating the padding_count that comes from data prep phase. - future work; sampling of negatives could also be taken out of model and to the data prep phase. I experimented w/ this and observed speed gains. - Copy hydra params from model to task, in order for dataset's to have the necessary mask arguments to enable mask indices creation. --- fairseq/criterions/wav2vec_criterion.py | 65 +++++++---- fairseq/models/wav2vec/wav2vec2.py | 137 +++++++++++++++--------- fairseq/tasks/audio_pretraining.py | 60 ++++++++++- 3 files changed, 192 insertions(+), 70 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 859177f2b6..f682508cb1 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -31,7 +31,7 @@ class Wav2VecCriterionConfig(FairseqDataclass): default_factory=lambda: [], metadata={"help": "output keys to log"}, ) - +from fairseq.utils import index_put, is_xla_tensor @register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig) class Wav2vecCriterion(FairseqCriterion): @@ -52,7 +52,9 @@ def forward(self, model, sample, reduce=True): net_output = model(**sample["net_input"]) logits = model.get_logits(net_output).float() target = model.get_targets(sample, net_output) + self.xla = is_xla_tensor(logits) + # XXX: handle weights on xla. weights = None if hasattr(model, "get_target_weights") and not self.infonce: weights = model.get_target_weights(target, net_output) @@ -61,21 +63,31 @@ def forward(self, model, sample, reduce=True): losses = [] + reduction = "none" if ((not reduce) or self.xla) else "sum" if self.infonce: - loss = F.cross_entropy( - logits, - target, - reduction="sum" if reduce else "none", - ) + loss = F.cross_entropy(logits, target, reduction=reduction) else: loss = F.binary_cross_entropy_with_logits( - logits, - target.float(), - weights, - reduction="sum" if reduce else "none", + logits, target.float(), weights, reduction=reduction + ) + + if self.xla: + # tpu-comment: since dynamic shapes lead to recompilations on xla, + # we don't shrink tensors using mask_indices. + # Instead, we use mask indices to adjust loss. + mi = ( + sample['net_input']['mask_indices'] + .transpose(0, 1) # logits are transposed in `model.get_logits` + .reshape(logits.size(0)) ) + loss = (loss * mi).sum() if reduce else (loss * mi) - sample_size = target.numel() if self.infonce else target.long().sum().item() + if 'sample_size' in sample and self.infonce: + sample_size = sample['sample_size'] + elif 'mask_indices' in sample['net_input']: + sample_size = sample['net_input']['mask_indices'].sum() + else: + sample_size = target.numel() if self.infonce else target.long().sum().item() losses.append(loss.detach().clone()) if self.loss_weights is not None: @@ -95,7 +107,7 @@ def forward(self, model, sample, reduce=True): losses.append(p) logging_output = { - "loss": loss.item() if reduce else loss, + "loss": loss.item() if (reduce and not self.xla) else loss.detach(), "ntokens": sample_size, "nsentences": sample["id"].numel(), "sample_size": sample_size, @@ -111,11 +123,14 @@ def forward(self, model, sample, reduce=True): if not self.training: logging_output["target"] = target.cpu().numpy() elif lk in net_output: - logging_output[lk] = float(net_output[lk]) + value = net_output[lk] + if not is_xla_tensor(value): + value = float(value) + logging_output[lk] = value if len(losses) > 1: for i, l in enumerate(losses): - logging_output[f"loss_{i}"] = l.item() + logging_output[f"loss_{i}"] = l.item() if not self.xla else l.detach() if self.infonce: with torch.no_grad(): @@ -126,9 +141,15 @@ def forward(self, model, sample, reduce=True): assert logits.dim() > 1, logits.shape max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 - both = max & min - corr = max.long().sum().item() - both.long().sum().item() - count = max.numel() + if is_xla_tensor(logits): + max, min = max * mi, min * mi + both = max & min + corr = max.long().sum() - both.long().sum() + count = mi.sum() + else: + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = float(max.numel()) logging_output["correct"] = corr logging_output["count"] = count @@ -188,11 +209,15 @@ def reduce_metrics(logging_outputs) -> None: else: metrics.log_scalar(k, val / len(logging_outputs), round=3) - @staticmethod - def logging_outputs_can_be_summed() -> bool: + # FIXME: revert when gather based xla reduction is implemented + #@staticmethod + #def logging_outputs_can_be_summed() -> bool: + def logging_outputs_can_be_summed(self) -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ - return False + # XXX: Gather based reduction not implemented for xla yet. + # So we fall to sum based reduction for xla. + return self.xla diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 644add7b17..6999dca2d9 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -26,7 +26,7 @@ TransposeLast, ) from fairseq.modules.transformer_sentence_encoder import init_bert_params -from fairseq.utils import buffered_arange +from fairseq.utils import buffered_arange, index_put, is_xla_tensor EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) @@ -330,47 +330,52 @@ def build_model(cls, cfg: Wav2Vec2Config, task=None): return cls(cfg) - def apply_mask(self, x, padding_mask): + def apply_mask( + self, x, padding_mask, + mask_indices=None, mask_channel_indices=None, + ): B, T, C = x.shape if self.mask_prob > 0: - mask_indices = compute_mask_indices( - (B, T), - padding_mask, - self.mask_prob, - self.mask_length, - self.mask_selection, - self.mask_other, - min_masks=2, - no_overlap=self.no_mask_overlap, - min_space=self.mask_min_space, - ) - mask_indices = torch.from_numpy(mask_indices).to(x.device) - x[mask_indices] = self.mask_emb + if mask_indices is None: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) else: mask_indices = None if self.mask_channel_prob > 0: - mask_channel_indices = compute_mask_indices( - (B, C), - None, - self.mask_channel_prob, - self.mask_channel_length, - self.mask_channel_selection, - self.mask_channel_other, - no_overlap=self.no_mask_channel_overlap, - min_space=self.mask_channel_min_space, - ) - mask_channel_indices = ( - torch.from_numpy(mask_channel_indices) - .to(x.device) - .unsqueeze(1) - .expand(-1, T, -1) - ) - x[mask_channel_indices] = 0 + if mask_channel_indices is None: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x = index_put(x, mask_channel_indices, 0) return x, mask_indices - def sample_negatives(self, y, num): + def sample_negatives(self, y, num, padding_count=None): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) @@ -378,8 +383,9 @@ def sample_negatives(self, y, num): bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C + # FIXME: what happens if padding_count is specified? cross_high = tsz * bsz - high = tsz + high = tsz - (padding_count or 0) with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" @@ -436,10 +442,17 @@ def compute_preds(self, x, y, negatives): logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) - logits /= self.logit_temp + logits = logits / self.logit_temp - if neg_is_pos.any(): - logits[1:][neg_is_pos] = float("-inf") + if is_xla_tensor(logits) or neg_is_pos.any(): + fillval = -float(2**30) + if not hasattr(self, '_inftensor'): + self._inftensor = ( + torch.tensor(fillval).to(x.device) + if is_xla_tensor(logits) else + float("-inf") + ) + logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) return logits @@ -458,7 +471,11 @@ def _conv_out_length(input_length, kernel_size, stride): return input_lengths.to(torch.long) - def forward(self, source, padding_mask=None, mask=True, features_only=False): + def forward( + self, source, padding_mask=None, mask=True, features_only=False, + mask_indices=None, mask_channel_indices=None, + padding_count=None, + ): if self.feature_grad_mult > 0: features = self.feature_extractor(source) @@ -509,8 +526,14 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): features = self.project_inp(features) if mask: - x, mask_indices = self.apply_mask(features, padding_mask) - if mask_indices is not None: + x, mask_indices = self.apply_mask( + features, padding_mask, + mask_indices=mask_indices, + mask_channel_indices=mask_channel_indices, + ) + if not is_xla_tensor(x) and mask_indices is not None: + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1) ) @@ -537,12 +560,18 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): y = self.project_q(y) if self.negatives_from_everywhere: - neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) - negs, _ = self.sample_negatives(neg_cands, y.size(1)) + neg_cands = self.quantizer( + unmasked_features, produce_targets=False + )["x"] + negs, _ = self.sample_negatives( + neg_cands, y.size(1), padding_count=padding_count, + ) negs = self.project_q(negs) else: - negs, _ = self.sample_negatives(y, y.size(1)) + negs, _ = self.sample_negatives( + y, y.size(1), padding_count=padding_count, + ) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( @@ -557,12 +586,20 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): y = self.project_q(y) if self.negatives_from_everywhere: - negs, _ = self.sample_negatives(unmasked_features, y.size(1)) + negs, _ = self.sample_negatives( + unmasked_features, y.size(1), + padding_count=padding_count, + ) negs = self.project_q(negs) else: - negs, _ = self.sample_negatives(y, y.size(1)) + negs, _ = self.sample_negatives( + y, y.size(1), padding_count=padding_count, + ) - x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + if not is_xla_tensor(x): + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) @@ -571,7 +608,9 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): x = self.final_proj(x) x = self.compute_preds(x, y, negs) - result = {"x": x, "padding_mask": padding_mask, "features_pen": features_pen} + result = { + "x": x, "padding_mask": padding_mask, "features_pen": features_pen, + } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl @@ -759,11 +798,11 @@ def forward(self, x, padding_mask=None): def extract_features(self, x, padding_mask=None): if padding_mask is not None: - x[padding_mask] = 0 + x = index_put(x, padding_mask, 0) x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) - x += x_conv + x = x + x_conv if not self.layer_norm_first: x = self.layer_norm(x) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index b7b5429819..67317a410f 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -12,7 +12,7 @@ from argparse import Namespace from dataclasses import dataclass, field from typing import Optional, Any -from omegaconf import MISSING +from omegaconf import MISSING, II from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset, encoders from fairseq.dataclass import FairseqDataclass @@ -86,6 +86,37 @@ class AudioPretrainingConfig(FairseqDataclass): "adds 'prev_output_tokens' to input and appends eos to target" }, ) + num_batch_buckets: int = field( + default=0, + metadata={ + "help": "number of buckets" + }, + ) + precompute_mask_indices: bool = field( + default=False, + metadata={ + "help": "flag to compute mask indices in data preparation.", + }, + ) + # The following are needed to precompute mask and mask channel indices + # before model's forward. + mask_length: Optional[int] = II("model.mask_length") + mask_prob: Optional[float] = II("model.mask_prob") + mask_selection: Optional[str] = II("model.mask_selection") + mask_other: Optional[float] = II("model.mask_other") + no_mask_overlap: Optional[bool] = II("model.no_mask_overlap") + mask_min_space: Optional[int] = II("model.mask_min_space") + mask_channel_length: Optional[int] = II("model.mask_channel_length") + mask_channel_prob: Optional[float] = II("model.mask_channel_prob") + mask_channel_selection: Optional[str] = II("model.mask_channel_selection") + mask_channel_other: Optional[float] = II("model.mask_channel_other") + no_mask_channel_overlap: Optional[bool] = II("model.no_mask_channel_overlap") + mask_channel_min_space: Optional[int] = II("model.mask_channel_min_space") + + conv_feature_layers: Optional[str] = II("model.conv_feature_layers") + encoder_embed_dim: Optional[int] = II("model.encoder_embed_dim") + + tpu: bool = II("common.tpu") @register_task("audio_pretraining", dataclass=AudioPretrainingConfig) @@ -121,6 +152,28 @@ def load_target_dictionary(self): return Dictionary.load(dict_path) return None + def _get_mask_precompute_kwargs(self, cfg): + if self.cfg.precompute_mask_indices: + args = [ + 'mask_length', + 'mask_prob', + 'mask_selection', + 'mask_other', + 'no_mask_overlap', + 'mask_min_space', + 'mask_channel_length', + 'mask_channel_prob', + 'mask_channel_selection', + 'mask_channel_other', + 'no_mask_channel_overlap', + 'mask_channel_min_space', + 'encoder_embed_dim', + 'conv_feature_layers', + ] + return {arg: cfg[arg] for arg in args} + else: + return {} + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): data_path = self.cfg.data task_cfg = task_cfg or self.cfg @@ -138,6 +191,11 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): min_sample_size=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, + num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), + compute_mask_indices=( + self.cfg.precompute_mask_indices or self.cfg.tpu + ), + **self._get_mask_precompute_kwargs(task_cfg), ) if task_cfg.labels: From ba7ba39ad8f0ff4c1627689bc121e6c73b112d0e Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Fri, 12 Feb 2021 20:52:00 +0000 Subject: [PATCH 77/82] Pass params to model that pretraining task tries to copy from model. Per previous commit, audio_pretraining task tries to copy mask prepatarion related arguments to pass on to fairseq_dataset. For the downstream finetuning job, fairseq uses the same task, and even though the task arguments are optional, when it tries to copy from model and can't (for a GPU built model), it errors. Maybe there's a better way to do this in hydra, by passing a kwarg to `II`? --- fairseq/models/wav2vec/wav2vec2_asr.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index afa51299b6..e8a1d03eb2 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from dataclasses import dataclass, field from omegaconf import MISSING, II, open_dict -from typing import Any +from typing import Optional, Any from fairseq import checkpoint_utils, tasks, utils from fairseq.dataclass import FairseqDataclass @@ -127,7 +127,27 @@ class Wav2Vec2AsrConfig(FairseqDataclass): @dataclass class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): - pass + mask_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + mask_channel_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + conv_feature_layers: Optional[str] = field( + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + metadata={ + "help": ( + "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + ), + }, + ) + encoder_embed_dim: Optional[int] = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) @register_model("wav2vec_ctc", dataclass=Wav2Vec2CtcConfig) From 62a96aa16a54e5188e2f252dee5eaf90f29919cb Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Mon, 15 Mar 2021 22:51:56 +0000 Subject: [PATCH 78/82] Add warning if mask_channel_prob is 0 on TPUs. --- fairseq/tasks/audio_pretraining.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 67317a410f..d68ef5cd6a 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -23,6 +23,9 @@ from ..logging import metrics +logger = logging.getLogger(__name__) + + class LabelEncoder(object): def __init__(self, dictionary): self.dictionary = dictionary @@ -198,6 +201,13 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): **self._get_mask_precompute_kwargs(task_cfg), ) + if self.cfg.tpu and task_cfg['mask_channel_prob'] == 0.0: + logger.info( + "Pretraining on TPUs may suffer convergence " + "issues when training with `mask_channel_prob` value of " + "0. You may want to set this to a low value close to 0." + ) + if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") with open(label_path, "r") as f: From 388f420a095f00bb3a635a7cc36a8603799b0e44 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Tue, 16 Mar 2021 17:41:56 +0000 Subject: [PATCH 79/82] Add missing import. --- fairseq/tasks/audio_pretraining.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index d68ef5cd6a..4ec3ef79ce 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -5,6 +5,7 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. +import logging import os import sys import torch From f2baa7e67e5cf014c123c56865bfe94200961d76 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Wed, 17 Mar 2021 19:55:41 +0000 Subject: [PATCH 80/82] Added content to README about tpus and examples. --- examples/wav2vec/README.md | 35 +++++++++ .../wav2vec2_large_librivox_tpu.yaml | 72 +++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index e95f292b51..b2fa6e65fb 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -222,6 +222,41 @@ $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 - --max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test ``` +### Run wav2vec2 pre-training on Google Cloud TPUs: + +Wav2Vec2 is now supported on TPUs! It's currently pre-training only. + +#### Using command line arguments on a v3-8: + +``` +$ OMP_NUM_THREADS=1 python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec2 --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test \ +--tpu --distributed-world-size 8 --num-batch-buckets 3 --enable-padding \ +--encoder-layerdrop 0 --mask-channel-prob 0.1 +``` + +#### Using command line arguments on a pod slice (v3-N with N > 8): + +Make sure to fill in the variables below. + +``` +$ python -m torch_xla.distributed.xla_dist \ + --tpu ${TPUNAME} --conda-env=torch-xla-${TORCH_XLA_VERSION} --env OMP_NUM_THREADS=1 \ + -- \ +python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec2 --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test \ +--tpu --distributed-world-size ${WORLD_SIZE} --num-batch-buckets 3 --enable-padding \ +--encoder-layerdrop 0 --mask-channel-prob 0.1 +``` + ### Extract embeddings from the downstream task data: ``` diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml new file mode 100644 index 0000000000..5d1e026078 --- /dev/null +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml @@ -0,0 +1,72 @@ +# @package _group_ + +common: + tpu: true + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: ??? + max_sample_size: 320000 + min_sample_size: 32000 + normalize: true + num_batch_buckets: 3 + precompute_mask_indices: true + +dataset: + num_workers: 6 + max_tokens: 1200000 + skip_invalid_size_inputs_valid_test: true + +distributed_training: + distributed_world_size: 128 + ddp_backend: legacy_ddp + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 0] + +optimization: + max_update: 1000000 + lr: [0.005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: wav2vec2 + quantize_targets: true + extractor_mode: layer_norm + layer_norm_first: true + final_dim: 768 + latent_temp: [2.0,0.1,0.999995] + encoder_layerdrop: 0.00 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + conv_bias: true + + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + + feature_grad_mult: 1.0 + From 1cee791cdddb88b767005122e1d94869878a5617 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Wed, 17 Mar 2021 22:27:33 +0000 Subject: [PATCH 81/82] Default to mask precomputation in dataset when running on tpus. --- fairseq/tasks/audio_pretraining.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 4ec3ef79ce..df073a1814 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -152,12 +152,14 @@ def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): def load_target_dictionary(self): if self.cfg.labels: - dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt") + dict_path = os.path.join( + self.cfg.data, f"dict.{self.cfg.labels}.txt" + ) return Dictionary.load(dict_path) return None def _get_mask_precompute_kwargs(self, cfg): - if self.cfg.precompute_mask_indices: + if self.cfg.precompute_mask_indices or self.cfg.tpu: args = [ 'mask_length', 'mask_prob', @@ -178,7 +180,9 @@ def _get_mask_precompute_kwargs(self, cfg): else: return {} - def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + def load_dataset( + self, split: str, task_cfg: FairseqDataclass = None, **kwargs + ): data_path = self.cfg.data task_cfg = task_cfg or self.cfg From 42932be1e73a4e2bb57f3ae67952a1fa1d67711e Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Fri, 19 Mar 2021 18:31:16 +0000 Subject: [PATCH 82/82] Add working example of hydra + config. --- examples/wav2vec/README.md | 19 ++++- .../wav2vec2_large_librivox_tpu-pod.yaml | 71 +++++++++++++++++++ .../wav2vec2_large_librivox_tpu.yaml | 17 +++-- 3 files changed, 97 insertions(+), 10 deletions(-) create mode 100644 examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index b2fa6e65fb..bfed3913cf 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -226,6 +226,15 @@ $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 - Wav2Vec2 is now supported on TPUs! It's currently pre-training only. +#### Using hydra on a v3-8: + +``` +$ OMP_NUM_THREADS=1 fairseq-hydra-train \ + task.data=/manifest/path \ + --config-dir /PATH/TO/FAIRSEQ/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox_tpu.yaml +``` + #### Using command line arguments on a v3-8: ``` @@ -239,9 +248,17 @@ $ OMP_NUM_THREADS=1 python train.py /manifest/path --save-dir /model/path --num- --encoder-layerdrop 0 --mask-channel-prob 0.1 ``` +#### Using hydra on a pod slice (v3-N with N > 8): + +``` +$ OMP_NUM_THREADS=1 fairseq-hydra-train \ + task.data=/manifest/path \ + --config-dir /PATH/TO/FAIRSEQ/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox_tpu-pod.yaml # edit distributed-world-size accordingly +``` + #### Using command line arguments on a pod slice (v3-N with N > 8): -Make sure to fill in the variables below. ``` $ python -m torch_xla.distributed.xla_dist \ diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml new file mode 100644 index 0000000000..676c9fe339 --- /dev/null +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml @@ -0,0 +1,71 @@ +# @package _group_ + +common: + tpu: true + fp16: false + log_format: json + log_interval: 10 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: ??? + max_sample_size: 250000 + min_sample_size: 32000 + normalize: true + num_batch_buckets: 3 + precompute_mask_indices: true + enable_padding: true + +dataset: + num_workers: 6 + max_tokens: 1200000 + skip_invalid_size_inputs_valid_test: true + +distributed_training: + distributed_world_size: 128 + ddp_backend: legacy_ddp + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 0] + +optimization: + max_update: 1000000 + lr: [0.005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: wav2vec2 + quantize_targets: true + extractor_mode: layer_norm + layer_norm_first: true + final_dim: 256 + latent_temp: [2.0,0.1,0.999995] + encoder_layerdrop: 0.00 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + conv_bias: true + + mask_channel_prob: 0.1 + mask_prob: 0.1 + + feature_grad_mult: 1.0 + diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml index 5d1e026078..c45c4d9117 100644 --- a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml @@ -2,9 +2,9 @@ common: tpu: true - fp16: true + fp16: false log_format: json - log_interval: 200 + log_interval: 10 checkpoint: save_interval_updates: 25000 @@ -14,11 +14,12 @@ checkpoint: task: _name: audio_pretraining data: ??? - max_sample_size: 320000 + max_sample_size: 250000 min_sample_size: 32000 normalize: true num_batch_buckets: 3 precompute_mask_indices: true + enable_padding: true dataset: num_workers: 6 @@ -26,7 +27,7 @@ dataset: skip_invalid_size_inputs_valid_test: true distributed_training: - distributed_world_size: 128 + distributed_world_size: 8 ddp_backend: legacy_ddp criterion: @@ -54,7 +55,7 @@ model: quantize_targets: true extractor_mode: layer_norm layer_norm_first: true - final_dim: 768 + final_dim: 256 latent_temp: [2.0,0.1,0.999995] encoder_layerdrop: 0.00 dropout_input: 0.0 @@ -63,10 +64,8 @@ model: attention_dropout: 0.0 conv_bias: true - encoder_layers: 24 - encoder_embed_dim: 1024 - encoder_ffn_embed_dim: 4096 - encoder_attention_heads: 16 + mask_channel_prob: 0.1 + mask_prob: 0.1 feature_grad_mult: 1.0